diff --git a/src/impl/KokkosComm_irecv.hpp b/src/impl/KokkosComm_irecv.hpp index d55262c2..4e5ffc61 100644 --- a/src/impl/KokkosComm_irecv.hpp +++ b/src/impl/KokkosComm_irecv.hpp @@ -33,8 +33,6 @@ template void irecv(RecvView &rv, int src, int tag, MPI_Comm comm, MPI_Request &req) { Kokkos::Tools::pushRegion("KokkosComm::Impl::irecv"); - using KCT = KokkosComm::Traits; - if (KokkosComm::is_contiguous(rv)) { using RecvScalar = typename RecvView::value_type; MPI_Irecv(KokkosComm::data_handle(rv), KokkosComm::span(rv), mpi_type_v, src, tag, comm, &req); @@ -44,4 +42,13 @@ void irecv(RecvView &rv, int src, int tag, MPI_Comm comm, MPI_Request &req) { Kokkos::Tools::popRegion(); } + +template +KokkosComm::Req irecv(RecvView &rv, int src, int tag, MPI_Comm comm) { + Kokkos::Tools::pushRegion("KokkosComm::Impl::irecv"); + KokkosComm::Req req; + irecv(rv, src, tag, comm, req.mpi_req()); + return req; +} + } // namespace KokkosComm::Impl diff --git a/unit_tests/test_isendirecv.cpp b/unit_tests/test_isendirecv.cpp index d76bf83a..aac549ac 100644 --- a/unit_tests/test_isendirecv.cpp +++ b/unit_tests/test_isendirecv.cpp @@ -85,10 +85,9 @@ void test_2d(const View2D &a) { KokkosComm::Req req = KokkosComm::isend(Kokkos::DefaultExecutionSpace(), a, dst, 0, MPI_COMM_WORLD); req.wait(); } else if (1 == rank) { - int src = 0; - MPI_Request req; - KokkosComm::irecv(a, src, 0, MPI_COMM_WORLD, req); - MPI_Wait(&req, MPI_STATUS_IGNORE); + int src = 0; + KokkosComm::Req req = KokkosComm::irecv(a, src, 0, MPI_COMM_WORLD); + req.wait(); int errs; Kokkos::parallel_reduce( policy, KOKKOS_LAMBDA(int i, int j, int &lsum) { lsum += a(i, j) != Scalar(i * a.extent(0) + j); }, errs);