Enzyme for adjoint tutorial: Stommel three-box ocean model

The goal of this tutorial is to teach about a specific usage of Enzyme's automatic differentiation capabilities, and will be centered around the Stommel ocean model. This is a nice example to see how powerful Enzyme is, and the ability of it to take a derivative of a complicated function (namely one that has many parts and parameters). This tutorial will focus first on the computations and getting Enzyme running, for those interested a mathematical explanation of the model and what an adjoint variable is will be provided at the end.

Brief model overview

The Stommel box model can be viewed as a watered down full ocean model. In our example, we have three boxes (Box One, Box Two, and Box Three) and we model the transport of fluid between them. The full equations of our system are given by:

\[\begin{aligned} U &= u_0 \left\{ \rho_2 - \left[ \rho_1 + (1 - \delta) \rho_3 \right] \right\} \\ \rho_i &= -\alpha T_i + \beta S_i, \; \; \; \; i = 1, 2, 3 \end{aligned}\]

for the transport U and densities $\rho$, and then the time derivatives

\[\begin{aligned} \dot{T_1} &= U(T_3 - T_1)/V_1 + \gamma (T_1^* - T_1 ) & \dot{S_1} &= U(S_3 - S_1)/V_1 + FW_1/V_1 \\ \dot{T_2} &= U(T_1 - T_2)/V_2 + \gamma (T_2^* - T_2 ) & \dot{S_2} &= U(S_1 - S_2)/V_2 + FW_2/V_2 \\ \dot{T_3} &= U(T_2 - T_3)/V_3 & \dot{S_3} &= U(S_2 - S_3)/V_3 \end{aligned}\]

for positive transport, $U > 0$, and

\[\begin{aligned} \dot{T_1} &= U(T_2 - T_1)/V_1 + \gamma (T_1^* - T_1) & \dot{S_1} &= U(S_2 - S_1)/V_1 + FW_1/V_1 \\ \dot{T_2} &= U(T_3 - T_2)/V_2 + \gamma (T_2^* - T_2 ) & \dot{S_2} &= U(S_3 - S_2)/V_2 + FW_2/V_2 \\ \dot{T_3} &= U(T_1 - T_3)/V_3 & \dot{S_3} &= U(S_1 - S_3)/V_3 \end{aligned}\]

for $U \leq 0$. The only force driving our system is a density gradient generated via temperature and salinity differences between the boxes. This makes it a really easy model to play around with! With this in mind, the model is run forward with the steps:

  1. Compute densities
  2. Compute transport
  3. Compute time derivatives of the box temperatures and salinities
  4. Update the state vector

We'll start by going through the model setup step by step, then providing a few test cases with Enzyme.

Model setup

Model dependencies

Let's first add the necessary packages to run everything

using Enzyme

Initialize constants

The system equations have quite a few constants that appear, here we initialize them for later use. We'll do this in a Julia way: we have an empty structure that will hold all the parameters, and a function (we'll call this setup) that initializes them. This means that, so long as we don't need to change parameters, we only need to run setup once.

struct ModelParameters

    # handy to have constants
    day::Float64
    year::Float64

    # Information related to the boxes
    boxlength::Vector{Float64}      ## Vector with north-south size of each box  [cm]
    boxdepth::Vector{Float64}       ## "          " the depth of each box  [cm]
    boxwidth::Float64               ## "          " the width of each box  [cm]
    boxarea::Vector{Float64}        ## "          " the area of each box   [cm^2]
    boxvol::Vector{Float64}         ## "          " the volume of each box   [cm^3]

    delta::Float64                  ## Constant ratio depth(box1) / (depth(box1) + depth(box3))

    # Parameters that appear in the box model equations
    u0::Float64
    alpha::Float64
    beta::Float64
    gamma::Float64

    # Coefficient for the Robert filter smoother
    rf_coeff::Float64

    # Freshwater forcing
    FW::Vector{Float64}

    # Restoring atmospheric temperatures and salinities
    Tstar::Vector{Float64}
    Sstar::Vector{Float64}

end

function setup()

    blength = [5000.0e5; 1000.0e5; 5000.0e5]
    bdepth = [1.0e5; 5.0e5; 4.0e5]

    delta = bdepth[1]/(bdepth[1] + bdepth[3])

    bwidth = 4000.0*1e5  ## box width, centimeters

    # box areas
    barea = [blength[1]*bwidth;
            blength[2]*bwidth;
            blength[3]*bwidth]

    # box volumes
    bvolume = [barea[1]*bdepth[1];
            barea[2]*bdepth[2];
            barea[3]*bdepth[3]]

    # parameters that are used to ensure units are in CGS (cent-gram-sec)

    day = 3600.0*24.0
    year = day*365.0
    Sv = 1e12                       ## one Sverdrup (a unit of ocean transport), 1e6 meters^3/second

    # parameters that appear in box model equations
    u0 = 16.0*Sv/0.0004
    alpha = 1668e-7
    beta = 0.7811e-3

    gamma = 1/(300*day)

    # robert filter coefficient for the smoother part of the timestep
    robert_filter_coeff = 0.25

    # freshwater forcing
    FW = [(100/year) * 35.0 * barea[1]; -(100/year) * 35.0 * barea[1]]

    # restoring atmospheric temperatures
    Tstar = [22.0; 0.0]
    Sstar = [36.0; 34.0]

    structure_with_parameters = ModelParameters(day,
        year,
        blength,
        bdepth,
        bwidth,
        barea,
        bvolume,
        delta,
        u0,
        alpha,
        beta,
        gamma,
        robert_filter_coeff,
        FW,
        Tstar,
        Sstar
    )

    return structure_with_parameters

end
setup (generic function with 1 method)

Define model functions

Here we define functions that will calculate quantities used in the forward steps.

# function to compute transport
#       Input: rho - the density vector
#       Output: U - transport value

function compute_transport(rho, params)

    U = params.u0 * (rho[2] - (params.delta * rho[1] + (1 - params.delta)*rho[3]))
    return U

end

# function to compute density
#       Input: state = [T1; T2; T3; S1; S2; S3]
#       Output: rho

function compute_density(state, params)

    rho = -params.alpha * state[1:3] + params.beta * state[4:6]
    return rho

end

# lastly, a function that takes one step forward
#       Input: state_now = [T1(t), T2(t), ..., S3(t)]
#              state_old = [T1(t-dt), ..., S3(t-dt)]
#              u = transport(t)
#              dt = time step
#       Output: state_new = [T1(t+dt), ..., S3(t+dt)]

function compute_update(state_now, state_old, u, params, dt)

    dstate_now_dt = zeros(6)
    state_new = zeros(6)

    # first computing the time derivatives of the various temperatures and salinities
    if u > 0

        dstate_now_dt[1] = u * (state_now[3] - state_now[1]) / params.boxvol[1] + params.gamma * (params.Tstar[1] - state_now[1])
        dstate_now_dt[2] = u * (state_now[1] - state_now[2]) / params.boxvol[2] + params.gamma * (params.Tstar[2] - state_now[2])
        dstate_now_dt[3] = u * (state_now[2] - state_now[3]) / params.boxvol[3]

        dstate_now_dt[4] = u * (state_now[6] - state_now[4]) / params.boxvol[1] + params.FW[1] / params.boxvol[1]
        dstate_now_dt[5] = u * (state_now[4] - state_now[5]) / params.boxvol[2] + params.FW[2] / params.boxvol[2]
        dstate_now_dt[6] = u * (state_now[5] - state_now[6]) / params.boxvol[3]

    elseif u <= 0

        dstate_now_dt[1] = u * (state_now[2] - state_now[1]) / params.boxvol[1] + params.gamma * (params.Tstar[1] - state_now[1])
        dstate_now_dt[2] = u * (state_now[3] - state_now[2]) / params.boxvol[2] + params.gamma * (params.Tstar[2] - state_now[2])
        dstate_now_dt[3] = u * (state_now[1] - state_now[3]) / params.boxvol[3]

        dstate_now_dt[4] = u * (state_now[5] - state_now[4]) / params.boxvol[1] + params.FW[1] / params.boxvol[1]
        dstate_now_dt[5] = u * (state_now[6] - state_now[5]) / params.boxvol[2] + params.FW[2] / params.boxvol[2]
        dstate_now_dt[6] = u * (state_now[4] - state_now[6]) / params.boxvol[3]

    end

    # update fldnew using a version of Euler's method
    state_new .= state_old + 2.0 * dt * dstate_now_dt

    return state_new
end
compute_update (generic function with 1 method)

Define forward functions

Finally, we create two functions, the first of which computes and stores all the states of the system, and the second will take just a single step forward.

Let's start with the standard forward function. This is just going to be used to store the states at every timestep:

