Enzyme main
Loading...
Searching...
No Matches
HMCUtils.h
Go to the documentation of this file.
1//===- HMCUtils.h - Utilities for HMC/NUTS algorithms -------* C++ -*-===//
2//
3// This file declares utility functions for Hamiltonian Monte Carlo (HMC) and
4// No-U-Turn Sampler (NUTS) implementations.
5//
6// Reference:
7// https://github.com/pyro-ppl/numpyro/blob/master/numpyro/infer/hmc_util.py
8//
9//===----------------------------------------------------------------------===//
10
11#ifndef ENZYME_MLIR_INTERFACES_HMC_UTILS_H
12#define ENZYME_MLIR_INTERFACES_HMC_UTILS_H
13
15#include "Dialect/Ops.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/Math/IR/Math.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/Value.h"
24
25namespace mlir {
26namespace impulse {
27
29 int64_t offset;
30 int64_t traceOffset;
31 int64_t size;
32 impulse::SupportAttr support;
33
34 SupportInfo(int64_t offset, int64_t traceOffset, int64_t size,
35 impulse::SupportAttr support)
37 }
38};
39
41 Value q;
42 Value p;
43 Value grad;
44 Value U;
45 Value rng;
46};
47
49 Value U; // U(q) = -log p(q)
50 Value grad; // dU/dq
51 Value rng; // Updated RNG state
52};
53
55 Value q0; // Flattened position vector
56 Value U0; // Initial potential energy
57 Value grad0; // Initial gradient at position q0
58 Value rng; // RNG state for SampleHMC
59};
60
61/// Result of one MCMC kernel step
63 Value q; // New position vector
64 Value grad; // Gradient at new position
65 Value U; // Potential energy at new position
66 Value accepted; // Whether proposal was accepted
67 Value accept_prob; // Mean acceptance probability
68 Value rng; // Updated RNG state
69};
70
77
78 SmallVector<Value> toValues() const {
81 }
82 static DualAveragingState fromValues(ArrayRef<Value> values) {
83 return {values[0], values[1], values[2], values[3], values[4]};
84 }
85 SmallVector<Type> getTypes() const {
86 return {log_step_size.getType(), log_step_size_avg.getType(),
87 gradient_avg.getType(), step_count.getType(),
88 prox_center.getType()};
89 }
90};
91
93 Value q;
94 Value p;
95 Value grad;
96};
97
98struct HMCContext {
99 FlatSymbolRefAttr fn;
100 ArrayRef<Value> fnInputs;
101 SmallVector<Type> fnResultTypes;
103 ArrayAttr selection;
104 ArrayAttr allAddresses;
105 Value invMass;
107 Value stepSize;
110 SmallVector<SupportInfo> supports;
111 FlatSymbolRefAttr logpdfFn;
112 DictionaryAttr autodiffAttrs;
113
114 HMCContext(FlatSymbolRefAttr fn, ArrayRef<Value> fnInputs,
115 ArrayRef<Type> fnResultTypes, Value originalTrace,
116 ArrayAttr selection, ArrayAttr allAddresses, Value invMass,
117 Value massMatrixSqrt, Value stepSize, Value trajectoryLength,
118 int64_t positionSize, ArrayRef<SupportInfo> supports,
119 DictionaryAttr autodiffAttrs = {})
120 : fn(fn), fnInputs(fnInputs),
126 supports(supports.begin(), supports.end()),
128
137
138 bool hasCustomLogpdf() const { return logpdfFn != nullptr; }
139
140 int64_t getFullTraceSize() const {
141 auto traceType = cast<RankedTensorType>(originalTrace.getType());
142 return traceType.getShape()[1];
143 }
144
145 Type getElementType() const {
146 return cast<RankedTensorType>(stepSize.getType()).getElementType();
147 }
148
149 RankedTensorType getPositionType() const {
150 return RankedTensorType::get({1, positionSize}, getElementType());
151 }
152
153 RankedTensorType getScalarType() const {
154 return RankedTensorType::get({}, getElementType());
155 }
156
158 for (const auto &info : supports) {
159 if (info.support && info.support.getKind() != impulse::SupportKind::REAL)
160 return true;
161 }
162 return false;
163 }
164
165 HMCContext withStepSize(Value newStepSize) const {
166 HMCContext copy = *this;
167 copy.stepSize = newStepSize;
168 return copy;
169 }
170};
171
172struct NUTSContext : public HMCContext {
173 Value H0;
176
177 NUTSContext(FlatSymbolRefAttr fn, ArrayRef<Value> fnInputs,
178 ArrayRef<Type> fnResultTypes, Value originalTrace,
179 ArrayAttr selection, ArrayAttr allAddresses, Value invMass,
180 Value massMatrixSqrt, Value stepSize, int64_t positionSize,
181 ArrayRef<SupportInfo> supports, Value H0, Value maxDeltaEnergy,
182 int64_t maxTreeDepth, DictionaryAttr autodiffAttrs = {})
185 /* Unused trajectoryLength */ Value(), positionSize,
188
189 NUTSContext(FlatSymbolRefAttr logpdfFn, ArrayRef<Value> fnInputs,
190 Value invMass, Value massMatrixSqrt, Value stepSize,
191 int64_t positionSize, Value H0, Value maxDeltaEnergy,
192 int64_t maxTreeDepth, DictionaryAttr autodiffAttrs = {})
194 /* Unused trajectoryLength */ Value(), positionSize,
197
198 NUTSContext withH0(Value newH0) const {
199 NUTSContext copy = *this;
200 copy.H0 = newH0;
201 return copy;
202 }
203};
204
211 Value rng;
212
213 SmallVector<Value> toValues() const;
214 static NUTSTreeState fromValues(ArrayRef<Value> values);
215 SmallVector<Type> getTypes() const;
216
219 return {q_right, p_right, grad_right};
220 }
221};
222
228
229/// Conditionally dump a value for debugging.
230/// Emits an enzyme::DumpOp if `debugDump` is true; otherwise has no effect.
231Value conditionalDump(OpBuilder &builder, Location loc, Value value,
232 StringRef label, bool debugDump);
233
234/// Computes `v = M^-1 @ p`
235/// If `invMass` is nullptr, returns `momentum` unchanged (assumes identity)
236/// If `invMass` is a diagonal matrix, computes `v = invMass * momentum`
237/// If `invMass` is a dense matrix, computes `v = invMass @ momentum`
238Value applyInverseMassMatrix(OpBuilder &builder, Location loc, Value invMass,
239 Value momentum, RankedTensorType positionType);
240
241/// Computes `K = 0.5 * p^T @ M^-1 @ p`
242Value computeKineticEnergy(OpBuilder &builder, Location loc, Value momentum,
243 Value invMass, RankedTensorType positionType);
244
245/// Computes the square root of the mass matrix from the inverse mass matrix.
246Value computeMassMatrixSqrt(OpBuilder &builder, Location loc, Value invMass,
247 RankedTensorType positionType);
248
249/// Samples momentum from `N(0, M)` where M is the mass matrix.
250/// Returns `(momentum, updated_rng_state)`.
251std::pair<Value, Value> sampleMomentum(OpBuilder &builder, Location loc,
252 Value rng, Value invMass,
253 Value massMatrixSqrt,
254 RankedTensorType positionType,
255 bool debugDump = false);
256
257/// Computes potential energy `U(q) = -log p(q)` and its gradient `dU/dq`
258GradientResult computePotentialAndGradient(OpBuilder &builder, Location loc,
259 Value position, Value rng,
260 const HMCContext &ctx);
261
262/// Computes a single leapfrog integration step.
263/// Specifically:
264/// - `p_half = p - (eps/2) * grad`
265/// - `q_new = q + eps * M^-1 * p_half`
266/// - `grad_new = dU/dq(q_new)`
267/// - `p_new = p_half - (eps/2) * grad_new`
268IntegrationResult computeIntegrationStep(OpBuilder &builder, Location loc,
269 const IntegratorState &leaf, Value rng,
270 Value direction,
271 const HMCContext &ctx);
272
273/// U-turn termination criterion.
274/// Returns `true` if U-turn detected in the trajectory.
275Value checkTurning(OpBuilder &builder, Location loc, Value pLeft, Value pRight,
276 Value pSum, const NUTSContext &ctx);
277
278/// Computes the uniform transition probability for subtree combination.
279/// Specifically: `sigmoid(new_weight - current_weight)`.
280Value computeUniformTransitionProb(OpBuilder &builder, Location loc,
281 Value currentWeight, Value newWeight);
282
283/// Computes the biased transition probability for main tree combination.
284/// Specifically: `min(1, exp(new_weight - current_weight))`, zeroed if
285/// turning or diverging.
286Value computeBiasedTransitionProb(OpBuilder &builder, Location loc,
287 Value currentWeight, Value newWeight,
288 Value turning, Value diverging);
289
290/// Combines a tree with a newly-built subtree during NUTS doubling process.
291///
292/// The `subtree` is merged into `tree` based on `direction`:
293/// - If going right (`direction` is true), then we keep `tree`'s left
294/// boundary, takes `subtree`'s right boundary
295/// - If going left (`direction` is false), then we take `subtree`'s left
296/// boundary, keeps `tree`'s right boundary
297///
298/// Proposal selection has two options:
299/// - Biased kernel (`biased` is true):
300/// `prob = min(1, exp(subtree.weight - tree.weight))`,
301/// zeroed if `turning` or `diverging`
302/// - Uniform kernel (`biased` is false):
303/// `prob = sigmoid(subtree.weight - tree.weight)`
304///
305/// Also checks the U-turn criterion on the combined momentum sum.
306NUTSTreeState combineTrees(OpBuilder &builder, Location loc,
307 const NUTSTreeState &tree,
308 const NUTSTreeState &subTree, Value direction,
309 Value rng, bool biased, const NUTSContext &ctx);
310
311/// Initializes HMC/NUTS state from a trace
312/// Specifically:
313/// - Extracts position from trace
314/// - Computes initial potential energy U0 = -weight
315/// - Samples initial momentum p0 ~ N(0, M)
316/// - Computes initial kinetic energy and Hamiltonian
317/// - Computes initial gradient via AutoDiffRegionOp
318InitialHMCState InitHMC(OpBuilder &builder, Location loc, Value rng,
319 const HMCContext &ctx, Value initialPosition = Value(),
320 bool debugDump = false);
321
322/// Single HMC iteration: momentum sampling + leapfrog + MH accept/reject
323MCMCKernelResult SampleHMC(OpBuilder &builder, Location loc, Value q,
324 Value grad, Value U, Value rng,
325 const HMCContext &ctx, bool debugDump = false);
326
327/// Single NUTS iteration: momentum sampling + tree building
328MCMCKernelResult SampleNUTS(OpBuilder &builder, Location loc, Value q,
329 Value grad, Value U, Value rng,
330 const NUTSContext &ctx, bool debugDump = false);
331
332/// Builds a base tree (leaf node) by taking one leapfrog step.
333NUTSTreeState buildBaseTree(OpBuilder &builder, Location loc,
334 const IntegratorState &leaf, Value rng,
335 Value direction, const NUTSContext &ctx);
336
337/// Extracts an appropriate leaf based on direction.
338IntegratorState getLeafFromTree(OpBuilder &builder, Location loc,
339 const NUTSTreeState &tree, Value direction,
340 const NUTSContext &ctx);
341
342/// Builds a subtree iteratively by appending leaves one at a time.
343SubtreeBuildResult buildIterativeSubtree(OpBuilder &builder, Location loc,
344 const NUTSTreeState &initialTree,
345 Value direction, Value pCkpts,
346 Value pSumCkpts,
347 const NUTSContext &ctx,
348 bool debugDump = false);
349
350/// Tree doubling by building a subtree of same depth and combining.
351SubtreeBuildResult doubleTree(OpBuilder &builder, Location loc,
352 const NUTSTreeState &tree, Value direction,
353 Value pCkpts, Value pSumCkpts,
354 const NUTSContext &ctx, bool debugDump = false);
355
356/// Main NUTS tree building loop.
357NUTSTreeState buildTree(OpBuilder &builder, Location loc,
358 const NUTSTreeState &initialTree,
359 const NUTSContext &ctx, bool debugDump = false);
360
361/// Computes checkpoint indices from leaf index for iterative turning check.
362std::pair<Value, Value> leafIdxToCheckpointIdxs(OpBuilder &builder,
363 Location loc, Value leafIdx);
364
365/// Checkpoint-based iterative turning check.
366Value checkIterativeTurning(OpBuilder &builder, Location loc, Value p,
367 Value pSum, Value pCkpts, Value pSumCkpts,
368 Value idxMin, Value idxMax, const NUTSContext &ctx,
369 bool debugDump = false);
370
371/// Update checkpoint arrays at even leaf indices.
372std::pair<Value, Value> updateCheckpoints(OpBuilder &builder, Location loc,
373 Value leafIdx, Value ckptIdxMax,
374 Value p, Value pSum, Value pCkpts,
375 Value pSumCkpts,
376 const NUTSContext &ctx,
377 bool debugDump = false);
378
379// TODO: Proper customization
381 double t0 = 10.0; // Stabilization
382 double kappa = 0.75; // Weight decay
383 double gamma = 0.05; // Convergence
384 double target_accept_prob = 0.8;
385};
386
387/// Initialize dual averaging state from initial step size.
388DualAveragingState initDualAveraging(OpBuilder &builder, Location loc,
389 Value stepSize);
390
391/// Update dual averaging state with observed acceptance probability.
392DualAveragingState updateDualAveraging(OpBuilder &builder, Location loc,
393 const DualAveragingState &state,
394 Value acceptProb,
395 const DualAveragingConfig &config);
396
397/// Get step size from dual averaging state.
398/// If `final` is true, returns the averaged step size.
399/// Otherwise, returns the updated step size.
400Value getStepSizeFromDualAveraging(OpBuilder &builder, Location loc,
401 const DualAveragingState &state,
402 bool final = false);
403
404/// State for Welford covariance estimation
406 Value mean;
407 Value m2; // Sum of squared deviations
408 Value n;
409
410 SmallVector<Value> toValues() const { return {mean, m2, n}; }
411 static WelfordState fromValues(ArrayRef<Value> values) {
412 return {values[0], values[1], values[2]};
413 }
414 SmallVector<Type> getTypes() const {
415 return {mean.getType(), m2.getType(), n.getType()};
416 }
417};
418
419/// Configuration for Welford covariance estimation
421 bool diagonal = true;
422 bool regularize = true; // Optional regularization (Stan's shrinkage)
423};
424
425/// Initialize state for Welford covariance estimation.
426WelfordState initWelford(OpBuilder &builder, Location loc, int64_t positionSize,
427 bool diagonal);
428
429/// Update Welford state with a new sample.
430WelfordState updateWelford(OpBuilder &builder, Location loc,
431 const WelfordState &state, Value sample,
432 const WelfordConfig &config);
433
434/// Finalize Welford state to produce sample covariance (returned as inverse
435/// mass matrix).
436Value finalizeWelford(OpBuilder &builder, Location loc,
437 const WelfordState &state, const WelfordConfig &config);
438
440 int64_t start;
441 int64_t end;
442};
443
444/// Build warmup adaptation schedule.
445/// TODO: Make customizable
446SmallVector<AdaptWindow> buildAdaptationSchedule(int64_t numSteps);
447
448/// Transform an entire position vector from constrained to unconstrained space
449/// based on the support information.
450Value unconstrainPosition(OpBuilder &builder, Location loc, Value constrained,
451 ArrayRef<SupportInfo> supports);
452
453/// Transform an entire position vector from unconstrained to constrained space.
454/// based on the support information.
455Value constrainPosition(OpBuilder &builder, Location loc, Value unconstrained,
456 ArrayRef<SupportInfo> supports);
457
458/// Compute total Jacobian correction for the constrain transform over all
459/// position vector slices.
460Value computeTotalJacobianCorrection(OpBuilder &builder, Location loc,
461 Value unconstrained,
462 ArrayRef<SupportInfo> supports);
463} // namespace impulse
464} // namespace mlir
465
466#endif // ENZYME_MLIR_INTERFACES_HMC_UTILS_H
PointerType * traceType(LLVMContext &C)
Value computeTotalJacobianCorrection(OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports)
Compute total Jacobian correction for the constrain transform over all position vector slices.
Value finalizeWelford(OpBuilder &builder, Location loc, const WelfordState &state, const WelfordConfig &config)
Finalize Welford state to produce sample covariance (returned as inverse mass matrix).
Value checkTurning(OpBuilder &builder, Location loc, Value pLeft, Value pRight, Value pSum, const NUTSContext &ctx)
U-turn termination criterion.
Definition HMCUtils.cpp:509
GradientResult computePotentialAndGradient(OpBuilder &builder, Location loc, Value position, Value rng, const HMCContext &ctx)
Computes potential energy U(q) = -log p(q) and its gradient dU/dq
Definition HMCUtils.cpp:352
std::pair< Value, Value > leafIdxToCheckpointIdxs(OpBuilder &builder, Location loc, Value leafIdx)
Computes checkpoint indices from leaf index for iterative turning check.
IntegratorState getLeafFromTree(OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, const NUTSContext &ctx)
Extracts an appropriate leaf based on direction.
std::pair< Value, Value > updateCheckpoints(OpBuilder &builder, Location loc, Value leafIdx, Value ckptIdxMax, Value p, Value pSum, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
Update checkpoint arrays at even leaf indices.
SubtreeBuildResult buildIterativeSubtree(OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
Builds a subtree iteratively by appending leaves one at a time.
std::pair< Value, Value > sampleMomentum(OpBuilder &builder, Location loc, Value rng, Value invMass, Value massMatrixSqrt, RankedTensorType positionType, bool debugDump=false)
Samples momentum from N(0, M) where M is the mass matrix.
Definition HMCUtils.cpp:233
Value computeUniformTransitionProb(OpBuilder &builder, Location loc, Value currentWeight, Value newWeight)
Computes the uniform transition probability for subtree combination.
Definition HMCUtils.cpp:557
SmallVector< AdaptWindow > buildAdaptationSchedule(int64_t numSteps)
Build warmup adaptation schedule.
NUTSTreeState buildBaseTree(OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const NUTSContext &ctx)
Builds a base tree (leaf node) by taking one leapfrog step.
Value computeKineticEnergy(OpBuilder &builder, Location loc, Value momentum, Value invMass, RankedTensorType positionType)
Computes K = 0.5 * p^T @ M^-1 @ p
Definition HMCUtils.cpp:166
NUTSTreeState combineTrees(OpBuilder &builder, Location loc, const NUTSTreeState &tree, const NUTSTreeState &subTree, Value direction, Value rng, bool biased, const NUTSContext &ctx)
Combines a tree with a newly-built subtree during NUTS doubling process.
Definition HMCUtils.cpp:590
Value computeBiasedTransitionProb(OpBuilder &builder, Location loc, Value currentWeight, Value newWeight, Value turning, Value diverging)
Computes the biased transition probability for main tree combination.
Definition HMCUtils.cpp:566
Value checkIterativeTurning(OpBuilder &builder, Location loc, Value p, Value pSum, Value pCkpts, Value pSumCkpts, Value idxMin, Value idxMax, const NUTSContext &ctx, bool debugDump=false)
Checkpoint-based iterative turning check.
NUTSTreeState buildTree(OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, const NUTSContext &ctx, bool debugDump=false)
Main NUTS tree building loop.
Value conditionalDump(OpBuilder &builder, Location loc, Value value, StringRef label, bool debugDump)
Conditionally dump a value for debugging.
Definition HMCUtils.cpp:64
MCMCKernelResult SampleNUTS(OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const NUTSContext &ctx, bool debugDump=false)
Single NUTS iteration: momentum sampling + tree building.
Definition HMCUtils.cpp:998
InitialHMCState InitHMC(OpBuilder &builder, Location loc, Value rng, const HMCContext &ctx, Value initialPosition=Value(), bool debugDump=false)
Initializes HMC/NUTS state from a trace Specifically:
Definition HMCUtils.cpp:701
SubtreeBuildResult doubleTree(OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump=false)
Tree doubling by building a subtree of same depth and combining.
IntegrationResult computeIntegrationStep(OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const HMCContext &ctx)
Computes a single leapfrog integration step.
Definition HMCUtils.cpp:460
DualAveragingState updateDualAveraging(OpBuilder &builder, Location loc, const DualAveragingState &state, Value acceptProb, const DualAveragingConfig &config)
Update dual averaging state with observed acceptance probability.
Value unconstrainPosition(OpBuilder &builder, Location loc, Value constrained, ArrayRef< SupportInfo > supports)
Transform an entire position vector from constrained to unconstrained space based on the support info...
DualAveragingState initDualAveraging(OpBuilder &builder, Location loc, Value stepSize)
Initialize dual averaging state from initial step size.
Value computeMassMatrixSqrt(OpBuilder &builder, Location loc, Value invMass, RankedTensorType positionType)
Computes the square root of the mass matrix from the inverse mass matrix.
Definition HMCUtils.cpp:190
WelfordState updateWelford(OpBuilder &builder, Location loc, const WelfordState &state, Value sample, const WelfordConfig &config)
Update Welford state with a new sample.
WelfordState initWelford(OpBuilder &builder, Location loc, int64_t positionSize, bool diagonal)
Initialize state for Welford covariance estimation.
Value getStepSizeFromDualAveraging(OpBuilder &builder, Location loc, const DualAveragingState &state, bool final=false)
Get step size from dual averaging state.
MCMCKernelResult SampleHMC(OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const HMCContext &ctx, bool debugDump=false)
Single HMC iteration: momentum sampling + leapfrog + MH accept/reject.
Definition HMCUtils.cpp:862
Value constrainPosition(OpBuilder &builder, Location loc, Value unconstrained, ArrayRef< SupportInfo > supports)
Transform an entire position vector from unconstrained to constrained space.
Value applyInverseMassMatrix(OpBuilder &builder, Location loc, Value invMass, Value momentum, RankedTensorType positionType)
Computes v = M^-1 @ p If invMass is nullptr, returns momentum unchanged (assumes identity) If invMass...
Definition HMCUtils.cpp:139
SmallVector< Type > getTypes() const
Definition HMCUtils.h:85
static DualAveragingState fromValues(ArrayRef< Value > values)
Definition HMCUtils.h:82
SmallVector< Value > toValues() const
Definition HMCUtils.h:78
DictionaryAttr autodiffAttrs
Definition HMCUtils.h:112
bool hasConstrainedSupports() const
Definition HMCUtils.h:157
Type getElementType() const
Definition HMCUtils.h:145
FlatSymbolRefAttr fn
Definition HMCUtils.h:99
HMCContext(FlatSymbolRefAttr logpdfFn, ArrayRef< Value > fnInputs, Value invMass, Value massMatrixSqrt, Value stepSize, Value trajectoryLength, int64_t positionSize, DictionaryAttr autodiffAttrs={})
Definition HMCUtils.h:129
HMCContext withStepSize(Value newStepSize) const
Definition HMCUtils.h:165
ArrayRef< Value > fnInputs
Definition HMCUtils.h:100
HMCContext(FlatSymbolRefAttr fn, ArrayRef< Value > fnInputs, ArrayRef< Type > fnResultTypes, Value originalTrace, ArrayAttr selection, ArrayAttr allAddresses, Value invMass, Value massMatrixSqrt, Value stepSize, Value trajectoryLength, int64_t positionSize, ArrayRef< SupportInfo > supports, DictionaryAttr autodiffAttrs={})
Definition HMCUtils.h:114
RankedTensorType getPositionType() const
Definition HMCUtils.h:149
int64_t getFullTraceSize() const
Definition HMCUtils.h:140
SmallVector< Type > fnResultTypes
Definition HMCUtils.h:101
SmallVector< SupportInfo > supports
Definition HMCUtils.h:110
bool hasCustomLogpdf() const
Definition HMCUtils.h:138
RankedTensorType getScalarType() const
Definition HMCUtils.h:153
FlatSymbolRefAttr logpdfFn
Definition HMCUtils.h:111
Result of one MCMC kernel step.
Definition HMCUtils.h:62
NUTSContext withH0(Value newH0) const
Definition HMCUtils.h:198
NUTSContext(FlatSymbolRefAttr fn, ArrayRef< Value > fnInputs, ArrayRef< Type > fnResultTypes, Value originalTrace, ArrayAttr selection, ArrayAttr allAddresses, Value invMass, Value massMatrixSqrt, Value stepSize, int64_t positionSize, ArrayRef< SupportInfo > supports, Value H0, Value maxDeltaEnergy, int64_t maxTreeDepth, DictionaryAttr autodiffAttrs={})
Definition HMCUtils.h:177
NUTSContext(FlatSymbolRefAttr logpdfFn, ArrayRef< Value > fnInputs, Value invMass, Value massMatrixSqrt, Value stepSize, int64_t positionSize, Value H0, Value maxDeltaEnergy, int64_t maxTreeDepth, DictionaryAttr autodiffAttrs={})
Definition HMCUtils.h:189
IntegratorState getRightLeaf() const
Definition HMCUtils.h:218
SmallVector< Value > toValues() const
Definition HMCUtils.cpp:26
IntegratorState getLeftLeaf() const
Definition HMCUtils.h:217
SmallVector< Type > getTypes() const
Definition HMCUtils.cpp:57
static NUTSTreeState fromValues(ArrayRef< Value > values)
Definition HMCUtils.cpp:35
SupportInfo(int64_t offset, int64_t traceOffset, int64_t size, impulse::SupportAttr support)
Definition HMCUtils.h:34
impulse::SupportAttr support
Definition HMCUtils.h:32
Configuration for Welford covariance estimation.
Definition HMCUtils.h:420
State for Welford covariance estimation.
Definition HMCUtils.h:405
SmallVector< Value > toValues() const
Definition HMCUtils.h:410
SmallVector< Type > getTypes() const
Definition HMCUtils.h:414
static WelfordState fromValues(ArrayRef< Value > values)
Definition HMCUtils.h:411