Skip to content

Commit

Permalink
Merge pull request #2446 from oscardssmith/os/fix-oop-matrix-u
Browse files Browse the repository at this point in the history
`_reshape` and `_vec` appropriately in more places
  • Loading branch information
ChrisRackauckas authored Sep 1, 2024
2 parents fc1a214 + 2bc1ff3 commit 04a7c58
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
25 changes: 13 additions & 12 deletions lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -758,16 +758,17 @@ function perform_step!(integrator, cache::QNDFConstantCache{max_order},
α₀ = 1
β₀ = inv((1 - κ) * γₖ[k])
if u isa Number
u₀ = sum(D[1:k]) + uprev
u₀ = sum(view(D, 1:k)) + uprev
ϕ = zero(u)
for i in 1:k
ϕ += γₖ[i] * D[i]
end
else
u₀ = reshape(sum(D[:, 1:k], dims = 2) .+ uprev, size(u))
u₀ = _reshape(sum(view(D, :, 1:k), dims = 2), axes(u)) .+ uprev
ϕ = zero(u)
for i in 1:k
ϕ = @.. ϕ + γₖ[i] * D[:, i]
D_row = _reshape(view(D, :, i), axes(u))
ϕ = @.. ϕ + γₖ[i] * D_row
end
end
markfirststage!(nlsolver)
Expand Down Expand Up @@ -802,14 +803,14 @@ function perform_step!(integrator, cache::QNDFConstantCache{max_order},
end
integrator.EEst = error_constant(integrator, k) * internalnorm(atmp, t)
if k > 1
@views atmpm1 = calculate_residuals(D[:, k], uprev, u, integrator.opts.abstol,
integrator.opts.reltol,
integrator.opts.internalnorm, t)
@views atmpm1 = calculate_residuals(_reshape(view(D, :, k), axes(u)),
uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
cache.EEst1 = error_constant(integrator, k - 1) * internalnorm(atmpm1, t)
end
if k < max_order
@views atmpp1 = calculate_residuals(D[:, k + 2], uprev, u, abstol, reltol,
internalnorm, t)
@views atmpp1 = calculate_residuals(_reshape(view(D, :, k + 2), axes(u)),
uprev, u, abstol, reltol, internalnorm, t)
cache.EEst2 = error_constant(integrator, k + 1) * internalnorm(atmpp1, t)
end
end
Expand Down Expand Up @@ -925,13 +926,13 @@ function perform_step!(integrator, cache::QNDFCache{max_order},
integrator.EEst = error_constant(integrator, k) * internalnorm(atmp, t)
if k > 1
@views calculate_residuals!(
atmpm1, reshape(D[:, k], size(u)), uprev, u, abstol,
atmpm1, _reshape(D[:, k], axes(u)), uprev, u, abstol,
reltol, internalnorm, t)
cache.EEst1 = error_constant(integrator, k - 1) * internalnorm(atmpm1, t)
end
if k < max_order
@views calculate_residuals!(
atmpp1, reshape(D[:, k + 2], size(u)), uprev, u, abstol,
atmpp1, _reshape(D[:, k + 2], axes(u)), uprev, u, abstol,
reltol, internalnorm, t)
cache.EEst2 = error_constant(integrator, k + 1) * internalnorm(atmpp1, t)
end
Expand Down Expand Up @@ -1112,7 +1113,7 @@ function perform_step!(integrator, cache::FBDFConstantCache{max_order},
end
tmp = -uprev * bdf_coeffs[k, 2]
for i in 1:(k - 1)
@views tmp = @.. tmp - u_corrector[:, i] * bdf_coeffs[k, i + 2]
tmp = @.. tmp - $(_reshape(view(u_corrector, :, i), axes(u))) * bdf_coeffs[k, i + 2]
end
end

Expand Down Expand Up @@ -1169,7 +1170,7 @@ function perform_step!(integrator, cache::FBDFConstantCache{max_order},
terk *= abs(dt^(k))
else
for i in 2:(k + 1)
@views terk = @.. terk + fd_weights[i, k + 1] * u_history[:, i - 1]
terk = @.. terk + fd_weights[i, k + 1] * $(_reshape(view(u_history, :, i - 1), axes(u)))
end
terk *= abs(dt^(k))
end
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqBDF/src/controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ function choose_order!(alg::FBDF, integrator,
terk_tmp = similar(u)
@.. terk_tmp = fd_weights[k - 2, 1] * _vec(u)
for i in 2:(k - 2)
@.. @views terk_tmp += fd_weights[i, k - 2] * u_history[:, i - 1]
@.. terk_tmp += fd_weights[i, k - 2] * $(_reshape(view(u_history, :, i - 1), axes(u)))
end
@.. terk_tmp *= abs(dt^(k - 2))
end
Expand Down
22 changes: 14 additions & 8 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ end

rhs1 = @. fw1 - αdt * Mw1 + βdt * Mw2
rhs2 = @. fw2 - βdt * Mw1 - αdt * Mw2
dw12 = LU1 \ (@. rhs1 + rhs2 * im)
dw12 = _reshape(LU1 \ _vec(@. rhs1 + rhs2 * im), axes(u))
integrator.stats.nsolve += 1
dw1 = real(dw12)
dw2 = imag(dw12)
Expand Down Expand Up @@ -450,8 +450,8 @@ end
rhs1 = @.. broadcast=false fw1-γdt * Mw1
rhs2 = @.. broadcast=false fw2 - αdt * Mw2+βdt * Mw3
rhs3 = @.. broadcast=false fw3 - βdt * Mw2-αdt * Mw3
dw1 = LU1 \ rhs1
dw23 = LU2 \ (@.. broadcast=false rhs2+rhs3 * im)
dw1 = _reshape(LU1 \ _vec(rhs1), axes(u))
dw23 = _reshape(LU2 \ _vec(@.. broadcast=false rhs2+rhs3 * im), axes(u))
integrator.stats.nsolve += 2
dw2 = real(dw23)
dw3 = imag(dw23)
Expand Down Expand Up @@ -507,7 +507,10 @@ end
tmp = @.. broadcast=false e1dt*z1+e2dt*z2+e3dt*z3
mass_matrix != I && (tmp = mass_matrix * tmp)
utilde = @.. broadcast=false integrator.fsalfirst+tmp
alg.smooth_est && (utilde = LU1 \ utilde; integrator.stats.nsolve += 1)
if alg.smooth_est
utilde = _reshape(LU1 \ _vec(utilde), axes(u))
integrator.stats.nsolve += 1
end
# RadauIIA5 needs a transformed rtol and atol see
# https://github.com/luchr/ODEInterface.jl/blob/0bd134a5a358c4bc13e0fb6a90e27e4ee79e0115/src/radau5.f#L399-L421
atmp = calculate_residuals(utilde, uprev, u, atol, rtol, internalnorm, t)
Expand Down Expand Up @@ -899,9 +902,9 @@ end
rhs3 = @.. broadcast=false fw3 - β1dt * Mw2-α1dt * Mw3
rhs4 = @.. broadcast=false fw4 - α2dt * Mw4+β2dt * Mw5
rhs5 = @.. broadcast=false fw5 - β2dt * Mw4-α2dt * Mw5
dw1 = LU1 \ rhs1
dw23 = LU2 \ (@.. broadcast=false rhs2+rhs3 * im)
dw45 = LU3 \ (@.. broadcast=false rhs4+rhs5 * im)
dw1 = _reshape(LU1 \ _vec(rhs1), axes(u))
dw23 = _reshape(LU2 \ _vec(@.. broadcast=false rhs2+rhs3 * im), axes(u))
dw45 = _reshape(LU3 \ _vec(@.. broadcast=false rhs4+rhs5 * im), axes(u))
integrator.stats.nsolve += 3
dw2 = real(dw23)
dw3 = imag(dw23)
Expand Down Expand Up @@ -969,7 +972,10 @@ end
tmp = @.. broadcast=false e1dt*z1+e2dt*z2+e3dt*z3+e4dt*z4+e5dt*z5
mass_matrix != I && (tmp = mass_matrix * tmp)
utilde = @.. broadcast=false integrator.fsalfirst+tmp
alg.smooth_est && (utilde = LU1 \ utilde; integrator.stats.nsolve += 1)
if alg.smooth_est
utilde = _reshape(LU1 \ _vec(utilde), axes(u))
integrator.stats.nsolve += 1
end
atmp = calculate_residuals(utilde, uprev, u, atol, rtol, internalnorm, t)
integrator.EEst = internalnorm(atmp, t)

Expand Down

0 comments on commit 04a7c58

Please sign in to comment.