Skip to content

Commit

Permalink
Merge pull request #2170 from gstein3m/Rodas5Pe
Browse files Browse the repository at this point in the history
Modifications Rodas5Pe, Rodas5Pr of Rodas5P concerning issue #2054 and minor bug fixes of Rodas3P/23W
  • Loading branch information
ChrisRackauckas authored Apr 29, 2024
2 parents 1040675 + 67656b7 commit 5782438
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ export MagnusMidpoint, LinearExponential, MagnusLeapfrog, LieEuler, CayleyEuler,

export Rosenbrock23, Rosenbrock32, RosShamp4, Veldd4, Velds4, GRK4T, GRK4A,
Ros4LStab, ROS3P, Rodas3, Rodas23W, Rodas3P, Rodas4, Rodas42, Rodas4P, Rodas4P2,
Rodas5, Rodas5P,
Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr,
RosenbrockW6S4OS, ROS34PW1a, ROS34PW1b, ROS34PW2, ROS34PW3, ROS34PRw, ROS3PRL,
ROS3PRL2,
ROS2, ROS2PR, ROS2S, ROS3, ROS3PR, Scholz4_7
Expand Down
5 changes: 5 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ isfsal(alg::Rodas3P) = false
isfsal(alg::Rodas23W) = false
isfsal(alg::Rodas5) = false
isfsal(alg::Rodas5P) = false
isfsal(alg::Rodas5Pr) = false
isfsal(alg::Rodas5Pe) = false
isfsal(alg::Rodas4) = false
isfsal(alg::Rodas42) = false
isfsal(alg::Rodas4P) = false
Expand Down Expand Up @@ -664,6 +666,8 @@ alg_order(alg::Rodas4P) = 4
alg_order(alg::Rodas4P2) = 4
alg_order(alg::Rodas5) = 5
alg_order(alg::Rodas5P) = 5
alg_order(alg::Rodas5Pr) = 5
alg_order(alg::Rodas5Pe) = 5

alg_order(alg::AB3) = 3
alg_order(alg::AB4) = 4
Expand Down Expand Up @@ -1037,6 +1041,7 @@ isstandard(alg::VCABM) = true
isWmethod(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false
isWmethod(alg::Rosenbrock23) = true
isWmethod(alg::Rosenbrock32) = true
isWmethod(alg::Rodas23W) = true
isWmethod(alg::ROS2S) = true
isWmethod(alg::ROS34PW1a) = true
isWmethod(alg::ROS34PW1b) = true
Expand Down
9 changes: 8 additions & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2967,6 +2967,11 @@ University of Geneva, Switzerland.
- Steinebach G. Construction of Rosenbrock–Wanner method Rodas5P and numerical benchmarks within the Julia Differential Equations package.
In: BIT Numerical Mathematics, 63(2), 2023
#### Rodas23W, Rodas3P, Rodas5Pe, Rodas5Pr
- Steinebach G. Rosenbrock methods within OrdinaryDiffEq.jl - Overview, recent developments and applications -
Preprint 2024
https://github.com/hbrs-cse/RosenbrockMethods/blob/main/paper/JuliaPaper.pdf
=#

for Alg in [
Expand Down Expand Up @@ -3000,7 +3005,9 @@ for Alg in [
:Rodas4P,
:Rodas4P2,
:Rodas5,
:Rodas5P
:Rodas5P,
:Rodas5Pe,
:Rodas5Pr
]
@eval begin
struct $Alg{CS, AD, F, P, FDT, ST, CJ} <:
Expand Down
4 changes: 2 additions & 2 deletions src/caches/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
constvalue(tTypeNoUnits)), J, W, linsolve)
end

