From 224881ad34353bc77b8c2056d36ce7609f7a7136 Mon Sep 17 00:00:00 2001 From: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> Date: Fri, 29 Sep 2023 23:42:55 +0000 Subject: [PATCH] Enable the same thing for static_map --- include/cuco/detail/static_map/kernels.cuh | 11 ++-- .../cuco/detail/static_map/static_map_ref.inl | 52 +++++++++++----- tests/static_map/heterogeneous_lookup_test.cu | 61 ++++++++++++------- 3 files changed, 81 insertions(+), 43 deletions(-) diff --git a/include/cuco/detail/static_map/kernels.cuh b/include/cuco/detail/static_map/kernels.cuh index a36095462..f8aee369e 100644 --- a/include/cuco/detail/static_map/kernels.cuh +++ b/include/cuco/detail/static_map/kernels.cuh @@ -21,6 +21,7 @@ #include #include +#include #include @@ -39,7 +40,7 @@ namespace detail { * * @tparam CGSize Number of threads in each CG * @tparam BlockSize Number of threads in each block - * @tparam InputIterator Device accessible input iterator whose `value_type` is + * @tparam InputIt Device accessible input iterator whose `value_type` is * convertible to the `value_type` of the data structure * @tparam Ref Type of non-owning device ref allowing access to storage * @@ -47,14 +48,14 @@ namespace detail { * @param n Number of input elements * @param ref Non-owning container device ref used to access the slot storage */ -template -__global__ void insert_or_assign(InputIterator first, cuco::detail::index_type n, Ref ref) +template +__global__ void insert_or_assign(InputIt first, cuco::detail::index_type n, Ref ref) { auto const loop_stride = cuco::detail::grid_stride() / CGSize; auto idx = cuco::detail::global_thread_id() / CGSize; while (idx < n) { - typename Ref::value_type const insert_pair{*(first + idx)}; + typename std::iterator_traits::value_type const& insert_pair{*(first + idx)}; if constexpr (CGSize == 1) { ref.insert_or_assign(insert_pair); } else { @@ -100,7 +101,7 @@ __global__ void find(InputIt first, cuco::detail::index_type n, OutputIt output_ while (idx - thread_idx < n) { // the whole thread block falls into the same iteration if (idx < n) { - auto const key = *(first + idx); + typename std::iterator_traits::value_type const& key = *(first + idx); if constexpr (CGSize == 1) { auto const found = ref.find(key); /* diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 28b3ffaf2..9c6f2ae19 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -243,10 +243,13 @@ class operator_impl< /** * @brief Inserts an element. * ++ * @tparam Value Input type which is implicitly convertible to 'value_type' + * * @param value The element to insert * @return True if the given element is successfully inserted */ - __device__ bool insert(value_type const& value) noexcept + template + __device__ bool insert(Value const& value) noexcept { ref_type& ref_ = static_cast(*this); auto constexpr has_payload = true; @@ -256,12 +259,15 @@ class operator_impl< /** * @brief Inserts an element. * ++ * @tparam Value Input type which is implicitly convertible to 'value_type' + * * @param group The Cooperative Group used to perform group insert * @param value The element to insert * @return True if the given element is successfully inserted */ + template __device__ bool insert(cooperative_groups::thread_block_tile const& group, - value_type const& value) noexcept + Value const& value) noexcept { auto& ref_ = static_cast(*this); auto constexpr has_payload = true; @@ -282,7 +288,8 @@ class operator_impl< using base_type = static_map_ref; using ref_type = static_map_ref; using key_type = typename base_type::key_type; - using value_type = typename base_type::value_type; + using value_type = typename base_type::value_type; + using mapped_type = typename base_type::mapped_type; static constexpr auto cg_size = base_type::cg_size; static constexpr auto window_size = base_type::window_size; @@ -295,14 +302,17 @@ class operator_impl< * @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v` * to the mapped_type corresponding to the key `k`. * ++ * @tparam Value Input type which is implicitly convertible to 'value_type' + * * @param value The element to insert */ - __device__ void insert_or_assign(value_type const& value) noexcept + template + __device__ void insert_or_assign(Value const& value) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); ref_type& ref_ = static_cast(*this); - auto const key = value.first; + auto const key = value.first; // TODO can we use auto here? auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); auto probing_iter = probing_scheme(key, storage_ref.window_extent()); @@ -318,7 +328,7 @@ class operator_impl< auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); ref_.impl_.atomic_store( &((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second, - value.second); + static_cast(value.second)); return; } if (eq_res == detail::equal_result::EMPTY) { @@ -339,15 +349,18 @@ class operator_impl< * @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v` * to the mapped_type corresponding to the key `k`. * ++ * @tparam Value Input type which is implicitly convertible to 'value_type' + * * @param group The Cooperative Group used to perform group insert * @param value The element to insert */ + template __device__ void insert_or_assign(cooperative_groups::thread_block_tile const& group, - value_type const& value) noexcept + Value const& value) noexcept { ref_type& ref_ = static_cast(*this); - auto const key = value.first; + auto const key = value.first; // TODO auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); auto probing_iter = probing_scheme(group, key, storage_ref.window_extent()); @@ -375,7 +388,7 @@ class operator_impl< if (group.thread_rank() == src_lane) { ref_.impl_.atomic_store( &((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second, - value.second); + static_cast(value.second)); } group.sync(); return; @@ -406,25 +419,28 @@ class operator_impl< * @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v` * to the mapped_type corresponding to the key `k`. * ++ * @tparam Value Input type which is implicitly convertible to 'value_type' + * * @param group The Cooperative Group used to perform group insert * @param value The element to insert * * @return Returns `true` if the given `value` is inserted or `value` has a match in the map. */ - __device__ constexpr bool attempt_insert_or_assign(value_type* slot, - value_type const& value) noexcept + template + __device__ constexpr bool attempt_insert_or_assign(value_type* slot, Value const& value) noexcept { ref_type& ref_ = static_cast(*this); auto const expected_key = ref_.impl_.empty_slot_sentinel().first; - auto old_key = ref_.impl_.compare_and_swap(&slot->first, expected_key, value.first); + auto old_key = + ref_.impl_.compare_and_swap(&slot->first, expected_key, static_cast(value.first)); auto* old_key_ptr = reinterpret_cast(&old_key); // if key success or key was already present in the map if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key) or (ref_.predicate_.equal_to(*old_key_ptr, value.first) == detail::equal_result::EQUAL)) { // Update payload - ref_.impl_.atomic_store(&slot->second, value.second); + ref_.impl_.atomic_store(&slot->second, static_cast(value.second)); return true; } return false; @@ -485,12 +501,15 @@ class operator_impl< * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * ++ * @tparam Value Input type which is implicitly convertible to 'value_type' + * * @param value The element to insert * * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - __device__ thrust::pair insert_and_find(value_type const& value) noexcept + template + __device__ thrust::pair insert_and_find(Value const& value) noexcept { ref_type& ref_ = static_cast(*this); auto constexpr has_payload = true; @@ -504,14 +523,17 @@ class operator_impl< * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * ++ * @tparam Value Input type which is implicitly convertible to 'value_type' + * * @param group The Cooperative Group used to perform group insert_and_find * @param value The element to insert * * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ + template __device__ thrust::pair insert_and_find( - cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept + cooperative_groups::thread_block_tile const& group, Value const& value) noexcept { ref_type& ref_ = static_cast(*this); auto constexpr has_payload = true; diff --git a/tests/static_map/heterogeneous_lookup_test.cu b/tests/static_map/heterogeneous_lookup_test.cu index e842612b1..a909eb6c0 100644 --- a/tests/static_map/heterogeneous_lookup_test.cu +++ b/tests/static_map/heterogeneous_lookup_test.cu @@ -41,6 +41,8 @@ struct key_pair { // Device equality operator is mandatory due to libcudacxx bug: // https://github.com/NVIDIA/libcudacxx/issues/223 __device__ bool operator==(key_pair const& other) const { return a == other.a and b == other.b; } + + __device__ explicit operator T() const noexcept { return a; } }; // probe key type @@ -64,61 +66,74 @@ struct key_triplet { // User-defined device hasher struct custom_hasher { template - __device__ uint32_t operator()(CustomKey const& k) + __device__ uint32_t operator()(CustomKey const& k) const { - return thrust::raw_reference_cast(k).a; + return k.a; }; }; // User-defined device key equality struct custom_key_equal { - template - __device__ bool operator()(LHS const& lhs, RHS const& rhs) + template + __device__ bool operator()(SlotKey const& lhs, InputKey const& rhs) const { - return thrust::raw_reference_cast(lhs).a == thrust::raw_reference_cast(rhs).a; + return lhs == rhs.a; } }; -TEMPLATE_TEST_CASE("Heterogeneous lookup", - "", +TEMPLATE_TEST_CASE_SIG("Heterogeneous lookup", + "", + ((typename T, int CGSize), T, CGSize), #if defined(CUCO_HAS_INDEPENDENT_THREADS) // Key type larger than 8B only supported for sm_70 and // up - int64_t, + (int64_t, 1), + (int64_t, 2), #endif - int32_t) + (int32_t, 1), + (int32_t, 2)) { - using Key = key_pair; - using Value = TestType; - using ProbeKey = key_triplet; + using Key = T; + using Value = T; + using InsertKey = key_pair; + using ProbeKey = key_triplet; + using probe_type = cuco::experimental::double_hashing; auto const sentinel_key = Key{-1}; auto const sentinel_value = Value{-1}; constexpr std::size_t num = 100; constexpr std::size_t capacity = num * 2; - cuco::static_map map{ - capacity, cuco::empty_key{sentinel_key}, cuco::empty_value{sentinel_value}}; - - auto insert_pairs = - thrust::make_transform_iterator(thrust::counting_iterator(0), - [] __device__(auto i) { return cuco::pair(i, i); }); + auto const probe = probe_type{custom_hasher{}, custom_hasher{}}; + + auto map = cuco::experimental::static_map, + cuda::thread_scope_device, + custom_key_equal, + probe_type>{capacity, + cuco::empty_key{sentinel_key}, + cuco::empty_value{sentinel_value}, + custom_key_equal{}, + probe}; + + auto insert_pairs = thrust::make_transform_iterator( + thrust::counting_iterator(0), + [] __device__(auto i) { return cuco::pair(i, i); }); auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator(0), [] __device__(auto i) { return ProbeKey(i); }); SECTION("All inserted keys-value pairs should be contained") { thrust::device_vector contained(num); - map.insert(insert_pairs, insert_pairs + num, custom_hasher{}, custom_key_equal{}); - map.contains( - probe_keys, probe_keys + num, contained.begin(), custom_hasher{}, custom_key_equal{}); + map.insert(insert_pairs, insert_pairs + num); + map.contains(probe_keys, probe_keys + num, contained.begin()); REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), thrust::identity{})); } SECTION("Non-inserted keys-value pairs should not be contained") { thrust::device_vector contained(num); - map.contains( - probe_keys, probe_keys + num, contained.begin(), custom_hasher{}, custom_key_equal{}); + map.contains(probe_keys, probe_keys + num, contained.begin()); REQUIRE(cuco::test::none_of(contained.begin(), contained.end(), thrust::identity{})); } }