Enzyme main
Loading...
Searching...
No Matches
HMCUtils.h File Reference
#include "Dialect/Impulse/Impulse.h"
#include "Dialect/Ops.h"
#include "Interfaces/TransformUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
Include dependency graph for HMCUtils.h:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Classes

struct  mlir::impulse::SupportInfo
 
struct  mlir::impulse::IntegrationResult
 
struct  mlir::impulse::GradientResult
 
struct  mlir::impulse::InitialHMCState
 
struct  mlir::impulse::MCMCKernelResult
 Result of one MCMC kernel step. More...
 
struct  mlir::impulse::DualAveragingState
 
struct  mlir::impulse::IntegratorState
 
struct  mlir::impulse::HMCContext
 
struct  mlir::impulse::NUTSContext
 
struct  mlir::impulse::NUTSTreeState
 
struct  mlir::impulse::SubtreeBuildResult
 
struct  mlir::impulse::DualAveragingConfig
 
struct  mlir::impulse::WelfordState
 State for Welford covariance estimation. More...
 
struct  mlir::impulse::WelfordConfig
 Configuration for Welford covariance estimation. More...
 
struct  mlir::impulse::AdaptWindow
 

Namespaces

namespace  mlir
 
namespace  mlir::impulse
 

Functions

Value mlir::impulse::conditionalDump (OpBuilder &builder, Location loc, Value value, StringRef label, bool debugDump)
 Conditionally dump a value for debugging.
 
Value mlir::impulse::applyInverseMassMatrix (OpBuilder &builder, Location loc, Value invMass, Value momentum, RankedTensorType positionType)
 Computes v = M^-1 @ p If invMass is nullptr, returns momentum unchanged (assumes identity) If invMass is a diagonal matrix, computes v = invMass * momentum If invMass is a dense matrix, computes v = invMass @ momentum
 
Value mlir::impulse::computeKineticEnergy (OpBuilder &builder, Location loc, Value momentum, Value invMass, RankedTensorType positionType)
 Computes K = 0.5 * p^T @ M^-1 @ p
 
Value mlir::impulse::computeMassMatrixSqrt (OpBuilder &builder, Location loc, Value invMass, RankedTensorType positionType)
 Computes the square root of the mass matrix from the inverse mass matrix.
 
std::pair< Value, Value > mlir::impulse::sampleMomentum (OpBuilder &builder, Location loc, Value rng, Value invMass, Value massMatrixSqrt, RankedTensorType positionType, bool debugDump=false)
 Samples momentum from N(0, M) where M is the mass matrix.
 
GradientResult mlir::impulse::computePotentialAndGradient (OpBuilder &builder, Location loc, Value position, Value rng, const HMCContext &ctx)
 Computes potential energy U(q) = -log p(q) and its gradient dU/dq
 
IntegrationResult mlir::impulse::computeIntegrationStep (OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const HMCContext &ctx)
 Computes a single leapfrog integration step.
 
Value mlir::impulse::checkTurning (OpBuilder &builder, Location loc, Value pLeft, Value pRight, Value pSum, const NUTSContext &ctx)
 U-turn termination criterion.
 
Value mlir::impulse::computeUniformTransitionProb (OpBuilder &builder, Location loc, Value currentWeight, Value newWeight)
 Computes the uniform transition probability for subtree combination.
 
Value mlir::impulse::computeBiasedTransitionProb (OpBuilder &builder, Location loc, Value currentWeight, Value newWeight, Value turning, Value diverging)
 Computes the biased transition probability for main tree combination.
 
NUTSTreeState mlir::impulse::combineTrees (OpBuilder &builder, Location loc, const NUTSTreeState &tree, const NUTSTreeState &subTree, Value direction, Value rng, bool biased, const NUTSContext &ctx)
 Combines a tree with a newly-built subtree during NUTS doubling process.
 
