diff --git a/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl index 50e605e409..cfc14219c5 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl @@ -76,4 +76,4 @@ end @generated function pick_static_chunksize(::Val{chunksize}) where {chunksize} x = ForwardDiff.pickchunksize(chunksize) :(Val{$x}()) -end +end \ No newline at end of file diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl index 7951f1fc72..e351ab54c3 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl @@ -12,7 +12,7 @@ struct StaticWOperator{isinv, T, F} <: AbstractSciMLOperator{T} # doing to how StaticArrays and StaticArraysCore are split up StaticArrays.LU(LowerTriangular(W), UpperTriangular(W), SVector{n}(1:n)) else - lu(W, check = false) + lu(W, check=false) end # when constructing W for the first time for the type # inv(W) can be singular @@ -938,28 +938,28 @@ function LinearSolve.init_cacheval( end for alg in [LinearSolve.AppleAccelerateLUFactorization, - LinearSolve.BunchKaufmanFactorization, - LinearSolve.CHOLMODFactorization, - LinearSolve.CholeskyFactorization, - LinearSolve.CudaOffloadFactorization, - LinearSolve.DiagonalFactorization, - LinearSolve.FastLUFactorization, - LinearSolve.FastQRFactorization, - LinearSolve.GenericFactorization, - LinearSolve.GenericLUFactorization, - LinearSolve.KLUFactorization, - LinearSolve.LDLtFactorization, - LinearSolve.LUFactorization, - LinearSolve.MKLLUFactorization, - LinearSolve.MetalLUFactorization, - LinearSolve.NormalBunchKaufmanFactorization, - LinearSolve.NormalCholeskyFactorization, - LinearSolve.QRFactorization, - LinearSolve.RFLUFactorization, - LinearSolve.SVDFactorization, - LinearSolve.SimpleLUFactorization, - LinearSolve.SparspakFactorization, - LinearSolve.UMFPACKFactorization] + LinearSolve.BunchKaufmanFactorization, + LinearSolve.CHOLMODFactorization, + LinearSolve.CholeskyFactorization, + LinearSolve.CudaOffloadFactorization, + LinearSolve.DiagonalFactorization, + LinearSolve.FastLUFactorization, + LinearSolve.FastQRFactorization, + LinearSolve.GenericFactorization, + LinearSolve.GenericLUFactorization, + LinearSolve.KLUFactorization, + LinearSolve.LDLtFactorization, + LinearSolve.LUFactorization, + LinearSolve.MKLLUFactorization, + LinearSolve.MetalLUFactorization, + LinearSolve.NormalBunchKaufmanFactorization, + LinearSolve.NormalCholeskyFactorization, + LinearSolve.QRFactorization, + LinearSolve.RFLUFactorization, + LinearSolve.SVDFactorization, + LinearSolve.SimpleLUFactorization, + LinearSolve.SparspakFactorization, + LinearSolve.UMFPACKFactorization] @eval function LinearSolve.init_cacheval(alg::$alg, A::WOperator, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) @@ -1003,4 +1003,4 @@ function resize_J_W!(cache, integrator, i) end nothing -end +end \ No newline at end of file diff --git a/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl b/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl index 27bde3e1b7..652b7ea6af 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl @@ -12,7 +12,7 @@ function DiffEqBase.interp_summary(::Type{cacheType}, dense::Bool) where { cacheType <: Union{Rodas4ConstantCache, Rodas23WConstantCache, Rodas3PConstantCache, - Rodas4Cache, Rodas23WCache, Rodas3PCache}} + RosenbrockCache, Rodas23WCache, Rodas3PCache}} dense ? "specialized 3rd order \"free\" stiffness-aware interpolation" : "1st order linear" end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl index 6a26cf2dba..4b6841a1bc 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl @@ -12,6 +12,40 @@ end # Shampine's Low-order Rosenbrocks +mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabType, + TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <: RosenbrockMutableCache + u::uType + uprev::uType + dense::Vector{rateType} + du::rateType + du1::rateType + du2::rateType + ks::Vector{rateType} + fsalfirst::rateType + fsallast::rateType + dT::rateType + J::JType + W::WType + tmp::rateType + atmp::uNoUnitsType + weight::uNoUnitsType + tab::TabType + tf::TFType + uf::UFType + linsolve_tmp::rateType + linsolve::F + jac_config::JCType + grad_config::GCType + reltol::RTolType + alg::A + step_limiter!::StepLimiter + stage_limiter!::StageLimiter +end +function full_cache(c::RosenbrockCache) + return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2, + c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp] +end + @cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType, TabType, TFType, UFType, F, JCType, GCType, RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache @@ -677,89 +711,12 @@ struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConsta autodiff::AD end -@cache mutable struct Rodas4Cache{uType, rateType, uNoUnitsType, JType, WType, TabType, - TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <: - RosenbrockMutableCache - u::uType - uprev::uType - dense1::rateType - dense2::rateType - du::rateType - du1::rateType - du2::rateType - k1::rateType - k2::rateType - k3::rateType - k4::rateType - k5::rateType - k6::rateType - fsalfirst::rateType - fsallast::rateType - dT::rateType - J::JType - W::WType - tmp::rateType - atmp::uNoUnitsType - weight::uNoUnitsType - tab::TabType - tf::TFType - uf::UFType - linsolve_tmp::rateType - linsolve::F - jac_config::JCType - grad_config::GCType - reltol::RTolType - alg::A - step_limiter!::StepLimiter - stage_limiter!::StageLimiter -end - -function alg_cache(alg::Rodas4, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - dense1 = zero(rate_prototype) - dense2 = zero(rate_prototype) - du = zero(rate_prototype) - du1 = zero(rate_prototype) - du2 = zero(rate_prototype) - k1 = zero(rate_prototype) - k2 = zero(rate_prototype) - k3 = zero(rate_prototype) - k4 = zero(rate_prototype) - k5 = zero(rate_prototype) - k6 = zero(rate_prototype) - fsalfirst = zero(rate_prototype) - fsallast = zero(rate_prototype) - dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - tmp = zero(rate_prototype) - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - weight = similar(u, uEltypeNoUnits) - recursivefill!(weight, false) - tab = Rodas4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) +tabtype(::Rodas4) = Rodas4Tableau +tabtype(::Rodas42) = Rodas42Tableau +tabtype(::Rodas4P) = Rodas4PTableau +tabtype(::Rodas4P2) = Rodas4P2Tableau - tf = TimeGradientWrapper(f, uprev, p) - uf = UJacobianWrapper(f, t, p) - linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) - linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, - assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) - Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4, - k5, k6, - fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, - linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!, - alg.stage_limiter!) -end - -function alg_cache(alg::Rodas4, u, rate_prototype, ::Type{uEltypeNoUnits}, +function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} @@ -769,189 +726,60 @@ function alg_cache(alg::Rodas4, u, rate_prototype, ::Type{uEltypeNoUnits}, linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) Rodas4ConstantCache(tf, uf, - Rodas4Tableau(constvalue(uBottomEltypeNoUnits), + tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)), J, W, linsolve, alg_autodiff(alg)) end -function alg_cache(alg::Rodas42, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - dense1 = zero(rate_prototype) - dense2 = zero(rate_prototype) - du = zero(rate_prototype) - du1 = zero(rate_prototype) - du2 = zero(rate_prototype) - k1 = zero(rate_prototype) - k2 = zero(rate_prototype) - k3 = zero(rate_prototype) - k4 = zero(rate_prototype) - k5 = zero(rate_prototype) - k6 = zero(rate_prototype) - fsalfirst = zero(rate_prototype) - fsallast = zero(rate_prototype) - dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - tmp = zero(rate_prototype) - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - weight = similar(u, uEltypeNoUnits) - recursivefill!(weight, false) - tab = Rodas42Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) +function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2}, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, + dt, reltol, p, calck, ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tf = TimeGradientWrapper(f, uprev, p) - uf = UJacobianWrapper(f, t, p) - linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) - linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, - assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) - Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4, - k5, k6, - fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, - linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!, - alg.stage_limiter!) -end - -function alg_cache(alg::Rodas42, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tf = TimeDerivativeWrapper(f, u, p) - uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) - linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) - linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) - Rodas4ConstantCache(tf, uf, - Rodas42Tableau(constvalue(uBottomEltypeNoUnits), - constvalue(tTypeNoUnits)), J, W, linsolve, - alg_autodiff(alg)) -end - -function alg_cache(alg::Rodas4P, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - dense1 = zero(rate_prototype) - dense2 = zero(rate_prototype) + # Initialize vectors + dense = [zero(rate_prototype) for _ in 1:2] + ks = [zero(rate_prototype) for _ in 1:6] du = zero(rate_prototype) du1 = zero(rate_prototype) du2 = zero(rate_prototype) - k1 = zero(rate_prototype) - k2 = zero(rate_prototype) - k3 = zero(rate_prototype) - k4 = zero(rate_prototype) - k5 = zero(rate_prototype) - k6 = zero(rate_prototype) - fsalfirst = zero(rate_prototype) - fsallast = zero(rate_prototype) - dT = zero(rate_prototype) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) - tmp = zero(rate_prototype) - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - weight = similar(u, uEltypeNoUnits) - recursivefill!(weight, false) - tab = Rodas4PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - tf = TimeGradientWrapper(f, uprev, p) - uf = UJacobianWrapper(f, t, p) - linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) - linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, - assumptions = LinearSolve.OperatorAssumptions(true)) - grad_config = build_grad_config(alg, f, tf, du1, t) - jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) - Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4, - k5, k6, - fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, - linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!, - alg.stage_limiter!) -end - -function alg_cache(alg::Rodas4P, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tf = TimeDerivativeWrapper(f, u, p) - uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) - linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) - linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) - Rodas4ConstantCache(tf, uf, - Rodas4PTableau(constvalue(uBottomEltypeNoUnits), - constvalue(tTypeNoUnits)), J, W, linsolve, - alg_autodiff(alg)) -end - -function alg_cache(alg::Rodas4P2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - dense1 = zero(rate_prototype) - dense2 = zero(rate_prototype) - du = zero(rate_prototype) - du1 = zero(rate_prototype) - du2 = zero(rate_prototype) - k1 = zero(rate_prototype) - k2 = zero(rate_prototype) - k3 = zero(rate_prototype) - k4 = zero(rate_prototype) - k5 = zero(rate_prototype) - k6 = zero(rate_prototype) + # Initialize other variables fsalfirst = zero(rate_prototype) fsallast = zero(rate_prototype) dT = zero(rate_prototype) + + # Build J and W matrices J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true)) + + # Temporary and helper variables tmp = zero(rate_prototype) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) weight = similar(u, uEltypeNoUnits) recursivefill!(weight, false) - tab = Rodas4P2Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) tf = TimeGradientWrapper(f, uprev, p) uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) + Pl, Pr = wrapprecs( alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + nothing)..., weight, tmp) + linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, - assumptions = LinearSolve.OperatorAssumptions(true)) + Pl = Pl, Pr = Pr, + assumptions = LinearSolve.OperatorAssumptions(true)) + grad_config = build_grad_config(alg, f, tf, du1, t) jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) - Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4, - k5, k6, - fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, - linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!, - alg.stage_limiter!) -end -function alg_cache(alg::Rodas4P2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, p, calck, - ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tf = TimeDerivativeWrapper(f, u, p) - uf = UDerivativeWrapper(f, t, p) - J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false)) - linprob = nothing #LinearProblem(W,copy(u); u0=copy(u)) - linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true) - Rodas4ConstantCache(tf, uf, - Rodas4P2Tableau(constvalue(uBottomEltypeNoUnits), - constvalue(tTypeNoUnits)), J, W, linsolve, - alg_autodiff(alg)) + # Return the cache struct with vectors + RosenbrockCache( + u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast, + dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp, + linsolve, jac_config, grad_config, reltol, alg, + alg.step_limiter!, alg.stage_limiter!) end ################################################################################ diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl index 6f50205cf1..4d7e659410 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl @@ -5,7 +5,7 @@ ROSENBROCKS_WITH_INTERPOLATIONS = Union{Rosenbrock23ConstantCache, Rosenbrock23C Rodas23WConstantCache, Rodas3PConstantCache, Rodas23WCache, Rodas3PCache, Rodas4ConstantCache, Rosenbrock5ConstantCache, - Rodas4Cache, Rosenbrock5Cache} + RosenbrockCache, Rosenbrock5Cache} function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::ROSENBROCKS_WITH_INTERPOLATIONS, @@ -135,14 +135,14 @@ From MATLAB ODE Suite by Shampine end @muladd function _ode_interpolant( - Θ, dt, y₀, y₁, k, cache::Union{Rodas4Cache, Rodas23WCache, Rodas3PCache}, + Θ, dt, y₀, y₁, k, cache::Union{RosenbrockCache, Rodas23WCache, Rodas3PCache}, idxs::Nothing, T::Type{Val{0}}, differential_vars) Θ1 = 1 - Θ @inbounds @.. broadcast=false Θ1 * y₀+Θ * (y₁ + Θ1 * (k[1] + Θ * k[2])) end @muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, + cache::Union{Rodas4ConstantCache, RosenbrockCache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, idxs, T::Type{Val{0}}, differential_vars) Θ1 = 1 - Θ @@ -150,7 +150,7 @@ end end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, + cache::Union{Rodas4ConstantCache, RosenbrockCache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, idxs::Nothing, T::Type{Val{0}}, differential_vars) Θ1 = 1 - Θ @@ -159,7 +159,7 @@ end end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, + cache::Union{Rodas4ConstantCache, RosenbrockCache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, idxs, T::Type{Val{0}}, differential_vars) Θ1 = 1 - Θ @@ -176,14 +176,14 @@ end end @muladd function _ode_interpolant( - Θ, dt, y₀, y₁, k, cache::Union{Rodas4Cache, Rodas23WCache, Rodas3PCache}, + Θ, dt, y₀, y₁, k, cache::Union{RosenbrockCache, Rodas23WCache, Rodas3PCache}, idxs::Nothing, T::Type{Val{1}}, differential_vars) @inbounds @.. broadcast=false (k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁)/dt end @muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, + cache::Union{Rodas4ConstantCache, RosenbrockCache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, idxs, T::Type{Val{1}}, differential_vars) @.. broadcast=false (k[1][idxs] + @@ -192,7 +192,7 @@ end end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, + cache::Union{Rodas4ConstantCache, RosenbrockCache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, idxs::Nothing, T::Type{Val{1}}, differential_vars) @.. broadcast=false out=(k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁) / @@ -201,7 +201,7 @@ end end @muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, - cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, + cache::Union{Rodas4ConstantCache, RosenbrockCache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache}, idxs, T::Type{Val{1}}, differential_vars) @views @.. broadcast=false out=(k[1][idxs] + diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index a724abbc52..51c938ee39 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -1207,31 +1207,13 @@ function initialize!(integrator, cache::Rodas4ConstantCache) end @muladd function perform_step!(integrator, cache::Rodas4ConstantCache, repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - @unpack tf, uf = cache - @unpack a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, C61, C62, C63, C64, C65, gamma, c2, c3, c4, d1, d2, d3, d4 = cache.tab + (;t, dt, uprev, u, f, p) = integrator + (;tf, uf) = cache + (;A, C, gamma, c, d, H) = cache.tab # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - dtC51 = C51 / dt - dtC52 = C52 / dt - dtC53 = C53 / dt - dtC54 = C54 / dt - dtC61 = C61 / dt - dtC62 = C62 / dt - dtC63 = C63 / dt - dtC64 = C64 / dt - dtC65 = C65 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtd4 = dt * d4 + dtC = C ./ dt + dtd = dt .* d dtgamma = dt * gamma mass_matrix = integrator.f.mass_matrix @@ -1246,105 +1228,78 @@ end return nothing end + # Initialize ks + num_stages = size(A,1) du = f(uprev, p, t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - linsolve_tmp = du + dtd1 * dT - + linsolve_tmp = @.. du + dtd[1] * dT k1 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a21 * k1 - du = f(u, p, t + c2 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + dtd2 * dT + dtC21 * k1 - else - linsolve_tmp = du + dtd2 * dT + mass_matrix * (dtC21 * k1) - end - - k2 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a31 * k1 + a32 * k2 - du = f(u, p, t + c3 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + dtd3 * dT + (dtC31 * k1 + dtC32 * k2) - else - linsolve_tmp = du + dtd3 * dT + mass_matrix * (dtC31 * k1 + dtC32 * k2) - end - - k3 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a41 * k1 + a42 * k2 + a43 * k3 - du = f(u, p, t + c4 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - linsolve_tmp = du + dtd4 * dT + (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - else - linsolve_tmp = du + dtd4 * dT + mass_matrix * (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - end - - k4 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4 - du = f(u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + # constant number for type stability make sure this is greater than num_stages + ks = ntuple(Returns(k1), 10) + # Loop for stages + for stage in 2:num_stages + u = uprev + for i in 1:stage-1 + u = @.. u + A[stage, i] * ks[i] + end - if mass_matrix === I - linsolve_tmp = du + (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) - else - linsolve_tmp = du + - mass_matrix * (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) - end + du = f(u, p, t + c[stage] * dt) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - k5 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = u + k5 - du = f(u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + # Compute linsolve_tmp for current stage + linsolve_tmp = zero(du) + if mass_matrix === I + for i in 1:stage-1 + linsolve_tmp = @.. linsolve_tmp + dtC[stage, i] * ks[i] + end + else + for i in 1:stage-1 + linsolve_tmp = @.. linsolve_tmp + dtC[stage, i] * ks[i] + end + linsolve_tmp = mass_matrix * linsolve_tmp + end + linsolve_tmp = @.. du + dtd[stage] * dT + linsolve_tmp - if mass_matrix === I - linsolve_tmp = du + (dtC61 * k1 + dtC62 * k2 + dtC65 * k5 + dtC64 * k4 + dtC63 * k3) - else - linsolve_tmp = du + - mass_matrix * - (dtC61 * k1 + dtC62 * k2 + dtC65 * k5 + dtC64 * k4 + dtC63 * k3) + ks = Base.setindex(ks, _reshape(W \ -_vec(linsolve_tmp), axes(uprev)), stage) + integrator.stats.nsolve += 1 end - - k6 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = u + k6 + #@show ks + u = u .+ ks[num_stages] if integrator.opts.adaptive - atmp = calculate_residuals(k6, uprev, u, integrator.opts.abstol, + atmp = calculate_residuals(ks[num_stages], uprev, u, integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) integrator.EEst = integrator.opts.internalnorm(atmp, t) end if integrator.opts.calck - @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35 = cache.tab - integrator.k[1] = h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + h25 * k5 - integrator.k[2] = h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5 + for j in eachindex(integrator.k) + integrator.k[j] = zero(integrator.k[1]) + end + for i in 1:num_stages + for j in eachindex(integrator.k) + integrator.k[j] = @.. integrator.k[j] + H[j, i] * ks[i] + end + end end + integrator.u = u return nothing end -function initialize!(integrator, cache::Rodas4Cache) +function initialize!(integrator, cache::RosenbrockCache) + dense = cache.dense + dense1, dense2 = dense[1], dense[2] integrator.kshortsize = 2 - @unpack dense1, dense2 = cache resize!(integrator.k, integrator.kshortsize) integrator.k[1] = dense1 integrator.k[2] = dense2 end -@muladd function perform_step!(integrator, cache::Rodas4Cache, repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - @unpack du, du1, du2, dT, J, W, uf, tf, k1, k2, k3, k4, k5, k6, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache - @unpack a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, C61, C62, C63, C64, C65, gamma, c2, c3, c4, d1, d2, d3, d4 = cache.tab + +@muladd function perform_step!(integrator, cache::RosenbrockCache, repeat_step = false) + (;t, dt, uprev, u, f, p) = integrator + (;du, du1, du2, dT, J, W, uf, tf, ks, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter!) = cache + (;A, C, gamma, c, d, H) = cache.tab # Assignments sizeu = size(u) @@ -1352,32 +1307,14 @@ end mass_matrix = integrator.f.mass_matrix # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - dtC51 = C51 / dt - dtC52 = C52 / dt - dtC53 = C53 / dt - dtC54 = C54 / dt - dtC61 = C61 / dt - dtC62 = C62 / dt - dtC63 = C63 / dt - dtC64 = C64 / dt - dtC65 = C65 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtd4 = dt * d4 + dtC = C .* inv(dt) + dtd = dt .* d dtgamma = dt * gamma - f(cache.fsalfirst, uprev, p, t) # used in calc_rosenbrock_differentiation! + f(cache.fsalfirst, uprev, p, t) OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true) + calc_rosenbrock_differentiation!(integrator, cache, dtd[1], dtgamma, repeat_step, true) calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev, integrator.opts.abstol, integrator.opts.reltol, @@ -1394,114 +1331,56 @@ end solverdata = (; gamma = dtgamma)) end - @.. broadcast=false $(_vec(k1))=-linres.u - - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + a21 * k1 - stage_limiter!(u, integrator, p, t + c2 * dt) - f(du, u, p, t + c2 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd2 * dT + dtC21 * k1 - else - @.. broadcast=false du1=dtC21 * k1 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + dtd2 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - @.. broadcast=false $(_vec(k2))=-linres.u - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + a31 * k1 + a32 * k2 - stage_limiter!(u, integrator, p, t + c3 * dt) - f(du, u, p, t + c3 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd3 * dT + (dtC31 * k1 + dtC32 * k2) - else - @.. broadcast=false du1=dtC31 * k1 + dtC32 * k2 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + dtd3 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - @.. broadcast=false $(_vec(k3))=-linres.u - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + a41 * k1 + a42 * k2 + a43 * k3 - stage_limiter!(u, integrator, p, t + c4 * dt) - f(du, u, p, t + c4 * dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd4 * dT + - (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - else - @.. broadcast=false du1=dtC41 * k1 + dtC42 * k2 + dtC43 * k3 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + dtd4 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - @.. broadcast=false $(_vec(k4))=-linres.u + @.. $(_vec(ks[1]))=-linres.u integrator.stats.nsolve += 1 - @.. broadcast=false u=uprev + a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4 - stage_limiter!(u, integrator, p, t + dt) - f(du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + - (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) - else - @.. broadcast=false du1=dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + du2 - end + for stage in 2:length(ks) + u .= uprev + for i in 1:stage-1 + @.. u += A[stage, i] * ks[i] + end - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - @.. broadcast=false $(_vec(k5))=-linres.u - integrator.stats.nsolve += 1 + stage_limiter!(u, integrator, p, t + c[stage] * dt) + f(du, u, p, t + c[stage] * dt) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - u .+= k5 - f(du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + du1 .= 0 + if mass_matrix === I + for i in 1:stage-1 + @.. du1 += dtC[stage, i] * ks[i] + end + else + for i in 1:stage-1 + @.. du1 += dtC[stage, i] * ks[i] + end + mul!(_vec(du2), mass_matrix, _vec(du1)) + du1 .= du2 + end + @.. linsolve_tmp = du + dtd[stage] * dT + du1 - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + (dtC61 * k1 + dtC62 * k2 + dtC65 * k5 + - dtC64 * k4 + dtC63 * k3) - else - @.. broadcast=false du1=dtC61 * k1 + dtC62 * k2 + dtC65 * k5 + dtC64 * k4 + - dtC63 * k3 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + du2 + linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) + @.. $(_vec(ks[stage]))=-linres.u + integrator.stats.nsolve += 1 end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - @.. broadcast=false $(_vec(k6))=-linres.u - integrator.stats.nsolve += 1 - - u .+= k6 + u .+= ks[end] step_limiter!(u, integrator, p, t + dt) if integrator.opts.adaptive - calculate_residuals!(atmp, k6, uprev, u, integrator.opts.abstol, + calculate_residuals!(atmp, ks[end], uprev, u, integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) integrator.EEst = integrator.opts.internalnorm(atmp, t) end if integrator.opts.calck - @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35 = cache.tab - @.. broadcast=false integrator.k[1]=h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + - h25 * k5 - @.. broadcast=false integrator.k[2]=h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + - h35 * k5 + for j in eachindex(integrator.k) + integrator.k[j] .= 0 + end + for i in eachindex(ks) + for j in eachindex(integrator.k) + @.. integrator.k[j] += H[j, i] * ks[i] + end + end end cache.linsolve = linres.cache end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl index 493c67111a..19e7a989d7 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl @@ -223,161 +223,66 @@ end @Rosenbrock4(:tableau) struct RodasTableau{T, T2} - a21::T - a31::T - a32::T - a41::T - a42::T - a43::T - a51::T - a52::T - a53::T - a54::T - C21::T - C31::T - C32::T - C41::T - C42::T - C43::T - C51::T - C52::T - C53::T - C54::T - C61::T - C62::T - C63::T - C64::T - C65::T + A::Matrix{T} + C::Matrix{T} gamma::T - c2::T2 - c3::T2 - c4::T2 - d1::T - d2::T - d3::T - d4::T - h21::T - h22::T - h23::T - h24::T - h25::T - h31::T - h32::T - h33::T - h34::T - h35::T + c::Vector{T2} + d::Vector{T} + H::Matrix{T} end + function Rodas4Tableau(T, T2) gamma = convert(T, 1 // 4) #BET2P=0.0317D0 #BET3P=0.0635D0 #BET4P=0.3438D0 - a21 = convert(T, 1.544000000000000) - a31 = convert(T, 0.9466785280815826) - a32 = convert(T, 0.2557011698983284) - a41 = convert(T, 3.314825187068521) - a42 = convert(T, 2.896124015972201) - a43 = convert(T, 0.9986419139977817) - a51 = convert(T, 1.221224509226641) - a52 = convert(T, 6.019134481288629) - a53 = convert(T, 12.53708332932087) - a54 = -convert(T, 0.6878860361058950) - C21 = -convert(T, 5.668800000000000) - C31 = -convert(T, 2.430093356833875) - C32 = -convert(T, 0.2063599157091915) - C41 = -convert(T, 0.1073529058151375) - C42 = -convert(T, 9.594562251023355) - C43 = -convert(T, 20.47028614809616) - C51 = convert(T, 7.496443313967647) - C52 = -convert(T, 10.24680431464352) - C53 = -convert(T, 33.99990352819905) - C54 = convert(T, 11.70890893206160) - C61 = convert(T, 8.083246795921522) - C62 = -convert(T, 7.981132988064893) - C63 = -convert(T, 31.52159432874371) - C64 = convert(T, 16.31930543123136) - C65 = -convert(T, 6.058818238834054) - - c2 = convert(T2, 0.386) - c3 = convert(T2, 0.21) - c4 = convert(T2, 0.63) - - d1 = convert(T, 0.2500000000000000) - d2 = -convert(T, 0.1043000000000000) - d3 = convert(T, 0.1035000000000000) - d4 = -convert(T, 0.03620000000000023) - - h21 = convert(T, 10.12623508344586) - h22 = -convert(T, 7.487995877610167) - h23 = -convert(T, 34.80091861555747) - h24 = -convert(T, 7.992771707568823) - h25 = convert(T, 1.025137723295662) - h31 = -convert(T, 0.6762803392801253) - h32 = convert(T, 6.087714651680015) - h33 = convert(T, 16.43084320892478) - h34 = convert(T, 24.76722511418386) - h35 = -convert(T, 6.594389125716872) - - RodasTableau(a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, - C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, C61, C62, C63, C64, C65, - gamma, c2, c3, c4, d1, d2, d3, d4, - h21, h22, h23, h24, h25, h31, h32, h33, h34, h35) + A = T[ + 0 0 0 0 0 0 + 1.544 0 0 0 0 0 + 0.9466785280815826 0.2557011698983284 0 0 0 0 + 3.314825187068521 2.896124015972201 0.9986419139977817 0 0 0 + 1.221224509226641 6.019134481288629 12.53708332932087 -0.6878860361058950 0 0 + 1.221224509226641 6.019134481288629 12.53708332932087 -0.6878860361058950 1 0 + ] + C = T[ + 0 0 0 0 0 + -5.6688 0 0 0 0 + -2.430093356833875 -0.2063599157091915 0 0 0 + -0.1073529058151375 -9.594562251023355 -20.47028614809616 0 0 + 7.496443313967647 -10.24680431464352 -33.99990352819905 11.70890893206160 0 + 8.083246795921522 -7.981132988064893 -31.52159432874371 16.31930543123136 -6.058818238834054 + ] + c = T2[0, 0.386, 0.21, 0.63, 1, 1] + d = T[0.25, -0.1043, 0.1035, -0.0362, 0, 0] + H = T[10.12623508344586 -7.487995877610167 -34.80091861555747 -7.992771707568823 1.025137723295662 0 + -0.6762803392801253 6.087714651680015 16.43084320892478 24.76722511418386 -6.594389125716872 0] + RodasTableau(A, C, gamma, c, d, H) end function Rodas42Tableau(T, T2) gamma = convert(T, 1 // 4) - #BET2P=0.0317D0 - #BET3P=0.0047369D0 - #BET4P=0.3438D0 - a21 = convert(T, 1.402888400000000) - a31 = convert(T, 0.6581212688557198) - a32 = -convert(T, 1.320936088384301) - a41 = convert(T, 7.131197445744498) - a42 = convert(T, 16.02964143958207) - a43 = -convert(T, 5.561572550509766) - a51 = convert(T, 22.73885722420363) - a52 = convert(T, 67.38147284535289) - a53 = -convert(T, 31.21877493038560) - a54 = convert(T, 0.7285641833203814) - C21 = -convert(T, 5.104353600000000) - C31 = -convert(T, 2.899967805418783) - C32 = convert(T, 4.040399359702244) - C41 = -convert(T, 32.64449927841361) - C42 = -convert(T, 99.35311008728094) - C43 = convert(T, 49.99119122405989) - C51 = -convert(T, 76.46023087151691) - C52 = -convert(T, 278.5942120829058) - C53 = convert(T, 153.9294840910643) - C54 = convert(T, 10.97101866258358) - C61 = -convert(T, 76.29701586804983) - C62 = -convert(T, 294.2795630511232) - C63 = convert(T, 162.0029695867566) - C64 = convert(T, 23.65166903095270) - C65 = -convert(T, 7.652977706771382) - c2 = convert(T2, 0.3507221) - c3 = convert(T2, 0.2557041) - c4 = convert(T2, 0.6817790) - d1 = convert(T, 0.2500000000000000) - d2 = -convert(T, 0.06902209999999998) - d3 = -convert(T, 0.0009671999999999459) - d4 = -convert(T, 0.08797900000000025) - - h21 = -convert(T, 38.71940424117216) - h22 = -convert(T, 135.8025833007622) - h23 = convert(T, 64.51068857505875) - h24 = -convert(T, 4.192663174613162) - h25 = -convert(T, 2.531932050335060) - h31 = -convert(T, 14.99268484949843) - h32 = -convert(T, 76.30242396627033) - h33 = convert(T, 58.65928432851416) - h34 = convert(T, 16.61359034616402) - h35 = -convert(T, 0.6758691794084156) - - RodasTableau(a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, - C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, C61, C62, C63, C64, C65, - gamma, c2, c3, c4, d1, d2, d3, d4, - h21, h22, h23, h24, h25, h31, h32, h33, h34, h35) + A = T[ + 0.0 0 0 0 0 0 + 1.4028884 0 0 0 0 0 + 0.6581212688557198 -1.320936088384301 0 0 0 0 + 7.131197445744498 16.02964143958207 -5.561572550509766 0 0 0 + 22.73885722420363 67.38147284535289 -31.21877493038560 0.7285641833203814 0 0 + 22.73885722420363 67.38147284535289 -31.21877493038560 0.7285641833203814 1 0 + ] + C = T[ + 0 0 0 0 0 + -5.1043536 0 0 0 0 + -2.899967805418783 4.040399359702244 0 0 0 + -32.64449927841361 -99.35311008728094 49.99119122405989 0 0 + -76.46023087151691 -278.5942120829058 153.9294840910643 10.97101866258358 0 + -76.29701586804983 -294.2795630511232 162.0029695867566 23.65166903095270 -7.652977706771382 + ] + c = T2[0, 0.3507221, 0.2557041, 0.681779, 1, 1] + d = T[0.25, -0.0690221, -0.0009672, -0.087979, 0, 0] + H = T[-38.71940424117216 -135.8025833007622 64.51068857505875 -4.192663174613162 -2.531932050335060 0 + -14.99268484949843 -76.30242396627033 58.65928432851416 16.61359034616402 -0.6758691794084156 0] + RodasTableau(A, C, gamma, c, d, H) end function Rodas4PTableau(T, T2) @@ -385,108 +290,52 @@ function Rodas4PTableau(T, T2) #BET2P=0.D0 #BET3P=c3*c3*(c3/6.d0-GAMMA/2.d0)/(GAMMA*GAMMA) #BET4P=0.3438D0 - a21 = convert(T, 3) - a31 = convert(T, 1.831036793486759) - a32 = convert(T, 0.4955183967433795) - a41 = convert(T, 2.304376582692669) - a42 = -convert(T, 0.05249275245743001) - a43 = -convert(T, 1.176798761832782) - a51 = -convert(T, 7.170454962423024) - a52 = -convert(T, 4.741636671481785) - a53 = -convert(T, 16.31002631330971) - a54 = -convert(T, 1.062004044111401) - C21 = -convert(T, 12) - C31 = -convert(T, 8.791795173947035) - C32 = -convert(T, 2.207865586973518) - C41 = convert(T, 10.81793056857153) - C42 = convert(T, 6.780270611428266) - C43 = convert(T, 19.53485944642410) - C51 = convert(T, 34.19095006749676) - C52 = convert(T, 15.49671153725963) - C53 = convert(T, 54.74760875964130) - C54 = convert(T, 14.16005392148534) - C61 = convert(T, 34.62605830930532) - C62 = convert(T, 15.30084976114473) - C63 = convert(T, 56.99955578662667) - C64 = convert(T, 18.40807009793095) - C65 = -convert(T, 5.714285714285717) - c2 = convert(T2, 3 * gamma) - c3 = convert(T2, 0.21) - c4 = convert(T2, 0.63) - d1 = convert(T, 0.2500000000000000) - d2 = convert(T, -0.5000000000000000) - d3 = convert(T, -0.0235040000000000) - d4 = convert(T, -0.0362000000000000) - - h21 = convert(T, 25.09876703708589) - h22 = convert(T, 11.62013104361867) - h23 = convert(T, 28.49148307714626) - h24 = -convert(T, 5.664021568594133) - h25 = convert(T, 0) - h31 = convert(T, 1.638054557396973) - h32 = -convert(T, 0.7373619806678748) - h33 = convert(T, 8.477918219238990) - h34 = convert(T, 15.99253148779520) - h35 = -convert(T, 1.882352941176471) - - RodasTableau(a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, - C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, C61, C62, C63, C64, C65, - gamma, c2, c3, c4, d1, d2, d3, d4, - h21, h22, h23, h24, h25, h31, h32, h33, h34, h35) + A = T[ + 0 0 0 0 0 0 + 3 0 0 0 0 0 + 1.831036793486759 0.4955183967433795 0 0 0 0 + 2.304376582692669 -0.05249275245743001 -1.176798761832782 0 0 0 + -7.170454962423024 -4.741636671481785 -16.31002631330971 -1.062004044111401 0 0 + -7.170454962423024 -4.741636671481785 -16.31002631330971 -1.062004044111401 1 0 + ] + C = T[ + 0 0 0 0 0 + -12 0 0 0 0 + -8.791795173947035 -2.207865586973518 0 0 0 + 10.81793056857153 6.780270611428266 19.53485944642410 0 0 + 34.19095006749676 15.49671153725963 54.74760875964130 14.16005392148534 0 + 34.62605830930532 15.30084976114473 56.99955578662667 18.40807009793095 -5.714285714285717 + ] + c = T2[0, 0.75, 0.21, 0.63, 1, 1] + d = T[0.25, -0.5, -0.023504, -0.0362, 0, 0] + H = T[25.09876703708589 11.62013104361867 28.49148307714626 -5.664021568594133 0 0 + 1.638054557396973 -0.7373619806678748 8.477918219238990 15.99253148779520 -1.882352941176471 0] + RodasTableau(A, C, gamma, c, d, H) end function Rodas4P2Tableau(T, T2) gamma = convert(T, 1 // 4) - a21 = convert(T, 3.000000000000000) - a31 = convert(T, 0.906377755268814) - a32 = -convert(T, 0.189707390391685) - a41 = convert(T, 3.758617027739064) - a42 = convert(T, 1.161741776019525) - a43 = -convert(T, 0.849258085312803) - a51 = convert(T, 7.089566927282776) - a52 = convert(T, 4.573591406461604) - a53 = -convert(T, 8.423496976860259) - a54 = -convert(T, 0.959280113459775) - - C21 = convert(T, -12.00000000000000) - C31 = convert(T, -6.354581592719008) - C32 = convert(T, 0.338972550544623) - C41 = convert(T, -8.575016317114033) - C42 = convert(T, -7.606483992117508) - C43 = convert(T, 12.224997650124820) - C51 = convert(T, -5.888975457523102) - C52 = convert(T, -8.157396617841821) - C53 = convert(T, 24.805546872612922) - C54 = convert(T, 12.790401512796979) - C61 = convert(T, -4.408651676063871) - C62 = convert(T, -6.692003137674639) - C63 = convert(T, 24.625568527593117) - C64 = convert(T, 16.627521966636085) - C65 = convert(T, -5.714285714285718) - - c2 = convert(T2, 0.750000000000000) - c3 = convert(T2, 0.321448134013046) - c4 = convert(T2, 0.519745732277726) - d1 = convert(T, 0.250000000000000) - d2 = convert(T, -0.500000000000000) - d3 = convert(T, -0.189532918363016) - d4 = convert(T, 0.085612108792769) - - h21 = convert(T, -5.323528268423303) - h22 = convert(T, -10.042123754867493) - h23 = convert(T, 17.175254928256965) - h24 = convert(T, -5.079931171878093) - h25 = convert(T, -0.016185991706112) - h31 = convert(T, 6.984505741529879) - h32 = convert(T, 6.914061169603662) - h33 = convert(T, -0.849178943070653) - h34 = convert(T, 18.104410789349338) - h35 = convert(T, -3.516963011559032) - - RodasTableau(a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, - C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, C61, C62, C63, C64, C65, - gamma, c2, c3, c4, d1, d2, d3, d4, - h21, h22, h23, h24, h25, h31, h32, h33, h34, h35) + A = T[ + 0 0 0 0 0 0 + 3 0 0 0 0 0 + 0.906377755268814 -0.189707390391685 0 0 0 0 + 3.758617027739064 1.161741776019525 -0.849258085312803 0 0 0 + 7.089566927282776 4.573591406461604 -8.423496976860259 -0.959280113459775 0 0 + 7.089566927282776 4.573591406461604 -8.423496976860259 -0.959280113459775 1 0 + ] + C = T[ + 0 0 0 0 0 + -12 0 0 0 0 + -6.354581592719008 0.338972550544623 0 0 0 + -8.575016317114033 -7.606483992117508 12.224997650124820 0 0 + -5.888975457523102 -8.157396617841821 24.805546872612922 12.790401512796979 0 + -4.408651676063871 -6.692003137674639 24.625568527593117 16.627521966636085 -5.714285714285718 + ] + c = T2[0, 0.75, 0.321448134013046, 0.519745732277726, 1, 1] + d = T[0.25, -0.5, -0.189532918363016, 0.085612108792769, 0, 0] + H = [-5.323528268423303 -10.042123754867493 17.175254928256965 -5.079931171878093 -0.016185991706112 0 + 6.984505741529879 6.914061169603662 -0.849178943070653 18.104410789349338 -3.516963011559032 0] + RodasTableau(A, C, gamma, c, d, H) end struct Rodas5Tableau{T, T2} diff --git a/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl b/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl index 6e77125e14..57a643afc6 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl @@ -291,30 +291,12 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rodas4ConstantCache, always_calc_begin = false, allow_calc_end = true, force_calc_end = false) if length(k) < 2 || always_calc_begin - @unpack tf, uf = cache - @unpack a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, C61, C62, C63, C64, C65, gamma, c2, c3, c4, d1, d2, d3, d4 = cache.tab + (;tf, uf) = cache + (;A, C, gamma, c, d, H) = cache.tab # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - dtC51 = C51 / dt - dtC52 = C52 / dt - dtC53 = C53 / dt - dtC54 = C54 / dt - dtC61 = C61 / dt - dtC62 = C62 / dt - dtC63 = C63 / dt - dtC64 = C64 / dt - dtC65 = C65 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtd4 = dt * d4 + dtC = C ./ dt + dtd = dt .* d dtgamma = dt * gamma mass_matrix = f.mass_matrix @@ -335,52 +317,60 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rodas4ConstantCache, J = ForwardDiff.derivative(uf, uprev) W = 1 / dtgamma - J end + + + num_stages = size(A,1) + du = f(u, p, t) + linsolve_tmp = @.. du + dtd[1] * dT + k1 = _reshape(W \ _vec(linsolve_tmp), axes(uprev)) + # constant number for type stability make sure this is greater than num_stages + ks = ntuple(Returns(k1), 10) + # Last stage doesn't affect ks + for stage in 2:num_stages-1 + u = uprev + for i in 1:stage-1 + u = @.. u + A[stage, i] * ks[i] + end + + du = f(u, p, t + c[stage] * dt) + + # Compute linsolve_tmp for current stage + linsolve_tmp = zero(du) + if mass_matrix === I + for i in 1:stage-1 + linsolve_tmp = @.. linsolve_tmp + dtC[stage, i] * ks[i] + end + else + for i in 1:stage-1 + linsolve_tmp = @.. linsolve_tmp + dtC[stage, i] * ks[i] + end + linsolve_tmp = mass_matrix * linsolve_tmp + end + linsolve_tmp = @.. du + dtd[stage] * dT + linsolve_tmp + ks = Base.setindex(ks, _reshape(W \ _vec(linsolve_tmp), axes(uprev)), stage) + end - du = f(uprev, p, t) - - linsolve_tmp = du + dtd1 * dT - - k1 = W \ linsolve_tmp - u = uprev + a21 * k1 - du = f(u, p, t + c2 * dt) - - linsolve_tmp = du + dtd2 * dT + dtC21 * k1 - - k2 = W \ linsolve_tmp - u = uprev + a31 * k1 + a32 * k2 - du = f(u, p, t + c3 * dt) - - linsolve_tmp = du + dtd3 * dT + (dtC31 * k1 + dtC32 * k2) - - k3 = W \ linsolve_tmp - u = uprev + a41 * k1 + a42 * k2 + a43 * k3 - du = f(u, p, t + c4 * dt) - - linsolve_tmp = du + dtd4 * dT + (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - - k4 = W \ linsolve_tmp - u = uprev + a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4 - du = f(u, p, t + dt) - - linsolve_tmp = du + (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3) - - k5 = W \ linsolve_tmp + k1 = zero(ks[1]) + k2 = zero(ks[1]) + H = cache.tab.H + # Last stage doesn't affect ks + for i in 1:num_stages-1 + k1 = @.. k1 + H[1, i] * ks[i] + k2 = @.. k2 + H[2, i] * ks[i] + end - @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35 = cache.tab - k₁ = h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + h25 * k5 - k₂ = h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5 - copyat_or_push!(k, 1, k₁) - copyat_or_push!(k, 2, k₂) + copyat_or_push!(k, 1, k1) + copyat_or_push!(k, 2, k2) end nothing end -function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rodas4Cache, +function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::RosenbrockCache, always_calc_begin = false, allow_calc_end = true, force_calc_end = false) if length(k) < 2 || always_calc_begin - @unpack du, du1, du2, tmp, k1, k2, k3, k4, k5, k6, dT, J, W, uf, tf, linsolve_tmp, jac_config, fsalfirst, weight = cache - @unpack a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, C61, C62, C63, C64, C65, gamma, c2, c3, c4, d1, d2, d3, d4 = cache.tab + (;du, du1, du2, tmp, ks, dT, J, W, uf, tf, linsolve_tmp, jac_config, fsalfirst, weight) = cache + (;A, C, gamma, c, d , H) = cache.tab # Assignments sizeu = size(u) @@ -388,114 +378,52 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rodas4Cache, mass_matrix = f.mass_matrix # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - dtC51 = C51 / dt - dtC52 = C52 / dt - dtC53 = C53 / dt - dtC54 = C54 / dt - dtC61 = C61 / dt - dtC62 = C62 / dt - dtC63 = C63 / dt - dtC64 = C64 / dt - dtC65 = C65 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtd4 = dt * d4 + dtC = C ./ dt + dtd = dt .* d dtgamma = dt * gamma - @.. broadcast=false linsolve_tmp=@muladd fsalfirst + dtgamma * dT + @.. linsolve_tmp=@muladd fsalfirst + dtgamma * dT - ### Jacobian does not need to be re-evaluated after an event - ### Since it's unchanged + # Jacobian does not need to be re-evaluated after an event since it's unchanged jacobian2W!(W, mass_matrix, dtgamma, J, true) linsolve = cache.linsolve - linres = dolinsolve(cache, linsolve; A = W, b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck1 = _vec(k1) - - @.. broadcast=false veck1=-vecu - @.. broadcast=false tmp=uprev + a21 * k1 - f(du, tmp, p, t + c2 * dt) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd2 * dT + dtC21 * k1 - else - @.. broadcast=false du1=dtC21 * k1 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=du + dtd2 * dT + du2 - end - - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck2 = _vec(k2) - @.. broadcast=false veck2=-vecu - @.. broadcast=false tmp=uprev + a31 * k1 + a32 * k2 - f(du, tmp, p, t + c3 * dt) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd3 * dT + (dtC31 * k1 + dtC32 * k2) - else - @.. broadcast=false du1=dtC31 * k1 + dtC32 * k2 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=du + dtd3 * dT + du2 - end - - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck3 = _vec(k3) - @.. broadcast=false veck3=-vecu - @.. broadcast=false tmp=uprev + a41 * k1 + a42 * k2 + a43 * k3 - f(du, tmp, p, t + c4 * dt) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd4 * dT + - (dtC41 * k1 + dtC42 * k2 + dtC43 * k3) - else - @.. broadcast=false du1=dtC41 * k1 + dtC42 * k2 + dtC43 * k3 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=du + dtd4 * dT + du2 + linres = dolinsolve(cache, linsolve; A = W, b = _vec(linsolve_tmp), reltol = cache.reltol) + @.. $(_vec(ks[1]))=-linres.u + # Last stage doesn't affect ks + for stage in 2:length(ks)-1 + tmp .= uprev + for i in 1:stage-1 + @.. tmp += A[stage, i] * _vec(ks[i]) + end + f(du, tmp, p, t + c[stage] * dt) + + if mass_matrix === I + @.. linsolve_tmp = du + dtd[stage] * dT + for i in 1:stage-1 + @.. linsolve_tmp += dtC[stage, i] * _vec(ks[i]) + end + else + du1 .= du + for i in 1:stage-1 + @.. du1 += dtC[stage, i] * _vec(ks[i]) + end + mul!(_vec(du2), mass_matrix, _vec(du1)) + @.. linsolve_tmp = du + dtd[stage] * dT + du2 + end + + linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), reltol = cache.reltol) + @.. $(_vec(ks[stage]))=-linres.u end - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck4 = _vec(k4) - @.. broadcast=false veck4=-vecu - @.. broadcast=false tmp=uprev + a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4 - f(du, tmp, p, t + dt) - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + (dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + - dtC53 * k3) - else - @.. broadcast=false du1=dtC52 * k2 + dtC54 * k4 + dtC51 * k1 + dtC53 * k3 - mul!(du2, mass_matrix, du1) - @.. broadcast=false linsolve_tmp=du + du2 + copyat_or_push!(k, 1, zero(du)) + copyat_or_push!(k, 2, zero(du)) + # Last stage doesn't affect ks + for i in 1:length(ks)-1 + @.. k[1] += H[1, i] * _vec(ks[i]) + @.. k[2] += H[2, i] * _vec(ks[i]) end - - linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp), - reltol = cache.reltol) - vecu = _vec(linres.u) - veck5 = _vec(k5) - @.. broadcast=false veck5=-vecu - @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35 = cache.tab - @.. broadcast=false k6=h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + h25 * k5 - copyat_or_push!(k, 1, copy(k6)) - - @.. broadcast=false k6=h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5 - copyat_or_push!(k, 2, copy(k6)) end nothing end