Enzyme main
Loading...
Searching...
No Matches
TransformUtils.cpp
Go to the documentation of this file.
1//===- TransformUtils.cpp - Constraint transforms for HMC ------* C++ -*-===//
2//
3// This file implements constraint transforms for HMC inference.
4//
5// Reference:
6// https://github.com/pyro-ppl/numpyro/blob/master/numpyro/distributions/transforms.py
7//
8//===----------------------------------------------------------------------===//
9
10#include "TransformUtils.h"
11
13#include "Dialect/Ops.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Math/IR/Math.h"
16
17#include <cmath>
18
19using namespace mlir;
20using namespace mlir::enzyme;
21using namespace mlir::enzyme::transforms;
22
23Value transforms::createLogit(OpBuilder &builder, Location loc, Value x) {
24 auto xType = cast<RankedTensorType>(x.getType());
25 auto elemType = xType.getElementType();
26
27 auto oneConst = arith::ConstantOp::create(
28 builder, loc, xType,
29 DenseElementsAttr::get(xType, builder.getFloatAttr(elemType, 1.0)));
30 auto oneMinusX = arith::SubFOp::create(builder, loc, oneConst, x);
31 auto logX = math::LogOp::create(builder, loc, x);
32 auto logOneMinusX = math::LogOp::create(builder, loc, oneMinusX);
33 return arith::SubFOp::create(builder, loc, logX, logOneMinusX);
34}
35
36Value transforms::createLogSigmoid(OpBuilder &builder, Location loc, Value x) {
37 auto xType = cast<RankedTensorType>(x.getType());
38 auto elemType = xType.getElementType();
39
40 // log_sigmoid(x) = -softplus(-x) = -log_add_exp(-x, 0)
41 auto negX = arith::NegFOp::create(builder, loc, x);
42 auto zeroConst = arith::ConstantOp::create(
43 builder, loc, xType,
44 DenseElementsAttr::get(xType, builder.getFloatAttr(elemType, 0.0)));
45 auto softplusNegX =
46 impulse::LogAddExpOp::create(builder, loc, xType, negX, zeroConst);
47 return arith::NegFOp::create(builder, loc, softplusNegX);
48}
49
50int64_t transforms::getUnconstrainedSize(int64_t constrainedSize,
51 impulse::SupportKind kind) {
52 switch (kind) {
53 case impulse::SupportKind::REAL:
54 case impulse::SupportKind::POSITIVE:
55 case impulse::SupportKind::UNIT_INTERVAL:
56 case impulse::SupportKind::INTERVAL:
57 case impulse::SupportKind::GREATER_THAN:
58 case impulse::SupportKind::LESS_THAN:
59 return constrainedSize;
60 }
61 llvm_unreachable("Unknown SupportKind");
62}
63
64int64_t transforms::getConstrainedSize(int64_t unconstrainedSize,
65 impulse::SupportKind kind) {
66 switch (kind) {
67 case impulse::SupportKind::REAL:
68 case impulse::SupportKind::POSITIVE:
69 case impulse::SupportKind::UNIT_INTERVAL:
70 case impulse::SupportKind::INTERVAL:
71 case impulse::SupportKind::GREATER_THAN:
72 case impulse::SupportKind::LESS_THAN:
73 return unconstrainedSize;
74 }
75 llvm_unreachable("Unknown SupportKind");
76}
77
78Value transforms::unconstrain(OpBuilder &builder, Location loc,
79 Value constrained, impulse::SupportAttr support) {
80 auto kind = support.getKind();
81 auto xType = cast<RankedTensorType>(constrained.getType());
82 auto elemType = xType.getElementType();
83
84 switch (kind) {
85 case impulse::SupportKind::REAL:
86 // Identity
87 return constrained;
88 case impulse::SupportKind::POSITIVE:
89 return math::LogOp::create(builder, loc, constrained);
90 case impulse::SupportKind::UNIT_INTERVAL:
91 return createLogit(builder, loc, constrained);
92 case impulse::SupportKind::INTERVAL: {
93 // z = logit((x - a) / (b - a))
94 auto lowerAttr = support.getLowerBound();
95 auto upperAttr = support.getUpperBound();
96 if (!lowerAttr || !upperAttr) {
97 llvm_unreachable("INTERVAL support requires lower and upper bounds");
98 }
99 double lower = lowerAttr.getValueAsDouble();
100 double upper = upperAttr.getValueAsDouble();
101
102 auto lowerConst = arith::ConstantOp::create(
103 builder, loc, xType,
104 DenseElementsAttr::get(xType, builder.getFloatAttr(elemType, lower)));
105 auto scaleConst = arith::ConstantOp::create(
106 builder, loc, xType,
107 DenseElementsAttr::get(xType,
108 builder.getFloatAttr(elemType, upper - lower)));
109
110 auto shifted = arith::SubFOp::create(builder, loc, constrained, lowerConst);
111 auto normalized = arith::DivFOp::create(builder, loc, shifted, scaleConst);
112 return createLogit(builder, loc, normalized);
113 }
114 case impulse::SupportKind::GREATER_THAN: {
115 // z = log(x - lower)
116 auto lowerAttr = support.getLowerBound();
117 if (!lowerAttr) {
118 llvm_unreachable("GREATER_THAN support requires lower bound");
119 }
120 double lower = lowerAttr.getValueAsDouble();
121
122 auto lowerConst = arith::ConstantOp::create(
123 builder, loc, xType,
124 DenseElementsAttr::get(xType, builder.getFloatAttr(elemType, lower)));
125 auto shifted = arith::SubFOp::create(builder, loc, constrained, lowerConst);
126 return math::LogOp::create(builder, loc, shifted);
127 }
128 case impulse::SupportKind::LESS_THAN: {
129 // z = log(upper - x)
130 auto upperAttr = support.getUpperBound();
131 if (!upperAttr) {
132 llvm_unreachable("LESS_THAN support requires upper bound");
133 }
134 double upper = upperAttr.getValueAsDouble();
135
136 auto upperConst = arith::ConstantOp::create(
137 builder, loc, xType,
138 DenseElementsAttr::get(xType, builder.getFloatAttr(elemType, upper)));
139 auto shifted = arith::SubFOp::create(builder, loc, upperConst, constrained);
140 return math::LogOp::create(builder, loc, shifted);
141 }
142 }
143 llvm_unreachable("Unknown SupportKind");
144}
145
146Value transforms::constrain(OpBuilder &builder, Location loc,
147 Value unconstrained, impulse::SupportAttr support) {
148 auto kind = support.getKind();
149 auto zType = cast<RankedTensorType>(unconstrained.getType());
150 auto elemType = zType.getElementType();
151
152 switch (kind) {
153 case impulse::SupportKind::REAL:
154 // Identity
155 return unconstrained;
156 case impulse::SupportKind::POSITIVE:
157 return math::ExpOp::create(builder, loc, unconstrained);
158 case impulse::SupportKind::UNIT_INTERVAL:
159 // x = sigmoid(z)
160 return impulse::LogisticOp::create(builder, loc, unconstrained.getType(),
161 unconstrained);
162 case impulse::SupportKind::INTERVAL: {
163 // x = a + (b - a) * sigmoid(z)
164 auto lowerAttr = support.getLowerBound();
165 auto upperAttr = support.getUpperBound();
166 if (!lowerAttr || !upperAttr) {
167 llvm_unreachable("INTERVAL support requires lower and upper bounds");
168 }
169 double lower = lowerAttr.getValueAsDouble();
170 double upper = upperAttr.getValueAsDouble();
171
172 auto lowerConst = arith::ConstantOp::create(
173 builder, loc, zType,
174 DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, lower)));
175 auto scaleConst = arith::ConstantOp::create(
176 builder, loc, zType,
177 DenseElementsAttr::get(zType,
178 builder.getFloatAttr(elemType, upper - lower)));
179
180 auto sigmoid = impulse::LogisticOp::create(
181 builder, loc, unconstrained.getType(), unconstrained);
182 auto scaled = arith::MulFOp::create(builder, loc, scaleConst, sigmoid);
183 return arith::AddFOp::create(builder, loc, lowerConst, scaled);
184 }
185 case impulse::SupportKind::GREATER_THAN: {
186 // x = lower + exp(z)
187 auto lowerAttr = support.getLowerBound();
188 if (!lowerAttr) {
189 llvm_unreachable("GREATER_THAN support requires lower bound");
190 }
191 double lower = lowerAttr.getValueAsDouble();
192
193 auto lowerConst = arith::ConstantOp::create(
194 builder, loc, zType,
195 DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, lower)));
196 auto expZ = math::ExpOp::create(builder, loc, unconstrained);
197 return arith::AddFOp::create(builder, loc, lowerConst, expZ);
198 }
199 case impulse::SupportKind::LESS_THAN: {
200 // x = upper - exp(z)
201 auto upperAttr = support.getUpperBound();
202 if (!upperAttr) {
203 llvm_unreachable("LESS_THAN support requires upper bound");
204 }
205 double upper = upperAttr.getValueAsDouble();
206
207 auto upperConst = arith::ConstantOp::create(
208 builder, loc, zType,
209 DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, upper)));
210 auto expZ = math::ExpOp::create(builder, loc, unconstrained);
211 return arith::SubFOp::create(builder, loc, upperConst, expZ);
212 }
213 }
214
215 llvm_unreachable("Unknown SupportKind");
216}
217
218Value transforms::logAbsDetJacobian(OpBuilder &builder, Location loc,
219 Value unconstrained,
220 impulse::SupportAttr support) {
221 auto kind = support.getKind();
222 auto zType = cast<RankedTensorType>(unconstrained.getType());
223 auto elemType = zType.getElementType();
224 auto scalarType = RankedTensorType::get({}, elemType);
225
226 switch (kind) {
227 case impulse::SupportKind::REAL: {
228 // Identity: log|det(I)| = 0
229 return arith::ConstantOp::create(
230 builder, loc, scalarType,
231 DenseElementsAttr::get(scalarType,
232 builder.getFloatAttr(elemType, 0.0)));
233 }
234 case impulse::SupportKind::POSITIVE: {
235 // x = exp(z), dx/dz = exp(z)
236 // log|det(J)| = sum(log|dx_i/dz_i|) = sum(z)
237 auto ones = arith::ConstantOp::create(
238 builder, loc, zType,
239 DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, 1.0)));
240 return impulse::DotOp::create(
241 builder, loc, scalarType, unconstrained, ones,
242 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
243 builder.getDenseI64ArrayAttr({0}), builder.getDenseI64ArrayAttr({0}));
244 }
245 case impulse::SupportKind::UNIT_INTERVAL: {
246 // x = sigmoid(z), dx/dz = sigmoid(z) * (1 - sigmoid(z))
247 // log|det(J)| = sum(log(sigmoid(z)) + log(1 - sigmoid(z)))
248 // = sum(log_sigmoid(z) + log_sigmoid(-z))
249 auto logSigZ = createLogSigmoid(builder, loc, unconstrained);
250 auto negZ = arith::NegFOp::create(builder, loc, unconstrained);
251 auto logSigNegZ = createLogSigmoid(builder, loc, negZ);
252 auto logProduct = arith::AddFOp::create(builder, loc, logSigZ, logSigNegZ);
253 auto ones = arith::ConstantOp::create(
254 builder, loc, zType,
255 DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, 1.0)));
256 return impulse::DotOp::create(
257 builder, loc, scalarType, logProduct, ones,
258 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
259 builder.getDenseI64ArrayAttr({0}), builder.getDenseI64ArrayAttr({0}));
260 }
261 case impulse::SupportKind::INTERVAL: {
262 // log|det(J)| = sum(log_sigmoid(z) + log(1 - sigmoid(z))) + n*log(scale)
263 auto lowerAttr = support.getLowerBound();
264 auto upperAttr = support.getUpperBound();
265 if (!lowerAttr || !upperAttr) {
266 llvm_unreachable("INTERVAL support requires lower and upper bounds");
267 }
268 double scale = upperAttr.getValueAsDouble() - lowerAttr.getValueAsDouble();
269
270 auto logSigZ = createLogSigmoid(builder, loc, unconstrained);
271 auto negZ = arith::NegFOp::create(builder, loc, unconstrained);
272 auto logSigNegZ = createLogSigmoid(builder, loc, negZ);
273 auto logProduct = arith::AddFOp::create(builder, loc, logSigZ, logSigNegZ);
274 auto ones = arith::ConstantOp::create(
275 builder, loc, zType,
276 DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, 1.0)));
277 auto sumLogProduct = impulse::DotOp::create(
278 builder, loc, scalarType, logProduct, ones,
279 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
280 builder.getDenseI64ArrayAttr({0}), builder.getDenseI64ArrayAttr({0}));
281
282 int64_t n = zType.getNumElements();
283 double logScaleTerm = n * std::log(scale);
284 auto logScaleConst = arith::ConstantOp::create(
285 builder, loc, scalarType,
286 DenseElementsAttr::get(scalarType,
287 builder.getFloatAttr(elemType, logScaleTerm)));
288 return arith::AddFOp::create(builder, loc, sumLogProduct, logScaleConst);
289 }
290 case impulse::SupportKind::GREATER_THAN:
291 case impulse::SupportKind::LESS_THAN: {
292 // log|det(J)| = sum(z)
293 auto ones = arith::ConstantOp::create(
294 builder, loc, zType,
295 DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, 1.0)));
296 return impulse::DotOp::create(
297 builder, loc, scalarType, unconstrained, ones,
298 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
299 builder.getDenseI64ArrayAttr({0}), builder.getDenseI64ArrayAttr({0}));
300 }
301 }
302
303 llvm_unreachable("Unknown SupportKind");
304}
Value logAbsDetJacobian(OpBuilder &builder, Location loc, Value unconstrained, impulse::SupportAttr support)
Compute log |det J| of the transform from unconstrained to constrained.
Value unconstrain(OpBuilder &builder, Location loc, Value constrained, impulse::SupportAttr support)
Transform from constrained to unconstrained space.
int64_t getConstrainedSize(int64_t unconstrainedSize, impulse::SupportKind kind)
Get the constrained size given an unconstrained size and support kind.
Value constrain(OpBuilder &builder, Location loc, Value unconstrained, impulse::SupportAttr support)
Transform from unconstrained to constrained space.
int64_t getUnconstrainedSize(int64_t constrainedSize, impulse::SupportKind kind)
Get the unconstrained size given a constrained size and support kind.
Value createLogSigmoid(OpBuilder &builder, Location loc, Value x)
Value createLogit(OpBuilder &builder, Location loc, Value x)