Skip to content

Commit

Permalink
Merge pull request #993 from ArnoStrouwen/default_alg
Browse files Browse the repository at this point in the history
use GaussAdjoint in automatic sensealg
  • Loading branch information
ChrisRackauckas authored May 15, 2024
2 parents 548b8ae + 381cc26 commit b279b07
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 11 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "0.1, 0.2, 1.0"
ADTypes = "0.1, 0.2, 1"
Adapt = "1.0, 2.0, 3.0, 4"
AlgebraicMultigrid = "0.6.0"
Aqua = "0.8.4"
Expand Down Expand Up @@ -85,7 +85,7 @@ RecursiveArrayTools = "3.18.1"
Reexport = "1.0"
ReverseDiff = "1.15.1"
SafeTestsets = "0.1.0"
SciMLBase = "2.28"
SciMLBase = "2.37"
SciMLOperators = "0.3"
SparseArrays = "1.10"
SparseDiffTools = "2.5"
Expand Down
6 changes: 4 additions & 2 deletions docs/src/examples/dde/delay_diffeq.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ using Plots
callback = function (state, l...; doplot = false)
display(loss_dde(state.u))
doplot &&
display(plot(solve(remake(prob_dde, p = state.u), MethodOfSteps(Tsit5()), saveat = 0.1),
display(plot(
solve(remake(prob_dde, p = state.u), MethodOfSteps(Tsit5()), saveat = 0.1),
ylim = (0, 6)))
return false
end
Expand All @@ -64,7 +65,8 @@ using Plots
callback = function (state, l...; doplot = false)
display(loss_dde(state.u))
doplot &&
display(plot(solve(remake(prob_dde, p = state.u), MethodOfSteps(Tsit5()), saveat = 0.1),
display(plot(
solve(remake(prob_dde, p = state.u), MethodOfSteps(Tsit5()), saveat = 0.1),
ylim = (0, 6)))
return false
end
Expand Down
9 changes: 8 additions & 1 deletion src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,12 @@ function automatic_sensealg_choice(
if verbose
@warn "Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs"
end

if p === nothing || p === DiffEqBase.NullParameters()
# QuadratureAdjoint skips all p calculations until the end
# So it's the fastest when there are no parameters
QuadratureAdjoint(autodiff = false, autojacvec = vjp)
elseif prob isa ODEProblem
GaussAdjoint(autodiff = false, autojacvec = vjp)
else
InterpolatingAdjoint(autodiff = false, autojacvec = vjp)
end
Expand All @@ -157,6 +158,8 @@ function automatic_sensealg_choice(
# QuadratureAdjoint skips all p calculations until the end
# So it's the fastest when there are no parameters
QuadratureAdjoint(autojacvec = vjp)
elseif prob isa ODEProblem
GaussAdjoint(autojacvec = vjp)
else
InterpolatingAdjoint(autojacvec = vjp)
end
Expand All @@ -170,12 +173,16 @@ function automatic_sensealg_choice(
# If reverse-mode isn't working, just fallback to numerical vjps
if p === nothing || p === DiffEqBase.NullParameters()
QuadratureAdjoint(autodiff = false, autojacvec = vjp)
elseif prob isa ODEProblem
GaussAdjoint(autodiff = false, autojacvec = vjp)
else
InterpolatingAdjoint(autodiff = false, autojacvec = vjp)
end
else
if p === nothing || p === DiffEqBase.NullParameters()
QuadratureAdjoint(autojacvec = vjp)
elseif prob isa ODEProblem
GaussAdjoint(autojacvec = vjp)
else
InterpolatingAdjoint(autojacvec = vjp)
end
Expand Down
5 changes: 2 additions & 3 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,12 @@ _, easy_res2 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
reltol = 1e-14,
sensealg = QuadratureAdjoint(abstol = 1e-14,
reltol = 1e-14))
@test_broken easy_res22 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
_, easy_res22 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(autojacvec = false,
abstol = 1e-14,
reltol = 1e-14))[1] isa
AbstractArray
reltol = 1e-14))
_, easy_res2 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
Expand Down
1 change: 0 additions & 1 deletion test/concrete_solve_derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ du06, dp6 = Zygote.gradient(
u0,
p)


@test ū0du01 rtol=1e-12
@test ū0 == du02
@test ū0du03 rtol=1e-12
Expand Down
3 changes: 1 addition & 2 deletions test/null_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ end
@test_broken Zygote.gradient(loss7, zeros(123))[1] == zeros(123)
@test Zygote.gradient(loss8, zeros(123))[1] == zeros(123)
@test Zygote.gradient(loss9, zeros(123))[1] == zeros(123)
@test_throws SciMLSensitivity.ZygoteVJPNothingError Zygote.gradient(loss10,
zeros(123))[1]==zeros(123)
@test Zygote.gradient(loss10, zeros(123))[1] == zeros(123)

## OOP tests for initial condition
function loss_oop(u0; sensealg = nothing)
Expand Down

0 comments on commit b279b07

Please sign in to comment.