begin
    import Pkg
    # careful: this is _not_ a reproducible environment
    # activate the local environment
    Pkg.activate(".")
    Pkg.instantiate()
    using PlutoUI, PlutoLinks
end
@revise using Enzyme
@revise using EnzymeCore
begin
    using CairoMakie
    set_theme!(
        theme_latexfonts();
        fontsize = 16,
        Lines = (linewidth = 2,),
        markersize = 16
    )
end

Reproducing "Stabilizing backpropagation through time to learn complex physics"

Fig 1 from https://openreview.net/pdf?id=bozbTTWcaw

begin
    N(xᵢ, θ) = θ[1] * xᵢ^2 + θ[2] * xᵢ
    S(xᵢ, cᵢ) = xᵢ + cᵢ
end
S (generic function with 1 method)
function simulate(N, S, x₀, y, θ, n)
    xᵢ = x₀
    for i in 1:n
        cᵢ = N(xᵢ, θ)
        xᵢ = S(xᵢ, cᵢ)
    end
    return L = 1 / 2 * (xᵢ - y)^2
end
simulate (generic function with 1 method)
begin
    x₀ = -0.3
    y = 2.0
    n = 4
end
4
begin
    θ₁ = -4:0.01:4
    θ₂ = -4:0.01:4
    θ_space = collect(Iterators.product(θ₁, θ₂))
end;
L_space = simulate.(N, S, x₀, y, θ_space, n);
let
    fig, ax, hm = heatmap(
        θ₁, θ₂, L_space,
        colorscale = log10,
        colormap = Makie.Reverse(:Blues),
        colorrange = (10^-5, 10^5)
    )
    Colorbar(fig[:, end + 1], hm)

    fig
end
function ∇simulate(N, S, x₀, y, θ, n)
    dθ = MixedDuplicated(θ, Ref(Enzyme.make_zero(θ)))
    Enzyme.autodiff(Enzyme.Reverse, simulate, Const(N), Const(S), Const(x₀), Const(y), dθ, Const(n))
    return dθ.dval[]
end
∇simulate (generic function with 1 method)
function plot_gradientfield(N, S, x₀, y, θ₁, θ₂, n)
    θ_space = collect(Iterators.product(θ₁, θ₂))
    gradient_field = ∇simulate.(N, S, x₀, y, θ_space, n)

    fig, ax, hm = heatmap(
        θ₁, θ₂, map(x -> sqrt(x[1]^2 + x[2]^2), gradient_field),
        colorscale = log10,
        colormap = Makie.Reverse(:Blues),
        colorrange = (10^-3, 10^3)
    )
    Colorbar(fig[:, end + 1], hm)

    streamplot!(
        ax, (θ) -> -∇simulate(N, S, x₀, y, θ, n), θ₁, θ₂,
        alpha = 0.5,
        colorscale = log10, color = p -> :red,
        arrow_size = 10
    )
    return fig
end
plot_gradientfield (generic function with 1 method)
plot_gradientfield(N, S, x₀, y, θ₁, θ₂, n)
begin
    @noinline function ignore_derivatives(x::T) where {T}
        return Core.inferencebarrier(x)::T
    end

    function EnzymeRules.forward(
            config,
            ::Const{typeof(ignore_derivatives)},
            A, x::Duplicated
        )
        return Enzyme.make_zero(x.val)
    end

    function EnzymeRules.augmented_primal(
            config,
            ::Const{typeof(ignore_derivatives)},
            FA, x
        )
        primal = EnzymeRules.needs_primal(config) ? x.val : nothing
        if x isa Active
            shadow = nothing
        else
            shadow = Enzyme.make_zero(x.val)
        end

        return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
    end
    function EnzymeRules.reverse(
            config,
            ::Const{typeof(ignore_derivatives)},
            dret::Active, tape, x::Active
        )
        return (Enzyme.make_zero(x.val),)
    end

    function EnzymeRules.reverse(
            config,
            ::Const{typeof(ignore_derivatives)},
            ::Type{<:Duplicated}, tape, x::Duplicated
        )
        return (nothing,)
    end
end
N_stop(xᵢ, θ) = θ[1] * ignore_derivatives(xᵢ^2) + θ[2] * ignore_derivatives(xᵢ)
N_stop (generic function with 1 method)
plot_gradientfield(N_stop, S, x₀, y, θ₁, θ₂, n)