From ea1c1d3b54e87000e749d7f453a95f0e034f59fd Mon Sep 17 00:00:00 2001 From: Daniel_Doehring Date: Wed, 27 Nov 2024 10:16:24 +0100 Subject: [PATCH] More efficient PERK implementation --- .../elixir_burgers_perk3.jl | 4 +- .../tree_1d_dgsem/elixir_advection_perk2.jl | 2 +- ext/TrixiConvexECOSExt.jl | 2 +- ext/TrixiNLsolveExt.jl | 2 +- .../methods_PERK2.jl | 48 +++---------- .../methods_PERK3.jl | 72 +++++++------------ .../paired_explicit_runge_kutta.jl | 32 ++++++++- test/test_structured_1d.jl | 4 +- 8 files changed, 70 insertions(+), 96 deletions(-) diff --git a/examples/structured_1d_dgsem/elixir_burgers_perk3.jl b/examples/structured_1d_dgsem/elixir_burgers_perk3.jl index bf91fde74ea..5ff3aff3678 100644 --- a/examples/structured_1d_dgsem/elixir_burgers_perk3.jl +++ b/examples/structured_1d_dgsem/elixir_burgers_perk3.jl @@ -1,9 +1,9 @@ # Convex and ECOS are imported because they are used for finding the optimal time step and optimal -# monomial coefficients in the stability polynomial of P-ERK time integrators. +# monomial coefficients in the stability polynomial of PERK time integrators. using Convex, ECOS # NLsolve is imported to solve the system of nonlinear equations to find the coefficients -# in the Butcher tableau in the third order P-ERK time integrator. +# in the Butcher tableau in the third order PERK time integrator. using NLsolve using OrdinaryDiffEq diff --git a/examples/tree_1d_dgsem/elixir_advection_perk2.jl b/examples/tree_1d_dgsem/elixir_advection_perk2.jl index 69b587b42ec..86b173d6564 100644 --- a/examples/tree_1d_dgsem/elixir_advection_perk2.jl +++ b/examples/tree_1d_dgsem/elixir_advection_perk2.jl @@ -1,6 +1,6 @@ # Convex and ECOS are imported because they are used for finding the optimal time step and optimal -# monomial coefficients in the stability polynomial of P-ERK time integrators. +# monomial coefficients in the stability polynomial of PERK time integrators. using Convex, ECOS using OrdinaryDiffEq diff --git a/ext/TrixiConvexECOSExt.jl b/ext/TrixiConvexECOSExt.jl index 9b897436dbb..a83ac0a524f 100644 --- a/ext/TrixiConvexECOSExt.jl +++ b/ext/TrixiConvexECOSExt.jl @@ -1,7 +1,7 @@ # Package extension for adding Convex-based features to Trixi.jl module TrixiConvexECOSExt -# Required for coefficient optimization in P-ERK scheme integrators +# Required for coefficient optimization in PERK scheme integrators if isdefined(Base, :get_extension) using Convex: MOI, solve!, Variable, minimize, evaluate using ECOS: Optimizer diff --git a/ext/TrixiNLsolveExt.jl b/ext/TrixiNLsolveExt.jl index fa188d04c71..2e6bca0ba7b 100644 --- a/ext/TrixiNLsolveExt.jl +++ b/ext/TrixiNLsolveExt.jl @@ -2,7 +2,7 @@ module TrixiNLsolveExt # Required for finding coefficients in Butcher tableau in the third order of -# P-ERK scheme integrators +# PERK scheme integrators if isdefined(Base, :get_extension) using NLsolve: nlsolve else diff --git a/src/time_integration/paired_explicit_runge_kutta/methods_PERK2.jl b/src/time_integration/paired_explicit_runge_kutta/methods_PERK2.jl index 2451680a505..35208abe405 100644 --- a/src/time_integration/paired_explicit_runge_kutta/methods_PERK2.jl +++ b/src/time_integration/paired_explicit_runge_kutta/methods_PERK2.jl @@ -200,9 +200,8 @@ mutable struct PairedExplicitRK2Integrator{RealT <: Real, uType, Params, Sol, F, finalstep::Bool # added for convenience dtchangeable::Bool force_stepfail::Bool - # PairedExplicitRK2 stages: + # Additional PairedExplicitRK2 stage k1::uType - k_higher::uType end function init(ode::ODEProblem, alg::PairedExplicitRK2; @@ -211,9 +210,8 @@ function init(ode::ODEProblem, alg::PairedExplicitRK2; du = zero(u0) u_tmp = zero(u0) - # PairedExplicitRK2 stages + # Additional PairedExplicitRK2 stage k1 = zero(u0) - k_higher = zero(u0) t0 = first(ode.tspan) tdir = sign(ode.tspan[end] - ode.tspan[1]) @@ -226,7 +224,7 @@ function init(ode::ODEProblem, alg::PairedExplicitRK2; ode.tspan; kwargs...), false, true, false, - k1, k_higher) + k1) # initialize callbacks if callback isa CallbackSet @@ -262,46 +260,18 @@ function step!(integrator::PairedExplicitRK2Integrator) end @trixi_timeit timer() "Paired Explicit Runge-Kutta ODE integration step" begin - # k1 - integrator.f(integrator.du, integrator.u, prob.p, integrator.t) - @threaded for i in eachindex(integrator.du) - integrator.k1[i] = integrator.du[i] * integrator.dt - end - - # Construct current state - @threaded for i in eachindex(integrator.u) - integrator.u_tmp[i] = integrator.u[i] + alg.c[2] * integrator.k1[i] - end - # k2 - integrator.f(integrator.du, integrator.u_tmp, prob.p, - integrator.t + alg.c[2] * integrator.dt) - - @threaded for i in eachindex(integrator.du) - integrator.k_higher[i] = integrator.du[i] * integrator.dt - end + # First and second stage are identical across all single/standalone PERK methods + PERK_k1!(integrator, prob.p) + PERK_k2!(integrator, prob.p, alg.c) # Higher stages for stage in 3:(alg.num_stages) - # Construct current state - @threaded for i in eachindex(integrator.u) - integrator.u_tmp[i] = integrator.u[i] + - alg.a_matrix[stage - 2, 1] * - integrator.k1[i] + - alg.a_matrix[stage - 2, 2] * - integrator.k_higher[i] - end - - integrator.f(integrator.du, integrator.u_tmp, prob.p, - integrator.t + alg.c[stage] * integrator.dt) - - @threaded for i in eachindex(integrator.du) - integrator.k_higher[i] = integrator.du[i] * integrator.dt - end + PERK_ki!(integrator, prob.p, alg.c, alg.a_matrix, stage) end @threaded for i in eachindex(integrator.u) - integrator.u[i] += alg.b1 * integrator.k1[i] + - alg.bS * integrator.k_higher[i] + integrator.u[i] += integrator.dt * (alg.b1 * integrator.k1[i] + + alg.bS * integrator.du[i]) end end # PairedExplicitRK2 step diff --git a/src/time_integration/paired_explicit_runge_kutta/methods_PERK3.jl b/src/time_integration/paired_explicit_runge_kutta/methods_PERK3.jl index 02bc3eeba36..25d4e06286e 100644 --- a/src/time_integration/paired_explicit_runge_kutta/methods_PERK3.jl +++ b/src/time_integration/paired_explicit_runge_kutta/methods_PERK3.jl @@ -59,7 +59,7 @@ function compute_PairedExplicitRK3_butcher_tableau(num_stages, tspan, monomial_coeffs, c; verbose) end - # Fill A-matrix in P-ERK style + # Fill A-matrix in PERK style a_matrix = zeros(num_stages - 2, 2) a_matrix[:, 1] = c[3:end] a_matrix[:, 1] -= a_unknown @@ -91,7 +91,7 @@ function compute_PairedExplicitRK3_butcher_tableau(num_stages, num_a_coeffs = size(a_coeffs, 1) @assert num_a_coeffs == a_coeffs_max - # Fill A-matrix in P-ERK style + # Fill A-matrix in PERK style a_matrix[:, 1] -= a_coeffs a_matrix[:, 2] = a_coeffs @@ -107,7 +107,7 @@ end verbose = false, cS2 = 1.0f0) Parameters: - - `num_stages` (`Int`): Number of stages in the paired explicit Runge-Kutta (P-ERK) method. + - `num_stages` (`Int`): Number of stages in the paired explicit Runge-Kutta (PERK) method. - `base_path_a_coeffs` (`AbstractString`): Path to a file containing some coefficients in the A-matrix in the Butcher tableau of the Runge Kutta method. The matrix should be stored in a text file at `joinpath(base_path_a_coeffs, "a_$(num_stages).txt")` and separated by line breaks. @@ -122,7 +122,7 @@ end s is the number of stages, default is 1.0f0. The following structures and methods provide an implementation of -the third-order paired explicit Runge-Kutta (P-ERK) method +the third-order paired explicit Runge-Kutta (PERK) method optimized for a certain simulation setup (PDE, IC & BC, Riemann Solver, DG Solver). The original paper is - Nasab, Vermeire (2022) @@ -258,61 +258,28 @@ function step!(integrator::PairedExplicitRK3Integrator) end @trixi_timeit timer() "Paired Explicit Runge-Kutta ODE integration step" begin - # k1 - integrator.f(integrator.du, integrator.u, prob.p, integrator.t) - @threaded for i in eachindex(integrator.du) - integrator.k1[i] = integrator.du[i] * integrator.dt - end - - # Construct current state - @threaded for i in eachindex(integrator.du) - integrator.u_tmp[i] = integrator.u[i] + alg.c[2] * integrator.k1[i] - end - # k2 - integrator.f(integrator.du, integrator.u_tmp, prob.p, - integrator.t + alg.c[2] * integrator.dt) + # First and second stage are identical across all single/standalone PERK methods + PERK_k1!(integrator, prob.p) + PERK_k2!(integrator, prob.p, alg.c) - @threaded for i in eachindex(integrator.du) - integrator.k_higher[i] = integrator.du[i] * integrator.dt - end - - # Higher stages for stage in 3:(alg.num_stages - 1) - # Construct current state - @threaded for i in eachindex(integrator.du) - integrator.u_tmp[i] = integrator.u[i] + - alg.a_matrix[stage - 2, 1] * - integrator.k1[i] + - alg.a_matrix[stage - 2, 2] * - integrator.k_higher[i] - end - - integrator.f(integrator.du, integrator.u_tmp, prob.p, - integrator.t + alg.c[stage] * integrator.dt) - - @threaded for i in eachindex(integrator.du) - integrator.k_higher[i] = integrator.du[i] * integrator.dt - end + PERK_ki!(integrator, prob.p, alg.c, alg.a_matrix, stage) end - # Last stage - @threaded for i in eachindex(integrator.du) - integrator.u_tmp[i] = integrator.u[i] + - alg.a_matrix[alg.num_stages - 2, 1] * - integrator.k1[i] + - alg.a_matrix[alg.num_stages - 2, 2] * - integrator.k_higher[i] + # We need to store `du` of the S-1 stage in `k_higher` for the final update: + @threaded for i in eachindex(integrator.u) + integrator.k_higher[i] = integrator.du[i] end - integrator.f(integrator.du, integrator.u_tmp, prob.p, - integrator.t + alg.c[alg.num_stages] * integrator.dt) + PERK_ki!(integrator, prob.p, alg.c, alg.a_matrix, alg.num_stages) @threaded for i in eachindex(integrator.u) # "Own" PairedExplicitRK based on SSPRK33. # Note that 'k_higher' carries the values of K_{S-1} # and that we construct 'K_S' "in-place" from 'integrator.du' - integrator.u[i] += (integrator.k1[i] + integrator.k_higher[i] + - 4.0 * integrator.du[i] * integrator.dt) / 6.0 + integrator.u[i] += integrator.dt * + (integrator.k1[i] + integrator.k_higher[i] + + 4.0 * integrator.du[i]) / 6.0 end end # PairedExplicitRK step timer @@ -334,4 +301,13 @@ function step!(integrator::PairedExplicitRK3Integrator) terminate!(integrator) end end + +function Base.resize!(integrator::PairedExplicitRK3Integrator, new_size) + resize!(integrator.u, new_size) + resize!(integrator.du, new_size) + resize!(integrator.u_tmp, new_size) + + resize!(integrator.k1, new_size) + resize!(integrator.k_higher, new_size) +end end # @muladd diff --git a/src/time_integration/paired_explicit_runge_kutta/paired_explicit_runge_kutta.jl b/src/time_integration/paired_explicit_runge_kutta/paired_explicit_runge_kutta.jl index c606326738f..84bc54e829a 100644 --- a/src/time_integration/paired_explicit_runge_kutta/paired_explicit_runge_kutta.jl +++ b/src/time_integration/paired_explicit_runge_kutta/paired_explicit_runge_kutta.jl @@ -108,13 +108,42 @@ function solve!(integrator::AbstractPairedExplicitRKIntegrator) @trixi_timeit timer() "main loop" while !integrator.finalstep step!(integrator) - end # "main loop" timer + end return TimeIntegratorSolution((first(prob.tspan), integrator.t), (prob.u0, integrator.u), integrator.sol.prob) end +# Function that computes the first stage of a general PERK method +@inline function PERK_k1!(integrator::AbstractPairedExplicitRKIntegrator, p) + integrator.f(integrator.k1, integrator.u, p, integrator.t) +end + +@inline function PERK_k2!(integrator::AbstractPairedExplicitRKSingleIntegrator, p, c) + @threaded for i in eachindex(integrator.du) + integrator.u_tmp[i] = integrator.u[i] + c[2] * integrator.dt * integrator.k1[i] + end + + integrator.f(integrator.du, integrator.u_tmp, p, + integrator.t + c[2] * integrator.dt) +end + +@inline function PERK_ki!(integrator::AbstractPairedExplicitRKSingleIntegrator, p, c, + a_matrix, stage) + # Construct current state + @threaded for i in eachindex(integrator.u) + integrator.u_tmp[i] = integrator.u[i] + + integrator.dt * (a_matrix[stage - 2, 1] * + integrator.k1[i] + + a_matrix[stage - 2, 2] * + integrator.du[i]) + end + + integrator.f(integrator.du, integrator.u_tmp, p, + integrator.t + c[stage] * integrator.dt) +end + # used for AMR (Adaptive Mesh Refinement) function Base.resize!(integrator::AbstractPairedExplicitRKIntegrator, new_size) resize!(integrator.u, new_size) @@ -122,7 +151,6 @@ function Base.resize!(integrator::AbstractPairedExplicitRKIntegrator, new_size) resize!(integrator.u_tmp, new_size) resize!(integrator.k1, new_size) - resize!(integrator.k_higher, new_size) end # get a cache where the RHS can be stored diff --git a/test/test_structured_1d.jl b/test/test_structured_1d.jl index 3061257d508..25d8bfe0a64 100644 --- a/test/test_structured_1d.jl +++ b/test/test_structured_1d.jl @@ -82,8 +82,8 @@ end save_solution=SaveSolutionCallback(dt = 0.1 + 1.0e-8), # Adding a small epsilon to avoid floating-point precision issues callbacks=CallbackSet(summary_callback, save_solution, analysis_callback, alive_callback), - l2=[5.726144786001842e-7], - linf=[3.430730019182704e-6]) + l2=[5.727164156769191e-7], + linf=[3.4323776822997587e-6]) # Ensure that we do not have excessive memory allocations # (e.g., from type instabilities) let