21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/Math/IR/Math.h"
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Interfaces/FunctionInterfaces.h"
26#include "mlir/Pass/PassManager.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29#include "llvm/ADT/APFloat.h"
31#define DEBUG_TYPE "expand-impulse"
39#define GEN_PASS_DEF_EXPANDIMPULSEPASS
40#include "Passes/Passes.h.inc"
46static int64_t computeTensorElementCount(RankedTensorType tensorType) {
47 int64_t elemCount = 1;
48 for (
auto dim : tensorType.getShape()) {
49 if (dim == ShapedType::kDynamic)
56using SampleOpMap = DenseMap<Attribute, impulse::SampleOp>;
58static SampleOpMap buildSampleOpMap(FunctionOpInterface fn) {
60 fn.walk([&](impulse::SampleOp sampleOp) {
61 if (
auto symbol = sampleOp.getSymbolAttr())
62 map[symbol] = sampleOp;
67static impulse::SampleOp findSampleBySymbol(
const SampleOpMap &map,
68 Attribute targetSymbol) {
69 auto it = map.find(targetSymbol);
70 return it != map.end() ? it->second :
nullptr;
73static int64_t computeSampleElementCount(Operation *op,
74 impulse::SampleOp sampleOp) {
75 int64_t totalCount = 0;
76 for (
unsigned i = 1; i < sampleOp.getNumResults(); ++i) {
77 auto resultType = sampleOp.getResult(i).getType();
78 auto tensorType = dyn_cast<RankedTensorType>(resultType);
80 op->emitError(
"Expected ranked tensor type for sample result");
83 int64_t elemCount = computeTensorElementCount(tensorType);
85 op->emitError(
"Dynamic tensor dimensions not supported");
88 totalCount += elemCount;
93static bool computePositionSizeForAddress(Operation *op,
94 const SampleOpMap &sampleMap,
95 ArrayRef<Attribute> address,
96 SymbolTableCollection &symbolTable,
97 int64_t &positionSize) {
101 auto sampleOp = findSampleBySymbol(sampleMap, address[0]);
105 if (address.size() > 1) {
106 if (sampleOp.getLogpdfAttr()) {
107 op->emitError(
"Cannot select nested address in distribution function");
111 auto genFn = cast<FunctionOpInterface>(
112 symbolTable.lookupNearestSymbolFrom(sampleOp, sampleOp.getFnAttr()));
113 if (!genFn || genFn.getFunctionBody().empty()) {
114 op->emitError(
"Cannot find generative function for nested address");
118 auto nestedMap = buildSampleOpMap(genFn);
119 return computePositionSizeForAddress(op, nestedMap, address.drop_front(),
120 symbolTable, positionSize);
123 int64_t elemCount = computeSampleElementCount(op, sampleOp);
127 positionSize += elemCount;
132computePositionSizeForSelection(Operation *op, FunctionOpInterface fn,
134 SymbolTableCollection &symbolTable) {
135 auto sampleMap = buildSampleOpMap(fn);
136 int64_t positionSize = 0;
138 for (
auto addr : selection) {
139 auto address = cast<ArrayAttr>(addr);
140 if (address.empty()) {
141 op->emitError(
"Empty address in selection");
145 SmallVector<Attribute> tailAddresses(address.begin(), address.end());
146 if (!computePositionSizeForAddress(op, sampleMap, tailAddresses,
147 symbolTable, positionSize)) {
148 op->emitError(
"Could not find sample with symbol in address chain");
157computeOffsetForSampleInSelection(Operation *op, FunctionOpInterface fn,
158 ArrayAttr selection, Attribute targetSymbol,
159 SymbolTableCollection &symbolTable) {
160 auto sampleMap = buildSampleOpMap(fn);
163 for (
auto addr : selection) {
164 auto address = cast<ArrayAttr>(addr);
168 auto firstSymbol = address[0];
170 if (firstSymbol == targetSymbol) {
174 SmallVector<Attribute> tailAddresses(address.begin(), address.end());
175 if (!computePositionSizeForAddress(op, sampleMap, tailAddresses,
176 symbolTable, offset)) {
184static SmallVector<impulse::SupportInfo>
185collectSupportInfoForSelection(Operation *op, FunctionOpInterface fn,
186 ArrayAttr selection, ArrayAttr allAddresses,
187 SymbolTableCollection &symbolTable) {
188 auto sampleMap = buildSampleOpMap(fn);
189 SmallVector<impulse::SupportInfo> supports;
190 int64_t currentPositionOffset = 0;
192 for (
auto addr : selection) {
193 auto address = cast<ArrayAttr>(addr);
198 if (address.size() != 1)
201 auto targetSymbol = address[0];
202 auto sampleOp = findSampleBySymbol(sampleMap, targetSymbol);
206 auto supportAttr = sampleOp.getSupportAttr();
208 int64_t sampleSize = computeSampleElementCount(op, sampleOp);
212 int64_t traceOffset = computeOffsetForSampleInSelection(
213 op, fn, allAddresses, targetSymbol, symbolTable);
214 if (traceOffset < 0) {
215 op->emitError(
"Symbol in selection not found in all_addresses - cannot "
216 "determine trace offset for scattered selection");
220 supports.emplace_back(currentPositionOffset, traceOffset, sampleSize,
222 currentPositionOffset += sampleSize;
228static ArrayAttr buildSubSelection(OpBuilder &builder, ArrayAttr selection,
229 Attribute targetSymbol) {
230 SmallVector<Attribute> subAddresses;
231 for (
auto addr : selection) {
232 auto address = cast<ArrayAttr>(addr);
235 if (address[0] == targetSymbol && address.size() > 1) {
236 SmallVector<Attribute> tail(address.begin() + 1, address.end());
237 subAddresses.push_back(builder.getArrayAttr(tail));
240 return builder.getArrayAttr(subAddresses);
244computeOffsetForNestedSample(Operation *op, FunctionOpInterface fn,
245 ArrayAttr selection, Attribute targetSymbol,
246 SymbolTableCollection &symbolTable) {
247 auto sampleMap = buildSampleOpMap(fn);
250 for (
auto addr : selection) {
251 auto address = cast<ArrayAttr>(addr);
255 if (address[0] == targetSymbol) {
259 SmallVector<Attribute> tailAddresses(address.begin(), address.end());
260 if (!computePositionSizeForAddress(op, sampleMap, tailAddresses,
261 symbolTable, offset)) {
269struct ExpandImpulsePass
270 :
public enzyme::impl::ExpandImpulsePassBase<ExpandImpulsePass> {
271 using ExpandImpulsePassBase::ExpandImpulsePassBase;
275 void runOnOperation()
override;
277 void getDependentDialects(DialectRegistry ®istry)
const override {
278 mlir::OpPassManager pm;
279 mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm);
280 if (!mlir::failed(result)) {
281 pm.getDependentDialects(registry);
285 .insert<mlir::arith::ArithDialect, mlir::math::MathDialect,
286 mlir::complex::ComplexDialect, mlir::cf::ControlFlowDialect,
287 mlir::enzyme::EnzymeDialect, mlir::impulse::ImpulseDialect>();
290 struct LowerUntracedCallPattern
291 :
public mlir::OpRewritePattern<impulse::UntracedCallOp> {
294 LogicalResult matchAndRewrite(impulse::UntracedCallOp CI,
295 PatternRewriter &rewriter)
const override {
296 SymbolTableCollection symbolTable;
298 auto fn = cast<FunctionOpInterface>(
299 symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()));
301 if (fn.getFunctionBody().empty()) {
302 CI.emitError(
"Impulse: trying to call an empty function");
307 FunctionOpInterface NewF = putils->newFunc;
309 SmallVector<Operation *, 4> toErase;
310 NewF.walk([&](impulse::SampleOp sampleOp) {
311 OpBuilder::InsertionGuard guard(rewriter);
312 rewriter.setInsertionPoint(sampleOp);
315 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
316 sampleOp, sampleOp.getFnAttr()));
318 func::CallOp::create(rewriter, sampleOp.getLoc(), distFn.getName(),
319 distFn.getResultTypes(), sampleOp.getInputs());
320 sampleOp.replaceAllUsesWith(distCall);
322 toErase.push_back(sampleOp);
325 for (Operation *op : toErase)
326 rewriter.eraseOp(op);
328 rewriter.setInsertionPoint(CI);
330 func::CallOp::create(rewriter, CI.getLoc(), NewF.getName(),
331 NewF.getResultTypes(), CI.getOperands());
333 rewriter.replaceOp(CI, newCI.getResults());
341 struct LowerSimulatePattern
342 :
public mlir::OpRewritePattern<impulse::SimulateOp> {
345 LogicalResult matchAndRewrite(impulse::SimulateOp CI,
346 PatternRewriter &rewriter)
const override {
347 SymbolTableCollection symbolTable;
349 auto fn = cast<FunctionOpInterface>(
350 symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()));
352 if (fn.getFunctionBody().empty()) {
354 "Impulse: calling `simulate` on an empty function; if this "
355 "is a distribution function, its sample op should have a "
356 "logpdf attribute to avoid recursive `simulate` calls which is "
357 "intended for generative functions");
361 ArrayAttr selection = CI.getSelectionAttr();
362 int64_t positionSize =
363 computePositionSizeForSelection(CI, fn, selection, symbolTable);
364 if (positionSize <= 0) {
365 CI.emitError(
"Impulse: failed to compute position size for simulate");
371 FunctionOpInterface NewF = putils->newFunc;
373 OpBuilder entryBuilder(putils->initializationBlock,
374 putils->initializationBlock->begin());
375 Location initLoc = putils->initializationBlock->begin()->getLoc();
376 auto scalarType = RankedTensorType::get({}, entryBuilder.getF64Type());
378 arith::ConstantOp::create(entryBuilder, initLoc, scalarType,
379 DenseElementsAttr::get(scalarType, 0.0));
380 Value weightAccumulator = zeroWeight;
383 RankedTensorType::get({1, positionSize}, entryBuilder.getF64Type());
385 arith::ConstantOp::create(entryBuilder, initLoc,
traceType,
387 Value currTrace = zeroTrace;
388 int64_t currentOffset = 0;
390 SmallVector<Operation *> toErase;
391 auto result = NewF.walk([&](impulse::SampleOp sampleOp) -> WalkResult {
392 OpBuilder::InsertionGuard guard(rewriter);
393 rewriter.setInsertionPoint(sampleOp);
395 SmallVector<Value> sampledValues;
396 bool isDistribution =
static_cast<bool>(sampleOp.getLogpdfAttr());
398 if (isDistribution) {
401 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
402 sampleOp, sampleOp.getFnAttr()));
404 auto distCall = func::CallOp::create(
405 rewriter, sampleOp.getLoc(), distFn.getName(),
406 distFn.getResultTypes(), sampleOp.getInputs());
408 sampledValues.append(distCall.getResults().begin(),
409 distCall.getResults().end());
412 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
413 sampleOp, sampleOp.getLogpdfAttr()));
416 SmallVector<Value> logpdfOperands;
417 for (
unsigned i = 1; i < sampledValues.size(); ++i) {
418 logpdfOperands.push_back(sampledValues[i]);
420 for (
unsigned i = 1; i < sampleOp.getNumOperands(); ++i) {
421 logpdfOperands.push_back(sampleOp.getOperand(i));
424 if (logpdfOperands.size() != logpdfFn.getNumArguments()) {
425 sampleOp.emitError(
"Impulse: failed to construct logpdf call; "
426 "logpdf function has wrong number of arguments");
427 return WalkResult::interrupt();
431 auto logpdf = func::CallOp::create(
432 rewriter, sampleOp.getLoc(), logpdfFn.getName(),
433 logpdfFn.getResultTypes(), logpdfOperands);
435 arith::AddFOp::create(rewriter, sampleOp.getLoc(),
436 weightAccumulator, logpdf.getResult(0));
439 bool inSelection =
false;
440 for (
auto addr : selection) {
441 auto address = cast<ArrayAttr>(addr);
442 if (!address.empty() && address[0] == sampleOp.getSymbolAttr()) {
449 for (
unsigned i = 1; i < sampledValues.size(); ++i) {
450 auto sampleValue = sampledValues[i];
451 auto sampleType = cast<RankedTensorType>(sampleValue.getType());
452 int64_t numElements = computeTensorElementCount(sampleType);
453 if (numElements < 0) {
455 "Impulse: dynamic tensor dimensions not supported");
456 return WalkResult::interrupt();
459 auto flatSampleType = RankedTensorType::get(
460 {1, numElements}, sampleType.getElementType());
461 auto flatSample = impulse::ReshapeOp::create(
462 rewriter, sampleOp.getLoc(), flatSampleType, sampleValue);
463 auto i64S = RankedTensorType::get({}, rewriter.getI64Type());
464 auto row0 = arith::ConstantOp::create(
465 rewriter, sampleOp.getLoc(), i64S,
466 DenseElementsAttr::get(i64S, rewriter.getI64IntegerAttr(0)));
467 auto colOff = arith::ConstantOp::create(
468 rewriter, sampleOp.getLoc(), i64S,
469 DenseElementsAttr::get(
470 i64S, rewriter.getI64IntegerAttr(currentOffset)));
471 currTrace = impulse::DynamicUpdateSliceOp::create(
472 rewriter, sampleOp.getLoc(),
traceType, currTrace,
473 flatSample, ValueRange{row0, colOff})
475 currentOffset += numElements;
481 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
482 sampleOp, sampleOp.getFnAttr()));
484 if (genFn.getFunctionBody().empty()) {
486 "Impulse: generative function body is empty; "
487 "if this is a distribution, add a logpdf attribute");
488 return WalkResult::interrupt();
491 ArrayAttr subSelection =
492 buildSubSelection(rewriter, selection, sampleOp.getSymbolAttr());
493 if (subSelection.empty()) {
496 auto genCall = func::CallOp::create(
497 rewriter, sampleOp.getLoc(), genFn.getName(),
498 genFn.getResultTypes(), sampleOp.getInputs());
499 sampledValues.append(genCall.getResults().begin(),
500 genCall.getResults().end());
502 int64_t subPositionSize = computePositionSizeForSelection(
503 sampleOp, genFn, subSelection, symbolTable);
504 if (subPositionSize <= 0) {
506 "Impulse: failed to compute sub-position size");
507 return WalkResult::interrupt();
511 auto subTraceType = RankedTensorType::get({1, subPositionSize},
512 rewriter.getF64Type());
513 auto scalarTy = RankedTensorType::get({}, rewriter.getF64Type());
514 SmallVector<Type> simResultTypes;
515 simResultTypes.push_back(subTraceType);
516 simResultTypes.push_back(scalarTy);
517 for (
auto t : genFn.getResultTypes())
518 simResultTypes.push_back(t);
520 auto nestedSimulate = impulse::SimulateOp::create(
521 rewriter, sampleOp.getLoc(), simResultTypes,
522 sampleOp.getFnAttr(), sampleOp.getInputs(), subSelection);
523 auto subTrace = nestedSimulate.getTrace();
524 auto subWeight = nestedSimulate.getWeight();
526 weightAccumulator = arith::AddFOp::create(
527 rewriter, sampleOp.getLoc(), weightAccumulator, subWeight);
529 int64_t mergeOffset = computeOffsetForNestedSample(
530 sampleOp, fn, selection, sampleOp.getSymbolAttr(), symbolTable);
531 if (mergeOffset < 0) {
532 sampleOp.emitError(
"Impulse: failed to compute merge offset");
533 return WalkResult::interrupt();
536 auto i64S = RankedTensorType::get({}, rewriter.getI64Type());
537 auto row0 = arith::ConstantOp::create(
538 rewriter, sampleOp.getLoc(), i64S,
539 DenseElementsAttr::get(i64S, rewriter.getI64IntegerAttr(0)));
540 auto colOff = arith::ConstantOp::create(
541 rewriter, sampleOp.getLoc(), i64S,
542 DenseElementsAttr::get(
543 i64S, rewriter.getI64IntegerAttr(mergeOffset)));
544 currTrace = impulse::DynamicUpdateSliceOp::create(
545 rewriter, sampleOp.getLoc(),
traceType, currTrace,
546 subTrace, ValueRange{row0, colOff})
549 std::max(currentOffset, mergeOffset + subPositionSize);
551 for (
auto output : nestedSimulate.getOutputs())
552 sampledValues.push_back(output);
557 sampleOp.replaceAllUsesWith(sampledValues);
559 toErase.push_back(sampleOp);
560 return WalkResult::advance();
563 for (Operation *op : toErase)
564 rewriter.eraseOp(op);
566 if (result.wasInterrupted()) {
567 CI.emitError(
"Impulse: failed to walk sample ops");
572 NewF.walk([&](func::ReturnOp retOp) {
573 OpBuilder::InsertionGuard guard(rewriter);
574 rewriter.setInsertionPoint(retOp);
575 SmallVector<Value> newRetVals;
576 newRetVals.push_back(currTrace);
577 newRetVals.push_back(weightAccumulator);
578 newRetVals.append(retOp.getOperands().begin(),
579 retOp.getOperands().end());
581 func::ReturnOp::create(rewriter, retOp.getLoc(), newRetVals);
582 rewriter.eraseOp(retOp);
585 rewriter.setInsertionPoint(CI);
586 auto newCI = func::CallOp::create(rewriter, CI.getLoc(), NewF.getName(),
587 NewF.getResultTypes(), CI.getInputs());
589 rewriter.replaceOp(CI, newCI.getResults());
597 struct LowerMCMCPattern :
public mlir::OpRewritePattern<impulse::InferOp> {
600 LowerMCMCPattern(MLIRContext *context,
bool debugDump,
601 PatternBenefit benefit = 1)
604 LogicalResult matchAndRewrite(impulse::InferOp mcmcOp,
605 PatternRewriter &rewriter)
const override {
606 SymbolTableCollection symbolTable;
608 bool hasLogpdfFn =
static_cast<bool>(mcmcOp.getLogpdfFnAttr());
611 auto fnAttr = mcmcOp.getFnAttr();
613 mcmcOp.emitError(
"Impulse: either fn or logpdf_fn must be provided");
616 auto fn = cast<FunctionOpInterface>(
617 symbolTable.lookupNearestSymbolFrom(mcmcOp, fnAttr));
618 if (fn.getFunctionBody().empty()) {
619 mcmcOp.emitError(
"Impulse: calling `mcmc` on an empty function");
624 if (!mcmcOp.getStepSize()) {
625 mcmcOp.emitError(
"Impulse: MCMC requires step_size parameter");
629 bool isHMC = mcmcOp.getHmcConfig().has_value();
630 bool isNUTS = mcmcOp.getNutsConfig().has_value();
631 if (!isHMC && !isNUTS) {
632 mcmcOp.emitError(
"Impulse: Unknown MCMC algorithm");
636 auto loc = mcmcOp.getLoc();
637 auto invMass = mcmcOp.getInverseMassMatrix();
638 Value adaptedInvMass = invMass;
639 auto stepSize = mcmcOp.getStepSize();
641 auto inputs = mcmcOp.getInputs();
642 if (inputs.empty()) {
643 mcmcOp.emitError(
"Impulse: MCMC requires at least rng_state input");
647 auto rngInput = inputs[0];
649 int64_t positionSize;
650 SmallVector<Value> fnInputs;
651 SmallVector<Type> fnResultTypes;
653 ArrayAttr selection, allAddresses;
654 SmallVector<SupportInfo> supports;
655 FlatSymbolRefAttr logpdfFnAttr;
658 logpdfFnAttr = mcmcOp.getLogpdfFnAttr();
659 fnInputs.assign(inputs.begin() + 1, inputs.end());
660 auto initialPos = mcmcOp.getInitialPosition();
661 auto initPosType = cast<RankedTensorType>(initialPos.getType());
662 positionSize = initPosType.getNumElements();
663 selection = mcmcOp.getSelectionAttr();
664 allAddresses = mcmcOp.getAllAddressesAttr();
666 fnInputs.assign(inputs.begin() + 1, inputs.end());
667 originalTrace = mcmcOp.getOriginalTrace();
668 selection = mcmcOp.getSelectionAttr();
669 allAddresses = mcmcOp.getAllAddressesAttr();
671 auto fn = cast<FunctionOpInterface>(
672 symbolTable.lookupNearestSymbolFrom(mcmcOp, mcmcOp.getFnAttr()));
674 computePositionSizeForSelection(mcmcOp, fn, selection, symbolTable);
675 if (positionSize <= 0)
678 supports = collectSupportInfoForSelection(mcmcOp, fn, selection,
679 allAddresses, symbolTable);
681 auto fnType = cast<FunctionType>(fn.getFunctionType());
682 fnResultTypes.assign(fnType.getResults().begin(),
683 fnType.getResults().end());
686 int64_t numSamples = mcmcOp.getNumSamples();
687 int64_t thinning = mcmcOp.getThinning();
688 int64_t numWarmup = mcmcOp.getNumWarmup();
691 cast<RankedTensorType>(stepSize.getType()).getElementType();
692 auto positionType = RankedTensorType::get({1, positionSize}, elemType);
693 auto scalarType = RankedTensorType::get({}, elemType);
694 auto i64TensorType = RankedTensorType::get({}, rewriter.getI64Type());
695 auto i1TensorType = RankedTensorType::get({}, rewriter.getI1Type());
698 Value trajectoryLength;
699 Value maxDeltaEnergy;
700 int64_t maxTreeDepth = 0;
702 bool adaptStepSize =
false;
703 bool adaptMassMatrix =
false;
704 auto F64TensorType = RankedTensorType::get({}, rewriter.getF64Type());
706 auto hmcConfig = mcmcOp.getHmcConfig().value();
707 double length = hmcConfig.getTrajectoryLength().getValueAsDouble();
708 trajectoryLength = arith::ConstantOp::create(
709 rewriter, loc, F64TensorType,
710 DenseElementsAttr::get(F64TensorType,
711 rewriter.getF64FloatAttr(length)));
712 adaptStepSize = hmcConfig.getAdaptStepSize();
713 adaptMassMatrix = hmcConfig.getAdaptMassMatrix();
715 auto nutsConfig = mcmcOp.getNutsConfig().value();
716 maxTreeDepth = nutsConfig.getMaxTreeDepth();
717 adaptStepSize = nutsConfig.getAdaptStepSize();
718 adaptMassMatrix = nutsConfig.getAdaptMassMatrix();
719 double maxDeltaEnergyVal =
720 nutsConfig.getMaxDeltaEnergy()
721 ? nutsConfig.getMaxDeltaEnergy().getValueAsDouble()
723 maxDeltaEnergy = arith::ConstantOp::create(
724 rewriter, loc, F64TensorType,
725 DenseElementsAttr::get(
726 F64TensorType, rewriter.getF64FloatAttr(maxDeltaEnergyVal)));
729 bool diagonal =
true;
731 auto invMassType = cast<RankedTensorType>(invMass.getType());
732 diagonal = (invMassType.getRank() == 1);
735 auto adaptedMassMatrixSqrt =
738 auto autodiffAttrs = mcmcOp.getAutodiffAttrsAttr();
740 auto makeHMCContext = [&](Value currentInvMass,
741 Value currentMassMatrixSqrt,
744 return HMCContext(logpdfFnAttr, fnInputs, currentInvMass,
745 currentMassMatrixSqrt, currentStepSize,
746 trajectoryLength, positionSize, autodiffAttrs);
748 return HMCContext(mcmcOp.getFnAttr(), fnInputs, fnResultTypes,
749 originalTrace, selection, allAddresses,
750 currentInvMass, currentMassMatrixSqrt,
751 currentStepSize, trajectoryLength, positionSize,
752 supports, autodiffAttrs);
756 auto makeNUTSContext =
757 [&](Value currentInvMass, Value currentMassMatrixSqrt,
760 return NUTSContext(logpdfFnAttr, fnInputs, currentInvMass,
761 currentMassMatrixSqrt, currentStepSize,
762 positionSize, U, maxDeltaEnergy, maxTreeDepth,
765 return NUTSContext(mcmcOp.getFnAttr(), fnInputs, fnResultTypes,
766 originalTrace, selection, allAddresses,
767 currentInvMass, currentMassMatrixSqrt,
768 currentStepSize, positionSize, supports, U,
769 maxDeltaEnergy, maxTreeDepth, autodiffAttrs);
773 Value currentQ, currentGrad, currentU, currentRng;
775 auto initialGrad = mcmcOp.getInitialGradient();
776 auto initialPE = mcmcOp.getInitialPotentialEnergy();
778 if (hasLogpdfFn && initialGrad && initialPE) {
779 currentQ = mcmcOp.getInitialPosition();
780 currentGrad = initialGrad;
781 currentU = initialPE;
782 currentRng = rngInput;
785 makeHMCContext(adaptedInvMass, adaptedMassMatrixSqrt, stepSize);
787 rewriter, loc, rngInput, baseCtx,
788 hasLogpdfFn ? mcmcOp.getInitialPosition() : Value(), debugDump);
789 currentQ = initState.q0;
790 currentGrad = initState.grad0;
791 currentU = initState.U0;
792 currentRng = initState.rng;
795 auto runSampleStepWithStepSize =
796 [&](OpBuilder &builder, Location loc, Value q, Value grad, Value U,
799 auto ctx = makeHMCContext(adaptedInvMass, adaptedMassMatrixSqrt,
801 return SampleHMC(builder, loc, q, grad, U, rng, ctx, debugDump);
803 auto nutsCtx = makeNUTSContext(adaptedInvMass, adaptedMassMatrixSqrt,
805 return SampleNUTS(builder, loc, q, grad, U, rng, nutsCtx, debugDump);
809 Value adaptedStepSize = stepSize;
811 auto runSampleStepWithInvMass =
812 [&](OpBuilder &builder, Location loc, Value q, Value grad, Value U,
813 Value rng, Value currentStepSize, Value currentInvMass,
816 auto ctx = makeHMCContext(currentInvMass, currentMassMatrixSqrt,
818 return SampleHMC(builder, loc, q, grad, U, rng, ctx, debugDump);
820 auto nutsCtx = makeNUTSContext(currentInvMass, currentMassMatrixSqrt,
822 return SampleNUTS(builder, loc, q, grad, U, rng, nutsCtx, debugDump);
826 if (!adaptedInvMass) {
827 adaptedInvMass = arith::ConstantOp::create(
828 rewriter, loc, positionType,
829 DenseElementsAttr::get(positionType,
830 rewriter.getFloatAttr(elemType, 1.0)));
831 adaptedMassMatrixSqrt = arith::ConstantOp::create(
832 rewriter, loc, positionType,
833 DenseElementsAttr::get(positionType,
834 rewriter.getFloatAttr(elemType, 1.0)));
838 auto c0 = arith::ConstantOp::create(
839 rewriter, loc, i64TensorType,
840 DenseElementsAttr::get(i64TensorType,
841 rewriter.getI64IntegerAttr(0)));
842 auto c1 = arith::ConstantOp::create(
843 rewriter, loc, i64TensorType,
844 DenseElementsAttr::get(i64TensorType,
845 rewriter.getI64IntegerAttr(1)));
846 auto numWarmupConst = arith::ConstantOp::create(
847 rewriter, loc, i64TensorType,
848 DenseElementsAttr::get(i64TensorType,
849 rewriter.getI64IntegerAttr(numWarmup)));
852 int64_t numWindows =
static_cast<int64_t
>(schedule.size());
854 SmallVector<Value> windowEndConstants;
855 for (
const auto &window : schedule) {
856 windowEndConstants.push_back(arith::ConstantOp::create(
857 rewriter, loc, i64TensorType,
858 DenseElementsAttr::get(i64TensorType,
859 rewriter.getI64IntegerAttr(window.end))));
862 auto numWindowsMinusOne = arith::ConstantOp::create(
863 rewriter, loc, i64TensorType,
864 DenseElementsAttr::get(i64TensorType,
865 rewriter.getI64IntegerAttr(numWindows - 1)));
866 auto lastIterConst = arith::ConstantOp::create(
867 rewriter, loc, i64TensorType,
868 DenseElementsAttr::get(i64TensorType,
869 rewriter.getI64IntegerAttr(numWarmup - 1)));
871 if (!adaptedInvMass) {
872 adaptedInvMass = arith::ConstantOp::create(
873 rewriter, loc, positionType,
874 DenseElementsAttr::get(positionType,
875 rewriter.getFloatAttr(elemType, 1.0)));
876 adaptedMassMatrixSqrt = arith::ConstantOp::create(
877 rewriter, loc, positionType,
878 DenseElementsAttr::get(positionType,
879 rewriter.getFloatAttr(elemType, 1.0)));
882 Value initialStepSize = stepSize;
885 "MCMC: initial step size before warmup", debugDump);
891 if (adaptMassMatrix) {
892 welfordState =
initWelford(rewriter, loc, positionSize, diagonal);
897 Value windowIdx = arith::ConstantOp::create(
898 rewriter, loc, i64TensorType,
899 DenseElementsAttr::get(i64TensorType,
900 rewriter.getI64IntegerAttr(0)));
905 SmallVector<Type> warmupLoopTypes = {positionType,
908 currentRng.getType(),
910 adaptedInvMass.getType(),
911 adaptedMassMatrixSqrt.getType()};
913 warmupLoopTypes.push_back(t);
914 if (adaptMassMatrix) {
915 for (Type t : welfordState.
getTypes())
916 warmupLoopTypes.push_back(t);
918 warmupLoopTypes.push_back(i64TensorType);
920 SmallVector<Value> warmupInitArgs = {currentQ,
926 adaptedMassMatrixSqrt};
928 warmupInitArgs.push_back(v);
929 if (adaptMassMatrix) {
930 for (Value v : welfordState.
toValues())
931 warmupInitArgs.push_back(v);
933 warmupInitArgs.push_back(windowIdx);
936 impulse::ForOp::create(rewriter, loc, warmupLoopTypes, c0,
937 numWarmupConst, c1, warmupInitArgs);
939 Block *warmupBody = rewriter.createBlock(&warmupLoop.getRegion());
940 warmupBody->addArgument(i64TensorType, loc);
941 for (Type t : warmupLoopTypes)
942 warmupBody->addArgument(t, loc);
944 rewriter.setInsertionPointToStart(warmupBody);
946 Value iterT = warmupBody->getArgument(0);
947 Value qLoop = warmupBody->getArgument(1);
948 Value gradLoop = warmupBody->getArgument(2);
949 Value ULoop = warmupBody->getArgument(3);
950 Value rngLoop = warmupBody->getArgument(4);
951 Value stepSizeLoop = warmupBody->getArgument(5);
952 Value invMassLoop = warmupBody->getArgument(6);
953 Value massMatrixSqrtLoop = warmupBody->getArgument(7);
955 SmallVector<Value> daStateLoopValues;
956 for (
int i = 0; i < 5; ++i)
957 daStateLoopValues.push_back(warmupBody->getArgument(8 + i));
962 if (adaptMassMatrix) {
963 SmallVector<Value> welfordStateLoopValues;
964 for (
int i = 0; i < 3; ++i)
965 welfordStateLoopValues.push_back(warmupBody->getArgument(13 + i));
967 windowIdxLoop = warmupBody->getArgument(16);
969 windowIdxLoop = warmupBody->getArgument(13);
972 auto sample = runSampleStepWithInvMass(rewriter, loc, qLoop, gradLoop,
973 ULoop, rngLoop, stepSizeLoop,
974 invMassLoop, massMatrixSqrtLoop);
979 Value currentStepSizeFromDA;
980 Value finalStepSizeFromDA;
984 sample.accept_prob, daConfig);
985 currentStepSizeFromDA =
987 finalStepSizeFromDA =
990 updatedDaState = daStateLoop;
991 currentStepSizeFromDA = stepSizeLoop;
992 finalStepSizeFromDA = stepSizeLoop;
996 auto isLastIter = arith::CmpIOp::create(
997 rewriter, loc, arith::CmpIPredicate::eq, iterT, lastIterConst);
998 Value adaptedStepSizeInLoop = impulse::SelectOp::create(
999 rewriter, loc, scalarType, isLastIter, finalStepSizeFromDA,
1000 currentStepSizeFromDA);
1002 const auto &floatSemantics =
1003 cast<FloatType>(elemType).getFloatSemantics();
1004 auto tinyConst = arith::ConstantOp::create(
1005 rewriter, loc, scalarType,
1006 DenseElementsAttr::get(
1007 scalarType, FloatAttr::get(elemType, llvm::APFloat::getSmallest(
1009 auto maxConst = arith::ConstantOp::create(
1010 rewriter, loc, scalarType,
1011 DenseElementsAttr::get(
1012 scalarType, FloatAttr::get(elemType, llvm::APFloat::getLargest(
1014 adaptedStepSizeInLoop = arith::MaximumFOp::create(
1015 rewriter, loc, adaptedStepSizeInLoop, tinyConst);
1016 adaptedStepSizeInLoop = arith::MinimumFOp::create(
1017 rewriter, loc, adaptedStepSizeInLoop, maxConst);
1019 auto windowIdxGtZero = arith::CmpIOp::create(
1020 rewriter, loc, arith::CmpIPredicate::sgt, windowIdxLoop, c0);
1021 auto windowIdxLtLast =
1022 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1023 windowIdxLoop, numWindowsMinusOne);
1024 auto isMiddleWindow = arith::AndIOp::create(
1025 rewriter, loc, windowIdxGtZero, windowIdxLtLast);
1029 if (adaptMassMatrix) {
1030 auto sampleType1D = RankedTensorType::get({positionSize}, elemType);
1032 impulse::ReshapeOp::create(rewriter, loc, sampleType1D, sample.q);
1034 rewriter, loc, welfordStateLoop, sample1D, welfordConfig);
1036 conditionalWelford.
mean = impulse::SelectOp::create(
1037 rewriter, loc, welfordStateLoop.
mean.getType(), isMiddleWindow,
1038 updatedWelfordAfterSample.
mean, welfordStateLoop.
mean);
1039 conditionalWelford.
m2 = impulse::SelectOp::create(
1040 rewriter, loc, welfordStateLoop.
m2.getType(), isMiddleWindow,
1041 updatedWelfordAfterSample.
m2, welfordStateLoop.
m2);
1042 conditionalWelford.
n = impulse::SelectOp::create(
1043 rewriter, loc, welfordStateLoop.
n.getType(), isMiddleWindow,
1044 updatedWelfordAfterSample.
n, welfordStateLoop.
n);
1047 Value atWindowEnd = arith::ConstantOp::create(
1048 rewriter, loc, i1TensorType,
1049 DenseElementsAttr::get(i1TensorType, rewriter.getBoolAttr(
false)));
1051 for (int64_t w = 0; w < numWindows; ++w) {
1052 auto windowIdxIsW = arith::CmpIOp::create(
1053 rewriter, loc, arith::CmpIPredicate::eq, windowIdxLoop,
1054 arith::ConstantOp::create(
1055 rewriter, loc, i64TensorType,
1056 DenseElementsAttr::get(i64TensorType,
1057 rewriter.getI64IntegerAttr(w))));
1058 auto tEqualsWindowEnd =
1059 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1060 iterT, windowEndConstants[w]);
1061 auto matchesThisWindow = arith::AndIOp::create(
1062 rewriter, loc, windowIdxIsW, tEqualsWindowEnd);
1063 atWindowEnd = arith::OrIOp::create(rewriter, loc, atWindowEnd,
1067 Value newWindowIdx =
1068 arith::AddIOp::create(rewriter, loc, windowIdxLoop, c1);
1069 Value windowIdxAfterIncrement =
1070 impulse::SelectOp::create(rewriter, loc, i64TensorType, atWindowEnd,
1071 newWindowIdx, windowIdxLoop);
1073 auto atMiddleWindowEnd =
1074 arith::AndIOp::create(rewriter, loc, atWindowEnd, isMiddleWindow);
1077 Value finalMassMatrixSqrt;
1079 Value finalStepSizeValue;
1082 SmallVector<Type> ifResultTypes;
1083 ifResultTypes.push_back(invMassLoop.getType());
1084 ifResultTypes.push_back(massMatrixSqrtLoop.getType());
1085 if (adaptMassMatrix) {
1086 ifResultTypes.push_back(conditionalWelford.
mean.getType());
1087 ifResultTypes.push_back(conditionalWelford.
m2.getType());
1088 ifResultTypes.push_back(conditionalWelford.
n.getType());
1090 for (Type t : updatedDaState.
getTypes())
1091 ifResultTypes.push_back(t);
1093 auto ifOp = impulse::IfOp::create(rewriter, loc, ifResultTypes,
1097 Block *trueBranch = rewriter.createBlock(&ifOp.getTrueBranch());
1098 rewriter.setInsertionPointToStart(trueBranch);
1100 SmallVector<Value> trueYieldValues;
1102 if (adaptMassMatrix) {
1105 auto newMassMatrixSqrt =
1107 auto reinitWelford =
1108 initWelford(rewriter, loc, positionSize, diagonal);
1110 trueYieldValues.push_back(newInvMass);
1111 trueYieldValues.push_back(newMassMatrixSqrt);
1112 trueYieldValues.push_back(reinitWelford.mean);
1113 trueYieldValues.push_back(reinitWelford.m2);
1114 trueYieldValues.push_back(reinitWelford.n);
1116 trueYieldValues.push_back(invMassLoop);
1117 trueYieldValues.push_back(massMatrixSqrtLoop);
1120 if (adaptStepSize) {
1121 auto reinitDaState =
1123 for (
auto v : reinitDaState.toValues())
1124 trueYieldValues.push_back(v);
1126 for (
auto v : updatedDaState.
toValues())
1127 trueYieldValues.push_back(v);
1130 impulse::YieldOp::create(rewriter, loc, trueYieldValues);
1134 Block *falseBranch = rewriter.createBlock(&ifOp.getFalseBranch());
1135 rewriter.setInsertionPointToStart(falseBranch);
1137 SmallVector<Value> falseYieldValues;
1138 falseYieldValues.push_back(invMassLoop);
1139 falseYieldValues.push_back(massMatrixSqrtLoop);
1140 if (adaptMassMatrix) {
1141 falseYieldValues.push_back(conditionalWelford.
mean);
1142 falseYieldValues.push_back(conditionalWelford.
m2);
1143 falseYieldValues.push_back(conditionalWelford.
n);
1145 for (
auto v : updatedDaState.
toValues())
1146 falseYieldValues.push_back(v);
1148 impulse::YieldOp::create(rewriter, loc, falseYieldValues);
1151 rewriter.setInsertionPointAfter(ifOp);
1153 size_t resultIdx = 0;
1154 finalInvMass = ifOp.getResult(resultIdx++);
1155 finalMassMatrixSqrt = ifOp.getResult(resultIdx++);
1156 if (adaptMassMatrix) {
1157 finalWelfordState.
mean = ifOp.getResult(resultIdx++);
1158 finalWelfordState.
m2 = ifOp.getResult(resultIdx++);
1159 finalWelfordState.
n = ifOp.getResult(resultIdx++);
1163 finalDaState.
gradient_avg = ifOp.getResult(resultIdx++);
1164 finalDaState.
step_count = ifOp.getResult(resultIdx++);
1165 finalDaState.
prox_center = ifOp.getResult(resultIdx++);
1167 finalStepSizeValue = adaptedStepSizeInLoop;
1169 SmallVector<Value> warmupYieldValues = {
1170 sample.q, sample.grad, sample.U, sample.rng,
1171 finalStepSizeValue, finalInvMass, finalMassMatrixSqrt};
1172 for (Value v : finalDaState.
toValues())
1173 warmupYieldValues.push_back(v);
1174 if (adaptMassMatrix) {
1175 for (Value v : finalWelfordState.
toValues())
1176 warmupYieldValues.push_back(v);
1178 warmupYieldValues.push_back(windowIdxAfterIncrement);
1180 impulse::YieldOp::create(rewriter, loc, warmupYieldValues);
1182 rewriter.setInsertionPointAfter(warmupLoop);
1184 currentQ = warmupLoop.getResult(0);
1185 currentGrad = warmupLoop.getResult(1);
1186 currentU = warmupLoop.getResult(2);
1187 currentRng = warmupLoop.getResult(3);
1188 adaptedStepSize = warmupLoop.getResult(4);
1189 adaptedInvMass = warmupLoop.getResult(5);
1190 adaptedMassMatrixSqrt = warmupLoop.getResult(6);
1194 "MCMC: adapted step size after warmup", debugDump);
1195 if (adaptMassMatrix) {
1197 rewriter, loc, adaptedInvMass,
1198 "MCMC: adapted inverse mass matrix after warmup", debugDump);
1202 int64_t collectionSize = numSamples / thinning;
1203 int64_t startIdx = numSamples % thinning;
1205 auto samplesBufferType =
1206 RankedTensorType::get({collectionSize, positionSize}, elemType);
1207 auto acceptedBufferType =
1208 RankedTensorType::get({collectionSize}, rewriter.getI1Type());
1210 auto samplesBuffer = arith::ConstantOp::create(
1211 rewriter, loc, samplesBufferType,
1212 DenseElementsAttr::get(samplesBufferType,
1213 rewriter.getFloatAttr(elemType, 0.0)));
1214 auto acceptedBuffer = arith::ConstantOp::create(
1215 rewriter, loc, acceptedBufferType,
1216 DenseElementsAttr::get(acceptedBufferType,
1217 rewriter.getBoolAttr(isNUTS)));
1219 auto c0 = arith::ConstantOp::create(
1220 rewriter, loc, i64TensorType,
1221 DenseElementsAttr::get(i64TensorType, rewriter.getI64IntegerAttr(0)));
1222 auto c1 = arith::ConstantOp::create(
1223 rewriter, loc, i64TensorType,
1224 DenseElementsAttr::get(i64TensorType, rewriter.getI64IntegerAttr(1)));
1225 auto numSamplesConst = arith::ConstantOp::create(
1226 rewriter, loc, i64TensorType,
1227 DenseElementsAttr::get(i64TensorType,
1228 rewriter.getI64IntegerAttr(numSamples)));
1229 auto startIdxConst = arith::ConstantOp::create(
1230 rewriter, loc, i64TensorType,
1231 DenseElementsAttr::get(i64TensorType,
1232 rewriter.getI64IntegerAttr(startIdx)));
1233 auto thinningConst = arith::ConstantOp::create(
1234 rewriter, loc, i64TensorType,
1235 DenseElementsAttr::get(i64TensorType,
1236 rewriter.getI64IntegerAttr(thinning)));
1239 SmallVector<Type> loopResultTypes = {
1240 positionType, positionType, scalarType,
1241 currentRng.getType(), samplesBufferType, acceptedBufferType};
1242 auto forLoopOp = impulse::ForOp::create(
1243 rewriter, loc, loopResultTypes, c0, numSamplesConst, c1,
1244 ValueRange{currentQ, currentGrad, currentU, currentRng, samplesBuffer,
1247 Block *loopBody = rewriter.createBlock(&forLoopOp.getRegion());
1248 loopBody->addArgument(i64TensorType, loc);
1249 loopBody->addArgument(positionType, loc);
1250 loopBody->addArgument(positionType, loc);
1251 loopBody->addArgument(scalarType, loc);
1252 loopBody->addArgument(currentRng.getType(), loc);
1253 loopBody->addArgument(samplesBufferType, loc);
1254 loopBody->addArgument(acceptedBufferType, loc);
1256 rewriter.setInsertionPointToStart(loopBody);
1257 Value iterIdx = loopBody->getArgument(0);
1258 Value qLoop = loopBody->getArgument(1);
1259 Value gradLoop = loopBody->getArgument(2);
1260 Value ULoop = loopBody->getArgument(3);
1261 Value rngLoop = loopBody->getArgument(4);
1262 Value samplesBufferLoop = loopBody->getArgument(5);
1263 Value acceptedBufferLoop = loopBody->getArgument(6);
1265 auto sample = runSampleStepWithStepSize(rewriter, loc, qLoop, gradLoop,
1266 ULoop, rngLoop, adaptedStepSize);
1267 auto q_constrained =
1272 arith::SubIOp::create(rewriter, loc, iterIdx, startIdxConst);
1274 arith::DivSIOp::create(rewriter, loc, iMinusStart, thinningConst);
1278 auto geStartIdx = arith::CmpIOp::create(
1279 rewriter, loc, arith::CmpIPredicate::sge, iterIdx, startIdxConst);
1281 arith::RemSIOp::create(rewriter, loc, iMinusStart, thinningConst);
1282 auto modIsZero = arith::CmpIOp::create(
1283 rewriter, loc, arith::CmpIPredicate::eq, modThinning, c0);
1285 arith::AndIOp::create(rewriter, loc, geStartIdx, modIsZero);
1287 auto zeroCol = arith::ConstantOp::create(
1288 rewriter, loc, i64TensorType,
1289 DenseElementsAttr::get(i64TensorType, rewriter.getI64IntegerAttr(0)));
1290 auto updatedSamplesBuffer = impulse::DynamicUpdateSliceOp::create(
1291 rewriter, loc, samplesBufferType, samplesBufferLoop, q_constrained,
1292 ValueRange{storageIdx, zeroCol});
1293 auto selectedSamplesBuffer = impulse::SelectOp::create(
1294 rewriter, loc, samplesBufferType, shouldStore, updatedSamplesBuffer,
1297 auto accepted1D = impulse::ReshapeOp::create(
1298 rewriter, loc, RankedTensorType::get({1}, rewriter.getI1Type()),
1300 auto updatedAcceptedBuffer = impulse::DynamicUpdateSliceOp::create(
1301 rewriter, loc, acceptedBufferType, acceptedBufferLoop, accepted1D,
1302 ValueRange{storageIdx});
1303 auto selectedAcceptedBuffer = impulse::SelectOp::create(
1304 rewriter, loc, acceptedBufferType, shouldStore, updatedAcceptedBuffer,
1305 acceptedBufferLoop);
1307 impulse::YieldOp::create(rewriter, loc,
1308 ValueRange{sample.q, sample.grad, sample.U,
1309 sample.rng, selectedSamplesBuffer,
1310 selectedAcceptedBuffer});
1312 rewriter.setInsertionPointAfter(forLoopOp);
1313 Value finalQ = forLoopOp.getResult(0);
1314 Value finalGrad = forLoopOp.getResult(1);
1315 Value finalU = forLoopOp.getResult(2);
1316 Value finalRng = forLoopOp.getResult(3);
1317 Value finalSamplesBuffer = forLoopOp.getResult(4);
1318 Value finalAcceptedBuffer = forLoopOp.getResult(5);
1320 finalSamplesBuffer =
1322 "MCMC: collected samples", debugDump);
1324 rewriter.replaceOp(mcmcOp, {finalSamplesBuffer, finalAcceptedBuffer,
1325 finalRng, finalQ, finalGrad, finalU,
1326 adaptedStepSize, adaptedInvMass});
1332 struct LowerMHPattern :
public mlir::OpRewritePattern<impulse::MHOp> {
1335 LogicalResult matchAndRewrite(impulse::MHOp mhOp,
1336 PatternRewriter &rewriter)
const override {
1337 SymbolTableCollection symbolTable;
1339 auto fn = cast<FunctionOpInterface>(
1340 symbolTable.lookupNearestSymbolFrom(mhOp, mhOp.getFnAttr()));
1342 if (fn.getFunctionBody().empty()) {
1344 "Impulse: calling `mh` on an empty function; if this is a "
1345 "distribution function, its sample op should have a logpdf "
1346 "attribute to avoid recursive `mh` calls which is intended for "
1347 "generative functions");
1351 auto loc = mhOp.getLoc();
1353 Value oldTrace = mhOp.getOperand(0);
1354 Value oldWeight = mhOp.getOperand(1);
1355 SmallVector<Value> inputs;
1356 for (
unsigned i = 2; i < mhOp.getNumOperands(); ++i)
1357 inputs.push_back(mhOp.getOperand(i));
1358 auto selection = mhOp.getSelectionAttr();
1361 auto weightType = cast<RankedTensorType>(oldWeight.getType());
1362 auto rngStateType = inputs[0].getType();
1365 auto nameAttr = mhOp.getNameAttr();
1367 nameAttr = rewriter.getStringAttr(
"");
1369 auto regenerateAddresses = mhOp.getRegenerateAddressesAttr();
1371 SmallVector<Type> regenResultTypes;
1373 regenResultTypes.push_back(weightType);
1374 for (
auto t : fn.getResultTypes())
1375 regenResultTypes.push_back(t);
1377 auto regenerateOp = rewriter.create<impulse::RegenerateOp>(
1384 regenerateAddresses,
1387 Value newTrace = regenerateOp.getNewTrace();
1388 Value newWeight = regenerateOp.getWeight();
1389 Value newRng = regenerateOp.getOutputs()[0];
1393 arith::SubFOp::create(rewriter, loc, newWeight, oldWeight);
1396 auto zeroConst = arith::ConstantOp::create(
1397 rewriter, loc, weightType, DenseElementsAttr::get(weightType, 0.0));
1398 auto oneConst = arith::ConstantOp::create(
1399 rewriter, loc, weightType, DenseElementsAttr::get(weightType, 1.0));
1401 auto randomOp = impulse::RandomOp::create(
1402 rewriter, loc, TypeRange{rngStateType, weightType}, newRng, zeroConst,
1404 impulse::RngDistributionAttr::get(rewriter.getContext(),
1405 impulse::RngDistribution::UNIFORM));
1406 auto logRand = math::LogOp::create(rewriter, loc, randomOp.getResult());
1407 Value finalRng = randomOp.getOutputRngState();
1410 auto accepted = arith::CmpFOp::create(
1411 rewriter, loc, arith::CmpFPredicate::OLT, logRand, logAlpha);
1414 auto selectedTrace = impulse::SelectOp::create(
1415 rewriter, loc,
traceType, accepted, newTrace, oldTrace);
1416 auto selectedWeight = arith::SelectOp::create(rewriter, loc, accepted,
1417 newWeight, oldWeight);
1419 rewriter.replaceOp(mhOp,
1420 {selectedTrace, selectedWeight, accepted, finalRng});
1425 struct LowerGeneratePattern
1426 :
public mlir::OpRewritePattern<impulse::GenerateOp> {
1429 LogicalResult matchAndRewrite(impulse::GenerateOp CI,
1430 PatternRewriter &rewriter)
const override {
1431 SymbolTableCollection symbolTable;
1433 auto fn = cast<FunctionOpInterface>(
1434 symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()));
1436 if (fn.getFunctionBody().empty()) {
1438 "Impulse: calling `generate` on an empty function; if this "
1439 "is a distribution function, its sample op should have a "
1440 "logpdf attribute to avoid recursive `generate` calls which is "
1441 "intended for generative functions");
1445 ArrayAttr selection = CI.getSelectionAttr();
1446 int64_t positionSize =
1447 computePositionSizeForSelection(CI, fn, selection, symbolTable);
1448 if (positionSize <= 0) {
1449 CI.emitError(
"Impulse: failed to compute position size for generate");
1453 int64_t constraintSize = computePositionSizeForSelection(
1454 CI, fn, CI.getConstrainedAddressesAttr(), symbolTable);
1455 if (constraintSize < 0) {
1456 CI.emitError(
"Impulse: failed to compute constraint size for generate");
1461 positionSize, constraintSize);
1462 FunctionOpInterface NewF = putils->newFunc;
1464 OpBuilder entryBuilder(putils->initializationBlock,
1465 putils->initializationBlock->begin());
1466 Location initLoc = putils->initializationBlock->begin()->getLoc();
1468 auto scalarType = RankedTensorType::get({}, entryBuilder.getF64Type());
1470 arith::ConstantOp::create(entryBuilder, initLoc, scalarType,
1471 DenseElementsAttr::get(scalarType, 0.0));
1472 Value weightAccumulator = zeroWeight;
1475 RankedTensorType::get({1, positionSize}, entryBuilder.getF64Type());
1477 arith::ConstantOp::create(entryBuilder, initLoc,
traceType,
1478 DenseElementsAttr::get(
traceType, 0.0));
1479 Value currTrace = zeroTrace;
1480 Value constraint = NewF.getArgument(0);
1481 int64_t currentTraceOffset = 0;
1483 SmallVector<Operation *> toErase;
1484 auto result = NewF.walk([&](impulse::SampleOp sampleOp) -> WalkResult {
1485 OpBuilder::InsertionGuard guard(rewriter);
1486 rewriter.setInsertionPoint(sampleOp);
1488 SmallVector<Value> sampledValues;
1489 bool isDistribution =
static_cast<bool>(sampleOp.getLogpdfAttr());
1491 if (isDistribution) {
1493 bool isConstrained =
false;
1494 int64_t constrainedOffset = -1;
1495 for (
auto addr : CI.getConstrainedAddressesAttr()) {
1496 auto address = cast<ArrayAttr>(addr);
1497 if (!address.empty() && address[0] == sampleOp.getSymbolAttr()) {
1498 if (address.size() != 1) {
1500 "Impulse: distribution function cannot have composite "
1501 "constrained address");
1502 return WalkResult::interrupt();
1504 isConstrained =
true;
1505 constrainedOffset = computeOffsetForSampleInSelection(
1506 CI, fn, CI.getConstrainedAddressesAttr(),
1507 sampleOp.getSymbolAttr(), symbolTable);
1512 if (isConstrained) {
1514 sampledValues.resize(sampleOp.getNumResults());
1515 sampledValues[0] = sampleOp.getOperand(0);
1517 for (
unsigned i = 1; i < sampleOp.getNumResults(); ++i) {
1519 cast<RankedTensorType>(sampleOp.getResult(i).getType());
1520 int64_t numElements = computeTensorElementCount(resultType);
1521 if (numElements < 0) {
1523 "Impulse: dynamic tensor dimensions not supported");
1524 return WalkResult::interrupt();
1527 auto sliceType = RankedTensorType::get(
1528 {1, numElements}, resultType.getElementType());
1529 auto sliced = impulse::SliceOp::create(
1530 rewriter, sampleOp.getLoc(), sliceType, constraint,
1531 rewriter.getDenseI64ArrayAttr({0, constrainedOffset}),
1532 rewriter.getDenseI64ArrayAttr(
1533 {1, constrainedOffset + numElements}),
1534 rewriter.getDenseI64ArrayAttr({1, 1}));
1535 auto extracted = impulse::ReshapeOp::create(
1536 rewriter, sampleOp.getLoc(), resultType, sliced);
1537 sampledValues[i] = extracted.getResult();
1538 constrainedOffset += numElements;
1543 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
1544 sampleOp, sampleOp.getLogpdfAttr()));
1547 SmallVector<Value> logpdfOperands;
1548 for (
unsigned i = 1; i < sampledValues.size(); ++i) {
1549 logpdfOperands.push_back(sampledValues[i]);
1551 for (
unsigned i = 1; i < sampleOp.getNumOperands(); ++i) {
1552 logpdfOperands.push_back(sampleOp.getOperand(i));
1555 if (logpdfOperands.size() != logpdfFn.getNumArguments()) {
1557 "Impulse: failed to construct logpdf call for constrained "
1558 "sample; logpdf function has wrong number of arguments");
1559 return WalkResult::interrupt();
1562 auto logpdf = func::CallOp::create(
1563 rewriter, sampleOp.getLoc(), logpdfFn.getName(),
1564 logpdfFn.getResultTypes(), logpdfOperands);
1566 arith::AddFOp::create(rewriter, sampleOp.getLoc(),
1567 weightAccumulator, logpdf.getResult(0));
1571 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
1572 sampleOp, sampleOp.getFnAttr()));
1574 auto distCall = func::CallOp::create(
1575 rewriter, sampleOp.getLoc(), distFn.getName(),
1576 distFn.getResultTypes(), sampleOp.getInputs());
1578 sampledValues.append(distCall.getResults().begin(),
1579 distCall.getResults().end());
1582 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
1583 sampleOp, sampleOp.getLogpdfAttr()));
1585 SmallVector<Value> logpdfOperands;
1586 for (
unsigned i = 1; i < sampledValues.size(); ++i) {
1587 logpdfOperands.push_back(sampledValues[i]);
1589 for (
unsigned i = 1; i < sampleOp.getNumOperands(); ++i) {
1590 logpdfOperands.push_back(sampleOp.getOperand(i));
1593 if (logpdfOperands.size() != logpdfFn.getNumArguments()) {
1595 "Impulse: failed to construct logpdf call; "
1596 "logpdf function has wrong number of arguments");
1597 return WalkResult::interrupt();
1600 auto logpdf = func::CallOp::create(
1601 rewriter, sampleOp.getLoc(), logpdfFn.getName(),
1602 logpdfFn.getResultTypes(), logpdfOperands);
1604 arith::AddFOp::create(rewriter, sampleOp.getLoc(),
1605 weightAccumulator, logpdf.getResult(0));
1608 bool inSelection =
false;
1609 for (
auto addr : selection) {
1610 auto address = cast<ArrayAttr>(addr);
1611 if (!address.empty() && address[0] == sampleOp.getSymbolAttr()) {
1618 for (
unsigned i = 1; i < sampledValues.size(); ++i) {
1619 auto sampleValue = sampledValues[i];
1620 auto sampleType = cast<RankedTensorType>(sampleValue.getType());
1621 int64_t numElements = computeTensorElementCount(sampleType);
1622 if (numElements < 0) {
1624 "Impulse: dynamic tensor dimensions not supported");
1625 return WalkResult::interrupt();
1628 auto flatSampleType = RankedTensorType::get(
1629 {1, numElements}, sampleType.getElementType());
1630 auto flatSample = impulse::ReshapeOp::create(
1631 rewriter, sampleOp.getLoc(), flatSampleType, sampleValue);
1632 auto i64S = RankedTensorType::get({}, rewriter.getI64Type());
1633 auto row0 = arith::ConstantOp::create(
1634 rewriter, sampleOp.getLoc(), i64S,
1635 DenseElementsAttr::get(i64S, rewriter.getI64IntegerAttr(0)));
1636 auto colOff = arith::ConstantOp::create(
1637 rewriter, sampleOp.getLoc(), i64S,
1638 DenseElementsAttr::get(
1639 i64S, rewriter.getI64IntegerAttr(currentTraceOffset)));
1640 currTrace = impulse::DynamicUpdateSliceOp::create(
1641 rewriter, sampleOp.getLoc(),
traceType, currTrace,
1642 flatSample, ValueRange{row0, colOff})
1644 currentTraceOffset += numElements;
1650 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
1651 sampleOp, sampleOp.getFnAttr()));
1653 if (genFn.getFunctionBody().empty()) {
1655 "Impulse: generative function body is empty; "
1656 "if this is a distribution, add a logpdf attribute");
1657 return WalkResult::interrupt();
1660 ArrayAttr subSelection =
1661 buildSubSelection(rewriter, selection, sampleOp.getSymbolAttr());
1662 ArrayAttr subConstrainedAddrs =
1663 buildSubSelection(rewriter, CI.getConstrainedAddressesAttr(),
1664 sampleOp.getSymbolAttr());
1666 if (subSelection.empty()) {
1669 auto genCall = func::CallOp::create(
1670 rewriter, sampleOp.getLoc(), genFn.getName(),
1671 genFn.getResultTypes(), sampleOp.getInputs());
1672 sampledValues.append(genCall.getResults().begin(),
1673 genCall.getResults().end());
1675 int64_t subPositionSize = computePositionSizeForSelection(
1676 sampleOp, genFn, subSelection, symbolTable);
1677 int64_t subConstraintSize = computePositionSizeForSelection(
1678 sampleOp, genFn, subConstrainedAddrs, symbolTable);
1679 if (subPositionSize <= 0 || subConstraintSize < 0) {
1680 sampleOp.emitError(
"Impulse: failed to compute sub-position or "
1681 "sub-constraint size");
1682 return WalkResult::interrupt();
1685 Value subConstraint;
1686 auto subConstraintType = RankedTensorType::get(
1687 {1, subConstraintSize}, rewriter.getF64Type());
1689 if (subConstraintSize > 0) {
1690 int64_t subConstraintOffset = computeOffsetForNestedSample(
1691 sampleOp, fn, CI.getConstrainedAddressesAttr(),
1692 sampleOp.getSymbolAttr(), symbolTable);
1694 subConstraint = impulse::SliceOp::create(
1695 rewriter, sampleOp.getLoc(), subConstraintType, constraint,
1696 rewriter.getDenseI64ArrayAttr({0, subConstraintOffset}),
1697 rewriter.getDenseI64ArrayAttr(
1698 {1, subConstraintOffset + subConstraintSize}),
1699 rewriter.getDenseI64ArrayAttr({1, 1}));
1701 subConstraint = arith::ConstantOp::create(
1702 rewriter, sampleOp.getLoc(), subConstraintType,
1703 DenseElementsAttr::get(subConstraintType, {0.0}));
1707 auto subTraceType = RankedTensorType::get({1, subPositionSize},
1708 rewriter.getF64Type());
1709 auto scalarTy = RankedTensorType::get({}, rewriter.getF64Type());
1710 SmallVector<Type> genResultTypes;
1711 genResultTypes.push_back(subTraceType);
1712 genResultTypes.push_back(scalarTy);
1713 for (
auto t : genFn.getResultTypes())
1714 genResultTypes.push_back(t);
1716 auto nestedGenerate = impulse::GenerateOp::create(
1717 rewriter, sampleOp.getLoc(), genResultTypes,
1718 sampleOp.getFnAttr(), sampleOp.getInputs(), subConstraint,
1719 subSelection, subConstrainedAddrs);
1721 Value subTrace = nestedGenerate.getTrace();
1722 Value subWeight = nestedGenerate.getWeight();
1724 weightAccumulator = arith::AddFOp::create(
1725 rewriter, sampleOp.getLoc(), weightAccumulator, subWeight);
1727 int64_t mergeOffset = computeOffsetForNestedSample(
1728 sampleOp, fn, selection, sampleOp.getSymbolAttr(), symbolTable);
1730 auto i64S = RankedTensorType::get({}, rewriter.getI64Type());
1731 auto row0 = arith::ConstantOp::create(
1732 rewriter, sampleOp.getLoc(), i64S,
1733 DenseElementsAttr::get(i64S, rewriter.getI64IntegerAttr(0)));
1734 auto colOff = arith::ConstantOp::create(
1735 rewriter, sampleOp.getLoc(), i64S,
1736 DenseElementsAttr::get(
1737 i64S, rewriter.getI64IntegerAttr(mergeOffset)));
1738 currTrace = impulse::DynamicUpdateSliceOp::create(
1739 rewriter, sampleOp.getLoc(),
traceType, currTrace,
1740 subTrace, ValueRange{row0, colOff})
1742 currentTraceOffset =
1743 std::max(currentTraceOffset, mergeOffset + subPositionSize);
1745 for (
auto output : nestedGenerate.getOutputs())
1746 sampledValues.push_back(output);
1750 sampleOp.replaceAllUsesWith(sampledValues);
1751 toErase.push_back(sampleOp);
1752 return WalkResult::advance();
1755 for (Operation *op : toErase)
1756 rewriter.eraseOp(op);
1758 if (result.wasInterrupted()) {
1759 CI.emitError(
"Impulse: failed to walk sample ops");
1764 NewF.walk([&](func::ReturnOp retOp) {
1765 OpBuilder::InsertionGuard guard(rewriter);
1766 rewriter.setInsertionPoint(retOp);
1768 SmallVector<Value> newRetVals;
1769 newRetVals.push_back(currTrace);
1770 newRetVals.push_back(weightAccumulator);
1771 newRetVals.append(retOp.getOperands().begin(),
1772 retOp.getOperands().end());
1774 func::ReturnOp::create(rewriter, retOp.getLoc(), newRetVals);
1775 rewriter.eraseOp(retOp);
1778 rewriter.setInsertionPoint(CI);
1779 SmallVector<Value> operands;
1780 operands.push_back(CI.getConstraint());
1781 operands.append(CI.getInputs().begin(), CI.getInputs().end());
1782 auto newCI = func::CallOp::create(rewriter, CI.getLoc(), NewF.getName(),
1783 NewF.getResultTypes(), operands);
1785 rewriter.replaceOp(CI, newCI.getResults());
1793 struct LowerRegeneratePattern
1794 :
public mlir::OpRewritePattern<impulse::RegenerateOp> {
1797 LogicalResult matchAndRewrite(impulse::RegenerateOp CI,
1798 PatternRewriter &rewriter)
const override {
1799 SymbolTableCollection symbolTable;
1801 auto fn = cast<FunctionOpInterface>(
1802 symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()));
1804 if (fn.getFunctionBody().empty()) {
1806 "Impulse: calling `regenerate` on an empty function; if this "
1807 "is a distribution function, its sample op should have a "
1808 "logpdf attribute to avoid recursive `regenerate` calls which is "
1809 "intended for generative functions");
1813 ArrayAttr selection = CI.getSelectionAttr();
1814 int64_t positionSize =
1815 computePositionSizeForSelection(CI, fn, selection, symbolTable);
1816 if (positionSize <= 0) {
1817 CI.emitError(
"Impulse: failed to compute position size for regenerate");
1823 FunctionOpInterface NewF = putils->newFunc;
1825 OpBuilder entryBuilder(putils->initializationBlock,
1826 putils->initializationBlock->begin());
1827 Location initLoc = putils->initializationBlock->begin()->getLoc();
1829 auto scalarType = RankedTensorType::get({}, entryBuilder.getF64Type());
1831 arith::ConstantOp::create(entryBuilder, initLoc, scalarType,
1832 DenseElementsAttr::get(scalarType, 0.0));
1833 Value weightAccumulator = zeroWeight;
1836 RankedTensorType::get({1, positionSize}, entryBuilder.getF64Type());
1838 arith::ConstantOp::create(entryBuilder, initLoc,
traceType,
1839 DenseElementsAttr::get(
traceType, 0.0));
1840 Value currTrace = zeroTrace;
1842 Value prevTrace = NewF.getArgument(0);
1843 int64_t currentTraceOffset = 0;
1845 SmallVector<Operation *> toErase;
1846 auto result = NewF.walk([&](impulse::SampleOp sampleOp) -> WalkResult {
1847 OpBuilder::InsertionGuard guard(rewriter);
1848 rewriter.setInsertionPoint(sampleOp);
1850 SmallVector<Value> sampledValues;
1851 bool isDistribution =
static_cast<bool>(sampleOp.getLogpdfAttr());
1853 if (isDistribution) {
1855 bool isSelected =
false;
1856 for (
auto addr : CI.getRegenerateAddressesAttr()) {
1857 auto address = cast<ArrayAttr>(addr);
1858 if (!address.empty() && address[0] == sampleOp.getSymbolAttr()) {
1859 if (address.size() != 1) {
1861 "Impulse: distribution function cannot have composite "
1862 "selected address");
1863 return WalkResult::interrupt();
1870 int64_t sampleOffset = computeOffsetForSampleInSelection(
1871 CI, fn, selection, sampleOp.getSymbolAttr(), symbolTable);
1876 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
1877 sampleOp, sampleOp.getFnAttr()));
1879 auto distCall = func::CallOp::create(
1880 rewriter, sampleOp.getLoc(), distFn.getName(),
1881 distFn.getResultTypes(), sampleOp.getInputs());
1883 sampledValues.append(distCall.getResults().begin(),
1884 distCall.getResults().end());
1887 sampledValues.resize(sampleOp.getNumResults());
1888 sampledValues[0] = sampleOp.getOperand(0);
1890 int64_t extractOffset = sampleOffset;
1891 for (
unsigned i = 1; i < sampleOp.getNumResults(); ++i) {
1893 cast<RankedTensorType>(sampleOp.getResult(i).getType());
1894 int64_t numElements = computeTensorElementCount(resultType);
1895 if (numElements < 0) {
1897 "Impulse: dynamic tensor dimensions not supported");
1898 return WalkResult::interrupt();
1901 auto sliceType = RankedTensorType::get(
1902 {1, numElements}, resultType.getElementType());
1903 auto sliced = impulse::SliceOp::create(
1904 rewriter, sampleOp.getLoc(), sliceType, prevTrace,
1905 rewriter.getDenseI64ArrayAttr({0, extractOffset}),
1906 rewriter.getDenseI64ArrayAttr(
1907 {1, extractOffset + numElements}),
1908 rewriter.getDenseI64ArrayAttr({1, 1}));
1909 auto extracted = impulse::ReshapeOp::create(
1910 rewriter, sampleOp.getLoc(), resultType, sliced);
1911 sampledValues[i] = extracted.getResult();
1912 extractOffset += numElements;
1917 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
1918 sampleOp, sampleOp.getLogpdfAttr()));
1920 SmallVector<Value> logpdfOperands;
1921 for (
unsigned i = 1; i < sampledValues.size(); ++i) {
1922 logpdfOperands.push_back(sampledValues[i]);
1924 for (
unsigned i = 1; i < sampleOp.getNumOperands(); ++i) {
1925 logpdfOperands.push_back(sampleOp.getOperand(i));
1928 if (logpdfOperands.size() != logpdfFn.getNumArguments()) {
1929 sampleOp.emitError(
"Impulse: failed to construct logpdf call; "
1930 "logpdf function has wrong number of arguments");
1931 return WalkResult::interrupt();
1934 auto logpdf = func::CallOp::create(
1935 rewriter, sampleOp.getLoc(), logpdfFn.getName(),
1936 logpdfFn.getResultTypes(), logpdfOperands);
1938 arith::AddFOp::create(rewriter, sampleOp.getLoc(),
1939 weightAccumulator, logpdf.getResult(0));
1941 bool inSelection =
false;
1942 for (
auto addr : selection) {
1943 auto address = cast<ArrayAttr>(addr);
1944 if (!address.empty() && address[0] == sampleOp.getSymbolAttr()) {
1951 for (
unsigned i = 1; i < sampledValues.size(); ++i) {
1952 auto sampleValue = sampledValues[i];
1953 auto sampleType = cast<RankedTensorType>(sampleValue.getType());
1954 int64_t numElements = computeTensorElementCount(sampleType);
1955 if (numElements < 0) {
1957 "Impulse: dynamic tensor dimensions not supported");
1958 return WalkResult::interrupt();
1961 auto flatSampleType = RankedTensorType::get(
1962 {1, numElements}, sampleType.getElementType());
1963 auto flatSample = impulse::ReshapeOp::create(
1964 rewriter, sampleOp.getLoc(), flatSampleType, sampleValue);
1965 auto i64S = RankedTensorType::get({}, rewriter.getI64Type());
1966 auto row0 = arith::ConstantOp::create(
1967 rewriter, sampleOp.getLoc(), i64S,
1968 DenseElementsAttr::get(i64S, rewriter.getI64IntegerAttr(0)));
1969 auto colOff = arith::ConstantOp::create(
1970 rewriter, sampleOp.getLoc(), i64S,
1971 DenseElementsAttr::get(
1972 i64S, rewriter.getI64IntegerAttr(currentTraceOffset)));
1973 currTrace = impulse::DynamicUpdateSliceOp::create(
1974 rewriter, sampleOp.getLoc(),
traceType, currTrace,
1975 flatSample, ValueRange{row0, colOff})
1977 currentTraceOffset += numElements;
1983 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
1984 sampleOp, sampleOp.getFnAttr()));
1986 if (genFn.getFunctionBody().empty()) {
1988 "Impulse: generative function body is empty; "
1989 "if this is a distribution, add a logpdf attribute");
1990 return WalkResult::interrupt();
1993 ArrayAttr subSelection =
1994 buildSubSelection(rewriter, selection, sampleOp.getSymbolAttr());
1995 ArrayAttr subRegenerateAddrs =
1996 buildSubSelection(rewriter, CI.getRegenerateAddressesAttr(),
1997 sampleOp.getSymbolAttr());
1999 if (subSelection.empty()) {
2000 auto genCall = func::CallOp::create(
2001 rewriter, sampleOp.getLoc(), genFn.getName(),
2002 genFn.getResultTypes(), sampleOp.getInputs());
2003 sampledValues.append(genCall.getResults().begin(),
2004 genCall.getResults().end());
2006 int64_t subPositionSize = computePositionSizeForSelection(
2007 sampleOp, genFn, subSelection, symbolTable);
2008 if (subPositionSize <= 0) {
2010 "Impulse: failed to compute sub-position size");
2011 return WalkResult::interrupt();
2014 int64_t mergeOffset = computeOffsetForNestedSample(
2015 sampleOp, fn, selection, sampleOp.getSymbolAttr(), symbolTable);
2016 if (mergeOffset < 0) {
2017 sampleOp.emitError(
"Impulse: failed to compute merge offset");
2018 return WalkResult::interrupt();
2021 auto subTraceType = RankedTensorType::get({1, subPositionSize},
2022 rewriter.getF64Type());
2023 Value subPrevTrace = impulse::SliceOp::create(
2024 rewriter, sampleOp.getLoc(), subTraceType, prevTrace,
2025 rewriter.getDenseI64ArrayAttr({0, mergeOffset}),
2026 rewriter.getDenseI64ArrayAttr(
2027 {1, mergeOffset + subPositionSize}),
2028 rewriter.getDenseI64ArrayAttr({1, 1}));
2031 auto scalarTy = RankedTensorType::get({}, rewriter.getF64Type());
2032 SmallVector<Type> regenResultTypes;
2033 regenResultTypes.push_back(subTraceType);
2034 regenResultTypes.push_back(scalarTy);
2035 for (
auto t : genFn.getResultTypes())
2036 regenResultTypes.push_back(t);
2038 auto nestedRegenerate = impulse::RegenerateOp::create(
2039 rewriter, sampleOp.getLoc(), regenResultTypes,
2040 sampleOp.getFnAttr(), sampleOp.getInputs(), subPrevTrace,
2041 subSelection, subRegenerateAddrs);
2043 Value subTrace = nestedRegenerate.getNewTrace();
2044 Value subWeight = nestedRegenerate.getWeight();
2046 weightAccumulator = arith::AddFOp::create(
2047 rewriter, sampleOp.getLoc(), weightAccumulator, subWeight);
2049 auto i64S = RankedTensorType::get({}, rewriter.getI64Type());
2050 auto row0 = arith::ConstantOp::create(
2051 rewriter, sampleOp.getLoc(), i64S,
2052 DenseElementsAttr::get(i64S, rewriter.getI64IntegerAttr(0)));
2053 auto colOff = arith::ConstantOp::create(
2054 rewriter, sampleOp.getLoc(), i64S,
2055 DenseElementsAttr::get(
2056 i64S, rewriter.getI64IntegerAttr(mergeOffset)));
2057 currTrace = impulse::DynamicUpdateSliceOp::create(
2058 rewriter, sampleOp.getLoc(),
traceType, currTrace,
2059 subTrace, ValueRange{row0, colOff})
2061 currentTraceOffset =
2062 std::max(currentTraceOffset, mergeOffset + subPositionSize);
2064 for (
auto output : nestedRegenerate.getOutputs())
2065 sampledValues.push_back(output);
2069 sampleOp.replaceAllUsesWith(sampledValues);
2070 toErase.push_back(sampleOp);
2071 return WalkResult::advance();
2074 for (Operation *op : toErase)
2075 rewriter.eraseOp(op);
2077 if (result.wasInterrupted()) {
2078 CI.emitError(
"Impulse: failed to walk sample ops");
2082 NewF.walk([&](func::ReturnOp retOp) {
2083 OpBuilder::InsertionGuard guard(rewriter);
2084 rewriter.setInsertionPoint(retOp);
2086 SmallVector<Value> newRetVals;
2087 newRetVals.push_back(currTrace);
2088 newRetVals.push_back(weightAccumulator);
2089 newRetVals.append(retOp.getOperands().begin(),
2090 retOp.getOperands().end());
2092 func::ReturnOp::create(rewriter, retOp.getLoc(), newRetVals);
2093 rewriter.eraseOp(retOp);
2096 rewriter.setInsertionPoint(CI);
2097 SmallVector<Value> operands;
2098 operands.push_back(CI.getOriginalTrace());
2099 operands.append(CI.getInputs().begin(), CI.getInputs().end());
2100 auto newCI = func::CallOp::create(rewriter, CI.getLoc(), NewF.getName(),
2101 NewF.getResultTypes(), operands);
2103 rewriter.replaceOp(CI, newCI.getResults());
2114void ExpandImpulsePass::runOnOperation() {
2115 RewritePatternSet patterns(&getContext());
2116 patterns.add<LowerUntracedCallPattern, LowerSimulatePattern,
2117 LowerGeneratePattern, LowerMHPattern, LowerRegeneratePattern>(
2119 patterns.add<LowerMCMCPattern>(&getContext(), debugDump);
2121 mlir::GreedyRewriteConfig config;
2124 applyPatternsGreedily(getOperation(), std::move(patterns), config))) {
2125 signalPassFailure();
2129 if (!postpasses.empty()) {
2130 mlir::PassManager pm(getOperation()->getContext());
2132 if (mlir::failed(mlir::parsePassPipeline(postpasses, pm))) {
2133 getOperation()->emitError()
2134 <<
"Failed to parse expand-impulse post-passes pipeline: "
2136 signalPassFailure();
2140 if (mlir::failed(pm.run(getOperation()))) {
2141 signalPassFailure();
PointerType * traceType(LLVMContext &C)
static ImpulseUtils * CreateFromClone(FunctionOpInterface toeval, ImpulseMode mode, int64_t positionSize=-1, int64_t constraintSize=-1)
Value finalizeWelford(OpBuilder &builder, Location loc, const WelfordState &state, const WelfordConfig &config)
Finalize Welford state to produce sample covariance (returned as inverse mass matrix).
SmallVector< AdaptWindow > buildAdaptationSchedule(int64_t numSteps)
Build warmup adaptation schedule.
Value conditionalDump(OpBuilder &builder, Location loc, Value value, StringRef label, bool debugDump)
Conditionally dump a value for debugging.
MCMCKernelResult SampleNUTS(OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const NUTSContext &ctx, bool debugDump=false)
Single NUTS iteration: momentum sampling + tree building.
InitialHMCState InitHMC(OpBuilder &builder, Location loc, Value rng, const HMCContext &ctx, Value initialPosition=Value(), bool debugDump=false)
Initializes HMC/NUTS state from a trace Specifically:
DualAveragingState updateDualAveraging(OpBuilder &builder, Location loc, const DualAveragingState &state, Value acceptProb, const DualAveragingConfig &config)
Update dual averaging state with observed acceptance probability.
DualAveragingState initDualAveraging(OpBuilder &builder, Location loc, Value stepSize)
Initialize dual averaging state from initial step size.
Value computeMassMatrixSqrt(OpBuilder &builder, Location loc, Value invMass, RankedTensorType positionType)
Computes the square root of the mass matrix from the inverse mass matrix.
WelfordState updateWelford(OpBuilder &builder, Location loc, const WelfordState &state, Value sample, const WelfordConfig &config)
Update Welford state with a new sample.
WelfordState initWelford(OpBuilder &builder, Location loc, int64_t positionSize, bool diagonal)
Initialize state for Welford covariance estimation.
Value getStepSizeFromDualAveraging(OpBuilder &builder, Location loc, const DualAveragingState &state, bool final=false)
Get step size from dual averaging state.
MCMCKernelResult SampleHMC(OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const HMCContext &ctx, bool debugDump=false)
Single HMC iteration: momentum sampling + leapfrog + MH accept/reject.
Value constrainPosition(OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports)
Transform an entire position vector from unconstrained to constrained space.
SmallVector< Type > getTypes() const
static DualAveragingState fromValues(ArrayRef< Value > values)
SmallVector< Value > toValues() const
Result of one MCMC kernel step.
Configuration for Welford covariance estimation.
State for Welford covariance estimation.
SmallVector< Value > toValues() const
SmallVector< Type > getTypes() const
static WelfordState fromValues(ArrayRef< Value > values)