Skip to content

Commit

Permalink
pass u for allowing the creation of constant cache fsal
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Aug 18, 2024
1 parent d3829c1 commit 645520f
Show file tree
Hide file tree
Showing 31 changed files with 83 additions and 79 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
abstract type ABMMutableCache <: OrdinaryDiffEqMutableCache end
abstract type ABMVariableCoefficientMutableCache <: OrdinaryDiffEqMutableCache end
get_fsalfirstlast(cache::ABMMutableCache) = (cache.fsalfirst, cache.k)
get_fsalfirstlast(cache::ABMVariableCoefficientMutableCache) = (cache.fsalfirst, cache.k4)
get_fsalfirstlast(cache::ABMMutableCache,u) = (cache.fsalfirst, cache.k)
get_fsalfirstlast(cache::ABMVariableCoefficientMutableCache,u) = (cache.fsalfirst, cache.k4)
@cache mutable struct AB3Cache{uType, rateType} <: ABMMutableCache
u::uType
uprev::uType
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqBDF/src/bdf_caches.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
abstract type BDFMutableCache <: OrdinaryDiffEqMutableCache end
get_fsalfirstlast(cache::BDFMutableCache) = (cache.fsalfirst, du_alias_or_new(cache.nlsolver, cache.fsalfirst))
get_fsalfirstlast(cache::BDFMutableCache,u) = (cache.fsalfirst, du_alias_or_new(cache.nlsolver, cache.fsalfirst))

@cache mutable struct ABDF2ConstantCache{N, dtType, rate_prototype} <:
OrdinaryDiffEqConstantCache
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqBDF/src/dae_caches.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
abstract type DAEBDFMutableCache <: OrdinaryDiffEqMutableCache end
get_fsalfirstlast(cache::DAEBDFMutableCache) = (cache.fsalfirst, du_alias_or_new(cache.nlsolver, cache.fsalfirst))
get_fsalfirstlast(cache::DAEBDFMutableCache,u) = (cache.fsalfirst, du_alias_or_new(cache.nlsolver, cache.fsalfirst))

@cache mutable struct DImplicitEulerCache{uType, rateType, uNoUnitsType, N} <:
DAEBDFMutableCache
Expand Down
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqCore/src/caches/basic_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ struct ODEEmptyCache <: OrdinaryDiffEqConstantCache end
struct ODEChunkCache{CS} <: OrdinaryDiffEqConstantCache end

# Don't worry about the potential alloc on a constant cache
get_fsalfirstlast(cache::OrdinaryDiffEqConstantCache) = (zero(cache.u), zero(cache.u))
get_fsalfirstlast(cache::OrdinaryDiffEqConstantCache,u) = (zero(u), zero(u))

mutable struct CompositeCache{T, F} <: OrdinaryDiffEqCache
caches::T
choice_function::F
current::Int
end

get_fsalfirstlast(cache::CompositeCache) = get_fsalfirstlast(cache.caches[1])
get_fsalfirstlast(cache::CompositeCache,u) = get_fsalfirstlast(cache.caches[1],u)

mutable struct DefaultCache{T1, T2, T3, T4, T5, T6, A, F, uType} <: OrdinaryDiffEqCache
args::A
Expand All @@ -32,7 +32,7 @@ mutable struct DefaultCache{T1, T2, T3, T4, T5, T6, A, F, uType} <: OrdinaryDiff
end
end

function get_fsalfirstlast(cache::DefaultCache)
function get_fsalfirstlast(cache::DefaultCache,u)
(cache.u,cache.u) # will be overwritten by the cache choice
end

