Skip to content

Commit

Permalink
Rosenbrock23, Rosenbrock32 unified
Browse files Browse the repository at this point in the history
  • Loading branch information
ParamThakkar123 committed Aug 8, 2024
1 parent 8d63077 commit 00ac93c
Showing 1 changed file with 53 additions and 107 deletions.
160 changes: 53 additions & 107 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function initialize!(integrator,
integrator.k[2] = zero(integrator.fsalfirst)
end

@muladd function perform_step!(integrator, cache::Rosenbrock23Cache, repeat_step = false)
@muladd function perform_step!(integrator, cache::Union{Rosenbrock23Cache, Rosenbrock32Cache}, repeat_step = false)
@unpack t, dt, uprev, u, f, p, opts = integrator
@unpack k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W, tmp, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache
@unpack c₃₂, d = cache.tab
Expand Down Expand Up @@ -72,8 +72,15 @@ end
f(f₁, u, p, t + dto2)
integrator.stats.nf += 1

if mass_matrix === I
copyto!(tmp, k₁)
if cache isa Rosenbrock23Cache
if mass_matrix === I
copyto!(tmp, k₁)
end
end
if cache isa Rosenbrock32Cache
if mass_matrix === I
tmp .= k₁
end
else
mul!(_vec(tmp), mass_matrix, _vec(k₁))
end
Expand All @@ -90,7 +97,49 @@ end
@.. broadcast=false k₂+=k₁
@.. broadcast=false u=uprev + dt * k₂
stage_limiter!(u, integrator, p, t + dt)
step_limiter!(u, integrator, p, t + dt)

if cache isa Rosenbrock23Cache
step_limiter!(u, integrator, p, t + dt)
end

if cache isa Rosenbrock32Cache
f(fsallast, tmp, p, t + dt)
integrator.stats.nf += 1

if mass_matrix === I
@.. broadcast=false linsolve_tmp=fsallast - c₃₂ * (k₂ - f₁) - 2(k₁ - fsalfirst) +
dt * dT
else
@.. broadcast=false du2=c₃₂ * k₂ + 2k₁
mul!(_vec(du1), mass_matrix, _vec(du2))
@.. broadcast=false linsolve_tmp=fsallast - du1 + c₃₂ * f₁ + 2fsalfirst + dt * dT
end

linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp))
vecu = _vec(linres.u)
veck3 = _vec(k₃)

@.. broadcast=false veck3=-vecu
integrator.stats.nsolve += 1

@.. broadcast=false u=uprev + dto6 * (k₁ + 4k₂ + k₃)
step_limiter!(u, integrator, p, t + dt)

if integrator.opts.adaptive
@.. broadcast=false tmp=dto6 * (k₁ - 2 * k₂ + k₃)
calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)

if mass_matrix !== I
@.. broadcast=false atmp=ifelse(cache.algebraic_vars, fsallast, false) /
integrator.opts.abstol
integrator.EEst += integrator.opts.internalnorm(atmp, t)
end
end
cache.linsolve = linres.cache
end


if integrator.opts.adaptive
f(fsallast, u, p, t + dt)
Expand Down Expand Up @@ -137,109 +186,6 @@ end
cache.linsolve = linres.cache
end

@muladd function perform_step!(integrator, cache::Rosenbrock32Cache, repeat_step = false)
@unpack t, dt, uprev, u, f, p, opts = integrator
@unpack k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W, tmp, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache
@unpack c₃₂, d = cache.tab

# Assignments
sizeu = size(u)
mass_matrix = integrator.f.mass_matrix

# Precalculations
γ = dt * d
dto2 = dt / 2
dto6 = dt / 6

if repeat_step
f(integrator.fsalfirst, uprev, p, t)
integrator.stats.nf += 1
end

calc_rosenbrock_differentiation!(integrator, cache, γ, γ, repeat_step, false)

calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)

if repeat_step
linres = dolinsolve(
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = γ))
else
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
solverdata = (; gamma = γ))
end

vecu = _vec(linres.u)
veck₁ = _vec(k₁)

@.. broadcast=false veck₁=-vecu
integrator.stats.nsolve += 1

@.. broadcast=false u=uprev + dto2 * k₁
stage_limiter!(u, integrator, p, t + dto2)
f(f₁, u, p, t + dto2)
integrator.stats.nf += 1

if mass_matrix === I
tmp .= k₁
else
mul!(_vec(tmp), mass_matrix, _vec(k₁))
end

@.. broadcast=false linsolve_tmp=f₁ - tmp

linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp))
vecu = _vec(linres.u)
veck2 = _vec(k₂)

@.. broadcast=false veck2=-vecu
integrator.stats.nsolve += 1

@.. broadcast=false k₂+=k₁
@.. broadcast=false tmp=uprev + dt * k₂
stage_limiter!(u, integrator, p, t + dt)
f(fsallast, tmp, p, t + dt)
integrator.stats.nf += 1

if mass_matrix === I
@.. broadcast=false linsolve_tmp=fsallast - c₃₂ * (k₂ - f₁) - 2(k₁ - fsalfirst) +
dt * dT
else
@.. broadcast=false du2=c₃₂ * k₂ + 2k₁
mul!(_vec(du1), mass_matrix, _vec(du2))
@.. broadcast=false linsolve_tmp=fsallast - du1 + c₃₂ * f₁ + 2fsalfirst + dt * dT
end

linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp))
vecu = _vec(linres.u)
veck3 = _vec(k₃)

@.. broadcast=false veck3=-vecu
integrator.stats.nsolve += 1

@.. broadcast=false u=uprev + dto6 * (k₁ + 4k₂ + k₃)

step_limiter!(u, integrator, p, t + dt)

if integrator.opts.adaptive
@.. broadcast=false tmp=dto6 * (k₁ - 2 * k₂ + k₃)
calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)

if mass_matrix !== I
@.. broadcast=false atmp=ifelse(cache.algebraic_vars, fsallast, false) /
integrator.opts.abstol
integrator.EEst += integrator.opts.internalnorm(atmp, t)
end
end
cache.linsolve = linres.cache
end

@muladd function perform_step!(integrator, cache::Rosenbrock23ConstantCache,
repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
Expand Down

0 comments on commit 00ac93c

Please sign in to comment.