From 7dc1e512f0c07ac39601a459f43b31adedcb08cd Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 30 Aug 2024 14:19:36 -0400 Subject: [PATCH 1/4] reshape and _vec appropriately in more places --- lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl | 21 +++++++++--------- lib/OrdinaryDiffEqBDF/src/controllers.jl | 2 +- .../src/firk_perform_step.jl | 22 ++++++++++++------- 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl index c410bf6a36..b0178688c3 100644 --- a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl +++ b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl @@ -758,16 +758,17 @@ function perform_step!(integrator, cache::QNDFConstantCache{max_order}, α₀ = 1 β₀ = inv((1 - κ) * γₖ[k]) if u isa Number - u₀ = sum(D[1:k]) + uprev + u₀ = sum(view(D, 1:k)) + uprev ϕ = zero(u) for i in 1:k ϕ += γₖ[i] * D[i] end else - u₀ = reshape(sum(D[:, 1:k], dims = 2) .+ uprev, size(u)) + u₀ = reshape(sum(view(D, :, 1:k), dims = 2), size(u)) .+ uprev ϕ = zero(u) for i in 1:k - ϕ = @.. ϕ + γₖ[i] * D[:, i] + D_row = reshape(view(D, :, i), size(u)) + ϕ = @.. ϕ + γₖ[i] * D_row end end markfirststage!(nlsolver) @@ -802,14 +803,14 @@ function perform_step!(integrator, cache::QNDFConstantCache{max_order}, end integrator.EEst = error_constant(integrator, k) * internalnorm(atmp, t) if k > 1 - @views atmpm1 = calculate_residuals(D[:, k], uprev, u, integrator.opts.abstol, - integrator.opts.reltol, - integrator.opts.internalnorm, t) + @views atmpm1 = calculate_residuals(reshape(view(D, :, k), size(u)), + uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) cache.EEst1 = error_constant(integrator, k - 1) * internalnorm(atmpm1, t) end if k < max_order - @views atmpp1 = calculate_residuals(D[:, k + 2], uprev, u, abstol, reltol, - internalnorm, t) + @views atmpp1 = calculate_residuals(reshape(view(D, :, k + 2), size(u)), + uprev, u, abstol, reltol, internalnorm, t) cache.EEst2 = error_constant(integrator, k + 1) * internalnorm(atmpp1, t) end end @@ -1112,7 +1113,7 @@ function perform_step!(integrator, cache::FBDFConstantCache{max_order}, end tmp = -uprev * bdf_coeffs[k, 2] for i in 1:(k - 1) - @views tmp = @.. tmp - u_corrector[:, i] * bdf_coeffs[k, i + 2] + tmp = @.. tmp - $(reshape(view(u_corrector, :, i), size(u))) * bdf_coeffs[k, i + 2] end end @@ -1169,7 +1170,7 @@ function perform_step!(integrator, cache::FBDFConstantCache{max_order}, terk *= abs(dt^(k)) else for i in 2:(k + 1) - @views terk = @.. terk + fd_weights[i, k + 1] * u_history[:, i - 1] + terk = @.. terk + fd_weights[i, k + 1] * $(reshape(view(u_history, :, i - 1), size(u))) end terk *= abs(dt^(k)) end diff --git a/lib/OrdinaryDiffEqBDF/src/controllers.jl b/lib/OrdinaryDiffEqBDF/src/controllers.jl index fc13dbd91a..089b1b6110 100644 --- a/lib/OrdinaryDiffEqBDF/src/controllers.jl +++ b/lib/OrdinaryDiffEqBDF/src/controllers.jl @@ -217,7 +217,7 @@ function choose_order!(alg::FBDF, integrator, terk_tmp = similar(u) @.. terk_tmp = fd_weights[k - 2, 1] * _vec(u) for i in 2:(k - 2) - @.. @views terk_tmp += fd_weights[i, k - 2] * u_history[:, i - 1] + @.. terk_tmp += fd_weights[i, k - 2] * $(reshape(view(u_history, :, i - 1), size(u))) end @.. terk_tmp *= abs(dt^(k - 2)) end diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl index bb0009cdd7..671b14b8f4 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl @@ -169,7 +169,7 @@ end rhs1 = @. fw1 - αdt * Mw1 + βdt * Mw2 rhs2 = @. fw2 - βdt * Mw1 - αdt * Mw2 - dw12 = LU1 \ (@. rhs1 + rhs2 * im) + dw12 = reshape(LU1 \ _vec(@. rhs1 + rhs2 * im), size(u)) integrator.stats.nsolve += 1 dw1 = real(dw12) dw2 = imag(dw12) @@ -450,8 +450,8 @@ end rhs1 = @.. broadcast=false fw1-γdt * Mw1 rhs2 = @.. broadcast=false fw2 - αdt * Mw2+βdt * Mw3 rhs3 = @.. broadcast=false fw3 - βdt * Mw2-αdt * Mw3 - dw1 = LU1 \ rhs1 - dw23 = LU2 \ (@.. broadcast=false rhs2+rhs3 * im) + dw1 = reshape(LU1 \ _vec(rhs1), size(u)) + dw23 = reshape(LU2 \ _vec(@.. broadcast=false rhs2+rhs3 * im), size(u)) integrator.stats.nsolve += 2 dw2 = real(dw23) dw3 = imag(dw23) @@ -507,7 +507,10 @@ end tmp = @.. broadcast=false e1dt*z1+e2dt*z2+e3dt*z3 mass_matrix != I && (tmp = mass_matrix * tmp) utilde = @.. broadcast=false integrator.fsalfirst+tmp - alg.smooth_est && (utilde = LU1 \ utilde; integrator.stats.nsolve += 1) + if alg.smooth_est + utilde = reshape(LU1 \ _vec(utilde), size(u)) + integrator.stats.nsolve += 1 + end # RadauIIA5 needs a transformed rtol and atol see # https://github.com/luchr/ODEInterface.jl/blob/0bd134a5a358c4bc13e0fb6a90e27e4ee79e0115/src/radau5.f#L399-L421 atmp = calculate_residuals(utilde, uprev, u, atol, rtol, internalnorm, t) @@ -899,9 +902,9 @@ end rhs3 = @.. broadcast=false fw3 - β1dt * Mw2-α1dt * Mw3 rhs4 = @.. broadcast=false fw4 - α2dt * Mw4+β2dt * Mw5 rhs5 = @.. broadcast=false fw5 - β2dt * Mw4-α2dt * Mw5 - dw1 = LU1 \ rhs1 - dw23 = LU2 \ (@.. broadcast=false rhs2+rhs3 * im) - dw45 = LU3 \ (@.. broadcast=false rhs4+rhs5 * im) + dw1 = reshape(LU1 \ _vec(rhs1), size(u)) + dw23 = reshape(LU2 \ _vec(@.. broadcast=false rhs2+rhs3 * im), size(u)) + dw45 = reshape(LU3 \ _vec(@.. broadcast=false rhs4+rhs5 * im), size(u)) integrator.stats.nsolve += 3 dw2 = real(dw23) dw3 = imag(dw23) @@ -969,7 +972,10 @@ end tmp = @.. broadcast=false e1dt*z1+e2dt*z2+e3dt*z3+e4dt*z4+e5dt*z5 mass_matrix != I && (tmp = mass_matrix * tmp) utilde = @.. broadcast=false integrator.fsalfirst+tmp - alg.smooth_est && (utilde = LU1 \ utilde; integrator.stats.nsolve += 1) + if alg.smooth_est + utilde = reshape(LU1 \ _vec(utilde), size(u)) + integrator.stats.nsolve += 1 + end atmp = calculate_residuals(utilde, uprev, u, atol, rtol, internalnorm, t) integrator.EEst = internalnorm(atmp, t) From 5f77fd4836d75bcf7d3b3470afa5a39874adedf0 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Sat, 31 Aug 2024 23:05:10 -0400 Subject: [PATCH 2/4] use _reshape --- lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl index b0178688c3..4f527bac7a 100644 --- a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl +++ b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl @@ -764,10 +764,10 @@ function perform_step!(integrator, cache::QNDFConstantCache{max_order}, ϕ += γₖ[i] * D[i] end else - u₀ = reshape(sum(view(D, :, 1:k), dims = 2), size(u)) .+ uprev + u₀ = _reshape(sum(view(D, :, 1:k), dims = 2), axes(u)) .+ uprev ϕ = zero(u) for i in 1:k - D_row = reshape(view(D, :, i), size(u)) + D_row = _reshape(view(D, :, i), axes(u)) ϕ = @.. ϕ + γₖ[i] * D_row end end @@ -803,13 +803,13 @@ function perform_step!(integrator, cache::QNDFConstantCache{max_order}, end integrator.EEst = error_constant(integrator, k) * internalnorm(atmp, t) if k > 1 - @views atmpm1 = calculate_residuals(reshape(view(D, :, k), size(u)), + @views atmpm1 = calculate_residuals(_reshape(view(D, :, k), axes(u)), uprev, u, integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) cache.EEst1 = error_constant(integrator, k - 1) * internalnorm(atmpm1, t) end if k < max_order - @views atmpp1 = calculate_residuals(reshape(view(D, :, k + 2), size(u)), + @views atmpp1 = calculate_residuals(_reshape(view(D, :, k + 2), axes(u)), uprev, u, abstol, reltol, internalnorm, t) cache.EEst2 = error_constant(integrator, k + 1) * internalnorm(atmpp1, t) end @@ -926,13 +926,13 @@ function perform_step!(integrator, cache::QNDFCache{max_order}, integrator.EEst = error_constant(integrator, k) * internalnorm(atmp, t) if k > 1 @views calculate_residuals!( - atmpm1, reshape(D[:, k], size(u)), uprev, u, abstol, + atmpm1, _reshape(D[:, k], axes(u)), uprev, u, abstol, reltol, internalnorm, t) cache.EEst1 = error_constant(integrator, k - 1) * internalnorm(atmpm1, t) end if k < max_order @views calculate_residuals!( - atmpp1, reshape(D[:, k + 2], size(u)), uprev, u, abstol, + atmpp1, _reshape(D[:, k + 2], axes(u)), uprev, u, abstol, reltol, internalnorm, t) cache.EEst2 = error_constant(integrator, k + 1) * internalnorm(atmpp1, t) end @@ -1113,7 +1113,7 @@ function perform_step!(integrator, cache::FBDFConstantCache{max_order}, end tmp = -uprev * bdf_coeffs[k, 2] for i in 1:(k - 1) - tmp = @.. tmp - $(reshape(view(u_corrector, :, i), size(u))) * bdf_coeffs[k, i + 2] + tmp = @.. tmp - $(_reshape(view(u_corrector, :, i), axes(u))) * bdf_coeffs[k, i + 2] end end @@ -1170,7 +1170,7 @@ function perform_step!(integrator, cache::FBDFConstantCache{max_order}, terk *= abs(dt^(k)) else for i in 2:(k + 1) - terk = @.. terk + fd_weights[i, k + 1] * $(reshape(view(u_history, :, i - 1), size(u))) + terk = @.. terk + fd_weights[i, k + 1] * $(_reshape(view(u_history, :, i - 1), axes(u))) end terk *= abs(dt^(k)) end From cbfdf8808d0ec41410052f1d9d80ef9385892612 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Sat, 31 Aug 2024 23:05:39 -0400 Subject: [PATCH 3/4] use _reshape --- lib/OrdinaryDiffEqBDF/src/controllers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/OrdinaryDiffEqBDF/src/controllers.jl b/lib/OrdinaryDiffEqBDF/src/controllers.jl index 089b1b6110..e992aaeb1d 100644 --- a/lib/OrdinaryDiffEqBDF/src/controllers.jl +++ b/lib/OrdinaryDiffEqBDF/src/controllers.jl @@ -217,7 +217,7 @@ function choose_order!(alg::FBDF, integrator, terk_tmp = similar(u) @.. terk_tmp = fd_weights[k - 2, 1] * _vec(u) for i in 2:(k - 2) - @.. terk_tmp += fd_weights[i, k - 2] * $(reshape(view(u_history, :, i - 1), size(u))) + @.. terk_tmp += fd_weights[i, k - 2] * $(_reshape(view(u_history, :, i - 1), axes(u))) end @.. terk_tmp *= abs(dt^(k - 2)) end From 2bc1ff3c9f10386ea96723aa384a9d2cfa9d9581 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Sat, 31 Aug 2024 23:06:46 -0400 Subject: [PATCH 4/4] use _reshape --- lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl index 671b14b8f4..0432ae1103 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl @@ -169,7 +169,7 @@ end rhs1 = @. fw1 - αdt * Mw1 + βdt * Mw2 rhs2 = @. fw2 - βdt * Mw1 - αdt * Mw2 - dw12 = reshape(LU1 \ _vec(@. rhs1 + rhs2 * im), size(u)) + dw12 = _reshape(LU1 \ _vec(@. rhs1 + rhs2 * im), axes(u)) integrator.stats.nsolve += 1 dw1 = real(dw12) dw2 = imag(dw12) @@ -450,8 +450,8 @@ end rhs1 = @.. broadcast=false fw1-γdt * Mw1 rhs2 = @.. broadcast=false fw2 - αdt * Mw2+βdt * Mw3 rhs3 = @.. broadcast=false fw3 - βdt * Mw2-αdt * Mw3 - dw1 = reshape(LU1 \ _vec(rhs1), size(u)) - dw23 = reshape(LU2 \ _vec(@.. broadcast=false rhs2+rhs3 * im), size(u)) + dw1 = _reshape(LU1 \ _vec(rhs1), axes(u)) + dw23 = _reshape(LU2 \ _vec(@.. broadcast=false rhs2+rhs3 * im), axes(u)) integrator.stats.nsolve += 2 dw2 = real(dw23) dw3 = imag(dw23) @@ -508,7 +508,7 @@ end mass_matrix != I && (tmp = mass_matrix * tmp) utilde = @.. broadcast=false integrator.fsalfirst+tmp if alg.smooth_est - utilde = reshape(LU1 \ _vec(utilde), size(u)) + utilde = _reshape(LU1 \ _vec(utilde), axes(u)) integrator.stats.nsolve += 1 end # RadauIIA5 needs a transformed rtol and atol see @@ -902,9 +902,9 @@ end rhs3 = @.. broadcast=false fw3 - β1dt * Mw2-α1dt * Mw3 rhs4 = @.. broadcast=false fw4 - α2dt * Mw4+β2dt * Mw5 rhs5 = @.. broadcast=false fw5 - β2dt * Mw4-α2dt * Mw5 - dw1 = reshape(LU1 \ _vec(rhs1), size(u)) - dw23 = reshape(LU2 \ _vec(@.. broadcast=false rhs2+rhs3 * im), size(u)) - dw45 = reshape(LU3 \ _vec(@.. broadcast=false rhs4+rhs5 * im), size(u)) + dw1 = _reshape(LU1 \ _vec(rhs1), axes(u)) + dw23 = _reshape(LU2 \ _vec(@.. broadcast=false rhs2+rhs3 * im), axes(u)) + dw45 = _reshape(LU3 \ _vec(@.. broadcast=false rhs4+rhs5 * im), axes(u)) integrator.stats.nsolve += 3 dw2 = real(dw23) dw3 = imag(dw23) @@ -973,7 +973,7 @@ end mass_matrix != I && (tmp = mass_matrix * tmp) utilde = @.. broadcast=false integrator.fsalfirst+tmp if alg.smooth_est - utilde = reshape(LU1 \ _vec(utilde), size(u)) + utilde = _reshape(LU1 \ _vec(utilde), axes(u)) integrator.stats.nsolve += 1 end atmp = calculate_residuals(utilde, uprev, u, atol, rtol, internalnorm, t)