Skip to content

Commit

Permalink
More efficient PERK implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielDoehring committed Nov 27, 2024
1 parent 0046123 commit ea1c1d3
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 96 deletions.
4 changes: 2 additions & 2 deletions examples/structured_1d_dgsem/elixir_burgers_perk3.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Convex and ECOS are imported because they are used for finding the optimal time step and optimal
# monomial coefficients in the stability polynomial of P-ERK time integrators.
# monomial coefficients in the stability polynomial of PERK time integrators.
using Convex, ECOS

# NLsolve is imported to solve the system of nonlinear equations to find the coefficients
# in the Butcher tableau in the third order P-ERK time integrator.
# in the Butcher tableau in the third order PERK time integrator.
using NLsolve

using OrdinaryDiffEq
Expand Down
2 changes: 1 addition & 1 deletion examples/tree_1d_dgsem/elixir_advection_perk2.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

# Convex and ECOS are imported because they are used for finding the optimal time step and optimal
# monomial coefficients in the stability polynomial of P-ERK time integrators.
# monomial coefficients in the stability polynomial of PERK time integrators.
using Convex, ECOS

using OrdinaryDiffEq
Expand Down
2 changes: 1 addition & 1 deletion ext/TrixiConvexECOSExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Package extension for adding Convex-based features to Trixi.jl
module TrixiConvexECOSExt

# Required for coefficient optimization in P-ERK scheme integrators
# Required for coefficient optimization in PERK scheme integrators
if isdefined(Base, :get_extension)
using Convex: MOI, solve!, Variable, minimize, evaluate
using ECOS: Optimizer
Expand Down
2 changes: 1 addition & 1 deletion ext/TrixiNLsolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
module TrixiNLsolveExt

