# Enzyme custom rules tutorial

The tutorial below focuses on a simple setting to illustrate the basic concepts of writing custom rules. For more complex custom rules beyond the scope of this tutorial, you may take inspiration from the following in-the-wild examples:

The goal of this tutorial is to give a simple example of defining a custom rule with Enzyme. Specifically, our goal will be to write custom rules for the following function `f`

:

```
function f(y, x)
y .= x.^2
return sum(y)
end
```

`f (generic function with 1 method)`

Our function `f`

populates its first input `y`

with the element-wise square of `x`

. In addition, it returns `sum(y)`

as output. What a sneaky function!

In this case, Enzyme can differentiate through `f`

automatically. For example, using forward mode:

```
using Enzyme
x = [3.0, 1.0]
dx = [1.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
g(y, x) = f(y, x)^2 # function to differentiate
@show autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) # derivative of g w.r.t. x[1]
@show dy; # derivative of y w.r.t. x[1] when g is run
```

```
autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) = (120.0,)
dy = [6.0, 0.0]
```

(See the AutoDiff API tutorial for more information on using `autodiff`

.)

But there may be special cases where we need to write a custom rule to help Enzyme out. Let's see how to write a custom rule for `f`

!

Enzyme can efficiently handle a wide range of constructs, and so a custom rule should only be required in certain special cases. For example, a function may make a foreign call that Enzyme cannot differentiate, or we may have higher-level mathematical knowledge that enables us to write a more efficient rule. Even in these cases, try to make your custom rule encapsulate the minimum possible construct that Enzyme cannot differentiate, rather than expanding the scope of the rule unnecessarily. For pedagogical purposes, we will disregard this principle here and go ahead and write a custom rule for `f`

:)

## Defining our first rule

First, we import the functions `EnzymeRules.forward`

, `EnzymeRules.augmented_primal`

, and `EnzymeRules.reverse`

. We need to overload `forward`

in order to define a custom forward rule, and we need to overload `augmented_primal`

and `reverse`

in order to define a custom reverse rule.

```
import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules
```

In this section, we write a simple forward rule to start out:

```
function forward(func::Const{typeof(f)}, ::Type{<:Duplicated}, y::Duplicated, x::Duplicated)
println("Using custom rule!")
ret = func.val(y.val, x.val)
y.dval .= 2 .* x.val .* x.dval
return Duplicated(ret, sum(y.dval))
end
```

`forward (generic function with 11 methods)`

In the signature of our rule, we have made use of `Enzyme`

's activity annotations. Let's break down each one:

- the
`Const`

annotation on`f`

indicates that we accept a function`f`

that does not have a derivative component, which makes sense since`f`

is not a closure with data that could be differentiated. - the
`Duplicated`

annotation given in the second argument annotates the return value of`f`

. This means that our`forward`

function should return an output of type`Duplicated`

, containing the original output`sum(y)`

and its derivative. - the
`Duplicated`

annotations for`x`

and`y`

mean that our`forward`

function handles inputs`x`

and`y`

which have been marked as`Duplicated`

. We should update their shadows with their derivative contributions.

In the logic of our forward function, we run the original function, populate `y.dval`

(the shadow of `y`

), and finally return a `Duplicated`

for the output as promised. Let's see our rule in action! With the same setup as before:

```
x = [3.0, 1.0]
dx = [1.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
g(y, x) = f(y, x)^2 # function to differentiate
@show autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) # derivative of g w.r.t. x[1]
@show dy; # derivative of y w.r.t. x[1] when g is run
```

```
Using custom rule!
autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) = (120.0,)
dy = [6.0, 0.0]
```

We see that our custom forward rule has been triggered and gives the same answer as before.

## Handling more activities

Our custom rule applies for the specific set of activities that are annotated for `f`

in the above `autodiff`

call. However, Enzyme has a number of other annotations. Let us consider a particular example, where the output has a `DuplicatedNoNeed`

annotation. This means we are only interested in its derivative, not its value. To squeeze out the last drop of performance, the below rule avoids computing the output of the original function and just computes its derivative.

```
function forward(func::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, y::Duplicated, x::Duplicated)
println("Using custom rule with DuplicatedNoNeed output.")
y.val .= x.val.^2
y.dval .= 2 .* x.val .* x.dval
return sum(y.dval)
end
```

`forward (generic function with 12 methods)`

