From 0e2685abcbac3b48a5f17b9db11b73206dd6cb9e Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Fri, 20 Sep 2024 09:06:28 -0400 Subject: [PATCH 1/6] broken 3rd order Rosenbrock refactor --- .../src/interp_func.jl | 21 +- .../src/rosenbrock_caches.jl | 706 +--------- .../src/rosenbrock_interpolants.jl | 147 +- .../src/rosenbrock_perform_step.jl | 1178 +---------------- .../src/rosenbrock_tableaus.jl | 270 ++-- .../src/stiff_addsteps.jl | 45 - 6 files changed, 133 insertions(+), 2234 deletions(-) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl b/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl index 07f4bbeee7..077e27c7b0 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl @@ -1,27 +1,8 @@ -function DiffEqBase.interp_summary(::Type{cacheType}, - dense::Bool) where { - cacheType <: - Union{Rosenbrock23ConstantCache, - Rosenbrock32ConstantCache, - Rosenbrock23Cache, - Rosenbrock32Cache}} - dense ? "specialized 2nd order \"free\" stiffness-aware interpolation" : - "1st order linear" -end -function DiffEqBase.interp_summary(::Type{cacheType}, - dense::Bool) where { - cacheType <: - Union{RosenbrockCombinedConstantCache, Rodas23WConstantCache, Rodas3PConstantCache, - RosenbrockCache, Rodas23WCache, Rodas3PCache}} - dense ? "specialized 3rd order \"free\" stiffness-aware interpolation" : - "1st order linear" -end - function DiffEqBase.interp_summary(::Type{cacheType}, dense::Bool) where { cacheType <: Union{RosenbrockCombinedConstantCache, RosenbrockCache}} - dense ? "specialized 4rd order \"free\" stiffness-aware interpolation" : + dense ? "specialized $(cache.interp_order) order \"free\" stiffness-aware interpolation" : "1st order linear" end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl index 33fc5dcd2b..81ceca315e 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl @@ -11,16 +11,15 @@ end ################################################################################ # Shampine's Low-order Rosenbrocks - mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabType, - TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <: - RosenbrockMutableCache + TFType, UFType, F, JCType, GCType, RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache u::uType uprev::uType dense::Vector{rateType} du::rateType du1::rateType du2::rateType + f₁::rateType ks::Vector{rateType} fsalfirst::rateType fsallast::rateType @@ -39,6 +38,7 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT grad_config::GCType reltol::RTolType alg::A + algebraic_vars::AV step_limiter!::StepLimiter stage_limiter!::StageLimiter interp_order::Int @@ -59,421 +59,6 @@ struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: Rose interp_order::Int end -@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType, - TabType, TFType, UFType, F, JCType, GCType, - RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache - u::uType - uprev::uType - k₁::rateType - k₂::rateType - k₃::rateType - du1::rateType - du2::rateType - f₁::rateType - fsalfirst::rateType - fsallast::rateType - dT::rateType - J::JType - W::WType - tmp::rateType - atmp::uNoUnitsType - weight::uNoUnitsType - tab::TabType - tf::TFType - uf::UFType - linsolve_tmp::rateType - linsolve::F - jac_config::JCType - grad_config::GCType - reltol::RTolType - alg::A - algebraic_vars::AV - step_limiter!::StepLimiter - stage_limiter!::StageLimiter -end - -@cache mutable struct Rosenbrock32Cache{uType, rateType, uNoUnitsType, JType, WType, - TabType, TFType, UFType, F, JCType, GCType, - RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache - u::uType - uprev::uType - k₁::rateType - k₂::rateType - k₃::rateType - du1::rateType - du2::rateType - f₁::rateType - fsalfirst::rateType - fsallast::rateType - dT::rateType - J::JType - W::WType - tmp::rateType - atmp::uNoUnitsType - weight::uNoUnitsType - tab::TabType - tf::TFType - uf::UFType - linsolve_tmp::rateType - linsolve::F - jac_config::JCType - grad_config::GCType - reltol::RTolType - alg::A - algebraic_vars::AV - step_limiter!::StepLimiter - stage_limiter!::StageLimiter -end - -function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - k₁ = zero(rate_prototype) - k₂ = zero(rate_prototype) - k₃ = zero(rate_prototype) - du1 = zero(rate_prototype) - du2 = zero(rate_prototype) - # f₀ = zero(u) fsalfirst - f₁ = zero(rate_prototype) - fsalfirst = zero(rate_prototype) - fsallast = zero(rate_prototype) - dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - tmp = zero(rate_prototype) - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - weight = similar(u, uEltypeNoUnits) - recursivefill!(weight, false) - tab = Rosenbrock23Tableau(constvalue(uBottomEltypeNoUnits)) - tf = TimeGradientWrapper(f, uprev, p) - uf = UJacobianWrapper(f, t, p) - linsolve_tmp = zero(rate_prototype) - - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) - linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, - assumptions = LinearSolve.OperatorAssumptions(true)) - - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) - algebraic_vars = f.mass_matrix === I ? nothing : - [all(iszero, x) for x in eachcol(f.mass_matrix)] - - Rosenbrock23Cache(u, uprev, k₁, k₂, k₃, du1, du2, f₁, - fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, - linsolve_tmp, - linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, - alg.stage_limiter!) -end - -function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - k₁ = zero(rate_prototype) - k₂ = zero(rate_prototype) - k₃ = zero(rate_prototype) - du1 = zero(rate_prototype) - du2 = zero(rate_prototype) - # f₀ = zero(u) fsalfirst - f₁ = zero(rate_prototype) - fsalfirst = zero(rate_prototype) - fsallast = zero(rate_prototype) - dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - tmp = zero(rate_prototype) - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - weight = similar(u, uEltypeNoUnits) - recursivefill!(weight, false) - tab = Rosenbrock32Tableau(constvalue(uBottomEltypeNoUnits)) - - tf = TimeGradientWrapper(f, uprev, p) - uf = UJacobianWrapper(f, t, p) - linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) - linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, - assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) - algebraic_vars = f.mass_matrix === I ? nothing : - [all(iszero, x) for x in eachcol(f.mass_matrix)] - - Rosenbrock32Cache(u, uprev, k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W, - tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config, - grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, alg.stage_limiter!) -end - -struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD} <: - RosenbrockConstantCache - c₃₂::T - d::T - tf::TF - uf::UF - J::JType - W::WType - linsolve::F - autodiff::AD -end - -function Rosenbrock23ConstantCache(::Type{T}, tf, uf, J, W, linsolve, autodiff) where {T} - tab = Rosenbrock23Tableau(T) - Rosenbrock23ConstantCache(tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff) -end - -function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tf = TimeDerivativeWrapper(f, u, p) - uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) - linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) - linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) - Rosenbrock23ConstantCache(constvalue(uBottomEltypeNoUnits), tf, uf, J, W, linsolve, - alg_autodiff(alg)) -end - -struct Rosenbrock32ConstantCache{T, TF, UF, JType, WType, F, AD} <: - RosenbrockConstantCache - c₃₂::T - d::T - tf::TF - uf::UF - J::JType - W::WType - linsolve::F - autodiff::AD -end - -function Rosenbrock32ConstantCache(::Type{T}, tf, uf, J, W, linsolve, autodiff) where {T} - tab = Rosenbrock32Tableau(T) - Rosenbrock32ConstantCache(tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff) -end - -function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tf = TimeDerivativeWrapper(f, u, p) - uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) - linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) - linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) - Rosenbrock32ConstantCache(constvalue(uBottomEltypeNoUnits), tf, uf, J, W, linsolve, - alg_autodiff(alg)) -end - -################################################################################ - -### 3rd order specialized Rosenbrocks - -struct Rosenbrock33ConstantCache{TF, UF, Tab, JType, WType, F} <: - RosenbrockConstantCache - tf::TF - uf::UF - tab::Tab - J::JType - W::WType - linsolve::F -end - -@cache mutable struct Rosenbrock33Cache{uType, rateType, uNoUnitsType, JType, WType, - TabType, TFType, UFType, F, JCType, GCType, - RTolType, A, StepLimiter, StageLimiter} <: RosenbrockMutableCache - u::uType - uprev::uType - du::rateType - du1::rateType - du2::rateType - k1::rateType - k2::rateType - k3::rateType - k4::rateType - fsalfirst::rateType - fsallast::rateType - dT::rateType - J::JType - W::WType - tmp::rateType - atmp::uNoUnitsType - weight::uNoUnitsType - tab::TabType - tf::TFType - uf::UFType - linsolve_tmp::rateType - linsolve::F - jac_config::JCType - grad_config::GCType - reltol::RTolType - alg::A - step_limiter!::StepLimiter - stage_limiter!::StageLimiter -end - -function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - du = zero(rate_prototype) - du1 = zero(rate_prototype) - du2 = zero(rate_prototype) - k1 = zero(rate_prototype) - k2 = zero(rate_prototype) - k3 = zero(rate_prototype) - k4 = zero(rate_prototype) - fsalfirst = zero(rate_prototype) - fsallast = zero(rate_prototype) - dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - tmp = zero(rate_prototype) - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - weight = similar(u, uEltypeNoUnits) - recursivefill!(weight, false) - tab = ROS3PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - tf = TimeGradientWrapper(f, uprev, p) - uf = UJacobianWrapper(f, t, p) - linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) - linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, - assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) - Rosenbrock33Cache(u, uprev, du, du1, du2, k1, k2, k3, k4, - fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, - linsolve_tmp, - linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!, - alg.stage_limiter!) -end - -function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tf = TimeDerivativeWrapper(f, u, p) - uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) - linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) - linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) - Rosenbrock33ConstantCache(tf, uf, - ROS3PTableau(constvalue(uBottomEltypeNoUnits), - constvalue(tTypeNoUnits)), J, W, linsolve) -end - -@cache mutable struct Rosenbrock34Cache{uType, rateType, uNoUnitsType, JType, WType, - TabType, TFType, UFType, F, JCType, GCType, StepLimiter, StageLimiter} <: - RosenbrockMutableCache - u::uType - uprev::uType - du::rateType - du1::rateType - du2::rateType - k1::rateType - k2::rateType - k3::rateType - k4::rateType - fsalfirst::rateType - fsallast::rateType - dT::rateType - J::JType - W::WType - tmp::rateType - atmp::uNoUnitsType - weight::uNoUnitsType - tab::TabType - tf::TFType - uf::UFType - linsolve_tmp::rateType - linsolve::F - jac_config::JCType - grad_config::GCType - step_limiter!::StepLimiter - stage_limiter!::StageLimiter -end - -function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - du = zero(rate_prototype) - du1 = zero(rate_prototype) - du2 = zero(rate_prototype) - k1 = zero(rate_prototype) - k2 = zero(rate_prototype) - k3 = zero(rate_prototype) - k4 = zero(rate_prototype) - fsalfirst = zero(rate_prototype) - fsallast = zero(rate_prototype) - dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - tmp = zero(rate_prototype) - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - weight = similar(u, uEltypeNoUnits) - recursivefill!(weight, false) - tab = Rodas3Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - - tf = TimeGradientWrapper(f, uprev, p) - uf = UJacobianWrapper(f, t, p) - linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) - linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, - assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) - Rosenbrock34Cache(u, uprev, du, du1, du2, k1, k2, k3, k4, - fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, - linsolve_tmp, - linsolve, jac_config, grad_config, alg.step_limiter!, - alg.stage_limiter!) -end - -struct Rosenbrock34ConstantCache{TF, UF, Tab, JType, WType, F} <: - RosenbrockConstantCache - tf::TF - uf::UF - tab::Tab - J::JType - W::WType - linsolve::F -end - -function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tf = TimeDerivativeWrapper(f, u, p) - uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) - linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) - linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) - Rosenbrock34ConstantCache(tf, uf, - Rodas3Tableau(constvalue(uBottomEltypeNoUnits), - constvalue(tTypeNoUnits)), J, W, linsolve) -end - -################################################################################ - -### ROS2 methods - @ROS2(:cache) ################################################################################ @@ -498,222 +83,11 @@ jac_cache(c::Rosenbrock4Cache) = (c.J, c.W) ############################################################################### ### Rodas methods - -struct Rodas23WConstantCache{TF, UF, Tab, JType, WType, F, AD} <: - RosenbrockConstantCache - tf::TF - uf::UF - tab::Tab - J::JType - W::WType - linsolve::F - autodiff::AD -end - -struct Rodas3PConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache - tf::TF - uf::UF - tab::Tab - J::JType - W::WType - linsolve::F - autodiff::AD -end - -@cache mutable struct Rodas23WCache{uType, rateType, uNoUnitsType, JType, WType, TabType, - TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <: - RosenbrockMutableCache - u::uType - uprev::uType - dense1::rateType - dense2::rateType - dense3::rateType - du::rateType - du1::rateType - du2::rateType - k1::rateType - k2::rateType - k3::rateType - k4::rateType - k5::rateType - fsalfirst::rateType - fsallast::rateType - dT::rateType - J::JType - W::WType - tmp::rateType - atmp::uNoUnitsType - weight::uNoUnitsType - tab::TabType - tf::TFType - uf::UFType - linsolve_tmp::rateType - linsolve::F - jac_config::JCType - grad_config::GCType - reltol::RTolType - alg::A - step_limiter!::StepLimiter - stage_limiter!::StageLimiter -end - -@cache mutable struct Rodas3PCache{uType, rateType, uNoUnitsType, JType, WType, TabType, - TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <: - RosenbrockMutableCache - u::uType - uprev::uType - dense1::rateType - dense2::rateType - dense3::rateType - du::rateType - du1::rateType - du2::rateType - k1::rateType - k2::rateType - k3::rateType - k4::rateType - k5::rateType - fsalfirst::rateType - fsallast::rateType - dT::rateType - J::JType - W::WType - tmp::rateType - atmp::uNoUnitsType - weight::uNoUnitsType - tab::TabType - tf::TFType - uf::UFType - linsolve_tmp::rateType - linsolve::F - jac_config::JCType - grad_config::GCType - reltol::RTolType - alg::A - step_limiter!::StepLimiter - stage_limiter!::StageLimiter -end - -function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - dense1 = zero(rate_prototype) - dense2 = zero(rate_prototype) - dense3 = zero(rate_prototype) - du = zero(rate_prototype) - du1 = zero(rate_prototype) - du2 = zero(rate_prototype) - k1 = zero(rate_prototype) - k2 = zero(rate_prototype) - k3 = zero(rate_prototype) - k4 = zero(rate_prototype) - k5 = zero(rate_prototype) - fsalfirst = zero(rate_prototype) - fsallast = zero(rate_prototype) - dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - tmp = zero(rate_prototype) - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - weight = similar(u, uEltypeNoUnits) - recursivefill!(weight, false) - tab = Rodas3PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - - tf = TimeGradientWrapper(f, uprev, p) - uf = UJacobianWrapper(f, t, p) - linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) - linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, - assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) - Rodas23WCache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5, - fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, - linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!, - alg.stage_limiter!) -end - -function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - dense1 = zero(rate_prototype) - dense2 = zero(rate_prototype) - dense3 = zero(rate_prototype) - du = zero(rate_prototype) - du1 = zero(rate_prototype) - du2 = zero(rate_prototype) - k1 = zero(rate_prototype) - k2 = zero(rate_prototype) - k3 = zero(rate_prototype) - k4 = zero(rate_prototype) - k5 = zero(rate_prototype) - fsalfirst = zero(rate_prototype) - fsallast = zero(rate_prototype) - dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - tmp = zero(rate_prototype) - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - weight = similar(u, uEltypeNoUnits) - recursivefill!(weight, false) - tab = Rodas3PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - - tf = TimeGradientWrapper(f, uprev, p) - uf = UJacobianWrapper(f, t, p) - linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) - linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, - assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) - Rodas3PCache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5, - fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, - linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!, - alg.stage_limiter!) -end - -function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tf = TimeDerivativeWrapper(f, u, p) - uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) - linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) - linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) - Rodas23WConstantCache(tf, uf, - Rodas3PTableau(constvalue(uBottomEltypeNoUnits), - constvalue(tTypeNoUnits)), J, W, linsolve, - alg_autodiff(alg)) -end - -function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tf = TimeDerivativeWrapper(f, u, p) - uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) - linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) - linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) - Rodas3PConstantCache(tf, uf, - Rodas3PTableau(constvalue(uBottomEltypeNoUnits), - constvalue(tTypeNoUnits)), J, W, linsolve, - alg_autodiff(alg)) -end - -### Rodas4 methods - +tabtype(::Rosenbrock23) = Rosenbrock23Tableau +tabtype(::Rosenbrock32) = Rosenbrock32Tableau +tabtype(::Rodas23W) = Rodas23WTableau +tabtype(::ROS3P) = ROS3PTableau +tabtype(::Rodas3) = Rodas3Tableau tabtype(::Rodas4) = Rodas4Tableau tabtype(::Rodas42) = Rodas42Tableau tabtype(::Rodas4P) = Rodas4PTableau @@ -723,45 +97,24 @@ tabtype(::Rodas5P) = Rodas5PTableau tabtype(::Rodas5Pr) = Rodas5PTableau tabtype(::Rodas5Pe) = Rodas5PTableau -function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, - u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tf = TimeDerivativeWrapper(f, u, p) - uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) - linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) - linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) - tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - RosenbrockCombinedConstantCache(tf, uf, - tab, J, W, linsolve, - alg_autodiff(alg), size(tab.H, 1)) -end - -function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, +function alg_cache( + alg::Union{Rosenbrock23, Rosenbrock32, ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - - tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - # Initialize vectors + tab = Rodas5PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) dense = [zero(rate_prototype) for _ in 1:size(tab.H, 1)] - ks = [zero(rate_prototype) for _ in 1:size(tab.A, 1)] du = zero(rate_prototype) du1 = zero(rate_prototype) du2 = zero(rate_prototype) + ks = [zero(rate_prototype) for _ in 1:size(tab.A, 1)] - # Initialize other variables + f₁ = zero(rate_prototype) fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) - - # Build J and W matrices J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - - # Temporary and helper variables tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) @@ -772,30 +125,39 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5 uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, nothing)..., weight, tmp) - linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) - - # Return the cache struct with vectors - RosenbrockCache( - u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast, - dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, - linsolve, jac_config, grad_config, reltol, alg, - alg.step_limiter!, alg.stage_limiter!, size(tab.H, 1)) + algebraic_vars = f.mass_matrix === I ? nothing : + [all(iszero, x) for x in eachcol(f.mass_matrix)] + RosenbrockCache(u, uprev, dense, du, du1, du2, ks, f₁, fsalfirst, fsallast, + dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, + linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, + alg.step_limiter!, alg.stage_limiter!, size(tab.H, 1)) end +function alg_cache( + alg::Union{Rosenbrock23, Rosenbrock32, ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, + u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, + dt, reltol, p, calck, + ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tf = TimeDerivativeWrapper(f, u, p) + uf = UDerivativeWrapper(f, t, p) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) + linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) + linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) + tab = + RosenbrockCombinedConstantCache(tf, uf, tab, J, W, linsolve, alg_autodiff(alg), size(tab.H, 1)) +end function get_fsalfirstlast( - cache::Union{Rosenbrock23Cache, Rosenbrock32Cache, Rosenbrock33Cache, - Rosenbrock34Cache, + cache::Union{RosenbrockCache, Rosenbrock4Cache}, u) (cache.fsalfirst, cache.fsallast) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl index 003595567e..20cf702e14 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl @@ -1,134 +1,9 @@ -### Fallbacks to capture -ROSENBROCKS_WITH_INTERPOLATIONS = Union{Rosenbrock23ConstantCache, Rosenbrock23Cache, - Rosenbrock32ConstantCache, Rosenbrock32Cache, - Rodas23WConstantCache, Rodas3PConstantCache, - Rodas23WCache, Rodas3PCache, - RosenbrockCombinedConstantCache, - RosenbrockCache} - -function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::ROSENBROCKS_WITH_INTERPOLATIONS, - idxs, T::Type{Val{D}}, differential_vars) where {D} - throw(DerivativeOrderNotPossibleError()) -end - -function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::ROSENBROCKS_WITH_INTERPOLATIONS, - idxs, T::Type{Val{D}}, differential_vars) where {D} - throw(DerivativeOrderNotPossibleError()) -end - -#### - -""" -From MATLAB ODE Suite by Shampine -""" -@def rosenbrock2332unpack begin - if cache isa OrdinaryDiffEqMutableCache - d = cache.tab.d - else - d = cache.d - end -end - -@def rosenbrock2332pre0 begin - @rosenbrock2332unpack - c1 = Θ * (1 - Θ) / (1 - 2d) - c2 = Θ * (Θ - 2d) / (1 - 2d) -end - -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock23ConstantCache, - Rosenbrock32ConstantCache}, idxs::Nothing, - T::Type{Val{0}}, differential_vars) - @rosenbrock2332pre0 - @inbounds y₀ + dt * (c1 * k[1] + c2 * k[2]) -end - -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock23Cache, Rosenbrock32Cache}, - idxs::Nothing, T::Type{Val{0}}, differential_vars) - @rosenbrock2332pre0 - @inbounds @.. y₀+dt * (c1 * k[1] + c2 * k[2]) -end - -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache, - Rosenbrock32ConstantCache, Rosenbrock32Cache - }, idxs, T::Type{Val{0}}, differential_vars) - @rosenbrock2332pre0 - @.. y₀[idxs]+dt * (c1 * k[1][idxs] + c2 * k[2][idxs]) -end - -@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock23ConstantCache, - Rosenbrock23Cache, - Rosenbrock32ConstantCache, Rosenbrock32Cache - }, idxs::Nothing, T::Type{Val{0}}, differential_vars) - @rosenbrock2332pre0 - @inbounds @.. out=y₀ + dt * (c1 * k[1] + c2 * k[2]) - out -end - -@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock23ConstantCache, - Rosenbrock23Cache, - Rosenbrock32ConstantCache, Rosenbrock32Cache - }, idxs, T::Type{Val{0}}, differential_vars) - @rosenbrock2332pre0 - @views @.. out=y₀[idxs] + dt * (c1 * k[1][idxs] + c2 * k[2][idxs]) - out -end - -# First Derivative of the dense output -@def rosenbrock2332pre1 begin - @rosenbrock2332unpack - c1diff = (1 - 2 * Θ) / (1 - 2 * d) - c2diff = (2 * Θ - 2 * d) / (1 - 2 * d) -end - -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache, - Rosenbrock32ConstantCache, Rosenbrock32Cache - }, idxs::Nothing, T::Type{Val{1}}, differential_vars) - @rosenbrock2332pre1 - @.. c1diff * k[1]+c2diff * k[2] -end - -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache, - Rosenbrock32ConstantCache, Rosenbrock32Cache - }, idxs, T::Type{Val{1}}, differential_vars) - @rosenbrock2332pre1 - @.. c1diff * k[1][idxs]+c2diff * k[2][idxs] -end - -@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock23ConstantCache, - Rosenbrock23Cache, - Rosenbrock32ConstantCache, Rosenbrock32Cache - }, idxs::Nothing, T::Type{Val{1}}, differential_vars) - @rosenbrock2332pre1 - @.. out=c1diff * k[1] + c2diff * k[2] - out -end - -@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock23ConstantCache, - Rosenbrock23Cache, - Rosenbrock32ConstantCache, Rosenbrock32Cache - }, idxs, T::Type{Val{1}}, differential_vars) - @rosenbrock2332pre1 - @views @.. out=c1diff * k[1][idxs] + c2diff * k[2][idxs] - out -end - """ From MATLAB ODE Suite by Shampine """ @muladd function _ode_interpolant( - Θ, dt, y₀, y₁, k, cache::Union{RosenbrockCombinedConstantCache, Rodas23WConstantCache, Rodas3PConstantCache, RosenbrockCache, Rodas23WCache, Rodas3PCache}, + Θ, dt, y₀, y₁, k, cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache}, idxs::Nothing, T::Type{Val{0}}, differential_vars) Θ1 = 1 - Θ if !hasproperty(cache, :interp_order) || cache.interp_order == 2 @@ -139,8 +14,7 @@ From MATLAB ODE Suite by Shampine end @muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache, - Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache}, idxs, T::Type{Val{0}}, differential_vars) Θ1 = 1 - Θ if !hasproperty(cache, :interp_order) || cache.interp_order == 2 @@ -151,8 +25,7 @@ end end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache, - Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache}, idxs::Nothing, T::Type{Val{0}}, differential_vars) Θ1 = 1 - Θ if !hasproperty(cache, :interp_order) || cache.interp_order == 2 @@ -164,8 +37,7 @@ end end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache, - Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache}, idxs, T::Type{Val{0}}, differential_vars) Θ1 = 1 - Θ if !hasproperty(cache, :interp_order) || cache.interp_order == 2 @@ -179,7 +51,7 @@ end # First Derivative @muladd function _ode_interpolant( - Θ, dt, y₀, y₁, k, cache::Union{RosenbrockCache, Rodas23WCache, Rodas3PCache, RosenbrockCombinedConstantCache, Rodas23WConstantCache, Rodas3PConstantCache}, + Θ, dt, y₀, y₁, k, cache::Union{RosenbrockCache, RosenbrockCombinedConstantCache}, idxs::Nothing, T::Type{Val{1}}, differential_vars) if !hasproperty(cache, :interp_order) || cache.interp_order == 2 @.. (k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁)/dt @@ -190,8 +62,7 @@ end end end @muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache, - Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache}, idxs, T::Type{Val{1}}, differential_vars) if !hasproperty(cache, :interp_order) || cache.interp_order == 2 @views @.. (k[1][idxs] + Θ * (-2 * k[1][idxs] + 2 * k[2][idxs] - 3 * k[2][idxs] * Θ) - @@ -203,8 +74,7 @@ end end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache, - Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache}, idxs::Nothing, T::Type{Val{1}}, differential_vars) if !hasproperty(cache, :interp_order) || cache.interp_order == 2 @.. out=(k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁) / dt @@ -216,8 +86,7 @@ end end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache, - Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache}, idxs, T::Type{Val{1}}, differential_vars) if !hasproperty(cache, :interp_order) || cache.interp_order == 2 @views @.. out=(k[1][idxs] + diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index 1414a93f92..a60c57af6c 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -1,811 +1,3 @@ -function initialize!(integrator, cache::Union{Rosenbrock23Cache, - Rosenbrock32Cache}) - integrator.kshortsize = 2 - @unpack k₁, k₂, fsalfirst, fsallast = cache - resize!(integrator.k, integrator.kshortsize) - integrator.k[1] = k₁ - integrator.k[2] = k₂ - integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) -end - -function initialize!(integrator, - cache::Union{Rosenbrock23ConstantCache, - Rosenbrock32ConstantCache}) - integrator.kshortsize = 2 - integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) - integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - # Avoid undefined entries if k is an array of arrays - integrator.fsallast = zero(integrator.fsalfirst) - integrator.k[1] = zero(integrator.fsalfirst) - integrator.k[2] = zero(integrator.fsalfirst) -end - -@muladd function perform_step!(integrator, cache::Rosenbrock23Cache, repeat_step = false) - @unpack t, dt, uprev, u, f, p, opts = integrator - @unpack k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W, tmp, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache - @unpack c₃₂, d = cache.tab - - # Assignments - sizeu = size(u) - mass_matrix = integrator.f.mass_matrix - - # Precalculations - dtγ = dt * d - neginvdtγ = -inv(dtγ) - dto2 = dt / 2 - dto6 = dt / 6 - - if repeat_step - f(integrator.fsalfirst, uprev, p, t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - end - - calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step) - - calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t) - - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtγ)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtγ)) - end - - vecu = _vec(linres.u) - veck₁ = _vec(k₁) - - @.. veck₁ = vecu * neginvdtγ - integrator.stats.nsolve += 1 - - @.. u = uprev + dto2 * k₁ - stage_limiter!(u, integrator, p, t + dto2) - f(f₁, u, p, t + dto2) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - copyto!(tmp, k₁) - else - mul!(_vec(tmp), mass_matrix, _vec(k₁)) - end - - @.. linsolve_tmp = f₁ - tmp - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) - veck₂ = _vec(k₂) - - @.. veck₂ = vecu * neginvdtγ + veck₁ - integrator.stats.nsolve += 1 - - @.. u = uprev + dt * k₂ - stage_limiter!(u, integrator, p, t + dt) - step_limiter!(u, integrator, p, t + dt) - - if integrator.opts.adaptive - f(fsallast, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=fsallast - c₃₂ * (k₂ - f₁) - - 2(k₁ - fsalfirst) + dt * dT - else - @.. broadcast=false du2=c₃₂ * k₂ + 2k₁ - mul!(_vec(du1), mass_matrix, _vec(du2)) - @.. broadcast=false linsolve_tmp=fsallast - du1 + c₃₂ * f₁ + 2fsalfirst + - dt * dT - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) - veck3 = _vec(k₃) - @.. veck3 = vecu * neginvdtγ - - integrator.stats.nsolve += 1 - - if mass_matrix === I - @.. broadcast=false tmp=dto6 * (k₁ - 2 * k₂ + k₃) - else - veck₁ = _vec(k₁) - veck₂ = _vec(k₂) - veck₃ = _vec(k₃) - vectmp = _vec(tmp) - @.. broadcast=false vectmp=ifelse(cache.algebraic_vars, - false, dto6 * (veck₁ - 2 * veck₂ + veck₃)) - end - calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - - if mass_matrix !== I - algvar = reshape(cache.algebraic_vars, size(u)) - invatol = inv(integrator.opts.abstol) - @.. atmp = ifelse(algvar, fsallast, false) * invatol - integrator.EEst += integrator.opts.internalnorm(atmp, t) - end - end - cache.linsolve = linres.cache -end - -@muladd function perform_step!(integrator, cache::Rosenbrock32Cache, repeat_step = false) - @unpack t, dt, uprev, u, f, p, opts = integrator - @unpack k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W, tmp, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache - @unpack c₃₂, d = cache.tab - - # Assignments - sizeu = size(u) - mass_matrix = integrator.f.mass_matrix - - # Precalculations - dtγ = dt * d - neginvdtγ = -inv(dtγ) - dto2 = dt / 2 - dto6 = dt / 6 - - if repeat_step - f(integrator.fsalfirst, uprev, p, t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - end - - calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step) - - calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t) - - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtγ)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtγ)) - end - - vecu = _vec(linres.u) - veck₁ = _vec(k₁) - - @.. veck₁ = vecu * neginvdtγ - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + dto2 * k₁ - stage_limiter!(u, integrator, p, t + dto2) - f(f₁, u, p, t + dto2) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - tmp .= k₁ - else - mul!(_vec(tmp), mass_matrix, _vec(k₁)) - end - - @.. broadcast=false linsolve_tmp=f₁ - tmp - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) - veck₂ = _vec(k₂) - - @.. veck₂ = vecu * neginvdtγ + veck₁ - integrator.stats.nsolve += 1 - - @.. tmp = uprev + dt * k₂ - stage_limiter!(u, integrator, p, t + dt) - f(fsallast, tmp, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=fsallast - c₃₂ * (k₂ - f₁) - 2(k₁ - fsalfirst) + - dt * dT - else - @.. broadcast=false du2=c₃₂ * k₂ + 2k₁ - mul!(_vec(du1), mass_matrix, _vec(du2)) - @.. broadcast=false linsolve_tmp=fsallast - du1 + c₃₂ * f₁ + 2fsalfirst + dt * dT - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) - veck3 = _vec(k₃) - - @.. veck3 = vecu * neginvdtγ - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + dto6 * (k₁ + 4k₂ + k₃) - - step_limiter!(u, integrator, p, t + dt) - - if integrator.opts.adaptive - @.. broadcast=false tmp=dto6 * (k₁ - 2 * k₂ + k₃) - calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - - if mass_matrix !== I - invatol = inv(integrator.opts.abstol) - @.. atmp = ifelse(cache.algebraic_vars, fsallast, false) * invatol - integrator.EEst += integrator.opts.internalnorm(atmp, t) - end - end - cache.linsolve = linres.cache -end - -@muladd function perform_step!(integrator, cache::Rosenbrock23ConstantCache, - repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - @unpack c₃₂, d, tf, uf = cache - - # Precalculations - dtγ = dt * d - neginvdtγ = -inv(dtγ) - dto2 = dt / 2 - dto6 = dt / 6 - - if repeat_step - integrator.fsalfirst = f(uprev, p, t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - end - - mass_matrix = integrator.f.mass_matrix - - # Time derivative - dT = calc_tderivative(integrator, cache) - - W = calc_W(integrator, cache, dtγ, repeat_step) - if !issuccess_W(W) - integrator.EEst = 2 - return nothing - end - - k₁ = _reshape(W \ _vec((integrator.fsalfirst + dtγ * dT)), axes(uprev)) * neginvdtγ - integrator.stats.nsolve += 1 - tmp = @.. uprev + dto2 * k₁ - f₁ = f(tmp, p, t + dto2) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - k₂ = _reshape(W \ _vec(f₁ - k₁), axes(uprev)) - else - k₂ = _reshape(W \ _vec(f₁ - mass_matrix * k₁), axes(uprev)) - end - k₂ = @.. k₂ * neginvdtγ + k₁ - integrator.stats.nsolve += 1 - u = uprev + dt * k₂ - - if integrator.opts.adaptive - integrator.fsallast = f(u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = @.. (integrator.fsallast - c₃₂ * (k₂ - f₁) - - 2 * (k₁ - integrator.fsalfirst) + dt * dT) - else - linsolve_tmp = mass_matrix * (@.. c₃₂ * k₂ + 2 * k₁) - linsolve_tmp = @.. (integrator.fsallast - linsolve_tmp + - c₃₂ * f₁ + 2 * integrator.fsalfirst + dt * dT) - end - k₃ = _reshape(W \ _vec(linsolve_tmp), axes(uprev)) * neginvdtγ - integrator.stats.nsolve += 1 - - if u isa Number - utilde = dto6 * f.mass_matrix[1, 1] * (k₁ - 2 * k₂ + k₃) - else - utilde = f.mass_matrix * (@.. dto6 * (k₁ - 2 * k₂ + k₃)) - end - atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - - if mass_matrix !== I - invatol = inv(integrator.opts.abstol) - atmp = @. ifelse(integrator.differential_vars, false, integrator.fsallast) * - invatol - integrator.EEst += integrator.opts.internalnorm(atmp, t) - end - end - integrator.k[1] = k₁ - integrator.k[2] = k₂ - integrator.u = u - return nothing -end - -@muladd function perform_step!(integrator, cache::Rosenbrock32ConstantCache, - repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - @unpack c₃₂, d, tf, uf = cache - - # Precalculations - dtγ = dt * d - neginvdtγ = -inv(dtγ) - dto2 = dt / 2 - dto6 = dt / 6 - - mass_matrix = integrator.f.mass_matrix - - if repeat_step - integrator.fsalfirst = f(uprev, p, t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - end - - # Time derivative - dT = calc_tderivative(integrator, cache) - - W = calc_W(integrator, cache, dtγ, repeat_step) - if !issuccess_W(W) - integrator.EEst = 2 - return nothing - end - - k₁ = _reshape(W \ -_vec((integrator.fsalfirst + dtγ * dT)), axes(uprev)) / dtγ - integrator.stats.nsolve += 1 - tmp = @.. uprev + dto2 * k₁ - f₁ = f(tmp, p, t + dto2) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - k₂ = _reshape(W \ _vec(f₁ - k₁), axes(uprev)) - else - linsolve_tmp = f₁ - mass_matrix * k₁ - k₂ = _reshape(W \ _vec(linsolve_tmp), axes(uprev)) - end - k₂ = @.. k₂ * neginvdtγ + k₁ - - integrator.stats.nsolve += 1 - tmp = @.. uprev + dt * k₂ - integrator.fsallast = f(tmp, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = @.. (integrator.fsallast - c₃₂ * (k₂ - f₁) - - 2(k₁ - integrator.fsalfirst) + dt * dT) - else - linsolve_tmp = mass_matrix * (@.. c₃₂ * k₂ + 2 * k₁) - linsolve_tmp = @.. (integrator.fsallast - linsolve_tmp + - c₃₂ * f₁ + 2 * integrator.fsalfirst + dt * dT) - end - k₃ = _reshape(W \ _vec(linsolve_tmp), axes(uprev)) * neginvdtγ - integrator.stats.nsolve += 1 - u = @.. uprev + dto6 * (k₁ + 4k₂ + k₃) - - if integrator.opts.adaptive - utilde = @.. dto6 * (k₁ - 2k₂ + k₃) - atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - - if mass_matrix !== I - invatol = inv(integrator.opts.abstol) - atmp = ifelse(integrator.differential_vars, false, integrator.fsallast) .* - invatol - integrator.EEst += integrator.opts.internalnorm(atmp, t) - end - end - - integrator.k[1] = k₁ - integrator.k[2] = k₂ - integrator.u = u - return nothing -end - -function initialize!(integrator, - cache::Union{Rosenbrock33ConstantCache, - Rosenbrock34ConstantCache, - Rosenbrock4ConstantCache}) - integrator.kshortsize = 2 - integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) - integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - # Avoid undefined entries if k is an array of arrays - integrator.fsallast = zero(integrator.fsalfirst) - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast -end - -function initialize!(integrator, - cache::Union{Rosenbrock33Cache, - Rosenbrock34Cache, - Rosenbrock4Cache}) - integrator.kshortsize = 2 - @unpack fsalfirst, fsallast = cache - resize!(integrator.k, integrator.kshortsize) - integrator.k[1] = fsalfirst - integrator.k[2] = fsallast - integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) -end - -@muladd function perform_step!(integrator, cache::Rosenbrock33ConstantCache, - repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - @unpack tf, uf = cache - @unpack a21, a31, a32, C21, C31, C32, b1, b2, b3, btilde1, btilde2, btilde3, gamma, c2, c3, d1, d2, d3 = cache.tab - - # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtgamma = dt * gamma - - mass_matrix = integrator.f.mass_matrix - - # Time derivative - dT = calc_tderivative(integrator, cache) - - W = calc_W(integrator, cache, dtgamma, repeat_step) - if !issuccess_W(W) - integrator.EEst = 2 - return nothing - end - - linsolve_tmp = integrator.fsalfirst + dtd1 * dT - - k1 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a21 * k1 - du = f(u, p, t + c2 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + dtd2 * dT + dtC21 * k1 - else - linsolve_tmp = du + dtd2 * dT + mass_matrix * (dtC21 * k1) - end - - k2 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a31 * k1 + a32 * k2 - du = f(u, p, t + c3 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + dtd3 * dT + dtC31 * k1 + dtC32 * k2 - else - linsolve_tmp = du + dtd3 * dT + mass_matrix * (dtC31 * k1 + dtC32 * k2) - end - - k3 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + b1 * k1 + b2 * k2 + b3 * k3 - integrator.fsallast = f(u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if integrator.opts.adaptive - utilde = btilde1 * k1 + btilde2 * k2 + btilde3 * k3 - atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u - return nothing -end - -@muladd function perform_step!(integrator, cache::Rosenbrock33Cache, repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - @unpack du, du1, du2, fsalfirst, fsallast, k1, k2, k3, dT, J, W, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache - @unpack a21, a31, a32, C21, C31, C32, b1, b2, b3, btilde1, btilde2, btilde3, gamma, c2, c3, d1, d2, d3 = cache.tab - - # Assignments - mass_matrix = integrator.f.mass_matrix - sizeu = size(u) - utilde = du - - # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtgamma = dt * gamma - - calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step) - - calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t) - - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - end - - vecu = _vec(linres.u) - veck1 = _vec(k1) - - @.. broadcast=false veck1=-vecu - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + a21 * k1 - stage_limiter!(u, integrator, p, t + c2 * dt) - f(du, u, p, t + c2 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd2 * dT + dtC21 * k1 - else - @.. broadcast=false du1=dtC21 * k1 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + dtd2 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) - veck2 = _vec(k2) - - @.. broadcast=false veck2=-vecu - - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + a31 * k1 + a32 * k2 - stage_limiter!(u, integrator, p, t + c3 * dt) - f(du, u, p, t + c3 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd3 * dT + dtC31 * k1 + dtC32 * k2 - else - @.. broadcast=false du1=dtC31 * k1 + dtC32 * k2 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + dtd3 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) - veck3 = _vec(k3) - - @.. broadcast=false veck3=-vecu - - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + b1 * k1 + b2 * k2 + b3 * k3 - - step_limiter!(u, integrator, p, t + dt) - - f(fsallast, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if integrator.opts.adaptive - @.. broadcast=false utilde=btilde1 * k1 + btilde2 * k2 + btilde3 * k3 - calculate_residuals!(atmp, utilde, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - cache.linsolve = linres.cache -end - -################################################################################ - -@muladd function perform_step!(integrator, cache::Rosenbrock34ConstantCache, - repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - @unpack tf, uf = cache - @unpack a21, a31, a32, a41, a42, a43, C21, C31, C32, C41, C42, C43, b1, b2, b3, b4, btilde1, btilde2, btilde3, btilde4, gamma, c2, c3, d1, d2, d3, d4 = cache.tab - - # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtd4 = dt * d4 - dtgamma = dt * gamma - - mass_matrix = integrator.f.mass_matrix - # Time derivative - tf.u = uprev - dT = calc_tderivative(integrator, cache) - - W = calc_W(integrator, cache, dtgamma, repeat_step) - if !issuccess_W(W) - integrator.EEst = 2 - return nothing - end - - linsolve_tmp = integrator.fsalfirst + dtd1 * dT - - k1 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev # +a21*k1 a21 == 0 - # du = f(u, p, t+c2*dt) c2 == 0 and a21 == 0 => du = f(uprev, p, t) == fsalfirst - - if mass_matrix === I - linsolve_tmp = integrator.fsalfirst + dtd2 * dT + dtC21 * k1 - else - linsolve_tmp = integrator.fsalfirst + dtd2 * dT + mass_matrix * (dtC21 * k1) - end - - k2 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a31 * k1 + a32 * k2 - du = f(u, p, t + c3 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + dtd3 * dT + dtC31 * k1 + dtC32 * k2 - else - linsolve_tmp = du + dtd3 * dT + mass_matrix * (dtC31 * k1 + dtC32 * k2) - end - - k3 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a41 * k1 + a42 * k2 + a43 * k3 - du = f(u, p, t + dt) #-- c4 = 1 - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + dtd4 * dT + dtC41 * k1 + dtC42 * k2 + dtC43 * k3 - else - linsolve_tmp = du + dtd4 * dT + mass_matrix * (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - end - - k4 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4 - integrator.fsallast = f(u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if integrator.opts.adaptive - utilde = btilde1 * k1 + btilde2 * k2 + btilde3 * k3 + btilde4 * k4 - atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u - return nothing -end - -@muladd function perform_step!(integrator, cache::Rosenbrock34Cache, repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - @unpack du, du1, du2, fsalfirst, fsallast, k1, k2, k3, k4, dT, J, W, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache - @unpack a21, a31, a32, a41, a42, a43, C21, C31, C32, C41, C42, C43, b1, b2, b3, b4, btilde1, btilde2, btilde3, btilde4, gamma, c2, c3, d1, d2, d3, d4 = cache.tab - - # Assignments - uidx = eachindex(integrator.uprev) - sizeu = size(u) - mass_matrix = integrator.f.mass_matrix - utilde = du - - # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtd4 = dt * d4 - dtgamma = dt * gamma - - calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step) - - calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t) - - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - end - - vecu = _vec(linres.u) - veck1 = _vec(k1) - - @.. broadcast=false veck1=-vecu - integrator.stats.nsolve += 1 - - #= - a21 == 0 and c2 == 0 - so du = integrator.fsalfirst! - @.. broadcast=false u = uprev + a21*k1 - - f(du, u, p, t+c2*dt) - =# - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=fsalfirst + dtd2 * dT + dtC21 * k1 - else - @.. broadcast=false du1=dtC21 * k1 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=fsalfirst + dtd2 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - veck2 = _vec(k2) - @.. broadcast=false veck2=-vecu - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + a31 * k1 + a32 * k2 - stage_limiter!(u, integrator, p, t + c3 * dt) - f(du, u, p, t + c3 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd3 * dT + dtC31 * k1 + dtC32 * k2 - else - @.. broadcast=false du1=dtC31 * k1 + dtC32 * k2 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + dtd3 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - veck3 = _vec(k3) - @.. broadcast=false veck3=-vecu - integrator.stats.nsolve += 1 - @.. broadcast=false u=uprev + a41 * k1 + a42 * k2 + a43 * k3 - stage_limiter!(u, integrator, p, t + dt) - f(du, u, p, t + dt) #-- c4 = 1 - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd4 * dT + dtC41 * k1 + dtC42 * k2 + - dtC43 * k3 - else - @.. broadcast=false du1=dtC41 * k1 + dtC42 * k2 + dtC43 * k3 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + dtd4 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - veck4 = _vec(k4) - @.. broadcast=false veck4=-vecu - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4 - - step_limiter!(u, integrator, p, t + dt) - - f(fsallast, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if integrator.opts.adaptive - @.. broadcast=false utilde=btilde1 * k1 + btilde2 * k2 + btilde3 * k3 + btilde4 * k4 - calculate_residuals!(atmp, utilde, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - cache.linsolve = linres.cache -end - ################################################################################ #### ROS2 type method @@ -833,375 +25,7 @@ end @Rosenbrock4(:performstep) -################################################################################ - -#### Rodas3P type method - -function initialize!(integrator, cache::Union{Rodas23WConstantCache, Rodas3PConstantCache}) - integrator.kshortsize = 3 - integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) - # Avoid undefined entries if k is an array of arrays - integrator.k[1] = zero(integrator.u) - integrator.k[2] = zero(integrator.u) - integrator.k[3] = zero(integrator.u) -end - -@muladd function perform_step!( - integrator, cache::Union{Rodas23WConstantCache, Rodas3PConstantCache}, - repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - @unpack tf, uf = cache - @unpack a21, a41, a42, a43, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, gamma, c2, c3, d1, d2, d3 = cache.tab - - # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - dtC51 = C51 / dt - dtC52 = C52 / dt - dtC53 = C53 / dt - dtC54 = C54 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtgamma = dt * gamma - - mass_matrix = integrator.f.mass_matrix - - # Time derivative - tf.u = uprev - dT = calc_tderivative(integrator, cache) - - W = calc_W(integrator, cache, dtgamma, repeat_step) - if !issuccess_W(W) - integrator.EEst = 2 - return nothing - end - - du = f(uprev, p, t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - k3 = copy(du) #-- save for stage 3 - - linsolve_tmp = du + dtd1 * dT - - k1 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a21 * k1 - du = f(u, p, t + c2 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + dtd2 * dT + dtC21 * k1 - else - linsolve_tmp = du + dtd2 * dT + mass_matrix * (dtC21 * k1) - end - - k2 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - - if mass_matrix === I - linsolve_tmp = k3 + dtd3 * dT + (dtC31 * k1 + dtC32 * k2) - else - linsolve_tmp = k3 + dtd3 * dT + mass_matrix * (dtC31 * k1 + dtC32 * k2) - end - - k3 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a41 * k1 + a42 * k2 + a43 * k3 - du = f(u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - else - linsolve_tmp = du + mass_matrix * (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - end - - k4 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - - if mass_matrix === I - linsolve_tmp = du + (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) - else - linsolve_tmp = du + - mass_matrix * (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) - end - - k5 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - du = u + k4 #-- solution p=2 - u = u + k5 #-- solution p=3 - - EEst = 0.0 - if integrator.opts.calck - @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35, h2_21, h2_22, h2_23, h2_24, h2_25 = cache.tab - integrator.k[1] = h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + h25 * k5 - integrator.k[2] = h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5 - integrator.k[3] = h2_21 * k1 + h2_22 * k2 + h2_23 * k3 + h2_24 * k4 + h2_25 * k5 - if integrator.opts.adaptive - if isa(linsolve_tmp, AbstractFloat) - u_int, u_diff = calculate_interpoldiff( - uprev, du, u, integrator.k[1], integrator.k[2], integrator.k[3]) - else - u_int = linsolve_tmp - u_diff = linsolve_tmp .+ 0 - calculate_interpoldiff!(u_int, u_diff, uprev, du, u, integrator.k[1], - integrator.k[2], integrator.k[3]) - end - atmp = calculate_residuals(u_diff, uprev, u_int, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - EEst = max(EEst, integrator.opts.internalnorm(atmp, t)) #-- role of t unclear - end - end - - if (integrator.alg isa Rodas23W) - k1 = u .+ 0 - u = du .+ 0 - du = k1 .+ 0 - if integrator.opts.calck - integrator.k[1] = integrator.k[3] .+ 0 - integrator.k[2] = 0 * integrator.k[2] - end - end - - if integrator.opts.adaptive - atmp = calculate_residuals(u - du, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = max(EEst, integrator.opts.internalnorm(atmp, t)) - end - - integrator.u = u - return nothing -end - -function initialize!(integrator, cache::Union{Rodas23WCache, Rodas3PCache}) - integrator.kshortsize = 3 - @unpack dense1, dense2, dense3 = cache - resize!(integrator.k, integrator.kshortsize) - integrator.k[1] = dense1 - integrator.k[2] = dense2 - integrator.k[3] = dense3 -end - -@muladd function perform_step!( - integrator, cache::Union{Rodas23WCache, Rodas3PCache}, repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - @unpack du, du1, du2, dT, J, W, uf, tf, k1, k2, k3, k4, k5, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache - @unpack a21, a41, a42, a43, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, gamma, c2, c3, d1, d2, d3 = cache.tab - - # Assignments - sizeu = size(u) - uidx = eachindex(integrator.uprev) - mass_matrix = integrator.f.mass_matrix - - # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - dtC51 = C51 / dt - dtC52 = C52 / dt - dtC53 = C53 / dt - dtC54 = C54 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtgamma = dt * gamma - - f(cache.fsalfirst, uprev, p, t) # used in calc_rosenbrock_differentiation! - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step) - - calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t) - - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = cache.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = cache.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - end - - @.. broadcast=false $(_vec(k1))=-linres.u - - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + a21 * k1 - stage_limiter!(u, integrator, p, t + c2 * dt) - f(du, u, p, t + c2 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd2 * dT + dtC21 * k1 - else - @.. broadcast=false du1=dtC21 * k1 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + dtd2 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - @.. broadcast=false $(_vec(k2))=-linres.u - integrator.stats.nsolve += 1 - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=cache.fsalfirst + dtd3 * dT + - (dtC31 * k1 + dtC32 * k2) - else - @.. broadcast=false du1=dtC31 * k1 + dtC32 * k2 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=cache.fsalfirst + dtd3 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - @.. broadcast=false $(_vec(k3))=-linres.u - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + a41 * k1 + a42 * k2 + a43 * k3 - stage_limiter!(u, integrator, p, t + c2 * dt) - f(du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + - (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - else - @.. broadcast=false du1=dtC41 * k1 + dtC42 * k2 + dtC43 * k3 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - @.. broadcast=false $(_vec(k4))=-linres.u - integrator.stats.nsolve += 1 - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + - (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) - else - @.. broadcast=false du1=dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - @.. broadcast=false $(_vec(k5))=-linres.u - integrator.stats.nsolve += 1 - - du = u + k4 #-- p=2 solution - u .+= k5 - - step_limiter!(u, integrator, p, t + dt) - - EEst = 0.0 - if integrator.opts.calck - @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35, h2_21, h2_22, h2_23, h2_24, h2_25 = cache.tab - @.. broadcast=false integrator.k[1]=h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + - h25 * k5 - @.. broadcast=false integrator.k[2]=h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + - h35 * k5 - @.. broadcast=false integrator.k[3]=h2_21 * k1 + h2_22 * k2 + h2_23 * k3 + - h2_24 * k4 + h2_25 * k5 - if integrator.opts.adaptive - calculate_interpoldiff!( - du1, du2, uprev, du, u, integrator.k[1], integrator.k[2], integrator.k[3]) - calculate_residuals!(atmp, du2, uprev, du1, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - EEst = max(EEst, integrator.opts.internalnorm(atmp, t)) #-- role of t unclear - end - end - - if (integrator.alg isa Rodas23W) - du1[:] = u[:] - u[:] = du[:] - du[:] = du1[:] - if integrator.opts.calck - integrator.k[1][:] = integrator.k[3][:] - integrator.k[2][:] .= 0.0 - end - end - - if integrator.opts.adaptive - calculate_residuals!(atmp, u - du, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = max(EEst, integrator.opts.internalnorm(atmp, t)) - end - cache.linsolve = linres.cache -end - -function calculate_interpoldiff(uprev, up2, up3, c_koeff, d_koeff, c2_koeff) - u_int = 0.0 - u_diff = 0.0 - a1 = up3 + c_koeff - up2 - c2_koeff - a2 = d_koeff - c_koeff + c2_koeff - a3 = -d_koeff - dis = a2^2 - 3 * a1 * a3 - u_int = up3 - u_diff = 0.0 - if dis > 0.0 #-- Min/Max occurs - tau1 = (-a2 - sqrt(dis)) / (3 * a3) - tau2 = (-a2 + sqrt(dis)) / (3 * a3) - if tau1 > tau2 - tau1, tau2 = tau2, tau1 - end - for tau in (tau1, tau2) - if (tau > 0.0) && (tau < 1.0) - y_tau = (1 - tau) * uprev + - tau * (up3 + (1 - tau) * (c_koeff + tau * d_koeff)) - dy_tau = ((a3 * tau + a2) * tau + a1) * tau - if abs(dy_tau) > abs(u_diff) - u_diff = dy_tau - u_int = y_tau - end - end - end - end - return u_int, u_diff -end - -function calculate_interpoldiff!(u_int, u_diff, uprev, up2, up3, c_koeff, d_koeff, c2_koeff) - for i in eachindex(up2) - a1 = up3[i] + c_koeff[i] - up2[i] - c2_koeff[i] - a2 = d_koeff[i] - c_koeff[i] + c2_koeff[i] - a3 = -d_koeff[i] - dis = a2^2 - 3 * a1 * a3 - u_int[i] = up3[i] - u_diff[i] = 0.0 - if dis > 0.0 #-- Min/Max occurs - tau1 = (-a2 - sqrt(dis)) / (3 * a3) - tau2 = (-a2 + sqrt(dis)) / (3 * a3) - if tau1 > tau2 - tau1, tau2 = tau2, tau1 - end - for tau in (tau1, tau2) - if (tau > 0.0) && (tau < 1.0) - y_tau = (1 - tau) * uprev[i] + - tau * (up3[i] + (1 - tau) * (c_koeff[i] + tau * d_koeff[i])) - dy_tau = ((a3 * tau + a2) * tau + a1) * tau - if abs(dy_tau) > abs(u_diff[i]) - u_diff[i] = dy_tau - u_int[i] = y_tau - end - end - end - end - end -end - -#### Rodas4 type method - +#### Arbirary order method function initialize!(integrator, cache::RosenbrockCombinedConstantCache) integrator.kshortsize = size(cache.tab.H, 1) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl index d7968cc572..e98bcf5d7f 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl @@ -21,197 +21,105 @@ function Rosenbrock32Tableau(T) end struct ROS3PTableau{T, T2} - a21::T - a31::T - a32::T - C21::T - C31::T - C32::T - b1::T - b2::T - b3::T - btilde1::T - btilde2::T - btilde3::T + A::Matrix{T} + C::Matrix{T} + b::Vector{T} + btilde::Vector{T} gamma::T2 - c2::T2 - c3::T2 - d1::T - d2::T - d3::T + c::Vector{T2} + d::Vector{T} end function ROS3PTableau(T, T2) gamma = convert(T, 1 / 2 + sqrt(3) / 6) igamma = inv(gamma) - a21 = convert(T, igamma) - a31 = convert(T, igamma) - a32 = convert(T, 0) - C21 = convert(T, -igamma^2) + A = T[ + 0 0 0 + convert(T, igamma) 0 0 + convert(T, igamma) convert(T, 0) 0 + ] tmp = -igamma * (convert(T, 2) - convert(T, 1 / 2) * igamma) - C31 = -igamma * (convert(T, 1) - tmp) - C32 = tmp + C = T[ + 0 0 0 + convert(T, -igamma^2) 0 0 + -igamma * (convert(T, 1) - tmp) tmp 0 + ] tmp = igamma * (convert(T, 2 / 3) - convert(T, 1 / 6) * igamma) - b1 = igamma * (convert(T, 1) + tmp) - b2 = tmp - b3 = convert(T, 1 / 3) * igamma + b = [(igamma * (convert(T, 1) + tmp)), (tmp), (convert(T, 1 / 3) * igamma)] # btilde1 = convert(T,2.113248654051871) # btilde2 = convert(T,1.000000000000000) # btilde3 = convert(T,0.4226497308103742) - btilde1 = b1 - convert(T, 2.113248654051871) - btilde2 = b2 - convert(T, 1.000000000000000) - btilde3 = b3 - convert(T, 0.4226497308103742) - c2 = convert(T, 1) - c3 = convert(T, 1) - d1 = convert(T, 0.7886751345948129) - d2 = convert(T, -0.2113248654051871) - d3 = convert(T, -1.077350269189626) - ROS3PTableau( - a21, a31, a32, C21, C31, C32, b1, b2, b3, btilde1, btilde2, btilde3, gamma, - c2, c3, d1, d2, d3) + btilde = [(convert(T, 2.113248654051871)), (convert(T, 1.000000000000000)), (convert(T, 0.4226497308103742))] + c = T[1, 1] + d = T[0.7886751345948129, -0.2113248654051871, -1.077350269189626] + ROS3PTableau(A, C, b, btilde, gamma, c, d) end struct Rodas3Tableau{T, T2} - a21::T - a31::T - a32::T - a41::T - a42::T - a43::T - C21::T - C31::T - C32::T - C41::T - C42::T - C43::T - b1::T - b2::T - b3::T - b4::T - btilde1::T - btilde2::T - btilde3::T - btilde4::T + A::Matrix{T} + C::Matrix{T} + b::Vector{T} + btilde::Vector{T} gamma::T2 - c2::T2 - c3::T2 - d1::T - d2::T - d3::T - d4::T + c::Vector{T2} + d::Vector{T} end function Rodas3Tableau(T, T2) gamma = convert(T, 1 // 2) - a21 = convert(T, 0) - a31 = convert(T, 2) - a32 = convert(T, 0) - a41 = convert(T, 2) - a42 = convert(T, 0) - a43 = convert(T, 1) - C21 = convert(T, 4) - C31 = convert(T, 1) - C32 = convert(T, -1) - C41 = convert(T, 1) - C42 = convert(T, -1) - C43 = convert(T, -8 // 3) - b1 = convert(T, 2) - b2 = convert(T, 0) - b3 = convert(T, 1) - b4 = convert(T, 1) - btilde1 = convert(T, 0.0) - btilde2 = convert(T, 0.0) - btilde3 = convert(T, 0.0) - btilde4 = convert(T, 1.0) - c2 = convert(T, 0.0) - c3 = convert(T, 1.0) - c4 = convert(T, 1.0) - d1 = convert(T, 1 // 2) - d2 = convert(T, 3 // 2) - d3 = convert(T, 0) - d4 = convert(T, 0) - Rodas3Tableau(a21, a31, a32, a41, a42, a43, C21, C31, C32, C41, C42, C43, b1, b2, b3, - b4, btilde1, btilde2, btilde3, btilde4, gamma, c2, c3, d1, d2, d3, d4) + A = T[ + 0 0 0 + 0 0 0 + 2 0 0 + 2 0 1 + ] + C = T[ + 0 0 0 + 4 0 0 + 1 -1 0 + 1 -1 -8 // 3 + ] + b = T[2, 0, 1, 1] + btilde = T[0.0, 0.0, 0.0, 1.0] + c = T[0.0, 1.0, 1.0] + d = T[1 // 2, 3 // 2, 0, 0] + Rodas3Tableau(A, C, b, btilde, gamma, c, d) end struct Rodas3PTableau{T, T2} - a21::T - a41::T - a42::T - a43::T - C21::T - C31::T - C32::T - C41::T - C42::T - C43::T - C51::T - C52::T - C53::T - C54::T + A::Matrix{T} + C::Matrix{T} gamma::T - c2::T2 - c3::T2 - d1::T - d2::T - d3::T - h21::T - h22::T - h23::T - h24::T - h25::T - h31::T - h32::T - h33::T - h34::T - h35::T - h2_21::T - h2_22::T - h2_23::T - h2_24::T - h2_25::T + c::Vector{T2} + d::Vector{T} + h::Matrix{T} + h2_2::Vector{T} end function Rodas3PTableau(T, T2) gamma = convert(T, 1 // 3) - a21 = convert(T, 4.0 / 3.0) - a41 = convert(T, 2.90625) - a42 = convert(T, 3.375) - a43 = convert(T, 0.40625) - C21 = -convert(T, 4.0) - C31 = convert(T, 8.25) - C32 = convert(T, 6.75) - C41 = convert(T, 1.21875) - C42 = -convert(T, 5.0625) - C43 = -convert(T, 1.96875) - C51 = convert(T, 4.03125) - C52 = -convert(T, 15.1875) - C53 = -convert(T, 4.03125) - C54 = convert(T, 6.0) - c2 = convert(T2, 4.0 / 9.0) - c3 = convert(T2, 0.0) - d1 = convert(T, 1.0 / 3.0) - d2 = -convert(T, 1.0 / 9.0) - d3 = convert(T, 1.0) - h21 = convert(T, 1.78125) - h22 = convert(T, 6.75) - h23 = convert(T, 0.15625) - h24 = -convert(T, 6.0) - h25 = -convert(T, 1.0) - h31 = convert(T, 4.21875) - h32 = -convert(T, 15.1875) - h33 = -convert(T, 3.09375) - h34 = convert(T, 9.0) - h35 = convert(T, 0.0) - h2_21 = convert(T, 4.21875) - h2_22 = -convert(T, 2.025) - h2_23 = -convert(T, 1.63125) - h2_24 = -convert(T, 1.7) - h2_25 = -convert(T, 0.1) - Rodas3PTableau(a21, a41, a42, a43, - C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, - gamma, c2, c3, d1, d2, d3, - h21, h22, h23, h24, h25, h31, h32, h33, h34, h35, h2_21, h2_22, h2_23, h2_24, h2_25) + A = T[ + 0 0 0 0 + 4.0 / 3.0 0 0 0 + 0 0 0 0 + 2.90625 3.375 0.40625 0 + ] + C = T[ + 0 0 0 0 + 4.0 0 0 0 + 8.25 6.75 0 0 + 1.21875 5.0625 1.96875 0 + 4.03125 15.1875 4.03125 6.0 + ] + c = T2[4.0 / 9.0, 0.0] + d = T[1.0 / 3.0, 1.0 / 9.0, 1.0] + H = T[ + 0 0 0 0 0 + 1.78125 6.75 0.15625 6.0 1.0 + 4.21875 15.1875 3.09375 9.0 0.0 + ] + h2_2 = T[4.21875, 2.025, 1.63125, 1.7, 0.1] + Rodas3PTableau(A, C, gamma, c, d, H, h2_2) end @ROS2(:tableau) @@ -338,13 +246,13 @@ function Rodas5Tableau(T, T2) -14.09640773051259 6.925207756232704 -41.47510893210728 2.343771018586405 24.13215229196062 1 1 0 ] C = T[ - 0 0 0 0 0 0 0 - -10.31323885133993 0 0 0 0 0 0 - -21.04823117650003 -7.234992135176716 0 0 0 0 0 - 32.22751541853323 -4.943732386540191 19.44922031041879 0 0 0 0 - -20.69865579590063 -8.816374604402768 1.260436877740897 -0.7495647613787146 0 0 0 - -46.22004352711257 -17.49534862857472 -289.6389582892057 93.60855400400906 318.3822534212147 0 0 - 34.20013733472935 -14.15535402717690 57.82335640988400 25.83362985412365 1.408950972071624 -6.551835421242162 0 + 0 0 0 0 0 0 0 + -10.31323885133993 0 0 0 0 0 0 + -21.04823117650003 -7.234992135176716 0 0 0 0 0 + 32.22751541853323 -4.943732386540191 19.44922031041879 0 0 0 0 + -20.69865579590063 -8.816374604402768 1.260436877740897 -0.7495647613787146 0 0 0 + -46.22004352711257 -17.49534862857472 -289.6389582892057 93.60855400400906 318.3822534212147 0 0 + 34.20013733472935 -14.15535402717690 57.82335640988400 25.83362985412365 1.408950972071624 -6.551835421242162 0 42.57076742291101 -13.80770672017997 93.98938432427124 18.77919633714503 -31.58359187223370 -6.685968952921985 -5.810979938412932 ] c = T2[0, 0.38, 0.3878509998321533, 0.4839718937873840, 0.4570477008819580, 1, 1, 1] @@ -380,7 +288,7 @@ function Rodas5Tableau(T, T2) b7 = convert(T,1) b8 = convert(T,1) =# - RodasTableau(A, C, gamma, c, d, H) + RodasTableau(A, C, gamma, d, c, H) end function Rodas5PTableau(T, T2) @@ -396,13 +304,13 @@ function Rodas5PTableau(T, T2) -7.502846399306121 2.561846144803919 -11.627539656261098 -0.18268767659942256 0.030198172008377946 1 1 0 ] C = T[ - 0 0 0 0 0 0 0 - -14.155112264123755 0 0 0 0 0 0 - -17.97296035885952 -2.859693295451294 0 0 0 0 0 - 147.12150275711716 -1.41221402718213 71.68940251302358 0 0 0 0 - 165.43517024871676 -0.4592823456491126 42.90938336958603 -5.961986721573306 0 0 0 - 24.854864614690072 -3.0009227002832186 47.4931110020768 5.5814197821558125 -0.6610691825249471 0 0 - 30.91273214028599 -3.1208243349937974 77.79954646070892 34.28646028294783 -19.097331116725623 -28.087943162872662 0 + 0 0 0 0 0 0 0 + -14.155112264123755 0 0 0 0 0 0 + -17.97296035885952 -2.859693295451294 0 0 0 0 0 + 147.12150275711716 -1.41221402718213 71.68940251302358 0 0 0 0 + 165.43517024871676 -0.4592823456491126 42.90938336958603 -5.961986721573306 0 0 0 + 24.854864614690072 -3.0009227002832186 47.4931110020768 5.5814197821558125 -0.6610691825249471 0 0 + 30.91273214028599 -3.1208243349937974 77.79954646070892 34.28646028294783 -19.097331116725623 -28.087943162872662 0 37.80277123390563 -3.2571969029072276 112.26918849496327 66.9347231244047 -40.06618937091002 -54.66780262877968 -9.48861652309627 ] c = T2[0, 0.6358126895828704, 0.4095798393397535, 0.9769306725060716, 0.4288403609558664, 1, 1, 1] @@ -412,7 +320,7 @@ function Rodas5PTableau(T, T2) -9.91568850695171 -0.9689944594115154 3.0438037242978453 -24.495224566215796 20.176138334709044 15.98066361424651 -6.789040303419874 -6.710236069923372 11.419903575922262 2.8879645146136994 72.92137995996029 80.12511834622643 -52.072871366152654 -59.78993625266729 -0.15582684282751913 4.883087185713722 ] - RodasTableau(A, C, gamma, c, d, H) + RodasTableau(A, C, gamma, d, c, H) end @RosenbrockW6S4OS(:tableau) @@ -462,4 +370,4 @@ beta3 = 6.8619167645278386e-2 beta4 = 0.8289547562599182 beta5 = 7.9630136489868164e-2 alpha64 = -0.2076823627400282 -=# +=# \ No newline at end of file diff --git a/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl b/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl index d3967b40ee..59674e9683 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl @@ -1,48 +1,3 @@ -function _ode_addsteps!(k, t, uprev, u, dt, f, p, - cache::Union{Rosenbrock23ConstantCache, - Rosenbrock32ConstantCache}, - always_calc_begin = false, allow_calc_end = true, - force_calc_end = false) - if length(k) < 2 || always_calc_begin - @unpack tf, uf, d = cache - dtγ = dt * d - neginvdtγ = -inv(dtγ) - dto2 = dt / 2 - tf.u = uprev - if cache.autodiff isa AutoForwardDiff - dT = ForwardDiff.derivative(tf, t) - else - dT = FiniteDiff.finite_difference_derivative(tf, t, dir = sign(dt)) - end - - mass_matrix = f.mass_matrix - if uprev isa Number - J = ForwardDiff.derivative(uf, uprev) - W = neginvdtγ .+ J - else - J = ForwardDiff.jacobian(uf, uprev) - if mass_matrix isa UniformScaling - W = neginvdtγ * mass_matrix + J - else - W = @.. neginvdtγ * mass_matrix .+ J - end - end - f₀ = f(uprev, p, t) - k₁ = _reshape(W \ _vec((f₀ + dtγ * dT)), axes(uprev)) * neginvdtγ - tmp = @.. uprev + dto2 * k₁ - f₁ = f(tmp, p, t + dto2) - if mass_matrix === I - k₂ = _reshape(W \ _vec(f₁ - k₁), axes(uprev)) - else - k₂ = _reshape(W \ _vec(f₁ - mass_matrix * k₁), axes(uprev)) - end - k₂ = @.. k₂ * neginvdtγ + k₁ - copyat_or_push!(k, 1, k₁) - copyat_or_push!(k, 2, k₂) - end - nothing -end - function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::RosenbrockCombinedConstantCache, always_calc_begin = false, allow_calc_end = true, force_calc_end = false) From ad6fdc228373f715231c6842e2321f5fcc1d808f Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 20 Sep 2024 15:57:19 -0400 Subject: [PATCH 2/6] don't update Rosenbrock23/32 --- .../src/rosenbrock_caches.jl | 227 +++++++++- .../src/rosenbrock_perform_step.jl | 396 ++++++++++++++++++ .../src/rosenbrock_tableaus.jl | 115 +++-- .../src/stiff_addsteps.jl | 45 ++ 4 files changed, 709 insertions(+), 74 deletions(-) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl index 81ceca315e..3079601ed9 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl @@ -8,6 +8,220 @@ function get_fsalfirstlast(cache::GenericRosenbrockMutableCache, u) (cache.fsalfirst, cache.fsallast) end +@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType, + TabType, TFType, UFType, F, JCType, GCType, + RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache + u::uType + uprev::uType + k₁::rateType + k₂::rateType + k₃::rateType + du1::rateType + du2::rateType + f₁::rateType + fsalfirst::rateType + fsallast::rateType + dT::rateType + J::JType + W::WType + tmp::rateType + atmp::uNoUnitsType + weight::uNoUnitsType + tab::TabType + tf::TFType + uf::UFType + linsolve_tmp::rateType + linsolve::F + jac_config::JCType + grad_config::GCType + reltol::RTolType + alg::A + algebraic_vars::AV + step_limiter!::StepLimiter + stage_limiter!::StageLimiter +end + +@cache mutable struct Rosenbrock32Cache{uType, rateType, uNoUnitsType, JType, WType, + TabType, TFType, UFType, F, JCType, GCType, + RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache + u::uType + uprev::uType + k₁::rateType + k₂::rateType + k₃::rateType + du1::rateType + du2::rateType + f₁::rateType + fsalfirst::rateType + fsallast::rateType + dT::rateType + J::JType + W::WType + tmp::rateType + atmp::uNoUnitsType + weight::uNoUnitsType + tab::TabType + tf::TFType + uf::UFType + linsolve_tmp::rateType + linsolve::F + jac_config::JCType + grad_config::GCType + reltol::RTolType + alg::A + algebraic_vars::AV + step_limiter!::StepLimiter + stage_limiter!::StageLimiter +end + +function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, + dt, reltol, p, calck, + ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + k₁ = zero(rate_prototype) + k₂ = zero(rate_prototype) + k₃ = zero(rate_prototype) + du1 = zero(rate_prototype) + du2 = zero(rate_prototype) + # f₀ = zero(u) fsalfirst + f₁ = zero(rate_prototype) + fsalfirst = zero(rate_prototype) + fsallast = zero(rate_prototype) + dT = zero(rate_prototype) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) + tmp = zero(rate_prototype) + atmp = similar(u, uEltypeNoUnits) + recursivefill!(atmp, false) + weight = similar(u, uEltypeNoUnits) + recursivefill!(weight, false) + tab = Rosenbrock23Tableau(constvalue(uBottomEltypeNoUnits)) + tf = TimeGradientWrapper(f, uprev, p) + uf = UJacobianWrapper(f, t, p) + linsolve_tmp = zero(rate_prototype) + + linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) + Pl, Pr = wrapprecs( + alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, + nothing)..., weight, tmp) + linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, + Pl = Pl, Pr = Pr, + assumptions = LinearSolve.OperatorAssumptions(true)) + + grad_config = build_grad_config(alg, f, tf, du1, t) + jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + algebraic_vars = f.mass_matrix === I ? nothing : + [all(iszero, x) for x in eachcol(f.mass_matrix)] + + Rosenbrock23Cache(u, uprev, k₁, k₂, k₃, du1, du2, f₁, + fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, + linsolve_tmp, + linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, + alg.stage_limiter!) +end + +function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, + dt, reltol, p, calck, + ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + k₁ = zero(rate_prototype) + k₂ = zero(rate_prototype) + k₃ = zero(rate_prototype) + du1 = zero(rate_prototype) + du2 = zero(rate_prototype) + # f₀ = zero(u) fsalfirst + f₁ = zero(rate_prototype) + fsalfirst = zero(rate_prototype) + fsallast = zero(rate_prototype) + dT = zero(rate_prototype) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) + tmp = zero(rate_prototype) + atmp = similar(u, uEltypeNoUnits) + recursivefill!(atmp, false) + weight = similar(u, uEltypeNoUnits) + recursivefill!(weight, false) + tab = Rosenbrock32Tableau(constvalue(uBottomEltypeNoUnits)) + + tf = TimeGradientWrapper(f, uprev, p) + uf = UJacobianWrapper(f, t, p) + linsolve_tmp = zero(rate_prototype) + linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) + + Pl, Pr = wrapprecs( + alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, + nothing)..., weight, tmp) + linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, + Pl = Pl, Pr = Pr, + assumptions = LinearSolve.OperatorAssumptions(true)) + grad_config = build_grad_config(alg, f, tf, du1, t) + jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + algebraic_vars = f.mass_matrix === I ? nothing : + [all(iszero, x) for x in eachcol(f.mass_matrix)] + + Rosenbrock32Cache(u, uprev, k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W, + tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config, + grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, alg.stage_limiter!) +end + +struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD} <: + RosenbrockConstantCache + c₃₂::T + d::T + tf::TF + uf::UF + J::JType + W::WType + linsolve::F + autodiff::AD +end + +function Rosenbrock23ConstantCache(::Type{T}, tf, uf, J, W, linsolve, autodiff) where {T} + tab = Rosenbrock23Tableau(T) + Rosenbrock23ConstantCache(tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff) +end + +function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, + dt, reltol, p, calck, + ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tf = TimeDerivativeWrapper(f, u, p) + uf = UDerivativeWrapper(f, t, p) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) + linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) + linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) + Rosenbrock23ConstantCache(constvalue(uBottomEltypeNoUnits), tf, uf, J, W, linsolve, + alg_autodiff(alg)) +end + +struct Rosenbrock32ConstantCache{T, TF, UF, JType, WType, F, AD} <: + RosenbrockConstantCache + c₃₂::T + d::T + tf::TF + uf::UF + J::JType + W::WType + linsolve::F + autodiff::AD +end + +function Rosenbrock32ConstantCache(::Type{T}, tf, uf, J, W, linsolve, autodiff) where {T} + tab = Rosenbrock32Tableau(T) + Rosenbrock32ConstantCache(tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff) +end + +function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, + dt, reltol, p, calck, + ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tf = TimeDerivativeWrapper(f, u, p) + uf = UDerivativeWrapper(f, t, p) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) + linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) + linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) + Rosenbrock32ConstantCache(constvalue(uBottomEltypeNoUnits), tf, uf, J, W, linsolve, + alg_autodiff(alg)) +end + ################################################################################ # Shampine's Low-order Rosenbrocks @@ -19,7 +233,6 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT du::rateType du1::rateType du2::rateType - f₁::rateType ks::Vector{rateType} fsalfirst::rateType fsallast::rateType @@ -43,6 +256,7 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT stage_limiter!::StageLimiter interp_order::Int end + function full_cache(c::RosenbrockCache) return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2, c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp] @@ -98,19 +312,18 @@ tabtype(::Rodas5Pr) = Rodas5PTableau tabtype(::Rodas5Pe) = Rodas5PTableau function alg_cache( - alg::Union{Rosenbrock23, Rosenbrock32, ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, + alg::Union{ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Rodas5PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) dense = [zero(rate_prototype) for _ in 1:size(tab.H, 1)] du = zero(rate_prototype) du1 = zero(rate_prototype) du2 = zero(rate_prototype) ks = [zero(rate_prototype) for _ in 1:size(tab.A, 1)] - f₁ = zero(rate_prototype) fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) @@ -135,14 +348,14 @@ function alg_cache( jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) algebraic_vars = f.mass_matrix === I ? nothing : [all(iszero, x) for x in eachcol(f.mass_matrix)] - RosenbrockCache(u, uprev, dense, du, du1, du2, ks, f₁, fsalfirst, fsallast, + RosenbrockCache(u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, alg.stage_limiter!, size(tab.H, 1)) end function alg_cache( - alg::Union{Rosenbrock23, Rosenbrock32, ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, + alg::Union{ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, @@ -152,7 +365,7 @@ function alg_cache( J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) - tab = + tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) RosenbrockCombinedConstantCache(tf, uf, tab, J, W, linsolve, alg_autodiff(alg), size(tab.H, 1)) end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index a60c57af6c..438be408e2 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -1,4 +1,400 @@ ################################################################################ +function initialize!(integrator, cache::Union{Rosenbrock23Cache, + Rosenbrock32Cache}) + integrator.kshortsize = 2 + @unpack k₁, k₂, fsalfirst, fsallast = cache + resize!(integrator.k, integrator.kshortsize) + integrator.k[1] = k₁ + integrator.k[2] = k₂ + integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) +end + +function initialize!(integrator, + cache::Union{Rosenbrock23ConstantCache, + Rosenbrock32ConstantCache}) + integrator.kshortsize = 2 + integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) + integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + + # Avoid undefined entries if k is an array of arrays + integrator.fsallast = zero(integrator.fsalfirst) + integrator.k[1] = zero(integrator.fsalfirst) + integrator.k[2] = zero(integrator.fsalfirst) +end + +@muladd function perform_step!(integrator, cache::Rosenbrock23Cache, repeat_step = false) + @unpack t, dt, uprev, u, f, p, opts = integrator + @unpack k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W, tmp, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache + @unpack c₃₂, d = cache.tab + + # Assignments + sizeu = size(u) + mass_matrix = integrator.f.mass_matrix + + # Precalculations + dtγ = dt * d + neginvdtγ = -inv(dtγ) + dto2 = dt / 2 + dto6 = dt / 6 + + if repeat_step + f(integrator.fsalfirst, uprev, p, t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + end + + calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step) + + calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev, + integrator.opts.abstol, integrator.opts.reltol, + integrator.opts.internalnorm, t) + + if repeat_step + linres = dolinsolve( + integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), + du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, + solverdata = (; gamma = dtγ)) + else + linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), + du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, + solverdata = (; gamma = dtγ)) + end + + vecu = _vec(linres.u) + veck₁ = _vec(k₁) + + @.. veck₁ = vecu * neginvdtγ + integrator.stats.nsolve += 1 + + @.. u = uprev + dto2 * k₁ + stage_limiter!(u, integrator, p, t + dto2) + f(f₁, u, p, t + dto2) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + + if mass_matrix === I + copyto!(tmp, k₁) + else + mul!(_vec(tmp), mass_matrix, _vec(k₁)) + end + + @.. linsolve_tmp = f₁ - tmp + + linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) + vecu = _vec(linres.u) + veck₂ = _vec(k₂) + + @.. veck₂ = vecu * neginvdtγ + veck₁ + integrator.stats.nsolve += 1 + + @.. u = uprev + dt * k₂ + stage_limiter!(u, integrator, p, t + dt) + step_limiter!(u, integrator, p, t + dt) + + if integrator.opts.adaptive + f(fsallast, u, p, t + dt) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + + if mass_matrix === I + @.. broadcast=false linsolve_tmp=fsallast - c₃₂ * (k₂ - f₁) - + 2(k₁ - fsalfirst) + dt * dT + else + @.. broadcast=false du2=c₃₂ * k₂ + 2k₁ + mul!(_vec(du1), mass_matrix, _vec(du2)) + @.. broadcast=false linsolve_tmp=fsallast - du1 + c₃₂ * f₁ + 2fsalfirst + + dt * dT + end + + linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) + vecu = _vec(linres.u) + veck3 = _vec(k₃) + @.. veck3 = vecu * neginvdtγ + + integrator.stats.nsolve += 1 + + if mass_matrix === I + @.. broadcast=false tmp=dto6 * (k₁ - 2 * k₂ + k₃) + else + veck₁ = _vec(k₁) + veck₂ = _vec(k₂) + veck₃ = _vec(k₃) + vectmp = _vec(tmp) + @.. broadcast=false vectmp=ifelse(cache.algebraic_vars, + false, dto6 * (veck₁ - 2 * veck₂ + veck₃)) + end + calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + + if mass_matrix !== I + algvar = reshape(cache.algebraic_vars, size(u)) + invatol = inv(integrator.opts.abstol) + @.. atmp = ifelse(algvar, fsallast, false) * invatol + integrator.EEst += integrator.opts.internalnorm(atmp, t) + end + end + cache.linsolve = linres.cache +end + +@muladd function perform_step!(integrator, cache::Rosenbrock32Cache, repeat_step = false) + @unpack t, dt, uprev, u, f, p, opts = integrator + @unpack k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W, tmp, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache + @unpack c₃₂, d = cache.tab + + # Assignments + sizeu = size(u) + mass_matrix = integrator.f.mass_matrix + + # Precalculations + dtγ = dt * d + neginvdtγ = -inv(dtγ) + dto2 = dt / 2 + dto6 = dt / 6 + + if repeat_step + f(integrator.fsalfirst, uprev, p, t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + end + + calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step) + + calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev, + integrator.opts.abstol, integrator.opts.reltol, + integrator.opts.internalnorm, t) + + if repeat_step + linres = dolinsolve( + integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), + du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, + solverdata = (; gamma = dtγ)) + else + linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), + du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, + solverdata = (; gamma = dtγ)) + end + + vecu = _vec(linres.u) + veck₁ = _vec(k₁) + + @.. veck₁ = vecu * neginvdtγ + integrator.stats.nsolve += 1 + + @.. broadcast=false u=uprev + dto2 * k₁ + stage_limiter!(u, integrator, p, t + dto2) + f(f₁, u, p, t + dto2) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + + if mass_matrix === I + tmp .= k₁ + else + mul!(_vec(tmp), mass_matrix, _vec(k₁)) + end + + @.. broadcast=false linsolve_tmp=f₁ - tmp + + linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) + vecu = _vec(linres.u) + veck₂ = _vec(k₂) + + @.. veck₂ = vecu * neginvdtγ + veck₁ + integrator.stats.nsolve += 1 + + @.. tmp = uprev + dt * k₂ + stage_limiter!(u, integrator, p, t + dt) + f(fsallast, tmp, p, t + dt) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + + if mass_matrix === I + @.. broadcast=false linsolve_tmp=fsallast - c₃₂ * (k₂ - f₁) - 2(k₁ - fsalfirst) + + dt * dT + else + @.. broadcast=false du2=c₃₂ * k₂ + 2k₁ + mul!(_vec(du1), mass_matrix, _vec(du2)) + @.. broadcast=false linsolve_tmp=fsallast - du1 + c₃₂ * f₁ + 2fsalfirst + dt * dT + end + + linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) + vecu = _vec(linres.u) + veck3 = _vec(k₃) + + @.. veck3 = vecu * neginvdtγ + integrator.stats.nsolve += 1 + + @.. broadcast=false u=uprev + dto6 * (k₁ + 4k₂ + k₃) + + step_limiter!(u, integrator, p, t + dt) + + if integrator.opts.adaptive + @.. broadcast=false tmp=dto6 * (k₁ - 2 * k₂ + k₃) + calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + + if mass_matrix !== I + invatol = inv(integrator.opts.abstol) + @.. atmp = ifelse(cache.algebraic_vars, fsallast, false) * invatol + integrator.EEst += integrator.opts.internalnorm(atmp, t) + end + end + cache.linsolve = linres.cache +end + +@muladd function perform_step!(integrator, cache::Rosenbrock23ConstantCache, + repeat_step = false) + @unpack t, dt, uprev, u, f, p = integrator + @unpack c₃₂, d, tf, uf = cache + + # Precalculations + dtγ = dt * d + neginvdtγ = -inv(dtγ) + dto2 = dt / 2 + dto6 = dt / 6 + + if repeat_step + integrator.fsalfirst = f(uprev, p, t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + end + + mass_matrix = integrator.f.mass_matrix + + # Time derivative + dT = calc_tderivative(integrator, cache) + + W = calc_W(integrator, cache, dtγ, repeat_step) + if !issuccess_W(W) + integrator.EEst = 2 + return nothing + end + + k₁ = _reshape(W \ _vec((integrator.fsalfirst + dtγ * dT)), axes(uprev)) * neginvdtγ + integrator.stats.nsolve += 1 + tmp = @.. uprev + dto2 * k₁ + f₁ = f(tmp, p, t + dto2) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + + if mass_matrix === I + k₂ = _reshape(W \ _vec(f₁ - k₁), axes(uprev)) + else + k₂ = _reshape(W \ _vec(f₁ - mass_matrix * k₁), axes(uprev)) + end + k₂ = @.. k₂ * neginvdtγ + k₁ + integrator.stats.nsolve += 1 + u = uprev + dt * k₂ + + if integrator.opts.adaptive + integrator.fsallast = f(u, p, t + dt) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + + if mass_matrix === I + linsolve_tmp = @.. (integrator.fsallast - c₃₂ * (k₂ - f₁) - + 2 * (k₁ - integrator.fsalfirst) + dt * dT) + else + linsolve_tmp = mass_matrix * (@.. c₃₂ * k₂ + 2 * k₁) + linsolve_tmp = @.. (integrator.fsallast - linsolve_tmp + + c₃₂ * f₁ + 2 * integrator.fsalfirst + dt * dT) + end + k₃ = _reshape(W \ _vec(linsolve_tmp), axes(uprev)) * neginvdtγ + integrator.stats.nsolve += 1 + + if u isa Number + utilde = dto6 * f.mass_matrix[1, 1] * (k₁ - 2 * k₂ + k₃) + else + utilde = f.mass_matrix * (@.. dto6 * (k₁ - 2 * k₂ + k₃)) + end + atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + + if mass_matrix !== I + invatol = inv(integrator.opts.abstol) + atmp = @. ifelse(integrator.differential_vars, false, integrator.fsallast) * + invatol + integrator.EEst += integrator.opts.internalnorm(atmp, t) + end + end + integrator.k[1] = k₁ + integrator.k[2] = k₂ + integrator.u = u + return nothing +end + +@muladd function perform_step!(integrator, cache::Rosenbrock32ConstantCache, + repeat_step = false) + @unpack t, dt, uprev, u, f, p = integrator + @unpack c₃₂, d, tf, uf = cache + + # Precalculations + dtγ = dt * d + neginvdtγ = -inv(dtγ) + dto2 = dt / 2 + dto6 = dt / 6 + + mass_matrix = integrator.f.mass_matrix + + if repeat_step + integrator.fsalfirst = f(uprev, p, t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + end + + # Time derivative + dT = calc_tderivative(integrator, cache) + + W = calc_W(integrator, cache, dtγ, repeat_step) + if !issuccess_W(W) + integrator.EEst = 2 + return nothing + end + + k₁ = _reshape(W \ -_vec((integrator.fsalfirst + dtγ * dT)), axes(uprev)) / dtγ + integrator.stats.nsolve += 1 + tmp = @.. uprev + dto2 * k₁ + f₁ = f(tmp, p, t + dto2) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + + if mass_matrix === I + k₂ = _reshape(W \ _vec(f₁ - k₁), axes(uprev)) + else + linsolve_tmp = f₁ - mass_matrix * k₁ + k₂ = _reshape(W \ _vec(linsolve_tmp), axes(uprev)) + end + k₂ = @.. k₂ * neginvdtγ + k₁ + + integrator.stats.nsolve += 1 + tmp = @.. uprev + dt * k₂ + integrator.fsallast = f(tmp, p, t + dt) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + + if mass_matrix === I + linsolve_tmp = @.. (integrator.fsallast - c₃₂ * (k₂ - f₁) - + 2(k₁ - integrator.fsalfirst) + dt * dT) + else + linsolve_tmp = mass_matrix * (@.. c₃₂ * k₂ + 2 * k₁) + linsolve_tmp = @.. (integrator.fsallast - linsolve_tmp + + c₃₂ * f₁ + 2 * integrator.fsalfirst + dt * dT) + end + k₃ = _reshape(W \ _vec(linsolve_tmp), axes(uprev)) * neginvdtγ + integrator.stats.nsolve += 1 + u = @.. uprev + dto6 * (k₁ + 4k₂ + k₃) + + if integrator.opts.adaptive + utilde = @.. dto6 * (k₁ - 2k₂ + k₃) + atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + + if mass_matrix !== I + invatol = inv(integrator.opts.abstol) + atmp = ifelse(integrator.differential_vars, false, integrator.fsallast) .* + invatol + integrator.EEst += integrator.opts.internalnorm(atmp, t) + end + end + + integrator.k[1] = k₁ + integrator.k[2] = k₂ + integrator.u = u + return nothing +end #### ROS2 type method diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl index e98bcf5d7f..7ae9b138e8 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl @@ -1,3 +1,12 @@ +struct RodasTableau{T, T2} + A::Matrix{T} + C::Matrix{T} + gamma::T + c::Vector{T2} + d::Vector{T} + H::Matrix{T} +end + struct Rosenbrock23Tableau{T} c₃₂::T d::T @@ -20,39 +29,30 @@ function Rosenbrock32Tableau(T) Rosenbrock32Tableau(c₃₂, d) end -struct ROS3PTableau{T, T2} - A::Matrix{T} - C::Matrix{T} - b::Vector{T} - btilde::Vector{T} - gamma::T2 - c::Vector{T2} - d::Vector{T} -end function ROS3PTableau(T, T2) gamma = convert(T, 1 / 2 + sqrt(3) / 6) igamma = inv(gamma) A = T[ - 0 0 0 - convert(T, igamma) 0 0 - convert(T, igamma) convert(T, 0) 0 + 0 0 0 + igamma 0 0 + igamma 0 0 ] tmp = -igamma * (convert(T, 2) - convert(T, 1 / 2) * igamma) C = T[ 0 0 0 - convert(T, -igamma^2) 0 0 - -igamma * (convert(T, 1) - tmp) tmp 0 + -igamma^2 0 0 + -igamma * (1 - tmp) tmp 0 ] - tmp = igamma * (convert(T, 2 / 3) - convert(T, 1 / 6) * igamma) - b = [(igamma * (convert(T, 1) + tmp)), (tmp), (convert(T, 1 / 3) * igamma)] + tmp = igamma * (convert(T, 2 // 3) - convert(T, 1 // 6) * igamma) + b = [(igamma * (convert(T, 1) + tmp)), (tmp), (convert(T, 1 // 3) * igamma)] # btilde1 = convert(T,2.113248654051871) # btilde2 = convert(T,1.000000000000000) # btilde3 = convert(T,0.4226497308103742) - btilde = [(convert(T, 2.113248654051871)), (convert(T, 1.000000000000000)), (convert(T, 0.4226497308103742))] - c = T[1, 1] + btilde = T[2.113248654051871, 1, 0.4226497308103742] + c = T2[1, 1] d = T[0.7886751345948129, -0.2113248654051871, -1.077350269189626] - ROS3PTableau(A, C, b, btilde, gamma, c, d) + RodasTableau(A, C, b, btilde, gamma, c, d) end struct Rodas3Tableau{T, T2} @@ -74,34 +74,24 @@ function Rodas3Tableau(T, T2) 2 0 1 ] C = T[ - 0 0 0 - 4 0 0 - 1 -1 0 - 1 -1 -8 // 3 + 0 0 0 + 4 0 0 + 1 -1 0 + 1 -1 -8 // 3 ] b = T[2, 0, 1, 1] btilde = T[0.0, 0.0, 0.0, 1.0] c = T[0.0, 1.0, 1.0] d = T[1 // 2, 3 // 2, 0, 0] - Rodas3Tableau(A, C, b, btilde, gamma, c, d) -end - -struct Rodas3PTableau{T, T2} - A::Matrix{T} - C::Matrix{T} - gamma::T - c::Vector{T2} - d::Vector{T} - h::Matrix{T} - h2_2::Vector{T} + RodasTableau(A, C, b, btilde, gamma, c, d) end function Rodas3PTableau(T, T2) gamma = convert(T, 1 // 3) A = T[ - 0 0 0 0 - 4.0 / 3.0 0 0 0 - 0 0 0 0 + 0 0 0 0 + 4 // 3 0 0 0 + 0 0 0 0 2.90625 3.375 0.40625 0 ] C = T[ @@ -111,15 +101,15 @@ function Rodas3PTableau(T, T2) 1.21875 5.0625 1.96875 0 4.03125 15.1875 4.03125 6.0 ] - c = T2[4.0 / 9.0, 0.0] - d = T[1.0 / 3.0, 1.0 / 9.0, 1.0] + c = T2[4 // 9, 0] + d = T[1 // 3, 1 // 9, 1] H = T[ - 0 0 0 0 0 - 1.78125 6.75 0.15625 6.0 1.0 - 4.21875 15.1875 3.09375 9.0 0.0 + 0 0 0 0 0 + 1.78125 6.75 0.15625 6 1 + 4.21875 15.1875 3.09375 9 0 ] h2_2 = T[4.21875, 2.025, 1.63125, 1.7, 0.1] - Rodas3PTableau(A, C, gamma, c, d, H, h2_2) + RodasTableau(A, C, gamma, c, d, H, h2_2) end @ROS2(:tableau) @@ -130,15 +120,6 @@ end @Rosenbrock4(:tableau) -struct RodasTableau{T, T2} - A::Matrix{T} - C::Matrix{T} - gamma::T - c::Vector{T2} - d::Vector{T} - H::Matrix{T} -end - function Rodas4Tableau(T, T2) gamma = convert(T, 1 // 4) #BET2P=0.0317D0 @@ -246,13 +227,13 @@ function Rodas5Tableau(T, T2) -14.09640773051259 6.925207756232704 -41.47510893210728 2.343771018586405 24.13215229196062 1 1 0 ] C = T[ - 0 0 0 0 0 0 0 - -10.31323885133993 0 0 0 0 0 0 - -21.04823117650003 -7.234992135176716 0 0 0 0 0 - 32.22751541853323 -4.943732386540191 19.44922031041879 0 0 0 0 - -20.69865579590063 -8.816374604402768 1.260436877740897 -0.7495647613787146 0 0 0 - -46.22004352711257 -17.49534862857472 -289.6389582892057 93.60855400400906 318.3822534212147 0 0 - 34.20013733472935 -14.15535402717690 57.82335640988400 25.83362985412365 1.408950972071624 -6.551835421242162 0 + 0 0 0 0 0 0 0 + -10.31323885133993 0 0 0 0 0 0 + -21.04823117650003 -7.234992135176716 0 0 0 0 0 + 32.22751541853323 -4.943732386540191 19.44922031041879 0 0 0 0 + -20.69865579590063 -8.816374604402768 1.260436877740897 -0.7495647613787146 0 0 0 + -46.22004352711257 -17.49534862857472 -289.6389582892057 93.60855400400906 318.3822534212147 0 0 + 34.20013733472935 -14.15535402717690 57.82335640988400 25.83362985412365 1.408950972071624 -6.551835421242162 0 42.57076742291101 -13.80770672017997 93.98938432427124 18.77919633714503 -31.58359187223370 -6.685968952921985 -5.810979938412932 ] c = T2[0, 0.38, 0.3878509998321533, 0.4839718937873840, 0.4570477008819580, 1, 1, 1] @@ -304,13 +285,13 @@ function Rodas5PTableau(T, T2) -7.502846399306121 2.561846144803919 -11.627539656261098 -0.18268767659942256 0.030198172008377946 1 1 0 ] C = T[ - 0 0 0 0 0 0 0 - -14.155112264123755 0 0 0 0 0 0 - -17.97296035885952 -2.859693295451294 0 0 0 0 0 - 147.12150275711716 -1.41221402718213 71.68940251302358 0 0 0 0 - 165.43517024871676 -0.4592823456491126 42.90938336958603 -5.961986721573306 0 0 0 - 24.854864614690072 -3.0009227002832186 47.4931110020768 5.5814197821558125 -0.6610691825249471 0 0 - 30.91273214028599 -3.1208243349937974 77.79954646070892 34.28646028294783 -19.097331116725623 -28.087943162872662 0 + 0 0 0 0 0 0 0 + -14.155112264123755 0 0 0 0 0 0 + -17.97296035885952 -2.859693295451294 0 0 0 0 0 + 147.12150275711716 -1.41221402718213 71.68940251302358 0 0 0 0 + 165.43517024871676 -0.4592823456491126 42.90938336958603 -5.961986721573306 0 0 0 + 24.854864614690072 -3.0009227002832186 47.4931110020768 5.5814197821558125 -0.6610691825249471 0 0 + 30.91273214028599 -3.1208243349937974 77.79954646070892 34.28646028294783 -19.097331116725623 -28.087943162872662 0 37.80277123390563 -3.2571969029072276 112.26918849496327 66.9347231244047 -40.06618937091002 -54.66780262877968 -9.48861652309627 ] c = T2[0, 0.6358126895828704, 0.4095798393397535, 0.9769306725060716, 0.4288403609558664, 1, 1, 1] @@ -370,4 +351,4 @@ beta3 = 6.8619167645278386e-2 beta4 = 0.8289547562599182 beta5 = 7.9630136489868164e-2 alpha64 = -0.2076823627400282 -=# \ No newline at end of file +=# diff --git a/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl b/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl index 59674e9683..d3967b40ee 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl @@ -1,3 +1,48 @@ +function _ode_addsteps!(k, t, uprev, u, dt, f, p, + cache::Union{Rosenbrock23ConstantCache, + Rosenbrock32ConstantCache}, + always_calc_begin = false, allow_calc_end = true, + force_calc_end = false) + if length(k) < 2 || always_calc_begin + @unpack tf, uf, d = cache + dtγ = dt * d + neginvdtγ = -inv(dtγ) + dto2 = dt / 2 + tf.u = uprev + if cache.autodiff isa AutoForwardDiff + dT = ForwardDiff.derivative(tf, t) + else + dT = FiniteDiff.finite_difference_derivative(tf, t, dir = sign(dt)) + end + + mass_matrix = f.mass_matrix + if uprev isa Number + J = ForwardDiff.derivative(uf, uprev) + W = neginvdtγ .+ J + else + J = ForwardDiff.jacobian(uf, uprev) + if mass_matrix isa UniformScaling + W = neginvdtγ * mass_matrix + J + else + W = @.. neginvdtγ * mass_matrix .+ J + end + end + f₀ = f(uprev, p, t) + k₁ = _reshape(W \ _vec((f₀ + dtγ * dT)), axes(uprev)) * neginvdtγ + tmp = @.. uprev + dto2 * k₁ + f₁ = f(tmp, p, t + dto2) + if mass_matrix === I + k₂ = _reshape(W \ _vec(f₁ - k₁), axes(uprev)) + else + k₂ = _reshape(W \ _vec(f₁ - mass_matrix * k₁), axes(uprev)) + end + k₂ = @.. k₂ * neginvdtγ + k₁ + copyat_or_push!(k, 1, k₁) + copyat_or_push!(k, 2, k₂) + end + nothing +end + function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::RosenbrockCombinedConstantCache, always_calc_begin = false, allow_calc_end = true, force_calc_end = false) From 9d2a4c83b92d64a9e92ba3658539f87d3cae609c Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 20 Sep 2024 16:17:07 -0400 Subject: [PATCH 3/6] undo Rosenbrock23 changes --- .../src/interp_func.jl | 11 ++ .../src/rosenbrock_interpolants.jl | 121 ++++++++++++++++++ 2 files changed, 132 insertions(+) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl b/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl index 077e27c7b0..43f66149fc 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl @@ -1,3 +1,14 @@ +function DiffEqBase.interp_summary(::Type{cacheType}, + dense::Bool) where { + cacheType <: + Union{Rosenbrock23ConstantCache, + Rosenbrock32ConstantCache, + Rosenbrock23Cache, + Rosenbrock32Cache}} + dense ? "specialized 2nd order \"free\" stiffness-aware interpolation" : + "1st order linear" +end + function DiffEqBase.interp_summary(::Type{cacheType}, dense::Bool) where { cacheType <: diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl index 20cf702e14..5a65299042 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl @@ -1,3 +1,124 @@ +### Fallbacks to capture +ROSENBROCKS_WITH_INTERPOLATIONS = Union{Rosenbrock23ConstantCache, Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache, + RosenbrockCombinedConstantCache, + RosenbrockCache} + +function _ode_interpolant(Θ, dt, y₀, y₁, k, + cache::ROSENBROCKS_WITH_INTERPOLATIONS, + idxs, T::Type{Val{D}}, differential_vars) where {D} + throw(DerivativeOrderNotPossibleError()) +end + +function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, + cache::ROSENBROCKS_WITH_INTERPOLATIONS, + idxs, T::Type{Val{D}}, differential_vars) where {D} + throw(DerivativeOrderNotPossibleError()) +end + +""" +From MATLAB ODE Suite by Shampine +""" +@def rosenbrock2332unpack begin + if cache isa OrdinaryDiffEqMutableCache + d = cache.tab.d + else + d = cache.d + end +end + +@def rosenbrock2332pre0 begin + @rosenbrock2332unpack + c1 = Θ * (1 - Θ) / (1 - 2d) + c2 = Θ * (Θ - 2d) / (1 - 2d) +end + +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, + Rosenbrock32ConstantCache}, idxs::Nothing, + T::Type{Val{0}}, differential_vars) + @rosenbrock2332pre0 + @inbounds y₀ + dt * (c1 * k[1] + c2 * k[2]) +end + +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23Cache, Rosenbrock32Cache}, + idxs::Nothing, T::Type{Val{0}}, differential_vars) + @rosenbrock2332pre0 + @inbounds @.. y₀+dt * (c1 * k[1] + c2 * k[2]) +end + +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache + }, idxs, T::Type{Val{0}}, differential_vars) + @rosenbrock2332pre0 + @.. y₀[idxs]+dt * (c1 * k[1][idxs] + c2 * k[2][idxs]) +end + +@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, + Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache + }, idxs::Nothing, T::Type{Val{0}}, differential_vars) + @rosenbrock2332pre0 + @inbounds @.. out=y₀ + dt * (c1 * k[1] + c2 * k[2]) + out +end + +@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, + Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache + }, idxs, T::Type{Val{0}}, differential_vars) + @rosenbrock2332pre0 + @views @.. out=y₀[idxs] + dt * (c1 * k[1][idxs] + c2 * k[2][idxs]) + out +end + +# First Derivative of the dense output +@def rosenbrock2332pre1 begin + @rosenbrock2332unpack + c1diff = (1 - 2 * Θ) / (1 - 2 * d) + c2diff = (2 * Θ - 2 * d) / (1 - 2 * d) +end + +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache + }, idxs::Nothing, T::Type{Val{1}}, differential_vars) + @rosenbrock2332pre1 + @.. c1diff * k[1]+c2diff * k[2] +end + +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache + }, idxs, T::Type{Val{1}}, differential_vars) + @rosenbrock2332pre1 + @.. c1diff * k[1][idxs]+c2diff * k[2][idxs] +end + +@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, + Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache + }, idxs::Nothing, T::Type{Val{1}}, differential_vars) + @rosenbrock2332pre1 + @.. out=c1diff * k[1] + c2diff * k[2] + out +end + +@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, + Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache + }, idxs, T::Type{Val{1}}, differential_vars) + @rosenbrock2332pre1 + @views @.. out=c1diff * k[1][idxs] + c2diff * k[2][idxs] + out +end + """ From MATLAB ODE Suite by Shampine """ From a8c0a52f8ea0a8413d13c42b23853be1b7719c45 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 20 Sep 2024 18:21:43 -0400 Subject: [PATCH 4/6] fixes --- .../src/interp_func.jl | 4 +- .../src/rosenbrock_caches.jl | 106 +++++++++--------- .../src/rosenbrock_perform_step.jl | 32 ++++-- .../src/rosenbrock_tableaus.jl | 87 ++++++-------- 4 files changed, 110 insertions(+), 119 deletions(-) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl b/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl index 43f66149fc..71b9435af6 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl @@ -9,11 +9,11 @@ function DiffEqBase.interp_summary(::Type{cacheType}, "1st order linear" end -function DiffEqBase.interp_summary(::Type{cacheType}, +function DiffEqBase.interp_summary(cache::Type{cacheType}, dense::Bool) where { cacheType <: Union{RosenbrockCombinedConstantCache, RosenbrockCache}} - dense ? "specialized $(cache.interp_order) order \"free\" stiffness-aware interpolation" : + dense ? "specialized ? order \"free\" stiffness-aware interpolation" : "1st order linear" end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl index 3079601ed9..4349415e2c 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl @@ -8,6 +8,54 @@ function get_fsalfirstlast(cache::GenericRosenbrockMutableCache, u) (cache.fsalfirst, cache.fsallast) end +mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabType, + TFType, UFType, F, JCType, GCType, RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache + u::uType + uprev::uType + dense::Vector{rateType} + du::rateType + du1::rateType + du2::rateType + ks::Vector{rateType} + fsalfirst::rateType + fsallast::rateType + dT::rateType + J::JType + W::WType + tmp::rateType + atmp::uNoUnitsType + weight::uNoUnitsType + tab::TabType + tf::TFType + uf::UFType + linsolve_tmp::rateType + linsolve::F + jac_config::JCType + grad_config::GCType + reltol::RTolType + alg::A + algebraic_vars::AV + step_limiter!::StepLimiter + stage_limiter!::StageLimiter + interp_order::Int +end + +function full_cache(c::RosenbrockCache) + return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2, + c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp] +end + +struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache + tf::TF + uf::UF + tab::Tab + J::JType + W::WType + linsolve::F + autodiff::AD + interp_order::Int +end + @cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType, TabType, TFType, UFType, F, JCType, GCType, RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache @@ -74,6 +122,10 @@ end stage_limiter!::StageLimiter end +function get_fsalfirstlast(cache::Union{Rosenbrock23Cache, Rosenbrock32Cache}, u) + (cache.fsalfirst, cache.fsallast) +end + function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, @@ -222,57 +274,6 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits}, alg_autodiff(alg)) end -################################################################################ - -# Shampine's Low-order Rosenbrocks -mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabType, - TFType, UFType, F, JCType, GCType, RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache - u::uType - uprev::uType - dense::Vector{rateType} - du::rateType - du1::rateType - du2::rateType - ks::Vector{rateType} - fsalfirst::rateType - fsallast::rateType - dT::rateType - J::JType - W::WType - tmp::rateType - atmp::uNoUnitsType - weight::uNoUnitsType - tab::TabType - tf::TFType - uf::UFType - linsolve_tmp::rateType - linsolve::F - jac_config::JCType - grad_config::GCType - reltol::RTolType - alg::A - algebraic_vars::AV - step_limiter!::StepLimiter - stage_limiter!::StageLimiter - interp_order::Int -end - -function full_cache(c::RosenbrockCache) - return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2, - c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp] -end - -struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache - tf::TF - uf::UF - tab::Tab - J::JType - W::WType - linsolve::F - autodiff::AD - interp_order::Int -end - @ROS2(:cache) ################################################################################ @@ -296,9 +297,6 @@ jac_cache(c::Rosenbrock4Cache) = (c.J, c.W) ############################################################################### -### Rodas methods -tabtype(::Rosenbrock23) = Rosenbrock23Tableau -tabtype(::Rosenbrock32) = Rosenbrock32Tableau tabtype(::Rodas23W) = Rodas23WTableau tabtype(::ROS3P) = ROS3PTableau tabtype(::Rodas3) = Rodas3Tableau diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index 438be408e2..9b2d292e10 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -434,7 +434,7 @@ end @muladd function perform_step!(integrator, cache::RosenbrockCombinedConstantCache, repeat_step = false) (;t, dt, uprev, u, f, p) = integrator (;tf, uf) = cache - (;A, C, gamma, c, d, H) = cache.tab + (;A, C, b, btilde, gamma, c, d, H) = cache.tab # Precalculations dtC = C ./ dt @@ -489,10 +489,17 @@ end integrator.stats.nsolve += 1 end #@show ks - u = u .+ ks[num_stages] + u = uprev + for i in 1:num_stages + u = @.. u + b[i] * ks[i] + end if integrator.opts.adaptive - atmp = calculate_residuals(ks[num_stages], uprev, u, integrator.opts.abstol, + utilde = uprev + for i in 1:num_stages + utilde = @.. utilde + btilde[i] * ks[i] + end + atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) integrator.EEst = integrator.opts.internalnorm(atmp, t) end @@ -538,7 +545,7 @@ end @muladd function perform_step!(integrator, cache::RosenbrockCache, repeat_step = false) (; t, dt, uprev, u, f, p) = integrator (; du, du1, du2, dT, J, W, uf, tf, ks, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter!) = cache - (; A, C, gamma, c, d, H) = cache.tab + (; A, C, b, btilde, gamma, c, d, H) = cache.tab # Assignments sizeu = size(u) @@ -549,6 +556,7 @@ end dtC = C .* inv(dt) dtd = dt .* d dtgamma = dt * gamma + utilde = du f(cache.fsalfirst, uprev, p, t) OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) @@ -572,8 +580,8 @@ end @.. $(_vec(ks[1])) = -linres.u integrator.stats.nsolve += 1 - - for stage in 2:length(ks) + num_stages = length(ks) + for stage in 2:num_stages u .= uprev for i in 1:(stage - 1) @.. u += A[stage, i] * ks[i] @@ -601,19 +609,25 @@ end @.. $(_vec(ks[stage])) = -linres.u integrator.stats.nsolve += 1 end - du .= ks[end] - u .+= ks[end] + u .= uprev + for i in 1:num_stages + @.. u += b[i] * ks[i] + end step_limiter!(u, integrator, p, t + dt) if integrator.opts.adaptive + utilde .= 0 + for i in 1:num_stages + @.. utilde += btilde[i] * ks[i] + end if (integrator.alg isa Rodas5Pe) @.. du = 0.2606326497975715 * ks[1] - 0.005158627295444251 * ks[2] + 1.3038988631109731 * ks[3] + 1.235000722062074 * ks[4] + -0.7931985603795049 * ks[5] - 1.005448461135913 * ks[6] - 0.18044626132120234 * ks[7] + 0.17051519239113755 * ks[8] end - calculate_residuals!(atmp, ks[end], uprev, u, integrator.opts.abstol, + calculate_residuals!(atmp, utilde, uprev, u, integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) integrator.EEst = integrator.opts.internalnorm(atmp, t) end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl index 7ae9b138e8..08710eaed1 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl @@ -1,6 +1,8 @@ struct RodasTableau{T, T2} A::Matrix{T} C::Matrix{T} + b::Vector{T} + btilde::Vector{T} gamma::T c::Vector{T2} d::Vector{T} @@ -40,29 +42,17 @@ function ROS3PTableau(T, T2) ] tmp = -igamma * (convert(T, 2) - convert(T, 1 / 2) * igamma) C = T[ - 0 0 0 + 0 0 0 -igamma^2 0 0 -igamma * (1 - tmp) tmp 0 ] tmp = igamma * (convert(T, 2 // 3) - convert(T, 1 // 6) * igamma) - b = [(igamma * (convert(T, 1) + tmp)), (tmp), (convert(T, 1 // 3) * igamma)] - # btilde1 = convert(T,2.113248654051871) - # btilde2 = convert(T,1.000000000000000) - # btilde3 = convert(T,0.4226497308103742) + b = T[igamma * (1 + tmp), tmp, igamma / 3] btilde = T[2.113248654051871, 1, 0.4226497308103742] - c = T2[1, 1] + c = T2[0, 1, 1] d = T[0.7886751345948129, -0.2113248654051871, -1.077350269189626] - RodasTableau(A, C, b, btilde, gamma, c, d) -end - -struct Rodas3Tableau{T, T2} - A::Matrix{T} - C::Matrix{T} - b::Vector{T} - btilde::Vector{T} - gamma::T2 - c::Vector{T2} - d::Vector{T} + H = zeros(T, 3, 3) + RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas3Tableau(T, T2) @@ -80,10 +70,11 @@ function Rodas3Tableau(T, T2) 1 -1 -8 // 3 ] b = T[2, 0, 1, 1] - btilde = T[0.0, 0.0, 0.0, 1.0] - c = T[0.0, 1.0, 1.0] + btilde = T[0, 0, 0, 1] + c = T[0, 1, 1] d = T[1 // 2, 3 // 2, 0, 0] - RodasTableau(A, C, b, btilde, gamma, c, d) + H = zeros(T, 3, 3) + RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas3PTableau(T, T2) @@ -91,7 +82,7 @@ function Rodas3PTableau(T, T2) A = T[ 0 0 0 0 4 // 3 0 0 0 - 0 0 0 0 + 4 // 3 0 0 0 2.90625 3.375 0.40625 0 ] C = T[ @@ -101,15 +92,16 @@ function Rodas3PTableau(T, T2) 1.21875 5.0625 1.96875 0 4.03125 15.1875 4.03125 6.0 ] - c = T2[4 // 9, 0] + b = A[end, :] + btilde = T[0, 0, 0, 1] + c = T2[0, 4 // 9, 1] d = T[1 // 3, 1 // 9, 1] H = T[ - 0 0 0 0 0 1.78125 6.75 0.15625 6 1 4.21875 15.1875 3.09375 9 0 ] h2_2 = T[4.21875, 2.025, 1.63125, 1.7, 0.1] - RodasTableau(A, C, gamma, c, d, H, h2_2) + RodasTableau(A, C, b, btilde, gamma, c, d, H)#, h2_2) end @ROS2(:tableau) @@ -141,11 +133,13 @@ function Rodas4Tableau(T, T2) 7.496443313967647 -10.24680431464352 -33.99990352819905 11.70890893206160 0 8.083246795921522 -7.981132988064893 -31.52159432874371 16.31930543123136 -6.058818238834054 ] + b = A[end, :] + btilde = T[0, 0, 0, 0, 0, 1] c = T2[0, 0.386, 0.21, 0.63, 1, 1] d = T[0.25, -0.1043, 0.1035, -0.0362, 0, 0] H = T[10.12623508344586 -7.487995877610167 -34.80091861555747 -7.992771707568823 1.025137723295662 0 -0.6762803392801253 6.087714651680015 16.43084320892478 24.76722511418386 -6.594389125716872 0] - RodasTableau(A, C, gamma, c, d, H) + RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas42Tableau(T, T2) @@ -162,11 +156,13 @@ function Rodas42Tableau(T, T2) -32.64449927841361 -99.35311008728094 49.99119122405989 0 0 -76.46023087151691 -278.5942120829058 153.9294840910643 10.97101866258358 0 -76.29701586804983 -294.2795630511232 162.0029695867566 23.65166903095270 -7.652977706771382] + b = A[end, :] + btilde = T[0, 0, 0, 0, 0, 1] c = T2[0, 0.3507221, 0.2557041, 0.681779, 1, 1] d = T[0.25, -0.0690221, -0.0009672, -0.087979, 0, 0] H = T[-38.71940424117216 -135.8025833007622 64.51068857505875 -4.192663174613162 -2.531932050335060 0 -14.99268484949843 -76.30242396627033 58.65928432851416 16.61359034616402 -0.6758691794084156 0] - RodasTableau(A, C, gamma, c, d, H) + RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas4PTableau(T, T2) @@ -186,11 +182,13 @@ function Rodas4PTableau(T, T2) 10.81793056857153 6.780270611428266 19.53485944642410 0 0 34.19095006749676 15.49671153725963 54.74760875964130 14.16005392148534 0 34.62605830930532 15.30084976114473 56.99955578662667 18.40807009793095 -5.714285714285717] + b = A[end, :] + btilde = T[0, 0, 0, 0, 0, 1] c = T2[0, 0.75, 0.21, 0.63, 1, 1] d = T[0.25, -0.5, -0.023504, -0.0362, 0, 0] H = T[25.09876703708589 11.62013104361867 28.49148307714626 -5.664021568594133 0 0 1.638054557396973 -0.7373619806678748 8.477918219238990 15.99253148779520 -1.882352941176471 0] - RodasTableau(A, C, gamma, c, d, H) + RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas4P2Tableau(T, T2) @@ -207,6 +205,8 @@ function Rodas4P2Tableau(T, T2) -8.575016317114033 -7.606483992117508 12.224997650124820 0 0 -5.888975457523102 -8.157396617841821 24.805546872612922 12.790401512796979 0 -4.408651676063871 -6.692003137674639 24.625568527593117 16.627521966636085 -5.714285714285718] + b = A[end, :] + btilde = T[0, 0, 0, 0, 0, 1] c = T2[0, 0.75, 0.321448134013046, 0.519745732277726, 1, 1] d = T[0.25, -0.5, -0.189532918363016, 0.085612108792769, 0, 0] H = [-5.323528268423303 -10.042123754867493 17.175254928256965 -5.079931171878093 -0.016185991706112 0 @@ -236,6 +236,8 @@ function Rodas5Tableau(T, T2) 34.20013733472935 -14.15535402717690 57.82335640988400 25.83362985412365 1.408950972071624 -6.551835421242162 0 42.57076742291101 -13.80770672017997 93.98938432427124 18.77919633714503 -31.58359187223370 -6.685968952921985 -5.810979938412932 ] + b = A[end, :] + btilde = T[0, 0, 0, 0, 0, 0, 0, 1] c = T2[0, 0.38, 0.3878509998321533, 0.4839718937873840, 0.4570477008819580, 1, 1, 1] d = T[gamma, -0.1823079225333714636, -0.319231832186874912, 0.3449828624725343, -0.377417564392089818, 0, 0, 0] @@ -244,32 +246,7 @@ function Rodas5Tableau(T, T2) 44.19024239501722 1.3677947663381929e-13 202.93261852171622 -35.5669339789154 -181.91095152160645 3.4116351403665033 2.5793540257308067 2.2435122582734066 -44.0988150021747 -5.755396159656812e-13 -181.26175034586677 56.99302194811676 183.21182741427398 -7.480257918273637 -5.792426076169686 -5.32503859794143 ] - # println("---Rodas5---") - - #= - a71 = -14.09640773051259 - a72 = 6.925207756232704 - a73 = -41.47510893210728 - a74 = 2.343771018586405 - a75 = 24.13215229196062 - a76 = convert(T,1) - a81 = -14.09640773051259 - a82 = 6.925207756232704 - a83 = -41.47510893210728 - a84 = 2.343771018586405 - a85 = 24.13215229196062 - a86 = convert(T,1) - a87 = convert(T,1) - b1 = -14.09640773051259 - b2 = 6.925207756232704 - b3 = -41.47510893210728 - b4 = 2.343771018586405 - b5 = 24.13215229196062 - b6 = convert(T,1) - b7 = convert(T,1) - b8 = convert(T,1) - =# - RodasTableau(A, C, gamma, d, c, H) + RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas5PTableau(T, T2) @@ -294,6 +271,8 @@ function Rodas5PTableau(T, T2) 30.91273214028599 -3.1208243349937974 77.79954646070892 34.28646028294783 -19.097331116725623 -28.087943162872662 0 37.80277123390563 -3.2571969029072276 112.26918849496327 66.9347231244047 -40.06618937091002 -54.66780262877968 -9.48861652309627 ] + b = A[end, :] + btilde = T[0, 0, 0, 0, 0, 0, 0, 1] c = T2[0, 0.6358126895828704, 0.4095798393397535, 0.9769306725060716, 0.4288403609558664, 1, 1, 1] d = T[0.21193756319429014, -0.42387512638858027, -0.3384627126235924, 1.8046452872882734, 2.325825639765069, 0, 0, 0] H = T[ @@ -301,7 +280,7 @@ function Rodas5PTableau(T, T2) -9.91568850695171 -0.9689944594115154 3.0438037242978453 -24.495224566215796 20.176138334709044 15.98066361424651 -6.789040303419874 -6.710236069923372 11.419903575922262 2.8879645146136994 72.92137995996029 80.12511834622643 -52.072871366152654 -59.78993625266729 -0.15582684282751913 4.883087185713722 ] - RodasTableau(A, C, gamma, d, c, H) + RodasTableau(A, C, b, btilde, gamma, c, d, H) end @RosenbrockW6S4OS(:tableau) From 3e9e51c40086c828ba1af8ba21441195feda8e25 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Sat, 21 Sep 2024 01:25:14 -0400 Subject: [PATCH 5/6] fixes --- .../src/rosenbrock_caches.jl | 5 ++- .../src/rosenbrock_perform_step.jl | 16 +++----- .../src/rosenbrock_tableaus.jl | 41 +++++++++++-------- 3 files changed, 32 insertions(+), 30 deletions(-) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl index 4349415e2c..df02b82e5c 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl @@ -300,6 +300,7 @@ jac_cache(c::Rosenbrock4Cache) = (c.J, c.W) tabtype(::Rodas23W) = Rodas23WTableau tabtype(::ROS3P) = ROS3PTableau tabtype(::Rodas3) = Rodas3Tableau +tabtype(::Rodas3P) = Rodas3PTableau tabtype(::Rodas4) = Rodas4Tableau tabtype(::Rodas42) = Rodas42Tableau tabtype(::Rodas4P) = Rodas4PTableau @@ -310,7 +311,7 @@ tabtype(::Rodas5Pr) = Rodas5PTableau tabtype(::Rodas5Pe) = Rodas5PTableau function alg_cache( - alg::Union{ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, + alg::Union{ROS3P, Rodas3, Rodas3P, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, @@ -353,7 +354,7 @@ function alg_cache( end function alg_cache( - alg::Union{ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, + alg::Union{ROS3P, Rodas3, Rodas3P, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index 9b2d292e10..b9cf1df1db 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -495,7 +495,7 @@ end end if integrator.opts.adaptive - utilde = uprev + utilde = zero(u) for i in 1:num_stages utilde = @.. utilde + btilde[i] * ks[i] end @@ -592,14 +592,10 @@ end OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) du1 .= 0 - if mass_matrix === I - for i in 1:(stage - 1) - @.. du1 += dtC[stage, i] * ks[i] - end - else - for i in 1:(stage - 1) - @.. du1 += dtC[stage, i] * ks[i] - end + for i in 1:(stage - 1) + @.. du1 += dtC[stage, i] * ks[i] + end + if mass_matrix !== I mul!(_vec(du2), mass_matrix, _vec(du1)) du1 .= du2 end @@ -617,7 +613,7 @@ end step_limiter!(u, integrator, p, t + dt) if integrator.opts.adaptive - utilde .= 0 + @.. utilde = 0 * u for i in 1:num_stages @.. utilde += btilde[i] * ks[i] end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl index 08710eaed1..fa20a973c7 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl @@ -51,17 +51,17 @@ function ROS3PTableau(T, T2) btilde = T[2.113248654051871, 1, 0.4226497308103742] c = T2[0, 1, 1] d = T[0.7886751345948129, -0.2113248654051871, -1.077350269189626] - H = zeros(T, 3, 3) + H = zeros(T, 2, 3) RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas3Tableau(T, T2) gamma = convert(T, 1 // 2) A = T[ - 0 0 0 - 0 0 0 - 2 0 0 - 2 0 1 + 0 0 0 0 + 0 0 0 0 + 2 0 0 0 + 2 0 1 0 ] C = T[ 0 0 0 @@ -71,31 +71,32 @@ function Rodas3Tableau(T, T2) ] b = T[2, 0, 1, 1] btilde = T[0, 0, 0, 1] - c = T[0, 1, 1] + c = T[0, 0, 1, 1] d = T[1 // 2, 3 // 2, 0, 0] - H = zeros(T, 3, 3) + H = zeros(T, 2, 4) RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas3PTableau(T, T2) gamma = convert(T, 1 // 3) A = T[ - 0 0 0 0 - 4 // 3 0 0 0 - 4 // 3 0 0 0 - 2.90625 3.375 0.40625 0 + 0 0 0 0 0 + 4 // 3 0 0 0 0 + 4 // 3 0 0 0 0 + 2.90625 3.375 0.40625 0 0 + 2.90625 3.375 0.40625 0 0 ] C = T[ 0 0 0 0 - 4.0 0 0 0 + -4.0 0 0 0 8.25 6.75 0 0 - 1.21875 5.0625 1.96875 0 - 4.03125 15.1875 4.03125 6.0 + 1.21875 -5.0625 -1.96875 0 + 4.03125 -15.1875 -4.03125 6.0 ] - b = A[end, :] - btilde = T[0, 0, 0, 1] - c = T2[0, 4 // 9, 1] - d = T[1 // 3, 1 // 9, 1] + b = T[2.90625, 3.375, 0.40625, 1, 0] + btilde = T[0, 0, 0, 1, -1] + c = T2[0, 4 // 9, 4 // 9, 1, 1] + d = T[1 // 3, -1 // 9, 1, 0, 0] H = T[ 1.78125 6.75 0.15625 6 1 4.21875 15.1875 3.09375 9 0 @@ -104,6 +105,10 @@ function Rodas3PTableau(T, T2) RodasTableau(A, C, b, btilde, gamma, c, d, H)#, h2_2) end +function Rodas23WTableau(T, T2) + tab = Rodas3PTableau(T, T2) + RodasTableau(tab.A, tab.C, tab.btilde, tab.b, tab.gamma, tab.c, tab.d, tab.H)#, h2_2) +end @ROS2(:tableau) @ROS23(:tableau) From bc3d1a4ecda1bbe1e4241450c1540548ebef0a73 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 27 Sep 2024 17:39:48 -0400 Subject: [PATCH 6/6] fix ROS3P, RODAS3P --- .../src/rosenbrock_perform_step.jl | 1 - .../src/rosenbrock_tableaus.jl | 43 ++++++++++--------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index b9cf1df1db..63d063cd25 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -488,7 +488,6 @@ end ks = Base.setindex(ks, _reshape(W \ -_vec(linsolve_tmp), axes(uprev)), stage) integrator.stats.nsolve += 1 end - #@show ks u = uprev for i in 1:num_stages u = @.. u + b[i] * ks[i] diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl index fa20a973c7..0acafc189b 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl @@ -33,22 +33,21 @@ end function ROS3PTableau(T, T2) - gamma = convert(T, 1 / 2 + sqrt(3) / 6) - igamma = inv(gamma) + sqrt3 = convert(T, -sqrt(3)) + gamma = convert(T, 0.5 + sqrt3 / 6) + igamma = 3 - sqrt3 A = T[ 0 0 0 igamma 0 0 igamma 0 0 ] - tmp = -igamma * (convert(T, 2) - convert(T, 1 / 2) * igamma) C = T[ - 0 0 0 - -igamma^2 0 0 - -igamma * (1 - tmp) tmp 0 + 0 0 0 + -igamma^2 0 0 + 2*sqrt3 -sqrt3 0 ] - tmp = igamma * (convert(T, 2 // 3) - convert(T, 1 // 6) * igamma) - b = T[igamma * (1 + tmp), tmp, igamma / 3] - btilde = T[2.113248654051871, 1, 0.4226497308103742] + b = T[2, inv(sqrt3), igamma / 3] + btilde = b .- T[2.113248654051871, 1, 0.4226497308103742] c = T2[0, 1, 1] d = T[0.7886751345948129, -0.2113248654051871, -1.077350269189626] H = zeros(T, 2, 3) @@ -80,20 +79,20 @@ end function Rodas3PTableau(T, T2) gamma = convert(T, 1 // 3) A = T[ - 0 0 0 0 0 - 4 // 3 0 0 0 0 - 4 // 3 0 0 0 0 - 2.90625 3.375 0.40625 0 0 - 2.90625 3.375 0.40625 0 0 + 0 0 0 0 0 + 4 // 3 0 0 0 0 + 0 0 0 0 0 + 2.90625 3.375 0.40625 0 0 + 2.90625 3.375 0.40625 0 0 ] C = T[ - 0 0 0 0 - -4.0 0 0 0 - 8.25 6.75 0 0 - 1.21875 -5.0625 -1.96875 0 - 4.03125 -15.1875 -4.03125 6.0 + 0 0 0 0 + -4 0 0 0 + 8.25 6.75 0 0 + 1.21875 -5.0625 -1.96875 0 + 4.03125 -15.1875 -4.03125 6 ] - b = T[2.90625, 3.375, 0.40625, 1, 0] + b = T[2.90625, 3.375, 0.40625, 0, 1] btilde = T[0, 0, 0, 1, -1] c = T2[0, 4 // 9, 4 // 9, 1, 1] d = T[1 // 3, -1 // 9, 1, 0, 0] @@ -107,7 +106,9 @@ end function Rodas23WTableau(T, T2) tab = Rodas3PTableau(T, T2) - RodasTableau(tab.A, tab.C, tab.btilde, tab.b, tab.gamma, tab.c, tab.d, tab.H)#, h2_2) + b = T[2.90625, 3.375, 0.40625, 1, 0] + btilde = T[0, 0, 0, 1, -1] + RodasTableau(tab.A, tab.C, b, btilde, tab.gamma, tab.c, tab.d, tab.H)#, h2_2) end @ROS2(:tableau)