diff --git a/include/cuco/detail/common_kernels.cuh b/include/cuco/detail/common_kernels.cuh index 759041bad..871e1f8e8 100644 --- a/include/cuco/detail/common_kernels.cuh +++ b/include/cuco/detail/common_kernels.cuh @@ -76,7 +76,7 @@ __global__ void insert_if_n(InputIterator first, while (idx < n) { if (pred(*(stencil + idx))) { - typename Ref::value_type const insert_element{*(first + idx)}; + auto const insert_element{*(first + idx)}; if constexpr (CGSize == 1) { if (ref.insert(insert_element)) { thread_num_successes++; }; } else { @@ -134,7 +134,7 @@ __global__ void insert_if_n( while (idx < n) { if (pred(*(stencil + idx))) { - typename Ref::value_type const insert_element{*(first + idx)}; + auto const insert_element{*(first + idx)}; if constexpr (CGSize == 1) { ref.insert(insert_element); } else { diff --git a/include/cuco/detail/equal_wrapper.cuh b/include/cuco/detail/equal_wrapper.cuh index d2ded4a33..0c05a4a9c 100644 --- a/include/cuco/detail/equal_wrapper.cuh +++ b/include/cuco/detail/equal_wrapper.cuh @@ -55,15 +55,16 @@ struct equal_wrapper { /** * @brief Equality check with the given equality callable. * - * @tparam U Right-hand side Element type + * @tparam LHS Left-hand side Element type + * @tparam RHS Right-hand side Element type * * @param lhs Left-hand side element to check equality * @param rhs Right-hand side element to check equality * * @return `EQUAL` if `lhs` and `rhs` are equivalent. `UNEQUAL` otherwise. */ - template - __device__ constexpr equal_result equal_to(T const& lhs, U const& rhs) const noexcept + template + __device__ constexpr equal_result equal_to(LHS const& lhs, RHS const& rhs) const noexcept { return equal_(lhs, rhs) ? equal_result::EQUAL : equal_result::UNEQUAL; } @@ -75,15 +76,16 @@ struct equal_wrapper { * first then perform a equality check with the given `equal_` callable, i.e., `equal_(lhs, rhs)`. * @note Container (like set or map) keys MUST be always on the left-hand side. * - * @tparam U Right-hand side Element type + * @tparam LHS Left-hand side Element type + * @tparam RHS Right-hand side Element type * * @param lhs Left-hand side element to check equality * @param rhs Right-hand side element to check equality * * @return Three way equality comparison result */ - template - __device__ constexpr equal_result operator()(T const& lhs, U const& rhs) const noexcept + template + __device__ constexpr equal_result operator()(LHS const& lhs, RHS const& rhs) const noexcept { return cuco::detail::bitwise_compare(lhs, empty_sentinel_) ? equal_result::EMPTY : this->equal_to(lhs, rhs); diff --git a/include/cuco/detail/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing_ref_impl.cuh index cce691c21..1e0dc35b3 100644 --- a/include/cuco/detail/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing_ref_impl.cuh @@ -155,6 +155,7 @@ class open_addressing_ref_impl { * @brief Inserts an element. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param value The element to insert @@ -162,8 +163,8 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template - __device__ bool insert(value_type const& value, Predicate const& predicate) noexcept + template + __device__ bool insert(Value const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); @@ -202,6 +203,7 @@ class open_addressing_ref_impl { * @brief Inserts an element. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert @@ -210,9 +212,9 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(cooperative_groups::thread_block_tile const& group, - value_type const& value, + Value const& value, Predicate const& predicate) noexcept { auto const key = [&]() { @@ -275,6 +277,7 @@ class open_addressing_ref_impl { * not. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param value The element to insert @@ -283,8 +286,8 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template - __device__ thrust::pair insert_and_find(value_type const& value, + template + __device__ thrust::pair insert_and_find(Value const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); @@ -309,7 +312,8 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::EQUAL) { return {iterator{&window_ptr[i]}, false}; } if (eq_res == detail::equal_result::EMPTY) { switch ([&]() { - if constexpr (sizeof(value_type) <= 8) { + if constexpr ((sizeof(value_type) <= 8) and + cuda::std::is_convertible_v) { return packed_cas(window_ptr + i, value, predicate); } else { return cas_dependent_write(window_ptr + i, value, predicate); @@ -337,6 +341,7 @@ class open_addressing_ref_impl { * not. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert_and_find @@ -346,10 +351,10 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, - value_type const& value, + Value const& value, Predicate const& predicate) noexcept { auto const key = [&]() { @@ -712,6 +717,7 @@ class open_addressing_ref_impl { * @brief Inserts the specified element with one single CAS operation. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -720,12 +726,12 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot, - value_type const& value, + Value const& value, Predicate const& predicate) noexcept { - auto old = compare_and_swap(slot, this->empty_slot_sentinel_, value); + auto old = compare_and_swap(slot, this->empty_slot_sentinel_, static_cast(value)); auto* old_ptr = reinterpret_cast(&old); auto const inserted = [&]() { if constexpr (HasPayload) { @@ -757,6 +763,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with two back-to-back CAS operations. * + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -765,15 +772,16 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result back_to_back_cas( - value_type* slot, value_type const& value, Predicate const& predicate) noexcept + value_type* slot, Value const& value, Predicate const& predicate) noexcept { auto const expected_key = this->empty_slot_sentinel_.first; auto const expected_payload = this->empty_slot_sentinel_.second; - auto old_key = compare_and_swap(&slot->first, expected_key, value.first); - auto old_payload = compare_and_swap(&slot->second, expected_payload, value.second); + auto old_key = compare_and_swap(&slot->first, expected_key, value.first); // TODO static_cast? + auto old_payload = + compare_and_swap(&slot->second, expected_payload, value.second); // TODO static_cast? using mapped_type = decltype(expected_payload); @@ -783,7 +791,8 @@ class open_addressing_ref_impl { // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { - old_payload = compare_and_swap(&slot->second, expected_payload, value.second); + old_payload = + compare_and_swap(&slot->second, expected_payload, value.second); // TODO static_cast? } return insert_result::SUCCESS; } else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { @@ -802,6 +811,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with CAS-dependent write operations. * + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -810,13 +820,13 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result cas_dependent_write( - value_type* slot, value_type const& value, Predicate const& predicate) noexcept + value_type* slot, Value const& value, Predicate const& predicate) noexcept { auto const expected_key = this->empty_slot_sentinel_.first; - auto old_key = compare_and_swap(&slot->first, expected_key, value.first); + auto old_key = compare_and_swap(&slot->first, expected_key, value.first); // TODO static_cast? auto* old_key_ptr = reinterpret_cast(&old_key); @@ -842,6 +852,7 @@ class open_addressing_ref_impl { * type and presence of other operator mixins. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -850,12 +861,12 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ insert_result attempt_insert(value_type* slot, - value_type const& value, + Value const& value, Predicate const& predicate) noexcept { - if constexpr (sizeof(value_type) <= 8) { + if constexpr ((sizeof(value_type) <= 8) and cuda::std::is_convertible_v) { return packed_cas(slot, value, predicate); } else { #if (_CUDA_ARCH__ < 700) diff --git a/tests/static_set/heterogeneous_lookup_test.cu b/tests/static_set/heterogeneous_lookup_test.cu index cbc0efac3..e1c3f5e9e 100644 --- a/tests/static_set/heterogeneous_lookup_test.cu +++ b/tests/static_set/heterogeneous_lookup_test.cu @@ -41,6 +41,8 @@ struct key_pair { // Device equality operator is mandatory due to libcudacxx bug: // https://github.com/NVIDIA/libcudacxx/issues/223 __device__ bool operator==(key_pair const& other) const { return a == other.a and b == other.b; } + + __device__ operator T() const noexcept { return a; } }; // probe key type @@ -75,14 +77,15 @@ struct custom_key_equal { template __device__ bool operator()(LHS const& lhs, RHS const& rhs) const { - return thrust::raw_reference_cast(lhs).a == thrust::raw_reference_cast(rhs).a; + return thrust::raw_reference_cast(lhs) == thrust::raw_reference_cast(rhs).a; } }; TEMPLATE_TEST_CASE_SIG( "Heterogeneous lookup", "", ((typename T, int CGSize), T, CGSize), (int32_t, 1), (int32_t, 2)) { - using Key = key_pair; + using Key = T; + using InsertKey = key_pair; using ProbeKey = key_triplet; using probe_type = cuco::experimental::double_hashing; @@ -98,15 +101,15 @@ TEMPLATE_TEST_CASE_SIG( probe_type>{ capacity, cuco::empty_key{sentinel_key}, custom_key_equal{}, probe}; - auto insert_pairs = thrust::make_transform_iterator(thrust::counting_iterator(0), - [] __device__(auto i) { return Key{i}; }); - auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator(0), + auto insert_keys = thrust::make_transform_iterator( + thrust::counting_iterator(0), [] __device__(auto i) { return InsertKey(i); }); + auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator(0), [] __device__(auto i) { return ProbeKey(i); }); SECTION("All inserted keys should be contained") { thrust::device_vector contained(num); - my_set.insert(insert_pairs, insert_pairs + num); + my_set.insert(insert_keys, insert_keys + num); my_set.contains(probe_keys, probe_keys + num, contained.begin()); REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), thrust::identity{})); }