diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index dd2af4e3a..d646ddbf6 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -639,15 +639,26 @@ function DiffEqBase._concrete_solve_adjoint( du0 = reshape(du0, size(u0)) - dp = p === nothing || p === SciMLBase.NullParameters() ? nothing : - dp isa AbstractArray ? reshape(dp', size(p)) : dp + dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : + dp isa AbstractArray ? reshape(dp', size(tunables)) : dp + + _, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() || + !isscimlstructure(p) + nothing, x -> (x,) + else + Zygote.pullback(p) do p + t, _, _ = canonicalize(Tunable(), p) + t + end + end if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(), NoTangent(), du0, dp, NoTangent(), + (NoTangent(), NoTangent(), du0, repack_adjoint(dp)[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) else - (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), + (NoTangent(), NoTangent(), NoTangent(), + du0, repack_adjoint(dp)[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) end end @@ -835,7 +846,7 @@ function DiffEqBase._concrete_solve_adjoint( pparts = typeof(tunables[1:1])[] for j in 0:(num_chunks - 1) local chunk - if ((j + 1) * chunk_size) <= length(p) + if ((j + 1) * chunk_size) <= length(tunables) chunk = ((j * chunk_size + 1):((j + 1) * chunk_size)) pchunk = vec(tunables)[chunk] pdualpart = seed_duals(pchunk, prob.f, @@ -957,7 +968,7 @@ function DiffEqBase._concrete_solve_adjoint( end push!(pparts, vec(_dp)) end - SciMLStructures.replace(Tunable(), p, reduce(vcat, pparts)) + reduce(vcat, pparts) end else dp = nothing @@ -1134,12 +1145,24 @@ function DiffEqBase._concrete_solve_adjoint( end end + _, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() || + !isscimlstructure(p) + nothing, x -> (x,) + else + Zygote.pullback(p) do p + t, _, _ = canonicalize(Tunable(), p) + t + end + end + if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(), NoTangent(), unthunk(du0), unthunk(dp), NoTangent(), + (NoTangent(), NoTangent(), unthunk(du0), + repack_adjoint(unthunk(dp))[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) else - (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), + (NoTangent(), NoTangent(), NoTangent(), + du0, repack_adjoint(unthunk(dp))[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) end end diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 74f69250d..baa8df50b 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -482,8 +482,8 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) ReverseDiff.reverse_pass!(tape) copyto!(vec(out), ReverseDiff.deriv(tp)) elseif sensealg.autojacvec isa ZygoteVJP - _dy, back = Zygote.pullback(p) do p - vec(f(y, p, t)) + _dy, back = Zygote.pullback(tunables) do tunables + vec(f(y, tunables, t)) end tmp = back(λ) if tmp[1] === nothing diff --git a/test/runtests.jl b/test/runtests.jl index 0a97ff62d..6fa3db51b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,7 @@ end @time @safetestset "Prob Kwargs" include("prob_kwargs.jl") @time @safetestset "DiscreteProblem Adjoints" include("discrete.jl") @time @safetestset "Time Type Mixing Adjoints" include("time_type_mixing.jl") + @time @safetestset "SciMLStructures Interface" include("scimlstructures_interface.jl") end end diff --git a/test/scimlstructures_interface.jl b/test/scimlstructures_interface.jl new file mode 100644 index 000000000..0754d3980 --- /dev/null +++ b/test/scimlstructures_interface.jl @@ -0,0 +1,80 @@ +# taken from https://github.com/SciML/SciMLStructures.jl/pull/28 +using OrdinaryDiffEq, SciMLSensitivity, Zygote +using LinearAlgebra +import SciMLStructures as SS + +mutable struct SubproblemParameters{P, Q, R} + p::P # tunable + q::Q + r::R +end +mutable struct Parameters{P, C} + subparams::P + coeffs::C # tunable matrix +end +# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams) +# and `du[length(subparams)+1:end] .= coeffs * u` +function rhs!(du, u, p::Parameters, t) + for (i, subpars) in enumerate(p.subparams) + du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t + end + N = length(p.subparams) + mul!(view(du, (N + 1):(length(du))), p.coeffs, u) + return nothing +end +u = sin.(0.1:0.1:1.0) +subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5] +p = Parameters(subparams, cos.([0.1i + 0.33j for i in 1:5, j in 1:10])) +tspan = (0.0, 1.0) +prob = ODEProblem(rhs!, u, tspan, p) +solve(prob, Tsit5()) + +# Mark the struct as a SciMLStructure +SS.isscimlstructure(::Parameters) = true +# It is mutable +SS.ismutablescimlstructure(::Parameters) = true +# Only contains `Tunable` portion +# We could also add a `Constants` portion to contain the values that are +# not tunable. The implementation would be similar to this one. +SS.hasportion(::SS.Tunable, ::Parameters) = true +function SS.canonicalize(::SS.Tunable, p::Parameters) + # concatenate all tunable values into a single vector + buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs)) + # repack takes a new vector of the same length as `buffer`, and constructs + # a new `Parameters` object using the values from the new vector for tunables + # and retaining old values for other parameters. This is exactly what replace does, + # so we can use that instead. + repack = let p = p + function repack(newbuffer) + SS.replace(SS.Tunable(), p, newbuffer) + end + end + # the canonicalized vector, the repack function, and a boolean indicating + # whether the buffer aliases values in the parameter object (here, it doesn't) + return buffer, repack, false +end +function SS.replace(::SS.Tunable, p::Parameters, newbuffer) + N = length(p.subparams) + length(p.coeffs) + @assert length(newbuffer) == N + subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) + for (i, subpar) in enumerate(p.subparams)] + coeffs = reshape( + view(newbuffer, (length(p.subparams) + 1):length(newbuffer)), size(p.coeffs)) + return Parameters(subparams, coeffs) +end +function SS.replace!(::SS.Tunable, p::Parameters, newbuffer) + N = length(p.subparams) + length(p.coeffs) + @assert length(newbuffer) == N + for (subpar, val) in zip(p.subparams, newbuffer) + subpar.p = val + end + copyto!(coeffs, view(newbuffer, (length(p.subparams) + 1):length(newbuffer))) + return p +end + +Zygote.gradient(0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) do tunables + newp = SS.replace(SS.Tunable(), p, tunables) + newprob = remake(prob; p = newp) + sol = solve(newprob, Tsit5()) + return sum(sol.u[end]) +end