diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl index 4349415e2c..df02b82e5c 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl @@ -300,6 +300,7 @@ jac_cache(c::Rosenbrock4Cache) = (c.J, c.W) tabtype(::Rodas23W) = Rodas23WTableau tabtype(::ROS3P) = ROS3PTableau tabtype(::Rodas3) = Rodas3Tableau +tabtype(::Rodas3P) = Rodas3PTableau tabtype(::Rodas4) = Rodas4Tableau tabtype(::Rodas42) = Rodas42Tableau tabtype(::Rodas4P) = Rodas4PTableau @@ -310,7 +311,7 @@ tabtype(::Rodas5Pr) = Rodas5PTableau tabtype(::Rodas5Pe) = Rodas5PTableau function alg_cache( - alg::Union{ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, + alg::Union{ROS3P, Rodas3, Rodas3P, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, @@ -353,7 +354,7 @@ function alg_cache( end function alg_cache( - alg::Union{ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, + alg::Union{ROS3P, Rodas3, Rodas3P, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index 9b2d292e10..b9cf1df1db 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -495,7 +495,7 @@ end end if integrator.opts.adaptive - utilde = uprev + utilde = zero(u) for i in 1:num_stages utilde = @.. utilde + btilde[i] * ks[i] end @@ -592,14 +592,10 @@ end OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) du1 .= 0 - if mass_matrix === I - for i in 1:(stage - 1) - @.. du1 += dtC[stage, i] * ks[i] - end - else - for i in 1:(stage - 1) - @.. du1 += dtC[stage, i] * ks[i] - end + for i in 1:(stage - 1) + @.. du1 += dtC[stage, i] * ks[i] + end + if mass_matrix !== I mul!(_vec(du2), mass_matrix, _vec(du1)) du1 .= du2 end @@ -617,7 +613,7 @@ end step_limiter!(u, integrator, p, t + dt) if integrator.opts.adaptive - utilde .= 0 + @.. utilde = 0 * u for i in 1:num_stages @.. utilde += btilde[i] * ks[i] end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl index 08710eaed1..fa20a973c7 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl @@ -51,17 +51,17 @@ function ROS3PTableau(T, T2) btilde = T[2.113248654051871, 1, 0.4226497308103742] c = T2[0, 1, 1] d = T[0.7886751345948129, -0.2113248654051871, -1.077350269189626] - H = zeros(T, 3, 3) + H = zeros(T, 2, 3) RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas3Tableau(T, T2) gamma = convert(T, 1 // 2) A = T[ - 0 0 0 - 0 0 0 - 2 0 0 - 2 0 1 + 0 0 0 0 + 0 0 0 0 + 2 0 0 0 + 2 0 1 0 ] C = T[ 0 0 0 @@ -71,31 +71,32 @@ function Rodas3Tableau(T, T2) ] b = T[2, 0, 1, 1] btilde = T[0, 0, 0, 1] - c = T[0, 1, 1] + c = T[0, 0, 1, 1] d = T[1 // 2, 3 // 2, 0, 0] - H = zeros(T, 3, 3) + H = zeros(T, 2, 4) RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas3PTableau(T, T2) gamma = convert(T, 1 // 3) A = T[ - 0 0 0 0 - 4 // 3 0 0 0 - 4 // 3 0 0 0 - 2.90625 3.375 0.40625 0 + 0 0 0 0 0 + 4 // 3 0 0 0 0 + 4 // 3 0 0 0 0 + 2.90625 3.375 0.40625 0 0 + 2.90625 3.375 0.40625 0 0 ] C = T[ 0 0 0 0 - 4.0 0 0 0 + -4.0 0 0 0 8.25 6.75 0 0 - 1.21875 5.0625 1.96875 0 - 4.03125 15.1875 4.03125 6.0 + 1.21875 -5.0625 -1.96875 0 + 4.03125 -15.1875 -4.03125 6.0 ] - b = A[end, :] - btilde = T[0, 0, 0, 1] - c = T2[0, 4 // 9, 1] - d = T[1 // 3, 1 // 9, 1] + b = T[2.90625, 3.375, 0.40625, 1, 0] + btilde = T[0, 0, 0, 1, -1] + c = T2[0, 4 // 9, 4 // 9, 1, 1] + d = T[1 // 3, -1 // 9, 1, 0, 0] H = T[ 1.78125 6.75 0.15625 6 1 4.21875 15.1875 3.09375 9 0 @@ -104,6 +105,10 @@ function Rodas3PTableau(T, T2) RodasTableau(A, C, b, btilde, gamma, c, d, H)#, h2_2) end +function Rodas23WTableau(T, T2) + tab = Rodas3PTableau(T, T2) + RodasTableau(tab.A, tab.C, tab.btilde, tab.b, tab.gamma, tab.c, tab.d, tab.H)#, h2_2) +end @ROS2(:tableau) @ROS23(:tableau)