Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Adapt.jl to change storage and element type #2212

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Michael Schlottke-Lakemper <[email protected]>", "Gregor
version = "0.9.8-DEV"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
CodeTracking = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand Down Expand Up @@ -64,6 +65,7 @@ TrixiMakieExt = "Makie"
TrixiNLsolveExt = "NLsolve"

[compat]
Adapt = "3.7, 4.0"
Accessors = "0.1.12"
CodeTracking = "1.0.5"
ConstructionBase = "1.3"
Expand Down
1 change: 1 addition & 0 deletions src/Trixi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import SciMLBase: get_du, get_tmp_cache, u_modified!,

using DelimitedFiles: readdlm
using Downloads: Downloads
import Adapt
using CodeTracking: CodeTracking
using ConstructionBase: ConstructionBase
using DiffEqCallbacks: PeriodicCallback, PeriodicCallbackAffect
Expand Down
22 changes: 22 additions & 0 deletions src/auxiliary/vector_of_arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
# Since these FMAs can increase the performance of many numerical algorithms,
# we need to opt-in explicitly.
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
@muladd begin
#! format: noindent

# Wraps a Vector of Arrays, forwards `getindex` to the underlying Vector.
# Implements `Adapt.adapt_structure` to allow offloading to the GPU which is
# not possible for a plain Vector of Arrays.
struct VecOfArrays{T <: AbstractArray}
arrays::Vector{T}
end
Base.getindex(v::VecOfArrays, i::Int) = Base.getindex(v.arrays, i)
Base.IndexStyle(v::VecOfArrays) = Base.IndexStyle(v.arrays)
Base.size(v::VecOfArrays) = Base.size(v.arrays)
Base.length(v::VecOfArrays) = Base.length(v.arrays)
Base.eltype(v::VecOfArrays{T}) where {T} = T
function Adapt.adapt_structure(to, v::VecOfArrays)
return [Adapt.adapt(to, arr) for arr in v.arrays] |> VecOfArrays
end
end # @muladd
36 changes: 31 additions & 5 deletions src/semidiscretization/semidiscretization_hyperbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ mutable struct SemidiscretizationHyperbolic{Mesh, Equations, InitialCondition,

function SemidiscretizationHyperbolic{Mesh, Equations, InitialCondition,
BoundaryConditions, SourceTerms, Solver,
Cache}(mesh::Mesh, equations::Equations,
Cache}(mesh::Mesh,
equations::Equations,
initial_condition::InitialCondition,
boundary_conditions::BoundaryConditions,
source_terms::SourceTerms,
solver::Solver,
cache::Cache) where {Mesh, Equations,
cache::Cache,
performance_counter::PerformanceCounter) where {Mesh, Equations,
InitialCondition,
BoundaryConditions,
SourceTerms,
Solver,
Cache}
Comment on lines +40 to 45
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
performance_counter::PerformanceCounter) where {Mesh, Equations,
InitialCondition,
BoundaryConditions,
SourceTerms,
Solver,
Cache}
performance_counter::PerformanceCounter) where {
Mesh,
Equations,
InitialCondition,
BoundaryConditions,
SourceTerms,
Solver,
Cache
}

performance_counter = PerformanceCounter()

new(mesh, equations, initial_condition, boundary_conditions, source_terms,
solver, cache, performance_counter)
end
Expand Down Expand Up @@ -74,14 +74,16 @@ function SemidiscretizationHyperbolic(mesh, equations, initial_condition, solver

check_periodicity_mesh_boundary_conditions(mesh, _boundary_conditions)

performance_counter = PerformanceCounter()

SemidiscretizationHyperbolic{typeof(mesh), typeof(equations),
typeof(initial_condition),
typeof(_boundary_conditions), typeof(source_terms),
typeof(solver), typeof(cache)}(mesh, equations,
initial_condition,
_boundary_conditions,
source_terms, solver,
cache)
cache, performance_counter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
cache, performance_counter)
cache,
performance_counter)

