From 40b086a7359d9ad1fdc94e8129c505946e2fd614 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 26 Aug 2024 13:44:17 +0530 Subject: [PATCH 01/11] chore: use tunables in vec_pjac --- src/concrete_solve.jl | 2 +- src/gauss_adjoint.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index e044c8220..dc0c859bc 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -640,7 +640,7 @@ function DiffEqBase._concrete_solve_adjoint( du0 = reshape(du0, size(u0)) dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : - dp isa AbstractArray ? reshape(dp', size(p)) : dp + dp isa AbstractArray ? reshape(dp', size(tunables)) : dp if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 656556ad3..482556b89 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -484,8 +484,8 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) ReverseDiff.reverse_pass!(tape) copyto!(vec(out), ReverseDiff.deriv(tp)) elseif sensealg.autojacvec isa ZygoteVJP - _dy, back = Zygote.pullback(p) do p - vec(f(y, p, t)) + _dy, back = Zygote.pullback(tunables) do tunables + vec(f(y, repack(tunables), t)) end tmp = back(λ) if tmp[1] === nothing From b3b42c7e1e59c655ac5099edd6b382241e88d5e2 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Wed, 28 Aug 2024 17:13:17 +0530 Subject: [PATCH 02/11] chore: return tangent types for dp --- src/concrete_solve.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index dc0c859bc..758f0c985 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -642,12 +642,17 @@ function DiffEqBase._concrete_solve_adjoint( dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : dp isa AbstractArray ? reshape(dp', size(tunables)) : dp + _, repack_adjoint = Zygote.pullback(p) do p + t, _, _ = canonicalize(Tunable(), p) + t + end + if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(), NoTangent(), du0, dp, NoTangent(), + (NoTangent(), NoTangent(), du0, repack_adjoint(dp)[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) else - (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), + (NoTangent(), NoTangent(), NoTangent(), du0, repack_adjoint(dp)[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) end end From 6d89492b599708bd986845b32c02db4f218ac139 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 21 Oct 2024 15:48:45 +0530 Subject: [PATCH 03/11] chore: also implement repack_adjoint for forward mode --- src/concrete_solve.jl | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 758f0c985..e9a52f9ce 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -827,7 +827,7 @@ function DiffEqBase._concrete_solve_adjoint( pparts = typeof(tunables[1:1])[] for j in 0:(num_chunks - 1) local chunk - if ((j + 1) * chunk_size) <= length(p) + if ((j + 1) * chunk_size) <= length(tunables) chunk = ((j * chunk_size + 1):((j + 1) * chunk_size)) pchunk = vec(tunables)[chunk] pdualpart = seed_duals(pchunk, prob.f, @@ -947,7 +947,7 @@ function DiffEqBase._concrete_solve_adjoint( end push!(pparts, vec(_dp)) end - SciMLStructures.replace(Tunable(), p, reduce(vcat, pparts)) + reduce(vcat, pparts) end else dp = nothing @@ -1116,12 +1116,17 @@ function DiffEqBase._concrete_solve_adjoint( end end + _, repack_adjoint = Zygote.pullback(p) do p + t, _, _ = canonicalize(Tunable(), p) + t + end + if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(), NoTangent(), unthunk(du0), unthunk(dp), NoTangent(), + (NoTangent(), NoTangent(), unthunk(du0), repack_adjoint(unthunk(dp))[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) else - (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), + (NoTangent(), NoTangent(), NoTangent(), du0, repack_adjoint(dp)[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) end end From 492587c6dde51644c498968f019da3d27b47f42d Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 21 Oct 2024 16:20:52 +0530 Subject: [PATCH 04/11] test: add docs example as test for SciMLStructures interface --- test/runtests.jl | 1 + test/scimlstructures_interface.jl | 78 +++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 test/scimlstructures_interface.jl diff --git a/test/runtests.jl b/test/runtests.jl index d42a4f6ae..a3ed6e3a5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,7 @@ end @time @safetestset "Prob Kwargs" include("prob_kwargs.jl") @time @safetestset "DiscreteProblem Adjoints" include("discrete.jl") @time @safetestset "Time Type Mixing Adjoints" include("time_type_mixing.jl") + @time @safetestset "SciMLStructures Interface" include("scimlstructures_interface.jl") end end diff --git a/test/scimlstructures_interface.jl b/test/scimlstructures_interface.jl new file mode 100644 index 000000000..51395b7f7 --- /dev/null +++ b/test/scimlstructures_interface.jl @@ -0,0 +1,78 @@ +# taken from https://github.com/SciML/SciMLStructures.jl/pull/28 +using OrdinaryDiffEq, SciMLSensitivity, Zygote +using LinearAlgebra +import SciMLStructures as SS + +mutable struct SubproblemParameters{P, Q, R} + p::P # tunable + q::Q + r::R +end +mutable struct Parameters{P, C} + subparams::P + coeffs::C # tunable matrix +end +# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams) +# and `du[length(subparams)+1:end] .= coeffs * u` +function rhs!(du, u, p::Parameters, t) + for (i, subpars) in enumerate(p.subparams) + du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t + end + N = length(p.subparams) + mul!(view(du, (N+1):(length(du))), p.coeffs, u) + return nothing +end +u = sin.(0.1:0.1:1.0) +subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5] +p = Parameters(subparams, cos.([0.1i+0.33j for i in 1:5, j in 1:10])) +tspan = (0.0, 1.0) +prob = ODEProblem(rhs!, u, tspan, p) +solve(prob, Tsit5()) + +# Mark the struct as a SciMLStructure +SS.isscimlstructure(::Parameters) = true +# It is mutable +SS.ismutablescimlstructure(::Parameters) = true +# Only contains `Tunable` portion +# We could also add a `Constants` portion to contain the values that are +# not tunable. The implementation would be similar to this one. +SS.hasportion(::SS.Tunable, ::Parameters) = true +function SS.canonicalize(::SS.Tunable, p::Parameters) + # concatenate all tunable values into a single vector + buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs)) + # repack takes a new vector of the same length as `buffer`, and constructs + # a new `Parameters` object using the values from the new vector for tunables + # and retaining old values for other parameters. This is exactly what replace does, + # so we can use that instead. + repack = let p = p + function repack(newbuffer) + SS.replace(SS.Tunable(), p, newbuffer) + end + end + # the canonicalized vector, the repack function, and a boolean indicating + # whether the buffer aliases values in the parameter object (here, it doesn't) + return buffer, repack, false +end +function SS.replace(::SS.Tunable, p::Parameters, newbuffer) + N = length(p.subparams) + length(p.coeffs) + @assert length(newbuffer) == N + subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)] + coeffs = reshape(view(newbuffer, (length(p.subparams)+1):length(newbuffer)), size(p.coeffs)) + return Parameters(subparams, coeffs) +end +function SS.replace!(::SS.Tunable, p::Parameters, newbuffer) + N = length(p.subparams) + length(p.coeffs) + @assert length(newbuffer) == N + for (subpar, val) in zip(p.subparams, newbuffer) + subpar.p = val + end + copyto!(coeffs, view(newbuffer, (length(p.subparams)+1):length(newbuffer))) + return p +end + +Zygote.gradient(0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) do tunables + newp = SS.replace(SS.Tunable(), p, tunables) + newprob = remake(prob; p = newp) + sol = solve(newprob, Tsit5()) + return sum(sol.u[end]) +end From 158ec6845814ee9bf8f8c74b61c6779e94b6c2f4 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 21 Oct 2024 18:37:44 +0530 Subject: [PATCH 05/11] chore: handle null paramters in repack_adjoint --- src/concrete_solve.jl | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 4c504d3f3..8d03554c6 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -642,9 +642,13 @@ function DiffEqBase._concrete_solve_adjoint( dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : dp isa AbstractArray ? reshape(dp', size(tunables)) : dp - _, repack_adjoint = Zygote.pullback(p) do p - t, _, _ = canonicalize(Tunable(), p) - t + _, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() + nothing, x -> (nothing, x) + else + Zygote.pullback(p) do p + t, _, _ = canonicalize(Tunable(), p) + t + end end if originator isa SciMLBase.TrackerOriginator || @@ -1139,9 +1143,13 @@ function DiffEqBase._concrete_solve_adjoint( end end - _, repack_adjoint = Zygote.pullback(p) do p - t, _, _ = canonicalize(Tunable(), p) - t + _, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() + nothing, x -> (nothing, x) + else + Zygote.pullback(p) do p + t, _, _ = canonicalize(Tunable(), p) + t + end end if originator isa SciMLBase.TrackerOriginator || From 9d2e932c4864fc4bfdfb8a4836c29533a1625ce5 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 7 Nov 2024 16:45:26 +0530 Subject: [PATCH 06/11] chore: do not try to work with non - SciMLStructures structs --- src/concrete_solve.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 8d03554c6..b82c4973f 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -642,8 +642,8 @@ function DiffEqBase._concrete_solve_adjoint( dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : dp isa AbstractArray ? reshape(dp', size(tunables)) : dp - _, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() - nothing, x -> (nothing, x) + _, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() || !isscimlstructure(p) + nothing, x -> (x,) else Zygote.pullback(p) do p t, _, _ = canonicalize(Tunable(), p) From d7a3ea80cd63dfdab2fcd10c240f4174599ea6b2 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 8 Nov 2024 18:30:55 +0530 Subject: [PATCH 07/11] chore: formatting --- src/concrete_solve.jl | 12 +++-- test/scimlstructures_interface.jl | 84 ++++++++++++++++--------------- 2 files changed, 51 insertions(+), 45 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index b82c4973f..78fef2287 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -642,7 +642,8 @@ function DiffEqBase._concrete_solve_adjoint( dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : dp isa AbstractArray ? reshape(dp', size(tunables)) : dp - _, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() || !isscimlstructure(p) + _, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() || + !isscimlstructure(p) nothing, x -> (x,) else Zygote.pullback(p) do p @@ -656,7 +657,8 @@ function DiffEqBase._concrete_solve_adjoint( (NoTangent(), NoTangent(), du0, repack_adjoint(dp)[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) else - (NoTangent(), NoTangent(), NoTangent(), du0, repack_adjoint(dp)[1], NoTangent(), + (NoTangent(), NoTangent(), NoTangent(), + du0, repack_adjoint(dp)[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) end end @@ -1154,10 +1156,12 @@ function DiffEqBase._concrete_solve_adjoint( if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(), NoTangent(), unthunk(du0), repack_adjoint(unthunk(dp))[1], NoTangent(), + (NoTangent(), NoTangent(), unthunk(du0), + repack_adjoint(unthunk(dp))[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) else - (NoTangent(), NoTangent(), NoTangent(), du0, repack_adjoint(dp)[1], NoTangent(), + (NoTangent(), NoTangent(), NoTangent(), + du0, repack_adjoint(dp)[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) end end diff --git a/test/scimlstructures_interface.jl b/test/scimlstructures_interface.jl index 51395b7f7..0754d3980 100644 --- a/test/scimlstructures_interface.jl +++ b/test/scimlstructures_interface.jl @@ -4,27 +4,27 @@ using LinearAlgebra import SciMLStructures as SS mutable struct SubproblemParameters{P, Q, R} - p::P # tunable - q::Q - r::R + p::P # tunable + q::Q + r::R end mutable struct Parameters{P, C} - subparams::P - coeffs::C # tunable matrix + subparams::P + coeffs::C # tunable matrix end # the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams) # and `du[length(subparams)+1:end] .= coeffs * u` function rhs!(du, u, p::Parameters, t) - for (i, subpars) in enumerate(p.subparams) - du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t - end - N = length(p.subparams) - mul!(view(du, (N+1):(length(du))), p.coeffs, u) - return nothing + for (i, subpars) in enumerate(p.subparams) + du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t + end + N = length(p.subparams) + mul!(view(du, (N + 1):(length(du))), p.coeffs, u) + return nothing end u = sin.(0.1:0.1:1.0) subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5] -p = Parameters(subparams, cos.([0.1i+0.33j for i in 1:5, j in 1:10])) +p = Parameters(subparams, cos.([0.1i + 0.33j for i in 1:5, j in 1:10])) tspan = (0.0, 1.0) prob = ODEProblem(rhs!, u, tspan, p) solve(prob, Tsit5()) @@ -38,41 +38,43 @@ SS.ismutablescimlstructure(::Parameters) = true # not tunable. The implementation would be similar to this one. SS.hasportion(::SS.Tunable, ::Parameters) = true function SS.canonicalize(::SS.Tunable, p::Parameters) - # concatenate all tunable values into a single vector - buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs)) - # repack takes a new vector of the same length as `buffer`, and constructs - # a new `Parameters` object using the values from the new vector for tunables - # and retaining old values for other parameters. This is exactly what replace does, - # so we can use that instead. - repack = let p = p - function repack(newbuffer) - SS.replace(SS.Tunable(), p, newbuffer) + # concatenate all tunable values into a single vector + buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs)) + # repack takes a new vector of the same length as `buffer`, and constructs + # a new `Parameters` object using the values from the new vector for tunables + # and retaining old values for other parameters. This is exactly what replace does, + # so we can use that instead. + repack = let p = p + function repack(newbuffer) + SS.replace(SS.Tunable(), p, newbuffer) + end end - end - # the canonicalized vector, the repack function, and a boolean indicating - # whether the buffer aliases values in the parameter object (here, it doesn't) - return buffer, repack, false + # the canonicalized vector, the repack function, and a boolean indicating + # whether the buffer aliases values in the parameter object (here, it doesn't) + return buffer, repack, false end function SS.replace(::SS.Tunable, p::Parameters, newbuffer) - N = length(p.subparams) + length(p.coeffs) - @assert length(newbuffer) == N - subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)] - coeffs = reshape(view(newbuffer, (length(p.subparams)+1):length(newbuffer)), size(p.coeffs)) - return Parameters(subparams, coeffs) + N = length(p.subparams) + length(p.coeffs) + @assert length(newbuffer) == N + subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) + for (i, subpar) in enumerate(p.subparams)] + coeffs = reshape( + view(newbuffer, (length(p.subparams) + 1):length(newbuffer)), size(p.coeffs)) + return Parameters(subparams, coeffs) end function SS.replace!(::SS.Tunable, p::Parameters, newbuffer) - N = length(p.subparams) + length(p.coeffs) - @assert length(newbuffer) == N - for (subpar, val) in zip(p.subparams, newbuffer) - subpar.p = val - end - copyto!(coeffs, view(newbuffer, (length(p.subparams)+1):length(newbuffer))) - return p + N = length(p.subparams) + length(p.coeffs) + @assert length(newbuffer) == N + for (subpar, val) in zip(p.subparams, newbuffer) + subpar.p = val + end + copyto!(coeffs, view(newbuffer, (length(p.subparams) + 1):length(newbuffer))) + return p end Zygote.gradient(0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) do tunables - newp = SS.replace(SS.Tunable(), p, tunables) - newprob = remake(prob; p = newp) - sol = solve(newprob, Tsit5()) - return sum(sol.u[end]) + newp = SS.replace(SS.Tunable(), p, tunables) + newprob = remake(prob; p = newp) + sol = solve(newprob, Tsit5()) + return sum(sol.u[end]) end From a1c5607ba31f8989330c33cb56db738cfb942826 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 8 Nov 2024 18:43:47 +0530 Subject: [PATCH 08/11] chore: dont repack --- src/concrete_solve.jl | 5 +++-- src/gauss_adjoint.jl | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 78fef2287..52339f035 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1145,8 +1145,9 @@ function DiffEqBase._concrete_solve_adjoint( end end - _, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() - nothing, x -> (nothing, x) + _, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() || + !isscimlstructure(p) + nothing, x -> (x,) else Zygote.pullback(p) do p t, _, _ = canonicalize(Tunable(), p) diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 238d3f2ba..baa8df50b 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -483,7 +483,7 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) copyto!(vec(out), ReverseDiff.deriv(tp)) elseif sensealg.autojacvec isa ZygoteVJP _dy, back = Zygote.pullback(tunables) do tunables - vec(f(y, repack(tunables), t)) + vec(f(y, tunables, t)) end tmp = back(λ) if tmp[1] === nothing From 338496635242fa135a65810ce77e4ca7e65454d9 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 8 Nov 2024 19:55:06 +0530 Subject: [PATCH 09/11] chore: unthunk dp --- src/concrete_solve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 52339f035..d646ddbf6 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1162,7 +1162,7 @@ function DiffEqBase._concrete_solve_adjoint( ntuple(_ -> NoTangent(), length(args))...) else (NoTangent(), NoTangent(), NoTangent(), - du0, repack_adjoint(dp)[1], NoTangent(), + du0, repack_adjoint(unthunk(dp))[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) end end From 1ad4419ac2da1e92b09ff81876a3842f868fa24c Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 8 Nov 2024 20:33:20 +0530 Subject: [PATCH 10/11] chore: revert unthunk --- src/concrete_solve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index d646ddbf6..52339f035 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1162,7 +1162,7 @@ function DiffEqBase._concrete_solve_adjoint( ntuple(_ -> NoTangent(), length(args))...) else (NoTangent(), NoTangent(), NoTangent(), - du0, repack_adjoint(unthunk(dp))[1], NoTangent(), + du0, repack_adjoint(dp)[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) end end From 20d64fd454d4c3f285f663b49bdbcbe2f6bee3bc Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 8 Nov 2024 21:57:19 +0530 Subject: [PATCH 11/11] chore: solve SciMLStructures#30 --- src/concrete_solve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 52339f035..d646ddbf6 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1162,7 +1162,7 @@ function DiffEqBase._concrete_solve_adjoint( ntuple(_ -> NoTangent(), length(args))...) else (NoTangent(), NoTangent(), NoTangent(), - du0, repack_adjoint(dp)[1], NoTangent(), + du0, repack_adjoint(unthunk(dp))[1], NoTangent(), ntuple(_ -> NoTangent(), length(args))...) end end