Skip to content

Commit

Permalink
Rodas5Pr, Rodas5Pe and minor bug fixes of Rodas3P
Browse files Browse the repository at this point in the history
Concerning  issue:
SciML#2054
  • Loading branch information
gstein3m committed Apr 17, 2024
1 parent 1c2a1d9 commit 2669546
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ export MagnusMidpoint, LinearExponential, MagnusLeapfrog, LieEuler, CayleyEuler,
export Rosenbrock23, Rosenbrock32, RosShamp4, Veldd4, Velds4, GRK4T, GRK4A,
Ros4LStab, ROS3P, Rodas3, Rodas23W, Rodas3P, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P,
RosenbrockW6S4OS, ROS34PW1a, ROS34PW1b, ROS34PW2, ROS34PW3, ROS34PRw, ROS3PRL, ROS3PRL2,
ROS2PR, ROS2S, ROS3, ROS3PR, Scholz4_7
ROS2PR, ROS2S, ROS3, ROS3PR, Scholz4_7, Rodas5Pe, Rodas5Pr

export LawsonEuler, NorsettEuler, ETD1, ETDRK2, ETDRK3, ETDRK4, HochOst4, Exp4, EPIRK4s3A,
EPIRK4s3B,
Expand Down
5 changes: 5 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ isfsal(alg::Rodas3P) = false
isfsal(alg::Rodas23W) = false
isfsal(alg::Rodas5) = false
isfsal(alg::Rodas5P) = false
isfsal(alg::Rodas5Pe) = false
isfsal(alg::Rodas5Pr) = false
isfsal(alg::Rodas4) = false
isfsal(alg::Rodas42) = false
isfsal(alg::Rodas4P) = false
Expand Down Expand Up @@ -652,6 +654,8 @@ alg_order(alg::Rodas4P) = 4
alg_order(alg::Rodas4P2) = 4
alg_order(alg::Rodas5) = 5
alg_order(alg::Rodas5P) = 5
alg_order(alg::Rodas5Pe) = 5
alg_order(alg::Rodas5Pr) = 5

alg_order(alg::AB3) = 3
alg_order(alg::AB4) = 4
Expand Down Expand Up @@ -1021,6 +1025,7 @@ isstandard(alg::VCABM) = true
isWmethod(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false
isWmethod(alg::Rosenbrock23) = true
isWmethod(alg::Rosenbrock32) = true
isWmethod(alg::Rodas23W) = true
isWmethod(alg::ROS2S) = true
isWmethod(alg::ROS34PW1a) = true
isWmethod(alg::ROS34PW1b) = true
Expand Down
2 changes: 2 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2982,6 +2982,8 @@ for Alg in [
:Rodas4P2,
:Rodas5,
:Rodas5P,
:Rodas5Pe,
:Rodas5Pr,
]
@eval begin
struct $Alg{CS, AD, F, P, FDT, ST, CJ} <:
Expand Down
4 changes: 2 additions & 2 deletions src/caches/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
constvalue(tTypeNoUnits)), J, W, linsolve)
end