Expand Down
25 changes: 14 additions & 11 deletions lib/OrdinaryDiffEqCore/src/perform_step/composite_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,42 +30,43 @@ function initialize!(integrator, cache::DefaultCache)
cache.current = cache.choice_function(integrator)
algs = integrator.alg.algs
init_ith_default_cache(cache, algs, cache.current)
u = integrator.u
if cache.current == 1
initialize!(integrator, cache.cache1)
fsalfirst, fsallast = get_fsalfirstlast(cache.cache1)
fsalfirst, fsallast = get_fsalfirstlast(cache.cache1,u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
elseif cache.current == 2
initialize!(integrator, cache.cache2)
fsalfirst, fsallast = get_fsalfirstlast(cache.cache2)
fsalfirst, fsallast = get_fsalfirstlast(cache.cache2,u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
# the controller was initialized by default for algs[1]
reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[2])
elseif cache.current == 3
initialize!(integrator, cache.cache3)
fsalfirst, fsallast = get_fsalfirstlast(cache.cache3)
fsalfirst, fsallast = get_fsalfirstlast(cache.cache3,u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
# the controller was initialized by default for algs[1]
reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[3])
elseif cache.current == 4
initialize!(integrator, cache.cache4)
fsalfirst, fsallast = get_fsalfirstlast(cache.cache4)
fsalfirst, fsallast = get_fsalfirstlast(cache.cache4,u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
# the controller was initialized by default for algs[1]
reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[4])
elseif cache.current == 5
initialize!(integrator, cache.cache5)
fsalfirst, fsallast = get_fsalfirstlast(cache.cache5)
fsalfirst, fsallast = get_fsalfirstlast(cache.cache5,u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
# the controller was initialized by default for algs[1]
reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[5])
elseif cache.current == 6
initialize!(integrator, cache.cache6)
fsalfirst, fsallast = get_fsalfirstlast(cache.cache6)
fsalfirst, fsallast = get_fsalfirstlast(cache.cache6,u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
# the controller was initialized by default for algs[1]
Expand All @@ -76,22 +77,23 @@ end

function initialize!(integrator, cache::CompositeCache)
cache.current = cache.choice_function(integrator)
u = integrator.u
if cache.current == 1
initialize!(integrator, @inbounds(cache.caches[1]))
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1])
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1],u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
elseif cache.current == 2
initialize!(integrator, @inbounds(cache.caches[2]))
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2])
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2],u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
# the controller was initialized by default for integrator.alg.algs[1]
reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1],
integrator.alg.algs[2])
else
initialize!(integrator, @inbounds(cache.caches[cache.current]))
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[cache.current])
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[cache.current],u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1],
Expand All @@ -102,14 +104,15 @@ end

