Skip to content

Commit

Permalink
Merge pull request #1152 from willtebbutt/wct/mooncake-inside-problems
Browse files Browse the repository at this point in the history
Mooncake Inside Problems
  • Loading branch information
ChrisRackauckas authored Dec 5, 2024
2 parents 99156d1 + 6ac50d1 commit 78bcac6
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 13 deletions.
12 changes: 10 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLSensitivity"
uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
authors = ["Christopher Rackauckas <[email protected]>", "Yingbo Ma <[email protected]>"]
version = "7.71.2"
version = "7.72.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -42,6 +42,12 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
SciMLSensitivityMooncakeExt = "Mooncake"

[compat]
ADTypes = "1.9"
Accessors = "0.1.36"
Expand Down Expand Up @@ -71,6 +77,7 @@ LinearSolve = "2"
Lux = "1"
Markdown = "1.10"
ModelingToolkit = "9.42"
Mooncake = "0.4.52"
NLsolve = "4.5.1"
NonlinearSolve = "3.0.1"
Optimization = "4"
Expand Down Expand Up @@ -110,6 +117,7 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Expand All @@ -123,4 +131,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
1 change: 1 addition & 0 deletions docs/src/manual/differential_equation_sensitivities.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ ZygoteVJP
EnzymeVJP
TrackerVJP
ReverseDiffVJP
MooncakeVJP
```

## More Details on Sensitivity Algorithm Choices
Expand Down
22 changes: 22 additions & 0 deletions ext/SciMLSensitivityMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module SciMLSensitivityMooncakeExt

using SciMLSensitivity, Mooncake
import SciMLSensitivity: get_paramjac_config, mooncake_run_ad, MooncakeVJP, MooncakeLoaded

function get_paramjac_config(::MooncakeLoaded, ::MooncakeVJP, pf, p, f, y, _t)
dy_mem = zero(y)
λ_mem = zero(y)
cache = Mooncake.prepare_pullback_cache(pf, dy_mem, y, p, _t)
return cache, pf, λ_mem, dy_mem
end

function mooncake_run_ad(paramjac_config::Tuple, y, p, t, λ)
cache, pf, λ_mem, dy_mem = paramjac_config
λ_mem .= λ
dy, _ = Mooncake.value_and_pullback!!(cache, λ_mem, pf, dy_mem, y, p, t)
y_grad = cache.tangents[3]
p_grad = cache.tangents[4]
return dy, y_grad, p_grad
end

end
47 changes: 47 additions & 0 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
paramjac_config = get_paramjac_config(autojacvec, p, f, y, _p, _t; numindvar, alg)
pf = get_pf(autojacvec; _f = unwrappedf, isinplace = isinplace, isRODE = isRODE)
paramjac_config = (paramjac_config..., Enzyme.make_zero(pf))
elseif autojacvec isa MooncakeVJP
pf = get_pf(autojacvec, prob, unwrappedf)
paramjac_config = get_paramjac_config(MooncakeLoaded(), autojacvec, pf, p, f, y, _t)
elseif SciMLBase.has_paramjac(f) || quad || !(autojacvec isa Bool) ||
autojacvec isa EnzymeVJP
paramjac_config = nothing
Expand Down Expand Up @@ -460,6 +463,15 @@ function get_paramjac_config(autojacvec::EnzymeVJP, p, f, y, _p, _t; numindvar,
return paramjac_config
end

# Dispatched on inside extension.
struct MooncakeLoaded end

function get_paramjac_config(::Any, ::MooncakeVJP, pf, p, f, y, _t)
msg = "MooncakeVJP requires Mooncake.jl is loaded. Install the package and do " *
"`using Mooncake` to use this functionality"
error(msg)
end

function get_pf(autojacvec::ReverseDiffVJP; _f = nothing, isinplace = nothing,
isRODE = nothing)
nothing
Expand Down Expand Up @@ -492,6 +504,41 @@ function get_pf(autojacvec::EnzymeVJP; _f, isinplace, isRODE)
end
end

function get_pf(::MooncakeVJP, prob, _f)
isinplace = DiffEqBase.isinplace(prob)
isRODE = isa(prob, RODEProblem)
pf = let f = _f
if isinplace && isRODE
function (out, u, _p, t, W)
f(out, u, _p, t, W)
return out
end
elseif isinplace
function (out, u, _p, t)
f(out, u, _p, t)
return out
end
elseif !isinplace && isRODE
function (out, u, _p, t, W)
out .= f(u, _p, t, W)
return out
end
else
# !isinplace
function (out, u, _p, t)
out .= f(u, _p, t)
return out
end
end
end
end

function mooncake_run_ad(paramjac_config, y, p, t, λ)
msg = "MooncakeVJP requires Mooncake.jl is loaded. Install the package and do " *
"`using Mooncake` to use this functionality"
error(msg)
end

function getprob(S::SensitivityFunction)
(S isa ODEBacksolveSensitivityFunction) ? S.prob : S.sol.prob
end
Expand Down
8 changes: 8 additions & 0 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,14 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
return
end

function _vecjacobian!(dλ, y, λ, p, t, S::SensitivityFunction, ::MooncakeVJP, dgrad, dy, W)
_dy, y_grad, p_grad = mooncake_run_ad(S.diffcache.paramjac_config, y, p, t, λ)
dy !== nothing && recursive_copyto!(dy, _dy)
!== nothing && recursive_copyto!(dλ, y_grad)
dgrad !== nothing && recursive_copyto!(dgrad, p_grad)
return
end

function jacNoise!(λ, y, p, t, S::SensitivityFunction;
dgrad = nothing, dλ = nothing, dy = nothing)
_jacNoise!(λ, y, p, t, S, S.sensealg.autojacvec, dgrad, dλ, dy)
Expand Down
8 changes: 8 additions & 0 deletions src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,11 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing)
end
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
pJ = nothing
elseif sensealg.autojacvec isa MooncakeVJP
pf = get_pf(sensealg.autojacvec, prob, f)
paramjac_config = get_paramjac_config(
MooncakeLoaded(), sensealg.autojacvec, pf, p, f, y, tspan[2])
pJ = nothing
elseif isautojacvec # Zygote
paramjac_config = nothing
pf = nothing
Expand Down Expand Up @@ -500,6 +505,9 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t))
elseif sensealg.autojacvec isa MooncakeVJP
_, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ)
out .= p_grad
else
error("autojacvec choice $(sensealg.autojacvec) is not supported by GaussAdjoint")
end
Expand Down
8 changes: 8 additions & 0 deletions src/quadrature_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,11 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing)
end
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
pJ = nothing
elseif sensealg.autojacvec isa MooncakeVJP
pf = get_pf(sensealg.autojacvec, prob, f)
paramjac_config = get_paramjac_config(
MooncakeLoaded(), sensealg.autojacvec, pf, p, f, y, tspan[2])
pJ = nothing
elseif isautojacvec # Zygote
paramjac_config = nothing
pf = nothing
Expand Down Expand Up @@ -288,6 +293,9 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand)
else
out[:] .= vec(tmp[1])
end
elseif sensealg.autojacvec isa MooncakeVJP
_, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ)
out .= p_grad
elseif sensealg.autojacvec isa EnzymeVJP
tmp3, tmp4, tmp6 = paramjac_config
tmp4 .= λ
Expand Down
17 changes: 17 additions & 0 deletions src/sensitivity_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,23 @@ struct ReverseDiffVJP{compile} <: VJPChoice
ReverseDiffVJP(compile = false) = new{compile}()
end

"""
```julia
MooncakeVJP <: VJPChoice
```
Uses Mooncake.jl to compute the vector-Jacobian products.
Does not support GPUs (CuArrays).
## Constructor
```julia
MooncakeVJP()
```
"""
struct MooncakeVJP <: VJPChoice end

@inline convert_tspan(::ForwardDiffSensitivity{CS, CTS}) where {CS, CTS} = CTS
@inline convert_tspan(::Any) = nothing
@inline function alg_autodiff(alg::AbstractSensitivityAlgorithm{
Expand Down
70 changes: 59 additions & 11 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ _, easy_res14 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussAdjoint())
_, easy_res15 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res16 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res142 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
Expand All @@ -158,6 +166,10 @@ _, easy_res146 = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dgdu_discret
sensealg = GaussAdjoint(checkpointing = true,
autojacvec = false),
checkpoints = sol.t[1:500:end])
_, easy_res147 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = GaussAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
adj_prob = ODEAdjointProblem(sol,
QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14,
autojacvec = SciMLSensitivity.ReverseDiffVJP()),
Expand Down Expand Up @@ -189,11 +201,14 @@ res, err = quadgk(integrand, 0.0, 10.0, atol = 1e-14, rtol = 1e-12)
@test isapprox(res, easy_res12, rtol = 1e-9)
@test isapprox(res, easy_res13, rtol = 1e-9)
@test isapprox(res, easy_res14, rtol = 1e-9)
@test isapprox(res, easy_res15, rtol = 1e-9)
@test isapprox(res, easy_res16, rtol = 1e-9)
@test isapprox(res, easy_res142, rtol = 1e-9)
@test isapprox(res, easy_res143, rtol = 1e-9)
@test isapprox(res, easy_res144, rtol = 1e-9)
@test isapprox(res, easy_res145, rtol = 1e-9)
@test isapprox(res, easy_res146, rtol = 1e-9)
@test isapprox(res, easy_res147, rtol = 1e-9)

println("OOP adjoint sensitivities ")

Expand All @@ -203,14 +218,11 @@ _, easy_res = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
_, easy_res2 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(abstol = 1e-14,
reltol = 1e-14))
sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14))
_, easy_res22 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(autojacvec = false,
abstol = 1e-14,
reltol = 1e-14))
sensealg = QuadratureAdjoint(autojacvec = false, abstol = 1e-14, reltol = 1e-14))
_, easy_res2 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
Expand All @@ -224,17 +236,15 @@ _, easy_res3 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
@test easy_res32 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = false))[1] isa
AbstractArray
sensealg = InterpolatingAdjoint(autojacvec = false))[1] isa AbstractArray
_, easy_res4 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = BacksolveAdjoint())
@test easy_res42 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = BacksolveAdjoint(autojacvec = false))[1] isa
AbstractArray
sensealg = BacksolveAdjoint(autojacvec = false))[1] isa AbstractArray
_, easy_res5 = adjoint_sensitivities(soloop,
Kvaerno5(nlsolve = NLAnderson(), smooth_est = false),
t = t, dgdu_discrete = dg, abstol = 1e-12,
Expand All @@ -248,8 +258,7 @@ _, easy_res6 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_discre
_, easy_res62 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
dgdu_discrete = dg, abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(checkpointing = true,
autojacvec = false),
sensealg = InterpolatingAdjoint(checkpointing = true, autojacvec = false),
checkpoints = soloop_nodense.t[1:5:end])

_, easy_res8 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_discrete = dg,
Expand Down Expand Up @@ -289,6 +298,39 @@ _, easy_res123 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_disc
reltol = 1e-14,
sensealg = GaussAdjoint(checkpointing = true),
checkpoints = soloop_nodense.t[1:5:end])

_, easy_res2_mc_quad = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(
abstol = 1e-14, reltol = 1e-14, autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res2_mc_interp = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res2_mc_back = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = BacksolveAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res6_mc_quad = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(
abstol = 1e-14, reltol = 1e-14, autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res6_mc_interp = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(checkpointing = true,
autojacvec = SciMLSensitivity.MooncakeVJP()),
checkpoints = soloop_nodense.t[1:5:end])
_, easy_res6_mc_back = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = BacksolveAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))

@test isapprox(res, easy_res, rtol = 1e-10)
@test isapprox(res, easy_res2, rtol = 1e-10)
@test isapprox(res, easy_res22, rtol = 1e-10)
Expand All @@ -309,6 +351,12 @@ _, easy_res123 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_disc
@test isapprox(res, easy_res12, rtol = 1e-9)
@test isapprox(res, easy_res122, rtol = 1e-9)
@test isapprox(res, easy_res123, rtol = 1e-4)
@test isapprox(res, easy_res2_mc_quad, rtol = 1e-9)
@test isapprox(res, easy_res2_mc_interp, rtol = 1e-9)
@test isapprox(res, easy_res2_mc_back, rtol = 1e-9)
@test isapprox(res, easy_res6_mc_quad, rtol = 1e-4)
@test isapprox(res, easy_res6_mc_interp, rtol = 1e-9)
@test isapprox(res, easy_res6_mc_back, rtol = 1e-9)

println("Calculate adjoint sensitivities ")

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using SciMLSensitivity, SafeTestsets
using Test, Pkg
import Mooncake

const GROUP = get(ENV, "GROUP", "All")

Expand Down

1 comment on commit 78bcac6

@willtebbutt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChrisRackauckas should this be registered, or are you waiting on something else before releasing?

Please sign in to comment.