Skip to content

Commit

Permalink
Merge pull request #70 from cwpearson/fix/remove-traits
Browse files Browse the repository at this point in the history
Move most things out of `Traits`
  • Loading branch information
cwpearson authored Jun 5, 2024
2 parents fe1b55b + 0f2e288 commit 2968558
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 50 deletions.
4 changes: 2 additions & 2 deletions src/KokkosComm_pack_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ template <KokkosView View>
struct PackTraits<View> {
using packer_type = Impl::Packer::DeepCopy<View>;

static bool needs_unpack(const View &v) { return !Traits<View>::is_contiguous(v); }
static bool needs_pack(const View &v) { return !Traits<View>::is_contiguous(v); }
static bool needs_unpack(const View &v) { return !KokkosComm::is_contiguous(v); }
static bool needs_pack(const View &v) { return !KokkosComm::is_contiguous(v); }
};

} // namespace KokkosComm
43 changes: 32 additions & 11 deletions src/KokkosComm_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,45 @@ struct Traits {
/*! \brief This can be specialized to do custom behavior for a particular view*/
template <KokkosView View>
struct Traits<View> {
// product of extents is span
static bool is_contiguous(const View &v) { return v.span_is_contiguous(); }

static auto data_handle(const View &v) { return v.data(); }

using non_const_packed_view_type =
Kokkos::View<typename View::non_const_data_type, typename View::array_layout, typename View::memory_space>;
using packed_view_type =
Kokkos::View<typename View::data_type, typename View::array_layout, typename View::memory_space>;
};

static size_t span(const View &v) { return v.span(); }
template <KokkosView View>
auto data_handle(const View &v) {
return v.data();
}

static size_t extent(const View &v, const int i) { return v.extent(i); }
static size_t stride(const View &v, const int i) { return v.stride(i); }
template <KokkosView View>
auto span(const View &v) {
return v.span();
}

static constexpr bool is_reference_counted() { return true; }
// true iff product of extents is span
template <KokkosView View>
bool is_contiguous(const View &v) {
return v.span_is_contiguous();
}

static constexpr size_t rank() { return View::rank; }
};
template <KokkosView View>
constexpr size_t rank() {
return View::rank;
}

template <KokkosView View>
size_t extent(const View &v, const int i) {
return v.extent(i);
}
template <KokkosView View>
size_t stride(const View &v, const int i) {
return v.stride(i);
}

