From 120772aa9c6ed7724478f1ef9dfabc05f72de884 Mon Sep 17 00:00:00 2001 From: ParamThakkar123 Date: Fri, 9 Aug 2024 08:57:37 +0530 Subject: [PATCH] Update --- .../src/rosenbrock_perform_step.jl | 140 +++++++++--------- 1 file changed, 67 insertions(+), 73 deletions(-) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index 930c63883d..26232bff32 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -191,7 +191,7 @@ end cache.linsolve = linres.cache end -@muladd function perform_step!(integrator, cache::Union{Rosenbrock23ConstantCache, Rosenbrock32ConstantCache}, +@muladd function perform_step!(integrator, cache::Union{Rosenbrock23ConstantCache, Rosenbrock32ConstantCache, Rosenbrock33ConstantCache}, repeat_step = false) if cache isa Rosenbrock32ConstantCache @@ -209,6 +209,22 @@ end end end + if cache isa Rosenbrock33ConstantCache + @unpack t, dt, uprev, u, f, p = integrator + @unpack tf, uf = cache + @unpack a21, a31, a32, C21, C31, C32, b1, b2, b3, btilde1, btilde2, btilde3, gamma, c2, c3, d1, d2, d3 = cache.tab + + # Precalculations + dtC21 = C21 / dt + dtC31 = C31 / dt + dtC32 = C32 / dt + + dtd1 = dt * d1 + dtd2 = dt * d2 + dtd3 = dt * d3 + dtgamma = dt * gamma + end + mass_matrix = integrator.f.mass_matrix # Time derivative @@ -220,11 +236,55 @@ end return nothing end + if cache isa Rosenbrock33ConstantCache + linsolve_tmp = integrator.fsalfirst + dtd1 * dT + end + k₁ = _reshape(W \ -_vec((integrator.fsalfirst + γ * dT)), axes(uprev)) integrator.stats.nsolve += 1 + + if cache isa Rosenbrock33ConstantCache + u = uprev + a21 * k1 + du = f(u, p, t + c2 * dt) + end + + f₁ = f(uprev + dto2 * k₁, p, t + dto2) integrator.stats.nf += 1 + if cache isa Rosenbrock33ConstantCache + if mass_matrix === I + linsolve_tmp = du + dtd2 * dT + dtC21 * k1 + else + linsolve_tmp = du + dtd2 * dT + mass_matrix * (dtC21 * k1) + end + + k2 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) + integrator.stats.nsolve += 1 + u = uprev + a31 * k1 + a32 * k2 + du = f(u, p, t + c3 * dt) + integrator.stats.nf += 1 + + if mass_matrix === I + linsolve_tmp = du + dtd3 * dT + dtC31 * k1 + dtC32 * k2 + else + linsolve_tmp = du + dtd3 * dT + mass_matrix * (dtC31 * k1 + dtC32 * k2) + end + + k3 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) + integrator.stats.nsolve += 1 + u = uprev + b1 * k1 + b2 * k2 + b3 * k3 + integrator.fsallast = f(u, p, t + dt) + integrator.stats.nf += 1 + + if integrator.opts.adaptive + utilde = btilde1 * k1 + btilde2 * k2 + btilde3 * k3 + atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + end + end + if mass_matrix === I k₂ = _reshape(W \ -_vec(f₁ - k₁), axes(uprev)) + k₁ else @@ -288,6 +348,12 @@ end integrator.EEst += integrator.opts.internalnorm(atmp, t) end end + + if cache isa Rosenbrock33ConstantCache + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + end + integrator.k[1] = k₁ integrator.k[2] = k₂ integrator.u = u @@ -324,78 +390,6 @@ function initialize!(integrator, integrator.stats.nf += 1 end -@muladd function perform_step!(integrator, cache::Rosenbrock33ConstantCache, - repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - @unpack tf, uf = cache - @unpack a21, a31, a32, C21, C31, C32, b1, b2, b3, btilde1, btilde2, btilde3, gamma, c2, c3, d1, d2, d3 = cache.tab - - # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtgamma = dt * gamma - - mass_matrix = integrator.f.mass_matrix - - # Time derivative - dT = calc_tderivative(integrator, cache) - - W = calc_W(integrator, cache, dtgamma, repeat_step, true) - if !issuccess_W(W) - integrator.EEst = 2 - return nothing - end - - linsolve_tmp = integrator.fsalfirst + dtd1 * dT - - k1 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a21 * k1 - du = f(u, p, t + c2 * dt) - integrator.stats.nf += 1 - - if mass_matrix === I - linsolve_tmp = du + dtd2 * dT + dtC21 * k1 - else - linsolve_tmp = du + dtd2 * dT + mass_matrix * (dtC21 * k1) - end - - k2 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + a31 * k1 + a32 * k2 - du = f(u, p, t + c3 * dt) - integrator.stats.nf += 1 - - if mass_matrix === I - linsolve_tmp = du + dtd3 * dT + dtC31 * k1 + dtC32 * k2 - else - linsolve_tmp = du + dtd3 * dT + mass_matrix * (dtC31 * k1 + dtC32 * k2) - end - - k3 = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) - integrator.stats.nsolve += 1 - u = uprev + b1 * k1 + b2 * k2 + b3 * k3 - integrator.fsallast = f(u, p, t + dt) - integrator.stats.nf += 1 - - if integrator.opts.adaptive - utilde = btilde1 * k1 + btilde2 * k2 + btilde3 * k3 - atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u - return nothing -end - @muladd function perform_step!(integrator, cache::Union{Rosenbrock33Cache, Rosenbrock34Cache}, repeat_step = false) @unpack t, dt, uprev, u, f, p = integrator @unpack du, du1, du2, fsalfirst, fsallast, k1, k2, k3, k4, dT, J, W, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache