Skip to content

Probabilistic Programming

Reactant.jl compiles ordinary Julia code through MLIR, running it on CPU, GPU, or TPU without rewriting. You write a normal Julia function, wrap a call to it in @compile, and Reactant traces the function, lowers it through MLIR, and hands you back a callable compiled program. Array inputs are staged through ConcreteRArray (created with Reactant.to_rarray) so they live on the target device.

Reactant.ProbProg is the Julia front-end for the impulse dialect, implemented across Enzyme (dialect definition and inference materialization passes) and Enzyme-JAX (backend-specific lowering). The impulse dialect provides high-level MLIR ops for describing probabilistic modeling and inference, materializes inference computation through compiler passes, and applies general-purpose and probabilistic-programming-specific optimizations during lowering.

optimize = :probprog opt-in required

For now, @compile needs an explicit optimize = :probprog argument on probabilistic programs to enable the impulse-specific MLIR passes (you'll see this in every @compile call below). Merging those passes into the default @compile pipeline is work in progress; once it lands, the explicit opt-in will no longer be required.

Next, we walk through two operating modes of Reactant.ProbProg: a trace-based mode built around a generative function, and a custom log-density mode that takes a custom log-density function.

Trace-based mode

We describe a Bayesian linear regression question:

slopeN(0,2)interceptN(0,10)yislope,interceptN(slopexi+intercept,1)

Both regression coefficients are given Gaussian priors, tighter on slope (standard deviation 2) and looser on intercept (standard deviation 10). Each observation y_i is then drawn from a Gaussian centered on the fitted value slope · x_i + intercept with fixed noise (standard deviation 1).

The data

Synthetic data drawn from slope = -2, intercept = 10. The xs / ys pair is what we'll condition on.

julia
xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90]

Describing the Model

We describe this model in Reactant.ProbProg as follows:

julia
using Reactant: ProbProg

function model(rng, xs)
    _, slope = ProbProg.sample(
        rng, ProbProg.Normal(0.0, 2.0, (1,)); symbol=:slope,
    )
    _, intercept = ProbProg.sample(
        rng, ProbProg.Normal(0.0, 10.0, (1,)); symbol=:intercept,
    )
    _, ys = ProbProg.sample(
        rng,
        ProbProg.Normal(slope .* xs .+ intercept, 1.0, (length(xs),));
        symbol=:ys,
    )
    return ys
end
model (generic function with 1 method)

Each random choice is introduced by a ProbProg.sample(rng, dist; symbol=...) call that takes a random number generator (RNG) and a distribution function. The symbol keyword names the sample site used for conditioning and specifying parameters to infer.

As a calling convention, ProbProg.sample returns (rng, value); the first element (omitted with _ above) is the updated RNG. In the current implementation rng is a traced ReactantRNG whose state corresponds to a tensor<2xui64> RNG state in the generated MLIR. We don't thread it through manually because Reactant tracing handles the input/output threading at the IR level, and ReactantRNG's internal state is updated via Julia mutability (see here for details).

Describing Inference

We condition on the observed ys with a Constraint object:

