Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Base.min / Base.max in MPI reductions #2054

Merged
merged 11 commits into from
Sep 13, 2024
5 changes: 5 additions & 0 deletions src/auxiliary/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ end
# when using `@fastmath`, which we also get from
# [Fortran](https://godbolt.org/z/Yrsa1js7P)
# or [C++](https://godbolt.org/z/674G7Pccv).
#
# Note however that such a custom reimplementation can cause incompatibilities with other
# packages. Currently we are affected by an issue with MPI.jl on ARM, see
# https://github.com/trixi-framework/Trixi.jl/issues/1922
# The workaround is to resort to Base.min / Base.max when using MPI reductions.
"""
Trixi.max(x, y, ...)

Expand Down
3 changes: 2 additions & 1 deletion src/callbacks_step/analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ function (analysis_callback::AnalysisCallback)(io, du, u, u_ode, t, semi)
res = maximum(abs, view(du, v, ..))
if mpi_isparallel()
# TODO: Debugging, here is a type instability
global_res = MPI.Reduce!(Ref(res), max, mpi_root(), mpi_comm())
# Base.max needed, see comment in src/auxiliary/math.jl
benegee marked this conversation as resolved.
Show resolved Hide resolved
global_res = MPI.Reduce!(Ref(res), Base.max, mpi_root(), mpi_comm())
if mpi_isroot()
res::eltype(du) = global_res[]
end
Expand Down
3 changes: 2 additions & 1 deletion src/callbacks_step/analysis_dg2d_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ function calc_error_norms(func, u, t, analyzer,
global_l2_error = Vector(l2_error)
global_linf_error = Vector(linf_error)
MPI.Reduce!(global_l2_error, +, mpi_root(), mpi_comm())
MPI.Reduce!(global_linf_error, max, mpi_root(), mpi_comm())
# Base.max needed, see comment in src/auxiliary/math.jl
benegee marked this conversation as resolved.
Show resolved Hide resolved
MPI.Reduce!(global_linf_error, Base.max, mpi_root(), mpi_comm())
total_volume = MPI.Reduce(volume, +, mpi_root(), mpi_comm())
if mpi_isroot()
l2_error = convert(typeof(l2_error), global_l2_error)
Expand Down
3 changes: 2 additions & 1 deletion src/callbacks_step/analysis_dg3d_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ function calc_error_norms(func, u, t, analyzer,
global_l2_error = Vector(l2_error)
global_linf_error = Vector(linf_error)
MPI.Reduce!(global_l2_error, +, mpi_root(), mpi_comm())
MPI.Reduce!(global_linf_error, max, mpi_root(), mpi_comm())
# Base.max needed, see comment in src/auxiliary/math.jl
benegee marked this conversation as resolved.
Show resolved Hide resolved
MPI.Reduce!(global_linf_error, Base.max, mpi_root(), mpi_comm())
total_volume = MPI.Reduce(volume, +, mpi_root(), mpi_comm())
if mpi_isroot()
l2_error = convert(typeof(l2_error), global_l2_error)
Expand Down
18 changes: 12 additions & 6 deletions src/callbacks_step/stepsize_dg2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ function max_dt(u, t, mesh::ParallelTreeMesh{2},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min needed, see comment in src/auxiliary/math.jl
benegee marked this conversation as resolved.
Show resolved Hide resolved
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand All @@ -70,7 +71,8 @@ function max_dt(u, t, mesh::ParallelTreeMesh{2},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min needed, see comment in src/auxiliary/math.jl
benegee marked this conversation as resolved.
Show resolved Hide resolved
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand Down Expand Up @@ -154,7 +156,8 @@ function max_dt(u, t, mesh::ParallelP4estMesh{2},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min needed, see comment in src/auxiliary/math.jl
benegee marked this conversation as resolved.
Show resolved Hide resolved
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand All @@ -170,7 +173,8 @@ function max_dt(u, t, mesh::ParallelP4estMesh{2},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min needed, see comment in src/auxiliary/math.jl
benegee marked this conversation as resolved.
Show resolved Hide resolved
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand All @@ -186,7 +190,8 @@ function max_dt(u, t, mesh::ParallelT8codeMesh{2},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min needed, see comment in src/auxiliary/math.jl
benegee marked this conversation as resolved.
Show resolved Hide resolved
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand All @@ -202,7 +207,8 @@ function max_dt(u, t, mesh::ParallelT8codeMesh{2},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min needed, see comment in src/auxiliary/math.jl
benegee marked this conversation as resolved.
Show resolved Hide resolved
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand Down
12 changes: 8 additions & 4 deletions src/callbacks_step/stepsize_dg3d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ function max_dt(u, t, mesh::ParallelP4estMesh{3},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min needed, see comment in src/auxiliary/math.jl
benegee marked this conversation as resolved.
Show resolved Hide resolved
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand All @@ -146,7 +147,8 @@ function max_dt(u, t, mesh::ParallelP4estMesh{3},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min needed, see comment in src/auxiliary/math.jl
benegee marked this conversation as resolved.
Show resolved Hide resolved
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand All @@ -162,7 +164,8 @@ function max_dt(u, t, mesh::ParallelT8codeMesh{3},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min needed, see comment in src/auxiliary/math.jl
benegee marked this conversation as resolved.
Show resolved Hide resolved
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand All @@ -178,7 +181,8 @@ function max_dt(u, t, mesh::ParallelT8codeMesh{3},
typeof(constant_speed), typeof(equations), typeof(dg),
typeof(cache)},
u, t, mesh, constant_speed, equations, dg, cache)
dt = MPI.Allreduce!(Ref(dt), min, mpi_comm())[]
# Base.min needed, see comment in src/auxiliary/math.jl
benegee marked this conversation as resolved.
Show resolved Hide resolved
dt = MPI.Allreduce!(Ref(dt), Base.min, mpi_comm())[]

return dt
end
Expand Down
Loading