Our rule is triggered, for example, when we call `autodiff`

directly on `f`

, as the return value's derivative isn't needed:

```
x = [3.0, 1.0]
dx = [1.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
@show autodiff(Forward, f, Duplicated(y, dy), Duplicated(x, dx)) # derivative of f w.r.t. x[1]
@show dy; # derivative of y w.r.t. x[1] when f is run
```

```
Using custom rule with DuplicatedNoNeed output.
autodiff(Forward, f, Duplicated(y, dy), Duplicated(x, dx)) = (6.0,)
dy = [6.0, 0.0]
```

When multiple custom rules for a function are defined, the correct rule is chosen using Julia's multiple dispatch. In particular, it is important to understand that the custom rule does not *determine* the activities of the inputs and the return value: rather, `Enzyme`

decides the activity annotations independently, and then *dispatches* to the custom rule handling the activities, if one exists. If a custom rule is specified for the correct function/argument types, but not the correct activity annotation, a runtime error will be thrown alerting the user to the missing activity rule rather than silently ignoring the rule."

Finally, it may be that either `x`

, `y`

, or the return value are marked as `Const`

. We can in fact handle this case, along with the previous two cases, all together in a single rule:

```
Base.delete_method.(methods(forward, (Const{typeof(f)}, Vararg{Any}))) # delete our old rules
function forward(func::Const{typeof(f)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
y::Union{Const, Duplicated}, x::Union{Const, Duplicated})
println("Using our general custom rule!")
y.val .= x.val.^2
if !(x isa Const) && !(y isa Const)
y.dval .= 2 .* x.val .* x.dval
elseif !(y isa Const)
make_zero!(y.dval)
end
dret = !(y isa Const) ? sum(y.dval) : zero(eltype(y.val))
if RT <: Const
return sum(y.val)
elseif RT <: DuplicatedNoNeed
return dret
else
return Duplicated(sum(y.val), dret)
end
end
```

`forward (generic function with 11 methods)`

Let's try out our rule:

```
x = [3.0, 1.0]
dx = [1.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
g(y, x) = f(y, x)^2 # function to differentiate
@show autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) # derivative of g w.r.t. x[1]
@show autodiff(Forward, g, Const(y), Duplicated(x, dx)) # derivative of g w.r.t. x[1], with y annotated Const
@show autodiff(Forward, g, Const(y), Const(x)); # derivative of g w.r.t. x[1], with x and y annotated Const
```

```
Using our general custom rule!
autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) = (120.0,)
Using our general custom rule!
autodiff(Forward, g, Const(y), Duplicated(x, dx)) = (0.0,)
autodiff(Forward, g, Const(y), Const(x)) = (0.0,)
```

Note that there are also exist batched duplicated annotations for forward mode, namely `BatchDuplicated`

and `BatchDuplicatedNoNeed`

, which are not covered in this tutorial.

## Defining a reverse-mode rule

Let's look at how to write a simple reverse-mode rule! First, we write a method for `EnzymeRules.augmented_primal`

:

```
function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active},
y::Duplicated, x::Duplicated)
println("In custom augmented primal rule.")
# Compute primal
if needs_primal(config)
primal = func.val(y.val, x.val)
else
y.val .= x.val.^2 # y still needs to be mutated even if primal not needed!
primal = nothing
end
# Save x in tape if x will be overwritten
if overwritten(config)[3]
tape = copy(x.val)
else
tape = nothing
end
# Return an AugmentedReturn object with shadow = nothing
return AugmentedReturn(primal, nothing, tape)
end
```

`augmented_primal (generic function with 10 methods)`

Let's unpack our signature for `augmented_primal`

:

- We accepted a
`EnzymeRules.Config`

object with a specified width of 1, which means that our rule does not support batched reverse mode. - We annotated
`f`

with`Const`

as usual. - We dispatched on an
`Active`

annotation for the return value. This is a special annotation for scalar values, such as our return value, that indicates that that we care about the value's derivative but we need not explicitly allocate a mutable shadow since it is a scalar value. - We annotated
`x`

and`y`

with`Duplicated`

, similar to our first simple forward rule.

Now, let's unpack the body of our `augmented_primal`

rule:

- We checked if the
`config`

requires the primal. If not, we need not compute the return value, but we make sure to mutate`y`

in all cases. - We checked if
`x`

could possibly be overwritten using the`Overwritten`

