Skip to content

Automatic Differentiation

Reactant integrates seamlessly with Enzyme.jl to provide high-performance automatic differentiation (AD) capabilities. This tutorial will guide you through using Enzyme.jl with Reactant to compute gradients using forward and reverse mode automatic differentiation.

julia
using Reactant, Enzyme, Random

Forward Mode Automatic Differentiation

Basic Forward Mode

julia
# Define a simple function
square(x) = x .^ 2

# Create input data
x = Reactant.to_rarray(Float32[3.0, 4.0, 5.0])

function sq_fwd(x)
    return Enzyme.autodiff(Forward, square, Duplicated(x, fill!(similar(x), true)))[1]
end

# Compute forward-mode autodiff
# Forward mode with Duplicated activity
result = @jit sq_fwd(x)

println("Result: ", result)
Result: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[6.0, 8.0, 10.0])

The Duplicated activity type means both the primal value and its derivative are computed.

Forward Mode with Primal

You can also get both the function value and its derivative:

julia
# Forward mode with primal value
function sq_fwd_primal(x)
    return Enzyme.autodiff(
        ForwardWithPrimal, square, Duplicated(x, fill!(similar(x), true))
    )
end

tangent, primal = @jit sq_fwd_primal(x)

println("Primal: ", primal)
println("Tangent: ", tangent)
Primal: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[9.0, 16.0, 25.0])
Tangent: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[6.0, 8.0, 10.0])

Computing Gradients

For computing gradients of scalar-valued functions:

julia
sum_squares(x) = sum(abs2, x)

x = Reactant.to_rarray(Float32[1.0, 2.0, 3.0])

# Compute gradient using forward mode
grad_result = @jit Enzyme.gradient(Forward, sum_squares, x)

println("Gradient: ", grad_result[1])
Warning: `Adapt.parent_type` is not implemented for Enzyme.TupleArray{Reactant.TracedRNumber{Float32}, (3,), 3, 1}. Assuming Enzyme.TupleArray{Reactant.TracedRNumber{Float32}, (3,), 3, 1} isn't a wrapped array.
@ Reactant ~/work/Reactant.jl/Reactant.jl/src/Reactant.jl:65
Gradient: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}[ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(2.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(4.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(6.0f0)]

Reverse Mode Automatic Differentiation

Basic Reverse Mode

julia
loss_function(x) = sum(x .^ 3)

x = Reactant.to_rarray(Float32[1.0, 2.0, 3.0])

# Compute gradient using reverse mode
grad = @jit Enzyme.gradient(Reverse, loss_function, x)

println("Gradient: ", grad[1])
Gradient: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[3.0, 12.0, 27.0])

Reverse Mode with Primal

Get both the function value and gradient:

julia
# Reverse mode with primal
result = @jit Enzyme.gradient(ReverseWithPrimal, loss_function, x)

println("Value: ", result.val)
println("Gradient: ", result.derivs[1])
Value: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(36.0f0)
Gradient: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[3.0, 12.0, 27.0])

More Examples

Multi-argument Functions

julia
function multi_arg_func(x, y)
    return sum(x .* y .^ 2)
end

x = Reactant.to_rarray(Float32[1.0, 2.0])
y = Reactant.to_rarray(Float32[3.0, 4.0])

# Gradient w.r.t. both arguments
grad = @jit Enzyme.gradient(Reverse, multi_arg_func, x, y)

println("∂f/∂x: ", grad[1])
println("∂f/∂y: ", grad[2])
∂f/∂x: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[9.0, 16.0])
∂f/∂y: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[6.0, 16.0])

Vector Mode AD

Vector mode computes multiple derivatives simultaneously:

julia
vector_func(x) = sum(abs2, x)

x = Reactant.to_rarray(collect(Float32, 1:4))

# Create onehot vectors for vector mode
onehot_vectors = @jit Enzyme.onehot(x)

# Vector forward mode
result = @jit Enzyme.autodiff(Forward, vector_func, BatchDuplicated(x, onehot_vectors))

