Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of concept: TrixiMPIArray #1104

Draft
wants to merge 37 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
8042a04
WIP: TrixiMPIArray
ranocha Mar 30, 2022
5da9e5c
update TODO notes
ranocha Mar 30, 2022
d09b4cc
use TrixiMPIArrays via allocate_coefficients
ranocha Mar 31, 2022
3434a58
do not dispatch on TrixiMPIArray for saving solution/restart files
ranocha Mar 31, 2022
8135d7f
WIP: experiment with global/local length settings
ranocha Mar 31, 2022
987407e
resize!
ranocha Apr 1, 2022
e7d3db3
Merge branch 'main' into hr/mpi_arrays
ranocha Apr 1, 2022
4e33567
add error-based step size control to tests
ranocha Apr 1, 2022
eb1d9b1
SIMD optimizations specialize also on TrixiMPIArrays
ranocha Apr 1, 2022
c2e0b86
replace some 1:length by eachindex
ranocha Apr 1, 2022
d58b1b5
local_copy for AMR
ranocha Apr 1, 2022
23d4520
specialize show
ranocha Apr 1, 2022
d8d85b7
clean up
ranocha Apr 1, 2022
4df8602
specialize view
ranocha Apr 1, 2022
3528a16
clean up
ranocha Apr 1, 2022
6f984c0
use global mpi_comm() again instead of mpi_comm(u)
ranocha Apr 1, 2022
efa08e7
dispatch on parallel mesh instead of TrixiMPIArray whenever possible
ranocha Apr 1, 2022
80f6d59
YAGNI mpi_rank, mpi_size
ranocha Apr 1, 2022
d28c888
use accessor function mpi_comm
ranocha Apr 1, 2022
5ba7f9e
update comment
ranocha Apr 1, 2022
01186f6
Merge branch 'hr/mpi_arrays' of github.com:trixi-framework/Trixi.jl i…
ranocha Apr 1, 2022
96e4a3d
fix efa08e7a76f1b823217c0c9981194510bee3caec
ranocha Apr 1, 2022
bce6cb7
get rid of local_copy
ranocha Apr 1, 2022
0af5ae7
Merge branch 'main' into hr/mpi_arrays
ranocha Apr 1, 2022
76ae70f
test P4estMesh in 2D and 3D with MPI and error-based step size control
ranocha Apr 1, 2022
82e480a
MPI tests with error-based step size control with reltol as rtol
ranocha Apr 3, 2022
5049bbb
specialize broadcasting
ranocha Apr 4, 2022
195e1e0
get rid of local_length
ranocha Apr 4, 2022
4de9a6f
more tests of TrixiMPIArrays
ranocha Apr 4, 2022
5277f86
print test names with error-based step size control
ranocha Apr 5, 2022
53fbfbd
export ode_norm, ode_unstable_check
ranocha Apr 5, 2022
b549d2a
more comments
ranocha Apr 5, 2022
8857b7a
fuse MPI reductions
ranocha Apr 5, 2022
923c6ad
clean-up
ranocha Apr 5, 2022
e38692b
mark ode_norm, ode_unstable_check as experimental
ranocha Apr 5, 2022
126d54d
Merge branch 'main' into hr/mpi_arrays
ranocha Apr 5, 2022
cdcf828
Merge branch 'main' into hr/mpi_arrays
ranocha Apr 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Michael Schlottke-Lakemper <[email protected]>", "Gregor
version = "0.4.28-pre"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CodeTracking = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
Expand Down Expand Up @@ -40,6 +41,7 @@ TriplotRecipes = "808ab39a-a642-4abf-81ff-4cb34ebbffa3"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
ArrayInterface = "3"
CodeTracking = "1.0.5"
ConstructionBase = "1.3"
EllipsisNotation = "1.0"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/reference-trixi.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ CurrentModule = Trixi
```

```@autodocs
Modules = [Trixi]
Modules = [Trixi, Trixi.TrixiMPIArrays]
```
3 changes: 2 additions & 1 deletion src/Trixi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ using SparseArrays: AbstractSparseMatrix, AbstractSparseMatrixCSC, sparse, dropt
# import @reexport now to make it available for further imports/exports
using Reexport: @reexport

using ArrayInterface: static_length
sloede marked this conversation as resolved.
Show resolved Hide resolved
using SciMLBase: CallbackSet, DiscreteCallback,
ODEProblem, ODESolution, ODEFunction
import SciMLBase: get_du, get_tmp_cache, u_modified!,
Expand All @@ -39,7 +40,6 @@ using HDF5: h5open, attributes
using IfElse: ifelse
using LinearMaps: LinearMap
using LoopVectorization: LoopVectorization, @turbo, indices
using LoopVectorization.ArrayInterface: static_length
using MPI: MPI
using MuladdMacro: @muladd
using GeometryBasics: GeometryBasics
Expand Down Expand Up @@ -99,6 +99,7 @@ include("basic_types.jl")
# Include all top-level source files
include("auxiliary/auxiliary.jl")
include("auxiliary/mpi.jl")
include("auxiliary/mpi_arrays.jl")
include("auxiliary/p4est.jl")
include("equations/equations.jl")
include("meshes/meshes.jl")
Expand Down
228 changes: 228 additions & 0 deletions src/auxiliary/mpi_arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@

# TODO: MPI. We keep this module inside Trixi for now. When it stabilizes and
# turns out to be generally useful, we can consider moving it to a
# separate package with simple test suite and documentation.
module TrixiMPIArrays

using ArrayInterface: ArrayInterface
using MPI: MPI

using ..Trixi: Trixi, mpi_comm

export TrixiMPIArray, ode_norm, ode_unstable_check


# Dispatch etc.
# The following functions have special dispatch behavior for `TrixiMPIArray`s.
# - `wrap_array`:
# the wrapped array is wrapped again in a `TrixiMPIArray`
# - `wrap_array_native`:
# should not be changed since it should return a plain `Array`
# - `allocate_coefficients`:
# this handles the return type of initialization stuff when setting an IC
# with MPI
# Besides these, we usually dispatch on MPI mesh types such as
# `mesh::ParallelTreeMesh` or ``mesh::ParallelP4eestMesh`, since this is
# consistent with other dispatches on the mesh type. However, we dispatch on
# `u::TrixiMPIArray` whenever this allows simplifying some code, e.g., because
# we can call a basic function on `parent(u)` and add some MPI stuff on top.
"""
TrixiMPIArray{T, N} <: AbstractArray{T, N}

A thin wrapper of arrays distributed via MPI used in Trixi.jl. The idea is that
these arrays behave as much as possible as plain arrays would in an SPMD-style
distributed MPI setting with exception of reductions, which are performed
globally. This allows to use these arrays in ODE solvers such as the ones from
OrdinaryDiffEq.jl, since vector space operations, broadcasting, and reductions
are the only operations required for explicit time integration methods with
fixed step sizes or adaptive step sizes based on CFL or error estimates.

!!! warning "Experimental code"
This code is experimental and may be changed or removed in any future release.
"""
struct TrixiMPIArray{T, N, Parent<:AbstractArray{T, N}} <: AbstractArray{T, N}
u_local::Parent
mpi_comm::MPI.Comm
end

