Enzyme main
Loading...
Searching...
No Matches
mlir::impulse::NUTSContext Struct Reference

#include "MLIR/Interfaces/HMCUtils.h"

Inheritance diagram for mlir::impulse::NUTSContext:
Collaboration diagram for mlir::impulse::NUTSContext:

Public Member Functions

 NUTSContext (FlatSymbolRefAttr fn, ArrayRef< Value > fnInputs, ArrayRef< Type > fnResultTypes, Value originalTrace, ArrayAttr selection, ArrayAttr allAddresses, Value invMass, Value massMatrixSqrt, Value stepSize, int64_t positionSize, ArrayRef< SupportInfo > supports, Value H0, Value maxDeltaEnergy, int64_t maxTreeDepth, DictionaryAttr autodiffAttrs={})
 
 NUTSContext (FlatSymbolRefAttr logpdfFn, ArrayRef< Value > fnInputs, Value invMass, Value massMatrixSqrt, Value stepSize, int64_t positionSize, Value H0, Value maxDeltaEnergy, int64_t maxTreeDepth, DictionaryAttr autodiffAttrs={})
 
NUTSContext withH0 (Value newH0) const
 
- Public Member Functions inherited from mlir::impulse::HMCContext
 HMCContext (FlatSymbolRefAttr fn, ArrayRef< Value > fnInputs, ArrayRef< Type > fnResultTypes, Value originalTrace, ArrayAttr selection, ArrayAttr allAddresses, Value invMass, Value massMatrixSqrt, Value stepSize, Value trajectoryLength, int64_t positionSize, ArrayRef< SupportInfo > supports, DictionaryAttr autodiffAttrs={})
 
 HMCContext (FlatSymbolRefAttr logpdfFn, ArrayRef< Value > fnInputs, Value invMass, Value massMatrixSqrt, Value stepSize, Value trajectoryLength, int64_t positionSize, DictionaryAttr autodiffAttrs={})
 
bool hasCustomLogpdf () const
 
int64_t getFullTraceSize () const
 
Type getElementType () const
 
RankedTensorType getPositionType () const
 
RankedTensorType getScalarType () const
 
bool hasConstrainedSupports () const
 
HMCContext withStepSize (Value newStepSize) const
 

Public Attributes

Value H0
 
Value maxDeltaEnergy
 
int64_t maxTreeDepth
 
- Public Attributes inherited from mlir::impulse::HMCContext
FlatSymbolRefAttr fn
 
ArrayRef< Value > fnInputs
 
SmallVector< Type > fnResultTypes
 
Value originalTrace
 
ArrayAttr selection
 
ArrayAttr allAddresses
 
Value invMass
 
Value massMatrixSqrt
 
Value stepSize
 
Value trajectoryLength
 
int64_t positionSize
 
SmallVector< SupportInfosupports
 
FlatSymbolRefAttr logpdfFn
 
DictionaryAttr autodiffAttrs
 

Detailed Description

Definition at line 172 of file HMCUtils.h.

Constructor & Destructor Documentation

◆ NUTSContext() [1/2]

mlir::impulse::NUTSContext::NUTSContext ( FlatSymbolRefAttr fn,
ArrayRef< Value > fnInputs,
ArrayRef< Type > fnResultTypes,
Value originalTrace,
ArrayAttr selection,
ArrayAttr allAddresses,
Value invMass,
Value massMatrixSqrt,
Value stepSize,
int64_t positionSize,
ArrayRef< SupportInfo > supports,
Value H0,
Value maxDeltaEnergy,
int64_t maxTreeDepth,
DictionaryAttr autodiffAttrs = {} )
inline

Definition at line 177 of file HMCUtils.h.

◆ NUTSContext() [2/2]

mlir::impulse::NUTSContext::NUTSContext ( FlatSymbolRefAttr logpdfFn,
ArrayRef< Value > fnInputs,
Value invMass,
Value massMatrixSqrt,
Value stepSize,
int64_t positionSize,
Value H0,
Value maxDeltaEnergy,
int64_t maxTreeDepth,
DictionaryAttr autodiffAttrs = {} )
inline

Definition at line 189 of file HMCUtils.h.

Member Function Documentation

◆ withH0()

NUTSContext mlir::impulse::NUTSContext::withH0 ( Value newH0) const
inline

Definition at line 198 of file HMCUtils.h.

References H0.

Referenced by mlir::impulse::SampleNUTS().

Member Data Documentation

◆ H0

Value mlir::impulse::NUTSContext::H0

Definition at line 173 of file HMCUtils.h.

Referenced by mlir::impulse::buildBaseTree(), and withH0().

◆ maxDeltaEnergy

Value mlir::impulse::NUTSContext::maxDeltaEnergy

Definition at line 174 of file HMCUtils.h.

Referenced by mlir::impulse::buildBaseTree().

◆ maxTreeDepth

int64_t mlir::impulse::NUTSContext::maxTreeDepth

Definition at line 175 of file HMCUtils.h.

Referenced by mlir::impulse::buildTree().


The documentation for this struct was generated from the following file: