Skip to content

Commit

Permalink
Merge pull request #2445 from SciML/fsal_du
Browse files Browse the repository at this point in the history
Improve NULL FSAL and fix du for non-FSAL
  • Loading branch information
ChrisRackauckas authored Aug 31, 2024
2 parents 951df5d + 624a9f9 commit f0b8470
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 59 deletions.
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqBDF/src/dae_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
end

# Not FSAL
get_fsalfirstlast(cache::DImplicitEulerCache, u) = (u, u)
get_fsalfirstlast(cache::DImplicitEulerCache, u) = (nothing, nothing)

mutable struct DImplicitEulerConstantCache{N} <: OrdinaryDiffEqConstantCache
nlsolver::N
Expand Down
9 changes: 8 additions & 1 deletion lib/OrdinaryDiffEqCore/src/caches/basic_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@ mutable struct CompositeCache{T, F} <: OrdinaryDiffEqCache
current::Int
end

get_fsalfirstlast(cache::CompositeCache, u) = get_fsalfirstlast(cache.caches[1], u)
function get_fsalfirstlast(cache::CompositeCache, u)
_x = get_fsalfirstlast(cache.caches[1], u)
if first(_x) !== nothing
return _x
else
return get_fsalfirstlast(cache.caches[2], u)
end
end

