diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c6927002a..e1583ef7c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -26,7 +26,7 @@ jobs: - SDE2 - SDE3 version: - - '1' + - '1.9' steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 25609c50c..06b5dbb93 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1289,7 +1289,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractDiscre p isa DiffEqBase.NullParameters || ReverseDiff.value!(tp, p) ReverseDiff.forward_pass!(tape) function reversediff_adjoint_backpass(ybar) - _ybar = if ybar isa VectorOfArray + _ybar = if ybar isa AbstractVectorOfArray Array(ybar) elseif eltype(ybar) <: AbstractArray Array(VectorOfArray(ybar)) @@ -1309,7 +1309,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractDiscre ntuple(_ -> NoTangent(), length(args))...) end end - Array(VectorOfArray(u)), reversediff_adjoint_backpass + sol, reversediff_adjoint_backpass end function DiffEqBase._concrete_solve_adjoint(prob::SciMLBase.AbstractODEProblem, alg, diff --git a/test/adjoint.jl b/test/adjoint.jl index 6fb54a5fa..642958796 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -245,7 +245,7 @@ _, easy_res6 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_discre reltol = 1e-14, sensealg = InterpolatingAdjoint(checkpointing = true), checkpoints = soloop_nodense.t[1:5:end]) -@test_broken _, easy_res62 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, +_, easy_res62 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, sensealg = InterpolatingAdjoint(checkpointing = true, diff --git a/test/forward_prob_kwargs.jl b/test/forward_prob_kwargs.jl index 68a671681..d1de53953 100644 --- a/test/forward_prob_kwargs.jl +++ b/test/forward_prob_kwargs.jl @@ -31,5 +31,5 @@ end res4 = Zygote.gradient(cost, p[1])[1] # (7.720368430265481,) @test res ≈ res2 -@test res ≈ res3 -@test res ≈ res4 +@test res2 ≈ res3 +@test res2 ≈ res4 diff --git a/test/reversediff_output_types.jl b/test/reversediff_output_types.jl new file mode 100644 index 000000000..eb32e354f --- /dev/null +++ b/test/reversediff_output_types.jl @@ -0,0 +1,19 @@ +using OrdinaryDiffEq, SciMLSensitivity, Zygote, Test + +function lotka_volterra(u, p, t) + du1 = p[1]*u[1] - p[2]*u[1]*u[2] + du2 = -p[3]*u[2] + p[4]*u[1]*u[2] + return [du1, du2] +end +u0 = [1.0,1.0] +tspan = (0.0,10.0) +p = [1.5,1.0,3.0,1.0] +prob = ODEProblem(lotka_volterra,u0,tspan,p) + +loss(u0; kwargs...) = solve(remake(prob, u0=u0), Tsit5(); reltol=1e-10, abstol=1e-10, kwargs...).u |> last |> sum + +dp1 = Zygote.gradient(loss, u0)[1] +dp2 = Zygote.gradient(u0 -> loss(u0; sensealg=TrackerAdjoint()), u0)[1] +dp3 = Zygote.gradient(u0 -> loss(u0; sensealg=ReverseDiffAdjoint()), u0)[1] +@test dp1 ≈ dp2 +@test dp1 ≈ dp3 \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index e4fce0f8d..b5c739860 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,6 +41,9 @@ end @time @safetestset "Complex Adjoints" begin include("complex_adjoints.jl") end + @time @safetestset "ReverseDiffAdjoint Output Type" begin + include("reversediff_output_types.jl") + end @time @safetestset "Forward Remake" begin include("forward_remake.jl") end