function integrate(state_now, state_old, dt, M, parameters)

    # Because of the adjoint problem we're setting up, we need to store both the states before
    # and after the Robert filter smoother has been applied
    states_before = [state_old]
    states_after = [state_old]

    for t = 1:M

        rho = compute_density(state_now, parameters)
        u = compute_transport(rho, parameters)
        state_new = compute_update(state_now, state_old, u, parameters, dt)

        # Applying the Robert filter smoother (needed for stability)
        state_new_smoothed = state_now + parameters.rf_coeff * (state_new - 2.0 * state_now + state_old)

        push!(states_after, state_new_smoothed)
        push!(states_before, state_new)

        # cycle the "now, new, old" states
        state_old = state_new_smoothed
        state_now = state_new

    end

    return states_after, states_before
end
integrate (generic function with 1 method)

Now, for the purposes of Enzyme, it would be convenient for us to have a function that runs a single step of the model forward rather than the whole integration. This would allow us to save as many of the adjoint variables as we wish when running the adjoint method, although for the example we'll discuss later we technically only need one of them

function one_step_forward(state_now, state_old, out_now, out_old, parameters, dt)

    state_new_smoothed = zeros(6)
    rho = compute_density(state_now, parameters)                             ## compute density
    u = compute_transport(rho, parameters)                                   ## compute transport
    state_new = compute_update(state_now, state_old, u, parameters, dt)      ## compute new state values

    # Robert filter smoother
    state_new_smoothed[:] = state_now + parameters.rf_coeff * (state_new - 2.0 * state_now + state_old)

    out_old[:] = state_new_smoothed
    out_now[:] = state_new

    return nothing

end
one_step_forward (generic function with 1 method)

One difference to note is that one_step_forward now returns nothing, but is rather a function of both its input and output. Since the output of the function is a vector, we need to have this return nothing for Enzyme to work. Now we can move on to some examples using Enzyme.

Example 1: Simply using Enzyme

For the first example let's just compute the gradient of our forward function and examine the output. We'll just run the model for one step, and take a dt of ten days. The initial conditions of the system are given as Tbar and Sbar. We run setup once here, and never have to run it again! (Unless we decide to change a parameter)

parameters = setup()

Tbar = [20.0; 1.0; 1.0]         ## initial temperatures
Sbar = [35.5; 34.5; 34.5]       ## initial salinities

# Running the model one step forward
states_after_smoother, states_before_smoother = integrate(
    copy([Tbar; Sbar]),
    copy([Tbar; Sbar]),
    10*parameters.day,
    1,
    parameters
)

# Run Enzyme one time on `one_step_forward``
dstate_now = zeros(6)
dstate_old = zeros(6)
out_now = zeros(6); dout_now = ones(6)
out_old = zeros(6); dout_old = ones(6)

autodiff(Reverse,
    one_step_forward,
    Duplicated([Tbar; Sbar], dstate_now),
    Duplicated([Tbar; Sbar], dstate_old),
    Duplicated(out_now, dout_now),
    Duplicated(out_old, dout_old),
    parameters,
    Const(10*parameters.day)
)
((nothing, nothing, nothing, nothing, nothing, nothing),)

In order to run Enzyme on one_step_forward, we've needed to provide quite a few placeholders, and wrap everything in Duplicated as all components of our function are vectors, not scalars. Let's go through and see what Enzyme did with all of those placeholders.

First we can look at what happened to the zero vectors out_now and out_old:

@show out_now, out_old
([20.101970893653334, 0.9646957730133332, 1.0, 35.50026715349918, 34.49973284650082, 34.5], [20.025492723413333, 0.9911739432533333, 1.0, 35.500066788374795, 34.499933211625205, 34.5])

Comparing to the results of forward func:

@show states_before_smoother[2], states_after_smoother[2]
([20.101970893653334, 0.9646957730133332, 1.0, 35.50026715349918, 34.49973284650082, 34.5], [20.025492723413333, 0.9911739432533333, 1.0, 35.500066788374795, 34.499933211625205, 34.5])

we see that Enzyme has computed and stored exactly the output of the forward step. Next, let's look at dstate_now:

@show dstate_now
6-element Vector{Float64}:
 0.41666666666666663
 0.41511917786666663
 0.5015474888
 0.49999999999999994
 0.49845251119999995
 0.5015474888

Just a few numbers, but this is what makes AD so nice: Enzyme has exactly computed the derivative of all outputs with respect to the input state_now, evaluated at state_now, and acted with this gradient on what we gave as dout_now (in our case, all ones). Using AD notation for reverse mode, this is

