Enzyme main
Loading...
Searching...
No Matches
HMCUtils.cpp
Go to the documentation of this file.
1//===- HMCUtils.cpp - Utilities for HMC/NUTS inference -------* C++ -*-===//
2//
3// This file implements utility functions for Hamiltonian Monte Carlo (HMC) and
4// No-U-Turn Sampler (NUTS) implementations.
5//
6// Reference:
7// https://github.com/pyro-ppl/numpyro/blob/master/numpyro/infer/hmc_util.py
8//
9//===----------------------------------------------------------------------===//
10
11#include "HMCUtils.h"
12
14#include "Dialect/Ops.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/Math/IR/Math.h"
18
19#include <cmath>
20#include <limits>
21
22using namespace mlir;
23using namespace mlir::enzyme;
24using namespace mlir::impulse;
25
34
36 assert(values.size() == 18 && "Expected 18 NUTSTreeState fields");
37 return {.q_left = values[0],
38 .p_left = values[1],
39 .grad_left = values[2],
40 .q_right = values[3],
41 .p_right = values[4],
42 .grad_right = values[5],
43 .q_proposal = values[6],
44 .grad_proposal = values[7],
45 .U_proposal = values[8],
46 .H_proposal = values[9],
47 .depth = values[10],
48 .weight = values[11],
49 .turning = values[12],
50 .diverging = values[13],
51 .sum_accept_probs = values[14],
52 .num_proposals = values[15],
53 .p_sum = values[16],
54 .rng = values[17]};
55}
56
57SmallVector<Type> NUTSTreeState::getTypes() const {
58 SmallVector<Type> types;
59 for (auto val : toValues())
60 types.push_back(val.getType());
61 return types;
62}
63
64Value impulse::conditionalDump(OpBuilder &builder, Location loc, Value value,
65 StringRef label, bool debugDump) {
66 if (debugDump) {
67 return enzyme::DumpOp::create(builder, loc, value.getType(), value,
68 builder.getStringAttr(label))
69 .getOutput();
70 }
71 return value;
72}
73
74/// Creates a 2D identity matrix of the specified type.
75static Value createIdentityMatrix(OpBuilder &builder, Location loc,
76 RankedTensorType matrixType) {
77 assert(matrixType.getRank() == 2 && "Expected 2D tensor type");
78 assert(matrixType.getShape()[0] == matrixType.getShape()[1] &&
79 "Expected square matrix type");
80
81 int64_t size = matrixType.getShape()[0];
82 auto elemType = matrixType.getElementType();
83
84 SmallVector<Attribute> values;
85 values.reserve(size * size);
86 for (int64_t i = 0; i < size; ++i) {
87 for (int64_t j = 0; j < size; ++j) {
88 double val = (i == j) ? 1.0 : 0.0;
89 values.push_back(builder.getFloatAttr(elemType, val));
90 }
91 }
92
93 auto denseAttr = DenseElementsAttr::get(matrixType, values);
94 return arith::ConstantOp::create(builder, loc, matrixType, denseAttr);
95}
96
97/// Creates a permutation matrix of size n x n.
98static Value createPermutationMatrix(OpBuilder &builder, Location loc,
99 RankedTensorType matrixType) {
100 assert(matrixType.getRank() == 2 && "Expected 2D tensor type");
101 assert(matrixType.getShape()[0] == matrixType.getShape()[1] &&
102 "Expected square matrix type");
103
104 int64_t size = matrixType.getShape()[0];
105 auto elemType = matrixType.getElementType();
106
107 SmallVector<Attribute> values;
108 values.reserve(size * size);
109 for (int64_t i = 0; i < size; ++i) {
110 for (int64_t j = 0; j < size; ++j) {
111 double val = (j == size - 1 - i) ? 1.0 : 0.0;
112 values.push_back(builder.getFloatAttr(elemType, val));
113 }
114 }
115
116 auto denseAttr = DenseElementsAttr::get(matrixType, values);
117 return arith::ConstantOp::create(builder, loc, matrixType, denseAttr);
118}
119
120/// Computes A[::-1, ::-1] using permutation matrix through P @ A @ P.
121static Value reverseRowsAndColumns(OpBuilder &builder, Location loc,
122 Value matrix) {
123 auto matrixType = cast<RankedTensorType>(matrix.getType());
124 auto P = createPermutationMatrix(builder, loc, matrixType);
125
126 // PA = P @ A
127 auto PA = impulse::DotOp::create(
128 builder, loc, matrixType, P, matrix, builder.getDenseI64ArrayAttr({}),
129 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({1}),
130 builder.getDenseI64ArrayAttr({0}));
131
132 // PAP = PA @ P
133 return impulse::DotOp::create(
134 builder, loc, matrixType, PA, P, builder.getDenseI64ArrayAttr({}),
135 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({1}),
136 builder.getDenseI64ArrayAttr({0}));
137}
138
139Value impulse::applyInverseMassMatrix(OpBuilder &builder, Location loc,
140 Value invMass, Value momentum,
141 RankedTensorType positionType) {
142 if (!invMass) {
143 return momentum;
144 }
145
146 auto invMassType = cast<RankedTensorType>(invMass.getType());
147
148 if (invMassType.getRank() == 1) {
149 // Diagonal: element-wise
150 auto diagMass =
151 impulse::ReshapeOp::create(builder, loc, positionType, invMass);
152 return arith::MulFOp::create(builder, loc, diagMass, momentum);
153 } else if (invMassType.getRank() == 2) {
154 // Dense: v = invMass @ p
155 return impulse::DotOp::create(
156 builder, loc, positionType, momentum, invMass,
157 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
158 builder.getDenseI64ArrayAttr({1}), builder.getDenseI64ArrayAttr({0}));
159 }
160
161 emitError(loc, "ProbProg: Provided invMass must have rank 1 or 2, got rank " +
162 std::to_string(invMassType.getRank()));
163 return nullptr;
164}
165
166Value impulse::computeKineticEnergy(OpBuilder &builder, Location loc,
167 Value momentum, Value invMass,
168 RankedTensorType positionType) {
169 auto elemType = positionType.getElementType();
170 auto scalarType = RankedTensorType::get({}, elemType);
171
172 auto halfConst = arith::ConstantOp::create(
173 builder, loc, scalarType,
174 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.5)));
175
176 // v = M^-1 @ p
177 auto v =
178 applyInverseMassMatrix(builder, loc, invMass, momentum, positionType);
179
180 // K = 0.5 * p^T @ v
181 // For 2D tensors [1, N], contract over both dimensions to get scalar
182 auto pDotV = impulse::DotOp::create(
183 builder, loc, scalarType, momentum, v, builder.getDenseI64ArrayAttr({}),
184 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({0, 1}),
185 builder.getDenseI64ArrayAttr({0, 1}));
186
187 return arith::MulFOp::create(builder, loc, halfConst, pDotV);
188}
189
190Value impulse::computeMassMatrixSqrt(OpBuilder &builder, Location loc,
191 Value invMass,
192 RankedTensorType positionType) {
193 if (!invMass) {
194 return Value();
195 }
196
197 auto invMassType = cast<RankedTensorType>(invMass.getType());
198 auto elemType = invMassType.getElementType();
199
200 if (invMassType.getRank() == 1) {
201 // Diagonal: mass_matrix_sqrt = 1/sqrt(invMass)
202 auto sqrtInvMass = math::SqrtOp::create(builder, loc, invMass);
203 auto onesVector = arith::ConstantOp::create(
204 builder, loc, invMassType,
205 DenseElementsAttr::get(invMassType,
206 builder.getFloatAttr(elemType, 1.0)));
207 return arith::DivFOp::create(builder, loc, onesVector, sqrtInvMass);
208 } else {
209 // Dense: mass_matrix_sqrt = M^{1/2}
210 // TODO: improve
211 // Reference:
212 // https://github.com/pyro-ppl/numpyro/blob/6a9cb9a530fe53897edb6c472368e58965b034e4/numpyro/infer/hmc_util.py#L499
213 auto reversedInvMass = reverseRowsAndColumns(builder, loc, invMass);
214 auto L_reversed =
215 impulse::CholeskyOp::create(builder, loc, invMassType, reversedInvMass,
216 /*lower=*/builder.getBoolAttr(true));
217 auto massMatrixSqrtInvT = reverseRowsAndColumns(builder, loc, L_reversed);
218 auto identityMatrix = createIdentityMatrix(builder, loc, invMassType);
219 auto massMatrixSqrt = impulse::TriangularSolveOp::create(
220 builder, loc, invMassType, massMatrixSqrtInvT, identityMatrix,
221 /*left_side=*/builder.getBoolAttr(true),
222 /*lower=*/builder.getBoolAttr(false),
223 /*unit_diagonal=*/builder.getBoolAttr(false),
224 /*transpose_a=*/
225 impulse::TransposeAttr::get(builder.getContext(),
226 impulse::Transpose::TRANSPOSE));
227
228 return massMatrixSqrt;
229 }
230}
231
232std::pair<Value, Value>
233impulse::sampleMomentum(OpBuilder &builder, Location loc, Value rng,
234 Value invMass, Value massMatrixSqrt,
235 RankedTensorType positionType, bool debugDump) {
236 auto elemType = positionType.getElementType();
237 auto scalarType = RankedTensorType::get({}, elemType);
238
239 auto zeroConst = arith::ConstantOp::create(
240 builder, loc, scalarType,
241 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0)));
242 auto oneConst = arith::ConstantOp::create(
243 builder, loc, scalarType,
244 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
245
246 rng = conditionalDump(builder, loc, rng, "sampleMomentum: input rng state",
247 debugDump);
248
249 // Sample eps ~ N(0, I)
250 auto randomOp = impulse::RandomOp::create(
251 builder, loc, TypeRange{rng.getType(), positionType}, rng, zeroConst,
252 oneConst,
253 impulse::RngDistributionAttr::get(builder.getContext(),
254 impulse::RngDistribution::NORMAL));
255
256 auto rngOut = randomOp.getOutputRngState();
257 auto eps = randomOp.getResult();
258
259 if (!massMatrixSqrt) {
260 return {eps, rngOut};
261 }
262
263 auto massMatrixSqrtType = cast<RankedTensorType>(massMatrixSqrt.getType());
264
265 if (massMatrixSqrtType.getRank() == 1) {
266 // Diagonal: p = massMatrixSqrt * eps (element-wise)
267 auto diagSqrt =
268 impulse::ReshapeOp::create(builder, loc, positionType, massMatrixSqrt);
269 auto p = arith::MulFOp::create(builder, loc, diagSqrt, eps);
270 return {p, rngOut};
271 } else {
272 // Dense: p = massMatrixSqrt @ eps
273 auto p = impulse::DotOp::create(
274 builder, loc, positionType, eps, massMatrixSqrt,
275 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
276 builder.getDenseI64ArrayAttr({1}), builder.getDenseI64ArrayAttr({1}));
277 return {p, rngOut};
278 }
279}
280
281static Value scatterPositionToTrace(OpBuilder &builder, Location loc,
282 Value position2d, Value fullTrace,
283 const HMCContext &ctx) {
284 auto elemType =
285 cast<RankedTensorType>(ctx.stepSize.getType()).getElementType();
286 auto traceType = RankedTensorType::get({1, ctx.getFullTraceSize()}, elemType);
287 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
288 auto c0 = arith::ConstantOp::create(
289 builder, loc, i64TensorType,
290 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
291
292 Value result = fullTrace;
293 for (const auto &info : ctx.supports) {
294 auto sliceType = RankedTensorType::get({1, info.size}, elemType);
295 auto posOffset = arith::ConstantOp::create(
296 builder, loc, i64TensorType,
297 DenseElementsAttr::get(i64TensorType,
298 builder.getI64IntegerAttr(info.offset)));
299 SmallVector<Value> extractIndices{c0, posOffset};
300 auto slice = impulse::DynamicSliceOp::create(
301 builder, loc, sliceType, position2d, extractIndices,
302 builder.getDenseI64ArrayAttr({1, info.size}));
303
304 auto traceOffset = arith::ConstantOp::create(
305 builder, loc, i64TensorType,
306 DenseElementsAttr::get(i64TensorType,
307 builder.getI64IntegerAttr(info.traceOffset)));
308 SmallVector<Value> updateIndices{c0, traceOffset};
309 result = impulse::DynamicUpdateSliceOp::create(
310 builder, loc, traceType, result, slice, updateIndices);
311 }
312 return result;
313}
314
315static Value gatherPositionFromTrace(OpBuilder &builder, Location loc,
316 Value fullTrace, const HMCContext &ctx) {
317 auto elemType =
318 cast<RankedTensorType>(ctx.stepSize.getType()).getElementType();
319 auto positionType2d = RankedTensorType::get({1, ctx.positionSize}, elemType);
320 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
321 auto c0 = arith::ConstantOp::create(
322 builder, loc, i64TensorType,
323 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
324
325 auto zeroConst = arith::ConstantOp::create(
326 builder, loc, positionType2d,
327 DenseElementsAttr::get(positionType2d,
328 builder.getFloatAttr(elemType, 0.0)));
329 Value result = zeroConst;
330 for (const auto &info : ctx.supports) {
331 auto sliceType = RankedTensorType::get({1, info.size}, elemType);
332 auto traceOffset = arith::ConstantOp::create(
333 builder, loc, i64TensorType,
334 DenseElementsAttr::get(i64TensorType,
335 builder.getI64IntegerAttr(info.traceOffset)));
336 SmallVector<Value> extractIndices{c0, traceOffset};
337 auto slice = impulse::DynamicSliceOp::create(
338 builder, loc, sliceType, fullTrace, extractIndices,
339 builder.getDenseI64ArrayAttr({1, info.size}));
340
341 auto posOffset = arith::ConstantOp::create(
342 builder, loc, i64TensorType,
343 DenseElementsAttr::get(i64TensorType,
344 builder.getI64IntegerAttr(info.offset)));
345 SmallVector<Value> updateIndices{c0, posOffset};
346 result = impulse::DynamicUpdateSliceOp::create(
347 builder, loc, positionType2d, result, slice, updateIndices);
348 }
349 return result;
350}
351
353 Location loc,
354 Value position, Value rng,
355 const HMCContext &ctx) {
356 auto positionType = ctx.getPositionType();
357 auto scalarType = ctx.getScalarType();
358 auto elemType = ctx.getElementType();
359
360 auto gradSeed = arith::ConstantOp::create(
361 builder, loc, scalarType,
362 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
363
364 bool isCustomLogpdf = ctx.hasCustomLogpdf();
365 auto flatType = RankedTensorType::get({ctx.positionSize}, elemType);
366 Value autodiffPosition = position;
367 auto autodiffPositionType = positionType;
368 auto autodiffGradType = positionType;
369 if (isCustomLogpdf) {
370 autodiffPosition =
371 impulse::ReshapeOp::create(builder, loc, flatType, position);
372 autodiffPositionType = flatType;
373 autodiffGradType = flatType;
374 }
375
376 SmallVector<Value> autodiffInputs{autodiffPosition, gradSeed};
377 SmallVector<NamedAttribute> adAttrs{
378 builder.getNamedAttr(
379 "activity",
380 builder.getArrayAttr({enzyme::ActivityAttr::get(
381 builder.getContext(), enzyme::Activity::enzyme_active)})),
382 builder.getNamedAttr(
383 "ret_activity",
384 builder.getArrayAttr(
385 {enzyme::ActivityAttr::get(builder.getContext(),
386 enzyme::Activity::enzyme_active),
387 enzyme::ActivityAttr::get(builder.getContext(),
388 enzyme::Activity::enzyme_const)})),
389 };
390 if (ctx.autodiffAttrs) {
391 for (auto attr : ctx.autodiffAttrs)
392 adAttrs.push_back(attr);
393 }
394 auto autodiffOp = enzyme::AutoDiffRegionOp::create(
395 builder, loc, TypeRange{scalarType, rng.getType(), autodiffGradType},
396 autodiffInputs, adAttrs);
397
398 Block *autodiffBlock = builder.createBlock(&autodiffOp.getBody());
399 autodiffBlock->addArgument(autodiffPositionType, loc);
400
401 builder.setInsertionPointToStart(autodiffBlock);
402 Value qArg = autodiffBlock->getArgument(0);
403
404 if (isCustomLogpdf) {
405 SmallVector<Value> callArgs;
406 callArgs.push_back(qArg);
407 callArgs.append(ctx.fnInputs.begin(), ctx.fnInputs.end());
408 auto callOp = func::CallOp::create(builder, loc, ctx.logpdfFn,
409 TypeRange{scalarType}, callArgs);
410 Value U = arith::NegFOp::create(builder, loc, callOp.getResult(0));
411 // TODO(#2695): handle hybrid case
412 enzyme::YieldOp::create(builder, loc, {U, rng});
413 } else {
414 auto traceType =
415 RankedTensorType::get({1, ctx.getFullTraceSize()}, elemType);
416
417 Value q_constrained = constrainPosition(builder, loc, qArg, ctx.supports);
418 Value fullTrace = scatterPositionToTrace(builder, loc, q_constrained,
419 ctx.originalTrace, ctx);
420
421 SmallVector<Value> generateInputs;
422 generateInputs.push_back(rng);
423 generateInputs.append(ctx.fnInputs.begin(), ctx.fnInputs.end());
424
425 SmallVector<Type> generateResultTypes;
426 generateResultTypes.push_back(traceType);
427 generateResultTypes.push_back(scalarType);
428 generateResultTypes.append(ctx.fnResultTypes.begin(),
429 ctx.fnResultTypes.end());
430
431 auto generateOp = impulse::GenerateOp::create(
432 builder, loc, generateResultTypes, ctx.fn, generateInputs, fullTrace,
433 ctx.allAddresses, ctx.allAddresses, builder.getStringAttr(""));
434
435 Value negWeight =
436 arith::NegFOp::create(builder, loc, generateOp.getWeight());
437 Value jacobianCorrection =
438 computeTotalJacobianCorrection(builder, loc, qArg, ctx.supports);
439 Value U =
440 arith::SubFOp::create(builder, loc, negWeight, jacobianCorrection);
441
442 SmallVector<Value> yieldValues{U, generateOp.getResult(2)};
443 enzyme::YieldOp::create(builder, loc, yieldValues);
444 }
445
446 builder.setInsertionPointAfter(autodiffOp);
447
448 Value grad = autodiffOp.getResult(2);
449 if (isCustomLogpdf) {
450 grad = impulse::ReshapeOp::create(builder, loc, positionType, grad);
451 }
452
453 return {
454 autodiffOp.getResult(0), // U
455 grad, // grad
456 autodiffOp.getResult(1) // rng
457 };
458}
459
461 Location loc,
462 const IntegratorState &leaf,
463 Value rng, Value direction,
464 const HMCContext &ctx) {
465 auto positionType = ctx.getPositionType();
466 auto scalarType = ctx.getScalarType();
467 auto elemType = ctx.getElementType();
468
469 auto negStepSize = arith::NegFOp::create(builder, loc, ctx.stepSize);
470 Value signedStepSize = impulse::SelectOp::create(
471 builder, loc, scalarType, direction, ctx.stepSize, negStepSize);
472
473 auto halfConst = arith::ConstantOp::create(
474 builder, loc, scalarType,
475 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.5)));
476
477 ArrayRef<int64_t> shape = positionType.getShape();
478 auto stepSizeBroadcast =
479 BroadcastOp::create(builder, loc, positionType, signedStepSize,
480 builder.getDenseI64ArrayAttr(shape));
481 auto halfStep =
482 arith::MulFOp::create(builder, loc, halfConst, signedStepSize);
483 auto halfStepBroadcast =
484 BroadcastOp::create(builder, loc, positionType, halfStep,
485 builder.getDenseI64ArrayAttr(shape));
486
487 // 1. Half step momentum: p_half = p - 0.5 * eps * grad
488 auto deltaP1 =
489 arith::MulFOp::create(builder, loc, halfStepBroadcast, leaf.grad);
490 Value pHalf = arith::SubFOp::create(builder, loc, leaf.p, deltaP1);
491
492 // 2. Full step position: q_new = q + eps * M^-1 * p_half
493 Value v =
494 applyInverseMassMatrix(builder, loc, ctx.invMass, pHalf, positionType);
495 auto deltaQ = arith::MulFOp::create(builder, loc, stepSizeBroadcast, v);
496 Value qNew = arith::AddFOp::create(builder, loc, leaf.q, deltaQ);
497
498 // 3. Compute gradient at new position
499 auto gradResult = computePotentialAndGradient(builder, loc, qNew, rng, ctx);
500
501 // 4. Final half step momentum: p_new = p_half - 0.5 * eps * grad_new
502 auto deltaP2 =
503 arith::MulFOp::create(builder, loc, halfStepBroadcast, gradResult.grad);
504 Value pNew = arith::SubFOp::create(builder, loc, pHalf, deltaP2);
505
506 return {qNew, pNew, gradResult.grad, gradResult.U, gradResult.rng};
507}
508
509Value impulse::checkTurning(OpBuilder &builder, Location loc, Value pLeft,
510 Value pRight, Value pSum, const NUTSContext &ctx) {
511 auto positionType = ctx.getPositionType();
512 auto scalarType = ctx.getScalarType();
513 auto elemType = ctx.getElementType();
514
515 auto zeroConst = arith::ConstantOp::create(
516 builder, loc, scalarType,
517 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0)));
518 auto halfConst = arith::ConstantOp::create(
519 builder, loc, scalarType,
520 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.5)));
521
522 Value vLeft =
523 applyInverseMassMatrix(builder, loc, ctx.invMass, pLeft, positionType);
524 Value vRight =
525 applyInverseMassMatrix(builder, loc, ctx.invMass, pRight, positionType);
526
527 // p_sum_centered = p_sum - (p_left + p_right) / 2
528 auto halfBroadcast = BroadcastOp::create(
529 builder, loc, positionType, halfConst,
530 builder.getDenseI64ArrayAttr(positionType.getShape()));
531
532 auto pLeftPlusPRight = arith::AddFOp::create(builder, loc, pLeft, pRight);
533 auto halfSum =
534 arith::MulFOp::create(builder, loc, halfBroadcast, pLeftPlusPRight);
535 Value pSumCentered = arith::SubFOp::create(builder, loc, pSum, halfSum);
536
537 auto leftAngle = impulse::DotOp::create(
538 builder, loc, scalarType, vLeft, pSumCentered,
539 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
540 builder.getDenseI64ArrayAttr({0, 1}),
541 builder.getDenseI64ArrayAttr({0, 1}));
542 auto rightAngle = impulse::DotOp::create(
543 builder, loc, scalarType, vRight, pSumCentered,
544 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
545 builder.getDenseI64ArrayAttr({0, 1}),
546 builder.getDenseI64ArrayAttr({0, 1}));
547
548 // turning = (left_angle <= 0) OR (right_angle <= 0)
549 auto leftNeg = arith::CmpFOp::create(builder, loc, arith::CmpFPredicate::OLE,
550 leftAngle, zeroConst);
551 auto rightNeg = arith::CmpFOp::create(builder, loc, arith::CmpFPredicate::OLE,
552 rightAngle, zeroConst);
553
554 return arith::OrIOp::create(builder, loc, leftNeg, rightNeg);
555}
556
557Value impulse::computeUniformTransitionProb(OpBuilder &builder, Location loc,
558 Value currentWeight,
559 Value newWeight) {
560 Value weightDiff =
561 arith::SubFOp::create(builder, loc, newWeight, currentWeight);
562 return impulse::LogisticOp::create(builder, loc, weightDiff.getType(),
563 weightDiff);
564}
565
566Value impulse::computeBiasedTransitionProb(OpBuilder &builder, Location loc,
567 Value currentWeight, Value newWeight,
568 Value turning, Value diverging) {
569 auto resultType = cast<RankedTensorType>(currentWeight.getType());
570 auto elemType = resultType.getElementType();
571
572 auto zeroConst = arith::ConstantOp::create(
573 builder, loc, resultType,
574 DenseElementsAttr::get(resultType, builder.getFloatAttr(elemType, 0.0)));
575 auto oneConst = arith::ConstantOp::create(
576 builder, loc, resultType,
577 DenseElementsAttr::get(resultType, builder.getFloatAttr(elemType, 1.0)));
578
579 Value weightDiff =
580 arith::SubFOp::create(builder, loc, newWeight, currentWeight);
581 Value expDiff = math::ExpOp::create(builder, loc, weightDiff);
582 Value clippedProb =
583 arith::MinimumFOp::create(builder, loc, oneConst, expDiff);
584 Value turningOrDiverging =
585 arith::OrIOp::create(builder, loc, turning, diverging);
586 return arith::SelectOp::create(builder, loc, resultType, turningOrDiverging,
587 zeroConst, clippedProb);
588}
589
590NUTSTreeState impulse::combineTrees(OpBuilder &builder, Location loc,
591 const NUTSTreeState &tree,
592 const NUTSTreeState &subTree,
593 Value direction, Value rng, bool biased,
594 const NUTSContext &ctx) {
595 auto positionType = ctx.getPositionType();
596 auto scalarType = ctx.getScalarType();
597 auto elemType = ctx.getElementType();
598 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
599
600 auto zeroConst = arith::ConstantOp::create(
601 builder, loc, scalarType,
602 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0)));
603 auto oneConst = arith::ConstantOp::create(
604 builder, loc, scalarType,
605 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
606
607 auto qLeft = impulse::SelectOp::create(builder, loc, positionType, direction,
608 tree.q_left, subTree.q_left);
609 auto pLeft = impulse::SelectOp::create(builder, loc, positionType, direction,
610 tree.p_left, subTree.p_left);
611 auto gradLeft = impulse::SelectOp::create(
612 builder, loc, positionType, direction, tree.grad_left, subTree.grad_left);
613 auto qRight = impulse::SelectOp::create(builder, loc, positionType, direction,
614 subTree.q_right, tree.q_right);
615 auto pRight = impulse::SelectOp::create(builder, loc, positionType, direction,
616 subTree.p_right, tree.p_right);
617 auto gradRight =
618 impulse::SelectOp::create(builder, loc, positionType, direction,
619 subTree.grad_right, tree.grad_right);
620
621 auto combinedWeight = impulse::LogAddExpOp::create(
622 builder, loc, scalarType, tree.weight, subTree.weight);
623
624 // Compute transition probability
625 Value transitionProb;
626 if (biased) {
627 transitionProb =
628 computeBiasedTransitionProb(builder, loc, tree.weight, subTree.weight,
629 subTree.turning, subTree.diverging);
630 } else {
631 transitionProb =
632 computeUniformTransitionProb(builder, loc, tree.weight, subTree.weight);
633 }
634
635 auto randomOp = impulse::RandomOp::create(
636 builder, loc, TypeRange{rng.getType(), scalarType}, rng, zeroConst,
637 oneConst,
638 impulse::RngDistributionAttr::get(builder.getContext(),
639 impulse::RngDistribution::UNIFORM));
640 auto rngOut = randomOp.getOutputRngState();
641 auto uniformSample = randomOp.getResult();
642
643 auto acceptNew = arith::CmpFOp::create(
644 builder, loc, arith::CmpFPredicate::OLT, uniformSample, transitionProb);
645
646 auto qProposal =
647 impulse::SelectOp::create(builder, loc, positionType, acceptNew,
648 subTree.q_proposal, tree.q_proposal);
649 auto gradProposal =
650 impulse::SelectOp::create(builder, loc, positionType, acceptNew,
651 subTree.grad_proposal, tree.grad_proposal);
652 auto UProposal = impulse::SelectOp::create(
653 builder, loc, scalarType, acceptNew, subTree.U_proposal, tree.U_proposal);
654 auto HProposal = impulse::SelectOp::create(
655 builder, loc, scalarType, acceptNew, subTree.H_proposal, tree.H_proposal);
656
657 auto oneI64 = arith::ConstantOp::create(
658 builder, loc, i64TensorType,
659 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
660 auto combinedDepth = arith::AddIOp::create(builder, loc, tree.depth, oneI64);
661
662 Value combinedTurning;
663 if (biased) {
664 auto turningCheck = checkTurning(
665 builder, loc, pLeft, pRight,
666 arith::AddFOp::create(builder, loc, tree.p_sum, subTree.p_sum), ctx);
667 combinedTurning =
668 arith::OrIOp::create(builder, loc, subTree.turning, turningCheck);
669 } else {
670 combinedTurning = tree.turning;
671 }
672
673 auto combinedDiverging =
674 arith::OrIOp::create(builder, loc, tree.diverging, subTree.diverging);
675 auto sumAcceptProbs = arith::AddFOp::create(
676 builder, loc, tree.sum_accept_probs, subTree.sum_accept_probs);
677 auto numProposals = arith::AddIOp::create(builder, loc, tree.num_proposals,
678 subTree.num_proposals);
679 auto pSum = arith::AddFOp::create(builder, loc, tree.p_sum, subTree.p_sum);
680
681 return {.q_left = qLeft,
682 .p_left = pLeft,
683 .grad_left = gradLeft,
684 .q_right = qRight,
685 .p_right = pRight,
686 .grad_right = gradRight,
687 .q_proposal = qProposal,
688 .grad_proposal = gradProposal,
689 .U_proposal = UProposal,
690 .H_proposal = HProposal,
691 .depth = combinedDepth,
692 .weight = combinedWeight,
693 .turning = combinedTurning,
694 .diverging = combinedDiverging,
695 .sum_accept_probs = sumAcceptProbs,
696 .num_proposals = numProposals,
697 .p_sum = pSum,
698 .rng = rngOut};
699}
700
701InitialHMCState impulse::InitHMC(OpBuilder &builder, Location loc, Value rng,
702 const HMCContext &ctx, Value initialPosition,
703 bool debugDump) {
704 auto positionType = ctx.getPositionType();
705 auto scalarType = ctx.getScalarType();
706 auto elemType = ctx.getElementType();
707
708 auto initSplit = impulse::RandomSplitOp::create(
709 builder, loc, TypeRange{rng.getType(), rng.getType()}, rng);
710 auto kernelSplit = impulse::RandomSplitOp::create(
711 builder, loc, TypeRange{rng.getType(), rng.getType(), rng.getType()},
712 initSplit.getResult(0));
713 auto rngForSampleKernel = kernelSplit.getResult(0);
714 auto rngForAutodiff = kernelSplit.getResult(1);
715
716 Value q0;
717 Value U0;
718
719 if (ctx.hasCustomLogpdf()) {
720 q0 = initialPosition;
721 auto flatType = RankedTensorType::get({ctx.positionSize}, elemType);
722 auto q0Flat = impulse::ReshapeOp::create(builder, loc, flatType, q0);
723 SmallVector<Value> callArgs;
724 callArgs.push_back(q0Flat);
725 callArgs.append(ctx.fnInputs.begin(), ctx.fnInputs.end());
726 auto callOp = func::CallOp::create(builder, loc, ctx.logpdfFn,
727 TypeRange{scalarType}, callArgs);
728 U0 = arith::NegFOp::create(builder, loc, callOp.getResult(0));
729 } else {
730 auto fullTraceType =
731 RankedTensorType::get({1, ctx.getFullTraceSize()}, elemType);
732
733 // 1. Extract initial position vector (constrained)
734 auto q0_constrained =
735 gatherPositionFromTrace(builder, loc, ctx.originalTrace, ctx);
736
737 // 2. Unconstrain to get position vector for HMC
738 q0 = unconstrainPosition(builder, loc, q0_constrained, ctx.supports);
739
740 // 3. Compute initial potential energy: U0 = -weight + correction
741 Value fullTraceInit = scatterPositionToTrace(builder, loc, q0_constrained,
742 ctx.originalTrace, ctx);
743
744 SmallVector<Value> generateInputsInit;
745 generateInputsInit.push_back(rngForAutodiff);
746 generateInputsInit.append(ctx.fnInputs.begin(), ctx.fnInputs.end());
747
748 SmallVector<Type> generateResultTypesInit;
749 generateResultTypesInit.push_back(fullTraceType);
750 generateResultTypesInit.push_back(scalarType);
751 generateResultTypesInit.append(ctx.fnResultTypes.begin(),
752 ctx.fnResultTypes.end());
753
754 auto generateOpInit = impulse::GenerateOp::create(
755 builder, loc, generateResultTypesInit, ctx.fn, generateInputsInit,
756 fullTraceInit, ctx.allAddresses, ctx.allAddresses,
757 builder.getStringAttr(""));
758
759 auto weight0 = generateOpInit.getWeight();
760 auto negWeight0 = arith::NegFOp::create(builder, loc, weight0);
761 auto jacobian0 =
762 computeTotalJacobianCorrection(builder, loc, q0, ctx.supports);
763 U0 = arith::SubFOp::create(builder, loc, negWeight0, jacobian0);
764 }
765
766 // 4. Compute initial gradient at q0
767 bool isCustomLogpdf = ctx.hasCustomLogpdf();
768 auto flatType = RankedTensorType::get({ctx.positionSize}, elemType);
769 Value autodiffQ0 = q0;
770 auto autodiffQ0Type = positionType;
771 auto autodiffGradType = positionType;
772 if (isCustomLogpdf) {
773 autodiffQ0 = impulse::ReshapeOp::create(builder, loc, flatType, q0);
774 autodiffQ0Type = flatType;
775 autodiffGradType = flatType;
776 }
777
778 auto gradSeedInit = arith::ConstantOp::create(
779 builder, loc, scalarType,
780 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
781 SmallVector<Value> autodiffInputs{autodiffQ0, gradSeedInit};
782 SmallVector<NamedAttribute> adInitAttrs{
783 builder.getNamedAttr(
784 "activity",
785 builder.getArrayAttr({enzyme::ActivityAttr::get(
786 builder.getContext(), enzyme::Activity::enzyme_active)})),
787 builder.getNamedAttr(
788 "ret_activity",
789 builder.getArrayAttr(
790 {enzyme::ActivityAttr::get(builder.getContext(),
791 enzyme::Activity::enzyme_active),
792 enzyme::ActivityAttr::get(builder.getContext(),
793 enzyme::Activity::enzyme_const)})),
794 };
795 if (ctx.autodiffAttrs) {
796 for (auto attr : ctx.autodiffAttrs)
797 adInitAttrs.push_back(attr);
798 }
799 auto autodiffInit = enzyme::AutoDiffRegionOp::create(
800 builder, loc,
801 TypeRange{scalarType, rngForAutodiff.getType(), autodiffGradType},
802 autodiffInputs, adInitAttrs);
803
804 Block *autodiffInitBlock = builder.createBlock(&autodiffInit.getBody());
805 autodiffInitBlock->addArgument(autodiffQ0Type, loc);
806
807 builder.setInsertionPointToStart(autodiffInitBlock);
808 auto q0Arg = autodiffInitBlock->getArgument(0);
809
810 if (isCustomLogpdf) {
811 SmallVector<Value> callArgs;
812 callArgs.push_back(q0Arg);
813 callArgs.append(ctx.fnInputs.begin(), ctx.fnInputs.end());
814 auto callOpInner = func::CallOp::create(builder, loc, ctx.logpdfFn,
815 TypeRange{scalarType}, callArgs);
816 auto U0_init =
817 arith::NegFOp::create(builder, loc, callOpInner.getResult(0));
818 enzyme::YieldOp::create(builder, loc, {U0_init, rngForAutodiff});
819 } else {
820 auto fullTraceType =
821 RankedTensorType::get({1, ctx.getFullTraceSize()}, elemType);
822 auto q0Arg_constrained =
823 constrainPosition(builder, loc, q0Arg, ctx.supports);
824 Value fullTraceInner = scatterPositionToTrace(
825 builder, loc, q0Arg_constrained, ctx.originalTrace, ctx);
826
827 SmallVector<Value> generateInputsInner;
828 generateInputsInner.push_back(rngForAutodiff);
829 generateInputsInner.append(ctx.fnInputs.begin(), ctx.fnInputs.end());
830
831 SmallVector<Type> generateResultTypesInner;
832 generateResultTypesInner.push_back(fullTraceType);
833 generateResultTypesInner.push_back(scalarType);
834 generateResultTypesInner.append(ctx.fnResultTypes.begin(),
835 ctx.fnResultTypes.end());
836
837 auto generateOpInner = impulse::GenerateOp::create(
838 builder, loc, generateResultTypesInner, ctx.fn, generateInputsInner,
839 fullTraceInner, ctx.allAddresses, ctx.allAddresses,
840 builder.getStringAttr(""));
841
842 auto negWeightInit =
843 arith::NegFOp::create(builder, loc, generateOpInner.getWeight());
844 auto jacobianInit =
845 computeTotalJacobianCorrection(builder, loc, q0Arg, ctx.supports);
846 auto U0_init =
847 arith::SubFOp::create(builder, loc, negWeightInit, jacobianInit);
848
849 SmallVector<Value> yieldValues{U0_init, generateOpInner.getResult(2)};
850 enzyme::YieldOp::create(builder, loc, yieldValues);
851 }
852 builder.setInsertionPointAfter(autodiffInit);
853
854 Value grad0 = autodiffInit.getResult(2);
855 if (isCustomLogpdf) {
856 grad0 = impulse::ReshapeOp::create(builder, loc, positionType, grad0);
857 }
858
859 return {q0, U0, grad0, rngForSampleKernel};
860}
861
862MCMCKernelResult impulse::SampleHMC(OpBuilder &builder, Location loc, Value q,
863 Value grad, Value U, Value rng,
864 const HMCContext &ctx, bool debugDump) {
865 auto positionType = ctx.getPositionType();
866 auto scalarType = ctx.getScalarType();
867 auto elemType = ctx.getElementType();
868 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
869 auto i1TensorType = RankedTensorType::get({}, builder.getI1Type());
870
871 // 0. Compute num_steps and adjusted step_size
872 // num_steps = ceil(trajectory_length / step_size)
873 // adjusted_step_size = trajectory_length / num_steps
874 auto trajDivStep =
875 arith::DivFOp::create(builder, loc, ctx.trajectoryLength, ctx.stepSize);
876 auto numStepsF64 = math::CeilOp::create(builder, loc, trajDivStep);
877 auto numSteps =
878 arith::FPToSIOp::create(builder, loc, i64TensorType, numStepsF64);
879 auto adjustedStepSize =
880 arith::DivFOp::create(builder, loc, ctx.trajectoryLength, numStepsF64);
881
882 auto adjustedCtx = ctx.withStepSize(adjustedStepSize);
883
884 // 1. Split RNG: [rngNext, rngMomentum, rngTransition]
885 auto sampleKernelSplit = impulse::RandomSplitOp::create(
886 builder, loc, TypeRange{rng.getType(), rng.getType(), rng.getType()},
887 rng);
888 auto rngNext = sampleKernelSplit.getResult(0);
889 auto rngMomentum = sampleKernelSplit.getResult(1);
890 auto rngTransition = sampleKernelSplit.getResult(2);
891
892 // 2. Sample fresh momentum p ~ N(0, M)
893 Value rngForMomentum = rngMomentum;
894 if (!ctx.hasCustomLogpdf()) {
895 auto momSplit = impulse::RandomSplitOp::create(
896 builder, loc, TypeRange{rng.getType(), rng.getType()}, rngMomentum);
897 rngForMomentum = momSplit.getResult(0);
898 }
899 auto [p0, rngAfterMomentum] =
900 sampleMomentum(builder, loc, rngForMomentum, ctx.invMass,
901 ctx.massMatrixSqrt, positionType, debugDump);
902
903 // 3. Compute K0 = 0.5 * p^T * M^-1 * p
904 auto K0 = computeKineticEnergy(builder, loc, p0, ctx.invMass, positionType);
905
906 // 4. Compute H0 = U + K0
907 auto H0 = arith::AddFOp::create(builder, loc, U, K0);
908
909 // 5. Leapfrog integration loop
910 auto direction = arith::ConstantOp::create(
911 builder, loc, i1TensorType,
912 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(true)));
913
914 auto c0 = arith::ConstantOp::create(
915 builder, loc, i64TensorType,
916 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
917 auto c1 = arith::ConstantOp::create(
918 builder, loc, i64TensorType,
919 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
920
921 // Loop carries: [q, p, grad, U, rng]
922 SmallVector<Type> loopResultTypes = {positionType, positionType, positionType,
923 scalarType, rngTransition.getType()};
924 auto forLoopOp =
925 impulse::ForOp::create(builder, loc, loopResultTypes, c0, numSteps, c1,
926 ValueRange{q, p0, grad, U, rngTransition});
927
928 Block *loopBody = builder.createBlock(&forLoopOp.getRegion());
929 loopBody->addArgument(i64TensorType, loc); // iv
930 loopBody->addArgument(positionType, loc); // q
931 loopBody->addArgument(positionType, loc); // p
932 loopBody->addArgument(positionType, loc); // gradient
933 loopBody->addArgument(scalarType, loc); // U
934 loopBody->addArgument(rngTransition.getType(), loc); // rng
935
936 builder.setInsertionPointToStart(loopBody);
937 auto qLoop = loopBody->getArgument(1);
938 auto pLoop = loopBody->getArgument(2);
939 auto gradLoop = loopBody->getArgument(3);
940 auto rngLoop = loopBody->getArgument(5);
941
942 IntegratorState leaf = {qLoop, pLoop, gradLoop};
943 auto step = computeIntegrationStep(builder, loc, leaf, rngLoop, direction,
944 adjustedCtx);
945
946 // Yield [q, p, grad, U, rng]
947 impulse::YieldOp::create(
948 builder, loc, ValueRange{step.q, step.p, step.grad, step.U, step.rng});
949
950 builder.setInsertionPointAfter(forLoopOp);
951 auto qProposal = forLoopOp.getResult(0);
952 auto pProposal = forLoopOp.getResult(1);
953 auto gradProposal = forLoopOp.getResult(2);
954 auto UProposal = forLoopOp.getResult(3);
955 auto rngAfterLeapfrog = forLoopOp.getResult(4);
956
957 // 6. Compute K1, H1 for proposal
958 auto K1 =
959 computeKineticEnergy(builder, loc, pProposal, ctx.invMass, positionType);
960 auto H1 = arith::AddFOp::create(builder, loc, UProposal, K1);
961
962 // 7. MH accept/reject
963 auto zeroConst = arith::ConstantOp::create(
964 builder, loc, scalarType,
965 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0)));
966 auto oneConst = arith::ConstantOp::create(
967 builder, loc, scalarType,
968 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
969
970 // α = min(1, exp(H0 - H1))
971 auto dH = arith::SubFOp::create(builder, loc, H0, H1);
972 auto expDH = math::ExpOp::create(builder, loc, dH);
973 auto accProb = arith::MinimumFOp::create(builder, loc, oneConst, expDH);
974
975 // u ~ Uniform(0, 1)
976 auto randomOp = impulse::RandomOp::create(
977 builder, loc, TypeRange{rngAfterLeapfrog.getType(), scalarType},
978 rngAfterLeapfrog, zeroConst, oneConst,
979 impulse::RngDistributionAttr::get(builder.getContext(),
980 impulse::RngDistribution::UNIFORM));
981 auto randUniform = randomOp.getResult();
982
983 // accepted = u < α
984 auto acceptedTensor = arith::CmpFOp::create(
985 builder, loc, arith::CmpFPredicate::OLT, randUniform, accProb);
986
987 // 8. Select between original and proposal
988 auto qFinal = impulse::SelectOp::create(builder, loc, positionType,
989 acceptedTensor, qProposal, q);
990 auto gradFinal = impulse::SelectOp::create(
991 builder, loc, positionType, acceptedTensor, gradProposal, grad);
992 auto UFinal = impulse::SelectOp::create(builder, loc, scalarType,
993 acceptedTensor, UProposal, U);
994
995 return {qFinal, gradFinal, UFinal, acceptedTensor, accProb, rngNext};
996}
997
998MCMCKernelResult impulse::SampleNUTS(OpBuilder &builder, Location loc, Value q,
999 Value grad, Value U, Value rng,
1000 const NUTSContext &ctx, bool debugDump) {
1001 auto positionType = ctx.getPositionType();
1002 auto scalarType = ctx.getScalarType();
1003 auto elemType = ctx.getElementType();
1004 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1005 auto i1TensorType = RankedTensorType::get({}, builder.getI1Type());
1006
1007 // 1. Split RNG: [rngNext, rngMomentum, rngTree]
1008 auto sampleKernelSplit = impulse::RandomSplitOp::create(
1009 builder, loc, TypeRange{rng.getType(), rng.getType(), rng.getType()},
1010 rng);
1011 auto rngNext = sampleKernelSplit.getResult(0);
1012 auto rngMomentum = sampleKernelSplit.getResult(1);
1013 auto rngTree = sampleKernelSplit.getResult(2);
1014
1015 // 2. Sample fresh momentum p ~ N(0, M)
1016 Value rngForMomentum = rngMomentum;
1017 if (!ctx.hasCustomLogpdf()) {
1018 auto momSplit = impulse::RandomSplitOp::create(
1019 builder, loc, TypeRange{rng.getType(), rng.getType()}, rngMomentum);
1020 rngForMomentum = momSplit.getResult(0);
1021 }
1022 auto [p0, rngAfterMomentum] =
1023 sampleMomentum(builder, loc, rngForMomentum, ctx.invMass,
1024 ctx.massMatrixSqrt, positionType, debugDump);
1025
1026 // 3. Compute K0 = 0.5 * p^T * M^-1 * p
1027 auto K0 = computeKineticEnergy(builder, loc, p0, ctx.invMass, positionType);
1028
1029 // 4. Compute H0 = U + K0
1030 auto H0 = arith::AddFOp::create(builder, loc, U, K0);
1031
1032 // 5. Initialize NUTS tree state
1033 auto iterCtx = ctx.withH0(H0);
1034
1035 auto zeroI64 = arith::ConstantOp::create(
1036 builder, loc, i64TensorType,
1037 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
1038 auto falseConst = arith::ConstantOp::create(
1039 builder, loc, i1TensorType,
1040 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(false)));
1041 auto zeroConst = arith::ConstantOp::create(
1042 builder, loc, scalarType,
1043 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0)));
1044
1045 NUTSTreeState initialTree = {.q_left = q,
1046 .p_left = p0,
1047 .grad_left = grad,
1048 .q_right = q,
1049 .p_right = p0,
1050 .grad_right = grad,
1051 .q_proposal = q,
1052 .grad_proposal = grad,
1053 .U_proposal = U,
1054 .H_proposal = H0,
1055 .depth = zeroI64,
1056 .weight = zeroConst,
1057 .turning = falseConst,
1058 .diverging = falseConst,
1059 .sum_accept_probs = zeroConst,
1060 .num_proposals = zeroI64,
1061 .p_sum = p0,
1062 .rng = rngTree};
1063
1064 // 6. Build NUTS tree
1065 auto finalTree = buildTree(builder, loc, initialTree, iterCtx, debugDump);
1066
1067 // 7. NUTS always accepts the proposal (implicit acceptance in the tree)
1068 auto trueConst = arith::ConstantOp::create(
1069 builder, loc, i1TensorType,
1070 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(true)));
1071
1072 // 8. Compute mean acceptance probability for step size adaptation
1073 auto oneI64 = arith::ConstantOp::create(
1074 builder, loc, i64TensorType,
1075 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1076 auto numProposalsClamped =
1077 arith::MaxSIOp::create(builder, loc, finalTree.num_proposals, oneI64);
1078 auto numProposalsFloat =
1079 arith::SIToFPOp::create(builder, loc, scalarType, numProposalsClamped);
1080 auto meanAcceptProb = arith::DivFOp::create(
1081 builder, loc, finalTree.sum_accept_probs, numProposalsFloat);
1082
1083 return {finalTree.q_proposal, finalTree.grad_proposal,
1084 finalTree.U_proposal, trueConst,
1085 meanAcceptProb, rngNext};
1086}
1087
1088NUTSTreeState impulse::buildBaseTree(OpBuilder &builder, Location loc,
1089 const IntegratorState &leaf, Value rng,
1090 Value direction, const NUTSContext &ctx) {
1091 auto positionType = ctx.getPositionType();
1092 auto scalarType = ctx.getScalarType();
1093 auto elemType = ctx.getElementType();
1094 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1095 auto i1TensorType = RankedTensorType::get({}, builder.getI1Type());
1096
1097 auto oneConst = arith::ConstantOp::create(
1098 builder, loc, scalarType,
1099 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
1100 auto zeroI64 = arith::ConstantOp::create(
1101 builder, loc, i64TensorType,
1102 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
1103 auto oneI64 = arith::ConstantOp::create(
1104 builder, loc, i64TensorType,
1105 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1106 auto falseConst = arith::ConstantOp::create(
1107 builder, loc, i1TensorType,
1108 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(false)));
1109
1110 IntegrationResult leap =
1111 computeIntegrationStep(builder, loc, leaf, rng, direction, ctx);
1112
1113 auto qNew = leap.q;
1114 auto pNew = leap.p;
1115 auto gradNew = leap.grad;
1116 auto UNew = leap.U;
1117 auto rngOut = leap.rng;
1118
1119 auto KNew =
1120 computeKineticEnergy(builder, loc, pNew, ctx.invMass, positionType);
1121 auto HNew = arith::AddFOp::create(builder, loc, UNew, KNew);
1122 Value deltaH = arith::SubFOp::create(builder, loc, HNew, ctx.H0);
1123
1124 // NaN check
1125 auto isNan = arith::CmpFOp::create(builder, loc, arith::CmpFPredicate::UNE,
1126 deltaH, deltaH);
1127 auto infConst = arith::ConstantOp::create(
1128 builder, loc, scalarType,
1129 DenseElementsAttr::get(
1130 scalarType, builder.getFloatAttr(
1131 elemType, std::numeric_limits<double>::infinity())));
1132 deltaH = arith::SelectOp::create(builder, loc, scalarType, isNan, infConst,
1133 deltaH);
1134
1135 auto treeWeight = arith::NegFOp::create(builder, loc, deltaH);
1136
1137 // Check for divergence
1138 auto diverging = arith::CmpFOp::create(
1139 builder, loc, arith::CmpFPredicate::OGT, deltaH, ctx.maxDeltaEnergy);
1140
1141 auto negDeltaH = arith::NegFOp::create(builder, loc, deltaH);
1142 auto expNegDelta = math::ExpOp::create(builder, loc, negDeltaH);
1143 auto acceptProb =
1144 arith::MinimumFOp::create(builder, loc, oneConst, expNegDelta);
1145
1146 return {.q_left = qNew,
1147 .p_left = pNew,
1148 .grad_left = gradNew,
1149 .q_right = qNew,
1150 .p_right = pNew,
1151 .grad_right = gradNew,
1152 .q_proposal = qNew,
1153 .grad_proposal = gradNew,
1154 .U_proposal = UNew,
1155 .H_proposal = HNew,
1156 .depth = zeroI64,
1157 .weight = treeWeight,
1158 .turning = falseConst,
1159 .diverging = diverging,
1160 .sum_accept_probs = acceptProb,
1161 .num_proposals = oneI64,
1162 .p_sum = pNew,
1163 .rng = rngOut};
1164}
1165
1166IntegratorState impulse::getLeafFromTree(OpBuilder &builder, Location loc,
1167 const NUTSTreeState &tree,
1168 Value direction,
1169 const NUTSContext &ctx) {
1170 auto positionType = ctx.getPositionType();
1171
1172 auto leafQ = impulse::SelectOp::create(builder, loc, positionType, direction,
1173 tree.q_right, tree.q_left);
1174 auto leafP = impulse::SelectOp::create(builder, loc, positionType, direction,
1175 tree.p_right, tree.p_left);
1176 auto leafGrad = impulse::SelectOp::create(
1177 builder, loc, positionType, direction, tree.grad_right, tree.grad_left);
1178 return {leafQ, leafP, leafGrad};
1179}
1180
1182impulse::buildIterativeSubtree(OpBuilder &builder, Location loc,
1183 const NUTSTreeState &initialTree,
1184 Value direction, Value pCkpts, Value pSumCkpts,
1185 const NUTSContext &ctx, bool debugDump) {
1186 auto i1TensorType = RankedTensorType::get({}, builder.getI1Type());
1187 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1188 auto pCkptsType = cast<RankedTensorType>(pCkpts.getType());
1189 auto trueConst = arith::ConstantOp::create(
1190 builder, loc, i1TensorType,
1191 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(true)));
1192 auto oneI64 = arith::ConstantOp::create(
1193 builder, loc, i64TensorType,
1194 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1195 auto zeroI64 = arith::ConstantOp::create(
1196 builder, loc, i64TensorType,
1197 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
1198
1199 // 2 ^ (initialTree.depth)
1200 auto maxNumProposals =
1201 arith::ShLIOp::create(builder, loc, oneI64, initialTree.depth);
1202
1203 SmallVector<Type> whileTypes = initialTree.getTypes();
1204 whileTypes.push_back(pCkptsType);
1205 whileTypes.push_back(pCkptsType);
1206 whileTypes.push_back(i64TensorType);
1207
1208 SmallVector<Value> whileInitVals = initialTree.toValues();
1209 whileInitVals[15] = zeroI64; // zero `num_proposals`
1210 whileInitVals.push_back(pCkpts);
1211 whileInitVals.push_back(pSumCkpts);
1212 whileInitVals.push_back(zeroI64);
1213
1214 auto whileOp =
1215 impulse::WhileOp::create(builder, loc, whileTypes, whileInitVals);
1216
1217 // Check: num_proposals < max_num_proposals && !turning && !diverging
1218 Block *condBlock = builder.createBlock(&whileOp.getConditionRegion());
1219 for (auto type : whileTypes)
1220 condBlock->addArgument(type, loc);
1221
1222 builder.setInsertionPointToStart(condBlock);
1223 SmallVector<Value> condTreeArgs(condBlock->getArguments().begin(),
1224 condBlock->getArguments().begin() + 18);
1225 NUTSTreeState condTree = NUTSTreeState::fromValues(condTreeArgs);
1226
1227 auto numProposalsCheck =
1228 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::slt,
1229 condTree.num_proposals, maxNumProposals);
1230 auto notTurning =
1231 arith::XOrIOp::create(builder, loc, condTree.turning, trueConst);
1232 auto notDiverging =
1233 arith::XOrIOp::create(builder, loc, condTree.diverging, trueConst);
1234 auto continueCond = arith::AndIOp::create(
1235 builder, loc,
1236 arith::AndIOp::create(builder, loc, numProposalsCheck, notTurning),
1237 notDiverging);
1238
1239 // Yield continue condition
1240 impulse::YieldOp::create(builder, loc, ValueRange{continueCond});
1241
1242 Block *bodyBlock = builder.createBlock(&whileOp.getBodyRegion());
1243 for (auto type : whileTypes)
1244 bodyBlock->addArgument(type, loc);
1245
1246 builder.setInsertionPointToStart(bodyBlock);
1247
1248 SmallVector<Value> bodyTreeArgs(bodyBlock->getArguments().begin(),
1249 bodyBlock->getArguments().begin() + 18);
1250 NUTSTreeState bodyTree = NUTSTreeState::fromValues(bodyTreeArgs);
1251 auto bodyPCkpts = bodyBlock->getArgument(18);
1252 auto bodyPSumCkpts = bodyBlock->getArgument(19);
1253 auto bodyLeafIdx = bodyBlock->getArgument(20);
1254
1255 // Extract leaf based on direction
1256 IntegratorState leaf =
1257 getLeafFromTree(builder, loc, bodyTree, direction, ctx);
1258
1259 auto rngSplit2 = impulse::RandomSplitOp::create(
1260 builder, loc, TypeRange{bodyTree.rng.getType(), bodyTree.rng.getType()},
1261 bodyTree.rng);
1262 auto rngNext = rngSplit2.getResult(0);
1263 auto rngCombine = rngSplit2.getResult(1);
1264
1265 // Build base tree
1266 NUTSTreeState newLeaf =
1267 buildBaseTree(builder, loc, leaf, rngNext, direction, ctx);
1268
1269 // First leaf check
1270 auto isFirstLeaf = arith::CmpIOp::create(
1271 builder, loc, arith::CmpIPredicate::eq, bodyTree.num_proposals, zeroI64);
1272
1273 SmallVector<Type> treeTypes = newLeaf.getTypes();
1274 auto ifOp = impulse::IfOp::create(builder, loc, treeTypes, isFirstLeaf);
1275 {
1276 Block *trueBranch = builder.createBlock(&ifOp.getTrueBranch());
1277 builder.setInsertionPointToStart(trueBranch);
1278 impulse::YieldOp::create(builder, loc, newLeaf.toValues());
1279 }
1280 {
1281 Block *falseBranch = builder.createBlock(&ifOp.getFalseBranch());
1282 builder.setInsertionPointToStart(falseBranch);
1283 NUTSTreeState combinedTree =
1284 combineTrees(builder, loc, bodyTree, newLeaf, direction, rngCombine,
1285 /*biased=*/false, ctx);
1286 impulse::YieldOp::create(builder, loc, combinedTree.toValues());
1287 }
1288
1289 builder.setInsertionPointAfter(ifOp);
1291 SmallVector<Value>(ifOp.getResults().begin(), ifOp.getResults().end()));
1292 updatedTree.rng = rngNext;
1293
1294 // Update and check iterative turning
1295 auto [ckptIdxMin, ckptIdxMax] =
1296 leafIdxToCheckpointIdxs(builder, loc, bodyLeafIdx);
1297 auto [updatedPCkpts, updatedPSumCkpts] = updateCheckpoints(
1298 builder, loc, bodyLeafIdx, ckptIdxMax, newLeaf.p_right, updatedTree.p_sum,
1299 bodyPCkpts, bodyPSumCkpts, ctx, debugDump);
1300 auto iterativeTurning = checkIterativeTurning(
1301 builder, loc, newLeaf.p_right, updatedTree.p_sum, updatedPCkpts,
1302 updatedPSumCkpts, ckptIdxMin, ckptIdxMax, ctx, debugDump);
1303
1304 updatedTree.turning =
1305 impulse::SelectOp::create(builder, loc, i1TensorType, isFirstLeaf,
1306 newLeaf.turning, iterativeTurning);
1307
1308 auto nextLeafIdx = arith::AddIOp::create(builder, loc, bodyLeafIdx, oneI64);
1309
1310 SmallVector<Value> yieldVals = updatedTree.toValues();
1311 yieldVals.push_back(updatedPCkpts);
1312 yieldVals.push_back(updatedPSumCkpts);
1313 yieldVals.push_back(nextLeafIdx);
1314 impulse::YieldOp::create(builder, loc, yieldVals);
1315
1316 builder.setInsertionPointAfter(whileOp);
1317
1318 SmallVector<Value> resultTreeArgs(whileOp.getResults().begin(),
1319 whileOp.getResults().begin() + 18);
1320 NUTSTreeState resultTree = NUTSTreeState::fromValues(resultTreeArgs);
1321 auto resultPCkpts = whileOp.getResult(18);
1322 auto resultPSumCkpts = whileOp.getResult(19);
1323
1324 // `combineTrees` increments depth at each leaf building step; we need to
1325 // restore to target depth here
1326 resultTree.depth = initialTree.depth;
1327
1328 return {resultTree, resultPCkpts, resultPSumCkpts};
1329}
1330
1331SubtreeBuildResult impulse::doubleTree(OpBuilder &builder, Location loc,
1332 const NUTSTreeState &tree,
1333 Value direction, Value pCkpts,
1334 Value pSumCkpts, const NUTSContext &ctx,
1335 bool debugDump) {
1336 auto rngSplit2 = impulse::RandomSplitOp::create(
1337 builder, loc, TypeRange{tree.rng.getType(), tree.rng.getType()},
1338 tree.rng);
1339 auto rngSubtree = rngSplit2.getResult(0);
1340 auto rngTransition = rngSplit2.getResult(1);
1341
1342 NUTSTreeState subTreeInit = tree;
1343 subTreeInit.rng = rngSubtree;
1344 auto subtreeResult = buildIterativeSubtree(
1345 builder, loc, subTreeInit, direction, pCkpts, pSumCkpts, ctx, debugDump);
1346
1347 // Tree combine using *biased* transition kernel
1348 NUTSTreeState combinedTree =
1349 combineTrees(builder, loc, tree, subtreeResult.tree, direction,
1350 rngTransition, /*biased=*/true, ctx);
1351
1352 return {combinedTree, subtreeResult.pCkpts, subtreeResult.pSumCkpts};
1353}
1354
1355NUTSTreeState impulse::buildTree(OpBuilder &builder, Location loc,
1356 const NUTSTreeState &initialTree,
1357 const NUTSContext &ctx, bool debugDump) {
1358 auto elemType =
1359 cast<RankedTensorType>(ctx.stepSize.getType()).getElementType();
1360 auto F64TensorType = RankedTensorType::get({}, elemType);
1361 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1362 auto i1TensorType = RankedTensorType::get({}, builder.getI1Type());
1363
1364 auto trueConst = arith::ConstantOp::create(
1365 builder, loc, i1TensorType,
1366 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(true)));
1367 auto halfConst = arith::ConstantOp::create(
1368 builder, loc, F64TensorType,
1369 DenseElementsAttr::get(F64TensorType, builder.getF64FloatAttr(0.5)));
1370 auto zeroConst = arith::ConstantOp::create(
1371 builder, loc, F64TensorType,
1372 DenseElementsAttr::get(F64TensorType, builder.getF64FloatAttr(0.0)));
1373 auto oneConst = arith::ConstantOp::create(
1374 builder, loc, F64TensorType,
1375 DenseElementsAttr::get(F64TensorType, builder.getF64FloatAttr(1.0)));
1376
1377 auto maxTreeDepth = arith::ConstantOp::create(
1378 builder, loc, i64TensorType,
1379 DenseElementsAttr::get(i64TensorType,
1380 builder.getI64IntegerAttr(ctx.maxTreeDepth)));
1381
1382 auto checkpointType =
1383 RankedTensorType::get({ctx.maxTreeDepth, ctx.positionSize}, elemType);
1384
1385 SmallVector<Type> whileTypes = initialTree.getTypes();
1386 SmallVector<Value> whileInitVals = initialTree.toValues();
1387
1388 auto whileOp =
1389 impulse::WhileOp::create(builder, loc, whileTypes, whileInitVals);
1390
1391 // Check: (depth < maxTreeDepth) && !turning && !diverging
1392 Block *condBlock = builder.createBlock(&whileOp.getConditionRegion());
1393 for (auto type : whileTypes)
1394 condBlock->addArgument(type, loc);
1395
1396 builder.setInsertionPointToStart(condBlock);
1397
1398 SmallVector<Value> condArgs(condBlock->getArguments().begin(),
1399 condBlock->getArguments().end());
1400 NUTSTreeState condTree = NUTSTreeState::fromValues(condArgs);
1401
1402 auto depthCheck = arith::CmpIOp::create(
1403 builder, loc, arith::CmpIPredicate::slt, condTree.depth, maxTreeDepth);
1404 auto notTurning =
1405 arith::XOrIOp::create(builder, loc, condTree.turning, trueConst);
1406 auto notDiverging =
1407 arith::XOrIOp::create(builder, loc, condTree.diverging, trueConst);
1408
1409 // Yield continue condition
1410 auto continueCond = arith::AndIOp::create(
1411 builder, loc, arith::AndIOp::create(builder, loc, depthCheck, notTurning),
1412 notDiverging);
1413
1414 impulse::YieldOp::create(builder, loc, ValueRange{continueCond});
1415
1416 Block *bodyBlock = builder.createBlock(&whileOp.getBodyRegion());
1417 for (auto type : whileTypes)
1418 bodyBlock->addArgument(type, loc);
1419
1420 builder.setInsertionPointToStart(bodyBlock);
1421
1422 SmallVector<Value> bodyArgs(bodyBlock->getArguments().begin(),
1423 bodyBlock->getArguments().end());
1424 NUTSTreeState bodyTree = NUTSTreeState::fromValues(bodyArgs);
1425
1426 // Create fresh checkpoint tensors
1427 auto zeroCkpts = arith::ConstantOp::create(
1428 builder, loc, checkpointType,
1429 DenseElementsAttr::get(checkpointType, builder.getF64FloatAttr(0.0)));
1430 Value bodyPCkpts = zeroCkpts;
1431 Value bodyPSumCkpts = zeroCkpts;
1432
1433 auto rngSplit3 = impulse::RandomSplitOp::create(
1434 builder, loc,
1435 TypeRange{bodyTree.rng.getType(), bodyTree.rng.getType(),
1436 bodyTree.rng.getType()},
1437 bodyTree.rng);
1438 auto rngNext = rngSplit3.getResult(0);
1439 auto rngDir = rngSplit3.getResult(1);
1440 auto rngDbl = rngSplit3.getResult(2);
1441
1442 auto directionRandom = impulse::RandomOp::create(
1443 builder, loc, TypeRange{rngDir.getType(), F64TensorType}, rngDir,
1444 zeroConst, oneConst,
1445 impulse::RngDistributionAttr::get(builder.getContext(),
1446 impulse::RngDistribution::UNIFORM));
1447 auto direction =
1448 arith::CmpFOp::create(builder, loc, arith::CmpFPredicate::OLT,
1449 directionRandom.getResult(), halfConst);
1450
1451 // Double the tree
1452 NUTSTreeState treeToDouble = bodyTree;
1453 treeToDouble.rng = rngDbl;
1454 auto doubleResult = doubleTree(builder, loc, treeToDouble, direction,
1455 bodyPCkpts, bodyPSumCkpts, ctx, debugDump);
1456
1457 NUTSTreeState treeToYield = doubleResult.tree;
1458 treeToYield.rng = rngNext;
1459 impulse::YieldOp::create(builder, loc, treeToYield.toValues());
1460
1461 builder.setInsertionPointAfter(whileOp);
1462
1463 SmallVector<Value> results(whileOp.getResults().begin(),
1464 whileOp.getResults().end());
1465 return NUTSTreeState::fromValues(results);
1466}
1467
1468std::pair<Value, Value> impulse::leafIdxToCheckpointIdxs(OpBuilder &builder,
1469 Location loc,
1470 Value leafIdx) {
1471 auto i64TensorType = cast<RankedTensorType>(leafIdx.getType());
1472
1473 auto oneConst = arith::ConstantOp::create(
1474 builder, loc, i64TensorType,
1475 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1476
1477 // idx_max = popcount(leafIdx >> 1)
1478 auto shiftedIdx = arith::ShRUIOp::create(builder, loc, leafIdx, oneConst);
1479 auto idxMax =
1480 impulse::PopcountOp::create(builder, loc, i64TensorType, shiftedIdx);
1481
1482 // num_subtrees = popcount((~leafIdx & (leafIdx + 1)) - 1)
1483 auto leafIdxPlusOne = arith::AddIOp::create(builder, loc, leafIdx, oneConst);
1484
1485 auto minusOneConst = arith::ConstantOp::create(
1486 builder, loc, i64TensorType,
1487 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(-1)));
1488 auto notLeafIdx = arith::XOrIOp::create(builder, loc, leafIdx, minusOneConst);
1489
1490 Value andResult =
1491 arith::AndIOp::create(builder, loc, notLeafIdx, leafIdxPlusOne);
1492 Value andMinusOne = arith::SubIOp::create(builder, loc, andResult, oneConst);
1493 Value numSubtrees =
1494 impulse::PopcountOp::create(builder, loc, i64TensorType, andMinusOne);
1495
1496 // idx_min = idx_max - num_subtrees + 1
1497 Value idxMaxMinusNumSubtrees =
1498 arith::SubIOp::create(builder, loc, idxMax, numSubtrees);
1499 Value idxMin =
1500 arith::AddIOp::create(builder, loc, idxMaxMinusNumSubtrees, oneConst);
1501
1502 return {idxMin, idxMax};
1503}
1504
1505Value impulse::checkIterativeTurning(OpBuilder &builder, Location loc, Value p,
1506 Value pSum, Value pCkpts, Value pSumCkpts,
1507 Value idxMin, Value idxMax,
1508 const NUTSContext &ctx, bool debugDump) {
1509 auto positionType = ctx.getPositionType();
1510 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1511 auto i1TensorType = RankedTensorType::get({}, builder.getI1Type());
1512
1513 auto falseConst = arith::ConstantOp::create(
1514 builder, loc, i1TensorType,
1515 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(false)));
1516 auto oneI64 = arith::ConstantOp::create(
1517 builder, loc, i64TensorType,
1518 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1519
1520 // Iterate from `idx_max` down to `idx_min`, check turning at each checkpoint
1521 SmallVector<Type> whileTypes = {i64TensorType, i1TensorType};
1522 SmallVector<Value> whileInitVals = {idxMax, falseConst};
1523
1524 auto whileOp =
1525 impulse::WhileOp::create(builder, loc, whileTypes, whileInitVals);
1526 Block *condBlock = builder.createBlock(&whileOp.getConditionRegion());
1527 condBlock->addArgument(i64TensorType, loc);
1528 condBlock->addArgument(i1TensorType, loc);
1529 builder.setInsertionPointToStart(condBlock);
1530
1531 Value iCond = condBlock->getArgument(0);
1532 Value turningCond = condBlock->getArgument(1);
1533
1534 auto iGeMin = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::sge,
1535 iCond, idxMin);
1536 auto trueConst = arith::ConstantOp::create(
1537 builder, loc, i1TensorType,
1538 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(true)));
1539 auto notTurning = arith::XOrIOp::create(builder, loc, turningCond, trueConst);
1540 auto continueLoop = arith::AndIOp::create(builder, loc, iGeMin, notTurning);
1541
1542 impulse::YieldOp::create(builder, loc, ValueRange{continueLoop.getResult()});
1543
1544 Block *bodyBlock = builder.createBlock(&whileOp.getBodyRegion());
1545 bodyBlock->addArgument(i64TensorType, loc);
1546 bodyBlock->addArgument(i1TensorType, loc);
1547 builder.setInsertionPointToStart(bodyBlock);
1548 Value iBody = bodyBlock->getArgument(0);
1549
1550 auto zeroI64 = arith::ConstantOp::create(
1551 builder, loc, i64TensorType,
1552 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
1553 Value pLeft = impulse::DynamicSliceOp::create(
1554 builder, loc, positionType, pCkpts, ValueRange{iBody, zeroI64},
1555 builder.getDenseI64ArrayAttr({1, ctx.positionSize}));
1556 Value pSumCkptI = impulse::DynamicSliceOp::create(
1557 builder, loc, positionType, pSumCkpts, ValueRange{iBody, zeroI64},
1558 builder.getDenseI64ArrayAttr({1, ctx.positionSize}));
1559
1560 // Compute subtree momentum sum: pSum - pSumCkpts[i] + pCkpts[i]
1561 auto pSumMinusCkpt = arith::SubFOp::create(builder, loc, pSum, pSumCkptI);
1562 Value subtreePSum = arith::AddFOp::create(builder, loc, pSumMinusCkpt, pLeft);
1563
1564 // Check turning
1565 Value turningAtCkpt = checkTurning(builder, loc, pLeft, p, subtreePSum, ctx);
1566
1567 Value iNext = arith::SubIOp::create(builder, loc, iBody, oneI64);
1568 impulse::YieldOp::create(builder, loc, ValueRange{iNext, turningAtCkpt});
1569
1570 builder.setInsertionPointAfter(whileOp);
1571 return whileOp.getResult(1);
1572}
1573
1574std::pair<Value, Value>
1575impulse::updateCheckpoints(OpBuilder &builder, Location loc, Value leafIdx,
1576 Value ckptIdxMax, Value p, Value pSum, Value pCkpts,
1577 Value pSumCkpts, const NUTSContext &ctx,
1578 bool debugDump) {
1579 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1580 auto oneI64 = arith::ConstantOp::create(
1581 builder, loc, i64TensorType,
1582 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1583 auto zeroI64 = arith::ConstantOp::create(
1584 builder, loc, i64TensorType,
1585 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
1586
1587 Value leafIdxBit0 = arith::AndIOp::create(builder, loc, leafIdx, oneI64);
1588 Value isEven = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
1589 leafIdxBit0, zeroI64);
1590
1591 auto pCkptsType = cast<RankedTensorType>(pCkpts.getType());
1592
1593 // Compute updates only on even leafIdx
1594 SmallVector<Type> ifResultTypes = {pCkptsType, pCkptsType};
1595 auto ifOp = impulse::IfOp::create(builder, loc, ifResultTypes, isEven);
1596 {
1597 Block *trueBranch = builder.createBlock(&ifOp.getTrueBranch());
1598 builder.setInsertionPointToStart(trueBranch);
1599
1600 auto updatedPCkpts = impulse::DynamicUpdateSliceOp::create(
1601 builder, loc, pCkptsType, pCkpts, p, ValueRange{ckptIdxMax, zeroI64});
1602 auto updatedPSumCkpts = impulse::DynamicUpdateSliceOp::create(
1603 builder, loc, pCkptsType, pSumCkpts, pSum,
1604 ValueRange{ckptIdxMax, zeroI64});
1605
1606 impulse::YieldOp::create(builder, loc,
1607 ValueRange{updatedPCkpts, updatedPSumCkpts});
1608 }
1609 {
1610 Block *falseBranch = builder.createBlock(&ifOp.getFalseBranch());
1611 builder.setInsertionPointToStart(falseBranch);
1612
1613 impulse::YieldOp::create(builder, loc, ValueRange{pCkpts, pSumCkpts});
1614 }
1615
1616 builder.setInsertionPointAfter(ifOp);
1617
1618 Value finalPCkpts = ifOp.getResult(0);
1619 Value finalPSumCkpts = ifOp.getResult(1);
1620
1621 return {finalPCkpts, finalPSumCkpts};
1622}
1623
1624DualAveragingState impulse::initDualAveraging(OpBuilder &builder, Location loc,
1625 Value stepSize) {
1626 auto stepSizeType = cast<RankedTensorType>(stepSize.getType());
1627 auto elemType = stepSizeType.getElementType();
1628 auto scalarType = RankedTensorType::get({}, elemType);
1629 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1630
1631 // prox_center = log(10) + log(step_size)
1632 auto log10Const = arith::ConstantOp::create(
1633 builder, loc, scalarType,
1634 DenseElementsAttr::get(scalarType,
1635 builder.getFloatAttr(elemType, std::log(10.0))));
1636 Value logStepSize = math::LogOp::create(builder, loc, stepSize);
1637 Value proxCenter =
1638 arith::AddFOp::create(builder, loc, log10Const, logStepSize);
1639
1640 auto zeroConst = arith::ConstantOp::create(
1641 builder, loc, scalarType,
1642 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0)));
1643 auto zeroI64 = arith::ConstantOp::create(
1644 builder, loc, i64TensorType,
1645 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
1646
1647 return {
1648 .log_step_size = zeroConst,
1649 .log_step_size_avg = zeroConst,
1650 .gradient_avg = zeroConst,
1651 .step_count = zeroI64,
1652 .prox_center = proxCenter,
1653 };
1654}
1655
1657impulse::updateDualAveraging(OpBuilder &builder, Location loc,
1658 const DualAveragingState &state, Value acceptProb,
1659 const DualAveragingConfig &config) {
1660 // Dual Averaging update:
1661 // g = target_accept_prob - accept_prob
1662 // g_avg = (1 - 1/(t+t0)) * g_avg + g/(t+t0)
1663 // log_step_size = prox_center - sqrt(t)/gamma * g_avg
1664 // log_step_size_avg = (1 - t^(-kappa)) * log_step_size_avg +
1665 // t^(-kappa) * log_step_size
1666 auto acceptProbType = cast<RankedTensorType>(acceptProb.getType());
1667 auto elemType = acceptProbType.getElementType();
1668 auto scalarType = RankedTensorType::get({}, elemType);
1669 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1670
1671 // t = t + 1
1672 auto oneI64 = arith::ConstantOp::create(
1673 builder, loc, i64TensorType,
1674 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1675 Value tNew = arith::AddIOp::create(builder, loc, state.step_count, oneI64);
1676 Value tFloat = arith::SIToFPOp::create(builder, loc, scalarType, tNew);
1677
1678 // t0
1679 auto t0Const = arith::ConstantOp::create(
1680 builder, loc, scalarType,
1681 DenseElementsAttr::get(scalarType,
1682 builder.getFloatAttr(elemType, config.t0)));
1683
1684 // g = target_accept_prob - accept_prob
1685 auto targetConst = arith::ConstantOp::create(
1686 builder, loc, scalarType,
1687 DenseElementsAttr::get(
1688 scalarType,
1689 builder.getFloatAttr(elemType, config.target_accept_prob)));
1690 Value g = arith::SubFOp::create(builder, loc, targetConst, acceptProb);
1691
1692 // t_plus_t0 = t + t0
1693 Value tPlusT0 = arith::AddFOp::create(builder, loc, tFloat, t0Const);
1694
1695 // weight = 1 / (t + t0)
1696 auto oneConst = arith::ConstantOp::create(
1697 builder, loc, scalarType,
1698 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
1699 Value weight = arith::DivFOp::create(builder, loc, oneConst, tPlusT0);
1700
1701 // decay = 1 - weight = 1 - 1/(t + t0)
1702 Value decay = arith::SubFOp::create(builder, loc, oneConst, weight);
1703
1704 // g_avg = decay * g_avg + weight * g
1705 // g_avg = (1 - 1/(t+t0)) * g_avg + g/(t+t0)
1706 Value gAvgDecayed =
1707 arith::MulFOp::create(builder, loc, decay, state.gradient_avg);
1708 Value gWeighted = arith::MulFOp::create(builder, loc, weight, g);
1709 Value gAvgNew = arith::AddFOp::create(builder, loc, gAvgDecayed, gWeighted);
1710
1711 // x_t = prox_center - sqrt(t) / gamma * g_avg
1712 Value sqrtT = math::SqrtOp::create(builder, loc, tFloat);
1713 auto gammaConst = arith::ConstantOp::create(
1714 builder, loc, scalarType,
1715 DenseElementsAttr::get(scalarType,
1716 builder.getFloatAttr(elemType, config.gamma)));
1717 Value sqrtTOverGamma = arith::DivFOp::create(builder, loc, sqrtT, gammaConst);
1718 Value adjustment =
1719 arith::MulFOp::create(builder, loc, sqrtTOverGamma, gAvgNew);
1720 Value logStepSizeNew =
1721 arith::SubFOp::create(builder, loc, state.prox_center, adjustment);
1722
1723 // weight_t = t^(-kappa)
1724 auto negKappaConst = arith::ConstantOp::create(
1725 builder, loc, scalarType,
1726 DenseElementsAttr::get(scalarType,
1727 builder.getFloatAttr(elemType, -config.kappa)));
1728 Value weightT = math::PowFOp::create(builder, loc, tFloat, negKappaConst);
1729
1730 // x_avg = (1 - weight_t) * x_avg + weight_t * x_t
1731 Value oneMinusWeightT =
1732 arith::SubFOp::create(builder, loc, oneConst, weightT);
1733 Value avgDecayed = arith::MulFOp::create(builder, loc, oneMinusWeightT,
1734 state.log_step_size_avg);
1735 Value newContribution =
1736 arith::MulFOp::create(builder, loc, weightT, logStepSizeNew);
1737 Value logStepSizeAvgNew =
1738 arith::AddFOp::create(builder, loc, avgDecayed, newContribution);
1739
1740 return {
1741 .log_step_size = logStepSizeNew,
1742 .log_step_size_avg = logStepSizeAvgNew,
1743 .gradient_avg = gAvgNew,
1744 .step_count = tNew,
1745 .prox_center = state.prox_center,
1746 };
1747}
1748
1749Value impulse::getStepSizeFromDualAveraging(OpBuilder &builder, Location loc,
1750 const DualAveragingState &state,
1751 bool final) {
1752 Value logStepSize = final ? state.log_step_size_avg : state.log_step_size;
1753 return math::ExpOp::create(builder, loc, logStepSize);
1754}
1755
1756WelfordState impulse::initWelford(OpBuilder &builder, Location loc,
1757 int64_t positionSize, bool diagonal) {
1758 auto elemType = builder.getF64Type();
1759 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1760 auto meanType = RankedTensorType::get({positionSize}, elemType);
1761
1762 Value mean = arith::ConstantOp::create(
1763 builder, loc, meanType,
1764 DenseElementsAttr::get(meanType, builder.getFloatAttr(elemType, 0.0)));
1765
1766 // Diagonal -> tensor<positionSize>
1767 // Dense -> tensor<positionSize x positionSize>
1768 Value m2;
1769 if (diagonal) {
1770 m2 = arith::ConstantOp::create(
1771 builder, loc, meanType,
1772 DenseElementsAttr::get(meanType, builder.getFloatAttr(elemType, 0.0)));
1773 } else {
1774 auto m2Type = RankedTensorType::get({positionSize, positionSize}, elemType);
1775 m2 = arith::ConstantOp::create(
1776 builder, loc, m2Type,
1777 DenseElementsAttr::get(m2Type, builder.getFloatAttr(elemType, 0.0)));
1778 }
1779
1780 Value n = arith::ConstantOp::create(
1781 builder, loc, i64TensorType,
1782 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
1783
1784 return {mean, m2, n};
1785}
1786
1787WelfordState impulse::updateWelford(OpBuilder &builder, Location loc,
1788 const WelfordState &state, Value sample,
1789 const WelfordConfig &config) {
1790 // Algorithm:
1791 // n = n + 1
1792 // delta_pre = sample - mean
1793 // mean = mean + delta_pre / n
1794 // delta_post = sample - mean
1795 // (if diagonal) m2 = m2 + delta_pre * delta_post
1796 // (if dense) m2 = m2 + outer(delta_post, delta_pre)
1797
1798 auto sampleType = cast<RankedTensorType>(sample.getType());
1799 auto elemType = sampleType.getElementType();
1800 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1801
1802 auto oneI64 = arith::ConstantOp::create(
1803 builder, loc, i64TensorType,
1804 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1805 Value nNew = arith::AddIOp::create(builder, loc, state.n, oneI64);
1806
1807 auto scalarType = RankedTensorType::get({}, elemType);
1808 Value nFloat = arith::SIToFPOp::create(builder, loc, scalarType, nNew);
1809
1810 Value nBroadcast = BroadcastOp::create(builder, loc, sampleType, nFloat,
1811 sampleType.getShape());
1812
1813 Value deltaPre = arith::SubFOp::create(builder, loc, sample, state.mean);
1814
1815 Value deltaPreOverN =
1816 arith::DivFOp::create(builder, loc, deltaPre, nBroadcast);
1817 Value meanNew =
1818 arith::AddFOp::create(builder, loc, state.mean, deltaPreOverN);
1819
1820 Value deltaPost = arith::SubFOp::create(builder, loc, sample, meanNew);
1821
1822 Value m2New;
1823 if (config.diagonal) {
1824 Value product = arith::MulFOp::create(builder, loc, deltaPre, deltaPost);
1825 m2New = arith::AddFOp::create(builder, loc, state.m2, product);
1826 } else { // Dense
1827 auto m2Type = cast<RankedTensorType>(state.m2.getType());
1828 Value outerProduct = impulse::DotOp::create(
1829 builder, loc, m2Type, deltaPost, deltaPre,
1830 /*lhs_batching_dimensions=*/builder.getDenseI64ArrayAttr({}),
1831 /*rhs_batching_dimensions=*/builder.getDenseI64ArrayAttr({}),
1832 /*lhs_contracting_dimensions=*/builder.getDenseI64ArrayAttr({}),
1833 /*rhs_contracting_dimensions=*/builder.getDenseI64ArrayAttr({}));
1834 m2New = arith::AddFOp::create(builder, loc, state.m2, outerProduct);
1835 }
1836
1837 return {meanNew, m2New, nNew};
1838}
1839
1840Value impulse::finalizeWelford(OpBuilder &builder, Location loc,
1841 const WelfordState &state,
1842 const WelfordConfig &config) {
1843 // Compute sample covariance: cov = m2 / (n - 1)
1844 auto m2Type = cast<RankedTensorType>(state.m2.getType());
1845 auto elemType = m2Type.getElementType();
1846 auto scalarType = RankedTensorType::get({}, elemType);
1847 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1848
1849 auto oneI64 = arith::ConstantOp::create(
1850 builder, loc, i64TensorType,
1851 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1852 Value nMinus1 = arith::SubIOp::create(builder, loc, state.n, oneI64);
1853 Value nMinus1Float =
1854 arith::SIToFPOp::create(builder, loc, scalarType, nMinus1);
1855
1856 Value nMinus1Bcast = BroadcastOp::create(builder, loc, m2Type, nMinus1Float,
1857 m2Type.getShape());
1858
1859 Value cov = arith::DivFOp::create(builder, loc, state.m2, nMinus1Bcast);
1860
1861 // (Optional) Regularization (Stan's shrinkage):
1862 // scaled_cov = (n / (n + 5)) * cov
1863 // shrinkage = 1e-3 * (5 / (n + 5))
1864 // (if diagonal) cov = scaled_cov + shrinkage
1865 // (if dense) cov = scaled_cov + shrinkage * I
1866 if (config.regularize) {
1867 Value nFloat = arith::SIToFPOp::create(builder, loc, scalarType, state.n);
1868
1869 auto fiveConst = arith::ConstantOp::create(
1870 builder, loc, scalarType,
1871 DenseElementsAttr::get(scalarType,
1872 builder.getFloatAttr(elemType, 5.0)));
1873 Value nPlusFive = arith::AddFOp::create(builder, loc, nFloat, fiveConst);
1874
1875 Value scale = arith::DivFOp::create(builder, loc, nFloat, nPlusFive);
1876 Value scaleBcast =
1877 BroadcastOp::create(builder, loc, m2Type, scale, m2Type.getShape());
1878 Value scaledCov = arith::MulFOp::create(builder, loc, scaleBcast, cov);
1879
1880 auto shrinkageBaseConst = arith::ConstantOp::create(
1881 builder, loc, scalarType,
1882 DenseElementsAttr::get(scalarType,
1883 builder.getFloatAttr(elemType, 1e-3 * 5.0)));
1884 Value shrinkage =
1885 arith::DivFOp::create(builder, loc, shrinkageBaseConst, nPlusFive);
1886
1887 if (config.diagonal) {
1888 Value shrinkageBcast = BroadcastOp::create(builder, loc, m2Type,
1889 shrinkage, m2Type.getShape());
1890 cov = arith::AddFOp::create(builder, loc, scaledCov, shrinkageBcast);
1891 } else {
1892 Value identity = createIdentityMatrix(builder, loc, m2Type);
1893 Value shrinkageBcast = BroadcastOp::create(builder, loc, m2Type,
1894 shrinkage, m2Type.getShape());
1895 Value shrinkageI =
1896 arith::MulFOp::create(builder, loc, shrinkageBcast, identity);
1897 cov = arith::AddFOp::create(builder, loc, scaledCov, shrinkageI);
1898 }
1899 }
1900
1901 return cov;
1902}
1903
1904SmallVector<AdaptWindow> impulse::buildAdaptationSchedule(int64_t numSteps) {
1905 // |<-- start buffer -->|<-- middle windows (doubling) -->|<-- end buffer -->|
1906 // | (no mass) | (collect + adapt mass) | (no mass) |
1907 // | step size only | step size + mass matrix | step size only |
1908
1909 SmallVector<AdaptWindow> schedule;
1910
1911 if (numSteps < 20) {
1912 schedule.push_back({0, numSteps - 1});
1913 return schedule;
1914 }
1915
1916 // Stan-style window schedule
1917 int64_t startBufferSize = 75;
1918 int64_t endBufferSize = 50;
1919 int64_t initWindowSize = 25;
1920
1921 if ((startBufferSize + endBufferSize + initWindowSize) > numSteps) {
1922 startBufferSize = static_cast<int64_t>(0.15 * numSteps);
1923 endBufferSize = static_cast<int64_t>(0.1 * numSteps);
1924 initWindowSize = numSteps - startBufferSize - endBufferSize;
1925 }
1926
1927 // Start buffer window
1928 schedule.push_back({0, startBufferSize - 1});
1929
1930 int64_t endWindowStart = numSteps - endBufferSize;
1931 int64_t nextWindowSize = initWindowSize;
1932 int64_t nextWindowStart = startBufferSize;
1933
1934 // Middle windows
1935 while (nextWindowStart < endWindowStart) {
1936 int64_t curWindowStart = nextWindowStart;
1937 int64_t curWindowSize = nextWindowSize;
1938
1939 if (3 * curWindowSize <= endWindowStart - curWindowStart) {
1940 nextWindowSize = 2 * curWindowSize;
1941 } else {
1942 curWindowSize = endWindowStart - curWindowStart;
1943 }
1944
1945 nextWindowStart = curWindowStart + curWindowSize;
1946 schedule.push_back({curWindowStart, nextWindowStart - 1});
1947 }
1948
1949 // End buffer window
1950 schedule.push_back({endWindowStart, numSteps - 1});
1951
1952 return schedule;
1953}
1954
1955Value impulse::unconstrainPosition(OpBuilder &builder, Location loc,
1956 Value constrained,
1957 ArrayRef<SupportInfo> supports) {
1958 bool hasConstraints = false;
1959 for (const auto &info : supports) {
1960 if (info.support && info.support.getKind() != impulse::SupportKind::REAL) {
1961 hasConstraints = true;
1962 break;
1963 }
1964 }
1965 if (!hasConstraints || supports.empty())
1966 return constrained;
1967
1968 auto inputType = cast<RankedTensorType>(constrained.getType());
1969 auto elemType = inputType.getElementType();
1970 int64_t positionSize = inputType.getShape()[1];
1971
1972 auto positionType1D = RankedTensorType::get({positionSize}, elemType);
1973 Value constrained1D =
1974 impulse::ReshapeOp::create(builder, loc, positionType1D, constrained);
1975
1976 SmallVector<Value> slices;
1977 for (const auto &info : supports) {
1978 auto sliceType = RankedTensorType::get({info.size}, elemType);
1979 auto slice = impulse::SliceOp::create(
1980 builder, loc, sliceType, constrained1D,
1981 builder.getDenseI64ArrayAttr({info.offset}),
1982 builder.getDenseI64ArrayAttr({info.offset + info.size}),
1983 builder.getDenseI64ArrayAttr({1}));
1984 Value unconstrainedSlice;
1985 if (info.support && info.support.getKind() != impulse::SupportKind::REAL) {
1986 unconstrainedSlice =
1987 transforms::unconstrain(builder, loc, slice, info.support);
1988 } else {
1989 unconstrainedSlice = slice;
1990 }
1991
1992 slices.push_back(unconstrainedSlice);
1993 }
1994
1995 Value result1D;
1996 if (slices.size() == 1) {
1997 result1D = slices[0];
1998 } else {
1999 auto resultType1D = RankedTensorType::get({positionSize}, elemType);
2000 result1D = arith::ConstantOp::create(
2001 builder, loc, resultType1D,
2002 DenseElementsAttr::get(resultType1D,
2003 builder.getFloatAttr(elemType, 0.0)));
2004
2005 auto i64ScalarType = RankedTensorType::get({}, builder.getI64Type());
2006 auto elemType1DSlice = RankedTensorType::get({1}, elemType);
2007 auto elemType0D = RankedTensorType::get({}, elemType);
2008 int64_t offset = 0;
2009 for (size_t i = 0; i < slices.size(); ++i) {
2010 auto sliceType = cast<RankedTensorType>(slices[i].getType());
2011 int64_t sliceSize = sliceType.getShape()[0];
2012
2013 for (int64_t j = 0; j < sliceSize; ++j) {
2014 auto elemIdx = arith::ConstantOp::create(
2015 builder, loc, i64ScalarType,
2016 DenseElementsAttr::get(i64ScalarType,
2017 builder.getI64IntegerAttr(j)));
2018 auto resultIdx = arith::ConstantOp::create(
2019 builder, loc, i64ScalarType,
2020 DenseElementsAttr::get(i64ScalarType,
2021 builder.getI64IntegerAttr(offset + j)));
2022
2023 auto elemSliced = impulse::DynamicSliceOp::create(
2024 builder, loc, elemType1DSlice, slices[i], ValueRange{elemIdx},
2025 builder.getDenseI64ArrayAttr({1}));
2026 auto elem =
2027 impulse::ReshapeOp::create(builder, loc, elemType0D, elemSliced);
2028
2029 auto resultSliced = impulse::ReshapeOp::create(
2030 builder, loc, RankedTensorType::get({1}, elemType), elem);
2031 result1D = impulse::DynamicUpdateSliceOp::create(
2032 builder, loc, resultType1D, result1D, resultSliced,
2033 ValueRange{resultIdx});
2034 }
2035
2036 offset += sliceSize;
2037 }
2038 }
2039
2040 auto resultType2D = RankedTensorType::get({1, positionSize}, elemType);
2041 return impulse::ReshapeOp::create(builder, loc, resultType2D, result1D);
2042}
2043
2044Value impulse::constrainPosition(OpBuilder &builder, Location loc,
2045 Value unconstrained,
2046 ArrayRef<SupportInfo> supports) {
2047 bool hasConstraints = false;
2048 for (const auto &info : supports) {
2049 if (info.support && info.support.getKind() != impulse::SupportKind::REAL) {
2050 hasConstraints = true;
2051 break;
2052 }
2053 }
2054 if (!hasConstraints || supports.empty())
2055 return unconstrained;
2056
2057 auto inputType = cast<RankedTensorType>(unconstrained.getType());
2058 auto elemType = inputType.getElementType();
2059 int64_t positionSize = inputType.getShape()[1];
2060
2061 auto positionType1D = RankedTensorType::get({positionSize}, elemType);
2062 Value unconstrained1D =
2063 impulse::ReshapeOp::create(builder, loc, positionType1D, unconstrained);
2064
2065 SmallVector<Value> slices;
2066 for (const auto &info : supports) {
2067 auto sliceType = RankedTensorType::get({info.size}, elemType);
2068 auto slice = impulse::SliceOp::create(
2069 builder, loc, sliceType, unconstrained1D,
2070 builder.getDenseI64ArrayAttr({info.offset}),
2071 builder.getDenseI64ArrayAttr({info.offset + info.size}),
2072 builder.getDenseI64ArrayAttr({1}));
2073
2074 Value constrainedSlice;
2075 if (info.support && info.support.getKind() != impulse::SupportKind::REAL) {
2076 constrainedSlice =
2077 transforms::constrain(builder, loc, slice, info.support);
2078 } else {
2079 constrainedSlice = slice;
2080 }
2081
2082 slices.push_back(constrainedSlice);
2083 }
2084
2085 Value result1D;
2086 if (slices.size() == 1) {
2087 result1D = slices[0];
2088 } else {
2089 auto resultType1D = RankedTensorType::get({positionSize}, elemType);
2090 result1D = arith::ConstantOp::create(
2091 builder, loc, resultType1D,
2092 DenseElementsAttr::get(resultType1D,
2093 builder.getFloatAttr(elemType, 0.0)));
2094
2095 auto i64ScalarType = RankedTensorType::get({}, builder.getI64Type());
2096 auto elemType1DSlice = RankedTensorType::get({1}, elemType);
2097 auto elemType0D = RankedTensorType::get({}, elemType);
2098 int64_t offset = 0;
2099 for (size_t i = 0; i < slices.size(); ++i) {
2100 auto sliceType = cast<RankedTensorType>(slices[i].getType());
2101 int64_t sliceSize = sliceType.getShape()[0];
2102
2103 for (int64_t j = 0; j < sliceSize; ++j) {
2104 auto elemIdx = arith::ConstantOp::create(
2105 builder, loc, i64ScalarType,
2106 DenseElementsAttr::get(i64ScalarType,
2107 builder.getI64IntegerAttr(j)));
2108 auto resultIdx = arith::ConstantOp::create(
2109 builder, loc, i64ScalarType,
2110 DenseElementsAttr::get(i64ScalarType,
2111 builder.getI64IntegerAttr(offset + j)));
2112
2113 auto elemSliced = impulse::DynamicSliceOp::create(
2114 builder, loc, elemType1DSlice, slices[i], ValueRange{elemIdx},
2115 builder.getDenseI64ArrayAttr({1}));
2116 auto elem =
2117 impulse::ReshapeOp::create(builder, loc, elemType0D, elemSliced);
2118
2119 auto resultSliced =
2120 impulse::ReshapeOp::create(builder, loc, elemType1DSlice, elem);
2121 result1D = impulse::DynamicUpdateSliceOp::create(
2122 builder, loc, resultType1D, result1D, resultSliced,
2123 ValueRange{resultIdx});
2124 }
2125
2126 offset += sliceSize;
2127 }
2128 }
2129
2130 auto resultType2D = RankedTensorType::get({1, positionSize}, elemType);
2131 return impulse::ReshapeOp::create(builder, loc, resultType2D, result1D);
2132}
2133
2134Value impulse::computeTotalJacobianCorrection(OpBuilder &builder, Location loc,
2135 Value unconstrained,
2136 ArrayRef<SupportInfo> supports) {
2137 auto inputType = cast<RankedTensorType>(unconstrained.getType());
2138 auto elemType = inputType.getElementType();
2139 auto scalarType = RankedTensorType::get({}, elemType);
2140
2141 Value total = arith::ConstantOp::create(
2142 builder, loc, scalarType,
2143 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0)));
2144
2145 if (supports.empty())
2146 return total;
2147
2148 int64_t positionSize = inputType.getShape()[1];
2149 auto positionType1D = RankedTensorType::get({positionSize}, elemType);
2150 Value unconstrained1D =
2151 impulse::ReshapeOp::create(builder, loc, positionType1D, unconstrained);
2152
2153 for (const auto &info : supports) {
2154 if (!info.support || info.support.getKind() == impulse::SupportKind::REAL)
2155 continue;
2156
2157 auto sliceType = RankedTensorType::get({info.size}, elemType);
2158 auto slice = impulse::SliceOp::create(
2159 builder, loc, sliceType, unconstrained1D,
2160 builder.getDenseI64ArrayAttr({info.offset}),
2161 builder.getDenseI64ArrayAttr({info.offset + info.size}),
2162 builder.getDenseI64ArrayAttr({1}));
2163 auto jacobian =
2164 transforms::logAbsDetJacobian(builder, loc, slice, info.support);
2165
2166 total = arith::AddFOp::create(builder, loc, total, jacobian);
2167 }
2168
2169 return total;
2170}
static Value gatherPositionFromTrace(OpBuilder &builder, Location loc, Value fullTrace, const HMCContext &ctx)
Definition HMCUtils.cpp:315
static Value createPermutationMatrix(OpBuilder &builder, Location loc, RankedTensorType matrixType)
Creates a permutation matrix of size n x n.
Definition HMCUtils.cpp:98
static Value scatterPositionToTrace(OpBuilder &builder, Location loc, Value position2d, Value fullTrace, const HMCContext &ctx)
Definition HMCUtils.cpp:281
static Value createIdentityMatrix(OpBuilder &builder, Location loc, RankedTensorType matrixType)
Creates a 2D identity matrix of the specified type.
Definition HMCUtils.cpp:75
static Value reverseRowsAndColumns(OpBuilder &builder, Location loc, Value matrix)
Computes A[::-1, ::-1] using permutation matrix through P @ A @ P.
Definition HMCUtils.cpp:121
PointerType * traceType(LLVMContext &C)
Value logAbsDetJacobian(OpBuilder &builder, Location loc, Value unconstrained, impulse::SupportAttr support)
Compute log |det J| of the transform from unconstrained to constrained.
Value unconstrain(OpBuilder &builder, Location loc, Value constrained, impulse::SupportAttr support)
Transform from constrained to unconstrained space.
Value constrain(OpBuilder &builder, Location loc, Value unconstrained, impulse::SupportAttr support)
Transform from unconstrained to constrained space.
Value computeTotalJacobianCorrection(OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports)
Compute total Jacobian correction for the constrain transform over all position vector slices.
Value finalizeWelford(OpBuilder &builder, Location loc, const WelfordState &state, const WelfordConfig &config)
Finalize Welford state to produce sample covariance (returned as inverse mass matrix).
Value checkTurning(OpBuilder &builder, Location loc, Value pLeft, Value pRight, Value pSum, const NUTSContext &ctx)
U-turn termination criterion.
Definition HMCUtils.cpp:509
GradientResult computePotentialAndGradient(OpBuilder &builder, Location loc, Value position, Value rng, const HMCContext &ctx)
Computes potential energy U(q) = -log p(q) and its gradient dU/dq
Definition HMCUtils.cpp:352
std::pair< Value, Value > leafIdxToCheckpointIdxs(OpBuilder &builder, Location loc, Value leafIdx)
Computes checkpoint indices from leaf index for iterative turning check.
IntegratorState getLeafFromTree(OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, const NUTSContext &ctx)
Extracts an appropriate leaf based on direction.
std::pair< Value, Value > updateCheckpoints(OpBuilder &builder, Location loc, Value leafIdx, Value ckptIdxMax, Value p, Value pSum, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
Update checkpoint arrays at even leaf indices.
SubtreeBuildResult buildIterativeSubtree(OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
Builds a subtree iteratively by appending leaves one at a time.
std::pair< Value, Value > sampleMomentum(OpBuilder &builder, Location loc, Value rng, Value invMass, Value massMatrixSqrt, RankedTensorType positionType, bool debugDump=false)
Samples momentum from N(0, M) where M is the mass matrix.
Definition HMCUtils.cpp:233
Value computeUniformTransitionProb(OpBuilder &builder, Location loc, Value currentWeight, Value newWeight)
Computes the uniform transition probability for subtree combination.
Definition HMCUtils.cpp:557
SmallVector< AdaptWindow > buildAdaptationSchedule(int64_t numSteps)
Build warmup adaptation schedule.
NUTSTreeState buildBaseTree(OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const NUTSContext &ctx)
Builds a base tree (leaf node) by taking one leapfrog step.
Value computeKineticEnergy(OpBuilder &builder, Location loc, Value momentum, Value invMass, RankedTensorType positionType)
Computes K = 0.5 * p^T @ M^-1 @ p
Definition HMCUtils.cpp:166
NUTSTreeState combineTrees(OpBuilder &builder, Location loc, const NUTSTreeState &tree, const NUTSTreeState &subTree, Value direction, Value rng, bool biased, const NUTSContext &ctx)
Combines a tree with a newly-built subtree during NUTS doubling process.
Definition HMCUtils.cpp:590
Value computeBiasedTransitionProb(OpBuilder &builder, Location loc, Value currentWeight, Value newWeight, Value turning, Value diverging)
Computes the biased transition probability for main tree combination.
Definition HMCUtils.cpp:566
Value checkIterativeTurning(OpBuilder &builder, Location loc, Value p, Value pSum, Value pCkpts, Value pSumCkpts, Value idxMin, Value idxMax, const NUTSContext &ctx, bool debugDump=false)
Checkpoint-based iterative turning check.
NUTSTreeState buildTree(OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, const NUTSContext &ctx, bool debugDump=false)
Main NUTS tree building loop.
Value conditionalDump(OpBuilder &builder, Location loc, Value value, StringRef label, bool debugDump)
Conditionally dump a value for debugging.
Definition HMCUtils.cpp:64
MCMCKernelResult SampleNUTS(OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const NUTSContext &ctx, bool debugDump=false)
Single NUTS iteration: momentum sampling + tree building.
Definition HMCUtils.cpp:998
InitialHMCState InitHMC(OpBuilder &builder, Location loc, Value rng, const HMCContext &ctx, Value initialPosition=Value(), bool debugDump=false)
Initializes HMC/NUTS state from a trace Specifically:
Definition HMCUtils.cpp:701
SubtreeBuildResult doubleTree(OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
Tree doubling by building a subtree of same depth and combining.
IntegrationResult computeIntegrationStep(OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const HMCContext &ctx)
Computes a single leapfrog integration step.
Definition HMCUtils.cpp:460
DualAveragingState updateDualAveraging(OpBuilder &builder, Location loc, const DualAveragingState &state, Value acceptProb, const DualAveragingConfig &config)
Update dual averaging state with observed acceptance probability.
Value unconstrainPosition(OpBuilder &builder, Location loc, Value constrained, ArrayRef< SupportInfo > supports)
Transform an entire position vector from constrained to unconstrained space based on the support info...
DualAveragingState initDualAveraging(OpBuilder &builder, Location loc, Value stepSize)
Initialize dual averaging state from initial step size.
Value computeMassMatrixSqrt(OpBuilder &builder, Location loc, Value invMass, RankedTensorType positionType)
Computes the square root of the mass matrix from the inverse mass matrix.
Definition HMCUtils.cpp:190
WelfordState updateWelford(OpBuilder &builder, Location loc, const WelfordState &state, Value sample, const WelfordConfig &config)
Update Welford state with a new sample.
WelfordState initWelford(OpBuilder &builder, Location loc, int64_t positionSize, bool diagonal)
Initialize state for Welford covariance estimation.
Value getStepSizeFromDualAveraging(OpBuilder &builder, Location loc, const DualAveragingState &state, bool final=false)
Get step size from dual averaging state.
MCMCKernelResult SampleHMC(OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const HMCContext &ctx, bool debugDump=false)
Single HMC iteration: momentum sampling + leapfrog + MH accept/reject.
Definition HMCUtils.cpp:862
Value constrainPosition(OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports)
Transform an entire position vector from unconstrained to constrained space.
Value applyInverseMassMatrix(OpBuilder &builder, Location loc, Value invMass, Value momentum, RankedTensorType positionType)
Computes v = M^-1 @ p If invMass is nullptr, returns momentum unchanged (assumes identity) If invMass...
Definition HMCUtils.cpp:139
DictionaryAttr autodiffAttrs
Definition HMCUtils.h:112
Type getElementType() const
Definition HMCUtils.h:145
FlatSymbolRefAttr fn
Definition HMCUtils.h:99
HMCContext withStepSize(Value newStepSize) const
Definition HMCUtils.h:165
ArrayRef< Value > fnInputs
Definition HMCUtils.h:100
RankedTensorType getPositionType() const
Definition HMCUtils.h:149
int64_t getFullTraceSize() const
Definition HMCUtils.h:140
SmallVector< Type > fnResultTypes
Definition HMCUtils.h:101
SmallVector< SupportInfo > supports
Definition HMCUtils.h:110
bool hasCustomLogpdf() const
Definition HMCUtils.h:138
RankedTensorType getScalarType() const
Definition HMCUtils.h:153
FlatSymbolRefAttr logpdfFn
Definition HMCUtils.h:111
Result of one MCMC kernel step.
Definition HMCUtils.h:62
NUTSContext withH0(Value newH0) const
Definition HMCUtils.h:198
SmallVector< Value > toValues() const
Definition HMCUtils.cpp:26
SmallVector< Type > getTypes() const
Definition HMCUtils.cpp:57
static NUTSTreeState fromValues(ArrayRef< Value > values)
Definition HMCUtils.cpp:35
Configuration for Welford covariance estimation.
Definition HMCUtils.h:420
State for Welford covariance estimation.
Definition HMCUtils.h:405