From 16ba23e9aae4bacb986b03ff66f3480a15c1708a Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 12 Sep 2024 15:45:15 -0400 Subject: [PATCH] fix merge --- .../src/rosenbrock_perform_step.jl | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index c4271a493c..57773da365 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -52,7 +52,7 @@ end linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp)) veck₁ = _vec(k₁) - @.. veck₁ = vecu * neginvdtγ + @.. veck₁ = linres.u * neginvdtγ integrator.stats.nsolve += 1 @.. u=uprev + dto2 * k₁ @@ -69,10 +69,9 @@ end @.. linsolve_tmp = f₁ - tmp linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) veck₂ = _vec(k₂) - @.. veck₂ = vecu * neginvdtγ + veck₁ + @.. veck₂ = linres.u * neginvdtγ + veck₁ integrator.stats.nsolve += 1 @.. u = uprev + dt * k₂ @@ -95,7 +94,7 @@ end linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) veck3 = _vec(k₃) - @.. veck3 = vecu * neginvdtγ + @.. veck3 = linres.u * neginvdtγ integrator.stats.nsolve += 1 @@ -150,9 +149,8 @@ end integrator.opts.internalnorm, t) linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp)) - veck₁ = _vec(k₁) - @.. veck₁ = vecu * neginvdtγ + @.. veck₁ = linres.u * neginvdtγ integrator.stats.nsolve += 1 @.. broadcast=false u=uprev + dto2 * k₁ @@ -169,9 +167,8 @@ end @.. broadcast=false linsolve_tmp=f₁ - tmp linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) veck₂ = _vec(k₂) - @.. veck₂ = vecu * neginvdtγ + veck₁ + @.. veck₂ = linres.u * neginvdtγ + veck₁ integrator.stats.nsolve += 1 @.. tmp = uprev + dt * k₂ @@ -191,7 +188,7 @@ end linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) veck3 = _vec(k₃) - @.. veck3 = vecu * neginvdtγ + @.. veck3 = linres.u * neginvdtγ integrator.stats.nsolve += 1 @.. broadcast=false u=uprev + dto6 * (k₁ + 4k₂ + k₃)