Skip to content

Commit

Permalink
Merge pull request #2390 from SciML/undef_fsal
Browse files Browse the repository at this point in the history
Refactor ODEIntegrator to not allow undef fsal states
  • Loading branch information
ChrisRackauckas authored Aug 19, 2024
2 parents f43bcaf + 4856d99 commit cb87c1c
Show file tree
Hide file tree
Showing 92 changed files with 524 additions and 824 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import OrdinaryDiffEqCore: OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCac
OrdinaryDiffEqAdaptiveAlgorithm,
OrdinaryDiffEqAdamsVarOrderVarStepAlgorithm,
constvalue, calculate_residuals, calculate_residuals!,
trivial_limiter!,
trivial_limiter!, get_fsalfirstlast,
full_cache
import OrdinaryDiffEqLowOrderRK: BS3ConstantCache, BS3Cache, RK4ConstantCache, RK4Cache
import RecursiveArrayTools: recursivefill!
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
@cache mutable struct AB3Cache{uType, rateType} <: OrdinaryDiffEqMutableCache
abstract type ABMMutableCache <: OrdinaryDiffEqMutableCache end
abstract type ABMVariableCoefficientMutableCache <: OrdinaryDiffEqMutableCache end
get_fsalfirstlast(cache::ABMMutableCache,u) = (cache.fsalfirst, cache.k)
get_fsalfirstlast(cache::ABMVariableCoefficientMutableCache,u) = (cache.fsalfirst, cache.k4)
@cache mutable struct AB3Cache{uType, rateType} <: ABMMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand Down Expand Up @@ -38,7 +42,7 @@ function alg_cache(alg::AB3, u, rate_prototype, ::Type{uEltypeNoUnits},
AB3ConstantCache(k2, k3, 1)
end

@cache mutable struct ABM32Cache{uType, rateType} <: OrdinaryDiffEqMutableCache
@cache mutable struct ABM32Cache{uType, rateType} <: ABMMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand Down Expand Up @@ -78,7 +82,7 @@ function alg_cache(alg::ABM32, u, rate_prototype, ::Type{uEltypeNoUnits},
ABM32ConstantCache(k2, k3, 1)
end

@cache mutable struct AB4Cache{uType, rateType} <: OrdinaryDiffEqMutableCache
@cache mutable struct AB4Cache{uType, rateType} <: ABMMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand Down Expand Up @@ -128,7 +132,7 @@ function alg_cache(alg::AB4, u, rate_prototype, ::Type{uEltypeNoUnits},
AB4ConstantCache(k2, k3, k4, 1)
end

@cache mutable struct ABM43Cache{uType, rateType} <: OrdinaryDiffEqMutableCache
@cache mutable struct ABM43Cache{uType, rateType} <: ABMMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand Down Expand Up @@ -184,7 +188,7 @@ function alg_cache(alg::ABM43, u, rate_prototype, ::Type{uEltypeNoUnits},
ABM43ConstantCache(k2, k3, k4, 1)
end

@cache mutable struct AB5Cache{uType, rateType} <: OrdinaryDiffEqMutableCache
@cache mutable struct AB5Cache{uType, rateType} <: ABMMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand Down Expand Up @@ -236,7 +240,7 @@ function alg_cache(alg::AB5, u, rate_prototype, ::Type{uEltypeNoUnits},
AB5ConstantCache(k2, k3, k4, k5, 1)
end

@cache mutable struct ABM54Cache{uType, rateType} <: OrdinaryDiffEqMutableCache
@cache mutable struct ABM54Cache{uType, rateType} <: ABMMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand Down Expand Up @@ -312,7 +316,7 @@ end

@cache mutable struct VCAB3Cache{uType, rateType, TabType, bs3Type, tArrayType, cArrayType,
uNoUnitsType, coefType, dtArrayType} <:
OrdinaryDiffEqMutableCache
ABMVariableCoefficientMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand Down Expand Up @@ -408,7 +412,7 @@ end

@cache mutable struct VCAB4Cache{uType, rateType, rk4cacheType, tArrayType, cArrayType,
uNoUnitsType, coefType, dtArrayType} <:
OrdinaryDiffEqMutableCache
ABMVariableCoefficientMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand Down Expand Up @@ -504,7 +508,7 @@ end

@cache mutable struct VCAB5Cache{uType, rateType, rk4cacheType, tArrayType, cArrayType,
uNoUnitsType, coefType, dtArrayType} <:
OrdinaryDiffEqMutableCache
ABMVariableCoefficientMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand Down Expand Up @@ -602,7 +606,7 @@ end
@cache mutable struct VCABM3Cache{
uType, rateType, TabType, bs3Type, tArrayType, cArrayType,
uNoUnitsType, coefType, dtArrayType} <:
OrdinaryDiffEqMutableCache
ABMVariableCoefficientMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand Down Expand Up @@ -708,7 +712,7 @@ end

@cache mutable struct VCABM4Cache{uType, rateType, rk4cacheType, tArrayType, cArrayType,
uNoUnitsType, coefType, dtArrayType} <:
OrdinaryDiffEqMutableCache
ABMVariableCoefficientMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand Down Expand Up @@ -813,7 +817,7 @@ end

@cache mutable struct VCABM5Cache{uType, rateType, rk4cacheType, tArrayType, cArrayType,
uNoUnitsType, coefType, dtArrayType} <:
OrdinaryDiffEqMutableCache
ABMVariableCoefficientMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand Down Expand Up @@ -919,7 +923,7 @@ end

@cache mutable struct VCABMCache{uType, rateType, dtType, tArrayType, cArrayType,
uNoUnitsType, coefType, dtArrayType} <:
OrdinaryDiffEqMutableCache
ABMVariableCoefficientMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ function initialize!(integrator,
ABM43Cache,
ABM54Cache})
@unpack fsalfirst, k = cache
integrator.fsalfirst = fsalfirst
integrator.fsallast = k

integrator.kshortsize = 2
resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
Expand Down Expand Up @@ -551,8 +550,7 @@ end

function initialize!(integrator, cache::VCAB3Cache)
@unpack fsalfirst, k4 = cache
integrator.fsalfirst = fsalfirst
integrator.fsallast = k4

integrator.kshortsize = 2
resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
Expand Down Expand Up @@ -685,8 +683,7 @@ end

function initialize!(integrator, cache::VCAB4Cache)
@unpack fsalfirst, k4 = cache
integrator.fsalfirst = fsalfirst
integrator.fsallast = k4

integrator.kshortsize = 2
resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
Expand Down Expand Up @@ -834,8 +831,7 @@ end

function initialize!(integrator, cache::VCAB5Cache)
@unpack fsalfirst, k4 = cache
integrator.fsalfirst = fsalfirst
integrator.fsallast = k4

integrator.kshortsize = 2
resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
Expand Down Expand Up @@ -981,8 +977,7 @@ end

function initialize!(integrator, cache::VCABM3Cache)
@unpack fsalfirst, k4 = cache
integrator.fsalfirst = fsalfirst
integrator.fsallast = k4

integrator.kshortsize = 2
resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
Expand Down Expand Up @@ -1125,8 +1120,7 @@ end

function initialize!(integrator, cache::VCABM4Cache)
@unpack fsalfirst, k4 = cache
integrator.fsalfirst = fsalfirst
integrator.fsallast = k4

integrator.kshortsize = 2
resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
Expand Down Expand Up @@ -1282,8 +1276,7 @@ end

function initialize!(integrator, cache::VCABM5Cache)
@unpack fsalfirst, k4 = cache
integrator.fsalfirst = fsalfirst
integrator.fsallast = k4

integrator.kshortsize = 2
resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
Expand Down Expand Up @@ -1459,8 +1452,7 @@ end

function initialize!(integrator, cache::VCABMCache)
@unpack fsalfirst, k4 = cache
integrator.fsalfirst = fsalfirst
integrator.fsallast = k4

integrator.kshortsize = 2
resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
Expand Down
3 changes: 2 additions & 1 deletion lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
default_controller, stepsize_controller!,
step_accept_controller!,
step_reject_controller!, post_newton_controller!,
u_modified!, DAEAlgorithm, _unwrap_val, DummyController
u_modified!, DAEAlgorithm, _unwrap_val, DummyController,
get_fsalfirstlast
using OrdinaryDiffEqSDIRK: ImplicitEulerConstantCache, ImplicitEulerCache

using TruncatedStacktraces, MuladdMacro, MacroTools, FastBroadcast, RecursiveArrayTools
Expand Down
17 changes: 10 additions & 7 deletions lib/OrdinaryDiffEqBDF/src/bdf_caches.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
abstract type BDFMutableCache <: OrdinaryDiffEqMutableCache end
get_fsalfirstlast(cache::BDFMutableCache,u) = (cache.fsalfirst, du_alias_or_new(cache.nlsolver, cache.fsalfirst))

@cache mutable struct ABDF2ConstantCache{N, dtType, rate_prototype} <:
OrdinaryDiffEqConstantCache
nlsolver::N
Expand All @@ -22,7 +25,7 @@ function alg_cache(alg::ABDF2, u, rate_prototype, ::Type{uEltypeNoUnits},
end

@cache mutable struct ABDF2Cache{uType, rateType, uNoUnitsType, N, dtType, StepLimiter} <:
OrdinaryDiffEqMutableCache
BDFMutableCache
uₙ::uType
uₙ₋₁::uType
uₙ₋₂::uType
Expand Down Expand Up @@ -78,7 +81,7 @@ end
du₂::rateType
end

@cache mutable struct SBDFCache{uType, rateType, N} <: OrdinaryDiffEqMutableCache
@cache mutable struct SBDFCache{uType, rateType, N} <: BDFMutableCache
cnt::Int
ark::Bool
u::uType
Expand Down Expand Up @@ -164,7 +167,7 @@ end
end

@cache mutable struct QNDF1Cache{uType, rateType, coefType, coefType1, coefType2,
uNoUnitsType, N, dtType, StepLimiter} <: OrdinaryDiffEqMutableCache
uNoUnitsType, N, dtType, StepLimiter} <: BDFMutableCache
uprev2::uType
fsalfirst::rateType
D::coefType1
Expand Down Expand Up @@ -252,7 +255,7 @@ end
end

@cache mutable struct QNDF2Cache{uType, rateType, coefType, coefType1, coefType2,
uNoUnitsType, N, dtType, StepLimiter} <: OrdinaryDiffEqMutableCache
uNoUnitsType, N, dtType, StepLimiter} <: BDFMutableCache
uprev2::uType
uprev3::uType
fsalfirst::rateType
Expand Down Expand Up @@ -383,7 +386,7 @@ end

@cache mutable struct QNDFCache{MO, UType, RUType, rateType, N, coefType, dtType, EEstType,
gammaType, uType, uNoUnitsType, StepLimiter} <:
OrdinaryDiffEqMutableCache
BDFMutableCache
fsalfirst::rateType
dd::uType
utilde::uType
Expand Down Expand Up @@ -462,7 +465,7 @@ function alg_cache(alg::QNDF{MO}, u, rate_prototype, ::Type{uEltypeNoUnits},
end

@cache mutable struct MEBDF2Cache{uType, rateType, uNoUnitsType, N} <:
OrdinaryDiffEqMutableCache
BDFMutableCache
u::uType
uprev::uType
uprev2::uType
Expand Down Expand Up @@ -574,7 +577,7 @@ end
@cache mutable struct FBDFCache{
MO, N, rateType, uNoUnitsType, tsType, tType, uType, uuType,
coeffType, EEstType, rType, wType, StepLimiter} <:
OrdinaryDiffEqMutableCache
BDFMutableCache
fsalfirst::rateType
nlsolver::N
ts::tsType
Expand Down
21 changes: 7 additions & 14 deletions lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ end

function initialize!(integrator, cache::ABDF2Cache)
integrator.kshortsize = 2
integrator.fsalfirst = cache.fsalfirst
integrator.fsallast = du_alias_or_new(cache.nlsolver, integrator.fsalfirst)

resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
Expand Down Expand Up @@ -255,8 +254,7 @@ function initialize!(integrator, cache::SBDFCache)
@unpack uprev, p, t = integrator
@unpack f1, f2 = integrator.f
integrator.kshortsize = 2
integrator.fsalfirst = cache.fsalfirst
integrator.fsallast = du_alias_or_new(cache.nlsolver, integrator.fsalfirst)

resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
Expand Down Expand Up @@ -412,8 +410,7 @@ end

function initialize!(integrator, cache::QNDF1Cache)
integrator.kshortsize = 2
integrator.fsalfirst = cache.fsalfirst
integrator.fsallast = du_alias_or_new(cache.nlsolver, integrator.fsalfirst)

resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
Expand Down Expand Up @@ -608,8 +605,7 @@ end

function initialize!(integrator, cache::QNDF2Cache)
integrator.kshortsize = 2
integrator.fsalfirst = cache.fsalfirst
integrator.fsallast = du_alias_or_new(cache.nlsolver, integrator.fsalfirst)

resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
Expand Down Expand Up @@ -833,8 +829,7 @@ end

function initialize!(integrator, cache::QNDFCache)
integrator.kshortsize = 2
integrator.fsalfirst = cache.fsalfirst
integrator.fsallast = du_alias_or_new(cache.nlsolver, integrator.fsalfirst)

resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
Expand Down Expand Up @@ -1010,8 +1005,7 @@ end

function initialize!(integrator, cache::MEBDF2Cache)
integrator.kshortsize = 2
integrator.fsalfirst = cache.fsalfirst
integrator.fsallast = du_alias_or_new(cache.nlsolver, integrator.fsalfirst)

resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
Expand Down Expand Up @@ -1218,8 +1212,7 @@ end

function initialize!(integrator, cache::FBDFCache)
integrator.kshortsize = 2
integrator.fsalfirst = cache.fsalfirst
integrator.fsallast = du_alias_or_new(cache.nlsolver, integrator.fsalfirst)

resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
Expand Down
12 changes: 9 additions & 3 deletions lib/OrdinaryDiffEqBDF/src/dae_caches.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
abstract type DAEBDFMutableCache <: OrdinaryDiffEqMutableCache end
get_fsalfirstlast(cache::DAEBDFMutableCache,u) = (cache.fsalfirst, du_alias_or_new(cache.nlsolver, cache.fsalfirst))

@cache mutable struct DImplicitEulerCache{uType, rateType, uNoUnitsType, N} <:
OrdinaryDiffEqMutableCache
DAEBDFMutableCache
u::uType
uprev::uType
uprev2::uType
Expand All @@ -9,6 +12,9 @@
nlsolver::N
end

# Not FSAL
get_fsalfirstlast(cache::DImplicitEulerCache,u) = (u,u)

mutable struct DImplicitEulerConstantCache{N} <: OrdinaryDiffEqConstantCache
nlsolver::N
end
Expand Down Expand Up @@ -68,7 +74,7 @@ function alg_cache(alg::DABDF2, du, u, res_prototype, rate_prototype,
end

@cache mutable struct DABDF2Cache{uType, rateType, uNoUnitsType, N, dtType} <:
OrdinaryDiffEqMutableCache
DAEBDFMutableCache
uₙ::uType
uₙ₋₁::uType
uₙ₋₂::uType
Expand Down Expand Up @@ -171,7 +177,7 @@ end

@cache mutable struct DFBDFCache{MO, N, rateType, uNoUnitsType, tsType, tType, uType,
uuType, coeffType, EEstType, rType, wType} <:
OrdinaryDiffEqMutableCache
DAEBDFMutableCache
fsalfirst::rateType
nlsolver::N
ts::tsType
Expand Down
Loading

0 comments on commit cb87c1c

Please sign in to comment.