template <KokkosView View>
constexpr bool is_reference_counted() {
return true;
}

} // namespace KokkosComm
16 changes: 7 additions & 9 deletions src/impl/KokkosComm_allgather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,21 @@ template <KokkosView SendView, KokkosView RecvView>
void allgather(const SendView &sv, const RecvView &rv, MPI_Comm comm) {
Kokkos::Tools::pushRegion("KokkosComm::Impl::allgather");

using ST = KokkosComm::Traits<SendView>;
using RT = KokkosComm::Traits<RecvView>;
using SendScalar = typename SendView::value_type;
using RecvScalar = typename RecvView::value_type;

static_assert(ST::rank() <= 1, "allgather for SendView::rank > 1 not supported");
static_assert(RT::rank() <= 1, "allgather for RecvView::rank > 1 not supported");
static_assert(KokkosComm::rank<SendView>() <= 1, "allgather for SendView::rank > 1 not supported");
static_assert(KokkosComm::rank<RecvView>() <= 1, "allgather for RecvView::rank > 1 not supported");

if (!ST::is_contiguous(sv)) {
if (!KokkosComm::is_contiguous(sv)) {
throw std::runtime_error("low-level allgather requires contiguous send view");
}
if (!RT::is_contiguous(rv)) {
if (!KokkosComm::is_contiguous(rv)) {
throw std::runtime_error("low-level allgather requires contiguous recv view");
}
const int count = ST::span(sv); // all ranks send/recv same count
MPI_Allgather(ST::data_handle(sv), count, mpi_type_v<SendScalar>, RT::data_handle(rv), count, mpi_type_v<RecvScalar>,
comm);
const int count = KokkosComm::span(sv); // all ranks send/recv same count
MPI_Allgather(KokkosComm::data_handle(sv), count, mpi_type_v<SendScalar>, KokkosComm::data_handle(rv), count,
mpi_type_v<RecvScalar>, comm);

Kokkos::Tools::popRegion();
}
Expand Down
19 changes: 8 additions & 11 deletions src/impl/KokkosComm_alltoall.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,34 +49,32 @@ void alltoall(const ExecSpace &space, const SendView &sv, const size_t sendCount
const size_t recvCount, MPI_Comm comm) {
Kokkos::Tools::pushRegion("KokkosComm::Impl::alltoall");

using ST = KokkosComm::Traits<SendView>;
using RT = KokkosComm::Traits<RecvView>;
using SendScalar = typename SendView::value_type;
using RecvScalar = typename RecvView::value_type;

static_assert(ST::rank() <= 1, "alltoall for SendView::rank > 1 not supported");
static_assert(RT::rank() <= 1, "alltoall for RecvView::rank > 1 not supported");
static_assert(KokkosComm::rank<SendView>() <= 1, "alltoall for SendView::rank > 1 not supported");
static_assert(KokkosComm::rank<RecvView>() <= 1, "alltoall for RecvView::rank > 1 not supported");

if (KokkosComm::PackTraits<SendView>::needs_pack(sv) || KokkosComm::PackTraits<RecvView>::needs_pack(rv)) {
throw std::runtime_error("alltoall for non-contiguous views not implemented");
} else {
int size;
MPI_Comm_size(comm, &size);

if (sendCount * size > ST::extent(sv, 0)) {
if (sendCount * size > KokkosComm::extent(sv, 0)) {
std::stringstream ss;
ss << "alltoall sendCount * communicator size (" << sendCount << " * " << size
<< ") is greater than send view size";
throw std::runtime_error(ss.str());
}
if (recvCount * size > RT::extent(rv, 0)) {
if (recvCount * size > KokkosComm::extent(rv, 0)) {
std::stringstream ss;
ss << "alltoall recvCount * communicator size (" << recvCount << " * " << size
<< ") is greater than recv view size";
throw std::runtime_error(ss.str());
}

MPI_Alltoall(ST::data_handle(sv), sendCount, mpi_type_v<SendScalar>, RT::data_handle(rv), recvCount,
MPI_Alltoall(KokkosComm::data_handle(sv), sendCount, mpi_type_v<SendScalar>, KokkosComm::data_handle(rv), recvCount,
mpi_type_v<RecvScalar>, comm);
}

Expand All @@ -88,25 +86,24 @@ template <KokkosExecutionSpace ExecSpace, KokkosView RecvView>
void alltoall(const ExecSpace &space, const RecvView &rv, const size_t recvCount, MPI_Comm comm) {
Kokkos::Tools::pushRegion("KokkosComm::Impl::alltoall");

using RT = KokkosComm::Traits<RecvView>;
using RecvScalar = typename RecvView::value_type;

static_assert(RT::rank() <= 1, "alltoall for RecvView::rank > 1 not supported");
static_assert(RecvView::rank <= 1, "alltoall for RecvView::rank > 1 not supported");

if (KokkosComm::PackTraits<RecvView>::needs_pack(rv)) {
throw std::runtime_error("alltoall for non-contiguous views not implemented");
} else {
int size;
MPI_Comm_size(comm, &size);

if (recvCount * size > RT::extent(rv, 0)) {
if (recvCount * size > KokkosComm::extent(rv, 0)) {
std::stringstream ss;
ss << "alltoall recvCount * communicator size (" << recvCount << " * " << size
<< ") is greater than recv view size";
throw std::runtime_error(ss.str());
}

MPI_Alltoall(MPI_IN_PLACE, 0 /*ignored*/, MPI_BYTE /*ignored*/, RT::data_handle(rv), recvCount,
MPI_Alltoall(MPI_IN_PLACE, 0 /*ignored*/, MPI_BYTE /*ignored*/, KokkosComm::data_handle(rv), recvCount,
mpi_type_v<RecvScalar>, comm);
}

Expand Down
4 changes: 2 additions & 2 deletions src/impl/KokkosComm_irecv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ void irecv(RecvView &rv, int src, int tag, MPI_Comm comm, MPI_Request &req) {

using KCT = KokkosComm::Traits<RecvView>;

if (KCT::is_contiguous(rv)) {
if (KokkosComm::is_contiguous(rv)) {
using RecvScalar = typename RecvView::value_type;
MPI_Irecv(KCT::data_handle(rv), KCT::span(rv), mpi_type_v<RecvScalar>, src, tag, comm, &req);
MPI_Irecv(KokkosComm::data_handle(rv), KokkosComm::span(rv), mpi_type_v<RecvScalar>, src, tag, comm, &req);
} else {
throw std::runtime_error("Only contiguous irecv viewsupported");
}
Expand Down
11 changes: 6 additions & 5 deletions src/impl/KokkosComm_isend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ void isend(const SendView &sv, int dest, int tag, MPI_Comm comm, MPI_Request &re
Kokkos::Tools::pushRegion("KokkosComm::Impl::isend");
using KCT = typename KokkosComm::Traits<SendView>;

if (KCT::is_contiguous(sv)) {
if (KokkosComm::is_contiguous(sv)) {
using SendScalar = typename SendView::non_const_value_type;
MPI_Isend(KCT::data_handle(sv), KCT::span(sv), mpi_type_v<SendScalar>, dest, tag, comm, &req);
MPI_Isend(KokkosComm::data_handle(sv), KokkosComm::span(sv), mpi_type_v<SendScalar>, dest, tag, comm, &req);
} else {
throw std::runtime_error("only contiguous views supported for low-level isend");
}
Expand Down Expand Up @@ -77,13 +77,14 @@ KokkosComm::Req isend(const ExecSpace &space, const SendView &sv, int dest, int

MpiArgs args = Packer::pack(space, sv);
space.fence();
mpi_isend_fn(KCT::data_handle(args.view), args.count, args.datatype, dest, tag, comm, &req.mpi_req());
mpi_isend_fn(KokkosComm::data_handle(args.view), args.count, args.datatype, dest, tag, comm, &req.mpi_req());
req.keep_until_wait(args.view);
} else {
using SendScalar = typename SendView::value_type;
space.fence(); // can't issue isend until work in space is complete
mpi_isend_fn(KCT::data_handle(sv), KCT::span(sv), mpi_type_v<SendScalar>, dest, tag, comm, &req.mpi_req());
if (KCT::is_reference_counted()) {
mpi_isend_fn(KokkosComm::data_handle(sv), KokkosComm::span(sv), mpi_type_v<SendScalar>, dest, tag, comm,
&req.mpi_req());
if (KokkosComm::is_reference_counted<SendView>()) {
req.keep_until_wait(sv);
}
}
Expand Down
12 changes: 7 additions & 5 deletions src/impl/KokkosComm_packer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,22 @@ struct MpiArgs {
template <KokkosView View>
struct DeepCopy {
using non_const_packed_view_type =
Kokkos::View<typename View::non_const_data_type, Kokkos::LayoutRight, typename View::memory_space>;
Kokkos::View<typename View::non_const_data_type, Kokkos::LayoutLeft, typename View::memory_space>;
using args_type = MpiArgs<non_const_packed_view_type>;

template <KokkosExecutionSpace ExecSpace>
static args_type allocate_packed_for(const ExecSpace &space, const std::string &label, const View &src) {
using KCT = KokkosComm::Traits<View>;

if constexpr (KCT::rank() == 1) {
if constexpr (KokkosComm::rank<View>() == 1) {
non_const_packed_view_type packed(Kokkos::view_alloc(space, Kokkos::WithoutInitializing, label), src.extent(0));
return args_type(packed, MPI_PACKED, KCT::span(packed) * sizeof(typename non_const_packed_view_type::value_type));
} else if constexpr (KCT::rank() == 2) {
return args_type(packed, MPI_PACKED,
KokkosComm::span(packed) * sizeof(typename non_const_packed_view_type::value_type));
} else if constexpr (KokkosComm::rank<View>() == 2) {
non_const_packed_view_type packed(Kokkos::view_alloc(space, Kokkos::WithoutInitializing, label), src.extent(0),
src.extent(1));
return args_type(packed, MPI_PACKED, KCT::span(packed) * sizeof(typename non_const_packed_view_type::value_type));
return args_type(packed, MPI_PACKED,
KokkosComm::span(packed) * sizeof(typename non_const_packed_view_type::value_type));
} else {
static_assert(std::is_void_v<View>, "allocate_packed_for for rank >= 2 views unimplemented");
}
Expand Down
6 changes: 3 additions & 3 deletions src/impl/KokkosComm_recv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ void recv(const RecvView &rv, int src, int tag, MPI_Comm comm, MPI_Status *statu
Kokkos::Tools::pushRegion("KokkosComm::Impl::recv");
using KCT = KokkosComm::Traits<RecvView>;

if (KCT::is_contiguous(rv)) {
if (KokkosComm::is_contiguous(rv)) {
using ScalarType = typename RecvView::non_const_value_type;
MPI_Recv(KCT::data_handle(rv), KCT::span(rv), mpi_type_v<ScalarType>, src, tag, comm, status);
MPI_Recv(KokkosComm::data_handle(rv), KokkosComm::span(rv), mpi_type_v<ScalarType>, src, tag, comm, status);
} else {
throw std::runtime_error("only contiguous views supported for low-level recv");
}
Expand All @@ -54,7 +54,7 @@ void recv(const ExecSpace &space, RecvView &rv, int src, int tag, MPI_Comm comm)

Args args = Packer::allocate_packed_for(space, "packed", rv);
space.fence(); // make sure allocation is complete before recv
MPI_Recv(KCT::data_handle(args.view), args.count, args.datatype, src, tag, comm, MPI_STATUS_IGNORE);
MPI_Recv(KokkosComm::data_handle(args.view), args.count, args.datatype, src, tag, comm, MPI_STATUS_IGNORE);
Packer::unpack_into(space, rv, args.view);
} else {
using RecvScalar = typename RecvView::value_type;
Expand Down
4 changes: 2 additions & 2 deletions src/impl/KokkosComm_send.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ void send(const SendView &sv, int dest, int tag, MPI_Comm comm) {
Kokkos::Tools::pushRegion("KokkosComm::Impl::send");
using KCT = typename KokkosComm::Traits<SendView>;

if (KCT::is_contiguous(sv)) {
if (KokkosComm::is_contiguous(sv)) {
using SendScalar = typename SendView::non_const_value_type;
MPI_Send(KCT::data_handle(sv), KCT::span(sv), mpi_type_v<SendScalar>, dest, tag, comm);
MPI_Send(KokkosComm::data_handle(sv), KokkosComm::span(sv), mpi_type_v<SendScalar>, dest, tag, comm);
} else {
throw std::runtime_error("only contiguous views supported for low-level send");
}
Expand Down

0 comments on commit 2968558

Please sign in to comment.