Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
ParamThakkar123 committed Aug 8, 2024
1 parent a006653 commit 8d63077
Showing 1 changed file with 88 additions and 146 deletions.
234 changes: 88 additions & 146 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -495,25 +495,50 @@ end
return nothing
end

@muladd function perform_step!(integrator, cache::Rosenbrock33Cache, repeat_step = false)
@muladd function perform_step!(integrator, cache::Union{Rosenbrock33Cache, Rosenbrock34Cache}, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
@unpack du, du1, du2, fsalfirst, fsallast, k1, k2, k3, dT, J, W, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache
@unpack a21, a31, a32, C21, C31, C32, b1, b2, b3, btilde1, btilde2, btilde3, gamma, c2, c3, d1, d2, d3 = cache.tab
@unpack du, du1, du2, fsalfirst, fsallast, k1, k2, k3, k4, dT, J, W, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache
@unpack a21, a31, a32, a41, a42, a43, C21, C31, C32, C41, C42, C43, b1, b2, b3, b4, btilde1, btilde2, btilde3, btilde4, gamma, c2, c3, d1, d2, d3, d4 = cache.tab

# Assignments
mass_matrix = integrator.f.mass_matrix
sizeu = size(u)
utilde = du
if cache isa Rosenbrock34Cache
uidx = eachindex(integrator.uprev)
sizeu = size(u)
mass_matrix = integrator.f.mass_matrix
utilde = du
end
if cache isa Rosenbrock33Cache
mass_matrix = integrator.f.mass_matrix
sizeu = size(u)
utilde = du
end

# Precalculations
dtC21 = C21 / dt
dtC31 = C31 / dt
dtC32 = C32 / dt
if cache isa Rosenbrock33Cache
dtC21 = C21 / dt
dtC31 = C31 / dt
dtC32 = C32 / dt

dtd1 = dt * d1
dtd2 = dt * d2
dtd3 = dt * d3
dtgamma = dt * gamma
dtd1 = dt * d1
dtd2 = dt * d2
dtd3 = dt * d3
dtgamma = dt * gamma
end

if cache isa Rosenbrock34Cache
dtC21 = C21 / dt
dtC31 = C31 / dt
dtC32 = C32 / dt
dtC41 = C41 / dt
dtC42 = C42 / dt
dtC43 = C43 / dt

dtd1 = dt * d1
dtd2 = dt * d2
dtd3 = dt * d3
dtd4 = dt * d4
dtgamma = dt * gamma
end

calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true)

Expand All @@ -538,10 +563,20 @@ end
@.. broadcast=false veck1=-vecu
integrator.stats.nsolve += 1

@.. broadcast=false u=uprev + a21 * k1
stage_limiter!(u, integrator, p, t + c2 * dt)
f(du, u, p, t + c2 * dt)
integrator.stats.nf += 1
#=
a21 == 0 and c2 == 0
so du = integrator.fsalfirst!
@.. broadcast=false u = uprev + a21*k1
f(du, u, p, t+c2*dt)
=#

if cache isa Rosenbrock33Cache
@.. broadcast=false u=uprev + a21 * k1
stage_limiter!(u, integrator, p, t + c2 * dt)
f(du, u, p, t + c2 * dt)
integrator.stats.nf += 1
end

if mass_matrix === I
@.. broadcast=false linsolve_tmp=du + dtd2 * dT + dtC21 * k1
Expand All @@ -552,11 +587,13 @@ end
end

linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp))
vecu = _vec(linres.u)
if cache isa Rosenbrock33Cache
vecu = _vec(linres.u)
end

veck2 = _vec(k2)

@.. broadcast=false veck2=-vecu

integrator.stats.nsolve += 1

@.. broadcast=false u=uprev + a31 * k1 + a32 * k2
Expand All @@ -573,18 +610,44 @@ end
end

linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp))
vecu = _vec(linres.u)
if cache isa Rosenbrock33Cache
vecu = _vec(linres.u)
end
veck3 = _vec(k3)

@.. broadcast=false veck3=-vecu

integrator.stats.nsolve += 1

@.. broadcast=false u=uprev + b1 * k1 + b2 * k2 + b3 * k3

step_limiter!(u, integrator, p, t + dt)

f(fsallast, u, p, t + dt)
if cache isa Rosenbrock33Cache
f(fsallast, u, p, t + dt)
end

if cache isa Rosenbrock34Cache
f(du, u, p, t + dt) #-- c4 = 1
end

