Skip to content

Commit

Permalink
Merge pull request #2184 from oscardssmith/os/default_solver-v2
Browse files Browse the repository at this point in the history
Redesign default ODE solver to be type-grounded and lazy
  • Loading branch information
ChrisRackauckas authored May 20, 2024
2 parents 6d35d93 + d7193fb commit b69b4ca
Show file tree
Hide file tree
Showing 14 changed files with 548 additions and 180 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Expand Down Expand Up @@ -75,7 +76,6 @@ RecursiveArrayTools = "2.36, 3"
Reexport = "1.0"
SciMLBase = "2.27.1"
SciMLOperators = "0.3"
SciMLStructures = "1"
SimpleNonlinearSolve = "1"
SimpleUnPack = "1"
SparseArrays = "1.9"
Expand Down
21 changes: 16 additions & 5 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ using LinearSolve, SimpleNonlinearSolve

using LineSearches

import EnumX

import FillArrays: Trues

# Interfaces
Expand Down Expand Up @@ -141,6 +143,7 @@ include("nlsolve/functional.jl")
include("nlsolve/newton.jl")

include("generic_rosenbrock.jl")
include("composite_algs.jl")

include("caches/basic_caches.jl")
include("caches/low_order_rk_caches.jl")
Expand Down Expand Up @@ -234,7 +237,6 @@ include("constants.jl")
include("solve.jl")
include("initdt.jl")
include("interp_func.jl")
include("composite_algs.jl")

import PrecompileTools

Expand All @@ -253,9 +255,14 @@ PrecompileTools.@compile_workload begin
Tsit5(), Vern7()
]

stiff = [Rosenbrock23(), Rosenbrock23(autodiff = false),
Rodas5P(), Rodas5P(autodiff = false),
FBDF(), FBDF(autodiff = false)
stiff = [Rosenbrock23(),
Rodas5P(),
FBDF()
]

default_ode = [
DefaultODEAlgorithm(autodiff=false),
DefaultODEAlgorithm()
]

autoswitch = [
Expand Down Expand Up @@ -284,7 +291,11 @@ PrecompileTools.@compile_workload begin
append!(solver_list, stiff)
end

if Preferences.@load_preference("PrecompileAutoSwitch", true)
if Preferences.@load_preference("PrecompileDefault", true)
append!(solver_list, default_ode)
end

if Preferences.@load_preference("PrecompileAutoSwitch", false)
append!(solver_list, autoswitch)
end

Expand Down
66 changes: 39 additions & 27 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ isimplicit(alg::CompositeAlgorithm) = any(isimplicit.(alg.algs))

isdtchangeable(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = true
isdtchangeable(alg::CompositeAlgorithm) = all(isdtchangeable.(alg.algs))

function isdtchangeable(alg::Union{LawsonEuler, NorsettEuler, LieEuler, MagnusGauss4,
CayleyEuler, ETDRK2, ETDRK3, ETDRK4, HochOst4, ETD2})
false
Expand Down Expand Up @@ -205,31 +206,35 @@ qmax_default(alg::CompositeAlgorithm) = minimum(qmax_default.(alg.algs))
qmax_default(alg::DP8) = 6
qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8

function has_chunksize(alg::OrdinaryDiffEqAlgorithm)
return alg isa Union{OrdinaryDiffEqExponentialAlgorithm,
OrdinaryDiffEqAdaptiveExponentialAlgorithm,
OrdinaryDiffEqImplicitAlgorithm,
OrdinaryDiffEqAdaptiveImplicitAlgorithm,
DAEAlgorithm,
CompositeAlgorithm}
end
function get_chunksize(alg::OrdinaryDiffEqAlgorithm)
error("This algorithm does not have a chunk size defined.")
end
get_chunksize(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = Val(CS)
get_chunksize(alg::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = Val(CS)
get_chunksize(alg::DAEAlgorithm{CS, AD}) where {CS, AD} = Val(CS)
function get_chunksize(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD},
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}}) where {
CS,
AD
}
function get_chunksize(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS},
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS},
OrdinaryDiffEqImplicitAlgorithm{CS},
OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS},
DAEAlgorithm{CS},
CompositeAlgorithm{CS}}) where {CS}
Val(CS)
end

function get_chunksize_int(alg::OrdinaryDiffEqAlgorithm)
error("This algorithm does not have a chunk size defined.")
end
get_chunksize_int(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = CS
get_chunksize_int(alg::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = CS
get_chunksize_int(alg::DAEAlgorithm{CS, AD}) where {CS, AD} = CS
function get_chunksize_int(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD},
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}}) where {
CS,
AD
}
function get_chunksize_int(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS},
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS},
OrdinaryDiffEqImplicitAlgorithm{CS},
OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS},
DAEAlgorithm{CS},
CompositeAlgorithm{CS}}) where {CS}
CS
end
# get_chunksize(alg::CompositeAlgorithm) = get_chunksize(alg.algs[alg.current_alg])
Expand Down Expand Up @@ -965,10 +970,12 @@ alg_can_repeat_jac(alg::OrdinaryDiffEqNewtonAdaptiveAlgorithm) = true
alg_can_repeat_jac(alg::IRKC) = false

