diff --git a/deps/gen_consts.jl b/deps/gen_consts.jl index f2db4680b..550eb1d37 100644 --- a/deps/gen_consts.jl +++ b/deps/gen_consts.jl @@ -78,6 +78,9 @@ MPI_handle = [ :MPI_Request => [ :MPI_REQUEST_NULL, ], + :MPI_Message => [ + :MPI_MESSAGE_NULL, + ], :MPI_Op => MPI_op_consts, :MPI_Datatype => MPI_datatype_consts ] diff --git a/src/pointtopoint.jl b/src/pointtopoint.jl index 4b350da10..2069d19fe 100644 --- a/src/pointtopoint.jl +++ b/src/pointtopoint.jl @@ -55,8 +55,12 @@ Get_error(status::Status) = Int(status.error) @mpi_handle Request buffer +@mpi_handle Message + const REQUEST_NULL = _Request(MPI_REQUEST_NULL, nothing) Request() = Request(REQUEST_NULL.val, nothing) +const MESSAGE_NULL = _Message(MPI_MESSAGE_NULL) +Message() = Message(MESSAGE_NULL.val) function Probe(src::Integer, tag::Integer, comm::Comm) stat_ref = Ref{Status}() @@ -78,6 +82,34 @@ function Iprobe(src::Integer, tag::Integer, comm::Comm) true, stat_ref[] end +function Improbe(src::Integer, tag::Integer, comm::Comm) + flag = Ref{Cint}() + stat_ref = Ref{Status}() + message = Message() + @mpichk ccall((:MPI_Improbe, libmpi), Cint, + (Cint, Cint, MPI_Comm, Ptr{Cint}, Ptr{MPI_Message}, Ptr{Status}), + src, tag, comm, flag, message, stat_ref) + if flag[] == 0 + return false, nothing, nothing + end + true, message, stat_ref[] +end + +""" + Mprobe(src::Integer, tag::Integer, comm::Comm) where T + +Blocking matched probe for a message on communicator `comm` using the message + tag `tag` +""" +function Mprobe(src::Integer, tag::Integer, comm::Comm) + stat_ref = Ref{Status}() + message = Message() + @mpichk ccall((:MPI_Mprobe, libmpi), Cint, + (Cint, Cint, MPI_Comm, Ptr{MPI_Message}, Ptr{Status}), + src, tag, comm, message, stat_ref) + message, stat_ref[] +end + function Get_count(stat::Status, ::Type{T}) where T count = Ref{Cint}() @mpichk ccall((:MPI_Get_count, libmpi), Cint, @@ -275,13 +307,31 @@ function Recv(::Type{T}, src::Integer, tag::Integer, comm::Comm) where T end function recv(src::Integer, tag::Integer, comm::Comm) - stat = Probe(src, tag, comm) + mess, stat = Mprobe(src, tag, comm) count = Get_count(stat, UInt8) buf = Array{UInt8}(undef, count) - stat = Recv!(buf, Get_source(stat), Get_tag(stat), comm) + stat = Mrecv!(buf, mess) (MPI.deserialize(buf), stat) end +""" + Mrecv!(buf::MPIBuffertype{T}, message::Message) where T + +Starts a blocking receive of a message specified by `message`, obtained by +Mprobe. + +Returns the communication `Request` for the nonblocking receive. +""" +function Mrecv!(buf::MPIBuffertype{T}, message::Message) where T + stat_ref = Ref{Status}() + # int MPI_Irecv(void* buf, int count, MPI_Datatype datatype, int source, + # int tag, MPI_Comm comm, MPI_Request *request) + @mpichk ccall((:MPI_Mrecv, libmpi), Cint, + (Ptr{T}, Cint, MPI_Datatype, Ptr{MPI_Message}, Ptr{Status}), + buf, length(buf), mpitype(T), message, stat_ref) + stat_ref[] +end + """ Irecv!(buf::MPIBuffertype{T}, count::Integer, datatype::Datatype, src::Integer, tag::Integer, comm::Comm) where T @@ -313,7 +363,7 @@ from MPI rank `src` of communicator `comm` using with the message tag `tag` Returns the communication `Request` for the nonblocking receive. """ function Irecv!(buf, count::Integer, - src::Integer, tag::Integer, comm::Comm) + src::Integer, tag::Integer, comm::Comm) Irecv!(buf, count, mpitype(eltype(buf)), src, tag, comm) end @@ -331,13 +381,13 @@ function Irecv!(buf::AbstractArray{T}, src::Integer, tag::Integer, end function irecv(src::Integer, tag::Integer, comm::Comm) - (flag, stat) = Iprobe(src, tag, comm) + (flag, mess, stat) = Improbe(src, tag, comm) if !flag return (false, nothing, nothing) end count = Get_count(stat, UInt8) buf = Array{UInt8}(undef, count) - stat = Recv!(buf, Get_source(stat), Get_tag(stat), comm) + stat = MRecv!(buf, mess) (true, MPI.deserialize(buf), stat) end