Skip to content

Commit

Permalink
Changes on line 1251
Browse files Browse the repository at this point in the history
  • Loading branch information
ParamThakkar123 committed Aug 21, 2024
1 parent 285a16b commit 9821cd2
Showing 1 changed file with 18 additions and 36 deletions.
54 changes: 18 additions & 36 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ end

# Precalculations
dtC = C ./ dt
dtd = dt * d
dtd = dt .* d
dtgamma = dt * gamma

mass_matrix = integrator.f.mass_matrix
Expand All @@ -1229,48 +1229,31 @@ end
end

k = Vector{typeof(uprev)}(undef, 6)
u = uprev

# Stage 1
du = f(uprev, p, t)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
linsolve_tmp = du + dtd[1] * dT
k[1] = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
integrator.stats.nsolve += 1
u = uprev + a[2] * k[1]

# Stage 2 to 4
for i in 2:4
du = f(u, p, t + c[i] * dt)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)

if mass_matrix === I
linsolve_tmp = du + dtd[i] * dT + sum(dtC[1:i] .* k[1:i])
for i in 1:6
if i == 1
du = f(uprev, p, t)
else
linsolve_tmp = du + dtd[i] * dT + mass_matrix * sum(dtC[1:i] .* k[1:i])
du = f(u, p, t + c[i] * dt)
end

k[i] = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
integrator.stats.nsolve += 1
u = uprev + sum(a[(i-1)*i÷2+1:(i-1)*i÷2+i] .* k[1:i])
end

# Stage 5 and 6
for i in 5:6
du = f(u, p, t + dt)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)

integrator.stats.nf += 1

if mass_matrix === I
linsolve_tmp = du + sum(dtC[5*i-4:5*i] .* k[1:5])
linsolve_tmp = du + dtd[i] * dT + sum(dtC[i][j] * k[j] for j in 1:i-1; init=0)
else
linsolve_tmp = du + mass_matrix * sum(dtC[5*i-4:5*i] .* k[1:5])
linsolve_tmp = du + dtd[i] * dT + mass_matrix * sum(dtC[i][j] * k[j] for j in 1:i-1; init=0)
end

k[i] = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
integrator.stats.nsolve += 1
u += k[i]

if i < 6
u = uprev + sum(a[i+1][j] * k[j] for j in 1:i; init=0)
end
end

u = u + k[6]

if integrator.opts.adaptive
atmp = calculate_residuals(k[6], uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
Expand All @@ -1279,10 +1262,9 @@ end

if integrator.opts.calck
@unpack h21, h22, h23, h24, h25, h31, h32, h33, h34, h35 = cache.tab
integrator.k[1] = h21 * k1 + h22 * k2 + h23 * k3 + h24 * k4 + h25 * k5
integrator.k[2] = h31 * k1 + h32 * k2 + h33 * k3 + h34 * k4 + h35 * k5
integrator.k[1] = h21 * k[1] + h22 * k[2] + h23 * k[3] + h24 * k[4] + h25 * k[5]
integrator.k[2] = h31 * k[1] + h32 * k[2] + h33 * k[3] + h34 * k[4] + h35 * k[5]
end

integrator.u = u
return nothing
end
Expand Down

0 comments on commit 9821cd2

Please sign in to comment.