From cb170235e8ac4cd83910c8c7bc64abcb1070dfbf Mon Sep 17 00:00:00 2001 From: Yunsong Wang <yunsongw@nvidia.com> Date: Fri, 6 Oct 2023 14:50:54 -0700 Subject: [PATCH] Update attempt_insert to incorporate with erase --- .../open_addressing_ref_impl.cuh | 105 +++++++++--------- 1 file changed, 55 insertions(+), 50 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 5136ca67e..73037f424 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -286,6 +286,7 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::AVAILABLE) { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, + slot_content, value)) { case insert_result::CONTINUE: continue; case insert_result::SUCCESS: return true; @@ -341,6 +342,7 @@ class open_addressing_ref_impl { auto const status = (group.thread_rank() == src_lane) ? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, + window_slots[src_lane], value) : insert_result::CONTINUE; @@ -389,9 +391,9 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::AVAILABLE) { switch ([&]() { if constexpr (sizeof(value_type) <= 8) { - return packed_cas(window_ptr + i, value); + return packed_cas(window_ptr + i, window_slots[i], value); } else { - return cas_dependent_write(window_ptr + i, value); + return cas_dependent_write(window_ptr + i, window_slots[i], value); } }()) { case insert_result::SUCCESS: { @@ -464,9 +466,9 @@ class open_addressing_ref_impl { auto const status = [&]() { if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; } if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot_ptr, value); + return packed_cas(slot_ptr, window_slots[src_lane], value); } else { - return cas_dependent_write(slot_ptr, value); + return cas_dependent_write(slot_ptr, window_slots[src_lane], value); } }(); @@ -485,25 +487,19 @@ class open_addressing_ref_impl { } } - template <typename Value, typename Predicate> - __device__ bool erase(Value const& value, Predicate const& predicate) noexcept + template <typename Value> + __device__ bool erase(Value const& value) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); - auto const key = [&]() { - if constexpr (this->has_payload) { - return value.first; - } else { - return value; - } - }(); + auto const key = this->extract_key(value); auto probing_iter = probing_scheme_(key, storage_ref_.window_extent()); while (true) { auto const window_slots = storage_ref_[*probing_iter]; for (auto& slot_content : window_slots) { - auto const eq_res = predicate(slot_content, key); + auto const eq_res = this->predicate_(this->extract_key(slot_content), key); // Key doesn't exist, return false if (eq_res == detail::equal_result::AVAILABLE) { return false; } @@ -517,10 +513,9 @@ class open_addressing_ref_impl { return this->erased_key_sentinel(); } }(); - switch (attempt_insert<this->has_payload>( - (storage_ref_.data() + *probing_iter)->data() + intra_window_index, - erased_slot, - predicate)) { + switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, + slot_content, + erased_slot)) { case insert_result::CONTINUE: continue; case insert_result::SUCCESS: return true; } @@ -821,16 +816,19 @@ class open_addressing_ref_impl { * * @tparam Value Input type which is implicitly convertible to 'value_type' * - * @param slot Pointer to the slot in memory - * @param value Element to insert + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert * * @return Result of this operation, i.e., success/continue/duplicate */ template <typename Value> - [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot, - Value const& value) noexcept + [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* address, + value_type const& expected, + Value const& desired) noexcept { - auto old = compare_and_swap(slot, this->empty_slot_sentinel_, static_cast<value_type>(value)); + auto old = + compare_and_swap(address, this->empty_slot_sentinel_, static_cast<value_type>(desired)); auto* old_ptr = reinterpret_cast<value_type*>(&old); auto const inserted = [&]() { if constexpr (this->has_payload) { @@ -848,10 +846,10 @@ class open_addressing_ref_impl { auto const res = [&]() { if constexpr (this->has_payload) { // If it's a map implementation, compare keys only - return this->predicate_.equal_to(old_ptr->first, value.first); + return this->predicate_.equal_to(old_ptr->first, desired.first); } else { // If it's a set implementation, compare the whole slot content - return this->predicate_.equal_to(*old_ptr, value); + return this->predicate_.equal_to(*old_ptr, desired); } }(); return res == detail::equal_result::EQUAL ? insert_result::DUPLICATE @@ -864,23 +862,26 @@ class open_addressing_ref_impl { * * @tparam Value Input type which is implicitly convertible to 'value_type' * - * @param slot Pointer to the slot in memory - * @param value Element to insert + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert * * @return Result of this operation, i.e., success/continue/duplicate */ template <typename Value> - [[nodiscard]] __device__ constexpr insert_result back_to_back_cas(value_type* slot, - Value const& value) noexcept + [[nodiscard]] __device__ constexpr insert_result back_to_back_cas(value_type* address, + value_type const& expected, + Value const& desired) noexcept { using mapped_type = decltype(this->empty_slot_sentinel_.second); auto const expected_key = this->empty_slot_sentinel_.first; auto const expected_payload = this->empty_slot_sentinel_.second; - auto old_key = compare_and_swap(&slot->first, expected_key, static_cast<key_type>(value.first)); - auto old_payload = - compare_and_swap(&slot->second, expected_payload, static_cast<mapped_type>(value.second)); + auto old_key = + compare_and_swap(&address->first, expected_key, static_cast<key_type>(desired.first)); + auto old_payload = compare_and_swap( + &address->second, expected_payload, static_cast<mapped_type>(desired.second)); auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key); auto* old_payload_ptr = reinterpret_cast<mapped_type*>(&old_payload); @@ -888,17 +889,17 @@ class open_addressing_ref_impl { // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { - old_payload = - compare_and_swap(&slot->second, expected_payload, static_cast<mapped_type>(value.second)); + old_payload = compare_and_swap( + &address->second, expected_payload, static_cast<mapped_type>(desired.second)); } return insert_result::SUCCESS; } else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { - atomic_store(&slot->second, expected_payload); + atomic_store(&address->second, expected_payload); } // Our key was already present in the slot, so our key is a duplicate // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - if (this->predicate_.equal_to(*old_key_ptr, value.first) == detail::equal_result::EQUAL) { + if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) { return insert_result::DUPLICATE; } @@ -910,32 +911,34 @@ class open_addressing_ref_impl { * * @tparam Value Input type which is implicitly convertible to 'value_type' * - * @param slot Pointer to the slot in memory - * @param value Element to insert + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert * * @return Result of this operation, i.e., success/continue/duplicate */ template <typename Value> - [[nodiscard]] __device__ constexpr insert_result cas_dependent_write(value_type* slot, - Value const& value) noexcept + [[nodiscard]] __device__ constexpr insert_result cas_dependent_write( + value_type* address, value_type const& expected, Value const& desired) noexcept { using mapped_type = decltype(this->empty_slot_sentinel_.second); auto const expected_key = this->empty_slot_sentinel_.first; - auto old_key = compare_and_swap(&slot->first, expected_key, static_cast<key_type>(value.first)); + auto old_key = + compare_and_swap(&address->first, expected_key, static_cast<key_type>(desired.first)); auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key); // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { - atomic_store(&slot->second, static_cast<mapped_type>(value.second)); + atomic_store(&address->second, static_cast<mapped_type>(desired.second)); return insert_result::SUCCESS; } // Our key was already present in the slot, so our key is a duplicate // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - if (this->predicate_.equal_to(*old_key_ptr, value.first) == detail::equal_result::EQUAL) { + if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) { return insert_result::DUPLICATE; } @@ -950,22 +953,24 @@ class open_addressing_ref_impl { * * @tparam Value Input type which is implicitly convertible to 'value_type' * - * @param slot Pointer to the slot in memory - * @param value Element to insert + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert * * @return Result of this operation, i.e., success/continue/duplicate */ template <typename Value> - [[nodiscard]] __device__ insert_result attempt_insert(value_type* slot, - Value const& value) noexcept + [[nodiscard]] __device__ insert_result attempt_insert(value_type* address, + value_type const& expected, + Value const& desired) noexcept { if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot, value); + return packed_cas(address, expected, desired); } else { #if (_CUDA_ARCH__ < 700) - return cas_dependent_write(slot, value); + return cas_dependent_write(address, expected, desired); #else - return back_to_back_cas(slot, value); + return back_to_back_cas(address, expected, desired); #endif } }