Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid needing adjoints of SciMLStructures' constructor #1135

Merged
merged 12 commits into from
Nov 8, 2024
39 changes: 31 additions & 8 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -639,15 +639,26 @@ function DiffEqBase._concrete_solve_adjoint(

du0 = reshape(du0, size(u0))

dp = p === nothing || p === SciMLBase.NullParameters() ? nothing :
dp isa AbstractArray ? reshape(dp', size(p)) : dp
dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing :
dp isa AbstractArray ? reshape(dp', size(tunables)) : dp

_, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() ||
!isscimlstructure(p)
nothing, x -> (x,)
else
Zygote.pullback(p) do p
t, _, _ = canonicalize(Tunable(), p)
t
end
end

if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
(NoTangent(), NoTangent(), du0, repack_adjoint(dp)[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
(NoTangent(), NoTangent(), NoTangent(),
du0, repack_adjoint(dp)[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
end
Expand Down Expand Up @@ -835,7 +846,7 @@ function DiffEqBase._concrete_solve_adjoint(
pparts = typeof(tunables[1:1])[]
for j in 0:(num_chunks - 1)
local chunk
if ((j + 1) * chunk_size) <= length(p)
if ((j + 1) * chunk_size) <= length(tunables)
chunk = ((j * chunk_size + 1):((j + 1) * chunk_size))
pchunk = vec(tunables)[chunk]
pdualpart = seed_duals(pchunk, prob.f,
Expand Down Expand Up @@ -957,7 +968,7 @@ function DiffEqBase._concrete_solve_adjoint(
end
push!(pparts, vec(_dp))
end
SciMLStructures.replace(Tunable(), p, reduce(vcat, pparts))
reduce(vcat, pparts)
end
else
dp = nothing
Expand Down Expand Up @@ -1134,12 +1145,24 @@ function DiffEqBase._concrete_solve_adjoint(
end
end

_, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() ||
!isscimlstructure(p)
nothing, x -> (x,)
else
Zygote.pullback(p) do p
t, _, _ = canonicalize(Tunable(), p)
t
end
end

if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), unthunk(du0), unthunk(dp), NoTangent(),
(NoTangent(), NoTangent(), unthunk(du0),
repack_adjoint(unthunk(dp))[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
(NoTangent(), NoTangent(), NoTangent(),
du0, repack_adjoint(unthunk(dp))[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,8 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
ReverseDiff.reverse_pass!(tape)
copyto!(vec(out), ReverseDiff.deriv(tp))
elseif sensealg.autojacvec isa ZygoteVJP
_dy, back = Zygote.pullback(p) do p
vec(f(y, p, t))
_dy, back = Zygote.pullback(tunables) do tunables
vec(f(y, tunables, t))
end
tmp = back(λ)
if tmp[1] === nothing
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ end
@time @safetestset "Prob Kwargs" include("prob_kwargs.jl")
@time @safetestset "DiscreteProblem Adjoints" include("discrete.jl")
@time @safetestset "Time Type Mixing Adjoints" include("time_type_mixing.jl")
@time @safetestset "SciMLStructures Interface" include("scimlstructures_interface.jl")
end
end

Expand Down
80 changes: 80 additions & 0 deletions test/scimlstructures_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# taken from https://github.com/SciML/SciMLStructures.jl/pull/28
using OrdinaryDiffEq, SciMLSensitivity, Zygote
using LinearAlgebra
import SciMLStructures as SS

mutable struct SubproblemParameters{P, Q, R}
p::P # tunable
q::Q
r::R
end
mutable struct Parameters{P, C}
subparams::P
coeffs::C # tunable matrix
end
# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams)
# and `du[length(subparams)+1:end] .= coeffs * u`
function rhs!(du, u, p::Parameters, t)
for (i, subpars) in enumerate(p.subparams)
du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
end
N = length(p.subparams)
mul!(view(du, (N + 1):(length(du))), p.coeffs, u)
return nothing
end
u = sin.(0.1:0.1:1.0)
subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5]
p = Parameters(subparams, cos.([0.1i + 0.33j for i in 1:5, j in 1:10]))
tspan = (0.0, 1.0)
prob = ODEProblem(rhs!, u, tspan, p)
solve(prob, Tsit5())

# Mark the struct as a SciMLStructure
SS.isscimlstructure(::Parameters) = true
# It is mutable
SS.ismutablescimlstructure(::Parameters) = true
# Only contains `Tunable` portion
# We could also add a `Constants` portion to contain the values that are
# not tunable. The implementation would be similar to this one.
SS.hasportion(::SS.Tunable, ::Parameters) = true
function SS.canonicalize(::SS.Tunable, p::Parameters)
# concatenate all tunable values into a single vector
buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
# repack takes a new vector of the same length as `buffer`, and constructs
# a new `Parameters` object using the values from the new vector for tunables
# and retaining old values for other parameters. This is exactly what replace does,
# so we can use that instead.
repack = let p = p
function repack(newbuffer)
SS.replace(SS.Tunable(), p, newbuffer)
end
end
# the canonicalized vector, the repack function, and a boolean indicating
# whether the buffer aliases values in the parameter object (here, it doesn't)
return buffer, repack, false
end
function SS.replace(::SS.Tunable, p::Parameters, newbuffer)
N = length(p.subparams) + length(p.coeffs)
@assert length(newbuffer) == N
subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r)
for (i, subpar) in enumerate(p.subparams)]
coeffs = reshape(
view(newbuffer, (length(p.subparams) + 1):length(newbuffer)), size(p.coeffs))
return Parameters(subparams, coeffs)
end
function SS.replace!(::SS.Tunable, p::Parameters, newbuffer)
N = length(p.subparams) + length(p.coeffs)
@assert length(newbuffer) == N
for (subpar, val) in zip(p.subparams, newbuffer)
subpar.p = val
end
copyto!(coeffs, view(newbuffer, (length(p.subparams) + 1):length(newbuffer)))
return p
end

Zygote.gradient(0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) do tunables
newp = SS.replace(SS.Tunable(), p, tunables)
newprob = remake(prob; p = newp)
sol = solve(newprob, Tsit5())
return sum(sol.u[end])
end
Loading