# Required for finding coefficients in Butcher tableau in the third order of
# P-ERK scheme integrators
# PERK scheme integrators
if isdefined(Base, :get_extension)
using NLsolve: nlsolve
else
Expand Down
48 changes: 9 additions & 39 deletions src/time_integration/paired_explicit_runge_kutta/methods_PERK2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,8 @@ mutable struct PairedExplicitRK2Integrator{RealT <: Real, uType, Params, Sol, F,
finalstep::Bool # added for convenience
dtchangeable::Bool
force_stepfail::Bool
# PairedExplicitRK2 stages:
# Additional PairedExplicitRK2 stage
k1::uType
k_higher::uType
end

function init(ode::ODEProblem, alg::PairedExplicitRK2;
Expand All @@ -211,9 +210,8 @@ function init(ode::ODEProblem, alg::PairedExplicitRK2;
du = zero(u0)
u_tmp = zero(u0)

# PairedExplicitRK2 stages
# Additional PairedExplicitRK2 stage
k1 = zero(u0)
k_higher = zero(u0)

t0 = first(ode.tspan)
tdir = sign(ode.tspan[end] - ode.tspan[1])
Expand All @@ -226,7 +224,7 @@ function init(ode::ODEProblem, alg::PairedExplicitRK2;
ode.tspan;
kwargs...),
false, true, false,
k1, k_higher)
k1)

# initialize callbacks
if callback isa CallbackSet
Expand Down Expand Up @@ -262,46 +260,18 @@ function step!(integrator::PairedExplicitRK2Integrator)
end

@trixi_timeit timer() "Paired Explicit Runge-Kutta ODE integration step" begin
# k1
integrator.f(integrator.du, integrator.u, prob.p, integrator.t)
@threaded for i in eachindex(integrator.du)
integrator.k1[i] = integrator.du[i] * integrator.dt
end

# Construct current state
@threaded for i in eachindex(integrator.u)
integrator.u_tmp[i] = integrator.u[i] + alg.c[2] * integrator.k1[i]
end
# k2
integrator.f(integrator.du, integrator.u_tmp, prob.p,
integrator.t + alg.c[2] * integrator.dt)

@threaded for i in eachindex(integrator.du)
integrator.k_higher[i] = integrator.du[i] * integrator.dt
end
# First and second stage are identical across all single/standalone PERK methods
PERK_k1!(integrator, prob.p)
PERK_k2!(integrator, prob.p, alg.c)

# Higher stages
for stage in 3:(alg.num_stages)
# Construct current state
@threaded for i in eachindex(integrator.u)
integrator.u_tmp[i] = integrator.u[i] +
alg.a_matrix[stage - 2, 1] *
integrator.k1[i] +
alg.a_matrix[stage - 2, 2] *
integrator.k_higher[i]
end

integrator.f(integrator.du, integrator.u_tmp, prob.p,
integrator.t + alg.c[stage] * integrator.dt)

@threaded for i in eachindex(integrator.du)
integrator.k_higher[i] = integrator.du[i] * integrator.dt
end
PERK_ki!(integrator, prob.p, alg.c, alg.a_matrix, stage)
end

@threaded for i in eachindex(integrator.u)
integrator.u[i] += alg.b1 * integrator.k1[i] +
alg.bS * integrator.k_higher[i]
integrator.u[i] += integrator.dt * (alg.b1 * integrator.k1[i] +
alg.bS * integrator.du[i])
end
end # PairedExplicitRK2 step

Expand Down
72 changes: 24 additions & 48 deletions src/time_integration/paired_explicit_runge_kutta/methods_PERK3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function compute_PairedExplicitRK3_butcher_tableau(num_stages, tspan,
monomial_coeffs, c;
verbose)
end
# Fill A-matrix in P-ERK style
# Fill A-matrix in PERK style
a_matrix = zeros(num_stages - 2, 2)
a_matrix[:, 1] = c[3:end]
a_matrix[:, 1] -= a_unknown
Expand Down Expand Up @@ -91,7 +91,7 @@ function compute_PairedExplicitRK3_butcher_tableau(num_stages,
num_a_coeffs = size(a_coeffs, 1)

@assert num_a_coeffs == a_coeffs_max
# Fill A-matrix in P-ERK style
# Fill A-matrix in PERK style
a_matrix[:, 1] -= a_coeffs
a_matrix[:, 2] = a_coeffs

Expand All @@ -107,7 +107,7 @@ end
verbose = false, cS2 = 1.0f0)
Parameters:
- `num_stages` (`Int`): Number of stages in the paired explicit Runge-Kutta (P-ERK) method.
- `num_stages` (`Int`): Number of stages in the paired explicit Runge-Kutta (PERK) method.
- `base_path_a_coeffs` (`AbstractString`): Path to a file containing some coefficients in the A-matrix in
the Butcher tableau of the Runge Kutta method.
The matrix should be stored in a text file at `joinpath(base_path_a_coeffs, "a_$(num_stages).txt")` and separated by line breaks.
Expand All @@ -122,7 +122,7 @@ end
s is the number of stages, default is 1.0f0.
The following structures and methods provide an implementation of
the third-order paired explicit Runge-Kutta (P-ERK) method
the third-order paired explicit Runge-Kutta (PERK) method
optimized for a certain simulation setup (PDE, IC & BC, Riemann Solver, DG Solver).
The original paper is
- Nasab, Vermeire (2022)
Expand Down Expand Up @@ -258,61 +258,28 @@ function step!(integrator::PairedExplicitRK3Integrator)
end

@trixi_timeit timer() "Paired Explicit Runge-Kutta ODE integration step" begin
# k1
integrator.f(integrator.du, integrator.u, prob.p, integrator.t)
@threaded for i in eachindex(integrator.du)
integrator.k1[i] = integrator.du[i] * integrator.dt
end

# Construct current state
@threaded for i in eachindex(integrator.du)
integrator.u_tmp[i] = integrator.u[i] + alg.c[2] * integrator.k1[i]
end
# k2
integrator.f(integrator.du, integrator.u_tmp, prob.p,
integrator.t + alg.c[2] * integrator.dt)
# First and second stage are identical across all single/standalone PERK methods
PERK_k1!(integrator, prob.p)
PERK_k2!(integrator, prob.p, alg.c)

@threaded for i in eachindex(integrator.du)
integrator.k_higher[i] = integrator.du[i] * integrator.dt
end

# Higher stages
for stage in 3:(alg.num_stages - 1)
# Construct current state
@threaded for i in eachindex(integrator.du)
integrator.u_tmp[i] = integrator.u[i] +
alg.a_matrix[stage - 2, 1] *
integrator.k1[i] +
alg.a_matrix[stage - 2, 2] *
integrator.k_higher[i]
end

integrator.f(integrator.du, integrator.u_tmp, prob.p,
integrator.t + alg.c[stage] * integrator.dt)

@threaded for i in eachindex(integrator.du)
integrator.k_higher[i] = integrator.du[i] * integrator.dt
end
PERK_ki!(integrator, prob.p, alg.c, alg.a_matrix, stage)
end

# Last stage
@threaded for i in eachindex(integrator.du)
integrator.u_tmp[i] = integrator.u[i] +
alg.a_matrix[alg.num_stages - 2, 1] *
integrator.k1[i] +
alg.a_matrix[alg.num_stages - 2, 2] *
integrator.k_higher[i]
# We need to store `du` of the S-1 stage in `k_higher` for the final update:
@threaded for i in eachindex(integrator.u)
integrator.k_higher[i] = integrator.du[i]
end

integrator.f(integrator.du, integrator.u_tmp, prob.p,
integrator.t + alg.c[alg.num_stages] * integrator.dt)
PERK_ki!(integrator, prob.p, alg.c, alg.a_matrix, alg.num_stages)

@threaded for i in eachindex(integrator.u)
# "Own" PairedExplicitRK based on SSPRK33.
# Note that 'k_higher' carries the values of K_{S-1}
# and that we construct 'K_S' "in-place" from 'integrator.du'
integrator.u[i] += (integrator.k1[i] + integrator.k_higher[i] +
4.0 * integrator.du[i] * integrator.dt) / 6.0
integrator.u[i] += integrator.dt *
(integrator.k1[i] + integrator.k_higher[i] +
4.0 * integrator.du[i]) / 6.0
end
end # PairedExplicitRK step timer

Expand All @@ -334,4 +301,13 @@ function step!(integrator::PairedExplicitRK3Integrator)
terminate!(integrator)
end
end

function Base.resize!(integrator::PairedExplicitRK3Integrator, new_size)
resize!(integrator.u, new_size)
resize!(integrator.du, new_size)
resize!(integrator.u_tmp, new_size)

resize!(integrator.k1, new_size)
resize!(integrator.k_higher, new_size)
end
end # @muladd
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,49 @@ function solve!(integrator::AbstractPairedExplicitRKIntegrator)

@trixi_timeit timer() "main loop" while !integrator.finalstep
step!(integrator)
end # "main loop" timer
end

return TimeIntegratorSolution((first(prob.tspan), integrator.t),
(prob.u0, integrator.u),
integrator.sol.prob)
end

# Function that computes the first stage of a general PERK method
@inline function PERK_k1!(integrator::AbstractPairedExplicitRKIntegrator, p)
integrator.f(integrator.k1, integrator.u, p, integrator.t)
end

@inline function PERK_k2!(integrator::AbstractPairedExplicitRKSingleIntegrator, p, c)
@threaded for i in eachindex(integrator.du)
integrator.u_tmp[i] = integrator.u[i] + c[2] * integrator.dt * integrator.k1[i]
end

integrator.f(integrator.du, integrator.u_tmp, p,
integrator.t + c[2] * integrator.dt)
end

@inline function PERK_ki!(integrator::AbstractPairedExplicitRKSingleIntegrator, p, c,
a_matrix, stage)
# Construct current state
@threaded for i in eachindex(integrator.u)
integrator.u_tmp[i] = integrator.u[i] +
integrator.dt * (a_matrix[stage - 2, 1] *
integrator.k1[i] +
a_matrix[stage - 2, 2] *
integrator.du[i])
end

integrator.f(integrator.du, integrator.u_tmp, p,
integrator.t + c[stage] * integrator.dt)
end

# used for AMR (Adaptive Mesh Refinement)
function Base.resize!(integrator::AbstractPairedExplicitRKIntegrator, new_size)
resize!(integrator.u, new_size)
resize!(integrator.du, new_size)
resize!(integrator.u_tmp, new_size)

resize!(integrator.k1, new_size)
resize!(integrator.k_higher, new_size)
end

# get a cache where the RHS can be stored
Expand Down
4 changes: 2 additions & 2 deletions test/test_structured_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ end
save_solution=SaveSolutionCallback(dt = 0.1 + 1.0e-8), # Adding a small epsilon to avoid floating-point precision issues
callbacks=CallbackSet(summary_callback, save_solution,
analysis_callback, alive_callback),
l2=[5.726144786001842e-7],
linf=[3.430730019182704e-6])
l2=[5.727164156769191e-7],
linf=[3.4323776822997587e-6])
# Ensure that we do not have excessive memory allocations
# (e.g., from type instabilities)
let
Expand Down

0 comments on commit ea1c1d3

Please sign in to comment.