Skip to content

Commit

Permalink
New methods Rodas23W/Rodas3P with error control for interpolation
Browse files Browse the repository at this point in the history
See issue
SciML#2054
  • Loading branch information
gstein3m committed Dec 22, 2023
1 parent 027741a commit 81ce811
Show file tree
Hide file tree
Showing 11 changed files with 1,182 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ export MagnusMidpoint, LinearExponential, MagnusLeapfrog, LieEuler, CayleyEuler,
MagnusAdapt4, RKMK2, RKMK4, LieRK4, CG2, CG3, CG4a

export Rosenbrock23, Rosenbrock32, RosShamp4, Veldd4, Velds4, GRK4T, GRK4A,
Ros4LStab, ROS3P, Rodas3, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P,
Ros4LStab, ROS3P, Rodas3, Rodas23W, Rodas3P, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P,
RosenbrockW6S4OS, ROS34PW1a, ROS34PW1b, ROS34PW2, ROS34PW3

export LawsonEuler, NorsettEuler, ETD1, ETDRK2, ETDRK3, ETDRK4, HochOst4, Exp4, EPIRK4s3A,
Expand Down
4 changes: 4 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ isfsal(tab::DiffEqBase.ExplicitRKTableau) = tab.fsal

# isfsal(alg::CompositeAlgorithm) = isfsal(alg.algs[alg.current])
isfsal(alg::FunctionMap) = false
isfsal(alg::Rodas3P) = false
isfsal(alg::Rodas23W) = false
isfsal(alg::Rodas5) = false
isfsal(alg::Rodas5P) = false
isfsal(alg::Rodas4) = false
Expand Down Expand Up @@ -620,9 +622,11 @@ alg_order(alg::Feagin14) = 14
alg_order(alg::PFRK87) = 8

alg_order(alg::Rosenbrock23) = 2
alg_order(alg::Rodas23W) = 3
alg_order(alg::Rosenbrock32) = 3
alg_order(alg::ROS3P) = 3
alg_order(alg::Rodas3) = 3
alg_order(alg::Rodas3P) = 3
alg_order(alg::ROS34PW1a) = 3
alg_order(alg::ROS34PW1b) = 3
alg_order(alg::ROS34PW2) = 3
Expand Down
13 changes: 12 additions & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2910,6 +2910,11 @@ Scientific Computing, 18 (1), pp. 1-22.
- Kaps, P. & Rentrop, Generalized Runge-Kutta methods of order four with stepsize control
for stiff ordinary differential equations. P. Numer. Math. (1979) 33: 55. doi:10.1007/BF01396495
#### Rodas23W, Rodas3P
- Steinebach G., Rodas23W / Rodas32P - a Rosenbrock-type method for DAEs with additional error estimate for dense output and Julia implementation,
in progress
#### Rodas4P
- Steinebach G. Order-reduction of ROW-methods for DAEs and method of lines
Expand All @@ -2921,10 +2926,14 @@ Scientific Computing, 18 (1), pp. 1-22.
Differential-Algebraic Equations Forum. Springer, Cham. https://doi.org/10.1007/978-3-030-53905-4_6
#### Rodas5
- Di Marzo G. RODAS5(4) – Méthodes de Rosenbrock d’ordre 5(4) adaptées aux problemes
différentiels-algébriques. MSc mathematics thesis, Faculty of Science,
University of Geneva, Switzerland.
#### Rodas5P
- Steinebach G. Construction of Rosenbrock–Wanner method Rodas5P and numerical benchmarks within the Julia Differential Equations package.
In: BIT Numerical Mathematics, 63(2), 2023
=#

