Probabilistic Programming
Reactant.jl compiles ordinary Julia code through MLIR, running it on CPU, GPU, or TPU without rewriting. You write a normal Julia function, wrap a call to it in @compile, and Reactant traces the function, lowers it through MLIR, and hands you back a callable compiled program. Array inputs are staged through ConcreteRArray (created with Reactant.to_rarray) so they live on the target device.
Reactant.ProbProg is the Julia front-end for the impulse dialect, implemented across Enzyme (dialect definition and inference materialization passes) and Enzyme-JAX (backend-specific lowering). The impulse dialect provides high-level MLIR ops for describing probabilistic modeling and inference, materializes inference computation through compiler passes, and applies general-purpose and probabilistic-programming-specific optimizations during lowering.
optimize = :probprog opt-in required
For now, @compile needs an explicit optimize = :probprog argument on probabilistic programs to enable the impulse-specific MLIR passes (you'll see this in every @compile call below). Merging those passes into the default @compile pipeline is work in progress; once it lands, the explicit opt-in will no longer be required.
Next, we walk through two operating modes of Reactant.ProbProg: a trace-based mode built around a generative function, and a custom log-density mode that takes a custom log-density function.
Trace-based mode
We describe a Bayesian linear regression question:
Both regression coefficients are given Gaussian priors, tighter on slope (standard deviation 2) and looser on intercept (standard deviation 10). Each observation y_i is then drawn from a Gaussian centered on the fitted value slope · x_i + intercept with fixed noise (standard deviation 1).
The data
Synthetic data drawn from slope = -2, intercept = 10. The xs / ys pair is what we'll condition on.
xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90]Describing the Model
We describe this model in Reactant.ProbProg as follows:
using Reactant: ProbProg
function model(rng, xs)
_, slope = ProbProg.sample(
rng, ProbProg.Normal(0.0, 2.0, (1,)); symbol=:slope,
)
_, intercept = ProbProg.sample(
rng, ProbProg.Normal(0.0, 10.0, (1,)); symbol=:intercept,
)
_, ys = ProbProg.sample(
rng,
ProbProg.Normal(slope .* xs .+ intercept, 1.0, (length(xs),));
symbol=:ys,
)
return ys
endmodel (generic function with 1 method)Each random choice is introduced by a ProbProg.sample(rng, dist; symbol=...) call that takes a random number generator (RNG) and a distribution function. The symbol keyword names the sample site used for conditioning and specifying parameters to infer.
As a calling convention, ProbProg.sample returns (rng, value); the first element (omitted with _ above) is the updated RNG. In the current implementation rng is a traced ReactantRNG whose state corresponds to a tensor<2xui64> RNG state in the generated MLIR. We don't thread it through manually because Reactant tracing handles the input/output threading at the IR level, and ReactantRNG's internal state is updated via Julia mutability (see here for details).
Describing Inference
We condition on the observed ys with a Constraint object:
obs = ProbProg.Constraint(:ys => ys)Reactant.ProbProg.Constraint with 1 entry:
Address([:ys]) => [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, …The current implementation requires a bit of boilerplate to flatten the Constraint into a tensor representation and to extract its address set before passing them to the @compile'd function below (see Traces and constrained inference for details):
obs_tensor = ProbProg.flatten_constraint(obs)1×10 ConcretePJRTArray{Float64,2}:
8.23 5.87 3.99 2.59 0.23 -0.66 -3.53 -6.91 -7.24 -9.9constrained_addresses = ProbProg.extract_addresses(obs)Set{Reactant.ProbProg.Address} with 1 element:
Reactant.ProbProg.Address([:ys])We then specify what parameters to infer:
selection = ProbProg.select(
ProbProg.Address(:slope),
ProbProg.Address(:intercept),
)OrderedCollections.OrderedSet{Reactant.ProbProg.Address} with 2 elements:
Reactant.ProbProg.Address([:intercept])
Reactant.ProbProg.Address([:slope])We express inference in a single function that conditions the model on the constraint with generate and then runs NUTS over the selected sites with mcmc.
function infer(rng, xs, obs_tensor, step_size, inverse_mass_matrix)
trace, = ProbProg.generate(
rng, obs_tensor, model, xs; constrained_addresses,
)
trace, = ProbProg.mcmc(
rng, trace, model, xs;
selection, algorithm=:NUTS,
step_size, inverse_mass_matrix,
num_warmup=200, num_samples=500,
)
return trace
endinfer (generic function with 1 method)The returned trace contains the sampling result as a 2D tensor: each row is the concatenation of all selected sites' flattened values for one post-warmup sample. (We will show a possible trace for this example problem below.)
Compiling with @compile
We compile infer with Reactant's @compile for compiler-optimized probabilistic inference:
rng = ReactantRNG()
step_size = Reactant.ConcreteRNumber(0.1)
inverse_mass_matrix = Reactant.ConcreteRArray([1.0 0.0; 0.0 1.0])
compiled_fn = @compile optimize=:probprog infer(
rng, xs, obs_tensor, step_size, inverse_mass_matrix,
)Reactant compiled function infer (with tag ##infer_reactant#2095)Defaults
It is often sufficient to start with step_size = 1.0 and an identity inverse_mass_matrix. With the default adapt_step_size = true and adapt_mass_matrix = true, mcmc adaptively selects appropriate values during the warmup iterations.
The compiled_fn is a callable object that takes the same arguments as infer and returns the inference result. We can execute the compiled inference program any number of times by calling it:
trace_tensor = compiled_fn(rng, xs, obs_tensor, step_size, inverse_mass_matrix)500×2 ConcretePJRTArray{Float64,2}:
9.66731 -1.98223
10.825 -2.06784
10.7683 -2.06023
10.7558 -2.08704
10.1402 -2.01146
12.1224 -2.24268
9.79632 -2.06333
9.78376 -2.03945
10.6829 -2.14995
9.9425 -1.8106
⋮
9.75526 -1.87955
9.82201 -1.89766
9.65348 -1.86772
10.3429 -1.93074
10.2749 -2.01294
10.6157 -2.08754
9.7794 -1.97121
10.9895 -2.09238
9.35415 -1.84762Each row is one post-warmup sample; columns hold the sampled values for each selected site:
:intercept :slope
sample 1: ... ...
sample 2: ... ...
⋮ ⋮ ⋮
sample N: ... ...From the sampler output, the posterior mean is:
(
posterior_mean_intercept = mean(trace_tensor[:, 1]),
posterior_mean_slope = mean(trace_tensor[:, 2]),
)(posterior_mean_intercept = 10.107614414740327, posterior_mean_slope = -1.977662947710627)The data were generated from slope = -2, intercept = 10; NUTS recovers both posterior means.
Custom logpdf mode
In larger applications, it is often infeasible to express the model in a PPL modeling language as we showed in the trace-based mode above. We can use Reactant.ProbProg to compile and run its inference algorithms directly on a hand-written log-density function via the custom logpdf mode.
For example, we can write the log-density function of the previous Bayesian linear regression model directly:
function logdensity(θ, xs, ys)
X = hcat(xs, ones(length(xs)))
residuals = ys .- X * θ
pr = -0.5 * sum(θ .^ 2 ./ [4.0, 100.0])
ll = -0.5 * sum(residuals .^ 2)
return ll + pr
endlogdensity (generic function with 1 method)We pass logdensity to the mcmc_logpdf interface along with an initial position vector (the parameter values the chain starts from):
function infer_logpdf(rng, θ0, xs, ys, step_size, inverse_mass_matrix)
trace, = ProbProg.mcmc_logpdf(rng, logdensity, θ0, xs, ys;
algorithm=:NUTS,
step_size, inverse_mass_matrix,
num_warmup=200, num_samples=500,
)
return trace
end
θ0 = Reactant.to_rarray(reshape([0.0, 0.0], 1, 2))
compiled_logpdf = @compile optimize=:probprog infer_logpdf(
rng, θ0, xs, ys, step_size, inverse_mass_matrix,
)
trace = compiled_logpdf(rng, θ0, xs, ys, step_size, inverse_mass_matrix)500×2 ConcretePJRTArray{Float64,2}:
-1.97305 10.2574
-2.05138 10.7388
-1.94223 9.45331
-1.97712 10.5116
-1.97949 10.2917
-2.0006 9.93694
-1.98166 10.2793
-1.8983 9.85859
-1.8983 9.85859
-1.82802 9.12852
⋮
-2.14399 10.9687
-1.88633 9.38085
-1.99174 9.90648
-1.99228 9.93794
-1.91589 10.0715
-1.99145 10.3571
-2.11713 11.0849
-1.94442 10.0922
-1.85799 9.54698We get similar inference results
(
posterior_mean_slope = mean(trace[:, 1]),
posterior_mean_intercept = mean(trace[:, 2]),
)(posterior_mean_slope = -1.9806765926108978, posterior_mean_intercept = 10.165651338216733)NUTS recovers both posterior means here too.
More Explanations
Sampling and distributions — semantics of
sample, the built-inDistributionhierarchy, custom samplers with user-suppliedlogpdf, and constrained supports.Traces and constrained inference —
simulateversusgenerate,Constraint/Addressconstruction, and the trace round-trip between flat position vector and tree-shapedTrace.MCMC: MH, HMC, NUTS — Metropolis-Hastings over a
Selection, gradient-based chains viamcmc, and log-density-driven chains viamcmc_logpdf.Running and resuming chains —
run_chain, warmup and checkpointing throughMCMCState,save_state/load_state, and posterior summaries withmcmc_summary.