function unwrap_alg(alg::SciMLBase.DEAlgorithm, is_stiff)
iscomp = alg isa CompositeAlgorithm
if !iscomp
if !(alg isa CompositeAlgorithm)
return alg
elseif alg.choice_function isa AutoSwitchCache
if length(alg.algs) > 2
return alg.algs[alg.choice_function.current]
end
if is_stiff === nothing
throwautoswitch(alg)
end
Expand All @@ -985,18 +992,21 @@ end

function unwrap_alg(integrator, is_stiff)
alg = integrator.alg
iscomp = alg isa CompositeAlgorithm
if !iscomp
if !(alg isa CompositeAlgorithm)
return alg
elseif alg.choice_function isa AutoSwitchCache
if is_stiff === nothing
throwautoswitch(alg)
end
num = is_stiff ? 2 : 1
if num == 1
return alg.algs[1]
if length(alg.algs) > 2
alg.algs[alg.choice_function.current]
else
return alg.algs[2]
if is_stiff === nothing
throwautoswitch(alg)
end
num = is_stiff ? 2 : 1
if num == 1
return alg.algs[1]
else
return alg.algs[2]
end
end
else
return _eval_index(identity, alg.algs, integrator.cache.current)
Expand Down Expand Up @@ -1071,3 +1081,5 @@ is_mass_matrix_alg(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false
is_mass_matrix_alg(alg::CompositeAlgorithm) = all(is_mass_matrix_alg, alg.algs)
is_mass_matrix_alg(alg::RosenbrockAlgorithm) = true
is_mass_matrix_alg(alg::NewtonAlgorithm) = !isesdirk(alg)
# hack for the default alg
is_mass_matrix_alg(alg::CompositeAlgorithm{<:Any, <:Tuple{Tsit5, Vern7, Rosenbrock23, Rodas5P, FBDF, FBDF}}) = true
70 changes: 65 additions & 5 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2969,7 +2969,7 @@ Scientific Computing, 18 (1), pp. 1-22.
differential-algebraic problems. Computational mathematics (2nd revised ed.), Springer (1996)
#### ROS2PR, ROS2S, ROS3PR, Scholz4_7
-Rang, Joachim (2014): The Prothero and Robinson example:
-Rang, Joachim (2014): The Prothero and Robinson example:
Convergence studies for Runge-Kutta and Rosenbrock-Wanner methods.
https://doi.org/10.24355/dbbs.084-201408121139-0
Expand Down Expand Up @@ -3014,16 +3014,16 @@ University of Geneva, Switzerland.
https://doi.org/10.1016/j.cam.2015.03.010
#### ROS3PRL, ROS3PRL2
-Rang, Joachim (2014): The Prothero and Robinson example:
-Rang, Joachim (2014): The Prothero and Robinson example:
Convergence studies for Runge-Kutta and Rosenbrock-Wanner methods.
https://doi.org/10.24355/dbbs.084-201408121139-0
#### Rodas5P
- Steinebach G. Construction of Rosenbrock–Wanner method Rodas5P and numerical benchmarks within the Julia Differential Equations package.
- Steinebach G. Construction of Rosenbrock–Wanner method Rodas5P and numerical benchmarks within the Julia Differential Equations package.
In: BIT Numerical Mathematics, 63(2), 2023
#### Rodas23W, Rodas3P, Rodas5Pe, Rodas5Pr
- Steinebach G. Rosenbrock methods within OrdinaryDiffEq.jl - Overview, recent developments and applications -
- Steinebach G. Rosenbrock methods within OrdinaryDiffEq.jl - Overview, recent developments and applications -
Preprint 2024
https://github.com/hbrs-cse/RosenbrockMethods/blob/main/paper/JuliaPaper.pdf
Expand Down Expand Up @@ -3239,9 +3239,13 @@ end

#########################################

struct CompositeAlgorithm{T, F} <: OrdinaryDiffEqCompositeAlgorithm
struct CompositeAlgorithm{CS, T, F} <: OrdinaryDiffEqCompositeAlgorithm
algs::T
choice_function::F
function CompositeAlgorithm(algs::T, choice_function::F) where {T,F}
CS = mapreduce(alg->has_chunksize(alg) ? get_chunksize_int(alg) : 0, max, algs)
new{CS, T, F}(algs, choice_function)
end
end

TruncatedStacktraces.@truncate_stacktrace CompositeAlgorithm 1
Expand All @@ -3250,6 +3254,62 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(CompositeAlgorithm)
end

mutable struct AutoSwitchCache{nAlg, sAlg, tolType, T}
count::Int
successive_switches::Int
nonstiffalg::nAlg
stiffalg::sAlg
is_stiffalg::Bool
maxstiffstep::Int
maxnonstiffstep::Int
nonstifftol::tolType
stifftol::tolType
dtfac::T
stiffalgfirst::Bool
switch_max::Int
current::Int
function AutoSwitchCache(count::Int,
successive_switches::Int,
nonstiffalg::nAlg,
stiffalg::sAlg,
is_stiffalg::Bool,
maxstiffstep::Int,
maxnonstiffstep::Int,
nonstifftol::tolType,
stifftol::tolType,
dtfac::T,
stiffalgfirst::Bool,
switch_max::Int,
current::Int=0) where {nAlg, sAlg, tolType, T}
new{nAlg, sAlg, tolType, T}(count,
successive_switches,
nonstiffalg,
stiffalg,
is_stiffalg,
maxstiffstep,
maxnonstiffstep,
nonstifftol,
stifftol,
dtfac,
stiffalgfirst,
switch_max,
current)
end

end

struct AutoSwitch{nAlg, sAlg, tolType, T}
nonstiffalg::nAlg
stiffalg::sAlg
maxstiffstep::Int
maxnonstiffstep::Int
nonstifftol::tolType
stifftol::tolType
dtfac::T
stiffalgfirst::Bool
switch_max::Int
end

################################################################################
"""
MEBDF2: Multistep Method
Expand Down
15 changes: 15 additions & 0 deletions src/cache_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@ function DiffEqBase.unwrap_cache(integrator::ODEIntegrator, is_stiff)
iscomp = alg isa CompositeAlgorithm
if !iscomp
return cache
elseif cache isa DefaultCache
current = integrator.cache.current
if current == 1
return cache.cache1
elseif current == 2
return cache.cache2
elseif current == 3
return cache.cache3
elseif current == 4
return cache.cache4
elseif current == 5
return cache.cache5
elseif current == 6
return cache.cache6
end
elseif alg.choice_function isa AutoSwitch
num = is_stiff ? 2 : 1
return cache.caches[num]
Expand Down
52 changes: 36 additions & 16 deletions src/caches/basic_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,26 @@ end

TruncatedStacktraces.@truncate_stacktrace CompositeCache 1

if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(CompositeCache)
mutable struct DefaultCache{T1, T2, T3, T4, T5, T6, A, F} <: OrdinaryDiffEqCache
args::A
choice_function::F
current::Int
cache1::T1
cache2::T2
cache3::T3
cache4::T4
cache5::T5
cache6::T6
function DefaultCache{T1, T2, T3, T4, T5, T6, F}(args, choice_function, current) where {T1, T2, T3, T4, T5, T6, F}
new{T1, T2, T3, T4, T5, T6, typeof(args), F}(args, choice_function, current)
end
end

function alg_cache(alg::CompositeAlgorithm{Tuple{T1, T2}, F}, u, rate_prototype,
::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev,
uprev2, f, t, dt, reltol, p, calck,
::Val{V}) where {T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits,
tTypeNoUnits}
caches = (
alg_cache(alg.algs[1], u, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits,
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)),
alg_cache(alg.algs[2], u, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits,
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)))
CompositeCache(caches, alg.choice_function, 1)
TruncatedStacktraces.@truncate_stacktrace DefaultCache 1

if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(CompositeCache)
Base.Experimental.silence!(DefaultCache)
end

function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -41,6 +43,24 @@ function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoU
CompositeCache(caches, alg.choice_function, 1)
end

function alg_cache(alg::CompositeAlgorithm{CS, Tuple{A1, A2, A3, A4, A5, A6}}, u,
rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits},
uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{V}) where {CS, V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, A1, A2, A3, A4, A5, A6}

args = (u, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt,
reltol, p, calck, Val(V))
argT = map(typeof, args)
T1 = Base.promote_op(alg_cache, A1, argT...)
T2 = Base.promote_op(alg_cache, A2, argT...)
T3 = Base.promote_op(alg_cache, A3, argT...)
T4 = Base.promote_op(alg_cache, A4, argT...)
T5 = Base.promote_op(alg_cache, A5, argT...)
T6 = Base.promote_op(alg_cache, A6, argT...)
DefaultCache{T1, T2, T3, T4, T5, T6, typeof(alg.choice_function)}(args, alg.choice_function, 1)
end

# map + closure approach doesn't infer
@generated function __alg_cache(algs::T, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev,
Expand Down
Loading

0 comments on commit b69b4ca

Please sign in to comment.