function initialize!(integrator, cache::CompositeCache{Tuple{T1, T2}, F}) where {T1, T2, F}
cache.current = cache.choice_function(integrator)
u = integrator.u
if cache.current == 1
initialize!(integrator, @inbounds(cache.caches[1]))
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1])
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1],u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
elseif cache.current == 2
initialize!(integrator, @inbounds(cache.caches[2]))
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2])
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2],u)
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1],
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqCore/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ function DiffEqBase.__init(
reinitiailize = true
saveiter = 0 # Starts at 0 so first save is at 1
saveiter_dense = 0
faslfirst, fsallast = get_fsalfirstlast(cache)
faslfirst, fsallast = get_fsalfirstlast(cache,u)

integrator = ODEIntegrator{typeof(_alg), isinplace(prob), uType, typeof(du),
tType, typeof(p),
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqExplicitRK/src/explicit_rk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
tab::TabType
end

get_fsalfirstlast(cache::ExplicitRKCache) = (cache.kk[1], cache.fsallast)
get_fsalfirstlast(cache::ExplicitRKCache,u) = (cache.kk[1], cache.fsallast)

function alg_cache(alg::ExplicitRK, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand Down
4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqExponentialRK/src/exponential_rk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Classical ExpRK method caches
abstract type ExpRKCache <: OrdinaryDiffEqMutableCache end
abstract type ExpRKConstantCache <: OrdinaryDiffEqConstantCache end
get_fsalfirstlast(cache::ExpRKCache) = (zero(cache.rtmp), zero(cache.rtmp))
get_fsalfirstlast(cache::ExpRKCache,u) = (zero(cache.rtmp), zero(cache.rtmp))

# Precomputation of exponential-like operators
"""
Expand Down Expand Up @@ -892,7 +892,7 @@ end
B1::expType # ϕ1(hA) + ϕ2(hA)
B0::expType # -ϕ2(hA)
end
get_fsalfirstlast(cache::ETD2) = (ETD2Fsal(cache.rtmp1), ETD2Fsal(cache.rtmp1))
get_fsalfirstlast(cache::ETD2,u) = (ETD2Fsal(cache.rtmp1), ETD2Fsal(cache.rtmp1))

function alg_cache(alg::ETD2, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
abstract type ExtrapolationMutableCache <: OrdinaryDiffEqMutableCache end
get_fsalfirstlast(cache::ExtrapolationMutableCache) = (cache.fsalfirst, cache.k)
get_fsalfirstlast(cache::ExtrapolationMutableCache,u) = (cache.fsalfirst, cache.k)

@cache mutable struct AitkenNevilleCache{
uType,
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
abstract type FIRKMutableCache <: OrdinaryDiffEqMutableCache end
get_fsalfirstlast(cache::FIRKMutableCache) = (cache.fsalfirst, cache.k)
get_fsalfirstlast(cache::FIRKMutableCache,u) = (cache.fsalfirst, cache.k)

mutable struct RadauIIA3ConstantCache{F, Tab, Tol, Dt, U, JType} <:
OrdinaryDiffEqConstantCache
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqFeagin/src/feagin_caches.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
abstract type FeaginCache <: OrdinaryDiffEqMutableCache end
get_fsalfirstlast(cache::FeaginCache) = (cache.fsalfirst, cache.k)
get_fsalfirstlast(cache::FeaginCache,u) = (cache.fsalfirst, cache.k)

@cache struct Feagin10Cache{uType, uNoUnitsType, rateType, TabType, StepLimiter} <:
FeaginCache
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqFunctionMap/src/functionmap_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
uprev::uType
tmp::rateType
end
get_fsalfirstlast(cache::FunctionMapCache) = (cache.u, cache.uprev)
get_fsalfirstlast(cache::FunctionMapCache,u) = (cache.u, cache.uprev)

function alg_cache(alg::FunctionMap, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand Down
4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqHighOrderRK/src/high_order_rk_caches.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
abstract type HighOrderRKMutableCache <: OrdinaryDiffEqMutableCache end
get_fsalfirstlast(cache::HighOrderRKMutableCache) = (cache.fsalfirst, cache.k)
get_fsalfirstlast(cache::HighOrderRKMutableCache,u) = (cache.fsalfirst, cache.k)
@cache struct TanYam7Cache{uType, rateType, uNoUnitsType, TabType, StageLimiter,
StepLimiter, Thread} <:
HighOrderRKMutableCache
Expand Down Expand Up @@ -92,7 +92,7 @@ end
step_limiter!::StepLimiter
thread::Thread
end
get_fsalfirstlast(cache::DP8Cache) = (cache.k1, cache.k13)
get_fsalfirstlast(cache::DP8Cache,u) = (cache.k1, cache.k13)

function alg_cache(alg::DP8, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# IMEX Multistep methods
abstract type IMEXMutableCache <: OrdinaryDiffEqMutableCache end
get_fsalfirstlast(cache::IMEXMutableCache) = (cache.fsalfirst, du_alias_or_new(cache.nlsolver, cache.fsalfirst))
get_fsalfirstlast(cache::IMEXMutableCache,u) = (cache.fsalfirst, du_alias_or_new(cache.nlsolver, cache.fsalfirst))

# CNAB2

Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqLinear/src/linear_caches.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
abstract type LinearMutableCache <: OrdinaryDiffEqMutableCache end
get_fsalfirstlast(cache::LinearMutableCache) = (cache.fsalfirst, cache.k)
get_fsalfirstlast(cache::LinearMutableCache,u) = (cache.fsalfirst, cache.k)

@cache struct MagnusMidpointCache{uType, rateType, WType, expType} <:
LinearMutableCache
Expand Down
10 changes: 5 additions & 5 deletions lib/OrdinaryDiffEqLowOrderRK/src/fixed_timestep_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function perform_step!(integrator, cache::EulerConstantCache, repeat_step = fals
integrator.u = u
end

get_fsalfirstlast(cache::EulerCache) = (cache.fsalfirst, cache.k)
get_fsalfirstlast(cache::EulerCache,u) = (cache.fsalfirst, cache.k)
function initialize!(integrator, cache::EulerCache)
integrator.kshortsize = 2
@unpack k, fsalfirst = cache
Expand Down Expand Up @@ -97,7 +97,7 @@ end
integrator.u = u
end

get_fsalfirstlast(cache::Union{HeunCache, RalstonCache}) = (cache.fsalfirst, cache.k)
get_fsalfirstlast(cache::Union{HeunCache, RalstonCache},u) = (cache.fsalfirst, cache.k)
function initialize!(integrator, cache::Union{HeunCache, RalstonCache})
integrator.kshortsize = 2
@unpack k, fsalfirst = cache
Expand Down Expand Up @@ -189,7 +189,7 @@ end
integrator.u = u
end

get_fsalfirstlast(cache::MidpointCache) = (cache.fsalfirst, cache.k)
get_fsalfirstlast(cache::MidpointCache,u) = (cache.fsalfirst, cache.k)
function initialize!(integrator, cache::MidpointCache)
@unpack k, fsalfirst = cache
integrator.fsalfirst = fsalfirst
Expand Down Expand Up @@ -288,7 +288,7 @@ end
integrator.u = u
end

get_fsalfirstlast(cache::RK4Cache) = (cache.fsalfirst, cache.k)
get_fsalfirstlast(cache::RK4Cache,u) = (cache.fsalfirst, cache.k)
function initialize!(integrator, cache::RK4Cache)
@unpack tmp, fsalfirst, k₂, k₃, k₄, k = cache
integrator.fsalfirst = fsalfirst
Expand Down Expand Up @@ -420,7 +420,7 @@ end
integrator.u = u
end

get_fsalfirstlast(cache::Anas5Cache) = (cache.k1, cache.k7)
get_fsalfirstlast(cache::Anas5Cache,u) = (cache.k1, cache.k7)
function initialize!(integrator, cache::Anas5Cache)
integrator.kshortsize = 7
resize!(integrator.k, integrator.kshortsize)
Expand Down
Loading

0 comments on commit 645520f

Please sign in to comment.