function TrixiMPIArray(u_local::AbstractArray{T, N}) where {T, N}
# TODO: MPI. Hard-coded to MPI.COMM_WORLD for now
mpi_comm = MPI.COMM_WORLD
TrixiMPIArray{T, N, typeof(u_local)}(u_local, mpi_comm)
end


# `Base.show` with additional helpful information
function Base.show(io::IO, u::TrixiMPIArray)
print(io, "TrixiMPIArray wrapping ", parent(u))
end

function Base.show(io::IO, mime::MIME"text/plain", u::TrixiMPIArray)
print(io, "TrixiMPIArray wrapping ")
show(io, mime, parent(u))
end


# Custom interface and general Base interface not covered by other parts below
Base.parent(u::TrixiMPIArray) = u.u_local
Base.resize!(u::TrixiMPIArray, new_size) = resize!(parent(u), new_size)
function Base.copy(u::TrixiMPIArray)
return TrixiMPIArray(copy(parent(u)), mpi_comm(u))
end

Trixi.mpi_comm(u::TrixiMPIArray) = u.mpi_comm


# Implementation of the abstract array interface of Base
# See https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-array
Base.size(u::TrixiMPIArray) = size(parent(u))
Base.getindex(u::TrixiMPIArray, idx) = getindex(parent(u), idx)
Base.setindex!(u::TrixiMPIArray, v, idx) = setindex!(parent(u), v, idx)
Base.IndexStyle(::Type{TrixiMPIArray{T, N, Parent}}) where {T, N, Parent} = IndexStyle(Parent)
function Base.similar(u::TrixiMPIArray, ::Type{S}, dims::NTuple{N, Int}) where {S, N}
return TrixiMPIArray(similar(parent(u), S, dims), mpi_comm(u))
end
Base.axes(u::TrixiMPIArray) = axes(parent(u))


