From fd23a3dcb2c2e53c228e629014b1c625e9e6515b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20J=C3=BCnger?= <2955913+sleeepyjack@users.noreply.github.com> Date: Mon, 2 Oct 2023 18:05:58 +0200 Subject: [PATCH] Enable heterogeneous insert for static_set (#375) --- include/cuco/detail/common_kernels.cuh | 22 +++---- include/cuco/detail/equal_wrapper.cuh | 14 +++-- .../cuco/detail/open_addressing_ref_impl.cuh | 60 +++++++++++-------- include/cuco/detail/static_set/kernels.cuh | 4 +- .../cuco/detail/static_set/static_set_ref.inl | 20 +++++-- tests/static_set/heterogeneous_lookup_test.cu | 21 ++++--- 6 files changed, 87 insertions(+), 54 deletions(-) diff --git a/include/cuco/detail/common_kernels.cuh b/include/cuco/detail/common_kernels.cuh index 759041bad..223f20609 100644 --- a/include/cuco/detail/common_kernels.cuh +++ b/include/cuco/detail/common_kernels.cuh @@ -23,6 +23,8 @@ #include +#include + namespace cuco { namespace experimental { namespace detail { @@ -37,7 +39,7 @@ namespace detail { * * @tparam CGSize Number of threads in each CG * @tparam BlockSize Number of threads in each block - * @tparam InputIterator Device accessible input iterator whose `value_type` is + * @tparam InputIt Device accessible input iterator whose `value_type` is * convertible to the `value_type` of the data structure * @tparam StencilIt Device accessible random access iterator whose value_type is * convertible to Predicate's argument type @@ -55,12 +57,12 @@ namespace detail { */ template -__global__ void insert_if_n(InputIterator first, +__global__ void insert_if_n(InputIt first, cuco::detail::index_type n, StencilIt stencil, Predicate pred, @@ -76,7 +78,7 @@ __global__ void insert_if_n(InputIterator first, while (idx < n) { if (pred(*(stencil + idx))) { - typename Ref::value_type const insert_element{*(first + idx)}; + typename std::iterator_traits::value_type const& insert_element{*(first + idx)}; if constexpr (CGSize == 1) { if (ref.insert(insert_element)) { thread_num_successes++; }; } else { @@ -106,7 +108,7 @@ __global__ void insert_if_n(InputIterator first, * * @tparam CGSize Number of threads in each CG * @tparam BlockSize Number of threads in each block - * @tparam InputIterator Device accessible input iterator whose `value_type` is + * @tparam InputIt Device accessible input iterator whose `value_type` is * convertible to the `value_type` of the data structure * @tparam StencilIt Device accessible random access iterator whose value_type is * convertible to Predicate's argument type @@ -122,19 +124,19 @@ __global__ void insert_if_n(InputIterator first, */ template __global__ void insert_if_n( - InputIterator first, cuco::detail::index_type n, StencilIt stencil, Predicate pred, Ref ref) + InputIt first, cuco::detail::index_type n, StencilIt stencil, Predicate pred, Ref ref) { auto const loop_stride = cuco::detail::grid_stride() / CGSize; auto idx = cuco::detail::global_thread_id() / CGSize; while (idx < n) { if (pred(*(stencil + idx))) { - typename Ref::value_type const insert_element{*(first + idx)}; + typename std::iterator_traits::value_type const& insert_element{*(first + idx)}; if constexpr (CGSize == 1) { ref.insert(insert_element); } else { @@ -198,7 +200,7 @@ __global__ void contains_if_n(InputIt first, while (idx - thread_idx < n) { // the whole thread block falls into the same iteration if constexpr (CGSize == 1) { if (idx < n) { - auto const key = *(first + idx); + typename std::iterator_traits::value_type const& key = *(first + idx); /* * The ld.relaxed.gpu instruction causes L1 to flush more frequently, causing increased * sector stores from L2 to global memory. By writing results to shared memory and then @@ -212,7 +214,7 @@ __global__ void contains_if_n(InputIt first, } else { auto const tile = cg::tiled_partition(cg::this_thread_block()); if (idx < n) { - auto const key = *(first + idx); + typename std::iterator_traits::value_type const& key = *(first + idx); auto const found = pred(*(stencil + idx)) ? ref.contains(tile, key) : false; if (tile.thread_rank() == 0) { *(output_begin + idx) = found; } } diff --git a/include/cuco/detail/equal_wrapper.cuh b/include/cuco/detail/equal_wrapper.cuh index d2ded4a33..0c05a4a9c 100644 --- a/include/cuco/detail/equal_wrapper.cuh +++ b/include/cuco/detail/equal_wrapper.cuh @@ -55,15 +55,16 @@ struct equal_wrapper { /** * @brief Equality check with the given equality callable. * - * @tparam U Right-hand side Element type + * @tparam LHS Left-hand side Element type + * @tparam RHS Right-hand side Element type * * @param lhs Left-hand side element to check equality * @param rhs Right-hand side element to check equality * * @return `EQUAL` if `lhs` and `rhs` are equivalent. `UNEQUAL` otherwise. */ - template - __device__ constexpr equal_result equal_to(T const& lhs, U const& rhs) const noexcept + template + __device__ constexpr equal_result equal_to(LHS const& lhs, RHS const& rhs) const noexcept { return equal_(lhs, rhs) ? equal_result::EQUAL : equal_result::UNEQUAL; } @@ -75,15 +76,16 @@ struct equal_wrapper { * first then perform a equality check with the given `equal_` callable, i.e., `equal_(lhs, rhs)`. * @note Container (like set or map) keys MUST be always on the left-hand side. * - * @tparam U Right-hand side Element type + * @tparam LHS Left-hand side Element type + * @tparam RHS Right-hand side Element type * * @param lhs Left-hand side element to check equality * @param rhs Right-hand side element to check equality * * @return Three way equality comparison result */ - template - __device__ constexpr equal_result operator()(T const& lhs, U const& rhs) const noexcept + template + __device__ constexpr equal_result operator()(LHS const& lhs, RHS const& rhs) const noexcept { return cuco::detail::bitwise_compare(lhs, empty_sentinel_) ? equal_result::EMPTY : this->equal_to(lhs, rhs); diff --git a/include/cuco/detail/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing_ref_impl.cuh index cce691c21..3967cffa3 100644 --- a/include/cuco/detail/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing_ref_impl.cuh @@ -155,6 +155,7 @@ class open_addressing_ref_impl { * @brief Inserts an element. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param value The element to insert @@ -162,8 +163,8 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template - __device__ bool insert(value_type const& value, Predicate const& predicate) noexcept + template + __device__ bool insert(Value const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); @@ -202,6 +203,7 @@ class open_addressing_ref_impl { * @brief Inserts an element. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert @@ -210,9 +212,9 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(cooperative_groups::thread_block_tile const& group, - value_type const& value, + Value const& value, Predicate const& predicate) noexcept { auto const key = [&]() { @@ -275,6 +277,7 @@ class open_addressing_ref_impl { * not. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param value The element to insert @@ -283,8 +286,8 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template - __device__ thrust::pair insert_and_find(value_type const& value, + template + __device__ thrust::pair insert_and_find(Value const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); @@ -337,6 +340,7 @@ class open_addressing_ref_impl { * not. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert_and_find @@ -346,10 +350,10 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, - value_type const& value, + Value const& value, Predicate const& predicate) noexcept { auto const key = [&]() { @@ -712,6 +716,7 @@ class open_addressing_ref_impl { * @brief Inserts the specified element with one single CAS operation. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -720,12 +725,12 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot, - value_type const& value, + Value const& value, Predicate const& predicate) noexcept { - auto old = compare_and_swap(slot, this->empty_slot_sentinel_, value); + auto old = compare_and_swap(slot, this->empty_slot_sentinel_, static_cast(value)); auto* old_ptr = reinterpret_cast(&old); auto const inserted = [&]() { if constexpr (HasPayload) { @@ -757,6 +762,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with two back-to-back CAS operations. * + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -765,17 +771,18 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result back_to_back_cas( - value_type* slot, value_type const& value, Predicate const& predicate) noexcept + value_type* slot, Value const& value, Predicate const& predicate) noexcept { + using mapped_type = decltype(this->empty_slot_sentinel_.second); + auto const expected_key = this->empty_slot_sentinel_.first; auto const expected_payload = this->empty_slot_sentinel_.second; - auto old_key = compare_and_swap(&slot->first, expected_key, value.first); - auto old_payload = compare_and_swap(&slot->second, expected_payload, value.second); - - using mapped_type = decltype(expected_payload); + auto old_key = compare_and_swap(&slot->first, expected_key, static_cast(value.first)); + auto old_payload = + compare_and_swap(&slot->second, expected_payload, static_cast(value.second)); auto* old_key_ptr = reinterpret_cast(&old_key); auto* old_payload_ptr = reinterpret_cast(&old_payload); @@ -783,7 +790,8 @@ class open_addressing_ref_impl { // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { - old_payload = compare_and_swap(&slot->second, expected_payload, value.second); + old_payload = + compare_and_swap(&slot->second, expected_payload, static_cast(value.second)); } return insert_result::SUCCESS; } else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { @@ -802,6 +810,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with CAS-dependent write operations. * + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -810,19 +819,21 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result cas_dependent_write( - value_type* slot, value_type const& value, Predicate const& predicate) noexcept + value_type* slot, Value const& value, Predicate const& predicate) noexcept { + using mapped_type = decltype(this->empty_slot_sentinel_.second); + auto const expected_key = this->empty_slot_sentinel_.first; - auto old_key = compare_and_swap(&slot->first, expected_key, value.first); + auto old_key = compare_and_swap(&slot->first, expected_key, static_cast(value.first)); auto* old_key_ptr = reinterpret_cast(&old_key); // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { - atomic_store(&slot->second, value.second); + atomic_store(&slot->second, static_cast(value.second)); return insert_result::SUCCESS; } @@ -842,6 +853,7 @@ class open_addressing_ref_impl { * type and presence of other operator mixins. * * @tparam HasPayload Boolean indicating it's a set or map implementation + * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -850,9 +862,9 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ insert_result attempt_insert(value_type* slot, - value_type const& value, + Value const& value, Predicate const& predicate) noexcept { if constexpr (sizeof(value_type) <= 8) { diff --git a/include/cuco/detail/static_set/kernels.cuh b/include/cuco/detail/static_set/kernels.cuh index 72744f2b4..15d725f68 100644 --- a/include/cuco/detail/static_set/kernels.cuh +++ b/include/cuco/detail/static_set/kernels.cuh @@ -24,6 +24,8 @@ #include +#include + namespace cuco { namespace experimental { namespace static_set_ns { @@ -62,7 +64,7 @@ __global__ void find(InputIt first, cuco::detail::index_type n, OutputIt output_ while (idx - thread_idx < n) { // the whole thread block falls into the same iteration if (idx < n) { - auto const key = *(first + idx); + typename std::iterator_traits::value_type const& key = *(first + idx); if constexpr (CGSize == 1) { auto const found = ref.find(key); /* diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index 4c3853971..3dbda9bbf 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -128,11 +128,14 @@ class operator_impl + __device__ bool insert(Value const& value) noexcept { ref_type& ref_ = static_cast(*this); auto constexpr has_payload = false; @@ -142,13 +145,16 @@ class operator_impl __device__ bool insert(cooperative_groups::thread_block_tile const& group, - value_type const& value) noexcept + Value const& value) noexcept { auto& ref_ = static_cast(*this); auto constexpr has_payload = false; @@ -208,12 +214,15 @@ class operator_impl insert_and_find(value_type const& value) noexcept + template + __device__ thrust::pair insert_and_find(Value const& value) noexcept { ref_type& ref_ = static_cast(*this); auto constexpr has_payload = false; @@ -227,14 +236,17 @@ class operator_impl __device__ thrust::pair insert_and_find( - cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept + cooperative_groups::thread_block_tile const& group, Value const& value) noexcept { ref_type& ref_ = static_cast(*this); auto constexpr has_payload = false; diff --git a/tests/static_set/heterogeneous_lookup_test.cu b/tests/static_set/heterogeneous_lookup_test.cu index cbc0efac3..ddc799ed3 100644 --- a/tests/static_set/heterogeneous_lookup_test.cu +++ b/tests/static_set/heterogeneous_lookup_test.cu @@ -41,6 +41,8 @@ struct key_pair { // Device equality operator is mandatory due to libcudacxx bug: // https://github.com/NVIDIA/libcudacxx/issues/223 __device__ bool operator==(key_pair const& other) const { return a == other.a and b == other.b; } + + __device__ explicit operator T() const noexcept { return a; } }; // probe key type @@ -66,23 +68,24 @@ struct custom_hasher { template __device__ uint32_t operator()(CustomKey const& k) const { - return thrust::raw_reference_cast(k).a; + return k.a; }; }; // User-defined device key equality struct custom_key_equal { - template - __device__ bool operator()(LHS const& lhs, RHS const& rhs) const + template + __device__ bool operator()(SlotKey const& lhs, InputKey const& rhs) const { - return thrust::raw_reference_cast(lhs).a == thrust::raw_reference_cast(rhs).a; + return lhs == rhs.a; } }; TEMPLATE_TEST_CASE_SIG( "Heterogeneous lookup", "", ((typename T, int CGSize), T, CGSize), (int32_t, 1), (int32_t, 2)) { - using Key = key_pair; + using Key = T; + using InsertKey = key_pair; using ProbeKey = key_triplet; using probe_type = cuco::experimental::double_hashing; @@ -98,15 +101,15 @@ TEMPLATE_TEST_CASE_SIG( probe_type>{ capacity, cuco::empty_key{sentinel_key}, custom_key_equal{}, probe}; - auto insert_pairs = thrust::make_transform_iterator(thrust::counting_iterator(0), - [] __device__(auto i) { return Key{i}; }); - auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator(0), + auto insert_keys = thrust::make_transform_iterator( + thrust::counting_iterator(0), [] __device__(auto i) { return InsertKey(i); }); + auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator(0), [] __device__(auto i) { return ProbeKey(i); }); SECTION("All inserted keys should be contained") { thrust::device_vector contained(num); - my_set.insert(insert_pairs, insert_pairs + num); + my_set.insert(insert_keys, insert_keys + num); my_set.contains(probe_keys, probe_keys + num, contained.begin()); REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), thrust::identity{})); }