Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign default ODE solver to be fully type-grounded #2103

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Expand Down
21 changes: 16 additions & 5 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ using LinearSolve, SimpleNonlinearSolve

using LineSearches

import EnumX

import FillArrays: Trues

# Interfaces
Expand Down Expand Up @@ -141,6 +143,7 @@ include("nlsolve/functional.jl")
include("nlsolve/newton.jl")

include("generic_rosenbrock.jl")
include("composite_algs.jl")

include("caches/basic_caches.jl")
include("caches/low_order_rk_caches.jl")
Expand Down Expand Up @@ -231,7 +234,6 @@ include("constants.jl")
include("solve.jl")
include("initdt.jl")
include("interp_func.jl")
include("composite_algs.jl")

import PrecompileTools

Expand All @@ -250,9 +252,14 @@ PrecompileTools.@compile_workload begin
Tsit5(), Vern7(),
]

stiff = [Rosenbrock23(), Rosenbrock23(autodiff = false),
Rodas5P(), Rodas5P(autodiff = false),
FBDF(), FBDF(autodiff = false),
stiff = [Rosenbrock23(),
Rodas5P(),
FBDF()
]

default_ode = [
DefaultODEAlgorithm(autodiff=false),
DefaultODEAlgorithm()
]

autoswitch = [
Expand Down Expand Up @@ -281,7 +288,11 @@ PrecompileTools.@compile_workload begin
append!(solver_list, stiff)
end

if Preferences.@load_preference("PrecompileAutoSwitch", true)
if Preferences.@load_preference("PrecompileDefault", true)
append!(solver_list, default_ode)
end

if Preferences.@load_preference("PrecompileAutoSwitch", false)
append!(solver_list, autoswitch)
end

Expand Down
29 changes: 21 additions & 8 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ isimplicit(alg::CompositeAlgorithm) = any(isimplicit.(alg.algs))

isdtchangeable(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = true
isdtchangeable(alg::CompositeAlgorithm) = all(isdtchangeable.(alg.algs))
isdtchangeable(alg::DefaultSolverAlgorithm) = true

function isdtchangeable(alg::Union{LawsonEuler, NorsettEuler, LieEuler, MagnusGauss4,
CayleyEuler, ETDRK2, ETDRK3, ETDRK4, HochOst4, ETD2})
false
Expand All @@ -176,12 +178,14 @@ ismultistep(alg::ETD2) = true
isadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false
isadaptive(alg::OrdinaryDiffEqAdaptiveAlgorithm) = true
isadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = all(isadaptive.(alg.algs))
isadaptive(alg::DefaultSolverAlgorithm) = true
isadaptive(alg::DImplicitEuler) = true
isadaptive(alg::DABDF2) = true
isadaptive(alg::DFBDF) = true

anyadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = isadaptive(alg)
anyadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = any(isadaptive, alg.algs)
anyadaptive(alg::DefaultSolverAlgorithm) = true

isautoswitch(alg) = false
isautoswitch(alg::CompositeAlgorithm) = alg.choice_function isa AutoSwitch
Expand All @@ -191,9 +195,11 @@ function qmin_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm})
end
qmin_default(alg::CompositeAlgorithm) = maximum(qmin_default.(alg.algs))
qmin_default(alg::DP8) = 1 // 3
qmin_default(alg::DefaultSolverAlgorithm) = 1 // 5

qmax_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = 10
qmax_default(alg::CompositeAlgorithm) = minimum(qmax_default.(alg.algs))
qmax_default(alg::DefaultSolverAlgorithm) = 10
qmax_default(alg::DP8) = 6
qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8

Expand Down Expand Up @@ -277,7 +283,7 @@ end

function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob)
algs = map(alg -> DiffEqBase.prepare_alg(alg, u0, p, prob), alg.algs)
CompositeAlgorithm(algs, alg.choice_function)
CompositeAlgorithm(algs, alg.choice_function, Val(allowfallbacks(alg)))
end

# Extract AD type parameter from algorithm, returning as Val to ensure type stability for boolean options.
Expand Down Expand Up @@ -352,6 +358,7 @@ end

