Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
ParamThakkar123 committed Aug 9, 2024
1 parent 329e323 commit 120772a
Showing 1 changed file with 67 additions and 73 deletions.
140 changes: 67 additions & 73 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ end
cache.linsolve = linres.cache
end

@muladd function perform_step!(integrator, cache::Union{Rosenbrock23ConstantCache, Rosenbrock32ConstantCache},
@muladd function perform_step!(integrator, cache::Union{Rosenbrock23ConstantCache, Rosenbrock32ConstantCache, Rosenbrock33ConstantCache},
repeat_step = false)

if cache isa Rosenbrock32ConstantCache
Expand All @@ -209,6 +209,22 @@ end
end
end

if cache isa Rosenbrock33ConstantCache
@unpack t, dt, uprev, u, f, p = integrator
@unpack tf, uf = cache
@unpack a21, a31, a32, C21, C31, C32, b1, b2, b3, btilde1, btilde2, btilde3, gamma, c2, c3, d1, d2, d3 = cache.tab

# Precalculations
dtC21 = C21 / dt
dtC31 = C31 / dt
dtC32 = C32 / dt

dtd1 = dt * d1
dtd2 = dt * d2
dtd3 = dt * d3
dtgamma = dt * gamma
end

mass_matrix = integrator.f.mass_matrix

# Time derivative
Expand All @@ -220,11 +236,55 @@ end
return nothing
end

if cache isa Rosenbrock33ConstantCache
linsolve_tmp = integrator.fsalfirst + dtd1 * dT
end

k₁ = _reshape(W \ -_vec((integrator.fsalfirst + γ * dT)), axes(uprev))
integrator.stats.nsolve += 1

if cache isa Rosenbrock33ConstantCache
u = uprev + a21 * k1
du = f(u, p, t + c2 * dt)
end


f₁ = f(uprev + dto2 * k₁, p, t + dto2)
integrator.stats.nf += 1

if cache isa Rosenbrock33ConstantCache
if mass_matrix === I
linsolve_tmp = du + dtd2 * dT + dtC21 * k1
else
linsolve_tmp = du + dtd2 * dT + mass_matrix * (dtC21 * k1)
end

k2 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
integrator.stats.nsolve += 1
u = uprev + a31 * k1 + a32 * k2
du = f(u, p, t + c3 * dt)
integrator.stats.nf += 1

if mass_matrix === I
linsolve_tmp = du + dtd3 * dT + dtC31 * k1 + dtC32 * k2
else
linsolve_tmp = du + dtd3 * dT + mass_matrix * (dtC31 * k1 + dtC32 * k2)
end

k3 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
integrator.stats.nsolve += 1
u = uprev + b1 * k1 + b2 * k2 + b3 * k3
integrator.fsallast = f(u, p, t + dt)
integrator.stats.nf += 1

if integrator.opts.adaptive
utilde = btilde1 * k1 + btilde2 * k2 + btilde3 * k3
atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
end

if mass_matrix === I
k₂ = _reshape(W \ -_vec(f₁ - k₁), axes(uprev)) + k₁
else
Expand Down Expand Up @@ -288,6 +348,12 @@ end
integrator.EEst += integrator.opts.internalnorm(atmp, t)
end
end

if cache isa Rosenbrock33ConstantCache
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
end

integrator.k[1] = k₁
integrator.k[2] = k₂
integrator.u = u
Expand Down Expand Up @@ -324,78 +390,6 @@ function initialize!(integrator,
integrator.stats.nf += 1
end

@muladd function perform_step!(integrator, cache::Rosenbrock33ConstantCache,
repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
@unpack tf, uf = cache
@unpack a21, a31, a32, C21, C31, C32, b1, b2, b3, btilde1, btilde2, btilde3, gamma, c2, c3, d1, d2, d3 = cache.tab

# Precalculations
dtC21 = C21 / dt
dtC31 = C31 / dt
dtC32 = C32 / dt

dtd1 = dt * d1
dtd2 = dt * d2
dtd3 = dt * d3
dtgamma = dt * gamma

mass_matrix = integrator.f.mass_matrix

# Time derivative
dT = calc_tderivative(integrator, cache)

W = calc_W(integrator, cache, dtgamma, repeat_step, true)
if !issuccess_W(W)
integrator.EEst = 2
return nothing
end

linsolve_tmp = integrator.fsalfirst + dtd1 * dT

k1 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
integrator.stats.nsolve += 1
u = uprev + a21 * k1
du = f(u, p, t + c2 * dt)
integrator.stats.nf += 1

if mass_matrix === I
linsolve_tmp = du + dtd2 * dT + dtC21 * k1
else
linsolve_tmp = du + dtd2 * dT + mass_matrix * (dtC21 * k1)
end

k2 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
integrator.stats.nsolve += 1
u = uprev + a31 * k1 + a32 * k2
du = f(u, p, t + c3 * dt)
integrator.stats.nf += 1

if mass_matrix === I
linsolve_tmp = du + dtd3 * dT + dtC31 * k1 + dtC32 * k2
else
linsolve_tmp = du + dtd3 * dT + mass_matrix * (dtC31 * k1 + dtC32 * k2)
end

k3 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
integrator.stats.nsolve += 1
u = uprev + b1 * k1 + b2 * k2 + b3 * k3
integrator.fsallast = f(u, p, t + dt)
integrator.stats.nf += 1

if integrator.opts.adaptive
utilde = btilde1 * k1 + btilde2 * k2 + btilde3 * k3
atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end

integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
integrator.u = u
return nothing
end

@muladd function perform_step!(integrator, cache::Union{Rosenbrock33Cache, Rosenbrock34Cache}, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
@unpack du, du1, du2, fsalfirst, fsallast, k1, k2, k3, k4, dT, J, W, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache
Expand Down

0 comments on commit 120772a

Please sign in to comment.