# Implementation of the strided array interface of Base
# See https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays
Base.strides(u::TrixiMPIArray) = strides(parent(u))
function Base.unsafe_convert(::Type{Ptr{T}}, u::TrixiMPIArray{T}) where {T}
return Base.unsafe_convert(Ptr{T}, parent(u))
end
Base.elsize(::Type{TrixiMPIArray{T, N, Parent}}) where {T, N, Parent} = Base.elsize(Parent)


# We need to customize broadcasting since broadcasting expressions allocating
# an output would return plain `Array`s otherwise, losing the MPI information.
# Such allocating broadcasting calls are used for example when determining the
# initial step size in OrdinaryDiffEq.jl.
# However, everything else appears to be fine, i.e., all broadcasting calls
# with a given output storage location work fine. In particular, threaded
# broadcasting with FastBroadcast.jl works fine, e.g., when using threaded RK
# methods such as `SSPRK43(thread=OrdinaryDiffEq.True())`.
# See also
# https://github.com/YingboMa/FastBroadcast.jl
# https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting
Base.BroadcastStyle(::Type{<:TrixiMPIArray}) = Broadcast.ArrayStyle{TrixiMPIArray}()

function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{TrixiMPIArray}}, ::Type{ElType}) where ElType
# Scan the inputs for the first TrixiMPIArray and use that to create a `similar`
# output array with MPI information
A = find_mpi_array(bc)
ranocha marked this conversation as resolved.
Show resolved Hide resolved
similar(A, axes(bc))
end
# `A = find_mpi_array(As)` returns the first TrixiMPIArray among the arguments
find_mpi_array(bc::Base.Broadcast.Broadcasted) = find_mpi_array(bc.args)
find_mpi_array(args::Tuple) = find_mpi_array(find_mpi_array(args[1]), Base.tail(args))
find_mpi_array(x) = x
find_mpi_array(::Tuple{}) = nothing
find_mpi_array(a::TrixiMPIArray, rest) = a
find_mpi_array(::Any, rest) = find_mpi_array(rest)


# Implementation of methods from ArrayInterface.jl for use with
# LoopVectorization.jl etc.
# See https://juliaarrays.github.io/ArrayInterface.jl/stable/
ArrayInterface.parent_type(::Type{TrixiMPIArray{T, N, Parent}}) where {T, N, Parent} = Parent


# TODO: MPI. For now, we do not implement specializations of LinearAlgebra
# functions such as `norm` or `dot`. We might revisit this again
# in the future.


# `mapreduce` functionality from Base using global reductions via MPI communication
ranocha marked this conversation as resolved.
Show resolved Hide resolved
# for use in, e.g., error-based step size control in OrdinaryDiffEq.jl
ranocha marked this conversation as resolved.
Show resolved Hide resolved
function Base.mapreduce(f::F, op::Op, u::TrixiMPIArray; kwargs...) where {F, Op}
local_value = mapreduce(f, op, parent(u); kwargs...)
return MPI.Allreduce(local_value, op, mpi_comm(u))
end


# Default settings of OrdinaryDiffEq etc.
# Interesting options could be
# - ODE_DEFAULT_UNSTABLE_CHECK
# - ODE_DEFAULT_ISOUTOFDOMAIN (disabled by default)
# - ODE_DEFAULT_NORM
# See https://github.com/SciML/DiffEqBase.jl/blob/master/src/common_defaults.jl
#
# Problems and inconsistencies with possible global `length`s of TrixiMPIArrays
#
# A basic question is how to handle `length`. We want `TrixiMPIArray`s to behave
# like regular `Array`s in most code, e.g., when looping over an array (which
# should use `eachindex`). At the same time, we want to be able to use adaptive
# time stepping using error estimates in OrdinaryDiffEq.jl. There, the default
# norm `ODE_DEFAULT_NORM` is the one described in the book of Hairer & Wanner,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just thinking - can we avoid this local_length issue if we define a norm function that works in parallel? That might be an alternative to having to remember to use local_length.

