3#include "mlir/CAPI/IR.h"
11 mlir::impulse::RngDistribution rngDist;
14 rngDist = mlir::impulse::RngDistribution::UNIFORM;
17 rngDist = mlir::impulse::RngDistribution::NORMAL;
20 rngDist = mlir::impulse::RngDistribution::MULTINORMAL;
23 return wrap(mlir::impulse::RngDistributionAttr::get(unwrap(ctx), rngDist));
27 bool hasLowerBound,
double lowerBound,
28 bool hasUpperBound,
double upperBound) {
29 auto *mlirCtx = unwrap(ctx);
31 mlir::impulse::SupportKind supportKind;
34 supportKind = mlir::impulse::SupportKind::REAL;
37 supportKind = mlir::impulse::SupportKind::POSITIVE;
40 supportKind = mlir::impulse::SupportKind::UNIT_INTERVAL;
43 supportKind = mlir::impulse::SupportKind::INTERVAL;
46 supportKind = mlir::impulse::SupportKind::GREATER_THAN;
49 supportKind = mlir::impulse::SupportKind::LESS_THAN;
53 mlir::FloatAttr lowerAttr;
56 mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), lowerBound);
58 mlir::FloatAttr upperAttr;
61 mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), upperBound);
63 return wrap(mlir::impulse::SupportAttr::get(mlirCtx, supportKind, lowerAttr,
68 bool adaptStepSize,
bool adaptMassMatrix) {
69 auto *mlirCtx = unwrap(ctx);
70 auto trajectoryLengthAttr =
71 mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), trajectoryLength);
73 return wrap(mlir::impulse::HMCConfigAttr::get(
74 mlirCtx, trajectoryLengthAttr, adaptStepSize, adaptMassMatrix));
78 bool hasMaxDeltaEnergy,
79 double maxDeltaEnergy,
bool adaptStepSize,
80 bool adaptMassMatrix) {
81 auto *mlirCtx = unwrap(ctx);
83 mlir::FloatAttr maxDeltaEnergyAttr;
84 if (hasMaxDeltaEnergy)
86 mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), maxDeltaEnergy);
88 return wrap(mlir::impulse::NUTSConfigAttr::get(
89 mlirCtx, maxTreeDepth, maxDeltaEnergyAttr, adaptStepSize,
94 return wrap(mlir::impulse::SymbolAttr::get(unwrap(ctx), ptr));
MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind, bool hasLowerBound, double lowerBound, bool hasUpperBound, double upperBound)
MlirAttribute enzymeRngDistributionAttrGet(MlirContext ctx, EnzymeRngDistribution dist)
MlirAttribute enzymeSymbolAttrGet(MlirContext ctx, uint64_t ptr)
MlirAttribute enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth, bool hasMaxDeltaEnergy, double maxDeltaEnergy, bool adaptStepSize, bool adaptMassMatrix)
MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength, bool adaptStepSize, bool adaptMassMatrix)
@ EnzymeRngDistribution_MultiNormal
@ EnzymeRngDistribution_Uniform
@ EnzymeRngDistribution_Normal
@ EnzymeSupportKind_UnitInterval
@ EnzymeSupportKind_LessThan
@ EnzymeSupportKind_Interval
@ EnzymeSupportKind_GreaterThan
@ EnzymeSupportKind_Positive