Skip to content

Commit

Permalink
Fix bugs in insert_and_find (#389)
Browse files Browse the repository at this point in the history
This PR fixes bugs in the new `insert_and_find` implementation where the
function should not return before the payload is updated. It also
migrates the related tests to test the new map.

---------

Co-authored-by: Daniel Juenger <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 18, 2023
1 parent 9018f69 commit 9966541
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 27 deletions.
79 changes: 76 additions & 3 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ class open_addressing_ref_impl {
__device__ thrust::pair<iterator, bool> insert_and_find(Value const& value) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
#if __CUDA_ARCH__ < 700
// Spinning to ensure that the write to the value part took place requires
// independent thread scheduling introduced with the Volta architecture.
static_assert(
cuco::detail::is_packable<value_type>(),
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
#endif

auto const key = this->extract_key(value);
auto probing_iter = probing_scheme_(key, storage_ref_.window_extent());
Expand All @@ -397,21 +404,36 @@ class open_addressing_ref_impl {
auto* window_ptr = (storage_ref_.data() + *probing_iter)->data();

// If the key is already in the container, return false
if (eq_res == detail::equal_result::EQUAL) { return {iterator{&window_ptr[i]}, false}; }
if (eq_res == detail::equal_result::EQUAL) {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second);
}
return {iterator{&window_ptr[i]}, false};
}
if (eq_res == detail::equal_result::EMPTY or
cuco::detail::bitwise_compare(this->extract_key(window_slots[i]),
this->erased_key_sentinel())) {
switch ([&]() {
auto const res = [&]() {
if constexpr (sizeof(value_type) <= 8) {
return packed_cas(window_ptr + i, window_slots[i], value);
} else {
return cas_dependent_write(window_ptr + i, window_slots[i], value);
}
}()) {
}();
switch (res) {
case insert_result::SUCCESS: {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second);
}
return {iterator{&window_ptr[i]}, true};
}
case insert_result::DUPLICATE: {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload((window_ptr + i)->second, this->empty_slot_sentinel_.second);
}
return {iterator{&window_ptr[i]}, false};
}
default: continue;
Expand Down Expand Up @@ -441,6 +463,14 @@ class open_addressing_ref_impl {
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group, Value const& value) noexcept
{
#if __CUDA_ARCH__ < 700
// Spinning to ensure that the write to the value part took place requires
// independent thread scheduling introduced with the Volta architecture.
static_assert(
cuco::detail::is_packable<value_type>(),
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
#endif

auto const key = this->extract_key(value);
auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent());

Expand Down Expand Up @@ -475,6 +505,13 @@ class open_addressing_ref_impl {
if (group_finds_equal) {
auto const src_lane = __ffs(group_finds_equal) - 1;
auto const res = group.shfl(reinterpret_cast<intptr_t>(slot_ptr), src_lane);
if (group.thread_rank() == src_lane) {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second);
}
}
group.sync();
return {iterator{reinterpret_cast<value_type*>(res)}, false};
}

Expand All @@ -494,9 +531,23 @@ class open_addressing_ref_impl {

switch (group.shfl(status, src_lane)) {
case insert_result::SUCCESS: {
if (group.thread_rank() == src_lane) {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second);
}
}
group.sync();
return {iterator{reinterpret_cast<value_type*>(res)}, true};
}
case insert_result::DUPLICATE: {
if (group.thread_rank() == src_lane) {
if constexpr (has_payload) {
// wait to ensure that the write to the value part also took place
this->wait_for_payload(slot_ptr->second, this->empty_slot_sentinel_.second);
}
}
group.sync();
return {iterator{reinterpret_cast<value_type*>(res)}, false};
}
default: continue;
Expand Down Expand Up @@ -1010,6 +1061,7 @@ class open_addressing_ref_impl {
// if key success
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
atomic_store(&address->second, static_cast<mapped_type>(thrust::get<1>(desired)));

return insert_result::SUCCESS;
}

Expand Down Expand Up @@ -1054,6 +1106,27 @@ class open_addressing_ref_impl {
}
}

/**
* @brief Waits until the slot payload has been updated
*
* @note The function will return once the slot payload is no longer equal to the sentinel value.
*
* @tparam T Map slot type
*
* @param slot The target slot to check payload with
* @param sentinel The slot sentinel value
*/
template <typename T>
__device__ void wait_for_payload(T& slot, T const& sentinel) const noexcept
{
auto ref = cuda::atomic_ref<T, Scope>{slot};
T current;
// TODO exponential backoff strategy
do {
current = ref.load(cuda::std::memory_order_relaxed);
} while (cuco::detail::bitwise_compare(current, sentinel));
}

// TODO: Clean up the sentinel handling since it's duplicated in ref and equal wrapper
value_type empty_slot_sentinel_; ///< Sentinel value indicating an empty slot
detail::equal_wrapper<key_type, key_equal> predicate_; ///< Key equality binary callable
Expand Down
104 changes: 80 additions & 24 deletions tests/static_map/insert_and_find_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,52 +28,108 @@

static constexpr int Iters = 10'000;

template <typename View>
__global__ void parallel_sum(View v)
template <typename Ref>
__global__ void parallel_sum(Ref v)
{
for (int i = 0; i < Iters; i++) {
#if __CUDA_ARCH__ < 700
if constexpr (cuco::detail::is_packable<View::value_type>())
if constexpr (cuco::detail::is_packable<Ref::value_type>())
#endif
{
auto [iter, inserted] = v.insert_and_find(thrust::make_pair(i, 1));
// for debugging...
// if (iter->second < 0) {
// asm("trap;");
// }
if (!inserted) { iter->second += 1; }
auto constexpr cg_size = Ref::cg_size;
if constexpr (cg_size == 1) {
auto [iter, inserted] = v.insert_and_find(cuco::pair{i, 1});
// for debugging...
// if (iter->second < 0) {
// asm("trap;");
// }
if (!inserted) {
auto ref =
cuda::atomic_ref<typename Ref::mapped_type, cuda::thread_scope_device>{iter->second};
ref.fetch_add(1);
}
} else {
auto const tile =
cooperative_groups::tiled_partition<cg_size>(cooperative_groups::this_thread_block());
auto [iter, inserted] = v.insert_and_find(tile, cuco::pair{i, 1});
if (!inserted and tile.thread_rank() == 0) {
auto ref =
cuda::atomic_ref<typename Ref::mapped_type, cuda::thread_scope_device>{iter->second};
ref.fetch_add(1);
}
}
}
#if __CUDA_ARCH__ < 700
else {
v.insert(thrust::make_pair(i, gridDim.x * blockDim.x));
auto constexpr cg_size = Ref::cg_size;
if constexpr (cg_size == 1) {
v.insert(cuco::pair{i, gridDim.x * blockDim.x});
} else {
auto const tile =
cooperative_groups::tiled_partition<cg_size>(cooperative_groups::this_thread_block());
v.insert(tile, cuco::pair{i, gridDim.x * blockDim.x / cg_size});
}
}
#endif
}
}

TEMPLATE_TEST_CASE_SIG("Parallel insert-or-update",
"",
((typename Key, typename Value), Key, Value),
(int32_t, int32_t),
(int32_t, int64_t),
(int64_t, int32_t),
(int64_t, int64_t))
TEMPLATE_TEST_CASE_SIG(
"static_map insert_and_find tests",
"",
((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize),
Key,
Value,
Probe,
CGSize),
(int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1),
(int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1),
(int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2),
(int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2),
(int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1),
(int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1),
(int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2),
(int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2),
(int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1),
(int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1),
(int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2),
(int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2),
(int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1),
(int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1),
(int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2),
(int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2))
{
cuco::empty_key<Key> empty_key_sentinel{-1};
cuco::empty_value<Value> empty_value_sentinel{-1};
cuco::static_map<Key, Value> m(10 * Iters, empty_key_sentinel, empty_value_sentinel);
using probe =
std::conditional_t<Probe == cuco::test::probe_sequence::linear_probing,
cuco::experimental::linear_probing<CGSize, cuco::murmurhash3_32<Key>>,
cuco::experimental::double_hashing<CGSize,
cuco::murmurhash3_32<Key>,
cuco::murmurhash3_32<Key>>>;

auto map = cuco::experimental::static_map<Key,
Value,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<Key>,
probe,
cuco::cuda_allocator<std::byte>,
cuco::experimental::storage<2>>{
10 * Iters, cuco::empty_key<Key>{-1}, cuco::empty_value<Value>{-1}};

static constexpr int Blocks = 1024;
static constexpr int Threads = 128;
parallel_sum<<<Blocks, Threads>>>(m.get_device_mutable_view());

parallel_sum<<<Blocks, Threads>>>(
map.ref(cuco::experimental::op::insert, cuco::experimental::op::insert_and_find));
CUCO_CUDA_TRY(cudaDeviceSynchronize());

thrust::device_vector<Key> d_keys(Iters);
thrust::device_vector<Value> d_values(Iters);

thrust::sequence(thrust::device, d_keys.begin(), d_keys.end());
m.find(d_keys.begin(), d_keys.end(), d_values.begin());
map.find(d_keys.begin(), d_keys.end(), d_values.begin());

REQUIRE(cuco::test::all_of(
d_values.begin(), d_values.end(), [] __device__(Value v) { return v == Blocks * Threads; }));
REQUIRE(cuco::test::all_of(d_values.begin(), d_values.end(), [] __device__(Value v) {
return v == (Blocks * Threads) / CGSize;
}));
}

0 comments on commit 9966541

Please sign in to comment.