Skip to content

Enzyme Dialect

Reactant.MLIR.Dialects.enzyme.addTo Method

addTo

TODO

source
Reactant.MLIR.Dialects.enzyme.broadcast Method

broadcast

Broadcast the operand by adding extra dimensions with sizes provided by the shape attribute to the front. For scalar operands, ranked tensor is created.

NOTE: Only works for scalar and ranked tensor operands for now.

source
Reactant.MLIR.Dialects.enzyme.cholesky Method

cholesky

Computes the Cholesky decomposition of a symmetric positive definite matrix A. Returns L such that A = L @ L^T (if lower=true) or A = U^T @ U (if lower=false).

source
Reactant.MLIR.Dialects.enzyme.concat Method

concat

Concat list of input arguments into a generic value

source
Reactant.MLIR.Dialects.enzyme.dot Method

dot

Computes a general dot product operation. To be lowered to stablehlo.dot_general.

source
Reactant.MLIR.Dialects.enzyme.dump Method

dump

Debug operation that dumps a tensor value with a label.

source
Reactant.MLIR.Dialects.enzyme.dynamic_slice Method

dynamic_slice

Extract a slice from a tensor at dynamic start indices.

source
Reactant.MLIR.Dialects.enzyme.dynamic_update_slice Method

dynamic_update_slice

Update a slice in a tensor at dynamic start indices.

source
Reactant.MLIR.Dialects.enzyme.extract Method

extract

Extract value from batched operand at index

source
Reactant.MLIR.Dialects.enzyme.for_loop Method

for_loop

A counted loop operation that iterates from lowerBound to upperBound by step, carrying iter_args through each iteration. The iteration variable and iter_args are passed to the body region.

source
Reactant.MLIR.Dialects.enzyme.generate Method

generate

Generates from a generative function with some addresses constrained. The constraint tensor contains flattened constrained values in the order specified by constrained_addresses.

Returns: (trace, weight, rng, retvals...)

source
Reactant.MLIR.Dialects.enzyme.if_ Method

if_

A conditional operation that executes exactly one of two branches based on a boolean predicate.

source
Reactant.MLIR.Dialects.enzyme.log_add_exp Method

log_add_exp

Computes log(exp(x) + exp(y)).

source
Reactant.MLIR.Dialects.enzyme.logistic Method

logistic

Computes the logistic (sigmoid) function: 1 / (1 + exp(-x)).

source
Reactant.MLIR.Dialects.enzyme.mcmc Function

mcmc

Runs MCMC inference on selected addresses.

Two modes of operation:

  1. Trace-based mode: fn and original_trace are provided. The model function with enzyme.sample ops defines the density.

  2. Custom logpdf mode: logpdf_fn and initial_position are provided. The logpdf function maps position → scalar log-density directly.

The selection attribute determines which addresses to sample via HMC/NUTS. All sample addresses are included in the trace tensor for consistency.

Returns: (trace, diagnostics, rng)

  • trace: tensor<num_samples x position_size x f64>

  • diagnostics: tensor<num_samples x i1> - placeholder for future expansion

  • rng: updated RNG state

source
Reactant.MLIR.Dialects.enzyme.mh Method

mh

Performs one MH step: regenerates selected addresses and accepts/rejects based on weight ratio.

source
Reactant.MLIR.Dialects.enzyme.popcount Method

popcount

Returns the number of 1-bits elementwise.

source
Reactant.MLIR.Dialects.enzyme.random Method

random

Generates random numbers using the rng_distribution algorithm and produces a result tensor.

If rng_distribution = UNIFORM, then the random numbers are generated following the uniform distribution over the interval [a, b). If a >= b, the behavior is undefined.

If rng_distribution = NORMAL, then the random numbers are generated following the normal distribution with mean = a and standard deviation = b. If b < 0, the behavior is undefined.

If rng_distribution = MULTINORMAL, then the random numbers are generated following the multivariate normal distribution with mean = a (scalar or vector) and covariance matrix = b. The parameter b should be a positive definite matrix.

By convention, the 0th operand in inputs is the initial RNG state and the 0th operand in results is the updated RNG state.

source
Reactant.MLIR.Dialects.enzyme.randomSplit Method

randomSplit

Splits an RNG state into multiple independent RNG states. Reference: https://github.com/jax-ml/jax/blob/c25e095fcec9678a4ce5f723afce0c6a3c48a5e7/jax/_src/random.py#L281-L294

source
Reactant.MLIR.Dialects.enzyme.regenerate Method

regenerate

Regenerates selected addresses while keeping others fixed. Used internally by MH.

Takes explicit old_trace and returns new trace with weight.

Returns: (new_trace, weight, retvals...)

  • new_trace: tensor<1 x position_size x f64> - flattened samples

  • weight: tensor<f64> - accumulated log probability

  • retvals: original function return values

source
Reactant.MLIR.Dialects.enzyme.sample Method

sample

Sample from a distribution. By convention, the 0th operand in inputs or outputs is the initial RNG state (seed).

source
Reactant.MLIR.Dialects.enzyme.select Method

select

Extended select operation that supports:

  • tensor<i1> conditions with differently-sized operands

  • standard cases supported by arith.select

source
Reactant.MLIR.Dialects.enzyme.simulate Method

simulate

Simulates a generative function, building a trace tensor containing all sampled values and computing the accumulated log probability weight.

The selection attribute specifies all sample addresses in order, determining the trace tensor layout.

Returns: (trace, weight, rng, retvals...)

  • trace: tensor<1 x position_size x f64> - flattened samples

  • weight: tensor<f64> - accumulated log probability

  • rng: updated RNG state

  • retvals: original function return values

source
Reactant.MLIR.Dialects.enzyme.slice Method

slice

Extract a static slice from a tensor.

source
Reactant.MLIR.Dialects.enzyme.triangular_solve Method

triangular_solve

Solves a system of linear equations with a triangular coefficient matrix. If left_side=true, solves op(A) @ X = B for X. If left_side=false, solves X @ op(A) = B for X. op(A) is determined by transpose_a: NO_TRANSPOSE, TRANSPOSE, or ADJOINT.

source
Reactant.MLIR.Dialects.enzyme.untracedCall Method

untracedCall

Call a probabilistic function without tracing. By convention, the 0th operand in inputs or outputs is the initial RNG state (seed).

source
Reactant.MLIR.Dialects.enzyme.while_loop Method

while_loop

A while loop operation that continues iterating as long as the condition evaluates to true. Intended to be lowered to stablehlo.while.

source