Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Sep 21, 2024
1 parent b9dcd23 commit cb77846
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 30 deletions.
5 changes: 3 additions & 2 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ jac_cache(c::Rosenbrock4Cache) = (c.J, c.W)
tabtype(::Rodas23W) = Rodas23WTableau
tabtype(::ROS3P) = ROS3PTableau
tabtype(::Rodas3) = Rodas3Tableau
tabtype(::Rodas3P) = Rodas3PTableau
tabtype(::Rodas4) = Rodas4Tableau
tabtype(::Rodas42) = Rodas42Tableau
tabtype(::Rodas4P) = Rodas4PTableau
Expand All @@ -310,7 +311,7 @@ tabtype(::Rodas5Pr) = Rodas5PTableau
tabtype(::Rodas5Pe) = Rodas5PTableau

function alg_cache(
alg::Union{ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr},
alg::Union{ROS3P, Rodas3, Rodas3P, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr},
u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
Expand Down Expand Up @@ -353,7 +354,7 @@ function alg_cache(
end

function alg_cache(
alg::Union{ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr},
alg::Union{ROS3P, Rodas3, Rodas3P, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr},
u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
Expand Down
16 changes: 6 additions & 10 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ end
end

if integrator.opts.adaptive
utilde = uprev
utilde = zero(u)
for i in 1:num_stages
utilde = @.. utilde + btilde[i] * ks[i]
end
Expand Down Expand Up @@ -592,14 +592,10 @@ end
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)

du1 .= 0
if mass_matrix === I
for i in 1:(stage - 1)
@.. du1 += dtC[stage, i] * ks[i]
end
else
for i in 1:(stage - 1)
@.. du1 += dtC[stage, i] * ks[i]
end
for i in 1:(stage - 1)
@.. du1 += dtC[stage, i] * ks[i]
end
if mass_matrix !== I
mul!(_vec(du2), mass_matrix, _vec(du1))
du1 .= du2
end
Expand All @@ -617,7 +613,7 @@ end
step_limiter!(u, integrator, p, t + dt)

if integrator.opts.adaptive
utilde .= 0
@.. utilde = 0 * u
for i in 1:num_stages
@.. utilde += btilde[i] * ks[i]
end
Expand Down
41 changes: 23 additions & 18 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ function ROS3PTableau(T, T2)
btilde = T[2.113248654051871, 1, 0.4226497308103742]
c = T2[0, 1, 1]
d = T[0.7886751345948129, -0.2113248654051871, -1.077350269189626]
H = zeros(T, 3, 3)
H = zeros(T, 2, 3)
RodasTableau(A, C, b, btilde, gamma, c, d, H)
end

function Rodas3Tableau(T, T2)
gamma = convert(T, 1 // 2)
A = T[
0 0 0
0 0 0
2 0 0
2 0 1
0 0 0 0
0 0 0 0
2 0 0 0
2 0 1 0
]
C = T[
0 0 0
Expand All @@ -71,31 +71,32 @@ function Rodas3Tableau(T, T2)
]
b = T[2, 0, 1, 1]
btilde = T[0, 0, 0, 1]
c = T[0, 1, 1]
c = T[0, 0, 1, 1]
d = T[1 // 2, 3 // 2, 0, 0]
H = zeros(T, 3, 3)
H = zeros(T, 2, 4)
RodasTableau(A, C, b, btilde, gamma, c, d, H)
end

function Rodas3PTableau(T, T2)
gamma = convert(T, 1 // 3)
A = T[
0 0 0 0
4 // 3 0 0 0
4 // 3 0 0 0
2.90625 3.375 0.40625 0
0 0 0 0 0
4 // 3 0 0 0 0
4 // 3 0 0 0 0
2.90625 3.375 0.40625 0 0
2.90625 3.375 0.40625 0 0
]
C = T[
0 0 0 0
4.0 0 0 0
-4.0 0 0 0
8.25 6.75 0 0
1.21875 5.0625 1.96875 0
4.03125 15.1875 4.03125 6.0
1.21875 -5.0625 -1.96875 0
4.03125 -15.1875 -4.03125 6.0
]
b = A[end, :]
btilde = T[0, 0, 0, 1]
c = T2[0, 4 // 9, 1]
d = T[1 // 3, 1 // 9, 1]
b = T[2.90625, 3.375, 0.40625, 1, 0]
btilde = T[0, 0, 0, 1, -1]
c = T2[0, 4 // 9, 4 // 9, 1, 1]
d = T[1 // 3, -1 // 9, 1, 0, 0]
H = T[
1.78125 6.75 0.15625 6 1
4.21875 15.1875 3.09375 9 0
Expand All @@ -104,6 +105,10 @@ function Rodas3PTableau(T, T2)
RodasTableau(A, C, b, btilde, gamma, c, d, H)#, h2_2)
end

function Rodas23WTableau(T, T2)
tab = Rodas3PTableau(T, T2)
RodasTableau(tab.A, tab.C, tab.btilde, tab.b, tab.gamma, tab.c, tab.d, tab.H)#, h2_2)
end
@ROS2(:tableau)

@ROS23(:tableau)
Expand Down

0 comments on commit cb77846

Please sign in to comment.