Skip to content

Commit

Permalink
Merge pull request #2220 from termi-official/do/rosenbrock-stage-limi…
Browse files Browse the repository at this point in the history
…ters

Introduce and propagate stage limiters for Rosenbrock methods.
  • Loading branch information
ChrisRackauckas authored May 24, 2024
2 parents c9a578c + b5b0c45 commit a1838b7
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 36 deletions.
11 changes: 7 additions & 4 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3085,20 +3085,23 @@ for Alg in [
:Rodas5Pe,
:Rodas5Pr]
@eval begin
struct $Alg{CS, AD, F, P, FDT, ST, CJ, StepLimiter} <:
struct $Alg{CS, AD, F, P, FDT, ST, CJ, StepLimiter, StageLimiter} <:
OrdinaryDiffEqRosenbrockAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
precs::P
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
end
function $Alg(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing,
precs = DEFAULT_PRECS, step_limiter! = trivial_limiter!)
precs = DEFAULT_PRECS, step_limiter! = trivial_limiter!,
stage_limiter! = trivial_limiter!)
$Alg{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
typeof(precs), diff_type, _unwrap_val(standardtag),
_unwrap_val(concrete_jac), typeof(step_limiter!)}(linsolve,
precs, step_limiter!)
_unwrap_val(concrete_jac), typeof(step_limiter!),
typeof(stage_limiter!)}(linsolve, precs, step_limiter!,
stage_limiter!)
end
end

Expand Down
59 changes: 39 additions & 20 deletions src/caches/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ abstract type RosenbrockMutableCache <: OrdinaryDiffEqMutableCache end

@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
TabType, TFType, UFType, F, JCType, GCType,
RTolType, A, AV, StepLimiter} <: RosenbrockMutableCache
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
u::uType
uprev::uType
k₁::rateType
Expand Down Expand Up @@ -33,13 +33,14 @@ abstract type RosenbrockMutableCache <: OrdinaryDiffEqMutableCache end
alg::A
algebraic_vars::AV
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
end

TruncatedStacktraces.@truncate_stacktrace Rosenbrock23Cache 1

@cache mutable struct Rosenbrock32Cache{uType, rateType, uNoUnitsType, JType, WType,
TabType, TFType, UFType, F, JCType, GCType,
RTolType, A, AV, StepLimiter} <: RosenbrockMutableCache
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
u::uType
uprev::uType
k₁::rateType
Expand Down Expand Up @@ -67,6 +68,7 @@ TruncatedStacktraces.@truncate_stacktrace Rosenbrock23Cache 1
alg::A
algebraic_vars::AV
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
end

function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -110,7 +112,8 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
Rosenbrock23Cache(u, uprev, k₁, k₂, k₃, du1, du2, f₁,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg.step_limiter!)
linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg.step_limiter!,
alg.stage_limiter!)
end

function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -153,7 +156,7 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},

Rosenbrock32Cache(u, uprev, k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W,
tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config,
grad_config, reltol, alg, algebraic_vars, alg.step_limiter!)
grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, alg.stage_limiter!)
end

struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD} <:
Expand Down Expand Up @@ -232,7 +235,7 @@ end

@cache mutable struct Rosenbrock33Cache{uType, rateType, uNoUnitsType, JType, WType,
TabType, TFType, UFType, F, JCType, GCType,
RTolType, A, StepLimiter} <: RosenbrockMutableCache
RTolType, A, StepLimiter, StageLimiter} <: RosenbrockMutableCache
u::uType
uprev::uType
du::rateType
Expand Down Expand Up @@ -260,6 +263,7 @@ end
reltol::RTolType
alg::A
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
end

function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -298,7 +302,8 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
Rosenbrock33Cache(u, uprev, du, du1, du2, k1, k2, k3, k4,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!)
end

function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -316,7 +321,7 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
end

@cache mutable struct Rosenbrock34Cache{uType, rateType, uNoUnitsType, JType, WType,
TabType, TFType, UFType, F, JCType, GCType, StepLimiter} <:
TabType, TFType, UFType, F, JCType, GCType, StepLimiter, StageLimiter} <:
RosenbrockMutableCache
u::uType
uprev::uType
Expand All @@ -343,6 +348,7 @@ end
jac_config::JCType
grad_config::GCType
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
end

