Skip to content

Commit

Permalink
Merge pull request #39 from cwpearson/mpi/allgather
Browse files Browse the repository at this point in the history
mpi: contiguous MPI_Allgather
  • Loading branch information
cwpearson authored May 29, 2024
2 parents 799dabf + cd8f4ff commit fe1b55b
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 0 deletions.
16 changes: 16 additions & 0 deletions docs/api/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@ Collective
:tparam RecvView: A Kokkos::View to recv
:tparam ExecSpace: A Kokkos execution space to operate in


.. cpp:function:: template <KokkosExecutionSpace ExecSpace, KokkosView SendView, KokkosView RecvView> \
void KokkosComm::allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, MPI_Comm comm)

MPI_Allgather wrapper

:param space: The execution space to operate in
:param sv: The data to send
:param rv: The view to receive into
:param comm: the MPI communicator
:tparam SendView: A Kokkos::View to send. Contiguous and rank less than 2.
:tparam RecvView: A Kokkos::View to recv. Contiguous and rank 1.
:tparam ExecSpace: A Kokkos execution space to operate in

If ``sv`` is a rank-0 view, the value from the jth rank will be placed in index j of ``rv``.

Related Types
-------------

Expand Down
6 changes: 6 additions & 0 deletions src/KokkosComm_collective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "KokkosComm_concepts.hpp"
#include "KokkosComm_alltoall.hpp"
#include "KokkosComm_reduce.hpp"
#include "KokkosComm_allgather.hpp"

namespace KokkosComm {

Expand All @@ -29,4 +30,9 @@ void reduce(const ExecSpace &space, const SendView &sv, const RecvView &rv, MPI_
return Impl::reduce(space, sv, rv, op, root, comm);
}

template <KokkosView SendView, KokkosView RecvView, KokkosExecutionSpace ExecSpace>
void allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, MPI_Comm comm) {
return Impl::allgather(space, sv, rv, comm);
}

} // namespace KokkosComm
70 changes: 70 additions & 0 deletions src/impl/KokkosComm_allgather.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#pragma once

#include <Kokkos_Core.hpp>

#include "KokkosComm_pack_traits.hpp"
#include "KokkosComm_traits.hpp"

// impl
#include "KokkosComm_include_mpi.hpp"
#include "KokkosComm_types.hpp"

namespace KokkosComm::Impl {

template <KokkosView SendView, KokkosView RecvView>
void allgather(const SendView &sv, const RecvView &rv, MPI_Comm comm) {
Kokkos::Tools::pushRegion("KokkosComm::Impl::allgather");

using ST = KokkosComm::Traits<SendView>;
using RT = KokkosComm::Traits<RecvView>;
using SendScalar = typename SendView::value_type;
using RecvScalar = typename RecvView::value_type;

static_assert(ST::rank() <= 1, "allgather for SendView::rank > 1 not supported");
static_assert(RT::rank() <= 1, "allgather for RecvView::rank > 1 not supported");

if (!ST::is_contiguous(sv)) {
throw std::runtime_error("low-level allgather requires contiguous send view");
}
if (!RT::is_contiguous(rv)) {
throw std::runtime_error("low-level allgather requires contiguous recv view");
}
const int count = ST::span(sv); // all ranks send/recv same count
MPI_Allgather(ST::data_handle(sv), count, mpi_type_v<SendScalar>, RT::data_handle(rv), count, mpi_type_v<RecvScalar>,
comm);

Kokkos::Tools::popRegion();
}

template <KokkosExecutionSpace ExecSpace, KokkosView SendView, KokkosView RecvView>
void allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, MPI_Comm comm) {
Kokkos::Tools::pushRegion("KokkosComm::Impl::allgather");
using SPT = KokkosComm::PackTraits<SendView>;
using RPT = KokkosComm::PackTraits<RecvView>;

if (SPT::needs_pack(sv) || RPT::needs_pack(rv)) {
throw std::runtime_error("allgather for non-contiguous views not implemented");
} else {
space.fence(); // work in space may have been used to produce send view data
allgather(sv, rv, comm);
}

Kokkos::Tools::popRegion();
}
} // namespace KokkosComm::Impl
1 change: 1 addition & 0 deletions unit_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ add_executable(test-main test_main.cpp
test_barrier.cpp
test_alltoall.cpp
test_reduce.cpp
test_allgather.cpp
)
target_link_libraries(test-main KokkosComm::KokkosComm gtest)
if(KOKKOSCOMM_ENABLE_TESTS)
Expand Down
82 changes: 82 additions & 0 deletions unit_tests/test_allgather.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#include <gtest/gtest.h>

#include "KokkosComm.hpp"

template <typename T>
class Allgather : public testing::Test {
public:
using Scalar = T;
};

using ScalarTypes = ::testing::Types<int, int64_t, float, double, Kokkos::complex<float>, Kokkos::complex<double>>;
TYPED_TEST_SUITE(Allgather, ScalarTypes);

TYPED_TEST(Allgather, 0D) {
using TestScalar = typename TestFixture::Scalar;

int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);

const int nContrib = 10;

Kokkos::View<TestScalar> sv("sv");
Kokkos::View<TestScalar *> rv("rv", size);

// fill send buffer
Kokkos::parallel_for(
sv.extent(0), KOKKOS_LAMBDA(const int i) { sv() = rank; });

KokkosComm::allgather(Kokkos::DefaultExecutionSpace(), sv, rv, MPI_COMM_WORLD);

int errs;
Kokkos::parallel_reduce(
rv.extent(0), KOKKOS_LAMBDA(const int &src, int &lsum) { lsum += rv(src) != src; }, errs);
EXPECT_EQ(errs, 0);
}

TYPED_TEST(Allgather, 1D_contig) {
using TestScalar = typename TestFixture::Scalar;

int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);

const int nContrib = 10;

Kokkos::View<TestScalar *> sv("sv", nContrib);
Kokkos::View<TestScalar *> rv("rv", size * nContrib);

// fill send buffer
Kokkos::parallel_for(
sv.extent(0), KOKKOS_LAMBDA(const int i) { sv(i) = rank + i; });

KokkosComm::allgather(Kokkos::DefaultExecutionSpace(), sv, rv, MPI_COMM_WORLD);

int errs;
Kokkos::parallel_reduce(
rv.extent(0),
KOKKOS_LAMBDA(const int &i, int &lsum) {
const int src = i / nContrib;
const int j = i % nContrib;
lsum += rv(i) != src + j;
},
errs);
EXPECT_EQ(errs, 0);
}

0 comments on commit fe1b55b

Please sign in to comment.