Skip to content

Commit

Permalink
Merge pull request #961 from SciML/reversediff_sol
Browse files Browse the repository at this point in the history
Make ReverseDiffAdjoint return ODE solutions
  • Loading branch information
ChrisRackauckas authored Dec 27, 2023
2 parents bd6e048 + 329939a commit 709a1d5
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- SDE2
- SDE3
version:
- '1'
- '1.9'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
4 changes: 2 additions & 2 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractDiscre
p isa DiffEqBase.NullParameters || ReverseDiff.value!(tp, p)
ReverseDiff.forward_pass!(tape)
function reversediff_adjoint_backpass(ybar)
_ybar = if ybar isa VectorOfArray
_ybar = if ybar isa AbstractVectorOfArray
Array(ybar)
elseif eltype(ybar) <: AbstractArray
Array(VectorOfArray(ybar))
Expand All @@ -1309,7 +1309,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractDiscre
ntuple(_ -> NoTangent(), length(args))...)
end
end
Array(VectorOfArray(u)), reversediff_adjoint_backpass
sol, reversediff_adjoint_backpass
end

function DiffEqBase._concrete_solve_adjoint(prob::SciMLBase.AbstractODEProblem, alg,
Expand Down
2 changes: 1 addition & 1 deletion test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ _, easy_res6 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_discre
reltol = 1e-14,
sensealg = InterpolatingAdjoint(checkpointing = true),
checkpoints = soloop_nodense.t[1:5:end])
@test_broken _, easy_res62 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
_, easy_res62 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
dgdu_discrete = dg, abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(checkpointing = true,
Expand Down
4 changes: 2 additions & 2 deletions test/forward_prob_kwargs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ end
res4 = Zygote.gradient(cost, p[1])[1] # (7.720368430265481,)

@test res res2
@test res res3
@test res res4
@test res2 res3
@test res2 res4
19 changes: 19 additions & 0 deletions test/reversediff_output_types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using OrdinaryDiffEq, SciMLSensitivity, Zygote, Test

function lotka_volterra(u, p, t)
du1 = p[1]*u[1] - p[2]*u[1]*u[2]
du2 = -p[3]*u[2] + p[4]*u[1]*u[2]
return [du1, du2]
end
u0 = [1.0,1.0]
tspan = (0.0,10.0)
p = [1.5,1.0,3.0,1.0]
prob = ODEProblem(lotka_volterra,u0,tspan,p)

loss(u0; kwargs...) = solve(remake(prob, u0=u0), Tsit5(); reltol=1e-10, abstol=1e-10, kwargs...).u |> last |> sum

dp1 = Zygote.gradient(loss, u0)[1]
dp2 = Zygote.gradient(u0 -> loss(u0; sensealg=TrackerAdjoint()), u0)[1]
dp3 = Zygote.gradient(u0 -> loss(u0; sensealg=ReverseDiffAdjoint()), u0)[1]
@test dp1 dp2
@test dp1 dp3
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ end
@time @safetestset "Complex Adjoints" begin
include("complex_adjoints.jl")
end
@time @safetestset "ReverseDiffAdjoint Output Type" begin
include("reversediff_output_types.jl")
end
@time @safetestset "Forward Remake" begin
include("forward_remake.jl")
end
Expand Down

0 comments on commit 709a1d5

Please sign in to comment.