Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign default ODE solver to be type-grounded and lazy #2184

Merged
merged 32 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
aa1c546
Redesign default ODE solver to be fully type-grounded
ChrisRackauckas Jan 1, 2024
dc7e96f
fix typo
oscardssmith May 9, 2024
4648f3a
typos
oscardssmith May 9, 2024
96147c9
rebase typos
oscardssmith May 9, 2024
b946f2e
typo
oscardssmith May 9, 2024
5f19a80
fix DelayDiffEq issue
oscardssmith May 13, 2024
51a84e0
bugfix
oscardssmith May 14, 2024
75ef069
use the default solver by default
oscardssmith May 14, 2024
0458d40
init current to 0
oscardssmith May 14, 2024
9bdc84b
add tests
oscardssmith May 14, 2024
59539d8
improve tests
oscardssmith May 15, 2024
66fd522
switching works other than functionwrappers
oscardssmith May 15, 2024
4901a02
it works
oscardssmith May 15, 2024
971fdf6
better test
oscardssmith May 15, 2024
868d836
better tests
oscardssmith May 15, 2024
217fb53
better tests
oscardssmith May 15, 2024
43e117a
fix composite chunksize
oscardssmith May 15, 2024
a5fa00f
fix test
oscardssmith May 15, 2024
b6ddc9e
forgot to save before commit
oscardssmith May 15, 2024
9e82e80
fix rober test
oscardssmith May 16, 2024
e046bbe
add FBDF tests
oscardssmith May 16, 2024
091623a
in place works
oscardssmith May 16, 2024
569a25e
don't bypass pipeline
oscardssmith May 17, 2024
28c7fe7
Update src/solve.jl
ChrisRackauckas May 17, 2024
3bd4a71
Update src/solve.jl
ChrisRackauckas May 17, 2024
ba897b9
fix test failures and re-enable autodiff
oscardssmith May 17, 2024
fdf946b
disable autodiff
oscardssmith May 17, 2024
b832f2f
fix non-identity mass matrix with default solver and test
oscardssmith May 20, 2024
4cef511
typo
oscardssmith May 20, 2024
cbf273a
fix chunksize inference
oscardssmith May 20, 2024
fc76006
add inferred test
oscardssmith May 20, 2024
d7193fb
fix tests
oscardssmith May 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Expand Down Expand Up @@ -75,7 +76,6 @@ RecursiveArrayTools = "2.36, 3"
Reexport = "1.0"
SciMLBase = "2.27.1"
SciMLOperators = "0.3"
SciMLStructures = "1"
SimpleNonlinearSolve = "1"
SimpleUnPack = "1"
SparseArrays = "1.9"
Expand Down
21 changes: 16 additions & 5 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ using LinearSolve, SimpleNonlinearSolve

using LineSearches

import EnumX

import FillArrays: Trues

# Interfaces
Expand Down Expand Up @@ -141,6 +143,7 @@ include("nlsolve/functional.jl")
include("nlsolve/newton.jl")

include("generic_rosenbrock.jl")
include("composite_algs.jl")

include("caches/basic_caches.jl")
include("caches/low_order_rk_caches.jl")
Expand Down Expand Up @@ -234,7 +237,6 @@ include("constants.jl")
include("solve.jl")
include("initdt.jl")
include("interp_func.jl")
include("composite_algs.jl")

import PrecompileTools

Expand All @@ -253,9 +255,14 @@ PrecompileTools.@compile_workload begin
Tsit5(), Vern7()
]

stiff = [Rosenbrock23(), Rosenbrock23(autodiff = false),
Rodas5P(), Rodas5P(autodiff = false),
FBDF(), FBDF(autodiff = false)
stiff = [Rosenbrock23(),
Rodas5P(),
FBDF()
]

default_ode = [
DefaultODEAlgorithm(autodiff=false),
DefaultODEAlgorithm()
]

autoswitch = [
Expand Down Expand Up @@ -284,7 +291,11 @@ PrecompileTools.@compile_workload begin
append!(solver_list, stiff)
end

if Preferences.@load_preference("PrecompileAutoSwitch", true)
if Preferences.@load_preference("PrecompileDefault", true)
append!(solver_list, default_ode)
end

if Preferences.@load_preference("PrecompileAutoSwitch", false)
append!(solver_list, autoswitch)
end

Expand Down
64 changes: 37 additions & 27 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@

isdtchangeable(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = true
isdtchangeable(alg::CompositeAlgorithm) = all(isdtchangeable.(alg.algs))

function isdtchangeable(alg::Union{LawsonEuler, NorsettEuler, LieEuler, MagnusGauss4,
CayleyEuler, ETDRK2, ETDRK3, ETDRK4, HochOst4, ETD2})
false
Expand Down Expand Up @@ -205,31 +206,35 @@
qmax_default(alg::DP8) = 6
qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8

function has_chunksize(alg::OrdinaryDiffEqAlgorithm)
return alg isa Union{OrdinaryDiffEqExponentialAlgorithm,
OrdinaryDiffEqAdaptiveExponentialAlgorithm,
OrdinaryDiffEqImplicitAlgorithm,
OrdinaryDiffEqAdaptiveImplicitAlgorithm,
DAEAlgorithm,
CompositeAlgorithm}
end
function get_chunksize(alg::OrdinaryDiffEqAlgorithm)
error("This algorithm does not have a chunk size defined.")
end
get_chunksize(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = Val(CS)
get_chunksize(alg::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = Val(CS)
get_chunksize(alg::DAEAlgorithm{CS, AD}) where {CS, AD} = Val(CS)
function get_chunksize(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD},
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}}) where {
CS,
AD
}
function get_chunksize(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS},
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS},
OrdinaryDiffEqImplicitAlgorithm{CS},
OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS},
DAEAlgorithm{CS},
CompositeAlgorithm{CS}}) where {CS}
Val(CS)
end

