11#ifndef ENZYME_MLIR_INTERFACES_HMC_UTILS_H
12#define ENZYME_MLIR_INTERFACES_HMC_UTILS_H
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/Math/IR/Math.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/Value.h"
83 return {values[0], values[1], values[2], values[3], values[4]};
146 return cast<RankedTensorType>(
stepSize.getType()).getElementType();
159 if (info.support && info.support.getKind() != impulse::SupportKind::REAL)
213 SmallVector<Value>
toValues()
const;
232 StringRef label,
bool debugDump);
239 Value momentum, RankedTensorType positionType);
243 Value invMass, RankedTensorType positionType);
247 RankedTensorType positionType);
251std::pair<Value, Value>
sampleMomentum(OpBuilder &builder, Location loc,
252 Value rng, Value invMass,
253 Value massMatrixSqrt,
254 RankedTensorType positionType,
255 bool debugDump =
false);
259 Value position, Value rng,
275Value
checkTurning(OpBuilder &builder, Location loc, Value pLeft, Value pRight,
281 Value currentWeight, Value newWeight);
287 Value currentWeight, Value newWeight,
288 Value turning, Value diverging);
319 const HMCContext &ctx, Value initialPosition = Value(),
320 bool debugDump =
false);
324 Value grad, Value U, Value rng,
325 const HMCContext &ctx,
bool debugDump =
false);
329 Value grad, Value U, Value rng,
345 Value direction, Value pCkpts,
348 bool debugDump =
false);
353 Value pCkpts, Value pSumCkpts,
363 Location loc, Value leafIdx);
367 Value pSum, Value pCkpts, Value pSumCkpts,
368 Value idxMin, Value idxMax,
const NUTSContext &ctx,
369 bool debugDump =
false);
373 Value leafIdx, Value ckptIdxMax,
374 Value p, Value pSum, Value pCkpts,
377 bool debugDump =
false);
412 return {values[0], values[1], values[2]};
415 return {
mean.getType(),
m2.getType(),
n.getType()};
451 ArrayRef<SupportInfo> supports);
456 ArrayRef<SupportInfo> supports);
462 ArrayRef<SupportInfo> supports);
PointerType * traceType(LLVMContext &C)
Value computeTotalJacobianCorrection(OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports)
Compute total Jacobian correction for the constrain transform over all position vector slices.
Value finalizeWelford(OpBuilder &builder, Location loc, const WelfordState &state, const WelfordConfig &config)
Finalize Welford state to produce sample covariance (returned as inverse mass matrix).
Value checkTurning(OpBuilder &builder, Location loc, Value pLeft, Value pRight, Value pSum, const NUTSContext &ctx)
U-turn termination criterion.
GradientResult computePotentialAndGradient(OpBuilder &builder, Location loc, Value position, Value rng, const HMCContext &ctx)
Computes potential energy U(q) = -log p(q) and its gradient dU/dq
std::pair< Value, Value > leafIdxToCheckpointIdxs(OpBuilder &builder, Location loc, Value leafIdx)
Computes checkpoint indices from leaf index for iterative turning check.
IntegratorState getLeafFromTree(OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, const NUTSContext &ctx)
Extracts an appropriate leaf based on direction.
std::pair< Value, Value > updateCheckpoints(OpBuilder &builder, Location loc, Value leafIdx, Value ckptIdxMax, Value p, Value pSum, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
Update checkpoint arrays at even leaf indices.
SubtreeBuildResult buildIterativeSubtree(OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
Builds a subtree iteratively by appending leaves one at a time.
std::pair< Value, Value > sampleMomentum(OpBuilder &builder, Location loc, Value rng, Value invMass, Value massMatrixSqrt, RankedTensorType positionType, bool debugDump=false)
Samples momentum from N(0, M) where M is the mass matrix.
Value computeUniformTransitionProb(OpBuilder &builder, Location loc, Value currentWeight, Value newWeight)
Computes the uniform transition probability for subtree combination.
SmallVector< AdaptWindow > buildAdaptationSchedule(int64_t numSteps)
Build warmup adaptation schedule.
NUTSTreeState buildBaseTree(OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const NUTSContext &ctx)
Builds a base tree (leaf node) by taking one leapfrog step.
Value computeKineticEnergy(OpBuilder &builder, Location loc, Value momentum, Value invMass, RankedTensorType positionType)
Computes K = 0.5 * p^T @ M^-1 @ p
NUTSTreeState combineTrees(OpBuilder &builder, Location loc, const NUTSTreeState &tree, const NUTSTreeState &subTree, Value direction, Value rng, bool biased, const NUTSContext &ctx)
Combines a tree with a newly-built subtree during NUTS doubling process.
Value computeBiasedTransitionProb(OpBuilder &builder, Location loc, Value currentWeight, Value newWeight, Value turning, Value diverging)
Computes the biased transition probability for main tree combination.
Value checkIterativeTurning(OpBuilder &builder, Location loc, Value p, Value pSum, Value pCkpts, Value pSumCkpts, Value idxMin, Value idxMax, const NUTSContext &ctx, bool debugDump=false)
Checkpoint-based iterative turning check.
NUTSTreeState buildTree(OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, const NUTSContext &ctx, bool debugDump=false)
Main NUTS tree building loop.
Value conditionalDump(OpBuilder &builder, Location loc, Value value, StringRef label, bool debugDump)
Conditionally dump a value for debugging.
MCMCKernelResult SampleNUTS(OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const NUTSContext &ctx, bool debugDump=false)
Single NUTS iteration: momentum sampling + tree building.
InitialHMCState InitHMC(OpBuilder &builder, Location loc, Value rng, const HMCContext &ctx, Value initialPosition=Value(), bool debugDump=false)
Initializes HMC/NUTS state from a trace Specifically:
SubtreeBuildResult doubleTree(OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
Tree doubling by building a subtree of same depth and combining.
IntegrationResult computeIntegrationStep(OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const HMCContext &ctx)
Computes a single leapfrog integration step.
DualAveragingState updateDualAveraging(OpBuilder &builder, Location loc, const DualAveragingState &state, Value acceptProb, const DualAveragingConfig &config)
Update dual averaging state with observed acceptance probability.
Value unconstrainPosition(OpBuilder &builder, Location loc, Value constrained, ArrayRef< SupportInfo > supports)
Transform an entire position vector from constrained to unconstrained space based on the support info...
DualAveragingState initDualAveraging(OpBuilder &builder, Location loc, Value stepSize)
Initialize dual averaging state from initial step size.
Value computeMassMatrixSqrt(OpBuilder &builder, Location loc, Value invMass, RankedTensorType positionType)
Computes the square root of the mass matrix from the inverse mass matrix.
WelfordState updateWelford(OpBuilder &builder, Location loc, const WelfordState &state, Value sample, const WelfordConfig &config)
Update Welford state with a new sample.
WelfordState initWelford(OpBuilder &builder, Location loc, int64_t positionSize, bool diagonal)
Initialize state for Welford covariance estimation.
Value getStepSizeFromDualAveraging(OpBuilder &builder, Location loc, const DualAveragingState &state, bool final=false)
Get step size from dual averaging state.
MCMCKernelResult SampleHMC(OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const HMCContext &ctx, bool debugDump=false)
Single HMC iteration: momentum sampling + leapfrog + MH accept/reject.
Value constrainPosition(OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports)
Transform an entire position vector from unconstrained to constrained space.
Value applyInverseMassMatrix(OpBuilder &builder, Location loc, Value invMass, Value momentum, RankedTensorType positionType)
Computes v = M^-1 @ p If invMass is nullptr, returns momentum unchanged (assumes identity) If invMass...
double target_accept_prob
SmallVector< Type > getTypes() const
static DualAveragingState fromValues(ArrayRef< Value > values)
SmallVector< Value > toValues() const
DictionaryAttr autodiffAttrs
bool hasConstrainedSupports() const
Type getElementType() const
HMCContext(FlatSymbolRefAttr logpdfFn, ArrayRef< Value > fnInputs, Value invMass, Value massMatrixSqrt, Value stepSize, Value trajectoryLength, int64_t positionSize, DictionaryAttr autodiffAttrs={})
HMCContext withStepSize(Value newStepSize) const
ArrayRef< Value > fnInputs
HMCContext(FlatSymbolRefAttr fn, ArrayRef< Value > fnInputs, ArrayRef< Type > fnResultTypes, Value originalTrace, ArrayAttr selection, ArrayAttr allAddresses, Value invMass, Value massMatrixSqrt, Value stepSize, Value trajectoryLength, int64_t positionSize, ArrayRef< SupportInfo > supports, DictionaryAttr autodiffAttrs={})
RankedTensorType getPositionType() const
int64_t getFullTraceSize() const
SmallVector< Type > fnResultTypes
SmallVector< SupportInfo > supports
bool hasCustomLogpdf() const
RankedTensorType getScalarType() const
FlatSymbolRefAttr logpdfFn
Result of one MCMC kernel step.
NUTSContext withH0(Value newH0) const
NUTSContext(FlatSymbolRefAttr fn, ArrayRef< Value > fnInputs, ArrayRef< Type > fnResultTypes, Value originalTrace, ArrayAttr selection, ArrayAttr allAddresses, Value invMass, Value massMatrixSqrt, Value stepSize, int64_t positionSize, ArrayRef< SupportInfo > supports, Value H0, Value maxDeltaEnergy, int64_t maxTreeDepth, DictionaryAttr autodiffAttrs={})
NUTSContext(FlatSymbolRefAttr logpdfFn, ArrayRef< Value > fnInputs, Value invMass, Value massMatrixSqrt, Value stepSize, int64_t positionSize, Value H0, Value maxDeltaEnergy, int64_t maxTreeDepth, DictionaryAttr autodiffAttrs={})
IntegratorState getRightLeaf() const
SmallVector< Value > toValues() const
IntegratorState getLeftLeaf() const
SmallVector< Type > getTypes() const
static NUTSTreeState fromValues(ArrayRef< Value > values)
SupportInfo(int64_t offset, int64_t traceOffset, int64_t size, impulse::SupportAttr support)
impulse::SupportAttr support
Configuration for Welford covariance estimation.
State for Welford covariance estimation.
SmallVector< Value > toValues() const
SmallVector< Type > getTypes() const
static WelfordState fromValues(ArrayRef< Value > values)