Skip to content

Commit

Permalink
Use Base.min / Base.max in MPI reductions (#2054)
Browse files Browse the repository at this point in the history
* use Base.min/max in MPI.Allreduce

MPI.jl's reduce currently does not work for custom operators (such as Trixi's
min/max) on ARM

* add comments

* explain workdaround

* typo

* Apply suggestions from code review

Co-authored-by: Hendrik Ranocha <[email protected]>

* switch to macos-latest in mpi tests

* remove arch specification for macos-latest

macos-latest is 14, which is ARM

* readd arch, required by julia-actions/setup-julia

* back to macos-13 and x64

---------

Co-authored-by: Hendrik Ranocha <[email protected]>
  • Loading branch information
benegee and ranocha authored Sep 13, 2024
1 parent 56d5420 commit 148dd67
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 13 deletions.
5 changes: 5 additions & 0 deletions src/auxiliary/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ end
# when using `@fastmath`, which we also get from
# [Fortran](https://godbolt.org/z/Yrsa1js7P)
# or [C++](https://godbolt.org/z/674G7Pccv).
#
# Note however that such a custom reimplementation can cause incompatibilities with other
# packages. Currently we are affected by an issue with MPI.jl on ARM, see
# https://github.com/trixi-framework/Trixi.jl/issues/1922
# The workaround is to resort to Base.min / Base.max when using MPI reductions.
"""
Trixi.max(x, y, ...)
Expand Down
3 changes: 2 additions & 1 deletion src/callbacks_step/analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ function (analysis_callback::AnalysisCallback)(io, du, u, u_ode, t, semi)
res = maximum(abs, view(du, v, ..))
if mpi_isparallel()
# TODO: Debugging, here is a type instability
global_res = MPI.Reduce!(Ref(res), max, mpi_root(), mpi_comm())
# Base.max instead of max needed, see comment in src/auxiliary/math.jl
global_res = MPI.Reduce!(Ref(res), Base.max, mpi_root(), mpi_comm())
if mpi_isroot()
res::eltype(du) = global_res[]
end
Expand Down
3 changes: 2 additions & 1 deletion src/callbacks_step/analysis_dg2d_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ function calc_error_norms(func, u, t, analyzer,
global_l2_error = Vector(l2_error)
global_linf_error = Vector(linf_error)
MPI.Reduce!(global_l2_error, +, mpi_root(), mpi_comm())
MPI.Reduce!(global_linf_error, max, mpi_root(), mpi_comm())
# Base.max instead of max needed, see comment in src/auxiliary/math.jl
MPI.Reduce!(global_linf_error, Base.max, mpi_root(), mpi_comm())
total_volume = MPI.Reduce(volume, +, mpi_root(), mpi_comm())
if mpi_isroot()
l2_error = convert(typeof(l2_error), global_l2_error)
Expand Down
3 changes: 2 additions & 1 deletion src/callbacks_step/analysis_dg3d_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ function calc_error_norms(func, u, t, analyzer,
global_l2_error = Vector(l2_error)
global_linf_error = Vector(linf_error)
MPI.Reduce!(global_l2_error, +, mpi_root(), mpi_comm())
MPI.Reduce!(global_linf_error, max, mpi_root(), mpi_comm())
# Base.max instead of max needed, see comment in src/auxiliary/math.jl
MPI.Reduce!(global_linf_error, Base.max, mpi_root(), mpi_comm())
total_volume = MPI.Reduce(volume, +, mpi_root(), mpi_comm())
if mpi_isroot()
l2_error = convert(typeof(l2_error), global_l2_error)
Expand Down
18 changes: 12 additions & 6 deletions src/callbacks_step/stepsize_dg2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ function max_dt(u, t, mesh::ParallelTreeMesh{2},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand All @@ -70,7 +71,8 @@ function max_dt(u, t, mesh::ParallelTreeMesh{2},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand Down Expand Up @@ -154,7 +156,8 @@ function max_dt(u, t, mesh::ParallelP4estMesh{2},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand All @@ -170,7 +173,8 @@ function max_dt(u, t, mesh::ParallelP4estMesh{2},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand All @@ -186,7 +190,8 @@ function max_dt(u, t, mesh::ParallelT8codeMesh{2},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand All @@ -202,7 +207,8 @@ function max_dt(u, t, mesh::ParallelT8codeMesh{2},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand Down
12 changes: 8 additions & 4 deletions src/callbacks_step/stepsize_dg3d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ function max_dt(u, t, mesh::ParallelP4estMesh{3},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand All @@ -146,7 +147,8 @@ function max_dt(u, t, mesh::ParallelP4estMesh{3},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand All @@ -162,7 +164,8 @@ function max_dt(u, t, mesh::ParallelT8codeMesh{3},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand All @@ -178,7 +181,8 @@ function max_dt(u, t, mesh::ParallelT8codeMesh{3},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min instead of min needed, see comment in src/auxiliary/math.jl
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand Down

0 comments on commit 148dd67

Please sign in to comment.