function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -382,7 +388,8 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits},
Rosenbrock34Cache(u, uprev, du, du1, du2, k1, k2, k3, k4,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
linsolve_tmp,
linsolve, jac_config, grad_config, alg.step_limiter!)
linsolve, jac_config, grad_config, alg.step_limiter!,
alg.stage_limiter!)
end

struct Rosenbrock34ConstantCache{TF, UF, Tab, JType, WType, F} <:
Expand Down Expand Up @@ -460,7 +467,7 @@ struct Rodas3PConstantCache{TF, UF, Tab, JType, WType, F, AD} <: OrdinaryDiffEqC
end

@cache mutable struct Rodas23WCache{uType, rateType, uNoUnitsType, JType, WType, TabType,
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter} <:
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <:
RosenbrockMutableCache
u::uType
uprev::uType
Expand Down Expand Up @@ -493,10 +500,11 @@ end
reltol::RTolType
alg::A
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
end

@cache mutable struct Rodas3PCache{uType, rateType, uNoUnitsType, JType, WType, TabType,
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter} <:
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <:
RosenbrockMutableCache
u::uType
uprev::uType
Expand Down Expand Up @@ -529,6 +537,7 @@ end
reltol::RTolType
alg::A
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
end

function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -571,7 +580,8 @@ function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits},
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rodas23WCache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!)
end

TruncatedStacktraces.@truncate_stacktrace Rodas23WCache 1
Expand Down Expand Up @@ -615,7 +625,8 @@ function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits},
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rodas3PCache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!)
end

TruncatedStacktraces.@truncate_stacktrace Rodas3PCache 1
Expand Down Expand Up @@ -663,7 +674,7 @@ struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F, AD} <: OrdinaryDiffEqCo
end

@cache mutable struct Rodas4Cache{uType, rateType, uNoUnitsType, JType, WType, TabType,
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter} <:
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <:
RosenbrockMutableCache
u::uType
uprev::uType
Expand Down Expand Up @@ -696,6 +707,7 @@ end
reltol::RTolType
alg::A
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
end

function alg_cache(alg::Rodas4, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -739,7 +751,8 @@ function alg_cache(alg::Rodas4, u, rate_prototype, ::Type{uEltypeNoUnits},
Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4,
k5, k6,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!)
end

TruncatedStacktraces.@truncate_stacktrace Rodas4Cache 1
Expand Down Expand Up @@ -800,7 +813,8 @@ function alg_cache(alg::Rodas42, u, rate_prototype, ::Type{uEltypeNoUnits},
Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4,
k5, k6,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!)
end

function alg_cache(alg::Rodas42, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -859,7 +873,8 @@ function alg_cache(alg::Rodas4P, u, rate_prototype, ::Type{uEltypeNoUnits},
Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4,
k5, k6,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!)
end

function alg_cache(alg::Rodas4P, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -918,7 +933,8 @@ function alg_cache(alg::Rodas4P2, u, rate_prototype, ::Type{uEltypeNoUnits},
Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4,
k5, k6,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!)
end

function alg_cache(alg::Rodas4P2, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -951,7 +967,7 @@ end

@cache mutable struct Rosenbrock5Cache{
uType, rateType, uNoUnitsType, JType, WType, TabType,
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter} <:
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <:
RosenbrockMutableCache
u::uType
uprev::uType
Expand Down Expand Up @@ -987,6 +1003,7 @@ end
reltol::RTolType
alg::A
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
end

TruncatedStacktraces.@truncate_stacktrace Rosenbrock5Cache 1
Expand Down Expand Up @@ -1036,7 +1053,8 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
k5, k6, k7, k8,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!)
end

function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -1099,7 +1117,8 @@ function alg_cache(
k5, k6, k7, k8,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!)
end

function alg_cache(
Expand Down
Loading

0 comments on commit a1838b7

Please sign in to comment.