Enzyme custom rules tutorial
The tutorial below focuses on a simple setting to illustrate the basic concepts of writing custom rules. For more complex custom rules beyond the scope of this tutorial, you may take inspiration from the following in-the-wild examples:
The goal of this tutorial is to give a simple example of defining a custom rule with Enzyme. Specifically, our goal will be to write custom rules for the following function f
:
function f(y, x)
y .= x.^2
return sum(y)
end
f (generic function with 1 method)
Our function f
populates its first input y
with the element-wise square of x
. In addition, it returns sum(y)
as output. What a sneaky function!
In this case, Enzyme can differentiate through f
automatically. For example, using forward mode:
using Enzyme
x = [3.0, 1.0]
dx = [1.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
g(y, x) = f(y, x)^2 # function to differentiate
@show autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) # derivative of g w.r.t. x[1]
@show dy; # derivative of y w.r.t. x[1] when g is run
autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) = (120.0,)
dy = [6.0, 0.0]
(See the AutoDiff API tutorial for more information on using autodiff
.)
But there may be special cases where we need to write a custom rule to help Enzyme out. Let's see how to write a custom rule for f
!
Enzyme can efficiently handle a wide range of constructs, and so a custom rule should only be required in certain special cases. For example, a function may make a foreign call that Enzyme cannot differentiate, or we may have higher-level mathematical knowledge that enables us to write a more efficient rule. Even in these cases, try to make your custom rule encapsulate the minimum possible construct that Enzyme cannot differentiate, rather than expanding the scope of the rule unnecessarily. For pedagogical purposes, we will disregard this principle here and go ahead and write a custom rule for f
:)
Defining our first rule
First, we import the functions EnzymeRules.forward
, EnzymeRules.augmented_primal
, and EnzymeRules.reverse
. We need to overload forward
in order to define a custom forward rule, and we need to overload augmented_primal
and reverse
in order to define a custom reverse rule.
import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules
In this section, we write a simple forward rule to start out:
function forward(config::FwdConfig, func::Const{typeof(f)}, ::Type{<:Duplicated}, y::Duplicated, x::Duplicated)
println("Using custom rule!")
ret = func.val(y.val, x.val)
y.dval .= 2 .* x.val .* x.dval
return Duplicated(ret, sum(y.dval))
end
forward (generic function with 15 methods)
In the signature of our rule, we have made use of Enzyme
's activity annotations. Let's break down each one:
- the
EnzymeRules.FwdConfig
configuration passes certain compile-time information about differentiation procedure (the width, and if we're using runtime activity), - the
Const
annotation onf
indicates that we accept a functionf
that does not have a derivative component, which makes sense sincef
is not a closure with data that could be differentiated. - the
Duplicated
annotation given in the second argument annotates the return value off
. This means that ourforward
function should return an output of typeDuplicated
, containing the original outputsum(y)
and its derivative. - the
Duplicated
annotations forx
andy
mean that ourforward
function handles inputsx
andy
which have been marked asDuplicated
. We should update their shadows with their derivative contributions.
In the logic of our forward function, we run the original function, populate y.dval
(the shadow of y
), and finally return a Duplicated
for the output as promised. Let's see our rule in action! With the same setup as before:
x = [3.0, 1.0]
dx = [1.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
g(y, x) = f(y, x)^2 # function to differentiate
@show autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) # derivative of g w.r.t. x[1]
@show dy; # derivative of y w.r.t. x[1] when g is run
Using custom rule!
autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) = (120.0,)
dy = [6.0, 0.0]
We see that our custom forward rule has been triggered and gives the same answer as before.
Handling more activities
Our custom rule applies for the specific set of activities that are annotated for f
in the above autodiff
call. However, Enzyme has a number of other annotations. Let us consider a particular example, where the output has a DuplicatedNoNeed
annotation. This means we are only interested in its derivative, not its value. To squeeze out the last drop of performance, the below rule avoids computing the output of the original function and just computes its derivative.
function forward(config, func::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, y::Duplicated, x::Duplicated)
println("Using custom rule with DuplicatedNoNeed output.")
y.val .= x.val.^2
y.dval .= 2 .* x.val .* x.dval
return sum(y.dval)
end
forward (generic function with 16 methods)
Our rule is triggered, for example, when we call autodiff
directly on f
, as the return value's derivative isn't needed:
x = [3.0, 1.0]
dx = [1.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
@show autodiff(Forward, f, Duplicated(y, dy), Duplicated(x, dx)) # derivative of f w.r.t. x[1]
@show dy; # derivative of y w.r.t. x[1] when f is run
Using custom rule with DuplicatedNoNeed output.
autodiff(Forward, f, Duplicated(y, dy), Duplicated(x, dx)) = (6.0,)
dy = [6.0, 0.0]
When multiple custom rules for a function are defined, the correct rule is chosen using Julia's multiple dispatch. In particular, it is important to understand that the custom rule does not determine the activities of the inputs and the return value: rather, Enzyme
decides the activity annotations independently, and then dispatches to the custom rule handling the activities, if one exists. If a custom rule is specified for the correct function/argument types, but not the correct activity annotation, a runtime error will be thrown alerting the user to the missing activity rule rather than silently ignoring the rule."
Finally, it may be that either x
, y
, or the return value are marked as Const
, in which case we can simply return the original result. However, Enzyme also may determine the return is not differentiable and also not needed for other computations, in which case we should simply return nothing.
We can in fact handle this case, along with the previous two cases, all together in a single rule by leveraging utility functions EnzymeRules.needs_primal
and EnzymeRules.needs_shadow
, which return true if the original return or the derivative is needed to be returned, respectively:
Base.delete_method.(methods(forward, (Const{typeof(f)}, Vararg{Any}))) # delete our old rules
function forward(config, func::Const{typeof(f)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
y::Union{Const, Duplicated}, x::Union{Const, Duplicated})
println("Using our general custom rule!")
y.val .= x.val.^2
if !(x isa Const) && !(y isa Const)
y.dval .= 2 .* x.val .* x.dval
elseif !(y isa Const)
make_zero!(y.dval)
end
dret = !(y isa Const) ? sum(y.dval) : zero(eltype(y.val))
if needs_primal(config) && needs_shadow(config)
return Duplicated(sum(y.val), dret)
elseif needs_primal(config)
return sum(y.val)
elseif needs_shadow(config)
return dret
else
return nothing
end
end
forward (generic function with 16 methods)
Let's try out our rule:
x = [3.0, 1.0]
dx = [1.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
g(y, x) = f(y, x)^2 # function to differentiate
@show autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) # derivative of g w.r.t. x[1]
@show autodiff(Forward, g, Const(y), Duplicated(x, dx)) # derivative of g w.r.t. x[1], with y annotated Const
@show autodiff(Forward, g, Const(y), Const(x)); # derivative of g w.r.t. x[1], with x and y annotated Const
Using custom rule!
autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) = (120.0,)
Using our general custom rule!
autodiff(Forward, g, Const(y), Duplicated(x, dx)) = (0.0,)
autodiff(Forward, g, Const(y), Const(x)) = (0.0,)
Note that there are also exist batched duplicated annotations for forward mode, namely BatchDuplicated
and BatchDuplicatedNoNeed
, which are not covered in this tutorial.
Defining a reverse-mode rule
Let's look at how to write a simple reverse-mode rule! First, we write a method for EnzymeRules.augmented_primal
:
function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active},
y::Duplicated, x::Duplicated)
println("In custom augmented primal rule.")
# Compute primal
if needs_primal(config)
primal = func.val(y.val, x.val)
else
y.val .= x.val.^2 # y still needs to be mutated even if primal not needed!
primal = nothing
end
# Save x in tape if x will be overwritten
if overwritten(config)[3]
tape = copy(x.val)
else
tape = nothing
end
# Return an AugmentedReturn object with shadow = nothing
return AugmentedReturn(primal, nothing, tape)
end
augmented_primal (generic function with 14 methods)
Let's unpack our signature for augmented_primal
:
- We accepted a
EnzymeRules.RevConfig
object with a specified width of 1, which means that our rule does not support batched reverse mode. - We annotated
f
withConst
as usual. - We dispatched on an
Active
annotation for the return value. This is a special annotation for scalar values, such as our return value, that indicates that that we care about the value's derivative but we need not explicitly allocate a mutable shadow since it is a scalar value. - We annotated
x
andy
withDuplicated
, similar to our first simple forward rule.
Now, let's unpack the body of our augmented_primal
rule:
- We checked if the
config
requires the primal. If not, we need not compute the return value, but we make sure to mutatey
in all cases. - We checked if
x
could possibly be overwritten using theOverwritten
attribute ofEnzymeRules.RevConfig
. If so, we save the elements ofx
on thetape
of the returnedEnzymeRules.AugmentedReturn
object. - We return a shadow of
nothing
since the return value isActive
and hence does not need a shadow.
Now, we write a method for EnzymeRules.reverse
:
function reverse(config::RevConfigWidth{1}, func::Const{typeof(f)}, dret::Active, tape,
y::Duplicated, x::Duplicated)
println("In custom reverse rule.")
# retrieve x value, either from original x or from tape if x may have been overwritten.
xval = overwritten(config)[3] ? tape : x.val
# accumulate dret into x's shadow. don't assign!
x.dval .+= 2 .* xval .* dret.val
# also accumulate any derivative in y's shadow into x's shadow.
x.dval .+= 2 .* xval .* y.dval
make_zero!(y.dval)
return (nothing, nothing)
end
reverse (generic function with 14 methods)
Let's make a few observations about our reverse rule:
- The activities used in the signature correspond to what we used for
augmented_primal
. - However, for
Active
return types such as in this case, we now receive an instancedret
ofActive
for the return type, not just a type annotation, which stores the derivative value forret
(not the original return value!). For the other annotations (e.g.Duplicated
), we still receive only the type. In that case, if necessary a reference to the shadow of the output should be placed on the tape inaugmented_primal
. - Using
dret.val
andy.dval
, we accumulate the backpropagated derivatives forx
into its shadowx.dval
. Note that we have to accumulate from bothy.dval
anddret.val
. This is because in reverse-mode AD we have to sum up the derivatives from all uses: ify
was read after our function, we need to consider derivatives from that use as well. - We zero-out
y
's shadow. This is becausey
is overwritten withinf
, so there is no derivative w.r.t. to they
that was originally inputted. - Finally, since all derivatives are accumulated in place (in the shadows of the
Duplicated
arguments), these derivatives must not be communicated via the return value. Hence, we return(nothing, nothing)
. If, instead, one of our arguments was annotated asActive
, we would have to provide its derivative at the corresponding index in the tuple returned.
Finally, let's see our reverse rule in action!
x = [3.0, 1.0]
dx = [0.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
g(y, x) = f(y, x)^2
autodiff(Reverse, g, Duplicated(y, dy), Duplicated(x, dx))
@show dx # derivative of g w.r.t. x
@show dy; # derivative of g w.r.t. y
In custom augmented primal rule.
In custom reverse rule.
dx = [120.0, 40.0]
dy = [0.0, 0.0]
Let's also try a function which mutates x
after running f
, and also uses y
directly rather than only ret
after running f
(but ultimately gives the same result as above):
function h(y, x)
ret = f(y, x)
x .= x.^2
return ret * sum(y)
end
x = [3.0, 1.0]
y = [0.0, 0.0]
make_zero!(dx)
make_zero!(dy)
autodiff(Reverse, h, Duplicated(y, dy), Duplicated(x, dx))
@show dx # derivative of h w.r.t. x
@show dy; # derivative of h w.r.t. y
In custom augmented primal rule.
In custom reverse rule.
dx = [120.0, 40.0]
dy = [0.0, 0.0]
Marking functions inactive
If we want to tell Enzyme that the function call does not affect the differentiation result in any form (i.e. not by side effects or through its return values), we can simply use EnzymeRules.inactive
. So long as there exists a matching dispatch to EnzymeRules.inactive
, the function will be considered inactive. For example:
printhi() = println("Hi!")
EnzymeRules.inactive(::typeof(printhi), args...) = nothing
function k(x)
printhi()
return x^2
end
autodiff(Forward, k, Duplicated(2.0, 1.0))
(4.0,)
Or for a case where we incorrectly mark a function inactive:
double(x) = 2*x
EnzymeRules.inactive(::typeof(double), args...) = nothing
autodiff(Forward, x -> x + double(x), Duplicated(2.0, 1.0)) # mathematically should be 3.0, inactive rule causes it to be 1.0
(1.0,)
Testing our rules
We can test our rules using finite differences using EnzymeTestUtils.test_forward
and EnzymeTestUtils.test_reverse
.
using EnzymeTestUtils, Test
@testset "f rules" begin
@testset "forward" begin
@testset for RT in (Const, DuplicatedNoNeed, Duplicated),
Tx in (Const, Duplicated),
Ty in (Const, Duplicated)
x = [3.0, 1.0]
y = [0.0, 0.0]
test_forward(g, RT, (x, Tx), (y, Ty))
end
end
@testset "reverse" begin
@testset for RT in (Active,),
Tx in (Duplicated,),
Ty in (Duplicated,),
fun in (g, h)
x = [3.0, 1.0]
y = [0.0, 0.0]
test_reverse(fun, RT, (x, Tx), (y, Ty))
end
end
end
Test.DefaultTestSet("f rules", Any[Test.DefaultTestSet("forward", Any[Test.DefaultTestSet("RT = Const, Tx = Const, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Const), (::Vector{Float64}, Const)", Any[], 6, false, false, true, 1.732080944063408e9, 1.732080945480474e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.732080941113874e9, 1.732080945480484e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Const, Tx = Const, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Const), (::Vector{Float64}, Duplicated)", Any[], 6, false, false, true, 1.732080945481987e9, 1.732080946566908e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.732080945480538e9, 1.732080946566913e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Const, Tx = Duplicated, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Const)", Any[], 6, false, false, true, 1.732080946568395e9, 1.73208094749559e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.732080946566963e9, 1.732080947495596e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Const, Tx = Duplicated, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 6, false, false, true, 1.732080947497033e9, 1.732080948544153e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.732080947495641e9, 1.732080948544159e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Const, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Const), (::Vector{Float64}, Const)", Any[], 6, false, false, true, 1.732080948545643e9, 1.732080949508021e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.732080948544212e9, 1.732080949508029e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Const, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Const), (::Vector{Float64}, Duplicated)", Any[], 6, false, false, true, 1.732080949509556e9, 1.732080951347017e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.732080949508084e9, 1.732080951347022e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Duplicated, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Const)", Any[], 6, false, false, true, 1.732080951348456e9, 1.7320809523661e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.732080951347053e9, 1.732080952366105e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Duplicated, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 6, false, false, true, 1.732080952367482e9, 1.732080953690882e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.732080952366132e9, 1.732080953690889e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Const, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Const), (::Vector{Float64}, Const)", Any[], 7, false, false, true, 1.732080953692429e9, 1.73208095433167e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.732080953690922e9, 1.732080954331675e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Const, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Const), (::Vector{Float64}, Duplicated)", Any[], 7, false, false, true, 1.732080954333127e9, 1.732080955065525e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.732080954331708e9, 1.732080955065529e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Duplicated, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Const)", Any[], 7, false, false, true, 1.732080955066966e9, 1.732080955846259e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.73208095506556e9, 1.732080955846265e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Duplicated, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 7, false, false, true, 1.732080955847764e9, 1.732080956734385e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.732080955846302e9, 1.732080956734389e9, false, "custom_rule.md")], 0, false, false, true, 1.732080941113837e9, 1.732080956734391e9, false, "custom_rule.md"), Test.DefaultTestSet("reverse", Any[Test.DefaultTestSet("RT = Active, Tx = Duplicated, Ty = Duplicated, fun = g", Any[Test.DefaultTestSet("test_reverse: g with return activity Active on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 11, false, false, true, 1.73208095690713e9, 1.732080959581236e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_reverse.jl")], 0, false, false, true, 1.732080956734447e9, 1.732080959581244e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Active, Tx = Duplicated, Ty = Duplicated, fun = h", Any[Test.DefaultTestSet("test_reverse: h with return activity Active on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 11, false, false, true, 1.732080959753449e9, 1.732080962817439e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_reverse.jl")], 0, false, false, true, 1.732080959581287e9, 1.732080962817443e9, false, "custom_rule.md")], 0, false, false, true, 1.732080956734413e9, 1.732080962817445e9, false, "custom_rule.md")], 0, false, false, true, 1.732080941113797e9, 1.732080962817446e9, false, "custom_rule.md")
In any package that implements Enzyme rules using EnzymeRules, it is recommended to add EnzymeTestUtils as a test dependency to test the rules.
This page was generated using Literate.jl.