15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/Math/IR/Math.h"
36 assert(values.size() == 18 &&
"Expected 18 NUTSTreeState fields");
37 return {.q_left = values[0],
39 .grad_left = values[2],
42 .grad_right = values[5],
43 .q_proposal = values[6],
44 .grad_proposal = values[7],
45 .U_proposal = values[8],
46 .H_proposal = values[9],
49 .turning = values[12],
50 .diverging = values[13],
51 .sum_accept_probs = values[14],
52 .num_proposals = values[15],
58 SmallVector<Type> types;
60 types.push_back(val.getType());
65 StringRef label,
bool debugDump) {
67 return enzyme::DumpOp::create(builder, loc, value.getType(), value,
68 builder.getStringAttr(label))
76 RankedTensorType matrixType) {
77 assert(matrixType.getRank() == 2 &&
"Expected 2D tensor type");
78 assert(matrixType.getShape()[0] == matrixType.getShape()[1] &&
79 "Expected square matrix type");
81 int64_t size = matrixType.getShape()[0];
82 auto elemType = matrixType.getElementType();
84 SmallVector<Attribute> values;
85 values.reserve(size * size);
86 for (int64_t i = 0; i < size; ++i) {
87 for (int64_t j = 0; j < size; ++j) {
88 double val = (i == j) ? 1.0 : 0.0;
89 values.push_back(builder.getFloatAttr(elemType, val));
93 auto denseAttr = DenseElementsAttr::get(matrixType, values);
94 return arith::ConstantOp::create(builder, loc, matrixType, denseAttr);
99 RankedTensorType matrixType) {
100 assert(matrixType.getRank() == 2 &&
"Expected 2D tensor type");
101 assert(matrixType.getShape()[0] == matrixType.getShape()[1] &&
102 "Expected square matrix type");
104 int64_t size = matrixType.getShape()[0];
105 auto elemType = matrixType.getElementType();
107 SmallVector<Attribute> values;
108 values.reserve(size * size);
109 for (int64_t i = 0; i < size; ++i) {
110 for (int64_t j = 0; j < size; ++j) {
111 double val = (j == size - 1 - i) ? 1.0 : 0.0;
112 values.push_back(builder.getFloatAttr(elemType, val));
116 auto denseAttr = DenseElementsAttr::get(matrixType, values);
117 return arith::ConstantOp::create(builder, loc, matrixType, denseAttr);
123 auto matrixType = cast<RankedTensorType>(matrix.getType());
127 auto PA = impulse::DotOp::create(
128 builder, loc, matrixType, P, matrix, builder.getDenseI64ArrayAttr({}),
129 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({1}),
130 builder.getDenseI64ArrayAttr({0}));
133 return impulse::DotOp::create(
134 builder, loc, matrixType, PA, P, builder.getDenseI64ArrayAttr({}),
135 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({1}),
136 builder.getDenseI64ArrayAttr({0}));
140 Value invMass, Value momentum,
141 RankedTensorType positionType) {
146 auto invMassType = cast<RankedTensorType>(invMass.getType());
148 if (invMassType.getRank() == 1) {
151 impulse::ReshapeOp::create(builder, loc, positionType, invMass);
152 return arith::MulFOp::create(builder, loc, diagMass, momentum);
153 }
else if (invMassType.getRank() == 2) {
155 return impulse::DotOp::create(
156 builder, loc, positionType, momentum, invMass,
157 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
158 builder.getDenseI64ArrayAttr({1}), builder.getDenseI64ArrayAttr({0}));
161 emitError(loc,
"ProbProg: Provided invMass must have rank 1 or 2, got rank " +
162 std::to_string(invMassType.getRank()));
167 Value momentum, Value invMass,
168 RankedTensorType positionType) {
169 auto elemType = positionType.getElementType();
170 auto scalarType = RankedTensorType::get({}, elemType);
172 auto halfConst = arith::ConstantOp::create(
173 builder, loc, scalarType,
174 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.5)));
182 auto pDotV = impulse::DotOp::create(
183 builder, loc, scalarType, momentum, v, builder.getDenseI64ArrayAttr({}),
184 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({0, 1}),
185 builder.getDenseI64ArrayAttr({0, 1}));
187 return arith::MulFOp::create(builder, loc, halfConst, pDotV);
192 RankedTensorType positionType) {
197 auto invMassType = cast<RankedTensorType>(invMass.getType());
198 auto elemType = invMassType.getElementType();
200 if (invMassType.getRank() == 1) {
202 auto sqrtInvMass = math::SqrtOp::create(builder, loc, invMass);
203 auto onesVector = arith::ConstantOp::create(
204 builder, loc, invMassType,
205 DenseElementsAttr::get(invMassType,
206 builder.getFloatAttr(elemType, 1.0)));
207 return arith::DivFOp::create(builder, loc, onesVector, sqrtInvMass);
215 impulse::CholeskyOp::create(builder, loc, invMassType, reversedInvMass,
216 builder.getBoolAttr(
true));
219 auto massMatrixSqrt = impulse::TriangularSolveOp::create(
220 builder, loc, invMassType, massMatrixSqrtInvT, identityMatrix,
221 builder.getBoolAttr(
true),
222 builder.getBoolAttr(
false),
223 builder.getBoolAttr(
false),
225 impulse::TransposeAttr::get(builder.getContext(),
226 impulse::Transpose::TRANSPOSE));
228 return massMatrixSqrt;
232std::pair<Value, Value>
234 Value invMass, Value massMatrixSqrt,
235 RankedTensorType positionType,
bool debugDump) {
236 auto elemType = positionType.getElementType();
237 auto scalarType = RankedTensorType::get({}, elemType);
239 auto zeroConst = arith::ConstantOp::create(
240 builder, loc, scalarType,
241 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0)));
242 auto oneConst = arith::ConstantOp::create(
243 builder, loc, scalarType,
244 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
246 rng =
conditionalDump(builder, loc, rng,
"sampleMomentum: input rng state",
250 auto randomOp = impulse::RandomOp::create(
251 builder, loc, TypeRange{rng.getType(), positionType}, rng, zeroConst,
253 impulse::RngDistributionAttr::get(builder.getContext(),
254 impulse::RngDistribution::NORMAL));
256 auto rngOut = randomOp.getOutputRngState();
257 auto eps = randomOp.getResult();
259 if (!massMatrixSqrt) {
260 return {eps, rngOut};
263 auto massMatrixSqrtType = cast<RankedTensorType>(massMatrixSqrt.getType());
265 if (massMatrixSqrtType.getRank() == 1) {
268 impulse::ReshapeOp::create(builder, loc, positionType, massMatrixSqrt);
269 auto p = arith::MulFOp::create(builder, loc, diagSqrt, eps);
273 auto p = impulse::DotOp::create(
274 builder, loc, positionType, eps, massMatrixSqrt,
275 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
276 builder.getDenseI64ArrayAttr({1}), builder.getDenseI64ArrayAttr({1}));
282 Value position2d, Value fullTrace,
285 cast<RankedTensorType>(ctx.
stepSize.getType()).getElementType();
287 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
288 auto c0 = arith::ConstantOp::create(
289 builder, loc, i64TensorType,
290 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
292 Value result = fullTrace;
293 for (
const auto &info : ctx.
supports) {
294 auto sliceType = RankedTensorType::get({1, info.size}, elemType);
295 auto posOffset = arith::ConstantOp::create(
296 builder, loc, i64TensorType,
297 DenseElementsAttr::get(i64TensorType,
298 builder.getI64IntegerAttr(info.offset)));
299 SmallVector<Value> extractIndices{c0, posOffset};
300 auto slice = impulse::DynamicSliceOp::create(
301 builder, loc, sliceType, position2d, extractIndices,
302 builder.getDenseI64ArrayAttr({1, info.size}));
304 auto traceOffset = arith::ConstantOp::create(
305 builder, loc, i64TensorType,
306 DenseElementsAttr::get(i64TensorType,
307 builder.getI64IntegerAttr(info.traceOffset)));
308 SmallVector<Value> updateIndices{c0, traceOffset};
309 result = impulse::DynamicUpdateSliceOp::create(
310 builder, loc,
traceType, result, slice, updateIndices);
318 cast<RankedTensorType>(ctx.
stepSize.getType()).getElementType();
319 auto positionType2d = RankedTensorType::get({1, ctx.
positionSize}, elemType);
320 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
321 auto c0 = arith::ConstantOp::create(
322 builder, loc, i64TensorType,
323 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
325 auto zeroConst = arith::ConstantOp::create(
326 builder, loc, positionType2d,
327 DenseElementsAttr::get(positionType2d,
328 builder.getFloatAttr(elemType, 0.0)));
329 Value result = zeroConst;
330 for (
const auto &info : ctx.
supports) {
331 auto sliceType = RankedTensorType::get({1, info.size}, elemType);
332 auto traceOffset = arith::ConstantOp::create(
333 builder, loc, i64TensorType,
334 DenseElementsAttr::get(i64TensorType,
335 builder.getI64IntegerAttr(info.traceOffset)));
336 SmallVector<Value> extractIndices{c0, traceOffset};
337 auto slice = impulse::DynamicSliceOp::create(
338 builder, loc, sliceType, fullTrace, extractIndices,
339 builder.getDenseI64ArrayAttr({1, info.size}));
341 auto posOffset = arith::ConstantOp::create(
342 builder, loc, i64TensorType,
343 DenseElementsAttr::get(i64TensorType,
344 builder.getI64IntegerAttr(info.offset)));
345 SmallVector<Value> updateIndices{c0, posOffset};
346 result = impulse::DynamicUpdateSliceOp::create(
347 builder, loc, positionType2d, result, slice, updateIndices);
354 Value position, Value rng,
360 auto gradSeed = arith::ConstantOp::create(
361 builder, loc, scalarType,
362 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
365 auto flatType = RankedTensorType::get({ctx.
positionSize}, elemType);
366 Value autodiffPosition = position;
367 auto autodiffPositionType = positionType;
368 auto autodiffGradType = positionType;
369 if (isCustomLogpdf) {
371 impulse::ReshapeOp::create(builder, loc, flatType, position);
372 autodiffPositionType = flatType;
373 autodiffGradType = flatType;
376 SmallVector<Value> autodiffInputs{autodiffPosition, gradSeed};
377 SmallVector<NamedAttribute> adAttrs{
378 builder.getNamedAttr(
380 builder.getArrayAttr({enzyme::ActivityAttr::get(
381 builder.getContext(), enzyme::Activity::enzyme_active)})),
382 builder.getNamedAttr(
384 builder.getArrayAttr(
385 {enzyme::ActivityAttr::get(builder.getContext(),
386 enzyme::Activity::enzyme_active),
387 enzyme::ActivityAttr::get(builder.getContext(),
388 enzyme::Activity::enzyme_const)})),
392 adAttrs.push_back(attr);
394 auto autodiffOp = enzyme::AutoDiffRegionOp::create(
395 builder, loc, TypeRange{scalarType, rng.getType(), autodiffGradType},
396 autodiffInputs, adAttrs);
398 Block *autodiffBlock = builder.createBlock(&autodiffOp.getBody());
399 autodiffBlock->addArgument(autodiffPositionType, loc);
401 builder.setInsertionPointToStart(autodiffBlock);
402 Value qArg = autodiffBlock->getArgument(0);
404 if (isCustomLogpdf) {
405 SmallVector<Value> callArgs;
406 callArgs.push_back(qArg);
408 auto callOp = func::CallOp::create(builder, loc, ctx.
logpdfFn,
409 TypeRange{scalarType}, callArgs);
410 Value U = arith::NegFOp::create(builder, loc, callOp.getResult(0));
412 enzyme::YieldOp::create(builder, loc, {U, rng});
421 SmallVector<Value> generateInputs;
422 generateInputs.push_back(rng);
425 SmallVector<Type> generateResultTypes;
426 generateResultTypes.push_back(
traceType);
427 generateResultTypes.push_back(scalarType);
431 auto generateOp = impulse::GenerateOp::create(
432 builder, loc, generateResultTypes, ctx.
fn, generateInputs, fullTrace,
436 arith::NegFOp::create(builder, loc, generateOp.getWeight());
437 Value jacobianCorrection =
440 arith::SubFOp::create(builder, loc, negWeight, jacobianCorrection);
442 SmallVector<Value> yieldValues{U, generateOp.getResult(2)};
443 enzyme::YieldOp::create(builder, loc, yieldValues);
446 builder.setInsertionPointAfter(autodiffOp);
448 Value grad = autodiffOp.getResult(2);
449 if (isCustomLogpdf) {
450 grad = impulse::ReshapeOp::create(builder, loc, positionType, grad);
454 autodiffOp.getResult(0),
456 autodiffOp.getResult(1)
463 Value rng, Value direction,
469 auto negStepSize = arith::NegFOp::create(builder, loc, ctx.
stepSize);
470 Value signedStepSize = impulse::SelectOp::create(
471 builder, loc, scalarType, direction, ctx.
stepSize, negStepSize);
473 auto halfConst = arith::ConstantOp::create(
474 builder, loc, scalarType,
475 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.5)));
477 ArrayRef<int64_t> shape = positionType.getShape();
478 auto stepSizeBroadcast =
479 BroadcastOp::create(builder, loc, positionType, signedStepSize,
480 builder.getDenseI64ArrayAttr(shape));
482 arith::MulFOp::create(builder, loc, halfConst, signedStepSize);
483 auto halfStepBroadcast =
484 BroadcastOp::create(builder, loc, positionType, halfStep,
485 builder.getDenseI64ArrayAttr(shape));
489 arith::MulFOp::create(builder, loc, halfStepBroadcast, leaf.
grad);
490 Value pHalf = arith::SubFOp::create(builder, loc, leaf.
p, deltaP1);
495 auto deltaQ = arith::MulFOp::create(builder, loc, stepSizeBroadcast, v);
496 Value qNew = arith::AddFOp::create(builder, loc, leaf.
q, deltaQ);
503 arith::MulFOp::create(builder, loc, halfStepBroadcast, gradResult.grad);
504 Value pNew = arith::SubFOp::create(builder, loc, pHalf, deltaP2);
506 return {qNew, pNew, gradResult.grad, gradResult.U, gradResult.rng};
510 Value pRight, Value pSum,
const NUTSContext &ctx) {
515 auto zeroConst = arith::ConstantOp::create(
516 builder, loc, scalarType,
517 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0)));
518 auto halfConst = arith::ConstantOp::create(
519 builder, loc, scalarType,
520 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.5)));
528 auto halfBroadcast = BroadcastOp::create(
529 builder, loc, positionType, halfConst,
530 builder.getDenseI64ArrayAttr(positionType.getShape()));
532 auto pLeftPlusPRight = arith::AddFOp::create(builder, loc, pLeft, pRight);
534 arith::MulFOp::create(builder, loc, halfBroadcast, pLeftPlusPRight);
535 Value pSumCentered = arith::SubFOp::create(builder, loc, pSum, halfSum);
537 auto leftAngle = impulse::DotOp::create(
538 builder, loc, scalarType, vLeft, pSumCentered,
539 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
540 builder.getDenseI64ArrayAttr({0, 1}),
541 builder.getDenseI64ArrayAttr({0, 1}));
542 auto rightAngle = impulse::DotOp::create(
543 builder, loc, scalarType, vRight, pSumCentered,
544 builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}),
545 builder.getDenseI64ArrayAttr({0, 1}),
546 builder.getDenseI64ArrayAttr({0, 1}));
549 auto leftNeg = arith::CmpFOp::create(builder, loc, arith::CmpFPredicate::OLE,
550 leftAngle, zeroConst);
551 auto rightNeg = arith::CmpFOp::create(builder, loc, arith::CmpFPredicate::OLE,
552 rightAngle, zeroConst);
554 return arith::OrIOp::create(builder, loc, leftNeg, rightNeg);
561 arith::SubFOp::create(builder, loc, newWeight, currentWeight);
562 return impulse::LogisticOp::create(builder, loc, weightDiff.getType(),
567 Value currentWeight, Value newWeight,
568 Value turning, Value diverging) {
569 auto resultType = cast<RankedTensorType>(currentWeight.getType());
570 auto elemType = resultType.getElementType();
572 auto zeroConst = arith::ConstantOp::create(
573 builder, loc, resultType,
574 DenseElementsAttr::get(resultType, builder.getFloatAttr(elemType, 0.0)));
575 auto oneConst = arith::ConstantOp::create(
576 builder, loc, resultType,
577 DenseElementsAttr::get(resultType, builder.getFloatAttr(elemType, 1.0)));
580 arith::SubFOp::create(builder, loc, newWeight, currentWeight);
581 Value expDiff = math::ExpOp::create(builder, loc, weightDiff);
583 arith::MinimumFOp::create(builder, loc, oneConst, expDiff);
584 Value turningOrDiverging =
585 arith::OrIOp::create(builder, loc, turning, diverging);
586 return arith::SelectOp::create(builder, loc, resultType, turningOrDiverging,
587 zeroConst, clippedProb);
593 Value direction, Value rng,
bool biased,
598 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
600 auto zeroConst = arith::ConstantOp::create(
601 builder, loc, scalarType,
602 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0)));
603 auto oneConst = arith::ConstantOp::create(
604 builder, loc, scalarType,
605 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
607 auto qLeft = impulse::SelectOp::create(builder, loc, positionType, direction,
609 auto pLeft = impulse::SelectOp::create(builder, loc, positionType, direction,
611 auto gradLeft = impulse::SelectOp::create(
613 auto qRight = impulse::SelectOp::create(builder, loc, positionType, direction,
615 auto pRight = impulse::SelectOp::create(builder, loc, positionType, direction,
618 impulse::SelectOp::create(builder, loc, positionType, direction,
621 auto combinedWeight = impulse::LogAddExpOp::create(
625 Value transitionProb;
635 auto randomOp = impulse::RandomOp::create(
636 builder, loc, TypeRange{rng.getType(), scalarType}, rng, zeroConst,
638 impulse::RngDistributionAttr::get(builder.getContext(),
639 impulse::RngDistribution::UNIFORM));
640 auto rngOut = randomOp.getOutputRngState();
641 auto uniformSample = randomOp.getResult();
643 auto acceptNew = arith::CmpFOp::create(
644 builder, loc, arith::CmpFPredicate::OLT, uniformSample, transitionProb);
647 impulse::SelectOp::create(builder, loc, positionType, acceptNew,
650 impulse::SelectOp::create(builder, loc, positionType, acceptNew,
652 auto UProposal = impulse::SelectOp::create(
654 auto HProposal = impulse::SelectOp::create(
657 auto oneI64 = arith::ConstantOp::create(
658 builder, loc, i64TensorType,
659 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
660 auto combinedDepth = arith::AddIOp::create(builder, loc, tree.
depth, oneI64);
662 Value combinedTurning;
665 builder, loc, pLeft, pRight,
666 arith::AddFOp::create(builder, loc, tree.
p_sum, subTree.
p_sum), ctx);
668 arith::OrIOp::create(builder, loc, subTree.
turning, turningCheck);
670 combinedTurning = tree.
turning;
673 auto combinedDiverging =
675 auto sumAcceptProbs = arith::AddFOp::create(
677 auto numProposals = arith::AddIOp::create(builder, loc, tree.
num_proposals,
679 auto pSum = arith::AddFOp::create(builder, loc, tree.
p_sum, subTree.
p_sum);
681 return {.q_left = qLeft,
683 .grad_left = gradLeft,
686 .grad_right = gradRight,
687 .q_proposal = qProposal,
688 .grad_proposal = gradProposal,
689 .U_proposal = UProposal,
690 .H_proposal = HProposal,
691 .depth = combinedDepth,
692 .weight = combinedWeight,
693 .turning = combinedTurning,
694 .diverging = combinedDiverging,
695 .sum_accept_probs = sumAcceptProbs,
696 .num_proposals = numProposals,
708 auto initSplit = impulse::RandomSplitOp::create(
709 builder, loc, TypeRange{rng.getType(), rng.getType()}, rng);
710 auto kernelSplit = impulse::RandomSplitOp::create(
711 builder, loc, TypeRange{rng.getType(), rng.getType(), rng.getType()},
712 initSplit.getResult(0));
713 auto rngForSampleKernel = kernelSplit.getResult(0);
714 auto rngForAutodiff = kernelSplit.getResult(1);
720 q0 = initialPosition;
721 auto flatType = RankedTensorType::get({ctx.
positionSize}, elemType);
722 auto q0Flat = impulse::ReshapeOp::create(builder, loc, flatType, q0);
723 SmallVector<Value> callArgs;
724 callArgs.push_back(q0Flat);
726 auto callOp = func::CallOp::create(builder, loc, ctx.
logpdfFn,
727 TypeRange{scalarType}, callArgs);
728 U0 = arith::NegFOp::create(builder, loc, callOp.getResult(0));
734 auto q0_constrained =
744 SmallVector<Value> generateInputsInit;
745 generateInputsInit.push_back(rngForAutodiff);
748 SmallVector<Type> generateResultTypesInit;
749 generateResultTypesInit.push_back(fullTraceType);
750 generateResultTypesInit.push_back(scalarType);
754 auto generateOpInit = impulse::GenerateOp::create(
755 builder, loc, generateResultTypesInit, ctx.
fn, generateInputsInit,
757 builder.getStringAttr(
""));
759 auto weight0 = generateOpInit.getWeight();
760 auto negWeight0 = arith::NegFOp::create(builder, loc, weight0);
763 U0 = arith::SubFOp::create(builder, loc, negWeight0, jacobian0);
768 auto flatType = RankedTensorType::get({ctx.
positionSize}, elemType);
769 Value autodiffQ0 = q0;
770 auto autodiffQ0Type = positionType;
771 auto autodiffGradType = positionType;
772 if (isCustomLogpdf) {
773 autodiffQ0 = impulse::ReshapeOp::create(builder, loc, flatType, q0);
774 autodiffQ0Type = flatType;
775 autodiffGradType = flatType;
778 auto gradSeedInit = arith::ConstantOp::create(
779 builder, loc, scalarType,
780 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
781 SmallVector<Value> autodiffInputs{autodiffQ0, gradSeedInit};
782 SmallVector<NamedAttribute> adInitAttrs{
783 builder.getNamedAttr(
785 builder.getArrayAttr({enzyme::ActivityAttr::get(
786 builder.getContext(), enzyme::Activity::enzyme_active)})),
787 builder.getNamedAttr(
789 builder.getArrayAttr(
790 {enzyme::ActivityAttr::get(builder.getContext(),
791 enzyme::Activity::enzyme_active),
792 enzyme::ActivityAttr::get(builder.getContext(),
793 enzyme::Activity::enzyme_const)})),
797 adInitAttrs.push_back(attr);
799 auto autodiffInit = enzyme::AutoDiffRegionOp::create(
801 TypeRange{scalarType, rngForAutodiff.getType(), autodiffGradType},
802 autodiffInputs, adInitAttrs);
804 Block *autodiffInitBlock = builder.createBlock(&autodiffInit.getBody());
805 autodiffInitBlock->addArgument(autodiffQ0Type, loc);
807 builder.setInsertionPointToStart(autodiffInitBlock);
808 auto q0Arg = autodiffInitBlock->getArgument(0);
810 if (isCustomLogpdf) {
811 SmallVector<Value> callArgs;
812 callArgs.push_back(q0Arg);
814 auto callOpInner = func::CallOp::create(builder, loc, ctx.
logpdfFn,
815 TypeRange{scalarType}, callArgs);
817 arith::NegFOp::create(builder, loc, callOpInner.getResult(0));
818 enzyme::YieldOp::create(builder, loc, {U0_init, rngForAutodiff});
822 auto q0Arg_constrained =
827 SmallVector<Value> generateInputsInner;
828 generateInputsInner.push_back(rngForAutodiff);
831 SmallVector<Type> generateResultTypesInner;
832 generateResultTypesInner.push_back(fullTraceType);
833 generateResultTypesInner.push_back(scalarType);
837 auto generateOpInner = impulse::GenerateOp::create(
838 builder, loc, generateResultTypesInner, ctx.
fn, generateInputsInner,
840 builder.getStringAttr(
""));
843 arith::NegFOp::create(builder, loc, generateOpInner.getWeight());
847 arith::SubFOp::create(builder, loc, negWeightInit, jacobianInit);
849 SmallVector<Value> yieldValues{U0_init, generateOpInner.getResult(2)};
850 enzyme::YieldOp::create(builder, loc, yieldValues);
852 builder.setInsertionPointAfter(autodiffInit);
854 Value grad0 = autodiffInit.getResult(2);
855 if (isCustomLogpdf) {
856 grad0 = impulse::ReshapeOp::create(builder, loc, positionType, grad0);
859 return {q0, U0, grad0, rngForSampleKernel};
863 Value grad, Value U, Value rng,
868 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
869 auto i1TensorType = RankedTensorType::get({}, builder.getI1Type());
876 auto numStepsF64 = math::CeilOp::create(builder, loc, trajDivStep);
878 arith::FPToSIOp::create(builder, loc, i64TensorType, numStepsF64);
879 auto adjustedStepSize =
885 auto sampleKernelSplit = impulse::RandomSplitOp::create(
886 builder, loc, TypeRange{rng.getType(), rng.getType(), rng.getType()},
888 auto rngNext = sampleKernelSplit.getResult(0);
889 auto rngMomentum = sampleKernelSplit.getResult(1);
890 auto rngTransition = sampleKernelSplit.getResult(2);
893 Value rngForMomentum = rngMomentum;
895 auto momSplit = impulse::RandomSplitOp::create(
896 builder, loc, TypeRange{rng.getType(), rng.getType()}, rngMomentum);
897 rngForMomentum = momSplit.getResult(0);
899 auto [p0, rngAfterMomentum] =
907 auto H0 = arith::AddFOp::create(builder, loc, U, K0);
910 auto direction = arith::ConstantOp::create(
911 builder, loc, i1TensorType,
912 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(
true)));
914 auto c0 = arith::ConstantOp::create(
915 builder, loc, i64TensorType,
916 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
917 auto c1 = arith::ConstantOp::create(
918 builder, loc, i64TensorType,
919 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
922 SmallVector<Type> loopResultTypes = {positionType, positionType, positionType,
923 scalarType, rngTransition.getType()};
925 impulse::ForOp::create(builder, loc, loopResultTypes, c0, numSteps, c1,
926 ValueRange{q, p0, grad, U, rngTransition});
928 Block *loopBody = builder.createBlock(&forLoopOp.getRegion());
929 loopBody->addArgument(i64TensorType, loc);
930 loopBody->addArgument(positionType, loc);
931 loopBody->addArgument(positionType, loc);
932 loopBody->addArgument(positionType, loc);
933 loopBody->addArgument(scalarType, loc);
934 loopBody->addArgument(rngTransition.getType(), loc);
936 builder.setInsertionPointToStart(loopBody);
937 auto qLoop = loopBody->getArgument(1);
938 auto pLoop = loopBody->getArgument(2);
939 auto gradLoop = loopBody->getArgument(3);
940 auto rngLoop = loopBody->getArgument(5);
947 impulse::YieldOp::create(
948 builder, loc, ValueRange{step.q, step.p, step.grad, step.U, step.rng});
950 builder.setInsertionPointAfter(forLoopOp);
951 auto qProposal = forLoopOp.getResult(0);
952 auto pProposal = forLoopOp.getResult(1);
953 auto gradProposal = forLoopOp.getResult(2);
954 auto UProposal = forLoopOp.getResult(3);
955 auto rngAfterLeapfrog = forLoopOp.getResult(4);
960 auto H1 = arith::AddFOp::create(builder, loc, UProposal, K1);
963 auto zeroConst = arith::ConstantOp::create(
964 builder, loc, scalarType,
965 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0)));
966 auto oneConst = arith::ConstantOp::create(
967 builder, loc, scalarType,
968 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
971 auto dH = arith::SubFOp::create(builder, loc, H0, H1);
972 auto expDH = math::ExpOp::create(builder, loc, dH);
973 auto accProb = arith::MinimumFOp::create(builder, loc, oneConst, expDH);
976 auto randomOp = impulse::RandomOp::create(
977 builder, loc, TypeRange{rngAfterLeapfrog.getType(), scalarType},
978 rngAfterLeapfrog, zeroConst, oneConst,
979 impulse::RngDistributionAttr::get(builder.getContext(),
980 impulse::RngDistribution::UNIFORM));
981 auto randUniform = randomOp.getResult();
984 auto acceptedTensor = arith::CmpFOp::create(
985 builder, loc, arith::CmpFPredicate::OLT, randUniform, accProb);
988 auto qFinal = impulse::SelectOp::create(builder, loc, positionType,
989 acceptedTensor, qProposal, q);
990 auto gradFinal = impulse::SelectOp::create(
991 builder, loc, positionType, acceptedTensor, gradProposal, grad);
992 auto UFinal = impulse::SelectOp::create(builder, loc, scalarType,
993 acceptedTensor, UProposal, U);
995 return {qFinal, gradFinal, UFinal, acceptedTensor, accProb, rngNext};
999 Value grad, Value U, Value rng,
1004 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1005 auto i1TensorType = RankedTensorType::get({}, builder.getI1Type());
1008 auto sampleKernelSplit = impulse::RandomSplitOp::create(
1009 builder, loc, TypeRange{rng.getType(), rng.getType(), rng.getType()},
1011 auto rngNext = sampleKernelSplit.getResult(0);
1012 auto rngMomentum = sampleKernelSplit.getResult(1);
1013 auto rngTree = sampleKernelSplit.getResult(2);
1016 Value rngForMomentum = rngMomentum;
1018 auto momSplit = impulse::RandomSplitOp::create(
1019 builder, loc, TypeRange{rng.getType(), rng.getType()}, rngMomentum);
1020 rngForMomentum = momSplit.getResult(0);
1022 auto [p0, rngAfterMomentum] =
1030 auto H0 = arith::AddFOp::create(builder, loc, U, K0);
1033 auto iterCtx = ctx.
withH0(H0);
1035 auto zeroI64 = arith::ConstantOp::create(
1036 builder, loc, i64TensorType,
1037 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
1038 auto falseConst = arith::ConstantOp::create(
1039 builder, loc, i1TensorType,
1040 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(
false)));
1041 auto zeroConst = arith::ConstantOp::create(
1042 builder, loc, scalarType,
1043 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0)));
1052 .grad_proposal = grad,
1056 .weight = zeroConst,
1057 .turning = falseConst,
1058 .diverging = falseConst,
1059 .sum_accept_probs = zeroConst,
1060 .num_proposals = zeroI64,
1065 auto finalTree =
buildTree(builder, loc, initialTree, iterCtx, debugDump);
1068 auto trueConst = arith::ConstantOp::create(
1069 builder, loc, i1TensorType,
1070 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(
true)));
1073 auto oneI64 = arith::ConstantOp::create(
1074 builder, loc, i64TensorType,
1075 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1076 auto numProposalsClamped =
1077 arith::MaxSIOp::create(builder, loc, finalTree.num_proposals, oneI64);
1078 auto numProposalsFloat =
1079 arith::SIToFPOp::create(builder, loc, scalarType, numProposalsClamped);
1080 auto meanAcceptProb = arith::DivFOp::create(
1081 builder, loc, finalTree.sum_accept_probs, numProposalsFloat);
1083 return {finalTree.q_proposal, finalTree.grad_proposal,
1084 finalTree.U_proposal, trueConst,
1085 meanAcceptProb, rngNext};
1094 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1095 auto i1TensorType = RankedTensorType::get({}, builder.getI1Type());
1097 auto oneConst = arith::ConstantOp::create(
1098 builder, loc, scalarType,
1099 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
1100 auto zeroI64 = arith::ConstantOp::create(
1101 builder, loc, i64TensorType,
1102 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
1103 auto oneI64 = arith::ConstantOp::create(
1104 builder, loc, i64TensorType,
1105 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1106 auto falseConst = arith::ConstantOp::create(
1107 builder, loc, i1TensorType,
1108 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(
false)));
1115 auto gradNew = leap.
grad;
1117 auto rngOut = leap.
rng;
1121 auto HNew = arith::AddFOp::create(builder, loc, UNew, KNew);
1122 Value deltaH = arith::SubFOp::create(builder, loc, HNew, ctx.
H0);
1125 auto isNan = arith::CmpFOp::create(builder, loc, arith::CmpFPredicate::UNE,
1127 auto infConst = arith::ConstantOp::create(
1128 builder, loc, scalarType,
1129 DenseElementsAttr::get(
1130 scalarType, builder.getFloatAttr(
1131 elemType, std::numeric_limits<double>::infinity())));
1132 deltaH = arith::SelectOp::create(builder, loc, scalarType, isNan, infConst,
1135 auto treeWeight = arith::NegFOp::create(builder, loc, deltaH);
1138 auto diverging = arith::CmpFOp::create(
1139 builder, loc, arith::CmpFPredicate::OGT, deltaH, ctx.
maxDeltaEnergy);
1141 auto negDeltaH = arith::NegFOp::create(builder, loc, deltaH);
1142 auto expNegDelta = math::ExpOp::create(builder, loc, negDeltaH);
1144 arith::MinimumFOp::create(builder, loc, oneConst, expNegDelta);
1146 return {.q_left = qNew,
1148 .grad_left = gradNew,
1151 .grad_right = gradNew,
1153 .grad_proposal = gradNew,
1157 .weight = treeWeight,
1158 .turning = falseConst,
1159 .diverging = diverging,
1160 .sum_accept_probs = acceptProb,
1161 .num_proposals = oneI64,
1172 auto leafQ = impulse::SelectOp::create(builder, loc, positionType, direction,
1174 auto leafP = impulse::SelectOp::create(builder, loc, positionType, direction,
1176 auto leafGrad = impulse::SelectOp::create(
1178 return {leafQ, leafP, leafGrad};
1184 Value direction, Value pCkpts, Value pSumCkpts,
1186 auto i1TensorType = RankedTensorType::get({}, builder.getI1Type());
1187 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1188 auto pCkptsType = cast<RankedTensorType>(pCkpts.getType());
1189 auto trueConst = arith::ConstantOp::create(
1190 builder, loc, i1TensorType,
1191 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(
true)));
1192 auto oneI64 = arith::ConstantOp::create(
1193 builder, loc, i64TensorType,
1194 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1195 auto zeroI64 = arith::ConstantOp::create(
1196 builder, loc, i64TensorType,
1197 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
1200 auto maxNumProposals =
1201 arith::ShLIOp::create(builder, loc, oneI64, initialTree.
depth);
1203 SmallVector<Type> whileTypes = initialTree.
getTypes();
1204 whileTypes.push_back(pCkptsType);
1205 whileTypes.push_back(pCkptsType);
1206 whileTypes.push_back(i64TensorType);
1208 SmallVector<Value> whileInitVals = initialTree.
toValues();
1209 whileInitVals[15] = zeroI64;
1210 whileInitVals.push_back(pCkpts);
1211 whileInitVals.push_back(pSumCkpts);
1212 whileInitVals.push_back(zeroI64);
1215 impulse::WhileOp::create(builder, loc, whileTypes, whileInitVals);
1218 Block *condBlock = builder.createBlock(&whileOp.getConditionRegion());
1219 for (
auto type : whileTypes)
1220 condBlock->addArgument(type, loc);
1222 builder.setInsertionPointToStart(condBlock);
1223 SmallVector<Value> condTreeArgs(condBlock->getArguments().begin(),
1224 condBlock->getArguments().begin() + 18);
1227 auto numProposalsCheck =
1228 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::slt,
1231 arith::XOrIOp::create(builder, loc, condTree.
turning, trueConst);
1233 arith::XOrIOp::create(builder, loc, condTree.
diverging, trueConst);
1234 auto continueCond = arith::AndIOp::create(
1236 arith::AndIOp::create(builder, loc, numProposalsCheck, notTurning),
1240 impulse::YieldOp::create(builder, loc, ValueRange{continueCond});
1242 Block *bodyBlock = builder.createBlock(&whileOp.getBodyRegion());
1243 for (
auto type : whileTypes)
1244 bodyBlock->addArgument(type, loc);
1246 builder.setInsertionPointToStart(bodyBlock);
1248 SmallVector<Value> bodyTreeArgs(bodyBlock->getArguments().begin(),
1249 bodyBlock->getArguments().begin() + 18);
1251 auto bodyPCkpts = bodyBlock->getArgument(18);
1252 auto bodyPSumCkpts = bodyBlock->getArgument(19);
1253 auto bodyLeafIdx = bodyBlock->getArgument(20);
1259 auto rngSplit2 = impulse::RandomSplitOp::create(
1260 builder, loc, TypeRange{bodyTree.
rng.getType(), bodyTree.
rng.getType()},
1262 auto rngNext = rngSplit2.getResult(0);
1263 auto rngCombine = rngSplit2.getResult(1);
1270 auto isFirstLeaf = arith::CmpIOp::create(
1271 builder, loc, arith::CmpIPredicate::eq, bodyTree.
num_proposals, zeroI64);
1273 SmallVector<Type> treeTypes = newLeaf.
getTypes();
1274 auto ifOp = impulse::IfOp::create(builder, loc, treeTypes, isFirstLeaf);
1276 Block *trueBranch = builder.createBlock(&ifOp.getTrueBranch());
1277 builder.setInsertionPointToStart(trueBranch);
1278 impulse::YieldOp::create(builder, loc, newLeaf.
toValues());
1281 Block *falseBranch = builder.createBlock(&ifOp.getFalseBranch());
1282 builder.setInsertionPointToStart(falseBranch);
1284 combineTrees(builder, loc, bodyTree, newLeaf, direction, rngCombine,
1286 impulse::YieldOp::create(builder, loc, combinedTree.
toValues());
1289 builder.setInsertionPointAfter(ifOp);
1291 SmallVector<Value>(ifOp.getResults().begin(), ifOp.getResults().end()));
1292 updatedTree.
rng = rngNext;
1295 auto [ckptIdxMin, ckptIdxMax] =
1298 builder, loc, bodyLeafIdx, ckptIdxMax, newLeaf.
p_right, updatedTree.
p_sum,
1299 bodyPCkpts, bodyPSumCkpts, ctx, debugDump);
1301 builder, loc, newLeaf.
p_right, updatedTree.
p_sum, updatedPCkpts,
1302 updatedPSumCkpts, ckptIdxMin, ckptIdxMax, ctx, debugDump);
1305 impulse::SelectOp::create(builder, loc, i1TensorType, isFirstLeaf,
1306 newLeaf.
turning, iterativeTurning);
1308 auto nextLeafIdx = arith::AddIOp::create(builder, loc, bodyLeafIdx, oneI64);
1310 SmallVector<Value> yieldVals = updatedTree.
toValues();
1311 yieldVals.push_back(updatedPCkpts);
1312 yieldVals.push_back(updatedPSumCkpts);
1313 yieldVals.push_back(nextLeafIdx);
1314 impulse::YieldOp::create(builder, loc, yieldVals);
1316 builder.setInsertionPointAfter(whileOp);
1318 SmallVector<Value> resultTreeArgs(whileOp.getResults().begin(),
1319 whileOp.getResults().begin() + 18);
1321 auto resultPCkpts = whileOp.getResult(18);
1322 auto resultPSumCkpts = whileOp.getResult(19);
1328 return {resultTree, resultPCkpts, resultPSumCkpts};
1333 Value direction, Value pCkpts,
1336 auto rngSplit2 = impulse::RandomSplitOp::create(
1337 builder, loc, TypeRange{tree.
rng.getType(), tree.
rng.getType()},
1339 auto rngSubtree = rngSplit2.getResult(0);
1340 auto rngTransition = rngSplit2.getResult(1);
1343 subTreeInit.
rng = rngSubtree;
1345 builder, loc, subTreeInit, direction, pCkpts, pSumCkpts, ctx, debugDump);
1349 combineTrees(builder, loc, tree, subtreeResult.tree, direction,
1350 rngTransition,
true, ctx);
1352 return {combinedTree, subtreeResult.pCkpts, subtreeResult.pSumCkpts};
1359 cast<RankedTensorType>(ctx.
stepSize.getType()).getElementType();
1360 auto F64TensorType = RankedTensorType::get({}, elemType);
1361 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1362 auto i1TensorType = RankedTensorType::get({}, builder.getI1Type());
1364 auto trueConst = arith::ConstantOp::create(
1365 builder, loc, i1TensorType,
1366 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(
true)));
1367 auto halfConst = arith::ConstantOp::create(
1368 builder, loc, F64TensorType,
1369 DenseElementsAttr::get(F64TensorType, builder.getF64FloatAttr(0.5)));
1370 auto zeroConst = arith::ConstantOp::create(
1371 builder, loc, F64TensorType,
1372 DenseElementsAttr::get(F64TensorType, builder.getF64FloatAttr(0.0)));
1373 auto oneConst = arith::ConstantOp::create(
1374 builder, loc, F64TensorType,
1375 DenseElementsAttr::get(F64TensorType, builder.getF64FloatAttr(1.0)));
1377 auto maxTreeDepth = arith::ConstantOp::create(
1378 builder, loc, i64TensorType,
1379 DenseElementsAttr::get(i64TensorType,
1382 auto checkpointType =
1385 SmallVector<Type> whileTypes = initialTree.
getTypes();
1386 SmallVector<Value> whileInitVals = initialTree.
toValues();
1389 impulse::WhileOp::create(builder, loc, whileTypes, whileInitVals);
1392 Block *condBlock = builder.createBlock(&whileOp.getConditionRegion());
1393 for (
auto type : whileTypes)
1394 condBlock->addArgument(type, loc);
1396 builder.setInsertionPointToStart(condBlock);
1398 SmallVector<Value> condArgs(condBlock->getArguments().begin(),
1399 condBlock->getArguments().end());
1402 auto depthCheck = arith::CmpIOp::create(
1403 builder, loc, arith::CmpIPredicate::slt, condTree.
depth, maxTreeDepth);
1405 arith::XOrIOp::create(builder, loc, condTree.
turning, trueConst);
1407 arith::XOrIOp::create(builder, loc, condTree.
diverging, trueConst);
1410 auto continueCond = arith::AndIOp::create(
1411 builder, loc, arith::AndIOp::create(builder, loc, depthCheck, notTurning),
1414 impulse::YieldOp::create(builder, loc, ValueRange{continueCond});
1416 Block *bodyBlock = builder.createBlock(&whileOp.getBodyRegion());
1417 for (
auto type : whileTypes)
1418 bodyBlock->addArgument(type, loc);
1420 builder.setInsertionPointToStart(bodyBlock);
1422 SmallVector<Value> bodyArgs(bodyBlock->getArguments().begin(),
1423 bodyBlock->getArguments().end());
1427 auto zeroCkpts = arith::ConstantOp::create(
1428 builder, loc, checkpointType,
1429 DenseElementsAttr::get(checkpointType, builder.getF64FloatAttr(0.0)));
1430 Value bodyPCkpts = zeroCkpts;
1431 Value bodyPSumCkpts = zeroCkpts;
1433 auto rngSplit3 = impulse::RandomSplitOp::create(
1435 TypeRange{bodyTree.
rng.getType(), bodyTree.
rng.getType(),
1436 bodyTree.
rng.getType()},
1438 auto rngNext = rngSplit3.getResult(0);
1439 auto rngDir = rngSplit3.getResult(1);
1440 auto rngDbl = rngSplit3.getResult(2);
1442 auto directionRandom = impulse::RandomOp::create(
1443 builder, loc, TypeRange{rngDir.getType(), F64TensorType}, rngDir,
1444 zeroConst, oneConst,
1445 impulse::RngDistributionAttr::get(builder.getContext(),
1446 impulse::RngDistribution::UNIFORM));
1448 arith::CmpFOp::create(builder, loc, arith::CmpFPredicate::OLT,
1449 directionRandom.getResult(), halfConst);
1453 treeToDouble.
rng = rngDbl;
1454 auto doubleResult =
doubleTree(builder, loc, treeToDouble, direction,
1455 bodyPCkpts, bodyPSumCkpts, ctx, debugDump);
1458 treeToYield.
rng = rngNext;
1459 impulse::YieldOp::create(builder, loc, treeToYield.
toValues());
1461 builder.setInsertionPointAfter(whileOp);
1463 SmallVector<Value> results(whileOp.getResults().begin(),
1464 whileOp.getResults().end());
1471 auto i64TensorType = cast<RankedTensorType>(leafIdx.getType());
1473 auto oneConst = arith::ConstantOp::create(
1474 builder, loc, i64TensorType,
1475 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1478 auto shiftedIdx = arith::ShRUIOp::create(builder, loc, leafIdx, oneConst);
1480 impulse::PopcountOp::create(builder, loc, i64TensorType, shiftedIdx);
1483 auto leafIdxPlusOne = arith::AddIOp::create(builder, loc, leafIdx, oneConst);
1485 auto minusOneConst = arith::ConstantOp::create(
1486 builder, loc, i64TensorType,
1487 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(-1)));
1488 auto notLeafIdx = arith::XOrIOp::create(builder, loc, leafIdx, minusOneConst);
1491 arith::AndIOp::create(builder, loc, notLeafIdx, leafIdxPlusOne);
1492 Value andMinusOne = arith::SubIOp::create(builder, loc, andResult, oneConst);
1494 impulse::PopcountOp::create(builder, loc, i64TensorType, andMinusOne);
1497 Value idxMaxMinusNumSubtrees =
1498 arith::SubIOp::create(builder, loc, idxMax, numSubtrees);
1500 arith::AddIOp::create(builder, loc, idxMaxMinusNumSubtrees, oneConst);
1502 return {idxMin, idxMax};
1506 Value pSum, Value pCkpts, Value pSumCkpts,
1507 Value idxMin, Value idxMax,
1510 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1511 auto i1TensorType = RankedTensorType::get({}, builder.getI1Type());
1513 auto falseConst = arith::ConstantOp::create(
1514 builder, loc, i1TensorType,
1515 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(
false)));
1516 auto oneI64 = arith::ConstantOp::create(
1517 builder, loc, i64TensorType,
1518 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1521 SmallVector<Type> whileTypes = {i64TensorType, i1TensorType};
1522 SmallVector<Value> whileInitVals = {idxMax, falseConst};
1525 impulse::WhileOp::create(builder, loc, whileTypes, whileInitVals);
1526 Block *condBlock = builder.createBlock(&whileOp.getConditionRegion());
1527 condBlock->addArgument(i64TensorType, loc);
1528 condBlock->addArgument(i1TensorType, loc);
1529 builder.setInsertionPointToStart(condBlock);
1531 Value iCond = condBlock->getArgument(0);
1532 Value turningCond = condBlock->getArgument(1);
1534 auto iGeMin = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::sge,
1536 auto trueConst = arith::ConstantOp::create(
1537 builder, loc, i1TensorType,
1538 DenseElementsAttr::get(i1TensorType, builder.getBoolAttr(
true)));
1539 auto notTurning = arith::XOrIOp::create(builder, loc, turningCond, trueConst);
1540 auto continueLoop = arith::AndIOp::create(builder, loc, iGeMin, notTurning);
1542 impulse::YieldOp::create(builder, loc, ValueRange{continueLoop.getResult()});
1544 Block *bodyBlock = builder.createBlock(&whileOp.getBodyRegion());
1545 bodyBlock->addArgument(i64TensorType, loc);
1546 bodyBlock->addArgument(i1TensorType, loc);
1547 builder.setInsertionPointToStart(bodyBlock);
1548 Value iBody = bodyBlock->getArgument(0);
1550 auto zeroI64 = arith::ConstantOp::create(
1551 builder, loc, i64TensorType,
1552 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
1553 Value pLeft = impulse::DynamicSliceOp::create(
1554 builder, loc, positionType, pCkpts, ValueRange{iBody, zeroI64},
1556 Value pSumCkptI = impulse::DynamicSliceOp::create(
1557 builder, loc, positionType, pSumCkpts, ValueRange{iBody, zeroI64},
1561 auto pSumMinusCkpt = arith::SubFOp::create(builder, loc, pSum, pSumCkptI);
1562 Value subtreePSum = arith::AddFOp::create(builder, loc, pSumMinusCkpt, pLeft);
1565 Value turningAtCkpt =
checkTurning(builder, loc, pLeft, p, subtreePSum, ctx);
1567 Value iNext = arith::SubIOp::create(builder, loc, iBody, oneI64);
1568 impulse::YieldOp::create(builder, loc, ValueRange{iNext, turningAtCkpt});
1570 builder.setInsertionPointAfter(whileOp);
1571 return whileOp.getResult(1);
1574std::pair<Value, Value>
1576 Value ckptIdxMax, Value p, Value pSum, Value pCkpts,
1579 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1580 auto oneI64 = arith::ConstantOp::create(
1581 builder, loc, i64TensorType,
1582 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1583 auto zeroI64 = arith::ConstantOp::create(
1584 builder, loc, i64TensorType,
1585 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
1587 Value leafIdxBit0 = arith::AndIOp::create(builder, loc, leafIdx, oneI64);
1588 Value isEven = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
1589 leafIdxBit0, zeroI64);
1591 auto pCkptsType = cast<RankedTensorType>(pCkpts.getType());
1594 SmallVector<Type> ifResultTypes = {pCkptsType, pCkptsType};
1595 auto ifOp = impulse::IfOp::create(builder, loc, ifResultTypes, isEven);
1597 Block *trueBranch = builder.createBlock(&ifOp.getTrueBranch());
1598 builder.setInsertionPointToStart(trueBranch);
1600 auto updatedPCkpts = impulse::DynamicUpdateSliceOp::create(
1601 builder, loc, pCkptsType, pCkpts, p, ValueRange{ckptIdxMax, zeroI64});
1602 auto updatedPSumCkpts = impulse::DynamicUpdateSliceOp::create(
1603 builder, loc, pCkptsType, pSumCkpts, pSum,
1604 ValueRange{ckptIdxMax, zeroI64});
1606 impulse::YieldOp::create(builder, loc,
1607 ValueRange{updatedPCkpts, updatedPSumCkpts});
1610 Block *falseBranch = builder.createBlock(&ifOp.getFalseBranch());
1611 builder.setInsertionPointToStart(falseBranch);
1613 impulse::YieldOp::create(builder, loc, ValueRange{pCkpts, pSumCkpts});
1616 builder.setInsertionPointAfter(ifOp);
1618 Value finalPCkpts = ifOp.getResult(0);
1619 Value finalPSumCkpts = ifOp.getResult(1);
1621 return {finalPCkpts, finalPSumCkpts};
1626 auto stepSizeType = cast<RankedTensorType>(stepSize.getType());
1627 auto elemType = stepSizeType.getElementType();
1628 auto scalarType = RankedTensorType::get({}, elemType);
1629 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1632 auto log10Const = arith::ConstantOp::create(
1633 builder, loc, scalarType,
1634 DenseElementsAttr::get(scalarType,
1635 builder.getFloatAttr(elemType, std::log(10.0))));
1636 Value logStepSize = math::LogOp::create(builder, loc, stepSize);
1638 arith::AddFOp::create(builder, loc, log10Const, logStepSize);
1640 auto zeroConst = arith::ConstantOp::create(
1641 builder, loc, scalarType,
1642 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0)));
1643 auto zeroI64 = arith::ConstantOp::create(
1644 builder, loc, i64TensorType,
1645 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
1648 .log_step_size = zeroConst,
1649 .log_step_size_avg = zeroConst,
1650 .gradient_avg = zeroConst,
1651 .step_count = zeroI64,
1652 .prox_center = proxCenter,
1666 auto acceptProbType = cast<RankedTensorType>(acceptProb.getType());
1667 auto elemType = acceptProbType.getElementType();
1668 auto scalarType = RankedTensorType::get({}, elemType);
1669 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1672 auto oneI64 = arith::ConstantOp::create(
1673 builder, loc, i64TensorType,
1674 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1675 Value tNew = arith::AddIOp::create(builder, loc, state.
step_count, oneI64);
1676 Value tFloat = arith::SIToFPOp::create(builder, loc, scalarType, tNew);
1679 auto t0Const = arith::ConstantOp::create(
1680 builder, loc, scalarType,
1681 DenseElementsAttr::get(scalarType,
1682 builder.getFloatAttr(elemType, config.
t0)));
1685 auto targetConst = arith::ConstantOp::create(
1686 builder, loc, scalarType,
1687 DenseElementsAttr::get(
1690 Value g = arith::SubFOp::create(builder, loc, targetConst, acceptProb);
1693 Value tPlusT0 = arith::AddFOp::create(builder, loc, tFloat, t0Const);
1696 auto oneConst = arith::ConstantOp::create(
1697 builder, loc, scalarType,
1698 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0)));
1699 Value weight = arith::DivFOp::create(builder, loc, oneConst, tPlusT0);
1702 Value decay = arith::SubFOp::create(builder, loc, oneConst, weight);
1707 arith::MulFOp::create(builder, loc, decay, state.
gradient_avg);
1708 Value gWeighted = arith::MulFOp::create(builder, loc, weight, g);
1709 Value gAvgNew = arith::AddFOp::create(builder, loc, gAvgDecayed, gWeighted);
1712 Value sqrtT = math::SqrtOp::create(builder, loc, tFloat);
1713 auto gammaConst = arith::ConstantOp::create(
1714 builder, loc, scalarType,
1715 DenseElementsAttr::get(scalarType,
1716 builder.getFloatAttr(elemType, config.
gamma)));
1717 Value sqrtTOverGamma = arith::DivFOp::create(builder, loc, sqrtT, gammaConst);
1719 arith::MulFOp::create(builder, loc, sqrtTOverGamma, gAvgNew);
1720 Value logStepSizeNew =
1721 arith::SubFOp::create(builder, loc, state.
prox_center, adjustment);
1724 auto negKappaConst = arith::ConstantOp::create(
1725 builder, loc, scalarType,
1726 DenseElementsAttr::get(scalarType,
1727 builder.getFloatAttr(elemType, -config.
kappa)));
1728 Value weightT = math::PowFOp::create(builder, loc, tFloat, negKappaConst);
1731 Value oneMinusWeightT =
1732 arith::SubFOp::create(builder, loc, oneConst, weightT);
1733 Value avgDecayed = arith::MulFOp::create(builder, loc, oneMinusWeightT,
1735 Value newContribution =
1736 arith::MulFOp::create(builder, loc, weightT, logStepSizeNew);
1737 Value logStepSizeAvgNew =
1738 arith::AddFOp::create(builder, loc, avgDecayed, newContribution);
1741 .log_step_size = logStepSizeNew,
1742 .log_step_size_avg = logStepSizeAvgNew,
1743 .gradient_avg = gAvgNew,
1753 return math::ExpOp::create(builder, loc, logStepSize);
1757 int64_t positionSize,
bool diagonal) {
1758 auto elemType = builder.getF64Type();
1759 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1760 auto meanType = RankedTensorType::get({positionSize}, elemType);
1762 Value mean = arith::ConstantOp::create(
1763 builder, loc, meanType,
1764 DenseElementsAttr::get(meanType, builder.getFloatAttr(elemType, 0.0)));
1770 m2 = arith::ConstantOp::create(
1771 builder, loc, meanType,
1772 DenseElementsAttr::get(meanType, builder.getFloatAttr(elemType, 0.0)));
1774 auto m2Type = RankedTensorType::get({positionSize, positionSize}, elemType);
1775 m2 = arith::ConstantOp::create(
1776 builder, loc, m2Type,
1777 DenseElementsAttr::get(m2Type, builder.getFloatAttr(elemType, 0.0)));
1780 Value n = arith::ConstantOp::create(
1781 builder, loc, i64TensorType,
1782 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));
1784 return {mean, m2, n};
1798 auto sampleType = cast<RankedTensorType>(sample.getType());
1799 auto elemType = sampleType.getElementType();
1800 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1802 auto oneI64 = arith::ConstantOp::create(
1803 builder, loc, i64TensorType,
1804 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1805 Value nNew = arith::AddIOp::create(builder, loc, state.
n, oneI64);
1807 auto scalarType = RankedTensorType::get({}, elemType);
1808 Value nFloat = arith::SIToFPOp::create(builder, loc, scalarType, nNew);
1810 Value nBroadcast = BroadcastOp::create(builder, loc, sampleType, nFloat,
1811 sampleType.getShape());
1813 Value deltaPre = arith::SubFOp::create(builder, loc, sample, state.
mean);
1815 Value deltaPreOverN =
1816 arith::DivFOp::create(builder, loc, deltaPre, nBroadcast);
1818 arith::AddFOp::create(builder, loc, state.
mean, deltaPreOverN);
1820 Value deltaPost = arith::SubFOp::create(builder, loc, sample, meanNew);
1824 Value product = arith::MulFOp::create(builder, loc, deltaPre, deltaPost);
1825 m2New = arith::AddFOp::create(builder, loc, state.
m2, product);
1827 auto m2Type = cast<RankedTensorType>(state.
m2.getType());
1828 Value outerProduct = impulse::DotOp::create(
1829 builder, loc, m2Type, deltaPost, deltaPre,
1830 builder.getDenseI64ArrayAttr({}),
1831 builder.getDenseI64ArrayAttr({}),
1832 builder.getDenseI64ArrayAttr({}),
1833 builder.getDenseI64ArrayAttr({}));
1834 m2New = arith::AddFOp::create(builder, loc, state.
m2, outerProduct);
1837 return {meanNew, m2New, nNew};
1844 auto m2Type = cast<RankedTensorType>(state.
m2.getType());
1845 auto elemType = m2Type.getElementType();
1846 auto scalarType = RankedTensorType::get({}, elemType);
1847 auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());
1849 auto oneI64 = arith::ConstantOp::create(
1850 builder, loc, i64TensorType,
1851 DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(1)));
1852 Value nMinus1 = arith::SubIOp::create(builder, loc, state.
n, oneI64);
1853 Value nMinus1Float =
1854 arith::SIToFPOp::create(builder, loc, scalarType, nMinus1);
1856 Value nMinus1Bcast = BroadcastOp::create(builder, loc, m2Type, nMinus1Float,
1859 Value cov = arith::DivFOp::create(builder, loc, state.
m2, nMinus1Bcast);
1867 Value nFloat = arith::SIToFPOp::create(builder, loc, scalarType, state.
n);
1869 auto fiveConst = arith::ConstantOp::create(
1870 builder, loc, scalarType,
1871 DenseElementsAttr::get(scalarType,
1872 builder.getFloatAttr(elemType, 5.0)));
1873 Value nPlusFive = arith::AddFOp::create(builder, loc, nFloat, fiveConst);
1875 Value scale = arith::DivFOp::create(builder, loc, nFloat, nPlusFive);
1877 BroadcastOp::create(builder, loc, m2Type, scale, m2Type.getShape());
1878 Value scaledCov = arith::MulFOp::create(builder, loc, scaleBcast, cov);
1880 auto shrinkageBaseConst = arith::ConstantOp::create(
1881 builder, loc, scalarType,
1882 DenseElementsAttr::get(scalarType,
1883 builder.getFloatAttr(elemType, 1e-3 * 5.0)));
1885 arith::DivFOp::create(builder, loc, shrinkageBaseConst, nPlusFive);
1888 Value shrinkageBcast = BroadcastOp::create(builder, loc, m2Type,
1889 shrinkage, m2Type.getShape());
1890 cov = arith::AddFOp::create(builder, loc, scaledCov, shrinkageBcast);
1893 Value shrinkageBcast = BroadcastOp::create(builder, loc, m2Type,
1894 shrinkage, m2Type.getShape());
1896 arith::MulFOp::create(builder, loc, shrinkageBcast, identity);
1897 cov = arith::AddFOp::create(builder, loc, scaledCov, shrinkageI);
1909 SmallVector<AdaptWindow> schedule;
1911 if (numSteps < 20) {
1912 schedule.push_back({0, numSteps - 1});
1917 int64_t startBufferSize = 75;
1918 int64_t endBufferSize = 50;
1919 int64_t initWindowSize = 25;
1921 if ((startBufferSize + endBufferSize + initWindowSize) > numSteps) {
1922 startBufferSize =
static_cast<int64_t
>(0.15 * numSteps);
1923 endBufferSize =
static_cast<int64_t
>(0.1 * numSteps);
1924 initWindowSize = numSteps - startBufferSize - endBufferSize;
1928 schedule.push_back({0, startBufferSize - 1});
1930 int64_t endWindowStart = numSteps - endBufferSize;
1931 int64_t nextWindowSize = initWindowSize;
1932 int64_t nextWindowStart = startBufferSize;
1935 while (nextWindowStart < endWindowStart) {
1936 int64_t curWindowStart = nextWindowStart;
1937 int64_t curWindowSize = nextWindowSize;
1939 if (3 * curWindowSize <= endWindowStart - curWindowStart) {
1940 nextWindowSize = 2 * curWindowSize;
1942 curWindowSize = endWindowStart - curWindowStart;
1945 nextWindowStart = curWindowStart + curWindowSize;
1946 schedule.push_back({curWindowStart, nextWindowStart - 1});
1950 schedule.push_back({endWindowStart, numSteps - 1});
1957 ArrayRef<SupportInfo> supports) {
1958 bool hasConstraints =
false;
1959 for (
const auto &info : supports) {
1960 if (info.support && info.support.getKind() != impulse::SupportKind::REAL) {
1961 hasConstraints =
true;
1965 if (!hasConstraints || supports.empty())
1968 auto inputType = cast<RankedTensorType>(constrained.getType());
1969 auto elemType = inputType.getElementType();
1970 int64_t positionSize = inputType.getShape()[1];
1972 auto positionType1D = RankedTensorType::get({positionSize}, elemType);
1973 Value constrained1D =
1974 impulse::ReshapeOp::create(builder, loc, positionType1D, constrained);
1976 SmallVector<Value> slices;
1977 for (
const auto &info : supports) {
1978 auto sliceType = RankedTensorType::get({info.size}, elemType);
1979 auto slice = impulse::SliceOp::create(
1980 builder, loc, sliceType, constrained1D,
1981 builder.getDenseI64ArrayAttr({info.offset}),
1982 builder.getDenseI64ArrayAttr({info.offset + info.size}),
1983 builder.getDenseI64ArrayAttr({1}));
1984 Value unconstrainedSlice;
1985 if (info.support && info.support.getKind() != impulse::SupportKind::REAL) {
1986 unconstrainedSlice =
1989 unconstrainedSlice = slice;
1992 slices.push_back(unconstrainedSlice);
1996 if (slices.size() == 1) {
1997 result1D = slices[0];
1999 auto resultType1D = RankedTensorType::get({positionSize}, elemType);
2000 result1D = arith::ConstantOp::create(
2001 builder, loc, resultType1D,
2002 DenseElementsAttr::get(resultType1D,
2003 builder.getFloatAttr(elemType, 0.0)));
2005 auto i64ScalarType = RankedTensorType::get({}, builder.getI64Type());
2006 auto elemType1DSlice = RankedTensorType::get({1}, elemType);
2007 auto elemType0D = RankedTensorType::get({}, elemType);
2009 for (
size_t i = 0; i < slices.size(); ++i) {
2010 auto sliceType = cast<RankedTensorType>(slices[i].getType());
2011 int64_t sliceSize = sliceType.getShape()[0];
2013 for (int64_t j = 0; j < sliceSize; ++j) {
2014 auto elemIdx = arith::ConstantOp::create(
2015 builder, loc, i64ScalarType,
2016 DenseElementsAttr::get(i64ScalarType,
2017 builder.getI64IntegerAttr(j)));
2018 auto resultIdx = arith::ConstantOp::create(
2019 builder, loc, i64ScalarType,
2020 DenseElementsAttr::get(i64ScalarType,
2021 builder.getI64IntegerAttr(offset + j)));
2023 auto elemSliced = impulse::DynamicSliceOp::create(
2024 builder, loc, elemType1DSlice, slices[i], ValueRange{elemIdx},
2025 builder.getDenseI64ArrayAttr({1}));
2027 impulse::ReshapeOp::create(builder, loc, elemType0D, elemSliced);
2029 auto resultSliced = impulse::ReshapeOp::create(
2030 builder, loc, RankedTensorType::get({1}, elemType), elem);
2031 result1D = impulse::DynamicUpdateSliceOp::create(
2032 builder, loc, resultType1D, result1D, resultSliced,
2033 ValueRange{resultIdx});
2036 offset += sliceSize;
2040 auto resultType2D = RankedTensorType::get({1, positionSize}, elemType);
2041 return impulse::ReshapeOp::create(builder, loc, resultType2D, result1D);
2045 Value unconstrained,
2046 ArrayRef<SupportInfo> supports) {
2047 bool hasConstraints =
false;
2048 for (
const auto &info : supports) {
2049 if (info.support && info.support.getKind() != impulse::SupportKind::REAL) {
2050 hasConstraints =
true;
2054 if (!hasConstraints || supports.empty())
2055 return unconstrained;
2057 auto inputType = cast<RankedTensorType>(unconstrained.getType());
2058 auto elemType = inputType.getElementType();
2059 int64_t positionSize = inputType.getShape()[1];
2061 auto positionType1D = RankedTensorType::get({positionSize}, elemType);
2062 Value unconstrained1D =
2063 impulse::ReshapeOp::create(builder, loc, positionType1D, unconstrained);
2065 SmallVector<Value> slices;
2066 for (
const auto &info : supports) {
2067 auto sliceType = RankedTensorType::get({info.size}, elemType);
2068 auto slice = impulse::SliceOp::create(
2069 builder, loc, sliceType, unconstrained1D,
2070 builder.getDenseI64ArrayAttr({info.offset}),
2071 builder.getDenseI64ArrayAttr({info.offset + info.size}),
2072 builder.getDenseI64ArrayAttr({1}));
2074 Value constrainedSlice;
2075 if (info.support && info.support.getKind() != impulse::SupportKind::REAL) {
2079 constrainedSlice = slice;
2082 slices.push_back(constrainedSlice);
2086 if (slices.size() == 1) {
2087 result1D = slices[0];
2089 auto resultType1D = RankedTensorType::get({positionSize}, elemType);
2090 result1D = arith::ConstantOp::create(
2091 builder, loc, resultType1D,
2092 DenseElementsAttr::get(resultType1D,
2093 builder.getFloatAttr(elemType, 0.0)));
2095 auto i64ScalarType = RankedTensorType::get({}, builder.getI64Type());
2096 auto elemType1DSlice = RankedTensorType::get({1}, elemType);
2097 auto elemType0D = RankedTensorType::get({}, elemType);
2099 for (
size_t i = 0; i < slices.size(); ++i) {
2100 auto sliceType = cast<RankedTensorType>(slices[i].getType());
2101 int64_t sliceSize = sliceType.getShape()[0];
2103 for (int64_t j = 0; j < sliceSize; ++j) {
2104 auto elemIdx = arith::ConstantOp::create(
2105 builder, loc, i64ScalarType,
2106 DenseElementsAttr::get(i64ScalarType,
2107 builder.getI64IntegerAttr(j)));
2108 auto resultIdx = arith::ConstantOp::create(
2109 builder, loc, i64ScalarType,
2110 DenseElementsAttr::get(i64ScalarType,
2111 builder.getI64IntegerAttr(offset + j)));
2113 auto elemSliced = impulse::DynamicSliceOp::create(
2114 builder, loc, elemType1DSlice, slices[i], ValueRange{elemIdx},
2115 builder.getDenseI64ArrayAttr({1}));
2117 impulse::ReshapeOp::create(builder, loc, elemType0D, elemSliced);
2120 impulse::ReshapeOp::create(builder, loc, elemType1DSlice, elem);
2121 result1D = impulse::DynamicUpdateSliceOp::create(
2122 builder, loc, resultType1D, result1D, resultSliced,
2123 ValueRange{resultIdx});
2126 offset += sliceSize;
2130 auto resultType2D = RankedTensorType::get({1, positionSize}, elemType);
2131 return impulse::ReshapeOp::create(builder, loc, resultType2D, result1D);
2135 Value unconstrained,
2136 ArrayRef<SupportInfo> supports) {
2137 auto inputType = cast<RankedTensorType>(unconstrained.getType());
2138 auto elemType = inputType.getElementType();
2139 auto scalarType = RankedTensorType::get({}, elemType);
2141 Value total = arith::ConstantOp::create(
2142 builder, loc, scalarType,
2143 DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0)));
2145 if (supports.empty())
2148 int64_t positionSize = inputType.getShape()[1];
2149 auto positionType1D = RankedTensorType::get({positionSize}, elemType);
2150 Value unconstrained1D =
2151 impulse::ReshapeOp::create(builder, loc, positionType1D, unconstrained);
2153 for (
const auto &info : supports) {
2154 if (!info.support || info.support.getKind() == impulse::SupportKind::REAL)
2157 auto sliceType = RankedTensorType::get({info.size}, elemType);
2158 auto slice = impulse::SliceOp::create(
2159 builder, loc, sliceType, unconstrained1D,
2160 builder.getDenseI64ArrayAttr({info.offset}),
2161 builder.getDenseI64ArrayAttr({info.offset + info.size}),
2162 builder.getDenseI64ArrayAttr({1}));
2166 total = arith::AddFOp::create(builder, loc, total, jacobian);
static Value gatherPositionFromTrace(OpBuilder &builder, Location loc, Value fullTrace, const HMCContext &ctx)
static Value createPermutationMatrix(OpBuilder &builder, Location loc, RankedTensorType matrixType)
Creates a permutation matrix of size n x n.
static Value scatterPositionToTrace(OpBuilder &builder, Location loc, Value position2d, Value fullTrace, const HMCContext &ctx)
static Value createIdentityMatrix(OpBuilder &builder, Location loc, RankedTensorType matrixType)
Creates a 2D identity matrix of the specified type.
static Value reverseRowsAndColumns(OpBuilder &builder, Location loc, Value matrix)
Computes A[::-1, ::-1] using permutation matrix through P @ A @ P.
PointerType * traceType(LLVMContext &C)
Value computeTotalJacobianCorrection(OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports)
Compute total Jacobian correction for the constrain transform over all position vector slices.
Value finalizeWelford(OpBuilder &builder, Location loc, const WelfordState &state, const WelfordConfig &config)
Finalize Welford state to produce sample covariance (returned as inverse mass matrix).
Value checkTurning(OpBuilder &builder, Location loc, Value pLeft, Value pRight, Value pSum, const NUTSContext &ctx)
U-turn termination criterion.
GradientResult computePotentialAndGradient(OpBuilder &builder, Location loc, Value position, Value rng, const HMCContext &ctx)
Computes potential energy U(q) = -log p(q) and its gradient dU/dq
std::pair< Value, Value > leafIdxToCheckpointIdxs(OpBuilder &builder, Location loc, Value leafIdx)
Computes checkpoint indices from leaf index for iterative turning check.
IntegratorState getLeafFromTree(OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, const NUTSContext &ctx)
Extracts an appropriate leaf based on direction.
std::pair< Value, Value > updateCheckpoints(OpBuilder &builder, Location loc, Value leafIdx, Value ckptIdxMax, Value p, Value pSum, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
Update checkpoint arrays at even leaf indices.
SubtreeBuildResult buildIterativeSubtree(OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
Builds a subtree iteratively by appending leaves one at a time.
std::pair< Value, Value > sampleMomentum(OpBuilder &builder, Location loc, Value rng, Value invMass, Value massMatrixSqrt, RankedTensorType positionType, bool debugDump=false)
Samples momentum from N(0, M) where M is the mass matrix.
Value computeUniformTransitionProb(OpBuilder &builder, Location loc, Value currentWeight, Value newWeight)
Computes the uniform transition probability for subtree combination.
SmallVector< AdaptWindow > buildAdaptationSchedule(int64_t numSteps)
Build warmup adaptation schedule.
NUTSTreeState buildBaseTree(OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const NUTSContext &ctx)
Builds a base tree (leaf node) by taking one leapfrog step.
Value computeKineticEnergy(OpBuilder &builder, Location loc, Value momentum, Value invMass, RankedTensorType positionType)
Computes K = 0.5 * p^T @ M^-1 @ p
NUTSTreeState combineTrees(OpBuilder &builder, Location loc, const NUTSTreeState &tree, const NUTSTreeState &subTree, Value direction, Value rng, bool biased, const NUTSContext &ctx)
Combines a tree with a newly-built subtree during NUTS doubling process.
Value computeBiasedTransitionProb(OpBuilder &builder, Location loc, Value currentWeight, Value newWeight, Value turning, Value diverging)
Computes the biased transition probability for main tree combination.
Value checkIterativeTurning(OpBuilder &builder, Location loc, Value p, Value pSum, Value pCkpts, Value pSumCkpts, Value idxMin, Value idxMax, const NUTSContext &ctx, bool debugDump=false)
Checkpoint-based iterative turning check.
NUTSTreeState buildTree(OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, const NUTSContext &ctx, bool debugDump=false)
Main NUTS tree building loop.
Value conditionalDump(OpBuilder &builder, Location loc, Value value, StringRef label, bool debugDump)
Conditionally dump a value for debugging.
MCMCKernelResult SampleNUTS(OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const NUTSContext &ctx, bool debugDump=false)
Single NUTS iteration: momentum sampling + tree building.
InitialHMCState InitHMC(OpBuilder &builder, Location loc, Value rng, const HMCContext &ctx, Value initialPosition=Value(), bool debugDump=false)
Initializes HMC/NUTS state from a trace Specifically:
SubtreeBuildResult doubleTree(OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
Tree doubling by building a subtree of same depth and combining.
IntegrationResult computeIntegrationStep(OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const HMCContext &ctx)
Computes a single leapfrog integration step.
DualAveragingState updateDualAveraging(OpBuilder &builder, Location loc, const DualAveragingState &state, Value acceptProb, const DualAveragingConfig &config)
Update dual averaging state with observed acceptance probability.
Value unconstrainPosition(OpBuilder &builder, Location loc, Value constrained, ArrayRef< SupportInfo > supports)
Transform an entire position vector from constrained to unconstrained space based on the support info...
DualAveragingState initDualAveraging(OpBuilder &builder, Location loc, Value stepSize)
Initialize dual averaging state from initial step size.
Value computeMassMatrixSqrt(OpBuilder &builder, Location loc, Value invMass, RankedTensorType positionType)
Computes the square root of the mass matrix from the inverse mass matrix.
WelfordState updateWelford(OpBuilder &builder, Location loc, const WelfordState &state, Value sample, const WelfordConfig &config)
Update Welford state with a new sample.
WelfordState initWelford(OpBuilder &builder, Location loc, int64_t positionSize, bool diagonal)
Initialize state for Welford covariance estimation.
Value getStepSizeFromDualAveraging(OpBuilder &builder, Location loc, const DualAveragingState &state, bool final=false)
Get step size from dual averaging state.
MCMCKernelResult SampleHMC(OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const HMCContext &ctx, bool debugDump=false)
Single HMC iteration: momentum sampling + leapfrog + MH accept/reject.
Value constrainPosition(OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports)
Transform an entire position vector from unconstrained to constrained space.
Value applyInverseMassMatrix(OpBuilder &builder, Location loc, Value invMass, Value momentum, RankedTensorType positionType)
Computes v = M^-1 @ p If invMass is nullptr, returns momentum unchanged (assumes identity) If invMass...
double target_accept_prob
DictionaryAttr autodiffAttrs
Type getElementType() const
HMCContext withStepSize(Value newStepSize) const
ArrayRef< Value > fnInputs
RankedTensorType getPositionType() const
int64_t getFullTraceSize() const
SmallVector< Type > fnResultTypes
SmallVector< SupportInfo > supports
bool hasCustomLogpdf() const
RankedTensorType getScalarType() const
FlatSymbolRefAttr logpdfFn
Result of one MCMC kernel step.
NUTSContext withH0(Value newH0) const
SmallVector< Value > toValues() const
SmallVector< Type > getTypes() const
static NUTSTreeState fromValues(ArrayRef< Value > values)
Configuration for Welford covariance estimation.
State for Welford covariance estimation.