Enzyme main
Loading...
Searching...
No Matches
ImpulseOps.cpp
Go to the documentation of this file.
1//===- ImpulseOps.cpp - Impulse dialect ops ----------------------*- C++
2//-*-===//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/IR/Builders.h"
13
14using namespace mlir;
15using namespace mlir::impulse;
16
17//===----------------------------------------------------------------------===//
18// SampleOp
19//===----------------------------------------------------------------------===//
20
21LogicalResult SampleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
22 auto global =
23 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
24 if (!global)
25 return emitOpError("'")
26 << getFn() << "' does not reference a valid global funcOp";
27
28 if (getLogpdfAttr()) {
29 auto global = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(
30 *this, getLogpdfAttr());
31 if (!global)
32 return emitOpError("'")
33 << getLogpdf().value() << "' does not reference a valid global "
34 << "funcOp";
35 }
36
37 return success();
38}
39
40//===----------------------------------------------------------------------===//
41// GenerateOp
42//===----------------------------------------------------------------------===//
43
44LogicalResult GenerateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
45 auto global =
46 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
47 if (!global)
48 return emitOpError("'")
49 << getFn() << "' does not reference a valid global funcOp";
50
51 return success();
52}
53
54//===----------------------------------------------------------------------===//
55// SimulateOp
56//===----------------------------------------------------------------------===//
57
58LogicalResult SimulateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
59 auto global =
60 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
61 if (!global)
62 return emitOpError("'")
63 << getFn() << "' does not reference a valid global funcOp";
64
65 return success();
66}
67
68//===----------------------------------------------------------------------===//
69// RegenerateOp
70//===----------------------------------------------------------------------===//
71
72LogicalResult
73RegenerateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
74 auto global =
75 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
76 if (!global)
77 return emitOpError("'")
78 << getFn() << "' does not reference a valid global funcOp";
79
80 return success();
81}
82
83//===----------------------------------------------------------------------===//
84// MHOp
85//===----------------------------------------------------------------------===//
86
87LogicalResult MHOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
88 auto global =
89 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
90 if (!global)
91 return emitOpError("'")
92 << getFn() << "' does not reference a valid global funcOp";
93
94 return success();
95}
96
97//===----------------------------------------------------------------------===//
98// InferOp
99//===----------------------------------------------------------------------===//
100
101LogicalResult InferOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
102 if (auto fnAttr = getFnAttr()) {
103 auto global =
104 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, fnAttr);
105 if (!global)
106 return emitOpError("'")
107 << getFn().value() << "' does not reference a valid global funcOp";
108 }
109
110 if (auto logpdfAttr = getLogpdfFnAttr()) {
111 auto global =
112 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, logpdfAttr);
113 if (!global)
114 return emitOpError("'") << logpdfAttr.getValue()
115 << "' does not reference a valid global funcOp";
116 }
117
118 return success();
119}
120
121LogicalResult InferOp::verify() {
122 bool hasHMC = getHmcConfig().has_value();
123 bool hasNUTS = getNutsConfig().has_value();
124
125 if (hasHMC + hasNUTS != 1) {
126 return emitOpError(
127 "Exactly one of hmc_config or nuts_config must be specified");
128 }
129
130 if (!getFnAttr() && !getLogpdfFnAttr()) {
131 return emitOpError("one of `fn` or `logpdf_fn` must be specified");
132 }
133
134 if (getFnAttr() && getLogpdfFnAttr()) {
135 return emitOpError("specifying both `fn` and `logpdf_fn` is unsupported");
136 }
137
138 if (getLogpdfFnAttr() && !getInitialPosition()) {
139 return emitOpError(
140 "custom logpdf mode requires `initial_position` to be provided");
141 }
142
143 return success();
144}