function alg_cache(alg::Rodas5P, u, rate_prototype, ::Type{uEltypeNoUnits},
function alg_cache(alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
Expand Down Expand Up @@ -1093,7 +1093,7 @@ function alg_cache(alg::Rodas5P, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve, jac_config, grad_config, reltol, alg)
end

function alg_cache(alg::Rodas5P, u, rate_prototype, ::Type{uEltypeNoUnits},
function alg_cache(alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
Expand Down
2 changes: 1 addition & 1 deletion src/dense/stiff_addsteps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p,
always_calc_begin = false, allow_calc_end = true,
force_calc_end = false)
if length(k) < 2 || always_calc_begin
@unpack du, du1, du2, tmp, k1, k2, k3, k4, k5, k6, dT, J, W, uf, tf, linsolve_tmp, jac_config, fsalfirst = cache
@unpack du, du1, du2, tmp, k1, k2, k3, k4, k5, dT, J, W, uf, tf, linsolve_tmp, jac_config, fsalfirst = cache
@unpack a21, a41, a42, a43, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, gamma, c2, c3, d1, d2, d3 = cache.tab

# Assignments
Expand Down
6 changes: 3 additions & 3 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ end
if J isa StaticArray &&
integrator.alg isa
Union{
Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P}
Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
W = W_transform ? J - mass_matrix * inv(dtgamma) :
dtgamma * J - mass_matrix
else
Expand All @@ -775,7 +775,7 @@ end
W_full
elseif len !== nothing &&
integrator.alg isa
Union{Rosenbrock23, Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P}
Union{Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
StaticWOperator(W_full)
else
DiffEqBase.default_factorize(W_full)
Expand Down Expand Up @@ -923,7 +923,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
len = StaticArrayInterface.known_length(typeof(J))
if len !== nothing &&
alg isa
Union{Rosenbrock23, Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P}
Union{Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
StaticWOperator(J, false)
else
ArrayInterface.lu_instance(J)
Expand Down
2 changes: 1 addition & 1 deletion src/integrators/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ end
!(integrator.alg isa
Union{Rosenbrock23, Rosenbrock32, Rodas23W,
Rodas3P, Rodas4, Rodas4P, Rodas4P2, Rodas5,
Rodas5P})
Rodas5P, Rodas5Pe, Rodas5Pr})
# Special stiff interpolations do not store the right value in fsallast
out .= integrator.fsallast
else
Expand Down
125 changes: 111 additions & 14 deletions src/perform_step/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1068,21 +1068,26 @@ end
integrator.k[2] = h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5
integrator.k[3] = h2_21 * k1 + h2_22 * k2 + h2_23 * k3 + h2_24 * k4 + h2_25 * k5
if integrator.opts.adaptive
calculate_interpoldiff!(
k1, k2, uprev, du, u, integrator.k[1], integrator.k[2], integrator.k[3])
atmp = calculate_residuals(k2, uprev, k1, integrator.opts.abstol,
if isa(linsolve_tmp,AbstractFloat)
u_int, u_diff = calculate_interpoldiff(uprev, du, u, integrator.k[1], integrator.k[2], integrator.k[3])
else
u_int = linsolve_tmp
u_diff = linsolve_tmp .+ 0
calculate_interpoldiff!(u_int, u_diff, uprev, du, u, integrator.k[1], integrator.k[2], integrator.k[3])
end
atmp = calculate_residuals(u_diff, uprev, u_int, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
EEst = max(EEst, integrator.opts.internalnorm(atmp, t)) #-- role of t unclear
EEst = max(EEst,integrator.opts.internalnorm(atmp, t)) #-- role of t unclear
end
end

if (integrator.alg isa Rodas23W)
k1[:] = u[:]
u[:] = du[:]
du[:] = k1[:]
k1 = u .+ 0
u = du .+ 0
du = k1 .+ 0
if integrator.opts.calck
integrator.k[1][:] = integrator.k[3][:]
integrator.k[2][:] .= 0.0
integrator.k[1] = integrator.k[3] .+ 0
integrator.k[2] = 0*integrator.k[2]
end
end

Expand Down Expand Up @@ -1432,7 +1437,6 @@ end
calculate_residuals!(atmp, du2, uprev, du1, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
EEst = max(EEst, integrator.opts.internalnorm(atmp, t)) #-- role of t unclear
#println(t," ",EEst," ",du2)
end
end

Expand All @@ -1451,11 +1455,40 @@ end
calculate_residuals!(atmp, u - du, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = max(EEst, integrator.opts.internalnorm(atmp, t))
#println(t," ",EEst," ",integrator.EEst)
end
cache.linsolve = linres.cache
end

function calculate_interpoldiff(uprev, up2, up3, c_koeff, d_koeff, c2_koeff)
u_int = 0.0
u_diff = 0.0
a1 = up3 + c_koeff - up2 - c2_koeff
a2 = d_koeff - c_koeff + c2_koeff
a3 = -d_koeff
dis = a2^2 - 3*a1*a3
u_int = up3
u_diff = 0.0
if dis > 0.0 #-- Min/Max occurs
tau1 = (-a2 - sqrt(dis))/(3*a3)
tau2 = (-a2 + sqrt(dis))/(3*a3)
if tau1 > tau2
tau1,tau2 = tau2,tau1
end
for tau in (tau1,tau2)
if (tau > 0.0) && (tau < 1.0)
y_tau = (1 - tau)*uprev +
tau*(up3 + (1 - tau)*(c_koeff + tau*d_koeff))
dy_tau = ((a3*tau + a2)*tau + a1)*tau
if abs(dy_tau) > abs(u_diff)
u_diff = dy_tau
u_int = y_tau
end
end
end
end
return u_int, u_diff
end

function calculate_interpoldiff!(u_int, u_diff, uprev, up2, up3, c_koeff, d_koeff, c2_koeff)
for i in eachindex(up2)
a1 = up3[i] + c_koeff[i] - up2[i] - c2_koeff[i]
Expand Down Expand Up @@ -2179,9 +2212,14 @@ end
k8 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
integrator.stats.nsolve += 1
u = u + k8
linsolve_tmp = k8

if integrator.opts.adaptive
atmp = calculate_residuals(k8, uprev, u, integrator.opts.abstol,
if (integrator.alg isa Rodas5Pe)
linsolve_tmp = 0.2606326497975715*k1 - 0.005158627295444251*k2 + 1.3038988631109731*k3 + 1.235000722062074*k4 +
- 0.7931985603795049*k5 - 1.005448461135913*k6 - 0.18044626132120234*k7 + 0.17051519239113755*k8
end
atmp = calculate_residuals(linsolve_tmp, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
Expand All @@ -2194,6 +2232,19 @@ end
h37 * k7 + h38 * k8
integrator.k[3] = h41 * k1 + h42 * k2 + h43 * k3 + h44 * k4 + h45 * k5 + h46 * k6 +
h47 * k7 + h48 * k8
if (integrator.alg isa Rodas5Pr) && integrator.opts.adaptive && (integrator.EEst < 1.0)
k2 = 0.5*(uprev + u + 0.5 * (integrator.k[1] + 0.5 * (integrator.k[2] + 0.5 * integrator.k[3])))
du1 = ( 0.25*(integrator.k[2] + integrator.k[3]) - uprev + u) / dt
du = f(k2, p, t + dt/2)
integrator.stats.nf += 1
if mass_matrix === I
du2 = du1 - du
else
du2 = mass_matrix*du1 - du
end
EEst = norm(du2) / (integrator.opts.abstol + integrator.opts.reltol*norm(k2))
integrator.EEst = max(EEst,integrator.EEst)
end
end

integrator.u = u
Expand Down Expand Up @@ -2409,10 +2460,15 @@ end
@.. broadcast=false veck8=-vecu
integrator.stats.nsolve += 1

du .= k8
u .+= k8

if integrator.opts.adaptive
calculate_residuals!(atmp, k8, uprev, u, integrator.opts.abstol,
if (integrator.alg isa Rodas5Pe)
du = 0.2606326497975715*k1 - 0.005158627295444251*k2 + 1.3038988631109731*k3 + 1.235000722062074*k4 +
- 0.7931985603795049*k5 - 1.005448461135913*k6 - 0.18044626132120234*k7 + 0.17051519239113755*k8
end
calculate_residuals!(atmp, du, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
Expand All @@ -2425,6 +2481,20 @@ end
h35 * k5 + h36 * k6 + h37 * k7 + h38 * k8
@.. broadcast=false integrator.k[3]=h41 * k1 + h42 * k2 + h43 * k3 + h44 * k4 +
h45 * k5 + h46 * k6 + h47 * k7 + h48 * k8
if (integrator.alg isa Rodas5Pr) && integrator.opts.adaptive && (integrator.EEst < 1.0)
k2 = 0.5*(uprev + u + 0.5 * (integrator.k[1] + 0.5 * (integrator.k[2] + 0.5 * integrator.k[3])))
du1 = ( 0.25*(integrator.k[2] + integrator.k[3]) - uprev + u) / dt
f(du, k2, p, t + dt/2)
integrator.stats.nf += 1
if mass_matrix === I
du2 = du1 - du
else
mul!(_vec(du2), mass_matrix, _vec(du1))
du2 = du2 - du
end
EEst = norm(du2) / (integrator.opts.abstol + integrator.opts.reltol*norm(k2))
integrator.EEst = max(EEst,integrator.EEst)
end
end
cache.linsolve = linres.cache
end
Expand Down Expand Up @@ -2709,10 +2779,17 @@ end

@inbounds @simd ivdep for i in eachindex(u)
u[i] += k8[i]
du[i] = k8[i]
end

if integrator.opts.adaptive
calculate_residuals!(atmp, k8, uprev, u, integrator.opts.abstol,
if (integrator.alg isa Rodas5Pe)
@inbounds @simd ivdep for i in eachindex(u)
du[i] = 0.2606326497975715*k1[i] - 0.005158627295444251*k2[i] + 1.3038988631109731*k3[i] + 1.235000722062074*k4[i] +
- 0.7931985603795049*k5[i] - 1.005448461135913*k6[i] - 0.18044626132120234*k7[i] + 0.17051519239113755*k8[i]
end
end
calculate_residuals!(atmp, du, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
Expand All @@ -2726,6 +2803,26 @@ end
h35 * k5[i] + h36 * k6[i] + h37 * k7[i] + h38 * k8[i]
integrator.k[3][i] = h41 * k1[i] + h42 * k2[i] + h43 * k3[i] + h44 * k4[i] +
h45 * k5[i] + h46 * k6[i] + h47 * k7[i] + h48 * k8[i]
if (integrator.alg isa Rodas5Pr)
k2[i] = 0.5*(uprev[i] + u[i] + 0.5 * (integrator.k[1][i] + 0.5 * (integrator.k[2][i] + 0.5 * integrator.k[3][i])))
du1[i] = ( 0.25*(integrator.k[2][i] + integrator.k[3][i]) - uprev[i] + u[i]) / dt
end
end
if integrator.opts.adaptive && (integrator.EEst < 1.0) && (integrator.alg isa Rodas5Pr)
f(du, k2, p, t + dt/2)
integrator.stats.nf += 1
if mass_matrix === I
@inbounds @simd ivdep for i in eachindex(u)
du2[i] = du1[i] - du[i]
end
else
mul!(_vec(du2), mass_matrix, _vec(du1))
@inbounds @simd ivdep for i in eachindex(u)
du2[i] = du2[i] - du[i]
end
end
EEst = norm(du2) / (integrator.opts.abstol + integrator.opts.reltol*norm(k2))
integrator.EEst = max(EEst,integrator.EEst)
end
end
cache.linsolve = linres.cache
Expand Down
Loading

0 comments on commit 5782438

Please sign in to comment.