From a457fd85d7a33e357ebd37520d50a12d1befbff0 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 1 Jan 2024 15:02:34 -0500 Subject: [PATCH 1/5] Redesign default ODE solver to be fully type-grounded This accomplishes a few things: * Faster precompile times by precompiling less * Full inference of results when using the automatic algorithm * Hopefully faster load times by also precompiling less This is done the same way as * linearsolve https://github.com/SciML/LinearSolve.jl/pull/307 * nonlinearsolve https://github.com/SciML/NonlinearSolve.jl/pull/238 and is thus the more modern SciML way of doing it. It avoids dispatch by having a single algorithm that always generates the full cache and instead of dispatching between algorithms always branches for the choice. It turns out, the mechanism already existed for this in OrdinaryDiffEq... it's CompositeAlgorithm, the same bones as AutoSwitch! As such, this reuses quite a bit of code from the auto-switch algorithms but instead of just having two choices it (currently) has 6 that it chooses between. This means that it has stiffness detection and switching behavior, but also in a size-dependent way. There are still some optimizations to do though. Like LinearSolve.jl, it would be more efficient to have a way to initialize the caches to size zero and then have a way to re-initialize them to the correct size. Right now, it'll generate the same Jacobian N times and it shouldn't need to do that. --- Project.toml | 1 + src/OrdinaryDiffEq.jl | 21 +- src/alg_utils.jl | 31 +- src/algorithms.jl | 41 +- src/caches/basic_caches.jl | 16 +- src/caches/verner_caches.jl | 30 +- src/composite_algs.jl | 163 ++- src/perform_step/composite_perform_step.jl | 189 ++- src/perform_step/verner_rk_perform_step.jl | 42 +- src/solve.jl | 1364 ++++++++++---------- 10 files changed, 1071 insertions(+), 827 deletions(-) diff --git a/Project.toml b/Project.toml index 328d2dfdd2..3d637e3a54 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index 2cea945101..6dd1c1fa67 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -28,6 +28,8 @@ using LinearSolve, SimpleNonlinearSolve using LineSearches +import EnumX + import FillArrays: Trues # Interfaces @@ -141,6 +143,7 @@ include("nlsolve/functional.jl") include("nlsolve/newton.jl") include("generic_rosenbrock.jl") +include("composite_algs.jl") include("caches/basic_caches.jl") include("caches/low_order_rk_caches.jl") @@ -231,7 +234,6 @@ include("constants.jl") include("solve.jl") include("initdt.jl") include("interp_func.jl") -include("composite_algs.jl") import PrecompileTools @@ -250,9 +252,14 @@ PrecompileTools.@compile_workload begin Tsit5(), Vern7(), ] - stiff = [Rosenbrock23(), Rosenbrock23(autodiff = false), - Rodas5P(), Rodas5P(autodiff = false), - FBDF(), FBDF(autodiff = false), + stiff = [Rosenbrock23(), + Rodas5P(), + FBDF() + ] + + default_ode = [ + DefaultODEAlgorithm(autodiff=false), + DefaultODEAlgorithm() ] autoswitch = [ @@ -281,7 +288,11 @@ PrecompileTools.@compile_workload begin append!(solver_list, stiff) end - if Preferences.@load_preference("PrecompileAutoSwitch", true) + if Preferences.@load_preference("PrecompileDefault", true) + append!(solver_list, stiff) + end + + if Preferences.@load_preference("PrecompileAutoSwitch", false) append!(solver_list, autoswitch) end diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 33f7f0b02b..08fc16235f 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -164,6 +164,8 @@ isimplicit(alg::CompositeAlgorithm) = any(isimplicit.(alg.algs)) isdtchangeable(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = true isdtchangeable(alg::CompositeAlgorithm) = all(isdtchangeable.(alg.algs)) +isdtchangeable(alg::DefaultSolverAlgorithm) = true + function isdtchangeable(alg::Union{LawsonEuler, NorsettEuler, LieEuler, MagnusGauss4, CayleyEuler, ETDRK2, ETDRK3, ETDRK4, HochOst4, ETD2}) false @@ -176,12 +178,14 @@ ismultistep(alg::ETD2) = true isadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false isadaptive(alg::OrdinaryDiffEqAdaptiveAlgorithm) = true isadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = all(isadaptive.(alg.algs)) +isadaptive(alg::DefaultSolverAlgorithm) = true isadaptive(alg::DImplicitEuler) = true isadaptive(alg::DABDF2) = true isadaptive(alg::DFBDF) = true anyadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = isadaptive(alg) anyadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = any(isadaptive, alg.algs) +anyadaptive(alg::DefaultSolverAlgorithm) = true isautoswitch(alg) = false isautoswitch(alg::CompositeAlgorithm) = alg.choice_function isa AutoSwitch @@ -191,9 +195,11 @@ function qmin_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) end qmin_default(alg::CompositeAlgorithm) = maximum(qmin_default.(alg.algs)) qmin_default(alg::DP8) = 1 // 3 +qmin_default(alg::DefaultSolverAlgorithm) = 1 // 5 qmax_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = 10 qmax_default(alg::CompositeAlgorithm) = minimum(qmax_default.(alg.algs)) +qmax_default(alg::DefaultSolverAlgorithm) = 10 qmax_default(alg::DP8) = 6 qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8 @@ -277,7 +283,7 @@ end function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob) algs = map(alg -> DiffEqBase.prepare_alg(alg, u0, p, prob), alg.algs) - CompositeAlgorithm(algs, alg.choice_function) + CompositeAlgorithm(algs, alg.choice_function, Val(allowfallbacks(alg))) end # Extract AD type parameter from algorithm, returning as Val to ensure type stability for boolean options. @@ -351,7 +357,8 @@ function concrete_jac(alg::Union{ end alg_extrapolates(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false -alg_extrapolates(alg::CompositeAlgorithm) = any(alg_extrapolates.(alg.algs)) +alg_extrapolates(alg::CompositeAlgorithm) = error("any(alg_extrapolates.(alg.algs))") +alg_extrapolates(alg::DefaultSolverAlgorithm) = false alg_extrapolates(alg::ImplicitEuler) = true alg_extrapolates(alg::DImplicitEuler) = true alg_extrapolates(alg::DABDF2) = true @@ -695,6 +702,7 @@ alg_order(alg::Alshina6) = 6 alg_maximum_order(alg) = alg_order(alg) alg_maximum_order(alg::CompositeAlgorithm) = maximum(alg_order(x) for x in alg.algs) +alg_maximum_order(alg::DefaultSolverAlgorithm) = 7 alg_maximum_order(alg::ExtrapolationMidpointDeuflhard) = 2(alg.max_order + 1) alg_maximum_order(alg::ImplicitDeuflhardExtrapolation) = 2(alg.max_order + 1) alg_maximum_order(alg::ExtrapolationMidpointHairerWanner) = 2(alg.max_order + 1) @@ -829,6 +837,7 @@ function gamma_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) isadaptive(alg) ? 9 // 10 : 0 end gamma_default(alg::CompositeAlgorithm) = maximum(gamma_default, alg.algs) +gamma_default(alg::DefaultSolverAlgorithm) = 9 // 10 gamma_default(alg::RKC) = 8 // 10 gamma_default(alg::IRKC) = 8 // 10 function gamma_default(alg::ExtrapolationMidpointDeuflhard) @@ -949,14 +958,18 @@ function unwrap_alg(integrator, is_stiff) if !iscomp return alg elseif alg.choice_function isa AutoSwitchCache - if is_stiff === nothing - throwautoswitch(alg) - end - num = is_stiff ? 2 : 1 - if num == 1 - return alg.algs[1] + if alg.choice_function.algtrait isa DefaultODESolver + alg.algs[alg.choice_function.current] else - return alg.algs[2] + if is_stiff === nothing + throwautoswitch(alg) + end + num = is_stiff ? 2 : 1 + if num == 1 + return alg.algs[1] + else + return alg.algs[2] + end end else return _eval_index(identity, alg.algs, integrator.cache.current) diff --git a/src/algorithms.jl b/src/algorithms.jl index a30e544070..b11392b8ad 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -3100,17 +3100,56 @@ end ######################################### -struct CompositeAlgorithm{T, F} <: OrdinaryDiffEqCompositeAlgorithm +struct CompositeAlgorithm{Fallbacks, T, F} <: OrdinaryDiffEqCompositeAlgorithm algs::T choice_function::F + function CompositeAlgorithm(algs, choice_function, fallbacks_enabled::Val{X} = Val(true)) where X + new{X, typeof(algs), typeof(choice_function)}(algs, choice_function) + end end +allowfallbacks(::CompositeAlgorithm{Fallbacks, T, F}) where {Fallbacks, T, F} = Fallbacks + TruncatedStacktraces.@truncate_stacktrace CompositeAlgorithm 1 if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!) Base.Experimental.silence!(CompositeAlgorithm) end +mutable struct AutoSwitchCache{Trait, nAlg, sAlg, tolType, T} + algtrait::Trait + count::Int + successive_switches::Int + nonstiffalg::nAlg + stiffalg::sAlg + is_stiffalg::Bool + maxstiffstep::Int + maxnonstiffstep::Int + nonstifftol::tolType + stifftol::tolType + dtfac::T + stiffalgfirst::Bool + switch_max::Int + current::Int +end + +struct AutoSwitch{Trait, nAlg, sAlg, tolType, T} + algtrait::Trait + nonstiffalg::nAlg + stiffalg::sAlg + maxstiffstep::Int + maxnonstiffstep::Int + nonstifftol::tolType + stifftol::tolType + dtfac::T + stiffalgfirst::Bool + switch_max::Int +end + +struct DefaultODESolver end +const DefaultSolverAlgorithm = Union{CompositeAlgorithm{false, <:Tuple, <:AutoSwitch{DefaultODESolver}}, +CompositeAlgorithm{false, <:Tuple, <:AutoSwitchCache{DefaultODESolver}}} + ################################################################################ """ MEBDF2: Multistep Method diff --git a/src/caches/basic_caches.jl b/src/caches/basic_caches.jl index 4c627d7b71..658518a59c 100644 --- a/src/caches/basic_caches.jl +++ b/src/caches/basic_caches.jl @@ -4,7 +4,7 @@ abstract type OrdinaryDiffEqMutableCache <: OrdinaryDiffEqCache end struct ODEEmptyCache <: OrdinaryDiffEqConstantCache end struct ODEChunkCache{CS} <: OrdinaryDiffEqConstantCache end -mutable struct CompositeCache{T, F} <: OrdinaryDiffEqCache +mutable struct CompositeCache{Fallbacks, T, F} <: OrdinaryDiffEqCache caches::T choice_function::F current::Int @@ -16,11 +16,11 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!) Base.Experimental.silence!(CompositeCache) end -function alg_cache(alg::CompositeAlgorithm{Tuple{T1, T2}, F}, u, rate_prototype, +function alg_cache(alg::CompositeAlgorithm{Fallbacks, Tuple{T1, T2}, F}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{V}) where {T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits, + ::Val{V}) where {Fallbacks, T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} caches = (alg_cache(alg.algs[1], u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, @@ -28,16 +28,18 @@ function alg_cache(alg::CompositeAlgorithm{Tuple{T1, T2}, F}, u, rate_prototype, alg_cache(alg.algs[2], u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V))) - CompositeCache(caches, alg.choice_function, 1) + CompositeCache{Fallbacks, typeof(caches), typeof(alg.choice_function)}( + caches, alg.choice_function, 1) end -function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits}, +function alg_cache(alg::CompositeAlgorithm{Fallbacks}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{V}) where {V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + ::Val{V}) where {Fallbacks, V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} caches = __alg_cache(alg.algs, u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)) - CompositeCache(caches, alg.choice_function, 1) + CompositeCache{Fallbacks, typeof(caches), typeof(alg.choice_function)}( + caches, alg.choice_function, 1) end # map + closure approach doesn't infer diff --git a/src/caches/verner_caches.jl b/src/caches/verner_caches.jl index a87bedcccd..6a95f6188f 100644 --- a/src/caches/verner_caches.jl +++ b/src/caches/verner_caches.jl @@ -20,6 +20,7 @@ stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern6Cache 1 @@ -44,11 +45,12 @@ function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern6Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, utilde, tmp, rtmp, atmp, tab, - alg.stage_limiter!, alg.step_limiter!, alg.thread) + alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end struct Vern6ConstantCache{TabType} <: OrdinaryDiffEqConstantCache tab::TabType + lazy::Bool end function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, @@ -56,7 +58,7 @@ function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tab = Vern6Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - Vern6ConstantCache(tab) + Vern6ConstantCache(tab, alg.lazy) end @cache struct Vern7Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter, @@ -81,6 +83,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern7Cache 1 @@ -105,16 +108,18 @@ function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern7Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, utilde, tmp, rtmp, atmp, - alg.stage_limiter!, alg.step_limiter!, alg.thread) + alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end -struct Vern7ConstantCache <: OrdinaryDiffEqConstantCache end +struct Vern7ConstantCache <: OrdinaryDiffEqConstantCache + lazy::Bool +end function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - Vern7ConstantCache() + Vern7ConstantCache(alg.lazy) end @cache struct Vern8Cache{uType, rateType, uNoUnitsType, TabType, StageLimiter, StepLimiter, @@ -143,6 +148,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern8Cache 1 @@ -171,11 +177,12 @@ function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern8Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, utilde, - tmp, rtmp, atmp, tab, alg.stage_limiter!, alg.step_limiter!, alg.thread) + tmp, rtmp, atmp, tab, alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end struct Vern8ConstantCache{TabType} <: OrdinaryDiffEqConstantCache tab::TabType + lazy::Bool end function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, @@ -183,7 +190,7 @@ function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tab = Vern8Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - Vern8ConstantCache(tab) + Vern8ConstantCache(tab, alg.lazy) end @cache struct Vern9Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter, @@ -214,6 +221,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern9Cache 1 @@ -245,14 +253,16 @@ function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits}, rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern9Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16, utilde, tmp, rtmp, atmp, alg.stage_limiter!, alg.step_limiter!, - alg.thread) + alg.thread, alg.lazy) end -struct Vern9ConstantCache <: OrdinaryDiffEqConstantCache end +struct Vern9ConstantCache <: OrdinaryDiffEqConstantCache + lazy::Bool +end function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - Vern9ConstantCache() + Vern9ConstantCache(alg.lazy) end diff --git a/src/composite_algs.jl b/src/composite_algs.jl index a6474fa00b..e82e2e54ca 100644 --- a/src/composite_algs.jl +++ b/src/composite_algs.jl @@ -1,34 +1,12 @@ -mutable struct AutoSwitchCache{nAlg, sAlg, tolType, T} - count::Int - successive_switches::Int - nonstiffalg::nAlg - stiffalg::sAlg - is_stiffalg::Bool - maxstiffstep::Int - maxnonstiffstep::Int - nonstifftol::tolType - stifftol::tolType - dtfac::T - stiffalgfirst::Bool - switch_max::Int -end +### AutoSwitch +### Designed to switch between two solvers, stiff and non-stiff -struct AutoSwitch{nAlg, sAlg, tolType, T} - nonstiffalg::nAlg - stiffalg::sAlg - maxstiffstep::Int - maxnonstiffstep::Int - nonstifftol::tolType - stifftol::tolType - dtfac::T - stiffalgfirst::Bool - switch_max::Int -end -function AutoSwitch(nonstiffalg, stiffalg; maxstiffstep = 10, maxnonstiffstep = 3, +function AutoSwitch(nonstiffalg, stiffalg, algtrait = nothing; + maxstiffstep = 10, maxnonstiffstep = 3, nonstifftol = 9 // 10, stifftol = 9 // 10, dtfac = 2, stiffalgfirst = false, switch_max = 5) - AutoSwitch(nonstiffalg, stiffalg, maxstiffstep, maxnonstiffstep, + AutoSwitch(algtrait, nonstiffalg, stiffalg, maxstiffstep, maxnonstiffstep, promote(nonstifftol, stifftol)..., dtfac, stiffalgfirst, switch_max) end @@ -52,7 +30,11 @@ function is_stiff(integrator, alg, ntol, stol, is_stiffalg) end function (AS::AutoSwitchCache)(integrator) - integrator.iter == 0 && return Int(AS.stiffalgfirst) + 1 + if AS.current == 0 + AS.current = Int(AS.stiffalgfirst) + 1 + return AS.current + end + dt = integrator.dt # Successive stiffness test positives are counted by a positive integer, # and successive stiffness test negatives are counted by a negative integer @@ -67,12 +49,18 @@ function (AS::AutoSwitchCache)(integrator) integrator.dt = dt / AS.dtfac AS.is_stiffalg = false end - return Int(AS.is_stiffalg) + 1 + AS.current = Int(AS.is_stiffalg) + 1 + return AS.current end -function AutoAlgSwitch(nonstiffalg, stiffalg; kwargs...) - AS = AutoSwitch(nonstiffalg, stiffalg; kwargs...) - CompositeAlgorithm((nonstiffalg, stiffalg), AS) +function AutoAlgSwitch(nonstiffalg::OrdinaryDiffEqAlgorithm, stiffalg::OrdinaryDiffEqAlgorithm, algtrait = nothing; kwargs...) + AS = AutoSwitch(nonstiffalg, stiffalg, algtrait; kwargs...) + CompositeAlgorithm((nonstiffalg, stiffalg), AS, Val(false)) +end + +function AutoAlgSwitch(nonstiffalg::Tuple, stiffalg::Tuple, algtrait; kwargs...) + AS = AutoSwitch(nonstiffalg, stiffalg, algtrait; kwargs...) + CompositeAlgorithm((nonstiffalg..., stiffalg...), AS, Val(false)) end AutoTsit5(alg; kwargs...) = AutoAlgSwitch(Tsit5(), alg; kwargs...) @@ -81,3 +69,114 @@ AutoVern6(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern6(lazy = lazy), alg; AutoVern7(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern7(lazy = lazy), alg; kwargs...) AutoVern8(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern8(lazy = lazy), alg; kwargs...) AutoVern9(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern9(lazy = lazy), alg; kwargs...) + +### Default ODE Solver + +EnumX.@enumx DefaultSolverChoice begin + Tsit5 = 1 + Vern7 = 2 + Rosnebrock23 = 3 + Rodas5P = 4 + FBDF = 5 + KrylovFBDF = 6 +end + +const NUM_NONSTIFF = 2 +const NUM_STIFF = 4 +const LOW_TOL = 1e-6 +const MED_TOL = 1e-2 +const EXTREME_TOL = 1e-9 +const SMALLSIZE = 50 +const MEDIUMSIZE = 500 +const STABILITY_SIZES = (alg_stability_size(Tsit5()), alg_stability_size(Vern7())) +const DEFAULTBETA2S = (beta2_default(Tsit5()), beta2_default(Vern7()), beta2_default(Rosenbrock23()), beta2_default(Rodas5P()), beta2_default(FBDF()), beta2_default(FBDF())) +const DEFAULTBETA1S = (beta1_default(Tsit5(),DEFAULTBETA2S[1]), beta1_default(Vern7(),DEFAULTBETA2S[2]), + beta1_default(Rosenbrock23(), DEFAULTBETA2S[3]), beta1_default(Rodas5P(), DEFAULTBETA2S[4]), + beta1_default(FBDF(), DEFAULTBETA2S[5]), beta1_default(FBDF(), DEFAULTBETA2S[6])) + +callbacks_exists(integrator) = !isempty(integrator.opts.callbacks) +current_nonstiff(current) = ifelse(current <= NUM_NONSTIFF,current,current-NUM_STIFF) + +function DefaultODEAlgorithm(; lazy = true, stiffalgfirst = false, kwargs...) + nonstiff = (Tsit5(), Vern7(lazy = lazy)) + stiff = (Rosenbrock23(;kwargs...), Rodas5P(;kwargs...), FBDF(;kwargs...), FBDF(;linsolve = LinearSolve.KrylovJL_GMRES())) + AutoAlgSwitch(nonstiff, stiff, DefaultODESolver(); stiffalgfirst) +end + +function is_stiff(integrator, alg, ntol, stol, is_stiffalg, current) + eigen_est, dt = integrator.eigen_est, integrator.dt + stiffness = abs(eigen_est * dt / STABILITY_SIZES[nonstiffchoice(integrator.opts.reltol)]) # `abs` here is just for safety + tol = is_stiffalg ? stol : ntol + os = oneunit(stiffness) + bool = stiffness > os * tol + + if !bool + integrator.alg.choice_function.successive_switches += 1 + else + integrator.alg.choice_function.successive_switches = 0 + end + + integrator.do_error_check = (integrator.alg.choice_function.successive_switches > + integrator.alg.choice_function.switch_max || !bool) || + is_stiffalg + bool +end + +function nonstiffchoice(reltol) + x = if reltol < LOW_TOL + DefaultSolverChoice.Vern7 + else + DefaultSolverChoice.Tsit5 + end + Int(x) +end + +function stiffchoice(reltol, len) + x = if len > MEDIUMSIZE + DefaultSolverChoice.KrylovFBDF + elseif len > SMALLSIZE + DefaultSolverChoice.FBDF + else + if reltol < LOW_TOL + DefaultSolverChoice.Rodas5P + else + DefaultSolverChoice.Rosnebrock23 + end + end + Int(x) +end + +function (AS::AutoSwitchCache{DefaultODESolver})(integrator) + + len = length(integrator.u) + reltol = integrator.opts.reltol + + # Chooose the starting method + if AS.current == 0 + choice = if AS.stiffalgfirst || integrator.f.mass_matrix != I + stiffchoice(reltol, len) + else + nonstiffchoice(reltol) + end + AS.current = choice + return AS.current + end + + dt = integrator.dt + # Successive stiffness test positives are counted by a positive integer, + # and successive stiffness test negatives are counted by a negative integer + AS.count = is_stiff(integrator, AS.nonstiffalg, AS.nonstifftol, AS.stifftol, + AS.is_stiffalg, AS.current) ? + AS.count < 0 ? 1 : AS.count + 1 : + AS.count > 0 ? -1 : AS.count - 1 + if (!AS.is_stiffalg && AS.count > AS.maxstiffstep) + integrator.dt = dt * AS.dtfac + AS.is_stiffalg = true + AS.current = stiffchoice(reltol, len) + elseif (AS.is_stiffalg && AS.count < -AS.maxnonstiffstep) + integrator.dt = dt / AS.dtfac + AS.is_stiffalg = false + AS.current = nonstiffchoice(reltol) + end + return AS.current +end diff --git a/src/perform_step/composite_perform_step.jl b/src/perform_step/composite_perform_step.jl index ae872fd521..d2bc4ee5a1 100644 --- a/src/perform_step/composite_perform_step.jl +++ b/src/perform_step/composite_perform_step.jl @@ -1,39 +1,4 @@ -#= - -Maybe do generated functions to reduce dispatch times? - -f(x) = x -g(x,i) = f(x[i]) -g{i}(x,::Type{Val{i}}) = f(x[i]) -@generated function gg(tup::Tuple, num) - N = length(tup.parameters) - :(@nif $(N+1) i->(i == num) i->(f(tup[i])) i->error("unreachable")) - end -h(i) = g((1,1.0,"foo"), i) -h2{i}(::Type{Val{i}}) = g((1,1.0,"foo"), Val{i}) -h3(i) = gg((1,1.0,"foo"), i) -@benchmark h(1) -mean time: 31.822 ns (0.00% GC) -@benchmark h2(Val{1}) -mean time: 1.585 ns (0.00% GC) -@benchmark h3(1) -mean time: 6.423 ns (0.00% GC) - -@generated function foo(tup::Tuple, num) - N = length(tup.parameters) - :(@nif $(N+1) i->(i == num) i->(tup[i]) i->error("unreachable")) -end - -@code_typed foo((1,1.0), 1) - -@generated function perform_step!(integrator, cache::CompositeCache, repeat_step=false) - N = length(cache.parameters) - :(@nif $(N+1) i->(i == num) i->(tup[i]) i->error("unreachable")) -end - -=# - -function initialize!(integrator, cache::CompositeCache) +function initialize!(integrator, cache::CompositeCache{Fallbacks}) where Fallbacks cache.current = cache.choice_function(integrator) if cache.current == 1 initialize!(integrator, @inbounds(cache.caches[1])) @@ -42,7 +7,27 @@ function initialize!(integrator, cache::CompositeCache) # the controller was initialized by default for integrator.alg.algs[1] reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], integrator.alg.algs[2]) - else + elseif cache.current == 3 + initialize!(integrator, @inbounds(cache.caches[3])) + # the controller was initialized by default for integrator.alg.algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], + integrator.alg.algs[3]) + elseif cache.current == 4 + initialize!(integrator, @inbounds(cache.caches[4])) + # the controller was initialized by default for integrator.alg.algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], + integrator.alg.algs[4]) + elseif cache.current == 5 + initialize!(integrator, @inbounds(cache.caches[5])) + # the controller was initialized by default for integrator.alg.algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], + integrator.alg.algs[5]) + elseif cache.current == 6 + initialize!(integrator, @inbounds(cache.caches[6])) + # the controller was initialized by default for integrator.alg.algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], + integrator.alg.algs[6]) + elseif Fallbacks initialize!(integrator, @inbounds(cache.caches[cache.current])) reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], integrator.alg.algs[cache.current]) @@ -50,7 +35,7 @@ function initialize!(integrator, cache::CompositeCache) resize!(integrator.k, integrator.kshortsize) end -function initialize!(integrator, cache::CompositeCache{Tuple{T1, T2}, F}) where {T1, T2, F} +function initialize!(integrator, cache::CompositeCache{false, Tuple{T1, T2}, F}) where {T1, T2, F} cache.current = cache.choice_function(integrator) if cache.current == 1 initialize!(integrator, @inbounds(cache.caches[1])) @@ -75,17 +60,25 @@ function ensure_behaving_adaptivity!(integrator, cache::CompositeCache) end end -function perform_step!(integrator, cache::CompositeCache, repeat_step = false) +function perform_step!(integrator, cache::CompositeCache{Fallbacks}, repeat_step = false) where Fallbacks if cache.current == 1 perform_step!(integrator, @inbounds(cache.caches[1]), repeat_step) elseif cache.current == 2 perform_step!(integrator, @inbounds(cache.caches[2]), repeat_step) - else + elseif cache.current == 3 + perform_step!(integrator, @inbounds(cache.caches[3]), repeat_step) + elseif cache.current == 4 + perform_step!(integrator, @inbounds(cache.caches[4]), repeat_step) + elseif cache.current == 5 + perform_step!(integrator, @inbounds(cache.caches[5]), repeat_step) + elseif cache.current == 6 + perform_step!(integrator, @inbounds(cache.caches[6]), repeat_step) + elseif Fallbacks perform_step!(integrator, @inbounds(cache.caches[cache.current]), repeat_step) end end -function perform_step!(integrator, cache::CompositeCache{Tuple{T1, T2}, F}, +function perform_step!(integrator, cache::CompositeCache{false, Tuple{T1, T2}, F}, repeat_step = false) where {T1, T2, F} if cache.current == 1 perform_step!(integrator, @inbounds(cache.caches[1]), repeat_step) @@ -97,7 +90,7 @@ end choose_algorithm!(integrator, cache::OrdinaryDiffEqCache) = nothing function choose_algorithm!(integrator, - cache::CompositeCache{Tuple{T1, T2}, F}) where {T1, T2, F} + cache::CompositeCache{false, Tuple{T1, T2}, F}) where {T1, T2, F} new_current = cache.choice_function(integrator) old_current = cache.current @inbounds if new_current != old_current @@ -121,36 +114,111 @@ function choose_algorithm!(integrator, end end -function choose_algorithm!(integrator, cache::CompositeCache) +function choose_algorithm!(integrator, cache::CompositeCache{Fallbacks, T, F}) where {Fallbacks, T, F} + new_current = cache.choice_function(integrator) + old_current = cache.current + !Fallbacks && error("Hitting fallbacks!") + @inbounds if new_current != old_current + cache.current = new_current + initialize!(integrator, @inbounds(cache.caches[new_current])) + + controller.beta2 = beta2_default(alg2) + controller.beta1 = beta2_default(alg2) + DEFAULTBETA2S + + reset_alg_dependent_opts!(integrator, integrator.alg.algs[old_current], + integrator.alg.algs[new_current]) + transfer_cache!(integrator, integrator.cache.caches[old_current], + integrator.cache.caches[new_current]) + end +end + +function choose_algorithm!(integrator, cache::CompositeCache{Fallbacks, T, <:AutoSwitchCache{DefaultODESolver}}) where {Fallbacks, T} new_current = cache.choice_function(integrator) old_current = cache.current @inbounds if new_current != old_current cache.current = new_current if new_current == 1 - initialize!(integrator, @inbounds(cache.caches[1])) + initialize!(integrator, @inbounds(cache.caches[1])); nothing elseif new_current == 2 - initialize!(integrator, @inbounds(cache.caches[2])) + initialize!(integrator, @inbounds(cache.caches[2])); nothing + elseif new_current == 3 + initialize!(integrator, @inbounds(cache.caches[3])); nothing + elseif new_current == 4 + initialize!(integrator, @inbounds(cache.caches[4])); nothing + elseif new_current == 5 + initialize!(integrator, @inbounds(cache.caches[5])); nothing + elseif new_current == 6 + initialize!(integrator, @inbounds(cache.caches[6])); nothing else - initialize!(integrator, @inbounds(cache.caches[new_current])) + error("Unrachable reached. Report this error") end - if old_current == 1 && new_current == 2 - reset_alg_dependent_opts!(integrator, integrator.alg.algs[1], - integrator.alg.algs[2]) - transfer_cache!(integrator, integrator.cache.caches[1], - integrator.cache.caches[2]) - elseif old_current == 2 && new_current == 1 - reset_alg_dependent_opts!(integrator, integrator.alg.algs[2], - integrator.alg.algs[1]) - transfer_cache!(integrator, integrator.cache.caches[2], - integrator.cache.caches[1]) + + # dtchangable, qmin_default, qmax_default, and isadaptive ignored since all same + integrator.opts.controller.beta1 = DEFAULTBETA1S[new_current] + integrator.opts.controller.beta2 = DEFAULTBETA2S[new_current] + end +end + +#= +""" +function choose_algorithm!(integrator, cache::CompositeCache{Fallbacks}) where Fallbacks + new_current = cache.choice_function(integrator) + old_current = cache.current + @inbounds if new_current != old_current + cache.current = new_current + initialize!(integrator, @inbounds(cache.caches[new_current])) + reset_alg_dependent_opts!(integrator, integrator.alg.algs[old_current], + integrator.alg.algs[new_current]) + transfer_cache!(integrator, integrator.cache.caches[old_current], + integrator.cache.caches[new_current]) + end +end +""" +@generated function choose_algorithm!(integrator, cache::CompositeCache{Fallbacks, T, F}) where {Fallbacks, T, F} + initialize_ex = :() + for idx in 1:length(T.types) + newex = quote + initialize!(integrator, @inbounds(cache.caches[$idx])) + end + initialize_ex = if initialize_ex == :() + Expr(:elseif, :(new_current == $idx)), newex, + :(error("Algorithm Choice not Allowed")) else - reset_alg_dependent_opts!(integrator, integrator.alg.algs[old_current], - integrator.alg.algs[new_current]) - transfer_cache!(integrator, integrator.cache.caches[old_current], - integrator.cache.caches[new_current]) + Expr(:elseif, :(new_current == $idx), newex, initialize_ex) + end + end + initialize_ex = Expr(:if, initialize_ex.args...) + + swap_ex = :() + for idx in 1:length(T.types), idx2 in 1:length(T.types) + new_swap_ex = quote + reset_alg_dependent_opts!(integrator, integrator.alg.algs[$idx], + integrator.alg.algs[$idx2]) + transfer_cache!(integrator, integrator.cache.caches[$idx], + integrator.cache.caches[$idx2]) + end + swap_ex = if swap_ex == :() + Expr(:elseif, :(old_current == $idx && new_current == $idx2)), new_swap_ex, + :(error("Algorithm Choice not Allowed")) + else + Expr(:elseif, :(new_current == $idx), swap_ex, swap_ex) + end + end + swap_ex = Expr(:if, swap_ex.args...) + + + quote + new_current = cache.choice_function(integrator) + old_current = cache.current + @inbounds if new_current != old_current + cache.current = new_current + $initialize_ex + $swap_ex end end end +=# """ If no user default, then this will change the default to the defaults @@ -170,6 +238,7 @@ function reset_alg_dependent_opts!(integrator, alg1, alg2) integrator.opts.qmax == qmax_default(alg2) end reset_alg_dependent_opts!(integrator.opts.controller, alg1, alg2) + nothing end # Write how to transfer the cache variables from one cache to the other diff --git a/src/perform_step/verner_rk_perform_step.jl b/src/perform_step/verner_rk_perform_step.jl index 5f2d876354..283cf7b95a 100644 --- a/src/perform_step/verner_rk_perform_step.jl +++ b/src/perform_step/verner_rk_perform_step.jl @@ -2,7 +2,7 @@ function initialize!(integrator, cache::Vern6ConstantCache) integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) # Pre-start fsal integrator.stats.nf += 1 alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) + cache.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -13,7 +13,7 @@ function initialize!(integrator, cache::Vern6ConstantCache) end integrator.k[integrator.kshortsize] = integrator.fsallast - if !alg.lazy + if !cache.lazy @inbounds for i in 10:12 integrator.k[i] = zero(integrator.fsalfirst) end @@ -63,7 +63,7 @@ end integrator.k[9] = k9 alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c10, a1001, a1004, a1005, a1006, a1007, a1008, a1009, c11, a1101, a1104, a1105, a1106, a1107, a1108, a1109, a1110, c12, a1201, a1204, a1205, a1206, a1207, a1208, a1209, a1210, a1211 = cache.tab.extra @@ -86,7 +86,7 @@ end function initialize!(integrator, cache::Vern6Cache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) + cache.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) integrator.fsalfirst = cache.k1 integrator.fsallast = cache.k9 @unpack k = integrator @@ -101,7 +101,7 @@ function initialize!(integrator, cache::Vern6Cache) k[8] = cache.k8 k[9] = cache.k9 # Set the pointers - if !alg.lazy + if !cache.lazy k[10] = similar(cache.k1) k[11] = similar(cache.k1) k[12] = similar(cache.k1) @@ -174,7 +174,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c10, a1001, a1004, a1005, a1006, a1007, a1008, a1009, c11, a1101, a1104, a1105, a1106, a1107, a1108, a1109, a1110, c12, a1201, a1204, a1205, a1206, a1207, a1208, a1209, a1210, a1211 = cache.tab.extra @@ -206,7 +206,7 @@ end function initialize!(integrator, cache::Vern7ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -267,7 +267,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @OnDemandTableauExtract Vern7ExtraStages T T2 @@ -302,7 +302,7 @@ function initialize!(integrator, cache::Vern7Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) resize!(k, integrator.kshortsize) k[1] = k1 k[2] = k2 @@ -315,7 +315,7 @@ function initialize!(integrator, cache::Vern7Cache) k[9] = k9 k[10] = k10 # Setup pointers - if !alg.lazy + if !cache.lazy k[11] = similar(cache.k1) k[12] = similar(cache.k1) k[13] = similar(cache.k1) @@ -406,7 +406,7 @@ end integrator.EEst = integrator.opts.internalnorm(atmp, t) end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack tmp = cache @@ -462,7 +462,7 @@ end function initialize!(integrator, cache::Vern8ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) + cache.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -539,7 +539,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c14, a1401, a1406, a1407, a1408, a1409, a1410, a1411, a1412, c15, a1501, a1506, a1507, a1508, a1509, a1510, a1511, a1512, a1514, c16, a1601, a1606, a1607, a1608, a1609, a1610, a1611, a1612, a1614, a1615, c17, a1701, a1706, a1707, a1708, a1709, a1710, a1711, a1712, a1714, a1715, a1716, c18, a1801, a1806, a1807, a1808, a1809, a1810, a1811, a1812, a1814, a1815, a1816, a1817, c19, a1901, a1906, a1907, a1908, a1909, a1910, a1911, a1912, a1914, a1915, a1916, a1917, c20, a2001, a2006, a2007, a2008, a2009, a2010, a2011, a2012, a2014, a2015, a2016, a2017, c21, a2101, a2106, a2107, a2108, a2109, a2110, a2111, a2112, a2114, a2115, a2116, a2117 = cache.tab.extra @@ -587,7 +587,7 @@ function initialize!(integrator, cache::Vern8Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) + cache.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) resize!(k, integrator.kshortsize) k[1] = k1 k[2] = k2 @@ -603,7 +603,7 @@ function initialize!(integrator, cache::Vern8Cache) k[12] = k12 k[13] = k13 # Setup pointers - if !alg.lazy + if !cache.lazy for i in 14:21 k[i] = similar(cache.k1) end @@ -709,7 +709,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c14, a1401, a1406, a1407, a1408, a1409, a1410, a1411, a1412, c15, a1501, a1506, a1507, a1508, a1509, a1510, a1511, a1512, a1514, c16, a1601, a1606, a1607, a1608, a1609, a1610, a1611, a1612, a1614, a1615, c17, a1701, a1706, a1707, a1708, a1709, a1710, a1711, a1712, a1714, a1715, a1716, c18, a1801, a1806, a1807, a1808, a1809, a1810, a1811, a1812, a1814, a1815, a1816, a1817, c19, a1901, a1906, a1907, a1908, a1909, a1910, a1911, a1912, a1914, a1915, a1916, a1917, c20, a2001, a2006, a2007, a2008, a2009, a2010, a2011, a2012, a2014, a2015, a2016, a2017, c21, a2101, a2106, a2107, a2108, a2109, a2110, a2111, a2112, a2114, a2115, a2116, a2117 = cache.tab.extra @@ -796,7 +796,7 @@ end function initialize!(integrator, cache::Vern9ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -880,7 +880,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @OnDemandTableauExtract Vern9ExtraStages T T2 @@ -940,7 +940,7 @@ function initialize!(integrator, cache::Vern9Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) resize!(k, integrator.kshortsize) # k2, k3,k4,k5,k6,k7 are not used in the code (not even in interpolations), we dont need their pointers. # So we mapped k[2] (from integrator) with k8 (from cache), k[3] with k9 and so on. @@ -955,7 +955,7 @@ function initialize!(integrator, cache::Vern9Cache) k[9] = k15 k[10] = k16 # Setup pointers - if !alg.lazy + if !cache.lazy for i in 11:20 k[i] = similar(cache.k1) end @@ -1082,7 +1082,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack tmp = cache diff --git a/src/solve.jl b/src/solve.jl index c0a01ad126..7e037ea999 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,682 +1,682 @@ -function DiffEqBase.__solve(prob::Union{DiffEqBase.AbstractODEProblem, - DiffEqBase.AbstractDAEProblem}, - alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, args...; - kwargs...) - integrator = DiffEqBase.__init(prob, alg, args...; kwargs...) - solve!(integrator) - integrator.sol -end - -function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem, - DiffEqBase.AbstractDAEProblem}, - alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, - timeseries_init = (), - ts_init = (), - ks_init = (), - recompile::Type{Val{recompile_flag}} = Val{true}; - saveat = (), - tstops = (), - d_discontinuities = (), - save_idxs = nothing, - save_everystep = isempty(saveat), - save_on = true, - save_start = save_everystep || isempty(saveat) || - saveat isa Number || prob.tspan[1] in saveat, - save_end = nothing, - callback = nothing, - dense = save_everystep && - !(alg isa Union{DAEAlgorithm, FunctionMap}) && - isempty(saveat), - calck = (callback !== nothing && callback !== CallbackSet()) || - (dense) || !isempty(saveat), # and no dense output - dt = alg isa FunctionMap && isempty(tstops) ? - eltype(prob.tspan)(1) : eltype(prob.tspan)(0), - dtmin = eltype(prob.tspan)(0), - dtmax = eltype(prob.tspan)((prob.tspan[end] - prob.tspan[1])), - force_dtmin = false, - adaptive = anyadaptive(alg), - gamma = gamma_default(alg), - abstol = nothing, - reltol = nothing, - qmin = qmin_default(alg), - qmax = qmax_default(alg), - qsteady_min = qsteady_min_default(alg), - qsteady_max = qsteady_max_default(alg), - beta1 = nothing, - beta2 = nothing, - qoldinit = anyadaptive(alg) ? 1 // 10^4 : 0, - controller = nothing, - fullnormalize = true, - failfactor = 2, - maxiters = anyadaptive(alg) ? 1000000 : typemax(Int), - internalnorm = ODE_DEFAULT_NORM, - internalopnorm = LinearAlgebra.opnorm, - isoutofdomain = ODE_DEFAULT_ISOUTOFDOMAIN, - unstable_check = ODE_DEFAULT_UNSTABLE_CHECK, - verbose = true, - timeseries_errors = true, - dense_errors = false, - advance_to_tstop = false, - stop_at_next_tstop = false, - initialize_save = true, - progress = false, - progress_steps = 1000, - progress_name = "ODE", - progress_message = ODE_DEFAULT_PROG_MESSAGE, - progress_id = gensym("OrdinaryDiffEq"), - userdata = nothing, - allow_extrapolation = alg_extrapolates(alg), - initialize_integrator = true, - alias_u0 = false, - alias_du0 = false, - initializealg = DefaultInit(), - kwargs...) where {recompile_flag} - if prob isa DiffEqBase.AbstractDAEProblem && alg isa OrdinaryDiffEqAlgorithm - error("You cannot use an ODE Algorithm with a DAEProblem") - end - - if prob isa DiffEqBase.AbstractODEProblem && alg isa DAEAlgorithm - error("You cannot use an DAE Algorithm with a ODEProblem") - end - - if prob isa DiffEqBase.ODEProblem - if !(prob.f isa DiffEqBase.DynamicalODEFunction) && alg isa PartitionedAlgorithm - error("You can not use a solver designed for partitioned ODE with this problem. Please choose a solver suitable for your problem") - end - end - - if prob.f isa DynamicalODEFunction && prob.f.mass_matrix isa Tuple - if any(mm != I for mm in prob.f.mass_matrix) - error("This solver is not able to use mass matrices.") - end - elseif !(prob isa DiscreteProblem) && - !(prob isa DiffEqBase.AbstractDAEProblem) && - !is_mass_matrix_alg(alg) && - prob.f.mass_matrix != I - error("This solver is not able to use mass matrices.") - end - - if alg isa OrdinaryDiffEqRosenbrockAdaptiveAlgorithm && - prob.f.mass_matrix isa AbstractMatrix && - all(isequal(0), prob.f.mass_matrix) - # technically this should also warn for zero operators but those are hard to check for - alg isa Union{Rosenbrock23, Rosenbrock32} && error("Rosenbrock23 and Rosenbrock32 require at least one differential variable to produce valid solutions") - if (dense || !isempty(saveat)) && verbose - @warn("Rosenbrock methods on equations without differential states do not bound the error on interpolations.") - end - end - - if !isempty(saveat) && dense - @warn("Dense output is incompatible with saveat. Please use the SavingCallback from the Callback Library to mix the two behaviors.") - end - - progress && @logmsg(LogLevel(-1), progress_name, _id=progress_id, progress=0) - - tType = eltype(prob.tspan) - tspan = prob.tspan - tdir = sign(tspan[end] - tspan[1]) - - t = tspan[1] - - if (((!(alg isa OrdinaryDiffEqAdaptiveAlgorithm) && - !(alg isa OrdinaryDiffEqCompositeAlgorithm) && - !(alg isa DAEAlgorithm)) || !adaptive || !isadaptive(alg)) && - dt == tType(0) && isempty(tstops)) && - !(alg isa Union{FunctionMap, LinearExponential}) - error("Fixed timestep methods require a choice of dt or choosing the tstops") - end - - isdae = alg isa DAEAlgorithm || (!(prob isa DiscreteProblem) && - prob.f.mass_matrix != I && - !(prob.f.mass_matrix isa Tuple) && - ArrayInterface.issingular(prob.f.mass_matrix)) - if alg isa CompositeAlgorithm && alg.choice_function isa AutoSwitch - auto = alg.choice_function - _alg = CompositeAlgorithm(alg.algs, - AutoSwitchCache(0, 0, - auto.nonstiffalg, - auto.stiffalg, - auto.stiffalgfirst, - auto.maxstiffstep, - auto.maxnonstiffstep, - auto.nonstifftol, - auto.stifftol, - auto.dtfac, - auto.stiffalgfirst, - auto.switch_max)) - else - _alg = alg - end - f = prob.f - p = prob.p - - # Get the control variables - - if alias_u0 - u = prob.u0 - else - u = recursivecopy(prob.u0) - end - - if _alg isa DAEAlgorithm - if alias_du0 - du = prob.du0 - else - du = recursivecopy(prob.du0) - end - duprev = recursivecopy(du) - else - du = nothing - duprev = nothing - end - - uType = typeof(u) - uBottomEltype = recursive_bottom_eltype(u) - uBottomEltypeNoUnits = recursive_unitless_bottom_eltype(u) - - uEltypeNoUnits = recursive_unitless_eltype(u) - tTypeNoUnits = typeof(one(tType)) - - if _alg isa FunctionMap - abstol_internal = false - elseif abstol === nothing - if uBottomEltypeNoUnits == uBottomEltype - abstol_internal = ForwardDiff.value(real(convert(uBottomEltype, - oneunit(uBottomEltype) * - 1 // 10^6))) - else - abstol_internal = ForwardDiff.value.(real.(oneunit.(u) .* 1 // 10^6)) - end - else - abstol_internal = real.(abstol) - end - - if _alg isa FunctionMap - reltol_internal = false - elseif reltol === nothing - if uBottomEltypeNoUnits == uBottomEltype - reltol_internal = real(convert(uBottomEltype, - oneunit(uBottomEltype) * 1 // 10^3)) - else - reltol_internal = real.(oneunit.(u) .* 1 // 10^3) - end - else - reltol_internal = real.(reltol) - end - - dtmax > zero(dtmax) && tdir < 0 && (dtmax *= tdir) # Allow positive dtmax, but auto-convert - # dtmin is all abs => does not care about sign already. - - if !isdae && isinplace(prob) && u isa AbstractArray && eltype(u) <: Number && - uBottomEltypeNoUnits == uBottomEltype && tType == tTypeNoUnits # Could this be more efficient for other arrays? - rate_prototype = recursivecopy(u) - elseif prob isa DAEProblem - rate_prototype = prob.du0 - else - if (uBottomEltypeNoUnits == uBottomEltype && tType == tTypeNoUnits) || - eltype(u) <: Enum - rate_prototype = u - else # has units! - rate_prototype = u / oneunit(tType) - end - end - rateType = typeof(rate_prototype) ## Can be different if united - - if isdae - if uBottomEltype == uBottomEltypeNoUnits - res_prototype = u - else - res_prototype = one(u) - end - resType = typeof(res_prototype) - end - - tstops_internal = initialize_tstops(tType, tstops, d_discontinuities, tspan) - saveat_internal = initialize_saveat(tType, saveat, tspan) - d_discontinuities_internal = initialize_d_discontinuities(tType, d_discontinuities, - tspan) - - callbacks_internal = CallbackSet(callback) - - max_len_cb = DiffEqBase.max_vector_callback_length_int(callbacks_internal) - if max_len_cb !== nothing - uBottomEltypeReal = real(uBottomEltype) - if isinplace(prob) - callback_cache = DiffEqBase.CallbackCache(u, max_len_cb, uBottomEltypeReal, - uBottomEltypeReal) - else - callback_cache = DiffEqBase.CallbackCache(max_len_cb, uBottomEltypeReal, - uBottomEltypeReal) - end - else - callback_cache = nothing - end - - ### Algorithm-specific defaults ### - if save_idxs === nothing - ksEltype = Vector{rateType} - else - ks_prototype = rate_prototype[save_idxs] - ksEltype = Vector{typeof(ks_prototype)} - end - - # Have to convert in case passed in wrong. - if save_idxs === nothing - timeseries = timeseries_init === () ? uType[] : - convert(Vector{uType}, timeseries_init) - else - u_initial = u[save_idxs] - timeseries = timeseries_init === () ? typeof(u_initial)[] : - convert(Vector{uType}, timeseries_init) - end - - ts = ts_init === () ? tType[] : convert(Vector{tType}, ts_init) - ks = ks_init === () ? ksEltype[] : convert(Vector{ksEltype}, ks_init) - alg_choice = _alg isa CompositeAlgorithm ? Int[] : () - - if (!adaptive || !isadaptive(_alg)) && save_everystep && tspan[2] - tspan[1] != Inf - if dt == 0 - steps = length(tstops) - else - # For fixed dt, the only time dtmin makes sense is if it's smaller than eps(). - # Therefore user specified dtmin doesn't matter, but we need to ensure dt>=eps() - # to prevent infinite loops. - abs(dt) < dtmin && - throw(ArgumentError("Supplied dt is smaller than dtmin")) - steps = ceil(Int, internalnorm((tspan[2] - tspan[1]) / dt, tspan[1])) - end - sizehint!(timeseries, steps + 1) - sizehint!(ts, steps + 1) - sizehint!(ks, steps + 1) - elseif save_everystep - sizehint!(timeseries, 50) - sizehint!(ts, 50) - sizehint!(ks, 50) - elseif !isempty(saveat_internal) - savelength = length(saveat_internal) + 1 - if save_start == false - savelength -= 1 - end - if save_end == false && prob.tspan[2] in saveat_internal.valtree - savelength -= 1 - end - sizehint!(timeseries, savelength) - sizehint!(ts, savelength) - sizehint!(ks, savelength) - else - sizehint!(timeseries, 2) - sizehint!(ts, 2) - sizehint!(ks, 2) - end - - QT, EEstT = if tTypeNoUnits <: Integer - typeof(qmin), typeof(qmin) - elseif prob isa DiscreteProblem - # The QT fields are not used for DiscreteProblems - constvalue(tTypeNoUnits), constvalue(tTypeNoUnits) - else - typeof(DiffEqBase.value(internalnorm(u, t))), typeof(internalnorm(u, t)) - end - - k = rateType[] - - if uses_uprev(_alg, adaptive) || calck - uprev = recursivecopy(u) - else - # Some algorithms do not use `uprev` explicitly. In that case, we can save - # some memory by aliasing `uprev = u`, e.g. for "2N" low storage methods. - uprev = u - end - if allow_extrapolation - uprev2 = recursivecopy(u) - else - uprev2 = uprev - end - - if prob isa DAEProblem - cache = alg_cache(_alg, du, u, res_prototype, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, - reltol_internal, p, calck, Val(isinplace(prob))) - else - cache = alg_cache(_alg, u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, - tTypeNoUnits, uprev, uprev2, f, t, dt, reltol_internal, p, calck, - Val(isinplace(prob))) - end - - # Setting up the step size controller - if (beta1 !== nothing || beta2 !== nothing) && controller !== nothing - throw(ArgumentError("Setting both the legacy PID parameters `beta1, beta2 = $((beta1, beta2))` and the `controller = $controller` is not allowed.")) - end - - if (beta1 !== nothing || beta2 !== nothing) - message = "Providing the legacy PID parameters `beta1, beta2` is deprecated. Use the keyword argument `controller` instead." - Base.depwarn(message, :init) - Base.depwarn(message, :solve) - end - - if controller === nothing - controller = default_controller(_alg, cache, qoldinit, beta1, beta2) - end - - save_end_user = save_end - save_end = save_end === nothing ? - save_everystep || isempty(saveat) || saveat isa Number || - prob.tspan[2] in saveat : save_end - - opts = DEOptions{typeof(abstol_internal), typeof(reltol_internal), - QT, tType, typeof(controller), - typeof(internalnorm), typeof(internalopnorm), - typeof(save_end_user), - typeof(callbacks_internal), - typeof(isoutofdomain), - typeof(progress_message), typeof(unstable_check), - typeof(tstops_internal), - typeof(d_discontinuities_internal), typeof(userdata), - typeof(save_idxs), - typeof(maxiters), typeof(tstops), - typeof(saveat), typeof(d_discontinuities)}(maxiters, save_everystep, - adaptive, abstol_internal, - reltol_internal, - QT(gamma), QT(qmax), - QT(qmin), - QT(qsteady_max), - QT(qsteady_min), - QT(qoldinit), - QT(failfactor), - tType(dtmax), tType(dtmin), - controller, - internalnorm, - internalopnorm, - save_idxs, tstops_internal, - saveat_internal, - d_discontinuities_internal, - tstops, saveat, - d_discontinuities, - userdata, progress, - progress_steps, - progress_name, - progress_message, - progress_id, - timeseries_errors, - dense_errors, dense, - save_on, save_start, - save_end, save_end_user, - callbacks_internal, - isoutofdomain, - unstable_check, - verbose, calck, force_dtmin, - advance_to_tstop, - stop_at_next_tstop) - - stats = SciMLBase.DEStats(0) - differential_vars = prob isa DAEProblem ? prob.differential_vars : get_differential_vars(f, u) - - if _alg isa OrdinaryDiffEqCompositeAlgorithm - id = CompositeInterpolationData(f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars) - sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, - dense = dense, k = ks, interp = id, - alg_choice = alg_choice, - calculate_error = false, stats = stats) - else - id = InterpolationData(f, timeseries, ts, ks, dense, cache, differential_vars) - sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, - dense = dense, k = ks, interp = id, - calculate_error = false, stats = stats) - end - - if recompile_flag == true - FType = typeof(f) - SolType = typeof(sol) - cacheType = typeof(cache) - else - FType = Function - if _alg isa OrdinaryDiffEqAlgorithm - SolType = DiffEqBase.AbstractODESolution - cacheType = OrdinaryDiffEqCache - else - SolType = DiffEqBase.AbstractDAESolution - cacheType = DAECache - end - end - - # rate/state = (state/time)/state = 1/t units, internalnorm drops units - # we don't want to differentiate through eigenvalue estimation - eigen_est = inv(one(tType)) - tprev = t - dtcache = tType(dt) - dtpropose = tType(dt) - iter = 0 - kshortsize = 0 - reeval_fsal = false - u_modified = false - EEst = EEstT(1) - just_hit_tstop = false - isout = false - accept_step = false - force_stepfail = false - last_stepfail = false - do_error_check = true - event_last_time = 0 - vector_event_last_time = 1 - last_event_error = _alg isa FunctionMap ? false : zero(uBottomEltypeNoUnits) - dtchangeable = isdtchangeable(_alg) - q11 = QT(1) - success_iter = 0 - erracc = QT(1) - dtacc = tType(1) - reinitiailize = true - saveiter = 0 # Starts at 0 so first save is at 1 - saveiter_dense = 0 - - integrator = ODEIntegrator{typeof(_alg), isinplace(prob), uType, typeof(du), - tType, typeof(p), - typeof(eigen_est), typeof(EEst), - QT, typeof(tdir), typeof(k), SolType, - FType, cacheType, - typeof(opts), fsal_typeof(_alg, rate_prototype), - typeof(last_event_error), typeof(callback_cache), - typeof(initializealg), typeof(differential_vars)}(sol, u, du, k, t, tType(dt), f, p, - uprev, uprev2, duprev, tprev, - _alg, dtcache, dtchangeable, - dtpropose, tdir, eigen_est, EEst, - QT(qoldinit), q11, - erracc, dtacc, success_iter, - iter, saveiter, saveiter_dense, cache, - callback_cache, - kshortsize, force_stepfail, - last_stepfail, - just_hit_tstop, do_error_check, - event_last_time, - vector_event_last_time, - last_event_error, accept_step, - isout, reeval_fsal, - u_modified, reinitiailize, isdae, - opts, stats, initializealg, differential_vars) - - if initialize_integrator - if isdae - DiffEqBase.initialize_dae!(integrator) - end - - if save_start - integrator.saveiter += 1 # Starts at 1 so first save is at 2 - integrator.saveiter_dense += 1 - copyat_or_push!(ts, 1, t) - if save_idxs === nothing - copyat_or_push!(timeseries, 1, integrator.u) - copyat_or_push!(ks, 1, [rate_prototype]) - else - copyat_or_push!(timeseries, 1, u_initial, Val{false}) - copyat_or_push!(ks, 1, [ks_prototype]) - end - else - integrator.saveiter = 0 # Starts at 0 so first save is at 1 - integrator.saveiter_dense = 0 - end - - initialize_callbacks!(integrator, initialize_save) - initialize!(integrator, integrator.cache) - - if _alg isa OrdinaryDiffEqCompositeAlgorithm - # in case user mixes adaptive and non-adaptive algorithms - ensure_behaving_adaptivity!(integrator, integrator.cache) - - if save_start - # Loop to get all of the extra possible saves in callback initialization - for i in 1:(integrator.saveiter) - copyat_or_push!(alg_choice, i, integrator.cache.current) - end - end - end - end - - handle_dt!(integrator) - integrator -end - -function DiffEqBase.solve!(integrator::ODEIntegrator) - @inbounds while !isempty(integrator.opts.tstops) - while integrator.tdir * integrator.t < first(integrator.opts.tstops) - loopheader!(integrator) - if integrator.do_error_check && check_error!(integrator) != ReturnCode.Success - return integrator.sol - end - perform_step!(integrator, integrator.cache) - loopfooter!(integrator) - if isempty(integrator.opts.tstops) - break - end - end - handle_tstop!(integrator) - end - postamble!(integrator) - - f = integrator.sol.prob.f - - if DiffEqBase.has_analytic(f) - DiffEqBase.calculate_solution_errors!(integrator.sol; - timeseries_errors = integrator.opts.timeseries_errors, - dense_errors = integrator.opts.dense_errors) - end - if integrator.sol.retcode != ReturnCode.Default - return integrator.sol - end - integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, ReturnCode.Success) -end - -# Helpers - -function handle_dt!(integrator) - if iszero(integrator.dt) && integrator.opts.adaptive - auto_dt_reset!(integrator) - if sign(integrator.dt) != integrator.tdir && !iszero(integrator.dt) && - !isnan(integrator.dt) - error("Automatic dt setting has the wrong sign. Exiting. Please report this error.") - end - if isnan(integrator.dt) - if integrator.opts.verbose - @warn("Automatic dt set the starting dt as NaN, causing instability. Exiting.") - end - end - elseif integrator.opts.adaptive && integrator.dt > zero(integrator.dt) && - integrator.tdir < 0 - integrator.dt *= integrator.tdir # Allow positive dt, but auto-convert - end -end - -# time stops -@inline function initialize_tstops(::Type{T}, tstops, d_discontinuities, tspan) where {T} - tstops_internal = BinaryHeap{T}(DataStructures.FasterForward()) - - t0, tf = tspan - tdir = sign(tf - t0) - tdir_t0 = tdir * t0 - tdir_tf = tdir * tf - - for t in tstops - tdir_t = tdir * t - tdir_t0 < tdir_t ≤ tdir_tf && push!(tstops_internal, tdir_t) - end - for t in d_discontinuities - tdir_t = tdir * t - tdir_t0 < tdir_t ≤ tdir_tf && push!(tstops_internal, tdir_t) - end - push!(tstops_internal, tdir_tf) - - return tstops_internal -end - -# saving time points -function initialize_saveat(::Type{T}, saveat, tspan) where {T} - saveat_internal = BinaryHeap{T}(DataStructures.FasterForward()) - - t0, tf = tspan - tdir = sign(tf - t0) - tdir_t0 = tdir * t0 - tdir_tf = tdir * tf - - if saveat isa Number - directional_saveat = tdir * abs(saveat) - for t in (t0 + directional_saveat):directional_saveat:tf - push!(saveat_internal, tdir * t) - end - elseif !isempty(saveat) - for t in saveat - tdir_t = tdir * t - tdir_t0 < tdir_t ≤ tdir_tf && push!(saveat_internal, tdir_t) - end - end - - return saveat_internal -end - -# discontinuities -function initialize_d_discontinuities(::Type{T}, d_discontinuities, tspan) where {T} - d_discontinuities_internal = BinaryHeap{T}(DataStructures.FasterForward()) - sizehint!(d_discontinuities_internal, length(d_discontinuities)) - - t0, tf = tspan - tdir = sign(tf - t0) - - for t in d_discontinuities - push!(d_discontinuities_internal, tdir * t) - end - - return d_discontinuities_internal -end - -function initialize_callbacks!(integrator, initialize_save = true) - t = integrator.t - u = integrator.u - callbacks = integrator.opts.callback - integrator.u_modified = true - - u_modified = initialize!(callbacks, u, t, integrator) - - # if the user modifies u, we need to fix previous values before initializing - # FSAL in order for the starting derivatives to be correct - if u_modified - if isinplace(integrator.sol.prob) - recursivecopy!(integrator.uprev, integrator.u) - else - integrator.uprev = integrator.u - end - - if alg_extrapolates(integrator.alg) - if isinplace(integrator.sol.prob) - recursivecopy!(integrator.uprev2, integrator.uprev) - else - integrator.uprev2 = integrator.uprev - end - end - - if initialize_save && - (any((c) -> c.save_positions[2], callbacks.discrete_callbacks) || - any((c) -> c.save_positions[2], callbacks.continuous_callbacks)) - savevalues!(integrator, true) - end - end - - # reset this as it is now handled so the integrators should proceed as normal - integrator.u_modified = false -end +function DiffEqBase.__solve(prob::Union{DiffEqBase.AbstractODEProblem, + DiffEqBase.AbstractDAEProblem}, + alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, args...; + kwargs...) + integrator = DiffEqBase.__init(prob, alg, args...; kwargs...) + solve!(integrator) + integrator.sol +end + +function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem, + DiffEqBase.AbstractDAEProblem}, + alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, + timeseries_init = (), + ts_init = (), + ks_init = (), + recompile::Type{Val{recompile_flag}} = Val{true}; + saveat = (), + tstops = (), + d_discontinuities = (), + save_idxs = nothing, + save_everystep = isempty(saveat), + save_on = true, + save_start = save_everystep || isempty(saveat) || + saveat isa Number || prob.tspan[1] in saveat, + save_end = nothing, + callback = nothing, + dense = save_everystep && + !(alg isa Union{DAEAlgorithm, FunctionMap}) && + isempty(saveat), + calck = (callback !== nothing && callback !== CallbackSet()) || + (dense) || !isempty(saveat), # and no dense output + dt = alg isa FunctionMap && isempty(tstops) ? + eltype(prob.tspan)(1) : eltype(prob.tspan)(0), + dtmin = eltype(prob.tspan)(0), + dtmax = eltype(prob.tspan)((prob.tspan[end] - prob.tspan[1])), + force_dtmin = false, + adaptive = anyadaptive(alg), + gamma = gamma_default(alg), + abstol = nothing, + reltol = nothing, + qmin = qmin_default(alg), + qmax = qmax_default(alg), + qsteady_min = qsteady_min_default(alg), + qsteady_max = qsteady_max_default(alg), + beta1 = nothing, + beta2 = nothing, + qoldinit = anyadaptive(alg) ? 1 // 10^4 : 0, + controller = nothing, + fullnormalize = true, + failfactor = 2, + maxiters = anyadaptive(alg) ? 1000000 : typemax(Int), + internalnorm = ODE_DEFAULT_NORM, + internalopnorm = LinearAlgebra.opnorm, + isoutofdomain = ODE_DEFAULT_ISOUTOFDOMAIN, + unstable_check = ODE_DEFAULT_UNSTABLE_CHECK, + verbose = true, + timeseries_errors = true, + dense_errors = false, + advance_to_tstop = false, + stop_at_next_tstop = false, + initialize_save = true, + progress = false, + progress_steps = 1000, + progress_name = "ODE", + progress_message = ODE_DEFAULT_PROG_MESSAGE, + progress_id = gensym("OrdinaryDiffEq"), + userdata = nothing, + allow_extrapolation = alg_extrapolates(alg), + initialize_integrator = true, + alias_u0 = false, + alias_du0 = false, + initializealg = DefaultInit(), + kwargs...) where {recompile_flag} + if prob isa DiffEqBase.AbstractDAEProblem && alg isa OrdinaryDiffEqAlgorithm + error("You cannot use an ODE Algorithm with a DAEProblem") + end + + if prob isa DiffEqBase.AbstractODEProblem && alg isa DAEAlgorithm + error("You cannot use an DAE Algorithm with a ODEProblem") + end + + if prob isa DiffEqBase.ODEProblem + if !(prob.f isa DiffEqBase.DynamicalODEFunction) && alg isa PartitionedAlgorithm + error("You can not use a solver designed for partitioned ODE with this problem. Please choose a solver suitable for your problem") + end + end + + if prob.f isa DynamicalODEFunction && prob.f.mass_matrix isa Tuple + if any(mm != I for mm in prob.f.mass_matrix) + error("This solver is not able to use mass matrices.") + end + elseif !(prob isa DiscreteProblem) && + !(prob isa DiffEqBase.AbstractDAEProblem) && + !is_mass_matrix_alg(alg) && + prob.f.mass_matrix != I + error("This solver is not able to use mass matrices.") + end + + if alg isa OrdinaryDiffEqRosenbrockAdaptiveAlgorithm && + prob.f.mass_matrix isa AbstractMatrix && + all(isequal(0), prob.f.mass_matrix) + # technically this should also warn for zero operators but those are hard to check for + alg isa Union{Rosenbrock23, Rosenbrock32} && error("Rosenbrock23 and Rosenbrock32 require at least one differential variable to produce valid solutions") + if (dense || !isempty(saveat)) && verbose + @warn("Rosenbrock methods on equations without differential states do not bound the error on interpolations.") + end + end + + if !isempty(saveat) && dense + @warn("Dense output is incompatible with saveat. Please use the SavingCallback from the Callback Library to mix the two behaviors.") + end + + progress && @logmsg(LogLevel(-1), progress_name, _id=progress_id, progress=0) + + tType = eltype(prob.tspan) + tspan = prob.tspan + tdir = sign(tspan[end] - tspan[1]) + + t = tspan[1] + + if (((!(alg isa OrdinaryDiffEqAdaptiveAlgorithm) && + !(alg isa OrdinaryDiffEqCompositeAlgorithm) && + !(alg isa DAEAlgorithm)) || !adaptive || !isadaptive(alg)) && + dt == tType(0) && isempty(tstops)) && + !(alg isa Union{FunctionMap, LinearExponential}) + error("Fixed timestep methods require a choice of dt or choosing the tstops") + end + + isdae = alg isa DAEAlgorithm || (!(prob isa DiscreteProblem) && + prob.f.mass_matrix != I && + !(prob.f.mass_matrix isa Tuple) && + ArrayInterface.issingular(prob.f.mass_matrix)) + if alg isa CompositeAlgorithm && alg.choice_function isa AutoSwitch + auto = alg.choice_function + _alg = CompositeAlgorithm(alg.algs, + AutoSwitchCache(auto.algtrait, 0, 0, + auto.nonstiffalg, + auto.stiffalg, + auto.stiffalgfirst, + auto.maxstiffstep, + auto.maxnonstiffstep, + auto.nonstifftol, + auto.stifftol, + auto.dtfac, + auto.stiffalgfirst, + auto.switch_max, 0), Val(allowfallbacks(alg))) + else + _alg = alg + end + f = prob.f + p = prob.p + + # Get the control variables + + if alias_u0 + u = prob.u0 + else + u = recursivecopy(prob.u0) + end + + if _alg isa DAEAlgorithm + if alias_du0 + du = prob.du0 + else + du = recursivecopy(prob.du0) + end + duprev = recursivecopy(du) + else + du = nothing + duprev = nothing + end + + uType = typeof(u) + uBottomEltype = recursive_bottom_eltype(u) + uBottomEltypeNoUnits = recursive_unitless_bottom_eltype(u) + + uEltypeNoUnits = recursive_unitless_eltype(u) + tTypeNoUnits = typeof(one(tType)) + + if _alg isa FunctionMap + abstol_internal = false + elseif abstol === nothing + if uBottomEltypeNoUnits == uBottomEltype + abstol_internal = ForwardDiff.value(real(convert(uBottomEltype, + oneunit(uBottomEltype) * + 1 // 10^6))) + else + abstol_internal = ForwardDiff.value.(real.(oneunit.(u) .* 1 // 10^6)) + end + else + abstol_internal = real.(abstol) + end + + if _alg isa FunctionMap + reltol_internal = false + elseif reltol === nothing + if uBottomEltypeNoUnits == uBottomEltype + reltol_internal = real(convert(uBottomEltype, + oneunit(uBottomEltype) * 1 // 10^3)) + else + reltol_internal = real.(oneunit.(u) .* 1 // 10^3) + end + else + reltol_internal = real.(reltol) + end + + dtmax > zero(dtmax) && tdir < 0 && (dtmax *= tdir) # Allow positive dtmax, but auto-convert + # dtmin is all abs => does not care about sign already. + + if !isdae && isinplace(prob) && u isa AbstractArray && eltype(u) <: Number && + uBottomEltypeNoUnits == uBottomEltype && tType == tTypeNoUnits # Could this be more efficient for other arrays? + rate_prototype = recursivecopy(u) + elseif prob isa DAEProblem + rate_prototype = prob.du0 + else + if (uBottomEltypeNoUnits == uBottomEltype && tType == tTypeNoUnits) || + eltype(u) <: Enum + rate_prototype = u + else # has units! + rate_prototype = u / oneunit(tType) + end + end + rateType = typeof(rate_prototype) ## Can be different if united + + if isdae + if uBottomEltype == uBottomEltypeNoUnits + res_prototype = u + else + res_prototype = one(u) + end + resType = typeof(res_prototype) + end + + tstops_internal = initialize_tstops(tType, tstops, d_discontinuities, tspan) + saveat_internal = initialize_saveat(tType, saveat, tspan) + d_discontinuities_internal = initialize_d_discontinuities(tType, d_discontinuities, + tspan) + + callbacks_internal = CallbackSet(callback) + + max_len_cb = DiffEqBase.max_vector_callback_length_int(callbacks_internal) + if max_len_cb !== nothing + uBottomEltypeReal = real(uBottomEltype) + if isinplace(prob) + callback_cache = DiffEqBase.CallbackCache(u, max_len_cb, uBottomEltypeReal, + uBottomEltypeReal) + else + callback_cache = DiffEqBase.CallbackCache(max_len_cb, uBottomEltypeReal, + uBottomEltypeReal) + end + else + callback_cache = nothing + end + + ### Algorithm-specific defaults ### + if save_idxs === nothing + ksEltype = Vector{rateType} + else + ks_prototype = rate_prototype[save_idxs] + ksEltype = Vector{typeof(ks_prototype)} + end + + # Have to convert in case passed in wrong. + if save_idxs === nothing + timeseries = timeseries_init === () ? uType[] : + convert(Vector{uType}, timeseries_init) + else + u_initial = u[save_idxs] + timeseries = timeseries_init === () ? typeof(u_initial)[] : + convert(Vector{uType}, timeseries_init) + end + + ts = ts_init === () ? tType[] : convert(Vector{tType}, ts_init) + ks = ks_init === () ? ksEltype[] : convert(Vector{ksEltype}, ks_init) + alg_choice = _alg isa CompositeAlgorithm ? Int[] : () + + if (!adaptive || !isadaptive(_alg)) && save_everystep && tspan[2] - tspan[1] != Inf + if dt == 0 + steps = length(tstops) + else + # For fixed dt, the only time dtmin makes sense is if it's smaller than eps(). + # Therefore user specified dtmin doesn't matter, but we need to ensure dt>=eps() + # to prevent infinite loops. + abs(dt) < dtmin && + throw(ArgumentError("Supplied dt is smaller than dtmin")) + steps = ceil(Int, internalnorm((tspan[2] - tspan[1]) / dt, tspan[1])) + end + sizehint!(timeseries, steps + 1) + sizehint!(ts, steps + 1) + sizehint!(ks, steps + 1) + elseif save_everystep + sizehint!(timeseries, 50) + sizehint!(ts, 50) + sizehint!(ks, 50) + elseif !isempty(saveat_internal) + savelength = length(saveat_internal) + 1 + if save_start == false + savelength -= 1 + end + if save_end == false && prob.tspan[2] in saveat_internal.valtree + savelength -= 1 + end + sizehint!(timeseries, savelength) + sizehint!(ts, savelength) + sizehint!(ks, savelength) + else + sizehint!(timeseries, 2) + sizehint!(ts, 2) + sizehint!(ks, 2) + end + + QT, EEstT = if tTypeNoUnits <: Integer + typeof(qmin), typeof(qmin) + elseif prob isa DiscreteProblem + # The QT fields are not used for DiscreteProblems + constvalue(tTypeNoUnits), constvalue(tTypeNoUnits) + else + typeof(DiffEqBase.value(internalnorm(u, t))), typeof(internalnorm(u, t)) + end + + k = rateType[] + + if uses_uprev(_alg, adaptive) || calck + uprev = recursivecopy(u) + else + # Some algorithms do not use `uprev` explicitly. In that case, we can save + # some memory by aliasing `uprev = u`, e.g. for "2N" low storage methods. + uprev = u + end + if allow_extrapolation + uprev2 = recursivecopy(u) + else + uprev2 = uprev + end + + if prob isa DAEProblem + cache = alg_cache(_alg, du, u, res_prototype, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, + reltol_internal, p, calck, Val(isinplace(prob))) + else + cache = alg_cache(_alg, u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, + tTypeNoUnits, uprev, uprev2, f, t, dt, reltol_internal, p, calck, + Val(isinplace(prob))) + end + + # Setting up the step size controller + if (beta1 !== nothing || beta2 !== nothing) && controller !== nothing + throw(ArgumentError("Setting both the legacy PID parameters `beta1, beta2 = $((beta1, beta2))` and the `controller = $controller` is not allowed.")) + end + + if (beta1 !== nothing || beta2 !== nothing) + message = "Providing the legacy PID parameters `beta1, beta2` is deprecated. Use the keyword argument `controller` instead." + Base.depwarn(message, :init) + Base.depwarn(message, :solve) + end + + if controller === nothing + controller = default_controller(_alg, cache, qoldinit, beta1, beta2) + end + + save_end_user = save_end + save_end = save_end === nothing ? + save_everystep || isempty(saveat) || saveat isa Number || + prob.tspan[2] in saveat : save_end + + opts = DEOptions{typeof(abstol_internal), typeof(reltol_internal), + QT, tType, typeof(controller), + typeof(internalnorm), typeof(internalopnorm), + typeof(save_end_user), + typeof(callbacks_internal), + typeof(isoutofdomain), + typeof(progress_message), typeof(unstable_check), + typeof(tstops_internal), + typeof(d_discontinuities_internal), typeof(userdata), + typeof(save_idxs), + typeof(maxiters), typeof(tstops), + typeof(saveat), typeof(d_discontinuities)}(maxiters, save_everystep, + adaptive, abstol_internal, + reltol_internal, + QT(gamma), QT(qmax), + QT(qmin), + QT(qsteady_max), + QT(qsteady_min), + QT(qoldinit), + QT(failfactor), + tType(dtmax), tType(dtmin), + controller, + internalnorm, + internalopnorm, + save_idxs, tstops_internal, + saveat_internal, + d_discontinuities_internal, + tstops, saveat, + d_discontinuities, + userdata, progress, + progress_steps, + progress_name, + progress_message, + progress_id, + timeseries_errors, + dense_errors, dense, + save_on, save_start, + save_end, save_end_user, + callbacks_internal, + isoutofdomain, + unstable_check, + verbose, calck, force_dtmin, + advance_to_tstop, + stop_at_next_tstop) + + stats = SciMLBase.DEStats(0) + differential_vars = prob isa DAEProblem ? prob.differential_vars : get_differential_vars(f, u) + + if _alg isa OrdinaryDiffEqCompositeAlgorithm + id = CompositeInterpolationData(f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars) + sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, + dense = dense, k = ks, interp = id, + alg_choice = alg_choice, + calculate_error = false, stats = stats) + else + id = InterpolationData(f, timeseries, ts, ks, dense, cache, differential_vars) + sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, + dense = dense, k = ks, interp = id, + calculate_error = false, stats = stats) + end + + if recompile_flag == true + FType = typeof(f) + SolType = typeof(sol) + cacheType = typeof(cache) + else + FType = Function + if _alg isa OrdinaryDiffEqAlgorithm + SolType = DiffEqBase.AbstractODESolution + cacheType = OrdinaryDiffEqCache + else + SolType = DiffEqBase.AbstractDAESolution + cacheType = DAECache + end + end + + # rate/state = (state/time)/state = 1/t units, internalnorm drops units + # we don't want to differentiate through eigenvalue estimation + eigen_est = inv(one(tType)) + tprev = t + dtcache = tType(dt) + dtpropose = tType(dt) + iter = 0 + kshortsize = 0 + reeval_fsal = false + u_modified = false + EEst = EEstT(1) + just_hit_tstop = false + isout = false + accept_step = false + force_stepfail = false + last_stepfail = false + do_error_check = true + event_last_time = 0 + vector_event_last_time = 1 + last_event_error = _alg isa FunctionMap ? false : zero(uBottomEltypeNoUnits) + dtchangeable = isdtchangeable(_alg) + q11 = QT(1) + success_iter = 0 + erracc = QT(1) + dtacc = tType(1) + reinitiailize = true + saveiter = 0 # Starts at 0 so first save is at 1 + saveiter_dense = 0 + + integrator = ODEIntegrator{typeof(_alg), isinplace(prob), uType, typeof(du), + tType, typeof(p), + typeof(eigen_est), typeof(EEst), + QT, typeof(tdir), typeof(k), SolType, + FType, cacheType, + typeof(opts), fsal_typeof(_alg, rate_prototype), + typeof(last_event_error), typeof(callback_cache), + typeof(initializealg), typeof(differential_vars)}(sol, u, du, k, t, tType(dt), f, p, + uprev, uprev2, duprev, tprev, + _alg, dtcache, dtchangeable, + dtpropose, tdir, eigen_est, EEst, + QT(qoldinit), q11, + erracc, dtacc, success_iter, + iter, saveiter, saveiter_dense, cache, + callback_cache, + kshortsize, force_stepfail, + last_stepfail, + just_hit_tstop, do_error_check, + event_last_time, + vector_event_last_time, + last_event_error, accept_step, + isout, reeval_fsal, + u_modified, reinitiailize, isdae, + opts, stats, initializealg, differential_vars) + + if initialize_integrator + if isdae + DiffEqBase.initialize_dae!(integrator) + end + + if save_start + integrator.saveiter += 1 # Starts at 1 so first save is at 2 + integrator.saveiter_dense += 1 + copyat_or_push!(ts, 1, t) + if save_idxs === nothing + copyat_or_push!(timeseries, 1, integrator.u) + copyat_or_push!(ks, 1, [rate_prototype]) + else + copyat_or_push!(timeseries, 1, u_initial, Val{false}) + copyat_or_push!(ks, 1, [ks_prototype]) + end + else + integrator.saveiter = 0 # Starts at 0 so first save is at 1 + integrator.saveiter_dense = 0 + end + + initialize_callbacks!(integrator, initialize_save) + initialize!(integrator, integrator.cache) + + if _alg isa OrdinaryDiffEqCompositeAlgorithm + # in case user mixes adaptive and non-adaptive algorithms + ensure_behaving_adaptivity!(integrator, integrator.cache) + + if save_start + # Loop to get all of the extra possible saves in callback initialization + for i in 1:(integrator.saveiter) + copyat_or_push!(alg_choice, i, integrator.cache.current) + end + end + end + end + + handle_dt!(integrator) + integrator +end + +function DiffEqBase.solve!(integrator::ODEIntegrator) + @inbounds while !isempty(integrator.opts.tstops) + while integrator.tdir * integrator.t < first(integrator.opts.tstops) + loopheader!(integrator) + if integrator.do_error_check && check_error!(integrator) != ReturnCode.Success + return integrator.sol + end + perform_step!(integrator, integrator.cache) + loopfooter!(integrator) + if isempty(integrator.opts.tstops) + break + end + end + handle_tstop!(integrator) + end + postamble!(integrator) + + f = integrator.sol.prob.f + + if DiffEqBase.has_analytic(f) + DiffEqBase.calculate_solution_errors!(integrator.sol; + timeseries_errors = integrator.opts.timeseries_errors, + dense_errors = integrator.opts.dense_errors) + end + if integrator.sol.retcode != ReturnCode.Default + return integrator.sol + end + integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, ReturnCode.Success) +end + +# Helpers + +function handle_dt!(integrator) + if iszero(integrator.dt) && integrator.opts.adaptive + auto_dt_reset!(integrator) + if sign(integrator.dt) != integrator.tdir && !iszero(integrator.dt) && + !isnan(integrator.dt) + error("Automatic dt setting has the wrong sign. Exiting. Please report this error.") + end + if isnan(integrator.dt) + if integrator.opts.verbose + @warn("Automatic dt set the starting dt as NaN, causing instability. Exiting.") + end + end + elseif integrator.opts.adaptive && integrator.dt > zero(integrator.dt) && + integrator.tdir < 0 + integrator.dt *= integrator.tdir # Allow positive dt, but auto-convert + end +end + +# time stops +@inline function initialize_tstops(::Type{T}, tstops, d_discontinuities, tspan) where {T} + tstops_internal = BinaryHeap{T}(DataStructures.FasterForward()) + + t0, tf = tspan + tdir = sign(tf - t0) + tdir_t0 = tdir * t0 + tdir_tf = tdir * tf + + for t in tstops + tdir_t = tdir * t + tdir_t0 < tdir_t ≤ tdir_tf && push!(tstops_internal, tdir_t) + end + for t in d_discontinuities + tdir_t = tdir * t + tdir_t0 < tdir_t ≤ tdir_tf && push!(tstops_internal, tdir_t) + end + push!(tstops_internal, tdir_tf) + + return tstops_internal +end + +# saving time points +function initialize_saveat(::Type{T}, saveat, tspan) where {T} + saveat_internal = BinaryHeap{T}(DataStructures.FasterForward()) + + t0, tf = tspan + tdir = sign(tf - t0) + tdir_t0 = tdir * t0 + tdir_tf = tdir * tf + + if saveat isa Number + directional_saveat = tdir * abs(saveat) + for t in (t0 + directional_saveat):directional_saveat:tf + push!(saveat_internal, tdir * t) + end + elseif !isempty(saveat) + for t in saveat + tdir_t = tdir * t + tdir_t0 < tdir_t ≤ tdir_tf && push!(saveat_internal, tdir_t) + end + end + + return saveat_internal +end + +# discontinuities +function initialize_d_discontinuities(::Type{T}, d_discontinuities, tspan) where {T} + d_discontinuities_internal = BinaryHeap{T}(DataStructures.FasterForward()) + sizehint!(d_discontinuities_internal, length(d_discontinuities)) + + t0, tf = tspan + tdir = sign(tf - t0) + + for t in d_discontinuities + push!(d_discontinuities_internal, tdir * t) + end + + return d_discontinuities_internal +end + +function initialize_callbacks!(integrator, initialize_save = true) + t = integrator.t + u = integrator.u + callbacks = integrator.opts.callback + integrator.u_modified = true + + u_modified = initialize!(callbacks, u, t, integrator) + + # if the user modifies u, we need to fix previous values before initializing + # FSAL in order for the starting derivatives to be correct + if u_modified + if isinplace(integrator.sol.prob) + recursivecopy!(integrator.uprev, integrator.u) + else + integrator.uprev = integrator.u + end + + if alg_extrapolates(integrator.alg) + if isinplace(integrator.sol.prob) + recursivecopy!(integrator.uprev2, integrator.uprev) + else + integrator.uprev2 = integrator.uprev + end + end + + if initialize_save && + (any((c) -> c.save_positions[2], callbacks.discrete_callbacks) || + any((c) -> c.save_positions[2], callbacks.continuous_callbacks)) + savevalues!(integrator, true) + end + end + + # reset this as it is now handled so the integrators should proceed as normal + integrator.u_modified = false +end From 678ccf20e3fab9c6cd36f5b3afcc4b9581c73988 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 1 Jan 2024 15:35:48 -0500 Subject: [PATCH 2/5] fix typo / test --- src/alg_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 08fc16235f..af07151ed2 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -357,7 +357,7 @@ function concrete_jac(alg::Union{ end alg_extrapolates(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false -alg_extrapolates(alg::CompositeAlgorithm) = error("any(alg_extrapolates.(alg.algs))") +alg_extrapolates(alg::CompositeAlgorithm) = any(alg_extrapolates.(alg.algs)) alg_extrapolates(alg::DefaultSolverAlgorithm) = false alg_extrapolates(alg::ImplicitEuler) = true alg_extrapolates(alg::DImplicitEuler) = true From a22f41b594d9993a0385354e6f80c89271c8f70b Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 1 Jan 2024 15:41:10 -0500 Subject: [PATCH 3/5] fix precompilation choices --- src/OrdinaryDiffEq.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index 6dd1c1fa67..3030686c4c 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -289,7 +289,7 @@ PrecompileTools.@compile_workload begin end if Preferences.@load_preference("PrecompileDefault", true) - append!(solver_list, stiff) + append!(solver_list, default_ode) end if Preferences.@load_preference("PrecompileAutoSwitch", false) From a780fa685bde14fbe494aaf8155a713c5b553813 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 2 Jan 2024 05:46:01 -0500 Subject: [PATCH 4/5] Update src/composite_algs.jl Co-authored-by: Nathanael Bosch --- src/composite_algs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/composite_algs.jl b/src/composite_algs.jl index e82e2e54ca..24c157f025 100644 --- a/src/composite_algs.jl +++ b/src/composite_algs.jl @@ -75,7 +75,7 @@ AutoVern9(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern9(lazy = lazy), alg; EnumX.@enumx DefaultSolverChoice begin Tsit5 = 1 Vern7 = 2 - Rosnebrock23 = 3 + Rosenbrock23 = 3 Rodas5P = 4 FBDF = 5 KrylovFBDF = 6 From 1db42e90315e05f3c67aaeed67e6a5b253dd91b9 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 2 Jan 2024 05:46:25 -0500 Subject: [PATCH 5/5] Update src/composite_algs.jl --- src/composite_algs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/composite_algs.jl b/src/composite_algs.jl index 24c157f025..9f07defcb3 100644 --- a/src/composite_algs.jl +++ b/src/composite_algs.jl @@ -140,7 +140,7 @@ function stiffchoice(reltol, len) if reltol < LOW_TOL DefaultSolverChoice.Rodas5P else - DefaultSolverChoice.Rosnebrock23 + DefaultSolverChoice.Rosenbrock23 end end Int(x)