InitialHMCState mlir::impulse::InitHMC (OpBuilder &builder, Location loc, Value rng, const HMCContext &ctx, Value initialPosition=Value(), bool debugDump=false)
 Initializes HMC/NUTS state from a trace Specifically:
 
MCMCKernelResult mlir::impulse::SampleHMC (OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const HMCContext &ctx, bool debugDump=false)
 Single HMC iteration: momentum sampling + leapfrog + MH accept/reject.
 
MCMCKernelResult mlir::impulse::SampleNUTS (OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const NUTSContext &ctx, bool debugDump=false)
 Single NUTS iteration: momentum sampling + tree building.
 
NUTSTreeState mlir::impulse::buildBaseTree (OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const NUTSContext &ctx)
 Builds a base tree (leaf node) by taking one leapfrog step.
 
IntegratorState mlir::impulse::getLeafFromTree (OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, const NUTSContext &ctx)
 Extracts an appropriate leaf based on direction.
 
SubtreeBuildResult mlir::impulse::buildIterativeSubtree (OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
 Builds a subtree iteratively by appending leaves one at a time.
 
SubtreeBuildResult mlir::impulse::doubleTree (OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
 Tree doubling by building a subtree of same depth and combining.
 
NUTSTreeState mlir::impulse::buildTree (OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, const NUTSContext &ctx, bool debugDump=false)
 Main NUTS tree building loop.
 
std::pair< Value, Value > mlir::impulse::leafIdxToCheckpointIdxs (OpBuilder &builder, Location loc, Value leafIdx)
 Computes checkpoint indices from leaf index for iterative turning check.
 
Value mlir::impulse::checkIterativeTurning (OpBuilder &builder, Location loc, Value p, Value pSum, Value pCkpts, Value pSumCkpts, Value idxMin, Value idxMax, const NUTSContext &ctx, bool debugDump=false)
 Checkpoint-based iterative turning check.
 
std::pair< Value, Value > mlir::impulse::updateCheckpoints (OpBuilder &builder, Location loc, Value leafIdx, Value ckptIdxMax, Value p, Value pSum, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
 Update checkpoint arrays at even leaf indices.
 
DualAveragingState mlir::impulse::initDualAveraging (OpBuilder &builder, Location loc, Value stepSize)
 Initialize dual averaging state from initial step size.
 
DualAveragingState mlir::impulse::updateDualAveraging (OpBuilder &builder, Location loc, const DualAveragingState &state, Value acceptProb, const DualAveragingConfig &config)
 Update dual averaging state with observed acceptance probability.
 
Value mlir::impulse::getStepSizeFromDualAveraging (OpBuilder &builder, Location loc, const DualAveragingState &state, bool final=false)
 Get step size from dual averaging state.
 
WelfordState mlir::impulse::initWelford (OpBuilder &builder, Location loc, int64_t positionSize, bool diagonal)
 Initialize state for Welford covariance estimation.
 
WelfordState mlir::impulse::updateWelford (OpBuilder &builder, Location loc, const WelfordState &state, Value sample, const WelfordConfig &config)
 Update Welford state with a new sample.
 
Value mlir::impulse::finalizeWelford (OpBuilder &builder, Location loc, const WelfordState &state, const WelfordConfig &config)
 Finalize Welford state to produce sample covariance (returned as inverse mass matrix).
 
SmallVector< AdaptWindowmlir::impulse::buildAdaptationSchedule (int64_t numSteps)
 Build warmup adaptation schedule.
 
Value mlir::impulse::unconstrainPosition (OpBuilder &builder, Location loc, Value constrained, ArrayRef< SupportInfo > supports)
 Transform an entire position vector from constrained to unconstrained space based on the support information.
 
Value mlir::impulse::constrainPosition (OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports)
 Transform an entire position vector from unconstrained to constrained space.
 
Value mlir::impulse::computeTotalJacobianCorrection (OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports)
 Compute total Jacobian correction for the constrain transform over all position vector slices.