end

# Create a new semidiscretization but change some parameters compared to the input.
Expand All @@ -103,6 +105,30 @@ function remake(semi::SemidiscretizationHyperbolic; uEltype = real(semi.solver),
source_terms, boundary_conditions, uEltype)
end

function Adapt.adapt_structure(to, semi::SemidiscretizationHyperbolic)
if !(typeof(semi.mesh) <: P4estMesh)
error("Adapt.adapt is only supported for semidiscretizations based on P4estMesh")
end

mesh = semi.mesh
equations = Adapt.adapt_structure(to, semi.equations)
initial_condition = Adapt.adapt_structure(to, semi.initial_condition)
boundary_conditions = Adapt.adapt_structure(to, semi.boundary_conditions)
source_terms = Adapt.adapt_structure(to, semi.source_terms)
solver = Adapt.adapt_structure(to, semi.solver)
cache = Adapt.adapt_structure(to, semi.cache)
performance_counter = semi.performance_counter

SemidiscretizationHyperbolic{typeof(mesh), typeof(equations),
typeof(initial_condition),
typeof(boundary_conditions), typeof(source_terms),
typeof(solver), typeof(cache)}(mesh, equations,
initial_condition,
boundary_conditions,
source_terms, solver,
cache, performance_counter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
cache, performance_counter)
cache,
performance_counter)

end

# general fallback
function digest_boundary_conditions(boundary_conditions, mesh, solver, cache)
boundary_conditions
Expand Down
35 changes: 35 additions & 0 deletions src/solvers/dgsem/basis_lobatto_legendre.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,31 @@ In particular, not the nodes themselves are returned.

@inline get_nodes(basis::LobattoLegendreBasis) = basis.nodes

function Adapt.adapt_structure(to, basis::LobattoLegendreBasis)
# Do not adapt SVector fields, i.e. nodes, weights and inverse_weights
(; nodes, weights, inverse_weights) = basis
inverse_vandermonde_legendre = Adapt.adapt_structure(to,
basis.inverse_vandermonde_legendre)
boundary_interpolation = basis.boundary_interpolation
derivative_matrix = Adapt.adapt_structure(to, basis.derivative_matrix)
derivative_split = Adapt.adapt_structure(to, basis.derivative_split)
derivative_split_transpose = Adapt.adapt_structure(to,
basis.derivative_split_transpose)
derivative_dhat = Adapt.adapt_structure(to, basis.derivative_dhat)
return LobattoLegendreBasis{real(basis), nnodes(basis), typeof(basis.nodes),
typeof(inverse_vandermonde_legendre),
typeof(boundary_interpolation),
typeof(derivative_matrix)}(nodes,
weights,
inverse_weights,
inverse_vandermonde_legendre,
boundary_interpolation,
derivative_matrix,
derivative_split,
derivative_split_transpose,
derivative_dhat)
end

"""
integrate(f, u, basis::LobattoLegendreBasis)

Expand Down Expand Up @@ -216,6 +241,16 @@ end

@inline polydeg(mortar::LobattoLegendreMortarL2) = nnodes(mortar) - 1

function Adapt.adapt_structure(to, mortar::LobattoLegendreMortarL2)
forward_upper = Adapt.adapt_structure(to, mortar.forward_upper)
forward_lower = Adapt.adapt_structure(to, mortar.forward_lower)
reverse_upper = Adapt.adapt_structure(to, mortar.reverse_upper)
reverse_lower = Adapt.adapt_structure(to, mortar.reverse_lower)
return LobattoLegendreMortarL2{real(mortar), nnodes(mortar), typeof(forward_upper),
typeof(reverse_upper)}(forward_upper, forward_lower,
reverse_upper, reverse_lower)
end

# TODO: We can create EC mortars along the lines of the following implementation.
# abstract type AbstractMortarEC{RealT} <: AbstractMortar{RealT} end

Expand Down
Loading
Loading