From 95cc4cac11ab671994ab31566a44d51db79125b1 Mon Sep 17 00:00:00 2001 From: Katie Hyatt Date: Fri, 29 Apr 2016 16:51:39 -0700 Subject: [PATCH] Wrapped ireduce, iscan, iexscan --- deps/CMakeLists.txt | 3 +++ deps/gen_functions.c | 3 +++ src/mpi-base.jl | 54 +++++++++++++++++++++++++++++++++++++++++ src/win_mpiconstants.jl | 3 +++ test/test_exscan.jl | 5 ++++ test/test_reduce.jl | 6 +++++ test/test_scan.jl | 4 +++ 7 files changed, 78 insertions(+) diff --git a/deps/CMakeLists.txt b/deps/CMakeLists.txt index a23fa9808..ad5be2429 100644 --- a/deps/CMakeLists.txt +++ b/deps/CMakeLists.txt @@ -55,12 +55,15 @@ FortranCInterface_HEADER(jlmpi_f2c.h MACRO_NAMESPACE "JLMPI_" SYMBOLS MPI_IALLTOALLV MPI_IBARRIER MPI_IBCAST + MPI_IEXSCAN MPI_IGATHER MPI_IGATHERV MPI_INIT MPI_INITIALIZED MPI_IPROBE MPI_IRECV + MPI_IREDUCE + MPI_ISCAN MPI_ISCATTER MPI_ISCATTERV MPI_ISEND diff --git a/deps/gen_functions.c b/deps/gen_functions.c index 04668abb8..91fb6188a 100644 --- a/deps/gen_functions.c +++ b/deps/gen_functions.c @@ -43,12 +43,15 @@ int main(int argc, char *argv[]) { printf(" :MPI_IALLTOALLV => \"%s\",\n", STRING(MPI_IALLTOALLV)); printf(" :MPI_IBARRIER => \"%s\",\n", STRING(MPI_IBARRIER)); printf(" :MPI_IBCAST => \"%s\",\n", STRING(MPI_IBCAST)); + printf(" :MPI_IEXSCAN => \"%s\",\n", STRING(MPI_IEXSCAN)); printf(" :MPI_IGATHER => \"%s\",\n", STRING(MPI_IGATHER)); printf(" :MPI_IGATHERV => \"%s\",\n", STRING(MPI_IGATHERV)); printf(" :MPI_INIT => \"%s\",\n", STRING(MPI_INIT)); printf(" :MPI_INITIALIZED => \"%s\",\n", STRING(MPI_INITIALIZED)); printf(" :MPI_IPROBE => \"%s\",\n", STRING(MPI_IPROBE)); printf(" :MPI_IRECV => \"%s\",\n", STRING(MPI_IRECV)); + printf(" :MPI_IREDUCE => \"%s\",\n", STRING(MPI_IREDUCE)); + printf(" :MPI_ISCAN => \"%s\",\n", STRING(MPI_ISCAN)); printf(" :MPI_ISCATTER => \"%s\",\n", STRING(MPI_ISCATTER)); printf(" :MPI_ISCATTERV => \"%s\",\n", STRING(MPI_ISCATTERV)); printf(" :MPI_ISEND => \"%s\",\n", STRING(MPI_ISEND)); diff --git a/src/mpi-base.jl b/src/mpi-base.jl index a14efc052..dd9ab4bdf 100644 --- a/src/mpi-base.jl +++ b/src/mpi-base.jl @@ -547,6 +547,30 @@ function Reduce{T}(object::T, op::Op, root::Integer, comm::Comm) isroot ? recvbuf[1] : nothing end +function Ireduce{T}(sendbuf::MPIBuffertype{T}, count::Integer, + op::Op, root::Integer, comm::Comm) + rval = Ref{Cint}() + isroot = Comm_rank(comm) == root + recvbuf = Array(T, isroot ? count : 0) + ccall(MPI_IREDUCE, Void, + (Ptr{T}, Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, + Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), + sendbuf, recvbuf, &count, &mpitype(T), &op.val, &root, &comm.val, + rval, &0) + Request(rval[], sendbuf), isroot ? recvbuf : nothing +end + +function Ireduce{T}(sendbuf::Array{T}, op::Op, root::Integer, comm::Comm) + Ireduce(sendbuf, length(sendbuf), op, root, comm) +end + +function Ireduce{T}(object::T, op::Op, root::Integer, comm::Comm) + isroot = Comm_rank(comm) == root + sendbuf = T[object] + req, recvbuf = Ireduce(sendbuf, op, root, comm) + req, isroot ? recvbuf[1] : nothing +end + function Scatter{T}(sendbuf::MPIBuffertype{T}, count::Integer, root::Integer, comm::Comm) recvbuf = Array(T, count) @@ -775,6 +799,21 @@ function Scan{T}(object::T, op::Op, comm::Comm) Scan(sendbuf,1,op,comm) end +function Iscan{T}(sendbuf::MPIBuffertype{T}, count::Integer, + op::Op, comm::Comm) + recvbuf = Array(T, count) + rval = Ref{Cint}() + ccall(MPI_ISCAN, Void, + (Ptr{T}, Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), + sendbuf, recvbuf, &count, &mpitype(T), &op.val, &comm.val, rval, &0) + Request(rval[], sendbuf), recvbuf +end + +function Iscan{T}(object::T, op::Op, comm::Comm) + sendbuf = T[object] + Iscan(sendbuf,1,op,comm) +end + function ExScan{T}(sendbuf::MPIBuffertype{T}, count::Integer, op::Op, comm::Comm) recvbuf = Array(T, count) @@ -789,6 +828,21 @@ function ExScan{T}(object::T, op::Op, comm::Comm) ExScan(sendbuf,1,op,comm) end +function IExScan{T}(sendbuf::MPIBuffertype{T}, count::Integer, + op::Op, comm::Comm) + recvbuf = Array(T, count) + rval = Ref{Cint}() + ccall(MPI_IEXSCAN, Void, + (Ptr{T}, Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), + sendbuf, recvbuf, &count, &mpitype(T), &op.val, &comm.val, rval, &0) + Request(rval[], sendbuf), recvbuf +end + +function IExScan{T}(object::T, op::Op, comm::Comm) + sendbuf = T[object] + IExScan(sendbuf,1,op,comm) +end + # Conversion between C and Fortran Comm handles: if HAVE_MPI_COMM_C2F # use MPI_Comm_f2c and MPI_Comm_c2f diff --git a/src/win_mpiconstants.jl b/src/win_mpiconstants.jl index bf8c03448..c58a6f7e7 100644 --- a/src/win_mpiconstants.jl +++ b/src/win_mpiconstants.jl @@ -56,6 +56,7 @@ const MPI_FINALIZE = (:MPI_FINALIZE, "msmpi.dll") const MPI_BCAST = (:MPI_BCAST, "msmpi.dll") const MPI_IBCAST = (:MPI_IBCAST, "msmpi.dll") const MPI_REDUCE = (:MPI_REDUCE, "msmpi.dll") +const MPI_IREDUCE = (:MPI_IREDUCE, "msmpi.dll") const MPI_IRECV = (:MPI_IRECV, "msmpi.dll") const MPI_RECV = (:MPI_RECV, "msmpi.dll") const MPI_ISEND = (:MPI_ISEND, "msmpi.dll") @@ -74,7 +75,9 @@ const MPI_SCATTER = (:MPI_SCATTER, "msmpi.dll") const MPI_SCATTERV = (:MPI_SCATTERV, "msmpi.dll") const MPI_SEND = (:MPI_SEND, "msmpi.dll") const MPI_SCAN = (:MPI_SCAN, "msmpi.dll") +const MPI_ISCAN = (:MPI_ISCAN, "msmpi.dll") const MPI_EXSCAN = (:MPI_EXSCAN, "msmpi.dll") +const MPI_IEXSCAN = (:MPI_IEXSCAN, "msmpi.dll") const MPI_GATHER = (:MPI_GATHER, "msmpi.dll") const MPI_GATHERV = (:MPI_GATHERV, "msmpi.dll") const MPI_IGATHER = (:MPI_GATHER, "msmpi.dll") diff --git a/test/test_exscan.jl b/test/test_exscan.jl index 514296685..d6b489729 100644 --- a/test/test_exscan.jl +++ b/test/test_exscan.jl @@ -16,6 +16,11 @@ for typ in typs if rank > 0 @test_approx_eq B[1] factorial(rank) end + req, C = MPI.IExScan(val, MPI.PROD, comm) + MPI.Wait!(req) + if rank > 0 + @test_approx_eq C[1] factorial(rank) + end end MPI.Finalize() diff --git a/test/test_reduce.jl b/test/test_reduce.jl index fc49977d5..ef1cccb20 100644 --- a/test/test_reduce.jl +++ b/test/test_reduce.jl @@ -23,4 +23,10 @@ sum_mesg = MPI.Reduce(mesg, MPI.SUM, root, comm) sum_mesg = rank == root ? sum_mesg : size*mesg @test isapprox(norm(sum_mesg-size*mesg), 0.0) +mesg = collect(1.0:5.0) +req, sum_mesg = MPI.Ireduce(mesg, MPI.SUM, root, comm) +MPI.Wait!(req) +sum_mesg = rank == root ? sum_mesg : size*mesg +@test isapprox(norm(sum_mesg-size*mesg), 0.0) + MPI.Finalize() diff --git a/test/test_scan.jl b/test/test_scan.jl index f903cf1e0..fe8067b5d 100644 --- a/test/test_scan.jl +++ b/test/test_scan.jl @@ -14,6 +14,10 @@ for typ in typs val = convert(typ,rank + 1) B = MPI.Scan(val, MPI.PROD, comm) @test_approx_eq B[1] factorial(val) + val = convert(typ,rank + 1) + req, B = MPI.Iscan(val, MPI.PROD, comm) + MPI.Wait!(req) + @test_approx_eq B[1] factorial(val) end MPI.Finalize()