diff --git a/Project.toml b/Project.toml index 328d2dfdd2..3d637e3a54 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index 2cea945101..3030686c4c 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -28,6 +28,8 @@ using LinearSolve, SimpleNonlinearSolve using LineSearches +import EnumX + import FillArrays: Trues # Interfaces @@ -141,6 +143,7 @@ include("nlsolve/functional.jl") include("nlsolve/newton.jl") include("generic_rosenbrock.jl") +include("composite_algs.jl") include("caches/basic_caches.jl") include("caches/low_order_rk_caches.jl") @@ -231,7 +234,6 @@ include("constants.jl") include("solve.jl") include("initdt.jl") include("interp_func.jl") -include("composite_algs.jl") import PrecompileTools @@ -250,9 +252,14 @@ PrecompileTools.@compile_workload begin Tsit5(), Vern7(), ] - stiff = [Rosenbrock23(), Rosenbrock23(autodiff = false), - Rodas5P(), Rodas5P(autodiff = false), - FBDF(), FBDF(autodiff = false), + stiff = [Rosenbrock23(), + Rodas5P(), + FBDF() + ] + + default_ode = [ + DefaultODEAlgorithm(autodiff=false), + DefaultODEAlgorithm() ] autoswitch = [ @@ -281,7 +288,11 @@ PrecompileTools.@compile_workload begin append!(solver_list, stiff) end - if Preferences.@load_preference("PrecompileAutoSwitch", true) + if Preferences.@load_preference("PrecompileDefault", true) + append!(solver_list, default_ode) + end + + if Preferences.@load_preference("PrecompileAutoSwitch", false) append!(solver_list, autoswitch) end diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 33f7f0b02b..af07151ed2 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -164,6 +164,8 @@ isimplicit(alg::CompositeAlgorithm) = any(isimplicit.(alg.algs)) isdtchangeable(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = true isdtchangeable(alg::CompositeAlgorithm) = all(isdtchangeable.(alg.algs)) +isdtchangeable(alg::DefaultSolverAlgorithm) = true + function isdtchangeable(alg::Union{LawsonEuler, NorsettEuler, LieEuler, MagnusGauss4, CayleyEuler, ETDRK2, ETDRK3, ETDRK4, HochOst4, ETD2}) false @@ -176,12 +178,14 @@ ismultistep(alg::ETD2) = true isadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false isadaptive(alg::OrdinaryDiffEqAdaptiveAlgorithm) = true isadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = all(isadaptive.(alg.algs)) +isadaptive(alg::DefaultSolverAlgorithm) = true isadaptive(alg::DImplicitEuler) = true isadaptive(alg::DABDF2) = true isadaptive(alg::DFBDF) = true anyadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = isadaptive(alg) anyadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = any(isadaptive, alg.algs) +anyadaptive(alg::DefaultSolverAlgorithm) = true isautoswitch(alg) = false isautoswitch(alg::CompositeAlgorithm) = alg.choice_function isa AutoSwitch @@ -191,9 +195,11 @@ function qmin_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) end qmin_default(alg::CompositeAlgorithm) = maximum(qmin_default.(alg.algs)) qmin_default(alg::DP8) = 1 // 3 +qmin_default(alg::DefaultSolverAlgorithm) = 1 // 5 qmax_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = 10 qmax_default(alg::CompositeAlgorithm) = minimum(qmax_default.(alg.algs)) +qmax_default(alg::DefaultSolverAlgorithm) = 10 qmax_default(alg::DP8) = 6 qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8 @@ -277,7 +283,7 @@ end function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob) algs = map(alg -> DiffEqBase.prepare_alg(alg, u0, p, prob), alg.algs) - CompositeAlgorithm(algs, alg.choice_function) + CompositeAlgorithm(algs, alg.choice_function, Val(allowfallbacks(alg))) end # Extract AD type parameter from algorithm, returning as Val to ensure type stability for boolean options. @@ -352,6 +358,7 @@ end alg_extrapolates(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false alg_extrapolates(alg::CompositeAlgorithm) = any(alg_extrapolates.(alg.algs)) +alg_extrapolates(alg::DefaultSolverAlgorithm) = false alg_extrapolates(alg::ImplicitEuler) = true alg_extrapolates(alg::DImplicitEuler) = true alg_extrapolates(alg::DABDF2) = true @@ -695,6 +702,7 @@ alg_order(alg::Alshina6) = 6 alg_maximum_order(alg) = alg_order(alg) alg_maximum_order(alg::CompositeAlgorithm) = maximum(alg_order(x) for x in alg.algs) +alg_maximum_order(alg::DefaultSolverAlgorithm) = 7 alg_maximum_order(alg::ExtrapolationMidpointDeuflhard) = 2(alg.max_order + 1) alg_maximum_order(alg::ImplicitDeuflhardExtrapolation) = 2(alg.max_order + 1) alg_maximum_order(alg::ExtrapolationMidpointHairerWanner) = 2(alg.max_order + 1) @@ -829,6 +837,7 @@ function gamma_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) isadaptive(alg) ? 9 // 10 : 0 end gamma_default(alg::CompositeAlgorithm) = maximum(gamma_default, alg.algs) +gamma_default(alg::DefaultSolverAlgorithm) = 9 // 10 gamma_default(alg::RKC) = 8 // 10 gamma_default(alg::IRKC) = 8 // 10 function gamma_default(alg::ExtrapolationMidpointDeuflhard) @@ -949,14 +958,18 @@ function unwrap_alg(integrator, is_stiff) if !iscomp return alg elseif alg.choice_function isa AutoSwitchCache - if is_stiff === nothing - throwautoswitch(alg) - end - num = is_stiff ? 2 : 1 - if num == 1 - return alg.algs[1] + if alg.choice_function.algtrait isa DefaultODESolver + alg.algs[alg.choice_function.current] else - return alg.algs[2] + if is_stiff === nothing + throwautoswitch(alg) + end + num = is_stiff ? 2 : 1 + if num == 1 + return alg.algs[1] + else + return alg.algs[2] + end end else return _eval_index(identity, alg.algs, integrator.cache.current) diff --git a/src/algorithms.jl b/src/algorithms.jl index a30e544070..b11392b8ad 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -3100,17 +3100,56 @@ end ######################################### -struct CompositeAlgorithm{T, F} <: OrdinaryDiffEqCompositeAlgorithm +struct CompositeAlgorithm{Fallbacks, T, F} <: OrdinaryDiffEqCompositeAlgorithm algs::T choice_function::F + function CompositeAlgorithm(algs, choice_function, fallbacks_enabled::Val{X} = Val(true)) where X + new{X, typeof(algs), typeof(choice_function)}(algs, choice_function) + end end +allowfallbacks(::CompositeAlgorithm{Fallbacks, T, F}) where {Fallbacks, T, F} = Fallbacks + TruncatedStacktraces.@truncate_stacktrace CompositeAlgorithm 1 if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!) Base.Experimental.silence!(CompositeAlgorithm) end +mutable struct AutoSwitchCache{Trait, nAlg, sAlg, tolType, T} + algtrait::Trait + count::Int + successive_switches::Int + nonstiffalg::nAlg + stiffalg::sAlg + is_stiffalg::Bool + maxstiffstep::Int + maxnonstiffstep::Int + nonstifftol::tolType + stifftol::tolType + dtfac::T + stiffalgfirst::Bool + switch_max::Int + current::Int +end + +struct AutoSwitch{Trait, nAlg, sAlg, tolType, T} + algtrait::Trait + nonstiffalg::nAlg + stiffalg::sAlg + maxstiffstep::Int + maxnonstiffstep::Int + nonstifftol::tolType + stifftol::tolType + dtfac::T + stiffalgfirst::Bool + switch_max::Int +end + +struct DefaultODESolver end +const DefaultSolverAlgorithm = Union{CompositeAlgorithm{false, <:Tuple, <:AutoSwitch{DefaultODESolver}}, +CompositeAlgorithm{false, <:Tuple, <:AutoSwitchCache{DefaultODESolver}}} + ################################################################################ """ MEBDF2: Multistep Method diff --git a/src/caches/basic_caches.jl b/src/caches/basic_caches.jl index 4c627d7b71..658518a59c 100644 --- a/src/caches/basic_caches.jl +++ b/src/caches/basic_caches.jl @@ -4,7 +4,7 @@ abstract type OrdinaryDiffEqMutableCache <: OrdinaryDiffEqCache end struct ODEEmptyCache <: OrdinaryDiffEqConstantCache end struct ODEChunkCache{CS} <: OrdinaryDiffEqConstantCache end -mutable struct CompositeCache{T, F} <: OrdinaryDiffEqCache +mutable struct CompositeCache{Fallbacks, T, F} <: OrdinaryDiffEqCache caches::T choice_function::F current::Int @@ -16,11 +16,11 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!) Base.Experimental.silence!(CompositeCache) end -function alg_cache(alg::CompositeAlgorithm{Tuple{T1, T2}, F}, u, rate_prototype, +function alg_cache(alg::CompositeAlgorithm{Fallbacks, Tuple{T1, T2}, F}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{V}) where {T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits, + ::Val{V}) where {Fallbacks, T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} caches = (alg_cache(alg.algs[1], u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, @@ -28,16 +28,18 @@ function alg_cache(alg::CompositeAlgorithm{Tuple{T1, T2}, F}, u, rate_prototype, alg_cache(alg.algs[2], u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V))) - CompositeCache(caches, alg.choice_function, 1) + CompositeCache{Fallbacks, typeof(caches), typeof(alg.choice_function)}( + caches, alg.choice_function, 1) end -function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits}, +function alg_cache(alg::CompositeAlgorithm{Fallbacks}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{V}) where {V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + ::Val{V}) where {Fallbacks, V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} caches = __alg_cache(alg.algs, u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)) - CompositeCache(caches, alg.choice_function, 1) + CompositeCache{Fallbacks, typeof(caches), typeof(alg.choice_function)}( + caches, alg.choice_function, 1) end # map + closure approach doesn't infer diff --git a/src/caches/verner_caches.jl b/src/caches/verner_caches.jl index a87bedcccd..6a95f6188f 100644 --- a/src/caches/verner_caches.jl +++ b/src/caches/verner_caches.jl @@ -20,6 +20,7 @@ stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern6Cache 1 @@ -44,11 +45,12 @@ function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern6Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, utilde, tmp, rtmp, atmp, tab, - alg.stage_limiter!, alg.step_limiter!, alg.thread) + alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end struct Vern6ConstantCache{TabType} <: OrdinaryDiffEqConstantCache tab::TabType + lazy::Bool end function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, @@ -56,7 +58,7 @@ function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tab = Vern6Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - Vern6ConstantCache(tab) + Vern6ConstantCache(tab, alg.lazy) end @cache struct Vern7Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter, @@ -81,6 +83,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern7Cache 1 @@ -105,16 +108,18 @@ function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern7Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, utilde, tmp, rtmp, atmp, - alg.stage_limiter!, alg.step_limiter!, alg.thread) + alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end -struct Vern7ConstantCache <: OrdinaryDiffEqConstantCache end +struct Vern7ConstantCache <: OrdinaryDiffEqConstantCache + lazy::Bool +end function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - Vern7ConstantCache() + Vern7ConstantCache(alg.lazy) end @cache struct Vern8Cache{uType, rateType, uNoUnitsType, TabType, StageLimiter, StepLimiter, @@ -143,6 +148,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern8Cache 1 @@ -171,11 +177,12 @@ function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern8Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, utilde, - tmp, rtmp, atmp, tab, alg.stage_limiter!, alg.step_limiter!, alg.thread) + tmp, rtmp, atmp, tab, alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end struct Vern8ConstantCache{TabType} <: OrdinaryDiffEqConstantCache tab::TabType + lazy::Bool end function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, @@ -183,7 +190,7 @@ function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tab = Vern8Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - Vern8ConstantCache(tab) + Vern8ConstantCache(tab, alg.lazy) end @cache struct Vern9Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter, @@ -214,6 +221,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern9Cache 1 @@ -245,14 +253,16 @@ function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits}, rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern9Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16, utilde, tmp, rtmp, atmp, alg.stage_limiter!, alg.step_limiter!, - alg.thread) + alg.thread, alg.lazy) end -struct Vern9ConstantCache <: OrdinaryDiffEqConstantCache end +struct Vern9ConstantCache <: OrdinaryDiffEqConstantCache + lazy::Bool +end function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - Vern9ConstantCache() + Vern9ConstantCache(alg.lazy) end diff --git a/src/composite_algs.jl b/src/composite_algs.jl index a6474fa00b..9f07defcb3 100644 --- a/src/composite_algs.jl +++ b/src/composite_algs.jl @@ -1,34 +1,12 @@ -mutable struct AutoSwitchCache{nAlg, sAlg, tolType, T} - count::Int - successive_switches::Int - nonstiffalg::nAlg - stiffalg::sAlg - is_stiffalg::Bool - maxstiffstep::Int - maxnonstiffstep::Int - nonstifftol::tolType - stifftol::tolType - dtfac::T - stiffalgfirst::Bool - switch_max::Int -end +### AutoSwitch +### Designed to switch between two solvers, stiff and non-stiff -struct AutoSwitch{nAlg, sAlg, tolType, T} - nonstiffalg::nAlg - stiffalg::sAlg - maxstiffstep::Int - maxnonstiffstep::Int - nonstifftol::tolType - stifftol::tolType - dtfac::T - stiffalgfirst::Bool - switch_max::Int -end -function AutoSwitch(nonstiffalg, stiffalg; maxstiffstep = 10, maxnonstiffstep = 3, +function AutoSwitch(nonstiffalg, stiffalg, algtrait = nothing; + maxstiffstep = 10, maxnonstiffstep = 3, nonstifftol = 9 // 10, stifftol = 9 // 10, dtfac = 2, stiffalgfirst = false, switch_max = 5) - AutoSwitch(nonstiffalg, stiffalg, maxstiffstep, maxnonstiffstep, + AutoSwitch(algtrait, nonstiffalg, stiffalg, maxstiffstep, maxnonstiffstep, promote(nonstifftol, stifftol)..., dtfac, stiffalgfirst, switch_max) end @@ -52,7 +30,11 @@ function is_stiff(integrator, alg, ntol, stol, is_stiffalg) end function (AS::AutoSwitchCache)(integrator) - integrator.iter == 0 && return Int(AS.stiffalgfirst) + 1 + if AS.current == 0 + AS.current = Int(AS.stiffalgfirst) + 1 + return AS.current + end + dt = integrator.dt # Successive stiffness test positives are counted by a positive integer, # and successive stiffness test negatives are counted by a negative integer @@ -67,12 +49,18 @@ function (AS::AutoSwitchCache)(integrator) integrator.dt = dt / AS.dtfac AS.is_stiffalg = false end - return Int(AS.is_stiffalg) + 1 + AS.current = Int(AS.is_stiffalg) + 1 + return AS.current end -function AutoAlgSwitch(nonstiffalg, stiffalg; kwargs...) - AS = AutoSwitch(nonstiffalg, stiffalg; kwargs...) - CompositeAlgorithm((nonstiffalg, stiffalg), AS) +function AutoAlgSwitch(nonstiffalg::OrdinaryDiffEqAlgorithm, stiffalg::OrdinaryDiffEqAlgorithm, algtrait = nothing; kwargs...) + AS = AutoSwitch(nonstiffalg, stiffalg, algtrait; kwargs...) + CompositeAlgorithm((nonstiffalg, stiffalg), AS, Val(false)) +end + +function AutoAlgSwitch(nonstiffalg::Tuple, stiffalg::Tuple, algtrait; kwargs...) + AS = AutoSwitch(nonstiffalg, stiffalg, algtrait; kwargs...) + CompositeAlgorithm((nonstiffalg..., stiffalg...), AS, Val(false)) end AutoTsit5(alg; kwargs...) = AutoAlgSwitch(Tsit5(), alg; kwargs...) @@ -81,3 +69,114 @@ AutoVern6(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern6(lazy = lazy), alg; AutoVern7(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern7(lazy = lazy), alg; kwargs...) AutoVern8(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern8(lazy = lazy), alg; kwargs...) AutoVern9(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern9(lazy = lazy), alg; kwargs...) + +### Default ODE Solver + +EnumX.@enumx DefaultSolverChoice begin + Tsit5 = 1 + Vern7 = 2 + Rosenbrock23 = 3 + Rodas5P = 4 + FBDF = 5 + KrylovFBDF = 6 +end + +const NUM_NONSTIFF = 2 +const NUM_STIFF = 4 +const LOW_TOL = 1e-6 +const MED_TOL = 1e-2 +const EXTREME_TOL = 1e-9 +const SMALLSIZE = 50 +const MEDIUMSIZE = 500 +const STABILITY_SIZES = (alg_stability_size(Tsit5()), alg_stability_size(Vern7())) +const DEFAULTBETA2S = (beta2_default(Tsit5()), beta2_default(Vern7()), beta2_default(Rosenbrock23()), beta2_default(Rodas5P()), beta2_default(FBDF()), beta2_default(FBDF())) +const DEFAULTBETA1S = (beta1_default(Tsit5(),DEFAULTBETA2S[1]), beta1_default(Vern7(),DEFAULTBETA2S[2]), + beta1_default(Rosenbrock23(), DEFAULTBETA2S[3]), beta1_default(Rodas5P(), DEFAULTBETA2S[4]), + beta1_default(FBDF(), DEFAULTBETA2S[5]), beta1_default(FBDF(), DEFAULTBETA2S[6])) + +callbacks_exists(integrator) = !isempty(integrator.opts.callbacks) +current_nonstiff(current) = ifelse(current <= NUM_NONSTIFF,current,current-NUM_STIFF) + +function DefaultODEAlgorithm(; lazy = true, stiffalgfirst = false, kwargs...) + nonstiff = (Tsit5(), Vern7(lazy = lazy)) + stiff = (Rosenbrock23(;kwargs...), Rodas5P(;kwargs...), FBDF(;kwargs...), FBDF(;linsolve = LinearSolve.KrylovJL_GMRES())) + AutoAlgSwitch(nonstiff, stiff, DefaultODESolver(); stiffalgfirst) +end + +function is_stiff(integrator, alg, ntol, stol, is_stiffalg, current) + eigen_est, dt = integrator.eigen_est, integrator.dt + stiffness = abs(eigen_est * dt / STABILITY_SIZES[nonstiffchoice(integrator.opts.reltol)]) # `abs` here is just for safety + tol = is_stiffalg ? stol : ntol + os = oneunit(stiffness) + bool = stiffness > os * tol + + if !bool + integrator.alg.choice_function.successive_switches += 1 + else + integrator.alg.choice_function.successive_switches = 0 + end + + integrator.do_error_check = (integrator.alg.choice_function.successive_switches > + integrator.alg.choice_function.switch_max || !bool) || + is_stiffalg + bool +end + +function nonstiffchoice(reltol) + x = if reltol < LOW_TOL + DefaultSolverChoice.Vern7 + else + DefaultSolverChoice.Tsit5 + end + Int(x) +end + +function stiffchoice(reltol, len) + x = if len > MEDIUMSIZE + DefaultSolverChoice.KrylovFBDF + elseif len > SMALLSIZE + DefaultSolverChoice.FBDF + else + if reltol < LOW_TOL + DefaultSolverChoice.Rodas5P + else + DefaultSolverChoice.Rosenbrock23 + end + end + Int(x) +end + +function (AS::AutoSwitchCache{DefaultODESolver})(integrator) + + len = length(integrator.u) + reltol = integrator.opts.reltol + + # Chooose the starting method + if AS.current == 0 + choice = if AS.stiffalgfirst || integrator.f.mass_matrix != I + stiffchoice(reltol, len) + else + nonstiffchoice(reltol) + end + AS.current = choice + return AS.current + end + + dt = integrator.dt + # Successive stiffness test positives are counted by a positive integer, + # and successive stiffness test negatives are counted by a negative integer + AS.count = is_stiff(integrator, AS.nonstiffalg, AS.nonstifftol, AS.stifftol, + AS.is_stiffalg, AS.current) ? + AS.count < 0 ? 1 : AS.count + 1 : + AS.count > 0 ? -1 : AS.count - 1 + if (!AS.is_stiffalg && AS.count > AS.maxstiffstep) + integrator.dt = dt * AS.dtfac + AS.is_stiffalg = true + AS.current = stiffchoice(reltol, len) + elseif (AS.is_stiffalg && AS.count < -AS.maxnonstiffstep) + integrator.dt = dt / AS.dtfac + AS.is_stiffalg = false + AS.current = nonstiffchoice(reltol) + end + return AS.current +end diff --git a/src/perform_step/composite_perform_step.jl b/src/perform_step/composite_perform_step.jl index ae872fd521..d2bc4ee5a1 100644 --- a/src/perform_step/composite_perform_step.jl +++ b/src/perform_step/composite_perform_step.jl @@ -1,39 +1,4 @@ -#= - -Maybe do generated functions to reduce dispatch times? - -f(x) = x -g(x,i) = f(x[i]) -g{i}(x,::Type{Val{i}}) = f(x[i]) -@generated function gg(tup::Tuple, num) - N = length(tup.parameters) - :(@nif $(N+1) i->(i == num) i->(f(tup[i])) i->error("unreachable")) - end -h(i) = g((1,1.0,"foo"), i) -h2{i}(::Type{Val{i}}) = g((1,1.0,"foo"), Val{i}) -h3(i) = gg((1,1.0,"foo"), i) -@benchmark h(1) -mean time: 31.822 ns (0.00% GC) -@benchmark h2(Val{1}) -mean time: 1.585 ns (0.00% GC) -@benchmark h3(1) -mean time: 6.423 ns (0.00% GC) - -@generated function foo(tup::Tuple, num) - N = length(tup.parameters) - :(@nif $(N+1) i->(i == num) i->(tup[i]) i->error("unreachable")) -end - -@code_typed foo((1,1.0), 1) - -@generated function perform_step!(integrator, cache::CompositeCache, repeat_step=false) - N = length(cache.parameters) - :(@nif $(N+1) i->(i == num) i->(tup[i]) i->error("unreachable")) -end - -=# - -function initialize!(integrator, cache::CompositeCache) +function initialize!(integrator, cache::CompositeCache{Fallbacks}) where Fallbacks cache.current = cache.choice_function(integrator) if cache.current == 1 initialize!(integrator, @inbounds(cache.caches[1])) @@ -42,7 +7,27 @@ function initialize!(integrator, cache::CompositeCache) # the controller was initialized by default for integrator.alg.algs[1] reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], integrator.alg.algs[2]) - else + elseif cache.current == 3 + initialize!(integrator, @inbounds(cache.caches[3])) + # the controller was initialized by default for integrator.alg.algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], + integrator.alg.algs[3]) + elseif cache.current == 4 + initialize!(integrator, @inbounds(cache.caches[4])) + # the controller was initialized by default for integrator.alg.algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], + integrator.alg.algs[4]) + elseif cache.current == 5 + initialize!(integrator, @inbounds(cache.caches[5])) + # the controller was initialized by default for integrator.alg.algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], + integrator.alg.algs[5]) + elseif cache.current == 6 + initialize!(integrator, @inbounds(cache.caches[6])) + # the controller was initialized by default for integrator.alg.algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], + integrator.alg.algs[6]) + elseif Fallbacks initialize!(integrator, @inbounds(cache.caches[cache.current])) reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], integrator.alg.algs[cache.current]) @@ -50,7 +35,7 @@ function initialize!(integrator, cache::CompositeCache) resize!(integrator.k, integrator.kshortsize) end -function initialize!(integrator, cache::CompositeCache{Tuple{T1, T2}, F}) where {T1, T2, F} +function initialize!(integrator, cache::CompositeCache{false, Tuple{T1, T2}, F}) where {T1, T2, F} cache.current = cache.choice_function(integrator) if cache.current == 1 initialize!(integrator, @inbounds(cache.caches[1])) @@ -75,17 +60,25 @@ function ensure_behaving_adaptivity!(integrator, cache::CompositeCache) end end -function perform_step!(integrator, cache::CompositeCache, repeat_step = false) +function perform_step!(integrator, cache::CompositeCache{Fallbacks}, repeat_step = false) where Fallbacks if cache.current == 1 perform_step!(integrator, @inbounds(cache.caches[1]), repeat_step) elseif cache.current == 2 perform_step!(integrator, @inbounds(cache.caches[2]), repeat_step) - else + elseif cache.current == 3 + perform_step!(integrator, @inbounds(cache.caches[3]), repeat_step) + elseif cache.current == 4 + perform_step!(integrator, @inbounds(cache.caches[4]), repeat_step) + elseif cache.current == 5 + perform_step!(integrator, @inbounds(cache.caches[5]), repeat_step) + elseif cache.current == 6 + perform_step!(integrator, @inbounds(cache.caches[6]), repeat_step) + elseif Fallbacks perform_step!(integrator, @inbounds(cache.caches[cache.current]), repeat_step) end end -function perform_step!(integrator, cache::CompositeCache{Tuple{T1, T2}, F}, +function perform_step!(integrator, cache::CompositeCache{false, Tuple{T1, T2}, F}, repeat_step = false) where {T1, T2, F} if cache.current == 1 perform_step!(integrator, @inbounds(cache.caches[1]), repeat_step) @@ -97,7 +90,7 @@ end choose_algorithm!(integrator, cache::OrdinaryDiffEqCache) = nothing function choose_algorithm!(integrator, - cache::CompositeCache{Tuple{T1, T2}, F}) where {T1, T2, F} + cache::CompositeCache{false, Tuple{T1, T2}, F}) where {T1, T2, F} new_current = cache.choice_function(integrator) old_current = cache.current @inbounds if new_current != old_current @@ -121,36 +114,111 @@ function choose_algorithm!(integrator, end end -function choose_algorithm!(integrator, cache::CompositeCache) +function choose_algorithm!(integrator, cache::CompositeCache{Fallbacks, T, F}) where {Fallbacks, T, F} + new_current = cache.choice_function(integrator) + old_current = cache.current + !Fallbacks && error("Hitting fallbacks!") + @inbounds if new_current != old_current + cache.current = new_current + initialize!(integrator, @inbounds(cache.caches[new_current])) + + controller.beta2 = beta2_default(alg2) + controller.beta1 = beta2_default(alg2) + DEFAULTBETA2S + + reset_alg_dependent_opts!(integrator, integrator.alg.algs[old_current], + integrator.alg.algs[new_current]) + transfer_cache!(integrator, integrator.cache.caches[old_current], + integrator.cache.caches[new_current]) + end +end + +function choose_algorithm!(integrator, cache::CompositeCache{Fallbacks, T, <:AutoSwitchCache{DefaultODESolver}}) where {Fallbacks, T} new_current = cache.choice_function(integrator) old_current = cache.current @inbounds if new_current != old_current cache.current = new_current if new_current == 1 - initialize!(integrator, @inbounds(cache.caches[1])) + initialize!(integrator, @inbounds(cache.caches[1])); nothing elseif new_current == 2 - initialize!(integrator, @inbounds(cache.caches[2])) + initialize!(integrator, @inbounds(cache.caches[2])); nothing + elseif new_current == 3 + initialize!(integrator, @inbounds(cache.caches[3])); nothing + elseif new_current == 4 + initialize!(integrator, @inbounds(cache.caches[4])); nothing + elseif new_current == 5 + initialize!(integrator, @inbounds(cache.caches[5])); nothing + elseif new_current == 6 + initialize!(integrator, @inbounds(cache.caches[6])); nothing else - initialize!(integrator, @inbounds(cache.caches[new_current])) + error("Unrachable reached. Report this error") end - if old_current == 1 && new_current == 2 - reset_alg_dependent_opts!(integrator, integrator.alg.algs[1], - integrator.alg.algs[2]) - transfer_cache!(integrator, integrator.cache.caches[1], - integrator.cache.caches[2]) - elseif old_current == 2 && new_current == 1 - reset_alg_dependent_opts!(integrator, integrator.alg.algs[2], - integrator.alg.algs[1]) - transfer_cache!(integrator, integrator.cache.caches[2], - integrator.cache.caches[1]) + + # dtchangable, qmin_default, qmax_default, and isadaptive ignored since all same + integrator.opts.controller.beta1 = DEFAULTBETA1S[new_current] + integrator.opts.controller.beta2 = DEFAULTBETA2S[new_current] + end +end + +#= +""" +function choose_algorithm!(integrator, cache::CompositeCache{Fallbacks}) where Fallbacks + new_current = cache.choice_function(integrator) + old_current = cache.current + @inbounds if new_current != old_current + cache.current = new_current + initialize!(integrator, @inbounds(cache.caches[new_current])) + reset_alg_dependent_opts!(integrator, integrator.alg.algs[old_current], + integrator.alg.algs[new_current]) + transfer_cache!(integrator, integrator.cache.caches[old_current], + integrator.cache.caches[new_current]) + end +end +""" +@generated function choose_algorithm!(integrator, cache::CompositeCache{Fallbacks, T, F}) where {Fallbacks, T, F} + initialize_ex = :() + for idx in 1:length(T.types) + newex = quote + initialize!(integrator, @inbounds(cache.caches[$idx])) + end + initialize_ex = if initialize_ex == :() + Expr(:elseif, :(new_current == $idx)), newex, + :(error("Algorithm Choice not Allowed")) else - reset_alg_dependent_opts!(integrator, integrator.alg.algs[old_current], - integrator.alg.algs[new_current]) - transfer_cache!(integrator, integrator.cache.caches[old_current], - integrator.cache.caches[new_current]) + Expr(:elseif, :(new_current == $idx), newex, initialize_ex) + end + end + initialize_ex = Expr(:if, initialize_ex.args...) + + swap_ex = :() + for idx in 1:length(T.types), idx2 in 1:length(T.types) + new_swap_ex = quote + reset_alg_dependent_opts!(integrator, integrator.alg.algs[$idx], + integrator.alg.algs[$idx2]) + transfer_cache!(integrator, integrator.cache.caches[$idx], + integrator.cache.caches[$idx2]) + end + swap_ex = if swap_ex == :() + Expr(:elseif, :(old_current == $idx && new_current == $idx2)), new_swap_ex, + :(error("Algorithm Choice not Allowed")) + else + Expr(:elseif, :(new_current == $idx), swap_ex, swap_ex) + end + end + swap_ex = Expr(:if, swap_ex.args...) + + + quote + new_current = cache.choice_function(integrator) + old_current = cache.current + @inbounds if new_current != old_current + cache.current = new_current + $initialize_ex + $swap_ex end end end +=# """ If no user default, then this will change the default to the defaults @@ -170,6 +238,7 @@ function reset_alg_dependent_opts!(integrator, alg1, alg2) integrator.opts.qmax == qmax_default(alg2) end reset_alg_dependent_opts!(integrator.opts.controller, alg1, alg2) + nothing end # Write how to transfer the cache variables from one cache to the other diff --git a/src/perform_step/verner_rk_perform_step.jl b/src/perform_step/verner_rk_perform_step.jl index 5f2d876354..283cf7b95a 100644 --- a/src/perform_step/verner_rk_perform_step.jl +++ b/src/perform_step/verner_rk_perform_step.jl @@ -2,7 +2,7 @@ function initialize!(integrator, cache::Vern6ConstantCache) integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) # Pre-start fsal integrator.stats.nf += 1 alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) + cache.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -13,7 +13,7 @@ function initialize!(integrator, cache::Vern6ConstantCache) end integrator.k[integrator.kshortsize] = integrator.fsallast - if !alg.lazy + if !cache.lazy @inbounds for i in 10:12 integrator.k[i] = zero(integrator.fsalfirst) end @@ -63,7 +63,7 @@ end integrator.k[9] = k9 alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c10, a1001, a1004, a1005, a1006, a1007, a1008, a1009, c11, a1101, a1104, a1105, a1106, a1107, a1108, a1109, a1110, c12, a1201, a1204, a1205, a1206, a1207, a1208, a1209, a1210, a1211 = cache.tab.extra @@ -86,7 +86,7 @@ end function initialize!(integrator, cache::Vern6Cache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) + cache.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) integrator.fsalfirst = cache.k1 integrator.fsallast = cache.k9 @unpack k = integrator @@ -101,7 +101,7 @@ function initialize!(integrator, cache::Vern6Cache) k[8] = cache.k8 k[9] = cache.k9 # Set the pointers - if !alg.lazy + if !cache.lazy k[10] = similar(cache.k1) k[11] = similar(cache.k1) k[12] = similar(cache.k1) @@ -174,7 +174,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c10, a1001, a1004, a1005, a1006, a1007, a1008, a1009, c11, a1101, a1104, a1105, a1106, a1107, a1108, a1109, a1110, c12, a1201, a1204, a1205, a1206, a1207, a1208, a1209, a1210, a1211 = cache.tab.extra @@ -206,7 +206,7 @@ end function initialize!(integrator, cache::Vern7ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -267,7 +267,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @OnDemandTableauExtract Vern7ExtraStages T T2 @@ -302,7 +302,7 @@ function initialize!(integrator, cache::Vern7Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) resize!(k, integrator.kshortsize) k[1] = k1 k[2] = k2 @@ -315,7 +315,7 @@ function initialize!(integrator, cache::Vern7Cache) k[9] = k9 k[10] = k10 # Setup pointers - if !alg.lazy + if !cache.lazy k[11] = similar(cache.k1) k[12] = similar(cache.k1) k[13] = similar(cache.k1) @@ -406,7 +406,7 @@ end integrator.EEst = integrator.opts.internalnorm(atmp, t) end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack tmp = cache @@ -462,7 +462,7 @@ end function initialize!(integrator, cache::Vern8ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) + cache.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -539,7 +539,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c14, a1401, a1406, a1407, a1408, a1409, a1410, a1411, a1412, c15, a1501, a1506, a1507, a1508, a1509, a1510, a1511, a1512, a1514, c16, a1601, a1606, a1607, a1608, a1609, a1610, a1611, a1612, a1614, a1615, c17, a1701, a1706, a1707, a1708, a1709, a1710, a1711, a1712, a1714, a1715, a1716, c18, a1801, a1806, a1807, a1808, a1809, a1810, a1811, a1812, a1814, a1815, a1816, a1817, c19, a1901, a1906, a1907, a1908, a1909, a1910, a1911, a1912, a1914, a1915, a1916, a1917, c20, a2001, a2006, a2007, a2008, a2009, a2010, a2011, a2012, a2014, a2015, a2016, a2017, c21, a2101, a2106, a2107, a2108, a2109, a2110, a2111, a2112, a2114, a2115, a2116, a2117 = cache.tab.extra @@ -587,7 +587,7 @@ function initialize!(integrator, cache::Vern8Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) + cache.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) resize!(k, integrator.kshortsize) k[1] = k1 k[2] = k2 @@ -603,7 +603,7 @@ function initialize!(integrator, cache::Vern8Cache) k[12] = k12 k[13] = k13 # Setup pointers - if !alg.lazy + if !cache.lazy for i in 14:21 k[i] = similar(cache.k1) end @@ -709,7 +709,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c14, a1401, a1406, a1407, a1408, a1409, a1410, a1411, a1412, c15, a1501, a1506, a1507, a1508, a1509, a1510, a1511, a1512, a1514, c16, a1601, a1606, a1607, a1608, a1609, a1610, a1611, a1612, a1614, a1615, c17, a1701, a1706, a1707, a1708, a1709, a1710, a1711, a1712, a1714, a1715, a1716, c18, a1801, a1806, a1807, a1808, a1809, a1810, a1811, a1812, a1814, a1815, a1816, a1817, c19, a1901, a1906, a1907, a1908, a1909, a1910, a1911, a1912, a1914, a1915, a1916, a1917, c20, a2001, a2006, a2007, a2008, a2009, a2010, a2011, a2012, a2014, a2015, a2016, a2017, c21, a2101, a2106, a2107, a2108, a2109, a2110, a2111, a2112, a2114, a2115, a2116, a2117 = cache.tab.extra @@ -796,7 +796,7 @@ end function initialize!(integrator, cache::Vern9ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -880,7 +880,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @OnDemandTableauExtract Vern9ExtraStages T T2 @@ -940,7 +940,7 @@ function initialize!(integrator, cache::Vern9Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) resize!(k, integrator.kshortsize) # k2, k3,k4,k5,k6,k7 are not used in the code (not even in interpolations), we dont need their pointers. # So we mapped k[2] (from integrator) with k8 (from cache), k[3] with k9 and so on. @@ -955,7 +955,7 @@ function initialize!(integrator, cache::Vern9Cache) k[9] = k15 k[10] = k16 # Setup pointers - if !alg.lazy + if !cache.lazy for i in 11:20 k[i] = similar(cache.k1) end @@ -1082,7 +1082,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack tmp = cache diff --git a/src/solve.jl b/src/solve.jl index c0a01ad126..7e037ea999 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,682 +1,682 @@ -function DiffEqBase.__solve(prob::Union{DiffEqBase.AbstractODEProblem, - DiffEqBase.AbstractDAEProblem}, - alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, args...; - kwargs...) - integrator = DiffEqBase.__init(prob, alg, args...; kwargs...) - solve!(integrator) - integrator.sol -end - -function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem, - DiffEqBase.AbstractDAEProblem}, - alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, - timeseries_init = (), - ts_init = (), - ks_init = (), - recompile::Type{Val{recompile_flag}} = Val{true}; - saveat = (), - tstops = (), - d_discontinuities = (), - save_idxs = nothing, - save_everystep = isempty(saveat), - save_on = true, - save_start = save_everystep || isempty(saveat) || - saveat isa Number || prob.tspan[1] in saveat, - save_end = nothing, - callback = nothing, - dense = save_everystep && - !(alg isa Union{DAEAlgorithm, FunctionMap}) && - isempty(saveat), - calck = (callback !== nothing && callback !== CallbackSet()) || - (dense) || !isempty(saveat), # and no dense output - dt = alg isa FunctionMap && isempty(tstops) ? - eltype(prob.tspan)(1) : eltype(prob.tspan)(0), - dtmin = eltype(prob.tspan)(0), - dtmax = eltype(prob.tspan)((prob.tspan[end] - prob.tspan[1])), - force_dtmin = false, - adaptive = anyadaptive(alg), - gamma = gamma_default(alg), - abstol = nothing, - reltol = nothing, - qmin = qmin_default(alg), - qmax = qmax_default(alg), - qsteady_min = qsteady_min_default(alg), - qsteady_max = qsteady_max_default(alg), - beta1 = nothing, - beta2 = nothing, - qoldinit = anyadaptive(alg) ? 1 // 10^4 : 0, - controller = nothing, - fullnormalize = true, - failfactor = 2, - maxiters = anyadaptive(alg) ? 1000000 : typemax(Int), - internalnorm = ODE_DEFAULT_NORM, - internalopnorm = LinearAlgebra.opnorm, - isoutofdomain = ODE_DEFAULT_ISOUTOFDOMAIN, - unstable_check = ODE_DEFAULT_UNSTABLE_CHECK, - verbose = true, - timeseries_errors = true, - dense_errors = false, - advance_to_tstop = false, - stop_at_next_tstop = false, - initialize_save = true, - progress = false, - progress_steps = 1000, - progress_name = "ODE", - progress_message = ODE_DEFAULT_PROG_MESSAGE, - progress_id = gensym("OrdinaryDiffEq"), - userdata = nothing, - allow_extrapolation = alg_extrapolates(alg), - initialize_integrator = true, - alias_u0 = false, - alias_du0 = false, - initializealg = DefaultInit(), - kwargs...) where {recompile_flag} - if prob isa DiffEqBase.AbstractDAEProblem && alg isa OrdinaryDiffEqAlgorithm - error("You cannot use an ODE Algorithm with a DAEProblem") - end - - if prob isa DiffEqBase.AbstractODEProblem && alg isa DAEAlgorithm - error("You cannot use an DAE Algorithm with a ODEProblem") - end - - if prob isa DiffEqBase.ODEProblem - if !(prob.f isa DiffEqBase.DynamicalODEFunction) && alg isa PartitionedAlgorithm - error("You can not use a solver designed for partitioned ODE with this problem. Please choose a solver suitable for your problem") - end - end - - if prob.f isa DynamicalODEFunction && prob.f.mass_matrix isa Tuple - if any(mm != I for mm in prob.f.mass_matrix) - error("This solver is not able to use mass matrices.") - end - elseif !(prob isa DiscreteProblem) && - !(prob isa DiffEqBase.AbstractDAEProblem) && - !is_mass_matrix_alg(alg) && - prob.f.mass_matrix != I - error("This solver is not able to use mass matrices.") - end - - if alg isa OrdinaryDiffEqRosenbrockAdaptiveAlgorithm && - prob.f.mass_matrix isa AbstractMatrix && - all(isequal(0), prob.f.mass_matrix) - # technically this should also warn for zero operators but those are hard to check for - alg isa Union{Rosenbrock23, Rosenbrock32} && error("Rosenbrock23 and Rosenbrock32 require at least one differential variable to produce valid solutions") - if (dense || !isempty(saveat)) && verbose - @warn("Rosenbrock methods on equations without differential states do not bound the error on interpolations.") - end - end - - if !isempty(saveat) && dense - @warn("Dense output is incompatible with saveat. Please use the SavingCallback from the Callback Library to mix the two behaviors.") - end - - progress && @logmsg(LogLevel(-1), progress_name, _id=progress_id, progress=0) - - tType = eltype(prob.tspan) - tspan = prob.tspan - tdir = sign(tspan[end] - tspan[1]) - - t = tspan[1] - - if (((!(alg isa OrdinaryDiffEqAdaptiveAlgorithm) && - !(alg isa OrdinaryDiffEqCompositeAlgorithm) && - !(alg isa DAEAlgorithm)) || !adaptive || !isadaptive(alg)) && - dt == tType(0) && isempty(tstops)) && - !(alg isa Union{FunctionMap, LinearExponential}) - error("Fixed timestep methods require a choice of dt or choosing the tstops") - end - - isdae = alg isa DAEAlgorithm || (!(prob isa DiscreteProblem) && - prob.f.mass_matrix != I && - !(prob.f.mass_matrix isa Tuple) && - ArrayInterface.issingular(prob.f.mass_matrix)) - if alg isa CompositeAlgorithm && alg.choice_function isa AutoSwitch - auto = alg.choice_function - _alg = CompositeAlgorithm(alg.algs, - AutoSwitchCache(0, 0, - auto.nonstiffalg, - auto.stiffalg, - auto.stiffalgfirst, - auto.maxstiffstep, - auto.maxnonstiffstep, - auto.nonstifftol, - auto.stifftol, - auto.dtfac, - auto.stiffalgfirst, - auto.switch_max)) - else - _alg = alg - end - f = prob.f - p = prob.p - - # Get the control variables - - if alias_u0 - u = prob.u0 - else - u = recursivecopy(prob.u0) - end - - if _alg isa DAEAlgorithm - if alias_du0 - du = prob.du0 - else - du = recursivecopy(prob.du0) - end - duprev = recursivecopy(du) - else - du = nothing - duprev = nothing - end - - uType = typeof(u) - uBottomEltype = recursive_bottom_eltype(u) - uBottomEltypeNoUnits = recursive_unitless_bottom_eltype(u) - - uEltypeNoUnits = recursive_unitless_eltype(u) - tTypeNoUnits = typeof(one(tType)) - - if _alg isa FunctionMap - abstol_internal = false - elseif abstol === nothing - if uBottomEltypeNoUnits == uBottomEltype - abstol_internal = ForwardDiff.value(real(convert(uBottomEltype, - oneunit(uBottomEltype) * - 1 // 10^6))) - else - abstol_internal = ForwardDiff.value.(real.(oneunit.(u) .* 1 // 10^6)) - end - else - abstol_internal = real.(abstol) - end - - if _alg isa FunctionMap - reltol_internal = false - elseif reltol === nothing - if uBottomEltypeNoUnits == uBottomEltype - reltol_internal = real(convert(uBottomEltype, - oneunit(uBottomEltype) * 1 // 10^3)) - else - reltol_internal = real.(oneunit.(u) .* 1 // 10^3) - end - else - reltol_internal = real.(reltol) - end - - dtmax > zero(dtmax) && tdir < 0 && (dtmax *= tdir) # Allow positive dtmax, but auto-convert - # dtmin is all abs => does not care about sign already. - - if !isdae && isinplace(prob) && u isa AbstractArray && eltype(u) <: Number && - uBottomEltypeNoUnits == uBottomEltype && tType == tTypeNoUnits # Could this be more efficient for other arrays? - rate_prototype = recursivecopy(u) - elseif prob isa DAEProblem - rate_prototype = prob.du0 - else - if (uBottomEltypeNoUnits == uBottomEltype && tType == tTypeNoUnits) || - eltype(u) <: Enum - rate_prototype = u - else # has units! - rate_prototype = u / oneunit(tType) - end - end - rateType = typeof(rate_prototype) ## Can be different if united - - if isdae - if uBottomEltype == uBottomEltypeNoUnits - res_prototype = u - else - res_prototype = one(u) - end - resType = typeof(res_prototype) - end - - tstops_internal = initialize_tstops(tType, tstops, d_discontinuities, tspan) - saveat_internal = initialize_saveat(tType, saveat, tspan) - d_discontinuities_internal = initialize_d_discontinuities(tType, d_discontinuities, - tspan) - - callbacks_internal = CallbackSet(callback) - - max_len_cb = DiffEqBase.max_vector_callback_length_int(callbacks_internal) - if max_len_cb !== nothing - uBottomEltypeReal = real(uBottomEltype) - if isinplace(prob) - callback_cache = DiffEqBase.CallbackCache(u, max_len_cb, uBottomEltypeReal, - uBottomEltypeReal) - else - callback_cache = DiffEqBase.CallbackCache(max_len_cb, uBottomEltypeReal, - uBottomEltypeReal) - end - else - callback_cache = nothing - end - - ### Algorithm-specific defaults ### - if save_idxs === nothing - ksEltype = Vector{rateType} - else - ks_prototype = rate_prototype[save_idxs] - ksEltype = Vector{typeof(ks_prototype)} - end - - # Have to convert in case passed in wrong. - if save_idxs === nothing - timeseries = timeseries_init === () ? uType[] : - convert(Vector{uType}, timeseries_init) - else - u_initial = u[save_idxs] - timeseries = timeseries_init === () ? typeof(u_initial)[] : - convert(Vector{uType}, timeseries_init) - end - - ts = ts_init === () ? tType[] : convert(Vector{tType}, ts_init) - ks = ks_init === () ? ksEltype[] : convert(Vector{ksEltype}, ks_init) - alg_choice = _alg isa CompositeAlgorithm ? Int[] : () - - if (!adaptive || !isadaptive(_alg)) && save_everystep && tspan[2] - tspan[1] != Inf - if dt == 0 - steps = length(tstops) - else - # For fixed dt, the only time dtmin makes sense is if it's smaller than eps(). - # Therefore user specified dtmin doesn't matter, but we need to ensure dt>=eps() - # to prevent infinite loops. - abs(dt) < dtmin && - throw(ArgumentError("Supplied dt is smaller than dtmin")) - steps = ceil(Int, internalnorm((tspan[2] - tspan[1]) / dt, tspan[1])) - end - sizehint!(timeseries, steps + 1) - sizehint!(ts, steps + 1) - sizehint!(ks, steps + 1) - elseif save_everystep - sizehint!(timeseries, 50) - sizehint!(ts, 50) - sizehint!(ks, 50) - elseif !isempty(saveat_internal) - savelength = length(saveat_internal) + 1 - if save_start == false - savelength -= 1 - end - if save_end == false && prob.tspan[2] in saveat_internal.valtree - savelength -= 1 - end - sizehint!(timeseries, savelength) - sizehint!(ts, savelength) - sizehint!(ks, savelength) - else - sizehint!(timeseries, 2) - sizehint!(ts, 2) - sizehint!(ks, 2) - end - - QT, EEstT = if tTypeNoUnits <: Integer - typeof(qmin), typeof(qmin) - elseif prob isa DiscreteProblem - # The QT fields are not used for DiscreteProblems - constvalue(tTypeNoUnits), constvalue(tTypeNoUnits) - else - typeof(DiffEqBase.value(internalnorm(u, t))), typeof(internalnorm(u, t)) - end - - k = rateType[] - - if uses_uprev(_alg, adaptive) || calck - uprev = recursivecopy(u) - else - # Some algorithms do not use `uprev` explicitly. In that case, we can save - # some memory by aliasing `uprev = u`, e.g. for "2N" low storage methods. - uprev = u - end - if allow_extrapolation - uprev2 = recursivecopy(u) - else - uprev2 = uprev - end - - if prob isa DAEProblem - cache = alg_cache(_alg, du, u, res_prototype, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, - reltol_internal, p, calck, Val(isinplace(prob))) - else - cache = alg_cache(_alg, u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, - tTypeNoUnits, uprev, uprev2, f, t, dt, reltol_internal, p, calck, - Val(isinplace(prob))) - end - - # Setting up the step size controller - if (beta1 !== nothing || beta2 !== nothing) && controller !== nothing - throw(ArgumentError("Setting both the legacy PID parameters `beta1, beta2 = $((beta1, beta2))` and the `controller = $controller` is not allowed.")) - end - - if (beta1 !== nothing || beta2 !== nothing) - message = "Providing the legacy PID parameters `beta1, beta2` is deprecated. Use the keyword argument `controller` instead." - Base.depwarn(message, :init) - Base.depwarn(message, :solve) - end - - if controller === nothing - controller = default_controller(_alg, cache, qoldinit, beta1, beta2) - end - - save_end_user = save_end - save_end = save_end === nothing ? - save_everystep || isempty(saveat) || saveat isa Number || - prob.tspan[2] in saveat : save_end - - opts = DEOptions{typeof(abstol_internal), typeof(reltol_internal), - QT, tType, typeof(controller), - typeof(internalnorm), typeof(internalopnorm), - typeof(save_end_user), - typeof(callbacks_internal), - typeof(isoutofdomain), - typeof(progress_message), typeof(unstable_check), - typeof(tstops_internal), - typeof(d_discontinuities_internal), typeof(userdata), - typeof(save_idxs), - typeof(maxiters), typeof(tstops), - typeof(saveat), typeof(d_discontinuities)}(maxiters, save_everystep, - adaptive, abstol_internal, - reltol_internal, - QT(gamma), QT(qmax), - QT(qmin), - QT(qsteady_max), - QT(qsteady_min), - QT(qoldinit), - QT(failfactor), - tType(dtmax), tType(dtmin), - controller, - internalnorm, - internalopnorm, - save_idxs, tstops_internal, - saveat_internal, - d_discontinuities_internal, - tstops, saveat, - d_discontinuities, - userdata, progress, - progress_steps, - progress_name, - progress_message, - progress_id, - timeseries_errors, - dense_errors, dense, - save_on, save_start, - save_end, save_end_user, - callbacks_internal, - isoutofdomain, - unstable_check, - verbose, calck, force_dtmin, - advance_to_tstop, - stop_at_next_tstop) - - stats = SciMLBase.DEStats(0) - differential_vars = prob isa DAEProblem ? prob.differential_vars : get_differential_vars(f, u) - - if _alg isa OrdinaryDiffEqCompositeAlgorithm - id = CompositeInterpolationData(f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars) - sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, - dense = dense, k = ks, interp = id, - alg_choice = alg_choice, - calculate_error = false, stats = stats) - else - id = InterpolationData(f, timeseries, ts, ks, dense, cache, differential_vars) - sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, - dense = dense, k = ks, interp = id, - calculate_error = false, stats = stats) - end - - if recompile_flag == true - FType = typeof(f) - SolType = typeof(sol) - cacheType = typeof(cache) - else - FType = Function - if _alg isa OrdinaryDiffEqAlgorithm - SolType = DiffEqBase.AbstractODESolution - cacheType = OrdinaryDiffEqCache - else - SolType = DiffEqBase.AbstractDAESolution - cacheType = DAECache - end - end - - # rate/state = (state/time)/state = 1/t units, internalnorm drops units - # we don't want to differentiate through eigenvalue estimation - eigen_est = inv(one(tType)) - tprev = t - dtcache = tType(dt) - dtpropose = tType(dt) - iter = 0 - kshortsize = 0 - reeval_fsal = false - u_modified = false - EEst = EEstT(1) - just_hit_tstop = false - isout = false - accept_step = false - force_stepfail = false - last_stepfail = false - do_error_check = true - event_last_time = 0 - vector_event_last_time = 1 - last_event_error = _alg isa FunctionMap ? false : zero(uBottomEltypeNoUnits) - dtchangeable = isdtchangeable(_alg) - q11 = QT(1) - success_iter = 0 - erracc = QT(1) - dtacc = tType(1) - reinitiailize = true - saveiter = 0 # Starts at 0 so first save is at 1 - saveiter_dense = 0 - - integrator = ODEIntegrator{typeof(_alg), isinplace(prob), uType, typeof(du), - tType, typeof(p), - typeof(eigen_est), typeof(EEst), - QT, typeof(tdir), typeof(k), SolType, - FType, cacheType, - typeof(opts), fsal_typeof(_alg, rate_prototype), - typeof(last_event_error), typeof(callback_cache), - typeof(initializealg), typeof(differential_vars)}(sol, u, du, k, t, tType(dt), f, p, - uprev, uprev2, duprev, tprev, - _alg, dtcache, dtchangeable, - dtpropose, tdir, eigen_est, EEst, - QT(qoldinit), q11, - erracc, dtacc, success_iter, - iter, saveiter, saveiter_dense, cache, - callback_cache, - kshortsize, force_stepfail, - last_stepfail, - just_hit_tstop, do_error_check, - event_last_time, - vector_event_last_time, - last_event_error, accept_step, - isout, reeval_fsal, - u_modified, reinitiailize, isdae, - opts, stats, initializealg, differential_vars) - - if initialize_integrator - if isdae - DiffEqBase.initialize_dae!(integrator) - end - - if save_start - integrator.saveiter += 1 # Starts at 1 so first save is at 2 - integrator.saveiter_dense += 1 - copyat_or_push!(ts, 1, t) - if save_idxs === nothing - copyat_or_push!(timeseries, 1, integrator.u) - copyat_or_push!(ks, 1, [rate_prototype]) - else - copyat_or_push!(timeseries, 1, u_initial, Val{false}) - copyat_or_push!(ks, 1, [ks_prototype]) - end - else - integrator.saveiter = 0 # Starts at 0 so first save is at 1 - integrator.saveiter_dense = 0 - end - - initialize_callbacks!(integrator, initialize_save) - initialize!(integrator, integrator.cache) - - if _alg isa OrdinaryDiffEqCompositeAlgorithm - # in case user mixes adaptive and non-adaptive algorithms - ensure_behaving_adaptivity!(integrator, integrator.cache) - - if save_start - # Loop to get all of the extra possible saves in callback initialization - for i in 1:(integrator.saveiter) - copyat_or_push!(alg_choice, i, integrator.cache.current) - end - end - end - end - - handle_dt!(integrator) - integrator -end - -function DiffEqBase.solve!(integrator::ODEIntegrator) - @inbounds while !isempty(integrator.opts.tstops) - while integrator.tdir * integrator.t < first(integrator.opts.tstops) - loopheader!(integrator) - if integrator.do_error_check && check_error!(integrator) != ReturnCode.Success - return integrator.sol - end - perform_step!(integrator, integrator.cache) - loopfooter!(integrator) - if isempty(integrator.opts.tstops) - break - end - end - handle_tstop!(integrator) - end - postamble!(integrator) - - f = integrator.sol.prob.f - - if DiffEqBase.has_analytic(f) - DiffEqBase.calculate_solution_errors!(integrator.sol; - timeseries_errors = integrator.opts.timeseries_errors, - dense_errors = integrator.opts.dense_errors) - end - if integrator.sol.retcode != ReturnCode.Default - return integrator.sol - end - integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, ReturnCode.Success) -end - -# Helpers - -function handle_dt!(integrator) - if iszero(integrator.dt) && integrator.opts.adaptive - auto_dt_reset!(integrator) - if sign(integrator.dt) != integrator.tdir && !iszero(integrator.dt) && - !isnan(integrator.dt) - error("Automatic dt setting has the wrong sign. Exiting. Please report this error.") - end - if isnan(integrator.dt) - if integrator.opts.verbose - @warn("Automatic dt set the starting dt as NaN, causing instability. Exiting.") - end - end - elseif integrator.opts.adaptive && integrator.dt > zero(integrator.dt) && - integrator.tdir < 0 - integrator.dt *= integrator.tdir # Allow positive dt, but auto-convert - end -end - -# time stops -@inline function initialize_tstops(::Type{T}, tstops, d_discontinuities, tspan) where {T} - tstops_internal = BinaryHeap{T}(DataStructures.FasterForward()) - - t0, tf = tspan - tdir = sign(tf - t0) - tdir_t0 = tdir * t0 - tdir_tf = tdir * tf - - for t in tstops - tdir_t = tdir * t - tdir_t0 < tdir_t ≤ tdir_tf && push!(tstops_internal, tdir_t) - end - for t in d_discontinuities - tdir_t = tdir * t - tdir_t0 < tdir_t ≤ tdir_tf && push!(tstops_internal, tdir_t) - end - push!(tstops_internal, tdir_tf) - - return tstops_internal -end - -# saving time points -function initialize_saveat(::Type{T}, saveat, tspan) where {T} - saveat_internal = BinaryHeap{T}(DataStructures.FasterForward()) - - t0, tf = tspan - tdir = sign(tf - t0) - tdir_t0 = tdir * t0 - tdir_tf = tdir * tf - - if saveat isa Number - directional_saveat = tdir * abs(saveat) - for t in (t0 + directional_saveat):directional_saveat:tf - push!(saveat_internal, tdir * t) - end - elseif !isempty(saveat) - for t in saveat - tdir_t = tdir * t - tdir_t0 < tdir_t ≤ tdir_tf && push!(saveat_internal, tdir_t) - end - end - - return saveat_internal -end - -# discontinuities -function initialize_d_discontinuities(::Type{T}, d_discontinuities, tspan) where {T} - d_discontinuities_internal = BinaryHeap{T}(DataStructures.FasterForward()) - sizehint!(d_discontinuities_internal, length(d_discontinuities)) - - t0, tf = tspan - tdir = sign(tf - t0) - - for t in d_discontinuities - push!(d_discontinuities_internal, tdir * t) - end - - return d_discontinuities_internal -end - -function initialize_callbacks!(integrator, initialize_save = true) - t = integrator.t - u = integrator.u - callbacks = integrator.opts.callback - integrator.u_modified = true - - u_modified = initialize!(callbacks, u, t, integrator) - - # if the user modifies u, we need to fix previous values before initializing - # FSAL in order for the starting derivatives to be correct - if u_modified - if isinplace(integrator.sol.prob) - recursivecopy!(integrator.uprev, integrator.u) - else - integrator.uprev = integrator.u - end - - if alg_extrapolates(integrator.alg) - if isinplace(integrator.sol.prob) - recursivecopy!(integrator.uprev2, integrator.uprev) - else - integrator.uprev2 = integrator.uprev - end - end - - if initialize_save && - (any((c) -> c.save_positions[2], callbacks.discrete_callbacks) || - any((c) -> c.save_positions[2], callbacks.continuous_callbacks)) - savevalues!(integrator, true) - end - end - - # reset this as it is now handled so the integrators should proceed as normal - integrator.u_modified = false -end +function DiffEqBase.__solve(prob::Union{DiffEqBase.AbstractODEProblem, + DiffEqBase.AbstractDAEProblem}, + alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, args...; + kwargs...) + integrator = DiffEqBase.__init(prob, alg, args...; kwargs...) + solve!(integrator) + integrator.sol +end + +function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem, + DiffEqBase.AbstractDAEProblem}, + alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, + timeseries_init = (), + ts_init = (), + ks_init = (), + recompile::Type{Val{recompile_flag}} = Val{true}; + saveat = (), + tstops = (), + d_discontinuities = (), + save_idxs = nothing, + save_everystep = isempty(saveat), + save_on = true, + save_start = save_everystep || isempty(saveat) || + saveat isa Number || prob.tspan[1] in saveat, + save_end = nothing, + callback = nothing, + dense = save_everystep && + !(alg isa Union{DAEAlgorithm, FunctionMap}) && + isempty(saveat), + calck = (callback !== nothing && callback !== CallbackSet()) || + (dense) || !isempty(saveat), # and no dense output + dt = alg isa FunctionMap && isempty(tstops) ? + eltype(prob.tspan)(1) : eltype(prob.tspan)(0), + dtmin = eltype(prob.tspan)(0), + dtmax = eltype(prob.tspan)((prob.tspan[end] - prob.tspan[1])), + force_dtmin = false, + adaptive = anyadaptive(alg), + gamma = gamma_default(alg), + abstol = nothing, + reltol = nothing, + qmin = qmin_default(alg), + qmax = qmax_default(alg), + qsteady_min = qsteady_min_default(alg), + qsteady_max = qsteady_max_default(alg), + beta1 = nothing, + beta2 = nothing, + qoldinit = anyadaptive(alg) ? 1 // 10^4 : 0, + controller = nothing, + fullnormalize = true, + failfactor = 2, + maxiters = anyadaptive(alg) ? 1000000 : typemax(Int), + internalnorm = ODE_DEFAULT_NORM, + internalopnorm = LinearAlgebra.opnorm, + isoutofdomain = ODE_DEFAULT_ISOUTOFDOMAIN, + unstable_check = ODE_DEFAULT_UNSTABLE_CHECK, + verbose = true, + timeseries_errors = true, + dense_errors = false, + advance_to_tstop = false, + stop_at_next_tstop = false, + initialize_save = true, + progress = false, + progress_steps = 1000, + progress_name = "ODE", + progress_message = ODE_DEFAULT_PROG_MESSAGE, + progress_id = gensym("OrdinaryDiffEq"), + userdata = nothing, + allow_extrapolation = alg_extrapolates(alg), + initialize_integrator = true, + alias_u0 = false, + alias_du0 = false, + initializealg = DefaultInit(), + kwargs...) where {recompile_flag} + if prob isa DiffEqBase.AbstractDAEProblem && alg isa OrdinaryDiffEqAlgorithm + error("You cannot use an ODE Algorithm with a DAEProblem") + end + + if prob isa DiffEqBase.AbstractODEProblem && alg isa DAEAlgorithm + error("You cannot use an DAE Algorithm with a ODEProblem") + end + + if prob isa DiffEqBase.ODEProblem + if !(prob.f isa DiffEqBase.DynamicalODEFunction) && alg isa PartitionedAlgorithm + error("You can not use a solver designed for partitioned ODE with this problem. Please choose a solver suitable for your problem") + end + end + + if prob.f isa DynamicalODEFunction && prob.f.mass_matrix isa Tuple + if any(mm != I for mm in prob.f.mass_matrix) + error("This solver is not able to use mass matrices.") + end + elseif !(prob isa DiscreteProblem) && + !(prob isa DiffEqBase.AbstractDAEProblem) && + !is_mass_matrix_alg(alg) && + prob.f.mass_matrix != I + error("This solver is not able to use mass matrices.") + end + + if alg isa OrdinaryDiffEqRosenbrockAdaptiveAlgorithm && + prob.f.mass_matrix isa AbstractMatrix && + all(isequal(0), prob.f.mass_matrix) + # technically this should also warn for zero operators but those are hard to check for + alg isa Union{Rosenbrock23, Rosenbrock32} && error("Rosenbrock23 and Rosenbrock32 require at least one differential variable to produce valid solutions") + if (dense || !isempty(saveat)) && verbose + @warn("Rosenbrock methods on equations without differential states do not bound the error on interpolations.") + end + end + + if !isempty(saveat) && dense + @warn("Dense output is incompatible with saveat. Please use the SavingCallback from the Callback Library to mix the two behaviors.") + end + + progress && @logmsg(LogLevel(-1), progress_name, _id=progress_id, progress=0) + + tType = eltype(prob.tspan) + tspan = prob.tspan + tdir = sign(tspan[end] - tspan[1]) + + t = tspan[1] + + if (((!(alg isa OrdinaryDiffEqAdaptiveAlgorithm) && + !(alg isa OrdinaryDiffEqCompositeAlgorithm) && + !(alg isa DAEAlgorithm)) || !adaptive || !isadaptive(alg)) && + dt == tType(0) && isempty(tstops)) && + !(alg isa Union{FunctionMap, LinearExponential}) + error("Fixed timestep methods require a choice of dt or choosing the tstops") + end + + isdae = alg isa DAEAlgorithm || (!(prob isa DiscreteProblem) && + prob.f.mass_matrix != I && + !(prob.f.mass_matrix isa Tuple) && + ArrayInterface.issingular(prob.f.mass_matrix)) + if alg isa CompositeAlgorithm && alg.choice_function isa AutoSwitch + auto = alg.choice_function + _alg = CompositeAlgorithm(alg.algs, + AutoSwitchCache(auto.algtrait, 0, 0, + auto.nonstiffalg, + auto.stiffalg, + auto.stiffalgfirst, + auto.maxstiffstep, + auto.maxnonstiffstep, + auto.nonstifftol, + auto.stifftol, + auto.dtfac, + auto.stiffalgfirst, + auto.switch_max, 0), Val(allowfallbacks(alg))) + else + _alg = alg + end + f = prob.f + p = prob.p + + # Get the control variables + + if alias_u0 + u = prob.u0 + else + u = recursivecopy(prob.u0) + end + + if _alg isa DAEAlgorithm + if alias_du0 + du = prob.du0 + else + du = recursivecopy(prob.du0) + end + duprev = recursivecopy(du) + else + du = nothing + duprev = nothing + end + + uType = typeof(u) + uBottomEltype = recursive_bottom_eltype(u) + uBottomEltypeNoUnits = recursive_unitless_bottom_eltype(u) + + uEltypeNoUnits = recursive_unitless_eltype(u) + tTypeNoUnits = typeof(one(tType)) + + if _alg isa FunctionMap + abstol_internal = false + elseif abstol === nothing + if uBottomEltypeNoUnits == uBottomEltype + abstol_internal = ForwardDiff.value(real(convert(uBottomEltype, + oneunit(uBottomEltype) * + 1 // 10^6))) + else + abstol_internal = ForwardDiff.value.(real.(oneunit.(u) .* 1 // 10^6)) + end + else + abstol_internal = real.(abstol) + end + + if _alg isa FunctionMap + reltol_internal = false + elseif reltol === nothing + if uBottomEltypeNoUnits == uBottomEltype + reltol_internal = real(convert(uBottomEltype, + oneunit(uBottomEltype) * 1 // 10^3)) + else + reltol_internal = real.(oneunit.(u) .* 1 // 10^3) + end + else + reltol_internal = real.(reltol) + end + + dtmax > zero(dtmax) && tdir < 0 && (dtmax *= tdir) # Allow positive dtmax, but auto-convert + # dtmin is all abs => does not care about sign already. + + if !isdae && isinplace(prob) && u isa AbstractArray && eltype(u) <: Number && + uBottomEltypeNoUnits == uBottomEltype && tType == tTypeNoUnits # Could this be more efficient for other arrays? + rate_prototype = recursivecopy(u) + elseif prob isa DAEProblem + rate_prototype = prob.du0 + else + if (uBottomEltypeNoUnits == uBottomEltype && tType == tTypeNoUnits) || + eltype(u) <: Enum + rate_prototype = u + else # has units! + rate_prototype = u / oneunit(tType) + end + end + rateType = typeof(rate_prototype) ## Can be different if united + + if isdae + if uBottomEltype == uBottomEltypeNoUnits + res_prototype = u + else + res_prototype = one(u) + end + resType = typeof(res_prototype) + end + + tstops_internal = initialize_tstops(tType, tstops, d_discontinuities, tspan) + saveat_internal = initialize_saveat(tType, saveat, tspan) + d_discontinuities_internal = initialize_d_discontinuities(tType, d_discontinuities, + tspan) + + callbacks_internal = CallbackSet(callback) + + max_len_cb = DiffEqBase.max_vector_callback_length_int(callbacks_internal) + if max_len_cb !== nothing + uBottomEltypeReal = real(uBottomEltype) + if isinplace(prob) + callback_cache = DiffEqBase.CallbackCache(u, max_len_cb, uBottomEltypeReal, + uBottomEltypeReal) + else + callback_cache = DiffEqBase.CallbackCache(max_len_cb, uBottomEltypeReal, + uBottomEltypeReal) + end + else + callback_cache = nothing + end + + ### Algorithm-specific defaults ### + if save_idxs === nothing + ksEltype = Vector{rateType} + else + ks_prototype = rate_prototype[save_idxs] + ksEltype = Vector{typeof(ks_prototype)} + end + + # Have to convert in case passed in wrong. + if save_idxs === nothing + timeseries = timeseries_init === () ? uType[] : + convert(Vector{uType}, timeseries_init) + else + u_initial = u[save_idxs] + timeseries = timeseries_init === () ? typeof(u_initial)[] : + convert(Vector{uType}, timeseries_init) + end + + ts = ts_init === () ? tType[] : convert(Vector{tType}, ts_init) + ks = ks_init === () ? ksEltype[] : convert(Vector{ksEltype}, ks_init) + alg_choice = _alg isa CompositeAlgorithm ? Int[] : () + + if (!adaptive || !isadaptive(_alg)) && save_everystep && tspan[2] - tspan[1] != Inf + if dt == 0 + steps = length(tstops) + else + # For fixed dt, the only time dtmin makes sense is if it's smaller than eps(). + # Therefore user specified dtmin doesn't matter, but we need to ensure dt>=eps() + # to prevent infinite loops. + abs(dt) < dtmin && + throw(ArgumentError("Supplied dt is smaller than dtmin")) + steps = ceil(Int, internalnorm((tspan[2] - tspan[1]) / dt, tspan[1])) + end + sizehint!(timeseries, steps + 1) + sizehint!(ts, steps + 1) + sizehint!(ks, steps + 1) + elseif save_everystep + sizehint!(timeseries, 50) + sizehint!(ts, 50) + sizehint!(ks, 50) + elseif !isempty(saveat_internal) + savelength = length(saveat_internal) + 1 + if save_start == false + savelength -= 1 + end + if save_end == false && prob.tspan[2] in saveat_internal.valtree + savelength -= 1 + end + sizehint!(timeseries, savelength) + sizehint!(ts, savelength) + sizehint!(ks, savelength) + else + sizehint!(timeseries, 2) + sizehint!(ts, 2) + sizehint!(ks, 2) + end + + QT, EEstT = if tTypeNoUnits <: Integer + typeof(qmin), typeof(qmin) + elseif prob isa DiscreteProblem + # The QT fields are not used for DiscreteProblems + constvalue(tTypeNoUnits), constvalue(tTypeNoUnits) + else + typeof(DiffEqBase.value(internalnorm(u, t))), typeof(internalnorm(u, t)) + end + + k = rateType[] + + if uses_uprev(_alg, adaptive) || calck + uprev = recursivecopy(u) + else + # Some algorithms do not use `uprev` explicitly. In that case, we can save + # some memory by aliasing `uprev = u`, e.g. for "2N" low storage methods. + uprev = u + end + if allow_extrapolation + uprev2 = recursivecopy(u) + else + uprev2 = uprev + end + + if prob isa DAEProblem + cache = alg_cache(_alg, du, u, res_prototype, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, + reltol_internal, p, calck, Val(isinplace(prob))) + else + cache = alg_cache(_alg, u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, + tTypeNoUnits, uprev, uprev2, f, t, dt, reltol_internal, p, calck, + Val(isinplace(prob))) + end + + # Setting up the step size controller + if (beta1 !== nothing || beta2 !== nothing) && controller !== nothing + throw(ArgumentError("Setting both the legacy PID parameters `beta1, beta2 = $((beta1, beta2))` and the `controller = $controller` is not allowed.")) + end + + if (beta1 !== nothing || beta2 !== nothing) + message = "Providing the legacy PID parameters `beta1, beta2` is deprecated. Use the keyword argument `controller` instead." + Base.depwarn(message, :init) + Base.depwarn(message, :solve) + end + + if controller === nothing + controller = default_controller(_alg, cache, qoldinit, beta1, beta2) + end + + save_end_user = save_end + save_end = save_end === nothing ? + save_everystep || isempty(saveat) || saveat isa Number || + prob.tspan[2] in saveat : save_end + + opts = DEOptions{typeof(abstol_internal), typeof(reltol_internal), + QT, tType, typeof(controller), + typeof(internalnorm), typeof(internalopnorm), + typeof(save_end_user), + typeof(callbacks_internal), + typeof(isoutofdomain), + typeof(progress_message), typeof(unstable_check), + typeof(tstops_internal), + typeof(d_discontinuities_internal), typeof(userdata), + typeof(save_idxs), + typeof(maxiters), typeof(tstops), + typeof(saveat), typeof(d_discontinuities)}(maxiters, save_everystep, + adaptive, abstol_internal, + reltol_internal, + QT(gamma), QT(qmax), + QT(qmin), + QT(qsteady_max), + QT(qsteady_min), + QT(qoldinit), + QT(failfactor), + tType(dtmax), tType(dtmin), + controller, + internalnorm, + internalopnorm, + save_idxs, tstops_internal, + saveat_internal, + d_discontinuities_internal, + tstops, saveat, + d_discontinuities, + userdata, progress, + progress_steps, + progress_name, + progress_message, + progress_id, + timeseries_errors, + dense_errors, dense, + save_on, save_start, + save_end, save_end_user, + callbacks_internal, + isoutofdomain, + unstable_check, + verbose, calck, force_dtmin, + advance_to_tstop, + stop_at_next_tstop) + + stats = SciMLBase.DEStats(0) + differential_vars = prob isa DAEProblem ? prob.differential_vars : get_differential_vars(f, u) + + if _alg isa OrdinaryDiffEqCompositeAlgorithm + id = CompositeInterpolationData(f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars) + sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, + dense = dense, k = ks, interp = id, + alg_choice = alg_choice, + calculate_error = false, stats = stats) + else + id = InterpolationData(f, timeseries, ts, ks, dense, cache, differential_vars) + sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, + dense = dense, k = ks, interp = id, + calculate_error = false, stats = stats) + end + + if recompile_flag == true + FType = typeof(f) + SolType = typeof(sol) + cacheType = typeof(cache) + else + FType = Function + if _alg isa OrdinaryDiffEqAlgorithm + SolType = DiffEqBase.AbstractODESolution + cacheType = OrdinaryDiffEqCache + else + SolType = DiffEqBase.AbstractDAESolution + cacheType = DAECache + end + end + + # rate/state = (state/time)/state = 1/t units, internalnorm drops units + # we don't want to differentiate through eigenvalue estimation + eigen_est = inv(one(tType)) + tprev = t + dtcache = tType(dt) + dtpropose = tType(dt) + iter = 0 + kshortsize = 0 + reeval_fsal = false + u_modified = false + EEst = EEstT(1) + just_hit_tstop = false + isout = false + accept_step = false + force_stepfail = false + last_stepfail = false + do_error_check = true + event_last_time = 0 + vector_event_last_time = 1 + last_event_error = _alg isa FunctionMap ? false : zero(uBottomEltypeNoUnits) + dtchangeable = isdtchangeable(_alg) + q11 = QT(1) + success_iter = 0 + erracc = QT(1) + dtacc = tType(1) + reinitiailize = true + saveiter = 0 # Starts at 0 so first save is at 1 + saveiter_dense = 0 + + integrator = ODEIntegrator{typeof(_alg), isinplace(prob), uType, typeof(du), + tType, typeof(p), + typeof(eigen_est), typeof(EEst), + QT, typeof(tdir), typeof(k), SolType, + FType, cacheType, + typeof(opts), fsal_typeof(_alg, rate_prototype), + typeof(last_event_error), typeof(callback_cache), + typeof(initializealg), typeof(differential_vars)}(sol, u, du, k, t, tType(dt), f, p, + uprev, uprev2, duprev, tprev, + _alg, dtcache, dtchangeable, + dtpropose, tdir, eigen_est, EEst, + QT(qoldinit), q11, + erracc, dtacc, success_iter, + iter, saveiter, saveiter_dense, cache, + callback_cache, + kshortsize, force_stepfail, + last_stepfail, + just_hit_tstop, do_error_check, + event_last_time, + vector_event_last_time, + last_event_error, accept_step, + isout, reeval_fsal, + u_modified, reinitiailize, isdae, + opts, stats, initializealg, differential_vars) + + if initialize_integrator + if isdae + DiffEqBase.initialize_dae!(integrator) + end + + if save_start + integrator.saveiter += 1 # Starts at 1 so first save is at 2 + integrator.saveiter_dense += 1 + copyat_or_push!(ts, 1, t) + if save_idxs === nothing + copyat_or_push!(timeseries, 1, integrator.u) + copyat_or_push!(ks, 1, [rate_prototype]) + else + copyat_or_push!(timeseries, 1, u_initial, Val{false}) + copyat_or_push!(ks, 1, [ks_prototype]) + end + else + integrator.saveiter = 0 # Starts at 0 so first save is at 1 + integrator.saveiter_dense = 0 + end + + initialize_callbacks!(integrator, initialize_save) + initialize!(integrator, integrator.cache) + + if _alg isa OrdinaryDiffEqCompositeAlgorithm + # in case user mixes adaptive and non-adaptive algorithms + ensure_behaving_adaptivity!(integrator, integrator.cache) + + if save_start + # Loop to get all of the extra possible saves in callback initialization + for i in 1:(integrator.saveiter) + copyat_or_push!(alg_choice, i, integrator.cache.current) + end + end + end + end + + handle_dt!(integrator) + integrator +end + +function DiffEqBase.solve!(integrator::ODEIntegrator) + @inbounds while !isempty(integrator.opts.tstops) + while integrator.tdir * integrator.t < first(integrator.opts.tstops) + loopheader!(integrator) + if integrator.do_error_check && check_error!(integrator) != ReturnCode.Success + return integrator.sol + end + perform_step!(integrator, integrator.cache) + loopfooter!(integrator) + if isempty(integrator.opts.tstops) + break + end + end + handle_tstop!(integrator) + end + postamble!(integrator) + + f = integrator.sol.prob.f + + if DiffEqBase.has_analytic(f) + DiffEqBase.calculate_solution_errors!(integrator.sol; + timeseries_errors = integrator.opts.timeseries_errors, + dense_errors = integrator.opts.dense_errors) + end + if integrator.sol.retcode != ReturnCode.Default + return integrator.sol + end + integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, ReturnCode.Success) +end + +# Helpers + +function handle_dt!(integrator) + if iszero(integrator.dt) && integrator.opts.adaptive + auto_dt_reset!(integrator) + if sign(integrator.dt) != integrator.tdir && !iszero(integrator.dt) && + !isnan(integrator.dt) + error("Automatic dt setting has the wrong sign. Exiting. Please report this error.") + end + if isnan(integrator.dt) + if integrator.opts.verbose + @warn("Automatic dt set the starting dt as NaN, causing instability. Exiting.") + end + end + elseif integrator.opts.adaptive && integrator.dt > zero(integrator.dt) && + integrator.tdir < 0 + integrator.dt *= integrator.tdir # Allow positive dt, but auto-convert + end +end + +# time stops +@inline function initialize_tstops(::Type{T}, tstops, d_discontinuities, tspan) where {T} + tstops_internal = BinaryHeap{T}(DataStructures.FasterForward()) + + t0, tf = tspan + tdir = sign(tf - t0) + tdir_t0 = tdir * t0 + tdir_tf = tdir * tf + + for t in tstops + tdir_t = tdir * t + tdir_t0 < tdir_t ≤ tdir_tf && push!(tstops_internal, tdir_t) + end + for t in d_discontinuities + tdir_t = tdir * t + tdir_t0 < tdir_t ≤ tdir_tf && push!(tstops_internal, tdir_t) + end + push!(tstops_internal, tdir_tf) + + return tstops_internal +end + +# saving time points +function initialize_saveat(::Type{T}, saveat, tspan) where {T} + saveat_internal = BinaryHeap{T}(DataStructures.FasterForward()) + + t0, tf = tspan + tdir = sign(tf - t0) + tdir_t0 = tdir * t0 + tdir_tf = tdir * tf + + if saveat isa Number + directional_saveat = tdir * abs(saveat) + for t in (t0 + directional_saveat):directional_saveat:tf + push!(saveat_internal, tdir * t) + end + elseif !isempty(saveat) + for t in saveat + tdir_t = tdir * t + tdir_t0 < tdir_t ≤ tdir_tf && push!(saveat_internal, tdir_t) + end + end + + return saveat_internal +end + +# discontinuities +function initialize_d_discontinuities(::Type{T}, d_discontinuities, tspan) where {T} + d_discontinuities_internal = BinaryHeap{T}(DataStructures.FasterForward()) + sizehint!(d_discontinuities_internal, length(d_discontinuities)) + + t0, tf = tspan + tdir = sign(tf - t0) + + for t in d_discontinuities + push!(d_discontinuities_internal, tdir * t) + end + + return d_discontinuities_internal +end + +function initialize_callbacks!(integrator, initialize_save = true) + t = integrator.t + u = integrator.u + callbacks = integrator.opts.callback + integrator.u_modified = true + + u_modified = initialize!(callbacks, u, t, integrator) + + # if the user modifies u, we need to fix previous values before initializing + # FSAL in order for the starting derivatives to be correct + if u_modified + if isinplace(integrator.sol.prob) + recursivecopy!(integrator.uprev, integrator.u) + else + integrator.uprev = integrator.u + end + + if alg_extrapolates(integrator.alg) + if isinplace(integrator.sol.prob) + recursivecopy!(integrator.uprev2, integrator.uprev) + else + integrator.uprev2 = integrator.uprev + end + end + + if initialize_save && + (any((c) -> c.save_positions[2], callbacks.discrete_callbacks) || + any((c) -> c.save_positions[2], callbacks.continuous_callbacks)) + savevalues!(integrator, true) + end + end + + # reset this as it is now handled so the integrators should proceed as normal + integrator.u_modified = false +end