Skip to content

Commit

Permalink
Feature support testing for IAllreduce and IReduce for GPU backends
Browse files Browse the repository at this point in the history
  • Loading branch information
Keluaa committed Nov 18, 2024
1 parent 9620a4b commit 311629f
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 15 deletions.
28 changes: 28 additions & 0 deletions test/mpi_support_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
include("common.jl")

MPI.Init()

# Those MPI calls may be unsupported features (e.g. for GPU backends), and will raise SIGSEGV
# (or a similar signal) when called, which cannot be handled in Julia in a portable way.

op = ARGS[1]
if op == "IAllreduce"
# IAllreduce is unsupported for CUDA with OpenMPI + UCX
# See https://docs.open-mpi.org/en/main/tuning-apps/networking/cuda.html#which-mpi-apis-do-not-work-with-cuda-aware-ucx
send_arr = ArrayType(zeros(Int, 1))
recv_arr = ArrayType{Int}(undef, 1)
synchronize()
req = MPI.IAllreduce!(send_arr, recv_arr, +, MPI.COMM_WORLD)
MPI.Wait(req)

elseif op == "IReduce"
# IAllreduce is unsupported for CUDA with OpenMPI + UCX
send_arr = ArrayType(zeros(Int, 1))
recv_arr = ArrayType{Int}(undef, 1)
synchronize()
req = MPI.IReduce!(send_arr, recv_arr, +, MPI.COMM_WORLD; root=0)
MPI.Wait(req)

else
error("unknown test: $op")
end
13 changes: 13 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ if Sys.isunix()
include("mpiexecjl.jl")
end

function is_mpi_operation_supported(mpi_op, n=nprocs)
test_file = joinpath(@__DIR__, "mpi_support_test.jl")
cmd = `$(mpiexec()) -n $n $(Base.julia_cmd()) --startup-file=no $test_file $mpi_op`
supported = success(run(ignorestatus(cmd)))
!supported && @warn "$mpi_op is unsupported with $backend_name"
return supported
end

if ArrayType != Array # we expect that only GPU backends can have unsupported features
ENV["JULIA_MPI_TEST_IALLREDUCE"] = is_mpi_operation_supported("IAllreduce")
ENV["JULIA_MPI_TEST_IREDUCE"] = is_mpi_operation_supported("IReduce")
end

excludefiles = split(get(ENV,"JULIA_MPI_TEST_EXCLUDE",""),',')

testdir = @__DIR__
Expand Down
19 changes: 13 additions & 6 deletions test/test_allreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ else
operators = [MPI.SUM, +, (x,y) -> 2x+y-x]
end

iallreduce_supported = get(ENV, "JULIA_MPI_TEST_IALLREDUCE", "true") == "true"


for T = [Int]
for dims = [1, 2, 3]
send_arr = ArrayType(zeros(T, Tuple(3 for i in 1:dims)))
Expand Down Expand Up @@ -46,16 +49,20 @@ for T = [Int]

# Nonblocking
recv_arr = ArrayType{T}(undef, size(send_arr))
req = MPI.IAllreduce!(send_arr, recv_arr, op, MPI.COMM_WORLD)
MPI.Wait(req)
@test recv_arr == comm_size .* send_arr
if iallreduce_supported
req = MPI.IAllreduce!(send_arr, recv_arr, op, MPI.COMM_WORLD)
MPI.Wait(req)
end
@test recv_arr == comm_size .* send_arr skip=!iallreduce_supported

# Nonblocking (IN_PLACE)
recv_arr = copy(send_arr)
synchronize()
req = MPI.IAllreduce!(recv_arr, op, MPI.COMM_WORLD)
MPI.Wait(req)
@test recv_arr == comm_size .* send_arr
if iallreduce_supported
req = MPI.IAllreduce!(recv_arr, op, MPI.COMM_WORLD)
MPI.Wait(req)
end
@test recv_arr == comm_size .* send_arr skip=!iallreduce_supported
end
end
end
Expand Down
26 changes: 17 additions & 9 deletions test/test_reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ const can_do_closures =
Sys.ARCH !== :aarch64 &&
!startswith(string(Sys.ARCH), "arm")

ireduce_supported = get(ENV, "JULIA_MPI_TEST_IREDUCE", "true") == "true"

using DoubleFloats

MPI.Init()
Expand Down Expand Up @@ -119,18 +121,22 @@ for T = [Int]

# Nonblocking
recv_arr = ArrayType{T}(undef, size(send_arr))
req = MPI.IReduce!(send_arr, recv_arr, op, MPI.COMM_WORLD; root=root)
MPI.Wait(req)
if ireduce_supported
req = MPI.IReduce!(send_arr, recv_arr, op, MPI.COMM_WORLD; root=root)
MPI.Wait(req)
end
if isroot
@test recv_arr == sz .* send_arr
@test recv_arr == sz .* send_arr skip=!ireduce_supported
end

# Nonblocking (IN_PLACE)
recv_arr = copy(send_arr)
req = MPI.IReduce!(recv_arr, op, MPI.COMM_WORLD; root=root)
MPI.Wait(req)
if ireduce_supported
req = MPI.IReduce!(recv_arr, op, MPI.COMM_WORLD; root=root)
MPI.Wait(req)
end
if isroot
@test recv_arr == sz .* send_arr
@test recv_arr == sz .* send_arr skip=!ireduce_supported
end
end
end
Expand All @@ -148,10 +154,12 @@ else
end

recv_arr = isroot ? zeros(eltype(send_arr), size(send_arr)) : nothing
req = MPI.IReduce!(send_arr, recv_arr, +, MPI.COMM_WORLD; root=root)
MPI.Wait(req)
if ireduce_supported
req = MPI.IReduce!(send_arr, recv_arr, +, MPI.COMM_WORLD; root=root)
MPI.Wait(req)
end
if rank == root
@test recv_arr [Double64(sz*i)/10 for i = 1:10] rtol=sz*eps(Double64)
@test recv_arr [Double64(sz*i)/10 for i = 1:10] rtol=sz*eps(Double64) skip=!ireduce_supported
end

MPI.Barrier( MPI.COMM_WORLD )
Expand Down

0 comments on commit 311629f

Please sign in to comment.