Skip to content

Commit

Permalink
Add CG insert_and_find tests
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Nov 13, 2023
1 parent 845e977 commit d295ecd
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 29 deletions.
34 changes: 34 additions & 0 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,12 @@ class open_addressing_ref_impl {
__device__ thrust::pair<iterator, bool> insert_and_find(Value const& value) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
#if __CUDA_ARCH__ < 700
// Spinning to ensure that the write to the value part took place requires
// independent thread scheduling introduced with the Volta architecture.
static_assert(cuco::detail::is_packable<value_type>(),
"insert_and_find is not supported for unpackable data on pre-Volta GPUs.");
#endif

auto const key = this->extract_key(value);
auto probing_iter = probing_scheme_(key, storage_ref_.window_extent());
Expand All @@ -399,6 +405,7 @@ class open_addressing_ref_impl {
// If the key is already in the container, return false
if (eq_res == detail::equal_result::EQUAL) {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second);
}
return {iterator{&window_ptr[i]}, false};
Expand All @@ -418,12 +425,14 @@ class open_addressing_ref_impl {
switch (res) {
case insert_result::SUCCESS: {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second);
}
return {iterator{&window_ptr[i]}, true};
}
case insert_result::DUPLICATE: {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second);
}
return {iterator{&window_ptr[i]}, false};
Expand Down Expand Up @@ -455,6 +464,13 @@ class open_addressing_ref_impl {
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group, Value const& value) noexcept
{
#if __CUDA_ARCH__ < 700
// Spinning to ensure that the write to the value part took place requires
// independent thread scheduling introduced with the Volta architecture.
static_assert(cuco::detail::is_packable<value_type>(),
"insert_and_find is not supported for unpackable data on pre-Volta GPUs.");
#endif

auto const key = this->extract_key(value);
auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent());

Expand Down Expand Up @@ -489,6 +505,12 @@ class open_addressing_ref_impl {
if (group_finds_equal) {
auto const src_lane = __ffs(group_finds_equal) - 1;
auto const res = group.shfl(reinterpret_cast<intptr_t>(slot_ptr), src_lane);
if (group.thread_rank() == src_lane) {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second);
}
}
return {iterator{reinterpret_cast<value_type*>(res)}, false};
}

Expand All @@ -508,9 +530,21 @@ class open_addressing_ref_impl {

switch (group.shfl(status, src_lane)) {
case insert_result::SUCCESS: {
if (group.thread_rank() == src_lane) {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second);
}
}
return {iterator{reinterpret_cast<value_type*>(res)}, true};
}
case insert_result::DUPLICATE: {
if (group.thread_rank() == src_lane) {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second);
}
}
return {iterator{reinterpret_cast<value_type*>(res)}, false};
}
default: continue;
Expand Down
96 changes: 67 additions & 29 deletions tests/static_map/insert_and_find_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

#include <catch2/catch_template_test_macros.hpp>

static constexpr int Iters = 10;
static constexpr int Iters = 10'000;

template <typename Ref>
__global__ void parallel_sum(Ref v)
Expand All @@ -36,15 +36,27 @@ __global__ void parallel_sum(Ref v)
if constexpr (cuco::detail::is_packable<Ref::value_type>())
#endif
{
auto [iter, inserted] = v.insert_and_find(cuco::pair{i, 1});
// for debugging...
// if (iter->second < 0) {
// asm("trap;");
// }
if (!inserted) {
auto ref =
cuda::atomic_ref<typename Ref::mapped_type, cuda::thread_scope_device>{iter->second};
ref.fetch_add(1);
auto constexpr cg_size = Ref::cg_size;
if constexpr (cg_size == 1) {
auto [iter, inserted] = v.insert_and_find(cuco::pair{i, 1});
// for debugging...
// if (iter->second < 0) {
// asm("trap;");
// }
if (!inserted) {
auto ref =
cuda::atomic_ref<typename Ref::mapped_type, cuda::thread_scope_device>{iter->second};
ref.fetch_add(1);
}
} else {
auto const tile =
cooperative_groups::tiled_partition<cg_size>(cooperative_groups::this_thread_block());
auto [iter, inserted] = v.insert_and_find(tile, cuco::pair{i, 1});
if (!inserted and tile.thread_rank() == 0) {
auto ref =
cuda::atomic_ref<typename Ref::mapped_type, cuda::thread_scope_device>{iter->second};
ref.fetch_add(1);
}
}
}
#if __CUDA_ARCH__ < 700
Expand All @@ -55,36 +67,62 @@ __global__ void parallel_sum(Ref v)
}
}

TEMPLATE_TEST_CASE_SIG("Parallel insert-or-update",
"",
((typename Key, typename Value), Key, Value),
(int32_t, int32_t),
(int32_t, int64_t),
(int64_t, int32_t),
(int64_t, int64_t))
TEMPLATE_TEST_CASE_SIG(
"static_map insert_and_find tests",
"",
((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize),
Key,
Value,
Probe,
CGSize),
(int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1),
(int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1),
(int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2),
(int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2),
(int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1),
(int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1),
(int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2),
(int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2),
(int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1),
(int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1),
(int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2),
(int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2),
(int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1),
(int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1),
(int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2),
(int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2))
{
cuco::empty_key<Key> empty_key_sentinel{-1};
cuco::empty_value<Value> empty_value_sentinel{-1};
auto m = cuco::experimental::static_map{
10 * Iters,
empty_key_sentinel,
empty_value_sentinel,
thrust::equal_to<Key>{},
cuco::experimental::linear_probing<1, cuco::murmurhash3_32<Key>>{}};
using probe =
std::conditional_t<Probe == cuco::test::probe_sequence::linear_probing,
cuco::experimental::linear_probing<CGSize, cuco::murmurhash3_32<Key>>,
cuco::experimental::double_hashing<CGSize,
cuco::murmurhash3_32<Key>,
cuco::murmurhash3_32<Key>>>;

auto map = cuco::experimental::static_map<Key,
Value,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<Key>,
probe,
cuco::cuda_allocator<std::byte>,
cuco::experimental::storage<2>>{
10 * Iters, cuco::empty_key<Key>{-1}, cuco::empty_value<Value>{-1}};

static constexpr int Blocks = 1024;
static constexpr int Threads = 128;

parallel_sum<<<Blocks, Threads>>>(
m.ref(cuco::experimental::op::insert, cuco::experimental::op::insert_and_find));
map.ref(cuco::experimental::op::insert, cuco::experimental::op::insert_and_find));
CUCO_CUDA_TRY(cudaDeviceSynchronize());

thrust::device_vector<Key> d_keys(Iters);
thrust::device_vector<Value> d_values(Iters);

thrust::sequence(thrust::device, d_keys.begin(), d_keys.end());
m.find(d_keys.begin(), d_keys.end(), d_values.begin());
map.find(d_keys.begin(), d_keys.end(), d_values.begin());

REQUIRE(cuco::test::all_of(
d_values.begin(), d_values.end(), [] __device__(Value v) { return v == Blocks * Threads; }));
REQUIRE(cuco::test::all_of(d_values.begin(), d_values.end(), [] __device__(Value v) {
return v == (Blocks * Threads) / CGSize;
}));
}

0 comments on commit d295ecd

Please sign in to comment.