From 81ce811099acd511be85f245028fca28cf074bbe Mon Sep 17 00:00:00 2001 From: Gerd Steinebach Date: Fri, 22 Dec 2023 09:14:32 +0100 Subject: [PATCH 1/2] New methods Rodas23W/Rodas3P with error control for interpolation See issue https://github.com/SciML/OrdinaryDiffEq.jl/issues/2054 --- src/OrdinaryDiffEq.jl | 2 +- src/alg_utils.jl | 4 + src/algorithms.jl | 13 +- src/caches/rosenbrock_caches.jl | 209 ++++++++ src/dense/rosenbrock_interpolants.jl | 35 +- src/dense/stiff_addsteps.jl | 349 ++++++++++++++ src/derivative_utils.jl | 2 +- src/integrators/integrator_interface.jl | 2 +- src/interp_func.jl | 4 +- src/perform_step/rosenbrock_perform_step.jl | 505 +++++++++++++++++++- src/tableaus/rosenbrock_tableaus.jl | 82 ++++ 11 files changed, 1182 insertions(+), 25 deletions(-) diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index 4cefb0aec1..90359e7296 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -386,7 +386,7 @@ export MagnusMidpoint, LinearExponential, MagnusLeapfrog, LieEuler, CayleyEuler, MagnusAdapt4, RKMK2, RKMK4, LieRK4, CG2, CG3, CG4a export Rosenbrock23, Rosenbrock32, RosShamp4, Veldd4, Velds4, GRK4T, GRK4A, - Ros4LStab, ROS3P, Rodas3, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, + Ros4LStab, ROS3P, Rodas3, Rodas23W, Rodas3P, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, RosenbrockW6S4OS, ROS34PW1a, ROS34PW1b, ROS34PW2, ROS34PW3 export LawsonEuler, NorsettEuler, ETD1, ETDRK2, ETDRK3, ETDRK4, HochOst4, Exp4, EPIRK4s3A, diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 33f7f0b02b..ec19eb34bf 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -30,6 +30,8 @@ isfsal(tab::DiffEqBase.ExplicitRKTableau) = tab.fsal # isfsal(alg::CompositeAlgorithm) = isfsal(alg.algs[alg.current]) isfsal(alg::FunctionMap) = false +isfsal(alg::Rodas3P) = false +isfsal(alg::Rodas23W) = false isfsal(alg::Rodas5) = false isfsal(alg::Rodas5P) = false isfsal(alg::Rodas4) = false @@ -620,9 +622,11 @@ alg_order(alg::Feagin14) = 14 alg_order(alg::PFRK87) = 8 alg_order(alg::Rosenbrock23) = 2 +alg_order(alg::Rodas23W) = 3 alg_order(alg::Rosenbrock32) = 3 alg_order(alg::ROS3P) = 3 alg_order(alg::Rodas3) = 3 +alg_order(alg::Rodas3P) = 3 alg_order(alg::ROS34PW1a) = 3 alg_order(alg::ROS34PW1b) = 3 alg_order(alg::ROS34PW2) = 3 diff --git a/src/algorithms.jl b/src/algorithms.jl index a30e544070..cbf8391ac7 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -2910,6 +2910,11 @@ Scientific Computing, 18 (1), pp. 1-22. - Kaps, P. & Rentrop, Generalized Runge-Kutta methods of order four with stepsize control for stiff ordinary differential equations. P. Numer. Math. (1979) 33: 55. doi:10.1007/BF01396495 +#### Rodas23W, Rodas3P + +- Steinebach G., Rodas23W / Rodas32P - a Rosenbrock-type method for DAEs with additional error estimate for dense output and Julia implementation, + in progress + #### Rodas4P - Steinebach G. Order-reduction of ROW-methods for DAEs and method of lines @@ -2921,10 +2926,14 @@ Scientific Computing, 18 (1), pp. 1-22. Differential-Algebraic Equations Forum. Springer, Cham. https://doi.org/10.1007/978-3-030-53905-4_6 #### Rodas5 - - Di Marzo G. RODAS5(4) – Méthodes de Rosenbrock d’ordre 5(4) adaptées aux problemes différentiels-algébriques. MSc mathematics thesis, Faculty of Science, University of Geneva, Switzerland. + +#### Rodas5P +- Steinebach G. Construction of Rosenbrock–Wanner method Rodas5P and numerical benchmarks within the Julia Differential Equations package. + In: BIT Numerical Mathematics, 63(2), 2023 + =# for Alg in [ @@ -2942,6 +2951,8 @@ for Alg in [ :GRK4T, :GRK4A, :Ros4LStab, + :Rodas23W, + :Rodas3P, :Rodas4, :Rodas42, :Rodas4P, diff --git a/src/caches/rosenbrock_caches.jl b/src/caches/rosenbrock_caches.jl index d553078f32..6671d1db05 100644 --- a/src/caches/rosenbrock_caches.jl +++ b/src/caches/rosenbrock_caches.jl @@ -418,6 +418,215 @@ jac_cache(c::Rosenbrock4Cache) = (c.J, c.W) ### Rodas methods +struct Rodas23WConstantCache{TF, UF, Tab, JType, WType, F, AD} <: OrdinaryDiffEqConstantCache + tf::TF + uf::UF + tab::Tab + J::JType + W::WType + linsolve::F + autodiff::AD +end + +struct Rodas3PConstantCache{TF, UF, Tab, JType, WType, F, AD} <: OrdinaryDiffEqConstantCache + tf::TF + uf::UF + tab::Tab + J::JType + W::WType + linsolve::F + autodiff::AD +end + +@cache mutable struct Rodas23WCache{uType, rateType, uNoUnitsType, JType, WType, TabType, + TFType, UFType, F, JCType, GCType, RTolType, A} <: + RosenbrockMutableCache + u::uType + uprev::uType + dense1::rateType + dense2::rateType + dense3::rateType + du::rateType + du1::rateType + du2::rateType + k1::rateType + k2::rateType + k3::rateType + k4::rateType + k5::rateType + fsalfirst::rateType + fsallast::rateType + dT::rateType + J::JType + W::WType + tmp::rateType + atmp::uNoUnitsType + weight::uNoUnitsType + tab::TabType + tf::TFType + uf::UFType + linsolve_tmp::rateType + linsolve::F + jac_config::JCType + grad_config::GCType + reltol::RTolType + alg::A +end + +@cache mutable struct Rodas3PCache{uType, rateType, uNoUnitsType, JType, WType, TabType, + TFType, UFType, F, JCType, GCType, RTolType, A} <: + RosenbrockMutableCache + u::uType + uprev::uType + dense1::rateType + dense2::rateType + dense3::rateType + du::rateType + du1::rateType + du2::rateType + k1::rateType + k2::rateType + k3::rateType + k4::rateType + k5::rateType + fsalfirst::rateType + fsallast::rateType + dT::rateType + J::JType + W::WType + tmp::rateType + atmp::uNoUnitsType + weight::uNoUnitsType + tab::TabType + tf::TFType + uf::UFType + linsolve_tmp::rateType + linsolve::F + jac_config::JCType + grad_config::GCType + reltol::RTolType + alg::A +end + +function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, + dt, reltol, p, calck, + ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + dense1 = zero(rate_prototype) + dense2 = zero(rate_prototype) + dense3 = zero(rate_prototype) + du = zero(rate_prototype) + du1 = zero(rate_prototype) + du2 = zero(rate_prototype) + k1 = zero(rate_prototype) + k2 = zero(rate_prototype) + k3 = zero(rate_prototype) + k4 = zero(rate_prototype) + k5 = zero(rate_prototype) + fsalfirst = zero(rate_prototype) + fsallast = zero(rate_prototype) + dT = zero(rate_prototype) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) + tmp = zero(rate_prototype) + atmp = similar(u, uEltypeNoUnits) + recursivefill!(atmp, false) + weight = similar(u, uEltypeNoUnits) + recursivefill!(weight, false) + tab = Rodas3PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + + tf = TimeGradientWrapper(f, uprev, p) + uf = UJacobianWrapper(f, t, p) + linsolve_tmp = zero(rate_prototype) + linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) + Pl, Pr = wrapprecs(alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, + nothing)..., weight, tmp) + linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, + Pl = Pl, Pr = Pr, + assumptions = LinearSolve.OperatorAssumptions(true)) + grad_config = build_grad_config(alg, f, tf, du1, t) + jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + Rodas23WCache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5, + fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, + linsolve, jac_config, grad_config, reltol, alg) +end + +TruncatedStacktraces.@truncate_stacktrace Rodas23WCache 1 +function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, + dt, reltol, p, calck, + ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + dense1 = zero(rate_prototype) + dense2 = zero(rate_prototype) + dense3 = zero(rate_prototype) + du = zero(rate_prototype) + du1 = zero(rate_prototype) + du2 = zero(rate_prototype) + k1 = zero(rate_prototype) + k2 = zero(rate_prototype) + k3 = zero(rate_prototype) + k4 = zero(rate_prototype) + k5 = zero(rate_prototype) + fsalfirst = zero(rate_prototype) + fsallast = zero(rate_prototype) + dT = zero(rate_prototype) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) + tmp = zero(rate_prototype) + atmp = similar(u, uEltypeNoUnits) + recursivefill!(atmp, false) + weight = similar(u, uEltypeNoUnits) + recursivefill!(weight, false) + tab = Rodas3PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + + tf = TimeGradientWrapper(f, uprev, p) + uf = UJacobianWrapper(f, t, p) + linsolve_tmp = zero(rate_prototype) + linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) + Pl, Pr = wrapprecs(alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, + nothing)..., weight, tmp) + linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, + Pl = Pl, Pr = Pr, + assumptions = LinearSolve.OperatorAssumptions(true)) + grad_config = build_grad_config(alg, f, tf, du1, t) + jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) + Rodas3PCache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5, + fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, + linsolve, jac_config, grad_config, reltol, alg) +end + +TruncatedStacktraces.@truncate_stacktrace Rodas3PCache 1 + +function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, + dt, reltol, p, calck, + ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tf = TimeDerivativeWrapper(f, u, p) + uf = UDerivativeWrapper(f, t, p) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) + linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) + linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) + Rodas23WConstantCache(tf, uf, + Rodas3PTableau(constvalue(uBottomEltypeNoUnits), + constvalue(tTypeNoUnits)), J, W, linsolve, + alg_autodiff(alg)) +end + +function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, + dt, reltol, p, calck, + ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tf = TimeDerivativeWrapper(f, u, p) + uf = UDerivativeWrapper(f, t, p) + J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) + linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) + linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) + Rodas3PConstantCache(tf, uf, + Rodas3PTableau(constvalue(uBottomEltypeNoUnits), + constvalue(tTypeNoUnits)), J, W, linsolve, + alg_autodiff(alg)) +end + +### Rodas4 methods + struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F, AD} <: OrdinaryDiffEqConstantCache tf::TF uf::UF diff --git a/src/dense/rosenbrock_interpolants.jl b/src/dense/rosenbrock_interpolants.jl index 7247938945..762dfe10b7 100644 --- a/src/dense/rosenbrock_interpolants.jl +++ b/src/dense/rosenbrock_interpolants.jl @@ -113,34 +113,35 @@ end """ From MATLAB ODE Suite by Shampine """ -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rodas4ConstantCache, +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Union{Rodas4ConstantCache, Rodas23WConstantCache, Rodas3PConstantCache}, idxs::Nothing, T::Type{Val{0}}) Θ1 = 1 - Θ @inbounds Θ1 * y₀ + Θ * (y₁ + Θ1 * (k[1] + Θ * k[2])) end -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rodas4Cache, idxs::Nothing, - T::Type{Val{0}}) +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Union{Rodas4Cache, Rodas23WCache, Rodas3PCache}, + idxs::Nothing, T::Type{Val{0}}) Θ1 = 1 - Θ @inbounds @.. broadcast=false Θ1 * y₀+Θ * (y₁ + Θ1 * (k[1] + Θ * k[2])) end @muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, Rodas4Cache}, idxs, - T::Type{Val{0}}) + cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, + idxs, T::Type{Val{0}}) Θ1 = 1 - Θ @.. broadcast=false Θ1 * y₀[idxs]+Θ * (y₁[idxs] + Θ1 * (k[1][idxs] + Θ * k[2][idxs])) end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, Rodas4Cache}, + cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, idxs::Nothing, T::Type{Val{0}}) Θ1 = 1 - Θ @.. broadcast=false out=Θ1 * y₀ + Θ * (y₁ + Θ1 * (k[1] + Θ * k[2])) out end -@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache::Rodas4Cache{<:Array}, +@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, + cache::Union{Rodas4Cache{<:Array}, Rodas23WCache{<:Array}, Rodas3PCache{<:Array}}, idxs::Nothing, T::Type{Val{0}}) Θ1 = 1 - Θ @inbounds @simd ivdep for i in eachindex(out) @@ -150,8 +151,8 @@ end end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, Rodas4Cache}, idxs, - T::Type{Val{0}}) + cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, + idxs, T::Type{Val{0}}) Θ1 = 1 - Θ @views @.. broadcast=false out=Θ1 * y₀[idxs] + Θ * (y₁[idxs] + Θ1 * (k[1][idxs] + Θ * k[2][idxs])) @@ -159,27 +160,27 @@ end end # First Derivative -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rodas4ConstantCache, +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Union{Rodas4ConstantCache, Rodas23WConstantCache, Rodas3PConstantCache}, idxs::Nothing, T::Type{Val{1}}) @inbounds (k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁) / dt end -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rodas4Cache, idxs::Nothing, - T::Type{Val{1}}) +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Union{Rodas4Cache, Rodas23WCache, Rodas3PCache}, + idxs::Nothing, T::Type{Val{1}}) @inbounds @.. broadcast=false (k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁)/dt end @muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, Rodas4Cache}, idxs, - T::Type{Val{1}}) + cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, + idxs, T::Type{Val{1}}) @.. broadcast=false (k[1][idxs] + Θ * (-2 * k[1][idxs] + 2 * k[2][idxs] - 3 * k[2][idxs] * Θ) - y₀[idxs] + y₁[idxs])/dt end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, Rodas4Cache}, + cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, idxs::Nothing, T::Type{Val{1}}) @.. broadcast=false out=(k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁) / dt @@ -187,8 +188,8 @@ end end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, Rodas4Cache}, idxs, - T::Type{Val{1}}) + cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, + idxs, T::Type{Val{1}}) @views @.. broadcast=false out=(k[1][idxs] + Θ * (-2 * k[1][idxs] + 2 * k[2][idxs] - diff --git a/src/dense/stiff_addsteps.jl b/src/dense/stiff_addsteps.jl index 2dfbe9ba2f..699252c9fe 100644 --- a/src/dense/stiff_addsteps.jl +++ b/src/dense/stiff_addsteps.jl @@ -147,6 +147,355 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rosenbrock23Cache{<:Arr nothing end +function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Union{Rodas23WConstantCache, Rodas3PConstantCache}, + always_calc_begin = false, allow_calc_end = true, + force_calc_end = false) + if length(k) < 2 || always_calc_begin + @unpack tf, uf = cache + @unpack a21, a41, a42, a43, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, gamma, c2, c3, d1, d2, d3 = cache.tab + + # Precalculations + dtC21 = C21 / dt + dtC31 = C31 / dt + dtC32 = C32 / dt + dtC41 = C41 / dt + dtC42 = C42 / dt + dtC43 = C43 / dt + dtC51 = C51 / dt + dtC52 = C52 / dt + dtC53 = C53 / dt + dtC54 = C54 / dt + + dtd1 = dt * d1 + dtd2 = dt * d2 + dtd3 = dt * d3 + dtgamma = dt * gamma + mass_matrix = f.mass_matrix + + # Time derivative + tf.u = uprev + if cache.autodiff isa AutoForwardDiff + dT = ForwardDiff.derivative(tf, t) + else + dT = FiniteDiff.finite_difference_derivative(tf, t, dir = sign(dt)) + end + + # Jacobian + uf.t = t + if uprev isa AbstractArray + J = ForwardDiff.jacobian(uf, uprev) + W = mass_matrix / dtgamma - J + else + J = ForwardDiff.derivative(uf, uprev) + W = 1 / dtgamma - J + end + + du = f(uprev, p, t) + k3 = copy(du) + + linsolve_tmp = du + dtd1 * dT + + k1 = W \ linsolve_tmp + u = uprev + a21 * k1 + du = f(u, p, t + c2 * dt) + + linsolve_tmp = du + dtd2 * dT + dtC21 * k1 + + k2 = W \ linsolve_tmp + + linsolve_tmp = k3 + dtd3 * dT + (dtC31 * k1 + dtC32 * k2) + + k3 = W \ linsolve_tmp + u = uprev + a41 * k1 + a42 * k2 + a43 * k3 + du = f(u, p, t + dt) + + linsolve_tmp = du + (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) + + k4 = W \ linsolve_tmp + + linsolve_tmp = du + (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) + + k5 = W \ linsolve_tmp + + @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35, h2_21, h2_22, h2_23, h2_24, h2_25 = cache.tab + k₁ = h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + h25 * k5 + k₂ = h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5 + #k₃ = h2_21 * k1 + h2_22 * k2 + h2_23 * k3 + h2_24 * k4 + h2_25 * k5 + copyat_or_push!(k, 1, k₁) + copyat_or_push!(k, 2, k₂) + #copyat_or_push!(k, 3, k₃) + end + nothing +end + +function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Union{Rodas23WCache, Rodas3PCache}, + always_calc_begin = false, allow_calc_end = true, + force_calc_end = false) + if length(k) < 2 || always_calc_begin + @unpack du, du1, du2, tmp, k1, k2, k3, k4, k5, dT, J, W, uf, tf, linsolve_tmp, jac_config, fsalfirst, weight = cache + @unpack a21, a41, a42, a43, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, gamma, c2, c3, d1, d2, d3 = cache.tab + + # Assignments + sizeu = size(u) + uidx = eachindex(uprev) + mass_matrix = f.mass_matrix + + # Precalculations + dtC21 = C21 / dt + dtC31 = C31 / dt + dtC32 = C32 / dt + dtC41 = C41 / dt + dtC42 = C42 / dt + dtC43 = C43 / dt + dtC51 = C51 / dt + dtC52 = C52 / dt + dtC53 = C53 / dt + dtC54 = C54 / dt + + dtd1 = dt * d1 + dtd2 = dt * d2 + dtd3 = dt * d3 + dtgamma = dt * gamma + + @.. broadcast=false linsolve_tmp=@muladd fsalfirst + dtgamma * dT + + ### Jacobian does not need to be re-evaluated after an event + ### Since it's unchanged + jacobian2W!(W, mass_matrix, dtgamma, J, true) + + linsolve = cache.linsolve + + linres = dolinsolve(cache, linsolve; A = W, b = _vec(linsolve_tmp), + reltol = cache.reltol) + vecu = _vec(linres.u) + veck1 = _vec(k1) + + @.. broadcast=false veck1=-vecu + @.. broadcast=false tmp=uprev + a21 * k1 + f(du, tmp, p, t + c2 * dt) + + if mass_matrix === I + @.. broadcast=false linsolve_tmp=du + dtd2 * dT + dtC21 * k1 + else + @.. broadcast=false du1=dtC21 * k1 + mul!(du2, mass_matrix, du1) + @.. broadcast=false linsolve_tmp=du + dtd2 * dT + du2 + end + + linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), + reltol = cache.reltol) + vecu = _vec(linres.u) + veck2 = _vec(k2) + @.. broadcast=false veck2=-vecu + + if mass_matrix === I + @.. broadcast=false linsolve_tmp=fsalfirst + dtd3 * dT + (dtC31 * k1 + dtC32 * k2) + else + @.. broadcast=false du1=dtC31 * k1 + dtC32 * k2 + mul!(du2, mass_matrix, du1) + @.. broadcast=false linsolve_tmp=fsalfirst + dtd3 * dT + du2 + end + + linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), + reltol = cache.reltol) + vecu = _vec(linres.u) + veck3 = _vec(k3) + @.. broadcast=false veck3=-vecu + @.. broadcast=false tmp=uprev + a41 * k1 + a42 * k2 + a43 * k3 + f(du, tmp, p, t + dt) + + if mass_matrix === I + @.. broadcast=false linsolve_tmp=du + (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) + else + @.. broadcast=false du1=dtC41 * k1 + dtC42 * k2 + dtC43 * k3 + mul!(du2, mass_matrix, du1) + @.. broadcast=false linsolve_tmp=du + du2 + end + + linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), + reltol = cache.reltol) + vecu = _vec(linres.u) + veck4 = _vec(k4) + @.. broadcast=false veck4=-vecu + + if mass_matrix === I + @.. broadcast=false linsolve_tmp=du + (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + + dtC53 * k3) + else + @.. broadcast=false du1=dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3 + mul!(du2, mass_matrix, du1) + @.. broadcast=false linsolve_tmp=du + du2 + end + + linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), + reltol = cache.reltol) + vecu = _vec(linres.u) + veck5 = _vec(k5) + @.. broadcast=false veck5=-vecu + @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35, h2_21, h2_22, h2_23, h2_24, h2_25 = cache.tab + @.. broadcast=false du=h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + h25 * k5 + copyat_or_push!(k, 1, copy(du)) + + @.. broadcast=false du=h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5 + copyat_or_push!(k, 2, copy(du)) + end + nothing +end + +function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Union{Rodas23WCache{<:Array}, Rodas3PCache{<:Array}}, + always_calc_begin = false, allow_calc_end = true, + force_calc_end = false) + if length(k) < 2 || always_calc_begin + @unpack du, du1, du2, tmp, k1, k2, k3, k4, k5, k6, dT, J, W, uf, tf, linsolve_tmp, jac_config, fsalfirst = cache + @unpack a21, a41, a42, a43, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, gamma, c2, c3, d1, d2, d3 = cache.tab + + # Assignments + sizeu = size(u) + uidx = eachindex(uprev) + mass_matrix = f.mass_matrix + + # Precalculations + dtC21 = C21 / dt + dtC31 = C31 / dt + dtC32 = C32 / dt + dtC41 = C41 / dt + dtC42 = C42 / dt + dtC43 = C43 / dt + dtC51 = C51 / dt + dtC52 = C52 / dt + dtC53 = C53 / dt + dtC54 = C54 / dt + + dtd1 = dt * d1 + dtd2 = dt * d2 + dtd3 = dt * d3 + dtgamma = dt * gamma + + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = @muladd fsalfirst[i] + dtgamma * dT[i] + end + + ### Jacobian does not need to be re-evaluated after an event + ### Since it's unchanged + jacobian2W!(W, mass_matrix, dtgamma, J, true) + + linsolve = cache.linsolve + + linres = dolinsolve(cache, linsolve; A = W, b = _vec(linsolve_tmp), + reltol = cache.reltol) + @inbounds @simd ivdep for i in eachindex(u) + k1[i] = -linres.u[i] + end + @inbounds @simd ivdep for i in eachindex(u) + tmp[i] = uprev[i] + a21 * k1[i] + end + f(du, tmp, p, t + c2 * dt) + + if mass_matrix === I + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = du[i] + dtd2 * dT[i] + dtC21 * k1[i] + end + else + @inbounds @simd ivdep for i in eachindex(u) + du1[i] = dtC21 * k1[i] + end + mul!(du2, mass_matrix, du1) + + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = du[i] + dtd2 * dT[i] + du2[i] + end + end + + linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), + reltol = cache.reltol) + @inbounds @simd ivdep for i in eachindex(u) + k2[i] = -linres.u[i] + end + + if mass_matrix === I + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = fsalfirst[i] + dtd3 * dT[i] + (dtC31 * k1[i] + dtC32 * k2[i]) + end + else + @inbounds @simd ivdep for i in eachindex(u) + du1[i] = dtC31 * k1[i] + dtC32 * k2[i] + end + mul!(du2, mass_matrix, du1) + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = fsalfirst[i] + dtd3 * dT[i] + du2[i] + end + end + + linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), + reltol = cache.reltol) + @inbounds @simd ivdep for i in eachindex(u) + k3[i] = -linres.u[i] + tmp[i] = uprev[i] + a41 * k1[i] + a42 * k2[i] + a43 * k3[i] + end + f(du, tmp, p, t + dt) + + if mass_matrix === I + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = du[i] + (dtC41 * k1[i] + dtC42 * k2[i] + dtC43 * k3[i]) + end + else + @inbounds @simd ivdep for i in eachindex(u) + du1[i] = dtC41 * k1[i] + dtC42 * k2[i] + dtC43 * k3[i] + end + mul!(du2, mass_matrix, du1) + + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = du[i] + du2[i] + end + end + + linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), + reltol = cache.reltol) + @inbounds @simd ivdep for i in eachindex(u) + k4[i] = -linres.u[i] + end + + if mass_matrix === I + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = du[i] + (dtC52 * k2[i] + dtC54 * k4[i] + dtC51 * k1[i] + + dtC53 * k3[i]) + end + else + @inbounds @simd ivdep for i in eachindex(u) + du1[i] = dtC52 * k2[i] + dtC54 * k4[i] + dtC51 * k1[i] + dtC53 * k3[i] + end + mul!(du2, mass_matrix, du1) + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = du[i] + du2[i] + end + end + + linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), + reltol = cache.reltol) + @inbounds @simd ivdep for i in eachindex(u) + k5[i] = -linres.u[i] + end + @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35, h2_21, h2_22, h2_23, h2_24, h2_25 = cache.tab + + @inbounds @simd ivdep for i in eachindex(u) + du[i] = h21 * k1[i] + h22 * k2[i] + h23 * k3[i] + h24 * k4[i] + h25 * k5[i] + end + copyat_or_push!(k, 1, copy(du)) + + @inbounds @simd ivdep for i in eachindex(u) + du[i] = h31 * k1[i] + h32 * k2[i] + h33 * k3[i] + h34 * k4[i] + h35 * k5[i] + end + copyat_or_push!(k, 2, copy(du)) + + #@inbounds @simd ivdep for i in eachindex(u) + # du[i] = h2_21 * k1[i] + h2_22 * k2[i] + h2_23 * k3[i] + h2_24 * k4[i] + h2_25 * k5[i] + #end + #copyat_or_push!(k, 3, copy(du)) + end + nothing +end + + function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rodas4ConstantCache, always_calc_begin = false, allow_calc_end = true, force_calc_end = false) diff --git a/src/derivative_utils.jl b/src/derivative_utils.jl index c671634a02..ab6997a13f 100644 --- a/src/derivative_utils.jl +++ b/src/derivative_utils.jl @@ -748,7 +748,7 @@ end J = f.jac(uprev, p, t) if J isa StaticArray && integrator.alg isa - Union{Rosenbrock23, Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P} + Union{Rosenbrock23, Rodas23W, Rodas3P,Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P} W = W_transform ? J - mass_matrix * inv(dtgamma) : dtgamma * J - mass_matrix else diff --git a/src/integrators/integrator_interface.jl b/src/integrators/integrator_interface.jl index f40fdc0f3c..8bbb691479 100644 --- a/src/integrators/integrator_interface.jl +++ b/src/integrators/integrator_interface.jl @@ -92,7 +92,7 @@ end else return if isdefined(integrator, :fsallast) && !(integrator.alg isa - Union{Rosenbrock23, Rosenbrock32, Rodas4, Rodas4P, Rodas4P2, Rodas5, + Union{Rosenbrock23, Rosenbrock32, Rodas23W, Rodas3P,Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P}) # Special stiff interpolations do not store the right value in fsallast out .= integrator.fsallast diff --git a/src/interp_func.jl b/src/interp_func.jl index 0f43051a87..6ff43bbabe 100644 --- a/src/interp_func.jl +++ b/src/interp_func.jl @@ -57,8 +57,8 @@ end function DiffEqBase.interp_summary(::Type{cacheType}, dense::Bool) where { cacheType <: - Union{Rodas4ConstantCache, - Rodas4Cache}} + Union{Rodas4ConstantCache, Rodas23WConstantCache, Rodas3PConstantCache, + Rodas4Cache, Rodas23WCache, Rodas3PCache}} dense ? "specialized 3rd order \"free\" stiffness-aware interpolation" : "1st order linear" end diff --git a/src/perform_step/rosenbrock_perform_step.jl b/src/perform_step/rosenbrock_perform_step.jl index 2256c403ae..fa74f6d4c8 100644 --- a/src/perform_step/rosenbrock_perform_step.jl +++ b/src/perform_step/rosenbrock_perform_step.jl @@ -677,7 +677,7 @@ end repeat_step = false) @unpack t, dt, uprev, u, f, p = integrator @unpack tf, uf = cache - @unpack a21, a31, a32, C21, C31, C32, C41, C42, C43, b1, b2, b3, b4, btilde1, btilde2, btilde3, btilde4, gamma, c2, c3, d1, d2, d3, d4 = cache.tab + @unpack a21, a31, a32, a41, a42, a43, C21, C31, C32, C41, C42, C43, b1, b2, b3, b4, btilde1, btilde2, btilde3, btilde4, gamma, c2, c3, d1, d2, d3, d4 = cache.tab # Precalculations dtC21 = C21 / dt @@ -731,6 +731,9 @@ end k3 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) integrator.stats.nsolve += 1 + u = uprev + a41 * k1 + a42 * k2 + a43 * k3 + du = f(u, p, t + dt) #-- c4 = 1 + integrator.stats.nf += 1 if mass_matrix === I linsolve_tmp = du + dtd4 * dT + dtC41 * k1 + dtC42 * k2 + dtC43 * k3 @@ -760,7 +763,7 @@ end @muladd function perform_step!(integrator, cache::Rosenbrock34Cache, repeat_step = false) @unpack t, dt, uprev, u, f, p = integrator @unpack du, du1, du2, fsalfirst, fsallast, k1, k2, k3, k4, dT, J, W, uf, tf, linsolve_tmp, jac_config, atmp, weight = cache - @unpack a21, a31, a32, C21, C31, C32, C41, C42, C43, b1, b2, b3, b4, btilde1, btilde2, btilde3, btilde4, gamma, c2, c3, d1, d2, d3, d4 = cache.tab + @unpack a21, a31, a32, a41, a42, a43, C21, C31, C32, C41, C42, C43, b1, b2, b3, b4, btilde1, btilde2, btilde3, btilde4, gamma, c2, c3, d1, d2, d3, d4 = cache.tab # Assignments uidx = eachindex(integrator.uprev) @@ -841,6 +844,9 @@ end veck3 = _vec(k3) @.. broadcast=false veck3=-vecu integrator.stats.nsolve += 1 + @.. broadcast=false u=uprev + a41 * k1 + a42 * k2 + a43 * k3 + f(du, u, p, t + dt) #-- c4 = 1 + integrator.stats.nf += 1 if mass_matrix === I @.. broadcast=false linsolve_tmp=du + dtd4 * dT + dtC41 * k1 + dtC42 * k2 + @@ -884,6 +890,501 @@ end ################################################################################ +#### Rodas3P type method + +function initialize!(integrator, cache::Union{Rodas23WConstantCache, Rodas3PConstantCache}) + integrator.kshortsize = 3 + integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) + # Avoid undefined entries if k is an array of arrays + integrator.k[1] = zero(integrator.u) + integrator.k[2] = zero(integrator.u) + integrator.k[3] = zero(integrator.u) +end + +@muladd function perform_step!(integrator, cache::Union{Rodas23WConstantCache, Rodas3PConstantCache}, repeat_step = false) + @unpack t, dt, uprev, u, f, p = integrator + @unpack tf, uf = cache + @unpack a21, a41, a42, a43, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, gamma, c2, c3, d1, d2, d3 = cache.tab + + # Precalculations + dtC21 = C21 / dt + dtC31 = C31 / dt + dtC32 = C32 / dt + dtC41 = C41 / dt + dtC42 = C42 / dt + dtC43 = C43 / dt + dtC51 = C51 / dt + dtC52 = C52 / dt + dtC53 = C53 / dt + dtC54 = C54 / dt + + dtd1 = dt * d1 + dtd2 = dt * d2 + dtd3 = dt * d3 + dtgamma = dt * gamma + + mass_matrix = integrator.f.mass_matrix + + # Time derivative + tf.u = uprev + dT = calc_tderivative(integrator, cache) + + W = calc_W(integrator, cache, dtgamma, repeat_step, true) + if !issuccess_W(W) + integrator.EEst = 2 + return nothing + end + + du = f(uprev, p, t) + integrator.stats.nf += 1 + k3 = copy(du) #-- save for stage 3 + + linsolve_tmp = du + dtd1 * dT + + k1 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) + integrator.stats.nsolve += 1 + u = uprev + a21 * k1 + du = f(u, p, t + c2 * dt) + integrator.stats.nf += 1 + + if mass_matrix === I + linsolve_tmp = du + dtd2 * dT + dtC21 * k1 + else + linsolve_tmp = du + dtd2 * dT + mass_matrix * (dtC21 * k1) + end + + k2 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) + integrator.stats.nsolve += 1 + + if mass_matrix === I + linsolve_tmp = k3 + dtd3 * dT + (dtC31 * k1 + dtC32 * k2) + else + linsolve_tmp = k3 + dtd3 * dT + mass_matrix * (dtC31 * k1 + dtC32 * k2) + end + + k3 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) + integrator.stats.nsolve += 1 + u = uprev + a41 * k1 + a42 * k2 + a43 * k3 + du = f(u, p, t + dt) + integrator.stats.nf += 1 + + if mass_matrix === I + linsolve_tmp = du + (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) + else + linsolve_tmp = du + mass_matrix * (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) + end + + k4 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) + integrator.stats.nsolve += 1 + + if mass_matrix === I + linsolve_tmp = du + (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) + else + linsolve_tmp = du + + mass_matrix * (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) + end + + k5 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) + integrator.stats.nsolve += 1 + du = u + k4 #-- solution p=2 + u = u + k5 #-- solution p=3 + + EEst = 0.0 + if integrator.opts.calck + @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35, h2_21, h2_22, h2_23, h2_24, h2_25 = cache.tab + integrator.k[1] = h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + h25 * k5 + integrator.k[2] = h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5 + integrator.k[3] = h2_21 * k1 + h2_22 * k2 + h2_23 * k3 + h2_24 * k4 + h2_25 * k5 + if integrator.opts.adaptive + calculate_interpoldiff!(k1, k2, uprev, du, u, integrator.k[1], integrator.k[2], integrator.k[3]) + atm = calculate_residuals!(k2, uprev, k1, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + EEst = max(EEst,integrator.opts.internalnorm(atmp, t)) #-- role of t unclear + end + end + + if (integrator.alg isa Rodas23W) + k1[:] = u[:]; u[:] = du[:]; du[:] = k1[:] + if integrator.opts.calck + integrator.k[1][:] = integrator.k[3][:] + integrator.k[2][:] .= 0.0 + end + end + + if integrator.opts.adaptive + atmp = calculate_residuals(u-du, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = max(EEst,integrator.opts.internalnorm(atmp, t)) + end + + integrator.u = u + return nothing +end + +function initialize!(integrator, cache::Union{Rodas23WCache, Rodas3PCache}) + integrator.kshortsize = 3 + @unpack dense1, dense2, dense3 = cache + resize!(integrator.k, integrator.kshortsize) + integrator.k[1] = dense1 + integrator.k[2] = dense2 + integrator.k[3] = dense3 +end + +@muladd function perform_step!(integrator, cache::Union{Rodas23WCache, Rodas3PCache}, repeat_step = false) + @unpack t, dt, uprev, u, f, p = integrator + @unpack du, du1, du2, dT, J, W, uf, tf, k1, k2, k3, k4, k5, linsolve_tmp, jac_config, atmp, weight = cache + @unpack a21, a41, a42, a43, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, gamma, c2, c3, d1, d2, d3 = cache.tab + + # Assignments + sizeu = size(u) + uidx = eachindex(integrator.uprev) + mass_matrix = integrator.f.mass_matrix + + # Precalculations + dtC21 = C21 / dt + dtC31 = C31 / dt + dtC32 = C32 / dt + dtC41 = C41 / dt + dtC42 = C42 / dt + dtC43 = C43 / dt + dtC51 = C51 / dt + dtC52 = C52 / dt + dtC53 = C53 / dt + dtC54 = C54 / dt + + dtd1 = dt * d1 + dtd2 = dt * d2 + dtd3 = dt * d3 + dtgamma = dt * gamma + + f(cache.fsalfirst, uprev, p, t) # used in calc_rosenbrock_differentiation! + integrator.stats.nf += 1 + + calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true) + + calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev, + integrator.opts.abstol, integrator.opts.reltol, + integrator.opts.internalnorm, t) + + if repeat_step + linres = dolinsolve(integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), + du = cache.fsalfirst, u = u, p = p, t = t, weight = weight, + solverdata = (; gamma = dtgamma)) + else + linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), + du = cache.fsalfirst, u = u, p = p, t = t, weight = weight, + solverdata = (; gamma = dtgamma)) + end + + @.. broadcast=false $(_vec(k1))=-linres.u + + integrator.stats.nsolve += 1 + + @.. broadcast=false u=uprev + a21 * k1 + f(du, u, p, t + c2 * dt) + integrator.stats.nf += 1 + + if mass_matrix === I + @.. broadcast=false linsolve_tmp=du + dtd2 * dT + dtC21 * k1 + else + @.. broadcast=false du1=dtC21 * k1 + mul!(_vec(du2), mass_matrix, _vec(du1)) + @.. broadcast=false linsolve_tmp=du + dtd2 * dT + du2 + end + + linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) + @.. broadcast=false $(_vec(k2))=-linres.u + integrator.stats.nsolve += 1 + + if mass_matrix === I + @.. broadcast=false linsolve_tmp = cache.fsalfirst + dtd3 * dT + (dtC31 * k1 + dtC32 * k2) + else + @.. broadcast=false du1=dtC31 * k1 + dtC32 * k2 + mul!(_vec(du2), mass_matrix, _vec(du1)) + @.. broadcast=false linsolve_tmp = cache.fsalfirst + dtd3 * dT + du2 + end + + linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) + @.. broadcast=false $(_vec(k3))=-linres.u + integrator.stats.nsolve += 1 + + @.. broadcast=false u=uprev + a41 * k1 + a42 * k2 + a43 * k3 + f(du, u, p, t + dt) + integrator.stats.nf += 1 + + if mass_matrix === I + @.. broadcast=false linsolve_tmp=du + + (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) + else + @.. broadcast=false du1=dtC41 * k1 + dtC42 * k2 + dtC43 * k3 + mul!(_vec(du2), mass_matrix, _vec(du1)) + @.. broadcast=false linsolve_tmp=du + du2 + end + + linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) + @.. broadcast=false $(_vec(k4))=-linres.u + integrator.stats.nsolve += 1 + + if mass_matrix === I + @.. broadcast=false linsolve_tmp=du + + (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) + else + @.. broadcast=false du1=dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3 + mul!(_vec(du2), mass_matrix, _vec(du1)) + @.. broadcast=false linsolve_tmp=du + du2 + end + + linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) + @.. broadcast=false $(_vec(k5))=-linres.u + integrator.stats.nsolve += 1 + + du = u + k4 #-- p=2 solution + u .+= k5 + + EEst = 0.0 + if integrator.opts.calck + @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35, h2_21, h2_22, h2_23, h2_24, h2_25 = cache.tab + @.. broadcast=false integrator.k[1]=h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + + h25 * k5 + @.. broadcast=false integrator.k[2]=h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + + h35 * k5 + @.. broadcast=false integrator.k[3]=h2_21 * k1 + h2_22 * k2 + h2_23 * k3 + h2_24 * k4 + h2_25 * k5 + if integrator.opts.adaptive + calculate_interpoldiff!(du1, du2, uprev, du, u, integrator.k[1], integrator.k[2], integrator.k[3]) + calculate_residuals!(atmp, du2, uprev, du1, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + EEst = max(EEst,integrator.opts.internalnorm(atmp, t)) #-- role of t unclear + end + end + + if (integrator.alg isa Rodas23W) + du1[:] = u[:]; u[:] = du[:]; du[:] = du1[:] + if integrator.opts.calck + integrator.k[1][:] = integrator.k[3][:] + integrator.k[2][:] .= 0.0 + end + end + + if integrator.opts.adaptive + calculate_residuals!(atmp, u-du, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = max(EEst,integrator.opts.internalnorm(atmp, t)) + end + cache.linsolve = linres.cache +end + +@muladd function perform_step!(integrator, cache::Union{Rodas23WCache{<:Array}, Rodas3PCache{<:Array}}, repeat_step = false) + @unpack t, dt, uprev, u, f, p = integrator + @unpack du, du1, du2, dT, J, W, uf, tf, k1, k2, k3, k4, k5, linsolve_tmp, jac_config, atmp, weight = cache + @unpack a21, a41, a42, a43, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, gamma, c2, c3, d1, d2, d3 = cache.tab + + # Assignments + sizeu = size(u) + uidx = eachindex(integrator.uprev) + mass_matrix = integrator.f.mass_matrix + + # Precalculations + dtC21 = C21 / dt + dtC31 = C31 / dt + dtC32 = C32 / dt + dtC41 = C41 / dt + dtC42 = C42 / dt + dtC43 = C43 / dt + dtC51 = C51 / dt + dtC52 = C52 / dt + dtC53 = C53 / dt + dtC54 = C54 / dt + + dtd1 = dt * d1 + dtd2 = dt * d2 + dtd3 = dt * d3 + dtgamma = dt * gamma + + f(cache.fsalfirst, uprev, p, t) # used in calc_rosenbrock_differentiation! + integrator.stats.nf += 1 + + calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true) + + calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev, + integrator.opts.abstol, integrator.opts.reltol, + integrator.opts.internalnorm, t) + + if repeat_step + linres = dolinsolve(integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), + du = cache.fsalfirst, u = u, p = p, t = t, weight = weight, + solverdata = (; gamma = dtgamma)) + else + linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), + du = cache.fsalfirst, u = u, p = p, t = t, weight = weight, + solverdata = (; gamma = dtgamma)) + end + + @inbounds @simd ivdep for i in eachindex(u) + k1[i] = -linres.u[i] + end + integrator.stats.nsolve += 1 + + @inbounds @simd ivdep for i in eachindex(u) + u[i] = uprev[i] + a21 * k1[i] + end + f(du, u, p, t + c2 * dt) + integrator.stats.nf += 1 + + if mass_matrix === I + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = du[i] + dtd2 * dT[i] + dtC21 * k1[i] + end + else + @inbounds @simd ivdep for i in eachindex(u) + du1[i] = dtC21 * k1[i] + end + mul!(_vec(du2), mass_matrix, _vec(du1)) + + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = du[i] + dtd2 * dT[i] + du2[i] + end + end + + linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) + @inbounds @simd ivdep for i in eachindex(u) + k2[i] = -linres.u[i] + end + integrator.stats.nsolve += 1 + + if mass_matrix === I + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = cache.fsalfirst[i] + dtd3 * dT[i] + (dtC31 * k1[i] + dtC32 * k2[i]) + end + else + @inbounds @simd ivdep for i in eachindex(u) + du1[i] = dtC31 * k1[i] + dtC32 * k2[i] + end + mul!(_vec(du2), mass_matrix, _vec(du1)) + + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = cache.fsalfirst[i] + dtd3 * dT[i] + du2[i] + end + end + + linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) + @inbounds @simd ivdep for i in eachindex(u) + k3[i] = -linres.u[i] + end + integrator.stats.nsolve += 1 + + @inbounds @simd ivdep for i in eachindex(u) + u[i] = uprev[i] + a41 * k1[i] + a42 * k2[i] + a43 * k3[i] + end + f(du, u, p, t + dt) + integrator.stats.nf += 1 + + if mass_matrix === I + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = du[i] + + (dtC41 * k1[i] + dtC42 * k2[i] + dtC43 * k3[i]) + end + else + @inbounds @simd ivdep for i in eachindex(u) + du1[i] = dtC41 * k1[i] + dtC42 * k2[i] + dtC43 * k3[i] + end + mul!(_vec(du2), mass_matrix, _vec(du1)) + + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = du[i] + du2[i] + end + end + + linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) + @inbounds @simd ivdep for i in eachindex(u) + k4[i] = -linres.u[i] + end + integrator.stats.nsolve += 1 + + if mass_matrix === I + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = du[i] + (dtC52 * k2[i] + dtC54 * k4[i] + dtC51 * k1[i] + + dtC53 * k3[i]) + end + else + @inbounds @simd ivdep for i in eachindex(u) + du1[i] = dtC52 * k2[i] + dtC54 * k4[i] + dtC51 * k1[i] + dtC53 * k3[i] + end + mul!(_vec(du2), mass_matrix, _vec(du1)) + + @inbounds @simd ivdep for i in eachindex(u) + linsolve_tmp[i] = du[i] + du2[i] + end + end + + linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) + @inbounds @simd ivdep for i in eachindex(u) + k5[i] = -linres.u[i] + end + integrator.stats.nsolve += 1 + + @inbounds @simd ivdep for i in eachindex(u) #-- sol p=2 + du[i] = u[i] + k4[i] + end + @inbounds @simd ivdep for i in eachindex(u) + u[i] += k5[i] + end + + EEst = 0.0 + if integrator.opts.calck + @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35, h2_21, h2_22, h2_23, h2_24, h2_25 = cache.tab + @inbounds @simd ivdep for i in eachindex(u) + integrator.k[1][i] = h21 * k1[i] + h22 * k2[i] + h23 * k3[i] + h24 * k4[i] + h25 * k5[i] + integrator.k[2][i] = h31 * k1[i] + h32 * k2[i] + h33 * k3[i] + h34 * k4[i] + h35 * k5[i] + integrator.k[3][i] = h2_21 * k1[i] + h2_22 * k2[i] + h2_23 * k3[i] + h2_24 * k4[i] + h2_25 * k5[i] + end + if integrator.opts.adaptive + calculate_interpoldiff!(du1, du2, uprev, du, u, integrator.k[1], integrator.k[2], integrator.k[3]) + calculate_residuals!(atmp, du2, uprev, du1, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + EEst = max(EEst,integrator.opts.internalnorm(atmp, t)) #-- role of t unclear + #println(t," ",EEst," ",du2) + end + end + + if (integrator.alg isa Rodas23W) + @inbounds @simd ivdep for i in eachindex(u) + tt = u[i]; u[i] = du[i]; du[i] = tt + if integrator.opts.calck + integrator.k[1][i] = integrator.k[3][i]; + integrator.k[2][i] = 0.0 + end + end + end + if integrator.opts.adaptive + calculate_residuals!(atmp, u-du, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = max(EEst,integrator.opts.internalnorm(atmp, t)) + #println(t," ",EEst," ",integrator.EEst) + end + cache.linsolve = linres.cache +end + +function calculate_interpoldiff!(u_int, u_diff, uprev, up2, up3, c_koeff, d_koeff, c2_koeff) + for i in eachindex(up2) + a1 = up3[i] + c_koeff[i] - up2[i] - c2_koeff[i]; a2 = d_koeff[i] - c_koeff[i] + c2_koeff[i]; a3 = -d_koeff[i] + dis = a2^2 - 3*a1*a3 + u_int[i] = up3[i]; u_diff[i] = 0.0 + if dis > 0.0 #-- Min/Max occurs + tau1 = (-a2 - sqrt(dis))/(3*a3); tau2 = (-a2 + sqrt(dis))/(3*a3) + if tau1 > tau2 tau1,tau2 = tau2,tau1; end + for tau in (tau1,tau2) + if (tau > 0.0) && (tau < 1.0) + y_tau = (1 - tau)*uprev[i] + tau*(up3[i] + (1 - tau)*(c_koeff[i] + tau*d_koeff[i])) + dy_tau = ((a3*tau + a2)*tau + a1)*tau + if abs(dy_tau) > abs(u_diff[i]) + u_diff[i] = dy_tau; u_int[i] = y_tau + end + end + end + end + end +end + #### Rodas4 type method function initialize!(integrator, cache::Rodas4ConstantCache) diff --git a/src/tableaus/rosenbrock_tableaus.jl b/src/tableaus/rosenbrock_tableaus.jl index f615a5f559..92f7546bf5 100644 --- a/src/tableaus/rosenbrock_tableaus.jl +++ b/src/tableaus/rosenbrock_tableaus.jl @@ -133,6 +133,88 @@ function Rodas3Tableau(T, T2) b4, btilde1, btilde2, btilde3, btilde4, gamma, c2, c3, d1, d2, d3, d4) end +struct Rodas3PTableau{T, T2} + a21::T + a41::T + a42::T + a43::T + C21::T + C31::T + C32::T + C41::T + C42::T + C43::T + C51::T + C52::T + C53::T + C54::T + gamma::T + c2::T2 + c3::T2 + d1::T + d2::T + d3::T + h21::T + h22::T + h23::T + h24::T + h25::T + h31::T + h32::T + h33::T + h34::T + h35::T + h2_21::T + h2_22::T + h2_23::T + h2_24::T + h2_25::T +end + +function Rodas3PTableau(T, T2) + gamma = convert(T, 1 // 3) + a21 = convert(T, 4.0/3.0) + a41 = convert(T, 2.90625) + a42 = convert(T, 3.375) + a43 = convert(T, 0.40625) + C21 = -convert(T, 4.0) + C31 = convert(T, 8.25) + C32 = convert(T, 6.75) + C41 = convert(T, 1.21875) + C42 = -convert(T, 5.0625) + C43 = -convert(T, 1.96875) + C51 = convert(T, 4.03125) + C52 = -convert(T, 15.1875) + C53 = -convert(T, 4.03125) + C54 = convert(T, 6.0) + c2 = convert(T2, 4.0/9.0) + c3 = convert(T2, 0.0) + d1 = convert(T, 1.0/3.0) + d2 = -convert(T, 1.0/9.0) + d3 = convert(T, 1.0) + h21 = convert(T, 1.78125) + h22 = convert(T, 6.75) + h23 = convert(T, 0.15625) + h24 = -convert(T, 6.0) + h25 = -convert(T, 1.0) + h31 = convert(T, 4.21875) + h32 = -convert(T, 15.1875) + h33 = -convert(T, 3.09375) + h34 = convert(T, 9.0) + h35 = convert(T, 0.0) + h2_21 = convert(T, 4.21875) + h2_22 = -convert(T, 2.025) + h2_23 = -convert(T, 1.63125) + h2_24 = -convert(T, 1.7) + h2_25 = -convert(T, 0.1) + Rodas3PTableau(a21, a41, a42, a43, + C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, + gamma, c2, c3, d1, d2, d3, + h21, h22, h23, h24, h25, h31, h32, h33, h34, h35, h2_21, h2_22, h2_23, h2_24, h2_25) +end + + + @ROS34PW(:tableau) @Rosenbrock4(:tableau) From 88e3062c57c2f06921e104660213d4b50646a8e4 Mon Sep 17 00:00:00 2001 From: Gerd Steinebach <64948537+gstein3m@users.noreply.github.com> Date: Wed, 3 Jan 2024 14:23:37 +0100 Subject: [PATCH 2/2] Update rosenbrock_interpolants.jl --- src/dense/rosenbrock_interpolants.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dense/rosenbrock_interpolants.jl b/src/dense/rosenbrock_interpolants.jl index 2974f64640..2e73d536bc 100644 --- a/src/dense/rosenbrock_interpolants.jl +++ b/src/dense/rosenbrock_interpolants.jl @@ -2,6 +2,8 @@ ROSENBROCKS_WITH_INTERPOLATIONS = Union{Rosenbrock23ConstantCache, Rosenbrock23Cache, Rosenbrock32ConstantCache, Rosenbrock32Cache, + Rodas23WConstantCache, Rodas3PConstantCache, + Rodas23WCache, Rodas3PCache, Rodas4ConstantCache, Rosenbrock5ConstantCache, Rodas4Cache, Rosenbrock5Cache}