diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index 19a1a6dc83..83a2c18b51 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -1206,7 +1206,7 @@ function initialize!(integrator, cache::RosenbrockConstantCache) integrator.k[2] = zero(integrator.u) end -@muladd function perform_step!(integrator, cache::RosenbrockConstantCache, repeat_step = false) +@muladd function perform_step!(integrator, cache::Rodas4ConstantCache, repeat_step = false) @unpack t, dt, uprev, u, f, p = integrator @unpack tf, uf = cache @unpack a, C, gamma, c, d = cache.tab @@ -1231,33 +1231,31 @@ end k = Vector{typeof(uprev)}(undef, 6) u = uprev - coeffs = [ - (a[2], c[2], dtC[1:1], dtd[1]), - (a[3:4], c[3], dtC[2:3], dtd[2]), - (a[5:7], c[4], dtC[4:6], dtd[3]), - (a[8:11], dtC[7:10], dtd[4]) - ] + # Stage 1 + du = f(uprev, p, t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + linsolve_tmp = du + dtd[1] * dT + k[1] = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) + integrator.stats.nsolve += 1 + u = uprev + a[2] * k[1] - for i in 1:4 - if i == 1 - du = f(uprev, p, t) - else - du = f(u, p, t + coeffs[i][2] * dt) - end + # Stage 2 to 4 + for i in 2:4 + du = f(u, p, t + c[i] * dt) OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) if mass_matrix === I - linsolve_tmp = du + coeffs[i][4] * dT + sum(dtC[1:length(coeffs[i][3])] .* k[1:length(coeffs[i][3])]) + linsolve_tmp = du + dtd[i] * dT + sum(dtC[1:i] .* k[1:i]) else - linsolve_tmp = du + coeffs[i][4] * dT + mass_matrix * sum(dtC[1:length(coeffs[i][3])] .* k[1:length(coeffs[i][3])]) + linsolve_tmp = du + dtd[i] * dT + mass_matrix * sum(dtC[1:i] .* k[1:i]) end k[i] = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) integrator.stats.nsolve += 1 - u = uprev + sum(coeffs[i][1] .* k[1:i]) + u = uprev + sum(a[(i-1)*i÷2+1:(i-1)*i÷2+i] .* k[1:i]) end - # Compute k5 and k6 using the last stages separately + # Stage 5 and 6 for i in 5:6 du = f(u, p, t + dt) OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) @@ -1280,8 +1278,9 @@ end end if integrator.opts.calck - integrator.k[1] = sum(cache.tab.h[1:5] .* k[1:5]) - integrator.k[2] = sum(cache.tab.h[6:10] .* k[1:5]) + @unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35 = cache.tab + integrator.k[1] = h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + h25 * k5 + integrator.k[2] = h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5 end integrator.u = u