attribute of`EnzymeRules.Config`

. If so, we save the elements of`x`

on the`tape`

of the returned`EnzymeRules.AugmentedReturn`

object. - We return a shadow of
`nothing`

since the return value is`Active`

and hence does not need a shadow.

Now, we write a method for `EnzymeRules.reverse`

:

```
function reverse(config::ConfigWidth{1}, func::Const{typeof(f)}, dret::Active, tape,
y::Duplicated, x::Duplicated)
println("In custom reverse rule.")
# retrieve x value, either from original x or from tape if x may have been overwritten.
xval = overwritten(config)[3] ? tape : x.val
# accumulate dret into x's shadow. don't assign!
x.dval .+= 2 .* xval .* dret.val
# also accumulate any derivative in y's shadow into x's shadow.
x.dval .+= 2 .* xval .* y.dval
make_zero!(y.dval)
return (nothing, nothing)
end
```

`reverse (generic function with 10 methods)`

Let's make a few observations about our reverse rule:

- The activities used in the signature correspond to what we used for
`augmented_primal`

. - However, for
`Active`

return types such as in this case, we now receive an*instance*`dret`

of`Active`

for the return type, not just a type annotation, which stores the derivative value for`ret`

(not the original return value!). For the other annotations (e.g.`Duplicated`

), we still receive only the type. In that case, if necessary a reference to the shadow of the output should be placed on the tape in`augmented_primal`

. - Using
`dret.val`

and`y.dval`

, we accumulate the backpropagated derivatives for`x`

into its shadow`x.dval`

. Note that we have to accumulate from both`y.dval`

and`dret.val`

. This is because in reverse-mode AD we have to sum up the derivatives from all uses: if`y`

was read after our function, we need to consider derivatives from that use as well. - We zero-out
`y`

's shadow. This is because`y`

is overwritten within`f`

, so there is no derivative w.r.t. to the`y`

that was originally inputted. - Finally, since all derivatives are accumulated
*in place*(in the shadows of the`Duplicated`

arguments), these derivatives must not be communicated via the return value. Hence, we return`(nothing, nothing)`

. If, instead, one of our arguments was annotated as`Active`

, we would have to provide its derivative at the corresponding index in the tuple returned.

Finally, let's see our reverse rule in action!

```
x = [3.0, 1.0]
dx = [0.0, 0.0]
y = [0.0, 0.0]
dy = [0.0, 0.0]
g(y, x) = f(y, x)^2
autodiff(Reverse, g, Duplicated(y, dy), Duplicated(x, dx))
@show dx # derivative of g w.r.t. x
@show dy; # derivative of g w.r.t. y
```

```
In custom augmented primal rule.
In custom reverse rule.
dx = [120.0, 40.0]
dy = [0.0, 0.0]
```

Let's also try a function which mutates `x`

after running `f`

, and also uses `y`

directly rather than only `ret`

after running `f`

(but ultimately gives the same result as above):

```
function h(y, x)
ret = f(y, x)
x .= x.^2
return ret * sum(y)
end
x = [3.0, 1.0]
y = [0.0, 0.0]
make_zero!(dx)
make_zero!(dy)
autodiff(Reverse, h, Duplicated(y, dy), Duplicated(x, dx))
@show dx # derivative of h w.r.t. x
@show dy; # derivative of h w.r.t. y
```

```
In custom augmented primal rule.
In custom reverse rule.
dx = [120.0, 40.0]
dy = [0.0, 0.0]
```

## Marking functions inactive

If we want to tell Enzyme that the function call does not affect the differentiation result in any form (i.e. not by side effects or through its return values), we can simply use `EnzymeRules.inactive`

. So long as there exists a matching dispatch to `EnzymeRules.inactive`

, the function will be considered inactive. For example:

```
printhi() = println("Hi!")
EnzymeRules.inactive(::typeof(printhi), args...) = nothing
function k(x)
printhi()
return x^2
end
autodiff(Forward, k, Duplicated(2.0, 1.0))
```

`(4.0,)`

Or for a case where we incorrectly mark a function inactive:

```
double(x) = 2*x
EnzymeRules.inactive(::typeof(double), args...) = nothing
autodiff(Forward, x -> x + double(x), Duplicated(2.0, 1.0)) # mathematically should be 3.0, inactive rule causes it to be 1.0
```

`(1.0,)`

## Testing our rules

