From e37c12d7001124f77b566147bb2a3e18d0107e44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20J=C3=BCnger?= <2955913+sleeepyjack@users.noreply.github.com> Date: Thu, 12 Oct 2023 19:27:34 +0200 Subject: [PATCH] Enable heterogeneous insert for static_map (#381) --- .../open_addressing_ref_impl.cuh | 29 +++-- include/cuco/detail/{ => pair}/pair.inl | 8 ++ include/cuco/detail/pair/tuple_helpers.inl | 118 ++++++++++++++++++ include/cuco/detail/static_map/kernels.cuh | 11 +- .../cuco/detail/static_map/static_map_ref.inl | 38 ++++-- include/cuco/pair.cuh | 7 +- tests/static_map/heterogeneous_lookup_test.cu | 55 ++++---- tests/static_set/heterogeneous_lookup_test.cu | 7 +- 8 files changed, 218 insertions(+), 55 deletions(-) rename include/cuco/detail/{ => pair}/pair.inl (89%) create mode 100644 include/cuco/detail/pair/tuple_helpers.inl diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 565af7aa4..683bf94b1 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -24,6 +24,7 @@ #include #include +#include #include @@ -865,9 +866,9 @@ class open_addressing_ref_impl { Value const& value) const noexcept { if constexpr (this->has_payload) { - return value.first; + return thrust::get<0>(thrust::raw_reference_cast(value)); } else { - return value; + return thrust::raw_reference_cast(value); } } @@ -886,7 +887,7 @@ class open_addressing_ref_impl { [[nodiscard]] __host__ __device__ constexpr auto const& extract_payload( Value const& value) const noexcept { - return value.second; + return thrust::get<1>(thrust::raw_reference_cast(value)); } /** @@ -952,10 +953,10 @@ class open_addressing_ref_impl { auto const expected_key = expected.first; auto const expected_payload = expected.second; - auto old_key = - compare_and_swap(&address->first, expected_key, static_cast(desired.first)); + auto old_key = compare_and_swap( + &address->first, expected_key, static_cast(thrust::get<0>(desired))); auto old_payload = compare_and_swap( - &address->second, expected_payload, static_cast(desired.second)); + &address->second, expected_payload, static_cast(thrust::get<1>(desired))); auto* old_key_ptr = reinterpret_cast(&old_key); auto* old_payload_ptr = reinterpret_cast(&old_payload); @@ -964,7 +965,7 @@ class open_addressing_ref_impl { if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { old_payload = compare_and_swap( - &address->second, expected_payload, static_cast(desired.second)); + &address->second, expected_payload, static_cast(thrust::get<1>(desired))); } return insert_result::SUCCESS; } else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { @@ -973,7 +974,9 @@ class open_addressing_ref_impl { // Our key was already present in the slot, so our key is a duplicate // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) { + if (this->predicate_.equal_to(*old_key_ptr, + thrust::get<0>(thrust::raw_reference_cast(desired))) == + detail::equal_result::EQUAL) { return insert_result::DUPLICATE; } @@ -999,20 +1002,22 @@ class open_addressing_ref_impl { auto const expected_key = expected.first; - auto old_key = - compare_and_swap(&address->first, expected_key, static_cast(desired.first)); + auto old_key = compare_and_swap( + &address->first, expected_key, static_cast(thrust::get<0>(desired))); auto* old_key_ptr = reinterpret_cast(&old_key); // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { - atomic_store(&address->second, static_cast(desired.second)); + atomic_store(&address->second, static_cast(thrust::get<1>(desired))); return insert_result::SUCCESS; } // Our key was already present in the slot, so our key is a duplicate // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) { + if (this->predicate_.equal_to(*old_key_ptr, + thrust::get<0>(thrust::raw_reference_cast(desired))) == + detail::equal_result::EQUAL) { return insert_result::DUPLICATE; } diff --git a/include/cuco/detail/pair.inl b/include/cuco/detail/pair/pair.inl similarity index 89% rename from include/cuco/detail/pair.inl rename to include/cuco/detail/pair/pair.inl index 56d16e4fb..3279a915d 100644 --- a/include/cuco/detail/pair.inl +++ b/include/cuco/detail/pair/pair.inl @@ -49,3 +49,11 @@ __host__ __device__ constexpr bool operator==(cuco::pair const& lhs, } } // namespace cuco + +namespace thrust { +#include +} // namespace thrust + +namespace cuda::std { +#include +} // namespace cuda::std diff --git a/include/cuco/detail/pair/tuple_helpers.inl b/include/cuco/detail/pair/tuple_helpers.inl new file mode 100644 index 000000000..29be199ee --- /dev/null +++ b/include/cuco/detail/pair/tuple_helpers.inl @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +template +struct tuple_size> : integral_constant { +}; + +template +struct tuple_size> : tuple_size> { +}; + +template +struct tuple_size> : tuple_size> { +}; + +template +struct tuple_size> : tuple_size> { +}; + +template +struct tuple_element> { + using type = void; +}; + +template +struct tuple_element<0, cuco::pair> { + using type = T1; +}; + +template +struct tuple_element<1, cuco::pair> { + using type = T2; +}; + +template +struct tuple_element<0, const cuco::pair> : tuple_element<0, cuco::pair> { +}; + +template +struct tuple_element<1, const cuco::pair> : tuple_element<1, cuco::pair> { +}; + +template +struct tuple_element<0, volatile cuco::pair> : tuple_element<0, cuco::pair> { +}; + +template +struct tuple_element<1, volatile cuco::pair> : tuple_element<1, cuco::pair> { +}; + +template +struct tuple_element<0, const volatile cuco::pair> : tuple_element<0, cuco::pair> { +}; + +template +struct tuple_element<1, const volatile cuco::pair> : tuple_element<1, cuco::pair> { +}; + +template +__host__ __device__ constexpr auto get(cuco::pair& p) -> + typename tuple_element>::type& +{ + static_assert(I < 2); + if constexpr (I == 0) { + return p.first; + } else { + return p.second; + } +} + +template +__host__ __device__ constexpr auto get(cuco::pair&& p) -> + typename tuple_element>::type&& +{ + static_assert(I < 2); + if constexpr (I == 0) { + return std::move(p.first); + } else { + return std::move(p.second); + } +} + +template +__host__ __device__ constexpr auto get(cuco::pair const& p) -> + typename tuple_element>::type const& +{ + static_assert(I < 2); + if constexpr (I == 0) { + return p.first; + } else { + return p.second; + } +} + +template +__host__ __device__ constexpr auto get(cuco::pair const&& p) -> + typename tuple_element>::type const&& +{ + static_assert(I < 2); + if constexpr (I == 0) { + return std::move(p.first); + } else { + return std::move(p.second); + } +} \ No newline at end of file diff --git a/include/cuco/detail/static_map/kernels.cuh b/include/cuco/detail/static_map/kernels.cuh index a36095462..f9171ef77 100644 --- a/include/cuco/detail/static_map/kernels.cuh +++ b/include/cuco/detail/static_map/kernels.cuh @@ -21,6 +21,7 @@ #include #include +#include #include @@ -39,7 +40,7 @@ namespace detail { * * @tparam CGSize Number of threads in each CG * @tparam BlockSize Number of threads in each block - * @tparam InputIterator Device accessible input iterator whose `value_type` is + * @tparam InputIt Device accessible input iterator whose `value_type` is * convertible to the `value_type` of the data structure * @tparam Ref Type of non-owning device ref allowing access to storage * @@ -47,14 +48,14 @@ namespace detail { * @param n Number of input elements * @param ref Non-owning container device ref used to access the slot storage */ -template -__global__ void insert_or_assign(InputIterator first, cuco::detail::index_type n, Ref ref) +template +__global__ void insert_or_assign(InputIt first, cuco::detail::index_type n, Ref ref) { auto const loop_stride = cuco::detail::grid_stride() / CGSize; auto idx = cuco::detail::global_thread_id() / CGSize; while (idx < n) { - typename Ref::value_type const insert_pair{*(first + idx)}; + typename std::iterator_traits::value_type const& insert_pair = *(first + idx); if constexpr (CGSize == 1) { ref.insert_or_assign(insert_pair); } else { @@ -100,7 +101,7 @@ __global__ void find(InputIt first, cuco::detail::index_type n, OutputIt output_ while (idx - thread_idx < n) { // the whole thread block falls into the same iteration if (idx < n) { - auto const key = *(first + idx); + typename std::iterator_traits::value_type const& key = *(first + idx); if constexpr (CGSize == 1) { auto const found = ref.find(key); /* diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index afaf8f589..f27f21e76 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -180,10 +180,14 @@ class operator_impl< /** * @brief Inserts an element. * + * @tparam Value Input type which is implicitly convertible to 'value_type' + * * @param value The element to insert + * * @return True if the given element is successfully inserted */ - __device__ bool insert(value_type const& value) noexcept + template + __device__ bool insert(Value const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert(value); @@ -192,12 +196,16 @@ class operator_impl< /** * @brief Inserts an element. * + * @tparam Value Input type which is implicitly convertible to 'value_type' + * * @param group The Cooperative Group used to perform group insert * @param value The element to insert + * * @return True if the given element is successfully inserted */ + template __device__ bool insert(cooperative_groups::thread_block_tile const& group, - value_type const& value) noexcept + Value const& value) noexcept { auto& ref_ = static_cast(*this); return ref_.impl_.insert(group, value); @@ -230,9 +238,12 @@ class operator_impl< * @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v` * to the mapped_type corresponding to the key `k`. * + * @tparam Value Input type which is implicitly convertible to 'value_type' + * * @param value The element to insert */ - __device__ void insert_or_assign(value_type const& value) noexcept + template + __device__ void insert_or_assign(Value const& value) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); @@ -275,11 +286,14 @@ class operator_impl< * @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v` * to the mapped_type corresponding to the key `k`. * + * @tparam Value Input type which is implicitly convertible to 'value_type' + * * @param group The Cooperative Group used to perform group insert * @param value The element to insert */ + template __device__ void insert_or_assign(cooperative_groups::thread_block_tile const& group, - value_type const& value) noexcept + Value const& value) noexcept { ref_type& ref_ = static_cast(*this); @@ -350,13 +364,15 @@ class operator_impl< * @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v` * to the mapped_type corresponding to the key `k`. * + * @tparam Value Input type which is implicitly convertible to 'value_type' + * * @param group The Cooperative Group used to perform group insert * @param value The element to insert * * @return Returns `true` if the given `value` is inserted or `value` has a match in the map. */ - __device__ constexpr bool attempt_insert_or_assign(value_type* slot, - value_type const& value) noexcept + template + __device__ constexpr bool attempt_insert_or_assign(value_type* slot, Value const& value) noexcept { ref_type& ref_ = static_cast(*this); auto const expected_key = ref_.impl_.empty_slot_sentinel().first; @@ -430,12 +446,15 @@ class operator_impl< * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * + * @tparam Value Input type which is implicitly convertible to 'value_type' + * * @param value The element to insert * * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - __device__ thrust::pair insert_and_find(value_type const& value) noexcept + template + __device__ thrust::pair insert_and_find(Value const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert_and_find(value); @@ -448,14 +467,17 @@ class operator_impl< * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * + * @tparam Value Input type which is implicitly convertible to 'value_type' + * * @param group The Cooperative Group used to perform group insert_and_find * @param value The element to insert * * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ + template __device__ thrust::pair insert_and_find( - cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept + cooperative_groups::thread_block_tile const& group, Value const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert_and_find(group, value); diff --git a/include/cuco/pair.cuh b/include/cuco/pair.cuh index 0a804cc04..d28cae5da 100644 --- a/include/cuco/pair.cuh +++ b/include/cuco/pair.cuh @@ -22,7 +22,7 @@ #include #include -#include +#include #include namespace cuco { @@ -87,7 +87,8 @@ struct alignas(detail::pair_alignment()) pair { */ template ::value>* = nullptr> __host__ __device__ constexpr pair(T const& p) - : pair{std::get<0>(thrust::raw_reference_cast(p)), std::get<1>(thrust::raw_reference_cast(p))} + : pair{cuda::std::get<0>(thrust::raw_reference_cast(p)), + cuda::std::get<1>(thrust::raw_reference_cast(p))} { } @@ -143,4 +144,4 @@ __host__ __device__ constexpr bool operator==(cuco::pair const& lhs, } // namespace cuco -#include +#include diff --git a/tests/static_map/heterogeneous_lookup_test.cu b/tests/static_map/heterogeneous_lookup_test.cu index e842612b1..ed1ace9bd 100644 --- a/tests/static_map/heterogeneous_lookup_test.cu +++ b/tests/static_map/heterogeneous_lookup_test.cu @@ -41,6 +41,8 @@ struct key_pair { // Device equality operator is mandatory due to libcudacxx bug: // https://github.com/NVIDIA/libcudacxx/issues/223 __device__ bool operator==(key_pair const& other) const { return a == other.a and b == other.b; } + + __device__ explicit operator T() const noexcept { return a; } }; // probe key type @@ -64,61 +66,70 @@ struct key_triplet { // User-defined device hasher struct custom_hasher { template - __device__ uint32_t operator()(CustomKey const& k) + __device__ uint32_t operator()(CustomKey const& k) const { - return thrust::raw_reference_cast(k).a; + return k.a; }; }; // User-defined device key equality struct custom_key_equal { - template - __device__ bool operator()(LHS const& lhs, RHS const& rhs) + template + __device__ bool operator()(SlotKey const& lhs, InputKey const& rhs) const { - return thrust::raw_reference_cast(lhs).a == thrust::raw_reference_cast(rhs).a; + return lhs == rhs.a; } }; -TEMPLATE_TEST_CASE("Heterogeneous lookup", - "", +TEMPLATE_TEST_CASE_SIG("Heterogeneous lookup", + "", + ((typename T, int CGSize), T, CGSize), #if defined(CUCO_HAS_INDEPENDENT_THREADS) // Key type larger than 8B only supported for sm_70 and // up - int64_t, + (int64_t, 1), + (int64_t, 2), #endif - int32_t) + + (int32_t, 1), + (int32_t, 2)) { - using Key = key_pair; - using Value = TestType; - using ProbeKey = key_triplet; + using Key = T; + using Value = T; + using InsertKey = key_pair; + using ProbeKey = key_triplet; + using probe_type = cuco::experimental::double_hashing; auto const sentinel_key = Key{-1}; auto const sentinel_value = Value{-1}; constexpr std::size_t num = 100; constexpr std::size_t capacity = num * 2; - cuco::static_map map{ - capacity, cuco::empty_key{sentinel_key}, cuco::empty_value{sentinel_value}}; + auto const probe = probe_type{custom_hasher{}, custom_hasher{}}; + + auto my_map = cuco::experimental::static_map{capacity, + cuco::empty_key{sentinel_key}, + cuco::empty_value{sentinel_value}, + custom_key_equal{}, + probe}; - auto insert_pairs = - thrust::make_transform_iterator(thrust::counting_iterator(0), - [] __device__(auto i) { return cuco::pair(i, i); }); + auto insert_pairs = thrust::make_transform_iterator( + thrust::counting_iterator(0), + [] __device__(auto i) { return cuco::pair(i, i); }); auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator(0), [] __device__(auto i) { return ProbeKey(i); }); SECTION("All inserted keys-value pairs should be contained") { thrust::device_vector contained(num); - map.insert(insert_pairs, insert_pairs + num, custom_hasher{}, custom_key_equal{}); - map.contains( - probe_keys, probe_keys + num, contained.begin(), custom_hasher{}, custom_key_equal{}); + my_map.insert(insert_pairs, insert_pairs + num); + my_map.contains(probe_keys, probe_keys + num, contained.begin()); REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), thrust::identity{})); } SECTION("Non-inserted keys-value pairs should not be contained") { thrust::device_vector contained(num); - map.contains( - probe_keys, probe_keys + num, contained.begin(), custom_hasher{}, custom_key_equal{}); + my_map.contains(probe_keys, probe_keys + num, contained.begin()); REQUIRE(cuco::test::none_of(contained.begin(), contained.end(), thrust::identity{})); } } diff --git a/tests/static_set/heterogeneous_lookup_test.cu b/tests/static_set/heterogeneous_lookup_test.cu index ddc799ed3..2875d5a6a 100644 --- a/tests/static_set/heterogeneous_lookup_test.cu +++ b/tests/static_set/heterogeneous_lookup_test.cu @@ -94,11 +94,8 @@ TEMPLATE_TEST_CASE_SIG( constexpr std::size_t num = 100; constexpr std::size_t capacity = num * 2; auto const probe = probe_type{custom_hasher{}, custom_hasher{}}; - auto my_set = cuco::experimental::static_set, - cuda::thread_scope_device, - custom_key_equal, - probe_type>{ + + auto my_set = cuco::experimental::static_set{ capacity, cuco::empty_key{sentinel_key}, custom_key_equal{}, probe}; auto insert_keys = thrust::make_transform_iterator(