Skip to content

Sampling and distributions

A random choice uses ProbProg.sample. Two call forms: a built-in Distribution, or a user sampler function with optional log-density.

Built-in distribution

julia
using Reactant
using Reactant: ProbProg, ReactantRNG

seed = Reactant.to_rarray(UInt64[1, 4])
rng  = ReactantRNG(seed)

_, x = ProbProg.sample(rng, ProbProg.Normal(0.0, 1.0, (10,)); symbol=:x)

Returns (updated_rng, value). symbol is the trace address. Omitting it produces a gensym name that cannot be constrained or inspected later.

RNG

Counter-based, seeded by a length-2 UInt64 vector:

julia
seed = Reactant.to_rarray(UInt64[1, 4])
rng  = ReactantRNG(seed)

Same seed reproduces the same trajectory. rng.seed updates in place after each compiled call.

Distributions

TypeConstructorsupport
NormalNormal(μ, σ, shape):real
ExponentialExponential(λ, shape):positive
LogNormalLogNormal(μ, σ, shape):positive
BernoulliBernoulli(logits, shape):real (logit scale)

shape is a non-empty tuple. Parameters broadcast against shape, so μ and σ can be scalars, ConcreteRNumbers, or arrays.

julia
ProbProg.Normal()               # μ=0, σ=1, shape=(1,)
ProbProg.Normal(0.0, 1.0)       # shape=(1,)
ProbProg.Normal(0.0, 1.0, (5,))
ProbProg.Exponential(2.0, (3,))
ProbProg.LogNormal(0.0, 0.5, (2, 2))
ProbProg.Bernoulli(logits, (4,))

Custom sampler

julia
ProbProg.sample(
    rng, my_sampler, args...;
    symbol  = :my_site,
    logpdf  = my_logpdf,            # (sample, args...) -> scalar
    support = :real,
    bounds  = (nothing, nothing),
)
  • my_sampler(rng, args...) must be traceable by Reactant.

  • With logpdf, the site contributes to the model weight and is an inference target. Without it, the site is traced but contributes no log-density.

  • logpdf is called with the sampled value and the original args, no rng.

Example:

julia
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)

function normal_logpdf(x, μ, σ, _)
    return -length(x) * log(σ) - length(x)/2 * log() -
           sum((x .- μ).^2 ./ (2 .* σ.^2))
end

function model(rng, xs)
    _, slope = ProbProg.sample(
        rng, normal, 0.0, 2.0, (1,);
        symbol=:slope, logpdf=normal_logpdf,
    )
    _, intercept = ProbProg.sample(
        rng, normal, 0.0, 10.0, (1,);
        symbol=:intercept, logpdf=normal_logpdf,
    )
    _, ys = ProbProg.sample(
        rng, normal, slope .* xs .+ intercept, 1.0, (length(xs),);
        symbol=:ys, logpdf=normal_logpdf,
    )
    return ys
end
model (generic function with 1 method)

support and bounds

HMC and NUTS unconstrain sites based on support before proposing:

supportMeaning
:realUnconstrained (default)
:positivex > 0
:unit_intervalx(0,1)
:intervalx(lower,upper) via bounds
:greater_thanx > \text{lower}
:less_thanx < \text{upper}
:simplexProbability simplex
:lower_choleskyLower-triangular Cholesky factor

Pass bounds = (lower, upper) for interval supports; either endpoint can be nothing.

julia
ProbProg.sample(
    rng, my_sampler, args...;
    symbol=, logpdf=my_logpdf,
    support=:interval, bounds=(0.0, 1.0),
)

ProbProg.sample(
    rng, my_sampler, args...;
    symbol=, logpdf=my_logpdf,
    support=:greater_than, bounds=(0.5, nothing),
)

Built-in distributions set support automatically.

Submodels

A sampler that itself calls ProbProg.sample yields nested traces. Inner sites become child addresses under the outer symbol:

julia
function inner(rng, μ, σ, shape)
    _, a = ProbProg.sample(rng, ProbProg.Normal(μ, σ, shape); symbol=:a)
    _, b = ProbProg.sample(rng, ProbProg.Normal(μ, σ, shape); symbol=:b)
    return a .* b
end

function outer(rng, μ, σ, shape)
    _, s = ProbProg.sample(rng, inner, μ, σ, shape; symbol=:s)
    _, t = ProbProg.sample(rng, inner, s, σ, shape; symbol=:t)
    return t
end
outer (generic function with 1 method)

The resulting trace exposes trace.subtraces[:s].choices[:a], etc.

untraced_call calls a probabilistic function without recording its choices:

julia
ProbProg.untraced_call(rng, inner, μ, σ, shape)

Inspecting IR

Unoptimised form shows raw impulse.sample ops:

julia
@code_hlo optimize=false ProbProg.sample(rng, ProbProg.Normal(μ, σ, (10,)))

Lowered form:

julia
@code_hlo optimize=:probprog ProbProg.untraced_call(rng, model, μ, σ, (10,))

Next: traces and constrained inference.