alg_extrapolates(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false
alg_extrapolates(alg::CompositeAlgorithm) = any(alg_extrapolates.(alg.algs))
alg_extrapolates(alg::DefaultSolverAlgorithm) = false
alg_extrapolates(alg::ImplicitEuler) = true
alg_extrapolates(alg::DImplicitEuler) = true
alg_extrapolates(alg::DABDF2) = true
Expand Down Expand Up @@ -695,6 +702,7 @@ alg_order(alg::Alshina6) = 6

alg_maximum_order(alg) = alg_order(alg)
alg_maximum_order(alg::CompositeAlgorithm) = maximum(alg_order(x) for x in alg.algs)
alg_maximum_order(alg::DefaultSolverAlgorithm) = 7
alg_maximum_order(alg::ExtrapolationMidpointDeuflhard) = 2(alg.max_order + 1)
alg_maximum_order(alg::ImplicitDeuflhardExtrapolation) = 2(alg.max_order + 1)
alg_maximum_order(alg::ExtrapolationMidpointHairerWanner) = 2(alg.max_order + 1)
Expand Down Expand Up @@ -829,6 +837,7 @@ function gamma_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm})
isadaptive(alg) ? 9 // 10 : 0
end
gamma_default(alg::CompositeAlgorithm) = maximum(gamma_default, alg.algs)
gamma_default(alg::DefaultSolverAlgorithm) = 9 // 10
gamma_default(alg::RKC) = 8 // 10
gamma_default(alg::IRKC) = 8 // 10
function gamma_default(alg::ExtrapolationMidpointDeuflhard)
Expand Down Expand Up @@ -949,14 +958,18 @@ function unwrap_alg(integrator, is_stiff)
if !iscomp
return alg
elseif alg.choice_function isa AutoSwitchCache
if is_stiff === nothing
throwautoswitch(alg)
end
num = is_stiff ? 2 : 1
if num == 1
return alg.algs[1]
if alg.choice_function.algtrait isa DefaultODESolver
alg.algs[alg.choice_function.current]
else
return alg.algs[2]
if is_stiff === nothing
throwautoswitch(alg)
end
num = is_stiff ? 2 : 1
if num == 1
return alg.algs[1]
else
return alg.algs[2]
end
end
else
return _eval_index(identity, alg.algs, integrator.cache.current)
Expand Down
41 changes: 40 additions & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3100,17 +3100,56 @@ end

#########################################

struct CompositeAlgorithm{T, F} <: OrdinaryDiffEqCompositeAlgorithm
struct CompositeAlgorithm{Fallbacks, T, F} <: OrdinaryDiffEqCompositeAlgorithm
algs::T
choice_function::F
function CompositeAlgorithm(algs, choice_function, fallbacks_enabled::Val{X} = Val(true)) where X
new{X, typeof(algs), typeof(choice_function)}(algs, choice_function)
end
end

allowfallbacks(::CompositeAlgorithm{Fallbacks, T, F}) where {Fallbacks, T, F} = Fallbacks

TruncatedStacktraces.@truncate_stacktrace CompositeAlgorithm 1

if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(CompositeAlgorithm)
end

mutable struct AutoSwitchCache{Trait, nAlg, sAlg, tolType, T}
algtrait::Trait
count::Int
successive_switches::Int
nonstiffalg::nAlg
stiffalg::sAlg
is_stiffalg::Bool
maxstiffstep::Int
maxnonstiffstep::Int
nonstifftol::tolType
stifftol::tolType
dtfac::T
stiffalgfirst::Bool
switch_max::Int
current::Int
end

struct AutoSwitch{Trait, nAlg, sAlg, tolType, T}
algtrait::Trait
nonstiffalg::nAlg
stiffalg::sAlg
maxstiffstep::Int
maxnonstiffstep::Int
nonstifftol::tolType
stifftol::tolType
dtfac::T
stiffalgfirst::Bool
switch_max::Int
end

struct DefaultODESolver end
const DefaultSolverAlgorithm = Union{CompositeAlgorithm{false, <:Tuple, <:AutoSwitch{DefaultODESolver}},
CompositeAlgorithm{false, <:Tuple, <:AutoSwitchCache{DefaultODESolver}}}

