11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/IR/Builders.h"
21LogicalResult SampleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
23 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*
this, getFnAttr());
25 return emitOpError(
"'")
26 << getFn() <<
"' does not reference a valid global funcOp";
28 if (getLogpdfAttr()) {
29 auto global = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(
30 *
this, getLogpdfAttr());
32 return emitOpError(
"'")
33 << getLogpdf().value() <<
"' does not reference a valid global "
44LogicalResult GenerateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
46 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*
this, getFnAttr());
48 return emitOpError(
"'")
49 << getFn() <<
"' does not reference a valid global funcOp";
58LogicalResult SimulateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
60 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*
this, getFnAttr());
62 return emitOpError(
"'")
63 << getFn() <<
"' does not reference a valid global funcOp";
73RegenerateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
75 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*
this, getFnAttr());
77 return emitOpError(
"'")
78 << getFn() <<
"' does not reference a valid global funcOp";
87LogicalResult MHOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
89 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*
this, getFnAttr());
91 return emitOpError(
"'")
92 << getFn() <<
"' does not reference a valid global funcOp";
101LogicalResult InferOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
102 if (
auto fnAttr = getFnAttr()) {
104 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*
this, fnAttr);
106 return emitOpError(
"'")
107 << getFn().value() <<
"' does not reference a valid global funcOp";
110 if (
auto logpdfAttr = getLogpdfFnAttr()) {
112 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*
this, logpdfAttr);
114 return emitOpError(
"'") << logpdfAttr.getValue()
115 <<
"' does not reference a valid global funcOp";
121LogicalResult InferOp::verify() {
122 bool hasHMC = getHmcConfig().has_value();
123 bool hasNUTS = getNutsConfig().has_value();
125 if (hasHMC + hasNUTS != 1) {
127 "Exactly one of hmc_config or nuts_config must be specified");
130 if (!getFnAttr() && !getLogpdfFnAttr()) {
131 return emitOpError(
"one of `fn` or `logpdf_fn` must be specified");
134 if (getFnAttr() && getLogpdfFnAttr()) {
135 return emitOpError(
"specifying both `fn` and `logpdf_fn` is unsupported");
138 if (getLogpdfFnAttr() && !getInitialPosition()) {
140 "custom logpdf mode requires `initial_position` to be provided");