julia
obs = ProbProg.Constraint(:ys => ys)
Reactant.ProbProg.Constraint with 1 entry:
  Address([:ys]) => [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, …

The current implementation requires a bit of boilerplate to flatten the Constraint into a tensor representation and to extract its address set before passing them to the @compile'd function below (see Traces and constrained inference for details):

julia
obs_tensor = ProbProg.flatten_constraint(obs)
1×10 ConcretePJRTArray{Float64,2}:
 8.23  5.87  3.99  2.59  0.23  -0.66  -3.53  -6.91  -7.24  -9.9
julia
constrained_addresses = ProbProg.extract_addresses(obs)
Set{Reactant.ProbProg.Address} with 1 element:
  Reactant.ProbProg.Address([:ys])

We then specify what parameters to infer:

julia
selection = ProbProg.select(
    ProbProg.Address(:slope),
    ProbProg.Address(:intercept),
)
OrderedCollections.OrderedSet{Reactant.ProbProg.Address} with 2 elements:
  Reactant.ProbProg.Address([:intercept])
  Reactant.ProbProg.Address([:slope])

We express inference in a single function that conditions the model on the constraint with generate and then runs NUTS over the selected sites with mcmc.

julia
function infer(rng, xs, obs_tensor, step_size, inverse_mass_matrix)
    trace, = ProbProg.generate(
        rng, obs_tensor, model, xs; constrained_addresses,
    )
    trace, = ProbProg.mcmc(
        rng, trace, model, xs;
        selection, algorithm=:NUTS,
        step_size, inverse_mass_matrix,
        num_warmup=200, num_samples=500,
    )
    return trace
end
infer (generic function with 1 method)

The returned trace contains the sampling result as a 2D tensor: each row is the concatenation of all selected sites' flattened values for one post-warmup sample. (We will show a possible trace for this example problem below.)

Compiling with @compile

We compile infer with Reactant's @compile for compiler-optimized probabilistic inference:

julia
rng                 = ReactantRNG()
step_size           = Reactant.ConcreteRNumber(0.1)
inverse_mass_matrix = Reactant.ConcreteRArray([1.0 0.0; 0.0 1.0])

compiled_fn = @compile optimize=:probprog infer(
    rng, xs, obs_tensor, step_size, inverse_mass_matrix,
)
Reactant compiled function infer (with tag ##infer_reactant#2098)

Defaults

It is often sufficient to start with step_size = 1.0 and an identity inverse_mass_matrix. With the default adapt_step_size = true and adapt_mass_matrix = true, mcmc adaptively selects appropriate values during the warmup iterations.

The compiled_fn is a callable object that takes the same arguments as infer and returns the inference result. We can execute the compiled inference program any number of times by calling it:

julia
trace_tensor = compiled_fn(rng, xs, obs_tensor, step_size, inverse_mass_matrix)
500×2 ConcretePJRTArray{Float64,2}:
 11.1216   -2.0421
 10.4061   -1.86466
  9.74547  -1.96183
 11.1968   -2.15209
 11.2325   -2.10342
 11.1054   -2.09806
 10.0431   -1.97479
 10.3341   -2.03595
 10.4008   -2.03963
 10.3813   -1.98268

  9.30446  -1.80016
 11.0091   -2.06852
  8.83636  -1.81066
  8.5156   -1.69056
  9.47989  -1.85481
 10.2086   -1.98107
 10.7282   -2.06577
  8.79433  -1.78854
  8.75925  -1.79561

Each row is one post-warmup sample; columns hold the sampled values for each selected site:

text
            :intercept   :slope
sample 1:      ...         ...
sample 2:      ...         ...
   ⋮            ⋮           ⋮
sample N:      ...         ...

From the sampler output, the posterior mean is:

julia
(
    posterior_mean_intercept = mean(trace_tensor[:, 1]),
    posterior_mean_slope     = mean(trace_tensor[:, 2]),
)
(posterior_mean_intercept = 10.09658312079464, posterior_mean_slope = -1.9720936724063631)

The data were generated from slope = -2, intercept = 10; NUTS recovers both posterior means.

Custom logpdf mode

In larger applications, it is often infeasible to express the model in a PPL modeling language as we showed in the trace-based mode above. We can use Reactant.ProbProg to compile and run its inference algorithms directly on a hand-written log-density function via the custom logpdf mode.

For example, we can write the log-density function of the previous Bayesian linear regression model directly:

julia
function logdensity(θ, xs, ys)
    X = hcat(xs, ones(length(xs)))
    residuals = ys .- X * θ
    pr = -0.5 * sum.^ 2 ./ [4.0, 100.0])
    ll = -0.5 * sum(residuals .^ 2)
    return ll + pr
end
logdensity (generic function with 1 method)

We pass logdensity to the mcmc_logpdf interface along with an initial position vector (the parameter values the chain starts from):

julia
function infer_logpdf(rng, θ0, xs, ys, step_size, inverse_mass_matrix)
    trace, = ProbProg.mcmc_logpdf(rng, logdensity, θ0, xs, ys;
        algorithm=:NUTS,
        step_size, inverse_mass_matrix,
        num_warmup=200, num_samples=500,
    )
    return trace
end

θ0 = Reactant.to_rarray(reshape([0.0, 0.0], 1, 2))
compiled_logpdf = @compile optimize=:probprog infer_logpdf(
    rng, θ0, xs, ys, step_size, inverse_mass_matrix,
)
trace = compiled_logpdf(rng, θ0, xs, ys, step_size, inverse_mass_matrix)
500×2 ConcretePJRTArray{Float64,2}:
 -1.78213   8.84844
 -2.19068  11.2294
 -2.03949  10.4488
 -1.88667   9.48885
 -2.0951   10.8067
 -2.10363  10.8939
 -2.14528  11.17
 -1.86508  10.2658
 -1.85651   9.81319
 -1.87651   9.74547

 -1.95414  10.2514
 -1.96849  10.1704
 -1.9388   10.4164
 -1.91567  10.3361
 -1.92606  10.5246
 -1.95926  10.2643
 -1.85822   9.16157
 -2.05255  10.8243
 -2.09756  11.0914

We get similar inference results

julia
(
    posterior_mean_slope     = mean(trace[:, 1]),
    posterior_mean_intercept = mean(trace[:, 2]),
)
(posterior_mean_slope = -1.9706218810377627, posterior_mean_intercept = 10.096908646270892)

NUTS recovers both posterior means here too.

More Explanations