Skip to content

Commit

Permalink
Extended T8codeMesh backend to support MPI interfaces datastructures.
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Markert committed Jan 12, 2024
1 parent e7175cf commit 5d906f8
Show file tree
Hide file tree
Showing 24 changed files with 860 additions and 156 deletions.
17 changes: 16 additions & 1 deletion examples/t8code_2d_dgsem/elixir_advection_basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,23 @@ mesh = T8codeMesh(trees_per_dimension, polydeg = 3,
coordinates_min = coordinates_min, coordinates_max = coordinates_max,
initial_refinement_level = 1)

function my_initial_condition(x, t,
equation::LinearScalarAdvectionEquation2D)
# Store translated coordinate for easy use of exact solution
x_trans = x - equation.advection_velocity * t

c = 1.0
A = 0.5
L = 2
f = 1 / L
omega = 2 * pi * f
scalar = c + A * sin(omega * sum(x_trans))

return SVector(scalar)
end

# A semidiscretization collects data structures and functions for the spatial discretization
semi = SemidiscretizationHyperbolic(mesh, equations, initial_condition_convergence_test,
semi = SemidiscretizationHyperbolic(mesh, equations, my_initial_condition,
solver)

###############################################################################
Expand Down
15 changes: 15 additions & 0 deletions src/auxiliary/mpi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ end
return nothing
end

@inline function mpi_println_serial(args...)
if mpi_rank() > 0
MPI.recv(mpi_rank()-1, 42, mpi_comm())
end

println("rank = $(mpi_rank()) | ", args...)
flush(stdout)

if mpi_rank() < mpi_nranks() - 1
MPI.send(undef, mpi_rank()+1, 42, mpi_comm())
end

return nothing
end

"""
ode_norm(u, t)
Expand Down
355 changes: 267 additions & 88 deletions src/auxiliary/t8code.jl

Large diffs are not rendered by default.

25 changes: 13 additions & 12 deletions src/callbacks_step/amr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::P4estMesh,
return has_changed
end

function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::SerialT8codeMesh,
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::T8codeMesh,
equations, dg::DG, cache, semi,
t, iter;
only_refine = false, only_coarsen = false,
Expand Down Expand Up @@ -754,21 +754,22 @@ function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::SerialT8codeMe
@trixi_timeit timer() "adapt" begin
difference = @trixi_timeit timer() "mesh" trixi_t8_adapt!(mesh, indicators)

@trixi_timeit timer() "solver" adapt!(u_ode, adaptor, mesh, equations, dg,
cache, difference)
end
# Store whether there were any cells coarsened or refined and perform load balancing.
has_changed = any(difference .!= 0)

# Store whether there were any cells coarsened or refined and perform load balancing.
has_changed = any(difference .!= 0)
# Check if mesh changed on other processes
if mpi_isparallel()
has_changed = MPI.Allreduce!(Ref(has_changed), |, mpi_comm())[]
end

# TODO: T8codeMesh for MPI not implemented yet.
# Check if mesh changed on other processes
# if mpi_isparallel()
# has_changed = MPI.Allreduce!(Ref(has_changed), |, mpi_comm())[]
# end
if has_changed
@trixi_timeit timer() "solver" adapt!(u_ode, adaptor, mesh, equations, dg,
cache, difference)
end
end

if has_changed
# TODO: T8codeMesh for MPI not implemented yet.
# # TODO: T8codeMesh for rebalance/partition not implemented yet.
# if mpi_isparallel() && amr_callback.dynamic_load_balancing
# @trixi_timeit timer() "dynamic load balancing" begin
# global_first_quadrant = unsafe_wrap(Array, mesh.p4est.global_first_quadrant, mpi_nranks() + 1)
Expand Down
5 changes: 0 additions & 5 deletions src/callbacks_step/amr_dg2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,6 @@ end
function adapt!(u_ode::AbstractVector, adaptor, mesh::T8codeMesh{2}, equations,
dg::DGSEM, cache, difference)

# Return early if there is nothing to do.
if !any(difference .!= 0)
return nothing
end

# Number of (local) cells/elements.
old_nelems = nelements(dg, cache)
new_nelems = ncells(mesh)
Expand Down
5 changes: 0 additions & 5 deletions src/callbacks_step/amr_dg3d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,6 @@ end
function adapt!(u_ode::AbstractVector, adaptor, mesh::T8codeMesh{3}, equations,
dg::DGSEM, cache, difference)

# Return early if there is nothing to do.
if !any(difference .!= 0)
return nothing
end

# Number of (local) cells/elements.
old_nelems = nelements(dg, cache)
new_nelems = ncells(mesh)
Expand Down
4 changes: 2 additions & 2 deletions src/callbacks_step/analysis_dg2d_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ function calc_error_norms_per_element(func, u, t, analyzer,
end

function calc_error_norms(func, u, t, analyzer,
mesh::ParallelP4estMesh{2}, equations,
mesh::Union{ParallelP4estMesh{2},ParallelT8codeMesh{2}}, equations,
initial_condition, dg::DGSEM, cache, cache_analysis)
@unpack vandermonde, weights = analyzer
@unpack node_coordinates, inverse_jacobian = cache.elements
Expand Down Expand Up @@ -171,7 +171,7 @@ function integrate_via_indices(func::Func, u,
end

function integrate_via_indices(func::Func, u,
mesh::ParallelP4estMesh{2}, equations,
mesh::Union{ParallelP4estMesh{2},ParallelT8codeMesh{2}}, equations,
dg::DGSEM, cache, args...; normalize = true) where {Func}
@unpack weights = dg.basis

Expand Down
4 changes: 2 additions & 2 deletions src/callbacks_step/analysis_dg3d_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#! format: noindent

function calc_error_norms(func, u, t, analyzer,
mesh::ParallelP4estMesh{3}, equations,
mesh::Union{ParallelP4estMesh{3},ParallelT8codeMesh{3}}, equations,
initial_condition, dg::DGSEM, cache, cache_analysis)
@unpack vandermonde, weights = analyzer
@unpack node_coordinates, inverse_jacobian = cache.elements
Expand Down Expand Up @@ -64,7 +64,7 @@ function calc_error_norms(func, u, t, analyzer,
end

function integrate_via_indices(func::Func, u,
mesh::ParallelP4estMesh{3}, equations,
mesh::Union{ParallelP4estMesh{3},ParallelT8codeMesh{3}}, equations,
dg::DGSEM, cache, args...; normalize = true) where {Func}
@unpack weights = dg.basis

Expand Down
33 changes: 33 additions & 0 deletions src/callbacks_step/stepsize_dg2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,22 @@ function max_dt(u, t, mesh::ParallelP4estMesh{2},
return dt
end

function max_dt(u, t, mesh::ParallelT8codeMesh{2},
constant_speed::False, equations, dg::DG, cache)
# call the method accepting a general `mesh::P4estMesh{2}`
# TODO: MPI, we should improve this; maybe we should dispatch on `u`
# and create some MPI array type, overloading broadcasting and mapreduce etc.
# Then, this specific array type should also work well with DiffEq etc.
dt = invoke(max_dt,
Tuple{typeof(u), typeof(t), T8codeMesh{2},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]

return dt
end

function max_dt(u, t, mesh::ParallelP4estMesh{2},
constant_speed::True, equations, dg::DG, cache)
# call the method accepting a general `mesh::P4estMesh{2}`
Expand All @@ -174,4 +190,21 @@ function max_dt(u, t, mesh::ParallelP4estMesh{2},

return dt
end

function max_dt(u, t, mesh::ParallelT8codeMesh{2},
constant_speed::True, equations, dg::DG, cache)
# call the method accepting a general `mesh::P4estMesh{2}`
# TODO: MPI, we should improve this; maybe we should dispatch on `u`
# and create some MPI array type, overloading broadcasting and mapreduce etc.
# Then, this specific array type should also work well with DiffEq etc.
dt = invoke(max_dt,
Tuple{typeof(u), typeof(t), T8codeMesh{2},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]

return dt
end

end # @muladd
33 changes: 33 additions & 0 deletions src/callbacks_step/stepsize_dg3d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,37 @@ function max_dt(u, t, mesh::ParallelP4estMesh{3},

return dt
end

function max_dt(u, t, mesh::ParallelT8codeMesh{3},
constant_speed::False, equations, dg::DG, cache)
# call the method accepting a general `mesh::P4estMesh{3}`
# TODO: MPI, we should improve this; maybe we should dispatch on `u`
# and create some MPI array type, overloading broadcasting and mapreduce etc.
# Then, this specific array type should also work well with DiffEq etc.
dt = invoke(max_dt,
Tuple{typeof(u), typeof(t), T8codeMesh{3},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]

return dt
end

function max_dt(u, t, mesh::ParallelT8codeMesh{3},
constant_speed::True, equations, dg::DG, cache)
# call the method accepting a general `mesh::P4estMesh{3}`
# TODO: MPI, we should improve this; maybe we should dispatch on `u`
# and create some MPI array type, overloading broadcasting and mapreduce etc.
# Then, this specific array type should also work well with DiffEq etc.
dt = invoke(max_dt,
Tuple{typeof(u), typeof(t), T8codeMesh{3},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]

return dt
end

end # @muladd
36 changes: 22 additions & 14 deletions src/meshes/t8code_mesh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ mutable struct T8codeMesh{NDIMS, RealT <: Real, IsParallel, NDIMSP2, NNODES} <:
boundary_names,
current_filename) where {NDIMS}

is_parallel = mpi_isparallel()
is_parallel = mpi_isparallel() ? True() : False()

mesh = new{NDIMS, Float64, typeof(is_parallel), NDIMS + 2, length(nodes)}(forest,
is_parallel)
Expand Down Expand Up @@ -81,7 +81,7 @@ const ParallelT8codeMesh{NDIMS} = T8codeMesh{NDIMS, <:Real, <:True}
@inline Base.ndims(::T8codeMesh{NDIMS}) where {NDIMS} = NDIMS
@inline Base.real(::T8codeMesh{NDIMS, RealT}) where {NDIMS, RealT} = RealT

@inline ntrees(mesh::T8codeMesh) = Int(t8_forest_get_num_local_trees(mesh.forest))
@inline ntrees(mesh::T8codeMesh) = size(mesh.tree_node_coordinates)[end]
@inline ncells(mesh::T8codeMesh) = Int(t8_forest_get_local_num_elements(mesh.forest))
@inline ninterfaces(mesh::T8codeMesh) = mesh.ninterfaces
@inline nmortars(mesh::T8codeMesh) = mesh.nmortars
Expand Down Expand Up @@ -188,21 +188,22 @@ function T8codeMesh(trees_per_dimension; polydeg,
T8code.Libt8.p8est_connectivity_destroy(conn)
end

do_face_ghost = mpi_isparallel()
scheme = t8_scheme_new_default_cxx()
forest = t8_forest_new_uniform(cmesh, scheme, initial_refinement_level, 0, mpi_comm())
forest = t8_forest_new_uniform(cmesh, scheme, initial_refinement_level, do_face_ghost, mpi_comm())

basis = LobattoLegendreBasis(RealT, polydeg)
nodes = basis.nodes

num_trees = t8_cmesh_get_num_trees(cmesh)

tree_node_coordinates = Array{RealT, NDIMS + 2}(undef, NDIMS,
ntuple(_ -> length(nodes), NDIMS)...,
prod(trees_per_dimension))
num_trees)

# Get cell length in reference mesh: Omega_ref = [-1,1]^2.
dx = [2 / n for n in trees_per_dimension]

num_local_trees = t8_cmesh_get_num_local_trees(cmesh)

# Non-periodic boundaries.
boundary_names = fill(Symbol("---"), 2 * NDIMS, prod(trees_per_dimension))

Expand All @@ -212,12 +213,12 @@ function T8codeMesh(trees_per_dimension; polydeg,
mapping_ = mapping
end

for itree in 1:num_local_trees
for itree in 1:num_trees
veptr = t8_cmesh_get_tree_vertices(cmesh, itree - 1)
verts = unsafe_wrap(Array, veptr, (3, 1 << NDIMS))

if NDIMS == 2
# Calculate node coordinates of reference mesh.
# Calculate node coordinates of reference mesh for 2D.
cell_x_offset = (verts[1, 1] - 0.5 * (trees_per_dimension[1] - 1)) * dx[1]
cell_y_offset = (verts[2, 1] - 0.5 * (trees_per_dimension[2] - 1)) * dx[2]

Expand All @@ -228,6 +229,7 @@ function T8codeMesh(trees_per_dimension; polydeg,
dx[2] * nodes[j] / 2)
end
elseif NDIMS == 3
# Calculate node coordinates of reference mesh for 2D.
cell_x_offset = (verts[1, 1] - 0.5 * (trees_per_dimension[1] - 1)) * dx[1]
cell_y_offset = (verts[2, 1] - 0.5 * (trees_per_dimension[2] - 1)) * dx[2]
cell_z_offset = (verts[3, 1] - 0.5 * (trees_per_dimension[3] - 1)) * dx[3]
Expand Down Expand Up @@ -289,17 +291,19 @@ conforming mesh from a `t8_cmesh` data structure.
function T8codeMesh{NDIMS}(cmesh::Ptr{t8_cmesh};
mapping = nothing, polydeg = 1, RealT = Float64,
initial_refinement_level = 0) where {NDIMS}

do_face_ghost = mpi_isparallel()
scheme = t8_scheme_new_default_cxx()
forest = t8_forest_new_uniform(cmesh, scheme, initial_refinement_level, 0, mpi_comm())
forest = t8_forest_new_uniform(cmesh, scheme, initial_refinement_level, do_face_ghost, mpi_comm())

basis = LobattoLegendreBasis(RealT, polydeg)
nodes = basis.nodes

num_local_trees = t8_cmesh_get_num_local_trees(cmesh)
num_trees = t8_cmesh_get_num_trees(cmesh)

tree_node_coordinates = Array{RealT, NDIMS + 2}(undef, NDIMS,
ntuple(_ -> length(nodes), NDIMS)...,
num_local_trees)
num_trees)

nodes_in = [-1.0, 1.0]
matrix = polynomial_interpolation_matrix(nodes_in, nodes)
Expand All @@ -308,7 +312,7 @@ function T8codeMesh{NDIMS}(cmesh::Ptr{t8_cmesh};
data_in = Array{RealT, 3}(undef, 2, 2, 2)
tmp1 = zeros(RealT, 2, length(nodes), length(nodes_in))

for itree in 0:(num_local_trees - 1)
for itree in 0:(num_trees - 1)
veptr = t8_cmesh_get_tree_vertices(cmesh, itree)
verts = unsafe_wrap(Array, veptr, (3, 1 << NDIMS))

Expand Down Expand Up @@ -339,7 +343,7 @@ function T8codeMesh{NDIMS}(cmesh::Ptr{t8_cmesh};
data_in = Array{RealT, 4}(undef, 3, 2, 2, 2)
tmp1 = zeros(RealT, 3, length(nodes), length(nodes_in), length(nodes_in))

for itree in 0:(num_local_trees - 1)
for itree in 0:(num_trees - 1)
veptr = t8_cmesh_get_tree_vertices(cmesh, itree)
verts = unsafe_wrap(Array, veptr, (3, 1 << NDIMS))

Expand Down Expand Up @@ -367,7 +371,7 @@ function T8codeMesh{NDIMS}(cmesh::Ptr{t8_cmesh};
map_node_coordinates!(tree_node_coordinates, mapping)

# There's no simple and generic way to distinguish boundaries. Name all of them :all.
boundary_names = fill(:all, 2 * NDIMS, num_local_trees)
boundary_names = fill(:all, 2 * NDIMS, num_trees)

return T8codeMesh{NDIMS}(forest, tree_node_coordinates, nodes,
boundary_names, "")
Expand Down Expand Up @@ -466,3 +470,7 @@ end
function partition!(mesh::T8codeMesh; allow_coarsening = true, weight_fn = C_NULL)
return nothing
end

function update_ghost_layer!(mesh::ParallelT8codeMesh)
return nothing
end
4 changes: 2 additions & 2 deletions src/solvers/dgsem_p4est/containers_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function Base.resize!(mpi_interfaces::P4estMPIInterfaceContainer, capacity)
end

# Create MPI interface container and initialize interface data
function init_mpi_interfaces(mesh::ParallelP4estMesh, equations, basis, elements)
function init_mpi_interfaces(mesh::Union{ParallelP4estMesh,ParallelT8codeMesh}, equations, basis, elements)
NDIMS = ndims(elements)
uEltype = eltype(elements)

Expand Down Expand Up @@ -133,7 +133,7 @@ function Base.resize!(mpi_mortars::P4estMPIMortarContainer, capacity)
end

# Create MPI mortar container and initialize MPI mortar data
function init_mpi_mortars(mesh::ParallelP4estMesh, equations, basis, elements)
function init_mpi_mortars(mesh::Union{ParallelP4estMesh,ParallelT8codeMesh}, equations, basis, elements)
NDIMS = ndims(mesh)
RealT = real(mesh)
uEltype = eltype(elements)
Expand Down
Loading

0 comments on commit 5d906f8

Please sign in to comment.