Skip to content

Commit

Permalink
Initial pass at adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Nov 26, 2024
1 parent 99156d1 commit 0e5a563
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -71,6 +72,7 @@ LinearSolve = "2"
Lux = "1"
Markdown = "1.10"
ModelingToolkit = "9.42"
Mooncake = "0.4.50"
NLsolve = "4.5.1"
NonlinearSolve = "3.0.1"
Optimization = "4"
Expand Down
1 change: 1 addition & 0 deletions src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, Zer
using Enzyme: Enzyme
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Mooncake: Mooncake
using Tracker: Tracker, TrackedArray
using ReverseDiff: ReverseDiff
using Zygote: Zygote
Expand Down
44 changes: 44 additions & 0 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
paramjac_config = get_paramjac_config(autojacvec, p, f, y, _p, _t; numindvar, alg)
pf = get_pf(autojacvec; _f = unwrappedf, isinplace = isinplace, isRODE = isRODE)
paramjac_config = (paramjac_config..., Enzyme.make_zero(pf))
elseif autojacvec isa MooncakeVJP
pf = get_pf(autojacvec; _f = unwrappedf, isinplace, isRODE)
paramjac_config = get_paramjac_config(autojacvec, pf, p, f, y, _p, _t; numindvar, alg)
elseif SciMLBase.has_paramjac(f) || quad || !(autojacvec isa Bool) ||
autojacvec isa EnzymeVJP
paramjac_config = nothing
Expand Down Expand Up @@ -460,6 +463,20 @@ function get_paramjac_config(autojacvec::EnzymeVJP, p, f, y, _p, _t; numindvar,
return paramjac_config
end

function get_paramjac_config(
::MooncakeVJP, pf, p, f, y, _p, _t;
numindvar, alg, isinplace = nothing, isRODE = nothing, _W = nothing,
)
dy_mem = zero(y)
dy_mem_grad = Mooncake.zero_tangent(dy_mem)
pf_grad = Mooncake.zero_tangent(pf)
y_grad = Mooncake.zero_tangent(y)
p_grad = Mooncake.zero_tangent(p)
λ_mem = zero(y)
rule = Mooncake.build_rrule(pf, dy_mem, y, p, _t)
return rule, pf, pf_grad, dy_mem, dy_mem_grad, y_grad, p_grad, λ_mem
end

function get_pf(autojacvec::ReverseDiffVJP; _f = nothing, isinplace = nothing,
isRODE = nothing)
nothing
Expand Down Expand Up @@ -492,6 +509,33 @@ function get_pf(autojacvec::EnzymeVJP; _f, isinplace, isRODE)
end
end

function get_pf(autojacvec::MooncakeVJP; _f, isinplace, isRODE)
pf = let f = _f
if isinplace && isRODE
function (out, u, _p, t, W)
f(out, u, _p, t, W)
return out
end
elseif isinplace
function (out, u, _p, t)
f(out, u, _p, t)
return out
end
elseif !isinplace && isRODE
function (out, u, _p, t, W)
out .= f(u, _p, t, W)
return out
end
else
# !isinplace
function (out, u, _p, t)
out .= f(u, _p, t)
return out
end
end
end
end

function getprob(S::SensitivityFunction)
(S isa ODEBacksolveSensitivityFunction) ? S.prob : S.sol.prob
end
Expand Down
62 changes: 62 additions & 0 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,68 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
return
end

function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::MooncakeVJP, dgrad, dy,
W) where {TS <: SensitivityFunction}
(; sensealg) = S
prob = getprob(S)
f = unwrapped_f(S.f)

if inplace_sensitivity(S)
rule, pf, pf_grad, dy_mem, dy_mem_grad, y_grad, p_grad, λ_mem = S.diffcache.paramjac_config
λ_mem .= λ
coduals = (
Mooncake.CoDual(pf, Mooncake.set_to_zero!!(pf_grad)),
Mooncake.CoDual(dy_mem, dy_mem_grad),
Mooncake.CoDual(y, Mooncake.set_to_zero!!(y_grad)),
Mooncake.CoDual(p, Mooncake.set_to_zero!!(p_grad)),
Mooncake.zero_codual(t), # t is a Float64, so zero_codual is basically free.
)
__dy, pb!! = rule(map(Mooncake.to_fwds, coduals)...)
dy !== nothing && recursive_copyto!(dy, __dy) # grab a copy of the final value.

