Enzyme custom rules tutorial
The tutorial below focuses on a simple setting to illustrate the basic concepts of writing custom rules. For more complex custom rules beyond the scope of this tutorial, you may take inspiration from the following in-the-wild examples:
The goal of this tutorial is to give a simple example of defining a custom rule with Enzyme. Specifically, our goal will be to write custom rules for the following function f:
function f(y, x)
y .= x.^2
return sum(y)
endf (generic function with 1 method)Our function f populates its first input y with the element-wise square of x. In addition, it returns sum(y) as output. What a sneaky function!
In this case, Enzyme can differentiate through f automatically. For example, using forward mode:
using Enzyme
x = [3.0, 1.0]
dx = [1.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
g(y, x) = f(y, x)^2 # function to differentiate
@show autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) # derivative of g w.r.t. x[1]
@show dy; # derivative of y w.r.t. x[1] when g is runautodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) = (120.0,)
dy = [6.0, 0.0](See the AutoDiff API tutorial for more information on using autodiff.)
But there may be special cases where we need to write a custom rule to help Enzyme out. Let's see how to write a custom rule for f!
Enzyme can efficiently handle a wide range of constructs, and so a custom rule should only be required in certain special cases. For example, a function may make a foreign call that Enzyme cannot differentiate, or we may have higher-level mathematical knowledge that enables us to write a more efficient rule. Even in these cases, try to make your custom rule encapsulate the minimum possible construct that Enzyme cannot differentiate, rather than expanding the scope of the rule unnecessarily. For pedagogical purposes, we will disregard this principle here and go ahead and write a custom rule for f :)
Defining our first rule
First, we import the functions EnzymeRules.forward, EnzymeRules.augmented_primal, and EnzymeRules.reverse. We need to overload forward in order to define a custom forward rule, and we need to overload augmented_primal and reverse in order to define a custom reverse rule.
import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRulesIn this section, we write a simple forward rule to start out:
function forward(config::FwdConfig, func::Const{typeof(f)}, ::Type{<:Duplicated}, y::Duplicated, x::Duplicated)
println("Using custom rule!")
ret = func.val(y.val, x.val)
y.dval .= 2 .* x.val .* x.dval
return Duplicated(ret, sum(y.dval))
endforward (generic function with 19 methods)In the signature of our rule, we have made use of Enzyme's activity annotations. Let's break down each one:
- the
EnzymeRules.FwdConfigconfiguration passes certain compile-time information about differentiation procedure (the width, and if we're using runtime activity), - the
Constannotation onfindicates that we accept a functionfthat does not have a derivative component, which makes sense sincefis not a closure with data that could be differentiated. - the
Duplicatedannotation given in the second argument annotates the return value off. This means that ourforwardfunction should return an output of typeDuplicated, containing the original outputsum(y)and its derivative. - the
Duplicatedannotations forxandymean that ourforwardfunction handles inputsxandywhich have been marked asDuplicated. We should update their shadows with their derivative contributions.
In the logic of our forward function, we run the original function, populate y.dval (the shadow of y), and finally return a Duplicated for the output as promised. Let's see our rule in action! With the same setup as before:
x = [3.0, 1.0]
dx = [1.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
g(y, x) = f(y, x)^2 # function to differentiate
@show autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) # derivative of g w.r.t. x[1]
@show dy; # derivative of y w.r.t. x[1] when g is runUsing custom rule!
autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) = (120.0,)
dy = [6.0, 0.0]We see that our custom forward rule has been triggered and gives the same answer as before.
Handling more activities
Our custom rule applies for the specific set of activities that are annotated for f in the above autodiff call. However, Enzyme has a number of other annotations. Let us consider a particular example, where the output has a DuplicatedNoNeed annotation. This means we are only interested in its derivative, not its value. To squeeze out the last drop of performance, the below rule avoids computing the output of the original function and just computes its derivative.
function forward(config, func::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, y::Duplicated, x::Duplicated)
println("Using custom rule with DuplicatedNoNeed output.")
y.val .= x.val.^2
y.dval .= 2 .* x.val .* x.dval
return sum(y.dval)
endforward (generic function with 20 methods)Our rule is triggered, for example, when we call autodiff directly on f, as the return value's derivative isn't needed:
x = [3.0, 1.0]
dx = [1.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
@show autodiff(Forward, f, Duplicated(y, dy), Duplicated(x, dx)) # derivative of f w.r.t. x[1]
@show dy; # derivative of y w.r.t. x[1] when f is runUsing custom rule with DuplicatedNoNeed output.
autodiff(Forward, f, Duplicated(y, dy), Duplicated(x, dx)) = (6.0,)
dy = [6.0, 0.0]When multiple custom rules for a function are defined, the correct rule is chosen using Julia's multiple dispatch. In particular, it is important to understand that the custom rule does not determine the activities of the inputs and the return value: rather, Enzyme decides the activity annotations independently, and then dispatches to the custom rule handling the activities, if one exists. If a custom rule is specified for the correct function/argument types, but not the correct activity annotation, a runtime error will be thrown alerting the user to the missing activity rule rather than silently ignoring the rule."
Finally, it may be that either x, y, or the return value are marked as Const, in which case we can simply return the original result. However, Enzyme also may determine the return is not differentiable and also not needed for other computations, in which case we should simply return nothing.
We can in fact handle this case, along with the previous two cases, all together in a single rule by leveraging utility functions EnzymeRules.needs_primal and EnzymeRules.needs_shadow, which return true if the original return or the derivative is needed to be returned, respectively:
Base.delete_method.(methods(forward, (Const{typeof(f)}, Vararg{Any}))) # delete our old rules
function forward(config, func::Const{typeof(f)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
y::Union{Const, Duplicated}, x::Union{Const, Duplicated})
println("Using our general custom rule!")
y.val .= x.val.^2
if !(x isa Const) && !(y isa Const)
y.dval .= 2 .* x.val .* x.dval
elseif !(y isa Const)
make_zero!(y.dval)
end
dret = !(y isa Const) ? sum(y.dval) : zero(eltype(y.val))
if needs_primal(config) && needs_shadow(config)
return Duplicated(sum(y.val), dret)
elseif needs_primal(config)
return sum(y.val)
elseif needs_shadow(config)
return dret
else
return nothing
end
endforward (generic function with 17 methods)Let's try out our rule:
x = [3.0, 1.0]
dx = [1.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
g(y, x) = f(y, x)^2 # function to differentiate
@show autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) # derivative of g w.r.t. x[1]
@show autodiff(Forward, g, Const(y), Duplicated(x, dx)) # derivative of g w.r.t. x[1], with y annotated Const
@show autodiff(Forward, g, Const(y), Const(x)); # derivative of g w.r.t. x[1], with x and y annotated ConstUsing custom rule!
autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) = (120.0,)
Using our general custom rule!
autodiff(Forward, g, Const(y), Duplicated(x, dx)) = (0.0,)
Using our general custom rule!
autodiff(Forward, g, Const(y), Const(x)) = (0.0,)Note that there are also exist batched duplicated annotations for forward mode, namely BatchDuplicated and BatchDuplicatedNoNeed, which are not covered in this tutorial.
Defining a reverse-mode rule
Let's look at how to write a simple reverse-mode rule! First, we write a method for EnzymeRules.augmented_primal:
function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active},
y::Duplicated, x::Duplicated)
println("In custom augmented primal rule.")
# Compute primal
if needs_primal(config)
primal = func.val(y.val, x.val)
else
y.val .= x.val.^2 # y still needs to be mutated even if primal not needed!
primal = nothing
end
# Save x in tape if x will be overwritten
if overwritten(config)[3]
tape = copy(x.val)
else
tape = nothing
end
# Return an AugmentedReturn object with shadow = nothing
return AugmentedReturn(primal, nothing, tape)
endaugmented_primal (generic function with 17 methods)Let's unpack our signature for augmented_primal :
- We accepted a
EnzymeRules.RevConfigobject with a specified width of 1, which means that our rule does not support batched reverse mode. - We annotated
fwithConstas usual. - We dispatched on an
Activeannotation for the return value. This is a special annotation for scalar values, such as our return value, that indicates that that we care about the value's derivative but we need not explicitly allocate a mutable shadow since it is a scalar value. - We annotated
xandywithDuplicated, similar to our first simple forward rule.
Now, let's unpack the body of our augmented_primal rule:
- We checked if the
configrequires the primal. If not, we need not compute the return value, but we make sure to mutateyin all cases. - We checked if
xcould possibly be overwritten using theOverwrittenattribute ofEnzymeRules.RevConfig. If so, we save the elements ofxon thetapeof the returnedEnzymeRules.AugmentedReturnobject. - We return a shadow of
nothingsince the return value isActiveand hence does not need a shadow.
Now, we write a method for EnzymeRules.reverse:
function reverse(config::RevConfigWidth{1}, func::Const{typeof(f)}, dret::Active, tape,
y::Duplicated, x::Duplicated)
println("In custom reverse rule.")
# retrieve x value, either from original x or from tape if x may have been overwritten.
xval = overwritten(config)[3] ? tape : x.val
# accumulate dret into x's shadow. don't assign!
x.dval .+= 2 .* xval .* dret.val
# also accumulate any derivative in y's shadow into x's shadow.
x.dval .+= 2 .* xval .* y.dval
make_zero!(y.dval)
return (nothing, nothing)
endreverse (generic function with 20 methods)Let's make a few observations about our reverse rule:
- The activities used in the signature correspond to what we used for
augmented_primal. - However, for
Activereturn types such as in this case, we now receive an instancedretofActivefor the return type, not just a type annotation, which stores the derivative value forret(not the original return value!). For the other annotations (e.g.Duplicated), we still receive only the type. In that case, if necessary a reference to the shadow of the output should be placed on the tape inaugmented_primal. - Using
dret.valandy.dval, we accumulate the backpropagated derivatives forxinto its shadowx.dval. Note that we have to accumulate from bothy.dvalanddret.val. This is because in reverse-mode AD we have to sum up the derivatives from all uses: ifywas read after our function, we need to consider derivatives from that use as well. - We zero-out
y's shadow. This is becauseyis overwritten withinf, so there is no derivative w.r.t. to theythat was originally inputted. - Finally, since all derivatives are accumulated in place (in the shadows of the
Duplicatedarguments), these derivatives must not be communicated via the return value. Hence, we return(nothing, nothing). If, instead, one of our arguments was annotated asActive, we would have to provide its derivative at the corresponding index in the tuple returned.
Finally, let's see our reverse rule in action!
x = [3.0, 1.0]
dx = [0.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
g(y, x) = f(y, x)^2
autodiff(Reverse, g, Duplicated(y, dy), Duplicated(x, dx))
@show dx # derivative of g w.r.t. x
@show dy; # derivative of g w.r.t. yIn custom augmented primal rule.
In custom reverse rule.
dx = [120.0, 40.0]
dy = [0.0, 0.0]Let's also try a function which mutates x after running f, and also uses y directly rather than only ret after running f (but ultimately gives the same result as above):
function h(y, x)
ret = f(y, x)
x .= x.^2
return ret * sum(y)
end
x = [3.0, 1.0]
y = [0.0, 0.0]
make_zero!(dx)
make_zero!(dy)
autodiff(Reverse, h, Duplicated(y, dy), Duplicated(x, dx))
@show dx # derivative of h w.r.t. x
@show dy; # derivative of h w.r.t. yIn custom augmented primal rule.
In custom reverse rule.
dx = [120.0, 40.0]
dy = [0.0, 0.0]Marking functions inactive
If we want to tell Enzyme that the function call does not affect the differentiation result in any form (i.e. not by side effects or through its return values), we can simply use EnzymeRules.inactive. So long as there exists a matching dispatch to EnzymeRules.inactive, the function will be considered inactive. For example:
printhi() = println("Hi!")
EnzymeRules.inactive(::typeof(printhi), args...) = nothing
function k(x)
printhi()
return x^2
end
autodiff(Forward, k, Duplicated(2.0, 1.0))(4.0,)Or for a case where we incorrectly mark a function inactive:
double(x) = 2*x
EnzymeRules.inactive(::typeof(double), args...) = nothing
autodiff(Forward, x -> x + double(x), Duplicated(2.0, 1.0)) # mathematically should be 3.0, inactive rule causes it to be 1.0(1.0,)Testing our rules
We can test our rules using finite differences using EnzymeTestUtils.test_forward and EnzymeTestUtils.test_reverse.
using EnzymeTestUtils, Test
@testset "f rules" begin
@testset "forward" begin
@testset for RT in (Const, DuplicatedNoNeed, Duplicated),
Tx in (Const, Duplicated),
Ty in (Const, Duplicated)
x = [3.0, 1.0]
y = [0.0, 0.0]
test_forward(g, RT, (x, Tx), (y, Ty))
end
end
@testset "reverse" begin
@testset for RT in (Active,),
Tx in (Duplicated,),
Ty in (Duplicated,),
fun in (g, h)
x = [3.0, 1.0]
y = [0.0, 0.0]
test_reverse(fun, RT, (x, Tx), (y, Ty))
end
end
endTest.DefaultTestSet("f rules", Any[Test.DefaultTestSet("forward", Any[Test.DefaultTestSet("RT = Const, Tx = Const, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Const), (::Vector{Float64}, Const)", Any[], 6, false, false, true, 1.761350469243491e9, 1.761350470222514e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.761350466760547e9, 1.761350470222527e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Const, Tx = Const, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Const), (::Vector{Float64}, Duplicated)", Any[], 6, false, false, true, 1.761350470223844e9, 1.76135047079379e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.761350470222569e9, 1.761350470793799e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Const, Tx = Duplicated, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Const)", Any[], 6, false, false, true, 1.761350470795056e9, 1.761350471030684e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.761350470793832e9, 1.761350471030694e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Const, Tx = Duplicated, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 6, false, false, true, 1.761350471031946e9, 1.761350471381257e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.761350471030728e9, 1.761350471381265e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Const, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Const), (::Vector{Float64}, Const)", Any[], 6, false, false, true, 1.761350471382539e9, 1.761350471755439e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.761350471381302e9, 1.761350471755453e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Const, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Const), (::Vector{Float64}, Duplicated)", Any[], 6, false, false, true, 1.761350471756695e9, 1.76135047296963e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.761350471755495e9, 1.761350472969635e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Duplicated, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Const)", Any[], 6, false, false, true, 1.761350472970878e9, 1.761350473347225e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.76135047296966e9, 1.761350473347231e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Duplicated, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 6, false, false, true, 1.761350473348466e9, 1.761350473969043e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.761350473347255e9, 1.761350473969048e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Const, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Const), (::Vector{Float64}, Const)", Any[], 7, false, false, true, 1.761350473970293e9, 1.761350474181819e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.761350473969073e9, 1.761350474181826e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Const, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Const), (::Vector{Float64}, Duplicated)", Any[], 7, false, false, true, 1.761350474183063e9, 1.761350474392114e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.76135047418185e9, 1.761350474392119e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Duplicated, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Const)", Any[], 7, false, false, true, 1.761350474393353e9, 1.761350474625679e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.761350474392142e9, 1.761350474625685e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Duplicated, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 7, false, false, true, 1.761350474626929e9, 1.76135047491084e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.761350474625708e9, 1.761350474910845e9, false, "custom_rule.md")], 0, false, false, true, 1.761350466760508e9, 1.761350474910846e9, false, "custom_rule.md"), Test.DefaultTestSet("reverse", Any[Test.DefaultTestSet("RT = Active, Tx = Duplicated, Ty = Duplicated, fun = g", Any[Test.DefaultTestSet("test_reverse: g with return activity Active on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 11, false, false, true, 1.761350475065367e9, 1.761350476854906e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_reverse.jl")], 0, false, false, true, 1.76135047491089e9, 1.761350476854924e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Active, Tx = Duplicated, Ty = Duplicated, fun = h", Any[Test.DefaultTestSet("test_reverse: h with return activity Active on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 11, false, false, true, 1.761350477010292e9, 1.761350478287642e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_reverse.jl")], 0, false, false, true, 1.761350476854956e9, 1.761350478287652e9, false, "custom_rule.md")], 0, false, false, true, 1.761350474910866e9, 1.761350478287654e9, false, "custom_rule.md")], 0, false, false, true, 1.76135046676047e9, 1.761350478287655e9, false, "custom_rule.md")In any package that implements Enzyme rules using EnzymeRules, it is recommended to add EnzymeTestUtils as a test dependency to test the rules.
This page was generated using Literate.jl.