Skip to content

Traces and constrained inference

A trace records every random choice in one model execution, along with its log-density and the model's return value. There are two ways to produce a trace:

  • simulate runs forward; each sample draws from the prior.

  • generate conditions on observed choices and returns an importance weight.

Each has a low-level form returning MLIR values, and a helper that compiles, runs, and returns an unflattened Trace.

Trace

FieldContents
choicesDict{Symbol,Any} at this level
subtracesDict{Symbol,Any} for submodels
retvalModel return
weightImportance weight (generate) or prior log-density (simulate)

Each value in choices is an array indexed on its first axis by the sample's symbol.

Internal Trace Representation

@compile optimize = :probprog returns Impulse's internal tensor-based representation of traces.

The flat form contains every random choice the model would make, flattened and concatenated into a single row, with the layout fixed at trace time. With num_samples rows (e.g., from a NUTS run with num_samples = 12), the trace tensor becomes a (num_samples, position_size) tensor where position_size is the total number of elements across all sampled sites. There are no symbol lookups in this representation; per-site offsets and shapes are baked in at compile time, and the bridge helpers below carry that layout metadata so the tensor can be reconstituted into a tree.

Helpers that bridge the two

Reactant frontend provides convenience helpers that handle the conversion in either direction.

  • simulate_ and generate_ compile, run the model, and immediately convert the result back to a tree-shaped Trace. The flat form never surfaces.

  • unflatten_trace does the explicit tensor → tree conversion, given a trace tensor and layout metadata (per-site offset, shape, address-path) collected by the Impulse tracing context.

  • with_trace installs the Impulse tracing context that collects the layout metadata while a compiled program is being built. your @compile call in it.

simulate_

julia
using Reactant
using Reactant: ProbProg, ReactantRNG

function model(rng, μ, σ, shape)
    _, s = ProbProg.sample(rng, ProbProg.Normal(μ, σ, shape); symbol=:s)
    _, t = ProbProg.sample(rng, ProbProg.Normal(s, σ, shape); symbol=:t)
    return t
end

seed = Reactant.to_rarray(UInt64[1, 4])
rng  = ReactantRNG(seed)
μ    = Reactant.ConcreteRNumber(0.0)
σ    = Reactant.ConcreteRNumber(1.0)

trace, weight = ProbProg.simulate_(rng, model, μ, σ, (3,))

trace.choices[:s]   # (1, 3)
trace.choices[:t]   # (1, 3)
trace.retval[1]
trace.weight
-8.199174728997857

simulate_ calls @compile optimize=:probprog on ProbProg.simulate(rng, model, args...) and reshapes the flat position tensor using layout metadata collected during tracing.

Submodels

julia
function pair(rng, μ, σ, shape)
    _, a = ProbProg.sample(rng, ProbProg.Normal(μ, σ, shape); symbol=:a)
    _, b = ProbProg.sample(rng, ProbProg.Normal(μ, σ, shape); symbol=:b)
    return a .* b
end

function outer(rng, μ, σ, shape)
    _, s = ProbProg.sample(rng, pair, μ, σ, shape; symbol=:s)
    _, t = ProbProg.sample(rng, pair, s,  σ, shape; symbol=:t)
    return t
end

trace, _ = ProbProg.simulate_(rng, outer, μ, σ, (3, 3, 3))
trace.subtraces[:s].choices[:a]
trace.subtraces[:t].choices[:b]
1×3×3×3 Array{Float64, 4}:
[:, :, 1, 1] =
 -1.61524  -1.62423  -0.206727

[:, :, 2, 1] =
 1.52398  -0.965305  -0.275609

[:, :, 3, 1] =
 3.2088  0.836275  0.85077

[:, :, 1, 2] =
 -1.03268  -0.351676  -1.87799

[:, :, 2, 2] =
 1.4149  1.0487  1.17869

[:, :, 3, 2] =
 1.34405  0.161754  0.0624658

[:, :, 1, 3] =
 1.45042  1.12212  0.235213

[:, :, 2, 3] =
 -0.558201  -0.401297  -3.025

[:, :, 3, 3] =
 -0.332149  0.868088  -0.506109

Low-level simulate

Inside a larger compiled program, simulate returns MLIR-traced values:

julia
trace_tensor, weight, retval = ProbProg.simulate(rng, model, μ, σ, shape)

trace_tensor is rank-2 shape (1, position_size). Wrap with with_trace to install the tracing context:

julia
code, tt = ProbProg.with_trace() do
    @code_hlo optimize=:probprog begin
        ProbProg.simulate(rng, model, μ, σ, shape)
    end
end

Rebuild a Trace with unflatten_trace:

julia
trace = ProbProg.unflatten_trace(trace_tensor, weight, tt.entries, retval)

Conditioning

A Constraint pins addresses to observed values:

julia
obs = ProbProg.Constraint(
    :param_a => [0.0],
    :param_b => [0.0],
    :ys_a    => [-2.3, -1.6, -0.4, 0.6, 1.4],
    :ys_b    => [-2.6, -1.4, -0.6, 0.4, 1.6],
)
Reactant.ProbProg.Constraint with 4 entries:
  Address([:ys_b])    => [-2.6, -1.4, -0.6, 0.4, 1.6]
  Address([:ys_a])    => [-2.3, -1.6, -0.4, 0.6, 1.4]
  Address([:param_b]) => [0.0]
  Address([:param_a]) => [0.0]

Nested addresses: Constraint(:outer => :inner => value).

generate_ returns a trace whose weight is the log importance weight:

julia
trace, weight = ProbProg.generate_(rng, obs, model, xs)

For embedding inside a compiled function, flatten manually and call generate:

julia
constrained_addresses = ProbProg.extract_addresses(obs)
obs_flat = Float64[]
for addr in constrained_addresses
    append!(obs_flat, vec(obs[addr]))
end
obs_tensor = Reactant.to_rarray(reshape(obs_flat, 1, :))

trace_tensor, weight, _ = ProbProg.generate(
    rng, obs_tensor, model, xs; constrained_addresses,
)

Append values in extract_addresses(obs) order. generate_ handles this automatically.

Addresses

An Address is a path of symbols:

julia
ProbProg.Address(:slope)
ProbProg.Address(:outer, :inner, :x)
ProbProg.Address([:outer, :inner, :x])
Reactant.ProbProg.Address([:outer, :inner, :x])

Equality is path equality. A Selection is an OrderedSet{Address}:

julia
ProbProg.select(
    ProbProg.Address(:slope),
    ProbProg.Address(:intercept),
)
OrderedCollections.OrderedSet{Reactant.ProbProg.Address} with 2 elements:
  Reactant.ProbProg.Address([:intercept])
  Reactant.ProbProg.Address([:slope])

MCMC kernels use selections to choose which sites to update.

Summary

text
simulate(_)   forward sampling, prior log-density
generate(_)   observations applied, importance weight
Trace         choices + subtraces + retval + weight
Constraint    Address => observed value
Selection     OrderedSet{Address}

Next: MH, HMC, NUTS.