println("Vector gradients: ", result[1])
Vector gradients: (ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(2.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(4.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(6.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(8.0f0))

Nested Automatic Differentiation

Compute higher-order derivatives:

julia
power4(x) = x^4

x = Reactant.ConcreteRNumber(3.1)

# First derivative
first_deriv(x) = Enzyme.gradient(Reverse, power4, x)[1]

# Second derivative
second_deriv(x) = Enzyme.gradient(Reverse, first_deriv, x)[1]

# Compute second derivative
result = @jit second_deriv(x)
result_enz = second_deriv(Float32(x))
println("Second derivative: ", result)
Second derivative: ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(115.32000000000002)

Division by Zero with Strong Zero

julia
div_by_zero(x) = min(1.0, 1 / x)

x = Reactant.ConcreteRNumber(0.0)

# Regular gradient (may be NaN)
regular_grad = @jit Enzyme.gradient(Reverse, div_by_zero, x)

# Strong zero gradient (handles singularities better)
strong_zero_grad = @jit Enzyme.gradient(Enzyme.set_strong_zero(Reverse), div_by_zero, x)

println("Regular gradient: ", Float32(regular_grad[1]))
println("Strong zero gradient: ", Float32(strong_zero_grad[1]))
Regular gradient: NaN
Strong zero gradient: -0.0

Ignoring Derivatives

Use EnzymeCore.ignore_derivatives to exclude parts of computation from gradient:

julia
function func_with_ignore(x)
    # This part won't contribute to gradient
    ignored_sum = Enzyme.ignore_derivatives(sum(x))
    # This part will contribute
    return sum(x .^ 2) + ignored_sum
end

x = Reactant.to_rarray([1.0, 2.0, 3.0])

grad = @jit Enzyme.gradient(Reverse, func_with_ignore, x)
println("Gradient: ", grad[1])
Gradient: ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}([2.0, 4.0, 6.0])

Complex Numbers and Special Arrays

Reactant supports complex numbers and various array types:

julia
# Complex arrays
x_complex = Reactant.to_rarray([1.0 + 2.0im, 3.0 + 4.0im])

function complex_func(z)
    return sum(abs2, z)
end

grad_complex = @jit Enzyme.gradient(Reverse, complex_func, x_complex)
println("Complex gradient: ", grad_complex[1])
Complex gradient: ConcretePJRTArray{ComplexF64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF64[2.0 + 4.0im, 6.0 + 8.0im])

Loops

When performing computations in a loop using the @trace (see the tutorial about Control Flow), Enzyme has to save intermediary results during the primal computation to be used in the reverse pass.

The @trace macro offers two parameters to limit the amount of memory that has to be saved. The first one in the mincut option which strictly reduces the amount of saved memory by saving only the minimal amount of memory needed for each iteration and recomputing variables if needed.

Secondly, it is possible to enable the checkpointing option which will save intermediary loop carried values every N iterations and perform complete recomputation during the reverse pass. This is a way to trade compute time against memory.

julia
function f_checkpointing(x, enable_checkpointing)
    y = copy(x)

    # The intermediary values of y will need to be cached
    # to be reused in the reverse pass. With checkpointing enabled,
    # the cache will be of size `Int(sqrt(9)) * length(y)` instead of
    # `9 * length(y)`.
    @trace checkpointing=enable_checkpointing for i in 1:9
        y .*= x
    end

    return y
end

f_checkpointing_diff(x, enable_checkpointing) =
    Enzyme.gradient(Reverse, f_checkpointing, x, Const(enable_checkpointing))
f_checkpointing_diff (generic function with 1 method)

Note

The currently implemented checkpointing scheme only supports a constant number of iterations which has an integer square root. If N is the number of iterations, the values will be cached N times against N times if checkpointing is disabled.

Complete Example: Neural Network Training

Training Lux Neural Networks

If you are using Lux.jl for neural networks, prefer using the TrainState API that abstracts away a lot of these details.

Here's a complete example of training a simple neural network:

julia
# Define network
function neural_net(x, w1, w2, b1, b2)
    h = tanh.(w1 * x .+ b1)
    return w2 * h .+ b2
end

# Loss function
function loss(x, y, w1, w2, b1, b2)
    pred = neural_net(x, w1, w2, b1, b2)
    return sum(abs2, pred .- y)
end

# Generate data
x = Reactant.to_rarray(rand(Float32, 10, 32))
y = Reactant.to_rarray(2 .* sum(abs2, Array(x); dims=1) .+ rand(Float32, 1, 32) .* 0.001f0)

# Initialize parameters
w1 = Reactant.to_rarray(rand(Float32, 20, 10))
w2 = Reactant.to_rarray(rand(Float32, 1, 20))
b1 = Reactant.to_rarray(rand(Float32, 20))
b2 = Reactant.to_rarray(rand(Float32, 1))

# Training step
function train_step(x, y, w1, w2, b1, b2, lr)
    # Compute gradients
    (; val, derivs) = Enzyme.gradient(
        ReverseWithPrimal, loss, Const(x), Const(y), w1, w2, b1, b2
    )

    # Update parameters (simple gradient descent)
    w1 .-= lr .* derivs[3]
    w2 .-= lr .* derivs[4]
    b1 .-= lr .* derivs[5]
    b2 .-= lr .* derivs[6]

    return val, w1, w2, b1, b2
end

# Training loop
compiled_train_step = @compile train_step(x, y, w1, w2, b1, b2, 0.001f0)

for epoch in 1:100
    global w1, w2, b1, b2
    loss_val, w1, w2, b1, b2 = compiled_train_step(x, y, w1, w2, b1, b2, 0.001f0)
    if epoch % 10 == 0
        @info "Epoch: $epoch, Loss: $loss_val"
    end
end

println("Training completed!")
[ Info: Epoch: 10, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(128.3026f0)
[ Info: Epoch: 20, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(126.06101f0)
[ Info: Epoch: 30, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(122.12404f0)
[ Info: Epoch: 40, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(113.871796f0)
[ Info: Epoch: 50, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(98.97273f0)
[ Info: Epoch: 60, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(85.462265f0)
[ Info: Epoch: 70, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(72.88088f0)
[ Info: Epoch: 80, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(60.875923f0)
[ Info: Epoch: 90, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(50.332005f0)
[ Info: Epoch: 100, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(41.42317f0)
Training completed!