\[\overline{\text{state\_now}} = \frac{\partial \text{out\_now}}{\partial \text{state\_now}}\right|_\text{state\_now} \overline{\text{out\_now} + \frac{\partial \text{out\_old}}{\partial \text{state\_now}}\right|_\text{state\_now} \overline{\text{out\_old}\]

We note here that had we initialized dstate_now and dstate_old as something else, our results will change. Let's multiply them by two and see what happens.

dstate_now_new = zeros(6)
dstate_old_new = zeros(6)
out_now = zeros(6); dout_now = 2*ones(6)
out_old = zeros(6); dout_old = 2*ones(6)
autodiff(Reverse,
    one_step_forward,
    Duplicated([Tbar; Sbar], dstate_now_new),
    Duplicated([Tbar; Sbar], dstate_old_new),
    Duplicated(out_now, dout_now),
    Duplicated(out_old, dout_old),
    parameters,
    Const(10*parameters.day)
)
((nothing, nothing, nothing, nothing, nothing, nothing),)

Now checking dstate_now and dstate_old we see

@show dstate_now_new
6-element Vector{Float64}:
 0.8333333333333333
 0.8302383557333333
 1.0030949776
 0.9999999999999999
 0.9969050223999999
 1.0030949776

What happened? Enzyme is actually taking the computed gradient and acting on what we give as input to dout_now and dout_old. Checking this, we see

@show 2*dstate_now
6-element Vector{Float64}:
 0.8333333333333333
 0.8302383557333333
 1.0030949776
 0.9999999999999999
 0.9969050223999999
 1.0030949776

and they match the new results. This exactly matches what we'd expect to happen since we scaled dout_now by two.

Example 2: Full sensitivity calculations

Now we want to use Enzyme for a bit more than just a single derivative. Let's say we'd like to understand how sensitive the final temperature of Box One is to the initial salinity of Box Two. That is, given the function

\[J = (1,0,0,0,0,0)^T \cdot \mathbf{x}(t_f)\]

we want Enzyme to calculate the derivative

\[\frac{\partial J}{\partial \mathbf{x}(0)}\]

where $x(t)$ is the state of the model at time t. If we think about $x(t_f)$ as solely depending on the initial condition, then this derivative is really

\[\frac{\partial J}{\partial \mathbf{x}(0)} = \frac{\partial}{\partial \mathbf{x}(0)} \left( (1,0,0,0,0,0)^T \cdot L(\ldots(L(\mathbf{x}(0)))) \right)\]

with $L(x(t)) = x(t + dt)$, i.e. one forward step. One could expand this derivative with the chain rule (and it would be very complicated), but really this is where Enzyme comes in. Each run of autodiff on our forward function is one piece of this big chain rule done for us! We also note that the chain rule goes from the outside in, so we start with the derivative of the forward function at the final state, and work backwards until the initial state. To get Enzyme to do this, we complete the following steps:

  1. Run the forward model and store outputs (in a real ocean model this wouldn't be feasible and we'd need to use checkpointing)
  2. Compute the initial derivative from the final state
  3. Use Enzyme to work backwards until we reach the desired derivative.

For simplicity we define a function that takes completes our AD steps

function compute_adjoint_values(states_before_smoother, states_after_smoother, M, parameters)

    dout_now = [0.0;0.0;0.0;0.0;0.0;0.0]
    dout_old = [1.0;0.0;0.0;0.0;0.0;0.0]

    for j = M:-1:1

        dstate_now = zeros(6)
        dstate_old = zeros(6)

        autodiff(Reverse,
            one_step_forward,
            Duplicated(states_before_smoother[j], dstate_now),
            Duplicated(states_after_smoother[j], dstate_old),
            Duplicated(zeros(6), dout_now),
            Duplicated(zeros(6), dout_old),
            parameters,
            Const(10*parameters.day)
        )

        if j == 1
            return dstate_now, dstate_old
        end

        dout_now = copy(dstate_now)
        dout_old = copy(dstate_old)

    end

end
compute_adjoint_values (generic function with 1 method)

First we integrate the model forward:

M = 10000                       ## Total number of forward steps to take
Tbar = [20.0; 1.0; 1.0]         ## initial temperatures
Sbar = [35.5; 34.5; 34.5]       ## initial salinities

states_after_smoother, states_before_smoother = integrate(
    copy([Tbar; Sbar]),
    copy([Tbar; Sbar]),
    10*parameters.day,
    M,
    parameters
)
([[20.0, 1.0, 1.0, 35.5, 34.5, 34.5], [20.025492723413333, 0.9911739432533333, 1.0, 35.500066788374795, 34.499933211625205, 34.5], [20.087333738863112, 0.969903327879738, 0.9999962888698433, 35.5002258363465, 34.49977427598468, 34.499999971917205], [20.126887776693476, 0.9565730187361253, 0.9999864055857747, 35.50032159242199, 34.49967881565066, 34.49999989798184], [20.175800259146232, 0.939999404716326, 0.9999761385546613, 35.50044168490963, 34.499559023896516, 34.499999822798465], [20.217139147904874, 0.9262460893284369, 0.9999602790899905, 35.50053762316879, 34.49946355154754, 34.49999970632092], [20.26010741717093, 0.9119570801723601, 0.9999429417690262, 35.50063684306695, 34.49936483183514, 34.49999958127448], [20.299841004013917, 0.898923258333431, 0.9999214585308304, 35.500724517340934, 34.499277775651855, 34.499999426751806], [20.33904302830221, 0.8861260303696249, 0.999897728143107, 35.500809261973714, 34.49919370454616, 34.49999925837004], [20.376340852677306, 0.8740827941922498, 0.9998707135288942, 35.50088669147369, 34.49911703653957, 34.49999906799669]  …  [21.418887115526317, 0.5824185356980063, 0.6288641735279802, 35.52187385672855, 34.493658923600115, 34.496116804917875], [21.418888901485786, 0.5824164982869285, 0.6288533573611339, 35.521877133050594, 34.49365793736968, 34.49611623239497], [21.41889068688411, 0.5824144614821922, 0.6288425432781672, 35.521880409275084, 34.49365695161351, 34.49611565977789], [21.41889247172089, 0.5824124252841737, 0.6288317312787093, 35.52188368540099, 34.493655966332106, 34.49611508706677], [21.41889425599573, 0.5824103896932495, 0.6288209213623892, 35.521886961427256, 34.49365498152598, 34.496114514261734], [21.418896039708233, 0.5824083547097952, 0.6288101135288363, 35.52189023735286, 34.493653997195636, 34.49611394136292], [21.418897822858, 0.5824063203341853, 0.6287993077776801, 35.521893513176764, 34.493653013341586, 34.49611336837046], [21.418899605444633, 0.5824042865667947, 0.6287885041085504, 35.52189678889792, 34.49365202996433, 34.49611279528448], [21.418901387467738, 0.5824022534079969, 0.6287777025210768, 35.5219000645153, 34.493651047064375, 34.496112222105126], [21.41890316892692, 0.5824002208581649, 0.6287669030148895, 35.52190334002786, 34.493650064642225, 34.496111648832525]], [[20.0, 1.0, 1.0, 35.5, 34.5, 34.5], [20.101970893653334, 0.9646957730133332, 1.0, 35.50026715349918, 34.49973284650082, 34.5], [20.119900444732444, 0.9590478222389521, 0.9999851554793732, 35.500302250012865, 34.49969819931187, 34.499999887668814], [20.180416478445903, 0.938293102586859, 0.999979022514509, 35.50045603331573, 34.49954458799422, 34.499999844672516], [20.215480302999644, 0.9268383949554608, 0.9999601036038525, 35.50053308058508, 34.49946810394696, 34.49999970386699], [20.26179572647398, 0.9113081626864997, 0.999944770597596, 35.50064264659536, 34.499358974399726, 34.49999959475123], [20.299699067830883, 0.8989659059880041, 0.9999219467909221, 35.50072445590829, 34.49927782699354, 34.499999429274546], [20.339858463222974, 0.8858041411853559, 0.9998989987724511, 35.5008123144802, 34.49919061678521, 34.499999267183654], [20.376614182748977, 0.8739725807743567, 0.9998714564966953, 35.50088790159353, 34.49911580896236, 34.499999072361035], [20.41309201690906, 0.8622599848506607, 0.9998422129790796, 35.500961700733974, 34.49904282368742, 34.49999886889466]  …  [21.418888901672787, 0.5824164980848566, 0.628853356666466, 35.521877133082995, 34.49365793721165, 34.496116232426374], [21.418890687071247, 0.5824144612799947, 0.6288425425836232, 35.52188040930783, 34.49365695145531, 34.496115659809256], [21.418892471908162, 0.5824124250818506, 0.6288317305842889, 35.521883685434084, 34.49365596617373, 34.49611508709808], [21.418894256183133, 0.5824103894908012, 0.6288209206680924, 35.521886961460694, 34.49365498136744, 34.496114514293005], [21.418896039895767, 0.5824083545072222, 0.6288101128346629, 35.52189023738665, 34.49365399703693, 34.49611394139415], [21.418897823045665, 0.582406320131487, 0.6287993070836302, 35.52189351321089, 34.49365301318271, 34.496113368401645], [21.41889960563243, 0.5824042863639719, 0.6287885034146238, 35.5218967889324, 34.49365202980529, 34.496112795315625], [21.41890138765567, 0.5824022532050498, 0.6287777018272735, 35.521900064550124, 34.493651046905164, 34.496112222136226], [21.418903169114984, 0.5824002206550933, 0.6287669023212096, 35.521903340063034, 34.49365006448285, 34.496111648863575], [21.41890495000998, 0.582398188714476, 0.6287561048960619, 35.521906615470094, 34.49364908253884, 34.49611107549781]])

Next, we pass all of our states to the AD function to get back to the desired derivative:

dstate_now, dstate_old = compute_adjoint_values(
    states_before_smoother,
    states_after_smoother,
    M,
    parameters
)
([4.536200515970697e-6, -5.5228719721780566e-5, -0.0036262930402804205, -0.0015248276469600676, 0.0030843576704276065, -0.00155953002346649], [9.076556101612683e-6, -0.00011108329224528888, -0.007258754140417836, -0.0030609928352190135, 0.00616139595759519, -0.003100403122374079])

And we're done! We were interested in sensitivity to the initial salinity of box two, which will live in what we've called dstate_old. Checking this value we see

@show dstate_old[5]
0.00616139595759519

As it stands this is just a number, but a good check that Enzyme has computed what we want is to approximate the derivative with a Taylor series. Specifically,

\[J(\mathbf{x}(0) + \varepsilon) \approx J(\mathbf{x}(0)) + \varepsilon \frac{\partial J}{\partial \mathbf{x}(0)}\]

and a simple rearrangement yields

\[\frac{\partial J}{\partial \mathbf{x}(0)} \approx \frac{J(\mathbf{x}(0) + \varepsilon) - J(\mathbf{x}(0))}{\varepsilon}\]

Hopefully we see that the analytical values converge close to the one we found with Enzyme:

# unperturbed final state
use_to_check = states_after_smoother[M+1]

# a loop to compute the perturbed final states
diffs = []
step_sizes = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10]
for eps in step_sizes

    state_new_smoothed = zeros(6)

    initial_temperature = [20.0; 1.0; 1.0]
    perturbed_initial_salinity = [35.5; 34.5; 34.5] + [0.0; eps; 0.0]

    state_old = [initial_temperature; perturbed_initial_salinity]
    state_now = [20.0; 1.0; 1.0; 35.5; 34.5; 34.5]

    for t = 1:M

        rho = compute_density(state_now, parameters)
        u = compute_transport(rho, parameters)
        state_new = compute_update(state_now, state_old, u, parameters, 10*parameters.day)

        state_new_smoothed[:] = state_now + parameters.rf_coeff * (state_new - 2.0 * state_now + state_old)

        state_old = state_new_smoothed
        state_now = state_new

    end

    push!(diffs, (state_old[1] - use_to_check[1])/eps)

end

Then checking what we found the derivative to be analytically:

@show diffs
10-element Vector{Any}:
 0.0057270806669862395
 0.006114968958925715
 0.006156721866545922
 0.0061609284074393145
 0.00616134023800896
 0.006161283039318732
 0.006160796317544737
 0.0061451288502212265
 0.006068034963391256
 0.006359357485052897

which comes very close to our calculated value. We can go further and check the percent difference to see

@show abs.(diffs .- dstate_old[5])./dstate_old[5]
10-element Vector{Float64}:
 0.07048975485394138
 0.007535142845712421
 0.0007586091011576398
 7.588380280917028e-5
 9.043338005407632e-6
 1.8326735894697888e-5
 9.732210923954578e-5
 0.0026401658789532625
 0.015152571729925521
 0.03212933056407103

and we get down to a percent difference on the order of $1e^{-5}$, showing Enzyme calculated the correct derivative. Success!


This page was generated using Literate.jl.