################################################################################
"""
MEBDF2: Multistep Method
Expand Down
16 changes: 9 additions & 7 deletions src/caches/basic_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ abstract type OrdinaryDiffEqMutableCache <: OrdinaryDiffEqCache end
struct ODEEmptyCache <: OrdinaryDiffEqConstantCache end
struct ODEChunkCache{CS} <: OrdinaryDiffEqConstantCache end

mutable struct CompositeCache{T, F} <: OrdinaryDiffEqCache
mutable struct CompositeCache{Fallbacks, T, F} <: OrdinaryDiffEqCache
caches::T
choice_function::F
current::Int
Expand All @@ -16,28 +16,30 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(CompositeCache)
end

function alg_cache(alg::CompositeAlgorithm{Tuple{T1, T2}, F}, u, rate_prototype,
function alg_cache(alg::CompositeAlgorithm{Fallbacks, Tuple{T1, T2}, F}, u, rate_prototype,
::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev,
uprev2, f, t, dt, reltol, p, calck,
::Val{V}) where {T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits,
::Val{V}) where {Fallbacks, T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits,
tTypeNoUnits}
caches = (alg_cache(alg.algs[1], u, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits,
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)),
alg_cache(alg.algs[2], u, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits,
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)))
CompositeCache(caches, alg.choice_function, 1)
CompositeCache{Fallbacks, typeof(caches), typeof(alg.choice_function)}(
caches, alg.choice_function, 1)
end

function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits},
function alg_cache(alg::CompositeAlgorithm{Fallbacks}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{V}) where {V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
::Val{V}) where {Fallbacks, V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
caches = __alg_cache(alg.algs, u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits,
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V))
CompositeCache(caches, alg.choice_function, 1)
CompositeCache{Fallbacks, typeof(caches), typeof(alg.choice_function)}(
caches, alg.choice_function, 1)
end

# map + closure approach doesn't infer
Expand Down
30 changes: 20 additions & 10 deletions src/caches/verner_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
stage_limiter!::StageLimiter
step_limiter!::StepLimiter
thread::Thread
lazy::Bool
end

TruncatedStacktraces.@truncate_stacktrace Vern6Cache 1
Expand All @@ -44,19 +45,20 @@ function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits},
recursivefill!(atmp, false)
rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype)
Vern6Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, utilde, tmp, rtmp, atmp, tab,
alg.stage_limiter!, alg.step_limiter!, alg.thread)
alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy)
end

struct Vern6ConstantCache{TabType} <: OrdinaryDiffEqConstantCache
tab::TabType
lazy::Bool
end

function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tab = Vern6Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
Vern6ConstantCache(tab)
Vern6ConstantCache(tab, alg.lazy)
end

@cache struct Vern7Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter,
Expand All @@ -81,6 +83,7 @@ end
stage_limiter!::StageLimiter
step_limiter!::StepLimiter
thread::Thread
lazy::Bool
end

TruncatedStacktraces.@truncate_stacktrace Vern7Cache 1
Expand All @@ -105,16 +108,18 @@ function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits},
recursivefill!(atmp, false)
rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype)
Vern7Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, utilde, tmp, rtmp, atmp,
alg.stage_limiter!, alg.step_limiter!, alg.thread)
alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy)
end

struct Vern7ConstantCache <: OrdinaryDiffEqConstantCache end
struct Vern7ConstantCache <: OrdinaryDiffEqConstantCache
lazy::Bool
end

function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
Vern7ConstantCache()
Vern7ConstantCache(alg.lazy)
end

@cache struct Vern8Cache{uType, rateType, uNoUnitsType, TabType, StageLimiter, StepLimiter,
Expand Down Expand Up @@ -143,6 +148,7 @@ end
stage_limiter!::StageLimiter
step_limiter!::StepLimiter
thread::Thread
lazy::Bool
end

TruncatedStacktraces.@truncate_stacktrace Vern8Cache 1
Expand Down Expand Up @@ -171,19 +177,20 @@ function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits},
recursivefill!(atmp, false)
rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype)
Vern8Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, utilde,
tmp, rtmp, atmp, tab, alg.stage_limiter!, alg.step_limiter!, alg.thread)
tmp, rtmp, atmp, tab, alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy)
end

struct Vern8ConstantCache{TabType} <: OrdinaryDiffEqConstantCache
tab::TabType
lazy::Bool
end

function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tab = Vern8Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
Vern8ConstantCache(tab)
Vern8ConstantCache(tab, alg.lazy)
end

@cache struct Vern9Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter,
Expand Down Expand Up @@ -214,6 +221,7 @@ end
stage_limiter!::StageLimiter
step_limiter!::StepLimiter
thread::Thread
lazy::Bool
end

TruncatedStacktraces.@truncate_stacktrace Vern9Cache 1
Expand Down Expand Up @@ -245,14 +253,16 @@ function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits},
rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype)
Vern9Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15,
k16, utilde, tmp, rtmp, atmp, alg.stage_limiter!, alg.step_limiter!,
alg.thread)
alg.thread, alg.lazy)
end

struct Vern9ConstantCache <: OrdinaryDiffEqConstantCache end
struct Vern9ConstantCache <: OrdinaryDiffEqConstantCache
lazy::Bool
end

function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
Vern9ConstantCache()
Vern9ConstantCache(alg.lazy)
end
Loading
Loading