begin
import Pkg
# careful: this is _not_ a reproducible environment
# activate the local environment
Pkg.activate(".")
Pkg.instantiate()
using PlutoUI, PlutoLinks
end
@revise using Enzyme
@revise using EnzymeCore
begin
using CairoMakie
set_theme!(
theme_latexfonts();
fontsize = 16,
Lines = (linewidth = 2,),
markersize = 16
)
end
Reproducing "Stabilizing backpropagation through time to learn complex physics"
Fig 1 from https://openreview.net/pdf?id=bozbTTWcaw
begin
N(xᵢ, θ) = θ[1] * xᵢ^2 + θ[2] * xᵢ
S(xᵢ, cᵢ) = xᵢ + cᵢ
end
S (generic function with 1 method)
function simulate(N, S, x₀, y, θ, n)
xᵢ = x₀
for i in 1:n
cᵢ = N(xᵢ, θ)
xᵢ = S(xᵢ, cᵢ)
end
return L = 1 / 2 * (xᵢ - y)^2
end
simulate (generic function with 1 method)
begin
x₀ = -0.3
y = 2.0
n = 4
end
4
begin
θ₁ = -4:0.01:4
θ₂ = -4:0.01:4
θ_space = collect(Iterators.product(θ₁, θ₂))
end;
L_space = simulate.(N, S, x₀, y, θ_space, n);
let
fig, ax, hm = heatmap(
θ₁, θ₂, L_space,
colorscale = log10,
colormap = Makie.Reverse(:Blues),
colorrange = (10^-5, 10^5)
)
Colorbar(fig[:, end + 1], hm)
fig
end
function ∇simulate(N, S, x₀, y, θ, n)
dθ = MixedDuplicated(θ, Ref(Enzyme.make_zero(θ)))
Enzyme.autodiff(Enzyme.Reverse, simulate, Const(N), Const(S), Const(x₀), Const(y), dθ, Const(n))
return dθ.dval[]
end
∇simulate (generic function with 1 method)
function plot_gradientfield(N, S, x₀, y, θ₁, θ₂, n)
θ_space = collect(Iterators.product(θ₁, θ₂))
gradient_field = ∇simulate.(N, S, x₀, y, θ_space, n)
fig, ax, hm = heatmap(
θ₁, θ₂, map(x -> sqrt(x[1]^2 + x[2]^2), gradient_field),
colorscale = log10,
colormap = Makie.Reverse(:Blues),
colorrange = (10^-3, 10^3)
)
Colorbar(fig[:, end + 1], hm)
streamplot!(
ax, (θ) -> -∇simulate(N, S, x₀, y, θ, n), θ₁, θ₂,
alpha = 0.5,
colorscale = log10, color = p -> :red,
arrow_size = 10
)
return fig
end
plot_gradientfield (generic function with 1 method)
plot_gradientfield(N, S, x₀, y, θ₁, θ₂, n)
begin
@noinline function ignore_derivatives(x::T) where {T}
return Core.inferencebarrier(x)::T
end
function EnzymeRules.forward(
config,
::Const{typeof(ignore_derivatives)},
A, x::Duplicated
)
return Enzyme.make_zero(x.val)
end
function EnzymeRules.augmented_primal(
config,
::Const{typeof(ignore_derivatives)},
FA, x
)
primal = EnzymeRules.needs_primal(config) ? x.val : nothing
if x isa Active
shadow = nothing
else
shadow = Enzyme.make_zero(x.val)
end
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end
function EnzymeRules.reverse(
config,
::Const{typeof(ignore_derivatives)},
dret::Active, tape, x::Active
)
return (Enzyme.make_zero(x.val),)
end
function EnzymeRules.reverse(
config,
::Const{typeof(ignore_derivatives)},
::Type{<:Duplicated}, tape, x::Duplicated
)
return (nothing,)
end
end
N_stop(xᵢ, θ) = θ[1] * ignore_derivatives(xᵢ^2) + θ[2] * ignore_derivatives(xᵢ)
N_stop (generic function with 1 method)
plot_gradientfield(N_stop, S, x₀, y, θ₁, θ₂, n)