mutable struct DefaultCache{T1, T2, T3, T4, T5, T6, A, F, uType} <: OrdinaryDiffEqCache
args::A
Expand Down
12 changes: 6 additions & 6 deletions lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end
@inline function DiffEqBase.get_du(integrator::ODEIntegrator)
isdiscretecache(integrator.cache) &&
error("Derivatives are not defined for this stepper.")
return if isdefined(integrator, :fsallast)
return if isfsal(integrator.alg)
integrator.fsallast
else
integrator(integrator.t, Val{1})
Expand All @@ -72,7 +72,7 @@ end
if isdiscretecache(integrator.cache)
out .= integrator.cache.tmp
else
return if isdefined(integrator, :fsallast) &&
return if isfsal(integrator.alg) &&
!has_stiff_interpolation(integrator.alg)
# Special stiff interpolations do not store the
# right value in fsallast
Expand Down Expand Up @@ -221,8 +221,8 @@ function resize!(integrator::ODEIntegrator, i::Int)
# may be required for things like units
c !== nothing && resize!(c, i)
end
resize!(integrator.fsalfirst, i)
resize!(integrator.fsallast, i)
!isnothing(integrator.fsalfirst) && resize!(integrator.fsalfirst, i)
!isnothing(integrator.fsallast) && resize!(integrator.fsallast, i)
resize_f!(integrator.f, i)
resize_nlsolver!(integrator, i)
resize_J_W!(cache, integrator, i)
Expand All @@ -235,8 +235,8 @@ function resize!(integrator::ODEIntegrator, i::NTuple{N, Int}) where {N}
for c in full_cache(cache)
resize!(c, i)
end
resize!(integrator.fsalfirst, i)
resize!(integrator.fsallast, i)
!isnothing(integrator.fsalfirst) && resize!(integrator.fsalfirst, i)
!isnothing(integrator.fsallast) && resize!(integrator.fsallast, i)
resize_f!(integrator.f, i)
# TODO the parts below need to be adapted for implicit methods
isdefined(integrator.cache, :nlsolver) && resize_nlsolver!(integrator, i)
Expand Down
76 changes: 38 additions & 38 deletions lib/OrdinaryDiffEqCore/src/perform_step/composite_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,41 +33,41 @@ function initialize!(integrator, cache::DefaultCache)
u = integrator.u
if cache.current == 1
fsalfirst, fsallast = get_fsalfirstlast(cache.cache1, u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, cache.cache1)
elseif cache.current == 2
fsalfirst, fsallast = get_fsalfirstlast(cache.cache2, u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, cache.cache2)
# the controller was initialized by default for algs[1]
reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[2])
elseif cache.current == 3
fsalfirst, fsallast = get_fsalfirstlast(cache.cache3, u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, cache.cache3)
# the controller was initialized by default for algs[1]
reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[3])
elseif cache.current == 4
fsalfirst, fsallast = get_fsalfirstlast(cache.cache4, u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, cache.cache4)
# the controller was initialized by default for algs[1]
reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[4])
elseif cache.current == 5
fsalfirst, fsallast = get_fsalfirstlast(cache.cache5, u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, cache.cache5)
# the controller was initialized by default for algs[1]
reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[5])
elseif cache.current == 6
fsalfirst, fsallast = get_fsalfirstlast(cache.cache6, u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, cache.cache6)
# the controller was initialized by default for algs[1]
reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[6])
Expand All @@ -80,21 +80,21 @@ function initialize!(integrator, cache::CompositeCache)
u = integrator.u
if cache.current == 1
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1], u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, @inbounds(cache.caches[1]))
elseif cache.current == 2
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2], u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, @inbounds(cache.caches[2]))
# the controller was initialized by default for integrator.alg.algs[1]
reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1],
integrator.alg.algs[2])
else
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[cache.current], u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, @inbounds(cache.caches[cache.current]))
reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1],
integrator.alg.algs[cache.current])
Expand All @@ -107,13 +107,13 @@ function initialize!(integrator, cache::CompositeCache{Tuple{T1, T2}, F}) where
u = integrator.u
if cache.current == 1
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1], u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, @inbounds(cache.caches[1]))
elseif cache.current == 2
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2], u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, @inbounds(cache.caches[2]))
reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1],
integrator.alg.algs[2])
Expand Down Expand Up @@ -173,13 +173,13 @@ function choose_algorithm!(integrator,
cache.current = new_current
if new_current == 1
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1], u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, @inbounds(cache.caches[1]))
elseif new_current == 2
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2], u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, @inbounds(cache.caches[2]))
end
if old_current == 1 && new_current == 2
Expand All @@ -206,38 +206,38 @@ function choose_algorithm!(integrator, cache::DefaultCache)
init_ith_default_cache(cache, algs, new_current)
if new_current == 1
fsalfirst, fsallast = get_fsalfirstlast(cache.cache1, u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, @inbounds(cache.cache1))
new_cache = cache.cache1
elseif new_current == 2
fsalfirst, fsallast = get_fsalfirstlast(cache.cache2, u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, @inbounds(cache.cache2))
new_cache = cache.cache2
elseif new_current == 3
fsalfirst, fsallast = get_fsalfirstlast(cache.cache3, u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, @inbounds(cache.cache3))
new_cache = cache.cache3
elseif new_current == 4
fsalfirst, fsallast = get_fsalfirstlast(cache.cache4, u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, @inbounds(cache.cache4))
new_cache = cache.cache4
elseif new_current == 5
fsalfirst, fsallast = get_fsalfirstlast(cache.cache5, u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, @inbounds(cache.cache5))
new_cache = cache.cache5
elseif new_current == 6
fsalfirst, fsallast = get_fsalfirstlast(cache.cache6, u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
!isnothing(fsallast) && (integrator.fsallast = fsallast)
initialize!(integrator, @inbounds(cache.cache6))
new_cache = cache.cache6
end
Expand Down
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqCore/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,14 +469,14 @@ function DiffEqBase.__init(
reinitiailize = true
saveiter = 0 # Starts at 0 so first save is at 1
saveiter_dense = 0
faslfirst, fsallast = get_fsalfirstlast(cache, rate_prototype)
fsalfirst, fsallast = get_fsalfirstlast(cache, rate_prototype)

integrator = ODEIntegrator{typeof(_alg), isinplace(prob), uType, typeof(du),
tType, typeof(p),
typeof(eigen_est), typeof(EEst),
QT, typeof(tdir), typeof(k), SolType,
FType, cacheType,
typeof(opts), typeof(faslfirst),
typeof(opts), typeof(fsalfirst),
typeof(last_event_error), typeof(callback_cache),
typeof(initializealg), typeof(differential_vars)}(
sol, u, du, k, t, tType(dt), f, p,
Expand All @@ -496,7 +496,7 @@ function DiffEqBase.__init(
isout, reeval_fsal,
u_modified, reinitiailize, isdae,
opts, stats, initializealg, differential_vars,
faslfirst, fsallast)
fsalfirst, fsallast)

if initialize_integrator
if isdae || SciMLBase.has_initializeprob(prob.f)
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqFunctionMap/src/functionmap_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
uprev::uType
tmp::rateType
end
get_fsalfirstlast(cache::FunctionMapCache, u) = (cache.u, cache.uprev)
get_fsalfirstlast(cache::FunctionMapCache, u) = (nothing, nothing)

function alg_cache(alg::FunctionMap, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@ end
integrator.u = u
end

get_fsalfirstlast(cache::LowStorageRK2NCache, u) = (cache.k, cache.k)
get_fsalfirstlast(cache::LowStorageRK2NCache, u) = (nothing, nothing)

function initialize!(integrator, cache::LowStorageRK2NCache)
@unpack k, tmp, williamson_condition = cache
integrator.kshortsize = 1
resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = k
integrator.fsalfirst = k # used for get_du
integrator.fsallast = k
integrator.f(k, integrator.uprev, integrator.p, integrator.t) # FSAL for interpolation
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqPDIRK/src/pdirk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
end

# Non-FSAL
get_fsalfirstlast(cache::PDIRK44Cache, u) = (cache.u, cache.uprev)
get_fsalfirstlast(cache::PDIRK44Cache, u) = (nothing, nothing)

struct PDIRK44ConstantCache{N, TabType} <: OrdinaryDiffEqConstantCache
nlsolver::N
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ function gen_cache_struct(tab::RosenbrockTableau,cachename::Symbol,constcachenam
end
end
cacheexpr=quote
@cache mutable struct $cachename{uType,rateType,uNoUnitsType,JType,WType,TabType,TFType,UFType,F,JCType,GCType} <: RosenbrockMutableCache
@cache mutable struct $cachename{uType,rateType,uNoUnitsType,JType,WType,TabType,TFType,UFType,F,JCType,GCType} <: GenericRosenbrockMutableCache
u::uType
uprev::uType
du::rateType
Expand Down
4 changes: 3 additions & 1 deletion lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
abstract type RosenbrockMutableCache <: OrdinaryDiffEqMutableCache end
abstract type GenericRosenbrockMutableCache <: RosenbrockMutableCache end
abstract type RosenbrockConstantCache <: OrdinaryDiffEqConstantCache end

# Fake values since non-FSAL
get_fsalfirstlast(cache::RosenbrockMutableCache, u) = (zero(u), zero(u))
get_fsalfirstlast(cache::RosenbrockMutableCache, u) = (nothing, nothing)
get_fsalfirstlast(cache::GenericRosenbrockMutableCache, u) = (cache.fsalfirst, cache.fsallast)

################################################################################

Expand Down
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqVerner/src/verner_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ end
end

# fake values since non-FSAL method
get_fsalfirstlast(cache::Vern7Cache, u) = (cache.k1, cache.k2)
get_fsalfirstlast(cache::Vern7Cache, u) = (nothing, nothing)

function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand Down Expand Up @@ -153,7 +153,7 @@ end
end

# fake values since non-FSAL method
get_fsalfirstlast(cache::Vern8Cache, u) = (cache.k1, cache.k2)
get_fsalfirstlast(cache::Vern8Cache, u) = (nothing, nothing)

function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand Down Expand Up @@ -227,7 +227,7 @@ end
end

# fake values since non-FSAL method
get_fsalfirstlast(cache::Vern9Cache, u) = (cache.k1, cache.k2)
get_fsalfirstlast(cache::Vern9Cache, u) = (nothing, nothing)

function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand Down
26 changes: 26 additions & 0 deletions test/interface/get_du.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using OrdinaryDiffEq, OrdinaryDiffEqCore, Test
function lorenz!(du, u, p, t)
du[1] = 10.0(u[2] - u[1])
du[2] = u[1] * (28.0 - u[3]) - u[2]
du[3] = u[1] * u[2] - (8 / 3) * u[3]
end
u0 = [1.0; 0.0; 0.0]
tspan = (0.0, 3.0)
condition(u,t,integrator) = t == 0.2
cache = zeros(3)
affect!(integrator) = cache .= get_du(integrator)
dusave = DiscreteCallback(condition, affect!)
affect2!(integrator) = get_du!(cache, integrator)
dusave_inplace = DiscreteCallback(condition, affect2!)

prob = ODEProblem(lorenz!, u0, tspan)
sol = solve(prob, Tsit5(), tstops = [0.2], callback = dusave, abstol=1e-12, reltol=1e-12)
res = copy(cache)

for alg in [Vern6(), Vern7(), Vern8(), Vern9(), Rodas4(), Rodas4P(), Rodas5(), Rodas5P(), TRBDF2(), KenCarp4(), FBDF(), QNDF()]
sol = solve(prob, alg, tstops = [0.2], callback = dusave, abstol=1e-12, reltol=1e-12)
@test res cache rtol=1e-5

sol = solve(prob, alg, tstops = [0.2], callback = dusave_inplace, abstol=1e-12, reltol=1e-12)
@test res cache rtol=1e-5
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ end
@time @safetestset "Linear Solver Split ODE Tests" include("interface/linear_solver_split_ode_test.jl")
@time @safetestset "Sparse Diff Tests" include("interface/sparsediff_tests.jl")
@time @safetestset "Enum Tests" include("interface/enums.jl")
@time @safetestset "Enum Tests" include("interface/get_du.jl")
@time @safetestset "Mass Matrix Tests" include("interface/mass_matrix_tests.jl")
@time @safetestset "W-Operator prototype tests" include("interface/wprototype_tests.jl")
end
Expand Down

0 comments on commit f0b8470

Please sign in to comment.