function get_chunksize_int(alg::OrdinaryDiffEqAlgorithm)
error("This algorithm does not have a chunk size defined.")
end
get_chunksize_int(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = CS
get_chunksize_int(alg::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = CS
get_chunksize_int(alg::DAEAlgorithm{CS, AD}) where {CS, AD} = CS
function get_chunksize_int(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD},
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}}) where {
CS,
AD
}
function get_chunksize_int(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS},
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS},
OrdinaryDiffEqImplicitAlgorithm{CS},
OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS},
DAEAlgorithm{CS},
CompositeAlgorithm{CS}}) where {CS}
CS
end
# get_chunksize(alg::CompositeAlgorithm) = get_chunksize(alg.algs[alg.current_alg])
Expand Down Expand Up @@ -965,10 +970,12 @@
alg_can_repeat_jac(alg::IRKC) = false

function unwrap_alg(alg::SciMLBase.DEAlgorithm, is_stiff)
iscomp = alg isa CompositeAlgorithm
if !iscomp
if !(alg isa CompositeAlgorithm)

Check warning on line 973 in src/alg_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/alg_utils.jl#L973

Added line #L973 was not covered by tests
return alg
elseif alg.choice_function isa AutoSwitchCache
if length(alg.algs) > 2
return alg.algs[alg.choice_function.current]

Check warning on line 977 in src/alg_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/alg_utils.jl#L976-L977

Added lines #L976 - L977 were not covered by tests
end
if is_stiff === nothing
throwautoswitch(alg)
end
Expand All @@ -985,18 +992,21 @@

function unwrap_alg(integrator, is_stiff)
alg = integrator.alg
iscomp = alg isa CompositeAlgorithm
if !iscomp
if !(alg isa CompositeAlgorithm)
return alg
elseif alg.choice_function isa AutoSwitchCache
if is_stiff === nothing
throwautoswitch(alg)
end
num = is_stiff ? 2 : 1
if num == 1
return alg.algs[1]
if length(alg.algs) > 2
alg.algs[alg.choice_function.current]

Check warning on line 999 in src/alg_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/alg_utils.jl#L998-L999

Added lines #L998 - L999 were not covered by tests
else
return alg.algs[2]
if is_stiff === nothing
throwautoswitch(alg)

Check warning on line 1002 in src/alg_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/alg_utils.jl#L1001-L1002

Added lines #L1001 - L1002 were not covered by tests
end
num = is_stiff ? 2 : 1
if num == 1
return alg.algs[1]

Check warning on line 1006 in src/alg_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/alg_utils.jl#L1004-L1006

Added lines #L1004 - L1006 were not covered by tests
else
return alg.algs[2]

Check warning on line 1008 in src/alg_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/alg_utils.jl#L1008

Added line #L1008 was not covered by tests
end
end
else
return _eval_index(identity, alg.algs, integrator.cache.current)
Expand Down
75 changes: 70 additions & 5 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2969,7 +2969,7 @@ Scientific Computing, 18 (1), pp. 1-22.
differential-algebraic problems. Computational mathematics (2nd revised ed.), Springer (1996)

#### ROS2PR, ROS2S, ROS3PR, Scholz4_7
-Rang, Joachim (2014): The Prothero and Robinson example:
-Rang, Joachim (2014): The Prothero and Robinson example:
Convergence studies for Runge-Kutta and Rosenbrock-Wanner methods.
https://doi.org/10.24355/dbbs.084-201408121139-0

Expand Down Expand Up @@ -3014,16 +3014,16 @@ University of Geneva, Switzerland.
https://doi.org/10.1016/j.cam.2015.03.010

#### ROS3PRL, ROS3PRL2
-Rang, Joachim (2014): The Prothero and Robinson example:
-Rang, Joachim (2014): The Prothero and Robinson example:
Convergence studies for Runge-Kutta and Rosenbrock-Wanner methods.
https://doi.org/10.24355/dbbs.084-201408121139-0

#### Rodas5P
- Steinebach G. Construction of Rosenbrock–Wanner method Rodas5P and numerical benchmarks within the Julia Differential Equations package.
- Steinebach G. Construction of Rosenbrock–Wanner method Rodas5P and numerical benchmarks within the Julia Differential Equations package.
In: BIT Numerical Mathematics, 63(2), 2023