We can test our rules using finite differences using `EnzymeTestUtils.test_forward`

and `EnzymeTestUtils.test_reverse`

.

```
using EnzymeTestUtils, Test
@testset "f rules" begin
@testset "forward" begin
@testset for RT in (Const, DuplicatedNoNeed, Duplicated),
Tx in (Const, Duplicated),
Ty in (Const, Duplicated)
x = [3.0, 1.0]
y = [0.0, 0.0]
test_forward(g, RT, (x, Tx), (y, Ty))
end
end
@testset "reverse" begin
@testset for RT in (Active,),
Tx in (Duplicated,),
Ty in (Duplicated,),
fun in (g, h)
x = [3.0, 1.0]
y = [0.0, 0.0]
test_reverse(fun, RT, (x, Tx), (y, Ty))
end
end
end
```

`Test.DefaultTestSet("f rules", Any[Test.DefaultTestSet("forward", Any[Test.DefaultTestSet("RT = Const, Tx = Const, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Const), (::Vector{Float64}, Const)", Any[], 6, false, false, true, 1.718555940485742e9, 1.718555942024954e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.718555937769956e9, 1.718555942024962e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Const, Tx = Const, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Const), (::Vector{Float64}, Duplicated)", Any[], 6, false, false, true, 1.718555942026289e9, 1.718555943508763e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.718555942025021e9, 1.718555943508768e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Const, Tx = Duplicated, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Const)", Any[], 6, false, false, true, 1.718555943509924e9, 1.718555944663926e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.718555943508819e9, 1.71855594466393e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Const, Tx = Duplicated, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 6, false, false, true, 1.718555944665096e9, 1.718555945907154e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.718555944663976e9, 1.71855594590716e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Const, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Const), (::Vector{Float64}, Const)", Any[], 6, false, false, true, 1.718555945908624e9, 1.718555947225969e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.718555945907222e9, 1.718555947225974e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Const, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Const), (::Vector{Float64}, Duplicated)", Any[], 6, false, false, true, 1.718555947227264e9, 1.71855594933468e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.718555947226024e9, 1.718555949334684e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Duplicated, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Const)", Any[], 6, false, false, true, 1.718555949335792e9, 1.718555950686698e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.718555949334715e9, 1.718555950686702e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Duplicated, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 6, false, false, true, 1.718555950687949e9, 1.718555952304883e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.718555950686733e9, 1.718555952304886e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Const, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Const), (::Vector{Float64}, Const)", Any[], 7, false, false, true, 1.718555952305952e9, 1.718555953325402e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.718555952304916e9, 1.718555953325405e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Const, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Const), (::Vector{Float64}, Duplicated)", Any[], 7, false, false, true, 1.718555953326478e9, 1.718555954427577e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.718555953325435e9, 1.718555954427581e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Duplicated, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Const)", Any[], 7, false, false, true, 1.718555954428933e9, 1.718555955557242e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.718555954427615e9, 1.718555955557245e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Duplicated, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 7, false, false, true, 1.7185559555586e9, 1.718555956775458e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.718555955557278e9, 1.718555956775461e9, false, "custom_rule.md")], 0, false, false, true, 1.718555937769905e9, 1.718555956775462e9, false, "custom_rule.md"), Test.DefaultTestSet("reverse", Any[Test.DefaultTestSet("RT = Active, Tx = Duplicated, Ty = Duplicated, fun = g", Any[Test.DefaultTestSet("test_reverse: g with return activity Active on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 11, false, false, true, 1.718555956899884e9, 1.718555960714414e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_reverse.jl")], 0, false, false, true, 1.718555956775523e9, 1.718555960714422e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Active, Tx = Duplicated, Ty = Duplicated, fun = h", Any[Test.DefaultTestSet("test_reverse: h with return activity Active on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 11, false, false, true, 1.718555960715937e9, 1.718555964069262e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_reverse.jl")], 0, false, false, true, 1.718555960714468e9, 1.718555964069266e9, false, "custom_rule.md")], 0, false, false, true, 1.718555956775483e9, 1.718555964069267e9, false, "custom_rule.md")], 0, false, false, true, 1.718555937769855e9, 1.718555964069268e9, false, "custom_rule.md")`

In any package that implements Enzyme rules using EnzymeRules, it is recommended to add EnzymeTestUtils as a test dependency to test the rules.

*This page was generated using Literate.jl.*