diff --git a/docs/src/reference/pointtopoint.md b/docs/src/reference/pointtopoint.md index b54efcc9d..adff1172d 100644 --- a/docs/src/reference/pointtopoint.md +++ b/docs/src/reference/pointtopoint.md @@ -3,7 +3,11 @@ ## Types ```@docs +MPI.AbstractRequest MPI.Request +MPI.UnsafeRequest +MPI.MultiRequest +MPI.UnsafeMultiRequest MPI.RequestSet MPI.Status ``` @@ -59,6 +63,7 @@ MPI.Waitsome ### Probe/Cancel ```@docs +MPI.isnull MPI.Cancel! MPI.Iprobe MPI.Probe diff --git a/src/collective.jl b/src/collective.jl index ba535ced5..c546e7a6a 100644 --- a/src/collective.jl +++ b/src/collective.jl @@ -17,7 +17,7 @@ function Barrier(comm::Comm) end """ - Ibarrier(comm::Comm) + Ibarrier(comm::Comm[, req::AbstractRequest = Request()) Blocks until `comm` is synchronized. @@ -28,8 +28,8 @@ If `comm` is an intercommunicator, then it blocks until all members of the other # External links $(_doc_external("MPI_Ibarrier")) """ -function Ibarrier(comm::Comm) - req = Request() +function Ibarrier(comm::Comm, req::AbstractRequest = Request()) + @assert isnull(req) # int MPI_Ibarrier(MPI_Comm comm, MPI_Req req) API.MPI_Ibarrier(comm, req) return req diff --git a/src/nonblocking.jl b/src/nonblocking.jl index bc1087019..88240e953 100644 --- a/src/nonblocking.jl +++ b/src/nonblocking.jl @@ -75,44 +75,282 @@ Get_tag(status::Status) = Int(status.tag) Get_error(status::Status) = Int(status.error) """ - MPI.Request + MPI.AbstractRequest -An MPI Request object, representing a non-blocking communication. This also contains a +An abstract type for Julia objects wrapping MPI Requests objects, which +represent non-blocking MPI communication operations. The following +implementations provided in MPI.jl + +- [`Request`](@ref): this is the default request type. +- [`UnsafeRequest`](@ref): similar to `Request`, but does not maintain a + reference to the underlying communication buffer. +- `MultiRequestItem`: created by calling `getindex` on a [`MultiRequest`](@ref) + / [`UnsafeMultiRequest`](@ref) object, which efficiently stores a collection + of requests. + +# How request objects are used + +A request object can be passed to non-blocking communication operations, such as +[`MPI.Isend`](@ref) and [`MPI.Irecv!`](@ref). If no object is provided, then an +[`MPI.Request`](@ref) is used. + +The status of a Request can be checked by the [`Wait`](@ref) and [`Test`](@ref) +functions or their mœultiple-request variants, which will deallocate the request +once it is determined to be complete. + +Alternatively, it will be deallocated by calling `MPI.free` or at finalization, +meaning that it is safe to ignore the request objects if the status of the +communication can be checked by other means. + +In certain cases, the operation can also be cancelled by [`Cancel!`](@ref). + +# Implementing new request types + +Subtypes `R <: AbstractRequest` should define the methods for the following +functions: + +- C conversion functions to `MPI_Request` and `Ptr{MPI_Request}`: + - `Base.cconvert(::Type{MPI_Request}, req::R)` / + `Base.unsafe_convert(::Type{MPI_Request}, req::R)` + - `Base.cconvert(::Type{Ptr{MPI_Request}}, req::R)` / + `Base.unsafe_convert(::Type{Ptr{MPI_Request}}, req::R)`` +- setbuffer!(req::R, val)`: keep a reference to the communication buffer `val`. + If `val == nothing`, then clear the reference. +""" +abstract type AbstractRequest end + +""" + isnull(req::AbstractRequest) + +Is `req` is a null request. +""" +function isnull(req::AbstractRequest) + creq = Base.cconvert(MPI_Request, req) + GC.@preserve creq Base.unsafe_convert(MPI_Request, creq) == API.MPI_REQUEST_NULL[] +end + +function Base.show(io::IO, req::AbstractRequest) + if get(io, :typeinfo, Any) != typeof(req) + print(io, typeof(req), ": ") + end + if isnull(req) + print(io, "null request") + else + ref_flag = Ref(Cint(0)) + ref_status = Ref(STATUS_ZERO) + API.MPI_Request_get_status(req, ref_flag, ref_status) + if ref_flag[] != 0 + status = ref_status[] + if status.source == API.MPI_ANY_SOURCE[] && status.tag == API.MPI_ANY_TAG[] && status.error == API.MPI_SUCCESS[] + print(io, "inactive request") + else + print(io, "completed request, source = ", status.source, ", tag = ", status.tag) + if status.error != API.MPI_SUCCESS[] + print(io, ", error = ", status.error) + end + end + else + print(io, "incomplete request") + end + end +end + +function free(req::AbstractRequest) + if !isnull(req) && !MPI.Finalized() + # int MPI_Request_free(MPI_Request *req) + API.MPI_Request_free(req) + end + setbuffer!(req, nothing) + return nothing +end + + +""" + MPI.Request() + +The default MPI Request object, representing a non-blocking communication. This also contains a reference to the buffer used in the communication to ensure it isn't garbage-collected during communication. -The status of a Request can be checked by the [`Wait`](@ref) and [`Test`](@ref) functions -or their multiple-request variants, which will deallocate the request once it is -determined to be complete. Alternatively, it will be deallocated at finalization, meaning -that it is safe to ignore the request objects if the status of the communication can be -checked by other means. - -See also [`Cancel!`](@ref). +See [`AbstractRequest`](@ref) for more information. """ -mutable struct Request +mutable struct Request <: AbstractRequest val::MPI_Request buffer end -Base.:(==)(a::Request, b::Request) = a.val == b.val +function Request() + req = Request(API.MPI_REQUEST_NULL[], nothing) + return finalizer(free, req) +end + +setbuffer!(req::Request, val) = req.buffer = val + Base.cconvert(::Type{MPI_Request}, request::Request) = request Base.unsafe_convert(::Type{MPI_Request}, request::Request) = request.val Base.unsafe_convert(::Type{Ptr{MPI_Request}}, request::Request) = convert(Ptr{MPI_Request}, pointer_from_objref(request)) + const REQUEST_NULL = Request(API.MPI_REQUEST_NULL[], nothing) add_load_time_hook!(() -> REQUEST_NULL.val = API.MPI_REQUEST_NULL[]) -Request() = Request(REQUEST_NULL.val, nothing) -isnull(req::Request) = req == REQUEST_NULL +""" + MPI.UnsafeRequest() -function free(req::Request) - if req != REQUEST_NULL && !MPI.Finalized() - # int MPI_Request_free(MPI_Request *req) - API.MPI_Request_free(req) +Similar to [`MPI.Request`](@ref), but does not maintain a reference to the +underlying communication buffer. This may have improve performance by reducing +memory allocations. + +!!! warning + + The user should ensure that another reference to the communication buffer is + maintained so that it is not cleaned up by the garbage collector + before the communication operation is complete. + + For example + ```julia + buf = MPI.Buffer(zeros(10)) + GC.@preserve buf begin + req = MPI.Isend(buf, comm, UnsafeRequest(); rank=1) + # ... + MPI.Wait(req) + end +""" +mutable struct UnsafeRequest <: AbstractRequest + val::MPI_Request +end + +function UnsafeRequest() + req = UnsafeRequest(API.MPI_REQUEST_NULL[]) + return finalizer(free, req) +end +setbuffer!(req::UnsafeRequest, val) = nothing + +Base.cconvert(::Type{MPI_Request}, request::UnsafeRequest) = request +Base.unsafe_convert(::Type{MPI_Request}, request::UnsafeRequest) = request.val +Base.unsafe_convert(::Type{Ptr{MPI_Request}}, request::UnsafeRequest) = convert(Ptr{MPI_Request}, pointer_from_objref(request)) + + +# abstract element type to work around lack of cyclic type definitions +# https://github.com/JuliaLang/julia/issues/269 +abstract type AbstractMultiRequest <: AbstractVector{AbstractRequest} +end + + +""" + MPI.MultiRequest(n::Integer=0) + +A collection of MPI Requests. This is useful when operating on multiple MPI +requests at the same time. `MultiRequest` objects can be passed directly to +[`MPI.Waitall`](@ref), [`MPI.Testall`](@ref), etc. + +`req[i]` will return a `MultiRequestItem` which +adheres to the [`AbstractRequest`] interface. + +# Usage +```julia +reqs = MPI.MultiRequest(n) +for i = 1:n + MPI.Isend(buf, comm, reqs[i]; rank=dest[i]) +end +MPI.Waitall(reqs) +``` +""" +struct MultiRequest <: AbstractMultiRequest + vals::Vector{MPI_Request} + buffers::Vector{Any} +end +MultiRequest(n::Integer=0) = + MultiRequest(MPI_Request[API.MPI_REQUEST_NULL[] for _ = 1:n], Any[nothing for _ = 1:n]) + +function update!(reqs::MultiRequest, i::Integer) + if isnull(reqs[i]) + setbuffer!(reqs[i], nothing) + end + return nothing +end +function update!(reqs::MultiRequest) + foreach(i -> update!(reqs, i), 1:length(reqs)) + return nothing +end + + +""" + MPI.UnsafeMultiRequest(n::Integer=0) + +Similar to [`MPI.MultiRequest`](@ref), except that it does not maintain +references to the underlying communication buffers. The same caveats apply as +[`MPI.UnsafeRequest`](@ref). +""" +struct UnsafeMultiRequest <: AbstractMultiRequest + vals::Vector{MPI_Request} +end +UnsafeMultiRequest(n::Integer=0) = + UnsafeMultiRequest(MPI_Request[API.MPI_REQUEST_NULL[] for _ = 1:n]) +update!(reqs::UnsafeMultiRequest, i::Integer) = nothing +update!(reqs::UnsafeMultiRequest) = nothing + +struct MultiRequestItem{MR <: AbstractMultiRequest} <: AbstractRequest + multireq::MR + idx::Int +end + +Base.eltype(::Type{MR}) where {MR<:AbstractMultiRequest} = MultiRequestItem{MR} +Base.length(mreq::AbstractMultiRequest) = length(mreq.vals) +Base.size(mreq::AbstractMultiRequest) = (length(mreq),) +Base.@propagate_inbounds function Base.getindex(mreq::AbstractMultiRequest,i::Integer) + @boundscheck checkbounds(mreq.vals,i) + MultiRequestItem(mreq, i) +end + +function free(mreq::AbstractMultiRequest) + for req in mreq + free(req) end - req.buffer = nothing return nothing end +Base.cconvert(::Type{MPI_Request}, req::MultiRequestItem) = req +Base.unsafe_convert(::Type{MPI_Request}, req::MultiRequestItem) = @inbounds req.multireq.vals[req.idx] +Base.unsafe_convert(::Type{Ptr{MPI_Request}}, req::MultiRequestItem) = + convert(Ptr{MPI_Request}, pointer(req.multireq.vals, req.idx)) + +setbuffer!(req::MultiRequestItem{MultiRequest}, val) = + @inbounds req.multireq.buffers[req.idx] = val +setbuffer!(req::MultiRequestItem{UnsafeMultiRequest}, val) = + nothing + + +function Base.resize!(mreq::MultiRequest, n::Integer) + m = length(mreq) + # free any requests being removed + for i = n+1:m + free(mreq[i]) + end + resize!(mreq.vals, n) + resize!(mreq.buffers, n) + for i = m+1:n + # initialize + req = mreq[i] + req.val = API.MPI_REQUEST_NULL[] + req.buffer = nothing + end + return mreq +end +function Base.resize!(mreq::UnsafeMultiRequest, n::Integer) + m = length(mreq) + # free any requests being removed + for i = n+1:m + free(mreq[i]) + end + resize!(mreq.vals, n) + for i = m+1:n + # initialize + req = mreq[i] + req.val = API.MPI_REQUEST_NULL[] + end + return mreq +end + """ Probe(comm::Comm; @@ -194,8 +432,8 @@ Get_count(stat::Status, ::Type{T}) where {T} = Get_count(stat, Datatype(T)) """ - Wait(req::Request) - status = Wait(req::Request, Status) + Wait(req::AbstractRequest) + status = Wait(req::AbstractRequest, Status) Block until the request `req` is complete and deallocated. @@ -204,22 +442,24 @@ The `Status` argument returns the [`Status`](@ref) of the completed request. # External links $(_doc_external("MPI_Wait")) """ -function Wait(req::Request, status::Union{Ref{Status}, Nothing}=nothing) +function Wait(req::AbstractRequest, status::Union{Ref{Status}, Nothing}=nothing) # int MPI_Wait(MPI_Request *request, MPI_Status *status) API.MPI_Wait(req, something(status, API.MPI_STATUS_IGNORE[])) # only clear the buffer for non-persistent requests - isnull(req) && (req.buffer = nothing) + if isnull(req) + setbuffer!(req, nothing) + end return nothing end -function Wait(req::Request, ::Type{Status}) +function Wait(req::AbstractRequest, ::Type{Status}) status = Ref(STATUS_ZERO) Wait(req, status) return status[] end """ - flag = Test(req::Request) - flag, status = Test(req::Request, Status) + flag = Test(req::AbstractRequest) + flag, status = Test(req::AbstractRequest, Status) Check if the request `req` is complete. If so, the request is deallocated and `flag = true` is returned. Otherwise `flag = false`. @@ -228,14 +468,16 @@ The `Status` argument additionally returns the [`Status`](@ref) of the completed # External links $(_doc_external("MPI_Test")) """ -function Test(req::Request, status::Union{Ref{Status}, Nothing}=nothing) +function Test(req::AbstractRequest, status::Union{Ref{Status}, Nothing}=nothing) flag = Ref{Cint}() # int MPI_Test(MPI_Request *request, int *flag, MPI_Status *status) API.MPI_Test(req, flag, something(status, API.MPI_STATUS_IGNORE[])) - isnull(req) && (req.buffer = nothing) + if isnull(req) + setbuffer!(req, nothing) + end return flag[] != 0 end -function Test(req::Request, ::Type{Status}) +function Test(req::AbstractRequest, ::Type{Status}) status = Ref(STATUS_ZERO) flag = Test(req, status) return flag, status[] @@ -249,6 +491,8 @@ end A wrapper for an array of `Request`s that can be used to reduce intermediate memory allocations in [`Waitall`](@ref), [`Testall`](@ref), [`Waitany`](@ref), [`Testany`](@ref), [`Waitsome`](@ref) or [`Testsome`](@ref). + +Consider using a [`MultiRequest`](@ref) or [`UnsafeMultiRequest`](@ref) instead. """ struct RequestSet <: AbstractVector{Request} requests::Vector{Request} @@ -300,7 +544,7 @@ The optional `statuses` or `Status` argument can be used to obtain the return # External links $(_doc_external("MPI_Waitall")) """ -function Waitall(reqs::RequestSet, statuses::Union{AbstractVector{Status},Nothing}=nothing) +function Waitall(reqs::Union{AbstractMultiRequest, RequestSet}, statuses::Union{AbstractVector{Status},Nothing}=nothing) n = length(reqs) n == 0 && return nothing @assert isnothing(statuses) || length(statuses) >= n @@ -310,7 +554,7 @@ function Waitall(reqs::RequestSet, statuses::Union{AbstractVector{Status},Nothin update!(reqs) return nothing end -function Waitall(reqs::RequestSet, ::Type{Status}) +function Waitall(reqs::Union{AbstractMultiRequest, RequestSet}, ::Type{Status}) statuses = Array{Status}(undef, length(reqs)) Waitall(reqs, statuses) return statuses @@ -335,7 +579,7 @@ The optional `statuses` or `Status` argument can be used to obtain the return # External links $(_doc_external("MPI_Testall")) """ -function Testall(reqs::RequestSet, statuses::Union{AbstractVector{Status},Nothing}=nothing) +function Testall(reqs::Union{AbstractMultiRequest, RequestSet}, statuses::Union{AbstractVector{Status},Nothing}=nothing) n = length(reqs) flag = Ref{Cint}() @assert isnothing(statuses) || length(statuses) >= n @@ -345,7 +589,7 @@ function Testall(reqs::RequestSet, statuses::Union{AbstractVector{Status},Nothin update!(reqs) return flag[] != 0 end -function Testall(reqs::RequestSet, ::Type{Status}) +function Testall(reqs::Union{AbstractMultiRequest, RequestSet}, ::Type{Status}) statuses = Array{Status}(undef, length(reqs)) flag = Testall(reqs, statuses) return flag, statuses @@ -371,7 +615,7 @@ of the request. # External links $(_doc_external("MPI_Waitany")) """ -function Waitany(reqs::RequestSet, status::Union{Ref{Status}, Nothing}=nothing) +function Waitany(reqs::Union{AbstractMultiRequest, RequestSet}, status::Union{Ref{Status}, Nothing}=nothing) ref_idx = Ref{Cint}() n = length(reqs) # int MPI_Waitany(int count, MPI_Request array_of_requests[], int *index, @@ -383,7 +627,7 @@ function Waitany(reqs::RequestSet, status::Union{Ref{Status}, Nothing}=nothing) update!(reqs, i) return i end -function Waitany(reqs::RequestSet, ::Type{Status}) +function Waitany(reqs::Union{AbstractMultiRequest, RequestSet}, ::Type{Status}) status = Ref(STATUS_ZERO) i = Waitany(reqs, status) return i, status[] @@ -411,7 +655,7 @@ The optional `status` argument can be used to obtain the return [`Status`](@ref) # External links $(_doc_external("MPI_Testany")) """ -function Testany(reqs::RequestSet, status::Union{Ref{Status}, Nothing}=nothing) +function Testany(reqs::Union{AbstractMultiRequest, RequestSet}, status::Union{Ref{Status}, Nothing}=nothing) ref_idx = Ref{Cint}() rflag = Ref{Cint}() n = length(reqs) @@ -425,7 +669,7 @@ function Testany(reqs::RequestSet, status::Union{Ref{Status}, Nothing}=nothing) update!(reqs, i) return flag, i end -function Testany(reqs::RequestSet, ::Type{Status}) +function Testany(reqs::Union{AbstractMultiRequest, RequestSet}, ::Type{Status}) status = Ref(STATUS_ZERO) flag, i = Testany(reqs, status) return flag, i, status[] @@ -450,7 +694,7 @@ The optional `statuses` argument can be used to obtain the return # External links $(_doc_external("MPI_Waitsome")) """ -function Waitsome(reqs::RequestSet, statuses::Union{AbstractVector{Status},Nothing}=nothing) +function Waitsome(reqs::Union{AbstractMultiRequest, RequestSet}, statuses::Union{AbstractVector{Status},Nothing}=nothing) ref_nout = Ref{Cint}() n = length(reqs) idxs = Vector{Cint}(undef, n) @@ -466,7 +710,7 @@ function Waitsome(reqs::RequestSet, statuses::Union{AbstractVector{Status},Nothi update!(reqs) return [Int(idxs[i]) + 1 for i = 1:nout] end -function Waitsome(reqs::RequestSet, ::Type{Status}) +function Waitsome(reqs::Union{AbstractMultiRequest, RequestSet}, ::Type{Status}) statuses = Array{Status}(undef, length(reqs)) inds = Waitsome(reqs, statuses) resize!(statuses, isnothing(inds) ? 0 : length(inds)) @@ -491,7 +735,7 @@ The optional `statuses` argument can be used to obtain the return # External links $(_doc_external("MPI_Testsome")) """ -function Testsome(reqs::RequestSet, statuses::Union{AbstractVector{Status},Nothing}=nothing) +function Testsome(reqs::Union{AbstractMultiRequest, RequestSet}, statuses::Union{AbstractVector{Status},Nothing}=nothing) ref_nout = Ref{Cint}() n = length(reqs) idxs = Vector{Cint}(undef, n) @@ -507,7 +751,7 @@ function Testsome(reqs::RequestSet, statuses::Union{AbstractVector{Status},Nothi update!(reqs) return [Int(idxs[i]) + 1 for i = 1:nout] end -function Testsome(reqs::RequestSet, ::Type{Status}) +function Testsome(reqs::Union{AbstractMultiRequest, RequestSet}, ::Type{Status}) statuses = Array{Status}(undef, length(reqs)) inds = Testsome(reqs, statuses) resize!(statuses, isnothing(inds) ? 0 : length(inds)) @@ -525,7 +769,7 @@ request is not deallocated, and can still be queried using the test or wait func # External links $(_doc_external("MPI_Cancel")) """ -function Cancel!(req::Request) +function Cancel!(req::AbstractRequest) # int MPI_Cancel(MPI_Request *request) API.MPI_Cancel(req) nothing diff --git a/src/pointtopoint.jl b/src/pointtopoint.jl index ddffc0ca8..e762edb03 100644 --- a/src/pointtopoint.jl +++ b/src/pointtopoint.jl @@ -44,46 +44,45 @@ function send(obj, dest::Integer, tag::Integer, comm::Comm) end """ - Isend(data, comm::Comm; dest::Integer, tag::Integer=0) + Isend(data, comm::Comm[, req::AbstractRequest = Request()]; dest::Integer, tag::Integer=0) Starts a nonblocking send of `data` to MPI rank `dest` of communicator `comm` using with the message tag `tag`. `data` can be a [`Buffer`](@ref), or any object for which [`Buffer_send`](@ref) is defined. -Returns the [`Request`](@ref) object for the nonblocking send. +Returns the [`AbstractRequest`](@ref) object for the nonblocking send. # External links $(_doc_external("MPI_Isend")) """ -Isend(data, comm::Comm; dest::Integer, tag::Integer=0) = - Isend(data, dest, tag, comm) +Isend(data, comm::Comm, req::AbstractRequest=Request(); dest::Integer, tag::Integer=0) = + Isend(data, dest, tag, comm, req) -function Isend(buf::Buffer, dest::Integer, tag::Integer, comm::Comm) - req = Request() +function Isend(buf::Buffer, dest::Integer, tag::Integer, comm::Comm, req::AbstractRequest=Request()) + @assert isnull(req) # int MPI_Isend(const void* buf, int count, MPI_Datatype datatype, int dest, # int tag, MPI_Comm comm, MPI_Request *request) API.MPI_Isend(buf.data, buf.count, buf.datatype, dest, tag, comm, req) - req.buffer = buf - finalizer(free, req) + setbuffer!(req, buf) return req end -Isend(data, dest::Integer, tag::Integer, comm::Comm) = - Isend(Buffer_send(data), dest, tag, comm) +Isend(data, dest::Integer, tag::Integer, comm::Comm, req::AbstractRequest=Request()) = + Isend(Buffer_send(data), dest, tag, comm, req) """ - isend(obj, comm::Comm; dest::Integer, tag::Integer=0) + isend(obj, comm::Comm[, req::AbstractRequest = Request()]; dest::Integer, tag::Integer=0) Starts a nonblocking send of using a serialized version of `obj` to MPI rank `dest` of communicator `comm` using with the message tag `tag`. Returns the commication `Request` for the nonblocking send. """ -isend(data, comm::Comm; dest::Integer, tag::Integer=0) = - isend(data, dest, tag, comm) -function isend(obj, dest::Integer, tag::Integer, comm::Comm) +isend(data, comm::Comm, req::AbstractRequest = Request(); dest::Integer, tag::Integer=0) = + isend(data, dest, tag, comm, req) +function isend(obj, dest::Integer, tag::Integer, comm::Comm, req::AbstractRequest = Request()) buf = MPI.serialize(obj) - Isend(buf, dest, tag, comm) + Isend(buf, dest, tag, comm, req) end """ @@ -183,7 +182,7 @@ end """ - req = Irecv!(recvbuf, comm::Comm; + req = Irecv!(recvbuf, comm::Comm[, req::AbstractRequest = Request()]; source::Integer=MPI.ANY_SOURCE, tag::Integer=MPI.ANY_TAG) Starts a nonblocking receive into the buffer `data` from MPI rank `source` of communicator @@ -191,24 +190,23 @@ Starts a nonblocking receive into the buffer `data` from MPI rank `source` of co `data` can be a [`Buffer`](@ref), or any object for which `Buffer(data)` is defined. -Returns the [`Request`](@ref) object for the nonblocking receive. +Returns the [`AbstractRequest`](@ref) object for the nonblocking receive. # External links $(_doc_external("MPI_Irecv")) """ -Irecv!(recvbuf, comm::Comm; source::Integer=API.MPI_ANY_SOURCE[], tag::Integer=API.MPI_ANY_TAG[]) = - Irecv!(recvbuf, source, tag, comm) -function Irecv!(buf::Buffer, source::Integer, tag::Integer, comm::Comm) - req = Request() +Irecv!(recvbuf, comm::Comm, req::AbstractRequest=Request(); source::Integer=API.MPI_ANY_SOURCE[], tag::Integer=API.MPI_ANY_TAG[]) = + Irecv!(recvbuf, source, tag, comm, req) +function Irecv!(buf::Buffer, source::Integer, tag::Integer, comm::Comm, req::AbstractRequest=Request()) + @assert isnull(req) # int MPI_Irecv(void* buf, int count, MPI_Datatype datatype, int source, # int tag, MPI_Comm comm, MPI_Request *request) API.MPI_Irecv(buf.data, buf.count, buf.datatype, source, tag, comm, req) - req.buffer = buf - finalizer(free, req) + setbuffer!(req, buf) return req end -Irecv!(data, source::Integer, tag::Integer, comm::Comm) = - Irecv!(Buffer(data), source, tag, comm) +Irecv!(data, source::Integer, tag::Integer, comm::Comm, req::AbstractRequest=Request()) = + Irecv!(Buffer(data), source, tag, comm, req) """ @@ -249,51 +247,51 @@ end # persistent requests """ - Send_init(buf, comm; dest, tag=0)::MPI.Request + Send_init(buf, comm::MPI.Comm[, req::AbstractRequest = Request()]; + dest, tag=0) -Allocate a persistent send request, returning a [`Request`](@ref) object. Use +Allocate a persistent send request, returning a [`AbstractRequest`](@ref) object. Use [`Start`](@ref) or [`Startall`](@ref) to start the communication operation, and `free` to deallocate the request. # External links $(_doc_external("MPI_Send_init")) """ -Send_init(buf, comm::Comm; dest::Integer, tag::Integer=0) = - Send_init(buf, dest, tag, comm) -function Send_init(buf::Buffer, dest::Integer, tag::Integer, comm::Comm) - req = Request() +Send_init(buf, comm::Comm, req::AbstractRequest=Request(); dest::Integer, tag::Integer=0) = + Send_init(buf, dest, tag, comm, req) +function Send_init(buf::Buffer, dest::Integer, tag::Integer, comm::Comm, req::AbstractRequest=Request()) + @assert isnull(req) API.MPI_Send_init(buf.data, buf.count, buf.datatype, dest, tag, comm, req) - req.buffer = buf - finalizer(free, req) + setbuffer!(req, buf) return req end -Send_init(buf, dest::Integer, tag::Integer, comm::Comm) = - Send_init(Buffer(buf), dest, tag, comm) +Send_init(buf, dest::Integer, tag::Integer, comm::Comm, req::AbstractRequest=Request()) = + Send_init(Buffer(buf), dest, tag, comm, req) """ - Recv_init(buf, comm; source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG)::MPI.Request + Recv_init(buf, comm::MPI.Comm[, req::AbstractRequest = Request()]; + source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG) -Allocate a persistent receive request, returning a [`Request`](@ref) object. Use +Allocate a persistent receive request, returning a [`AbstractRequest`](@ref) object. Use [`Start`](@ref) or [`Startall`](@ref) to start the communication operation, and `free` to deallocate the request. # External links $(_doc_external("MPI_Recv_init")) """ -Recv_init(buf, comm::Comm; source=API.MPI_ANY_SOURCE[], tag=API.MPI_ANY_TAG[]) = - Recv_init(buf, source, tag, comm) -function Recv_init(buf::Buffer, source::Integer, tag::Integer, comm::Comm) - req = Request() +Recv_init(buf, comm::Comm, req::AbstractRequest=Request(); source=API.MPI_ANY_SOURCE[], tag=API.MPI_ANY_TAG[]) = + Recv_init(buf, source, tag, comm, req) +function Recv_init(buf::Buffer, source::Integer, tag::Integer, comm::Comm, req::AbstractRequest=Request()) + @assert isnull(req) API.MPI_Recv_init(buf.data, buf.count, buf.datatype, source, tag, comm, req) - req.buffer = buf - finalizer(free, req) + setbuffer!(req, buf) return req end -Recv_init(buf, source::Integer, tag::Integer, comm::Comm) = - Recv_init(Buffer(buf), source, tag, comm) +Recv_init(buf, source::Integer, tag::Integer, comm::Comm, req::AbstractRequest=Request()) = + Recv_init(Buffer(buf), source, tag, comm, req) """ - Start(request::Request) + Start(request::AbstractRequest) Start a persistent communication request created by [`Send_init`](@ref) or [`Recv_init`](@ref). Call [`Wait`](@ref) to complete the request. @@ -301,7 +299,7 @@ Start a persistent communication request created by [`Send_init`](@ref) or # External links $(_doc_external("MPI_Start")) """ -function Start(req::Request) +function Start(req::AbstractRequest) API.MPI_Start(req) return nothing end @@ -316,7 +314,7 @@ or [`Recv_init`](@ref). Call [`Waitall`](@ref) to complete the requests. $(_doc_external("MPI_Startall")) """ Startall(reqs::AbstractVector{Request}) = Startall(RequestSet(reqs)) -function Startall(reqs::RequestSet) +function Startall(reqs::Union{AbstractMultiRequest, RequestSet}) API.MPI_Startall(length(reqs), reqs.vals) update!(reqs) return nothing diff --git a/test/test_persistent.jl b/test/test_persistent.jl index 50b19c2c9..fcb6d67b8 100644 --- a/test/test_persistent.jl +++ b/test/test_persistent.jl @@ -14,17 +14,19 @@ recv_mesg_expected = ArrayType{Float64}(undef,N) fill!(send_mesg, rank) synchronize() -sendreq = MPI.Send_init(send_mesg, comm; tag=7, dest=mod(rank + 1, size)) +sendreq = MPI.Send_init(send_mesg, comm; tag=7, dest=mod(rank + 1, size)) recvreq = MPI.Recv_init(recv_mesg, comm; tag=7, source=mod(rank - 1, size)) @test MPI.Test(sendreq) @test MPI.Test(recvreq) -for i = 1:3 +@test Base.sprint(show, sendreq) == Base.sprint(show, recvreq) == "MPI.Request: inactive request" +for i = 1:3 MPI.Start(recvreq) @test !MPI.Test(recvreq) + @test Base.sprint(show, recvreq) == "MPI.Request: incomplete request" MPI.Barrier(comm) @@ -34,6 +36,7 @@ for i = 1:3 @test MPI.Test(sendreq) @test MPI.Test(recvreq) + @test Base.sprint(show, sendreq) == Base.sprint(show, recvreq) == "MPI.Request: inactive request" copyto!(send_mesg, recv_mesg) @@ -56,3 +59,31 @@ for i = 4:6 synchronize() @test recv_mesg == recv_mesg_expected end + +MPI.free(sendreq) +MPI.free(recvreq) + +@test MPI.isnull(sendreq) +@test MPI.isnull(recvreq) + + +reqs = MPI.MultiRequest(2) +MPI.Send_init(send_mesg, comm, reqs[1]; tag=8, dest=mod(rank + 1, size)) +MPI.Recv_init(recv_mesg, comm, reqs[2]; tag=8, source=mod(rank - 1, size)) +@test MPI.Testall(reqs) + +for i = 7:9 + MPI.Startall(reqs) + MPI.Waitall(reqs) + @test MPI.Testall(reqs) + + copyto!(send_mesg, recv_mesg) + + fill!(recv_mesg_expected, mod(rank-i, size)) + synchronize() + @test recv_mesg == recv_mesg_expected +end + +MPI.free(reqs) +@test MPI.Testall(reqs) +@test all(MPI.isnull, reqs) diff --git a/test/test_test.jl b/test/test_test.jl index dc9c80e7f..696913a8a 100644 --- a/test/test_test.jl +++ b/test/test_test.jl @@ -53,5 +53,32 @@ MPI.Waitall(reqs) @test isnothing(inds) @test isempty(stats) +reqs = MPI.MultiRequest(2) +MPI.Irecv!(recv_mesg, comm, reqs[1]; source=src, tag=src+32) +@test !MPI.Testall(reqs) + +MPI.Barrier(comm) +MPI.Isend(send_mesg, comm, reqs[2]; dest=dst, tag=rank+32, ) + +inds = MPI.Waitsome(reqs) +@test !isempty(inds) + +MPI.Waitall(reqs) +@test MPI.Testall(reqs) + +reqs = MPI.UnsafeMultiRequest(2) +GC.@preserve send_mesg recv_mesg begin + MPI.Irecv!(recv_mesg, comm, reqs[1]; source=src, tag=src+32) + @test !MPI.Testall(reqs) + + MPI.Barrier(comm) + MPI.Isend(send_mesg, comm, reqs[2]; dest=dst, tag=rank+32, ) + inds = MPI.Waitsome(reqs) + @test !isempty(inds) + + MPI.Waitall(reqs) + @test MPI.Testall(reqs) +end + MPI.Finalize() @test MPI.Finalized() diff --git a/test/test_wait.jl b/test/test_wait.jl index c3521d115..1aa2061bf 100644 --- a/test/test_wait.jl +++ b/test/test_wait.jl @@ -20,14 +20,14 @@ recv_check = zeros(Int, nsends) for i = 1:nsends idx = MPI.Waitany(send_reqs) - @test send_reqs[idx] == MPI.REQUEST_NULL + @test MPI.isnull(send_reqs[idx]) send_check[idx] += 1 end @test send_check == ones(Int, nsends) for i = 1:nsends idx = MPI.Waitany(recv_reqs) - @test recv_reqs[idx] == MPI.REQUEST_NULL + @test MPI.isnull(recv_reqs[idx]) recv_check[idx] += 1 end @test recv_check == ones(Int, nsends)