Enzyme main
Loading...
Searching...
No Matches
mlir::impulse Namespace Reference

Classes

struct  AdaptWindow
 
struct  DualAveragingConfig
 
struct  DualAveragingState
 
struct  GradientResult
 
struct  HMCContext
 
class  ImpulseUtils
 
struct  InitialHMCState
 
struct  IntegrationResult
 
struct  IntegratorState
 
struct  MCMCKernelResult
 Result of one MCMC kernel step. More...
 
struct  NUTSContext
 
struct  NUTSTreeState
 
struct  SubtreeBuildResult
 
struct  SupportInfo
 
struct  WelfordConfig
 Configuration for Welford covariance estimation. More...
 
struct  WelfordState
 State for Welford covariance estimation. More...
 

Enumerations

enum class  ImpulseMode { Call = 0 , Simulate = 1 , Generate = 2 , Regenerate = 3 }
 

Functions

Value conditionalDump (OpBuilder &builder, Location loc, Value value, StringRef label, bool debugDump)
 Conditionally dump a value for debugging.
 
Value applyInverseMassMatrix (OpBuilder &builder, Location loc, Value invMass, Value momentum, RankedTensorType positionType)
 Computes v = M^-1 @ p If invMass is nullptr, returns momentum unchanged (assumes identity) If invMass is a diagonal matrix, computes v = invMass * momentum If invMass is a dense matrix, computes v = invMass @ momentum
 
Value computeKineticEnergy (OpBuilder &builder, Location loc, Value momentum, Value invMass, RankedTensorType positionType)
 Computes K = 0.5 * p^T @ M^-1 @ p
 
Value computeMassMatrixSqrt (OpBuilder &builder, Location loc, Value invMass, RankedTensorType positionType)
 Computes the square root of the mass matrix from the inverse mass matrix.
 
std::pair< Value, Value > sampleMomentum (OpBuilder &builder, Location loc, Value rng, Value invMass, Value massMatrixSqrt, RankedTensorType positionType, bool debugDump=false)
 Samples momentum from N(0, M) where M is the mass matrix.
 
GradientResult computePotentialAndGradient (OpBuilder &builder, Location loc, Value position, Value rng, const HMCContext &ctx)
 Computes potential energy U(q) = -log p(q) and its gradient dU/dq
 
IntegrationResult computeIntegrationStep (OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const HMCContext &ctx)
 Computes a single leapfrog integration step.
 
Value checkTurning (OpBuilder &builder, Location loc, Value pLeft, Value pRight, Value pSum, const NUTSContext &ctx)
 U-turn termination criterion.
 
Value computeUniformTransitionProb (OpBuilder &builder, Location loc, Value currentWeight, Value newWeight)
 Computes the uniform transition probability for subtree combination.
 
Value computeBiasedTransitionProb (OpBuilder &builder, Location loc, Value currentWeight, Value newWeight, Value turning, Value diverging)
 Computes the biased transition probability for main tree combination.
 
NUTSTreeState combineTrees (OpBuilder &builder, Location loc, const NUTSTreeState &tree, const NUTSTreeState &subTree, Value direction, Value rng, bool biased, const NUTSContext &ctx)
 Combines a tree with a newly-built subtree during NUTS doubling process.
 
InitialHMCState InitHMC (OpBuilder &builder, Location loc, Value rng, const HMCContext &ctx, Value initialPosition=Value(), bool debugDump=false)
 Initializes HMC/NUTS state from a trace Specifically:
 
MCMCKernelResult SampleHMC (OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const HMCContext &ctx, bool debugDump=false)
 Single HMC iteration: momentum sampling + leapfrog + MH accept/reject.
 
MCMCKernelResult SampleNUTS (OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const NUTSContext &ctx, bool debugDump=false)
 Single NUTS iteration: momentum sampling + tree building.
 
NUTSTreeState buildBaseTree (OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const NUTSContext &ctx)
 Builds a base tree (leaf node) by taking one leapfrog step.
 