function alg_cache(alg::Rodas5P, u, rate_prototype, ::Type{uEltypeNoUnits},
function alg_cache(alg::Union{Rodas5P,Rodas5Pe,Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
Expand Down Expand Up @@ -1073,7 +1073,7 @@ function alg_cache(alg::Rodas5P, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve, jac_config, grad_config, reltol, alg)
end

function alg_cache(alg::Rodas5P, u, rate_prototype, ::Type{uEltypeNoUnits},
function alg_cache(alg::Union{Rodas5P,Rodas5Pe,Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
Expand Down
2 changes: 1 addition & 1 deletion src/dense/stiff_addsteps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Union{Rodas23WCache{<:A
always_calc_begin = false, allow_calc_end = true,
force_calc_end = false)
if length(k) < 2 || always_calc_begin
@unpack du, du1, du2, tmp, k1, k2, k3, k4, k5, k6, dT, J, W, uf, tf, linsolve_tmp, jac_config, fsalfirst = cache
@unpack du, du1, du2, tmp, k1, k2, k3, k4, k5, dT, J, W, uf, tf, linsolve_tmp, jac_config, fsalfirst = cache
@unpack a21, a41, a42, a43, C21, C31, C32, C41, C42, C43, C51, C52, C53, C54, gamma, c2, c3, d1, d2, d3 = cache.tab

# Assignments
Expand Down
6 changes: 3 additions & 3 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ end
J = f.jac(uprev, p, t)
if J isa StaticArray &&
integrator.alg isa
Union{Rosenbrock23, Rodas23W, Rodas3P,Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P}
Union{Rosenbrock23, Rodas23W, Rodas3P,Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
W = W_transform ? J - mass_matrix * inv(dtgamma) :
dtgamma * J - mass_matrix
else
Expand All @@ -773,7 +773,7 @@ end
W_full
elseif len !== nothing &&
integrator.alg isa
Union{Rosenbrock23, Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P}
Union{Rosenbrock23, Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
StaticWOperator(W_full)
else
DiffEqBase.default_factorize(W_full)
Expand Down Expand Up @@ -920,7 +920,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
len = StaticArrayInterface.known_length(typeof(J))
if len !== nothing &&
alg isa
Union{Rosenbrock23, Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P}
Union{Rosenbrock23, Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
StaticWOperator(J, false)
else
ArrayInterface.lu_instance(J)
Expand Down
2 changes: 1 addition & 1 deletion src/integrators/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ end
return if isdefined(integrator, :fsallast) &&
!(integrator.alg isa
Union{Rosenbrock23, Rosenbrock32, Rodas23W, Rodas3P,Rodas4, Rodas4P, Rodas4P2, Rodas5,
Rodas5P})
Rodas5P, Rodas5Pe, Rodas5Pr})
# Special stiff interpolations do not store the right value in fsallast
out .= integrator.fsallast
else
Expand Down
105 changes: 100 additions & 5 deletions src/perform_step/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1044,8 +1044,13 @@ end
integrator.k[2] = h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5
integrator.k[3] = h2_21 * k1 + h2_22 * k2 + h2_23 * k3 + h2_24 * k4 + h2_25 * k5
if integrator.opts.adaptive
calculate_interpoldiff!(k1, k2, uprev, du, u, integrator.k[1], integrator.k[2], integrator.k[3])
atmp = calculate_residuals(k2, uprev, k1, integrator.opts.abstol,
if isa(linsolve_tmp,AbstractFloat)
u_int, u_diff = calculate_interpoldiff(uprev, du, u, integrator.k[1], integrator.k[2], integrator.k[3])
else
u_int = linsolve_tmp; u_diff = copy(linsolve_tmp)
calculate_interpoldiff!(u_int, u_diff, uprev, du, u, integrator.k[1], integrator.k[2], integrator.k[3])
end
atmp = calculate_residuals(u_diff, uprev, u_int, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
EEst = max(EEst,integrator.opts.internalnorm(atmp, t)) #-- role of t unclear
end
Expand Down Expand Up @@ -1412,6 +1417,29 @@ end
cache.linsolve = linres.cache
end

function calculate_interpoldiff(uprev, up2, up3, c_koeff, d_koeff, c2_koeff)
u_int = 0.0; u_diff = 0.0
for i in eachindex(up2)
a1 = up3[i] + c_koeff[i] - up2[i] - c2_koeff[i]; a2 = d_koeff[i] - c_koeff[i] + c2_koeff[i]; a3 = -d_koeff[i]
dis = a2^2 - 3*a1*a3
u_int = up3[i]; u_diff = 0.0
if dis > 0.0 #-- Min/Max occurs
tau1 = (-a2 - sqrt(dis))/(3*a3); tau2 = (-a2 + sqrt(dis))/(3*a3)
if tau1 > tau2 tau1,tau2 = tau2,tau1; end
for tau in (tau1,tau2)
if (tau > 0.0) && (tau < 1.0)
y_tau = (1 - tau)*uprev[i] + tau*(up3[i] + (1 - tau)*(c_koeff[i] + tau*d_koeff[i]))
dy_tau = ((a3*tau + a2)*tau + a1)*tau
if abs(dy_tau) > abs(u_diff[i])
u_diff = dy_tau; u_int = y_tau
end
end
end
end
end
return u_int, u_diff
end

function calculate_interpoldiff!(u_int, u_diff, uprev, up2, up3, c_koeff, d_koeff, c2_koeff)
for i in eachindex(up2)
a1 = up3[i] + c_koeff[i] - up2[i] - c2_koeff[i]; a2 = d_koeff[i] - c_koeff[i] + c2_koeff[i]; a3 = -d_koeff[i]
Expand Down Expand Up @@ -2125,9 +2153,14 @@ end
k8 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
integrator.stats.nsolve += 1
u = u + k8
linsolve_tmp = k8

if integrator.opts.adaptive
atmp = calculate_residuals(k8, uprev, u, integrator.opts.abstol,
if (integrator.alg isa Rodas5Pe)
linsolve_tmp = 0.2606326497975715*k1 - 0.005158627295444251*k2 + 1.3038988631109731*k3 + 1.235000722062074*k4 +
- 0.7931985603795049*k5 - 1.005448461135913*k6 - 0.18044626132120234*k7 + 0.17051519239113755*k8
end
atmp = calculate_residuals(linsolve_tmp, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
Expand All @@ -2140,6 +2173,20 @@ end
h37 * k7 + h38 * k8
integrator.k[3] = h41 * k1 + h42 * k2 + h43 * k3 + h44 * k4 + h45 * k5 + h46 * k6 +
h47 * k7 + h48 * k8
if (integrator.alg isa Rodas5Pr) && integrator.opts.adaptive && (integrator.EEst < 1.0)
k2 = 0.5*(uprev + u + 0.5 * (integrator.k[1] + 0.5 * (integrator.k[2] + 0.5 * integrator.k[3])))
du1 = (integrator.k[1] + 0.5*(-2*integrator.k[1] + 2*integrator.k[2] +
0.5*(-3*integrator.k[2] + 3*integrator.k[3] - 2*integrator.k[3])) - uprev + u) / dt
du = f(k2, p, t + dt/2)
integrator.stats.nf += 1
if mass_matrix === I
du2 = du1 - du
else
du2 = mass_matrix*du1 - du
end
EEst = norm(du2) / (integrator.opts.abstol + integrator.opts.reltol*norm(k2))
integrator.EEst = max(EEst,integrator.EEst)
end
end

integrator.u = u
Expand Down Expand Up @@ -2354,10 +2401,15 @@ end
@.. broadcast=false veck8=-vecu
integrator.stats.nsolve += 1

du .= k8
u .+= k8

if integrator.opts.adaptive
calculate_residuals!(atmp, k8, uprev, u, integrator.opts.abstol,
if (integrator.alg isa Rodas5Pe)
du = 0.2606326497975715*k1 - 0.005158627295444251*k2 + 1.3038988631109731*k3 + 1.235000722062074*k4 +
- 0.7931985603795049*k5 - 1.005448461135913*k6 - 0.18044626132120234*k7 + 0.17051519239113755*k8
end
calculate_residuals!(atmp, du, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
Expand All @@ -2370,6 +2422,21 @@ end
h35 * k5 + h36 * k6 + h37 * k7 + h38 * k8
@.. broadcast=false integrator.k[3]=h41 * k1 + h42 * k2 + h43 * k3 + h44 * k4 +
h45 * k5 + h46 * k6 + h47 * k7 + h48 * k8
if (integrator.alg isa Rodas5Pr) && integrator.opts.adaptive && (integrator.EEst < 1.0)
k2 = 0.5*(uprev + u + 0.5 * (integrator.k[1] + 0.5 * (integrator.k[2] + 0.5 * integrator.k[3])))
du1 = (integrator.k[1] + 0.5*(-2*integrator.k[1] + 2*integrator.k[2] +
0.5*(-3*integrator.k[2] + 3*integrator.k[3] - 2*integrator.k[3])) - uprev + u) / dt
f(du, k2, p, t + dt/2)
integrator.stats.nf += 1
if mass_matrix === I
du2 = du1 - du
else
mul!(_vec(du2), mass_matrix, _vec(du1))
du2 = du2 - du
end
EEst = norm(du2) / (integrator.opts.abstol + integrator.opts.reltol*norm(k2))
integrator.EEst = max(EEst,integrator.EEst)
end
end
cache.linsolve = linres.cache
end
Expand Down Expand Up @@ -2653,10 +2720,17 @@ end

@inbounds @simd ivdep for i in eachindex(u)
u[i] += k8[i]
du[i] = k8[i]
end

if integrator.opts.adaptive
calculate_residuals!(atmp, k8, uprev, u, integrator.opts.abstol,
if (integrator.alg isa Rodas5Pe)
@inbounds @simd ivdep for i in eachindex(u)
du[i] = 0.2606326497975715*k1[i] - 0.005158627295444251*k2[i] + 1.3038988631109731*k3[i] + 1.235000722062074*k4[i] +
- 0.7931985603795049*k5[i] - 1.005448461135913*k6[i] - 0.18044626132120234*k7[i] + 0.17051519239113755*k8[i]
end
end
calculate_residuals!(atmp, du, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
Expand All @@ -2670,6 +2744,27 @@ end
h35 * k5[i] + h36 * k6[i] + h37 * k7[i] + h38 * k8[i]
integrator.k[3][i] = h41 * k1[i] + h42 * k2[i] + h43 * k3[i] + h44 * k4[i] +
h45 * k5[i] + h46 * k6[i] + h47 * k7[i] + h48 * k8[i]
if (integrator.alg isa Rodas5Pr)
k2[i] = 0.5*(uprev[i] + u[i] + 0.5 * (integrator.k[1][i] + 0.5 * (integrator.k[2][i] + 0.5 * integrator.k[3][i])))
du1[i] = (integrator.k[1][i] + 0.5*(-2*integrator.k[1][i] + 2*integrator.k[2][i] +
0.5*(-3*integrator.k[2][i] + 3*integrator.k[3][i] - 2*integrator.k[3][i])) - uprev[i] + u[i]) / dt
end
end
if integrator.opts.adaptive && (integrator.EEst < 1.0) && (integrator.alg isa Rodas5Pr)
f(du, k2, p, t + dt/2)
integrator.stats.nf += 1
if mass_matrix === I
@inbounds @simd ivdep for i in eachindex(u)
du2[i] = du1[i] - du[i]
end
else
mul!(_vec(du2), mass_matrix, _vec(du1))
@inbounds @simd ivdep for i in eachindex(u)
du2[i] = du2[i] - du[i]
end
end
EEst = norm(du2) / (integrator.opts.abstol + integrator.opts.reltol*norm(k2))
integrator.EEst = max(EEst,integrator.EEst)
end
end
cache.linsolve = linres.cache
Expand Down

0 comments on commit 2669546

Please sign in to comment.