A potential downside of local_length - that I just noticed - is that it allows users to create code that works in serial but may fail in spectacularly surprising ways if run in parallel. That is, if someone uses length where local_length is required, it works fine in serial but may cause weird issues in parallel (especially if running with --check-bounds=no).

Copy link
Member Author

@ranocha ranocha Apr 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's the issue of the minimally invasive approach using a global length. However, I would argue that users should better use eachindex in most cases, which is fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I agree - eachindex should be used where possible. It makes, however, for difficult-to-understand errors, and the "wrong" use of length might be hard to spot in reviews. I suggest to continue making it work, but then we should revisit this (or at least capture it in an issue).

Copy link
Member Author

@ranocha ranocha Apr 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another alternative would be to write our own norm function and pass that as solve(ode, alg; kwargs..., internalnorm=our_new_norm_function). However, that requires yet another keyword argument we need to remember.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, especially since it would fail very late during the initialization (or even worse, just hang) if forgotten. Maybe we need our own trixi_solve that passes some default options to OrdinaryDiffEq.jl's solve?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either this or set up some trixi_default_kwargs()?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That might be better. We don't need to solve this right now, though, do we? Maybe we just copy the current discussion to an issue and deal with it later, once we have some more experience with the new type.

Copy link
Member Author

@ranocha ranocha Apr 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sounds good to me. I'll leave this thread open and we can continue the discussion later (#1108).

# i.e., it includes a division by the global `length` of the array. We could
# specialize `ODE_DEFAULT_NORM` accordingly, but that requires depending on
# DiffEqBase.jl (instead of SciMLBase.jl). Alternatively, we could implement
# this via Requires.jl, but that will prevent precompilation and maybe trigger
# invalidations. Alternatively, could implement our own norm and pass that as
# `internalnorm`. We could also try to use the least intrusive approach for now
# and specialize `length` to return a global length while making sure that all
# local behavior is still working as expected (if using `eachindex` instead of
# `1:length` etc.). This means that we would have
# `eachindex(u) != Base.OneTo(length(u))` for `u::TrixiMPIArray` in general,
# even if `u` and its underlying array use one-based indexing.
# Some consequences are that we need to implement specializations of `show`,
# since the default ones call `length`. However, this doesn't work if not all
# ranks call the same method, e.g., when showing an array only on one rank.
# Moreover, we would need to specialize `copyto!` and probably many other
# functions. Since this can lead to hard-to-find bugs and problems in MPI code,
# we use a more verbose approach. Thus, we let `length` be a local `length` and
# provide a new function `Trixi.ode_norm` to be passed as `internalnorm` of
# OrdinaryDiffEq's `solve` function.


# Specialization of `view`. Without these, `view`s of arrays returned by
# `wrap_array` with multiple conserved variables do not always work...
# This may also be related to the use of a global `length`?
Base.view(u::TrixiMPIArray, idx::Vararg{Any,N}) where {N} = view(parent(u), idx...)


"""
ode_norm(u, t)

Implementation of the weighted L2 norm of Hairer and Wanner used for error-based
step size control in OrdinaryDiffEq.jl. This function is aware of
[`TrixiMPIArray`](@ref)s, handling them appropriately with global MPI
communication.

You must pass this function as keyword argument
`internalnorm=Trixi.ode_norm`
of `solve` when using error-based step size control with MPI parallel execution
of Trixi.jl.
"""
ode_norm(u, t) = @fastmath abs(u)
sloede marked this conversation as resolved.
Show resolved Hide resolved
ode_norm(u::AbstractArray, t) = sqrt(sum(abs2, u) / length(u))
function ode_norm(u::TrixiMPIArray, t)
local_sumabs2 = sum(abs2, parent(u))
local_length = length(parent(u))
# TODO: MPI. This could be fused into one call to improve parallel performance.
global_sumabs2 = MPI.Allreduce(local_sumabs2, +, mpi_comm(u))
global_length = MPI.Allreduce(local_length, +, mpi_comm(u))
ranocha marked this conversation as resolved.
Show resolved Hide resolved
return sqrt(global_sumabs2 / global_length)
end


