Automatic Differentiation
Reactant integrates seamlessly with Enzyme.jl to provide high-performance automatic differentiation (AD) capabilities. This tutorial will guide you through using Enzyme.jl with Reactant to compute gradients using forward and reverse mode automatic differentiation.
using Reactant, Enzyme, Random
Forward Mode Automatic Differentiation
Basic Forward Mode
# Define a simple function
square(x) = x .^ 2
# Create input data
x = Reactant.to_rarray(Float32[3.0, 4.0, 5.0])
function sq_fwd(x)
return Enzyme.autodiff(Forward, square, Duplicated(x, fill!(similar(x), true)))[1]
end
# Compute forward-mode autodiff
# Forward mode with Duplicated activity
result = @jit sq_fwd(x)
println("Result: ", result)
Result: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[6.0, 8.0, 10.0])
The Duplicated
activity type means both the primal value and its derivative are computed.
Forward Mode with Primal
You can also get both the function value and its derivative:
# Forward mode with primal value
function sq_fwd_primal(x)
return Enzyme.autodiff(
ForwardWithPrimal, square, Duplicated(x, fill!(similar(x), true))
)
end
tangent, primal = @jit sq_fwd_primal(x)
println("Primal: ", primal)
println("Tangent: ", tangent)
Primal: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[9.0, 16.0, 25.0])
Tangent: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[6.0, 8.0, 10.0])
Computing Gradients
For computing gradients of scalar-valued functions:
sum_squares(x) = sum(abs2, x)
x = Reactant.to_rarray(Float32[1.0, 2.0, 3.0])
# Compute gradient using forward mode
grad_result = @jit Enzyme.gradient(Forward, sum_squares, x)
println("Gradient: ", grad_result[1])
┌ Warning: `Adapt.parent_type` is not implemented for Enzyme.TupleArray{Reactant.TracedRNumber{Float32}, (3,), 3, 1}. Assuming Enzyme.TupleArray{Reactant.TracedRNumber{Float32}, (3,), 3, 1} isn't a wrapped array.
└ @ Reactant ~/work/Reactant.jl/Reactant.jl/src/Reactant.jl:50
Gradient: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}[ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(2.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(4.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(6.0f0)]
Reverse Mode Automatic Differentiation
Basic Reverse Mode
loss_function(x) = sum(x .^ 3)
x = Reactant.to_rarray(Float32[1.0, 2.0, 3.0])
# Compute gradient using reverse mode
grad = @jit Enzyme.gradient(Reverse, loss_function, x)
println("Gradient: ", grad[1])
Gradient: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[3.0, 12.0, 27.0])
Reverse Mode with Primal
Get both the function value and gradient:
# Reverse mode with primal
result = @jit Enzyme.gradient(ReverseWithPrimal, loss_function, x)
println("Value: ", result.val)
println("Gradient: ", result.derivs[1])
Value: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(36.0f0)
Gradient: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[3.0, 12.0, 27.0])
More Examples
Multi-argument Functions
function multi_arg_func(x, y)
return sum(x .* y .^ 2)
end
x = Reactant.to_rarray(Float32[1.0, 2.0])
y = Reactant.to_rarray(Float32[3.0, 4.0])
# Gradient w.r.t. both arguments
grad = @jit Enzyme.gradient(Reverse, multi_arg_func, x, y)
println("∂f/∂x: ", grad[1])
println("∂f/∂y: ", grad[2])
∂f/∂x: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[9.0, 16.0])
∂f/∂y: ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[6.0, 16.0])
Vector Mode AD
Vector mode computes multiple derivatives simultaneously:
vector_func(x) = sum(abs2, x)
x = Reactant.to_rarray(collect(Float32, 1:4))
# Create onehot vectors for vector mode
onehot_vectors = @jit Enzyme.onehot(x)
# Vector forward mode
result = @jit Enzyme.autodiff(Forward, vector_func, BatchDuplicated(x, onehot_vectors))
println("Vector gradients: ", result[1])
Vector gradients: (ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(2.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(4.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(6.0f0), ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(8.0f0))
Nested Automatic Differentiation
Compute higher-order derivatives:
power4(x) = x^4
x = Reactant.ConcreteRNumber(3.1)
# First derivative
first_deriv(x) = Enzyme.gradient(Reverse, power4, x)[1]
# Second derivative
second_deriv(x) = Enzyme.gradient(Reverse, first_deriv, x)[1]
# Compute second derivative
result = @jit second_deriv(x)
result_enz = second_deriv(Float32(x))
println("Second derivative: ", result)
Second derivative: ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(115.32000000000002)
Division by Zero with Strong Zero
div_by_zero(x) = min(1.0, 1 / x)
x = Reactant.ConcreteRNumber(0.0)
# Regular gradient (may be NaN)
regular_grad = @jit Enzyme.gradient(Reverse, div_by_zero, x)
# Strong zero gradient (handles singularities better)
strong_zero_grad = @jit Enzyme.gradient(Enzyme.set_strong_zero(Reverse), div_by_zero, x)
println("Regular gradient: ", Float32(regular_grad[1]))
println("Strong zero gradient: ", Float32(strong_zero_grad[1]))
Regular gradient: NaN
Strong zero gradient: -0.0
Ignoring Derivatives
Use Reactant.ignore_derivatives
to exclude parts of computation from gradient:
function func_with_ignore(x)
# This part won't contribute to gradient
ignored_sum = Reactant.ignore_derivatives(sum(x))
# This part will contribute
return sum(x .^ 2) + ignored_sum
end
x = Reactant.to_rarray([1.0, 2.0, 3.0])
grad = @jit Enzyme.gradient(Reverse, func_with_ignore, x)
println("Gradient: ", grad[1])
Gradient: ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}([2.0, 4.0, 6.0])
Complex Numbers and Special Arrays
Reactant supports complex numbers and various array types:
# Complex arrays
x_complex = Reactant.to_rarray([1.0 + 2.0im, 3.0 + 4.0im])
function complex_func(z)
return sum(abs2, z)
end
grad_complex = @jit Enzyme.gradient(Reverse, complex_func, x_complex)
println("Complex gradient: ", grad_complex[1])
Complex gradient: ConcretePJRTArray{ComplexF64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF64[2.0 + 4.0im, 6.0 + 8.0im])
Complete Example: Neural Network Training
Here's a complete example of training a simple neural network:
# Define network
function neural_net(x, w1, w2, b1, b2)
h = tanh.(w1 * x .+ b1)
return w2 * h .+ b2
end
# Loss function
function loss(x, y, w1, w2, b1, b2)
pred = neural_net(x, w1, w2, b1, b2)
return sum(abs2, pred .- y)
end
# Generate data
x = Reactant.to_rarray(rand(Float32, 10, 32))
y = Reactant.to_rarray(2 .* sum(abs2, Array(x); dims=1) .+ rand(Float32, 1, 32) .* 0.001f0)
# Initialize parameters
w1 = Reactant.to_rarray(rand(Float32, 20, 10))
w2 = Reactant.to_rarray(rand(Float32, 1, 20))
b1 = Reactant.to_rarray(rand(Float32, 20))
b2 = Reactant.to_rarray(rand(Float32, 1))
# Training step
function train_step(x, y, w1, w2, b1, b2, lr)
# Compute gradients
(; val, derivs) = Enzyme.gradient(
ReverseWithPrimal, loss, Const(x), Const(y), w1, w2, b1, b2
)
# Update parameters (simple gradient descent)
w1 .-= lr .* derivs[3]
w2 .-= lr .* derivs[4]
b1 .-= lr .* derivs[5]
b2 .-= lr .* derivs[6]
return val, w1, w2, b1, b2
end
# Training loop
compiled_train_step = @compile train_step(x, y, w1, w2, b1, b2, 0.001f0)
for epoch in 1:100
global w1, w2, b1, b2
loss_val, w1, w2, b1, b2 = compiled_train_step(x, y, w1, w2, b1, b2, 0.001f0)
if epoch % 10 == 0
@info "Epoch: $epoch, Loss: $loss_val"
end
end
println("Training completed!")
[ Info: Epoch: 10, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(151.27728f0)
[ Info: Epoch: 20, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(149.66284f0)
[ Info: Epoch: 30, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(147.34663f0)
[ Info: Epoch: 40, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(143.70265f0)
[ Info: Epoch: 50, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(137.2854f0)
[ Info: Epoch: 60, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(125.445015f0)
[ Info: Epoch: 70, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(107.30259f0)
[ Info: Epoch: 80, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(83.6433f0)
[ Info: Epoch: 90, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(59.056812f0)
[ Info: Epoch: 100, Loss: ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(40.826714f0)
Training completed!