|
Enzyme main
|
Classes | |
| struct | AdaptWindow |
| struct | DualAveragingConfig |
| struct | DualAveragingState |
| struct | GradientResult |
| struct | HMCContext |
| class | ImpulseUtils |
| struct | InitialHMCState |
| struct | IntegrationResult |
| struct | IntegratorState |
| struct | MCMCKernelResult |
| Result of one MCMC kernel step. More... | |
| struct | NUTSContext |
| struct | NUTSTreeState |
| struct | SubtreeBuildResult |
| struct | SupportInfo |
| struct | WelfordConfig |
| Configuration for Welford covariance estimation. More... | |
| struct | WelfordState |
| State for Welford covariance estimation. More... | |
Enumerations | |
| enum class | ImpulseMode { Call = 0 , Simulate = 1 , Generate = 2 , Regenerate = 3 } |
Functions | |
| Value | conditionalDump (OpBuilder &builder, Location loc, Value value, StringRef label, bool debugDump) |
| Conditionally dump a value for debugging. | |
| Value | applyInverseMassMatrix (OpBuilder &builder, Location loc, Value invMass, Value momentum, RankedTensorType positionType) |
Computes v = M^-1 @ p If invMass is nullptr, returns momentum unchanged (assumes identity) If invMass is a diagonal matrix, computes v = invMass * momentum If invMass is a dense matrix, computes v = invMass @ momentum | |
| Value | computeKineticEnergy (OpBuilder &builder, Location loc, Value momentum, Value invMass, RankedTensorType positionType) |
Computes K = 0.5 * p^T @ M^-1 @ p | |
| Value | computeMassMatrixSqrt (OpBuilder &builder, Location loc, Value invMass, RankedTensorType positionType) |
| Computes the square root of the mass matrix from the inverse mass matrix. | |
| std::pair< Value, Value > | sampleMomentum (OpBuilder &builder, Location loc, Value rng, Value invMass, Value massMatrixSqrt, RankedTensorType positionType, bool debugDump=false) |
Samples momentum from N(0, M) where M is the mass matrix. | |
| GradientResult | computePotentialAndGradient (OpBuilder &builder, Location loc, Value position, Value rng, const HMCContext &ctx) |
Computes potential energy U(q) = -log p(q) and its gradient dU/dq | |
| IntegrationResult | computeIntegrationStep (OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const HMCContext &ctx) |
| Computes a single leapfrog integration step. | |
| Value | checkTurning (OpBuilder &builder, Location loc, Value pLeft, Value pRight, Value pSum, const NUTSContext &ctx) |
| U-turn termination criterion. | |
| Value | computeUniformTransitionProb (OpBuilder &builder, Location loc, Value currentWeight, Value newWeight) |
| Computes the uniform transition probability for subtree combination. | |
| Value | computeBiasedTransitionProb (OpBuilder &builder, Location loc, Value currentWeight, Value newWeight, Value turning, Value diverging) |
| Computes the biased transition probability for main tree combination. | |
| NUTSTreeState | combineTrees (OpBuilder &builder, Location loc, const NUTSTreeState &tree, const NUTSTreeState &subTree, Value direction, Value rng, bool biased, const NUTSContext &ctx) |
| Combines a tree with a newly-built subtree during NUTS doubling process. | |
| InitialHMCState | InitHMC (OpBuilder &builder, Location loc, Value rng, const HMCContext &ctx, Value initialPosition=Value(), bool debugDump=false) |
| Initializes HMC/NUTS state from a trace Specifically: | |
| MCMCKernelResult | SampleHMC (OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const HMCContext &ctx, bool debugDump=false) |
| Single HMC iteration: momentum sampling + leapfrog + MH accept/reject. | |
| MCMCKernelResult | SampleNUTS (OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const NUTSContext &ctx, bool debugDump=false) |
| Single NUTS iteration: momentum sampling + tree building. | |
| NUTSTreeState | buildBaseTree (OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const NUTSContext &ctx) |
| Builds a base tree (leaf node) by taking one leapfrog step. | |
| IntegratorState | getLeafFromTree (OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, const NUTSContext &ctx) |
| Extracts an appropriate leaf based on direction. | |
| SubtreeBuildResult | buildIterativeSubtree (OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false) |
| Builds a subtree iteratively by appending leaves one at a time. | |
| SubtreeBuildResult | doubleTree (OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false) |
| Tree doubling by building a subtree of same depth and combining. | |
| NUTSTreeState | buildTree (OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, const NUTSContext &ctx, bool debugDump=false) |
| Main NUTS tree building loop. | |
| std::pair< Value, Value > | leafIdxToCheckpointIdxs (OpBuilder &builder, Location loc, Value leafIdx) |
| Computes checkpoint indices from leaf index for iterative turning check. | |
| Value | checkIterativeTurning (OpBuilder &builder, Location loc, Value p, Value pSum, Value pCkpts, Value pSumCkpts, Value idxMin, Value idxMax, const NUTSContext &ctx, bool debugDump=false) |
| Checkpoint-based iterative turning check. | |
| std::pair< Value, Value > | updateCheckpoints (OpBuilder &builder, Location loc, Value leafIdx, Value ckptIdxMax, Value p, Value pSum, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false) |
| Update checkpoint arrays at even leaf indices. | |
| DualAveragingState | initDualAveraging (OpBuilder &builder, Location loc, Value stepSize) |
| Initialize dual averaging state from initial step size. | |
| DualAveragingState | updateDualAveraging (OpBuilder &builder, Location loc, const DualAveragingState &state, Value acceptProb, const DualAveragingConfig &config) |
| Update dual averaging state with observed acceptance probability. | |
| Value | getStepSizeFromDualAveraging (OpBuilder &builder, Location loc, const DualAveragingState &state, bool final=false) |
| Get step size from dual averaging state. | |
| WelfordState | initWelford (OpBuilder &builder, Location loc, int64_t positionSize, bool diagonal) |
| Initialize state for Welford covariance estimation. | |
| WelfordState | updateWelford (OpBuilder &builder, Location loc, const WelfordState &state, Value sample, const WelfordConfig &config) |
| Update Welford state with a new sample. | |
| Value | finalizeWelford (OpBuilder &builder, Location loc, const WelfordState &state, const WelfordConfig &config) |
| Finalize Welford state to produce sample covariance (returned as inverse mass matrix). | |
| SmallVector< AdaptWindow > | buildAdaptationSchedule (int64_t numSteps) |
| Build warmup adaptation schedule. | |
| Value | unconstrainPosition (OpBuilder &builder, Location loc, Value constrained, ArrayRef< SupportInfo > supports) |
| Transform an entire position vector from constrained to unconstrained space based on the support information. | |
| Value | constrainPosition (OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports) |
| Transform an entire position vector from unconstrained to constrained space. | |
| Value | computeTotalJacobianCorrection (OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports) |
| Compute total Jacobian correction for the constrain transform over all position vector slices. | |
|
strong |
| Enumerator | |
|---|---|
| Call | |
| Simulate | |
| Generate | |
| Regenerate | |
Definition at line 24 of file ImpulseUtils.h.
| Value mlir::impulse::applyInverseMassMatrix | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | invMass, | ||
| Value | momentum, | ||
| RankedTensorType | positionType ) |
Computes v = M^-1 @ p If invMass is nullptr, returns momentum unchanged (assumes identity) If invMass is a diagonal matrix, computes v = invMass * momentum If invMass is a dense matrix, computes v = invMass @ momentum
Definition at line 139 of file HMCUtils.cpp.
Referenced by checkTurning(), computeIntegrationStep(), and computeKineticEnergy().
| SmallVector< AdaptWindow > mlir::impulse::buildAdaptationSchedule | ( | int64_t | numSteps | ) |
Build warmup adaptation schedule.
TODO: Make customizable
Definition at line 1904 of file HMCUtils.cpp.
| NUTSTreeState mlir::impulse::buildBaseTree | ( | OpBuilder & | builder, |
| Location | loc, | ||
| const IntegratorState & | leaf, | ||
| Value | rng, | ||
| Value | direction, | ||
| const NUTSContext & | ctx ) |
Builds a base tree (leaf node) by taking one leapfrog step.
Definition at line 1088 of file HMCUtils.cpp.
References computeIntegrationStep(), computeKineticEnergy(), mlir::impulse::HMCContext::getElementType(), mlir::impulse::HMCContext::getPositionType(), mlir::impulse::HMCContext::getScalarType(), mlir::impulse::IntegrationResult::grad, mlir::impulse::NUTSContext::H0, mlir::impulse::HMCContext::invMass, mlir::impulse::NUTSContext::maxDeltaEnergy, mlir::impulse::IntegrationResult::p, mlir::impulse::IntegrationResult::q, mlir::impulse::IntegrationResult::rng, and mlir::impulse::IntegrationResult::U.
Referenced by buildIterativeSubtree().
| SubtreeBuildResult mlir::impulse::buildIterativeSubtree | ( | OpBuilder & | builder, |
| Location | loc, | ||
| const NUTSTreeState & | initialTree, | ||
| Value | direction, | ||
| Value | pCkpts, | ||
| Value | pSumCkpts, | ||
| const NUTSContext & | ctx, | ||
| bool | debugDump = false ) |
Builds a subtree iteratively by appending leaves one at a time.
Definition at line 1182 of file HMCUtils.cpp.
References buildBaseTree(), checkIterativeTurning(), combineTrees(), mlir::impulse::NUTSTreeState::depth, mlir::impulse::NUTSTreeState::diverging, mlir::impulse::NUTSTreeState::fromValues(), getLeafFromTree(), mlir::impulse::NUTSTreeState::getTypes(), leafIdxToCheckpointIdxs(), mlir::impulse::NUTSTreeState::num_proposals, mlir::impulse::NUTSTreeState::p_right, mlir::impulse::NUTSTreeState::p_sum, mlir::impulse::NUTSTreeState::rng, mlir::impulse::NUTSTreeState::toValues(), mlir::impulse::NUTSTreeState::turning, and updateCheckpoints().
Referenced by doubleTree().
| NUTSTreeState mlir::impulse::buildTree | ( | OpBuilder & | builder, |
| Location | loc, | ||
| const NUTSTreeState & | initialTree, | ||
| const NUTSContext & | ctx, | ||
| bool | debugDump = false ) |
Main NUTS tree building loop.
Definition at line 1355 of file HMCUtils.cpp.
References mlir::impulse::NUTSTreeState::depth, mlir::impulse::NUTSTreeState::diverging, doubleTree(), mlir::impulse::NUTSTreeState::fromValues(), mlir::impulse::NUTSTreeState::getTypes(), mlir::impulse::NUTSContext::maxTreeDepth, mlir::impulse::HMCContext::positionSize, mlir::impulse::NUTSTreeState::rng, mlir::impulse::HMCContext::stepSize, mlir::impulse::NUTSTreeState::toValues(), and mlir::impulse::NUTSTreeState::turning.
Referenced by SampleNUTS().
| Value mlir::impulse::checkIterativeTurning | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | p, | ||
| Value | pSum, | ||
| Value | pCkpts, | ||
| Value | pSumCkpts, | ||
| Value | idxMin, | ||
| Value | idxMax, | ||
| const NUTSContext & | ctx, | ||
| bool | debugDump = false ) |
Checkpoint-based iterative turning check.
Definition at line 1505 of file HMCUtils.cpp.
References checkTurning(), mlir::impulse::HMCContext::getPositionType(), and mlir::impulse::HMCContext::positionSize.
Referenced by buildIterativeSubtree().
| Value mlir::impulse::checkTurning | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | pLeft, | ||
| Value | pRight, | ||
| Value | pSum, | ||
| const NUTSContext & | ctx ) |
U-turn termination criterion.
Returns true if U-turn detected in the trajectory.
Definition at line 509 of file HMCUtils.cpp.
References applyInverseMassMatrix(), mlir::impulse::HMCContext::getElementType(), mlir::impulse::HMCContext::getPositionType(), mlir::impulse::HMCContext::getScalarType(), and mlir::impulse::HMCContext::invMass.
Referenced by checkIterativeTurning(), and combineTrees().
| NUTSTreeState mlir::impulse::combineTrees | ( | OpBuilder & | builder, |
| Location | loc, | ||
| const NUTSTreeState & | tree, | ||
| const NUTSTreeState & | subTree, | ||
| Value | direction, | ||
| Value | rng, | ||
| bool | biased, | ||
| const NUTSContext & | ctx ) |
Combines a tree with a newly-built subtree during NUTS doubling process.
The subtree is merged into tree based on direction:
direction is true), then we keep tree's left boundary, takes subtree's right boundarydirection is false), then we take subtree's left boundary, keeps tree's right boundaryProposal selection has two options:
biased is true): prob = min(1, exp(subtree.weight - tree.weight)), zeroed if turning or divergingbiased is false): prob = sigmoid(subtree.weight - tree.weight)Also checks the U-turn criterion on the combined momentum sum.
Definition at line 590 of file HMCUtils.cpp.
References checkTurning(), computeBiasedTransitionProb(), computeUniformTransitionProb(), mlir::impulse::NUTSTreeState::depth, mlir::impulse::NUTSTreeState::diverging, mlir::impulse::HMCContext::getElementType(), mlir::impulse::HMCContext::getPositionType(), mlir::impulse::HMCContext::getScalarType(), mlir::impulse::NUTSTreeState::grad_left, mlir::impulse::NUTSTreeState::grad_proposal, mlir::impulse::NUTSTreeState::grad_right, mlir::impulse::NUTSTreeState::H_proposal, mlir::impulse::NUTSTreeState::num_proposals, mlir::impulse::NUTSTreeState::p_left, mlir::impulse::NUTSTreeState::p_right, mlir::impulse::NUTSTreeState::p_sum, mlir::impulse::NUTSTreeState::q_left, mlir::impulse::NUTSTreeState::q_proposal, mlir::impulse::NUTSTreeState::q_right, mlir::impulse::NUTSTreeState::sum_accept_probs, mlir::impulse::NUTSTreeState::turning, mlir::impulse::NUTSTreeState::U_proposal, and mlir::impulse::NUTSTreeState::weight.
Referenced by buildIterativeSubtree(), and doubleTree().
| Value mlir::impulse::computeBiasedTransitionProb | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | currentWeight, | ||
| Value | newWeight, | ||
| Value | turning, | ||
| Value | diverging ) |
Computes the biased transition probability for main tree combination.
Specifically: min(1, exp(new_weight - current_weight)), zeroed if turning or diverging.
Definition at line 566 of file HMCUtils.cpp.
Referenced by combineTrees().
| IntegrationResult mlir::impulse::computeIntegrationStep | ( | OpBuilder & | builder, |
| Location | loc, | ||
| const IntegratorState & | leaf, | ||
| Value | rng, | ||
| Value | direction, | ||
| const HMCContext & | ctx ) |
Computes a single leapfrog integration step.
Specifically:
p_half = p - (eps/2) * gradq_new = q + eps * M^-1 * p_halfgrad_new = dU/dq(q_new)p_new = p_half - (eps/2) * grad_new Definition at line 460 of file HMCUtils.cpp.
References applyInverseMassMatrix(), computePotentialAndGradient(), mlir::impulse::HMCContext::getElementType(), mlir::impulse::HMCContext::getPositionType(), mlir::impulse::HMCContext::getScalarType(), mlir::impulse::IntegratorState::grad, mlir::impulse::HMCContext::invMass, mlir::impulse::IntegratorState::p, mlir::impulse::IntegratorState::q, and mlir::impulse::HMCContext::stepSize.
Referenced by buildBaseTree(), and SampleHMC().
| Value mlir::impulse::computeKineticEnergy | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | momentum, | ||
| Value | invMass, | ||
| RankedTensorType | positionType ) |
Computes K = 0.5 * p^T @ M^-1 @ p
Definition at line 166 of file HMCUtils.cpp.
References applyInverseMassMatrix().
Referenced by buildBaseTree(), SampleHMC(), and SampleNUTS().
| Value mlir::impulse::computeMassMatrixSqrt | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | invMass, | ||
| RankedTensorType | positionType ) |
Computes the square root of the mass matrix from the inverse mass matrix.
Definition at line 190 of file HMCUtils.cpp.
References createIdentityMatrix(), and reverseRowsAndColumns().
| GradientResult mlir::impulse::computePotentialAndGradient | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | position, | ||
| Value | rng, | ||
| const HMCContext & | ctx ) |
Computes potential energy U(q) = -log p(q) and its gradient dU/dq
Definition at line 352 of file HMCUtils.cpp.
References mlir::impulse::HMCContext::allAddresses, mlir::impulse::HMCContext::autodiffAttrs, computeTotalJacobianCorrection(), constrainPosition(), mlir::impulse::HMCContext::fn, mlir::impulse::HMCContext::fnInputs, mlir::impulse::HMCContext::fnResultTypes, mlir::impulse::HMCContext::getElementType(), mlir::impulse::HMCContext::getFullTraceSize(), mlir::impulse::HMCContext::getPositionType(), mlir::impulse::HMCContext::getScalarType(), mlir::impulse::HMCContext::hasCustomLogpdf(), mlir::impulse::HMCContext::logpdfFn, mlir::impulse::HMCContext::originalTrace, mlir::impulse::HMCContext::positionSize, scatterPositionToTrace(), mlir::impulse::HMCContext::supports, and traceType().
Referenced by computeIntegrationStep().
| Value mlir::impulse::computeTotalJacobianCorrection | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | unconstrained, | ||
| ArrayRef< SupportInfo > | supports ) |
Compute total Jacobian correction for the constrain transform over all position vector slices.
Definition at line 2134 of file HMCUtils.cpp.
References mlir::enzyme::transforms::logAbsDetJacobian().
Referenced by computePotentialAndGradient(), and InitHMC().
| Value mlir::impulse::computeUniformTransitionProb | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | currentWeight, | ||
| Value | newWeight ) |
Computes the uniform transition probability for subtree combination.
Specifically: sigmoid(new_weight - current_weight).
Definition at line 557 of file HMCUtils.cpp.
Referenced by combineTrees().
| Value mlir::impulse::conditionalDump | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | value, | ||
| StringRef | label, | ||
| bool | debugDump ) |
Conditionally dump a value for debugging.
Emits an enzyme::DumpOp if debugDump is true; otherwise has no effect.
Definition at line 64 of file HMCUtils.cpp.
Referenced by sampleMomentum().
| Value mlir::impulse::constrainPosition | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | unconstrained, | ||
| ArrayRef< SupportInfo > | supports ) |
Transform an entire position vector from unconstrained to constrained space.
based on the support information.
Definition at line 2044 of file HMCUtils.cpp.
References mlir::enzyme::transforms::constrain().
Referenced by computePotentialAndGradient(), and InitHMC().
| SubtreeBuildResult mlir::impulse::doubleTree | ( | OpBuilder & | builder, |
| Location | loc, | ||
| const NUTSTreeState & | tree, | ||
| Value | direction, | ||
| Value | pCkpts, | ||
| Value | pSumCkpts, | ||
| const NUTSContext & | ctx, | ||
| bool | debugDump = false ) |
Tree doubling by building a subtree of same depth and combining.
Definition at line 1331 of file HMCUtils.cpp.
References buildIterativeSubtree(), combineTrees(), and mlir::impulse::NUTSTreeState::rng.
Referenced by buildTree().
| Value mlir::impulse::finalizeWelford | ( | OpBuilder & | builder, |
| Location | loc, | ||
| const WelfordState & | state, | ||
| const WelfordConfig & | config ) |
Finalize Welford state to produce sample covariance (returned as inverse mass matrix).
Definition at line 1840 of file HMCUtils.cpp.
References createIdentityMatrix(), mlir::impulse::WelfordConfig::diagonal, mlir::impulse::WelfordState::m2, mlir::impulse::WelfordState::n, and mlir::impulse::WelfordConfig::regularize.
| IntegratorState mlir::impulse::getLeafFromTree | ( | OpBuilder & | builder, |
| Location | loc, | ||
| const NUTSTreeState & | tree, | ||
| Value | direction, | ||
| const NUTSContext & | ctx ) |
Extracts an appropriate leaf based on direction.
Definition at line 1166 of file HMCUtils.cpp.
References mlir::impulse::HMCContext::getPositionType(), mlir::impulse::NUTSTreeState::grad_left, mlir::impulse::NUTSTreeState::grad_right, mlir::impulse::NUTSTreeState::p_left, mlir::impulse::NUTSTreeState::p_right, mlir::impulse::NUTSTreeState::q_left, and mlir::impulse::NUTSTreeState::q_right.
Referenced by buildIterativeSubtree().
| Value mlir::impulse::getStepSizeFromDualAveraging | ( | OpBuilder & | builder, |
| Location | loc, | ||
| const DualAveragingState & | state, | ||
| bool | final = false ) |
Get step size from dual averaging state.
If final is true, returns the averaged step size. Otherwise, returns the updated step size.
Definition at line 1749 of file HMCUtils.cpp.
References mlir::impulse::DualAveragingState::log_step_size, and mlir::impulse::DualAveragingState::log_step_size_avg.
| DualAveragingState mlir::impulse::initDualAveraging | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | stepSize ) |
Initialize dual averaging state from initial step size.
Definition at line 1624 of file HMCUtils.cpp.
| InitialHMCState mlir::impulse::InitHMC | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | rng, | ||
| const HMCContext & | ctx, | ||
| Value | initialPosition = Value(), | ||
| bool | debugDump = false ) |
Initializes HMC/NUTS state from a trace Specifically:
Definition at line 701 of file HMCUtils.cpp.
References mlir::impulse::HMCContext::allAddresses, mlir::impulse::HMCContext::autodiffAttrs, computeTotalJacobianCorrection(), constrainPosition(), mlir::impulse::HMCContext::fn, mlir::impulse::HMCContext::fnInputs, mlir::impulse::HMCContext::fnResultTypes, gatherPositionFromTrace(), mlir::impulse::HMCContext::getElementType(), mlir::impulse::HMCContext::getFullTraceSize(), mlir::impulse::HMCContext::getPositionType(), mlir::impulse::HMCContext::getScalarType(), mlir::impulse::HMCContext::hasCustomLogpdf(), mlir::impulse::HMCContext::logpdfFn, mlir::impulse::HMCContext::originalTrace, mlir::impulse::HMCContext::positionSize, scatterPositionToTrace(), mlir::impulse::HMCContext::supports, and unconstrainPosition().
| WelfordState mlir::impulse::initWelford | ( | OpBuilder & | builder, |
| Location | loc, | ||
| int64_t | positionSize, | ||
| bool | diagonal ) |
Initialize state for Welford covariance estimation.
Definition at line 1756 of file HMCUtils.cpp.
| std::pair< Value, Value > mlir::impulse::leafIdxToCheckpointIdxs | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | leafIdx ) |
Computes checkpoint indices from leaf index for iterative turning check.
Definition at line 1468 of file HMCUtils.cpp.
Referenced by buildIterativeSubtree().
| MCMCKernelResult mlir::impulse::SampleHMC | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | q, | ||
| Value | grad, | ||
| Value | U, | ||
| Value | rng, | ||
| const HMCContext & | ctx, | ||
| bool | debugDump = false ) |
Single HMC iteration: momentum sampling + leapfrog + MH accept/reject.
Definition at line 862 of file HMCUtils.cpp.
References computeIntegrationStep(), computeKineticEnergy(), mlir::impulse::HMCContext::getElementType(), mlir::impulse::HMCContext::getPositionType(), mlir::impulse::HMCContext::getScalarType(), mlir::impulse::HMCContext::hasCustomLogpdf(), mlir::impulse::HMCContext::invMass, mlir::impulse::HMCContext::massMatrixSqrt, sampleMomentum(), mlir::impulse::HMCContext::stepSize, mlir::impulse::HMCContext::trajectoryLength, and mlir::impulse::HMCContext::withStepSize().
| std::pair< Value, Value > mlir::impulse::sampleMomentum | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | rng, | ||
| Value | invMass, | ||
| Value | massMatrixSqrt, | ||
| RankedTensorType | positionType, | ||
| bool | debugDump = false ) |
Samples momentum from N(0, M) where M is the mass matrix.
Returns (momentum, updated_rng_state).
Definition at line 233 of file HMCUtils.cpp.
References conditionalDump().
Referenced by SampleHMC(), and SampleNUTS().
| MCMCKernelResult mlir::impulse::SampleNUTS | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | q, | ||
| Value | grad, | ||
| Value | U, | ||
| Value | rng, | ||
| const NUTSContext & | ctx, | ||
| bool | debugDump = false ) |
Single NUTS iteration: momentum sampling + tree building.
Definition at line 998 of file HMCUtils.cpp.
References buildTree(), computeKineticEnergy(), mlir::impulse::HMCContext::getElementType(), mlir::impulse::HMCContext::getPositionType(), mlir::impulse::HMCContext::getScalarType(), mlir::impulse::HMCContext::hasCustomLogpdf(), mlir::impulse::HMCContext::invMass, mlir::impulse::HMCContext::massMatrixSqrt, sampleMomentum(), and mlir::impulse::NUTSContext::withH0().
| Value mlir::impulse::unconstrainPosition | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | constrained, | ||
| ArrayRef< SupportInfo > | supports ) |
Transform an entire position vector from constrained to unconstrained space based on the support information.
Definition at line 1955 of file HMCUtils.cpp.
References mlir::enzyme::transforms::unconstrain().
Referenced by InitHMC().
| std::pair< Value, Value > mlir::impulse::updateCheckpoints | ( | OpBuilder & | builder, |
| Location | loc, | ||
| Value | leafIdx, | ||
| Value | ckptIdxMax, | ||
| Value | p, | ||
| Value | pSum, | ||
| Value | pCkpts, | ||
| Value | pSumCkpts, | ||
| const NUTSContext & | ctx, | ||
| bool | debugDump = false ) |
Update checkpoint arrays at even leaf indices.
Definition at line 1575 of file HMCUtils.cpp.
Referenced by buildIterativeSubtree().
| DualAveragingState mlir::impulse::updateDualAveraging | ( | OpBuilder & | builder, |
| Location | loc, | ||
| const DualAveragingState & | state, | ||
| Value | acceptProb, | ||
| const DualAveragingConfig & | config ) |
Update dual averaging state with observed acceptance probability.
Definition at line 1657 of file HMCUtils.cpp.
References mlir::impulse::DualAveragingConfig::gamma, mlir::impulse::DualAveragingState::gradient_avg, mlir::impulse::DualAveragingConfig::kappa, mlir::impulse::DualAveragingState::log_step_size_avg, mlir::impulse::DualAveragingState::prox_center, mlir::impulse::DualAveragingState::step_count, mlir::impulse::DualAveragingConfig::t0, and mlir::impulse::DualAveragingConfig::target_accept_prob.
| WelfordState mlir::impulse::updateWelford | ( | OpBuilder & | builder, |
| Location | loc, | ||
| const WelfordState & | state, | ||
| Value | sample, | ||
| const WelfordConfig & | config ) |
Update Welford state with a new sample.
Definition at line 1787 of file HMCUtils.cpp.
References mlir::impulse::WelfordConfig::diagonal, mlir::impulse::WelfordState::m2, mlir::impulse::WelfordState::mean, and mlir::impulse::WelfordState::n.