From cb170235e8ac4cd83910c8c7bc64abcb1070dfbf Mon Sep 17 00:00:00 2001
From: Yunsong Wang <yunsongw@nvidia.com>
Date: Fri, 6 Oct 2023 14:50:54 -0700
Subject: [PATCH] Update attempt_insert to incorporate with erase

---
 .../open_addressing_ref_impl.cuh              | 105 +++++++++---------
 1 file changed, 55 insertions(+), 50 deletions(-)

diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
index 5136ca67e..73037f424 100644
--- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
+++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
@@ -286,6 +286,7 @@ class open_addressing_ref_impl {
         if (eq_res == detail::equal_result::AVAILABLE) {
           auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
           switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
+                                 slot_content,
                                  value)) {
             case insert_result::CONTINUE: continue;
             case insert_result::SUCCESS: return true;
@@ -341,6 +342,7 @@ class open_addressing_ref_impl {
         auto const status =
           (group.thread_rank() == src_lane)
             ? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
+                             window_slots[src_lane],
                              value)
             : insert_result::CONTINUE;
 
@@ -389,9 +391,9 @@ class open_addressing_ref_impl {
         if (eq_res == detail::equal_result::AVAILABLE) {
           switch ([&]() {
             if constexpr (sizeof(value_type) <= 8) {
-              return packed_cas(window_ptr + i, value);
+              return packed_cas(window_ptr + i, window_slots[i], value);
             } else {
-              return cas_dependent_write(window_ptr + i, value);
+              return cas_dependent_write(window_ptr + i, window_slots[i], value);
             }
           }()) {
             case insert_result::SUCCESS: {
@@ -464,9 +466,9 @@ class open_addressing_ref_impl {
         auto const status   = [&]() {
           if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; }
           if constexpr (sizeof(value_type) <= 8) {
-            return packed_cas(slot_ptr, value);
+            return packed_cas(slot_ptr, window_slots[src_lane], value);
           } else {
-            return cas_dependent_write(slot_ptr, value);
+            return cas_dependent_write(slot_ptr, window_slots[src_lane], value);
           }
         }();
 
@@ -485,25 +487,19 @@ class open_addressing_ref_impl {
     }
   }
 
-  template <typename Value, typename Predicate>
-  __device__ bool erase(Value const& value, Predicate const& predicate) noexcept
+  template <typename Value>
+  __device__ bool erase(Value const& value) noexcept
   {
     static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
 
-    auto const key = [&]() {
-      if constexpr (this->has_payload) {
-        return value.first;
-      } else {
-        return value;
-      }
-    }();
+    auto const key    = this->extract_key(value);
     auto probing_iter = probing_scheme_(key, storage_ref_.window_extent());
 
     while (true) {
       auto const window_slots = storage_ref_[*probing_iter];
 
       for (auto& slot_content : window_slots) {
-        auto const eq_res = predicate(slot_content, key);
+        auto const eq_res = this->predicate_(this->extract_key(slot_content), key);
 
         // Key doesn't exist, return false
         if (eq_res == detail::equal_result::AVAILABLE) { return false; }
@@ -517,10 +513,9 @@ class open_addressing_ref_impl {
               return this->erased_key_sentinel();
             }
           }();
-          switch (attempt_insert<this->has_payload>(
-            (storage_ref_.data() + *probing_iter)->data() + intra_window_index,
-            erased_slot,
-            predicate)) {
+          switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
+                                 slot_content,
+                                 erased_slot)) {
             case insert_result::CONTINUE: continue;
             case insert_result::SUCCESS: return true;
           }
@@ -821,16 +816,19 @@ class open_addressing_ref_impl {
    *
    * @tparam Value Input type which is implicitly convertible to 'value_type'
    *
-   * @param slot Pointer to the slot in memory
-   * @param value Element to insert
+   * @param address Pointer to the slot in memory
+   * @param expected Element to compare against
+   * @param desired Element to insert
    *
    * @return Result of this operation, i.e., success/continue/duplicate
    */
   template <typename Value>
-  [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot,
-                                                              Value const& value) noexcept
+  [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* address,
+                                                              value_type const& expected,
+                                                              Value const& desired) noexcept
   {
-    auto old = compare_and_swap(slot, this->empty_slot_sentinel_, static_cast<value_type>(value));
+    auto old =
+      compare_and_swap(address, this->empty_slot_sentinel_, static_cast<value_type>(desired));
     auto* old_ptr       = reinterpret_cast<value_type*>(&old);
     auto const inserted = [&]() {
       if constexpr (this->has_payload) {
@@ -848,10 +846,10 @@ class open_addressing_ref_impl {
       auto const res = [&]() {
         if constexpr (this->has_payload) {
           // If it's a map implementation, compare keys only
-          return this->predicate_.equal_to(old_ptr->first, value.first);
+          return this->predicate_.equal_to(old_ptr->first, desired.first);
         } else {
           // If it's a set implementation, compare the whole slot content
-          return this->predicate_.equal_to(*old_ptr, value);
+          return this->predicate_.equal_to(*old_ptr, desired);
         }
       }();
       return res == detail::equal_result::EQUAL ? insert_result::DUPLICATE
@@ -864,23 +862,26 @@ class open_addressing_ref_impl {
    *
    * @tparam Value Input type which is implicitly convertible to 'value_type'
    *
-   * @param slot Pointer to the slot in memory
-   * @param value Element to insert
+   * @param address Pointer to the slot in memory
+   * @param expected Element to compare against
+   * @param desired Element to insert
    *
    * @return Result of this operation, i.e., success/continue/duplicate
    */
   template <typename Value>
-  [[nodiscard]] __device__ constexpr insert_result back_to_back_cas(value_type* slot,
-                                                                    Value const& value) noexcept
+  [[nodiscard]] __device__ constexpr insert_result back_to_back_cas(value_type* address,
+                                                                    value_type const& expected,
+                                                                    Value const& desired) noexcept
   {
     using mapped_type = decltype(this->empty_slot_sentinel_.second);
 
     auto const expected_key     = this->empty_slot_sentinel_.first;
     auto const expected_payload = this->empty_slot_sentinel_.second;
 
-    auto old_key = compare_and_swap(&slot->first, expected_key, static_cast<key_type>(value.first));
-    auto old_payload =
-      compare_and_swap(&slot->second, expected_payload, static_cast<mapped_type>(value.second));
+    auto old_key =
+      compare_and_swap(&address->first, expected_key, static_cast<key_type>(desired.first));
+    auto old_payload = compare_and_swap(
+      &address->second, expected_payload, static_cast<mapped_type>(desired.second));
 
     auto* old_key_ptr     = reinterpret_cast<key_type*>(&old_key);
     auto* old_payload_ptr = reinterpret_cast<mapped_type*>(&old_payload);
@@ -888,17 +889,17 @@ class open_addressing_ref_impl {
     // if key success
     if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
       while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
-        old_payload =
-          compare_and_swap(&slot->second, expected_payload, static_cast<mapped_type>(value.second));
+        old_payload = compare_and_swap(
+          &address->second, expected_payload, static_cast<mapped_type>(desired.second));
       }
       return insert_result::SUCCESS;
     } else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
-      atomic_store(&slot->second, expected_payload);
+      atomic_store(&address->second, expected_payload);
     }
 
     // Our key was already present in the slot, so our key is a duplicate
     // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
-    if (this->predicate_.equal_to(*old_key_ptr, value.first) == detail::equal_result::EQUAL) {
+    if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
       return insert_result::DUPLICATE;
     }
 
@@ -910,32 +911,34 @@ class open_addressing_ref_impl {
    *
    * @tparam Value Input type which is implicitly convertible to 'value_type'
    *
-   * @param slot Pointer to the slot in memory
-   * @param value Element to insert
+   * @param address Pointer to the slot in memory
+   * @param expected Element to compare against
+   * @param desired Element to insert
    *
    * @return Result of this operation, i.e., success/continue/duplicate
    */
   template <typename Value>
-  [[nodiscard]] __device__ constexpr insert_result cas_dependent_write(value_type* slot,
-                                                                       Value const& value) noexcept
+  [[nodiscard]] __device__ constexpr insert_result cas_dependent_write(
+    value_type* address, value_type const& expected, Value const& desired) noexcept
   {
     using mapped_type = decltype(this->empty_slot_sentinel_.second);
 
     auto const expected_key = this->empty_slot_sentinel_.first;
 
-    auto old_key = compare_and_swap(&slot->first, expected_key, static_cast<key_type>(value.first));
+    auto old_key =
+      compare_and_swap(&address->first, expected_key, static_cast<key_type>(desired.first));
 
     auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);
 
     // if key success
     if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
-      atomic_store(&slot->second, static_cast<mapped_type>(value.second));
+      atomic_store(&address->second, static_cast<mapped_type>(desired.second));
       return insert_result::SUCCESS;
     }
 
     // Our key was already present in the slot, so our key is a duplicate
     // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
-    if (this->predicate_.equal_to(*old_key_ptr, value.first) == detail::equal_result::EQUAL) {
+    if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
       return insert_result::DUPLICATE;
     }
 
@@ -950,22 +953,24 @@ class open_addressing_ref_impl {
    *
    * @tparam Value Input type which is implicitly convertible to 'value_type'
    *
-   * @param slot Pointer to the slot in memory
-   * @param value Element to insert
+   * @param address Pointer to the slot in memory
+   * @param expected Element to compare against
+   * @param desired Element to insert
    *
    * @return Result of this operation, i.e., success/continue/duplicate
    */
   template <typename Value>
-  [[nodiscard]] __device__ insert_result attempt_insert(value_type* slot,
-                                                        Value const& value) noexcept
+  [[nodiscard]] __device__ insert_result attempt_insert(value_type* address,
+                                                        value_type const& expected,
+                                                        Value const& desired) noexcept
   {
     if constexpr (sizeof(value_type) <= 8) {
-      return packed_cas(slot, value);
+      return packed_cas(address, expected, desired);
     } else {
 #if (_CUDA_ARCH__ < 700)
-      return cas_dependent_write(slot, value);
+      return cas_dependent_write(address, expected, desired);
 #else
-      return back_to_back_cas(slot, value);
+      return back_to_back_cas(address, expected, desired);
 #endif
     }
   }