Enzyme Dialect
Reactant.MLIR.Dialects.enzyme.broadcast Method
broadcast
Broadcast the operand by adding extra dimensions with sizes provided by the shape attribute to the front. For scalar operands, ranked tensor is created.
NOTE: Only works for scalar and ranked tensor operands for now.
sourceReactant.MLIR.Dialects.enzyme.cholesky Method
cholesky
Computes the Cholesky decomposition of a symmetric positive definite matrix A. Returns L such that A = L @ L^T (if lower=true) or A = U^T @ U (if lower=false).
sourceReactant.MLIR.Dialects.enzyme.concat Method
concat
Concat list of input arguments into a generic value
sourceReactant.MLIR.Dialects.enzyme.dot Method
dot
Computes a general dot product operation. To be lowered to stablehlo.dot_general.
Reactant.MLIR.Dialects.enzyme.dump Method
dump
Debug operation that dumps a tensor value with a label.
sourceReactant.MLIR.Dialects.enzyme.dynamic_slice Method
dynamic_slice
Extract a slice from a tensor at dynamic start indices.
sourceReactant.MLIR.Dialects.enzyme.dynamic_update_slice Method
dynamic_update_slice
Update a slice in a tensor at dynamic start indices.
sourceReactant.MLIR.Dialects.enzyme.extract Method
extract
Extract value from batched operand at index
sourceReactant.MLIR.Dialects.enzyme.for_loop Method
for_loop
A counted loop operation that iterates from lowerBound to upperBound by step, carrying iter_args through each iteration. The iteration variable and iter_args are passed to the body region.
Reactant.MLIR.Dialects.enzyme.generate Method
generate
Generates from a generative function with some addresses constrained. The constraint tensor contains flattened constrained values in the order specified by constrained_addresses.
Returns: (trace, weight, rng, retvals...)
sourceReactant.MLIR.Dialects.enzyme.if_ Method
if_
A conditional operation that executes exactly one of two branches based on a boolean predicate.
sourceReactant.MLIR.Dialects.enzyme.logistic Method
logistic
Computes the logistic (sigmoid) function: 1 / (1 + exp(-x)).
sourceReactant.MLIR.Dialects.enzyme.mcmc Function
mcmc
Runs MCMC inference on selected addresses.
Two modes of operation:
Trace-based mode:
fnandoriginal_traceare provided. The model function withenzyme.sampleops defines the density.Custom logpdf mode:
logpdf_fnandinitial_positionare provided. The logpdf function maps position → scalar log-density directly.
The selection attribute determines which addresses to sample via HMC/NUTS. All sample addresses are included in the trace tensor for consistency.
Returns: (trace, diagnostics, rng)
trace: tensor<num_samples x position_size x f64>
diagnostics: tensor<num_samples x i1> - placeholder for future expansion
rng: updated RNG state
Reactant.MLIR.Dialects.enzyme.mh Method
mh
Performs one MH step: regenerates selected addresses and accepts/rejects based on weight ratio.
sourceReactant.MLIR.Dialects.enzyme.popcount Method
popcount
Returns the number of 1-bits elementwise.
sourceReactant.MLIR.Dialects.enzyme.random Method
random
Generates random numbers using the rng_distribution algorithm and produces a result tensor.
If rng_distribution = UNIFORM, then the random numbers are generated following the uniform distribution over the interval [a, b). If a >= b, the behavior is undefined.
If rng_distribution = NORMAL, then the random numbers are generated following the normal distribution with mean = a and standard deviation = b. If b < 0, the behavior is undefined.
If rng_distribution = MULTINORMAL, then the random numbers are generated following the multivariate normal distribution with mean = a (scalar or vector) and covariance matrix = b. The parameter b should be a positive definite matrix.
By convention, the 0th operand in inputs is the initial RNG state and the 0th operand in results is the updated RNG state.
sourceReactant.MLIR.Dialects.enzyme.randomSplit Method
randomSplit
Splits an RNG state into multiple independent RNG states. Reference: https://github.com/jax-ml/jax/blob/c25e095fcec9678a4ce5f723afce0c6a3c48a5e7/jax/_src/random.py#L281-L294
sourceReactant.MLIR.Dialects.enzyme.regenerate Method
regenerate
Regenerates selected addresses while keeping others fixed. Used internally by MH.
Takes explicit old_trace and returns new trace with weight.
Returns: (new_trace, weight, retvals...)
new_trace: tensor<1 x position_size x f64> - flattened samples
weight: tensor<f64> - accumulated log probability
retvals: original function return values
Reactant.MLIR.Dialects.enzyme.sample Method
sample
Sample from a distribution. By convention, the 0th operand in inputs or outputs is the initial RNG state (seed).
Reactant.MLIR.Dialects.enzyme.select Method
select
Extended select operation that supports:
tensor<i1>conditions with differently-sized operandsstandard cases supported by
arith.select
Reactant.MLIR.Dialects.enzyme.simulate Method
simulate
Simulates a generative function, building a trace tensor containing all sampled values and computing the accumulated log probability weight.
The selection attribute specifies all sample addresses in order, determining the trace tensor layout.
Returns: (trace, weight, rng, retvals...)
trace: tensor<1 x position_size x f64> - flattened samples
weight: tensor<f64> - accumulated log probability
rng: updated RNG state
retvals: original function return values
Reactant.MLIR.Dialects.enzyme.triangular_solve Method
triangular_solve
Solves a system of linear equations with a triangular coefficient matrix. If left_side=true, solves op(A) @ X = B for X. If left_side=false, solves X @ op(A) = B for X. op(A) is determined by transpose_a: NO_TRANSPOSE, TRANSPOSE, or ADJOINT.
sourceReactant.MLIR.Dialects.enzyme.untracedCall Method
untracedCall
Call a probabilistic function without tracing. By convention, the 0th operand in inputs or outputs is the initial RNG state (seed).
Reactant.MLIR.Dialects.enzyme.while_loop Method
while_loop
A while loop operation that continues iterating as long as the condition evaluates to true. Intended to be lowered to stablehlo.while.