Skip to content

Commit

Permalink
Redesign default ODE solver to be fully type-grounded
Browse files Browse the repository at this point in the history
This accomplishes a few things:

* Faster precompile times by precompiling less
* Full inference of results when using the automatic algorithm
* Hopefully faster load times by also precompiling less

This is done the same way as

* linearsolve SciML/LinearSolve.jl#307
* nonlinearsolve SciML/NonlinearSolve.jl#238

and is thus the more modern SciML way of doing it. It avoids dispatch by having a single algorithm that always generates the full cache and instead of dispatching between algorithms always branches for the choice.

It turns out, the mechanism already existed for this in OrdinaryDiffEq... it's CompositeAlgorithm, the same bones as AutoSwitch! As such, this reuses quite a bit of code from the auto-switch algorithms but instead of just having two choices it (currently) has 6 that it chooses between. This means that it has stiffness detection and switching behavior, but also in a size-dependent way.

There are still some optimizations to do though. Like LinearSolve.jl, it would be more efficient to have a way to initialize the caches to size zero and then have a way to re-initialize them to the correct size. Right now, it'll generate the same Jacobian N times and it shouldn't need to do that.
  • Loading branch information
ChrisRackauckas committed Jan 1, 2024
1 parent f2ca03e commit a457fd8
Show file tree
Hide file tree
Showing 10 changed files with 1,071 additions and 827 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Expand Down
21 changes: 16 additions & 5 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ using LinearSolve, SimpleNonlinearSolve

using LineSearches

import EnumX

import FillArrays: Trues

# Interfaces
Expand Down Expand Up @@ -141,6 +143,7 @@ include("nlsolve/functional.jl")
include("nlsolve/newton.jl")

include("generic_rosenbrock.jl")
include("composite_algs.jl")

include("caches/basic_caches.jl")
include("caches/low_order_rk_caches.jl")
Expand Down Expand Up @@ -231,7 +234,6 @@ include("constants.jl")
include("solve.jl")
include("initdt.jl")
include("interp_func.jl")
include("composite_algs.jl")

import PrecompileTools

Expand All @@ -250,9 +252,14 @@ PrecompileTools.@compile_workload begin
Tsit5(), Vern7(),
]

stiff = [Rosenbrock23(), Rosenbrock23(autodiff = false),
Rodas5P(), Rodas5P(autodiff = false),
FBDF(), FBDF(autodiff = false),
stiff = [Rosenbrock23(),
Rodas5P(),
FBDF()
]

default_ode = [
DefaultODEAlgorithm(autodiff=false),
DefaultODEAlgorithm()
]

autoswitch = [
Expand Down Expand Up @@ -281,7 +288,11 @@ PrecompileTools.@compile_workload begin
append!(solver_list, stiff)
end

if Preferences.@load_preference("PrecompileAutoSwitch", true)
if Preferences.@load_preference("PrecompileDefault", true)
append!(solver_list, stiff)
end

if Preferences.@load_preference("PrecompileAutoSwitch", false)
append!(solver_list, autoswitch)
end

Expand Down
31 changes: 22 additions & 9 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ isimplicit(alg::CompositeAlgorithm) = any(isimplicit.(alg.algs))

isdtchangeable(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = true
isdtchangeable(alg::CompositeAlgorithm) = all(isdtchangeable.(alg.algs))
isdtchangeable(alg::DefaultSolverAlgorithm) = true

function isdtchangeable(alg::Union{LawsonEuler, NorsettEuler, LieEuler, MagnusGauss4,
CayleyEuler, ETDRK2, ETDRK3, ETDRK4, HochOst4, ETD2})
false
Expand All @@ -176,12 +178,14 @@ ismultistep(alg::ETD2) = true
isadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false
isadaptive(alg::OrdinaryDiffEqAdaptiveAlgorithm) = true
isadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = all(isadaptive.(alg.algs))
isadaptive(alg::DefaultSolverAlgorithm) = true
isadaptive(alg::DImplicitEuler) = true
isadaptive(alg::DABDF2) = true
isadaptive(alg::DFBDF) = true

anyadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = isadaptive(alg)
anyadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = any(isadaptive, alg.algs)
anyadaptive(alg::DefaultSolverAlgorithm) = true

isautoswitch(alg) = false
isautoswitch(alg::CompositeAlgorithm) = alg.choice_function isa AutoSwitch
Expand All @@ -191,9 +195,11 @@ function qmin_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm})
end
qmin_default(alg::CompositeAlgorithm) = maximum(qmin_default.(alg.algs))
qmin_default(alg::DP8) = 1 // 3
qmin_default(alg::DefaultSolverAlgorithm) = 1 // 5

qmax_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = 10
qmax_default(alg::CompositeAlgorithm) = minimum(qmax_default.(alg.algs))
qmax_default(alg::DefaultSolverAlgorithm) = 10
qmax_default(alg::DP8) = 6
qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8

Expand Down Expand Up @@ -277,7 +283,7 @@ end

function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob)
algs = map(alg -> DiffEqBase.prepare_alg(alg, u0, p, prob), alg.algs)
CompositeAlgorithm(algs, alg.choice_function)
CompositeAlgorithm(algs, alg.choice_function, Val(allowfallbacks(alg)))
end

