Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ParamThakkar123 committed Aug 21, 2024
1 parent c0eee83 commit f4cb382
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1206,14 +1206,14 @@ function initialize!(integrator, cache::RosenbrockConstantCache)
integrator.k[2] = zero(integrator.u)
end

@muladd function perform_step!(integrator, cache::RosenbrockConstantCache, repeat_step = false)
@muladd function perform_step!(integrator, cache::Rodas4ConstantCache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
@unpack tf, uf = cache
@unpack a, C, gamma, c, d = cache.tab

# Precalculations
dtC = C ./ dt
dtd = dt .* d
dtd = dt * d
dtgamma = dt * gamma

mass_matrix = integrator.f.mass_matrix
Expand All @@ -1229,30 +1229,49 @@ end
end

k = Vector{typeof(uprev)}(undef, 6)
u = uprev

coeffs = [
(a[2], c[2], dtC[1:1], dtd[1]),
(a[3:4], c[3], dtC[2:3], dtd[2]),
(a[5:7], c[4], dtC[4:6], dtd[3]),
(a[8:11], dtC[7:10], dtd[4])
]

for i in 1:6
for i in 1:4
if i == 1
du = f(uprev, p, t)
else
du = f(u, p, t + c[i] * dt)
du = f(u, p, t + coeffs[i][2] * dt)
end
integrator.stats.nf += 1
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)

if mass_matrix === I
linsolve_tmp = du + dtd[i] * dT + sum(dtC[i][j] * k[j] for j in 1:i-1; init=0)
linsolve_tmp = du + coeffs[i][4] * dT + sum(dtC[1:length(coeffs[i][3])] .* k[1:length(coeffs[i][3])])
else
linsolve_tmp = du + dtd[i] * dT + mass_matrix * sum(dtC[i][j] * k[j] for j in 1:i-1; init=0)
linsolve_tmp = du + coeffs[i][4] * dT + mass_matrix * sum(dtC[1:length(coeffs[i][3])] .* k[1:length(coeffs[i][3])])
end

k[i] = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
integrator.stats.nsolve += 1

if i < 6
u = uprev + sum(a[i+1][j] * k[j] for j in 1:i; init=0)
end
u = uprev + sum(coeffs[i][1] .* k[1:i])
end

u = u + k[6]
# Compute k5 and k6 using the last stages separately
for i in 5:6
du = f(u, p, t + dt)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)

if mass_matrix === I
linsolve_tmp = du + sum(dtC[5*i-4:5*i] .* k[1:5])
else
linsolve_tmp = du + mass_matrix * sum(dtC[5*i-4:5*i] .* k[1:5])
end

k[i] = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
integrator.stats.nsolve += 1
u += k[i]
end

if integrator.opts.adaptive
atmp = calculate_residuals(k[6], uprev, u, integrator.opts.abstol,
Expand All @@ -1261,10 +1280,10 @@ end
end

if integrator.opts.calck
@unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35 = cache.tab
integrator.k[1] = h21 * k[1] + h22 * k[2] + h23 * k[3] + h24 * k[4] + h25 * k[5]
integrator.k[2] = h31 * k[1] + h32 * k[2] + h33 * k[3] + h34 * k[4] + h35 * k[5]
integrator.k[1] = sum(cache.tab.h[1:5] .* k[1:5])
integrator.k[2] = sum(cache.tab.h[6:10] .* k[1:5])
end

integrator.u = u
return nothing
end
Expand Down

0 comments on commit f4cb382

Please sign in to comment.