Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/array partitioning #187

Merged
merged 4 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 52 additions & 30 deletions include/ygm/container/detail/array_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@
#pragma once
#include <vector>
#include <ygm/comm.hpp>
#include <ygm/container/container_traits.hpp>
#include <ygm/detail/ygm_ptr.hpp>
#include <ygm/detail/ygm_traits.hpp>
#include <ygm/container/container_traits.hpp>

namespace ygm::container::detail {

template <typename Value, typename Index>
class array_impl {
public:
using self_type = array_impl<Value, Index>;
using ptr_type = typename ygm::ygm_ptr<self_type>;
using mapped_type = Value;
using key_type = Index;
using size_type = Index;
using ygm_for_all_types = std::tuple< Index, Value >;
using ygm_container_type = ygm::container::array_tag;
using self_type = array_impl<Value, Index>;
using ptr_type = typename ygm::ygm_ptr<self_type>;
using mapped_type = Value;
using key_type = Index;
using size_type = Index;
using ygm_for_all_types = std::tuple<Index, Value>;
using ygm_container_type = ygm::container::array_tag;

array_impl(ygm::comm &comm, const size_type size)
: m_global_size(size), m_default_value{}, m_comm(comm), pthis(this) {
Expand Down Expand Up @@ -49,18 +49,21 @@ class array_impl {
void resize(const size_type size, const mapped_type &fill_value) {
m_comm.barrier();

m_global_size = size;
m_block_size = size / m_comm.size() + (size % m_comm.size() > 0);
m_global_size = size;
m_small_block_size = size / m_comm.size();
m_large_block_size = m_small_block_size + ((size / m_comm.size()) > 0);
m_comm.cout0(m_small_block_size, " : ", m_large_block_size);

if (m_comm.rank() != m_comm.size() - 1) {
m_local_vec.resize(m_block_size, fill_value);
m_local_vec.resize(
m_small_block_size + (m_comm.rank() < (size % m_comm.size())),
fill_value);

if (m_comm.rank() < (size % m_comm.size())) {
m_local_start_index = m_comm.rank() * m_large_block_size;
} else {
// Last rank may get less data
size_type block_size = m_global_size % m_block_size;
if (block_size == 0) {
block_size = m_block_size;
}
m_local_vec.resize(block_size, fill_value);
m_local_start_index =
(size % m_comm.size()) * m_large_block_size +
(m_comm.rank() - (size % m_comm.size())) * m_small_block_size;
}

m_comm.barrier();
Expand All @@ -81,9 +84,9 @@ class array_impl {
}

template <typename BinaryOp>
void async_binary_op_update_value(const key_type index,
void async_binary_op_update_value(const key_type index,
const mapped_type &value,
const BinaryOp &b) {
const BinaryOp &b) {
ASSERT_RELEASE(index < m_global_size);
auto updater = [](const key_type i, mapped_type &v,
const mapped_type &new_value) {
Expand Down Expand Up @@ -115,7 +118,7 @@ class array_impl {
key_type l_index = parray->local_index(i);
ASSERT_RELEASE(l_index < parray->m_local_vec.size());
mapped_type &l_value = parray->m_local_vec[l_index];
Visitor *vis = nullptr;
Visitor *vis = nullptr;
if constexpr (std::is_invocable<decltype(visitor), const key_type &,
mapped_type &, VisitorArgs &...>() ||
std::is_invocable<decltype(visitor), ptr_type,
Expand Down Expand Up @@ -167,28 +170,47 @@ class array_impl {

const mapped_type &default_value() const { return m_default_value; }

int owner(const key_type index) const { return index / m_block_size; }
int owner(const key_type index) const {
int to_return;
// Owner depends on whether index is before switching to small blocks
if (index < (m_global_size % m_comm.size()) * m_large_block_size) {
to_return = index / m_large_block_size;
} else {
to_return =
(m_global_size % m_comm.size()) +
(index - (m_global_size % m_comm.size()) * m_large_block_size) /
m_small_block_size;
}
ASSERT_RELEASE((to_return >= 0) && (to_return < m_comm.size()));

return to_return;
}

bool is_mine(const key_type index) const {
return owner(index) == m_comm.rank();
}

key_type local_index(const key_type index) {
return index % m_block_size;
key_type to_return = index - m_local_start_index;
ASSERT_RELEASE((to_return >= 0) && (to_return <= m_small_block_size));
return to_return;
}

key_type global_index(const key_type index) {
return m_comm.rank() * m_block_size + index;
key_type to_return;
return m_local_start_index + index;
}

protected:
array_impl() = delete;

size_type m_global_size;
size_type m_block_size;
mapped_type m_default_value;
std::vector<mapped_type> m_local_vec;
ygm::comm &m_comm;
typename ygm::ygm_ptr<self_type> pthis;
size_type m_global_size;
size_type m_small_block_size;
size_type m_large_block_size;
size_type m_local_start_index;
mapped_type m_default_value;
std::vector<mapped_type> m_local_vec;
ygm::comm &m_comm;
typename ygm::ygm_ptr<self_type> pthis;
};
} // namespace ygm::container::detail
20 changes: 17 additions & 3 deletions include/ygm/container/detail/bag.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,24 @@ void bag<Item, Alloc>::rebalance() {
// int to_send[m_comm.size()] = {0};
std::unordered_map<size_t, size_t> to_send;

auto global_size = size();
size_t small_block_size = global_size / m_comm.size();
size_t large_block_size =
global_size / m_comm.size() + ((global_size / m_comm.size()) > 0);

for (size_t i = 0; i < local_size(); i++) {
size_t idx = prefix_val + i;
size_t target_rank = idx / target_size;
size_t idx = prefix_val + i;
size_t target_rank;

// Determine target rank to match partitioning in ygm::container::array
if (idx < (global_size % m_comm.size()) * large_block_size) {
target_rank = idx / large_block_size;
} else {
target_rank = (global_size % m_comm.size()) +
(idx - (global_size % m_comm.size()) * large_block_size) /
small_block_size;
}

if (target_rank != m_comm.rank()) {
to_send[target_rank]++;
}
Expand Down Expand Up @@ -254,4 +269,3 @@ void bag<Item, Alloc>::local_for_all_pair_types(Function fn) {
}

} // namespace ygm::container

17 changes: 10 additions & 7 deletions test/test_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
int main(int argc, char **argv) {
ygm::comm world(&argc, &argv);

// Test basic tagging
// Test basic tagging
{
int size = 64;
ygm::container::array<int> arr(world, size);

static_assert(std::is_same_v< decltype(arr)::self_type, decltype(arr) >);
static_assert(std::is_same_v< decltype(arr)::mapped_type, decltype(size) >);
static_assert(std::is_same_v< decltype(arr)::key_type, size_t >);
static_assert(std::is_same_v< decltype(arr)::size_type, decltype(arr)::key_type >);
static_assert(std::is_same_v< decltype(arr)::ygm_for_all_types,
std::tuple< decltype(arr)::key_type, decltype(arr)::mapped_type > >);
static_assert(std::is_same_v<decltype(arr)::self_type, decltype(arr)>);
static_assert(std::is_same_v<decltype(arr)::mapped_type, decltype(size)>);
static_assert(std::is_same_v<decltype(arr)::key_type, size_t>);
static_assert(
std::is_same_v<decltype(arr)::size_type, decltype(arr)::key_type>);
static_assert(
std::is_same_v<
decltype(arr)::ygm_for_all_types,
std::tuple<decltype(arr)::key_type, decltype(arr)::mapped_type> >);
}

// Test async_set
Expand Down
67 changes: 37 additions & 30 deletions test/test_bag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
int main(int argc, char** argv) {
ygm::comm world(&argc, &argv);


// Test basic tagging
// Test basic tagging
{
ygm::container::bag<std::string> bbag(world);

static_assert(std::is_same_v< decltype(bbag)::self_type, decltype(bbag) >);
static_assert(std::is_same_v< decltype(bbag)::value_type, std::string >);
static_assert(std::is_same_v< decltype(bbag)::size_type, size_t >);
static_assert(std::is_same_v< decltype(bbag)::ygm_for_all_types,
std::tuple< decltype(bbag)::value_type > >);
static_assert(std::is_same_v<decltype(bbag)::self_type, decltype(bbag)>);
static_assert(std::is_same_v<decltype(bbag)::value_type, std::string>);
static_assert(std::is_same_v<decltype(bbag)::size_type, size_t>);
static_assert(std::is_same_v<decltype(bbag)::ygm_for_all_types,
std::tuple<decltype(bbag)::value_type>>);
}

//
Expand Down Expand Up @@ -56,19 +55,21 @@ int main(int argc, char** argv) {
// Test local_shuffle and global_shuffle
{
ygm::container::bag<int> bbag(world);
int num_of_items = 20;
int num_of_items = 20;
if (world.rank0()) {
for (int i = 0; i < num_of_items; i++) {
bbag.async_insert(i);
}
}
int seed = 100;
ygm::default_random_engine<> rng1 = ygm::default_random_engine<>(world, seed);
int seed = 100;
ygm::default_random_engine<> rng1 =
ygm::default_random_engine<>(world, seed);
bbag.local_shuffle(rng1);

ygm::default_random_engine<> rng2 = ygm::default_random_engine<>(world, seed);
ygm::default_random_engine<> rng2 =
ygm::default_random_engine<>(world, seed);
bbag.global_shuffle(rng2);

bbag.local_shuffle();
bbag.global_shuffle();

Expand All @@ -77,7 +78,8 @@ int main(int argc, char** argv) {
auto bag_content = bbag.gather_to_vector(0);
if (world.rank0()) {
for (int i = 0; i < num_of_items; i++) {
if (std::find(bag_content.begin(), bag_content.end(), i) == bag_content.end()) {
if (std::find(bag_content.begin(), bag_content.end(), i) ==
bag_content.end()) {
ASSERT_RELEASE(false);
}
}
Expand Down Expand Up @@ -139,7 +141,7 @@ int main(int argc, char** argv) {
{
ygm::container::bag<std::string> bbag(world);
bbag.async_insert("begin", 0);
bbag.async_insert("end", world.size()-1);
bbag.async_insert("end", world.size() - 1);
bbag.rebalance();
ASSERT_RELEASE(bbag.local_size() == 2);
}
Expand All @@ -148,18 +150,22 @@ int main(int argc, char** argv) {
// Test rebalance with non-standard rebalance sizes
{
ygm::container::bag<std::string> bbag(world);
bbag.async_insert("middle", world.size()/2);
bbag.async_insert("end", world.size()-1);
if (world.rank0())
bbag.async_insert("middle", world.size()/2);
bbag.async_insert("middle", world.size() / 2);
bbag.async_insert("end", world.size() - 1);
if (world.rank0()) bbag.async_insert("middle", world.size() / 2);
bbag.rebalance();

size_t target_size = std::ceil((bbag.size() * 1.0) / world.size());
size_t remainder = bbag.size() % target_size;
if (world.rank() != world.size() - 1)
ASSERT_RELEASE(bbag.local_size() == target_size);
else
ASSERT_RELEASE(bbag.local_size() == remainder);
size_t target_size = std::ceil((bbag.size() * 1.0) / world.size());
size_t remainder = bbag.size() % world.size();
size_t small_block_size = bbag.size() / world.size();
size_t large_block_size =
bbag.size() / world.size() + (bbag.size() % world.size() > 0);

if (world.rank() < remainder) {
ASSERT_RELEASE(bbag.local_size() == large_block_size);
} else {
ASSERT_RELEASE(bbag.local_size() == small_block_size);
}
}

//
Expand All @@ -168,18 +174,19 @@ int main(int argc, char** argv) {
ygm::container::bag<int> bbag(world);
if (world.rank0()) {
for (int i = 0; i < 100; i++) {
bbag.async_insert(i, (i*3) % world.size());
bbag.async_insert(i, (i * 3) % world.size());
}
for (int i = 100; i < 200; i++) {
bbag.async_insert(i, (i*5) % world.size());
bbag.async_insert(i, (i * 5) % world.size());
}
}
bbag.rebalance();
auto v = bbag.gather_to_vector();

auto v = bbag.gather_to_vector();
std::set<int> value_set(v.begin(), v.end());
ASSERT_RELEASE(value_set.size() == 200);
ASSERT_RELEASE(*std::min_element(value_set.begin(), value_set.end()) == 0);
ASSERT_RELEASE(*std::max_element(value_set.begin(), value_set.end()) == 199);
ASSERT_RELEASE(*std::max_element(value_set.begin(), value_set.end()) ==
199);
}
}
}