IntegratorState getLeafFromTree (OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, const NUTSContext &ctx)
 Extracts an appropriate leaf based on direction.
 
SubtreeBuildResult buildIterativeSubtree (OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
 Builds a subtree iteratively by appending leaves one at a time.
 
SubtreeBuildResult doubleTree (OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
 Tree doubling by building a subtree of same depth and combining.
 
NUTSTreeState buildTree (OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, const NUTSContext &ctx, bool debugDump=false)
 Main NUTS tree building loop.
 
std::pair< Value, Value > leafIdxToCheckpointIdxs (OpBuilder &builder, Location loc, Value leafIdx)
 Computes checkpoint indices from leaf index for iterative turning check.
 
Value checkIterativeTurning (OpBuilder &builder, Location loc, Value p, Value pSum, Value pCkpts, Value pSumCkpts, Value idxMin, Value idxMax, const NUTSContext &ctx, bool debugDump=false)
 Checkpoint-based iterative turning check.
 
std::pair< Value, Value > updateCheckpoints (OpBuilder &builder, Location loc, Value leafIdx, Value ckptIdxMax, Value p, Value pSum, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
 Update checkpoint arrays at even leaf indices.
 
DualAveragingState initDualAveraging (OpBuilder &builder, Location loc, Value stepSize)
 Initialize dual averaging state from initial step size.
 
DualAveragingState updateDualAveraging (OpBuilder &builder, Location loc, const DualAveragingState &state, Value acceptProb, const DualAveragingConfig &config)
 Update dual averaging state with observed acceptance probability.
 
Value getStepSizeFromDualAveraging (OpBuilder &builder, Location loc, const DualAveragingState &state, bool final=false)
 Get step size from dual averaging state.
 
WelfordState initWelford (OpBuilder &builder, Location loc, int64_t positionSize, bool diagonal)
 Initialize state for Welford covariance estimation.
 
WelfordState updateWelford (OpBuilder &builder, Location loc, const WelfordState &state, Value sample, const WelfordConfig &config)
 Update Welford state with a new sample.
 
Value finalizeWelford (OpBuilder &builder, Location loc, const WelfordState &state, const WelfordConfig &config)
 Finalize Welford state to produce sample covariance (returned as inverse mass matrix).
 
SmallVector< AdaptWindowbuildAdaptationSchedule (int64_t numSteps)
 Build warmup adaptation schedule.
 
Value unconstrainPosition (OpBuilder &builder, Location loc, Value constrained, ArrayRef< SupportInfo > supports)
 Transform an entire position vector from constrained to unconstrained space based on the support information.
 
Value constrainPosition (OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports)
 Transform an entire position vector from unconstrained to constrained space.
 
Value computeTotalJacobianCorrection (OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports)
 Compute total Jacobian correction for the constrain transform over all position vector slices.
 

Enumeration Type Documentation

◆ ImpulseMode

enum class mlir::impulse::ImpulseMode
strong
Enumerator
Call 
Simulate 
Generate 
Regenerate 

Definition at line 24 of file ImpulseUtils.h.

Function Documentation

◆ applyInverseMassMatrix()

Value mlir::impulse::applyInverseMassMatrix ( OpBuilder & builder,
Location loc,
Value invMass,
Value momentum,
RankedTensorType positionType )

Computes v = M^-1 @ p If invMass is nullptr, returns momentum unchanged (assumes identity) If invMass is a diagonal matrix, computes v = invMass * momentum If invMass is a dense matrix, computes v = invMass @ momentum

Definition at line 139 of file HMCUtils.cpp.

Referenced by checkTurning(), computeIntegrationStep(), and computeKineticEnergy().

◆ buildAdaptationSchedule()

SmallVector< AdaptWindow > mlir::impulse::buildAdaptationSchedule ( int64_t numSteps)

Build warmup adaptation schedule.

TODO: Make customizable

Definition at line 1904 of file HMCUtils.cpp.

◆ buildBaseTree()

◆ buildIterativeSubtree()

◆ buildTree()

◆ checkIterativeTurning()

Value mlir::impulse::checkIterativeTurning ( OpBuilder & builder,
Location loc,
Value p,
Value pSum,
Value pCkpts,
Value pSumCkpts,
Value idxMin,
Value idxMax,
const NUTSContext & ctx,
bool debugDump = false )

Checkpoint-based iterative turning check.

Definition at line 1505 of file HMCUtils.cpp.

References checkTurning(), mlir::impulse::HMCContext::getPositionType(), and mlir::impulse::HMCContext::positionSize.

Referenced by buildIterativeSubtree().

◆ checkTurning()

Value mlir::impulse::checkTurning ( OpBuilder & builder,
Location loc,
Value pLeft,
Value pRight,
Value pSum,
const NUTSContext & ctx )

◆ combineTrees()

NUTSTreeState mlir::impulse::combineTrees ( OpBuilder & builder,
Location loc,
const NUTSTreeState & tree,
const NUTSTreeState & subTree,
Value direction,
Value rng,
bool biased,
const NUTSContext & ctx )

Combines a tree with a newly-built subtree during NUTS doubling process.

The subtree is merged into tree based on direction:

  • If going right (direction is true), then we keep tree's left boundary, takes subtree's right boundary
  • If going left (direction is false), then we take subtree's left boundary, keeps tree's right boundary

Proposal selection has two options:

  • Biased kernel (biased is true): prob = min(1, exp(subtree.weight - tree.weight)), zeroed if turning or diverging
  • Uniform kernel (biased is false): prob = sigmoid(subtree.weight - tree.weight)

Also checks the U-turn criterion on the combined momentum sum.

Definition at line 590 of file HMCUtils.cpp.

References checkTurning(), computeBiasedTransitionProb(), computeUniformTransitionProb(), mlir::impulse::NUTSTreeState::depth, mlir::impulse::NUTSTreeState::diverging, mlir::impulse::HMCContext::getElementType(), mlir::impulse::HMCContext::getPositionType(), mlir::impulse::HMCContext::getScalarType(), mlir::impulse::NUTSTreeState::grad_left, mlir::impulse::NUTSTreeState::grad_proposal, mlir::impulse::NUTSTreeState::grad_right, mlir::impulse::NUTSTreeState::H_proposal, mlir::impulse::NUTSTreeState::num_proposals, mlir::impulse::NUTSTreeState::p_left, mlir::impulse::NUTSTreeState::p_right, mlir::impulse::NUTSTreeState::p_sum, mlir::impulse::NUTSTreeState::q_left, mlir::impulse::NUTSTreeState::q_proposal, mlir::impulse::NUTSTreeState::q_right, mlir::impulse::NUTSTreeState::sum_accept_probs, mlir::impulse::NUTSTreeState::turning, mlir::impulse::NUTSTreeState::U_proposal, and mlir::impulse::NUTSTreeState::weight.

Referenced by buildIterativeSubtree(), and doubleTree().

◆ computeBiasedTransitionProb()

Value mlir::impulse::computeBiasedTransitionProb ( OpBuilder & builder,
Location loc,
Value currentWeight,
Value newWeight,
Value turning,
Value diverging )

Computes the biased transition probability for main tree combination.

Specifically: min(1, exp(new_weight - current_weight)), zeroed if turning or diverging.

Definition at line 566 of file HMCUtils.cpp.

Referenced by combineTrees().

◆ computeIntegrationStep()

IntegrationResult mlir::impulse::computeIntegrationStep ( OpBuilder & builder,
Location loc,
const IntegratorState & leaf,
Value rng,
Value direction,
const HMCContext & ctx )

Computes a single leapfrog integration step.

Specifically:

  • p_half = p - (eps/2) * grad
  • q_new = q + eps * M^-1 * p_half
  • grad_new = dU/dq(q_new)
  • p_new = p_half - (eps/2) * grad_new

Definition at line 460 of file HMCUtils.cpp.

References applyInverseMassMatrix(), computePotentialAndGradient(), mlir::impulse::HMCContext::getElementType(), mlir::impulse::HMCContext::getPositionType(), mlir::impulse::HMCContext::getScalarType(), mlir::impulse::IntegratorState::grad, mlir::impulse::HMCContext::invMass, mlir::impulse::IntegratorState::p, mlir::impulse::IntegratorState::q, and mlir::impulse::HMCContext::stepSize.

Referenced by buildBaseTree(), and SampleHMC().

◆ computeKineticEnergy()

Value mlir::impulse::computeKineticEnergy ( OpBuilder & builder,
Location loc,
Value momentum,
Value invMass,
RankedTensorType positionType )

Computes K = 0.5 * p^T @ M^-1 @ p

Definition at line 166 of file HMCUtils.cpp.

References applyInverseMassMatrix().

Referenced by buildBaseTree(), SampleHMC(), and SampleNUTS().

◆ computeMassMatrixSqrt()

Value mlir::impulse::computeMassMatrixSqrt ( OpBuilder & builder,
Location loc,
Value invMass,
RankedTensorType positionType )

Computes the square root of the mass matrix from the inverse mass matrix.

Definition at line 190 of file HMCUtils.cpp.

References createIdentityMatrix(), and reverseRowsAndColumns().

◆ computePotentialAndGradient()

◆ computeTotalJacobianCorrection()

Value mlir::impulse::computeTotalJacobianCorrection ( OpBuilder & builder,
Location loc,
Value unconstrained,
ArrayRef< SupportInfo > supports )

Compute total Jacobian correction for the constrain transform over all position vector slices.

Definition at line 2134 of file HMCUtils.cpp.

References mlir::enzyme::transforms::logAbsDetJacobian().

Referenced by computePotentialAndGradient(), and InitHMC().

◆ computeUniformTransitionProb()

Value mlir::impulse::computeUniformTransitionProb ( OpBuilder & builder,
Location loc,
Value currentWeight,
Value newWeight )

Computes the uniform transition probability for subtree combination.

Specifically: sigmoid(new_weight - current_weight).

Definition at line 557 of file HMCUtils.cpp.

Referenced by combineTrees().

◆ conditionalDump()

Value mlir::impulse::conditionalDump ( OpBuilder & builder,
Location loc,
Value value,
StringRef label,
bool debugDump )

Conditionally dump a value for debugging.

Emits an enzyme::DumpOp if debugDump is true; otherwise has no effect.

Definition at line 64 of file HMCUtils.cpp.

Referenced by sampleMomentum().

◆ constrainPosition()

Value mlir::impulse::constrainPosition ( OpBuilder & builder,
Location loc,
Value unconstrained,
ArrayRef< SupportInfo > supports )

Transform an entire position vector from unconstrained to constrained space.

based on the support information.

Definition at line 2044 of file HMCUtils.cpp.

References mlir::enzyme::transforms::constrain().

Referenced by computePotentialAndGradient(), and InitHMC().

◆ doubleTree()

SubtreeBuildResult mlir::impulse::doubleTree ( OpBuilder & builder,
Location loc,
const NUTSTreeState & tree,
Value direction,
Value pCkpts,
Value pSumCkpts,
const NUTSContext & ctx,
bool debugDump = false )

Tree doubling by building a subtree of same depth and combining.

Definition at line 1331 of file HMCUtils.cpp.

References buildIterativeSubtree(), combineTrees(), and mlir::impulse::NUTSTreeState::rng.

Referenced by buildTree().

◆ finalizeWelford()

Value mlir::impulse::finalizeWelford ( OpBuilder & builder,
Location loc,
const WelfordState & state,
const WelfordConfig & config )

Finalize Welford state to produce sample covariance (returned as inverse mass matrix).

Definition at line 1840 of file HMCUtils.cpp.

References createIdentityMatrix(), mlir::impulse::WelfordConfig::diagonal, mlir::impulse::WelfordState::m2, mlir::impulse::WelfordState::n, and mlir::impulse::WelfordConfig::regularize.

◆ getLeafFromTree()

IntegratorState mlir::impulse::getLeafFromTree ( OpBuilder & builder,
Location loc,
const NUTSTreeState & tree,
Value direction,
const NUTSContext & ctx )

◆ getStepSizeFromDualAveraging()

Value mlir::impulse::getStepSizeFromDualAveraging ( OpBuilder & builder,
Location loc,
const DualAveragingState & state,
bool final = false )

Get step size from dual averaging state.

If final is true, returns the averaged step size. Otherwise, returns the updated step size.

Definition at line 1749 of file HMCUtils.cpp.

References mlir::impulse::DualAveragingState::log_step_size, and mlir::impulse::DualAveragingState::log_step_size_avg.

◆ initDualAveraging()

DualAveragingState mlir::impulse::initDualAveraging ( OpBuilder & builder,
Location loc,
Value stepSize )

Initialize dual averaging state from initial step size.

Definition at line 1624 of file HMCUtils.cpp.

◆ InitHMC()

InitialHMCState mlir::impulse::InitHMC ( OpBuilder & builder,
Location loc,
Value rng,
const HMCContext & ctx,
Value initialPosition = Value(),
bool debugDump = false )

◆ initWelford()

WelfordState mlir::impulse::initWelford ( OpBuilder & builder,
Location loc,
int64_t positionSize,
bool diagonal )

Initialize state for Welford covariance estimation.

Definition at line 1756 of file HMCUtils.cpp.

◆ leafIdxToCheckpointIdxs()

std::pair< Value, Value > mlir::impulse::leafIdxToCheckpointIdxs ( OpBuilder & builder,
Location loc,
Value leafIdx )

Computes checkpoint indices from leaf index for iterative turning check.

Definition at line 1468 of file HMCUtils.cpp.

Referenced by buildIterativeSubtree().

◆ SampleHMC()

◆ sampleMomentum()

std::pair< Value, Value > mlir::impulse::sampleMomentum ( OpBuilder & builder,
Location loc,
Value rng,
Value invMass,
Value massMatrixSqrt,
RankedTensorType positionType,
bool debugDump = false )

Samples momentum from N(0, M) where M is the mass matrix.

Returns (momentum, updated_rng_state).

Definition at line 233 of file HMCUtils.cpp.

References conditionalDump().

Referenced by SampleHMC(), and SampleNUTS().

◆ SampleNUTS()

MCMCKernelResult mlir::impulse::SampleNUTS ( OpBuilder & builder,
Location loc,
Value q,
Value grad,
Value U,
Value rng,
const NUTSContext & ctx,
bool debugDump = false )

◆ unconstrainPosition()

Value mlir::impulse::unconstrainPosition ( OpBuilder & builder,
Location loc,
Value constrained,
ArrayRef< SupportInfo > supports )

Transform an entire position vector from constrained to unconstrained space based on the support information.

Definition at line 1955 of file HMCUtils.cpp.

References mlir::enzyme::transforms::unconstrain().

Referenced by InitHMC().

◆ updateCheckpoints()

std::pair< Value, Value > mlir::impulse::updateCheckpoints ( OpBuilder & builder,
Location loc,
Value leafIdx,
Value ckptIdxMax,
Value p,
Value pSum,
Value pCkpts,
Value pSumCkpts,
const NUTSContext & ctx,
bool debugDump = false )

Update checkpoint arrays at even leaf indices.

Definition at line 1575 of file HMCUtils.cpp.

Referenced by buildIterativeSubtree().

◆ updateDualAveraging()

◆ updateWelford()

WelfordState mlir::impulse::updateWelford ( OpBuilder & builder,
Location loc,
const WelfordState & state,
Value sample,
const WelfordConfig & config )