Enzyme main
Loading...
Searching...
No Matches
ExpandImpulsePass.cpp
Go to the documentation of this file.
1//===- ExpandImpulsePass.cpp - Expand Impulse region ops ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass that expands high-level Impulse region ops
10// (simulate, generate, regenerate, mcmc, mh, untraced_call) into lower-level
11// Impulse ops (sample, random, etc.) plus arith/math/cf.
12//===----------------------------------------------------------------------===//
13
15#include "Dialect/Ops.h"
16#include "Interfaces/HMCUtils.h"
18#include "PassDetails.h"
19#include "Passes/Passes.h"
20
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/Math/IR/Math.h"
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Interfaces/FunctionInterfaces.h"
26#include "mlir/Pass/PassManager.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28
29#include "llvm/ADT/APFloat.h"
30
31#define DEBUG_TYPE "expand-impulse"
32
33using namespace mlir;
34using namespace mlir::enzyme;
35using namespace mlir::impulse;
36
37namespace mlir {
38namespace enzyme {
39#define GEN_PASS_DEF_EXPANDIMPULSEPASS
40#include "Passes/Passes.h.inc"
41} // namespace enzyme
42} // namespace mlir
43
44namespace {
45
46static int64_t computeTensorElementCount(RankedTensorType tensorType) {
47 int64_t elemCount = 1;
48 for (auto dim : tensorType.getShape()) {
49 if (dim == ShapedType::kDynamic)
50 return -1;
51 elemCount *= dim;
52 }
53 return elemCount;
54}
55
56using SampleOpMap = DenseMap<Attribute, impulse::SampleOp>;
57
58static SampleOpMap buildSampleOpMap(FunctionOpInterface fn) {
59 SampleOpMap map;
60 fn.walk([&](impulse::SampleOp sampleOp) {
61 if (auto symbol = sampleOp.getSymbolAttr())
62 map[symbol] = sampleOp;
63 });
64 return map;
65}
66
67static impulse::SampleOp findSampleBySymbol(const SampleOpMap &map,
68 Attribute targetSymbol) {
69 auto it = map.find(targetSymbol);
70 return it != map.end() ? it->second : nullptr;
71}
72
73static int64_t computeSampleElementCount(Operation *op,
74 impulse::SampleOp sampleOp) {
75 int64_t totalCount = 0;
76 for (unsigned i = 1; i < sampleOp.getNumResults(); ++i) {
77 auto resultType = sampleOp.getResult(i).getType();
78 auto tensorType = dyn_cast<RankedTensorType>(resultType);
79 if (!tensorType) {
80 op->emitError("Expected ranked tensor type for sample result");
81 return -1;
82 }
83 int64_t elemCount = computeTensorElementCount(tensorType);
84 if (elemCount < 0) {
85 op->emitError("Dynamic tensor dimensions not supported");
86 return -1;
87 }
88 totalCount += elemCount;
89 }
90 return totalCount;
91}
92
93static bool computePositionSizeForAddress(Operation *op,
94 const SampleOpMap &sampleMap,
95 ArrayRef<Attribute> address,
96 SymbolTableCollection &symbolTable,
97 int64_t &positionSize) {
98 if (address.empty())
99 return false;
100
101 auto sampleOp = findSampleBySymbol(sampleMap, address[0]);
102 if (!sampleOp)
103 return false;
104
105 if (address.size() > 1) {
106 if (sampleOp.getLogpdfAttr()) {
107 op->emitError("Cannot select nested address in distribution function");
108 return false;
109 }
110
111 auto genFn = cast<FunctionOpInterface>(
112 symbolTable.lookupNearestSymbolFrom(sampleOp, sampleOp.getFnAttr()));
113 if (!genFn || genFn.getFunctionBody().empty()) {
114 op->emitError("Cannot find generative function for nested address");
115 return false;
116 }
117
118 auto nestedMap = buildSampleOpMap(genFn);
119 return computePositionSizeForAddress(op, nestedMap, address.drop_front(),
120 symbolTable, positionSize);
121 }
122
123 int64_t elemCount = computeSampleElementCount(op, sampleOp);
124 if (elemCount < 0)
125 return false;
126
127 positionSize += elemCount;
128 return true;
129}
130
131static int64_t
132computePositionSizeForSelection(Operation *op, FunctionOpInterface fn,
133 ArrayAttr selection,
134 SymbolTableCollection &symbolTable) {
135 auto sampleMap = buildSampleOpMap(fn);
136 int64_t positionSize = 0;
137
138 for (auto addr : selection) {
139 auto address = cast<ArrayAttr>(addr);
140 if (address.empty()) {
141 op->emitError("Empty address in selection");
142 return -1;
143 }
144
145 SmallVector<Attribute> tailAddresses(address.begin(), address.end());
146 if (!computePositionSizeForAddress(op, sampleMap, tailAddresses,
147 symbolTable, positionSize)) {
148 op->emitError("Could not find sample with symbol in address chain");
149 return -1;
150 }
151 }
152
153 return positionSize;
154}
155
156static int64_t
157computeOffsetForSampleInSelection(Operation *op, FunctionOpInterface fn,
158 ArrayAttr selection, Attribute targetSymbol,
159 SymbolTableCollection &symbolTable) {
160 auto sampleMap = buildSampleOpMap(fn);
161 int64_t offset = 0;
162
163 for (auto addr : selection) {
164 auto address = cast<ArrayAttr>(addr);
165 if (address.empty())
166 continue;
167
168 auto firstSymbol = address[0];
169
170 if (firstSymbol == targetSymbol) {
171 return offset;
172 }
173
174 SmallVector<Attribute> tailAddresses(address.begin(), address.end());
175 if (!computePositionSizeForAddress(op, sampleMap, tailAddresses,
176 symbolTable, offset)) {
177 return -1;
178 }
179 }
180
181 return -1;
182}
183
184static SmallVector<impulse::SupportInfo>
185collectSupportInfoForSelection(Operation *op, FunctionOpInterface fn,
186 ArrayAttr selection, ArrayAttr allAddresses,
187 SymbolTableCollection &symbolTable) {
188 auto sampleMap = buildSampleOpMap(fn);
189 SmallVector<impulse::SupportInfo> supports;
190 int64_t currentPositionOffset = 0;
191
192 for (auto addr : selection) {
193 auto address = cast<ArrayAttr>(addr);
194 if (address.empty())
195 continue;
196
197 // TODO: Handle nested cases
198 if (address.size() != 1)
199 continue;
200
201 auto targetSymbol = address[0];
202 auto sampleOp = findSampleBySymbol(sampleMap, targetSymbol);
203 if (!sampleOp)
204 continue;
205
206 auto supportAttr = sampleOp.getSupportAttr();
207
208 int64_t sampleSize = computeSampleElementCount(op, sampleOp);
209 if (sampleSize < 0)
210 continue;
211
212 int64_t traceOffset = computeOffsetForSampleInSelection(
213 op, fn, allAddresses, targetSymbol, symbolTable);
214 if (traceOffset < 0) {
215 op->emitError("Symbol in selection not found in all_addresses - cannot "
216 "determine trace offset for scattered selection");
217 return {};
218 }
219
220 supports.emplace_back(currentPositionOffset, traceOffset, sampleSize,
221 supportAttr);
222 currentPositionOffset += sampleSize;
223 }
224
225 return supports;
226}
227
228static ArrayAttr buildSubSelection(OpBuilder &builder, ArrayAttr selection,
229 Attribute targetSymbol) {
230 SmallVector<Attribute> subAddresses;
231 for (auto addr : selection) {
232 auto address = cast<ArrayAttr>(addr);
233 if (address.empty())
234 continue;
235 if (address[0] == targetSymbol && address.size() > 1) {
236 SmallVector<Attribute> tail(address.begin() + 1, address.end());
237 subAddresses.push_back(builder.getArrayAttr(tail));
238 }
239 }
240 return builder.getArrayAttr(subAddresses);
241}
242
243static int64_t
244computeOffsetForNestedSample(Operation *op, FunctionOpInterface fn,
245 ArrayAttr selection, Attribute targetSymbol,
246 SymbolTableCollection &symbolTable) {
247 auto sampleMap = buildSampleOpMap(fn);
248 int64_t offset = 0;
249
250 for (auto addr : selection) {
251 auto address = cast<ArrayAttr>(addr);
252 if (address.empty())
253 continue;
254
255 if (address[0] == targetSymbol) {
256 return offset;
257 }
258
259 SmallVector<Attribute> tailAddresses(address.begin(), address.end());
260 if (!computePositionSizeForAddress(op, sampleMap, tailAddresses,
261 symbolTable, offset)) {
262 return -1;
263 }
264 }
265
266 return -1;
267}
268
269struct ExpandImpulsePass
270 : public enzyme::impl::ExpandImpulsePassBase<ExpandImpulsePass> {
271 using ExpandImpulsePassBase::ExpandImpulsePassBase;
272
273 MEnzymeLogic Logic;
274
275 void runOnOperation() override;
276
277 void getDependentDialects(DialectRegistry &registry) const override {
278 mlir::OpPassManager pm;
279 mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm);
280 if (!mlir::failed(result)) {
281 pm.getDependentDialects(registry);
282 }
283
284 registry
285 .insert<mlir::arith::ArithDialect, mlir::math::MathDialect,
286 mlir::complex::ComplexDialect, mlir::cf::ControlFlowDialect,
287 mlir::enzyme::EnzymeDialect, mlir::impulse::ImpulseDialect>();
288 }
289
290 struct LowerUntracedCallPattern
291 : public mlir::OpRewritePattern<impulse::UntracedCallOp> {
292 using mlir::OpRewritePattern<impulse::UntracedCallOp>::OpRewritePattern;
293
294 LogicalResult matchAndRewrite(impulse::UntracedCallOp CI,
295 PatternRewriter &rewriter) const override {
296 SymbolTableCollection symbolTable;
297
298 auto fn = cast<FunctionOpInterface>(
299 symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()));
300
301 if (fn.getFunctionBody().empty()) {
302 CI.emitError("Impulse: trying to call an empty function");
303 return failure();
304 }
305
306 auto putils = ImpulseUtils::CreateFromClone(fn, ImpulseMode::Call);
307 FunctionOpInterface NewF = putils->newFunc;
308
309 SmallVector<Operation *, 4> toErase;
310 NewF.walk([&](impulse::SampleOp sampleOp) {
311 OpBuilder::InsertionGuard guard(rewriter);
312 rewriter.setInsertionPoint(sampleOp);
313
314 auto distFn =
315 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
316 sampleOp, sampleOp.getFnAttr()));
317 auto distCall =
318 func::CallOp::create(rewriter, sampleOp.getLoc(), distFn.getName(),
319 distFn.getResultTypes(), sampleOp.getInputs());
320 sampleOp.replaceAllUsesWith(distCall);
321
322 toErase.push_back(sampleOp);
323 });
324
325 for (Operation *op : toErase)
326 rewriter.eraseOp(op);
327
328 rewriter.setInsertionPoint(CI);
329 auto newCI =
330 func::CallOp::create(rewriter, CI.getLoc(), NewF.getName(),
331 NewF.getResultTypes(), CI.getOperands());
332
333 rewriter.replaceOp(CI, newCI.getResults());
334
335 delete putils;
336
337 return success();
338 }
339 };
340
341 struct LowerSimulatePattern
342 : public mlir::OpRewritePattern<impulse::SimulateOp> {
343 using mlir::OpRewritePattern<impulse::SimulateOp>::OpRewritePattern;
344
345 LogicalResult matchAndRewrite(impulse::SimulateOp CI,
346 PatternRewriter &rewriter) const override {
347 SymbolTableCollection symbolTable;
348
349 auto fn = cast<FunctionOpInterface>(
350 symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()));
351
352 if (fn.getFunctionBody().empty()) {
353 CI.emitError(
354 "Impulse: calling `simulate` on an empty function; if this "
355 "is a distribution function, its sample op should have a "
356 "logpdf attribute to avoid recursive `simulate` calls which is "
357 "intended for generative functions");
358 return failure();
359 }
360
361 ArrayAttr selection = CI.getSelectionAttr();
362 int64_t positionSize =
363 computePositionSizeForSelection(CI, fn, selection, symbolTable);
364 if (positionSize <= 0) {
365 CI.emitError("Impulse: failed to compute position size for simulate");
366 return failure();
367 }
368
369 auto putils = ImpulseUtils::CreateFromClone(fn, ImpulseMode::Simulate,
370 positionSize);
371 FunctionOpInterface NewF = putils->newFunc;
372
373 OpBuilder entryBuilder(putils->initializationBlock,
374 putils->initializationBlock->begin());
375 Location initLoc = putils->initializationBlock->begin()->getLoc();
376 auto scalarType = RankedTensorType::get({}, entryBuilder.getF64Type());
377 auto zeroWeight =
378 arith::ConstantOp::create(entryBuilder, initLoc, scalarType,
379 DenseElementsAttr::get(scalarType, 0.0));
380 Value weightAccumulator = zeroWeight;
381
382 auto traceType =
383 RankedTensorType::get({1, positionSize}, entryBuilder.getF64Type());
384 auto zeroTrace =
385 arith::ConstantOp::create(entryBuilder, initLoc, traceType,
386 DenseElementsAttr::get(traceType, 0.0));
387 Value currTrace = zeroTrace;
388 int64_t currentOffset = 0;
389
390 SmallVector<Operation *> toErase;
391 auto result = NewF.walk([&](impulse::SampleOp sampleOp) -> WalkResult {
392 OpBuilder::InsertionGuard guard(rewriter);
393 rewriter.setInsertionPoint(sampleOp);
394
395 SmallVector<Value> sampledValues; // Values to replace uses of sample op
396 bool isDistribution = static_cast<bool>(sampleOp.getLogpdfAttr());
397
398 if (isDistribution) {
399 // A1. Distribution function: call the distribution function.
400 auto distFn =
401 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
402 sampleOp, sampleOp.getFnAttr()));
403
404 auto distCall = func::CallOp::create(
405 rewriter, sampleOp.getLoc(), distFn.getName(),
406 distFn.getResultTypes(), sampleOp.getInputs());
407
408 sampledValues.append(distCall.getResults().begin(),
409 distCall.getResults().end());
410
411 auto logpdfFn =
412 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
413 sampleOp, sampleOp.getLogpdfAttr()));
414
415 // logpdf operands: (<non-RNG outputs>..., <non-RNG inputs>...)
416 SmallVector<Value> logpdfOperands;
417 for (unsigned i = 1; i < sampledValues.size(); ++i) {
418 logpdfOperands.push_back(sampledValues[i]);
419 }
420 for (unsigned i = 1; i < sampleOp.getNumOperands(); ++i) {
421 logpdfOperands.push_back(sampleOp.getOperand(i));
422 }
423
424 if (logpdfOperands.size() != logpdfFn.getNumArguments()) {
425 sampleOp.emitError("Impulse: failed to construct logpdf call; "
426 "logpdf function has wrong number of arguments");
427 return WalkResult::interrupt();
428 }
429
430 // A2. Compute and accumulate weight.
431 auto logpdf = func::CallOp::create(
432 rewriter, sampleOp.getLoc(), logpdfFn.getName(),
433 logpdfFn.getResultTypes(), logpdfOperands);
434 weightAccumulator =
435 arith::AddFOp::create(rewriter, sampleOp.getLoc(),
436 weightAccumulator, logpdf.getResult(0));
437
438 // A3. Check if this sample is in the selection and insert into trace
439 bool inSelection = false;
440 for (auto addr : selection) {
441 auto address = cast<ArrayAttr>(addr);
442 if (!address.empty() && address[0] == sampleOp.getSymbolAttr()) {
443 inSelection = true;
444 break;
445 }
446 }
447
448 if (inSelection) {
449 for (unsigned i = 1; i < sampledValues.size(); ++i) {
450 auto sampleValue = sampledValues[i];
451 auto sampleType = cast<RankedTensorType>(sampleValue.getType());
452 int64_t numElements = computeTensorElementCount(sampleType);
453 if (numElements < 0) {
454 sampleOp.emitError(
455 "Impulse: dynamic tensor dimensions not supported");
456 return WalkResult::interrupt();
457 }
458
459 auto flatSampleType = RankedTensorType::get(
460 {1, numElements}, sampleType.getElementType());
461 auto flatSample = impulse::ReshapeOp::create(
462 rewriter, sampleOp.getLoc(), flatSampleType, sampleValue);
463 auto i64S = RankedTensorType::get({}, rewriter.getI64Type());
464 auto row0 = arith::ConstantOp::create(
465 rewriter, sampleOp.getLoc(), i64S,
466 DenseElementsAttr::get(i64S, rewriter.getI64IntegerAttr(0)));
467 auto colOff = arith::ConstantOp::create(
468 rewriter, sampleOp.getLoc(), i64S,
469 DenseElementsAttr::get(
470 i64S, rewriter.getI64IntegerAttr(currentOffset)));
471 currTrace = impulse::DynamicUpdateSliceOp::create(
472 rewriter, sampleOp.getLoc(), traceType, currTrace,
473 flatSample, ValueRange{row0, colOff})
474 .getResult();
475 currentOffset += numElements;
476 }
477 }
478 } else {
479 // B. Generative function: recursively simulate the nested function
480 auto genFn =
481 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
482 sampleOp, sampleOp.getFnAttr()));
483
484 if (genFn.getFunctionBody().empty()) {
485 sampleOp.emitError(
486 "Impulse: generative function body is empty; "
487 "if this is a distribution, add a logpdf attribute");
488 return WalkResult::interrupt();
489 }
490
491 ArrayAttr subSelection =
492 buildSubSelection(rewriter, selection, sampleOp.getSymbolAttr());
493 if (subSelection.empty()) {
494 // No samples from this generative function are in the selection
495 // Just call the function directly
496 auto genCall = func::CallOp::create(
497 rewriter, sampleOp.getLoc(), genFn.getName(),
498 genFn.getResultTypes(), sampleOp.getInputs());
499 sampledValues.append(genCall.getResults().begin(),
500 genCall.getResults().end());
501 } else {
502 int64_t subPositionSize = computePositionSizeForSelection(
503 sampleOp, genFn, subSelection, symbolTable);
504 if (subPositionSize <= 0) {
505 sampleOp.emitError(
506 "Impulse: failed to compute sub-position size");
507 return WalkResult::interrupt();
508 }
509
510 // Build result types: (trace, weight, original_returns...)
511 auto subTraceType = RankedTensorType::get({1, subPositionSize},
512 rewriter.getF64Type());
513 auto scalarTy = RankedTensorType::get({}, rewriter.getF64Type());
514 SmallVector<Type> simResultTypes;
515 simResultTypes.push_back(subTraceType);
516 simResultTypes.push_back(scalarTy);
517 for (auto t : genFn.getResultTypes())
518 simResultTypes.push_back(t);
519
520 auto nestedSimulate = impulse::SimulateOp::create(
521 rewriter, sampleOp.getLoc(), simResultTypes,
522 sampleOp.getFnAttr(), sampleOp.getInputs(), subSelection);
523 auto subTrace = nestedSimulate.getTrace();
524 auto subWeight = nestedSimulate.getWeight();
525
526 weightAccumulator = arith::AddFOp::create(
527 rewriter, sampleOp.getLoc(), weightAccumulator, subWeight);
528
529 int64_t mergeOffset = computeOffsetForNestedSample(
530 sampleOp, fn, selection, sampleOp.getSymbolAttr(), symbolTable);
531 if (mergeOffset < 0) {
532 sampleOp.emitError("Impulse: failed to compute merge offset");
533 return WalkResult::interrupt();
534 }
535
536 auto i64S = RankedTensorType::get({}, rewriter.getI64Type());
537 auto row0 = arith::ConstantOp::create(
538 rewriter, sampleOp.getLoc(), i64S,
539 DenseElementsAttr::get(i64S, rewriter.getI64IntegerAttr(0)));
540 auto colOff = arith::ConstantOp::create(
541 rewriter, sampleOp.getLoc(), i64S,
542 DenseElementsAttr::get(
543 i64S, rewriter.getI64IntegerAttr(mergeOffset)));
544 currTrace = impulse::DynamicUpdateSliceOp::create(
545 rewriter, sampleOp.getLoc(), traceType, currTrace,
546 subTrace, ValueRange{row0, colOff})
547 .getResult();
548 currentOffset =
549 std::max(currentOffset, mergeOffset + subPositionSize);
550
551 for (auto output : nestedSimulate.getOutputs())
552 sampledValues.push_back(output);
553 }
554 }
555
556 // D. Replace uses of the original sample op with the new values.
557 sampleOp.replaceAllUsesWith(sampledValues);
558
559 toErase.push_back(sampleOp);
560 return WalkResult::advance();
561 });
562
563 for (Operation *op : toErase)
564 rewriter.eraseOp(op);
565
566 if (result.wasInterrupted()) {
567 CI.emitError("Impulse: failed to walk sample ops");
568 return failure();
569 }
570
571 // E. Rewrite the return to return (trace, weight, <original returns>...)
572 NewF.walk([&](func::ReturnOp retOp) {
573 OpBuilder::InsertionGuard guard(rewriter);
574 rewriter.setInsertionPoint(retOp);
575 SmallVector<Value> newRetVals;
576 newRetVals.push_back(currTrace);
577 newRetVals.push_back(weightAccumulator);
578 newRetVals.append(retOp.getOperands().begin(),
579 retOp.getOperands().end());
580
581 func::ReturnOp::create(rewriter, retOp.getLoc(), newRetVals);
582 rewriter.eraseOp(retOp);
583 });
584
585 rewriter.setInsertionPoint(CI);
586 auto newCI = func::CallOp::create(rewriter, CI.getLoc(), NewF.getName(),
587 NewF.getResultTypes(), CI.getInputs());
588
589 rewriter.replaceOp(CI, newCI.getResults());
590
591 delete putils;
592
593 return success();
594 }
595 };
596
597 struct LowerMCMCPattern : public mlir::OpRewritePattern<impulse::InferOp> {
598 bool debugDump;
599
600 LowerMCMCPattern(MLIRContext *context, bool debugDump,
601 PatternBenefit benefit = 1)
602 : OpRewritePattern(context, benefit), debugDump(debugDump) {}
603
604 LogicalResult matchAndRewrite(impulse::InferOp mcmcOp,
605 PatternRewriter &rewriter) const override {
606 SymbolTableCollection symbolTable;
607
608 bool hasLogpdfFn = static_cast<bool>(mcmcOp.getLogpdfFnAttr());
609
610 if (!hasLogpdfFn) {
611 auto fnAttr = mcmcOp.getFnAttr();
612 if (!fnAttr) {
613 mcmcOp.emitError("Impulse: either fn or logpdf_fn must be provided");
614 return failure();
615 }
616 auto fn = cast<FunctionOpInterface>(
617 symbolTable.lookupNearestSymbolFrom(mcmcOp, fnAttr));
618 if (fn.getFunctionBody().empty()) {
619 mcmcOp.emitError("Impulse: calling `mcmc` on an empty function");
620 return failure();
621 }
622 }
623
624 if (!mcmcOp.getStepSize()) {
625 mcmcOp.emitError("Impulse: MCMC requires step_size parameter");
626 return failure();
627 }
628
629 bool isHMC = mcmcOp.getHmcConfig().has_value();
630 bool isNUTS = mcmcOp.getNutsConfig().has_value();
631 if (!isHMC && !isNUTS) {
632 mcmcOp.emitError("Impulse: Unknown MCMC algorithm");
633 return failure();
634 }
635
636 auto loc = mcmcOp.getLoc();
637 auto invMass = mcmcOp.getInverseMassMatrix();
638 Value adaptedInvMass = invMass;
639 auto stepSize = mcmcOp.getStepSize();
640
641 auto inputs = mcmcOp.getInputs();
642 if (inputs.empty()) {
643 mcmcOp.emitError("Impulse: MCMC requires at least rng_state input");
644 return failure();
645 }
646
647 auto rngInput = inputs[0];
648
649 int64_t positionSize;
650 SmallVector<Value> fnInputs;
651 SmallVector<Type> fnResultTypes;
652 Value originalTrace;
653 ArrayAttr selection, allAddresses;
654 SmallVector<SupportInfo> supports;
655 FlatSymbolRefAttr logpdfFnAttr;
656
657 if (hasLogpdfFn) {
658 logpdfFnAttr = mcmcOp.getLogpdfFnAttr();
659 fnInputs.assign(inputs.begin() + 1, inputs.end());
660 auto initialPos = mcmcOp.getInitialPosition();
661 auto initPosType = cast<RankedTensorType>(initialPos.getType());
662 positionSize = initPosType.getNumElements();
663 selection = mcmcOp.getSelectionAttr();
664 allAddresses = mcmcOp.getAllAddressesAttr();
665 } else {
666 fnInputs.assign(inputs.begin() + 1, inputs.end());
667 originalTrace = mcmcOp.getOriginalTrace();
668 selection = mcmcOp.getSelectionAttr();
669 allAddresses = mcmcOp.getAllAddressesAttr();
670
671 auto fn = cast<FunctionOpInterface>(
672 symbolTable.lookupNearestSymbolFrom(mcmcOp, mcmcOp.getFnAttr()));
673 positionSize =
674 computePositionSizeForSelection(mcmcOp, fn, selection, symbolTable);
675 if (positionSize <= 0)
676 return failure();
677
678 supports = collectSupportInfoForSelection(mcmcOp, fn, selection,
679 allAddresses, symbolTable);
680
681 auto fnType = cast<FunctionType>(fn.getFunctionType());
682 fnResultTypes.assign(fnType.getResults().begin(),
683 fnType.getResults().end());
684 }
685
686 int64_t numSamples = mcmcOp.getNumSamples();
687 int64_t thinning = mcmcOp.getThinning();
688 int64_t numWarmup = mcmcOp.getNumWarmup();
689
690 auto elemType =
691 cast<RankedTensorType>(stepSize.getType()).getElementType();
692 auto positionType = RankedTensorType::get({1, positionSize}, elemType);
693 auto scalarType = RankedTensorType::get({}, elemType);
694 auto i64TensorType = RankedTensorType::get({}, rewriter.getI64Type());
695 auto i1TensorType = RankedTensorType::get({}, rewriter.getI1Type());
696
697 // Algorithm-specific configuration
698 Value trajectoryLength;
699 Value maxDeltaEnergy;
700 int64_t maxTreeDepth = 0;
701
702 bool adaptStepSize = false;
703 bool adaptMassMatrix = false;
704 auto F64TensorType = RankedTensorType::get({}, rewriter.getF64Type());
705 if (isHMC) {
706 auto hmcConfig = mcmcOp.getHmcConfig().value();
707 double length = hmcConfig.getTrajectoryLength().getValueAsDouble();
708 trajectoryLength = arith::ConstantOp::create(
709 rewriter, loc, F64TensorType,
710 DenseElementsAttr::get(F64TensorType,
711 rewriter.getF64FloatAttr(length)));
712 adaptStepSize = hmcConfig.getAdaptStepSize();
713 adaptMassMatrix = hmcConfig.getAdaptMassMatrix();
714 } else {
715 auto nutsConfig = mcmcOp.getNutsConfig().value();
716 maxTreeDepth = nutsConfig.getMaxTreeDepth();
717 adaptStepSize = nutsConfig.getAdaptStepSize();
718 adaptMassMatrix = nutsConfig.getAdaptMassMatrix();
719 double maxDeltaEnergyVal =
720 nutsConfig.getMaxDeltaEnergy()
721 ? nutsConfig.getMaxDeltaEnergy().getValueAsDouble()
722 : 1000.0;
723 maxDeltaEnergy = arith::ConstantOp::create(
724 rewriter, loc, F64TensorType,
725 DenseElementsAttr::get(
726 F64TensorType, rewriter.getF64FloatAttr(maxDeltaEnergyVal)));
727 }
728
729 bool diagonal = true;
730 if (invMass) {
731 auto invMassType = cast<RankedTensorType>(invMass.getType());
732 diagonal = (invMassType.getRank() == 1);
733 }
734
735 auto adaptedMassMatrixSqrt =
736 computeMassMatrixSqrt(rewriter, loc, adaptedInvMass, positionType);
737
738 auto autodiffAttrs = mcmcOp.getAutodiffAttrsAttr();
739
740 auto makeHMCContext = [&](Value currentInvMass,
741 Value currentMassMatrixSqrt,
742 Value currentStepSize) -> HMCContext {
743 if (hasLogpdfFn) {
744 return HMCContext(logpdfFnAttr, fnInputs, currentInvMass,
745 currentMassMatrixSqrt, currentStepSize,
746 trajectoryLength, positionSize, autodiffAttrs);
747 } else {
748 return HMCContext(mcmcOp.getFnAttr(), fnInputs, fnResultTypes,
749 originalTrace, selection, allAddresses,
750 currentInvMass, currentMassMatrixSqrt,
751 currentStepSize, trajectoryLength, positionSize,
752 supports, autodiffAttrs);
753 }
754 };
755
756 auto makeNUTSContext =
757 [&](Value currentInvMass, Value currentMassMatrixSqrt,
758 Value currentStepSize, Value U) -> NUTSContext {
759 if (hasLogpdfFn) {
760 return NUTSContext(logpdfFnAttr, fnInputs, currentInvMass,
761 currentMassMatrixSqrt, currentStepSize,
762 positionSize, U, maxDeltaEnergy, maxTreeDepth,
763 autodiffAttrs);
764 } else {
765 return NUTSContext(mcmcOp.getFnAttr(), fnInputs, fnResultTypes,
766 originalTrace, selection, allAddresses,
767 currentInvMass, currentMassMatrixSqrt,
768 currentStepSize, positionSize, supports, U,
769 maxDeltaEnergy, maxTreeDepth, autodiffAttrs);
770 }
771 };
772
773 Value currentQ, currentGrad, currentU, currentRng;
774
775 auto initialGrad = mcmcOp.getInitialGradient();
776 auto initialPE = mcmcOp.getInitialPotentialEnergy();
777
778 if (hasLogpdfFn && initialGrad && initialPE) {
779 currentQ = mcmcOp.getInitialPosition();
780 currentGrad = initialGrad;
781 currentU = initialPE;
782 currentRng = rngInput;
783 } else {
784 auto baseCtx =
785 makeHMCContext(adaptedInvMass, adaptedMassMatrixSqrt, stepSize);
786 auto initState = InitHMC(
787 rewriter, loc, rngInput, baseCtx,
788 hasLogpdfFn ? mcmcOp.getInitialPosition() : Value(), debugDump);
789 currentQ = initState.q0;
790 currentGrad = initState.grad0;
791 currentU = initState.U0;
792 currentRng = initState.rng;
793 }
794
795 auto runSampleStepWithStepSize =
796 [&](OpBuilder &builder, Location loc, Value q, Value grad, Value U,
797 Value rng, Value currentStepSize) -> MCMCKernelResult {
798 if (isHMC) {
799 auto ctx = makeHMCContext(adaptedInvMass, adaptedMassMatrixSqrt,
800 currentStepSize);
801 return SampleHMC(builder, loc, q, grad, U, rng, ctx, debugDump);
802 } else {
803 auto nutsCtx = makeNUTSContext(adaptedInvMass, adaptedMassMatrixSqrt,
804 currentStepSize, U);
805 return SampleNUTS(builder, loc, q, grad, U, rng, nutsCtx, debugDump);
806 }
807 };
808
809 Value adaptedStepSize = stepSize;
810
811 auto runSampleStepWithInvMass =
812 [&](OpBuilder &builder, Location loc, Value q, Value grad, Value U,
813 Value rng, Value currentStepSize, Value currentInvMass,
814 Value currentMassMatrixSqrt) -> MCMCKernelResult {
815 if (isHMC) {
816 auto ctx = makeHMCContext(currentInvMass, currentMassMatrixSqrt,
817 currentStepSize);
818 return SampleHMC(builder, loc, q, grad, U, rng, ctx, debugDump);
819 } else {
820 auto nutsCtx = makeNUTSContext(currentInvMass, currentMassMatrixSqrt,
821 currentStepSize, U);
822 return SampleNUTS(builder, loc, q, grad, U, rng, nutsCtx, debugDump);
823 }
824 };
825
826 if (!adaptedInvMass) {
827 adaptedInvMass = arith::ConstantOp::create(
828 rewriter, loc, positionType,
829 DenseElementsAttr::get(positionType,
830 rewriter.getFloatAttr(elemType, 1.0)));
831 adaptedMassMatrixSqrt = arith::ConstantOp::create(
832 rewriter, loc, positionType,
833 DenseElementsAttr::get(positionType,
834 rewriter.getFloatAttr(elemType, 1.0)));
835 }
836
837 if (numWarmup > 0) {
838 auto c0 = arith::ConstantOp::create(
839 rewriter, loc, i64TensorType,
840 DenseElementsAttr::get(i64TensorType,
841 rewriter.getI64IntegerAttr(0)));
842 auto c1 = arith::ConstantOp::create(
843 rewriter, loc, i64TensorType,
844 DenseElementsAttr::get(i64TensorType,
845 rewriter.getI64IntegerAttr(1)));
846 auto numWarmupConst = arith::ConstantOp::create(
847 rewriter, loc, i64TensorType,
848 DenseElementsAttr::get(i64TensorType,
849 rewriter.getI64IntegerAttr(numWarmup)));
850
851 auto schedule = buildAdaptationSchedule(numWarmup);
852 int64_t numWindows = static_cast<int64_t>(schedule.size());
853
854 SmallVector<Value> windowEndConstants;
855 for (const auto &window : schedule) {
856 windowEndConstants.push_back(arith::ConstantOp::create(
857 rewriter, loc, i64TensorType,
858 DenseElementsAttr::get(i64TensorType,
859 rewriter.getI64IntegerAttr(window.end))));
860 }
861
862 auto numWindowsMinusOne = arith::ConstantOp::create(
863 rewriter, loc, i64TensorType,
864 DenseElementsAttr::get(i64TensorType,
865 rewriter.getI64IntegerAttr(numWindows - 1)));
866 auto lastIterConst = arith::ConstantOp::create(
867 rewriter, loc, i64TensorType,
868 DenseElementsAttr::get(i64TensorType,
869 rewriter.getI64IntegerAttr(numWarmup - 1)));
870
871 if (!adaptedInvMass) {
872 adaptedInvMass = arith::ConstantOp::create(
873 rewriter, loc, positionType,
874 DenseElementsAttr::get(positionType,
875 rewriter.getFloatAttr(elemType, 1.0)));
876 adaptedMassMatrixSqrt = arith::ConstantOp::create(
877 rewriter, loc, positionType,
878 DenseElementsAttr::get(positionType,
879 rewriter.getFloatAttr(elemType, 1.0)));
880 }
881
882 Value initialStepSize = stepSize;
883 initialStepSize =
884 conditionalDump(rewriter, loc, initialStepSize,
885 "MCMC: initial step size before warmup", debugDump);
886 DualAveragingState daState =
887 initDualAveraging(rewriter, loc, initialStepSize);
888
889 WelfordState welfordState;
890 WelfordConfig welfordConfig;
891 if (adaptMassMatrix) {
892 welfordState = initWelford(rewriter, loc, positionSize, diagonal);
893 welfordConfig.diagonal = diagonal;
894 welfordConfig.regularize = true;
895 }
896
897 Value windowIdx = arith::ConstantOp::create(
898 rewriter, loc, i64TensorType,
899 DenseElementsAttr::get(i64TensorType,
900 rewriter.getI64IntegerAttr(0)));
901
902 // Warmup loop carries by default:
903 // [q, grad, U, rng, stepSize, invMass, massMatrixSqrt, daState(5),
904 // welfordState(3)?, windowIdx]
905 SmallVector<Type> warmupLoopTypes = {positionType,
906 positionType,
907 scalarType,
908 currentRng.getType(),
909 scalarType, // stepSize
910 adaptedInvMass.getType(),
911 adaptedMassMatrixSqrt.getType()};
912 for (Type t : daState.getTypes())
913 warmupLoopTypes.push_back(t);
914 if (adaptMassMatrix) {
915 for (Type t : welfordState.getTypes())
916 warmupLoopTypes.push_back(t);
917 }
918 warmupLoopTypes.push_back(i64TensorType); // windowIdx
919
920 SmallVector<Value> warmupInitArgs = {currentQ,
921 currentGrad,
922 currentU,
923 currentRng,
924 initialStepSize,
925 adaptedInvMass,
926 adaptedMassMatrixSqrt};
927 for (Value v : daState.toValues())
928 warmupInitArgs.push_back(v);
929 if (adaptMassMatrix) {
930 for (Value v : welfordState.toValues())
931 warmupInitArgs.push_back(v);
932 }
933 warmupInitArgs.push_back(windowIdx);
934
935 auto warmupLoop =
936 impulse::ForOp::create(rewriter, loc, warmupLoopTypes, c0,
937 numWarmupConst, c1, warmupInitArgs);
938
939 Block *warmupBody = rewriter.createBlock(&warmupLoop.getRegion());
940 warmupBody->addArgument(i64TensorType, loc); // iteration index t
941 for (Type t : warmupLoopTypes)
942 warmupBody->addArgument(t, loc);
943
944 rewriter.setInsertionPointToStart(warmupBody);
945
946 Value iterT = warmupBody->getArgument(0);
947 Value qLoop = warmupBody->getArgument(1);
948 Value gradLoop = warmupBody->getArgument(2);
949 Value ULoop = warmupBody->getArgument(3);
950 Value rngLoop = warmupBody->getArgument(4);
951 Value stepSizeLoop = warmupBody->getArgument(5);
952 Value invMassLoop = warmupBody->getArgument(6);
953 Value massMatrixSqrtLoop = warmupBody->getArgument(7);
954
955 SmallVector<Value> daStateLoopValues;
956 for (int i = 0; i < 5; ++i)
957 daStateLoopValues.push_back(warmupBody->getArgument(8 + i));
958 auto daStateLoop = DualAveragingState::fromValues(daStateLoopValues);
959
960 WelfordState welfordStateLoop;
961 Value windowIdxLoop;
962 if (adaptMassMatrix) {
963 SmallVector<Value> welfordStateLoopValues;
964 for (int i = 0; i < 3; ++i)
965 welfordStateLoopValues.push_back(warmupBody->getArgument(13 + i));
966 welfordStateLoop = WelfordState::fromValues(welfordStateLoopValues);
967 windowIdxLoop = warmupBody->getArgument(16);
968 } else {
969 windowIdxLoop = warmupBody->getArgument(13);
970 }
971
972 auto sample = runSampleStepWithInvMass(rewriter, loc, qLoop, gradLoop,
973 ULoop, rngLoop, stepSizeLoop,
974 invMassLoop, massMatrixSqrtLoop);
975
976 // Update dual averaging state
977 DualAveragingConfig daConfig;
978 DualAveragingState updatedDaState;
979 Value currentStepSizeFromDA;
980 Value finalStepSizeFromDA;
981
982 if (adaptStepSize) {
983 updatedDaState = updateDualAveraging(rewriter, loc, daStateLoop,
984 sample.accept_prob, daConfig);
985 currentStepSizeFromDA =
986 getStepSizeFromDualAveraging(rewriter, loc, updatedDaState);
987 finalStepSizeFromDA =
988 getStepSizeFromDualAveraging(rewriter, loc, updatedDaState, true);
989 } else {
990 updatedDaState = daStateLoop;
991 currentStepSizeFromDA = stepSizeLoop;
992 finalStepSizeFromDA = stepSizeLoop;
993 }
994
995 // Use log_step_size_avg at last iteration
996 auto isLastIter = arith::CmpIOp::create(
997 rewriter, loc, arith::CmpIPredicate::eq, iterT, lastIterConst);
998 Value adaptedStepSizeInLoop = impulse::SelectOp::create(
999 rewriter, loc, scalarType, isLastIter, finalStepSizeFromDA,
1000 currentStepSizeFromDA);
1001
1002 const auto &floatSemantics =
1003 cast<FloatType>(elemType).getFloatSemantics();
1004 auto tinyConst = arith::ConstantOp::create(
1005 rewriter, loc, scalarType,
1006 DenseElementsAttr::get(
1007 scalarType, FloatAttr::get(elemType, llvm::APFloat::getSmallest(
1008 floatSemantics))));
1009 auto maxConst = arith::ConstantOp::create(
1010 rewriter, loc, scalarType,
1011 DenseElementsAttr::get(
1012 scalarType, FloatAttr::get(elemType, llvm::APFloat::getLargest(
1013 floatSemantics))));
1014 adaptedStepSizeInLoop = arith::MaximumFOp::create(
1015 rewriter, loc, adaptedStepSizeInLoop, tinyConst);
1016 adaptedStepSizeInLoop = arith::MinimumFOp::create(
1017 rewriter, loc, adaptedStepSizeInLoop, maxConst);
1018
1019 auto windowIdxGtZero = arith::CmpIOp::create(
1020 rewriter, loc, arith::CmpIPredicate::sgt, windowIdxLoop, c0);
1021 auto windowIdxLtLast =
1022 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1023 windowIdxLoop, numWindowsMinusOne);
1024 auto isMiddleWindow = arith::AndIOp::create(
1025 rewriter, loc, windowIdxGtZero, windowIdxLtLast);
1026
1027 // Conditionally update Welford
1028 WelfordState conditionalWelford;
1029 if (adaptMassMatrix) {
1030 auto sampleType1D = RankedTensorType::get({positionSize}, elemType);
1031 Value sample1D =
1032 impulse::ReshapeOp::create(rewriter, loc, sampleType1D, sample.q);
1033 WelfordState updatedWelfordAfterSample = updateWelford(
1034 rewriter, loc, welfordStateLoop, sample1D, welfordConfig);
1035
1036 conditionalWelford.mean = impulse::SelectOp::create(
1037 rewriter, loc, welfordStateLoop.mean.getType(), isMiddleWindow,
1038 updatedWelfordAfterSample.mean, welfordStateLoop.mean);
1039 conditionalWelford.m2 = impulse::SelectOp::create(
1040 rewriter, loc, welfordStateLoop.m2.getType(), isMiddleWindow,
1041 updatedWelfordAfterSample.m2, welfordStateLoop.m2);
1042 conditionalWelford.n = impulse::SelectOp::create(
1043 rewriter, loc, welfordStateLoop.n.getType(), isMiddleWindow,
1044 updatedWelfordAfterSample.n, welfordStateLoop.n);
1045 }
1046
1047 Value atWindowEnd = arith::ConstantOp::create(
1048 rewriter, loc, i1TensorType,
1049 DenseElementsAttr::get(i1TensorType, rewriter.getBoolAttr(false)));
1050
1051 for (int64_t w = 0; w < numWindows; ++w) {
1052 auto windowIdxIsW = arith::CmpIOp::create(
1053 rewriter, loc, arith::CmpIPredicate::eq, windowIdxLoop,
1054 arith::ConstantOp::create(
1055 rewriter, loc, i64TensorType,
1056 DenseElementsAttr::get(i64TensorType,
1057 rewriter.getI64IntegerAttr(w))));
1058 auto tEqualsWindowEnd =
1059 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1060 iterT, windowEndConstants[w]);
1061 auto matchesThisWindow = arith::AndIOp::create(
1062 rewriter, loc, windowIdxIsW, tEqualsWindowEnd);
1063 atWindowEnd = arith::OrIOp::create(rewriter, loc, atWindowEnd,
1064 matchesThisWindow);
1065 }
1066
1067 Value newWindowIdx =
1068 arith::AddIOp::create(rewriter, loc, windowIdxLoop, c1);
1069 Value windowIdxAfterIncrement =
1070 impulse::SelectOp::create(rewriter, loc, i64TensorType, atWindowEnd,
1071 newWindowIdx, windowIdxLoop);
1072
1073 auto atMiddleWindowEnd =
1074 arith::AndIOp::create(rewriter, loc, atWindowEnd, isMiddleWindow);
1075
1076 Value finalInvMass;
1077 Value finalMassMatrixSqrt;
1078 WelfordState finalWelfordState;
1079 Value finalStepSizeValue;
1080 DualAveragingState finalDaState;
1081
1082 SmallVector<Type> ifResultTypes;
1083 ifResultTypes.push_back(invMassLoop.getType());
1084 ifResultTypes.push_back(massMatrixSqrtLoop.getType());
1085 if (adaptMassMatrix) {
1086 ifResultTypes.push_back(conditionalWelford.mean.getType());
1087 ifResultTypes.push_back(conditionalWelford.m2.getType());
1088 ifResultTypes.push_back(conditionalWelford.n.getType());
1089 }
1090 for (Type t : updatedDaState.getTypes())
1091 ifResultTypes.push_back(t);
1092
1093 auto ifOp = impulse::IfOp::create(rewriter, loc, ifResultTypes,
1094 atMiddleWindowEnd);
1095
1096 {
1097 Block *trueBranch = rewriter.createBlock(&ifOp.getTrueBranch());
1098 rewriter.setInsertionPointToStart(trueBranch);
1099
1100 SmallVector<Value> trueYieldValues;
1101
1102 if (adaptMassMatrix) {
1103 auto newInvMass = finalizeWelford(rewriter, loc, conditionalWelford,
1104 welfordConfig);
1105 auto newMassMatrixSqrt =
1106 computeMassMatrixSqrt(rewriter, loc, newInvMass, positionType);
1107 auto reinitWelford =
1108 initWelford(rewriter, loc, positionSize, diagonal);
1109
1110 trueYieldValues.push_back(newInvMass);
1111 trueYieldValues.push_back(newMassMatrixSqrt);
1112 trueYieldValues.push_back(reinitWelford.mean);
1113 trueYieldValues.push_back(reinitWelford.m2);
1114 trueYieldValues.push_back(reinitWelford.n);
1115 } else {
1116 trueYieldValues.push_back(invMassLoop);
1117 trueYieldValues.push_back(massMatrixSqrtLoop);
1118 }
1119
1120 if (adaptStepSize) {
1121 auto reinitDaState =
1122 initDualAveraging(rewriter, loc, adaptedStepSizeInLoop);
1123 for (auto v : reinitDaState.toValues())
1124 trueYieldValues.push_back(v);
1125 } else {
1126 for (auto v : updatedDaState.toValues())
1127 trueYieldValues.push_back(v);
1128 }
1129
1130 impulse::YieldOp::create(rewriter, loc, trueYieldValues);
1131 }
1132
1133 {
1134 Block *falseBranch = rewriter.createBlock(&ifOp.getFalseBranch());
1135 rewriter.setInsertionPointToStart(falseBranch);
1136
1137 SmallVector<Value> falseYieldValues;
1138 falseYieldValues.push_back(invMassLoop);
1139 falseYieldValues.push_back(massMatrixSqrtLoop);
1140 if (adaptMassMatrix) {
1141 falseYieldValues.push_back(conditionalWelford.mean);
1142 falseYieldValues.push_back(conditionalWelford.m2);
1143 falseYieldValues.push_back(conditionalWelford.n);
1144 }
1145 for (auto v : updatedDaState.toValues())
1146 falseYieldValues.push_back(v);
1147
1148 impulse::YieldOp::create(rewriter, loc, falseYieldValues);
1149 }
1150
1151 rewriter.setInsertionPointAfter(ifOp);
1152
1153 size_t resultIdx = 0;
1154 finalInvMass = ifOp.getResult(resultIdx++);
1155 finalMassMatrixSqrt = ifOp.getResult(resultIdx++);
1156 if (adaptMassMatrix) {
1157 finalWelfordState.mean = ifOp.getResult(resultIdx++);
1158 finalWelfordState.m2 = ifOp.getResult(resultIdx++);
1159 finalWelfordState.n = ifOp.getResult(resultIdx++);
1160 }
1161 finalDaState.log_step_size = ifOp.getResult(resultIdx++);
1162 finalDaState.log_step_size_avg = ifOp.getResult(resultIdx++);
1163 finalDaState.gradient_avg = ifOp.getResult(resultIdx++);
1164 finalDaState.step_count = ifOp.getResult(resultIdx++);
1165 finalDaState.prox_center = ifOp.getResult(resultIdx++);
1166
1167 finalStepSizeValue = adaptedStepSizeInLoop;
1168
1169 SmallVector<Value> warmupYieldValues = {
1170 sample.q, sample.grad, sample.U, sample.rng,
1171 finalStepSizeValue, finalInvMass, finalMassMatrixSqrt};
1172 for (Value v : finalDaState.toValues())
1173 warmupYieldValues.push_back(v);
1174 if (adaptMassMatrix) {
1175 for (Value v : finalWelfordState.toValues())
1176 warmupYieldValues.push_back(v);
1177 }
1178 warmupYieldValues.push_back(windowIdxAfterIncrement);
1179
1180 impulse::YieldOp::create(rewriter, loc, warmupYieldValues);
1181
1182 rewriter.setInsertionPointAfter(warmupLoop);
1183
1184 currentQ = warmupLoop.getResult(0);
1185 currentGrad = warmupLoop.getResult(1);
1186 currentU = warmupLoop.getResult(2);
1187 currentRng = warmupLoop.getResult(3);
1188 adaptedStepSize = warmupLoop.getResult(4);
1189 adaptedInvMass = warmupLoop.getResult(5);
1190 adaptedMassMatrixSqrt = warmupLoop.getResult(6);
1191
1192 adaptedStepSize =
1193 conditionalDump(rewriter, loc, adaptedStepSize,
1194 "MCMC: adapted step size after warmup", debugDump);
1195 if (adaptMassMatrix) {
1196 adaptedInvMass = conditionalDump(
1197 rewriter, loc, adaptedInvMass,
1198 "MCMC: adapted inverse mass matrix after warmup", debugDump);
1199 }
1200 }
1201
1202 int64_t collectionSize = numSamples / thinning;
1203 int64_t startIdx = numSamples % thinning;
1204
1205 auto samplesBufferType =
1206 RankedTensorType::get({collectionSize, positionSize}, elemType);
1207 auto acceptedBufferType =
1208 RankedTensorType::get({collectionSize}, rewriter.getI1Type());
1209
1210 auto samplesBuffer = arith::ConstantOp::create(
1211 rewriter, loc, samplesBufferType,
1212 DenseElementsAttr::get(samplesBufferType,
1213 rewriter.getFloatAttr(elemType, 0.0)));
1214 auto acceptedBuffer = arith::ConstantOp::create(
1215 rewriter, loc, acceptedBufferType,
1216 DenseElementsAttr::get(acceptedBufferType,
1217 rewriter.getBoolAttr(isNUTS)));
1218
1219 auto c0 = arith::ConstantOp::create(
1220 rewriter, loc, i64TensorType,
1221 DenseElementsAttr::get(i64TensorType, rewriter.getI64IntegerAttr(0)));
1222 auto c1 = arith::ConstantOp::create(
1223 rewriter, loc, i64TensorType,
1224 DenseElementsAttr::get(i64TensorType, rewriter.getI64IntegerAttr(1)));
1225 auto numSamplesConst = arith::ConstantOp::create(
1226 rewriter, loc, i64TensorType,
1227 DenseElementsAttr::get(i64TensorType,
1228 rewriter.getI64IntegerAttr(numSamples)));
1229 auto startIdxConst = arith::ConstantOp::create(
1230 rewriter, loc, i64TensorType,
1231 DenseElementsAttr::get(i64TensorType,
1232 rewriter.getI64IntegerAttr(startIdx)));
1233 auto thinningConst = arith::ConstantOp::create(
1234 rewriter, loc, i64TensorType,
1235 DenseElementsAttr::get(i64TensorType,
1236 rewriter.getI64IntegerAttr(thinning)));
1237
1238 // Loop carries: [q, grad, U, rng, samplesBuffer, acceptedBuffer]
1239 SmallVector<Type> loopResultTypes = {
1240 positionType, positionType, scalarType,
1241 currentRng.getType(), samplesBufferType, acceptedBufferType};
1242 auto forLoopOp = impulse::ForOp::create(
1243 rewriter, loc, loopResultTypes, c0, numSamplesConst, c1,
1244 ValueRange{currentQ, currentGrad, currentU, currentRng, samplesBuffer,
1245 acceptedBuffer});
1246
1247 Block *loopBody = rewriter.createBlock(&forLoopOp.getRegion());
1248 loopBody->addArgument(i64TensorType, loc); // i (iteration index)
1249 loopBody->addArgument(positionType, loc); // q
1250 loopBody->addArgument(positionType, loc); // grad
1251 loopBody->addArgument(scalarType, loc); // U
1252 loopBody->addArgument(currentRng.getType(), loc); // rng
1253 loopBody->addArgument(samplesBufferType, loc); // samplesBuffer
1254 loopBody->addArgument(acceptedBufferType, loc); // acceptedBuffer
1255
1256 rewriter.setInsertionPointToStart(loopBody);
1257 Value iterIdx = loopBody->getArgument(0);
1258 Value qLoop = loopBody->getArgument(1);
1259 Value gradLoop = loopBody->getArgument(2);
1260 Value ULoop = loopBody->getArgument(3);
1261 Value rngLoop = loopBody->getArgument(4);
1262 Value samplesBufferLoop = loopBody->getArgument(5);
1263 Value acceptedBufferLoop = loopBody->getArgument(6);
1264
1265 auto sample = runSampleStepWithStepSize(rewriter, loc, qLoop, gradLoop,
1266 ULoop, rngLoop, adaptedStepSize);
1267 auto q_constrained =
1268 impulse::constrainPosition(rewriter, loc, sample.q, supports);
1269
1270 // Storage index: idx = (i - start_idx) / thinning
1271 auto iMinusStart =
1272 arith::SubIOp::create(rewriter, loc, iterIdx, startIdxConst);
1273 auto storageIdx =
1274 arith::DivSIOp::create(rewriter, loc, iMinusStart, thinningConst);
1275
1276 // Store condition:
1277 // (i >= start_idx) && ((i - start_idx) % thinning == 0)
1278 auto geStartIdx = arith::CmpIOp::create(
1279 rewriter, loc, arith::CmpIPredicate::sge, iterIdx, startIdxConst);
1280 auto modThinning =
1281 arith::RemSIOp::create(rewriter, loc, iMinusStart, thinningConst);
1282 auto modIsZero = arith::CmpIOp::create(
1283 rewriter, loc, arith::CmpIPredicate::eq, modThinning, c0);
1284 auto shouldStore =
1285 arith::AndIOp::create(rewriter, loc, geStartIdx, modIsZero);
1286
1287 auto zeroCol = arith::ConstantOp::create(
1288 rewriter, loc, i64TensorType,
1289 DenseElementsAttr::get(i64TensorType, rewriter.getI64IntegerAttr(0)));
1290 auto updatedSamplesBuffer = impulse::DynamicUpdateSliceOp::create(
1291 rewriter, loc, samplesBufferType, samplesBufferLoop, q_constrained,
1292 ValueRange{storageIdx, zeroCol});
1293 auto selectedSamplesBuffer = impulse::SelectOp::create(
1294 rewriter, loc, samplesBufferType, shouldStore, updatedSamplesBuffer,
1295 samplesBufferLoop);
1296
1297 auto accepted1D = impulse::ReshapeOp::create(
1298 rewriter, loc, RankedTensorType::get({1}, rewriter.getI1Type()),
1299 sample.accepted);
1300 auto updatedAcceptedBuffer = impulse::DynamicUpdateSliceOp::create(
1301 rewriter, loc, acceptedBufferType, acceptedBufferLoop, accepted1D,
1302 ValueRange{storageIdx});
1303 auto selectedAcceptedBuffer = impulse::SelectOp::create(
1304 rewriter, loc, acceptedBufferType, shouldStore, updatedAcceptedBuffer,
1305 acceptedBufferLoop);
1306
1307 impulse::YieldOp::create(rewriter, loc,
1308 ValueRange{sample.q, sample.grad, sample.U,
1309 sample.rng, selectedSamplesBuffer,
1310 selectedAcceptedBuffer});
1311
1312 rewriter.setInsertionPointAfter(forLoopOp);
1313 Value finalQ = forLoopOp.getResult(0);
1314 Value finalGrad = forLoopOp.getResult(1);
1315 Value finalU = forLoopOp.getResult(2);
1316 Value finalRng = forLoopOp.getResult(3);
1317 Value finalSamplesBuffer = forLoopOp.getResult(4);
1318 Value finalAcceptedBuffer = forLoopOp.getResult(5);
1319
1320 finalSamplesBuffer =
1321 conditionalDump(rewriter, loc, finalSamplesBuffer,
1322 "MCMC: collected samples", debugDump);
1323
1324 rewriter.replaceOp(mcmcOp, {finalSamplesBuffer, finalAcceptedBuffer,
1325 finalRng, finalQ, finalGrad, finalU,
1326 adaptedStepSize, adaptedInvMass});
1327
1328 return success();
1329 }
1330 };
1331
1332 struct LowerMHPattern : public mlir::OpRewritePattern<impulse::MHOp> {
1333 using mlir::OpRewritePattern<impulse::MHOp>::OpRewritePattern;
1334
1335 LogicalResult matchAndRewrite(impulse::MHOp mhOp,
1336 PatternRewriter &rewriter) const override {
1337 SymbolTableCollection symbolTable;
1338
1339 auto fn = cast<FunctionOpInterface>(
1340 symbolTable.lookupNearestSymbolFrom(mhOp, mhOp.getFnAttr()));
1341
1342 if (fn.getFunctionBody().empty()) {
1343 mhOp.emitError(
1344 "Impulse: calling `mh` on an empty function; if this is a "
1345 "distribution function, its sample op should have a logpdf "
1346 "attribute to avoid recursive `mh` calls which is intended for "
1347 "generative functions");
1348 return failure();
1349 }
1350
1351 auto loc = mhOp.getLoc();
1352
1353 Value oldTrace = mhOp.getOperand(0);
1354 Value oldWeight = mhOp.getOperand(1);
1355 SmallVector<Value> inputs;
1356 for (unsigned i = 2; i < mhOp.getNumOperands(); ++i)
1357 inputs.push_back(mhOp.getOperand(i));
1358 auto selection = mhOp.getSelectionAttr();
1359
1360 auto traceType = oldTrace.getType();
1361 auto weightType = cast<RankedTensorType>(oldWeight.getType());
1362 auto rngStateType = inputs[0].getType();
1363
1364 // 1. Create regenerate op with the same function and selection
1365 auto nameAttr = mhOp.getNameAttr();
1366 if (!nameAttr)
1367 nameAttr = rewriter.getStringAttr("");
1368
1369 auto regenerateAddresses = mhOp.getRegenerateAddressesAttr();
1370
1371 SmallVector<Type> regenResultTypes;
1372 regenResultTypes.push_back(traceType);
1373 regenResultTypes.push_back(weightType);
1374 for (auto t : fn.getResultTypes())
1375 regenResultTypes.push_back(t);
1376
1377 auto regenerateOp = rewriter.create<impulse::RegenerateOp>(
1378 loc,
1379 /*resultTypes*/ regenResultTypes,
1380 /*fn*/ mhOp.getFnAttr(),
1381 /*inputs*/ inputs,
1382 /*original_trace*/ oldTrace,
1383 /*selection*/ selection,
1384 /*regenerate_addresses*/ regenerateAddresses,
1385 /*name*/ nameAttr);
1386
1387 Value newTrace = regenerateOp.getNewTrace();
1388 Value newWeight = regenerateOp.getWeight();
1389 Value newRng = regenerateOp.getOutputs()[0];
1390
1391 // 2. Compute log_alpha = new_weight - old_weight
1392 auto logAlpha =
1393 arith::SubFOp::create(rewriter, loc, newWeight, oldWeight);
1394
1395 // 3. Sample uniform random in (0, 1) and compute log
1396 auto zeroConst = arith::ConstantOp::create(
1397 rewriter, loc, weightType, DenseElementsAttr::get(weightType, 0.0));
1398 auto oneConst = arith::ConstantOp::create(
1399 rewriter, loc, weightType, DenseElementsAttr::get(weightType, 1.0));
1400
1401 auto randomOp = impulse::RandomOp::create(
1402 rewriter, loc, TypeRange{rngStateType, weightType}, newRng, zeroConst,
1403 oneConst,
1404 impulse::RngDistributionAttr::get(rewriter.getContext(),
1405 impulse::RngDistribution::UNIFORM));
1406 auto logRand = math::LogOp::create(rewriter, loc, randomOp.getResult());
1407 Value finalRng = randomOp.getOutputRngState();
1408
1409 // 4. Check if proposal is accepted: log(rand()) < log_alpha
1410 auto accepted = arith::CmpFOp::create(
1411 rewriter, loc, arith::CmpFPredicate::OLT, logRand, logAlpha);
1412
1413 // 5. Select trace and weight based on acceptance
1414 auto selectedTrace = impulse::SelectOp::create(
1415 rewriter, loc, traceType, accepted, newTrace, oldTrace);
1416 auto selectedWeight = arith::SelectOp::create(rewriter, loc, accepted,
1417 newWeight, oldWeight);
1418
1419 rewriter.replaceOp(mhOp,
1420 {selectedTrace, selectedWeight, accepted, finalRng});
1421 return success();
1422 }
1423 };
1424
1425 struct LowerGeneratePattern
1426 : public mlir::OpRewritePattern<impulse::GenerateOp> {
1427 using mlir::OpRewritePattern<impulse::GenerateOp>::OpRewritePattern;
1428
1429 LogicalResult matchAndRewrite(impulse::GenerateOp CI,
1430 PatternRewriter &rewriter) const override {
1431 SymbolTableCollection symbolTable;
1432
1433 auto fn = cast<FunctionOpInterface>(
1434 symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()));
1435
1436 if (fn.getFunctionBody().empty()) {
1437 CI.emitError(
1438 "Impulse: calling `generate` on an empty function; if this "
1439 "is a distribution function, its sample op should have a "
1440 "logpdf attribute to avoid recursive `generate` calls which is "
1441 "intended for generative functions");
1442 return failure();
1443 }
1444
1445 ArrayAttr selection = CI.getSelectionAttr();
1446 int64_t positionSize =
1447 computePositionSizeForSelection(CI, fn, selection, symbolTable);
1448 if (positionSize <= 0) {
1449 CI.emitError("Impulse: failed to compute position size for generate");
1450 return failure();
1451 }
1452
1453 int64_t constraintSize = computePositionSizeForSelection(
1454 CI, fn, CI.getConstrainedAddressesAttr(), symbolTable);
1455 if (constraintSize < 0) {
1456 CI.emitError("Impulse: failed to compute constraint size for generate");
1457 return failure();
1458 }
1459
1460 auto putils = ImpulseUtils::CreateFromClone(fn, ImpulseMode::Generate,
1461 positionSize, constraintSize);
1462 FunctionOpInterface NewF = putils->newFunc;
1463
1464 OpBuilder entryBuilder(putils->initializationBlock,
1465 putils->initializationBlock->begin());
1466 Location initLoc = putils->initializationBlock->begin()->getLoc();
1467
1468 auto scalarType = RankedTensorType::get({}, entryBuilder.getF64Type());
1469 auto zeroWeight =
1470 arith::ConstantOp::create(entryBuilder, initLoc, scalarType,
1471 DenseElementsAttr::get(scalarType, 0.0));
1472 Value weightAccumulator = zeroWeight;
1473
1474 auto traceType =
1475 RankedTensorType::get({1, positionSize}, entryBuilder.getF64Type());
1476 auto zeroTrace =
1477 arith::ConstantOp::create(entryBuilder, initLoc, traceType,
1478 DenseElementsAttr::get(traceType, 0.0));
1479 Value currTrace = zeroTrace;
1480 Value constraint = NewF.getArgument(0);
1481 int64_t currentTraceOffset = 0;
1482
1483 SmallVector<Operation *> toErase;
1484 auto result = NewF.walk([&](impulse::SampleOp sampleOp) -> WalkResult {
1485 OpBuilder::InsertionGuard guard(rewriter);
1486 rewriter.setInsertionPoint(sampleOp);
1487
1488 SmallVector<Value> sampledValues;
1489 bool isDistribution = static_cast<bool>(sampleOp.getLogpdfAttr());
1490
1491 if (isDistribution) {
1492 // A1. Distribution function: call the distribution function.
1493 bool isConstrained = false;
1494 int64_t constrainedOffset = -1;
1495 for (auto addr : CI.getConstrainedAddressesAttr()) {
1496 auto address = cast<ArrayAttr>(addr);
1497 if (!address.empty() && address[0] == sampleOp.getSymbolAttr()) {
1498 if (address.size() != 1) {
1499 sampleOp.emitError(
1500 "Impulse: distribution function cannot have composite "
1501 "constrained address");
1502 return WalkResult::interrupt();
1503 }
1504 isConstrained = true;
1505 constrainedOffset = computeOffsetForSampleInSelection(
1506 CI, fn, CI.getConstrainedAddressesAttr(),
1507 sampleOp.getSymbolAttr(), symbolTable);
1508 break;
1509 }
1510 }
1511
1512 if (isConstrained) {
1513 // Extract sampled values from constraint tensor
1514 sampledValues.resize(sampleOp.getNumResults());
1515 sampledValues[0] = sampleOp.getOperand(0); // RNG state
1516
1517 for (unsigned i = 1; i < sampleOp.getNumResults(); ++i) {
1518 auto resultType =
1519 cast<RankedTensorType>(sampleOp.getResult(i).getType());
1520 int64_t numElements = computeTensorElementCount(resultType);
1521 if (numElements < 0) {
1522 sampleOp.emitError(
1523 "Impulse: dynamic tensor dimensions not supported");
1524 return WalkResult::interrupt();
1525 }
1526
1527 auto sliceType = RankedTensorType::get(
1528 {1, numElements}, resultType.getElementType());
1529 auto sliced = impulse::SliceOp::create(
1530 rewriter, sampleOp.getLoc(), sliceType, constraint,
1531 rewriter.getDenseI64ArrayAttr({0, constrainedOffset}),
1532 rewriter.getDenseI64ArrayAttr(
1533 {1, constrainedOffset + numElements}),
1534 rewriter.getDenseI64ArrayAttr({1, 1}));
1535 auto extracted = impulse::ReshapeOp::create(
1536 rewriter, sampleOp.getLoc(), resultType, sliced);
1537 sampledValues[i] = extracted.getResult();
1538 constrainedOffset += numElements;
1539 }
1540
1541 // Compute weight via logpdf using constrained values.
1542 auto logpdfFn =
1543 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
1544 sampleOp, sampleOp.getLogpdfAttr()));
1545
1546 // logpdf operands: (<non-RNG outputs>..., <non-RNG inputs>...)
1547 SmallVector<Value> logpdfOperands;
1548 for (unsigned i = 1; i < sampledValues.size(); ++i) {
1549 logpdfOperands.push_back(sampledValues[i]);
1550 }
1551 for (unsigned i = 1; i < sampleOp.getNumOperands(); ++i) {
1552 logpdfOperands.push_back(sampleOp.getOperand(i));
1553 }
1554
1555 if (logpdfOperands.size() != logpdfFn.getNumArguments()) {
1556 sampleOp.emitError(
1557 "Impulse: failed to construct logpdf call for constrained "
1558 "sample; logpdf function has wrong number of arguments");
1559 return WalkResult::interrupt();
1560 }
1561
1562 auto logpdf = func::CallOp::create(
1563 rewriter, sampleOp.getLoc(), logpdfFn.getName(),
1564 logpdfFn.getResultTypes(), logpdfOperands);
1565 weightAccumulator =
1566 arith::AddFOp::create(rewriter, sampleOp.getLoc(),
1567 weightAccumulator, logpdf.getResult(0));
1568 } else {
1569 // Unconstrained: call the distribution function
1570 auto distFn =
1571 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
1572 sampleOp, sampleOp.getFnAttr()));
1573
1574 auto distCall = func::CallOp::create(
1575 rewriter, sampleOp.getLoc(), distFn.getName(),
1576 distFn.getResultTypes(), sampleOp.getInputs());
1577
1578 sampledValues.append(distCall.getResults().begin(),
1579 distCall.getResults().end());
1580
1581 auto logpdfFn =
1582 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
1583 sampleOp, sampleOp.getLogpdfAttr()));
1584
1585 SmallVector<Value> logpdfOperands;
1586 for (unsigned i = 1; i < sampledValues.size(); ++i) {
1587 logpdfOperands.push_back(sampledValues[i]);
1588 }
1589 for (unsigned i = 1; i < sampleOp.getNumOperands(); ++i) {
1590 logpdfOperands.push_back(sampleOp.getOperand(i));
1591 }
1592
1593 if (logpdfOperands.size() != logpdfFn.getNumArguments()) {
1594 sampleOp.emitError(
1595 "Impulse: failed to construct logpdf call; "
1596 "logpdf function has wrong number of arguments");
1597 return WalkResult::interrupt();
1598 }
1599
1600 auto logpdf = func::CallOp::create(
1601 rewriter, sampleOp.getLoc(), logpdfFn.getName(),
1602 logpdfFn.getResultTypes(), logpdfOperands);
1603 weightAccumulator =
1604 arith::AddFOp::create(rewriter, sampleOp.getLoc(),
1605 weightAccumulator, logpdf.getResult(0));
1606 }
1607
1608 bool inSelection = false;
1609 for (auto addr : selection) {
1610 auto address = cast<ArrayAttr>(addr);
1611 if (!address.empty() && address[0] == sampleOp.getSymbolAttr()) {
1612 inSelection = true;
1613 break;
1614 }
1615 }
1616
1617 if (inSelection) {
1618 for (unsigned i = 1; i < sampledValues.size(); ++i) {
1619 auto sampleValue = sampledValues[i];
1620 auto sampleType = cast<RankedTensorType>(sampleValue.getType());
1621 int64_t numElements = computeTensorElementCount(sampleType);
1622 if (numElements < 0) {
1623 sampleOp.emitError(
1624 "Impulse: dynamic tensor dimensions not supported");
1625 return WalkResult::interrupt();
1626 }
1627
1628 auto flatSampleType = RankedTensorType::get(
1629 {1, numElements}, sampleType.getElementType());
1630 auto flatSample = impulse::ReshapeOp::create(
1631 rewriter, sampleOp.getLoc(), flatSampleType, sampleValue);
1632 auto i64S = RankedTensorType::get({}, rewriter.getI64Type());
1633 auto row0 = arith::ConstantOp::create(
1634 rewriter, sampleOp.getLoc(), i64S,
1635 DenseElementsAttr::get(i64S, rewriter.getI64IntegerAttr(0)));
1636 auto colOff = arith::ConstantOp::create(
1637 rewriter, sampleOp.getLoc(), i64S,
1638 DenseElementsAttr::get(
1639 i64S, rewriter.getI64IntegerAttr(currentTraceOffset)));
1640 currTrace = impulse::DynamicUpdateSliceOp::create(
1641 rewriter, sampleOp.getLoc(), traceType, currTrace,
1642 flatSample, ValueRange{row0, colOff})
1643 .getResult();
1644 currentTraceOffset += numElements;
1645 }
1646 }
1647 } else {
1648 // B. Generative function: recursively generate the nested function
1649 auto genFn =
1650 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
1651 sampleOp, sampleOp.getFnAttr()));
1652
1653 if (genFn.getFunctionBody().empty()) {
1654 sampleOp.emitError(
1655 "Impulse: generative function body is empty; "
1656 "if this is a distribution, add a logpdf attribute");
1657 return WalkResult::interrupt();
1658 }
1659
1660 ArrayAttr subSelection =
1661 buildSubSelection(rewriter, selection, sampleOp.getSymbolAttr());
1662 ArrayAttr subConstrainedAddrs =
1663 buildSubSelection(rewriter, CI.getConstrainedAddressesAttr(),
1664 sampleOp.getSymbolAttr());
1665
1666 if (subSelection.empty()) {
1667 // No samples from this generative function are in the selection
1668 // Just call the function directly
1669 auto genCall = func::CallOp::create(
1670 rewriter, sampleOp.getLoc(), genFn.getName(),
1671 genFn.getResultTypes(), sampleOp.getInputs());
1672 sampledValues.append(genCall.getResults().begin(),
1673 genCall.getResults().end());
1674 } else {
1675 int64_t subPositionSize = computePositionSizeForSelection(
1676 sampleOp, genFn, subSelection, symbolTable);
1677 int64_t subConstraintSize = computePositionSizeForSelection(
1678 sampleOp, genFn, subConstrainedAddrs, symbolTable);
1679 if (subPositionSize <= 0 || subConstraintSize < 0) {
1680 sampleOp.emitError("Impulse: failed to compute sub-position or "
1681 "sub-constraint size");
1682 return WalkResult::interrupt();
1683 }
1684
1685 Value subConstraint;
1686 auto subConstraintType = RankedTensorType::get(
1687 {1, subConstraintSize}, rewriter.getF64Type());
1688
1689 if (subConstraintSize > 0) {
1690 int64_t subConstraintOffset = computeOffsetForNestedSample(
1691 sampleOp, fn, CI.getConstrainedAddressesAttr(),
1692 sampleOp.getSymbolAttr(), symbolTable);
1693
1694 subConstraint = impulse::SliceOp::create(
1695 rewriter, sampleOp.getLoc(), subConstraintType, constraint,
1696 rewriter.getDenseI64ArrayAttr({0, subConstraintOffset}),
1697 rewriter.getDenseI64ArrayAttr(
1698 {1, subConstraintOffset + subConstraintSize}),
1699 rewriter.getDenseI64ArrayAttr({1, 1}));
1700 } else {
1701 subConstraint = arith::ConstantOp::create(
1702 rewriter, sampleOp.getLoc(), subConstraintType,
1703 DenseElementsAttr::get(subConstraintType, {0.0}));
1704 }
1705
1706 // Build result types: (trace, weight, original_returns...)
1707 auto subTraceType = RankedTensorType::get({1, subPositionSize},
1708 rewriter.getF64Type());
1709 auto scalarTy = RankedTensorType::get({}, rewriter.getF64Type());
1710 SmallVector<Type> genResultTypes;
1711 genResultTypes.push_back(subTraceType);
1712 genResultTypes.push_back(scalarTy);
1713 for (auto t : genFn.getResultTypes())
1714 genResultTypes.push_back(t);
1715
1716 auto nestedGenerate = impulse::GenerateOp::create(
1717 rewriter, sampleOp.getLoc(), genResultTypes,
1718 sampleOp.getFnAttr(), sampleOp.getInputs(), subConstraint,
1719 subSelection, subConstrainedAddrs);
1720
1721 Value subTrace = nestedGenerate.getTrace();
1722 Value subWeight = nestedGenerate.getWeight();
1723
1724 weightAccumulator = arith::AddFOp::create(
1725 rewriter, sampleOp.getLoc(), weightAccumulator, subWeight);
1726
1727 int64_t mergeOffset = computeOffsetForNestedSample(
1728 sampleOp, fn, selection, sampleOp.getSymbolAttr(), symbolTable);
1729
1730 auto i64S = RankedTensorType::get({}, rewriter.getI64Type());
1731 auto row0 = arith::ConstantOp::create(
1732 rewriter, sampleOp.getLoc(), i64S,
1733 DenseElementsAttr::get(i64S, rewriter.getI64IntegerAttr(0)));
1734 auto colOff = arith::ConstantOp::create(
1735 rewriter, sampleOp.getLoc(), i64S,
1736 DenseElementsAttr::get(
1737 i64S, rewriter.getI64IntegerAttr(mergeOffset)));
1738 currTrace = impulse::DynamicUpdateSliceOp::create(
1739 rewriter, sampleOp.getLoc(), traceType, currTrace,
1740 subTrace, ValueRange{row0, colOff})
1741 .getResult();
1742 currentTraceOffset =
1743 std::max(currentTraceOffset, mergeOffset + subPositionSize);
1744
1745 for (auto output : nestedGenerate.getOutputs())
1746 sampledValues.push_back(output);
1747 }
1748 }
1749
1750 sampleOp.replaceAllUsesWith(sampledValues);
1751 toErase.push_back(sampleOp);
1752 return WalkResult::advance();
1753 });
1754
1755 for (Operation *op : toErase)
1756 rewriter.eraseOp(op);
1757
1758 if (result.wasInterrupted()) {
1759 CI.emitError("Impulse: failed to walk sample ops");
1760 return failure();
1761 }
1762
1763 // Rewrite the return to return (trace, weight, <original returns>...)
1764 NewF.walk([&](func::ReturnOp retOp) {
1765 OpBuilder::InsertionGuard guard(rewriter);
1766 rewriter.setInsertionPoint(retOp);
1767
1768 SmallVector<Value> newRetVals;
1769 newRetVals.push_back(currTrace);
1770 newRetVals.push_back(weightAccumulator);
1771 newRetVals.append(retOp.getOperands().begin(),
1772 retOp.getOperands().end());
1773
1774 func::ReturnOp::create(rewriter, retOp.getLoc(), newRetVals);
1775 rewriter.eraseOp(retOp);
1776 });
1777
1778 rewriter.setInsertionPoint(CI);
1779 SmallVector<Value> operands;
1780 operands.push_back(CI.getConstraint());
1781 operands.append(CI.getInputs().begin(), CI.getInputs().end());
1782 auto newCI = func::CallOp::create(rewriter, CI.getLoc(), NewF.getName(),
1783 NewF.getResultTypes(), operands);
1784
1785 rewriter.replaceOp(CI, newCI.getResults());
1786
1787 delete putils;
1788
1789 return success();
1790 }
1791 };
1792
1793 struct LowerRegeneratePattern
1794 : public mlir::OpRewritePattern<impulse::RegenerateOp> {
1795 using mlir::OpRewritePattern<impulse::RegenerateOp>::OpRewritePattern;
1796
1797 LogicalResult matchAndRewrite(impulse::RegenerateOp CI,
1798 PatternRewriter &rewriter) const override {
1799 SymbolTableCollection symbolTable;
1800
1801 auto fn = cast<FunctionOpInterface>(
1802 symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()));
1803
1804 if (fn.getFunctionBody().empty()) {
1805 CI.emitError(
1806 "Impulse: calling `regenerate` on an empty function; if this "
1807 "is a distribution function, its sample op should have a "
1808 "logpdf attribute to avoid recursive `regenerate` calls which is "
1809 "intended for generative functions");
1810 return failure();
1811 }
1812
1813 ArrayAttr selection = CI.getSelectionAttr();
1814 int64_t positionSize =
1815 computePositionSizeForSelection(CI, fn, selection, symbolTable);
1816 if (positionSize <= 0) {
1817 CI.emitError("Impulse: failed to compute position size for regenerate");
1818 return failure();
1819 }
1820
1821 auto putils = ImpulseUtils::CreateFromClone(fn, ImpulseMode::Regenerate,
1822 positionSize);
1823 FunctionOpInterface NewF = putils->newFunc;
1824
1825 OpBuilder entryBuilder(putils->initializationBlock,
1826 putils->initializationBlock->begin());
1827 Location initLoc = putils->initializationBlock->begin()->getLoc();
1828
1829 auto scalarType = RankedTensorType::get({}, entryBuilder.getF64Type());
1830 auto zeroWeight =
1831 arith::ConstantOp::create(entryBuilder, initLoc, scalarType,
1832 DenseElementsAttr::get(scalarType, 0.0));
1833 Value weightAccumulator = zeroWeight;
1834
1835 auto traceType =
1836 RankedTensorType::get({1, positionSize}, entryBuilder.getF64Type());
1837 auto zeroTrace =
1838 arith::ConstantOp::create(entryBuilder, initLoc, traceType,
1839 DenseElementsAttr::get(traceType, 0.0));
1840 Value currTrace = zeroTrace;
1841
1842 Value prevTrace = NewF.getArgument(0);
1843 int64_t currentTraceOffset = 0;
1844
1845 SmallVector<Operation *> toErase;
1846 auto result = NewF.walk([&](impulse::SampleOp sampleOp) -> WalkResult {
1847 OpBuilder::InsertionGuard guard(rewriter);
1848 rewriter.setInsertionPoint(sampleOp);
1849
1850 SmallVector<Value> sampledValues;
1851 bool isDistribution = static_cast<bool>(sampleOp.getLogpdfAttr());
1852
1853 if (isDistribution) {
1854 // A1. Distribution function: call the distribution function.
1855 bool isSelected = false;
1856 for (auto addr : CI.getRegenerateAddressesAttr()) {
1857 auto address = cast<ArrayAttr>(addr);
1858 if (!address.empty() && address[0] == sampleOp.getSymbolAttr()) {
1859 if (address.size() != 1) {
1860 sampleOp.emitError(
1861 "Impulse: distribution function cannot have composite "
1862 "selected address");
1863 return WalkResult::interrupt();
1864 }
1865 isSelected = true;
1866 break;
1867 }
1868 }
1869
1870 int64_t sampleOffset = computeOffsetForSampleInSelection(
1871 CI, fn, selection, sampleOp.getSymbolAttr(), symbolTable);
1872
1873 if (isSelected) {
1874 // A2. Regenerate: call the distribution function.
1875 auto distFn =
1876 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
1877 sampleOp, sampleOp.getFnAttr()));
1878
1879 auto distCall = func::CallOp::create(
1880 rewriter, sampleOp.getLoc(), distFn.getName(),
1881 distFn.getResultTypes(), sampleOp.getInputs());
1882
1883 sampledValues.append(distCall.getResults().begin(),
1884 distCall.getResults().end());
1885 } else {
1886 // B. Generative function: extract from original trace.
1887 sampledValues.resize(sampleOp.getNumResults());
1888 sampledValues[0] = sampleOp.getOperand(0); // RNG state
1889
1890 int64_t extractOffset = sampleOffset;
1891 for (unsigned i = 1; i < sampleOp.getNumResults(); ++i) {
1892 auto resultType =
1893 cast<RankedTensorType>(sampleOp.getResult(i).getType());
1894 int64_t numElements = computeTensorElementCount(resultType);
1895 if (numElements < 0) {
1896 sampleOp.emitError(
1897 "Impulse: dynamic tensor dimensions not supported");
1898 return WalkResult::interrupt();
1899 }
1900
1901 auto sliceType = RankedTensorType::get(
1902 {1, numElements}, resultType.getElementType());
1903 auto sliced = impulse::SliceOp::create(
1904 rewriter, sampleOp.getLoc(), sliceType, prevTrace,
1905 rewriter.getDenseI64ArrayAttr({0, extractOffset}),
1906 rewriter.getDenseI64ArrayAttr(
1907 {1, extractOffset + numElements}),
1908 rewriter.getDenseI64ArrayAttr({1, 1}));
1909 auto extracted = impulse::ReshapeOp::create(
1910 rewriter, sampleOp.getLoc(), resultType, sliced);
1911 sampledValues[i] = extracted.getResult();
1912 extractOffset += numElements;
1913 }
1914 }
1915
1916 auto logpdfFn =
1917 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
1918 sampleOp, sampleOp.getLogpdfAttr()));
1919
1920 SmallVector<Value> logpdfOperands;
1921 for (unsigned i = 1; i < sampledValues.size(); ++i) {
1922 logpdfOperands.push_back(sampledValues[i]);
1923 }
1924 for (unsigned i = 1; i < sampleOp.getNumOperands(); ++i) {
1925 logpdfOperands.push_back(sampleOp.getOperand(i));
1926 }
1927
1928 if (logpdfOperands.size() != logpdfFn.getNumArguments()) {
1929 sampleOp.emitError("Impulse: failed to construct logpdf call; "
1930 "logpdf function has wrong number of arguments");
1931 return WalkResult::interrupt();
1932 }
1933
1934 auto logpdf = func::CallOp::create(
1935 rewriter, sampleOp.getLoc(), logpdfFn.getName(),
1936 logpdfFn.getResultTypes(), logpdfOperands);
1937 weightAccumulator =
1938 arith::AddFOp::create(rewriter, sampleOp.getLoc(),
1939 weightAccumulator, logpdf.getResult(0));
1940
1941 bool inSelection = false;
1942 for (auto addr : selection) {
1943 auto address = cast<ArrayAttr>(addr);
1944 if (!address.empty() && address[0] == sampleOp.getSymbolAttr()) {
1945 inSelection = true;
1946 break;
1947 }
1948 }
1949
1950 if (inSelection) {
1951 for (unsigned i = 1; i < sampledValues.size(); ++i) {
1952 auto sampleValue = sampledValues[i];
1953 auto sampleType = cast<RankedTensorType>(sampleValue.getType());
1954 int64_t numElements = computeTensorElementCount(sampleType);
1955 if (numElements < 0) {
1956 sampleOp.emitError(
1957 "Impulse: dynamic tensor dimensions not supported");
1958 return WalkResult::interrupt();
1959 }
1960
1961 auto flatSampleType = RankedTensorType::get(
1962 {1, numElements}, sampleType.getElementType());
1963 auto flatSample = impulse::ReshapeOp::create(
1964 rewriter, sampleOp.getLoc(), flatSampleType, sampleValue);
1965 auto i64S = RankedTensorType::get({}, rewriter.getI64Type());
1966 auto row0 = arith::ConstantOp::create(
1967 rewriter, sampleOp.getLoc(), i64S,
1968 DenseElementsAttr::get(i64S, rewriter.getI64IntegerAttr(0)));
1969 auto colOff = arith::ConstantOp::create(
1970 rewriter, sampleOp.getLoc(), i64S,
1971 DenseElementsAttr::get(
1972 i64S, rewriter.getI64IntegerAttr(currentTraceOffset)));
1973 currTrace = impulse::DynamicUpdateSliceOp::create(
1974 rewriter, sampleOp.getLoc(), traceType, currTrace,
1975 flatSample, ValueRange{row0, colOff})
1976 .getResult();
1977 currentTraceOffset += numElements;
1978 }
1979 }
1980 } else {
1981 // B. Generative function: recursively regenerate the nested function
1982 auto genFn =
1983 cast<FunctionOpInterface>(symbolTable.lookupNearestSymbolFrom(
1984 sampleOp, sampleOp.getFnAttr()));
1985
1986 if (genFn.getFunctionBody().empty()) {
1987 sampleOp.emitError(
1988 "Impulse: generative function body is empty; "
1989 "if this is a distribution, add a logpdf attribute");
1990 return WalkResult::interrupt();
1991 }
1992
1993 ArrayAttr subSelection =
1994 buildSubSelection(rewriter, selection, sampleOp.getSymbolAttr());
1995 ArrayAttr subRegenerateAddrs =
1996 buildSubSelection(rewriter, CI.getRegenerateAddressesAttr(),
1997 sampleOp.getSymbolAttr());
1998
1999 if (subSelection.empty()) {
2000 auto genCall = func::CallOp::create(
2001 rewriter, sampleOp.getLoc(), genFn.getName(),
2002 genFn.getResultTypes(), sampleOp.getInputs());
2003 sampledValues.append(genCall.getResults().begin(),
2004 genCall.getResults().end());
2005 } else {
2006 int64_t subPositionSize = computePositionSizeForSelection(
2007 sampleOp, genFn, subSelection, symbolTable);
2008 if (subPositionSize <= 0) {
2009 sampleOp.emitError(
2010 "Impulse: failed to compute sub-position size");
2011 return WalkResult::interrupt();
2012 }
2013
2014 int64_t mergeOffset = computeOffsetForNestedSample(
2015 sampleOp, fn, selection, sampleOp.getSymbolAttr(), symbolTable);
2016 if (mergeOffset < 0) {
2017 sampleOp.emitError("Impulse: failed to compute merge offset");
2018 return WalkResult::interrupt();
2019 }
2020
2021 auto subTraceType = RankedTensorType::get({1, subPositionSize},
2022 rewriter.getF64Type());
2023 Value subPrevTrace = impulse::SliceOp::create(
2024 rewriter, sampleOp.getLoc(), subTraceType, prevTrace,
2025 rewriter.getDenseI64ArrayAttr({0, mergeOffset}),
2026 rewriter.getDenseI64ArrayAttr(
2027 {1, mergeOffset + subPositionSize}),
2028 rewriter.getDenseI64ArrayAttr({1, 1}));
2029
2030 // Build result types: (new_trace, weight, original_returns...)
2031 auto scalarTy = RankedTensorType::get({}, rewriter.getF64Type());
2032 SmallVector<Type> regenResultTypes;
2033 regenResultTypes.push_back(subTraceType);
2034 regenResultTypes.push_back(scalarTy);
2035 for (auto t : genFn.getResultTypes())
2036 regenResultTypes.push_back(t);
2037
2038 auto nestedRegenerate = impulse::RegenerateOp::create(
2039 rewriter, sampleOp.getLoc(), regenResultTypes,
2040 sampleOp.getFnAttr(), sampleOp.getInputs(), subPrevTrace,
2041 subSelection, subRegenerateAddrs);
2042
2043 Value subTrace = nestedRegenerate.getNewTrace();
2044 Value subWeight = nestedRegenerate.getWeight();
2045
2046 weightAccumulator = arith::AddFOp::create(
2047 rewriter, sampleOp.getLoc(), weightAccumulator, subWeight);
2048
2049 auto i64S = RankedTensorType::get({}, rewriter.getI64Type());
2050 auto row0 = arith::ConstantOp::create(
2051 rewriter, sampleOp.getLoc(), i64S,
2052 DenseElementsAttr::get(i64S, rewriter.getI64IntegerAttr(0)));
2053 auto colOff = arith::ConstantOp::create(
2054 rewriter, sampleOp.getLoc(), i64S,
2055 DenseElementsAttr::get(
2056 i64S, rewriter.getI64IntegerAttr(mergeOffset)));
2057 currTrace = impulse::DynamicUpdateSliceOp::create(
2058 rewriter, sampleOp.getLoc(), traceType, currTrace,
2059 subTrace, ValueRange{row0, colOff})
2060 .getResult();
2061 currentTraceOffset =
2062 std::max(currentTraceOffset, mergeOffset + subPositionSize);
2063
2064 for (auto output : nestedRegenerate.getOutputs())
2065 sampledValues.push_back(output);
2066 }
2067 }
2068
2069 sampleOp.replaceAllUsesWith(sampledValues);
2070 toErase.push_back(sampleOp);
2071 return WalkResult::advance();
2072 });
2073
2074 for (Operation *op : toErase)
2075 rewriter.eraseOp(op);
2076
2077 if (result.wasInterrupted()) {
2078 CI.emitError("Impulse: failed to walk sample ops");
2079 return failure();
2080 }
2081
2082 NewF.walk([&](func::ReturnOp retOp) {
2083 OpBuilder::InsertionGuard guard(rewriter);
2084 rewriter.setInsertionPoint(retOp);
2085
2086 SmallVector<Value> newRetVals;
2087 newRetVals.push_back(currTrace);
2088 newRetVals.push_back(weightAccumulator);
2089 newRetVals.append(retOp.getOperands().begin(),
2090 retOp.getOperands().end());
2091
2092 func::ReturnOp::create(rewriter, retOp.getLoc(), newRetVals);
2093 rewriter.eraseOp(retOp);
2094 });
2095
2096 rewriter.setInsertionPoint(CI);
2097 SmallVector<Value> operands;
2098 operands.push_back(CI.getOriginalTrace());
2099 operands.append(CI.getInputs().begin(), CI.getInputs().end());
2100 auto newCI = func::CallOp::create(rewriter, CI.getLoc(), NewF.getName(),
2101 NewF.getResultTypes(), operands);
2102
2103 rewriter.replaceOp(CI, newCI.getResults());
2104
2105 delete putils;
2106
2107 return success();
2108 }
2109 };
2110};
2111
2112} // end anonymous namespace
2113
2114void ExpandImpulsePass::runOnOperation() {
2115 RewritePatternSet patterns(&getContext());
2116 patterns.add<LowerUntracedCallPattern, LowerSimulatePattern,
2117 LowerGeneratePattern, LowerMHPattern, LowerRegeneratePattern>(
2118 &getContext());
2119 patterns.add<LowerMCMCPattern>(&getContext(), debugDump);
2120
2121 mlir::GreedyRewriteConfig config;
2122
2123 if (failed(
2124 applyPatternsGreedily(getOperation(), std::move(patterns), config))) {
2125 signalPassFailure();
2126 return;
2127 }
2128
2129 if (!postpasses.empty()) {
2130 mlir::PassManager pm(getOperation()->getContext());
2131
2132 if (mlir::failed(mlir::parsePassPipeline(postpasses, pm))) {
2133 getOperation()->emitError()
2134 << "Failed to parse expand-impulse post-passes pipeline: "
2135 << postpasses;
2136 signalPassFailure();
2137 return;
2138 }
2139
2140 if (mlir::failed(pm.run(getOperation()))) {
2141 signalPassFailure();
2142 return;
2143 }
2144 }
2145}
PointerType * traceType(LLVMContext &C)
static ImpulseUtils * CreateFromClone(FunctionOpInterface toeval, ImpulseMode mode, int64_t positionSize=-1, int64_t constraintSize=-1)
Value finalizeWelford(OpBuilder &builder, Location loc, const WelfordState &state, const WelfordConfig &config)
Finalize Welford state to produce sample covariance (returned as inverse mass matrix).
SmallVector< AdaptWindow > buildAdaptationSchedule(int64_t numSteps)
Build warmup adaptation schedule.
Value conditionalDump(OpBuilder &builder, Location loc, Value value, StringRef label, bool debugDump)
Conditionally dump a value for debugging.
Definition HMCUtils.cpp:64
MCMCKernelResult SampleNUTS(OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const NUTSContext &ctx, bool debugDump=false)
Single NUTS iteration: momentum sampling + tree building.
Definition HMCUtils.cpp:998
InitialHMCState InitHMC(OpBuilder &builder, Location loc, Value rng, const HMCContext &ctx, Value initialPosition=Value(), bool debugDump=false)
Initializes HMC/NUTS state from a trace Specifically:
Definition HMCUtils.cpp:701
DualAveragingState updateDualAveraging(OpBuilder &builder, Location loc, const DualAveragingState &state, Value acceptProb, const DualAveragingConfig &config)
Update dual averaging state with observed acceptance probability.
DualAveragingState initDualAveraging(OpBuilder &builder, Location loc, Value stepSize)
Initialize dual averaging state from initial step size.
Value computeMassMatrixSqrt(OpBuilder &builder, Location loc, Value invMass, RankedTensorType positionType)
Computes the square root of the mass matrix from the inverse mass matrix.
Definition HMCUtils.cpp:190
WelfordState updateWelford(OpBuilder &builder, Location loc, const WelfordState &state, Value sample, const WelfordConfig &config)
Update Welford state with a new sample.
WelfordState initWelford(OpBuilder &builder, Location loc, int64_t positionSize, bool diagonal)
Initialize state for Welford covariance estimation.
Value getStepSizeFromDualAveraging(OpBuilder &builder, Location loc, const DualAveragingState &state, bool final=false)
Get step size from dual averaging state.
MCMCKernelResult SampleHMC(OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const HMCContext &ctx, bool debugDump=false)
Single HMC iteration: momentum sampling + leapfrog + MH accept/reject.
Definition HMCUtils.cpp:862
Value constrainPosition(OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports)
Transform an entire position vector from unconstrained to constrained space.
SmallVector< Type > getTypes() const
Definition HMCUtils.h:85
static DualAveragingState fromValues(ArrayRef< Value > values)
Definition HMCUtils.h:82
SmallVector< Value > toValues() const
Definition HMCUtils.h:78
Result of one MCMC kernel step.
Definition HMCUtils.h:62
Configuration for Welford covariance estimation.
Definition HMCUtils.h:420
State for Welford covariance estimation.
Definition HMCUtils.h:405
SmallVector< Value > toValues() const
Definition HMCUtils.h:410
SmallVector< Type > getTypes() const
Definition HMCUtils.h:414
static WelfordState fromValues(ArrayRef< Value > values)
Definition HMCUtils.h:411