diff --git a/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl b/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl index 652b7ea6af..07f4bbeee7 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl @@ -11,7 +11,7 @@ end function DiffEqBase.interp_summary(::Type{cacheType}, dense::Bool) where { cacheType <: - Union{Rodas4ConstantCache, Rodas23WConstantCache, Rodas3PConstantCache, + Union{RosenbrockCombinedConstantCache, Rodas23WConstantCache, Rodas3PConstantCache, RosenbrockCache, Rodas23WCache, Rodas3PCache}} dense ? "specialized 3rd order \"free\" stiffness-aware interpolation" : "1st order linear" @@ -20,8 +20,8 @@ end function DiffEqBase.interp_summary(::Type{cacheType}, dense::Bool) where { cacheType <: - Union{Rosenbrock5ConstantCache, - Rosenbrock5Cache}} + Union{RosenbrockCombinedConstantCache, + RosenbrockCache}} dense ? "specialized 4rd order \"free\" stiffness-aware interpolation" : "1st order linear" end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl index 5d6abd9f42..33fc5dcd2b 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl @@ -41,12 +41,24 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT alg::A step_limiter!::StepLimiter stage_limiter!::StageLimiter + interp_order::Int end function full_cache(c::RosenbrockCache) return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2, c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp] end +struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache + tf::TF + uf::UF + tab::Tab + J::JType + W::WType + linsolve::F + autodiff::AD + interp_order::Int +end + @cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType, TabType, TFType, UFType, F, JCType, GCType, RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache @@ -702,22 +714,16 @@ end ### Rodas4 methods -struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache - tf::TF - uf::UF - tab::Tab - J::JType - W::WType - linsolve::F - autodiff::AD -end - tabtype(::Rodas4) = Rodas4Tableau tabtype(::Rodas42) = Rodas42Tableau tabtype(::Rodas4P) = Rodas4PTableau tabtype(::Rodas4P2) = Rodas4P2Tableau +tabtype(::Rodas5) = Rodas5Tableau +tabtype(::Rodas5P) = Rodas5PTableau +tabtype(::Rodas5Pr) = Rodas5PTableau +tabtype(::Rodas5Pe) = Rodas5PTableau -function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2}, +function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, @@ -727,21 +733,22 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2}, J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) - Rodas4ConstantCache(tf, uf, - tabtype(alg)(constvalue(uBottomEltypeNoUnits), - constvalue(tTypeNoUnits)), J, W, linsolve, - alg_autodiff(alg)) + tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + RosenbrockCombinedConstantCache(tf, uf, + tab, J, W, linsolve, + alg_autodiff(alg), size(tab.H, 1)) end -function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2}, +function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) # Initialize vectors - dense = [zero(rate_prototype) for _ in 1:2] - ks = [zero(rate_prototype) for _ in 1:6] + dense = [zero(rate_prototype) for _ in 1:size(tab.H, 1)] + ks = [zero(rate_prototype) for _ in 1:size(tab.A, 1)] du = zero(rate_prototype) du1 = zero(rate_prototype) du2 = zero(rate_prototype) @@ -760,7 +767,6 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2}, recursivefill!(atmp, false) weight = similar(u, uEltypeNoUnits) recursivefill!(weight, false) - tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) tf = TimeGradientWrapper(f, uprev, p) uf = UJacobianWrapper(f, t, p) @@ -783,190 +789,9 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2}, u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config, grad_config, reltol, alg, - alg.step_limiter!, alg.stage_limiter!) -end - -################################################################################ - -### Rosenbrock5 - -struct Rosenbrock5ConstantCache{TF, UF, Tab, JType, WType, F} <: RosenbrockConstantCache - tf::TF - uf::UF - tab::Tab - J::JType - W::WType - linsolve::F -end - -@cache mutable struct Rosenbrock5Cache{ - uType, rateType, uNoUnitsType, JType, WType, TabType, - TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <: - RosenbrockMutableCache - u::uType - uprev::uType - dense1::rateType - dense2::rateType - dense3::rateType - du::rateType - du1::rateType - du2::rateType - k1::rateType - k2::rateType - k3::rateType - k4::rateType - k5::rateType - k6::rateType - k7::rateType - k8::rateType - fsalfirst::rateType - fsallast::rateType - dT::rateType - J::JType - W::WType - tmp::rateType - atmp::uNoUnitsType - weight::uNoUnitsType - tab::TabType - tf::TFType - uf::UFType - linsolve_tmp::rateType - linsolve::F - jac_config::JCType - grad_config::GCType - reltol::RTolType - alg::A - step_limiter!::StepLimiter - stage_limiter!::StageLimiter + alg.step_limiter!, alg.stage_limiter!, size(tab.H, 1)) end -function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - dense1 = zero(rate_prototype) - dense2 = zero(rate_prototype) - dense3 = zero(rate_prototype) - du = zero(rate_prototype) - du1 = zero(rate_prototype) - du2 = zero(rate_prototype) - k1 = zero(rate_prototype) - k2 = zero(rate_prototype) - k3 = zero(rate_prototype) - k4 = zero(rate_prototype) - k5 = zero(rate_prototype) - k6 = zero(rate_prototype) - k7 = zero(rate_prototype) - k8 = zero(rate_prototype) - fsalfirst = zero(rate_prototype) - fsallast = zero(rate_prototype) - dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - tmp = zero(rate_prototype) - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - weight = similar(u, uEltypeNoUnits) - recursivefill!(weight, false) - tab = Rodas5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - - tf = TimeGradientWrapper(f, uprev, p) - uf = UJacobianWrapper(f, t, p) - linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) - linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, - assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) - Rosenbrock5Cache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, - k5, k6, k7, k8, - fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, - linsolve_tmp, - linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!, - alg.stage_limiter!) -end - -function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tf = TimeDerivativeWrapper(f, u, p) - uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) - linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) - linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) - Rosenbrock5ConstantCache(tf, uf, - Rodas5Tableau(constvalue(uBottomEltypeNoUnits), - constvalue(tTypeNoUnits)), J, W, linsolve) -end - -function alg_cache( - alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - dense1 = zero(rate_prototype) - dense2 = zero(rate_prototype) - dense3 = zero(rate_prototype) - du = zero(rate_prototype) - du1 = zero(rate_prototype) - du2 = zero(rate_prototype) - k1 = zero(rate_prototype) - k2 = zero(rate_prototype) - k3 = zero(rate_prototype) - k4 = zero(rate_prototype) - k5 = zero(rate_prototype) - k6 = zero(rate_prototype) - k7 = zero(rate_prototype) - k8 = zero(rate_prototype) - fsalfirst = zero(rate_prototype) - fsallast = zero(rate_prototype) - dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - tmp = zero(rate_prototype) - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - weight = similar(u, uEltypeNoUnits) - recursivefill!(weight, false) - tab = Rodas5PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - - tf = TimeGradientWrapper(f, uprev, p) - uf = UJacobianWrapper(f, t, p) - linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) - linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, - assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) - Rosenbrock5Cache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, - k5, k6, k7, k8, - fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, - linsolve_tmp, - linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!, - alg.stage_limiter!) -end - -function alg_cache( - alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tf = TimeDerivativeWrapper(f, u, p) - uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) - linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) - linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) - Rosenbrock5ConstantCache(tf, uf, - Rodas5PTableau(constvalue(uBottomEltypeNoUnits), - constvalue(tTypeNoUnits)), J, W, linsolve) -end function get_fsalfirstlast( cache::Union{Rosenbrock23Cache, Rosenbrock32Cache, Rosenbrock33Cache, diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl index 4d7e659410..003595567e 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl @@ -1,11 +1,10 @@ ### Fallbacks to capture - ROSENBROCKS_WITH_INTERPOLATIONS = Union{Rosenbrock23ConstantCache, Rosenbrock23Cache, Rosenbrock32ConstantCache, Rosenbrock32Cache, Rodas23WConstantCache, Rodas3PConstantCache, Rodas23WCache, Rodas3PCache, - Rodas4ConstantCache, Rosenbrock5ConstantCache, - RosenbrockCache, Rosenbrock5Cache} + RosenbrockCombinedConstantCache, + RosenbrockCache} function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::ROSENBROCKS_WITH_INTERPOLATIONS, @@ -50,7 +49,7 @@ end cache::Union{Rosenbrock23Cache, Rosenbrock32Cache}, idxs::Nothing, T::Type{Val{0}}, differential_vars) @rosenbrock2332pre0 - @inbounds @.. broadcast=false y₀+dt * (c1 * k[1] + c2 * k[2]) + @inbounds @.. y₀+dt * (c1 * k[1] + c2 * k[2]) end @muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, @@ -58,7 +57,7 @@ end Rosenbrock32ConstantCache, Rosenbrock32Cache }, idxs, T::Type{Val{0}}, differential_vars) @rosenbrock2332pre0 - @.. broadcast=false y₀[idxs]+dt * (c1 * k[1][idxs] + c2 * k[2][idxs]) + @.. y₀[idxs]+dt * (c1 * k[1][idxs] + c2 * k[2][idxs]) end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, @@ -67,7 +66,7 @@ end Rosenbrock32ConstantCache, Rosenbrock32Cache }, idxs::Nothing, T::Type{Val{0}}, differential_vars) @rosenbrock2332pre0 - @inbounds @.. broadcast=false out=y₀ + dt * (c1 * k[1] + c2 * k[2]) + @inbounds @.. out=y₀ + dt * (c1 * k[1] + c2 * k[2]) out end @@ -77,7 +76,7 @@ end Rosenbrock32ConstantCache, Rosenbrock32Cache }, idxs, T::Type{Val{0}}, differential_vars) @rosenbrock2332pre0 - @views @.. broadcast=false out=y₀[idxs] + dt * (c1 * k[1][idxs] + c2 * k[2][idxs]) + @views @.. out=y₀[idxs] + dt * (c1 * k[1][idxs] + c2 * k[2][idxs]) out end @@ -93,7 +92,7 @@ end Rosenbrock32ConstantCache, Rosenbrock32Cache }, idxs::Nothing, T::Type{Val{1}}, differential_vars) @rosenbrock2332pre1 - @.. broadcast=false c1diff * k[1]+c2diff * k[2] + @.. c1diff * k[1]+c2diff * k[2] end @muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, @@ -101,7 +100,7 @@ end Rosenbrock32ConstantCache, Rosenbrock32Cache }, idxs, T::Type{Val{1}}, differential_vars) @rosenbrock2332pre1 - @.. broadcast=false c1diff * k[1][idxs]+c2diff * k[2][idxs] + @.. c1diff * k[1][idxs]+c2diff * k[2][idxs] end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, @@ -110,7 +109,7 @@ end Rosenbrock32ConstantCache, Rosenbrock32Cache }, idxs::Nothing, T::Type{Val{1}}, differential_vars) @rosenbrock2332pre1 - @.. broadcast=false out=c1diff * k[1] + c2diff * k[2] + @.. out=c1diff * k[1] + c2diff * k[2] out end @@ -120,204 +119,143 @@ end Rosenbrock32ConstantCache, Rosenbrock32Cache }, idxs, T::Type{Val{1}}, differential_vars) @rosenbrock2332pre1 - @views @.. broadcast=false out=c1diff * k[1][idxs] + c2diff * k[2][idxs] + @views @.. out=c1diff * k[1][idxs] + c2diff * k[2][idxs] out end """ From MATLAB ODE Suite by Shampine """ -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, Rodas23WConstantCache, Rodas3PConstantCache}, - idxs::Nothing, T::Type{Val{0}}, differential_vars) - Θ1 = 1 - Θ - @inbounds Θ1 * y₀ + Θ * (y₁ + Θ1 * (k[1] + Θ * k[2])) -end @muladd function _ode_interpolant( - Θ, dt, y₀, y₁, k, cache::Union{RosenbrockCache, Rodas23WCache, Rodas3PCache}, + Θ, dt, y₀, y₁, k, cache::Union{RosenbrockCombinedConstantCache, Rodas23WConstantCache, Rodas3PConstantCache, RosenbrockCache, Rodas23WCache, Rodas3PCache}, idxs::Nothing, T::Type{Val{0}}, differential_vars) Θ1 = 1 - Θ - @inbounds @.. broadcast=false Θ1 * y₀+Θ * (y₁ + Θ1 * (k[1] + Θ * k[2])) + if !hasproperty(cache, :interp_order) || cache.interp_order == 2 + @.. Θ1 * y₀+Θ * (y₁ + Θ1 * (k[1] + Θ * k[2])) + else + @.. Θ1 * y₀+Θ * (y₁ + Θ1 * (k[1] + Θ * (k[2] + Θ * k[3]))) + end end @muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, RosenbrockCache, Rodas23WConstantCache, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, idxs, T::Type{Val{0}}, differential_vars) Θ1 = 1 - Θ - @.. broadcast=false Θ1 * y₀[idxs]+Θ * (y₁[idxs] + Θ1 * (k[1][idxs] + Θ * k[2][idxs])) + if !hasproperty(cache, :interp_order) || cache.interp_order == 2 + @views @.. Θ1 * y₀[idxs]+Θ * (y₁[idxs] + Θ1 * (k[1][idxs] + Θ * k[2][idxs])) + else + @views @.. Θ1 * y₀[idxs]+Θ * (y₁[idxs] + Θ1 * (k[1][idxs] + Θ * (k[2][idxs] + Θ * k[3][idxs]))) + end end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, RosenbrockCache, Rodas23WConstantCache, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, idxs::Nothing, T::Type{Val{0}}, differential_vars) Θ1 = 1 - Θ - @.. broadcast=false out=Θ1 * y₀ + Θ * (y₁ + Θ1 * (k[1] + Θ * k[2])) + if !hasproperty(cache, :interp_order) || cache.interp_order == 2 + @.. out=Θ1 * y₀ + Θ * (y₁ + Θ1 * (k[1] + Θ * k[2])) + else + @.. out=Θ1 * y₀ + Θ * (y₁ + Θ1 * (k[1] + Θ * (k[2] + Θ * k[3]))) + end out end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, RosenbrockCache, Rodas23WConstantCache, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, idxs, T::Type{Val{0}}, differential_vars) Θ1 = 1 - Θ - @views @.. broadcast=false out=Θ1 * y₀[idxs] + - Θ * (y₁[idxs] + Θ1 * (k[1][idxs] + Θ * k[2][idxs])) + if !hasproperty(cache, :interp_order) || cache.interp_order == 2 + @views @.. out=Θ1 * y₀[idxs] + Θ * (y₁[idxs] + Θ1 * (k[1][idxs] + Θ * k[2][idxs])) + else + @views @.. out=Θ1 * y₀[idxs]+Θ * (y₁[idxs] + + Θ1 * (k[1][idxs] + Θ * (k[2][idxs] + Θ * k[3][idxs]))) + end out end # First Derivative -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, Rodas23WConstantCache, Rodas3PConstantCache}, - idxs::Nothing, T::Type{Val{1}}, differential_vars) - @inbounds (k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁) / dt -end - @muladd function _ode_interpolant( - Θ, dt, y₀, y₁, k, cache::Union{RosenbrockCache, Rodas23WCache, Rodas3PCache}, + Θ, dt, y₀, y₁, k, cache::Union{RosenbrockCache, Rodas23WCache, Rodas3PCache, RosenbrockCombinedConstantCache, Rodas23WConstantCache, Rodas3PConstantCache}, idxs::Nothing, T::Type{Val{1}}, differential_vars) - @inbounds @.. broadcast=false (k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + - y₁)/dt + if !hasproperty(cache, :interp_order) || cache.interp_order == 2 + @.. (k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁)/dt + else + @.. (k[1] + Θ * (-2 * k[1] + 2 * k[2] + + Θ * (-3 * k[2] + 3 * k[3] - 4 * Θ * k[3])) - + y₀ + y₁)/dt + end end - @muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, RosenbrockCache, Rodas23WConstantCache, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, idxs, T::Type{Val{1}}, differential_vars) - @.. broadcast=false (k[1][idxs] + - Θ * (-2 * k[1][idxs] + 2 * k[2][idxs] - 3 * k[2][idxs] * Θ) - - y₀[idxs] + y₁[idxs])/dt + if !hasproperty(cache, :interp_order) || cache.interp_order == 2 + @views @.. (k[1][idxs] + Θ * (-2 * k[1][idxs] + 2 * k[2][idxs] - 3 * k[2][idxs] * Θ) - + y₀[idxs] + y₁[idxs])/dt + else + @views @.. (k[1][idxs] + Θ * (-2 * k[1][idxs] + 2 * k[2][idxs] + + Θ * (-3 * k[2][idxs] + 3 * k[3][idxs] - 4 * Θ * k[3][idxs])) - y₀[idxs] + y₁[idxs])/dt + end end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, RosenbrockCache, Rodas23WConstantCache, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, idxs::Nothing, T::Type{Val{1}}, differential_vars) - @.. broadcast=false out=(k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁) / - dt + if !hasproperty(cache, :interp_order) || cache.interp_order == 2 + @.. out=(k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁) / dt + else + @.. out=(k[1] + Θ * (-2 * k[1] + 2 * k[2] + + Θ * (-3 * k[2] + 3 * k[3] - 4 * Θ * k[3])) - y₀ + y₁) / dt + end out end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, RosenbrockCache, Rodas23WConstantCache, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, idxs, T::Type{Val{1}}, differential_vars) - @views @.. broadcast=false out=(k[1][idxs] + - Θ * - (-2 * k[1][idxs] + 2 * k[2][idxs] - - 3 * k[2][idxs] * Θ) - - y₀[idxs] + y₁[idxs]) / dt - out -end - -#- - -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rosenbrock5ConstantCache, - idxs::Nothing, T::Type{Val{0}}, differential_vars) - Θ1 = 1 - Θ - @inbounds Θ1 * y₀ + Θ * (y₁ + Θ1 * (k[1] + Θ * (k[2] + Θ * k[3]))) -end - -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rosenbrock5Cache, idxs::Nothing, - T::Type{Val{0}}, differential_vars) - Θ1 = 1 - Θ - @inbounds @.. broadcast=false Θ1 * y₀+Θ * (y₁ + Θ1 * (k[1] + Θ * (k[2] + Θ * k[3]))) -end - -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock5ConstantCache, Rosenbrock5Cache}, - idxs, T::Type{Val{0}}, differential_vars) - Θ1 = 1 - Θ - @.. broadcast=false Θ1 * - y₀[idxs]+Θ * (y₁[idxs] + - Θ1 * (k[1][idxs] + Θ * (k[2][idxs] + Θ * k[3][idxs]))) -end - -@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock5ConstantCache, Rosenbrock5Cache}, - idxs::Nothing, T::Type{Val{0}}, differential_vars) - Θ1 = 1 - Θ - @.. broadcast=false out=Θ1 * y₀ + Θ * (y₁ + Θ1 * (k[1] + Θ * (k[2] + Θ * k[3]))) - out -end - -@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock5ConstantCache, Rosenbrock5Cache}, - idxs, T::Type{Val{0}}, differential_vars) - Θ1 = 1 - Θ - @views @.. broadcast=false out=Θ1 * y₀[idxs] + - Θ * (y₁[idxs] + - Θ1 * (k[1][idxs] + Θ * (k[2][idxs] + Θ * k[3][idxs]))) - out -end - -# First Derivative -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rosenbrock5ConstantCache, - idxs::Nothing, T::Type{Val{1}}, differential_vars) - @inbounds (k[1] + - Θ * (-2 * k[1] + 2 * k[2] + Θ * (-3 * k[2] + 3 * k[3] - 4 * Θ * k[3])) - y₀ + - y₁) / dt -end - -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rosenbrock5Cache, idxs::Nothing, - T::Type{Val{1}}, differential_vars) - @inbounds @.. broadcast=false (k[1] + - Θ * (-2 * k[1] + 2 * k[2] + - Θ * (-3 * k[2] + 3 * k[3] - 4 * Θ * k[3])) - y₀ + y₁)/dt -end - -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock5ConstantCache, Rosenbrock5Cache}, - idxs, T::Type{Val{1}}, differential_vars) - @.. broadcast=false (k[1][idxs] + - Θ * (-2 * k[1][idxs] + 2 * k[2][idxs] + - Θ * (-3 * k[2][idxs] + 3 * k[3][idxs] - 4 * Θ * k[3][idxs])) - - y₀[idxs] + y₁[idxs])/dt -end - -@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock5ConstantCache, Rosenbrock5Cache}, - idxs::Nothing, T::Type{Val{1}}, differential_vars) - @.. broadcast=false out=(k[1] + - Θ * (-2 * k[1] + 2 * k[2] + - Θ * (-3 * k[2] + 3 * k[3] - 4 * Θ * k[3])) - y₀ + y₁) / dt - out -end - -@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock5ConstantCache, Rosenbrock5Cache}, - idxs, T::Type{Val{1}}, differential_vars) - @views @.. broadcast=false out=(k[1][idxs] + - Θ * (-2 * k[1][idxs] + 2 * k[2][idxs] + - Θ * - (-3 * k[2][idxs] + 3 * k[3][idxs] - 4 * Θ * k[3][idxs])) - - y₀[idxs] + y₁[idxs]) / dt + if !hasproperty(cache, :interp_order) || cache.interp_order == 2 + @views @.. out=(k[1][idxs] + + Θ * + (-2 * k[1][idxs] + 2 * k[2][idxs] - + 3 * k[2][idxs] * Θ) - + y₀[idxs] + y₁[idxs]) / dt + else + @views @.. broadcast=false out=(k[1][idxs] + + Θ * (-2 * k[1][idxs] + 2 * k[2][idxs] + + Θ * + (-3 * k[2][idxs] + 3 * k[3][idxs] - 4 * Θ * k[3][idxs])) - + y₀[idxs] + y₁[idxs]) / dt + end out end # Second Derivative -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rosenbrock5ConstantCache, +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::RosenbrockCombinedConstantCache, idxs::Nothing, T::Type{Val{2}}, differential_vars) @inbounds (-2 * k[1] + 2 * k[2] + Θ * (-6 * k[2] + 6 * k[3] - 12 * Θ * k[3])) / dt^2 end -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rosenbrock5Cache, idxs::Nothing, +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::RosenbrockCache, idxs::Nothing, T::Type{Val{2}}, differential_vars) @inbounds @.. broadcast=false (-2 * k[1] + 2 * k[2] + Θ * (-6 * k[2] + 6 * k[3] - 12 * Θ * k[3]))/dt^2 end @muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock5ConstantCache, Rosenbrock5Cache}, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache}, idxs, T::Type{Val{2}}, differential_vars) @.. broadcast=false (-2 * k[1][idxs] + 2 * k[2][idxs] + Θ * (-6 * k[2][idxs] + 6 * k[3][idxs] - 12 * Θ * k[3][idxs]))/dt^2 end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock5ConstantCache, Rosenbrock5Cache}, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache}, idxs::Nothing, T::Type{Val{2}}, differential_vars) @.. broadcast=false out=(-2 * k[1] + 2 * k[2] + Θ * (-6 * k[2] + 6 * k[3] - 12 * Θ * k[3])) / dt^2 @@ -325,7 +263,7 @@ end end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock5ConstantCache, Rosenbrock5Cache}, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache}, idxs, T::Type{Val{2}}, differential_vars) @views @.. broadcast=false out=(-2 * k[1][idxs] + 2 * k[2][idxs] + Θ * @@ -335,31 +273,31 @@ end end # Third Derivative -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rosenbrock5ConstantCache, +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::RosenbrockCombinedConstantCache, idxs::Nothing, T::Type{Val{3}}, differential_vars) @inbounds (-6 * k[2] + 6 * k[3] - 24 * Θ * k[3]) / dt^3 end -@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rosenbrock5Cache, idxs::Nothing, +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::RosenbrockCache, idxs::Nothing, T::Type{Val{3}}, differential_vars) @inbounds @.. broadcast=false (-6 * k[2] + 6 * k[3] - 24 * Θ * k[3])/dt^3 end @muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock5ConstantCache, Rosenbrock5Cache}, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache}, idxs, T::Type{Val{3}}, differential_vars) @.. broadcast=false (-6 * k[2][idxs] + 6 * k[3][idxs] - 24 * Θ * k[3][idxs])/dt^3 end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock5ConstantCache, Rosenbrock5Cache}, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache}, idxs::Nothing, T::Type{Val{3}}, differential_vars) @.. broadcast=false out=(-6 * k[2] + 6 * k[3] - 24 * Θ * k[3]) / dt^3 out end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rosenbrock5ConstantCache, Rosenbrock5Cache}, + cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache}, idxs, T::Type{Val{3}}, differential_vars) @views @.. broadcast=false out=(-6 * k[2][idxs] + 6 * k[3][idxs] - 24 * Θ * k[3][idxs]) / dt^3 diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index 06e1a96e39..1414a93f92 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -1202,18 +1202,19 @@ end #### Rodas4 type method -function initialize!(integrator, cache::Rodas4ConstantCache) - integrator.kshortsize = 2 +function initialize!(integrator, cache::RosenbrockCombinedConstantCache) + integrator.kshortsize = size(cache.tab.H, 1) integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) # Avoid undefined entries if k is an array of arrays - integrator.k[1] = zero(integrator.u) - integrator.k[2] = zero(integrator.u) + for i in 1:integrator.kshortsize + integrator.k[i] = zero(integrator.u) + end end -@muladd function perform_step!(integrator, cache::Rodas4ConstantCache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; tf, uf) = cache - (; A, C, gamma, c, d, H) = cache.tab +@muladd function perform_step!(integrator, cache::RosenbrockCombinedConstantCache, repeat_step = false) + (;t, dt, uprev, u, f, p) = integrator + (;tf, uf) = cache + (;A, C, gamma, c, d, H) = cache.tab # Precalculations dtC = C ./ dt @@ -1235,6 +1236,7 @@ end # Initialize ks num_stages = size(A, 1) du = f(uprev, p, t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) linsolve_tmp = @.. du + dtd[1] * dT k1 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) # constant number for type stability make sure this is greater than num_stages @@ -1284,6 +1286,21 @@ end integrator.k[j] = @.. integrator.k[j] + H[j, i] * ks[i] end end + if (integrator.alg isa Rodas5Pr) && integrator.opts.adaptive && + (integrator.EEst < 1.0) + k2 = 0.5 * (uprev + u + + 0.5 * (integrator.k[1] + 0.5 * (integrator.k[2] + 0.5 * integrator.k[3]))) + du1 = (0.25 * (integrator.k[2] + integrator.k[3]) - uprev + u) / dt + du = f(k2, p, t + dt / 2) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + if mass_matrix === I + du2 = du1 - du + else + du2 = mass_matrix * du1 - du + end + EEst = norm(du2) / norm(integrator.opts.abstol .+ integrator.opts.reltol .* k2) + integrator.EEst = max(EEst, integrator.EEst) + end end integrator.u = u @@ -1291,12 +1308,11 @@ end end function initialize!(integrator, cache::RosenbrockCache) - dense = cache.dense - dense1, dense2 = dense[1], dense[2] - integrator.kshortsize = 2 + integrator.kshortsize = size(cache.tab.H, 1) resize!(integrator.k, integrator.kshortsize) - integrator.k[1] = dense1 - integrator.k[2] = dense2 + for i in 1:integrator.kshortsize + integrator.k[i] = cache.dense[i] + end end @muladd function perform_step!(integrator, cache::RosenbrockCache, repeat_step = false) @@ -1365,11 +1381,18 @@ end @.. $(_vec(ks[stage])) = -linres.u integrator.stats.nsolve += 1 end + du .= ks[end] u .+= ks[end] step_limiter!(u, integrator, p, t + dt) if integrator.opts.adaptive + if (integrator.alg isa Rodas5Pe) + @.. du = 0.2606326497975715 * ks[1] - 0.005158627295444251 * ks[2] + + 1.3038988631109731 * ks[3] + 1.235000722062074 * ks[4] + + -0.7931985603795049 * ks[5] - 1.005448461135913 * ks[6] - + 0.18044626132120234 * ks[7] + 0.17051519239113755 * ks[8] + end calculate_residuals!(atmp, ks[end], uprev, u, integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) integrator.EEst = integrator.opts.internalnorm(atmp, t) @@ -1384,481 +1407,22 @@ end @.. integrator.k[j] += H[j, i] * ks[i] end end - end - cache.linsolve = linres.cache -end - -############################################################################### - -### Rodas5 Method - -function initialize!(integrator, cache::Rosenbrock5ConstantCache) - integrator.kshortsize = 3 - integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) - # Avoid undefined entries if k is an array of arrays - integrator.k[1] = zero(integrator.u) - integrator.k[2] = zero(integrator.u) - integrator.k[3] = zero(integrator.u) -end - -@muladd function perform_step!(integrator, cache::Rosenbrock5ConstantCache, - repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - @unpack tf, uf = cache - @unpack a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, C61, C62, C63, C64, C65, C71, C72, C73, C74, C75, C76, C81, C82, C83, C84, C85, C86, C87, gamma, d1, d2, d3, d4, d5, c2, c3, c4, c5 = cache.tab - - # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - dtC51 = C51 / dt - dtC52 = C52 / dt - dtC53 = C53 / dt - dtC54 = C54 / dt - dtC61 = C61 / dt - dtC62 = C62 / dt - dtC63 = C63 / dt - dtC64 = C64 / dt - dtC65 = C65 / dt - dtC71 = C71 / dt - dtC72 = C72 / dt - dtC73 = C73 / dt - dtC74 = C74 / dt - dtC75 = C75 / dt - dtC76 = C76 / dt - dtC81 = C81 / dt - dtC82 = C82 / dt - dtC83 = C83 / dt - dtC84 = C84 / dt - dtC85 = C85 / dt - dtC86 = C86 / dt - dtC87 = C87 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtd4 = dt * d4 - dtd5 = dt * d5 - dtgamma = dt * gamma - - mass_matrix = integrator.f.mass_matrix - - # Time derivative - dT = calc_tderivative(integrator, cache) - - W = calc_W(integrator, cache, dtgamma, repeat_step) - if !issuccess_W(W) - integrator.EEst = 2 - return nothing - end - - du1 = f(uprev, p, t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - linsolve_tmp = du1 + dtd1 * dT - - k1 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a21 * k1 - du = f(u, p, t + c2 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + dtd2 * dT + dtC21 * k1 - else - linsolve_tmp = du + dtd2 * dT + mass_matrix * (dtC21 * k1) - end - - k2 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a31 * k1 + a32 * k2 - du = f(u, p, t + c3 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + dtd3 * dT + (dtC31 * k1 + dtC32 * k2) - else - linsolve_tmp = du + dtd3 * dT + mass_matrix * (dtC31 * k1 + dtC32 * k2) - end - - k3 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a41 * k1 + a42 * k2 + a43 * k3 - du = f(u, p, t + c4 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + dtd4 * dT + (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - else - linsolve_tmp = du + dtd4 * dT + mass_matrix * (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - end - - k4 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4 - du = f(u, p, t + c5 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + dtd5 * dT + (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) - else - linsolve_tmp = du + dtd5 * dT + - mass_matrix * (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) - end - - k5 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5 - du = f(u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + (dtC61 * k1 + dtC62 * k2 + dtC63 * k3 + dtC64 * k4 + dtC65 * k5) - else - linsolve_tmp = du + - mass_matrix * - (dtC61 * k1 + dtC62 * k2 + dtC63 * k3 + dtC64 * k4 + dtC65 * k5) - end - - k6 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = u + k6 - du = f(u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + - (dtC71 * k1 + dtC72 * k2 + dtC73 * k3 + dtC74 * k4 + dtC75 * k5 + - dtC76 * k6) - else - linsolve_tmp = du + - mass_matrix * - (dtC71 * k1 + dtC72 * k2 + dtC73 * k3 + dtC74 * k4 + dtC75 * k5 + - dtC76 * k6) - end - - k7 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = u + k7 - du = f(u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + - (dtC81 * k1 + dtC82 * k2 + dtC83 * k3 + dtC84 * k4 + dtC85 * k5 + - dtC86 * k6 + dtC87 * k7) - else - linsolve_tmp = du + - mass_matrix * - (dtC81 * k1 + dtC82 * k2 + dtC83 * k3 + dtC84 * k4 + dtC85 * k5 + - dtC86 * k6 + dtC87 * k7) - end - - k8 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = u + k8 - linsolve_tmp = k8 - - if integrator.opts.adaptive - if (integrator.alg isa Rodas5Pe) - linsolve_tmp = 0.2606326497975715 * k1 - 0.005158627295444251 * k2 + - 1.3038988631109731 * k3 + 1.235000722062074 * k4 + - -0.7931985603795049 * k5 - 1.005448461135913 * k6 - - 0.18044626132120234 * k7 + 0.17051519239113755 * k8 - end - atmp = calculate_residuals(linsolve_tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.opts.calck - @unpack h21, h22, h23, h24, h25, h26, h27, h28, h31, h32, h33, h34, h35, h36, h37, h38, h41, h42, h43, h44, h45, h46, h47, h48 = cache.tab - integrator.k[1] = h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + h25 * k5 + h26 * k6 + - h27 * k7 + h28 * k8 - integrator.k[2] = h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5 + h36 * k6 + - h37 * k7 + h38 * k8 - integrator.k[3] = h41 * k1 + h42 * k2 + h43 * k3 + h44 * k4 + h45 * k5 + h46 * k6 + - h47 * k7 + h48 * k8 - if (integrator.alg isa Rodas5Pr) && integrator.opts.adaptive && - (integrator.EEst < 1.0) - k2 = 0.5 * (uprev + u + - 0.5 * (integrator.k[1] + 0.5 * (integrator.k[2] + 0.5 * integrator.k[3]))) - du1 = (0.25 * (integrator.k[2] + integrator.k[3]) - uprev + u) / dt - du = f(k2, p, t + dt / 2) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - if mass_matrix === I - du2 = du1 - du - else - du2 = mass_matrix * du1 - du - end - EEst = norm(du2) / norm(integrator.opts.abstol .+ integrator.opts.reltol .* k2) - integrator.EEst = max(EEst, integrator.EEst) - end - end - - integrator.u = u - return nothing -end - -function initialize!(integrator, cache::Rosenbrock5Cache) - integrator.kshortsize = 3 - @unpack dense1, dense2, dense3 = cache - resize!(integrator.k, integrator.kshortsize) - integrator.k[1] = dense1 - integrator.k[2] = dense2 - integrator.k[3] = dense3 -end - -@muladd function perform_step!(integrator, cache::Rosenbrock5Cache, repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - @unpack du, du1, du2, k1, k2, k3, k4, k5, k6, k7, k8, dT, J, W, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache - @unpack a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, C61, C62, C63, C64, C65, C71, C72, C73, C74, C75, C76, C81, C82, C83, C84, C85, C86, C87, gamma, d1, d2, d3, d4, d5, c2, c3, c4, c5 = cache.tab - - # Assignments - sizeu = size(u) - uidx = eachindex(integrator.uprev) - mass_matrix = integrator.f.mass_matrix - - # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - dtC51 = C51 / dt - dtC52 = C52 / dt - dtC53 = C53 / dt - dtC54 = C54 / dt - dtC61 = C61 / dt - dtC62 = C62 / dt - dtC63 = C63 / dt - dtC64 = C64 / dt - dtC65 = C65 / dt - dtC71 = C71 / dt - dtC72 = C72 / dt - dtC73 = C73 / dt - dtC74 = C74 / dt - dtC75 = C75 / dt - dtC76 = C76 / dt - dtC81 = C81 / dt - dtC82 = C82 / dt - dtC83 = C83 / dt - dtC84 = C84 / dt - dtC85 = C85 / dt - dtC86 = C86 / dt - dtC87 = C87 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtd4 = dt * d4 - dtd5 = dt * d5 - dtgamma = dt * gamma - - f(cache.fsalfirst, uprev, p, t) # used in calc_rosenbrock_differentiation! - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step) - - calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t) - - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = cache.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = cache.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - end - - vecu = _vec(linres.u) - veck1 = _vec(k1) - - @.. broadcast=false veck1=-vecu - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + a21 * k1 - stage_limiter!(u, integrator, p, t + c2 * dt) - f(du, u, p, t + c2 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd2 * dT + dtC21 * k1 - else - @.. broadcast=false du1=dtC21 * k1 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + dtd2 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - veck2 = _vec(k2) - @.. broadcast=false veck2=-vecu - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + a31 * k1 + a32 * k2 - stage_limiter!(u, integrator, p, t + c3 * dt) - f(du, u, p, t + c3 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd3 * dT + (dtC31 * k1 + dtC32 * k2) - else - @.. broadcast=false du1=dtC31 * k1 + dtC32 * k2 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + dtd3 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - veck3 = _vec(k3) - @.. broadcast=false veck3=-vecu - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + a41 * k1 + a42 * k2 + a43 * k3 - stage_limiter!(u, integrator, p, t + c4 * dt) - f(du, u, p, t + c4 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd4 * dT + - (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - else - @.. broadcast=false du1=dtC41 * k1 + dtC42 * k2 + dtC43 * k3 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + dtd4 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - veck4 = _vec(k4) - @.. broadcast=false veck4=-vecu - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4 - stage_limiter!(u, integrator, p, t + c5 * dt) - f(du, u, p, t + c5 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd5 * dT + - (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) - else - @.. broadcast=false du1=dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + dtd5 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - veck5 = _vec(k5) - @.. broadcast=false veck5=-vecu - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5 - stage_limiter!(u, integrator, p, t + dt) - f(du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + (dtC61 * k1 + dtC62 * k2 + dtC63 * k3 + - dtC64 * k4 + dtC65 * k5) - else - @.. broadcast=false du1=dtC61 * k1 + dtC62 * k2 + dtC63 * k3 + dtC64 * k4 + - dtC65 * k5 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - veck6 = _vec(k6) - @.. broadcast=false veck6=-vecu - integrator.stats.nsolve += 1 - - u .+= k6 - stage_limiter!(u, integrator, p, t + dt) - f(du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + (dtC71 * k1 + dtC72 * k2 + dtC73 * k3 + - dtC74 * k4 + dtC75 * k5 + dtC76 * k6) - else - @.. broadcast=false du1=dtC71 * k1 + dtC72 * k2 + dtC73 * k3 + dtC74 * k4 + - dtC75 * k5 + dtC76 * k6 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - veck7 = _vec(k7) - @.. broadcast=false veck7=-vecu - integrator.stats.nsolve += 1 - - u .+= k7 - f(du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + (dtC81 * k1 + dtC82 * k2 + dtC83 * k3 + - dtC84 * k4 + dtC85 * k5 + dtC86 * k6 + dtC87 * k7) - else - @.. broadcast=false du1=dtC81 * k1 + dtC82 * k2 + dtC83 * k3 + dtC84 * k4 + - dtC85 * k5 + dtC86 * k6 + dtC87 * k7 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - veck8 = _vec(k8) - @.. broadcast=false veck8=-vecu - integrator.stats.nsolve += 1 - - du .= k8 - u .+= k8 - - step_limiter!(u, integrator, p, t + dt) - - if integrator.opts.adaptive - if (integrator.alg isa Rodas5Pe) - @. du = 0.2606326497975715 * k1 - 0.005158627295444251 * k2 + - 1.3038988631109731 * k3 + 1.235000722062074 * k4 + - -0.7931985603795049 * k5 - 1.005448461135913 * k6 - - 0.18044626132120234 * k7 + 0.17051519239113755 * k8 - end - calculate_residuals!(atmp, du, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.opts.calck - @unpack h21, h22, h23, h24, h25, h26, h27, h28, h31, h32, h33, h34, h35, h36, h37, h38, h41, h42, h43, h44, h45, h46, h47, h48 = cache.tab - @.. broadcast=false integrator.k[1]=h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + - h25 * k5 + h26 * k6 + h27 * k7 + h28 * k8 - @.. broadcast=false integrator.k[2]=h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + - h35 * k5 + h36 * k6 + h37 * k7 + h38 * k8 - @.. broadcast=false integrator.k[3]=h41 * k1 + h42 * k2 + h43 * k3 + h44 * k4 + - h45 * k5 + h46 * k6 + h47 * k7 + h48 * k8 if (integrator.alg isa Rodas5Pr) && integrator.opts.adaptive && - (integrator.EEst < 1.0) - k2 = 0.5 * (uprev + u + - 0.5 * (integrator.k[1] + 0.5 * (integrator.k[2] + 0.5 * integrator.k[3]))) - du1 = (0.25 * (integrator.k[2] + integrator.k[3]) - uprev + u) / dt - f(du, k2, p, t + dt / 2) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - if mass_matrix === I - du2 = du1 - du - else - mul!(_vec(du2), mass_matrix, _vec(du1)) - du2 = du2 - du - end - EEst = norm(du2) / norm(integrator.opts.abstol .+ integrator.opts.reltol .* k2) - integrator.EEst = max(EEst, integrator.EEst) - end + (integrator.EEst < 1.0) + ks[2] = 0.5 * (uprev + u + + 0.5 * (integrator.k[1] + 0.5 * (integrator.k[2] + 0.5 * integrator.k[3]))) + du1 = (0.25 * (integrator.k[2] + integrator.k[3]) - uprev + u) / dt + f(du, ks[2], p, t + dt / 2) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + if mass_matrix === I + @.. du2 = du1 - du + else + mul!(_vec(du2), mass_matrix, _vec(du1)) + @.. du2 -= du + end + EEst = norm(du2) / norm(integrator.opts.abstol .+ integrator.opts.reltol .* ks[2]) + integrator.EEst = max(EEst, integrator.EEst) + end end cache.linsolve = linres.cache end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl index f64d5e1165..d7968cc572 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl @@ -236,18 +236,22 @@ function Rodas4Tableau(T, T2) #BET2P=0.0317D0 #BET3P=0.0635D0 #BET4P=0.3438D0 - A = T[0 0 0 0 0 0 - 1.544 0 0 0 0 0 - 0.9466785280815826 0.2557011698983284 0 0 0 0 - 3.314825187068521 2.896124015972201 0.9986419139977817 0 0 0 - 1.221224509226641 6.019134481288629 12.53708332932087 -0.6878860361058950 0 0 - 1.221224509226641 6.019134481288629 12.53708332932087 -0.6878860361058950 1 0] - C = T[0 0 0 0 0 - -5.6688 0 0 0 0 - -2.430093356833875 -0.2063599157091915 0 0 0 - -0.1073529058151375 -9.594562251023355 -20.47028614809616 0 0 - 7.496443313967647 -10.24680431464352 -33.99990352819905 11.70890893206160 0 - 8.083246795921522 -7.981132988064893 -31.52159432874371 16.31930543123136 -6.058818238834054] + A = T[ + 0 0 0 0 0 0 + 1.544 0 0 0 0 0 + 0.9466785280815826 0.2557011698983284 0 0 0 0 + 3.314825187068521 2.896124015972201 0.9986419139977817 0 0 0 + 1.221224509226641 6.019134481288629 12.53708332932087 -0.6878860361058950 0 0 + 1.221224509226641 6.019134481288629 12.53708332932087 -0.6878860361058950 1 0 + ] + C = T[ + 0 0 0 0 0 + -5.6688 0 0 0 0 + -2.430093356833875 -0.2063599157091915 0 0 0 + -0.1073529058151375 -9.594562251023355 -20.47028614809616 0 0 + 7.496443313967647 -10.24680431464352 -33.99990352819905 11.70890893206160 0 + 8.083246795921522 -7.981132988064893 -31.52159432874371 16.31930543123136 -6.058818238834054 + ] c = T2[0, 0.386, 0.21, 0.63, 1, 1] d = T[0.25, -0.1043, 0.1035, -0.0362, 0, 0] H = T[10.12623508344586 -7.487995877610167 -34.80091861555747 -7.992771707568823 1.025137723295662 0 @@ -321,168 +325,36 @@ function Rodas4P2Tableau(T, T2) RodasTableau(A, C, gamma, c, d, H) end -struct Rodas5Tableau{T, T2} - a21::T - a31::T - a32::T - a41::T - a42::T - a43::T - a51::T - a52::T - a53::T - a54::T - a61::T - a62::T - a63::T - a64::T - a65::T - C21::T - C31::T - C32::T - C41::T - C42::T - C43::T - C51::T - C52::T - C53::T - C54::T - C61::T - C62::T - C63::T - C64::T - C65::T - C71::T - C72::T - C73::T - C74::T - C75::T - C76::T - C81::T - C82::T - C83::T - C84::T - C85::T - C86::T - C87::T - gamma::T2 - d1::T - d2::T - d3::T - d4::T - d5::T - c2::T2 - c3::T2 - c4::T2 - c5::T2 - h21::T - h22::T - h23::T - h24::T - h25::T - h26::T - h27::T - h28::T - h31::T - h32::T - h33::T - h34::T - h35::T - h36::T - h37::T - h38::T - h41::T - h42::T - h43::T - h44::T - h45::T - h46::T - h47::T - h48::T -end - function Rodas5Tableau(T, T2) gamma = convert(T2, 0.19) - a21 = convert(T, 2.0) - a31 = convert(T, 3.040894194418781) - a32 = convert(T, 1.041747909077569) - a41 = convert(T, 2.576417536461461) - a42 = convert(T, 1.622083060776640) - a43 = convert(T, -0.9089668560264532) - a51 = convert(T, 2.760842080225597) - a52 = convert(T, 1.446624659844071) - a53 = convert(T, -0.3036980084553738) - a54 = convert(T, 0.2877498600325443) - a61 = convert(T, -14.09640773051259) - a62 = convert(T, 6.925207756232704) - a63 = convert(T, -41.47510893210728) - a64 = convert(T, 2.343771018586405) - a65 = convert(T, 24.13215229196062) - C21 = convert(T, -10.31323885133993) - C31 = convert(T, -21.04823117650003) - C32 = convert(T, -7.234992135176716) - C41 = convert(T, 32.22751541853323) - C42 = convert(T, -4.943732386540191) - C43 = convert(T, 19.44922031041879) - C51 = convert(T, -20.69865579590063) - C52 = convert(T, -8.816374604402768) - C53 = convert(T, 1.260436877740897) - C54 = convert(T, -0.7495647613787146) - C61 = convert(T, -46.22004352711257) - C62 = convert(T, -17.49534862857472) - C63 = convert(T, -289.6389582892057) - C64 = convert(T, 93.60855400400906) - C65 = convert(T, 318.3822534212147) - C71 = convert(T, 34.20013733472935) - C72 = convert(T, -14.15535402717690) - C73 = convert(T, 57.82335640988400) - C74 = convert(T, 25.83362985412365) - C75 = convert(T, 1.408950972071624) - C76 = convert(T, -6.551835421242162) - C81 = convert(T, 42.57076742291101) - C82 = convert(T, -13.80770672017997) - C83 = convert(T, 93.98938432427124) - C84 = convert(T, 18.77919633714503) - C85 = convert(T, -31.58359187223370) - C86 = convert(T, -6.685968952921985) - C87 = convert(T, -5.810979938412932) - c2 = convert(T2, 0.38) - c3 = convert(T2, 0.3878509998321533) - c4 = convert(T2, 0.4839718937873840) - c5 = convert(T2, 0.4570477008819580) - d1 = convert(T, gamma) - d2 = convert(T, -0.1823079225333714636) - d3 = convert(T, -0.319231832186874912) - d4 = convert(T, 0.3449828624725343) - d5 = convert(T, -0.377417564392089818) - - h21 = convert(T, 27.354592673333357) - h22 = convert(T, -6.925207756232857) - h23 = convert(T, 26.40037733258859) - h24 = convert(T, 0.5635230501052979) - h25 = convert(T, -4.699151156849391) - h26 = convert(T, -1.6008677469422725) - h27 = convert(T, -1.5306074446748028) - h28 = convert(T, -1.3929872940716344) - - h31 = convert(T, 44.19024239501722) - h32 = convert(T, 1.3677947663381929e-13) - h33 = convert(T, 202.93261852171622) - h34 = convert(T, -35.5669339789154) - h35 = convert(T, -181.91095152160645) - h36 = convert(T, 3.4116351403665033) - h37 = convert(T, 2.5793540257308067) - h38 = convert(T, 2.2435122582734066) - - h41 = convert(T, -44.0988150021747) - h42 = convert(T, -5.755396159656812e-13) - h43 = convert(T, -181.26175034586677) - h44 = convert(T, 56.99302194811676) - h45 = convert(T, 183.21182741427398) - h46 = convert(T, -7.480257918273637) - h47 = convert(T, -5.792426076169686) - h48 = convert(T, -5.32503859794143) - + A = T[ + 0 0 0 0 0 0 0 0 + 2.0 0 0 0 0 0 0 0 + 3.040894194418781 1.041747909077569 0 0 0 0 0 0 + 2.576417536461461 1.622083060776640 -0.9089668560264532 0 0 0 0 0 + 2.760842080225597 1.446624659844071 -0.3036980084553738 0.2877498600325443 0 0 0 0 + -14.09640773051259 6.925207756232704 -41.47510893210728 2.343771018586405 24.13215229196062 0 0 0 + -14.09640773051259 6.925207756232704 -41.47510893210728 2.343771018586405 24.13215229196062 1 0 0 + -14.09640773051259 6.925207756232704 -41.47510893210728 2.343771018586405 24.13215229196062 1 1 0 + ] + C = T[ + 0 0 0 0 0 0 0 + -10.31323885133993 0 0 0 0 0 0 + -21.04823117650003 -7.234992135176716 0 0 0 0 0 + 32.22751541853323 -4.943732386540191 19.44922031041879 0 0 0 0 + -20.69865579590063 -8.816374604402768 1.260436877740897 -0.7495647613787146 0 0 0 + -46.22004352711257 -17.49534862857472 -289.6389582892057 93.60855400400906 318.3822534212147 0 0 + 34.20013733472935 -14.15535402717690 57.82335640988400 25.83362985412365 1.408950972071624 -6.551835421242162 0 + 42.57076742291101 -13.80770672017997 93.98938432427124 18.77919633714503 -31.58359187223370 -6.685968952921985 -5.810979938412932 + ] + c = T2[0, 0.38, 0.3878509998321533, 0.4839718937873840, 0.4570477008819580, 1, 1, 1] + d = T[gamma, -0.1823079225333714636, -0.319231832186874912, 0.3449828624725343, -0.377417564392089818, 0, 0, 0] + + H = T[ + 27.354592673333357 -6.925207756232857 26.40037733258859 0.5635230501052979 -4.699151156849391 -1.6008677469422725 -1.5306074446748028 -1.3929872940716344 + 44.19024239501722 1.3677947663381929e-13 202.93261852171622 -35.5669339789154 -181.91095152160645 3.4116351403665033 2.5793540257308067 2.2435122582734066 + -44.0988150021747 -5.755396159656812e-13 -181.26175034586677 56.99302194811676 183.21182741427398 -7.480257918273637 -5.792426076169686 -5.32503859794143 + ] # println("---Rodas5---") #= @@ -508,111 +380,39 @@ function Rodas5Tableau(T, T2) b7 = convert(T,1) b8 = convert(T,1) =# - - Rodas5Tableau(a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, - a61, a62, a63, a64, a65, - C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, - C61, C62, C63, C64, C65, C71, C72, C73, C74, C75, C76, - C81, C82, C83, C84, C85, C86, C87, - gamma, d1, d2, d3, d4, d5, c2, c3, c4, c5, - h21, h22, h23, h24, h25, h26, h27, h28, h31, h32, h33, h34, h35, h36, h37, - h38, h41, h42, h43, h44, h45, h46, h47, h48) + RodasTableau(A, C, gamma, c, d, H) end function Rodas5PTableau(T, T2) gamma = convert(T2, 0.21193756319429014) - - a21 = convert(T, 3.0) - a31 = convert(T, 2.849394379747939) - a32 = convert(T, 0.45842242204463923) - a41 = convert(T, -6.954028509809101) - a42 = convert(T, 2.489845061869568) - a43 = convert(T, -10.358996098473584) - a51 = convert(T, 2.8029986275628964) - a52 = convert(T, 0.5072464736228206) - a53 = convert(T, -0.3988312541770524) - a54 = convert(T, -0.04721187230404641) - a61 = convert(T, -7.502846399306121) - a62 = convert(T, 2.561846144803919) - a63 = convert(T, -11.627539656261098) - a64 = convert(T, -0.18268767659942256) - a65 = convert(T, 0.030198172008377946) - - C21 = convert(T, -14.155112264123755) - C31 = convert(T, -17.97296035885952) - C32 = convert(T, -2.859693295451294) - C41 = convert(T, 147.12150275711716) - C42 = convert(T, -1.41221402718213) - C43 = convert(T, 71.68940251302358) - C51 = convert(T, 165.43517024871676) - C52 = convert(T, -0.4592823456491126) - C53 = convert(T, 42.90938336958603) - C54 = convert(T, -5.961986721573306) - C61 = convert(T, 24.854864614690072) - C62 = convert(T, -3.0009227002832186) - C63 = convert(T, 47.4931110020768) - C64 = convert(T, 5.5814197821558125) - C65 = convert(T, -0.6610691825249471) - C71 = convert(T, 30.91273214028599) - C72 = convert(T, -3.1208243349937974) - C73 = convert(T, 77.79954646070892) - C74 = convert(T, 34.28646028294783) - C75 = convert(T, -19.097331116725623) - C76 = convert(T, -28.087943162872662) - C81 = convert(T, 37.80277123390563) - C82 = convert(T, -3.2571969029072276) - C83 = convert(T, 112.26918849496327) - C84 = convert(T, 66.9347231244047) - C85 = convert(T, -40.06618937091002) - C86 = convert(T, -54.66780262877968) - C87 = convert(T, -9.48861652309627) - - c2 = convert(T2, 0.6358126895828704) - c3 = convert(T2, 0.4095798393397535) - c4 = convert(T2, 0.9769306725060716) - c5 = convert(T2, 0.4288403609558664) - - d1 = convert(T, 0.21193756319429014) - d2 = convert(T, -0.42387512638858027) - d3 = convert(T, -0.3384627126235924) - d4 = convert(T, 1.8046452872882734) - d5 = convert(T, 2.325825639765069) - - h21 = convert(T, 25.948786856663858) - h22 = convert(T, -2.5579724845846235) - h23 = convert(T, 10.433815404888879) - h24 = convert(T, -2.3679251022685204) - h25 = convert(T, 0.524948541321073) - h26 = convert(T, 1.1241088310450404) - h27 = convert(T, 0.4272876194431874) - h28 = convert(T, -0.17202221070155493) - - h31 = convert(T, -9.91568850695171) - h32 = convert(T, -0.9689944594115154) - h33 = convert(T, 3.0438037242978453) - h34 = convert(T, -24.495224566215796) - h35 = convert(T, 20.176138334709044) - h36 = convert(T, 15.98066361424651) - h37 = convert(T, -6.789040303419874) - h38 = convert(T, -6.710236069923372) - - h41 = convert(T, 11.419903575922262) - h42 = convert(T, 2.8879645146136994) - h43 = convert(T, 72.92137995996029) - h44 = convert(T, 80.12511834622643) - h45 = convert(T, -52.072871366152654) - h46 = convert(T, -59.78993625266729) - h47 = convert(T, -0.15582684282751913) - h48 = convert(T, 4.883087185713722) - - Rodas5Tableau(a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, - a61, a62, a63, a64, a65, - C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, - C61, C62, C63, C64, C65, C71, C72, C73, C74, C75, C76, - C81, C82, C83, C84, C85, C86, C87, - gamma, d1, d2, d3, d4, d5, c2, c3, c4, c5, - h21, h22, h23, h24, h25, h26, h27, h28, h31, h32, h33, h34, h35, h36, h37, - h38, h41, h42, h43, h44, h45, h46, h47, h48) + A = T[ + 0 0 0 0 0 0 0 0 + 3.0 0 0 0 0 0 0 0 + 2.849394379747939 0.45842242204463923 0 0 0 0 0 0 + -6.954028509809101 2.489845061869568 -10.358996098473584 0 0 0 0 0 + 2.8029986275628964 0.5072464736228206 -0.3988312541770524 -0.04721187230404641 0 0 0 0 + -7.502846399306121 2.561846144803919 -11.627539656261098 -0.18268767659942256 0.030198172008377946 0 0 0 + -7.502846399306121 2.561846144803919 -11.627539656261098 -0.18268767659942256 0.030198172008377946 1 0 0 + -7.502846399306121 2.561846144803919 -11.627539656261098 -0.18268767659942256 0.030198172008377946 1 1 0 + ] + C = T[ + 0 0 0 0 0 0 0 + -14.155112264123755 0 0 0 0 0 0 + -17.97296035885952 -2.859693295451294 0 0 0 0 0 + 147.12150275711716 -1.41221402718213 71.68940251302358 0 0 0 0 + 165.43517024871676 -0.4592823456491126 42.90938336958603 -5.961986721573306 0 0 0 + 24.854864614690072 -3.0009227002832186 47.4931110020768 5.5814197821558125 -0.6610691825249471 0 0 + 30.91273214028599 -3.1208243349937974 77.79954646070892 34.28646028294783 -19.097331116725623 -28.087943162872662 0 + 37.80277123390563 -3.2571969029072276 112.26918849496327 66.9347231244047 -40.06618937091002 -54.66780262877968 -9.48861652309627 + ] + c = T2[0, 0.6358126895828704, 0.4095798393397535, 0.9769306725060716, 0.4288403609558664, 1, 1, 1] + d = T[0.21193756319429014, -0.42387512638858027, -0.3384627126235924, 1.8046452872882734, 2.325825639765069, 0, 0, 0] + H = T[ + 25.948786856663858 -2.5579724845846235 10.433815404888879 -2.3679251022685204 0.524948541321073 1.1241088310450404 0.4272876194431874 -0.17202221070155493 + -9.91568850695171 -0.9689944594115154 3.0438037242978453 -24.495224566215796 20.176138334709044 15.98066361424651 -6.789040303419874 -6.710236069923372 + 11.419903575922262 2.8879645146136994 72.92137995996029 80.12511834622643 -52.072871366152654 -59.78993625266729 -0.15582684282751913 4.883087185713722 + ] + RodasTableau(A, C, gamma, c, d, H) end @RosenbrockW6S4OS(:tableau) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl b/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl index d6fca3b14f..d3967b40ee 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl @@ -43,264 +43,10 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, nothing end -function _ode_addsteps!(k, t, uprev, u, dt, f, p, - cache::Union{Rosenbrock23Cache, Rosenbrock32Cache}, - always_calc_begin = false, allow_calc_end = true, - force_calc_end = false) - if length(k) < 2 || always_calc_begin - @unpack k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W, tmp, uf, tf, linsolve_tmp, weight = cache - @unpack c₃₂, d = cache.tab - uidx = eachindex(uprev) - - # Assignments - sizeu = size(u) - mass_matrix = f.mass_matrix - dtγ = dt * d - neginvdtγ = -inv(dtγ) - dto2 = dt / 2 - - @.. linsolve_tmp = @muladd fsalfirst + dtγ * dT - - ### Jacobian does not need to be re-evaluated after an event - ### Since it's unchanged - jacobian2W!(W, mass_matrix, dtγ, J) - - linsolve = cache.linsolve - - linres = dolinsolve(cache, linsolve; A = W, b = _vec(linsolve_tmp), - reltol = cache.reltol) - - vecu = _vec(linres.u) - veck₁ = _vec(k₁) - @.. veck₁ = vecu * neginvdtγ - - @.. tmp = uprev + dto2 * k₁ - f(f₁, tmp, p, t + dto2) - - if mass_matrix === I - tmp .= k₁ - else - mul!(_vec(tmp), mass_matrix, _vec(k₁)) - end - - @.. linsolve_tmp = f₁ - tmp - - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck₂ = _vec(k₂) - - @.. veck₂ = vecu * neginvdtγ + veck₁ - - copyat_or_push!(k, 1, k₁) - copyat_or_push!(k, 2, k₂) - cache.linsolve = linres.cache - end - nothing -end - -function _ode_addsteps!( - k, t, uprev, u, dt, f, p, cache::Union{Rodas23WConstantCache, Rodas3PConstantCache}, - always_calc_begin = false, allow_calc_end = true, - force_calc_end = false) - if length(k) < 2 || always_calc_begin - @unpack tf, uf = cache - @unpack a21, a41, a42, a43, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, gamma, c2, c3, d1, d2, d3 = cache.tab - - # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - dtC51 = C51 / dt - dtC52 = C52 / dt - dtC53 = C53 / dt - dtC54 = C54 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtgamma = dt * gamma - mass_matrix = f.mass_matrix - - # Time derivative - tf.u = uprev - if cache.autodiff isa AutoForwardDiff - dT = ForwardDiff.derivative(tf, t) - else - dT = FiniteDiff.finite_difference_derivative(tf, t, dir = sign(dt)) - end - - # Jacobian - uf.t = t - if uprev isa AbstractArray - J = ForwardDiff.jacobian(uf, uprev) - W = mass_matrix / dtgamma - J - else - J = ForwardDiff.derivative(uf, uprev) - W = 1 / dtgamma - J - end - - du = f(uprev, p, t) - k3 = copy(du) - - linsolve_tmp = du + dtd1 * dT - - k1 = W \ linsolve_tmp - u = uprev + a21 * k1 - du = f(u, p, t + c2 * dt) - - linsolve_tmp = du + dtd2 * dT + dtC21 * k1 - - k2 = W \ linsolve_tmp - - linsolve_tmp = k3 + dtd3 * dT + (dtC31 * k1 + dtC32 * k2) - - k3 = W \ linsolve_tmp - u = uprev + a41 * k1 + a42 * k2 + a43 * k3 - du = f(u, p, t + dt) - - linsolve_tmp = du + (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - - k4 = W \ linsolve_tmp - - linsolve_tmp = du + (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) - - k5 = W \ linsolve_tmp - - @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35, h2_21, h2_22, h2_23, h2_24, h2_25 = cache.tab - k₁ = h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + h25 * k5 - k₂ = h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5 - #k₃ = h2_21 * k1 + h2_22 * k2 + h2_23 * k3 + h2_24 * k4 + h2_25 * k5 - copyat_or_push!(k, 1, k₁) - copyat_or_push!(k, 2, k₂) - #copyat_or_push!(k, 3, k₃) - end - nothing -end - -function _ode_addsteps!( - k, t, uprev, u, dt, f, p, cache::Union{Rodas23WCache, Rodas3PCache}, - always_calc_begin = false, allow_calc_end = true, - force_calc_end = false) - if length(k) < 2 || always_calc_begin - @unpack du, du1, du2, tmp, k1, k2, k3, k4, k5, dT, J, W, uf, tf, linsolve_tmp, jac_config, fsalfirst, weight = cache - @unpack a21, a41, a42, a43, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, gamma, c2, c3, d1, d2, d3 = cache.tab - - # Assignments - sizeu = size(u) - uidx = eachindex(uprev) - mass_matrix = f.mass_matrix - - # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - dtC51 = C51 / dt - dtC52 = C52 / dt - dtC53 = C53 / dt - dtC54 = C54 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtgamma = dt * gamma - - @.. broadcast=false linsolve_tmp=@muladd fsalfirst + dtgamma * dT - - ### Jacobian does not need to be re-evaluated after an event - ### Since it's unchanged - jacobian2W!(W, mass_matrix, dtgamma, J) - - linsolve = cache.linsolve - - linres = dolinsolve(cache, linsolve; A = W, b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck1 = _vec(k1) - - @.. broadcast=false veck1=-vecu - @.. broadcast=false tmp=uprev + a21 * k1 - f(du, tmp, p, t + c2 * dt) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd2 * dT + dtC21 * k1 - else - @.. broadcast=false du1=dtC21 * k1 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=du + dtd2 * dT + du2 - end - - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck2 = _vec(k2) - @.. broadcast=false veck2=-vecu - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=fsalfirst + dtd3 * dT + - (dtC31 * k1 + dtC32 * k2) - else - @.. broadcast=false du1=dtC31 * k1 + dtC32 * k2 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=fsalfirst + dtd3 * dT + du2 - end - - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck3 = _vec(k3) - @.. broadcast=false veck3=-vecu - @.. broadcast=false tmp=uprev + a41 * k1 + a42 * k2 + a43 * k3 - f(du, tmp, p, t + dt) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - else - @.. broadcast=false du1=dtC41 * k1 + dtC42 * k2 + dtC43 * k3 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=du + du2 - end - - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck4 = _vec(k4) - @.. broadcast=false veck4=-vecu - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + - dtC53 * k3) - else - @.. broadcast=false du1=dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=du + du2 - end - - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck5 = _vec(k5) - @.. broadcast=false veck5=-vecu - @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35, h2_21, h2_22, h2_23, h2_24, h2_25 = cache.tab - @.. broadcast=false du=h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + h25 * k5 - copyat_or_push!(k, 1, copy(du)) - - @.. broadcast=false du=h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5 - copyat_or_push!(k, 2, copy(du)) - end - nothing -end - -function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rodas4ConstantCache, +function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::RosenbrockCombinedConstantCache, always_calc_begin = false, allow_calc_end = true, force_calc_end = false) - if length(k) < 2 || always_calc_begin + if length(k) < size(cache.tab.H, 1) || always_calc_begin (; tf, uf) = cache (; A, C, gamma, c, d, H) = cache.tab @@ -328,7 +74,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rodas4ConstantCache, W = 1 / dtgamma - J end - num_stages = size(A, 1) + num_stages = size(A,1) du = f(u, p, t) linsolve_tmp = @.. du + dtd[1] * dT k1 = _reshape(W \ _vec(linsolve_tmp), axes(uprev)) @@ -359,17 +105,14 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rodas4ConstantCache, ks = Base.setindex(ks, _reshape(W \ _vec(linsolve_tmp), axes(uprev)), stage) end - k1 = zero(ks[1]) - k2 = zero(ks[1]) - H = cache.tab.H - # Last stage doesn't affect ks - for i in 1:(num_stages - 1) - k1 = @.. k1 + H[1, i] * ks[i] - k2 = @.. k2 + H[2, i] * ks[i] + for j in 1:size(H, 1) + kj = zero(ks[1]) + # Last stage doesn't affect ks + for i in 1:(num_stages - 1) + kj = @.. kj + H[j, i] * ks[i] + end + copyat_or_push!(k, j, kj) end - - copyat_or_push!(k, 1, k1) - copyat_or_push!(k, 2, k2) end nothing end @@ -385,6 +128,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::RosenbrockCache, sizeu = size(u) uidx = eachindex(uprev) mass_matrix = f.mass_matrix + tmp = ks[end] # Precalculations dtC = C ./ dt @@ -428,349 +172,14 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::RosenbrockCache, @.. $(_vec(ks[stage])) = -linres.u end - copyat_or_push!(k, 1, zero(du)) - copyat_or_push!(k, 2, zero(du)) - # Last stage doesn't affect ks - for i in 1:(length(ks) - 1) - @.. k[1] += H[1, i] * _vec(ks[i]) - @.. k[2] += H[2, i] * _vec(ks[i]) - end - end - nothing -end - -function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rosenbrock5ConstantCache, - always_calc_begin = false, allow_calc_end = true, - force_calc_end = false) - if length(k) < 3 || always_calc_begin - @unpack tf, uf = cache - @unpack a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, C61, C62, C63, C64, C65, C71, C72, C73, C74, C75, C76, C81, C82, C83, C84, C85, C86, C87, gamma, d1, d2, d3, d4, d5, c2, c3, c4, c5 = cache.tab - - # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - dtC51 = C51 / dt - dtC52 = C52 / dt - dtC53 = C53 / dt - dtC54 = C54 / dt - dtC61 = C61 / dt - dtC62 = C62 / dt - dtC63 = C63 / dt - dtC64 = C64 / dt - dtC65 = C65 / dt - dtC71 = C71 / dt - dtC72 = C72 / dt - dtC73 = C73 / dt - dtC74 = C74 / dt - dtC75 = C75 / dt - dtC76 = C76 / dt - dtC81 = C81 / dt - dtC82 = C82 / dt - dtC83 = C83 / dt - dtC84 = C84 / dt - dtC85 = C85 / dt - dtC86 = C86 / dt - dtC87 = C87 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtd4 = dt * d4 - dtd5 = dt * d5 - dtgamma = dt * gamma - mass_matrix = f.mass_matrix - - # Time derivative - tf.u = uprev - # if cache.autodiff isa AutoForwardDiff - # dT = ForwardDiff.derivative(tf, t) - # else - dT = FiniteDiff.finite_difference_derivative(tf, t, dir = sign(dt)) - # end - - # Jacobian - uf.t = t - if uprev isa AbstractArray - J = ForwardDiff.jacobian(uf, uprev) - W = mass_matrix / dtgamma - J - else - J = ForwardDiff.derivative(uf, uprev) - W = 1 / dtgamma - J - end - - du = f(uprev, p, t) - - linsolve_tmp = du + dtd1 * dT - - k1 = W \ linsolve_tmp - u = uprev + a21 * k1 - du = f(u, p, t + c2 * dt) - - linsolve_tmp = du + dtd2 * dT + dtC21 * k1 - - k2 = W \ linsolve_tmp - u = uprev + a31 * k1 + a32 * k2 - du = f(u, p, t + c3 * dt) - - linsolve_tmp = du + dtd3 * dT + (dtC31 * k1 + dtC32 * k2) - - k3 = W \ linsolve_tmp - u = uprev + a41 * k1 + a42 * k2 + a43 * k3 - du = f(u, p, t + c4 * dt) - - linsolve_tmp = du + dtd4 * dT + (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - - k4 = W \ linsolve_tmp - u = uprev + a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4 - du = f(u, p, t + c5 * dt) - - linsolve_tmp = du + dtd5 * dT + (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) - - k5 = W \ linsolve_tmp - u = uprev + a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5 - du = f(u, p, t + dt) - - linsolve_tmp = du + (dtC61 * k1 + dtC62 * k2 + dtC63 * k3 + dtC64 * k4 + dtC65 * k5) - - k6 = W \ linsolve_tmp - u = u + k6 - du = f(u, p, t + dt) - - linsolve_tmp = du + - (dtC71 * k1 + dtC72 * k2 + dtC73 * k3 + dtC74 * k4 + dtC75 * k5 + - dtC76 * k6) - - k7 = W \ linsolve_tmp - - u = u + k7 - du = f(u, p, t + dt) - - linsolve_tmp = du + - (dtC81 * k1 + dtC82 * k2 + dtC83 * k3 + dtC84 * k4 + dtC85 * k5 + - dtC86 * k6 + dtC87 * k7) - - k8 = W \ linsolve_tmp - - @unpack h21, h22, h23, h24, h25, h26, h27, h28, h31, h32, h33, h34, h35, h36, h37, h38, h41, h42, h43, h44, h45, h46, h47, h48 = cache.tab - k₁ = h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + h25 * k5 + h26 * k6 + h27 * k7 + - h28 * k8 - k₂ = h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5 + h36 * k6 + h37 * k7 + - h38 * k8 - k₃ = h41 * k1 + h42 * k2 + h43 * k3 + h44 * k4 + h45 * k5 + h46 * k6 + h47 * k7 + - h48 * k8 - copyat_or_push!(k, 1, k₁) - copyat_or_push!(k, 2, k₂) - copyat_or_push!(k, 3, k₃) - end - nothing -end - -function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rosenbrock5Cache, - always_calc_begin = false, allow_calc_end = true, - force_calc_end = false) - if length(k) < 3 || always_calc_begin - @unpack du, du1, du2, tmp, k1, k2, k3, k4, k5, k6, k7, k8, dT, J, W, uf, tf, linsolve_tmp, jac_config, fsalfirst, weight = cache - @unpack a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, C61, C62, C63, C64, C65, C71, C72, C73, C74, C75, C76, C81, C82, C83, C84, C85, C86, C87, gamma, d1, d2, d3, d4, d5, c2, c3, c4, c5 = cache.tab - - # Assignments - sizeu = size(u) - uidx = eachindex(uprev) - mass_matrix = f.mass_matrix - tmp = k8 # integrator.tmp === linsolve_tmp, aliasing fails due to linsolve mutation - - # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - dtC51 = C51 / dt - dtC52 = C52 / dt - dtC53 = C53 / dt - dtC54 = C54 / dt - dtC61 = C61 / dt - dtC62 = C62 / dt - dtC63 = C63 / dt - dtC64 = C64 / dt - dtC65 = C65 / dt - dtC71 = C71 / dt - dtC72 = C72 / dt - dtC73 = C73 / dt - dtC74 = C74 / dt - dtC75 = C75 / dt - dtC76 = C76 / dt - dtC81 = C81 / dt - dtC82 = C82 / dt - dtC83 = C83 / dt - dtC84 = C84 / dt - dtC85 = C85 / dt - dtC86 = C86 / dt - dtC87 = C87 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtd4 = dt * d4 - dtd5 = dt * d5 - dtgamma = dt * gamma - - @.. broadcast=false linsolve_tmp=@muladd fsalfirst + dtgamma * dT - - ### Jacobian does not need to be re-evaluated after an event - ### Since it's unchanged - jacobian2W!(W, mass_matrix, dtgamma, J) - - linsolve = cache.linsolve - - linres = dolinsolve(cache, linsolve; A = W, b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck1 = _vec(k1) - - @.. broadcast=false veck1=-vecu - @.. broadcast=false tmp=uprev + a21 * k1 - f(du, tmp, p, t + c2 * dt) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd2 * dT + dtC21 * k1 - else - @.. broadcast=false du1=dtC21 * k1 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=du + dtd2 * dT + du2 - end - - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck2 = _vec(k2) - @.. broadcast=false veck2=-vecu - @.. broadcast=false tmp=uprev + a31 * k1 + a32 * k2 - f(du, tmp, p, t + c3 * dt) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd3 * dT + (dtC31 * k1 + dtC32 * k2) - else - @.. broadcast=false du1=dtC31 * k1 + dtC32 * k2 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=du + dtd3 * dT + du2 - end - - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck3 = _vec(k3) - @.. broadcast=false veck3=-vecu - @.. broadcast=false tmp=uprev + a41 * k1 + a42 * k2 + a43 * k3 - f(du, tmp, p, t + c4 * dt) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd4 * dT + - (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - else - @.. broadcast=false du1=dtC41 * k1 + dtC42 * k2 + dtC43 * k3 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=du + dtd4 * dT + du2 - end - - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck4 = _vec(k4) - @.. broadcast=false veck4=-vecu - @.. broadcast=false tmp=uprev + a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4 - f(du, tmp, p, t + c5 * dt) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd5 * dT + - (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + - dtC53 * k3) - else - @.. broadcast=false du1=dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=du + dtd5 * dT + du2 - end - - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck5 = _vec(k5) - @.. broadcast=false veck5=-vecu - @.. broadcast=false tmp=uprev + a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5 - f(du, tmp, p, t + dt) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + (dtC62 * k2 + dtC64 * k4 + dtC61 * k1 + - dtC63 * k3 + dtC65 * k5) - else - @.. broadcast=false du1=dtC62 * k2 + dtC64 * k4 + dtC61 * k1 + dtC63 * k3 + - dtC65 * k5 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=du + du2 - end - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck6 = _vec(k6) - @.. broadcast=false veck6=-vecu - @.. broadcast=false tmp+=k6 - f(du, tmp, p, t + dt) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + (dtC71 * k1 + dtC72 * k2 + dtC73 * k3 + - dtC74 * k4 + dtC75 * k5 + dtC76 * k6) - else - @.. broadcast=false du1=dtC72 * k2 + dtC74 * k4 + dtC71 * k1 + dtC73 * k3 + - dtC75 * k5 + dtC76 * k6 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=du + du2 - end - - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck7 = _vec(k7) - @.. broadcast=false veck7=-vecu - @.. broadcast=false tmp+=k7 - f(du, tmp, p, t + dt) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + (dtC81 * k1 + dtC82 * k2 + dtC83 * k3 + - dtC84 * k4 + dtC85 * k5 + dtC86 * k6 + - dtC87 * k7) - else - @.. broadcast=false du1=dtC81 * k1 + dtC82 * k2 + dtC83 * k3 + dtC84 * k4 + - dtC85 * k5 + dtC86 * k6 + dtC87 * k7 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=du + du2 + for j in 1:size(H, 1) + copyat_or_push!(k, j, zero(du)) + # Last stage doesn't affect ks + for i in 1:(length(ks) - 1) + @.. k[j] += H[j, i] * _vec(ks[i]) + end end - - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck8 = _vec(k8) - @.. broadcast=false veck8=-vecu - - # https://github.com/SciML/OrdinaryDiffEq.jl/issues/2055 - tmp = linsolve_tmp - - @unpack h21, h22, h23, h24, h25, h26, h27, h28, h31, h32, h33, h34, h35, h36, h37, h38, h41, h42, h43, h44, h45, h46, h47, h48 = cache.tab - @.. broadcast=false tmp=h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + h25 * k5 + - h26 * k6 + h27 * k7 + h28 * k8 - copyat_or_push!(k, 1, copy(tmp)) - - @.. broadcast=false tmp=h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5 + - h36 * k6 + h37 * k7 + h38 * k8 - copyat_or_push!(k, 2, copy(tmp)) - - @.. broadcast=false tmp=h41 * k1 + h42 * k2 + h43 * k3 + h44 * k4 + h45 * k5 + - h46 * k6 + h47 * k7 + h48 * k8 - copyat_or_push!(k, 3, copy(tmp)) end nothing end diff --git a/lib/OrdinaryDiffEqRosenbrock/test/ode_rosenbrock_tests.jl b/lib/OrdinaryDiffEqRosenbrock/test/ode_rosenbrock_tests.jl index d2e5d0608f..d384937120 100644 --- a/lib/OrdinaryDiffEqRosenbrock/test/ode_rosenbrock_tests.jl +++ b/lib/OrdinaryDiffEqRosenbrock/test/ode_rosenbrock_tests.jl @@ -609,7 +609,7 @@ import LinearSolve prob = prob_ode_linear - dts = (1 / 2) .^ (6:-1:3) + dts = (1 / 2) .^ (5:-1:2) sim = test_convergence(dts, prob, Rodas5(), dense_errors = true) @test sim.𝒪est[:final]≈5 atol=testTol @test sim.𝒪est[:L2]≈5 atol=testTol @@ -630,7 +630,6 @@ import LinearSolve prob = prob_ode_linear - dts = (1 / 2) .^ (5:-1:2) sim = test_convergence(dts, prob, Rodas5P(), dense_errors = true) #@test sim.𝒪est[:final]≈5 atol=testTol #-- observed order > 6 @test sim.𝒪est[:L2]≈5 atol=testTol