Enzyme main
Loading...
Searching...
No Matches
EnzymeMLIR.cpp
Go to the documentation of this file.
1#include "EnzymeMLIR.h"
2
3#include "mlir/CAPI/IR.h"
4
5#include "Dialect/Dialect.h"
7#include "Dialect/Ops.h"
8
9MlirAttribute enzymeRngDistributionAttrGet(MlirContext ctx,
11 mlir::impulse::RngDistribution rngDist;
12 switch (dist) {
14 rngDist = mlir::impulse::RngDistribution::UNIFORM;
15 break;
17 rngDist = mlir::impulse::RngDistribution::NORMAL;
18 break;
20 rngDist = mlir::impulse::RngDistribution::MULTINORMAL;
21 break;
22 }
23 return wrap(mlir::impulse::RngDistributionAttr::get(unwrap(ctx), rngDist));
24}
25
26MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind,
27 bool hasLowerBound, double lowerBound,
28 bool hasUpperBound, double upperBound) {
29 auto *mlirCtx = unwrap(ctx);
30
31 mlir::impulse::SupportKind supportKind;
32 switch (kind) {
34 supportKind = mlir::impulse::SupportKind::REAL;
35 break;
37 supportKind = mlir::impulse::SupportKind::POSITIVE;
38 break;
40 supportKind = mlir::impulse::SupportKind::UNIT_INTERVAL;
41 break;
43 supportKind = mlir::impulse::SupportKind::INTERVAL;
44 break;
46 supportKind = mlir::impulse::SupportKind::GREATER_THAN;
47 break;
49 supportKind = mlir::impulse::SupportKind::LESS_THAN;
50 break;
51 }
52
53 mlir::FloatAttr lowerAttr;
54 if (hasLowerBound)
55 lowerAttr =
56 mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), lowerBound);
57
58 mlir::FloatAttr upperAttr;
59 if (hasUpperBound)
60 upperAttr =
61 mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), upperBound);
62
63 return wrap(mlir::impulse::SupportAttr::get(mlirCtx, supportKind, lowerAttr,
64 upperAttr));
65}
66
67MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength,
68 bool adaptStepSize, bool adaptMassMatrix) {
69 auto *mlirCtx = unwrap(ctx);
70 auto trajectoryLengthAttr =
71 mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), trajectoryLength);
72
73 return wrap(mlir::impulse::HMCConfigAttr::get(
74 mlirCtx, trajectoryLengthAttr, adaptStepSize, adaptMassMatrix));
75}
76
77MlirAttribute enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth,
78 bool hasMaxDeltaEnergy,
79 double maxDeltaEnergy, bool adaptStepSize,
80 bool adaptMassMatrix) {
81 auto *mlirCtx = unwrap(ctx);
82
83 mlir::FloatAttr maxDeltaEnergyAttr;
84 if (hasMaxDeltaEnergy)
85 maxDeltaEnergyAttr =
86 mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), maxDeltaEnergy);
87
88 return wrap(mlir::impulse::NUTSConfigAttr::get(
89 mlirCtx, maxTreeDepth, maxDeltaEnergyAttr, adaptStepSize,
90 adaptMassMatrix));
91}
92
93MlirAttribute enzymeSymbolAttrGet(MlirContext ctx, uint64_t ptr) {
94 return wrap(mlir::impulse::SymbolAttr::get(unwrap(ctx), ptr));
95}
MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind, bool hasLowerBound, double lowerBound, bool hasUpperBound, double upperBound)
MlirAttribute enzymeRngDistributionAttrGet(MlirContext ctx, EnzymeRngDistribution dist)
Definition EnzymeMLIR.cpp:9
MlirAttribute enzymeSymbolAttrGet(MlirContext ctx, uint64_t ptr)
MlirAttribute enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth, bool hasMaxDeltaEnergy, double maxDeltaEnergy, bool adaptStepSize, bool adaptMassMatrix)
MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength, bool adaptStepSize, bool adaptMassMatrix)
EnzymeRngDistribution
Definition EnzymeMLIR.h:18
@ EnzymeRngDistribution_MultiNormal
Definition EnzymeMLIR.h:21
@ EnzymeRngDistribution_Uniform
Definition EnzymeMLIR.h:19
@ EnzymeRngDistribution_Normal
Definition EnzymeMLIR.h:20
EnzymeSupportKind
Definition EnzymeMLIR.h:27
@ EnzymeSupportKind_Real
Definition EnzymeMLIR.h:28
@ EnzymeSupportKind_UnitInterval
Definition EnzymeMLIR.h:30
@ EnzymeSupportKind_LessThan
Definition EnzymeMLIR.h:33
@ EnzymeSupportKind_Interval
Definition EnzymeMLIR.h:31
@ EnzymeSupportKind_GreaterThan
Definition EnzymeMLIR.h:32
@ EnzymeSupportKind_Positive
Definition EnzymeMLIR.h:29