for Alg in [
Expand All @@ -2942,6 +2951,8 @@ for Alg in [
:GRK4T,
:GRK4A,
:Ros4LStab,
:Rodas23W,
:Rodas3P,
:Rodas4,
:Rodas42,
:Rodas4P,
Expand Down
209 changes: 209 additions & 0 deletions src/caches/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,215 @@ jac_cache(c::Rosenbrock4Cache) = (c.J, c.W)

### Rodas methods

struct Rodas23WConstantCache{TF, UF, Tab, JType, WType, F, AD} <: OrdinaryDiffEqConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
autodiff::AD
end

struct Rodas3PConstantCache{TF, UF, Tab, JType, WType, F, AD} <: OrdinaryDiffEqConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
autodiff::AD
end

@cache mutable struct Rodas23WCache{uType, rateType, uNoUnitsType, JType, WType, TabType,
TFType, UFType, F, JCType, GCType, RTolType, A} <:
RosenbrockMutableCache
u::uType
uprev::uType
dense1::rateType
dense2::rateType
dense3::rateType
du::rateType
du1::rateType
du2::rateType
k1::rateType
k2::rateType
k3::rateType
k4::rateType
k5::rateType
fsalfirst::rateType
fsallast::rateType
dT::rateType
J::JType
W::WType
tmp::rateType
atmp::uNoUnitsType
weight::uNoUnitsType
tab::TabType
tf::TFType
uf::UFType
linsolve_tmp::rateType
linsolve::F
jac_config::JCType
grad_config::GCType
reltol::RTolType
alg::A
end

@cache mutable struct Rodas3PCache{uType, rateType, uNoUnitsType, JType, WType, TabType,
TFType, UFType, F, JCType, GCType, RTolType, A} <:
RosenbrockMutableCache
u::uType
uprev::uType
dense1::rateType
dense2::rateType
dense3::rateType
du::rateType
du1::rateType
du2::rateType
k1::rateType
k2::rateType
k3::rateType
k4::rateType
k5::rateType
fsalfirst::rateType
fsallast::rateType
dT::rateType
J::JType
W::WType
tmp::rateType
atmp::uNoUnitsType
weight::uNoUnitsType
tab::TabType
tf::TFType
uf::UFType
linsolve_tmp::rateType
linsolve::F
jac_config::JCType
grad_config::GCType
reltol::RTolType
alg::A
end

function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
dense1 = zero(rate_prototype)
dense2 = zero(rate_prototype)
dense3 = zero(rate_prototype)
du = zero(rate_prototype)
du1 = zero(rate_prototype)
du2 = zero(rate_prototype)
k1 = zero(rate_prototype)
k2 = zero(rate_prototype)
k3 = zero(rate_prototype)
k4 = zero(rate_prototype)
k5 = zero(rate_prototype)
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
weight = similar(u, uEltypeNoUnits)
recursivefill!(weight, false)
tab = Rodas3PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))

tf = TimeGradientWrapper(f, uprev, p)
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rodas23WCache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg)
end

TruncatedStacktraces.@truncate_stacktrace Rodas23WCache 1
function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
dense1 = zero(rate_prototype)
dense2 = zero(rate_prototype)
dense3 = zero(rate_prototype)
du = zero(rate_prototype)
du1 = zero(rate_prototype)
du2 = zero(rate_prototype)
k1 = zero(rate_prototype)
k2 = zero(rate_prototype)
k3 = zero(rate_prototype)
k4 = zero(rate_prototype)
k5 = zero(rate_prototype)
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
weight = similar(u, uEltypeNoUnits)
recursivefill!(weight, false)
tab = Rodas3PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))

tf = TimeGradientWrapper(f, uprev, p)
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rodas3PCache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg)
end

TruncatedStacktraces.@truncate_stacktrace Rodas3PCache 1

function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tf = TimeDerivativeWrapper(f, u, p)
uf = UDerivativeWrapper(f, t, p)
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
Rodas23WConstantCache(tf, uf,
Rodas3PTableau(constvalue(uBottomEltypeNoUnits),
constvalue(tTypeNoUnits)), J, W, linsolve,
alg_autodiff(alg))
end

function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tf = TimeDerivativeWrapper(f, u, p)
uf = UDerivativeWrapper(f, t, p)
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
Rodas3PConstantCache(tf, uf,
Rodas3PTableau(constvalue(uBottomEltypeNoUnits),
constvalue(tTypeNoUnits)), J, W, linsolve,
alg_autodiff(alg))
end

### Rodas4 methods

struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F, AD} <: OrdinaryDiffEqConstantCache
tf::TF
uf::UF
Expand Down
35 changes: 18 additions & 17 deletions src/dense/rosenbrock_interpolants.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,34 +113,35 @@ end
"""
From MATLAB ODE Suite by Shampine
"""
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rodas4ConstantCache,
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Union{Rodas4ConstantCache, Rodas23WConstantCache, Rodas3PConstantCache},
idxs::Nothing, T::Type{Val{0}})
Θ1 = 1 - Θ
@inbounds Θ1 * y₀ + Θ * (y₁ + Θ1 * (k[1] + Θ * k[2]))
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rodas4Cache, idxs::Nothing,
T::Type{Val{0}})
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Union{Rodas4Cache, Rodas23WCache, Rodas3PCache},
idxs::Nothing, T::Type{Val{0}})
Θ1 = 1 - Θ
@inbounds @.. broadcast=false Θ1 * y₀+Θ * (y₁ + Θ1 * (k[1] + Θ * k[2]))
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{Rodas4ConstantCache, Rodas4Cache}, idxs,
T::Type{Val{0}})
cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache},
idxs, T::Type{Val{0}})
Θ1 = 1 - Θ
@.. broadcast=false Θ1 * y₀[idxs]+Θ * (y₁[idxs] + Θ1 * (k[1][idxs] + Θ * k[2][idxs]))
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{Rodas4ConstantCache, Rodas4Cache},
cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache},
idxs::Nothing, T::Type{Val{0}})
Θ1 = 1 - Θ
@.. broadcast=false out=Θ1 * y₀ + Θ * (y₁ + Θ1 * (k[1] + Θ * k[2]))
out
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache::Rodas4Cache{<:Array},
@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{Rodas4Cache{<:Array}, Rodas23WCache{<:Array}, Rodas3PCache{<:Array}},
idxs::Nothing, T::Type{Val{0}})
Θ1 = 1 - Θ
@inbounds @simd ivdep for i in eachindex(out)
Expand All @@ -150,45 +151,45 @@ end
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{Rodas4ConstantCache, Rodas4Cache}, idxs,
T::Type{Val{0}})
cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache},
idxs, T::Type{Val{0}})
Θ1 = 1 - Θ
@views @.. broadcast=false out=Θ1 * y₀[idxs] +
Θ * (y₁[idxs] + Θ1 * (k[1][idxs] + Θ * k[2][idxs]))
out
end

# First Derivative
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rodas4ConstantCache,
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Union{Rodas4ConstantCache, Rodas23WConstantCache, Rodas3PConstantCache},
idxs::Nothing, T::Type{Val{1}})
@inbounds (k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁) / dt
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Rodas4Cache, idxs::Nothing,
T::Type{Val{1}})
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, cache::Union{Rodas4Cache, Rodas23WCache, Rodas3PCache},
idxs::Nothing, T::Type{Val{1}})
@inbounds @.. broadcast=false (k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ +
y₁)/dt
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{Rodas4ConstantCache, Rodas4Cache}, idxs,
T::Type{Val{1}})
cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache},
idxs, T::Type{Val{1}})
@.. broadcast=false (k[1][idxs] +
Θ * (-2 * k[1][idxs] + 2 * k[2][idxs] - 3 * k[2][idxs] * Θ) -
y₀[idxs] + y₁[idxs])/dt
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{Rodas4ConstantCache, Rodas4Cache},
cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache},
idxs::Nothing, T::Type{Val{1}})
@.. broadcast=false out=(k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁) /
dt
out
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{Rodas4ConstantCache, Rodas4Cache}, idxs,
T::Type{Val{1}})
cache::Union{Rodas4ConstantCache, Rodas4Cache, Rodas23WConstantCache, Rodas23WCache, Rodas3PConstantCache, Rodas3PCache},
idxs, T::Type{Val{1}})
@views @.. broadcast=false out=(k[1][idxs] +
Θ *
(-2 * k[1][idxs] + 2 * k[2][idxs] -
Expand Down
Loading

0 comments on commit 81ce811

Please sign in to comment.