From 20e46e7562b99b1bd770dda75589af90a33045a4 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 30 Aug 2024 11:45:33 -0400 Subject: [PATCH 1/9] Improve NULL FSAL and fix du for non-FSAL Fixes the downstream SciMLSensitivity --- lib/OrdinaryDiffEqBDF/src/dae_caches.jl | 2 +- .../src/integrators/integrator_interface.jl | 4 +-- .../src/functionmap_caches.jl | 2 +- .../src/low_storage_rk_perform_step.jl | 2 +- lib/OrdinaryDiffEqPDIRK/src/pdirk_caches.jl | 2 +- .../src/rosenbrock_caches.jl | 2 +- lib/OrdinaryDiffEqVerner/src/verner_caches.jl | 6 ++--- test/interface/get_du.jl | 26 +++++++++++++++++++ test/runtests.jl | 1 + 9 files changed, 37 insertions(+), 10 deletions(-) create mode 100644 test/interface/get_du.jl diff --git a/lib/OrdinaryDiffEqBDF/src/dae_caches.jl b/lib/OrdinaryDiffEqBDF/src/dae_caches.jl index 708a3d119f..aa03fc0f25 100644 --- a/lib/OrdinaryDiffEqBDF/src/dae_caches.jl +++ b/lib/OrdinaryDiffEqBDF/src/dae_caches.jl @@ -15,7 +15,7 @@ end end # Not FSAL -get_fsalfirstlast(cache::DImplicitEulerCache, u) = (u, u) +get_fsalfirstlast(cache::DImplicitEulerCache, u) = (nothing, nothing) mutable struct DImplicitEulerConstantCache{N} <: OrdinaryDiffEqConstantCache nlsolver::N diff --git a/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl b/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl index 896d4ecf51..f134ec4b81 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl @@ -59,7 +59,7 @@ end @inline function DiffEqBase.get_du(integrator::ODEIntegrator) isdiscretecache(integrator.cache) && error("Derivatives are not defined for this stepper.") - return if isdefined(integrator, :fsallast) + return if isfsal(integrator.alg) integrator.fsallast else integrator(integrator.t, Val{1}) @@ -72,7 +72,7 @@ end if isdiscretecache(integrator.cache) out .= integrator.cache.tmp else - return if isdefined(integrator, :fsallast) && + return if isfsal(integrator.alg) && !has_stiff_interpolation(integrator.alg) # Special stiff interpolations do not store the # right value in fsallast diff --git a/lib/OrdinaryDiffEqFunctionMap/src/functionmap_caches.jl b/lib/OrdinaryDiffEqFunctionMap/src/functionmap_caches.jl index 38d7afe1d8..a2b03c0818 100644 --- a/lib/OrdinaryDiffEqFunctionMap/src/functionmap_caches.jl +++ b/lib/OrdinaryDiffEqFunctionMap/src/functionmap_caches.jl @@ -3,7 +3,7 @@ uprev::uType tmp::rateType end -get_fsalfirstlast(cache::FunctionMapCache, u) = (cache.u, cache.uprev) +get_fsalfirstlast(cache::FunctionMapCache, u) = (nothing, nothing) function alg_cache(alg::FunctionMap, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, diff --git a/lib/OrdinaryDiffEqLowStorageRK/src/low_storage_rk_perform_step.jl b/lib/OrdinaryDiffEqLowStorageRK/src/low_storage_rk_perform_step.jl index 7565d8778e..60a7863dbf 100644 --- a/lib/OrdinaryDiffEqLowStorageRK/src/low_storage_rk_perform_step.jl +++ b/lib/OrdinaryDiffEqLowStorageRK/src/low_storage_rk_perform_step.jl @@ -33,7 +33,7 @@ end integrator.u = u end -get_fsalfirstlast(cache::LowStorageRK2NCache, u) = (cache.k, cache.k) +get_fsalfirstlast(cache::LowStorageRK2NCache, u) = (nothing, nothing) function initialize!(integrator, cache::LowStorageRK2NCache) @unpack k, tmp, williamson_condition = cache diff --git a/lib/OrdinaryDiffEqPDIRK/src/pdirk_caches.jl b/lib/OrdinaryDiffEqPDIRK/src/pdirk_caches.jl index 2cbbbf828f..2818424697 100644 --- a/lib/OrdinaryDiffEqPDIRK/src/pdirk_caches.jl +++ b/lib/OrdinaryDiffEqPDIRK/src/pdirk_caches.jl @@ -8,7 +8,7 @@ end # Non-FSAL -get_fsalfirstlast(cache::PDIRK44Cache, u) = (cache.u, cache.uprev) +get_fsalfirstlast(cache::PDIRK44Cache, u) = (nothing, nothing) struct PDIRK44ConstantCache{N, TabType} <: OrdinaryDiffEqConstantCache nlsolver::N diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl index 0a70a7fdde..1b2430ada8 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl @@ -2,7 +2,7 @@ abstract type RosenbrockMutableCache <: OrdinaryDiffEqMutableCache end abstract type RosenbrockConstantCache <: OrdinaryDiffEqConstantCache end # Fake values since non-FSAL -get_fsalfirstlast(cache::RosenbrockMutableCache, u) = (zero(u), zero(u)) +get_fsalfirstlast(cache::RosenbrockMutableCache, u) = (nothing, nothing) ################################################################################ diff --git a/lib/OrdinaryDiffEqVerner/src/verner_caches.jl b/lib/OrdinaryDiffEqVerner/src/verner_caches.jl index 19289eb2c2..9a3444d101 100644 --- a/lib/OrdinaryDiffEqVerner/src/verner_caches.jl +++ b/lib/OrdinaryDiffEqVerner/src/verner_caches.jl @@ -87,7 +87,7 @@ end end # fake values since non-FSAL method -get_fsalfirstlast(cache::Vern7Cache, u) = (cache.k1, cache.k2) +get_fsalfirstlast(cache::Vern7Cache, u) = (nothing, nothing) function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, @@ -153,7 +153,7 @@ end end # fake values since non-FSAL method -get_fsalfirstlast(cache::Vern8Cache, u) = (cache.k1, cache.k2) +get_fsalfirstlast(cache::Vern8Cache, u) = (nothing, nothing) function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, @@ -227,7 +227,7 @@ end end # fake values since non-FSAL method -get_fsalfirstlast(cache::Vern9Cache, u) = (cache.k1, cache.k2) +get_fsalfirstlast(cache::Vern9Cache, u) = (nothing, nothing) function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, diff --git a/test/interface/get_du.jl b/test/interface/get_du.jl new file mode 100644 index 0000000000..f4ad7c9c37 --- /dev/null +++ b/test/interface/get_du.jl @@ -0,0 +1,26 @@ +using OrdinaryDiffEq, OrdinaryDiffEqCore, Test +function lorenz!(du, u, p, t) + du[1] = 10.0(u[2] - u[1]) + du[2] = u[1] * (28.0 - u[3]) - u[2] + du[3] = u[1] * u[2] - (8 / 3) * u[3] +end +u0 = [1.0; 0.0; 0.0] +tspan = (0.0, 3.0) +condition(u,t,integrator) = t == 0.2 +cache = zeros(3) +affect!(integrator) = cache .= get_du(integrator) +dusave = DiscreteCallback(condition, affect!) +affect2!(integrator) = get_du!(cache, integrator) +dusave_inplace = DiscreteCallback(condition, affect2!) + +prob = ODEProblem(lorenz!, u0, tspan) +sol = solve(prob, Tsit5(), tstops = [0.2], callback = dusave, abstol=1e-12, reltol=1e-12) +res = copy(cache) + +for alg in [Vern6(), Vern7(), Vern8(), Vern9(), Rodas4(), Rodas4P(), Rodas5(), Rodas5P(), TRBDF2(), KenCarp4(), FBDF(), QNDF()] + sol = solve(prob, alg, tstops = [0.2], callback = dusave, abstol=1e-12, reltol=1e-12) + @test res ≈ cache rtol=1e-5 + + sol = solve(prob, alg, tstops = [0.2], callback = dusave_inplace, abstol=1e-12, reltol=1e-12) + @test res ≈ cache rtol=1e-5 +end diff --git a/test/runtests.jl b/test/runtests.jl index a452d56ee8..5c89af81b7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,6 +66,7 @@ end @time @safetestset "Linear Solver Split ODE Tests" include("interface/linear_solver_split_ode_test.jl") @time @safetestset "Sparse Diff Tests" include("interface/sparsediff_tests.jl") @time @safetestset "Enum Tests" include("interface/enums.jl") + @time @safetestset "Enum Tests" include("interface/get_du.jl") @time @safetestset "Mass Matrix Tests" include("interface/mass_matrix_tests.jl") @time @safetestset "W-Operator prototype tests" include("interface/wprototype_tests.jl") end From 5283d7b151e732fe9e5db46110dbd1fdfb620e89 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 30 Aug 2024 14:15:42 -0400 Subject: [PATCH 2/9] fix composite handling of nothing fsal --- .../perform_step/composite_perform_step.jl | 76 +++++++++---------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/perform_step/composite_perform_step.jl b/lib/OrdinaryDiffEqCore/src/perform_step/composite_perform_step.jl index 91d6d2a781..975754fc5d 100644 --- a/lib/OrdinaryDiffEqCore/src/perform_step/composite_perform_step.jl +++ b/lib/OrdinaryDiffEqCore/src/perform_step/composite_perform_step.jl @@ -33,41 +33,41 @@ function initialize!(integrator, cache::DefaultCache) u = integrator.u if cache.current == 1 fsalfirst, fsallast = get_fsalfirstlast(cache.cache1, u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, cache.cache1) elseif cache.current == 2 fsalfirst, fsallast = get_fsalfirstlast(cache.cache2, u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, cache.cache2) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[2]) elseif cache.current == 3 fsalfirst, fsallast = get_fsalfirstlast(cache.cache3, u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, cache.cache3) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[3]) elseif cache.current == 4 fsalfirst, fsallast = get_fsalfirstlast(cache.cache4, u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, cache.cache4) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[4]) elseif cache.current == 5 fsalfirst, fsallast = get_fsalfirstlast(cache.cache5, u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, cache.cache5) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[5]) elseif cache.current == 6 fsalfirst, fsallast = get_fsalfirstlast(cache.cache6, u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, cache.cache6) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[6]) @@ -80,21 +80,21 @@ function initialize!(integrator, cache::CompositeCache) u = integrator.u if cache.current == 1 fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1], u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, @inbounds(cache.caches[1])) elseif cache.current == 2 fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2], u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, @inbounds(cache.caches[2])) # the controller was initialized by default for integrator.alg.algs[1] reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], integrator.alg.algs[2]) else fsalfirst, fsallast = get_fsalfirstlast(cache.caches[cache.current], u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, @inbounds(cache.caches[cache.current])) reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], integrator.alg.algs[cache.current]) @@ -107,13 +107,13 @@ function initialize!(integrator, cache::CompositeCache{Tuple{T1, T2}, F}) where u = integrator.u if cache.current == 1 fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1], u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, @inbounds(cache.caches[1])) elseif cache.current == 2 fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2], u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, @inbounds(cache.caches[2])) reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], integrator.alg.algs[2]) @@ -173,13 +173,13 @@ function choose_algorithm!(integrator, cache.current = new_current if new_current == 1 fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1], u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, @inbounds(cache.caches[1])) elseif new_current == 2 fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2], u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, @inbounds(cache.caches[2])) end if old_current == 1 && new_current == 2 @@ -206,38 +206,38 @@ function choose_algorithm!(integrator, cache::DefaultCache) init_ith_default_cache(cache, algs, new_current) if new_current == 1 fsalfirst, fsallast = get_fsalfirstlast(cache.cache1, u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, @inbounds(cache.cache1)) new_cache = cache.cache1 elseif new_current == 2 fsalfirst, fsallast = get_fsalfirstlast(cache.cache2, u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, @inbounds(cache.cache2)) new_cache = cache.cache2 elseif new_current == 3 fsalfirst, fsallast = get_fsalfirstlast(cache.cache3, u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, @inbounds(cache.cache3)) new_cache = cache.cache3 elseif new_current == 4 fsalfirst, fsallast = get_fsalfirstlast(cache.cache4, u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, @inbounds(cache.cache4)) new_cache = cache.cache4 elseif new_current == 5 fsalfirst, fsallast = get_fsalfirstlast(cache.cache5, u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, @inbounds(cache.cache5)) new_cache = cache.cache5 elseif new_current == 6 fsalfirst, fsallast = get_fsalfirstlast(cache.cache6, u) - integrator.fsalfirst = fsalfirst - integrator.fsallast = fsallast + !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst + !isnothing(fsallast) && integrator.fsallast = fsallast initialize!(integrator, @inbounds(cache.cache6)) new_cache = cache.cache6 end From bc4e96f3dce8e846e4a2cf41b1ff33e22381cc02 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 31 Aug 2024 06:08:07 -0400 Subject: [PATCH 3/9] fix assignment --- .../perform_step/composite_perform_step.jl | 76 +++++++++---------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/perform_step/composite_perform_step.jl b/lib/OrdinaryDiffEqCore/src/perform_step/composite_perform_step.jl index 975754fc5d..c3c101a47f 100644 --- a/lib/OrdinaryDiffEqCore/src/perform_step/composite_perform_step.jl +++ b/lib/OrdinaryDiffEqCore/src/perform_step/composite_perform_step.jl @@ -33,41 +33,41 @@ function initialize!(integrator, cache::DefaultCache) u = integrator.u if cache.current == 1 fsalfirst, fsallast = get_fsalfirstlast(cache.cache1, u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, cache.cache1) elseif cache.current == 2 fsalfirst, fsallast = get_fsalfirstlast(cache.cache2, u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, cache.cache2) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[2]) elseif cache.current == 3 fsalfirst, fsallast = get_fsalfirstlast(cache.cache3, u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, cache.cache3) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[3]) elseif cache.current == 4 fsalfirst, fsallast = get_fsalfirstlast(cache.cache4, u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, cache.cache4) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[4]) elseif cache.current == 5 fsalfirst, fsallast = get_fsalfirstlast(cache.cache5, u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, cache.cache5) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[5]) elseif cache.current == 6 fsalfirst, fsallast = get_fsalfirstlast(cache.cache6, u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, cache.cache6) # the controller was initialized by default for algs[1] reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[6]) @@ -80,21 +80,21 @@ function initialize!(integrator, cache::CompositeCache) u = integrator.u if cache.current == 1 fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1], u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, @inbounds(cache.caches[1])) elseif cache.current == 2 fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2], u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, @inbounds(cache.caches[2])) # the controller was initialized by default for integrator.alg.algs[1] reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], integrator.alg.algs[2]) else fsalfirst, fsallast = get_fsalfirstlast(cache.caches[cache.current], u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, @inbounds(cache.caches[cache.current])) reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], integrator.alg.algs[cache.current]) @@ -107,13 +107,13 @@ function initialize!(integrator, cache::CompositeCache{Tuple{T1, T2}, F}) where u = integrator.u if cache.current == 1 fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1], u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, @inbounds(cache.caches[1])) elseif cache.current == 2 fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2], u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, @inbounds(cache.caches[2])) reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1], integrator.alg.algs[2]) @@ -173,13 +173,13 @@ function choose_algorithm!(integrator, cache.current = new_current if new_current == 1 fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1], u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, @inbounds(cache.caches[1])) elseif new_current == 2 fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2], u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, @inbounds(cache.caches[2])) end if old_current == 1 && new_current == 2 @@ -206,38 +206,38 @@ function choose_algorithm!(integrator, cache::DefaultCache) init_ith_default_cache(cache, algs, new_current) if new_current == 1 fsalfirst, fsallast = get_fsalfirstlast(cache.cache1, u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, @inbounds(cache.cache1)) new_cache = cache.cache1 elseif new_current == 2 fsalfirst, fsallast = get_fsalfirstlast(cache.cache2, u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, @inbounds(cache.cache2)) new_cache = cache.cache2 elseif new_current == 3 fsalfirst, fsallast = get_fsalfirstlast(cache.cache3, u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, @inbounds(cache.cache3)) new_cache = cache.cache3 elseif new_current == 4 fsalfirst, fsallast = get_fsalfirstlast(cache.cache4, u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, @inbounds(cache.cache4)) new_cache = cache.cache4 elseif new_current == 5 fsalfirst, fsallast = get_fsalfirstlast(cache.cache5, u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, @inbounds(cache.cache5)) new_cache = cache.cache5 elseif new_current == 6 fsalfirst, fsallast = get_fsalfirstlast(cache.cache6, u) - !isnothing(fsalfirst) && integrator.fsalfirst = fsalfirst - !isnothing(fsallast) && integrator.fsallast = fsallast + !isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst) + !isnothing(fsallast) && (integrator.fsallast = fsallast) initialize!(integrator, @inbounds(cache.cache6)) new_cache = cache.cache6 end From f9bfffe6888a13dc61eb86f1af4bf99f4bb4841e Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 31 Aug 2024 07:00:33 -0400 Subject: [PATCH 4/9] print --- lib/OrdinaryDiffEqCore/src/solve.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index 563c7654db..5cf0dbc8e2 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -471,6 +471,8 @@ function DiffEqBase.__init( saveiter_dense = 0 faslfirst, fsallast = get_fsalfirstlast(cache, rate_prototype) + @show typeof(fsalfirst) + integrator = ODEIntegrator{typeof(_alg), isinplace(prob), uType, typeof(du), tType, typeof(p), typeof(eigen_est), typeof(EEst), From f6799e442e336f5ea455b0d4877ea8f23fc18b82 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 31 Aug 2024 11:56:39 -0400 Subject: [PATCH 5/9] fix typo --- lib/OrdinaryDiffEqCore/src/solve.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index 5cf0dbc8e2..cef8195e98 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -469,7 +469,7 @@ function DiffEqBase.__init( reinitiailize = true saveiter = 0 # Starts at 0 so first save is at 1 saveiter_dense = 0 - faslfirst, fsallast = get_fsalfirstlast(cache, rate_prototype) + fsalfirst, fsallast = get_fsalfirstlast(cache, rate_prototype) @show typeof(fsalfirst) @@ -478,7 +478,7 @@ function DiffEqBase.__init( typeof(eigen_est), typeof(EEst), QT, typeof(tdir), typeof(k), SolType, FType, cacheType, - typeof(opts), typeof(faslfirst), + typeof(opts), typeof(fsalfirst), typeof(last_event_error), typeof(callback_cache), typeof(initializealg), typeof(differential_vars)}( sol, u, du, k, t, tType(dt), f, p, @@ -498,7 +498,7 @@ function DiffEqBase.__init( isout, reeval_fsal, u_modified, reinitiailize, isdae, opts, stats, initializealg, differential_vars, - faslfirst, fsallast) + fsalfirst, fsallast) if initialize_integrator if isdae || SciMLBase.has_initializeprob(prob.f) From 241b3933d8c1e2ed6a5d16b01959b8f4c199315e Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 31 Aug 2024 12:33:13 -0400 Subject: [PATCH 6/9] fix not real FSAL --- .../src/low_storage_rk_perform_step.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/OrdinaryDiffEqLowStorageRK/src/low_storage_rk_perform_step.jl b/lib/OrdinaryDiffEqLowStorageRK/src/low_storage_rk_perform_step.jl index 60a7863dbf..5344dc296a 100644 --- a/lib/OrdinaryDiffEqLowStorageRK/src/low_storage_rk_perform_step.jl +++ b/lib/OrdinaryDiffEqLowStorageRK/src/low_storage_rk_perform_step.jl @@ -40,8 +40,6 @@ function initialize!(integrator, cache::LowStorageRK2NCache) integrator.kshortsize = 1 resize!(integrator.k, integrator.kshortsize) integrator.k[1] = k - integrator.fsalfirst = k # used for get_du - integrator.fsallast = k integrator.f(k, integrator.uprev, integrator.p, integrator.t) # FSAL for interpolation OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) end From 41db9c2eb80062869a55338bfb8a004f32fe4a30 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 31 Aug 2024 14:44:10 -0400 Subject: [PATCH 7/9] handle nothing first fsal case --- lib/OrdinaryDiffEqCore/src/caches/basic_caches.jl | 9 ++++++++- lib/OrdinaryDiffEqCore/src/solve.jl | 2 -- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/caches/basic_caches.jl b/lib/OrdinaryDiffEqCore/src/caches/basic_caches.jl index 088fcc4b76..9bbc9bd177 100644 --- a/lib/OrdinaryDiffEqCore/src/caches/basic_caches.jl +++ b/lib/OrdinaryDiffEqCore/src/caches/basic_caches.jl @@ -13,7 +13,14 @@ mutable struct CompositeCache{T, F} <: OrdinaryDiffEqCache current::Int end -get_fsalfirstlast(cache::CompositeCache, u) = get_fsalfirstlast(cache.caches[1], u) +function get_fsalfirstlast(cache::CompositeCache, u) + _x = get_fsalfirstlast(cache.caches[1], u) + if first(_x) !== nothing + return _x + else + return get_fsalfirstlast(cache.caches[2], u) + end +end mutable struct DefaultCache{T1, T2, T3, T4, T5, T6, A, F, uType} <: OrdinaryDiffEqCache args::A diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index cef8195e98..a5efe16d43 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -471,8 +471,6 @@ function DiffEqBase.__init( saveiter_dense = 0 fsalfirst, fsallast = get_fsalfirstlast(cache, rate_prototype) - @show typeof(fsalfirst) - integrator = ODEIntegrator{typeof(_alg), isinplace(prob), uType, typeof(du), tType, typeof(p), typeof(eigen_est), typeof(EEst), From 55a8d67b72e7788ff0cbbea24dc3a350ad5661c3 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 31 Aug 2024 15:07:34 -0400 Subject: [PATCH 8/9] fix generic rosenbrock --- lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl | 2 +- lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl b/lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl index e6bfaaf4d9..a67b4a9548 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl @@ -176,7 +176,7 @@ function gen_cache_struct(tab::RosenbrockTableau,cachename::Symbol,constcachenam end end cacheexpr=quote - @cache mutable struct $cachename{uType,rateType,uNoUnitsType,JType,WType,TabType,TFType,UFType,F,JCType,GCType} <: RosenbrockMutableCache + @cache mutable struct $cachename{uType,rateType,uNoUnitsType,JType,WType,TabType,TFType,UFType,F,JCType,GCType} <: GenericRosenbrockMutableCache u::uType uprev::uType du::rateType diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl index 1b2430ada8..bf8cfb33cd 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl @@ -1,8 +1,10 @@ abstract type RosenbrockMutableCache <: OrdinaryDiffEqMutableCache end +abstract type GenericRosenbrockMutableCache <: RosenbrockMutableCache end abstract type RosenbrockConstantCache <: OrdinaryDiffEqConstantCache end # Fake values since non-FSAL get_fsalfirstlast(cache::RosenbrockMutableCache, u) = (nothing, nothing) +get_fsalfirstlast(cache::GenericRosenbrockMutableCache, u) = (cache.fsalfirst, cache.fsallast) ################################################################################ From 624a9f9cec233af7bf0b76d1844e9e2ac053ce06 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 31 Aug 2024 15:47:49 -0400 Subject: [PATCH 9/9] oh betty --- .../src/integrators/integrator_interface.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl b/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl index f134ec4b81..ccae5a0609 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl @@ -221,8 +221,8 @@ function resize!(integrator::ODEIntegrator, i::Int) # may be required for things like units c !== nothing && resize!(c, i) end - resize!(integrator.fsalfirst, i) - resize!(integrator.fsallast, i) + !isnothing(integrator.fsalfirst) && resize!(integrator.fsalfirst, i) + !isnothing(integrator.fsallast) && resize!(integrator.fsallast, i) resize_f!(integrator.f, i) resize_nlsolver!(integrator, i) resize_J_W!(cache, integrator, i) @@ -235,8 +235,8 @@ function resize!(integrator::ODEIntegrator, i::NTuple{N, Int}) where {N} for c in full_cache(cache) resize!(c, i) end - resize!(integrator.fsalfirst, i) - resize!(integrator.fsallast, i) + !isnothing(integrator.fsalfirst) && resize!(integrator.fsalfirst, i) + !isnothing(integrator.fsallast) && resize!(integrator.fsallast, i) resize_f!(integrator.f, i) # TODO the parts below need to be adapted for implicit methods isdefined(integrator.cache, :nlsolver) && resize_nlsolver!(integrator, i)