Skip to content

Commit

Permalink
Merge pull request #2540 from oscardssmith/os/fix-DefaultCache-type-s…
Browse files Browse the repository at this point in the history
…tability

fix type stability for `DefaultCache`
  • Loading branch information
ChrisRackauckas authored Nov 21, 2024
2 parents 9799ee1 + 64defc0 commit 1b537a0
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
3 changes: 2 additions & 1 deletion lib/OrdinaryDiffEqCore/src/caches/basic_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ function alg_cache(alg::CompositeAlgorithm{CS, Tuple{A1, A2, A3, A4, A5, A6}}, u
args = (u, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt,
reltol, p, calck, Val(V))
argT = map(typeof, args)
# Core.Typeof to turn uEltypeNoUnits into Type{uEltypeNoUnits} rather than DataType
argT = map(Core.Typeof, args)
T1 = Base.promote_op(alg_cache, A1, argT...)
T2 = Base.promote_op(alg_cache, A2, argT...)
T3 = Base.promote_op(alg_cache, A3, argT...)
Expand Down
18 changes: 16 additions & 2 deletions lib/OrdinaryDiffEqCore/src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ function default_ode_interpolant(
return ode_interpolant(Θ, integrator.dt, integrator.uprev,
integrator.u, integrator.k, cache.cache5, idxs,
deriv, integrator.differential_vars)
else # alg_choice == 6
elseif alg_choice == 6
return ode_interpolant(Θ, integrator.dt, integrator.uprev,
integrator.u, integrator.k, cache.cache6, idxs,
deriv, integrator.differential_vars)
else
error("DefaultCache invalid alg_choice. File an issue.")
end
end

Expand Down Expand Up @@ -227,6 +229,8 @@ end
ode_interpolant!(val, Θ, integrator.dt, integrator.uprev, integrator.u,
integrator.k, integrator.cache.cache6,
idxs, deriv, integrator.differential_vars)
else
error("DefaultCache invalid alg_choice. File an issue.")
end
else
ode_interpolant!(val, Θ, integrator.dt, integrator.uprev, integrator.u,
Expand Down Expand Up @@ -256,10 +260,12 @@ function default_ode_interpolant!(
return ode_interpolant!(val, Θ, integrator.dt, integrator.uprev,
integrator.u, integrator.k, cache.cache5, idxs,
deriv, integrator.differential_vars)
else # alg_choice == 6
elseif alg_choice == 6
return ode_interpolant!(val, Θ, integrator.dt, integrator.uprev,
integrator.u, integrator.k, cache.cache6, idxs,
deriv, integrator.differential_vars)
else
error("DefaultCache invalid alg_choice. File an issue.")
end
end

Expand Down Expand Up @@ -380,6 +386,8 @@ function default_ode_extrapolant!(
ode_interpolant!(val, Θ, integrator.t - integrator.tprev,
integrator.uprev2, integrator.uprev,
integrator.k, cache.cache6, idxs, deriv, integrator.differential_vars)
else
error("DefaultCache invalid alg_choice. File an issue.")
end
end

Expand Down Expand Up @@ -444,6 +452,8 @@ function default_ode_extrapolant(
ode_interpolant(Θ, integrator.t - integrator.tprev,
integrator.uprev2, integrator.uprev,
integrator.k, cache.cache6, idxs, deriv, integrator.differential_vars)
else
error("DefaultCache invalid alg_choice. File an issue.")
end
end

Expand Down Expand Up @@ -810,6 +820,8 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p,
cache.cache6) # update the kcurrent
val = ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊],
cache.cache6, idxs, deriv, differential_vars)
else
error("DefaultCache invalid alg_choice. File an issue.")
end
else
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,
Expand Down Expand Up @@ -892,6 +904,8 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p,
cache.cache6) # update the kcurrent
ode_interpolant!(out, Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊],
cache.cache6, idxs, deriv, differential_vars)
else
error("DefaultCache invalid alg_choice. File an issue.")
end
else
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,
Expand Down
4 changes: 4 additions & 0 deletions lib/OrdinaryDiffEqDefault/test/default_solver_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ end
prob_rober = ODEProblem(rober, [1.0, 0.0, 0.0], (0.0, 1e3), (0.04, 3e7, 1e4))
sol = solve(prob_rober)
rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23(autodiff = false)))
#test that cache is type stable
@test typeof(sol.interp.cache.cache3) == typeof(rosensol.interp.cache.caches[2])
# test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this).
@test sol.stats.naccept == rosensol.stats.naccept
@test sol.stats.nf == rosensol.stats.nf
Expand All @@ -50,6 +52,8 @@ rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23(autodiff = false)))
sol = solve(prob_rober, reltol = 1e-7, abstol = 1e-7)
rosensol = solve(
prob_rober, AutoVern7(Rodas5P(autodiff = false)), reltol = 1e-7, abstol = 1e-7)
#test that cache is type stable
@test typeof(sol.interp.cache.cache4) == typeof(rosensol.interp.cache.caches[2])
# test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this).
@test sol.stats.naccept == rosensol.stats.naccept
@test sol.stats.nf == rosensol.stats.nf
Expand Down

0 comments on commit 1b537a0

Please sign in to comment.