"""
ode_unstable_check(dt, u, semi, t)

Implementation of the basic check for instability used in OrdinaryDiffEq.jl.
Instead of checking something like `any(isnan, u)`, this function just checks
`isnan(dt)`. This helps when using [`TrixiMPIArray`](@ref)s, since no additional
global communication is required and all ranks will return the same result.

You should pass this function as keyword argument
`unstable_check=Trixi.ode_unstable_check`
of `solve` when using error-based step size control with MPI parallel execution
of Trixi.jl.
"""
ode_unstable_check(dt, u, semi, t) = isnan(dt)
sloede marked this conversation as resolved.
Show resolved Hide resolved


end # module

using .TrixiMPIArrays: TrixiMPIArrays, TrixiMPIArray, ode_norm, ode_unstable_check
4 changes: 2 additions & 2 deletions src/callbacks_step/amr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::TreeMesh,
@unpack to_refine, to_coarsen = amr_callback.amr_cache
empty!(to_refine)
empty!(to_coarsen)
for element in 1:length(lambda)
for element in eachindex(lambda)
controller_value = lambda[element]
if controller_value > 0
push!(to_refine, leaf_cell_ids[element])
Expand Down Expand Up @@ -282,7 +282,7 @@ function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::TreeMesh,
end

# Extract only those parent cells for which all children should be coarsened
to_coarsen = collect(1:length(parents_to_coarsen))[parents_to_coarsen .== 2^ndims(mesh)]
to_coarsen = collect(eachindex(parents_to_coarsen))[parents_to_coarsen .== 2^ndims(mesh)]

# Finally, coarsen mesh
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh.tree, to_coarsen)
Expand Down
21 changes: 11 additions & 10 deletions src/callbacks_step/analysis_dg2d_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,18 @@ function calc_error_norms(func, u, t, analyzer,
end


function integrate_via_indices(func::Func, u,
mesh::ParallelTreeMesh{2}, equations, dg::DGSEM, cache,
# We need to dispatch on `u::TrixiMPIArray` instead of `mesh::TreeMesh{2}` to
# simply use `parent(u)` instead of some `invoke` call.
function integrate_via_indices(func::Func, u::TrixiMPIArray,
mesh::TreeMesh{2}, equations, dg::DGSEM, cache,
args...; normalize=true) where {Func}
# call the method accepting a general `mesh::TreeMesh{2}`
# TODO: MPI, we should improve this; maybe we should dispatch on `u`
# and create some MPI array type, overloading broadcasting and mapreduce etc.
# Then, this specific array type should also work well with DiffEq etc.
local_integral = invoke(integrate_via_indices,
Tuple{typeof(func), typeof(u), TreeMesh{2}, typeof(equations),
typeof(dg), typeof(cache), map(typeof, args)...},
func, u, mesh, equations, dg, cache, args..., normalize=normalize)
# Call the method for the local degrees of freedom and perform a global
# MPI reduction afterwards.
# Note that the simple `TreeMesh` implements an efficient way to compute
# the global volume without requiring communication. This global volume is
# already used when `normalize=true`.
local_integral = integrate_via_indices(func, parent(u), mesh, equations, dg,
cache, args...; normalize)

# OBS! Global results are only calculated on MPI root, all other domains receive `nothing`
global_integral = MPI.Reduce!(Ref(local_integral), +, mpi_root(), mpi_comm())
Expand Down
4 changes: 3 additions & 1 deletion src/callbacks_step/save_restart_dg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ function load_restart_file(mesh::Union{SerialTreeMesh, StructuredMesh, Unstructu
return u_ode
end


# Note that we cannot dispatch on `u::TrixiMPIArray` since we use
# `wrap_array_native` before calling this method, loosing the MPI array
# wrapper type.
function save_restart_file(u, time, dt, timestep,
mesh::Union{ParallelTreeMesh, ParallelP4estMesh}, equations, dg::DG, cache,
restart_callback)
Expand Down
4 changes: 3 additions & 1 deletion src/callbacks_step/save_solution_dg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ function save_solution_file(u, time, dt, timestep,
return filename
end


# Note that we cannot dispatch on `u::TrixiMPIArray` since we use
# `wrap_array_native` before calling this method, loosing the MPI array
# wrapper type.
function save_solution_file(u, time, dt, timestep,
mesh::Union{ParallelTreeMesh, ParallelP4estMesh}, equations, dg::DG, cache,
solution_callback, element_variables=Dict{Symbol,Any}();
Expand Down
Loading