Skip to content

Commit

Permalink
Implement automatically serializing scatter (#758)
Browse files Browse the repository at this point in the history
* Implement automatically serializing `scatter`
* scatter: optimize serialization
  • Loading branch information
lukas-weber authored Jul 30, 2023
1 parent 48dd78c commit bfaa7ca
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/src/reference/collective.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ MPI.Neighbor_allgatherv!
```@docs
MPI.Scatter!
MPI.Scatter
MPI.scatter
MPI.Scatterv!
```

Expand Down
46 changes: 46 additions & 0 deletions src/collective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,52 @@ Scatter(sendbuf, T, comm; root::Integer=Cint(0)) =
Scatter(sendbuf, ::Type{T}, root::Integer, comm::Comm) where {T} =
Scatter!(sendbuf, Ref{T}(), root, comm)[]

"""
scatter(objs::Union{AbstractVector, Nothing}, comm::Comm; root::Integer=0)
Sends the `j`-th element of `objs` in the `root` process to rank `j-1` and returns it. On `root`, `objs` is expected to be a `Comm_size(comm)`-element vector. On the other ranks, it is ignored and can be `nothing`.
This method can handle arbitrary data.
# See also
- [`Scatter!`](@ref)
"""
function scatter(objs::Union{AbstractVector, Nothing}, comm::Comm; root::Integer=0)
isroot = Comm_rank(comm) == root

if isroot
if length(objs) != Comm_size(comm)
throw(ArgumentError("Length of argument objs ($(length(objs))) != number of ranks in comm ($(Comm_size(comm)))."))
end

sendbuffer = IOBuffer()
counts = Vector{Int64}(undef, length(objs))

last_pos = 0
for (i, obj) in enumerate(objs)
Serialization.serialize(sendbuffer, i == root + 1 ? nothing : obj)
counts[i] = position(sendbuffer) - last_pos
last_pos = position(sendbuffer)
end

count = Scatter(counts, Int64, comm; root = root)
sendbuf = VBuffer(take!(sendbuffer), counts)

Scatterv!(sendbuf, IN_PLACE, comm; root = root)
return objs[root + 1]
else
count = Scatter(nothing, Int64, comm; root = root)

data = Array{UInt8}(undef, count)
recvbuf = Buffer(data)

Scatterv!(nothing, recvbuf, comm; root = root)
return MPI.deserialize(recvbuf.data)
end
end


"""
Scatterv!(sendbuf, recvbuf, comm::Comm; root::Integer=0)
Expand Down
17 changes: 17 additions & 0 deletions test/test_scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,31 @@ for T in MPITestTypes
end
@test Array(B)[1] == T(rank+1)

B = MPI.scatter(A, comm; root = root)
@test B == T(rank+1)

# Test throwing
if isroot
B = ArrayType{T}(undef, 0)
@test_throws DivideError MPI.Scatter!(A, B, comm; root=root)
B = ArrayType{T}(undef, 8)
@test_throws AssertionError MPI.Scatter!(A, B, comm; root=root)

wrong_length = ArrayType{T}(undef, size-1)
@test_throws ArgumentError MPI.scatter(wrong_length, comm; root=root)
end
end


objs = ["test", 1, Array{Int}, [1,"test"]]
objs_sized = [objs[mod1(i, length(objs))] for i = 1:size]

B = MPI.scatter(objs_sized, comm; root = root)
@test B == objs_sized[rank+1]
objs_gathered = MPI.gather(B, comm; root = root)
if isroot
@test objs_gathered == objs_sized
end

MPI.Finalize()
@test MPI.Finalized()

0 comments on commit bfaa7ca

Please sign in to comment.