if cache isa Rosenbrock34Cache
if mass_matrix === I
@.. broadcast=false linsolve_tmp=du + dtd4 * dT + dtC41 * k1 + dtC42 * k2 +
dtC43 * k3
else
@.. broadcast=false du1=dtC41 * k1 + dtC42 * k2 + dtC43 * k3
mul!(_vec(du2), mass_matrix, _vec(du1))
@.. broadcast=false linsolve_tmp=du + dtd4 * dT + du2
end
linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp))
veck4 = _vec(k4)
@.. broadcast=false veck4=-vecu
integrator.stats.nsolve += 1

@.. broadcast=false u=uprev + b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4

step_limiter!(u, integrator, p, t + dt)

f(fsallast, u, p, t + dt)
end

integrator.stats.nf += 1

if integrator.opts.adaptive
Expand Down Expand Up @@ -685,127 +748,6 @@ end
return nothing
end

@muladd function perform_step!(integrator, cache::Rosenbrock34Cache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
@unpack du, du1, du2, fsalfirst, fsallast, k1, k2, k3, k4, dT, J, W, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache
@unpack a21, a31, a32, a41, a42, a43, C21, C31, C32, C41, C42, C43, b1, b2, b3, b4, btilde1, btilde2, btilde3, btilde4, gamma, c2, c3, d1, d2, d3, d4 = cache.tab

# Assignments
uidx = eachindex(integrator.uprev)
sizeu = size(u)
mass_matrix = integrator.f.mass_matrix
utilde = du

# Precalculations
dtC21 = C21 / dt
dtC31 = C31 / dt
dtC32 = C32 / dt
dtC41 = C41 / dt
dtC42 = C42 / dt
dtC43 = C43 / dt

dtd1 = dt * d1
dtd2 = dt * d2
dtd3 = dt * d3
dtd4 = dt * d4
dtgamma = dt * gamma

calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true)

calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = dtgamma))
end

vecu = _vec(linres.u)
veck1 = _vec(k1)

@.. broadcast=false veck1=-vecu
integrator.stats.nsolve += 1

#=
a21 == 0 and c2 == 0
so du = integrator.fsalfirst!
@.. broadcast=false u = uprev + a21*k1
f(du, u, p, t+c2*dt)
=#

if mass_matrix === I
@.. broadcast=false linsolve_tmp=fsalfirst + dtd2 * dT + dtC21 * k1
else
@.. broadcast=false du1=dtC21 * k1
mul!(_vec(du2), mass_matrix, _vec(du1))
@.. broadcast=false linsolve_tmp=fsalfirst + dtd2 * dT + du2
end

linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp))
veck2 = _vec(k2)
@.. broadcast=false veck2=-vecu
integrator.stats.nsolve += 1

@.. broadcast=false u=uprev + a31 * k1 + a32 * k2
stage_limiter!(u, integrator, p, t + c3 * dt)
f(du, u, p, t + c3 * dt)
integrator.stats.nf += 1

if mass_matrix === I
@.. broadcast=false linsolve_tmp=du + dtd3 * dT + dtC31 * k1 + dtC32 * k2
else
@.. broadcast=false du1=dtC31 * k1 + dtC32 * k2
mul!(_vec(du2), mass_matrix, _vec(du1))
@.. broadcast=false linsolve_tmp=du + dtd3 * dT + du2
end

linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp))
veck3 = _vec(k3)
@.. broadcast=false veck3=-vecu
integrator.stats.nsolve += 1
@.. broadcast=false u=uprev + a41 * k1 + a42 * k2 + a43 * k3
stage_limiter!(u, integrator, p, t + dt)
f(du, u, p, t + dt) #-- c4 = 1
integrator.stats.nf += 1

if mass_matrix === I
@.. broadcast=false linsolve_tmp=du + dtd4 * dT + dtC41 * k1 + dtC42 * k2 +
dtC43 * k3
else
@.. broadcast=false du1=dtC41 * k1 + dtC42 * k2 + dtC43 * k3
mul!(_vec(du2), mass_matrix, _vec(du1))
@.. broadcast=false linsolve_tmp=du + dtd4 * dT + du2
end

linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp))
veck4 = _vec(k4)
@.. broadcast=false veck4=-vecu
integrator.stats.nsolve += 1

@.. broadcast=false u=uprev + b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4

step_limiter!(u, integrator, p, t + dt)

f(fsallast, u, p, t + dt)
integrator.stats.nf += 1

if integrator.opts.adaptive
@.. broadcast=false utilde=btilde1 * k1 + btilde2 * k2 + btilde3 * k3 + btilde4 * k4
calculate_residuals!(atmp, utilde, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
cache.linsolve = linres.cache
end

################################################################################

#### ROS2 type method
Expand Down

0 comments on commit 8d63077

Please sign in to comment.