#### Rodas23W, Rodas3P, Rodas5Pe, Rodas5Pr
- Steinebach G. Rosenbrock methods within OrdinaryDiffEq.jl - Overview, recent developments and applications -
- Steinebach G. Rosenbrock methods within OrdinaryDiffEq.jl - Overview, recent developments and applications -
Preprint 2024
https://github.com/hbrs-cse/RosenbrockMethods/blob/main/paper/JuliaPaper.pdf

Expand Down Expand Up @@ -3239,9 +3239,18 @@ end

#########################################

struct CompositeAlgorithm{T, F} <: OrdinaryDiffEqCompositeAlgorithm
struct CompositeAlgorithm{CS, T, F} <: OrdinaryDiffEqCompositeAlgorithm
algs::T
choice_function::F
function CompositeAlgorithm(algs::T, choice_function::F) where {T,F}
CS = 0
for alg in algs
if has_chunksize(alg)
CS = get_chunksize_int(alg)
end
end
new{CS, T, F}(algs, choice_function)
end
end

TruncatedStacktraces.@truncate_stacktrace CompositeAlgorithm 1
Expand All @@ -3250,6 +3259,62 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(CompositeAlgorithm)
end

mutable struct AutoSwitchCache{nAlg, sAlg, tolType, T}
count::Int
successive_switches::Int
nonstiffalg::nAlg
stiffalg::sAlg
is_stiffalg::Bool
maxstiffstep::Int
maxnonstiffstep::Int
nonstifftol::tolType
stifftol::tolType
dtfac::T
stiffalgfirst::Bool
switch_max::Int
current::Int
function AutoSwitchCache(count::Int,
successive_switches::Int,
nonstiffalg::nAlg,
stiffalg::sAlg,
is_stiffalg::Bool,
maxstiffstep::Int,
maxnonstiffstep::Int,
nonstifftol::tolType,
stifftol::tolType,
dtfac::T,
stiffalgfirst::Bool,
switch_max::Int,
current::Int=0) where {nAlg, sAlg, tolType, T}
new{nAlg, sAlg, tolType, T}(count,
successive_switches,
nonstiffalg,
stiffalg,
is_stiffalg,
maxstiffstep,
maxnonstiffstep,
nonstifftol,
stifftol,
dtfac,
stiffalgfirst,
switch_max,
current)
end

end

struct AutoSwitch{nAlg, sAlg, tolType, T}
nonstiffalg::nAlg
stiffalg::sAlg
maxstiffstep::Int
maxnonstiffstep::Int
nonstifftol::tolType
stifftol::tolType
dtfac::T
stiffalgfirst::Bool
switch_max::Int
end

################################################################################
"""
MEBDF2: Multistep Method
Expand Down
52 changes: 36 additions & 16 deletions src/caches/basic_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,26 @@ end

TruncatedStacktraces.@truncate_stacktrace CompositeCache 1

if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(CompositeCache)
mutable struct DefaultCache{T1, T2, T3, T4, T5, T6, A, F} <: OrdinaryDiffEqCache
args::A
choice_function::F
current::Int
cache1::T1
cache2::T2
cache3::T3
cache4::T4
cache5::T5
cache6::T6
function DefaultCache{T1, T2, T3, T4, T5, T6, F}(args, choice_function, current) where {T1, T2, T3, T4, T5, T6, F}
new{T1, T2, T3, T4, T5, T6, typeof(args), F}(args, choice_function, current)
end
end

function alg_cache(alg::CompositeAlgorithm{Tuple{T1, T2}, F}, u, rate_prototype,
::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev,
uprev2, f, t, dt, reltol, p, calck,
::Val{V}) where {T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits,
tTypeNoUnits}
caches = (
alg_cache(alg.algs[1], u, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits,
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)),
alg_cache(alg.algs[2], u, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits,
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)))
CompositeCache(caches, alg.choice_function, 1)
TruncatedStacktraces.@truncate_stacktrace DefaultCache 1

if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(CompositeCache)
Base.Experimental.silence!(DefaultCache)
end

function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -41,6 +43,24 @@ function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoU
CompositeCache(caches, alg.choice_function, 1)
end

function alg_cache(alg::CompositeAlgorithm{CS, Tuple{A1, A2, A3, A4, A5, A6}}, u,
rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits},
uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{V}) where {CS, V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, A1, A2, A3, A4, A5, A6}

args = (u, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt,
reltol, p, calck, Val(V))
argT = map(typeof, args)
T1 = Base.promote_op(alg_cache, A1, argT...)
T2 = Base.promote_op(alg_cache, A2, argT...)
T3 = Base.promote_op(alg_cache, A3, argT...)
T4 = Base.promote_op(alg_cache, A4, argT...)
T5 = Base.promote_op(alg_cache, A5, argT...)
T6 = Base.promote_op(alg_cache, A6, argT...)
DefaultCache{T1, T2, T3, T4, T5, T6, typeof(alg.choice_function)}(args, alg.choice_function, 1)
end

# map + closure approach doesn't infer
@generated function __alg_cache(algs::T, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev,
Expand Down
Loading
Loading