Automatic Differentiation
Reactant integrates seamlessly with Enzyme.jl to provide high-performance automatic differentiation (AD) capabilities. This tutorial will guide you through using Enzyme.jl with Reactant to compute gradients using forward and reverse mode automatic differentiation.
using Reactant, Enzyme, Random
Forward Mode Automatic Differentiation
Basic Forward Mode
# Define a simple function
square(x) = x .^ 2
# Create input data
x = Reactant.to_rarray(Float32[3.0, 4.0, 5.0])
function sq_fwd(x)
return Enzyme.autodiff(Forward, square, Duplicated(x, fill!(similar(x), true)))[1]
end
# Compute forward-mode autodiff
# Forward mode with Duplicated activity
result = @jit sq_fwd(x)
println("Result: ", result)
Result: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[6.0, 8.0, 10.0])
The Duplicated
activity type means both the primal value and its derivative are computed.
Forward Mode with Primal
You can also get both the function value and its derivative:
# Forward mode with primal value
function sq_fwd_primal(x)
return Enzyme.autodiff(
ForwardWithPrimal, square, Duplicated(x, fill!(similar(x), true))
)
end
tangent, primal = @jit sq_fwd_primal(x)
println("Primal: ", primal)
println("Tangent: ", tangent)
Primal: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[9.0, 16.0, 25.0])
Tangent: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[6.0, 8.0, 10.0])
Computing Gradients
For computing gradients of scalar-valued functions:
sum_squares(x) = sum(abs2, x)
x = Reactant.to_rarray(Float32[1.0, 2.0, 3.0])
# Compute gradient using forward mode
grad_result = @jit Enzyme.gradient(Forward, sum_squares, x)
println("Gradient: ", grad_result[1])
┌ Warning: `Adapt.parent_type` is not implemented for Enzyme.TupleArray{Reactant.TracedRNumber{Float32}, (3,), 3, 1}. Assuming Enzyme.TupleArray{Reactant.TracedRNumber{Float32}, (3,), 3, 1} isn't a wrapped array.
└ @ Reactant ~/work/Reactant.jl/Reactant.jl/src/Reactant.jl:65
Gradient: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}[ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(2.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(4.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(6.0f0)]
Reverse Mode Automatic Differentiation
Basic Reverse Mode
loss_function(x) = sum(x .^ 3)
x = Reactant.to_rarray(Float32[1.0, 2.0, 3.0])
# Compute gradient using reverse mode
grad = @jit Enzyme.gradient(Reverse, loss_function, x)
println("Gradient: ", grad[1])
Gradient: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[3.0, 12.0, 27.0])
Reverse Mode with Primal
Get both the function value and gradient:
# Reverse mode with primal
result = @jit Enzyme.gradient(ReverseWithPrimal, loss_function, x)
println("Value: ", result.val)
println("Gradient: ", result.derivs[1])
Value: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(36.0f0)
Gradient: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[3.0, 12.0, 27.0])
More Examples
Multi-argument Functions
function multi_arg_func(x, y)
return sum(x .* y .^ 2)
end
x = Reactant.to_rarray(Float32[1.0, 2.0])
y = Reactant.to_rarray(Float32[3.0, 4.0])
# Gradient w.r.t. both arguments
grad = @jit Enzyme.gradient(Reverse, multi_arg_func, x, y)
println("∂f/∂x: ", grad[1])
println("∂f/∂y: ", grad[2])
∂f/∂x: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[9.0, 16.0])
∂f/∂y: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[6.0, 16.0])
Vector Mode AD
Vector mode computes multiple derivatives simultaneously:
vector_func(x) = sum(abs2, x)
x = Reactant.to_rarray(collect(Float32, 1:4))
# Create onehot vectors for vector mode
onehot_vectors = @jit Enzyme.onehot(x)
# Vector forward mode
result = @jit Enzyme.autodiff(Forward, vector_func, BatchDuplicated(x, onehot_vectors))
println("Vector gradients: ", result[1])
Vector gradients: (ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(2.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(4.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(6.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(8.0f0))
Nested Automatic Differentiation
Compute higher-order derivatives:
power4(x) = x^4
x = Reactant.ConcreteRNumber(3.1)
# First derivative
first_deriv(x) = Enzyme.gradient(Reverse, power4, x)[1]
# Second derivative
second_deriv(x) = Enzyme.gradient(Reverse, first_deriv, x)[1]
# Compute second derivative
result = @jit second_deriv(x)
result_enz = second_deriv(Float32(x))
println("Second derivative: ", result)
Second derivative: ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(115.32000000000002)
Division by Zero with Strong Zero
div_by_zero(x) = min(1.0, 1 / x)
x = Reactant.ConcreteRNumber(0.0)
# Regular gradient (may be NaN)
regular_grad = @jit Enzyme.gradient(Reverse, div_by_zero, x)
# Strong zero gradient (handles singularities better)
strong_zero_grad = @jit Enzyme.gradient(Enzyme.set_strong_zero(Reverse), div_by_zero, x)
println("Regular gradient: ", Float32(regular_grad[1]))
println("Strong zero gradient: ", Float32(strong_zero_grad[1]))
Regular gradient: NaN
Strong zero gradient: -0.0
Ignoring Derivatives
Use EnzymeCore.ignore_derivatives
to exclude parts of computation from gradient:
function func_with_ignore(x)
# This part won't contribute to gradient
ignored_sum = Enzyme.ignore_derivatives(sum(x))
# This part will contribute
return sum(x .^ 2) + ignored_sum
end
x = Reactant.to_rarray([1.0, 2.0, 3.0])
grad = @jit Enzyme.gradient(Reverse, func_with_ignore, x)
println("Gradient: ", grad[1])
Gradient: ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}([2.0, 4.0, 6.0])
Complex Numbers and Special Arrays
Reactant supports complex numbers and various array types:
# Complex arrays
x_complex = Reactant.to_rarray([1.0 + 2.0im, 3.0 + 4.0im])
function complex_func(z)
return sum(abs2, z)
end
grad_complex = @jit Enzyme.gradient(Reverse, complex_func, x_complex)
println("Complex gradient: ", grad_complex[1])
Complex gradient: ConcretePJRTArray{ComplexF64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF64[2.0 + 4.0im, 6.0 + 8.0im])
Loops
When performing computations in a loop using the @trace
(see the tutorial about Control Flow), Enzyme has to save intermediary results during the primal computation to be used in the reverse pass.
The @trace
macro offers two parameters to limit the amount of memory that has to be saved. The first one in the mincut
option which strictly reduces the amount of saved memory by saving only the minimal amount of memory needed for each iteration and recomputing variables if needed.
Secondly, it is possible to enable the checkpointing
option which will save intermediary loop carried values every
function f_checkpointing(x, enable_checkpointing)
y = copy(x)
# The intermediary values of y will need to be cached
# to be reused in the reverse pass. With checkpointing enabled,
# the cache will be of size `Int(sqrt(9)) * length(y)` instead of
# `9 * length(y)`.
@trace checkpointing=enable_checkpointing for i in 1:9
y .*= x
end
return y
end
f_checkpointing_diff(x, enable_checkpointing) =
Enzyme.gradient(Reverse, f_checkpointing, x, Const(enable_checkpointing))
f_checkpointing_diff (generic function with 1 method)
Note
The currently implemented checkpointing scheme only supports a constant number of iterations which has an integer square root. If
Complete Example: Neural Network Training
Training Lux Neural Networks
If you are using Lux.jl for neural networks, prefer using the TrainState API that abstracts away a lot of these details.
Here's a complete example of training a simple neural network:
# Define network
function neural_net(x, w1, w2, b1, b2)
h = tanh.(w1 * x .+ b1)
return w2 * h .+ b2
end
# Loss function
function loss(x, y, w1, w2, b1, b2)
pred = neural_net(x, w1, w2, b1, b2)
return sum(abs2, pred .- y)
end
# Generate data
x = Reactant.to_rarray(rand(Float32, 10, 32))
y = Reactant.to_rarray(2 .* sum(abs2, Array(x); dims=1) .+ rand(Float32, 1, 32) .* 0.001f0)
# Initialize parameters
w1 = Reactant.to_rarray(rand(Float32, 20, 10))
w2 = Reactant.to_rarray(rand(Float32, 1, 20))
b1 = Reactant.to_rarray(rand(Float32, 20))
b2 = Reactant.to_rarray(rand(Float32, 1))
# Training step
function train_step(x, y, w1, w2, b1, b2, lr)
# Compute gradients
(; val, derivs) = Enzyme.gradient(
ReverseWithPrimal, loss, Const(x), Const(y), w1, w2, b1, b2
)
# Update parameters (simple gradient descent)
w1 .-= lr .* derivs[3]
w2 .-= lr .* derivs[4]
b1 .-= lr .* derivs[5]
b2 .-= lr .* derivs[6]
return val, w1, w2, b1, b2
end
# Training loop
compiled_train_step = @compile train_step(x, y, w1, w2, b1, b2, 0.001f0)
for epoch in 1:100
global w1, w2, b1, b2
loss_val, w1, w2, b1, b2 = compiled_train_step(x, y, w1, w2, b1, b2, 0.001f0)
if epoch % 10 == 0
@info "Epoch: $epoch, Loss: $loss_val"
end
end
println("Training completed!")
[ Info: Epoch: 10, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(128.3026f0)
[ Info: Epoch: 20, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(126.06101f0)
[ Info: Epoch: 30, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(122.12404f0)
[ Info: Epoch: 40, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(113.871796f0)
[ Info: Epoch: 50, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(98.97273f0)
[ Info: Epoch: 60, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(85.462265f0)
[ Info: Epoch: 70, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(72.88088f0)
[ Info: Epoch: 80, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(60.875923f0)
[ Info: Epoch: 90, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(50.332005f0)
[ Info: Epoch: 100, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(41.42317f0)
Training completed!