# Copy fdata across and run the reverse-pass.
Mooncake.increment!!(__dy.dx, Mooncake.fdata(λ_mem))
rdata_out = pb!!(Mooncake.rdata(λ_mem))
(_, _, grad_y, grad_p, _) = map(
(f, r) -> Mooncake.tangent(Mooncake.fdata(f.dx), r), coduals, rdata_out
)

# Copy values to target memory.
!== nothing && recursive_copyto!(dλ, grad_y)
dgrad !== nothing && recursive_copyto!(dgrad, grad_p)
else
@show "out-of-place"
if W === nothing
_dy, back = Zygote.pullback(y, p) do u, p
vec(f(u, p, t))
end
else
_dy, back = Zygote.pullback(y, p) do u, p
vec(f(u, p, t, W))
end
end

# Grab values from `_dy` before `back` in case mutated
dy !== nothing && recursive_copyto!(dy, _dy)

tmp1, tmp2 = back(λ)
if tmp1 === nothing && !sensealg.autojacvec.allow_nothing
throw(ZygoteVJPNothingError())
elseif tmp1 !== nothing
!== nothing && recursive_copyto!(dλ, tmp1)
end

if dgrad !== nothing
if tmp2 === nothing && !sensealg.autojacvec.allow_nothing
throw(ZygoteVJPNothingError())
elseif tmp2 !== nothing
!isempty(dgrad) && recursive_copyto!(dgrad, tmp2)
end
end
end
return
end

function jacNoise!(λ, y, p, t, S::SensitivityFunction;
dgrad = nothing, dλ = nothing, dy = nothing)
_jacNoise!(λ, y, p, t, S, S.sensealg.autojacvec, dgrad, dλ, dy)
Expand Down
37 changes: 37 additions & 0 deletions src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,31 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing)
end
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
pJ = nothing
elseif sensealg.autojacvec isa MooncakeVJP
pf = let f = unwrappedf
if DiffEqBase.isinplace(prob)
function (out, u, _p, t)
f(out, u, _p, t)
return out
end
else
!DiffEqBase.isinplace(prob)
function (out, u, _p, t)
out .= f(u, _p, t)
return out
end
end
end
pf_grad_mem = Mooncake.zero_tangent(pf)
dy_mem = zero(y)
dy_grad_mem = zero(y)
y_grad_mem = zero(y)
p_grad_mem = Mooncake.zero_tangent(p)
rule = Mooncake.build_rrule(pf, dy_mem, y, p, tspan[2])
paramjac_config = (
pf_grad_mem, dy_mem, dy_grad_mem, y_grad_mem, p_grad_mem, rule
)
pJ = nothing
elseif isautojacvec # Zygote
paramjac_config = nothing
pf = nothing
Expand Down Expand Up @@ -500,6 +525,18 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t))
elseif sensealg.autojacvec isa MooncakeVJP
pf_grad_mem, dy_mem, dy_grad_mem, y_grad_mem, p_grad_mem, rule = paramjac_config
_, (_, _, _, _out, _) = Mooncake.__value_and_pullback!!(
rule,
λ,
Mooncake.CoDual(pf, pf_grad_mem),
Mooncake.CoDual(dy_mem, dy_grad_mem),
Mooncake.CoDual(y, Mooncake.set_to_zero!!(y_grad_mem)),
Mooncake.CoDual(p, Mooncake.set_to_zero!!(p_grad_mem)),
Mooncake.zero_codual(t),
)
out .= _out
else
error("autojacvec choice $(sensealg.autojacvec) is not supported by GaussAdjoint")
end
Expand Down
40 changes: 40 additions & 0 deletions src/quadrature_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,34 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing)
end
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
pJ = nothing
elseif sensealg.autojacvec isa MooncakeVJP
pf = let f = unwrappedf
if DiffEqBase.isinplace(prob) && prob isa RODEProblem
function (out, u, _p, t, W)
f(out, u, _p, t, W)
return out
end
elseif DiffEqBase.isinplace(prob)
function (out, u, _p, t)
f(out, u, _p, t)
return out
end
elseif !DiffEqBase.isinplace(prob) && prob isa RODEProblem
f
else
f
end
end
pf_grad_mem = Mooncake.zero_tangent(pf)
dy_mem = zero(y)
dy_grad_mem = zero(y)
y_grad_mem = zero(y)
p_grad_mem = Mooncake.zero_tangent(p)
rule = Mooncake.build_rrule(pf, dy_mem, y, p, tspan[2])
paramjac_config = (
pf_grad_mem, dy_mem, dy_grad_mem, y_grad_mem, p_grad_mem, rule
)
pJ = nothing
elseif isautojacvec # Zygote
paramjac_config = nothing
pf = nothing
Expand Down Expand Up @@ -288,6 +316,18 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand)
else
out[:] .= vec(tmp[1])
end
elseif sensealg.autojacvec isa MooncakeVJP
pf_grad_mem, dy_mem, dy_grad_mem, y_grad_mem, p_grad_mem, rule = paramjac_config
_, (_, _, _, _out, _) = Mooncake.__value_and_pullback!!(
rule,
λ,
Mooncake.CoDual(pf, pf_grad_mem),
Mooncake.CoDual(dy_mem, dy_grad_mem),
Mooncake.CoDual(y, Mooncake.set_to_zero!!(y_grad_mem)),
Mooncake.CoDual(p, Mooncake.set_to_zero!!(p_grad_mem)),
Mooncake.zero_codual(t),
)
out .= _out
elseif sensealg.autojacvec isa EnzymeVJP
tmp3, tmp4, tmp6 = paramjac_config
tmp4 .= λ
Expand Down
17 changes: 17 additions & 0 deletions src/sensitivity_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,23 @@ struct ReverseDiffVJP{compile} <: VJPChoice
ReverseDiffVJP(compile = false) = new{compile}()
end

