From 58300e58e0760350386140746d5e2c8e9b65b387 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Mon, 9 Oct 2023 13:43:52 -0700 Subject: [PATCH] Fix several bugs --- include/cuco/detail/equal_wrapper.cuh | 8 +- .../open_addressing_ref_impl.cuh | 73 ++++++++++++------- .../cuco/detail/static_map/static_map_ref.inl | 8 +- 3 files changed, 54 insertions(+), 35 deletions(-) diff --git a/include/cuco/detail/equal_wrapper.cuh b/include/cuco/detail/equal_wrapper.cuh index 1f2611c48..e42caa401 100644 --- a/include/cuco/detail/equal_wrapper.cuh +++ b/include/cuco/detail/equal_wrapper.cuh @@ -26,7 +26,7 @@ namespace detail { /** * @brief Enum of equality comparison results. */ -enum class equal_result : int32_t { UNEQUAL = 0, AVAILABLE = 1, EQUAL = 2 }; +enum class equal_result : int32_t { UNEQUAL = 0, EMPTY = 1, EQUAL = 2, ERASED = 3 }; /** * @brief Key equality wrapper. @@ -92,10 +92,8 @@ struct equal_wrapper { template __device__ constexpr equal_result operator()(LHS const& lhs, RHS const& rhs) const noexcept { - return cuco::detail::bitwise_compare(lhs, empty_sentinel_) or - cuco::detail::bitwise_compare(lhs, erased_sentinel_) - ? equal_result::AVAILABLE - : this->equal_to(lhs, rhs); + return cuco::detail::bitwise_compare(lhs, empty_sentinel_) ? equal_result::EMPTY + : this->equal_to(lhs, rhs); } }; diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 28f11fe3a..759594ef8 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -283,7 +283,9 @@ class open_addressing_ref_impl { // If the key is already in the container, return false if (eq_res == detail::equal_result::EQUAL) { return false; } - if (eq_res == detail::equal_result::AVAILABLE) { + if (eq_res == detail::equal_result::EMPTY or + cuco::detail::bitwise_compare(this->extract_key(slot_content), + this->erased_key_sentinel())) { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, slot_content, @@ -321,11 +323,19 @@ class open_addressing_ref_impl { auto const [state, intra_window_index] = [&]() { for (auto i = 0; i < window_size; ++i) { switch (this->predicate_(this->extract_key(window_slots[i]), key)) { - case detail::equal_result::AVAILABLE: - return window_probing_results{detail::equal_result::AVAILABLE, i}; + case detail::equal_result::EMPTY: + return window_probing_results{detail::equal_result::EMPTY, i}; case detail::equal_result::EQUAL: return window_probing_results{detail::equal_result::EQUAL, i}; - default: continue; + default: { + continue; + if (cuco::detail::bitwise_compare(this->extract_key(window_slots[i]), + this->erased_key_sentinel())) { + return window_probing_results{detail::equal_result::ERASED, i}; + } else { + continue; + } + } } } // returns dummy index `-1` for UNEQUAL @@ -335,14 +345,14 @@ class open_addressing_ref_impl { // If the key is already in the container, return false if (group.any(state == detail::equal_result::EQUAL)) { return false; } - auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE); - + auto const group_contains_available = + group.ballot(state == detail::equal_result::EMPTY or state == detail::equal_result::ERASED); if (group_contains_available) { auto const src_lane = __ffs(group_contains_available) - 1; auto const status = (group.thread_rank() == src_lane) ? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, - window_slots[src_lane], + window_slots[intra_window_index], value) : insert_result::CONTINUE; @@ -388,7 +398,9 @@ class open_addressing_ref_impl { // If the key is already in the container, return false if (eq_res == detail::equal_result::EQUAL) { return {iterator{&window_ptr[i]}, false}; } - if (eq_res == detail::equal_result::AVAILABLE) { + if (eq_res == detail::equal_result::EMPTY or + cuco::detail::bitwise_compare(this->extract_key(window_slots[i]), + this->erased_key_sentinel())) { switch ([&]() { if constexpr (sizeof(value_type) <= 8) { return packed_cas(window_ptr + i, window_slots[i], value); @@ -438,11 +450,19 @@ class open_addressing_ref_impl { auto const [state, intra_window_index] = [&]() { for (auto i = 0; i < window_size; ++i) { switch (this->predicate_(this->extract_key(window_slots[i]), key)) { - case detail::equal_result::AVAILABLE: - return window_probing_results{detail::equal_result::AVAILABLE, i}; + case detail::equal_result::EMPTY: + return window_probing_results{detail::equal_result::EMPTY, i}; case detail::equal_result::EQUAL: return window_probing_results{detail::equal_result::EQUAL, i}; - default: continue; + default: { + continue; + if (cuco::detail::bitwise_compare(this->extract_key(window_slots[i]), + this->erased_key_sentinel())) { + return window_probing_results{detail::equal_result::ERASED, i}; + } else { + continue; + } + } } } // returns dummy index `-1` for UNEQUAL @@ -459,16 +479,17 @@ class open_addressing_ref_impl { return {iterator{reinterpret_cast(res)}, false}; } - auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE); + auto const group_contains_available = + group.ballot(state == detail::equal_result::EMPTY or state == detail::equal_result::ERASED); if (group_contains_available) { auto const src_lane = __ffs(group_contains_available) - 1; auto const res = group.shfl(reinterpret_cast(slot_ptr), src_lane); - auto const status = [&]() { + auto const status = [&, target_idx = intra_window_index]() { if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; } if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot_ptr, window_slots[src_lane], value); + return packed_cas(slot_ptr, window_slots[target_idx], value); } else { - return cas_dependent_write(slot_ptr, window_slots[src_lane], value); + return cas_dependent_write(slot_ptr, window_slots[target_idx], value); } }(); @@ -510,7 +531,7 @@ class open_addressing_ref_impl { auto const eq_res = this->predicate_(this->extract_key(slot_content), key); // Key doesn't exist, return false - if (eq_res == detail::equal_result::AVAILABLE) { return false; } + if (eq_res == detail::equal_result::EMPTY) { return false; } // Key exists, return true if successfully deleted if (eq_res == detail::equal_result::EQUAL) { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); @@ -549,8 +570,8 @@ class open_addressing_ref_impl { auto const [state, intra_window_index] = [&]() { for (auto i = 0; i < window_size; ++i) { switch (this->predicate_(this->extract_key(window_slots[i]), key)) { - case detail::equal_result::AVAILABLE: - return window_probing_results{detail::equal_result::AVAILABLE, i}; + case detail::equal_result::EMPTY: + return window_probing_results{detail::equal_result::EMPTY, i}; case detail::equal_result::EQUAL: return window_probing_results{detail::equal_result::EQUAL, i}; default: continue; @@ -575,7 +596,7 @@ class open_addressing_ref_impl { case insert_result::DUPLICATE: return false; default: continue; } - } else if (group.any(state == detail::equal_result::AVAILABLE)) { + } else if (group.any(state == detail::equal_result::EMPTY)) { // Key doesn't exist, return false return false; } else { @@ -609,7 +630,7 @@ class open_addressing_ref_impl { for (auto& slot_content : window_slots) { switch (this->predicate_(this->extract_key(slot_content), key)) { case detail::equal_result::UNEQUAL: continue; - case detail::equal_result::AVAILABLE: return false; + case detail::equal_result::EMPTY: return false; case detail::equal_result::EQUAL: return true; } } @@ -642,7 +663,7 @@ class open_addressing_ref_impl { auto const state = [&]() { for (auto& slot : window_slots) { switch (this->predicate_(this->extract_key(slot), key)) { - case detail::equal_result::AVAILABLE: return detail::equal_result::AVAILABLE; + case detail::equal_result::EMPTY: return detail::equal_result::EMPTY; case detail::equal_result::EQUAL: return detail::equal_result::EQUAL; default: continue; } @@ -651,7 +672,7 @@ class open_addressing_ref_impl { }(); if (group.any(state == detail::equal_result::EQUAL)) { return true; } - if (group.any(state == detail::equal_result::AVAILABLE)) { return false; } + if (group.any(state == detail::equal_result::EMPTY)) { return false; } ++probing_iter; } @@ -681,7 +702,7 @@ class open_addressing_ref_impl { for (auto i = 0; i < window_size; ++i) { switch (this->predicate_(this->extract_key(window_slots[i]), key)) { - case detail::equal_result::AVAILABLE: { + case detail::equal_result::EMPTY: { return this->end(); } case detail::equal_result::EQUAL: { @@ -719,8 +740,8 @@ class open_addressing_ref_impl { auto const [state, intra_window_index] = [&]() { for (auto i = 0; i < window_size; ++i) { switch (this->predicate_(this->extract_key(window_slots[i]), key)) { - case detail::equal_result::AVAILABLE: - return window_probing_results{detail::equal_result::AVAILABLE, i}; + case detail::equal_result::EMPTY: + return window_probing_results{detail::equal_result::EMPTY, i}; case detail::equal_result::EQUAL: return window_probing_results{detail::equal_result::EQUAL, i}; default: continue; @@ -741,7 +762,7 @@ class open_addressing_ref_impl { } // Find an empty slot, meaning that the probe key isn't present in the container - if (group.any(state == detail::equal_result::AVAILABLE)) { return this->end(); } + if (group.any(state == detail::equal_result::EMPTY)) { return this->end(); } ++probing_iter; } diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 5b7a075a4..20d031c0d 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -256,7 +256,7 @@ class operator_impl< value.second); return; } - if (eq_res == detail::equal_result::AVAILABLE) { + if (eq_res == detail::equal_result::EMPTY) { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); if (attempt_insert_or_assign( (storage_ref.data() + *probing_iter)->data() + intra_window_index, value)) { @@ -293,8 +293,8 @@ class operator_impl< auto const [state, intra_window_index] = [&]() { for (auto i = 0; i < window_size; ++i) { switch (ref_.impl_.predicate()(window_slots[i].first, key)) { - case detail::equal_result::AVAILABLE: - return detail::window_probing_results{detail::equal_result::AVAILABLE, i}; + case detail::equal_result::EMPTY: + return detail::window_probing_results{detail::equal_result::EMPTY, i}; case detail::equal_result::EQUAL: return detail::window_probing_results{detail::equal_result::EQUAL, i}; default: continue; @@ -316,7 +316,7 @@ class operator_impl< return; } - auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE); + auto const group_contains_available = group.ballot(state == detail::equal_result::EMPTY); if (group_contains_available) { auto const src_lane = __ffs(group_contains_available) - 1; auto const status =