# Extract AD type parameter from algorithm, returning as Val to ensure type stability for boolean options.
Expand Down Expand Up @@ -351,7 +357,8 @@ function concrete_jac(alg::Union{
end

alg_extrapolates(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false
alg_extrapolates(alg::CompositeAlgorithm) = any(alg_extrapolates.(alg.algs))
alg_extrapolates(alg::CompositeAlgorithm) = error("any(alg_extrapolates.(alg.algs))")
alg_extrapolates(alg::DefaultSolverAlgorithm) = false
alg_extrapolates(alg::ImplicitEuler) = true
alg_extrapolates(alg::DImplicitEuler) = true
alg_extrapolates(alg::DABDF2) = true
Expand Down Expand Up @@ -695,6 +702,7 @@ alg_order(alg::Alshina6) = 6

alg_maximum_order(alg) = alg_order(alg)
alg_maximum_order(alg::CompositeAlgorithm) = maximum(alg_order(x) for x in alg.algs)
alg_maximum_order(alg::DefaultSolverAlgorithm) = 7
alg_maximum_order(alg::ExtrapolationMidpointDeuflhard) = 2(alg.max_order + 1)
alg_maximum_order(alg::ImplicitDeuflhardExtrapolation) = 2(alg.max_order + 1)
alg_maximum_order(alg::ExtrapolationMidpointHairerWanner) = 2(alg.max_order + 1)
Expand Down Expand Up @@ -829,6 +837,7 @@ function gamma_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm})
isadaptive(alg) ? 9 // 10 : 0
end
gamma_default(alg::CompositeAlgorithm) = maximum(gamma_default, alg.algs)
gamma_default(alg::DefaultSolverAlgorithm) = 9 // 10
gamma_default(alg::RKC) = 8 // 10
gamma_default(alg::IRKC) = 8 // 10
function gamma_default(alg::ExtrapolationMidpointDeuflhard)
Expand Down Expand Up @@ -949,14 +958,18 @@ function unwrap_alg(integrator, is_stiff)
if !iscomp
return alg
elseif alg.choice_function isa AutoSwitchCache
if is_stiff === nothing
throwautoswitch(alg)
end
num = is_stiff ? 2 : 1
if num == 1
return alg.algs[1]
if alg.choice_function.algtrait isa DefaultODESolver
alg.algs[alg.choice_function.current]
else
return alg.algs[2]
if is_stiff === nothing
throwautoswitch(alg)
end
num = is_stiff ? 2 : 1
if num == 1
return alg.algs[1]
else
return alg.algs[2]
end
end
else
return _eval_index(identity, alg.algs, integrator.cache.current)
Expand Down
41 changes: 40 additions & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3100,17 +3100,56 @@ end

#########################################

struct CompositeAlgorithm{T, F} <: OrdinaryDiffEqCompositeAlgorithm
struct CompositeAlgorithm{Fallbacks, T, F} <: OrdinaryDiffEqCompositeAlgorithm
algs::T
choice_function::F
function CompositeAlgorithm(algs, choice_function, fallbacks_enabled::Val{X} = Val(true)) where X
new{X, typeof(algs), typeof(choice_function)}(algs, choice_function)
end
end

allowfallbacks(::CompositeAlgorithm{Fallbacks, T, F}) where {Fallbacks, T, F} = Fallbacks

TruncatedStacktraces.@truncate_stacktrace CompositeAlgorithm 1

if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(CompositeAlgorithm)
end

mutable struct AutoSwitchCache{Trait, nAlg, sAlg, tolType, T}
algtrait::Trait
count::Int
successive_switches::Int
nonstiffalg::nAlg
stiffalg::sAlg
is_stiffalg::Bool
maxstiffstep::Int
maxnonstiffstep::Int
nonstifftol::tolType
stifftol::tolType
dtfac::T
stiffalgfirst::Bool
switch_max::Int
current::Int
end

struct AutoSwitch{Trait, nAlg, sAlg, tolType, T}
algtrait::Trait
nonstiffalg::nAlg
stiffalg::sAlg
maxstiffstep::Int
maxnonstiffstep::Int
nonstifftol::tolType
stifftol::tolType
dtfac::T
stiffalgfirst::Bool
switch_max::Int
end

struct DefaultODESolver end
const DefaultSolverAlgorithm = Union{CompositeAlgorithm{false, <:Tuple, <:AutoSwitch{DefaultODESolver}},
CompositeAlgorithm{false, <:Tuple, <:AutoSwitchCache{DefaultODESolver}}}

################################################################################
"""
MEBDF2: Multistep Method
Expand Down
16 changes: 9 additions & 7 deletions src/caches/basic_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ abstract type OrdinaryDiffEqMutableCache <: OrdinaryDiffEqCache end
struct ODEEmptyCache <: OrdinaryDiffEqConstantCache end
struct ODEChunkCache{CS} <: OrdinaryDiffEqConstantCache end

mutable struct CompositeCache{T, F} <: OrdinaryDiffEqCache
mutable struct CompositeCache{Fallbacks, T, F} <: OrdinaryDiffEqCache
caches::T
choice_function::F
current::Int
Expand All @@ -16,28 +16,30 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(CompositeCache)
end

function alg_cache(alg::CompositeAlgorithm{Tuple{T1, T2}, F}, u, rate_prototype,
function alg_cache(alg::CompositeAlgorithm{Fallbacks, Tuple{T1, T2}, F}, u, rate_prototype,
::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev,
uprev2, f, t, dt, reltol, p, calck,
::Val{V}) where {T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits,
::Val{V}) where {Fallbacks, T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits,
tTypeNoUnits}
caches = (alg_cache(alg.algs[1], u, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits,
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)),
alg_cache(alg.algs[2], u, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits,
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)))
CompositeCache(caches, alg.choice_function, 1)
CompositeCache{Fallbacks, typeof(caches), typeof(alg.choice_function)}(
caches, alg.choice_function, 1)
end

function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits},
function alg_cache(alg::CompositeAlgorithm{Fallbacks}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{V}) where {V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
::Val{V}) where {Fallbacks, V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
caches = __alg_cache(alg.algs, u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits,
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V))
CompositeCache(caches, alg.choice_function, 1)
CompositeCache{Fallbacks, typeof(caches), typeof(alg.choice_function)}(
caches, alg.choice_function, 1)
end

# map + closure approach doesn't infer
Expand Down
30 changes: 20 additions & 10 deletions src/caches/verner_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
stage_limiter!::StageLimiter
step_limiter!::StepLimiter
thread::Thread
lazy::Bool
end

TruncatedStacktraces.@truncate_stacktrace Vern6Cache 1
Expand All @@ -44,19 +45,20 @@ function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits},
recursivefill!(atmp, false)
rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype)
Vern6Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, utilde, tmp, rtmp, atmp, tab,
alg.stage_limiter!, alg.step_limiter!, alg.thread)
alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy)
end

struct Vern6ConstantCache{TabType} <: OrdinaryDiffEqConstantCache
tab::TabType
lazy::Bool
end

function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tab = Vern6Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
Vern6ConstantCache(tab)
Vern6ConstantCache(tab, alg.lazy)
end

@cache struct Vern7Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter,
Expand All @@ -81,6 +83,7 @@ end
stage_limiter!::StageLimiter
step_limiter!::StepLimiter
thread::Thread
lazy::Bool
end

TruncatedStacktraces.@truncate_stacktrace Vern7Cache 1
Expand All @@ -105,16 +108,18 @@ function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits},
recursivefill!(atmp, false)
rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype)
Vern7Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, utilde, tmp, rtmp, atmp,
alg.stage_limiter!, alg.step_limiter!, alg.thread)
alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy)
end

struct Vern7ConstantCache <: OrdinaryDiffEqConstantCache end
struct Vern7ConstantCache <: OrdinaryDiffEqConstantCache
lazy::Bool
end

function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
Vern7ConstantCache()
Vern7ConstantCache(alg.lazy)
end

@cache struct Vern8Cache{uType, rateType, uNoUnitsType, TabType, StageLimiter, StepLimiter,
Expand Down Expand Up @@ -143,6 +148,7 @@ end
stage_limiter!::StageLimiter
step_limiter!::StepLimiter
thread::Thread
lazy::Bool
end

TruncatedStacktraces.@truncate_stacktrace Vern8Cache 1
Expand Down Expand Up @@ -171,19 +177,20 @@ function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits},
recursivefill!(atmp, false)
rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype)
Vern8Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, utilde,
tmp, rtmp, atmp, tab, alg.stage_limiter!, alg.step_limiter!, alg.thread)
tmp, rtmp, atmp, tab, alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy)
end

struct Vern8ConstantCache{TabType} <: OrdinaryDiffEqConstantCache
tab::TabType
lazy::Bool
end

function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tab = Vern8Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
Vern8ConstantCache(tab)
Vern8ConstantCache(tab, alg.lazy)
end

@cache struct Vern9Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter,
Expand Down Expand Up @@ -214,6 +221,7 @@ end
stage_limiter!::StageLimiter
step_limiter!::StepLimiter
thread::Thread
lazy::Bool
end

TruncatedStacktraces.@truncate_stacktrace Vern9Cache 1
Expand Down Expand Up @@ -245,14 +253,16 @@ function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits},
rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype)
Vern9Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15,
k16, utilde, tmp, rtmp, atmp, alg.stage_limiter!, alg.step_limiter!,
alg.thread)
alg.thread, alg.lazy)
end

struct Vern9ConstantCache <: OrdinaryDiffEqConstantCache end
struct Vern9ConstantCache <: OrdinaryDiffEqConstantCache
lazy::Bool
end

function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
Vern9ConstantCache()
Vern9ConstantCache(alg.lazy)
end
Loading

0 comments on commit a457fd8

Please sign in to comment.