"""
```julia
MooncakeVJP <: VJPChoice
```
Uses Mooncake.jl to compute the vector-Jacobian products.
Does not support GPUs (CuArrays).
## Constructor
```julia
MooncakeVJP()
```
"""
struct MooncakeVJP <: VJPChoice end

@inline convert_tspan(::ForwardDiffSensitivity{CS, CTS}) where {CS, CTS} = CTS
@inline convert_tspan(::Any) = nothing
@inline function alg_autodiff(alg::AbstractSensitivityAlgorithm{
Expand Down
15 changes: 15 additions & 0 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ _, easy_res14 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussAdjoint())
_, easy_res15 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res16 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res142 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
Expand All @@ -158,6 +166,10 @@ _, easy_res146 = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dgdu_discret
sensealg = GaussAdjoint(checkpointing = true,
autojacvec = false),
checkpoints = sol.t[1:500:end])
_, easy_res147 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
adj_prob = ODEAdjointProblem(sol,
QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14,
autojacvec = SciMLSensitivity.ReverseDiffVJP()),
Expand Down Expand Up @@ -189,11 +201,14 @@ res, err = quadgk(integrand, 0.0, 10.0, atol = 1e-14, rtol = 1e-12)
@test isapprox(res, easy_res12, rtol = 1e-9)
@test isapprox(res, easy_res13, rtol = 1e-9)
@test isapprox(res, easy_res14, rtol = 1e-9)
@test isapprox(res, easy_res15, rtol = 1e-9)
@test isapprox(res, easy_res16, rtol = 1e-9)
@test isapprox(res, easy_res142, rtol = 1e-9)
@test isapprox(res, easy_res143, rtol = 1e-9)
@test isapprox(res, easy_res144, rtol = 1e-9)
@test isapprox(res, easy_res145, rtol = 1e-9)
@test isapprox(res, easy_res146, rtol = 1e-9)
@test isapprox(res, easy_res147, rtol = 1e-9)

println("OOP adjoint sensitivities ")

Expand Down

0 comments on commit 0e5a563

Please sign in to comment.