From 8d63077f862672fd2e22d69e38825f48f4f1fadb Mon Sep 17 00:00:00 2001 From: ParamThakkar123 Date: Thu, 8 Aug 2024 22:08:59 +0530 Subject: [PATCH] Update --- .../src/rosenbrock_perform_step.jl | 234 +++++++----------- 1 file changed, 88 insertions(+), 146 deletions(-) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index cc96b2447d..b8ba095a2e 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -495,25 +495,50 @@ end return nothing end -@muladd function perform_step!(integrator, cache::Rosenbrock33Cache, repeat_step = false) +@muladd function perform_step!(integrator, cache::Union{Rosenbrock33Cache, Rosenbrock34Cache}, repeat_step = false) @unpack t, dt, uprev, u, f, p = integrator - @unpack du, du1, du2, fsalfirst, fsallast, k1, k2, k3, dT, J, W, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache - @unpack a21, a31, a32, C21, C31, C32, b1, b2, b3, btilde1, btilde2, btilde3, gamma, c2, c3, d1, d2, d3 = cache.tab + @unpack du, du1, du2, fsalfirst, fsallast, k1, k2, k3, k4, dT, J, W, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache + @unpack a21, a31, a32, a41, a42, a43, C21, C31, C32, C41, C42, C43, b1, b2, b3, b4, btilde1, btilde2, btilde3, btilde4, gamma, c2, c3, d1, d2, d3, d4 = cache.tab # Assignments - mass_matrix = integrator.f.mass_matrix - sizeu = size(u) - utilde = du + if cache isa Rosenbrock34Cache + uidx = eachindex(integrator.uprev) + sizeu = size(u) + mass_matrix = integrator.f.mass_matrix + utilde = du + end + if cache isa Rosenbrock33Cache + mass_matrix = integrator.f.mass_matrix + sizeu = size(u) + utilde = du + end # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt + if cache isa Rosenbrock33Cache + dtC21 = C21 / dt + dtC31 = C31 / dt + dtC32 = C32 / dt - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtgamma = dt * gamma + dtd1 = dt * d1 + dtd2 = dt * d2 + dtd3 = dt * d3 + dtgamma = dt * gamma + end + + if cache isa Rosenbrock34Cache + dtC21 = C21 / dt + dtC31 = C31 / dt + dtC32 = C32 / dt + dtC41 = C41 / dt + dtC42 = C42 / dt + dtC43 = C43 / dt + + dtd1 = dt * d1 + dtd2 = dt * d2 + dtd3 = dt * d3 + dtd4 = dt * d4 + dtgamma = dt * gamma + end calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true) @@ -538,10 +563,20 @@ end @.. broadcast=false veck1=-vecu integrator.stats.nsolve += 1 - @.. broadcast=false u=uprev + a21 * k1 - stage_limiter!(u, integrator, p, t + c2 * dt) - f(du, u, p, t + c2 * dt) - integrator.stats.nf += 1 + #= + a21 == 0 and c2 == 0 + so du = integrator.fsalfirst! + @.. broadcast=false u = uprev + a21*k1 + + f(du, u, p, t+c2*dt) + =# + + if cache isa Rosenbrock33Cache + @.. broadcast=false u=uprev + a21 * k1 + stage_limiter!(u, integrator, p, t + c2 * dt) + f(du, u, p, t + c2 * dt) + integrator.stats.nf += 1 + end if mass_matrix === I @.. broadcast=false linsolve_tmp=du + dtd2 * dT + dtC21 * k1 @@ -552,11 +587,13 @@ end end linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) + if cache isa Rosenbrock33Cache + vecu = _vec(linres.u) + end + veck2 = _vec(k2) @.. broadcast=false veck2=-vecu - integrator.stats.nsolve += 1 @.. broadcast=false u=uprev + a31 * k1 + a32 * k2 @@ -573,18 +610,44 @@ end end linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) + if cache isa Rosenbrock33Cache + vecu = _vec(linres.u) + end veck3 = _vec(k3) - @.. broadcast=false veck3=-vecu - integrator.stats.nsolve += 1 - @.. broadcast=false u=uprev + b1 * k1 + b2 * k2 + b3 * k3 - step_limiter!(u, integrator, p, t + dt) - f(fsallast, u, p, t + dt) + if cache isa Rosenbrock33Cache + f(fsallast, u, p, t + dt) + end + + if cache isa Rosenbrock34Cache + f(du, u, p, t + dt) #-- c4 = 1 + end + + if cache isa Rosenbrock34Cache + if mass_matrix === I + @.. broadcast=false linsolve_tmp=du + dtd4 * dT + dtC41 * k1 + dtC42 * k2 + + dtC43 * k3 + else + @.. broadcast=false du1=dtC41 * k1 + dtC42 * k2 + dtC43 * k3 + mul!(_vec(du2), mass_matrix, _vec(du1)) + @.. broadcast=false linsolve_tmp=du + dtd4 * dT + du2 + end + linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) + veck4 = _vec(k4) + @.. broadcast=false veck4=-vecu + integrator.stats.nsolve += 1 + + @.. broadcast=false u=uprev + b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4 + + step_limiter!(u, integrator, p, t + dt) + + f(fsallast, u, p, t + dt) + end + integrator.stats.nf += 1 if integrator.opts.adaptive @@ -685,127 +748,6 @@ end return nothing end -@muladd function perform_step!(integrator, cache::Rosenbrock34Cache, repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - @unpack du, du1, du2, fsalfirst, fsallast, k1, k2, k3, k4, dT, J, W, uf, tf, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter! = cache - @unpack a21, a31, a32, a41, a42, a43, C21, C31, C32, C41, C42, C43, b1, b2, b3, b4, btilde1, btilde2, btilde3, btilde4, gamma, c2, c3, d1, d2, d3, d4 = cache.tab - - # Assignments - uidx = eachindex(integrator.uprev) - sizeu = size(u) - mass_matrix = integrator.f.mass_matrix - utilde = du - - # Precalculations - dtC21 = C21 / dt - dtC31 = C31 / dt - dtC32 = C32 / dt - dtC41 = C41 / dt - dtC42 = C42 / dt - dtC43 = C43 / dt - - dtd1 = dt * d1 - dtd2 = dt * d2 - dtd3 = dt * d3 - dtd4 = dt * d4 - dtgamma = dt * gamma - - calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true) - - calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t) - - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - end - - vecu = _vec(linres.u) - veck1 = _vec(k1) - - @.. broadcast=false veck1=-vecu - integrator.stats.nsolve += 1 - - #= - a21 == 0 and c2 == 0 - so du = integrator.fsalfirst! - @.. broadcast=false u = uprev + a21*k1 - - f(du, u, p, t+c2*dt) - =# - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=fsalfirst + dtd2 * dT + dtC21 * k1 - else - @.. broadcast=false du1=dtC21 * k1 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=fsalfirst + dtd2 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - veck2 = _vec(k2) - @.. broadcast=false veck2=-vecu - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + a31 * k1 + a32 * k2 - stage_limiter!(u, integrator, p, t + c3 * dt) - f(du, u, p, t + c3 * dt) - integrator.stats.nf += 1 - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd3 * dT + dtC31 * k1 + dtC32 * k2 - else - @.. broadcast=false du1=dtC31 * k1 + dtC32 * k2 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + dtd3 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - veck3 = _vec(k3) - @.. broadcast=false veck3=-vecu - integrator.stats.nsolve += 1 - @.. broadcast=false u=uprev + a41 * k1 + a42 * k2 + a43 * k3 - stage_limiter!(u, integrator, p, t + dt) - f(du, u, p, t + dt) #-- c4 = 1 - integrator.stats.nf += 1 - - if mass_matrix === I - @.. broadcast=false linsolve_tmp=du + dtd4 * dT + dtC41 * k1 + dtC42 * k2 + - dtC43 * k3 - else - @.. broadcast=false du1=dtC41 * k1 + dtC42 * k2 + dtC43 * k3 - mul!(_vec(du2), mass_matrix, _vec(du1)) - @.. broadcast=false linsolve_tmp=du + dtd4 * dT + du2 - end - - linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - veck4 = _vec(k4) - @.. broadcast=false veck4=-vecu - integrator.stats.nsolve += 1 - - @.. broadcast=false u=uprev + b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4 - - step_limiter!(u, integrator, p, t + dt) - - f(fsallast, u, p, t + dt) - integrator.stats.nf += 1 - - if integrator.opts.adaptive - @.. broadcast=false utilde=btilde1 * k1 + btilde2 * k2 + btilde3 * k3 + btilde4 * k4 - calculate_residuals!(atmp, utilde, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - cache.linsolve = linres.cache -end - ################################################################################ #### ROS2 type method