From faa781091ca89a32423d393741040298fbcde8e7 Mon Sep 17 00:00:00 2001 From: Eyal Soha Date: Wed, 27 Sep 2023 05:33:13 +0000 Subject: [PATCH] WIP --- include/cuco/detail/equal_wrapper.cuh | 4 ++-- include/cuco/detail/open_addressing_ref_impl.cuh | 14 +++++++------- include/cuco/detail/static_set/static_set_ref.inl | 3 ++- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/include/cuco/detail/equal_wrapper.cuh b/include/cuco/detail/equal_wrapper.cuh index d2ded4a33..efa4b2cc8 100644 --- a/include/cuco/detail/equal_wrapper.cuh +++ b/include/cuco/detail/equal_wrapper.cuh @@ -62,8 +62,8 @@ struct equal_wrapper { * * @return `EQUAL` if `lhs` and `rhs` are equivalent. `UNEQUAL` otherwise. */ - template - __device__ constexpr equal_result equal_to(T const& lhs, U const& rhs) const noexcept + template + __device__ constexpr equal_result equal_to(local_T const& lhs, U const& rhs) const noexcept { return equal_(lhs, rhs) ? equal_result::EQUAL : equal_result::UNEQUAL; } diff --git a/include/cuco/detail/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing_ref_impl.cuh index cce691c21..ced089c35 100644 --- a/include/cuco/detail/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing_ref_impl.cuh @@ -283,8 +283,8 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template - __device__ thrust::pair insert_and_find(value_type const& value, + template + __device__ thrust::pair insert_and_find(local_value_type const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); @@ -346,10 +346,10 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, - value_type const& value, + local_value_type const& value, Predicate const& predicate) noexcept { auto const key = [&]() { @@ -720,12 +720,12 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot, - value_type const& value, + local_value_type const& value, Predicate const& predicate) noexcept { - auto old = compare_and_swap(slot, this->empty_slot_sentinel_, value); + auto old = compare_and_swap(slot, this->empty_slot_sentinel_, static_cast(value)); auto* old_ptr = reinterpret_cast(&old); auto const inserted = [&]() { if constexpr (HasPayload) { diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index 4c3853971..0c9bef6b1 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -233,8 +233,9 @@ class operator_impl __device__ thrust::pair insert_and_find( - cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept + cooperative_groups::thread_block_tile const& group, local_value_type const& value) noexcept { ref_type& ref_ = static_cast(*this); auto constexpr has_payload = false;