Skip to content

Commit

Permalink
Perform_step! refactor for Rodas5*
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith authored Sep 19, 2024
2 parents 2fa672b + 0dfb039 commit f17e149
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 1,718 deletions.
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ end
function DiffEqBase.interp_summary(::Type{cacheType},
dense::Bool) where {
cacheType <:
Union{Rodas4ConstantCache, Rodas23WConstantCache, Rodas3PConstantCache,
Union{RosenbrockCombinedConstantCache, Rodas23WConstantCache, Rodas3PConstantCache,
RosenbrockCache, Rodas23WCache, Rodas3PCache}}
dense ? "specialized 3rd order \"free\" stiffness-aware interpolation" :
"1st order linear"
Expand All @@ -20,8 +20,8 @@ end
function DiffEqBase.interp_summary(::Type{cacheType},
dense::Bool) where {
cacheType <:
Union{Rosenbrock5ConstantCache,
Rosenbrock5Cache}}
Union{RosenbrockCombinedConstantCache,
RosenbrockCache}}
dense ? "specialized 4rd order \"free\" stiffness-aware interpolation" :
"1st order linear"
end
227 changes: 26 additions & 201 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,24 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT
alg::A
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
interp_order::Int
end
function full_cache(c::RosenbrockCache)
return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2,
c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp]
end

struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
autodiff::AD
interp_order::Int
end

@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
TabType, TFType, UFType, F, JCType, GCType,
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
Expand Down Expand Up @@ -702,22 +714,16 @@ end

### Rodas4 methods

struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
autodiff::AD
end

tabtype(::Rodas4) = Rodas4Tableau
tabtype(::Rodas42) = Rodas42Tableau
tabtype(::Rodas4P) = Rodas4PTableau
tabtype(::Rodas4P2) = Rodas4P2Tableau
tabtype(::Rodas5) = Rodas5Tableau
tabtype(::Rodas5P) = Rodas5PTableau
tabtype(::Rodas5Pr) = Rodas5PTableau
tabtype(::Rodas5Pe) = Rodas5PTableau

function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr},
u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
Expand All @@ -727,21 +733,22 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
Rodas4ConstantCache(tf, uf,
tabtype(alg)(constvalue(uBottomEltypeNoUnits),
constvalue(tTypeNoUnits)), J, W, linsolve,
alg_autodiff(alg))
tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
RosenbrockCombinedConstantCache(tf, uf,
tab, J, W, linsolve,
alg_autodiff(alg), size(tab.H, 1))
end

function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr},
u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}

tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
# Initialize vectors
dense = [zero(rate_prototype) for _ in 1:2]
ks = [zero(rate_prototype) for _ in 1:6]
dense = [zero(rate_prototype) for _ in 1:size(tab.H, 1)]
ks = [zero(rate_prototype) for _ in 1:size(tab.A, 1)]
du = zero(rate_prototype)
du1 = zero(rate_prototype)
du2 = zero(rate_prototype)
Expand All @@ -760,7 +767,6 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
recursivefill!(atmp, false)
weight = similar(u, uEltypeNoUnits)
recursivefill!(weight, false)
tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))

tf = TimeGradientWrapper(f, uprev, p)
uf = UJacobianWrapper(f, t, p)
Expand All @@ -783,190 +789,9 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast,
dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg,
alg.step_limiter!, alg.stage_limiter!)
end

################################################################################

### Rosenbrock5

struct Rosenbrock5ConstantCache{TF, UF, Tab, JType, WType, F} <: RosenbrockConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
end

@cache mutable struct Rosenbrock5Cache{
uType, rateType, uNoUnitsType, JType, WType, TabType,
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <:
RosenbrockMutableCache
u::uType
uprev::uType
dense1::rateType
dense2::rateType
dense3::rateType
du::rateType
du1::rateType
du2::rateType
k1::rateType
k2::rateType
k3::rateType
k4::rateType
k5::rateType
k6::rateType
k7::rateType
k8::rateType
fsalfirst::rateType
fsallast::rateType
dT::rateType
J::JType
W::WType
tmp::rateType
atmp::uNoUnitsType
weight::uNoUnitsType
tab::TabType
tf::TFType
uf::UFType
linsolve_tmp::rateType
linsolve::F
jac_config::JCType
grad_config::GCType
reltol::RTolType
alg::A
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
alg.step_limiter!, alg.stage_limiter!, size(tab.H, 1))
end

function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
dense1 = zero(rate_prototype)
dense2 = zero(rate_prototype)
dense3 = zero(rate_prototype)
du = zero(rate_prototype)
du1 = zero(rate_prototype)
du2 = zero(rate_prototype)
k1 = zero(rate_prototype)
k2 = zero(rate_prototype)
k3 = zero(rate_prototype)
k4 = zero(rate_prototype)
k5 = zero(rate_prototype)
k6 = zero(rate_prototype)
k7 = zero(rate_prototype)
k8 = zero(rate_prototype)
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
weight = similar(u, uEltypeNoUnits)
recursivefill!(weight, false)
tab = Rodas5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))

tf = TimeGradientWrapper(f, uprev, p)
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rosenbrock5Cache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4,
k5, k6, k7, k8,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!)
end

function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tf = TimeDerivativeWrapper(f, u, p)
uf = UDerivativeWrapper(f, t, p)
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
Rosenbrock5ConstantCache(tf, uf,
Rodas5Tableau(constvalue(uBottomEltypeNoUnits),
constvalue(tTypeNoUnits)), J, W, linsolve)
end

function alg_cache(
alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
dense1 = zero(rate_prototype)
dense2 = zero(rate_prototype)
dense3 = zero(rate_prototype)
du = zero(rate_prototype)
du1 = zero(rate_prototype)
du2 = zero(rate_prototype)
k1 = zero(rate_prototype)
k2 = zero(rate_prototype)
k3 = zero(rate_prototype)
k4 = zero(rate_prototype)
k5 = zero(rate_prototype)
k6 = zero(rate_prototype)
k7 = zero(rate_prototype)
k8 = zero(rate_prototype)
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
weight = similar(u, uEltypeNoUnits)
recursivefill!(weight, false)
tab = Rodas5PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))

tf = TimeGradientWrapper(f, uprev, p)
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rosenbrock5Cache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4,
k5, k6, k7, k8,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!)
end

function alg_cache(
alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tf = TimeDerivativeWrapper(f, u, p)
uf = UDerivativeWrapper(f, t, p)
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
Rosenbrock5ConstantCache(tf, uf,
Rodas5PTableau(constvalue(uBottomEltypeNoUnits),
constvalue(tTypeNoUnits)), J, W, linsolve)
end

function get_fsalfirstlast(
cache::Union{Rosenbrock23Cache, Rosenbrock32Cache, Rosenbrock33Cache,
Expand Down
Loading

0 comments on commit f17e149

Please sign in to comment.