Skip to content

Commit

Permalink
reduce allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaLampert committed Aug 21, 2024
1 parent b6fc314 commit 719a03d
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 40 deletions.
8 changes: 5 additions & 3 deletions src/equations/equations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,14 @@ terms are present.
function energy_total_modified(q_global, equations::AbstractShallowWaterEquations, cache)
# `q_global` is an `ArrayPartition` of the primitive variables at all nodes
@assert nvariables(equations) == length(q_global.x)
# tmp1 is always in the cache
@unpack tmp1 = cache

e = similar(q_global.x[begin])
for i in eachindex(q_global.x[begin])
e[i] = energy_total(get_node_vars(q_global, equations, i), equations)
tmp1[i] = energy_total(get_node_vars(q_global, equations, i), equations)
end

return e
return tmp1
end

varnames(::typeof(energy_total_modified), equations) = ("e_modified",)
Expand Down Expand Up @@ -458,6 +459,7 @@ function solve_system_matrix!(dv, system_matrix, rhs,
cholesky!(factorization, system_matrix; check = false)
if issuccess(factorization)
scale_by_mass_matrix!(rhs, D1)
# see https://github.com/JoshuaLampert/DispersiveShallowWater.jl/issues/122
dv .= factorization \ rhs
else
# The factorization may fail if the time step is too large
Expand Down
2 changes: 1 addition & 1 deletion src/equations/serre_green_naghdi_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ function create_cache(mesh,
end

function assemble_system_matrix!(cache, h, D1, D1mat,
equations::SerreGreenNaghdiEquations1D{BathymetryFlat})
::SerreGreenNaghdiEquations1D{BathymetryFlat})
(; M_h, M_h3_3) = cache

@.. M_h = h
Expand Down
108 changes: 72 additions & 36 deletions src/equations/svaerd_kalisch_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,28 @@ function create_cache(mesh, equations::SvaerdKalischEquations1D,
h = ones(RealT, nnodes(mesh))
hv = zero(h)
b = zero(h)
eta_x = zero(h)
v_x = zero(h)
alpha_eta_x_x = zero(h)
y_x = zero(h)
v_y_x = zero(h)
yv_x = zero(h)
vy = zero(h)
vy_x = zero(h)
y_v_x = zero(h)
h_v_x = zero(h)
hv2_x = zero(h)
v_xx = zero(h)
gamma_v_xx_x = zero(h)
gamma_v_x_xx = zero(h)
alpha_hat = sqrt.(equations.alpha * sqrt.(g * D) .* D .^ 2)
beta_hat = equations.beta * D .^ 3
gamma_hat = equations.gamma * sqrt.(g * D) .* D .^ 3
tmp2 = zero(h)
M = mass_matrix(D1)
M_h = zero(h)
M_beta = scale_by_mass_matrix!(beta_hat, D1)
M_beta = copy(beta_hat)
scale_by_mass_matrix!(M_beta, D1)
if D1 isa PeriodicDerivativeOperator ||
D1 isa UniformPeriodicCoupledOperator
D1_central = D1
Expand All @@ -222,14 +236,19 @@ function create_cache(mesh, equations::SvaerdKalischEquations1D,
@error "unknown type of first-derivative operator: $(typeof(D1))"
end
factorization = cholesky(system_matrix)
return (; factorization = factorization, minus_MD1betaD1 = minus_MD1betaD1, D = D,
h = h, hv = hv, b = b, v_x = v_x,
alpha_hat = alpha_hat, beta_hat = beta_hat, gamma_hat = gamma_hat,
tmp2 = tmp2, D1_central = D1_central, M = M, D1 = D1, M_h = M_h)
cache = (; factorization, minus_MD1betaD1, D, h, hv, b, eta_x, v_x,
alpha_eta_x_x, y_x, v_y_x, yv_x, vy, vy_x, y_v_x, h_v_x, hv2_x, v_xx, gamma_v_xx_x, gamma_v_x_xx,
alpha_hat, beta_hat, gamma_hat, tmp2, D1_central, M, D1, M_h)
if D1 isa PeriodicUpwindOperators
eta_x_upwind = zero(h)
v_x_upwind = zero(h)
cache = (; cache..., eta_x_upwind, v_x_upwind)
end
return cache
end

function assemble_system_matrix!(cache, h, D1, D1mat,
equations::SvaerdKalischEquations1D)
::SvaerdKalischEquations1D)
(; M_h, minus_MD1betaD1) = cache

@.. M_h = h
Expand All @@ -250,10 +269,11 @@ end
# - Joshua Lampert and Hendrik Ranocha (2024)
# Structure-Preserving Numerical Methods for Two Nonlinear Systems of Dispersive Wave Equations
# [DOI: 10.48550/arXiv.2402.16669](https://doi.org/10.48550/arXiv.2402.16669)
# TODO: Simplify for the case of flat bathymetry and use higher-order operators
function rhs!(dq, q, t, mesh, equations::SvaerdKalischEquations1D,
initial_condition, ::BoundaryConditionPeriodic, source_terms,
solver, cache)
@unpack D, h, hv, b, alpha_hat, gamma_hat, tmp1, tmp2, D1_central, M, D1 = cache
@unpack D, h, hv, b, eta_x, v_x, alpha_eta_x_x, y_x, v_y_x, yv_x, vy, vy_x, y_v_x, h_v_x, hv2_x, v_xx, gamma_v_xx_x, gamma_v_x_xx, alpha_hat, gamma_hat, tmp1, tmp2, D1_central, M, D1 = cache

g = gravity_constant(equations)
eta, v = q.x
Expand All @@ -267,54 +287,72 @@ function rhs!(dq, q, t, mesh, equations::SvaerdKalischEquations1D,

if D1 isa PeriodicDerivativeOperator ||
D1 isa UniformPeriodicCoupledOperator
D1eta = D1_central * eta
D1v = D1_central * v
tmp1 = alpha_hat .* (D1_central * (alpha_hat .* D1eta))
vD1y = v .* (D1_central * tmp1)
D1vy = D1_central * (v .* tmp1)
yD1v = tmp1 .* D1v
mul!(eta_x, D1_central, eta)
mul!(v_x, D1_central, v)
@.. tmp1 = alpha_hat * eta_x
mul!(alpha_eta_x_x, D1_central, tmp1)
@.. tmp1 = alpha_hat * alpha_eta_x_x
mul!(y_x, D1_central, tmp1)
@.. v_y_x = v * y_x
@.. vy = v * tmp1
mul!(vy_x, D1_central, vy)
@.. y_v_x = tmp1 * v_x
@.. tmp2 = tmp1 - hv
mul!(deta, D1_central, tmp2)
elseif D1 isa PeriodicUpwindOperators
D1eta = D1_central * eta
D1v = D1_central * v
tmp1 = alpha_hat .* (D1.minus * (alpha_hat .* (D1.plus * eta)))
vD1y = v .* (D1.minus * tmp1)
D1vy = D1.minus * (v .* tmp1)
yD1v = tmp1 .* (D1.plus * v)
deta[:] = D1.minus * tmp1 - D1_central * hv
@unpack eta_x_upwind, v_x_upwind = cache
mul!(eta_x, D1_central, eta)
mul!(v_x, D1_central, v)
mul!(eta_x_upwind, D1.plus, eta)
@.. tmp1 = alpha_hat * eta_x_upwind
mul!(alpha_eta_x_x, D1_central, tmp1)
@.. tmp1 = alpha_hat * alpha_eta_x_x
mul!(y_x, D1.minus, tmp1)
@.. v_y_x = v * y_x
@.. vy = v * tmp1
mul!(vy_x, D1.minus, vy)
mul!(v_x_upwind, D1.plus, v)
@.. y_v_x = tmp1 * v_x_upwind
# deta[:] = D1.minus * tmp1 - D1_central * hv
mul!(deta, D1.minus, tmp1)
mul!(deta, D1_central, hv, -1.0, 1.0)
else
@error "unknown type of first derivative operator: $(typeof(D1))"
end
end

# split form
@trixi_timeit timer() "dv hyperbolic" begin
D1_hv = D1_central * hv
D1_hv2 = D1_central * (hv .* v)
D1_gamma_hat_D2_v = D1_central * (gamma_hat .* (solver.D2 * v))
D2_gamma_hat_D1_v = solver.D2 * (gamma_hat .* D1v)
@.. dv = -(0.5 * (D1_hv2 + hv * D1v - v * D1_hv) +
g * h * D1eta +
0.5 * (vD1y - D1vy - yD1v) -
0.5 * D1_gamma_hat_D2_v -
0.5 * D2_gamma_hat_D1_v)
mul!(h_v_x, D1_central, hv)
@.. tmp1 = hv * v
mul!(hv2_x, D1_central, tmp1)
mul!(v_xx, solver.D2, v)
@.. tmp1 = gamma_hat * v_xx
mul!(gamma_v_xx_x, D1_central, tmp1)
@.. tmp1 = gamma_hat * v_x
mul!(gamma_v_x_xx, solver.D2, tmp1)
@.. dv = -(0.5 * (hv2_x + hv * v_x - v * h_v_x) +
g * h * eta_x +
0.5 * (v_y_x - vy_x - y_v_x) -
0.5 * gamma_v_xx_x -
0.5 * gamma_v_x_xx)
end

# no split form
# dv[:] = -(D1_central * (hv .* v) - v .* (D1_central * hv)+
# g * h .* D1eta +
# vD1y - D1vy -
# 0.5 * D1_central * (gamma_hat .* (solver.D2 * v)) -
# 0.5 * solver.D2 * (gamma_hat .* D1v))
# g * h .* eta_x +
# vy_x - v_y_c -
# 0.5 * gamma_v_xx_x -
# 0.5 * gamma_v_x_xx)

@trixi_timeit timer() "source terms" calc_sources!(dq, q, t, source_terms, equations,
solver)
@trixi_timeit timer() "assemble system matrix" begin
system_matrix = assemble_system_matrix!(cache, h, D1, D1_central, equations)
end
@trixi_timeit timer() "solving elliptic system" begin
solve_system_matrix!(dv, system_matrix, dv,
tmp1 .= dv
solve_system_matrix!(dv, system_matrix, tmp1,
equations, D1, cache)
end

Expand Down Expand Up @@ -352,8 +390,6 @@ end
return equations.eta0 - D
end

@inline entropy(u, equations::SvaerdKalischEquations1D) = energy_total(u, equations)

# The modified entropy/energy takes the whole `q` for every point in space
"""
energy_total_modified(q_global, equations::SvaerdKalischEquations1D, cache)
Expand Down

0 comments on commit 719a03d

Please sign in to comment.