diff --git a/Project.toml b/Project.toml index 5b69a55c5b..79aa0340d1 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" @@ -75,7 +76,6 @@ RecursiveArrayTools = "2.36, 3" Reexport = "1.0" SciMLBase = "2.27.1" SciMLOperators = "0.3" -SciMLStructures = "1" SimpleNonlinearSolve = "1" SimpleUnPack = "1" SparseArrays = "1.9" diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index ae6b2001ab..4aa2e5b187 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -26,6 +26,8 @@ using LinearSolve, SimpleNonlinearSolve using LineSearches +import EnumX + import FillArrays: Trues # Interfaces @@ -141,6 +143,7 @@ include("nlsolve/functional.jl") include("nlsolve/newton.jl") include("generic_rosenbrock.jl") +include("composite_algs.jl") include("caches/basic_caches.jl") include("caches/low_order_rk_caches.jl") @@ -234,7 +237,6 @@ include("constants.jl") include("solve.jl") include("initdt.jl") include("interp_func.jl") -include("composite_algs.jl") import PrecompileTools @@ -253,9 +255,14 @@ PrecompileTools.@compile_workload begin Tsit5(), Vern7() ] - stiff = [Rosenbrock23(), Rosenbrock23(autodiff = false), - Rodas5P(), Rodas5P(autodiff = false), - FBDF(), FBDF(autodiff = false) + stiff = [Rosenbrock23(), + Rodas5P(), + FBDF() + ] + + default_ode = [ + DefaultODEAlgorithm(autodiff=false), + DefaultODEAlgorithm() ] autoswitch = [ @@ -284,7 +291,11 @@ PrecompileTools.@compile_workload begin append!(solver_list, stiff) end - if Preferences.@load_preference("PrecompileAutoSwitch", true) + if Preferences.@load_preference("PrecompileDefault", true) + append!(solver_list, default_ode) + end + + if Preferences.@load_preference("PrecompileAutoSwitch", false) append!(solver_list, autoswitch) end diff --git a/src/alg_utils.jl b/src/alg_utils.jl index d82fce2f1e..59512be35e 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -172,6 +172,7 @@ isimplicit(alg::CompositeAlgorithm) = any(isimplicit.(alg.algs)) isdtchangeable(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = true isdtchangeable(alg::CompositeAlgorithm) = all(isdtchangeable.(alg.algs)) + function isdtchangeable(alg::Union{LawsonEuler, NorsettEuler, LieEuler, MagnusGauss4, CayleyEuler, ETDRK2, ETDRK3, ETDRK4, HochOst4, ETD2}) false @@ -205,31 +206,35 @@ qmax_default(alg::CompositeAlgorithm) = minimum(qmax_default.(alg.algs)) qmax_default(alg::DP8) = 6 qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8 +function has_chunksize(alg::OrdinaryDiffEqAlgorithm) + return alg isa Union{OrdinaryDiffEqExponentialAlgorithm, + OrdinaryDiffEqAdaptiveExponentialAlgorithm, + OrdinaryDiffEqImplicitAlgorithm, + OrdinaryDiffEqAdaptiveImplicitAlgorithm, + DAEAlgorithm, + CompositeAlgorithm} +end function get_chunksize(alg::OrdinaryDiffEqAlgorithm) error("This algorithm does not have a chunk size defined.") end -get_chunksize(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = Val(CS) -get_chunksize(alg::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = Val(CS) -get_chunksize(alg::DAEAlgorithm{CS, AD}) where {CS, AD} = Val(CS) -function get_chunksize(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD}, - OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}}) where { - CS, - AD -} +function get_chunksize(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS}, + OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS}, + OrdinaryDiffEqImplicitAlgorithm{CS}, + OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS}, + DAEAlgorithm{CS}, + CompositeAlgorithm{CS}}) where {CS} Val(CS) end function get_chunksize_int(alg::OrdinaryDiffEqAlgorithm) error("This algorithm does not have a chunk size defined.") end -get_chunksize_int(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = CS -get_chunksize_int(alg::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = CS -get_chunksize_int(alg::DAEAlgorithm{CS, AD}) where {CS, AD} = CS -function get_chunksize_int(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD}, - OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}}) where { - CS, - AD -} +function get_chunksize_int(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS}, + OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS}, + OrdinaryDiffEqImplicitAlgorithm{CS}, + OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS}, + DAEAlgorithm{CS}, + CompositeAlgorithm{CS}}) where {CS} CS end # get_chunksize(alg::CompositeAlgorithm) = get_chunksize(alg.algs[alg.current_alg]) @@ -965,10 +970,12 @@ alg_can_repeat_jac(alg::OrdinaryDiffEqNewtonAdaptiveAlgorithm) = true alg_can_repeat_jac(alg::IRKC) = false function unwrap_alg(alg::SciMLBase.DEAlgorithm, is_stiff) - iscomp = alg isa CompositeAlgorithm - if !iscomp + if !(alg isa CompositeAlgorithm) return alg elseif alg.choice_function isa AutoSwitchCache + if length(alg.algs) > 2 + return alg.algs[alg.choice_function.current] + end if is_stiff === nothing throwautoswitch(alg) end @@ -985,18 +992,21 @@ end function unwrap_alg(integrator, is_stiff) alg = integrator.alg - iscomp = alg isa CompositeAlgorithm - if !iscomp + if !(alg isa CompositeAlgorithm) return alg elseif alg.choice_function isa AutoSwitchCache - if is_stiff === nothing - throwautoswitch(alg) - end - num = is_stiff ? 2 : 1 - if num == 1 - return alg.algs[1] + if length(alg.algs) > 2 + alg.algs[alg.choice_function.current] else - return alg.algs[2] + if is_stiff === nothing + throwautoswitch(alg) + end + num = is_stiff ? 2 : 1 + if num == 1 + return alg.algs[1] + else + return alg.algs[2] + end end else return _eval_index(identity, alg.algs, integrator.cache.current) @@ -1071,3 +1081,5 @@ is_mass_matrix_alg(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false is_mass_matrix_alg(alg::CompositeAlgorithm) = all(is_mass_matrix_alg, alg.algs) is_mass_matrix_alg(alg::RosenbrockAlgorithm) = true is_mass_matrix_alg(alg::NewtonAlgorithm) = !isesdirk(alg) +# hack for the default alg +is_mass_matrix_alg(alg::CompositeAlgorithm{<:Any, <:Tuple{Tsit5, Vern7, Rosenbrock23, Rodas5P, FBDF, FBDF}}) = true diff --git a/src/algorithms.jl b/src/algorithms.jl index bc2a0ac0f5..da3b321e31 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -2969,7 +2969,7 @@ Scientific Computing, 18 (1), pp. 1-22. differential-algebraic problems. Computational mathematics (2nd revised ed.), Springer (1996) #### ROS2PR, ROS2S, ROS3PR, Scholz4_7 --Rang, Joachim (2014): The Prothero and Robinson example: +-Rang, Joachim (2014): The Prothero and Robinson example: Convergence studies for Runge-Kutta and Rosenbrock-Wanner methods. https://doi.org/10.24355/dbbs.084-201408121139-0 @@ -3014,16 +3014,16 @@ University of Geneva, Switzerland. https://doi.org/10.1016/j.cam.2015.03.010 #### ROS3PRL, ROS3PRL2 --Rang, Joachim (2014): The Prothero and Robinson example: +-Rang, Joachim (2014): The Prothero and Robinson example: Convergence studies for Runge-Kutta and Rosenbrock-Wanner methods. https://doi.org/10.24355/dbbs.084-201408121139-0 #### Rodas5P -- Steinebach G. Construction of Rosenbrock–Wanner method Rodas5P and numerical benchmarks within the Julia Differential Equations package. +- Steinebach G. Construction of Rosenbrock–Wanner method Rodas5P and numerical benchmarks within the Julia Differential Equations package. In: BIT Numerical Mathematics, 63(2), 2023 #### Rodas23W, Rodas3P, Rodas5Pe, Rodas5Pr -- Steinebach G. Rosenbrock methods within OrdinaryDiffEq.jl - Overview, recent developments and applications - +- Steinebach G. Rosenbrock methods within OrdinaryDiffEq.jl - Overview, recent developments and applications - Preprint 2024 https://github.com/hbrs-cse/RosenbrockMethods/blob/main/paper/JuliaPaper.pdf @@ -3239,9 +3239,13 @@ end ######################################### -struct CompositeAlgorithm{T, F} <: OrdinaryDiffEqCompositeAlgorithm +struct CompositeAlgorithm{CS, T, F} <: OrdinaryDiffEqCompositeAlgorithm algs::T choice_function::F + function CompositeAlgorithm(algs::T, choice_function::F) where {T,F} + CS = mapreduce(alg->has_chunksize(alg) ? get_chunksize_int(alg) : 0, max, algs) + new{CS, T, F}(algs, choice_function) + end end TruncatedStacktraces.@truncate_stacktrace CompositeAlgorithm 1 @@ -3250,6 +3254,62 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!) Base.Experimental.silence!(CompositeAlgorithm) end +mutable struct AutoSwitchCache{nAlg, sAlg, tolType, T} + count::Int + successive_switches::Int + nonstiffalg::nAlg + stiffalg::sAlg + is_stiffalg::Bool + maxstiffstep::Int + maxnonstiffstep::Int + nonstifftol::tolType + stifftol::tolType + dtfac::T + stiffalgfirst::Bool + switch_max::Int + current::Int + function AutoSwitchCache(count::Int, + successive_switches::Int, + nonstiffalg::nAlg, + stiffalg::sAlg, + is_stiffalg::Bool, + maxstiffstep::Int, + maxnonstiffstep::Int, + nonstifftol::tolType, + stifftol::tolType, + dtfac::T, + stiffalgfirst::Bool, + switch_max::Int, + current::Int=0) where {nAlg, sAlg, tolType, T} + new{nAlg, sAlg, tolType, T}(count, + successive_switches, + nonstiffalg, + stiffalg, + is_stiffalg, + maxstiffstep, + maxnonstiffstep, + nonstifftol, + stifftol, + dtfac, + stiffalgfirst, + switch_max, + current) + end + +end + +struct AutoSwitch{nAlg, sAlg, tolType, T} + nonstiffalg::nAlg + stiffalg::sAlg + maxstiffstep::Int + maxnonstiffstep::Int + nonstifftol::tolType + stifftol::tolType + dtfac::T + stiffalgfirst::Bool + switch_max::Int +end + ################################################################################ """ MEBDF2: Multistep Method diff --git a/src/cache_utils.jl b/src/cache_utils.jl index bb9810562a..0297e8578f 100644 --- a/src/cache_utils.jl +++ b/src/cache_utils.jl @@ -8,6 +8,21 @@ function DiffEqBase.unwrap_cache(integrator::ODEIntegrator, is_stiff) iscomp = alg isa CompositeAlgorithm if !iscomp return cache + elseif cache isa DefaultCache + current = integrator.cache.current + if current == 1 + return cache.cache1 + elseif current == 2 + return cache.cache2 + elseif current == 3 + return cache.cache3 + elseif current == 4 + return cache.cache4 + elseif current == 5 + return cache.cache5 + elseif current == 6 + return cache.cache6 + end elseif alg.choice_function isa AutoSwitch num = is_stiff ? 2 : 1 return cache.caches[num] diff --git a/src/caches/basic_caches.jl b/src/caches/basic_caches.jl index fa492dbe7a..7134524176 100644 --- a/src/caches/basic_caches.jl +++ b/src/caches/basic_caches.jl @@ -12,24 +12,26 @@ end TruncatedStacktraces.@truncate_stacktrace CompositeCache 1 -if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!) - Base.Experimental.silence!(CompositeCache) +mutable struct DefaultCache{T1, T2, T3, T4, T5, T6, A, F} <: OrdinaryDiffEqCache + args::A + choice_function::F + current::Int + cache1::T1 + cache2::T2 + cache3::T3 + cache4::T4 + cache5::T5 + cache6::T6 + function DefaultCache{T1, T2, T3, T4, T5, T6, F}(args, choice_function, current) where {T1, T2, T3, T4, T5, T6, F} + new{T1, T2, T3, T4, T5, T6, typeof(args), F}(args, choice_function, current) + end end -function alg_cache(alg::CompositeAlgorithm{Tuple{T1, T2}, F}, u, rate_prototype, - ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, - uprev2, f, t, dt, reltol, p, calck, - ::Val{V}) where {T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits, - tTypeNoUnits} - caches = ( - alg_cache(alg.algs[1], u, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, - tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)), - alg_cache(alg.algs[2], u, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, - tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V))) - CompositeCache(caches, alg.choice_function, 1) +TruncatedStacktraces.@truncate_stacktrace DefaultCache 1 + +if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!) + Base.Experimental.silence!(CompositeCache) + Base.Experimental.silence!(DefaultCache) end function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits}, @@ -41,6 +43,24 @@ function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoU CompositeCache(caches, alg.choice_function, 1) end +function alg_cache(alg::CompositeAlgorithm{CS, Tuple{A1, A2, A3, A4, A5, A6}}, u, + rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{V}) where {CS, V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, A1, A2, A3, A4, A5, A6} + + args = (u, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt, + reltol, p, calck, Val(V)) + argT = map(typeof, args) + T1 = Base.promote_op(alg_cache, A1, argT...) + T2 = Base.promote_op(alg_cache, A2, argT...) + T3 = Base.promote_op(alg_cache, A3, argT...) + T4 = Base.promote_op(alg_cache, A4, argT...) + T5 = Base.promote_op(alg_cache, A5, argT...) + T6 = Base.promote_op(alg_cache, A6, argT...) + DefaultCache{T1, T2, T3, T4, T5, T6, typeof(alg.choice_function)}(args, alg.choice_function, 1) +end + # map + closure approach doesn't infer @generated function __alg_cache(algs::T, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, diff --git a/src/caches/verner_caches.jl b/src/caches/verner_caches.jl index 08b1de0919..6de34d0559 100644 --- a/src/caches/verner_caches.jl +++ b/src/caches/verner_caches.jl @@ -20,6 +20,7 @@ stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern6Cache 1 @@ -44,11 +45,12 @@ function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern6Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, utilde, tmp, rtmp, atmp, tab, - alg.stage_limiter!, alg.step_limiter!, alg.thread) + alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end struct Vern6ConstantCache{TabType} <: OrdinaryDiffEqConstantCache tab::TabType + lazy::Bool end function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, @@ -56,7 +58,7 @@ function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits}, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tab = Vern6Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - Vern6ConstantCache(tab) + Vern6ConstantCache(tab, alg.lazy) end @cache struct Vern7Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter, @@ -81,6 +83,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern7Cache 1 @@ -105,16 +108,18 @@ function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern7Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, utilde, tmp, rtmp, atmp, - alg.stage_limiter!, alg.step_limiter!, alg.thread) + alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end -struct Vern7ConstantCache <: OrdinaryDiffEqConstantCache end +struct Vern7ConstantCache <: OrdinaryDiffEqConstantCache + lazy::Bool +end function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - Vern7ConstantCache() + Vern7ConstantCache(alg.lazy) end @cache struct Vern8Cache{uType, rateType, uNoUnitsType, TabType, StageLimiter, StepLimiter, @@ -143,6 +148,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern8Cache 1 @@ -171,11 +177,12 @@ function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, recursivefill!(atmp, false) rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern8Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, utilde, - tmp, rtmp, atmp, tab, alg.stage_limiter!, alg.step_limiter!, alg.thread) + tmp, rtmp, atmp, tab, alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) end struct Vern8ConstantCache{TabType} <: OrdinaryDiffEqConstantCache tab::TabType + lazy::Bool end function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, @@ -183,7 +190,7 @@ function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tab = Vern8Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - Vern8ConstantCache(tab) + Vern8ConstantCache(tab, alg.lazy) end @cache struct Vern9Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter, @@ -214,6 +221,7 @@ end stage_limiter!::StageLimiter step_limiter!::StepLimiter thread::Thread + lazy::Bool end TruncatedStacktraces.@truncate_stacktrace Vern9Cache 1 @@ -245,14 +253,16 @@ function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits}, rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) Vern9Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16, utilde, tmp, rtmp, atmp, alg.stage_limiter!, alg.step_limiter!, - alg.thread) + alg.thread, alg.lazy) end -struct Vern9ConstantCache <: OrdinaryDiffEqConstantCache end +struct Vern9ConstantCache <: OrdinaryDiffEqConstantCache + lazy::Bool +end function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - Vern9ConstantCache() + Vern9ConstantCache(alg.lazy) end diff --git a/src/composite_algs.jl b/src/composite_algs.jl index 9f1d4dcf6b..2c730e76b1 100644 --- a/src/composite_algs.jl +++ b/src/composite_algs.jl @@ -1,30 +1,8 @@ -mutable struct AutoSwitchCache{nAlg, sAlg, tolType, T} - count::Int - successive_switches::Int - nonstiffalg::nAlg - stiffalg::sAlg - is_stiffalg::Bool - maxstiffstep::Int - maxnonstiffstep::Int - nonstifftol::tolType - stifftol::tolType - dtfac::T - stiffalgfirst::Bool - switch_max::Int -end +### AutoSwitch +### Designed to switch between two solvers, stiff and non-stiff -struct AutoSwitch{nAlg, sAlg, tolType, T} - nonstiffalg::nAlg - stiffalg::sAlg - maxstiffstep::Int - maxnonstiffstep::Int - nonstifftol::tolType - stifftol::tolType - dtfac::T - stiffalgfirst::Bool - switch_max::Int -end -function AutoSwitch(nonstiffalg, stiffalg; maxstiffstep = 10, maxnonstiffstep = 3, +function AutoSwitch(nonstiffalg, stiffalg; + maxstiffstep = 10, maxnonstiffstep = 3, nonstifftol = 9 // 10, stifftol = 9 // 10, dtfac = 2, stiffalgfirst = false, switch_max = 5) @@ -41,7 +19,6 @@ function is_stiff(integrator, alg, ntol, stol, is_stiffalg) if !bool integrator.alg.choice_function.successive_switches += 1 - integrator.do_error_check = false else integrator.alg.choice_function.successive_switches = 0 end @@ -53,7 +30,16 @@ function is_stiff(integrator, alg, ntol, stol, is_stiffalg) end function (AS::AutoSwitchCache)(integrator) - integrator.iter == 0 && return Int(AS.stiffalgfirst) + 1 + #horrible awful hack + isdefault = integrator.alg isa CompositeAlgorithm{<:Any, <:Tuple{Tsit5, Vern7, Rosenbrock23, Rodas5P, FBDF, FBDF}} + if isdefault + return default_autoswitch(AS, integrator) + end + if AS.current == 0 + AS.current = Int(AS.stiffalgfirst) + 1 + return AS.current + end + dt = integrator.dt # Successive stiffness test positives are counted by a positive integer, # and successive stiffness test negatives are counted by a negative integer @@ -68,17 +54,133 @@ function (AS::AutoSwitchCache)(integrator) integrator.dt = dt / AS.dtfac AS.is_stiffalg = false end - return Int(AS.is_stiffalg) + 1 + AS.current = Int(AS.is_stiffalg) + 1 + return AS.current end -function AutoAlgSwitch(nonstiffalg, stiffalg; kwargs...) +function AutoAlgSwitch(nonstiffalg::OrdinaryDiffEqAlgorithm, stiffalg::OrdinaryDiffEqAlgorithm; kwargs...) AS = AutoSwitch(nonstiffalg, stiffalg; kwargs...) CompositeAlgorithm((nonstiffalg, stiffalg), AS) end +function AutoAlgSwitch(nonstiffalg::Tuple, stiffalg::Tuple; kwargs...) + AS = AutoSwitch(nonstiffalg, stiffalg; kwargs...) + CompositeAlgorithm((nonstiffalg..., stiffalg...), AS) +end + AutoTsit5(alg; kwargs...) = AutoAlgSwitch(Tsit5(), alg; kwargs...) AutoDP5(alg; kwargs...) = AutoAlgSwitch(DP5(), alg; kwargs...) AutoVern6(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern6(lazy = lazy), alg; kwargs...) AutoVern7(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern7(lazy = lazy), alg; kwargs...) AutoVern8(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern8(lazy = lazy), alg; kwargs...) AutoVern9(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern9(lazy = lazy), alg; kwargs...) + +### Default ODE Solver + +EnumX.@enumx DefaultSolverChoice begin + Tsit5 = 1 + Vern7 = 2 + Rosenbrock23 = 3 + Rodas5P = 4 + FBDF = 5 + KrylovFBDF = 6 +end + +const NUM_NONSTIFF = 2 +const NUM_STIFF = 4 +const LOW_TOL = 1e-6 +const MED_TOL = 1e-2 +const EXTREME_TOL = 1e-9 +const SMALLSIZE = 50 +const MEDIUMSIZE = 500 +const STABILITY_SIZES = (alg_stability_size(Tsit5()), alg_stability_size(Vern7())) +const DEFAULTBETA2S = (beta2_default(Tsit5()), beta2_default(Vern7()), beta2_default(Rosenbrock23()), beta2_default(Rodas5P()), beta2_default(FBDF()), beta2_default(FBDF())) +const DEFAULTBETA1S = (beta1_default(Tsit5(),DEFAULTBETA2S[1]), beta1_default(Vern7(),DEFAULTBETA2S[2]), + beta1_default(Rosenbrock23(), DEFAULTBETA2S[3]), beta1_default(Rodas5P(), DEFAULTBETA2S[4]), + beta1_default(FBDF(), DEFAULTBETA2S[5]), beta1_default(FBDF(), DEFAULTBETA2S[6])) + +callbacks_exists(integrator) = !isempty(integrator.opts.callbacks) +current_nonstiff(current) = ifelse(current <= NUM_NONSTIFF,current,current-NUM_STIFF) + +function DefaultODEAlgorithm(; lazy = true, stiffalgfirst = false, kwargs...) + nonstiff = (Tsit5(), Vern7(lazy = lazy)) + stiff = (Rosenbrock23(;kwargs...), Rodas5P(;kwargs...), FBDF(;kwargs...), FBDF(;linsolve = LinearSolve.KrylovJL_GMRES())) + AutoAlgSwitch(nonstiff, stiff; stiffalgfirst) +end + +function is_stiff(integrator, alg, ntol, stol, is_stiffalg, current) + eigen_est, dt = integrator.eigen_est, integrator.dt + stiffness = abs(eigen_est * dt / STABILITY_SIZES[nonstiffchoice(integrator.opts.reltol)]) # `abs` here is just for safety + tol = is_stiffalg ? stol : ntol + os = oneunit(stiffness) + bool = stiffness > os * tol + + if !bool + integrator.alg.choice_function.successive_switches += 1 + else + integrator.alg.choice_function.successive_switches = 0 + end + + integrator.do_error_check = (integrator.alg.choice_function.successive_switches > + integrator.alg.choice_function.switch_max || !bool) || + is_stiffalg + bool +end + +function nonstiffchoice(reltol) + x = if reltol < LOW_TOL + DefaultSolverChoice.Vern7 + else + DefaultSolverChoice.Tsit5 + end + Int(x) +end + +function stiffchoice(reltol, len) + x = if len > MEDIUMSIZE + DefaultSolverChoice.KrylovFBDF + elseif len > SMALLSIZE + DefaultSolverChoice.FBDF + else + if reltol < LOW_TOL + DefaultSolverChoice.Rodas5P + else + DefaultSolverChoice.Rosenbrock23 + end + end + Int(x) +end + +function default_autoswitch(AS::AutoSwitchCache, integrator) + len = length(integrator.u) + reltol = integrator.opts.reltol + + # Chooose the starting method + if AS.current == 0 + choice = if AS.stiffalgfirst || integrator.f.mass_matrix != I + stiffchoice(reltol, len) + else + nonstiffchoice(reltol) + end + AS.current = choice + return AS.current + end + + dt = integrator.dt + # Successive stiffness test positives are counted by a positive integer, + # and successive stiffness test negatives are counted by a negative integer + AS.count = is_stiff(integrator, AS.nonstiffalg, AS.nonstifftol, AS.stifftol, + AS.is_stiffalg, AS.current) ? + AS.count < 0 ? 1 : AS.count + 1 : + AS.count > 0 ? -1 : AS.count - 1 + if integrator.f.mass_matrix != I || (!AS.is_stiffalg && AS.count > AS.maxstiffstep) + integrator.dt = dt * AS.dtfac + AS.is_stiffalg = true + AS.current = stiffchoice(reltol, len) + elseif (AS.is_stiffalg && AS.count < -AS.maxnonstiffstep) + integrator.dt = dt / AS.dtfac + AS.is_stiffalg = false + AS.current = nonstiffchoice(reltol) + end + return AS.current +end diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 236b9a82b0..13306d9bd2 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -284,7 +284,6 @@ function build_jac_config(alg, f::F1, uf::F2, du1, uprev, u, tmp, du2, if alg_autodiff(alg) isa AutoForwardDiff _chunksize = get_chunksize(alg) === Val(0) ? nothing : get_chunksize(alg) # SparseDiffEq uses different convection... - T = if standardtag(alg) typeof(ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(u))) else diff --git a/src/perform_step/composite_perform_step.jl b/src/perform_step/composite_perform_step.jl index 75057a3adc..2eb19e1688 100644 --- a/src/perform_step/composite_perform_step.jl +++ b/src/perform_step/composite_perform_step.jl @@ -1,38 +1,61 @@ -#= - -Maybe do generated functions to reduce dispatch times? - -f(x) = x -g(x,i) = f(x[i]) -g{i}(x,::Type{Val{i}}) = f(x[i]) -@generated function gg(tup::Tuple, num) - N = length(tup.parameters) - :(@nif $(N+1) i->(i == num) i->(f(tup[i])) i->error("unreachable")) - end -h(i) = g((1,1.0,"foo"), i) -h2{i}(::Type{Val{i}}) = g((1,1.0,"foo"), Val{i}) -h3(i) = gg((1,1.0,"foo"), i) -@benchmark h(1) -mean time: 31.822 ns (0.00% GC) -@benchmark h2(Val{1}) -mean time: 1.585 ns (0.00% GC) -@benchmark h3(1) -mean time: 6.423 ns (0.00% GC) - -@generated function foo(tup::Tuple, num) - N = length(tup.parameters) - :(@nif $(N+1) i->(i == num) i->(tup[i]) i->error("unreachable")) +function init_ith_default_cache(cache::DefaultCache, algs, i) + if i == 1 + if !isdefined(cache, :cache1) + cache.cache1 = alg_cache(algs[1], cache.args...) + end + elseif i == 2 + if !isdefined(cache, :cache2) + cache.cache2 = alg_cache(algs[2], cache.args...) + end + elseif i == 3 + if !isdefined(cache, :cache3) + cache.cache3 = alg_cache(algs[3], cache.args...) + end + elseif i == 4 + if !isdefined(cache, :cache4) + cache.cache4 = alg_cache(algs[4], cache.args...) + end + elseif i == 5 + if !isdefined(cache, :cache5) + cache.cache5 = alg_cache(algs[5], cache.args...) + end + elseif i == 6 + if !isdefined(cache, :cache6) + cache.cache6 = alg_cache(algs[6], cache.args...) + end + end end -@code_typed foo((1,1.0), 1) - -@generated function perform_step!(integrator, cache::CompositeCache, repeat_step=false) - N = length(cache.parameters) - :(@nif $(N+1) i->(i == num) i->(tup[i]) i->error("unreachable")) +function initialize!(integrator, cache::DefaultCache) + cache.current = cache.choice_function(integrator) + algs = integrator.alg.algs + init_ith_default_cache(cache, algs, cache.current) + if cache.current == 1 + initialize!(integrator, cache.cache1) + elseif cache.current == 2 + initialize!(integrator, cache.cache2) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[2]) + elseif cache.current == 3 + initialize!(integrator, cache.cache3) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[3]) + elseif cache.current == 4 + initialize!(integrator, cache.cache4) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[4]) + elseif cache.current == 5 + initialize!(integrator, cache.cache5) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[5]) + elseif cache.current == 6 + initialize!(integrator, cache.cache6) + # the controller was initialized by default for algs[1] + reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[6]) + end + resize!(integrator.k, integrator.kshortsize) end -=# - function initialize!(integrator, cache::CompositeCache) cache.current = cache.choice_function(integrator) if cache.current == 1 @@ -69,28 +92,37 @@ the behaviour is consistent. In particular, prevents dt ⟶ 0 if starting with non-adaptive alg and opts.adaptive=true, and dt=cst if starting with adaptive alg and opts.adaptive=false. """ -function ensure_behaving_adaptivity!(integrator, cache::CompositeCache) +function ensure_behaving_adaptivity!(integrator, cache::Union{DefaultCache, CompositeCache}) if anyadaptive(integrator.alg) && !isadaptive(integrator.alg) integrator.opts.adaptive = isadaptive(integrator.alg.algs[cache.current]) end end -function perform_step!(integrator, cache::CompositeCache, repeat_step = false) +function perform_step!(integrator, cache::DefaultCache, repeat_step = false) + algs = integrator.alg.algs + init_ith_default_cache(cache, algs, cache.current) if cache.current == 1 - perform_step!(integrator, @inbounds(cache.caches[1]), repeat_step) + perform_step!(integrator, @inbounds(cache.cache1), repeat_step) elseif cache.current == 2 - perform_step!(integrator, @inbounds(cache.caches[2]), repeat_step) - else - perform_step!(integrator, @inbounds(cache.caches[cache.current]), repeat_step) + perform_step!(integrator, @inbounds(cache.cache2), repeat_step) + elseif cache.current == 3 + perform_step!(integrator, @inbounds(cache.cache3), repeat_step) + elseif cache.current == 4 + perform_step!(integrator, @inbounds(cache.cache4), repeat_step) + elseif cache.current == 5 + perform_step!(integrator, @inbounds(cache.cache5), repeat_step) + elseif cache.current == 6 + perform_step!(integrator, @inbounds(cache.cache6), repeat_step) end end -function perform_step!(integrator, cache::CompositeCache{Tuple{T1, T2}, F}, - repeat_step = false) where {T1, T2, F} +function perform_step!(integrator, cache::CompositeCache, repeat_step = false) if cache.current == 1 perform_step!(integrator, @inbounds(cache.caches[1]), repeat_step) elseif cache.current == 2 perform_step!(integrator, @inbounds(cache.caches[2]), repeat_step) + else + perform_step!(integrator, @inbounds(cache.caches[cache.current]), repeat_step) end end @@ -121,34 +153,52 @@ function choose_algorithm!(integrator, end end -function choose_algorithm!(integrator, cache::CompositeCache) +function choose_algorithm!(integrator, cache::DefaultCache) new_current = cache.choice_function(integrator) old_current = cache.current @inbounds if new_current != old_current + algs = integrator.alg.algs cache.current = new_current + init_ith_default_cache(cache, algs, new_current) if new_current == 1 - initialize!(integrator, @inbounds(cache.caches[1])) + initialize!(integrator, @inbounds(cache.cache1)) + new_cache = cache.cache1 elseif new_current == 2 - initialize!(integrator, @inbounds(cache.caches[2])) - else - initialize!(integrator, @inbounds(cache.caches[new_current])) + initialize!(integrator, @inbounds(cache.cache2)) + new_cache = cache.cache2 + elseif new_current == 3 + initialize!(integrator, @inbounds(cache.cache3)) + new_cache = cache.cache3 + elseif new_current == 4 + initialize!(integrator, @inbounds(cache.cache4)) + new_cache = cache.cache4 + elseif new_current == 5 + initialize!(integrator, @inbounds(cache.cache5)) + new_cache = cache.cache5 + elseif new_current == 6 + initialize!(integrator, @inbounds(cache.cache6)) + new_cache = cache.cache6 end - if old_current == 1 && new_current == 2 - reset_alg_dependent_opts!(integrator, integrator.alg.algs[1], - integrator.alg.algs[2]) - transfer_cache!(integrator, integrator.cache.caches[1], - integrator.cache.caches[2]) - elseif old_current == 2 && new_current == 1 - reset_alg_dependent_opts!(integrator, integrator.alg.algs[2], - integrator.alg.algs[1]) - transfer_cache!(integrator, integrator.cache.caches[2], - integrator.cache.caches[1]) - else - reset_alg_dependent_opts!(integrator, integrator.alg.algs[old_current], - integrator.alg.algs[new_current]) - transfer_cache!(integrator, integrator.cache.caches[old_current], - integrator.cache.caches[new_current]) + + if old_current == 1 + old_cache = cache.cache1 + elseif old_current == 2 + old_cache = cache.cache2 + elseif old_current == 3 + old_cache = cache.cache3 + elseif old_current == 4 + old_cache = cache.cache4 + elseif old_current == 5 + old_cache = cache.cache5 + elseif old_current == 6 + old_cache = cache.cache6 end + + integrator.opts.controller.beta2 = beta2 = beta2_default(algs[new_current]) + integrator.opts.controller.beta1 = beta1_default(algs[new_current], beta2) + + reset_alg_dependent_opts!(integrator, algs[old_current], algs[new_current]) + transfer_cache!(integrator, old_cache, new_cache) end end @@ -170,6 +220,7 @@ function reset_alg_dependent_opts!(integrator, alg1, alg2) integrator.opts.qmax == qmax_default(alg2) end reset_alg_dependent_opts!(integrator.opts.controller, alg1, alg2) + nothing end # Write how to transfer the cache variables from one cache to the other diff --git a/src/perform_step/verner_rk_perform_step.jl b/src/perform_step/verner_rk_perform_step.jl index 0a44e217da..7c76e26719 100644 --- a/src/perform_step/verner_rk_perform_step.jl +++ b/src/perform_step/verner_rk_perform_step.jl @@ -2,7 +2,7 @@ function initialize!(integrator, cache::Vern6ConstantCache) integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) # Pre-start fsal integrator.stats.nf += 1 alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) + cache.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -13,7 +13,7 @@ function initialize!(integrator, cache::Vern6ConstantCache) end integrator.k[integrator.kshortsize] = integrator.fsallast - if !alg.lazy + if !cache.lazy @inbounds for i in 10:12 integrator.k[i] = zero(integrator.fsalfirst) end @@ -63,7 +63,7 @@ end integrator.k[9] = k9 alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c10, a1001, a1004, a1005, a1006, a1007, a1008, a1009, c11, a1101, a1104, a1105, a1106, a1107, a1108, a1109, a1110, c12, a1201, a1204, a1205, a1206, a1207, a1208, a1209, a1210, a1211 = cache.tab.extra @@ -94,7 +94,7 @@ end function initialize!(integrator, cache::Vern6Cache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) + cache.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12) integrator.fsalfirst = cache.k1 integrator.fsallast = cache.k9 @unpack k = integrator @@ -109,7 +109,7 @@ function initialize!(integrator, cache::Vern6Cache) k[8] = cache.k8 k[9] = cache.k9 # Set the pointers - if !alg.lazy + if !cache.lazy k[10] = similar(cache.k1) k[11] = similar(cache.k1) k[12] = similar(cache.k1) @@ -182,7 +182,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c10, a1001, a1004, a1005, a1006, a1007, a1008, a1009, c11, a1101, a1104, a1105, a1106, a1107, a1108, a1109, a1110, c12, a1201, a1204, a1205, a1206, a1207, a1208, a1209, a1210, a1211 = cache.tab.extra @@ -214,7 +214,7 @@ end function initialize!(integrator, cache::Vern7ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -277,7 +277,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @OnDemandTableauExtract Vern7ExtraStages T T2 @@ -329,7 +329,7 @@ function initialize!(integrator, cache::Vern7Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 16) resize!(k, integrator.kshortsize) k[1] = k1 k[2] = k2 @@ -342,7 +342,7 @@ function initialize!(integrator, cache::Vern7Cache) k[9] = k9 k[10] = k10 # Setup pointers - if !alg.lazy + if !cache.lazy k[11] = similar(cache.k1) k[12] = similar(cache.k1) k[13] = similar(cache.k1) @@ -433,7 +433,7 @@ end integrator.EEst = integrator.opts.internalnorm(atmp, t) end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack tmp = cache @@ -489,7 +489,7 @@ end function initialize!(integrator, cache::Vern8ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) + cache.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -575,7 +575,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c14, a1401, a1406, a1407, a1408, a1409, a1410, a1411, a1412, c15, a1501, a1506, a1507, a1508, a1509, a1510, a1511, a1512, a1514, c16, a1601, a1606, a1607, a1608, a1609, a1610, a1611, a1612, a1614, a1615, c17, a1701, a1706, a1707, a1708, a1709, a1710, a1711, a1712, a1714, a1715, a1716, c18, a1801, a1806, a1807, a1808, a1809, a1810, a1811, a1812, a1814, a1815, a1816, a1817, c19, a1901, a1906, a1907, a1908, a1909, a1910, a1911, a1912, a1914, a1915, a1916, a1917, c20, a2001, a2006, a2007, a2008, a2009, a2010, a2011, a2012, a2014, a2015, a2016, a2017, c21, a2101, a2106, a2107, a2108, a2109, a2110, a2111, a2112, a2114, a2115, a2116, a2117 = cache.tab.extra @@ -642,7 +642,7 @@ function initialize!(integrator, cache::Vern8Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) + cache.lazy ? (integrator.kshortsize = 13) : (integrator.kshortsize = 21) resize!(k, integrator.kshortsize) k[1] = k1 k[2] = k2 @@ -658,7 +658,7 @@ function initialize!(integrator, cache::Vern8Cache) k[12] = k12 k[13] = k13 # Setup pointers - if !alg.lazy + if !cache.lazy for i in 14:21 k[i] = similar(cache.k1) end @@ -764,7 +764,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack c14, a1401, a1406, a1407, a1408, a1409, a1410, a1411, a1412, c15, a1501, a1506, a1507, a1508, a1509, a1510, a1511, a1512, a1514, c16, a1601, a1606, a1607, a1608, a1609, a1610, a1611, a1612, a1614, a1615, c17, a1701, a1706, a1707, a1708, a1709, a1710, a1711, a1712, a1714, a1715, a1716, c18, a1801, a1806, a1807, a1808, a1809, a1810, a1811, a1812, a1814, a1815, a1816, a1817, c19, a1901, a1906, a1907, a1908, a1909, a1910, a1911, a1912, a1914, a1915, a1916, a1917, c20, a2001, a2006, a2007, a2008, a2009, a2010, a2011, a2012, a2014, a2015, a2016, a2017, c21, a2101, a2106, a2107, a2108, a2109, a2110, a2111, a2112, a2114, a2115, a2116, a2117 = cache.tab.extra @@ -851,7 +851,7 @@ end function initialize!(integrator, cache::Vern9ConstantCache) alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays @@ -945,7 +945,7 @@ end integrator.u = u alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @OnDemandTableauExtract Vern9ExtraStages T T2 @@ -1032,7 +1032,7 @@ function initialize!(integrator, cache::Vern9Cache) @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16 = cache @unpack k = integrator alg = unwrap_alg(integrator, false) - alg.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 20) resize!(k, integrator.kshortsize) # k2, k3,k4,k5,k6,k7 are not used in the code (not even in interpolations), we dont need their pointers. # So we mapped k[2] (from integrator) with k8 (from cache), k[3] with k9 and so on. @@ -1047,7 +1047,7 @@ function initialize!(integrator, cache::Vern9Cache) k[9] = k15 k[10] = k16 # Setup pointers - if !alg.lazy + if !cache.lazy for i in 11:20 k[i] = similar(cache.k1) end @@ -1174,7 +1174,7 @@ end end alg = unwrap_alg(integrator, false) - if !alg.lazy && (integrator.opts.adaptive == false || + if !cache.lazy && (integrator.opts.adaptive == false || accept_step_controller(integrator, integrator.opts.controller)) k = integrator.k @unpack tmp = cache diff --git a/src/solve.jl b/src/solve.jl index e29136450e..dc91f4644c 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -146,7 +146,7 @@ function DiffEqBase.__init( auto.stifftol, auto.dtfac, auto.stiffalgfirst, - auto.switch_max)) + auto.switch_max, 0)) else _alg = alg end @@ -415,11 +415,9 @@ function DiffEqBase.__init( differential_vars = prob isa DAEProblem ? prob.differential_vars : get_differential_vars(f, u) - id = InterpolationData( - f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars, false) + id = InterpolationData(f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars, false) sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, - dense = dense, k = ks, interp = id, - alg_choice = alg_choice, + dense = dense, k = ks, interp = id, alg_choice = alg_choice, calculate_error = false, stats = stats) if recompile_flag == true @@ -534,6 +532,13 @@ function DiffEqBase.__init( integrator end +function DiffEqBase.__init(prob::ODEProblem, ::Nothing, args...; kwargs...) + DiffEqBase.init(prob, DefaultODEAlgorithm(autodiff=false), args...; kwargs...) +end +function DiffEqBase.__solve(prob::ODEProblem, ::Nothing, args...; kwargs...) + DiffEqBase.solve(prob, DefaultODEAlgorithm(autodiff=false), args...; kwargs...) +end + function DiffEqBase.solve!(integrator::ODEIntegrator) @inbounds while !isempty(integrator.opts.tstops) while integrator.tdir * integrator.t < first(integrator.opts.tstops) diff --git a/test/interface/default_solver_tests.jl b/test/interface/default_solver_tests.jl new file mode 100644 index 0000000000..0a0af32148 --- /dev/null +++ b/test/interface/default_solver_tests.jl @@ -0,0 +1,82 @@ +using OrdinaryDiffEq, Test, LinearSolve, LinearAlgebra, SparseArrays + +f_2dlinear = (du, u, p, t) -> (@. du = p * u) + +prob_ode_2Dlinear = ODEProblem(f_2dlinear, rand(4, 2), (0.0, 1.0), 1.01) +sol = @inferred solve(prob_ode_2Dlinear) + +tsitsol = solve(prob_ode_2Dlinear, Tsit5()) +# test that default isn't much worse than Tsit5 (we expect it to use Tsit5 for this). +@test sol.stats.naccept == tsitsol.stats.naccept +@test sol.stats.nf == tsitsol.stats.nf +@test all(isequal(1), sol.alg_choice) + +sol = solve(prob_ode_2Dlinear, reltol=1e-10) +vernsol = solve(prob_ode_2Dlinear, Vern7(), reltol=1e-10) +# test that default isn't much worse than Tsit5 (we expect it to use Tsit5 for this). +@test sol.stats.naccept == vernsol.stats.naccept +@test sol.stats.nf == vernsol.stats.nf +@test all(isequal(2), sol.alg_choice) + +prob_ode_linear_fast = ODEProblem(ODEFunction(f_2dlinear, mass_matrix=2*I(2)), rand(2), (0.0, 1.0), 1.01) +sol = solve(prob_ode_linear_fast) +@test all(isequal(3), sol.alg_choice) +# for some reason the timestepping here is different from regular Rosenbrock23 (including the initial timestep) + +function rober(u, p, t) + y₁, y₂, y₃ = u + k₁, k₂, k₃ = p + [-k₁ * y₁ + k₃ * y₂ * y₃, + k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2, + k₂ * y₂^2] +end +prob_rober = ODEProblem(rober, [1.0,0.0,0.0],(0.0,1e3),(0.04,3e7,1e4)) +sol = solve(prob_rober) +rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23(autodiff=false))) +# test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this). +@test sol.stats.naccept == rosensol.stats.naccept +@test sol.stats.nf == rosensol.stats.nf +@test unique(sol.alg_choice) == [1,3] +@test sol.alg_choice[1] == 1 +@test sol.alg_choice[end] == 3 + +sol = solve(prob_rober, reltol=1e-7, abstol=1e-7) +rosensol = solve(prob_rober, AutoVern7(Rodas5P(autodiff=false)), reltol=1e-7, abstol=1e-7) +# test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this). +@test sol.stats.naccept == rosensol.stats.naccept +@test sol.stats.nf == rosensol.stats.nf +@test unique(sol.alg_choice) == [2,4] +@test sol.alg_choice[1] == 2 +@test sol.alg_choice[end] == 4 + +function exrober(du, u, p, t) + y₁, y₂, y₃ = u + k₁, k₂, k₃ = p + du .= vcat([-k₁ * y₁ + k₃ * y₂ * y₃, + k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2, + k₂ * y₂^2, ], u[4:end]) +end + +for n in (100, ) # 600 should be added but currently is broken for unknown reasons + stiffalg = n < 50 ? 4 : n < 500 ? 5 : 6 + linsolve = n < 500 ? nothing : KrylovJL_GMRES() + jac_prototype = sparse(I(n+3)) + jac_prototype[1:3, 1:3] .= 1.0 + + prob_ex_rober = ODEProblem(ODEFunction(exrober; jac_prototype), vcat([1.0,0.0,0.0], ones(n)),(0.0,100.0),(0.04,3e7,1e4)) + sol = solve(prob_ex_rober) + fsol = solve(prob_ex_rober, AutoTsit5(FBDF(;autodiff=false, linsolve))) + # test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this). + @test sol.stats.naccept == fsol.stats.naccept + @test sol.stats.nf == fsol.stats.nf + @test unique(sol.alg_choice) == [1,stiffalg] +end + +function swaplinear(u, p, t) + [u[2], u[1]].*p +end +swaplinearf = ODEFunction(swaplinear, mass_matrix=ones(2,2)-I(2)) +prob_swaplinear = ODEProblem(swaplinearf, rand(2), (0., 1.), 1.01) +sol = solve(prob_swaplinear, reltol=1e-7) # reltol must be set to avoid running into a bug with Rosenbrock23 +@test all(isequal(4), sol.alg_choice) +# for some reason the timestepping here is different from regular Rodas5P (including the initial timestep) diff --git a/test/runtests.jl b/test/runtests.jl index 6cdf509647..a15d6044dc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,6 +52,7 @@ end @time @safetestset "Controller Tests" include("interface/controllers.jl") @time @safetestset "Inplace Interpolation Tests" include("interface/inplace_interpolation.jl") @time @safetestset "Algebraic Interpolation Tests" include("interface/algebraic_interpolation.jl") + @time @safetestset "Default Solver Tests" include("interface/default_solver_tests.jl") end if !is_APPVEYOR && (GROUP == "All" || GROUP == "InterfaceII" || GROUP == "Interface")