Skip to content

Commit

Permalink
ishmem, working distributed_vector put and get and most of its tests (o…
Browse files Browse the repository at this point in the history
  • Loading branch information
lslusarczyk authored Nov 13, 2023
1 parent 8a031f5 commit 4848345
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 93 deletions.
8 changes: 2 additions & 6 deletions benchmarks/gbench/mhp/mhp-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ int main(int argc, char *argv[]) {
exit(0);
}

std::ofstream *logfile = nullptr;
std::unique_ptr<std::ofstream> logfile;
if (options.count("log")) {
logfile = new std::ofstream(fmt::format("dr.{}.log", comm_rank));
logfile.reset(new std::ofstream(fmt::format("dr.{}.log", comm_rank)));
dr::drlog.set_file(*logfile);
}
dr::drlog.debug("Rank: {}\n", comm_rank);
Expand Down Expand Up @@ -130,10 +130,6 @@ int main(int argc, char *argv[]) {
}
benchmark::Shutdown();

if (logfile) {
delete logfile;
}

dr::mhp::finalize();
MPI_Finalize();

Expand Down
11 changes: 10 additions & 1 deletion include/dr/detail/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

#pragma once

#ifdef DRISHMEM
#include "ishmem.h"
#endif

namespace dr {

class communicator {
Expand Down Expand Up @@ -32,7 +36,12 @@ class communicator {

MPI_Comm mpi_comm() const { return mpi_comm_; }

void barrier() const { MPI_Barrier(mpi_comm_); }
void barrier() const {
#ifdef DRISHMEM
ishmem_barrier_all();
#endif
MPI_Barrier(mpi_comm_);
}

void bcast(void *src, std::size_t count, std::size_t root) const {
MPI_Bcast(src, count, MPI_BYTE, root, mpi_comm_);
Expand Down
17 changes: 16 additions & 1 deletion include/dr/mhp/containers/distributed_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,15 @@ template <typename T> class distributed_vector {
~distributed_vector() {
if (!finalized()) {
fence();
#ifdef DRISHMEM
drlog.debug("calling ishmem_free({})\n", static_cast<void *>(data_));
ishmem_free(data_);
#else
active_wins().erase(win_.mpi_win());
win_.free();
__detail::allocator<T>().deallocate(data_, data_size_);
data_ = nullptr;
#endif
delete halo_;
}
}
Expand Down Expand Up @@ -167,7 +172,14 @@ template <typename T> class distributed_vector {

data_size_ = segment_size_ + hb.prev + hb.next;
if (size_ > 0) {

#ifdef DRISHMEM
data_ = static_cast<T *>(ishmem_malloc(data_size_));
drlog.debug("called ishmem_malloc({}) -> got:{}\n", data_size_,
static_cast<void *>(data_));
#else
data_ = __detail::allocator<T>().allocate(data_size_);
#endif
}

halo_ = new span_halo<T>(default_comm(), data_, data_size_, hb);
Expand All @@ -177,9 +189,10 @@ template <typename T> class distributed_vector {
segments_.emplace_back(this, segment_index++,
std::min(segment_size_, size - i));
}

#ifndef DRISHMEM
win_.create(default_comm(), data_, data_size_ * sizeof(T));
active_wins().insert(win_.mpi_win());
#endif
fence();
}

Expand All @@ -193,7 +206,9 @@ template <typename T> class distributed_vector {
distribution distribution_;
std::size_t size_;
std::vector<dv_segment<distributed_vector>> segments_;
#ifndef DRISHMEM
dr::rma_window win_;
#endif
};

template <typename T> auto &halo(const distributed_vector<T> &dv) {
Expand Down
30 changes: 29 additions & 1 deletion include/dr/mhp/containers/segment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,19 @@ template <typename DV> class dv_segment_iterator {
assert(dv_ != nullptr);
assert(segment_index_ * dv_->segment_size_ + index_ < dv_->size());
auto segment_offset = index_ + dv_->distribution_.halo().prev;
#ifdef DRISHMEM
drlog.debug("calling ishmem_getmem(dst:{}, src:{} (= dv:{} + "
"segm_offset:{}), size:{}, peer:{})\n",
static_cast<void *>(dst),
static_cast<void const *>(this->dv_->data_ + segment_offset),
static_cast<void const *>(this->dv_->data_), segment_offset,
size * sizeof(*dst), segment_index_);
ishmem_getmem(dst, this->dv_->data_ + segment_offset, size * sizeof(*dst),
segment_index_);
#else
dv_->win_.get(dst, size * sizeof(*dst), segment_index_,
segment_offset * sizeof(*dst));
#endif
}

value_type get() const {
Expand All @@ -143,8 +154,19 @@ template <typename DV> class dv_segment_iterator {
auto segment_offset = index_ + dv_->distribution_.halo().prev;
dr::drlog.debug("dv put:: ({}:{}:{})\n", segment_index_, segment_offset,
size);
#ifdef DRISHMEM
drlog.debug("calling ishmem_putmem(dst:{} (= dv:{} + segm_offset:{}), "
"src:{}, size:{}, peer:{})\n",
static_cast<void *>(this->dv_->data_ + segment_offset),
static_cast<void *>(this->dv_->data_), segment_offset,
static_cast<void const *>(dst), size * sizeof(*dst),
segment_index_);
ishmem_putmem(this->dv_->data_ + segment_offset, dst, size * sizeof(*dst),
segment_index_);
#else
dv_->win_.put(dst, size * sizeof(*dst), segment_index_,
segment_offset * sizeof(*dst));
#endif
}

void put(const value_type &value) const { put(&value, 1); }
Expand All @@ -158,7 +180,13 @@ template <typename DV> class dv_segment_iterator {
#ifndef SYCL_LANGUAGE_VERSION
assert(dv_ != nullptr);
#endif
const auto my_process_segment_index = dv_->win_.communicator().rank();
const auto my_process_segment_index =
#ifdef DRISHMEM
ishmem_my_pe();
drlog.debug("called ishmem_my_pe() -> {}\n", my_process_segment_index);
#else
dv_->win_.communicator().rank();
#endif

if (my_process_segment_index == segment_index_)
return dv_->data_ + index_ + dv_->distribution_.halo().prev;
Expand Down
5 changes: 5 additions & 0 deletions include/dr/mhp/global.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ inline void initialize_mpi() {
}

#ifdef DRISHMEM
drlog.debug("calling ishmem_init()\n");
ishmem_init();
#endif
}
Expand All @@ -89,6 +90,7 @@ inline void finalize_mpi() {
}

#ifdef DRISHMEM
drlog.debug("calling ishmem_finalize()\n");
ishmem_finalize();
#endif
}
Expand All @@ -109,6 +111,9 @@ inline auto use_sycl() { return __detail::gcontext()->use_sycl_; }

inline void fence() {
dr::drlog.debug("fence\n");
#ifdef DRISHMEM
ishmem_barrier_all();
#endif
for (auto win : __detail::gcontext()->wins_) {
MPI_Win_fence(0, win);
}
Expand Down
22 changes: 20 additions & 2 deletions test/gtest/common/distributed_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,31 @@ TYPED_TEST(DistributedVectorAllTypes, StaticAsserts) {
static_assert(dr::distributed_contiguous_range<decltype(dv)>);
}

// gtest support
TYPED_TEST(DistributedVectorAllTypes, getAndPut) {
TypeParam dv(10);

if (comm_rank == 0) {
dv[5] = 13;
}
barrier();

for (std::size_t idx = 0; idx < 10; ++idx) {
auto val = dv[idx];
if (idx == 5) {
EXPECT_EQ(val, 13);
} else {
EXPECT_NE(val, 13);
}
}
}

TYPED_TEST(DistributedVectorAllTypes, Stream) {
Ops1<TypeParam> ops(10);
std::ostringstream os;
os << ops.dist_vec;
EXPECT_EQ(os.str(), "{ 100, 101, 102, 103, 104, 105, 106, 107, 108, 109 }");
}

// gtest support
TYPED_TEST(DistributedVectorAllTypes, Equality) {
Ops1<TypeParam> ops(10);
iota(ops.dist_vec, 100);
Expand Down Expand Up @@ -69,6 +85,7 @@ TEST(DistributedVector, ConstructorFill) {
EXPECT_EQ(local_vec, dist_vec);
}

#ifndef DRISHMEM
TEST(DistributedVector, ConstructorBasicAOS) {
OpsAOS ops(10);
EXPECT_EQ(ops.vec, ops.dist_vec);
Expand All @@ -81,3 +98,4 @@ TEST(DistributedVector, ConstructorFillAOS) {

EXPECT_EQ(local_vec, dist_vec);
}
#endif
2 changes: 2 additions & 0 deletions test/gtest/include/common-tests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct AOS_Struct {
int first, second;
};

#ifndef DRISHMEM
struct OpsAOS {

using dist_vec_type = xhp::distributed_vector<AOS_Struct>;
Expand All @@ -39,6 +40,7 @@ inline std::ostream &operator<<(std::ostream &os, const AOS_Struct &st) {
os << "[ " << st.first << " " << st.second << " ]";
return os;
}
#endif

template <typename T> struct Ops1 {
Ops1(std::size_t n) : dist_vec(n), vec(n) {
Expand Down
48 changes: 42 additions & 6 deletions test/gtest/mhp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,55 @@ endforeach()

if(ENABLE_ISHMEM)

foreach(shmem-test-name IN ITEMS mst msts)
add_executable(${shmem-test-name} shmem_basic.cpp) # mhp-tests.cpp
# distributed_vector.cpp
# mst stands for mhp-shmem-test
foreach(shmem-test-name IN ITEMS mst-basic mst)
add_executable(${shmem-test-name} mhp-tests.cpp)
target_link_libraries(
${shmem-test-name} GTest::gtest_main cxxopts DR::mpi ze_loader # fabric
pmi_simple sma)
target_link_libraries(${shmem-test-name}
${CMAKE_BINARY_DIR}/lib/libishmem.a)
target_compile_definitions(${shmem-test-name} PRIVATE DRISHMEM)
add_shmem_test(${shmem-test-name}-1 ${shmem-test-name} 1 --sycl)
add_shmem_test(${shmem-test-name}-2 ${shmem-test-name} 2 --sycl)
add_shmem_test(${shmem-test-name}-1 ${shmem-test-name} 1 --sycl --log)
add_shmem_test(${shmem-test-name}-2 ${shmem-test-name} 2 --sycl --log)
endforeach()
target_compile_definitions(msts PRIVATE STANDALONE_TEST)
target_sources(mst-basic PRIVATE shmem_basic.cpp)
# cmake-format: off
target_sources(
mst
PRIVATE
../common/all.cpp
../common/copy.cpp
../common/counted.cpp
../common/distributed_vector.cpp
../common/drop.cpp
../common/enumerate.cpp
# ../common/exclusive_scan.cpp # segfault
../common/fill.cpp
../common/for_each.cpp
# ../common/inclusive_scan.cpp # doesn't end
../common/iota.cpp
../common/iota_view.cpp
../common/reduce.cpp
# ../common/sort.cpp # not compile
../common/subrange.cpp
../common/sycl_utils.cpp
../common/take.cpp
../common/transform.cpp
../common/transform_view.cpp
../common/zip.cpp
../common/zip_local.cpp
alignment.cpp
communicator.cpp
copy.cpp
distributed_vector.cpp
# halo.cpp # segfault mdstar.cpp # segfault in Mdspan.GridLocalReference
# mhpsort.cpp # not compile
reduce.cpp
stencil.cpp
# segments.cpp slide_view.cpp # not compile wave_kernel.cpp # test fails
)
# cmake-format: on

endif()

Expand Down
8 changes: 2 additions & 6 deletions test/gtest/mhp/mhp-tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,15 @@ int main(int argc, char *argv[]) {
}

dr_init();
std::ofstream *logfile = nullptr;
std::unique_ptr<std::ofstream> logfile;
if (options.count("log")) {
logfile = new std::ofstream(fmt::format("dr.{}.log", comm_rank));
logfile.reset(new std::ofstream(fmt::format("dr.{}.log", comm_rank)));
dr::drlog.set_file(*logfile);
}
dr::drlog.debug("Rank: {}\n", comm_rank);

auto res = RUN_ALL_TESTS();

if (logfile) {
delete logfile;
}

dr::mhp::finalize();
MPI_Finalize();

Expand Down
Loading

0 comments on commit 4848345

Please sign in to comment.