diff --git a/deps/CMakeLists.txt b/deps/CMakeLists.txt index ad5be2429..b13a88b45 100644 --- a/deps/CMakeLists.txt +++ b/deps/CMakeLists.txt @@ -51,6 +51,7 @@ FortranCInterface_HEADER(jlmpi_f2c.h MACRO_NAMESPACE "JLMPI_" SYMBOLS MPI_GET_PROCESSOR_NAME MPI_IALLGATHER MPI_IALLGATHERV + MPI_IALLREDUCE MPI_IALLTOALL MPI_IALLTOALLV MPI_IBARRIER diff --git a/deps/gen_functions.c b/deps/gen_functions.c index 91fb6188a..6cc4a67c8 100644 --- a/deps/gen_functions.c +++ b/deps/gen_functions.c @@ -39,6 +39,7 @@ int main(int argc, char *argv[]) { STRING(MPI_GET_PROCESSOR_NAME)); printf(" :MPI_IALLGATHER => \"%s\",\n", STRING(MPI_IALLGATHER)); printf(" :MPI_IALLGATHERV => \"%s\",\n", STRING(MPI_IALLGATHERV)); + printf(" :MPI_IALLREDUCE => \"%s\",\n", STRING(MPI_IALLREDUCE)); printf(" :MPI_IALLTOALL => \"%s\",\n", STRING(MPI_IALLTOALL)); printf(" :MPI_IALLTOALLV => \"%s\",\n", STRING(MPI_IALLTOALLV)); printf(" :MPI_IBARRIER => \"%s\",\n", STRING(MPI_IBARRIER)); diff --git a/src/mpi-base.jl b/src/mpi-base.jl index dd9ab4bdf..7c2c2b251 100644 --- a/src/mpi-base.jl +++ b/src/mpi-base.jl @@ -571,6 +571,45 @@ function Ireduce{T}(object::T, op::Op, root::Integer, comm::Comm) req, isroot ? recvbuf[1] : nothing end +function Allreduce{T}(sendbuf::MPIBuffertype{T}, count::Integer, op::Op, comm::Comm) + recvbuf = Array(T, count) + ccall(MPI_ALLREDUCE, Void, + (Ptr{T}, Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), + sendbuf, recvbuf, &count, &mpitype(T), &op.val, &comm.val, &0) + recvbuf +end + +function Allreduce{T}(sendbuf::Array{T}, op::Op, comm::Comm) + Allreduce(sendbuf, length(sendbuf), op, comm) +end + +function Allreduce{T}(object::T, op::Op, comm::Comm) + sendbuf = T[object] + recvbuf = Allreduce(sendbuf, op, comm) + recvbuf[1] +end + +function Iallreduce{T}(sendbuf::MPIBuffertype{T}, count::Integer, op::Op, comm::Comm) + rval = Ref{Cint}() + recvbuf = Array(T, count) + ccall(MPI_IALLREDUCE, Void, + (Ptr{T}, Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, + Ptr{Cint}, Ptr{Cint}), + sendbuf, recvbuf, &count, &mpitype(T), &op.val, &comm.val, + rval, &0) + Request(rval[], sendbuf), recvbuf +end + +function Iallreduce{T}(sendbuf::Array{T}, op::Op, comm::Comm) + Iallreduce(sendbuf, length(sendbuf), op, comm) +end + +function Iallreduce{T}(object::T, op::Op, comm::Comm) + sendbuf = T[object] + req, recvbuf = Iallreduce(sendbuf, op, comm) + req, recvbuf[1] +end + function Scatter{T}(sendbuf::MPIBuffertype{T}, count::Integer, root::Integer, comm::Comm) recvbuf = Array(T, count) diff --git a/src/win_mpiconstants.jl b/src/win_mpiconstants.jl index c58a6f7e7..f1a59deaf 100644 --- a/src/win_mpiconstants.jl +++ b/src/win_mpiconstants.jl @@ -65,6 +65,8 @@ const MPI_ALLGATHER = (:MPI_ALLGATHER, "msmpi.dll") const MPI_ALLGATHERV = (:MPI_ALLGATHERV, "msmpi.dll") const MPI_IALLGATHER = (:MPI_ALLGATHER, "msmpi.dll") const MPI_IALLGATHERV = (:MPI_ALLGATHERV, "msmpi.dll") +const MPI_ALLREDUCE = (:MPI_ALLREDUCE, "msmpi.dll") +const MPI_IALLREDUCE = (:MPI_IALLREDUCE, "msmpi.dll") const MPI_ALLTOALL = (:MPI_ALLTOALL, "msmpi.dll") const MPI_ALLTOALLV = (:MPI_ALLTOALLV, "msmpi.dll") const MPI_IALLTOALL = (:MPI_IALLTOALL, "msmpi.dll") diff --git a/test/test_allreduce.jl b/test/test_allreduce.jl new file mode 100644 index 000000000..5d808f682 --- /dev/null +++ b/test/test_allreduce.jl @@ -0,0 +1,30 @@ +using Base.Test + +using MPI + +MPI.Init() + +comm = MPI.COMM_WORLD +size = MPI.Comm_size(comm) +rank = MPI.Comm_rank(comm) + +val = sum(0:size-1) +@test MPI.Allreduce(rank, MPI.SUM, comm) == val + +val = size-1 +@test MPI.Allreduce(rank, MPI.MAX, comm) == val + +val = 0 +@test MPI.Allreduce(rank, MPI.MIN, comm) == val + +mesg = collect(1.0:5.0) +sum_mesg = MPI.Allreduce(mesg, MPI.SUM, comm) +@test isapprox(norm(sum_mesg-size*mesg), 0.0) + +mesg = collect(1.0:5.0) +req, sum_mesg = MPI.Iallreduce(mesg, MPI.SUM, comm) +MPI.Wait!(req) +sum_mesg = sum_mesg +@test isapprox(norm(sum_mesg-size*mesg), 0.0) + +MPI.Finalize()