Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add matched recv functions #699

Merged
merged 5 commits into from
Jan 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/src/reference/pointtopoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,21 @@ MPI.Iprobe
MPI.Probe
```

### Persistent requests
## Persistent requests

```@docs
MPI.Send_init
MPI.Recv_init
MPI.Start
MPI.Startall
```

## Matching probes and receives

```@docs
MPI.Message
MPI.Mprobe
MPI.Improbe
MPI.Mrecv!
MPI.Imrecv!
```
135 changes: 133 additions & 2 deletions src/pointtopoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ Returns the deserialized object and optionally the [`Status`](@ref) of the recei
recv(comm::Comm, status=nothing; source::Integer=API.MPI_ANY_SOURCE[], tag::Integer=API.MPI_ANY_TAG[]) =
recv(source, tag, comm, status)
function recv(source::Integer, tag::Integer, comm::Comm, status::Union{Ref{Status}, Nothing})
stat = Probe(comm, Status; source=source, tag=tag)
msg, stat = Mprobe(comm, Status; source=source, tag=tag)
count = Get_count(stat, UInt8)
buf = Array{UInt8}(undef, count)
stat = Recv!(buf, comm, status; source=Get_source(stat), tag=Get_tag(stat))
Mrecv!(buf, msg, status)
return MPI.deserialize(buf)
end
function recv(source::Integer, tag::Integer, comm::Comm, ::Type{Status})
Expand Down Expand Up @@ -319,3 +319,134 @@ function Startall(reqs::Union{AbstractMultiRequest, RequestSet})
update!(reqs)
return nothing
end


"""
MPI.Message

An MPI message handle object, used by matched receive operations. These are
returned by [`MPI.Mprobe`](@ref) and [`MPI.Improbe`](@ref) operations, and must
be received by either [`MPI.Mrecv!`](@ref) or [`MPI.Imrecv!`](@ref).
"""
mutable struct Message
val::MPI_Message
end
Base.unsafe_convert(::Type{Ptr{MPI_Message}}, msg::Message) = convert(Ptr{MPI_Message}, pointer_from_objref(msg))

Message() = Message(API.MPI_MESSAGE_NULL[])
isnull(msg::Message) = msg.val == API.MPI_MESSAGE_NULL[]

"""
ismsg, msg = MPI.Improbe(comm::MPI.Comm;
source::Integer=MPI.ANY_SOURCE, tag::Integer=MPI.ANY_TAG)
ismsg, msg, status = MPI.Improbe(comm::MPI.Comm, MPI.Status;
source::Integer=MPI.ANY_SOURCE, tag::Integer=MPI.ANY_TAG)

Matching non-blocking probe. Similar to [`MPI.Iprobe`](@ref), except that it
also returns `msg`, an [`MPI.Message`](@ref) object.

Checks if there is a message that can be received matching `source`, `tag` and
`comm`. If so, returns `ismsg = true`, and a [`Message`](@ref) objec `msg`,
which must be received by either [`MPI.Mrecv!`](@ref) or [`MPI.Imrecv!`](@ref).
Otherwise `msg` is set to be a null `Message`.

The `Status` argument additionally returns the [`Status`](@ref) of the completed
request.

# External links
$(_doc_external("MPI_Improbe"))
"""
Improbe(comm::Comm, status=nothing; source::Integer=API.MPI_ANY_SOURCE[], tag::Integer=API.MPI_ANY_TAG[]) =
Improbe(source, tag, comm, status)
function Improbe(source::Integer, tag::Integer, comm::Comm, status::Union{Ref{Status}, Nothing})
flag = Ref{Cint}()
msg = Message()
API.MPI_Improbe(source, tag, comm, flag, msg, something(status, API.MPI_STATUS_IGNORE[]))
return flag[] != 0, msg
end
function Improbe(source::Integer, tag::Integer, comm::Comm, ::Type{Status})
status = Ref(STATUS_ZERO)
ismsg, msg = Improbe(source, tag, comm, status)
return ismsg, msg, status[]
end

"""
msg = MPI.Mprobe(comm::MPI.Comm;
source::Integer=MPI.ANY_SOURCE, tag::Integer=MPI.ANY_TAG)
msg, status = MPI.Mprobe(comm::MPI.Comm, MPI.Status;
source::Integer=MPI.ANY_SOURCE, tag::Integer=MPI.ANY_TAG)

Matching blocking probe. Similar to [`MPI.Probe`](@ref), except that it also
returns `msg`, an [`MPI.Message`](@ref) object.

Blocks until a message that can be received matching `source`, `tag` and `comm`,
returning a [`Message`](@ref) objec `msg`, which must be received by either
[`MPI.Mrecv!`](@ref) or [`MPI.Imrecv!`](@ref).

The `Status` argument additionally returns the [`Status`](@ref) of the completed
request.

# External links
$(_doc_external("MPI_Mprobe"))
"""
Mprobe(comm::Comm, status=nothing; source::Integer=API.MPI_ANY_SOURCE[], tag::Integer=API.MPI_ANY_TAG[]) =
Mprobe(source, tag, comm, status)
function Mprobe(source::Integer, tag::Integer, comm::Comm, status::Union{Ref{Status}, Nothing})
msg = Message()
API.MPI_Mprobe(source, tag, comm, msg, something(status, API.MPI_STATUS_IGNORE[]))
return msg
end
function Mprobe(source::Integer, tag::Integer, comm::Comm, ::Type{Status})
status = Ref(STATUS_ZERO)
msg = Mprobe(source, tag, comm, status)
return msg, status[]
end

"""
data = MPI.Mrecv!(recvbuf, msg::MPI.Message)
data, status = MPI.Mrecv!(recvbuf, msg::MPI.Message, MPI.Status)

Completes a blocking receive matched by a matching probe operation into the
buffer `recvbuf`, and the [`Message`](@ref) `msg`.

`recvbuf` can be a [`Buffer`](@ref), or any object for which `Buffer(recvbuf)`
is defined.

Optionally returns the [`Status`](@ref) object of the receive.

# External links
$(_doc_external("MPI_Mrecv"))
"""
function Mrecv!(recvbuf::Buffer, msg::Message, status::Union{Ref{Status},Nothing}=nothing)
API.MPI_Mrecv(recvbuf.data, recvbuf.count, recvbuf.datatype, msg, something(status, API.MPI_STATUS_IGNORE[]))
return recvbuf.data
end
Mrecv!(recvbuf, msg::Message, status::Union{Ref{Status},Nothing}=nothing) =
Mrecv!(Buffer(recvbuf), msg, status)
function Mrecv!(recvbuf,msg::Message, ::Type{Status})
status = Ref(STATUS_ZERO)
data = Mrecv!(recvbuf, msg, status)
return data, status[]
end

"""
req = MPI.Imrecv!(recvbuf, msg::MPI.Message[, req::AbstractRequest=Request()])

Starts a nonblocking receive matched by a matching probe operation into the
buffer `recvbuf`, and the [`Message`](@ref) `msg`.

`recvbuf` can be a [`Buffer`](@ref), or any object for which `Buffer(recvbuf)` is defined.

Returns `req`, an [`AbstractRequest`](@ref) object for the nonblocking receive.

# External links
$(_doc_external("MPI_Imrecv"))
"""
function Imrecv!(buf::Buffer, msg::Message, req::AbstractRequest=Request())
@assert isnull(req)
API.MPI_Imrecv(buf.data, buf.count, buf.datatype, msg, req)
setbuffer!(req, buf)
return req
end
Imrecv!(data, msg::Message, req::AbstractRequest=Request()) =
Imrecv!(Buffer(data), msg, req)
63 changes: 63 additions & 0 deletions test/test_matched.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
include("common.jl")

MPI.Init()

comm = MPI.COMM_WORLD
size = MPI.Comm_size(comm)
rank = MPI.Comm_rank(comm)

dst = mod(rank+1, size)
src = mod(rank-1, size)

N = 32

sendbuf1 = ArrayType(Float64[1 for i = 1:N])
sendbuf2 = ArrayType(Float64[2 for i = 1:N+2])
sendbuf3 = ArrayType(Float64[3 for i = 1:N])
sendbuf4 = ArrayType(Float64[3 for i = 1:N+2])

synchronize()

sendreq1 = MPI.Isend(sendbuf1, comm; dest=dst, tag=rank+32)
sendreq2 = MPI.Isend(sendbuf2, comm; dest=dst, tag=rank+32)
sendreq3 = MPI.Isend(sendbuf3, comm; dest=dst, tag=rank+32)
sendreq4 = MPI.Isend(sendbuf4, comm; dest=dst, tag=rank+32)

recvmsg1 = MPI.Mprobe(comm; source=src)
while true
global recvmsg2
ismsg, recvmsg2 = MPI.Improbe(comm; source=src)
if ismsg
break
end
end
recvmsg3, status3 = MPI.Mprobe(comm, MPI.Status; source=src, tag=src+32)
@test MPI.Get_count(status3, Float64) == N

while true
global status4, recvmsg4
ismsg, recvmsg4, status4 = MPI.Improbe(comm, MPI.Status; source=src)
if ismsg
break
end
end
@test MPI.Get_count(status4, Float64) == N + 2

recvbuf1 = ArrayType{Float64}(undef, N)
recvbuf2 = ArrayType{Float64}(undef, N+2)
recvbuf3 = ArrayType{Float64}(undef, N)
recvbuf4 = ArrayType{Float64}(undef, N+2)

MPI.Mrecv!(recvbuf4, recvmsg4)
MPI.Mrecv!(recvbuf3, recvmsg3)
recvreq2 = MPI.Imrecv!(recvbuf2, recvmsg2)
recvreq1 = MPI.Imrecv!(recvbuf1, recvmsg1)

MPI.Waitall([sendreq1, sendreq2, sendreq3, sendreq4, recvreq1, recvreq2])

@test recvbuf1 == sendbuf1
@test recvbuf2 == sendbuf2
@test recvbuf3 == sendbuf3
@test recvbuf4 == sendbuf4

MPI.Finalize()