Skip to content

Commit

Permalink
Removes impl from ygm::container::array
Browse files Browse the repository at this point in the history
  • Loading branch information
steiltre committed Dec 21, 2023
1 parent 0baced1 commit 1bddb54
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 256 deletions.
86 changes: 48 additions & 38 deletions include/ygm/container/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,39 @@

#pragma once

#include <ygm/container/detail/array_impl.hpp>
#include <ygm/comm.hpp>
#include <ygm/container/container_traits.hpp>

namespace ygm::container {

template <typename Value, typename Index = size_t>
class array {
public:
using self_type = array<Value, Index>;
using mapped_type = Value;
using key_type = Index;
using size_type = Index;
using ygm_for_all_types = std::tuple< Index, Value >;
using ygm_container_type = ygm::container::array_tag;
using impl_type = detail::array_impl<mapped_type, key_type>;
using self_type = array<Value, Index>;
using mapped_type = Value;
using key_type = Index;
using size_type = Index;
using ygm_for_all_types = std::tuple<Index, Value>;
using ygm_container_type = ygm::container::array_tag;
using ptr_type = typename ygm::ygm_ptr<self_type>;

array() = delete;

array(ygm::comm& comm, const size_type size) : m_impl(comm, size) {}
array(ygm::comm& comm, const size_type size);

array(ygm::comm& comm, const size_type size, const mapped_type& default_value)
: m_impl(comm, size, default_value) {}
array(ygm::comm& comm, const size_type size,
const mapped_type& default_value);

array(const self_type& rhs) : m_impl(rhs.m_impl) {}
array(const self_type& rhs);

void async_set(const key_type index, const mapped_type& value) {
m_impl.async_set(index, value);
}
~array();

void async_set(const key_type index, const mapped_type& value);

template <typename BinaryOp>
void async_binary_op_update_value(const key_type index,
void async_binary_op_update_value(const key_type index,
const mapped_type& value,
const BinaryOp& b) {
m_impl.async_binary_op_update_value(index, value, b);
}
const BinaryOp& b);

void async_bit_and(const key_type index, const mapped_type& value) {
async_binary_op_update_value(index, value, std::bit_and<mapped_type>());
Expand Down Expand Up @@ -78,9 +76,7 @@ class array {
}

template <typename UnaryOp>
void async_unary_op_update_value(const key_type index, const UnaryOp& u) {
m_impl.async_unary_op_update_value(index, u);
}
void async_unary_op_update_value(const key_type index, const UnaryOp& u);

void async_increment(const key_type index) {
async_unary_op_update_value(index,
Expand All @@ -94,32 +90,46 @@ class array {

template <typename Visitor, typename... VisitorArgs>
void async_visit(const key_type index, Visitor visitor,
const VisitorArgs&... args) {
m_impl.async_visit(index, visitor,
std::forward<const VisitorArgs>(args)...);
}
const VisitorArgs&... args);

template <typename Function>
void for_all(Function fn) {
m_impl.for_all(fn);
}
void for_all(Function fn);

size_type size() { return m_impl.size(); }
size_type size();

typename ygm::ygm_ptr<impl_type> get_ygm_ptr() const {
return m_impl.get_ygm_ptr();
}
typename ygm::ygm_ptr<self_type> get_ygm_ptr() const;

int owner(const key_type index) const;

int owner(const key_type index) const { return m_impl.owner(index); }
bool is_mine(const key_type index) const;

bool is_mine(const key_type index) const { return m_impl.is_mine(index); }
ygm::comm& comm();

ygm::comm& comm() { return m_impl.comm(); }
const mapped_type& default_value() const;

const mapped_type& default_value() const { return m_impl.default_value(); }
void resize(const size_type size, const mapped_type& fill_value);

void resize(const size_type size);

private:
impl_type m_impl;
template <typename Function>
void local_for_all(Function fn);

key_type local_index(key_type index);

key_type global_index(key_type index);

private:
size_type m_global_size;
size_type m_small_block_size;
size_type m_large_block_size;
size_type m_local_start_index;
mapped_type m_default_value;
std::vector<mapped_type> m_local_vec;
ygm::comm& m_comm;
typename ygm::ygm_ptr<self_type> pthis;
};

} // namespace ygm::container

#include <ygm/container/detail/array.ipp>
230 changes: 230 additions & 0 deletions include/ygm/container/detail/array.ipp
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
// Copyright 2019-2021 Lawrence Livermore National Security, LLC and other YGM
// Project Developers. See the top-level COPYRIGHT file for details.
//
// SPDX-License-Identifier: MIT

#pragma once

namespace ygm::container {

template <typename Value, typename Index>
array<Value, Index>::array(ygm::comm &comm, const size_type size)
: m_global_size(size), m_default_value{}, m_comm(comm), pthis(this) {
pthis.check(m_comm);

resize(size);
}

template <typename Value, typename Index>
array<Value, Index>::array(ygm::comm &comm, const size_type size,
const mapped_type &dv)
: m_default_value(dv), m_comm(comm), pthis(this) {
pthis.check(m_comm);

resize(size);
}

template <typename Value, typename Index>
array<Value, Index>::array(const self_type &rhs)
: m_default_value(rhs.m_default_value),
m_comm(rhs.m_comm),
m_global_size(rhs.m_global_size),
m_small_block_size(rhs.m_small_block_size),
m_large_block_size(rhs.m_large_block_size),
m_local_start_index(rhs.m_local_start_index),
m_local_vec(rhs.m_local_vec),
pthis(this) {
pthis.check(m_comm);
}

template <typename Value, typename Index>
array<Value, Index>::~array() {
m_comm.barrier();
}

template <typename Value, typename Index>
void array<Value, Index>::resize(const size_type size,
const mapped_type &fill_value) {
m_comm.barrier();

m_global_size = size;
m_small_block_size = size / m_comm.size();
m_large_block_size = m_small_block_size + ((size / m_comm.size()) > 0);

m_local_vec.resize(
m_small_block_size + (m_comm.rank() < (size % m_comm.size())),
fill_value);

if (m_comm.rank() < (size % m_comm.size())) {
m_local_start_index = m_comm.rank() * m_large_block_size;
} else {
m_local_start_index =
(size % m_comm.size()) * m_large_block_size +
(m_comm.rank() - (size % m_comm.size())) * m_small_block_size;
}

m_comm.barrier();
}

template <typename Value, typename Index>
void array<Value, Index>::resize(const size_type size) {
resize(size, m_default_value);
}

template <typename Value, typename Index>
void array<Value, Index>::async_set(const key_type index,
const mapped_type &value) {
ASSERT_RELEASE(index < m_global_size);
auto putter = [](auto parray, const key_type i, const mapped_type &v) {
key_type l_index = parray->local_index(i);
ASSERT_RELEASE(l_index < parray->m_local_vec.size());
parray->m_local_vec[l_index] = v;
};

int dest = owner(index);
m_comm.async(dest, putter, pthis, index, value);
}

template <typename Value, typename Index>
template <typename BinaryOp>
void array<Value, Index>::async_binary_op_update_value(const key_type index,
const mapped_type &value,
const BinaryOp &b) {
ASSERT_RELEASE(index < m_global_size);
auto updater = [](const key_type i, mapped_type &v,
const mapped_type &new_value) {
BinaryOp *binary_op;
v = (*binary_op)(v, new_value);
};

async_visit(index, updater, value);
}
template <typename Value, typename Index>
template <typename UnaryOp>
void array<Value, Index>::async_unary_op_update_value(const key_type index,
const UnaryOp &u) {
ASSERT_RELEASE(index < m_global_size);
auto updater = [](const key_type i, mapped_type &v) {
UnaryOp *u;
v = (*u)(v);
};

async_visit(index, updater);
}

template <typename Value, typename Index>
template <typename Visitor, typename... VisitorArgs>
void array<Value, Index>::async_visit(const key_type index, Visitor visitor,
const VisitorArgs &...args) {
ASSERT_RELEASE(index < m_global_size);
int dest = owner(index);
auto visit_wrapper = [](auto parray, const key_type i,
const VisitorArgs &...args) {
key_type l_index = parray->local_index(i);
ASSERT_RELEASE(l_index < parray->m_local_vec.size());
mapped_type &l_value = parray->m_local_vec[l_index];
Visitor *vis = nullptr;
if constexpr (std::is_invocable<decltype(visitor), const key_type &,
mapped_type &, VisitorArgs &...>() ||
std::is_invocable<decltype(visitor), ptr_type,
const key_type &, mapped_type &,
VisitorArgs &...>()) {
ygm::meta::apply_optional(*vis, std::make_tuple(parray),
std::forward_as_tuple(i, l_value, args...));
} else {
static_assert(
ygm::detail::always_false<>,
"remote array lambda signature must be invocable with (const "
"&key_type, mapped_type&, ...) or (ptr_type, const "
"&key_type, mapped_type&, ...) signatures");
}
};

m_comm.async(dest, visit_wrapper, pthis, index,
std::forward<const VisitorArgs>(args)...);
}

template <typename Value, typename Index>
template <typename Function>
void array<Value, Index>::for_all(Function fn) {
m_comm.barrier();
local_for_all(fn);
}

template <typename Value, typename Index>
template <typename Function>
void array<Value, Index>::local_for_all(Function fn) {
if constexpr (std::is_invocable<decltype(fn), const key_type,
mapped_type &>()) {
for (int i = 0; i < m_local_vec.size(); ++i) {
key_type g_index = global_index(i);
fn(g_index, m_local_vec[i]);
}
} else if constexpr (std::is_invocable<decltype(fn), mapped_type &>()) {
std::for_each(std::begin(m_local_vec), std::end(m_local_vec), fn);
} else {
static_assert(ygm::detail::always_false<>,
"local array lambda must be invocable with (const "
"key_type, mapped_type &) or (mapped_type &) signatures");
}
}

template <typename Value, typename Index>
typename array<Value, Index>::size_type array<Value, Index>::size() {
return m_global_size;
}

template <typename Value, typename Index>
typename array<Value, Index>::ptr_type array<Value, Index>::get_ygm_ptr()
const {
return pthis;
}

template <typename Value, typename Index>
ygm::comm &array<Value, Index>::comm() {
return m_comm;
}

template <typename Value, typename Index>
const typename array<Value, Index>::mapped_type &
array<Value, Index>::default_value() const {
return m_default_value;
}

template <typename Value, typename Index>
int array<Value, Index>::owner(const key_type index) const {
int to_return;
// Owner depends on whether index is before switching to small blocks
if (index < (m_global_size % m_comm.size()) * m_large_block_size) {
to_return = index / m_large_block_size;
} else {
to_return = (m_global_size % m_comm.size()) +
(index - (m_global_size % m_comm.size()) * m_large_block_size) /
m_small_block_size;
}
ASSERT_RELEASE((to_return >= 0) && (to_return < m_comm.size()));

return to_return;
}

template <typename Value, typename Index>
bool array<Value, Index>::is_mine(const key_type index) const {
return owner(index) == m_comm.rank();
}

template <typename Value, typename Index>
typename array<Value, Index>::key_type array<Value, Index>::local_index(
const key_type index) {
key_type to_return = index - m_local_start_index;
ASSERT_RELEASE((to_return >= 0) && (to_return <= m_small_block_size));
return to_return;
}

template <typename Value, typename Index>
typename array<Value, Index>::key_type array<Value, Index>::global_index(
const key_type index) {
key_type to_return;
return m_local_start_index + index;
}

}; // namespace ygm::container
Loading

0 comments on commit 1bddb54

Please sign in to comment.