From aa1c546e0d754ec6abb00924829a178b5c225d45 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 1 Jan 2024 15:02:34 -0500 Subject: [PATCH 01/32] Redesign default ODE solver to be fully type-grounded This accomplishes a few things: * Faster precompile times by precompiling less * Full inference of results when using the automatic algorithm * Hopefully faster load times by also precompiling less This is done the same way as * linearsolve https://github.com/SciML/LinearSolve.jl/pull/307 * nonlinearsolve https://github.com/SciML/NonlinearSolve.jl/pull/238 and is thus the more modern SciML way of doing it. It avoids dispatch by having a single algorithm that always generates the full cache and instead of dispatching between algorithms always branches for the choice. It turns out, the mechanism already existed for this in OrdinaryDiffEq... it's CompositeAlgorithm, the same bones as AutoSwitch! As such, this reuses quite a bit of code from the auto-switch algorithms but instead of just having two choices it (currently) has 6 that it chooses between. This means that it has stiffness detection and switching behavior, but also in a size-dependent way. There are still some optimizations to do though. Like LinearSolve.jl, it would be more efficient to have a way to initialize the caches to size zero and then have a way to re-initialize them to the correct size. Right now, it'll generate the same Jacobian N times and it shouldn't need to do that. fix typo / test fix precompilation choices Update src/composite_algs.jl Co-authored-by: Nathanael Bosch Update src/composite_algs.jl switch CompositeCache away from tuple so it can start undef Default Cache fix precompile remove fallbacks remove fallbacks --- Project.toml | 2 +- src/OrdinaryDiffEq.jl | 21 ++- src/alg_utils.jl | 29 ++-- src/algorithms.jl | 37 +++++ src/caches/basic_caches.jl | 55 ++++--- src/caches/verner_caches.jl | 30 ++-- src/composite_algs.jl | 160 +++++++++++++++++---- src/perform_step/composite_perform_step.jl | 149 +++++++++++-------- src/perform_step/verner_rk_perform_step.jl | 42 +++--- src/solve.jl | 22 +-- 10 files changed, 389 insertions(+), 158 deletions(-) diff --git a/Project.toml b/Project.toml index 5b69a55c5b..79aa0340d1 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" @@ -75,7 +76,6 @@ RecursiveArrayTools = "2.36, 3" Reexport = "1.0" SciMLBase = "2.27.1" SciMLOperators = "0.3" -SciMLStructures = "1" SimpleNonlinearSolve = "1" SimpleUnPack = "1" SparseArrays = "1.9" diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index ae6b2001ab..4aa2e5b187 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -26,6 +26,8 @@ using LinearSolve, SimpleNonlinearSolve using LineSearches +import EnumX + import FillArrays: Trues # Interfaces @@ -141,6 +143,7 @@ include("nlsolve/functional.jl") include("nlsolve/newton.jl") include("generic_rosenbrock.jl") +include("composite_algs.jl") include("caches/basic_caches.jl") include("caches/low_order_rk_caches.jl") @@ -234,7 +237,6 @@ include("constants.jl") include("solve.jl") include("initdt.jl") include("interp_func.jl") -include("composite_algs.jl") import PrecompileTools @@ -253,9 +255,14 @@ PrecompileTools.@compile_workload begin Tsit5(), Vern7() ] - stiff = [Rosenbrock23(), Rosenbrock23(autodiff = false), - Rodas5P(), Rodas5P(autodiff = false), - FBDF(), FBDF(autodiff = false) + stiff = [Rosenbrock23(), + Rodas5P(), + FBDF() + ] + + default_ode = [ + DefaultODEAlgorithm(autodiff=false), + DefaultODEAlgorithm() ] autoswitch = [ @@ -284,7 +291,11 @@ PrecompileTools.@compile_workload begin append!(solver_list, stiff) end - if Preferences.@load_preference("PrecompileAutoSwitch", true) + if Preferences.@load_preference("PrecompileDefault", true) + append!(solver_list, default_ode) + end + + if Preferences.@load_preference("PrecompileAutoSwitch", false) append!(solver_list, autoswitch) end diff --git a/src/alg_utils.jl b/src/alg_utils.jl index d82fce2f1e..236646398c 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -172,6 +172,8 @@ isimplicit(alg::CompositeAlgorithm) = any(isimplicit.(alg.algs)) isdtchangeable(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = true isdtchangeable(alg::CompositeAlgorithm) = all(isdtchangeable.(alg.algs)) +isdtchangeable(alg::DefaultSolverAlgorithm) = true + function isdtchangeable(alg::Union{LawsonEuler, NorsettEuler, LieEuler, MagnusGauss4, CayleyEuler, ETDRK2, ETDRK3, ETDRK4, HochOst4, ETD2}) false @@ -184,12 +186,14 @@ ismultistep(alg::ETD2) = true isadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false isadaptive(alg::OrdinaryDiffEqAdaptiveAlgorithm) = true isadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = all(isadaptive.(alg.algs)) +isadaptive(alg::DefaultSolverAlgorithm) = true isadaptive(alg::DImplicitEuler) = true isadaptive(alg::DABDF2) = true isadaptive(alg::DFBDF) = true anyadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = isadaptive(alg) anyadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = any(isadaptive, alg.algs) +anyadaptive(alg::DefaultSolverAlgorithm) = true isautoswitch(alg) = false isautoswitch(alg::CompositeAlgorithm) = alg.choice_function isa AutoSwitch @@ -199,9 +203,11 @@ function qmin_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) end qmin_default(alg::CompositeAlgorithm) = maximum(qmin_default.(alg.algs)) qmin_default(alg::DP8) = 1 // 3 +qmin_default(alg::DefaultSolverAlgorithm) = 1 // 5 qmax_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = 10 qmax_default(alg::CompositeAlgorithm) = minimum(qmax_default.(alg.algs)) +qmax_default(alg::DefaultSolverAlgorithm) = 10 qmax_default(alg::DP8) = 6 qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8 @@ -287,7 +293,7 @@ end function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob) algs = map(alg -> DiffEqBase.prepare_alg(alg, u0, p, prob), alg.algs) - CompositeAlgorithm(algs, alg.choice_function) + CompositeAlgorithm(algs, alg.choice_function,) end has_autodiff(alg::OrdinaryDiffEqAlgorithm) = false @@ -370,6 +376,7 @@ end alg_extrapolates(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false alg_extrapolates(alg::CompositeAlgorithm) = any(alg_extrapolates.(alg.algs)) +alg_extrapolates(alg::DefaultSolverAlgorithm) = false alg_extrapolates(alg::ImplicitEuler) = true alg_extrapolates(alg::DImplicitEuler) = true alg_extrapolates(alg::DABDF2) = true @@ -733,6 +740,7 @@ alg_order(alg::QPRK98) = 9 alg_maximum_order(alg) = alg_order(alg) alg_maximum_order(alg::CompositeAlgorithm) = maximum(alg_order(x) for x in alg.algs) +alg_maximum_order(alg::DefaultSolverAlgorithm) = 7 alg_maximum_order(alg::ExtrapolationMidpointDeuflhard) = 2(alg.max_order + 1) alg_maximum_order(alg::ImplicitDeuflhardExtrapolation) = 2(alg.max_order + 1) alg_maximum_order(alg::ExtrapolationMidpointHairerWanner) = 2(alg.max_order + 1) @@ -869,6 +877,7 @@ function gamma_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) isadaptive(alg) ? 9 // 10 : 0 end gamma_default(alg::CompositeAlgorithm) = maximum(gamma_default, alg.algs) +gamma_default(alg::DefaultSolverAlgorithm) = 9 // 10 gamma_default(alg::RKC) = 8 // 10 gamma_default(alg::IRKC) = 8 // 10 function gamma_default(alg::ExtrapolationMidpointDeuflhard) @@ -989,14 +998,18 @@ function unwrap_alg(integrator, is_stiff) if !iscomp return alg elseif alg.choice_function isa AutoSwitchCache - if is_stiff === nothing - throwautoswitch(alg) - end - num = is_stiff ? 2 : 1 - if num == 1 - return alg.algs[1] + if alg.choice_function.algtrait isa DefaultODESolver + alg.algs[alg.choice_function.current] else - return alg.algs[2] + if is_stiff === nothing + throwautoswitch(alg) + end + num = is_stiff ? 2 : 1 + if num == 1 + return alg.algs[1] + else + return alg.algs[2] + end end else return _eval_index(identity, alg.algs, integrator.cache.current) diff --git a/src/algorithms.jl b/src/algorithms.jl index bc2a0ac0f5..f481221421 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -3242,6 +3242,9 @@ end struct CompositeAlgorithm{T, F} <: OrdinaryDiffEqCompositeAlgorithm algs::T choice_function::F + function CompositeAlgorithm(algs, choice_function) + new{typeof(algs), typeof(choice_function)}(algs, choice_function) + end end TruncatedStacktraces.@truncate_stacktrace CompositeAlgorithm 1 @@ -3250,6 +3253,40 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!) Base.Experimental.silence!(CompositeAlgorithm) end +mutable struct AutoSwitchCache{Trait, nAlg, sAlg, tolType, T} + algtrait::Trait + count::Int + successive_switches::Int + nonstiffalg::nAlg + stiffalg::sAlg + is_stiffalg::Bool + maxstiffstep::Int + maxnonstiffstep::Int + nonstifftol::tolType + stifftol::tolType + dtfac::T + stiffalgfirst::Bool + switch_max::Int + current::Int +end + +struct AutoSwitch{Trait, nAlg, sAlg, tolType, T} + algtrait::Trait + nonstiffalg::nAlg + stiffalg::sAlg + maxstiffstep::Int + maxnonstiffstep::Int + nonstifftol::tolType + stifftol::tolType + dtfac::T + stiffalgfirst::Bool + switch_max::Int +end + +struct DefaultODESolver end +const DefaultSolverAlgorithm = Union{CompositeAlgorithm{<:Tuple, <:AutoSwitch{DefaultODESolver}}, +CompositeAlgorithm{<:Tuple, <:AutoSwitchCache{DefaultODESolver}}} + ################################################################################ """ MEBDF2: Multistep Method diff --git a/src/caches/basic_caches.jl b/src/caches/basic_caches.jl index fa492dbe7a..5853f6449f 100644 --- a/src/caches/basic_caches.jl +++ b/src/caches/basic_caches.jl @@ -12,24 +12,26 @@ end TruncatedStacktraces.@truncate_stacktrace CompositeCache 1 -if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!) - Base.Experimental.silence!(CompositeCache) +mutable struct DefaultCache{T1, T2, T3, T4, T5, T6, A, F} <: OrdinaryDiffEqCache + args::A + choice_function::F + current::Int + cache1::T1 + cache2::T2 + cache3::T3 + cache4::T4 + cache5::T5 + cache6::T6 + function DefaultCache{T1, T2, T3, T4, T5, T6, F}(args, choice_function, current) where {T1, T2, T3, T4, T5, T6, F} + new{T1, T2, T3, T4, T5, T6, typeof(args), F}(args, choice_function, current) + end end -function alg_cache(alg::CompositeAlgorithm{Tuple{T1, T2}, F}, u, rate_prototype, - ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, - uprev2, f, t, dt, reltol, p, calck, - ::Val{V}) where {T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits, - tTypeNoUnits} - caches = ( - alg_cache(alg.algs[1], u, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, - tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)), - alg_cache(alg.algs[2], u, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, - tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V))) - CompositeCache(caches, alg.choice_function, 1) +TruncatedStacktraces.@truncate_stacktrace DefaultCache 1 + +if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!) + Base.Experimental.silence!(CompositeCache) + Base.Experimental.silence!(DefaultCache) end function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits}, @@ -38,7 +40,26 @@ function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoU ::Val{V}) where {V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} caches = __alg_cache(alg.algs, u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)) - CompositeCache(caches, alg.choice_function, 1) + CompositeCache{typeof(caches), typeof(alg.choice_function)}( + caches, alg.choice_function, 1) +end + +function alg_cache(alg::CompositeAlgorithm{Tuple{A1, A2, A3, A4, A5, A6}}, u, + rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{V}) where {V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, A1, A2, A3, A4, A5, A6} + + args = (u, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, + reltol, p, calck, Val(V)) + argT = map(typeof, args) + T1 = Base.promote_op(alg_cache, A1, argT...) + T2 = Base.promote_op(alg_cache, A2, argT...) + T3 = Base.promote_op(alg_cache, A3, argT...) + T4 = Base.promote_op(alg_cache, A4, argT...) + T5 = Base.promote_op(alg_cache, A5, argT...) + T6 = Base.promote_op(alg_cache, A6, argT...) + DefaultCache{T1, T2, T3, T4, T5, T6, typeof(alg.choice_function)}(args, alg.choice_function, 1) end # map + closure approach doesn't infer diff --git a/src/caches/verner_caches.jl b/src/caches/verner_caches.jl index 08b1de0919..6de34d0559 100644 --- a/src/caches/verner_caches.jl +++ b/src/caches/verner_caches.jl @@ -20,6 +20,7 @@ stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern6Cache 1 @@ -44,11 +45,12 @@ function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern6Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, utilde, tmp, rtmp, atmp, tab, - alg.stage_limiter!, alg.step_limiter!, alg.thread) + alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end struct Vern6ConstantCache{TabType} <: OrdinaryDiffEqConstantCache tab::TabType + lazy::Bool end function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, @@ -56,7 +58,7 @@ function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tab = Vern6Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - Vern6ConstantCache(tab) + Vern6ConstantCache(tab, alg.lazy) end @cache struct Vern7Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter, @@ -81,6 +83,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern7Cache 1 @@ -105,16 +108,18 @@ function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern7Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, utilde, tmp, rtmp, atmp, - alg.stage_limiter!, alg.step_limiter!, alg.thread) + alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end -struct Vern7ConstantCache <: OrdinaryDiffEqConstantCache end +struct Vern7ConstantCache <: OrdinaryDiffEqConstantCache + lazy::Bool +end function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - Vern7ConstantCache() + Vern7ConstantCache(alg.lazy) end @cache struct Vern8Cache{uType, rateType, uNoUnitsType, TabType, StageLimiter, StepLimiter, @@ -143,6 +148,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern8Cache 1 @@ -171,11 +177,12 @@ function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern8Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, utilde, - tmp, rtmp, atmp, tab, alg.stage_limiter!, alg.step_limiter!, alg.thread) + tmp, rtmp, atmp, tab, alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end struct Vern8ConstantCache{TabType} <: OrdinaryDiffEqConstantCache tab::TabType + lazy::Bool end function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, @@ -183,7 +190,7 @@ function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tab = Vern8Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - Vern8ConstantCache(tab) + Vern8ConstantCache(tab, alg.lazy) end @cache struct Vern9Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter, @@ -214,6 +221,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern9Cache 1 @@ -245,14 +253,16 @@ function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits}, rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern9Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16, utilde, tmp, rtmp, atmp, alg.stage_limiter!, alg.step_limiter!, - alg.thread) + alg.thread, alg.lazy) end -struct Vern9ConstantCache <: OrdinaryDiffEqConstantCache end +struct Vern9ConstantCache <: OrdinaryDiffEqConstantCache + lazy::Bool +end function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - Vern9ConstantCache() + Vern9ConstantCache(alg.lazy) end diff --git a/src/composite_algs.jl b/src/composite_algs.jl index 9f1d4dcf6b..1f76765b68 100644 --- a/src/composite_algs.jl +++ b/src/composite_algs.jl @@ -1,30 +1,8 @@ -mutable struct AutoSwitchCache{nAlg, sAlg, tolType, T} - count::Int - successive_switches::Int - nonstiffalg::nAlg - stiffalg::sAlg - is_stiffalg::Bool - maxstiffstep::Int - maxnonstiffstep::Int - nonstifftol::tolType - stifftol::tolType - dtfac::T - stiffalgfirst::Bool - switch_max::Int -end +### AutoSwitch +### Designed to switch between two solvers, stiff and non-stiff -struct AutoSwitch{nAlg, sAlg, tolType, T} - nonstiffalg::nAlg - stiffalg::sAlg - maxstiffstep::Int - maxnonstiffstep::Int - nonstifftol::tolType - stifftol::tolType - dtfac::T - stiffalgfirst::Bool - switch_max::Int -end -function AutoSwitch(nonstiffalg, stiffalg; maxstiffstep = 10, maxnonstiffstep = 3, +function AutoSwitch(nonstiffalg, stiffalg, algtrait = nothing; + maxstiffstep = 10, maxnonstiffstep = 3, nonstifftol = 9 // 10, stifftol = 9 // 10, dtfac = 2, stiffalgfirst = false, switch_max = 5) @@ -41,7 +19,6 @@ function is_stiff(integrator, alg, ntol, stol, is_stiffalg) if !bool integrator.alg.choice_function.successive_switches += 1 - integrator.do_error_check = false else integrator.alg.choice_function.successive_switches = 0 end @@ -53,7 +30,11 @@ function is_stiff(integrator, alg, ntol, stol, is_stiffalg) end function (AS::AutoSwitchCache)(integrator) - integrator.iter == 0 && return Int(AS.stiffalgfirst) + 1 + if AS.current == 0 + AS.current = Int(AS.stiffalgfirst) + 1 + return AS.current + end + dt = integrator.dt # Successive stiffness test positives are counted by a positive integer, # and successive stiffness test negatives are counted by a negative integer @@ -68,17 +49,134 @@ function (AS::AutoSwitchCache)(integrator) integrator.dt = dt / AS.dtfac AS.is_stiffalg = false end - return Int(AS.is_stiffalg) + 1 + AS.current = Int(AS.is_stiffalg) + 1 + return AS.current end -function AutoAlgSwitch(nonstiffalg, stiffalg; kwargs...) - AS = AutoSwitch(nonstiffalg, stiffalg; kwargs...) +function AutoAlgSwitch(nonstiffalg::OrdinaryDiffEqAlgorithm, stiffalg::OrdinaryDiffEqAlgorithm, algtrait = nothing; kwargs...) + AS = AutoSwitch(nonstiffalg, stiffalg, algtrait; kwargs...) CompositeAlgorithm((nonstiffalg, stiffalg), AS) end +function AutoAlgSwitch(nonstiffalg::Tuple, stiffalg::Tuple, algtrait; kwargs...) + AS = AutoSwitch(nonstiffalg, stiffalg, algtrait; kwargs...) + CompositeAlgorithm((nonstiffalg..., stiffalg...), AS) +end + AutoTsit5(alg; kwargs...) = AutoAlgSwitch(Tsit5(), alg; kwargs...) AutoDP5(alg; kwargs...) = AutoAlgSwitch(DP5(), alg; kwargs...) AutoVern6(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern6(lazy = lazy), alg; kwargs...) AutoVern7(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern7(lazy = lazy), alg; kwargs...) AutoVern8(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern8(lazy = lazy), alg; kwargs...) AutoVern9(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern9(lazy = lazy), alg; kwargs...) + +### Default ODE Solver + +EnumX.@enumx DefaultSolverChoice begin + Tsit5 = 1 + Vern7 = 2 + Rosenbrock23 = 3 + Rodas5P = 4 + FBDF = 5 + KrylovFBDF = 6 +end + +const NUM_NONSTIFF = 2 +const NUM_STIFF = 4 +const LOW_TOL = 1e-6 +const MED_TOL = 1e-2 +const EXTREME_TOL = 1e-9 +const SMALLSIZE = 50 +const MEDIUMSIZE = 500 +const STABILITY_SIZES = (alg_stability_size(Tsit5()), alg_stability_size(Vern7())) +const DEFAULTBETA2S = (beta2_default(Tsit5()), beta2_default(Vern7()), beta2_default(Rosenbrock23()), beta2_default(Rodas5P()), beta2_default(FBDF()), beta2_default(FBDF())) +const DEFAULTBETA1S = (beta1_default(Tsit5(),DEFAULTBETA2S[1]), beta1_default(Vern7(),DEFAULTBETA2S[2]), + beta1_default(Rosenbrock23(), DEFAULTBETA2S[3]), beta1_default(Rodas5P(), DEFAULTBETA2S[4]), + beta1_default(FBDF(), DEFAULTBETA2S[5]), beta1_default(FBDF(), DEFAULTBETA2S[6])) + +callbacks_exists(integrator) = !isempty(integrator.opts.callbacks) +current_nonstiff(current) = ifelse(current <= NUM_NONSTIFF,current,current-NUM_STIFF) + +function DefaultODEAlgorithm(; lazy = true, stiffalgfirst = false, kwargs...) + nonstiff = (Tsit5(), Vern7(lazy = lazy)) + stiff = (Rosenbrock23(;kwargs...), Rodas5P(;kwargs...), FBDF(;kwargs...), FBDF(;linsolve = LinearSolve.KrylovJL_GMRES())) + AutoAlgSwitch(nonstiff, stiff, DefaultODESolver(); stiffalgfirst) +end + +function is_stiff(integrator, alg, ntol, stol, is_stiffalg, current) + eigen_est, dt = integrator.eigen_est, integrator.dt + stiffness = abs(eigen_est * dt / STABILITY_SIZES[nonstiffchoice(integrator.opts.reltol)]) # `abs` here is just for safety + tol = is_stiffalg ? stol : ntol + os = oneunit(stiffness) + bool = stiffness > os * tol + + if !bool + integrator.alg.choice_function.successive_switches += 1 + else + integrator.alg.choice_function.successive_switches = 0 + end + + integrator.do_error_check = (integrator.alg.choice_function.successive_switches > + integrator.alg.choice_function.switch_max || !bool) || + is_stiffalg + bool +end + +function nonstiffchoice(reltol) + x = if reltol < LOW_TOL + DefaultSolverChoice.Vern7 + else + DefaultSolverChoice.Tsit5 + end + Int(x) +end + +function stiffchoice(reltol, len) + x = if len > MEDIUMSIZE + DefaultSolverChoice.KrylovFBDF + elseif len > SMALLSIZE + DefaultSolverChoice.FBDF + else + if reltol < LOW_TOL + DefaultSolverChoice.Rodas5P + else + DefaultSolverChoice.Rosenbrock23 + end + end + Int(x) +end + +function (AS::AutoSwitchCache{DefaultODESolver})(integrator) + + len = length(integrator.u) + reltol = integrator.opts.reltol + + # Chooose the starting method + if AS.current == 0 + choice = if AS.stiffalgfirst || integrator.f.mass_matrix != I + stiffchoice(reltol, len) + else + nonstiffchoice(reltol) + end + AS.current = choice + return AS.current + end + + dt = integrator.dt + # Successive stiffness test positives are counted by a positive integer, + # and successive stiffness test negatives are counted by a negative integer + AS.count = is_stiff(integrator, AS.nonstiffalg, AS.nonstifftol, AS.stifftol, + AS.is_stiffalg, AS.current) ? + AS.count < 0 ? 1 : AS.count + 1 : + AS.count > 0 ? -1 : AS.count - 1 + if (!AS.is_stiffalg && AS.count > AS.maxstiffstep) + integrator.dt = dt * AS.dtfac + AS.is_stiffalg = true + AS.current = stiffchoice(reltol, len) + elseif (AS.is_stiffalg && AS.count < -AS.maxnonstiffstep) + integrator.dt = dt / AS.dtfac + AS.is_stiffalg = false + AS.current = nonstiffchoice(reltol) + end + return AS.current +end diff --git a/src/perform_step/composite_perform_step.jl b/src/perform_step/composite_perform_step.jl index 75057a3adc..36cf00230d 100644 --- a/src/perform_step/composite_perform_step.jl +++ b/src/perform_step/composite_perform_step.jl @@ -1,38 +1,50 @@ -#= - -Maybe do generated functions to reduce dispatch times? - -f(x) = x -g(x,i) = f(x[i]) -g{i}(x,::Type{Val{i}}) = f(x[i]) -@generated function gg(tup::Tuple, num) - N = length(tup.parameters) - :(@nif $(N+1) i->(i == num) i->(f(tup[i])) i->error("unreachable")) - end -h(i) = g((1,1.0,"foo"), i) -h2{i}(::Type{Val{i}}) = g((1,1.0,"foo"), Val{i}) -h3(i) = gg((1,1.0,"foo"), i) -@benchmark h(1) -mean time: 31.822 ns (0.00% GC) -@benchmark h2(Val{1}) -mean time: 1.585 ns (0.00% GC) -@benchmark h3(1) -mean time: 6.423 ns (0.00% GC) - -@generated function foo(tup::Tuple, num) - N = length(tup.parameters) - :(@nif $(N+1) i->(i == num) i->(tup[i]) i->error("unreachable")) -end - -@code_typed foo((1,1.0), 1) - -@generated function perform_step!(integrator, cache::CompositeCache, repeat_step=false) - N = length(cache.parameters) - :(@nif $(N+1) i->(i == num) i->(tup[i]) i->error("unreachable")) +function initialize!(integrator, cache::DefaultCache) + cache.current = cache.choice_function(integrator) + algs = integrator.alg.algs + if cache.current == 1 + if !isdefined(cache, :cache1) + cache.cache1 = alg_cache(algs[1], cache.args...) + end + initialize!(integrator, cache.cache1) + elseif cache.current == 2 + if !isdefined(cache, :cache2) + cache.cache2 = alg_cache(algs[2], cache.args...) + end + initialize!(integrator, cache.cache2) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[2]) + elseif cache.current == 3 + if !isdefined(cache, :cache3) + cache.cache3 = alg_cache(algs[3], cache.args...) + end + initialize!(integrator, cache.cache3) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[3]) + elseif cache.current == 4 + if !isdefined(cache, :cache4) + cache.cache4 = alg_cache(algs[4], cache.args...) + end + initialize!(integrator, cache.cache4) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[4]) + elseif cache.current == 5 + if !isdefined(cache, :cache5) + cache.cache5 = alg_cache(algs[5], cache.args...) + end + initialize!(integrator, cache.cache5) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[5]) + elseif cache.current == 6 + if !isdefined(cache, :cache6) + cache.cache6 = alg_cache(algs[6], cache.args...) + end + initialize!(integrator, cache.cache6) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[6]) + end + resize!(integrator.k, integrator.kshortsize) end -=# - function initialize!(integrator, cache::CompositeCache) cache.current = cache.choice_function(integrator) if cache.current == 1 @@ -69,28 +81,35 @@ the behaviour is consistent. In particular, prevents dt ⟶ 0 if starting with non-adaptive alg and opts.adaptive=true, and dt=cst if starting with adaptive alg and opts.adaptive=false. """ -function ensure_behaving_adaptivity!(integrator, cache::CompositeCache) +function ensure_behaving_adaptivity!(integrator, cache::Union{DefaultCache, CompositeCache}) if anyadaptive(integrator.alg) && !isadaptive(integrator.alg) integrator.opts.adaptive = isadaptive(integrator.alg.algs[cache.current]) end end -function perform_step!(integrator, cache::CompositeCache, repeat_step = false) +function perform_step!(integrator, cache::DefaultCache, repeat_step = false) if cache.current == 1 - perform_step!(integrator, @inbounds(cache.caches[1]), repeat_step) + perform_step!(integrator, @inbounds(cache.cache1), repeat_step) elseif cache.current == 2 - perform_step!(integrator, @inbounds(cache.caches[2]), repeat_step) - else - perform_step!(integrator, @inbounds(cache.caches[cache.current]), repeat_step) + perform_step!(integrator, @inbounds(cache.cache2), repeat_step) + elseif cache.current == 3 + perform_step!(integrator, @inbounds(cache.cache3), repeat_step) + elseif cache.current == 4 + perform_step!(integrator, @inbounds(cache.cache4), repeat_step) + elseif cache.current == 5 + perform_step!(integrator, @inbounds(cache.cache5), repeat_step) + elseif cache.current == 6 + perform_step!(integrator, @inbounds(cache.cache6), repeat_step) end end -function perform_step!(integrator, cache::CompositeCache{Tuple{T1, T2}, F}, - repeat_step = false) where {T1, T2, F} +function perform_step!(integrator, cache::CompositeCache, repeat_step = false) if cache.current == 1 perform_step!(integrator, @inbounds(cache.caches[1]), repeat_step) elseif cache.current == 2 perform_step!(integrator, @inbounds(cache.caches[2]), repeat_step) + else + perform_step!(integrator, @inbounds(cache.caches[cache.current]), repeat_step) end end @@ -122,6 +141,24 @@ function choose_algorithm!(integrator, end function choose_algorithm!(integrator, cache::CompositeCache) + new_current = cache.choice_function(integrator) + old_current = cache.current + @inbounds if new_current != old_current + cache.current = new_current + initialize!(integrator, @inbounds(cache.caches[new_current])) + + controller.beta2 = beta2_default(alg2) + controller.beta1 = beta2_default(alg2) + DEFAULTBETA2S + + reset_alg_dependent_opts!(integrator, integrator.alg.algs[old_current], + integrator.alg.algs[new_current]) + transfer_cache!(integrator, integrator.cache.caches[old_current], + integrator.cache.caches[new_current]) + end +end + +function choose_algorithm!(integrator, cache::CompositeCache{<:Any, <:AutoSwitchCache{DefaultODESolver}}) new_current = cache.choice_function(integrator) old_current = cache.current @inbounds if new_current != old_current @@ -130,26 +167,23 @@ function choose_algorithm!(integrator, cache::CompositeCache) initialize!(integrator, @inbounds(cache.caches[1])) elseif new_current == 2 initialize!(integrator, @inbounds(cache.caches[2])) + elseif new_current == 3 + initialize!(integrator, @inbounds(cache.caches[3])) + elseif new_current == 4 + initialize!(integrator, @inbounds(cache.caches[4])) + elseif new_current == 5 + initialize!(integrator, @inbounds(cache.caches[5])) + elseif new_current == 6 + initialize!(integrator, @inbounds(cache.caches[6])) else initialize!(integrator, @inbounds(cache.caches[new_current])) end - if old_current == 1 && new_current == 2 - reset_alg_dependent_opts!(integrator, integrator.alg.algs[1], - integrator.alg.algs[2]) - transfer_cache!(integrator, integrator.cache.caches[1], - integrator.cache.caches[2]) - elseif old_current == 2 && new_current == 1 - reset_alg_dependent_opts!(integrator, integrator.alg.algs[2], - integrator.alg.algs[1]) - transfer_cache!(integrator, integrator.cache.caches[2], - integrator.cache.caches[1]) - else - reset_alg_dependent_opts!(integrator, integrator.alg.algs[old_current], - integrator.alg.algs[new_current]) - transfer_cache!(integrator, integrator.cache.caches[old_current], - integrator.cache.caches[new_current]) - end + + # dtchangable, qmin_default, qmax_default, and isadaptive ignored since all same + integrator.opts.controller.beta1 = DEFAULTBETA1S[new_current] + integrator.opts.controller.beta2 = DEFAULTBETA2S[new_current] end + nothing end """ @@ -170,6 +204,7 @@ function reset_alg_dependent_opts!(integrator, alg1, alg2) integrator.opts.qmax == qmax_default(alg2) end reset_alg_dependent_opts!(integrator.opts.controller, alg1, alg2) + nothing end # Write how to transfer the cache variables from one cache to the other diff --git a/src/perform_step/verner_rk_perform_step.jl b/src/perform_step/verner_rk_perform_step.jl index 0a44e217da..7c76e26719 100644 --- a/src/perform_step/verner_rk_perform_step.jl +++ b/src/perform_step/verner_rk_perform_step.jl @@ -2,7 +2,7 @@ function initialize!(integrator, cache::Vern6ConstantCache) integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) # Pre-start fsal integrator.stats.nf += 1 alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) + cache.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -13,7 +13,7 @@ function initialize!(integrator, cache::Vern6ConstantCache) end integrator.k[integrator.kshortsize] = integrator.fsallast - if !alg.lazy + if !cache.lazy @inbounds for i in 10:12 integrator.k[i] = zero(integrator.fsalfirst) end @@ -63,7 +63,7 @@ end integrator.k[9] = k9 alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c10, a1001, a1004, a1005, a1006, a1007, a1008, a1009, c11, a1101, a1104, a1105, a1106, a1107, a1108, a1109, a1110, c12, a1201, a1204, a1205, a1206, a1207, a1208, a1209, a1210, a1211 = cache.tab.extra @@ -94,7 +94,7 @@ end function initialize!(integrator, cache::Vern6Cache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) + cache.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) integrator.fsalfirst = cache.k1 integrator.fsallast = cache.k9 @unpack k = integrator @@ -109,7 +109,7 @@ function initialize!(integrator, cache::Vern6Cache) k[8] = cache.k8 k[9] = cache.k9 # Set the pointers - if !alg.lazy + if !cache.lazy k[10] = similar(cache.k1) k[11] = similar(cache.k1) k[12] = similar(cache.k1) @@ -182,7 +182,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c10, a1001, a1004, a1005, a1006, a1007, a1008, a1009, c11, a1101, a1104, a1105, a1106, a1107, a1108, a1109, a1110, c12, a1201, a1204, a1205, a1206, a1207, a1208, a1209, a1210, a1211 = cache.tab.extra @@ -214,7 +214,7 @@ end function initialize!(integrator, cache::Vern7ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -277,7 +277,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @OnDemandTableauExtract Vern7ExtraStages T T2 @@ -329,7 +329,7 @@ function initialize!(integrator, cache::Vern7Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) resize!(k, integrator.kshortsize) k[1] = k1 k[2] = k2 @@ -342,7 +342,7 @@ function initialize!(integrator, cache::Vern7Cache) k[9] = k9 k[10] = k10 # Setup pointers - if !alg.lazy + if !cache.lazy k[11] = similar(cache.k1) k[12] = similar(cache.k1) k[13] = similar(cache.k1) @@ -433,7 +433,7 @@ end integrator.EEst = integrator.opts.internalnorm(atmp, t) end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack tmp = cache @@ -489,7 +489,7 @@ end function initialize!(integrator, cache::Vern8ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) + cache.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -575,7 +575,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c14, a1401, a1406, a1407, a1408, a1409, a1410, a1411, a1412, c15, a1501, a1506, a1507, a1508, a1509, a1510, a1511, a1512, a1514, c16, a1601, a1606, a1607, a1608, a1609, a1610, a1611, a1612, a1614, a1615, c17, a1701, a1706, a1707, a1708, a1709, a1710, a1711, a1712, a1714, a1715, a1716, c18, a1801, a1806, a1807, a1808, a1809, a1810, a1811, a1812, a1814, a1815, a1816, a1817, c19, a1901, a1906, a1907, a1908, a1909, a1910, a1911, a1912, a1914, a1915, a1916, a1917, c20, a2001, a2006, a2007, a2008, a2009, a2010, a2011, a2012, a2014, a2015, a2016, a2017, c21, a2101, a2106, a2107, a2108, a2109, a2110, a2111, a2112, a2114, a2115, a2116, a2117 = cache.tab.extra @@ -642,7 +642,7 @@ function initialize!(integrator, cache::Vern8Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) + cache.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) resize!(k, integrator.kshortsize) k[1] = k1 k[2] = k2 @@ -658,7 +658,7 @@ function initialize!(integrator, cache::Vern8Cache) k[12] = k12 k[13] = k13 # Setup pointers - if !alg.lazy + if !cache.lazy for i in 14:21 k[i] = similar(cache.k1) end @@ -764,7 +764,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c14, a1401, a1406, a1407, a1408, a1409, a1410, a1411, a1412, c15, a1501, a1506, a1507, a1508, a1509, a1510, a1511, a1512, a1514, c16, a1601, a1606, a1607, a1608, a1609, a1610, a1611, a1612, a1614, a1615, c17, a1701, a1706, a1707, a1708, a1709, a1710, a1711, a1712, a1714, a1715, a1716, c18, a1801, a1806, a1807, a1808, a1809, a1810, a1811, a1812, a1814, a1815, a1816, a1817, c19, a1901, a1906, a1907, a1908, a1909, a1910, a1911, a1912, a1914, a1915, a1916, a1917, c20, a2001, a2006, a2007, a2008, a2009, a2010, a2011, a2012, a2014, a2015, a2016, a2017, c21, a2101, a2106, a2107, a2108, a2109, a2110, a2111, a2112, a2114, a2115, a2116, a2117 = cache.tab.extra @@ -851,7 +851,7 @@ end function initialize!(integrator, cache::Vern9ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -945,7 +945,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @OnDemandTableauExtract Vern9ExtraStages T T2 @@ -1032,7 +1032,7 @@ function initialize!(integrator, cache::Vern9Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) resize!(k, integrator.kshortsize) # k2, k3,k4,k5,k6,k7 are not used in the code (not even in interpolations), we dont need their pointers. # So we mapped k[2] (from integrator) with k8 (from cache), k[3] with k9 and so on. @@ -1047,7 +1047,7 @@ function initialize!(integrator, cache::Vern9Cache) k[9] = k15 k[10] = k16 # Setup pointers - if !alg.lazy + if !cache.lazy for i in 11:20 k[i] = similar(cache.k1) end @@ -1174,7 +1174,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack tmp = cache diff --git a/src/solve.jl b/src/solve.jl index e29136450e..2b5c26efb6 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -136,7 +136,7 @@ function DiffEqBase.__init( if alg isa CompositeAlgorithm && alg.choice_function isa AutoSwitch auto = alg.choice_function _alg = CompositeAlgorithm(alg.algs, - AutoSwitchCache(0, 0, + AutoSwitchCache(auto.algtrait, 0, 0, auto.nonstiffalg, auto.stiffalg, auto.stiffalgfirst, @@ -146,7 +146,7 @@ function DiffEqBase.__init( auto.stifftol, auto.dtfac, auto.stiffalgfirst, - auto.switch_max)) + auto.switch_max, 0)) else _alg = alg end @@ -415,12 +415,18 @@ function DiffEqBase.__init( differential_vars = prob isa DAEProblem ? prob.differential_vars : get_differential_vars(f, u) - id = InterpolationData( - f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars, false) - sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, - dense = dense, k = ks, interp = id, - alg_choice = alg_choice, - calculate_error = false, stats = stats) + if _alg isa OrdinaryDiffEqCompositeAlgorithm + id = CompositeInterpolationData(f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars) + sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, + dense = dense, k = ks, interp = id, + alg_choice = alg_choice, + calculate_error = false, stats = stats) + else + id = InterpolationData(f, timeseries, ts, ks, dense, cache, differential_vars) + sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, + dense = dense, k = ks, interp = id, + calculate_error = false, stats = stats) + end if recompile_flag == true FType = typeof(f) From dc7e96f7f70bd793d0fd4b8dd92de71d174744e0 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 9 May 2024 09:32:03 -0400 Subject: [PATCH 02/32] fix typo --- src/composite_algs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/composite_algs.jl b/src/composite_algs.jl index 1f76765b68..29272fcaca 100644 --- a/src/composite_algs.jl +++ b/src/composite_algs.jl @@ -6,7 +6,7 @@ function AutoSwitch(nonstiffalg, stiffalg, algtrait = nothing; nonstifftol = 9 // 10, stifftol = 9 // 10, dtfac = 2, stiffalgfirst = false, switch_max = 5) - AutoSwitch(nonstiffalg, stiffalg, maxstiffstep, maxnonstiffstep, + AutoSwitch(algtrait, nonstiffalg, stiffalg, maxstiffstep, maxnonstiffstep, promote(nonstifftol, stifftol)..., dtfac, stiffalgfirst, switch_max) end From 4648f3a755b74e0b29efb1a964603433a795306a Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 9 May 2024 13:55:44 -0400 Subject: [PATCH 03/32] typos --- src/algorithms.jl | 11 ++++------- src/caches/basic_caches.jl | 3 +-- src/solve.jl | 16 ++++------------ 3 files changed, 9 insertions(+), 21 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index f481221421..2e079e13ac 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -2969,7 +2969,7 @@ Scientific Computing, 18 (1), pp. 1-22. differential-algebraic problems. Computational mathematics (2nd revised ed.), Springer (1996) #### ROS2PR, ROS2S, ROS3PR, Scholz4_7 --Rang, Joachim (2014): The Prothero and Robinson example: +-Rang, Joachim (2014): The Prothero and Robinson example: Convergence studies for Runge-Kutta and Rosenbrock-Wanner methods. https://doi.org/10.24355/dbbs.084-201408121139-0 @@ -3014,16 +3014,16 @@ University of Geneva, Switzerland. https://doi.org/10.1016/j.cam.2015.03.010 #### ROS3PRL, ROS3PRL2 --Rang, Joachim (2014): The Prothero and Robinson example: +-Rang, Joachim (2014): The Prothero and Robinson example: Convergence studies for Runge-Kutta and Rosenbrock-Wanner methods. https://doi.org/10.24355/dbbs.084-201408121139-0 #### Rodas5P -- Steinebach G. Construction of Rosenbrock–Wanner method Rodas5P and numerical benchmarks within the Julia Differential Equations package. +- Steinebach G. Construction of Rosenbrock–Wanner method Rodas5P and numerical benchmarks within the Julia Differential Equations package. In: BIT Numerical Mathematics, 63(2), 2023 #### Rodas23W, Rodas3P, Rodas5Pe, Rodas5Pr -- Steinebach G. Rosenbrock methods within OrdinaryDiffEq.jl - Overview, recent developments and applications - +- Steinebach G. Rosenbrock methods within OrdinaryDiffEq.jl - Overview, recent developments and applications - Preprint 2024 https://github.com/hbrs-cse/RosenbrockMethods/blob/main/paper/JuliaPaper.pdf @@ -3242,9 +3242,6 @@ end struct CompositeAlgorithm{T, F} <: OrdinaryDiffEqCompositeAlgorithm algs::T choice_function::F - function CompositeAlgorithm(algs, choice_function) - new{typeof(algs), typeof(choice_function)}(algs, choice_function) - end end TruncatedStacktraces.@truncate_stacktrace CompositeAlgorithm 1 diff --git a/src/caches/basic_caches.jl b/src/caches/basic_caches.jl index 5853f6449f..5b11769545 100644 --- a/src/caches/basic_caches.jl +++ b/src/caches/basic_caches.jl @@ -40,8 +40,7 @@ function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoU ::Val{V}) where {V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} caches = __alg_cache(alg.algs, u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)) - CompositeCache{typeof(caches), typeof(alg.choice_function)}( - caches, alg.choice_function, 1) + CompositeCaches(caches, alg.choice_function, 1) end function alg_cache(alg::CompositeAlgorithm{Tuple{A1, A2, A3, A4, A5, A6}}, u, diff --git a/src/solve.jl b/src/solve.jl index 2b5c26efb6..2452e59db6 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -415,18 +415,10 @@ function DiffEqBase.__init( differential_vars = prob isa DAEProblem ? prob.differential_vars : get_differential_vars(f, u) - if _alg isa OrdinaryDiffEqCompositeAlgorithm - id = CompositeInterpolationData(f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars) - sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, - dense = dense, k = ks, interp = id, - alg_choice = alg_choice, - calculate_error = false, stats = stats) - else - id = InterpolationData(f, timeseries, ts, ks, dense, cache, differential_vars) - sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, - dense = dense, k = ks, interp = id, - calculate_error = false, stats = stats) - end + id = InterpolationData(f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars, false) + sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, + dense = dense, k = ks, interp = id, + calculate_error = false, stats = stats) if recompile_flag == true FType = typeof(f) From 96147c973472634fa65b320871b110d96b39ae94 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 9 May 2024 14:05:04 -0400 Subject: [PATCH 04/32] rebase typos --- src/alg_utils.jl | 2 +- src/solve.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 236646398c..2d26ed00a1 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -293,7 +293,7 @@ end function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob) algs = map(alg -> DiffEqBase.prepare_alg(alg, u0, p, prob), alg.algs) - CompositeAlgorithm(algs, alg.choice_function,) + CompositeAlgorithm(algs, alg.choice_function) end has_autodiff(alg::OrdinaryDiffEqAlgorithm) = false diff --git a/src/solve.jl b/src/solve.jl index 2452e59db6..70a890695e 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -417,7 +417,7 @@ function DiffEqBase.__init( id = InterpolationData(f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars, false) sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, - dense = dense, k = ks, interp = id, + dense = dense, k = ks, interp = id, alg_choice = alg_choice, calculate_error = false, stats = stats) if recompile_flag == true From b946f2e592de22071cc1478976247b5c51069abc Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 9 May 2024 14:40:00 -0400 Subject: [PATCH 05/32] typo --- src/caches/basic_caches.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/caches/basic_caches.jl b/src/caches/basic_caches.jl index 5b11769545..74e4aec923 100644 --- a/src/caches/basic_caches.jl +++ b/src/caches/basic_caches.jl @@ -40,7 +40,7 @@ function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoU ::Val{V}) where {V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} caches = __alg_cache(alg.algs, u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)) - CompositeCaches(caches, alg.choice_function, 1) + CompositeCache(caches, alg.choice_function, 1) end function alg_cache(alg::CompositeAlgorithm{Tuple{A1, A2, A3, A4, A5, A6}}, u, From 5f19a80e5d81440e25d78763a71f90c344369ac8 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 13 May 2024 16:44:20 -0400 Subject: [PATCH 06/32] fix DelayDiffEq issue --- src/alg_utils.jl | 13 +++------- src/algorithms.jl | 10 ++------ src/composite_algs.jl | 22 +++++++++------- src/perform_step/composite_perform_step.jl | 30 ++++++---------------- src/solve.jl | 2 +- 5 files changed, 28 insertions(+), 49 deletions(-) diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 2d26ed00a1..6827dba8d6 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -172,7 +172,6 @@ isimplicit(alg::CompositeAlgorithm) = any(isimplicit.(alg.algs)) isdtchangeable(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = true isdtchangeable(alg::CompositeAlgorithm) = all(isdtchangeable.(alg.algs)) -isdtchangeable(alg::DefaultSolverAlgorithm) = true function isdtchangeable(alg::Union{LawsonEuler, NorsettEuler, LieEuler, MagnusGauss4, CayleyEuler, ETDRK2, ETDRK3, ETDRK4, HochOst4, ETD2}) @@ -186,14 +185,12 @@ ismultistep(alg::ETD2) = true isadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false isadaptive(alg::OrdinaryDiffEqAdaptiveAlgorithm) = true isadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = all(isadaptive.(alg.algs)) -isadaptive(alg::DefaultSolverAlgorithm) = true isadaptive(alg::DImplicitEuler) = true isadaptive(alg::DABDF2) = true isadaptive(alg::DFBDF) = true anyadaptive(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = isadaptive(alg) anyadaptive(alg::OrdinaryDiffEqCompositeAlgorithm) = any(isadaptive, alg.algs) -anyadaptive(alg::DefaultSolverAlgorithm) = true isautoswitch(alg) = false isautoswitch(alg::CompositeAlgorithm) = alg.choice_function isa AutoSwitch @@ -203,11 +200,9 @@ function qmin_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) end qmin_default(alg::CompositeAlgorithm) = maximum(qmin_default.(alg.algs)) qmin_default(alg::DP8) = 1 // 3 -qmin_default(alg::DefaultSolverAlgorithm) = 1 // 5 qmax_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = 10 qmax_default(alg::CompositeAlgorithm) = minimum(qmax_default.(alg.algs)) -qmax_default(alg::DefaultSolverAlgorithm) = 10 qmax_default(alg::DP8) = 6 qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8 @@ -376,7 +371,6 @@ end alg_extrapolates(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false alg_extrapolates(alg::CompositeAlgorithm) = any(alg_extrapolates.(alg.algs)) -alg_extrapolates(alg::DefaultSolverAlgorithm) = false alg_extrapolates(alg::ImplicitEuler) = true alg_extrapolates(alg::DImplicitEuler) = true alg_extrapolates(alg::DABDF2) = true @@ -740,7 +734,6 @@ alg_order(alg::QPRK98) = 9 alg_maximum_order(alg) = alg_order(alg) alg_maximum_order(alg::CompositeAlgorithm) = maximum(alg_order(x) for x in alg.algs) -alg_maximum_order(alg::DefaultSolverAlgorithm) = 7 alg_maximum_order(alg::ExtrapolationMidpointDeuflhard) = 2(alg.max_order + 1) alg_maximum_order(alg::ImplicitDeuflhardExtrapolation) = 2(alg.max_order + 1) alg_maximum_order(alg::ExtrapolationMidpointHairerWanner) = 2(alg.max_order + 1) @@ -877,7 +870,6 @@ function gamma_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) isadaptive(alg) ? 9 // 10 : 0 end gamma_default(alg::CompositeAlgorithm) = maximum(gamma_default, alg.algs) -gamma_default(alg::DefaultSolverAlgorithm) = 9 // 10 gamma_default(alg::RKC) = 8 // 10 gamma_default(alg::IRKC) = 8 // 10 function gamma_default(alg::ExtrapolationMidpointDeuflhard) @@ -978,6 +970,9 @@ function unwrap_alg(alg::SciMLBase.DEAlgorithm, is_stiff) if !iscomp return alg elseif alg.choice_function isa AutoSwitchCache + if length(alg.algs) >2 + return alg.algs[alg.choice_function.current] + end if is_stiff === nothing throwautoswitch(alg) end @@ -998,7 +993,7 @@ function unwrap_alg(integrator, is_stiff) if !iscomp return alg elseif alg.choice_function isa AutoSwitchCache - if alg.choice_function.algtrait isa DefaultODESolver + if ralse alg.algs[alg.choice_function.current] else if is_stiff === nothing diff --git a/src/algorithms.jl b/src/algorithms.jl index 2e079e13ac..2287be75e0 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -3250,8 +3250,7 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!) Base.Experimental.silence!(CompositeAlgorithm) end -mutable struct AutoSwitchCache{Trait, nAlg, sAlg, tolType, T} - algtrait::Trait +mutable struct AutoSwitchCache{nAlg, sAlg, tolType, T} count::Int successive_switches::Int nonstiffalg::nAlg @@ -3267,8 +3266,7 @@ mutable struct AutoSwitchCache{Trait, nAlg, sAlg, tolType, T} current::Int end -struct AutoSwitch{Trait, nAlg, sAlg, tolType, T} - algtrait::Trait +struct AutoSwitch{nAlg, sAlg, tolType, T} nonstiffalg::nAlg stiffalg::sAlg maxstiffstep::Int @@ -3280,10 +3278,6 @@ struct AutoSwitch{Trait, nAlg, sAlg, tolType, T} switch_max::Int end -struct DefaultODESolver end -const DefaultSolverAlgorithm = Union{CompositeAlgorithm{<:Tuple, <:AutoSwitch{DefaultODESolver}}, -CompositeAlgorithm{<:Tuple, <:AutoSwitchCache{DefaultODESolver}}} - ################################################################################ """ MEBDF2: Multistep Method diff --git a/src/composite_algs.jl b/src/composite_algs.jl index 29272fcaca..6e9447a4a4 100644 --- a/src/composite_algs.jl +++ b/src/composite_algs.jl @@ -1,12 +1,12 @@ ### AutoSwitch ### Designed to switch between two solvers, stiff and non-stiff -function AutoSwitch(nonstiffalg, stiffalg, algtrait = nothing; +function AutoSwitch(nonstiffalg, stiffalg; maxstiffstep = 10, maxnonstiffstep = 3, nonstifftol = 9 // 10, stifftol = 9 // 10, dtfac = 2, stiffalgfirst = false, switch_max = 5) - AutoSwitch(algtrait, nonstiffalg, stiffalg, maxstiffstep, maxnonstiffstep, + AutoSwitch(nonstiffalg, stiffalg, maxstiffstep, maxnonstiffstep, promote(nonstifftol, stifftol)..., dtfac, stiffalgfirst, switch_max) end @@ -30,6 +30,11 @@ function is_stiff(integrator, alg, ntol, stol, is_stiffalg) end function (AS::AutoSwitchCache)(integrator) + #horrible awful hack + isdefault = integrator.alg isa CompositeAlgorithm{<:Tuple{Tsit5, Vern7, Rosenbrock23, Rodas5P, FBDF, FBDF}} + if isdefault + return default_autoswitch(AS, integrator) + end if AS.current == 0 AS.current = Int(AS.stiffalgfirst) + 1 return AS.current @@ -53,13 +58,13 @@ function (AS::AutoSwitchCache)(integrator) return AS.current end -function AutoAlgSwitch(nonstiffalg::OrdinaryDiffEqAlgorithm, stiffalg::OrdinaryDiffEqAlgorithm, algtrait = nothing; kwargs...) - AS = AutoSwitch(nonstiffalg, stiffalg, algtrait; kwargs...) +function AutoAlgSwitch(nonstiffalg::OrdinaryDiffEqAlgorithm, stiffalg::OrdinaryDiffEqAlgorithm; kwargs...) + AS = AutoSwitch(nonstiffalg, stiffalg; kwargs...) CompositeAlgorithm((nonstiffalg, stiffalg), AS) end -function AutoAlgSwitch(nonstiffalg::Tuple, stiffalg::Tuple, algtrait; kwargs...) - AS = AutoSwitch(nonstiffalg, stiffalg, algtrait; kwargs...) +function AutoAlgSwitch(nonstiffalg::Tuple, stiffalg::Tuple; kwargs...) + AS = AutoSwitch(nonstiffalg, stiffalg; kwargs...) CompositeAlgorithm((nonstiffalg..., stiffalg...), AS) end @@ -100,7 +105,7 @@ current_nonstiff(current) = ifelse(current <= NUM_NONSTIFF,current,current-NUM_S function DefaultODEAlgorithm(; lazy = true, stiffalgfirst = false, kwargs...) nonstiff = (Tsit5(), Vern7(lazy = lazy)) stiff = (Rosenbrock23(;kwargs...), Rodas5P(;kwargs...), FBDF(;kwargs...), FBDF(;linsolve = LinearSolve.KrylovJL_GMRES())) - AutoAlgSwitch(nonstiff, stiff, DefaultODESolver(); stiffalgfirst) + AutoAlgSwitch(nonstiff, stiff; stiffalgfirst) end function is_stiff(integrator, alg, ntol, stol, is_stiffalg, current) @@ -146,8 +151,7 @@ function stiffchoice(reltol, len) Int(x) end -function (AS::AutoSwitchCache{DefaultODESolver})(integrator) - +function default_autoswitch(AS::AutoSwitchCache, integrator) len = length(integrator.u) reltol = integrator.opts.reltol diff --git a/src/perform_step/composite_perform_step.jl b/src/perform_step/composite_perform_step.jl index 36cf00230d..ee798709d8 100644 --- a/src/perform_step/composite_perform_step.jl +++ b/src/perform_step/composite_perform_step.jl @@ -141,24 +141,6 @@ function choose_algorithm!(integrator, end function choose_algorithm!(integrator, cache::CompositeCache) - new_current = cache.choice_function(integrator) - old_current = cache.current - @inbounds if new_current != old_current - cache.current = new_current - initialize!(integrator, @inbounds(cache.caches[new_current])) - - controller.beta2 = beta2_default(alg2) - controller.beta1 = beta2_default(alg2) - DEFAULTBETA2S - - reset_alg_dependent_opts!(integrator, integrator.alg.algs[old_current], - integrator.alg.algs[new_current]) - transfer_cache!(integrator, integrator.cache.caches[old_current], - integrator.cache.caches[new_current]) - end -end - -function choose_algorithm!(integrator, cache::CompositeCache{<:Any, <:AutoSwitchCache{DefaultODESolver}}) new_current = cache.choice_function(integrator) old_current = cache.current @inbounds if new_current != old_current @@ -179,11 +161,15 @@ function choose_algorithm!(integrator, cache::CompositeCache{<:Any, <:AutoSwitch initialize!(integrator, @inbounds(cache.caches[new_current])) end - # dtchangable, qmin_default, qmax_default, and isadaptive ignored since all same - integrator.opts.controller.beta1 = DEFAULTBETA1S[new_current] - integrator.opts.controller.beta2 = DEFAULTBETA2S[new_current] + controller.beta2 = beta2_default(alg2) + controller.beta1 = beta2_default(alg2) + DEFAULTBETA2S + + reset_alg_dependent_opts!(integrator, integrator.alg.algs[old_current], + integrator.alg.algs[new_current]) + transfer_cache!(integrator, integrator.cache.caches[old_current], + integrator.cache.caches[new_current]) end - nothing end """ diff --git a/src/solve.jl b/src/solve.jl index 70a890695e..560b86dc62 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -136,7 +136,7 @@ function DiffEqBase.__init( if alg isa CompositeAlgorithm && alg.choice_function isa AutoSwitch auto = alg.choice_function _alg = CompositeAlgorithm(alg.algs, - AutoSwitchCache(auto.algtrait, 0, 0, + AutoSwitchCache(0, 0, auto.nonstiffalg, auto.stiffalg, auto.stiffalgfirst, From 51a84e0de281198035a2941b8ddaa793044af2bc Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Tue, 14 May 2024 09:47:12 -0400 Subject: [PATCH 07/32] bugfix --- src/alg_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 6827dba8d6..b4b2b125cb 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -993,7 +993,7 @@ function unwrap_alg(integrator, is_stiff) if !iscomp return alg elseif alg.choice_function isa AutoSwitchCache - if ralse + if length(alg.algs) > 2 alg.algs[alg.choice_function.current] else if is_stiff === nothing From 75ef069d593bb8a04781fc1cabf99725d4711a79 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 14 May 2024 10:34:56 -0400 Subject: [PATCH 08/32] use the default solver by default --- src/solve.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/solve.jl b/src/solve.jl index 560b86dc62..826d67a3ca 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -532,6 +532,13 @@ function DiffEqBase.__init( integrator end +function DiffEqBase.__init(prob::ODEProblem, ::Nothing, args...; kwargs...) + DiffEqBase.__init(prob, DefaultODEAlgorithm(), args...; kwargs...) +end +function DiffEqBase.__solve(prob::ODEProblem, ::Nothing, args...; kwargs...) + DiffEqBase.__solve(prob, DefaultODEAlgorithm(), args...; kwargs...) +end + function DiffEqBase.solve!(integrator::ODEIntegrator) @inbounds while !isempty(integrator.opts.tstops) while integrator.tdir * integrator.t < first(integrator.opts.tstops) From 0458d407b6e15214d4d210f6b9dff9240ee5151d Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 14 May 2024 10:37:49 -0400 Subject: [PATCH 09/32] init current to 0 --- src/algorithms.jl | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/algorithms.jl b/src/algorithms.jl index 2287be75e0..ff2a9ce4f7 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -3264,6 +3264,34 @@ mutable struct AutoSwitchCache{nAlg, sAlg, tolType, T} stiffalgfirst::Bool switch_max::Int current::Int + function AutoSwitchCache(count::Int, + successive_switches::Int, + nonstiffalg::nAlg, + stiffalg::sAlg, + is_stiffalg::Bool, + maxstiffstep::Int, + maxnonstiffstep::Int, + nonstifftol::tolType, + stifftol::tolType, + dtfac::T, + stiffalgfirst::Bool, + switch_max::Int, + current::Int=0) where {nAlg, sAlg, tolType, T} + new{nAlg, sAlg, tolType, T}(count, + successive_switches, + nonstiffalg, + stiffalg, + is_stiffalg, + maxstiffstep, + maxnonstiffstep, + nonstifftol, + stifftol, + dtfac, + stiffalgfirst, + switch_max, + current) + end + end struct AutoSwitch{nAlg, sAlg, tolType, T} From 9bdc84bad118eba79b732ee0e494aa334ab6131e Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 14 May 2024 15:00:21 -0400 Subject: [PATCH 10/32] add tests --- test/interface/default_solver_tests.jl | 19 +++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 20 insertions(+) create mode 100644 test/interface/default_solver_tests.jl diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl new file mode 100644 index 0000000000..39471b3d4a --- /dev/null +++ b/test/interface/default_solver_tests.jl @@ -0,0 +1,19 @@ +using OrdinaryDiffEq, Test + +f_2dlinear = (du, u, p, t) -> (@. du = p * u) + +prob_ode_2Dlinear = ODEProblem(f_2dlinear, rand(4, 2), (0.0, 1.0), 1.01) +sol = solve(prob_ode_2Dlinear) + +tsitsol = solve(prob_ode_2Dlinear, Tsit5()) +# test that default isn't much worse than Tsit5 (we expect it to use Tsit5 for this). +@test sol.stats.naccept < tsitsol.stats.naccept + 2 +@test sol.stats.nf < tsitsol.stats.nf + 20 + +sol = solve(prob_ode_2Dlinear, reltol=1e-10) +vernsol = solve(prob_ode_2Dlinear, Vern7(), reltol=1e-10) +# test that default isn't much worse than Tsit5 (we expect it to use Tsit5 for this). +@test sol.stats.naccept < tsitsol.stats.naccept + 2 +@test sol.stats.nf < tsitsol.stats.nf + 20 + +prob_ode_2Dlinear_stiff = ODEProblem(f_2dlinear, rand(4, 2), (0.0, 1.0), -1.01) diff --git a/test/runtests.jl b/test/runtests.jl index 6cdf509647..a15d6044dc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,6 +52,7 @@ end @time @safetestset "Controller Tests" include("interface/controllers.jl") @time @safetestset "Inplace Interpolation Tests" include("interface/inplace_interpolation.jl") @time @safetestset "Algebraic Interpolation Tests" include("interface/algebraic_interpolation.jl") + @time @safetestset "Default Solver Tests" include("interface/default_solver_tests.jl") end if !is_APPVEYOR && (GROUP == "All" || GROUP == "InterfaceII" || GROUP == "Interface") From 59539d811d211b7f59256854d94d8723da84057e Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 15 May 2024 09:57:57 -0400 Subject: [PATCH 11/32] improve tests --- test/interface/default_solver_tests.jl | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index 39471b3d4a..4b0fd2f8e7 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -16,4 +16,14 @@ vernsol = solve(prob_ode_2Dlinear, Vern7(), reltol=1e-10) @test sol.stats.naccept < tsitsol.stats.naccept + 2 @test sol.stats.nf < tsitsol.stats.nf + 20 -prob_ode_2Dlinear_stiff = ODEProblem(f_2dlinear, rand(4, 2), (0.0, 1.0), -1.01) +function rober(u, p, t) + y₁, y₂, y₃ = u + k₁, k₂, k₃ = p + [-k₁ * y₁ + k₃ * y₂ * y₃, + k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2, + k₂ * y₂^2] +end + +prob_rober = ODEProblem(rober, [1.0,0.0,0.0],(0.0,1e5),(0.04,3e7,1e4)) +sol = solve(prob_rober) +rosensol = solve(prob_rober, Rosenbrock23()) From 66fd5222a7e8cdcc14aa1c877b530d45c65f52f0 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 15 May 2024 14:58:31 -0400 Subject: [PATCH 12/32] switching works other than functionwrappers --- src/alg_utils.jl | 6 +- src/perform_step/composite_perform_step.jl | 98 ++++++++++++++-------- test/interface/default_solver_tests.jl | 4 +- 3 files changed, 69 insertions(+), 39 deletions(-) diff --git a/src/alg_utils.jl b/src/alg_utils.jl index b4b2b125cb..e7bef159cf 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -966,8 +966,7 @@ alg_can_repeat_jac(alg::OrdinaryDiffEqNewtonAdaptiveAlgorithm) = true alg_can_repeat_jac(alg::IRKC) = false function unwrap_alg(alg::SciMLBase.DEAlgorithm, is_stiff) - iscomp = alg isa CompositeAlgorithm - if !iscomp + if !(alg isa CompositeAlgorithm) return alg elseif alg.choice_function isa AutoSwitchCache if length(alg.algs) >2 @@ -989,8 +988,7 @@ end function unwrap_alg(integrator, is_stiff) alg = integrator.alg - iscomp = alg isa CompositeAlgorithm - if !iscomp + if !(alg isa CompositeAlgorithm) return alg elseif alg.choice_function isa AutoSwitchCache if length(alg.algs) > 2 diff --git a/src/perform_step/composite_perform_step.jl b/src/perform_step/composite_perform_step.jl index ee798709d8..2eb19e1688 100644 --- a/src/perform_step/composite_perform_step.jl +++ b/src/perform_step/composite_perform_step.jl @@ -1,43 +1,54 @@ -function initialize!(integrator, cache::DefaultCache) - cache.current = cache.choice_function(integrator) - algs = integrator.alg.algs - if cache.current == 1 +function init_ith_default_cache(cache::DefaultCache, algs, i) + if i == 1 if !isdefined(cache, :cache1) cache.cache1 = alg_cache(algs[1], cache.args...) end - initialize!(integrator, cache.cache1) - elseif cache.current == 2 + elseif i == 2 if !isdefined(cache, :cache2) cache.cache2 = alg_cache(algs[2], cache.args...) end + elseif i == 3 + if !isdefined(cache, :cache3) + cache.cache3 = alg_cache(algs[3], cache.args...) + end + elseif i == 4 + if !isdefined(cache, :cache4) + cache.cache4 = alg_cache(algs[4], cache.args...) + end + elseif i == 5 + if !isdefined(cache, :cache5) + cache.cache5 = alg_cache(algs[5], cache.args...) + end + elseif i == 6 + if !isdefined(cache, :cache6) + cache.cache6 = alg_cache(algs[6], cache.args...) + end + end +end + +function initialize!(integrator, cache::DefaultCache) + cache.current = cache.choice_function(integrator) + algs = integrator.alg.algs + init_ith_default_cache(cache, algs, cache.current) + if cache.current == 1 + initialize!(integrator, cache.cache1) + elseif cache.current == 2 initialize!(integrator, cache.cache2) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[2]) elseif cache.current == 3 - if !isdefined(cache, :cache3) - cache.cache3 = alg_cache(algs[3], cache.args...) - end initialize!(integrator, cache.cache3) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[3]) elseif cache.current == 4 - if !isdefined(cache, :cache4) - cache.cache4 = alg_cache(algs[4], cache.args...) - end initialize!(integrator, cache.cache4) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[4]) elseif cache.current == 5 - if !isdefined(cache, :cache5) - cache.cache5 = alg_cache(algs[5], cache.args...) - end initialize!(integrator, cache.cache5) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[5]) elseif cache.current == 6 - if !isdefined(cache, :cache6) - cache.cache6 = alg_cache(algs[6], cache.args...) - end initialize!(integrator, cache.cache6) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[6]) @@ -88,6 +99,8 @@ function ensure_behaving_adaptivity!(integrator, cache::Union{DefaultCache, Comp end function perform_step!(integrator, cache::DefaultCache, repeat_step = false) + algs = integrator.alg.algs + init_ith_default_cache(cache, algs, cache.current) if cache.current == 1 perform_step!(integrator, @inbounds(cache.cache1), repeat_step) elseif cache.current == 2 @@ -140,35 +153,52 @@ function choose_algorithm!(integrator, end end -function choose_algorithm!(integrator, cache::CompositeCache) +function choose_algorithm!(integrator, cache::DefaultCache) new_current = cache.choice_function(integrator) old_current = cache.current @inbounds if new_current != old_current + algs = integrator.alg.algs cache.current = new_current + init_ith_default_cache(cache, algs, new_current) if new_current == 1 - initialize!(integrator, @inbounds(cache.caches[1])) + initialize!(integrator, @inbounds(cache.cache1)) + new_cache = cache.cache1 elseif new_current == 2 - initialize!(integrator, @inbounds(cache.caches[2])) + initialize!(integrator, @inbounds(cache.cache2)) + new_cache = cache.cache2 elseif new_current == 3 - initialize!(integrator, @inbounds(cache.caches[3])) + initialize!(integrator, @inbounds(cache.cache3)) + new_cache = cache.cache3 elseif new_current == 4 - initialize!(integrator, @inbounds(cache.caches[4])) + initialize!(integrator, @inbounds(cache.cache4)) + new_cache = cache.cache4 elseif new_current == 5 - initialize!(integrator, @inbounds(cache.caches[5])) + initialize!(integrator, @inbounds(cache.cache5)) + new_cache = cache.cache5 elseif new_current == 6 - initialize!(integrator, @inbounds(cache.caches[6])) - else - initialize!(integrator, @inbounds(cache.caches[new_current])) + initialize!(integrator, @inbounds(cache.cache6)) + new_cache = cache.cache6 + end + + if old_current == 1 + old_cache = cache.cache1 + elseif old_current == 2 + old_cache = cache.cache2 + elseif old_current == 3 + old_cache = cache.cache3 + elseif old_current == 4 + old_cache = cache.cache4 + elseif old_current == 5 + old_cache = cache.cache5 + elseif old_current == 6 + old_cache = cache.cache6 end - controller.beta2 = beta2_default(alg2) - controller.beta1 = beta2_default(alg2) - DEFAULTBETA2S + integrator.opts.controller.beta2 = beta2 = beta2_default(algs[new_current]) + integrator.opts.controller.beta1 = beta1_default(algs[new_current], beta2) - reset_alg_dependent_opts!(integrator, integrator.alg.algs[old_current], - integrator.alg.algs[new_current]) - transfer_cache!(integrator, integrator.cache.caches[old_current], - integrator.cache.caches[new_current]) + reset_alg_dependent_opts!(integrator, algs[old_current], algs[new_current]) + transfer_cache!(integrator, old_cache, new_cache) end end diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index 4b0fd2f8e7..5ef6997635 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -23,7 +23,9 @@ function rober(u, p, t) k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2, k₂ * y₂^2] end - prob_rober = ODEProblem(rober, [1.0,0.0,0.0],(0.0,1e5),(0.04,3e7,1e4)) sol = solve(prob_rober) rosensol = solve(prob_rober, Rosenbrock23()) +# test that default isn't much worse than Rosenbrock23 (we expect it to use Rosenbrock23 for this). +@test sol.stats.naccept < rosensol.stats.naccept + 2 +@test sol.stats.nf < rosensol.stats.nf + 20 From 4901a0249b8d08ead5e682af48126a8d414ef08d Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 15 May 2024 15:20:46 -0400 Subject: [PATCH 13/32] it works --- src/alg_utils.jl | 26 ++++++++++---------------- src/algorithms.jl | 5 ++++- src/caches/basic_caches.jl | 4 ++-- src/composite_algs.jl | 2 +- test/interface/default_solver_tests.jl | 8 ++++---- 5 files changed, 21 insertions(+), 24 deletions(-) diff --git a/src/alg_utils.jl b/src/alg_utils.jl index e7bef159cf..5369311e81 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -209,28 +209,22 @@ qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8 function get_chunksize(alg::OrdinaryDiffEqAlgorithm) error("This algorithm does not have a chunk size defined.") end -get_chunksize(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = Val(CS) -get_chunksize(alg::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = Val(CS) -get_chunksize(alg::DAEAlgorithm{CS, AD}) where {CS, AD} = Val(CS) -function get_chunksize(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD}, - OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}}) where { - CS, - AD -} +function get_chunksize(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS}, + OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS}, + OrdinaryDiffEqImplicitAlgorithm{CS}, + OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS}, + DAEAlgorithm{CS}}) where {CS} Val(CS) end function get_chunksize_int(alg::OrdinaryDiffEqAlgorithm) error("This algorithm does not have a chunk size defined.") end -get_chunksize_int(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = CS -get_chunksize_int(alg::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = CS -get_chunksize_int(alg::DAEAlgorithm{CS, AD}) where {CS, AD} = CS -function get_chunksize_int(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD}, - OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}}) where { - CS, - AD -} +function get_chunksize_int(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS}, + OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS}, + OrdinaryDiffEqImplicitAlgorithm{CS}, + OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS}, + DAEAlgorithm{CS}}) where {CS} CS end # get_chunksize(alg::CompositeAlgorithm) = get_chunksize(alg.algs[alg.current_alg]) diff --git a/src/algorithms.jl b/src/algorithms.jl index ff2a9ce4f7..f2711809e5 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -3239,9 +3239,12 @@ end ######################################### -struct CompositeAlgorithm{T, F} <: OrdinaryDiffEqCompositeAlgorithm +struct CompositeAlgorithm{CS, T, F} <: OrdinaryDiffEqCompositeAlgorithm algs::T choice_function::F + function CompositeAlgorithm(algs::T, choice_function::F) where {T,F} + new{get_chunksize_int(algs[end]), T, F}(algs, choice_function) + end end TruncatedStacktraces.@truncate_stacktrace CompositeAlgorithm 1 diff --git a/src/caches/basic_caches.jl b/src/caches/basic_caches.jl index 74e4aec923..7134524176 100644 --- a/src/caches/basic_caches.jl +++ b/src/caches/basic_caches.jl @@ -43,10 +43,10 @@ function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoU CompositeCache(caches, alg.choice_function, 1) end -function alg_cache(alg::CompositeAlgorithm{Tuple{A1, A2, A3, A4, A5, A6}}, u, +function alg_cache(alg::CompositeAlgorithm{CS, Tuple{A1, A2, A3, A4, A5, A6}}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{V}) where {V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, A1, A2, A3, A4, A5, A6} + ::Val{V}) where {CS, V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, A1, A2, A3, A4, A5, A6} args = (u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, diff --git a/src/composite_algs.jl b/src/composite_algs.jl index 6e9447a4a4..32e60b2c58 100644 --- a/src/composite_algs.jl +++ b/src/composite_algs.jl @@ -31,7 +31,7 @@ end function (AS::AutoSwitchCache)(integrator) #horrible awful hack - isdefault = integrator.alg isa CompositeAlgorithm{<:Tuple{Tsit5, Vern7, Rosenbrock23, Rodas5P, FBDF, FBDF}} + isdefault = integrator.alg isa CompositeAlgorithm{<:Any, <:Tuple{Tsit5, Vern7, Rosenbrock23, Rodas5P, FBDF, FBDF}} if isdefault return default_autoswitch(AS, integrator) end diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index 5ef6997635..f21362c3b0 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -25,7 +25,7 @@ function rober(u, p, t) end prob_rober = ODEProblem(rober, [1.0,0.0,0.0],(0.0,1e5),(0.04,3e7,1e4)) sol = solve(prob_rober) -rosensol = solve(prob_rober, Rosenbrock23()) -# test that default isn't much worse than Rosenbrock23 (we expect it to use Rosenbrock23 for this). -@test sol.stats.naccept < rosensol.stats.naccept + 2 -@test sol.stats.nf < rosensol.stats.nf + 20 +rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23())) +# test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this). +@test sol.stats.naccept == rosensol.stats.naccept +@test sol.stats.nf == rosensol.stats.nf From 971fdf6479531eff37fa575a9fd4a49df3345b5b Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 15 May 2024 15:30:19 -0400 Subject: [PATCH 14/32] better test --- test/interface/default_solver_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index f21362c3b0..68d7774f45 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -7,8 +7,8 @@ sol = solve(prob_ode_2Dlinear) tsitsol = solve(prob_ode_2Dlinear, Tsit5()) # test that default isn't much worse than Tsit5 (we expect it to use Tsit5 for this). -@test sol.stats.naccept < tsitsol.stats.naccept + 2 -@test sol.stats.nf < tsitsol.stats.nf + 20 +@test sol.stats.naccept == tsitsol.stats.naccept +@test sol.stats.nf == tsitsol.stats.nf sol = solve(prob_ode_2Dlinear, reltol=1e-10) vernsol = solve(prob_ode_2Dlinear, Vern7(), reltol=1e-10) From 868d8363f47ac2758a7f44035579d27b523f0e41 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 15 May 2024 17:11:19 -0400 Subject: [PATCH 15/32] better tests --- test/interface/default_solver_tests.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index 68d7774f45..8554e5d696 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -9,12 +9,14 @@ tsitsol = solve(prob_ode_2Dlinear, Tsit5()) # test that default isn't much worse than Tsit5 (we expect it to use Tsit5 for this). @test sol.stats.naccept == tsitsol.stats.naccept @test sol.stats.nf == tsitsol.stats.nf +@test unique(sol.alg_choice) = [1] sol = solve(prob_ode_2Dlinear, reltol=1e-10) vernsol = solve(prob_ode_2Dlinear, Vern7(), reltol=1e-10) # test that default isn't much worse than Tsit5 (we expect it to use Tsit5 for this). @test sol.stats.naccept < tsitsol.stats.naccept + 2 @test sol.stats.nf < tsitsol.stats.nf + 20 +@test unique(sol.alg_choice) = [2] function rober(u, p, t) y₁, y₂, y₃ = u @@ -23,9 +25,18 @@ function rober(u, p, t) k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2, k₂ * y₂^2] end -prob_rober = ODEProblem(rober, [1.0,0.0,0.0],(0.0,1e5),(0.04,3e7,1e4)) +prob_rober = ODEProblem(rober, [1.0,0.0,0.0],(0.0,1e3),(0.04,3e7,1e4)) sol = solve(prob_rober) rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23())) # test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this). @test sol.stats.naccept == rosensol.stats.naccept @test sol.stats.nf == rosensol.stats.nf +@test sol.alg_choice[end] == 3 +@test unique(sol.alg_choice) == [1,3] + +sol = solve(prob_rober, reltol=1e-7, abstol=1e-7) +rosensol = solve(prob_rober, AutoVern7(Rodas5P(), lazy=false), reltol=1e-7, abstol=1e-7) +# test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this). +@test sol.stats.naccept == rosensol.stats.naccept +@test sol.stats.nf == rosensol.stats.nf +@test unique(sol.alg_choice) == [2,4] From 217fb5302e3e813750c3ec8cdfd89b8b79b030f4 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 15 May 2024 17:19:27 -0400 Subject: [PATCH 16/32] better tests --- test/interface/default_solver_tests.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index 8554e5d696..e43ebc786c 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -9,14 +9,14 @@ tsitsol = solve(prob_ode_2Dlinear, Tsit5()) # test that default isn't much worse than Tsit5 (we expect it to use Tsit5 for this). @test sol.stats.naccept == tsitsol.stats.naccept @test sol.stats.nf == tsitsol.stats.nf -@test unique(sol.alg_choice) = [1] +@test all(isequal(1), sol.alg_choice) sol = solve(prob_ode_2Dlinear, reltol=1e-10) vernsol = solve(prob_ode_2Dlinear, Vern7(), reltol=1e-10) # test that default isn't much worse than Tsit5 (we expect it to use Tsit5 for this). @test sol.stats.naccept < tsitsol.stats.naccept + 2 @test sol.stats.nf < tsitsol.stats.nf + 20 -@test unique(sol.alg_choice) = [2] +@test all(isequal(2), sol.alg_choice) function rober(u, p, t) y₁, y₂, y₃ = u @@ -31,8 +31,9 @@ rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23())) # test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this). @test sol.stats.naccept == rosensol.stats.naccept @test sol.stats.nf == rosensol.stats.nf -@test sol.alg_choice[end] == 3 @test unique(sol.alg_choice) == [1,3] +@test sol.alg_choice[1] == 1 +@test sol.alg_choice[end] == 3 sol = solve(prob_rober, reltol=1e-7, abstol=1e-7) rosensol = solve(prob_rober, AutoVern7(Rodas5P(), lazy=false), reltol=1e-7, abstol=1e-7) @@ -40,3 +41,5 @@ rosensol = solve(prob_rober, AutoVern7(Rodas5P(), lazy=false), reltol=1e-7, abst @test sol.stats.naccept == rosensol.stats.naccept @test sol.stats.nf == rosensol.stats.nf @test unique(sol.alg_choice) == [2,4] +@test sol.alg_choice[1] == 2 +@test sol.alg_choice[end] == 4 From 43e117a4d763b48580f1118abcdb2e3e4fc15d00 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 15 May 2024 17:31:44 -0400 Subject: [PATCH 17/32] fix composite chunksize --- src/alg_utils.jl | 7 +++++++ src/algorithms.jl | 8 +++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 5369311e81..6dd89c5773 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -206,6 +206,13 @@ qmax_default(alg::CompositeAlgorithm) = minimum(qmax_default.(alg.algs)) qmax_default(alg::DP8) = 6 qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8 +function has_chunksize(alg::OrdinaryDiffEqAlgorithm) + return alg isa Union{OrdinaryDiffEqExponentialAlgorithm, + OrdinaryDiffEqAdaptiveExponentialAlgorithm, + OrdinaryDiffEqImplicitAlgorithm, + OrdinaryDiffEqAdaptiveImplicitAlgorithm, + DAEAlgorithm} +end function get_chunksize(alg::OrdinaryDiffEqAlgorithm) error("This algorithm does not have a chunk size defined.") end diff --git a/src/algorithms.jl b/src/algorithms.jl index f2711809e5..3ae407d4d3 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -3243,7 +3243,13 @@ struct CompositeAlgorithm{CS, T, F} <: OrdinaryDiffEqCompositeAlgorithm algs::T choice_function::F function CompositeAlgorithm(algs::T, choice_function::F) where {T,F} - new{get_chunksize_int(algs[end]), T, F}(algs, choice_function) + CS = 0 + for alg in algs + if has_chunksize(alg) + CS = get_chunksize_int(alg) + end + end + new{CS, T, F}(algs, choice_function) end end From a5fa00f1e1c0c500a4eb60f7fef277c43f7621f1 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 15 May 2024 17:51:58 -0400 Subject: [PATCH 18/32] fix test --- test/interface/default_solver_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index e43ebc786c..0a3a8669c9 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -12,7 +12,7 @@ tsitsol = solve(prob_ode_2Dlinear, Tsit5()) @test all(isequal(1), sol.alg_choice) sol = solve(prob_ode_2Dlinear, reltol=1e-10) -vernsol = solve(prob_ode_2Dlinear, Vern7(), reltol=1e-10) +vernsol = solve(prob_ode_2Dlinear, AutoVern7(Rodas5P()), reltol=1e-10) # test that default isn't much worse than Tsit5 (we expect it to use Tsit5 for this). @test sol.stats.naccept < tsitsol.stats.naccept + 2 @test sol.stats.nf < tsitsol.stats.nf + 20 From b6ddc9e969ba462f24e6bdd235741474e837a266 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 15 May 2024 17:55:39 -0400 Subject: [PATCH 19/32] forgot to save before commit --- test/interface/default_solver_tests.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index 0a3a8669c9..eb720ca521 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -12,10 +12,10 @@ tsitsol = solve(prob_ode_2Dlinear, Tsit5()) @test all(isequal(1), sol.alg_choice) sol = solve(prob_ode_2Dlinear, reltol=1e-10) -vernsol = solve(prob_ode_2Dlinear, AutoVern7(Rodas5P()), reltol=1e-10) +vernsol = solve(prob_ode_2Dlinear, Vern7(), reltol=1e-10) # test that default isn't much worse than Tsit5 (we expect it to use Tsit5 for this). -@test sol.stats.naccept < tsitsol.stats.naccept + 2 -@test sol.stats.nf < tsitsol.stats.nf + 20 +@test sol.stats.naccept == vernsol.stats.naccept +@test sol.stats.nf == vernsol.stats.nf @test all(isequal(2), sol.alg_choice) function rober(u, p, t) From 9e82e8000c65b0e3b8901560590dddfdb6091156 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Thu, 16 May 2024 00:53:29 -0400 Subject: [PATCH 20/32] fix rober test --- test/interface/default_solver_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index eb720ca521..79477952e3 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -36,7 +36,7 @@ rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23())) @test sol.alg_choice[end] == 3 sol = solve(prob_rober, reltol=1e-7, abstol=1e-7) -rosensol = solve(prob_rober, AutoVern7(Rodas5P(), lazy=false), reltol=1e-7, abstol=1e-7) +rosensol = solve(prob_rober, AutoVern7(Rodas5P()), reltol=1e-7, abstol=1e-7) # test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this). @test sol.stats.naccept == rosensol.stats.naccept @test sol.stats.nf == rosensol.stats.nf From e046bbe66217311f6d559dfbbee71422802c3fd5 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 16 May 2024 15:21:30 -0400 Subject: [PATCH 21/32] add FBDF tests --- src/alg_utils.jl | 11 +++++++---- test/interface/default_solver_tests.jl | 24 +++++++++++++++++++++++- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 6dd89c5773..1afac21fd9 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -211,7 +211,8 @@ function has_chunksize(alg::OrdinaryDiffEqAlgorithm) OrdinaryDiffEqAdaptiveExponentialAlgorithm, OrdinaryDiffEqImplicitAlgorithm, OrdinaryDiffEqAdaptiveImplicitAlgorithm, - DAEAlgorithm} + DAEAlgorithm, + CompositeAlgorithm} end function get_chunksize(alg::OrdinaryDiffEqAlgorithm) error("This algorithm does not have a chunk size defined.") @@ -220,7 +221,8 @@ function get_chunksize(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS}, OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS}, OrdinaryDiffEqImplicitAlgorithm{CS}, OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS}, - DAEAlgorithm{CS}}) where {CS} + DAEAlgorithm{CS}, + CompositeAlgorithm{CS}}) where {CS} Val(CS) end @@ -231,7 +233,8 @@ function get_chunksize_int(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS}, OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS}, OrdinaryDiffEqImplicitAlgorithm{CS}, OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS}, - DAEAlgorithm{CS}}) where {CS} + DAEAlgorithm{CS}, + CompositeAlgorithm{CS}}) where {CS} CS end # get_chunksize(alg::CompositeAlgorithm) = get_chunksize(alg.algs[alg.current_alg]) @@ -970,7 +973,7 @@ function unwrap_alg(alg::SciMLBase.DEAlgorithm, is_stiff) if !(alg isa CompositeAlgorithm) return alg elseif alg.choice_function isa AutoSwitchCache - if length(alg.algs) >2 + if length(alg.algs) > 2 return alg.algs[alg.choice_function.current] end if is_stiff === nothing diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index 79477952e3..f5e4933448 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -1,4 +1,4 @@ -using OrdinaryDiffEq, Test +using OrdinaryDiffEq, Test, LinearSolve f_2dlinear = (du, u, p, t) -> (@. du = p * u) @@ -43,3 +43,25 @@ rosensol = solve(prob_rober, AutoVern7(Rodas5P()), reltol=1e-7, abstol=1e-7) @test unique(sol.alg_choice) == [2,4] @test sol.alg_choice[1] == 2 @test sol.alg_choice[end] == 4 + +function exrober(du, u, p, t) + y₁, y₂, y₃ = u + k₁, k₂, k₃ = p + du .= vcat([-k₁ * y₁ + k₃ * y₂ * y₃, + k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2, + k₂ * y₂^2, ], u[4:end]) +end + +for n in (100, 1000) + stiffalg = n < 500 ? 5 : 6 + linsolve = n < 500 ? nothing : KrylovJL_GMRES() + prob_ex_rober = ODEProblem(exrober, vcat([1.0,0.0,0.0], ones(n)),(0.0,100.0),(0.04,3e7,1e4)) + sol = solve(prob_ex_rober) + fsol = solve(prob_ex_rober, AutoTsit5(FBDF(;linsolve))) + # test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this). + @test sol.stats.naccept == fsol.stats.naccept + @test sol.stats.nf == fsol.stats.nf + @test unique(sol.alg_choice) == [1,stiffalg] + @test sol.alg_choice[1] == 1 + @test sol.alg_choice[end] == stiffalg +end From 091623aaba05f047c5142026972722c878adfe41 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 16 May 2024 16:23:07 -0400 Subject: [PATCH 22/32] in place works --- src/derivative_wrappers.jl | 1 - src/solve.jl | 6 ++++-- test/interface/default_solver_tests.jl | 9 ++++++--- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 236b9a82b0..13306d9bd2 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -284,7 +284,6 @@ function build_jac_config(alg, f::F1, uf::F2, du1, uprev, u, tmp, du2, if alg_autodiff(alg) isa AutoForwardDiff _chunksize = get_chunksize(alg) === Val(0) ? nothing : get_chunksize(alg) # SparseDiffEq uses different convection... - T = if standardtag(alg) typeof(ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(u))) else diff --git a/src/solve.jl b/src/solve.jl index 826d67a3ca..330ef41560 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -533,10 +533,12 @@ function DiffEqBase.__init( end function DiffEqBase.__init(prob::ODEProblem, ::Nothing, args...; kwargs...) - DiffEqBase.__init(prob, DefaultODEAlgorithm(), args...; kwargs...) + alg = DiffEqBase.prepare_alg(DefaultODEAlgorithm(), prob.u0, prob.p, prob) + DiffEqBase.__init(prob, alg, args...; kwargs...) end function DiffEqBase.__solve(prob::ODEProblem, ::Nothing, args...; kwargs...) - DiffEqBase.__solve(prob, DefaultODEAlgorithm(), args...; kwargs...) + alg = DiffEqBase.prepare_alg(DefaultODEAlgorithm(), prob.u0, prob.p, prob) + DiffEqBase.__solve(prob, alg, args...; kwargs...) end function DiffEqBase.solve!(integrator::ODEIntegrator) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index f5e4933448..2999d99a66 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -1,4 +1,4 @@ -using OrdinaryDiffEq, Test, LinearSolve +using OrdinaryDiffEq, Test, LinearSolve, LinearAlgebra, SparseArrays f_2dlinear = (du, u, p, t) -> (@. du = p * u) @@ -52,10 +52,13 @@ function exrober(du, u, p, t) k₂ * y₂^2, ], u[4:end]) end -for n in (100, 1000) +for n in (100, 600) stiffalg = n < 500 ? 5 : 6 linsolve = n < 500 ? nothing : KrylovJL_GMRES() - prob_ex_rober = ODEProblem(exrober, vcat([1.0,0.0,0.0], ones(n)),(0.0,100.0),(0.04,3e7,1e4)) + jac_prototype = sparse(I(n+3)) + jac_prototype[1:3, 1:3] .= 1.0 + + prob_ex_rober = ODEProblem(ODEFunction(exrober; jac_prototype), vcat([1.0,0.0,0.0], ones(n)),(0.0,100.0),(0.04,3e7,1e4)) sol = solve(prob_ex_rober) fsol = solve(prob_ex_rober, AutoTsit5(FBDF(;linsolve))) # test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this). From 569a25e39412aa370deec49e116697de4825bd7c Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Thu, 16 May 2024 20:20:30 -0400 Subject: [PATCH 23/32] don't bypass pipeline --- src/solve.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 330ef41560..9dcacc9156 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -534,11 +534,11 @@ end function DiffEqBase.__init(prob::ODEProblem, ::Nothing, args...; kwargs...) alg = DiffEqBase.prepare_alg(DefaultODEAlgorithm(), prob.u0, prob.p, prob) - DiffEqBase.__init(prob, alg, args...; kwargs...) + DiffEqBase.init(prob, alg, args...; kwargs...) end function DiffEqBase.__solve(prob::ODEProblem, ::Nothing, args...; kwargs...) alg = DiffEqBase.prepare_alg(DefaultODEAlgorithm(), prob.u0, prob.p, prob) - DiffEqBase.__solve(prob, alg, args...; kwargs...) + DiffEqBase.solve(prob, alg, args...; kwargs...) end function DiffEqBase.solve!(integrator::ODEIntegrator) From 28c7fe79e2f280f619b0a6fa3de6cc1778c314c8 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 17 May 2024 07:08:49 -0400 Subject: [PATCH 24/32] Update src/solve.jl --- src/solve.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 9dcacc9156..7c0bbe067e 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -533,8 +533,7 @@ function DiffEqBase.__init( end function DiffEqBase.__init(prob::ODEProblem, ::Nothing, args...; kwargs...) - alg = DiffEqBase.prepare_alg(DefaultODEAlgorithm(), prob.u0, prob.p, prob) - DiffEqBase.init(prob, alg, args...; kwargs...) + DiffEqBase.init(prob, DefaultODEAlgorithm(autodiff=false), args...; kwargs...) end function DiffEqBase.__solve(prob::ODEProblem, ::Nothing, args...; kwargs...) alg = DiffEqBase.prepare_alg(DefaultODEAlgorithm(), prob.u0, prob.p, prob) From 3bd4a71384248aab85fd857b41ec7b1f762ed322 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 17 May 2024 07:09:03 -0400 Subject: [PATCH 25/32] Update src/solve.jl --- src/solve.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 7c0bbe067e..dc91f4644c 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -536,8 +536,7 @@ function DiffEqBase.__init(prob::ODEProblem, ::Nothing, args...; kwargs...) DiffEqBase.init(prob, DefaultODEAlgorithm(autodiff=false), args...; kwargs...) end function DiffEqBase.__solve(prob::ODEProblem, ::Nothing, args...; kwargs...) - alg = DiffEqBase.prepare_alg(DefaultODEAlgorithm(), prob.u0, prob.p, prob) - DiffEqBase.solve(prob, alg, args...; kwargs...) + DiffEqBase.solve(prob, DefaultODEAlgorithm(autodiff=false), args...; kwargs...) end function DiffEqBase.solve!(integrator::ODEIntegrator) From ba897b9c9e634c1fffe9a0d610d274727d1cc3e9 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 17 May 2024 13:01:49 -0400 Subject: [PATCH 26/32] fix test failures and re-enable autodiff --- src/cache_utils.jl | 15 +++++++++++++++ src/solve.jl | 4 ++-- test/interface/default_solver_tests.jl | 4 ++-- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/cache_utils.jl b/src/cache_utils.jl index bb9810562a..0297e8578f 100644 --- a/src/cache_utils.jl +++ b/src/cache_utils.jl @@ -8,6 +8,21 @@ function DiffEqBase.unwrap_cache(integrator::ODEIntegrator, is_stiff) iscomp = alg isa CompositeAlgorithm if !iscomp return cache + elseif cache isa DefaultCache + current = integrator.cache.current + if current == 1 + return cache.cache1 + elseif current == 2 + return cache.cache2 + elseif current == 3 + return cache.cache3 + elseif current == 4 + return cache.cache4 + elseif current == 5 + return cache.cache5 + elseif current == 6 + return cache.cache6 + end elseif alg.choice_function isa AutoSwitch num = is_stiff ? 2 : 1 return cache.caches[num] diff --git a/src/solve.jl b/src/solve.jl index dc91f4644c..94704c38af 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -533,10 +533,10 @@ function DiffEqBase.__init( end function DiffEqBase.__init(prob::ODEProblem, ::Nothing, args...; kwargs...) - DiffEqBase.init(prob, DefaultODEAlgorithm(autodiff=false), args...; kwargs...) + DiffEqBase.init(prob, DefaultODEAlgorithm(), args...; kwargs...) end function DiffEqBase.__solve(prob::ODEProblem, ::Nothing, args...; kwargs...) - DiffEqBase.solve(prob, DefaultODEAlgorithm(autodiff=false), args...; kwargs...) + DiffEqBase.solve(prob, DefaultODEAlgorithm(), args...; kwargs...) end function DiffEqBase.solve!(integrator::ODEIntegrator) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index 2999d99a66..13682d3930 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -52,8 +52,8 @@ function exrober(du, u, p, t) k₂ * y₂^2, ], u[4:end]) end -for n in (100, 600) - stiffalg = n < 500 ? 5 : 6 +for n in (100, ) # 600 should be added but currently is broken for unknown reasons + stiffalg = n < 50 ? 4 : n < 500 ? 5 : 6 linsolve = n < 500 ? nothing : KrylovJL_GMRES() jac_prototype = sparse(I(n+3)) jac_prototype[1:3, 1:3] .= 1.0 From fdf946bd1ab311519430a4e4da7992c8301b4e90 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 17 May 2024 16:38:38 -0400 Subject: [PATCH 27/32] disable autodiff --- src/solve.jl | 4 ++-- test/interface/default_solver_tests.jl | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 94704c38af..dc91f4644c 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -533,10 +533,10 @@ function DiffEqBase.__init( end function DiffEqBase.__init(prob::ODEProblem, ::Nothing, args...; kwargs...) - DiffEqBase.init(prob, DefaultODEAlgorithm(), args...; kwargs...) + DiffEqBase.init(prob, DefaultODEAlgorithm(autodiff=false), args...; kwargs...) end function DiffEqBase.__solve(prob::ODEProblem, ::Nothing, args...; kwargs...) - DiffEqBase.solve(prob, DefaultODEAlgorithm(), args...; kwargs...) + DiffEqBase.solve(prob, DefaultODEAlgorithm(autodiff=false), args...; kwargs...) end function DiffEqBase.solve!(integrator::ODEIntegrator) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index 13682d3930..0ea8bd1be8 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -27,7 +27,7 @@ function rober(u, p, t) end prob_rober = ODEProblem(rober, [1.0,0.0,0.0],(0.0,1e3),(0.04,3e7,1e4)) sol = solve(prob_rober) -rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23())) +rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23(autodiff=false))) # test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this). @test sol.stats.naccept == rosensol.stats.naccept @test sol.stats.nf == rosensol.stats.nf @@ -36,7 +36,7 @@ rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23())) @test sol.alg_choice[end] == 3 sol = solve(prob_rober, reltol=1e-7, abstol=1e-7) -rosensol = solve(prob_rober, AutoVern7(Rodas5P()), reltol=1e-7, abstol=1e-7) +rosensol = solve(prob_rober, AutoVern7(Rodas5P(autodiff=false)), reltol=1e-7, abstol=1e-7) # test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this). @test sol.stats.naccept == rosensol.stats.naccept @test sol.stats.nf == rosensol.stats.nf @@ -60,7 +60,7 @@ for n in (100, ) # 600 should be added but currently is broken for unknown reaso prob_ex_rober = ODEProblem(ODEFunction(exrober; jac_prototype), vcat([1.0,0.0,0.0], ones(n)),(0.0,100.0),(0.04,3e7,1e4)) sol = solve(prob_ex_rober) - fsol = solve(prob_ex_rober, AutoTsit5(FBDF(;linsolve))) + fsol = solve(prob_ex_rober, AutoTsit5(FBDF(;autodiff=false, linsolve))) # test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this). @test sol.stats.naccept == fsol.stats.naccept @test sol.stats.nf == fsol.stats.nf From b832f2fa4e0af69455d23a2a5bdabe39298a07fd Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Mon, 20 May 2024 11:09:50 -0400 Subject: [PATCH 28/32] fix non-identity mass matrix with default solver and test --- src/alg_utils.jl | 2 ++ src/composite_algs.jl | 2 +- test/interface/default_solver_tests.jl | 14 ++++++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 1afac21fd9..59512be35e 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -1081,3 +1081,5 @@ is_mass_matrix_alg(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false is_mass_matrix_alg(alg::CompositeAlgorithm) = all(is_mass_matrix_alg, alg.algs) is_mass_matrix_alg(alg::RosenbrockAlgorithm) = true is_mass_matrix_alg(alg::NewtonAlgorithm) = !isesdirk(alg) +# hack for the default alg +is_mass_matrix_alg(alg::CompositeAlgorithm{<:Any, <:Tuple{Tsit5, Vern7, Rosenbrock23, Rodas5P, FBDF, FBDF}}) = true diff --git a/src/composite_algs.jl b/src/composite_algs.jl index 32e60b2c58..2c730e76b1 100644 --- a/src/composite_algs.jl +++ b/src/composite_algs.jl @@ -173,7 +173,7 @@ function default_autoswitch(AS::AutoSwitchCache, integrator) AS.is_stiffalg, AS.current) ? AS.count < 0 ? 1 : AS.count + 1 : AS.count > 0 ? -1 : AS.count - 1 - if (!AS.is_stiffalg && AS.count > AS.maxstiffstep) + if integrator.f.mass_matrix != I || (!AS.is_stiffalg && AS.count > AS.maxstiffstep) integrator.dt = dt * AS.dtfac AS.is_stiffalg = true AS.current = stiffchoice(reltol, len) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index 0ea8bd1be8..8868644889 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -18,6 +18,11 @@ vernsol = solve(prob_ode_2Dlinear, Vern7(), reltol=1e-10) @test sol.stats.nf == vernsol.stats.nf @test all(isequal(2), sol.alg_choice) +prob_ode_linear_fast = ODEProblem(ODEFunction(f_2dlinear, mass_matrix=2*I(2)), rand(2), (0.0, 1.0), 1.01) +sol = solve(prob_ode_linear_fast) +@test all(isequal(3), sol.alg_choice) +# for some reason the timestepping here is different from regular Rosenbrock23 (including the initial timestep) + function rober(u, p, t) y₁, y₂, y₃ = u k₁, k₂, k₃ = p @@ -68,3 +73,12 @@ for n in (100, ) # 600 should be added but currently is broken for unknown reaso @test sol.alg_choice[1] == 1 @test sol.alg_choice[end] == stiffalg end + +function swaplinear(u, p, t) + [u[2], u[1]].*p +end +swaplinearf = ODEFunction(swaplinear, mass_matrix=ones(2,2)-I(2)) +prob_swaplinear = ODEProblem(swaplinearf, rand(2), (0., 1.) 1.01) +sol = solve(prob_swaplinear, reltol=1e-7) # reltol must be set to avoid running into a bug with Rosenbrock23 +@test all(isequal(4), sol.alg_choice) +# for some reason the timestepping here is different from regular Rodas5P (including the initial timestep) From 4cef511d63d277eb37f7c92cf66e86e4ba4ec0d4 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 20 May 2024 13:24:00 -0400 Subject: [PATCH 29/32] typo --- test/interface/default_solver_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index 8868644889..98ca79df13 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -78,7 +78,7 @@ function swaplinear(u, p, t) [u[2], u[1]].*p end swaplinearf = ODEFunction(swaplinear, mass_matrix=ones(2,2)-I(2)) -prob_swaplinear = ODEProblem(swaplinearf, rand(2), (0., 1.) 1.01) +prob_swaplinear = ODEProblem(swaplinearf, rand(2), (0., 1.), 1.01) sol = solve(prob_swaplinear, reltol=1e-7) # reltol must be set to avoid running into a bug with Rosenbrock23 @test all(isequal(4), sol.alg_choice) # for some reason the timestepping here is different from regular Rodas5P (including the initial timestep) From cbf273accf551bfe6c8f9ee2a161cef9589e12d6 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 20 May 2024 13:52:29 -0400 Subject: [PATCH 30/32] fix chunksize inference --- src/algorithms.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index 3ae407d4d3..da3b321e31 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -3243,12 +3243,7 @@ struct CompositeAlgorithm{CS, T, F} <: OrdinaryDiffEqCompositeAlgorithm algs::T choice_function::F function CompositeAlgorithm(algs::T, choice_function::F) where {T,F} - CS = 0 - for alg in algs - if has_chunksize(alg) - CS = get_chunksize_int(alg) - end - end + CS = mapreduce(alg->has_chunksize(alg) ? get_chunksize_int(alg) : 0, max, algs) new{CS, T, F}(algs, choice_function) end end From fc7600622d4340374ff4f8e659ecd2fd3f2752c9 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 20 May 2024 13:53:23 -0400 Subject: [PATCH 31/32] add inferred test --- test/interface/default_solver_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index 98ca79df13..5bc49a4306 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -3,7 +3,7 @@ using OrdinaryDiffEq, Test, LinearSolve, LinearAlgebra, SparseArrays f_2dlinear = (du, u, p, t) -> (@. du = p * u) prob_ode_2Dlinear = ODEProblem(f_2dlinear, rand(4, 2), (0.0, 1.0), 1.01) -sol = solve(prob_ode_2Dlinear) +sol = @inferred solve(prob_ode_2Dlinear) tsitsol = solve(prob_ode_2Dlinear, Tsit5()) # test that default isn't much worse than Tsit5 (we expect it to use Tsit5 for this). From d7193fbd2d9b92ae6e5ede07ca6434ffaa0f465e Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 20 May 2024 15:49:34 -0400 Subject: [PATCH 32/32] fix tests --- test/interface/default_solver_tests.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl index 5bc49a4306..0a0af32148 100644 --- a/test/interface/default_solver_tests.jl +++ b/test/interface/default_solver_tests.jl @@ -70,8 +70,6 @@ for n in (100, ) # 600 should be added but currently is broken for unknown reaso @test sol.stats.naccept == fsol.stats.naccept @test sol.stats.nf == fsol.stats.nf @test unique(sol.alg_choice) == [1,stiffalg] - @test sol.alg_choice[1] == 1 - @test sol.alg_choice[end] == stiffalg end function swaplinear(u, p, t)