From fdee9bdba8d87ebcebc201232d68c81bb8ad34c6 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 29 Oct 2024 09:38:03 -0700 Subject: [PATCH] Add count APIs for set and map --- include/cuco/detail/static_map/static_map.inl | 16 ++++++ .../cuco/detail/static_map/static_map_ref.inl | 54 +++++++++++++++++++ include/cuco/detail/static_set/static_set.inl | 15 ++++++ .../cuco/detail/static_set/static_set_ref.inl | 52 ++++++++++++++++++ include/cuco/static_map.cuh | 16 ++++++ include/cuco/static_set.cuh | 16 ++++++ tests/static_map/unique_sequence_test.cu | 4 ++ tests/static_set/retrieve_test.cu | 4 ++ 8 files changed, 177 insertions(+) diff --git a/include/cuco/detail/static_map/static_map.inl b/include/cuco/detail/static_map/static_map.inl index e2915e1fd..16da68629 100644 --- a/include/cuco/detail/static_map/static_map.inl +++ b/include/cuco/detail/static_map/static_map.inl @@ -555,6 +555,22 @@ void static_map(callback_op), ref(op::for_each), stream); } +template +template +static_map::size_type +static_map::count( + InputIt first, InputIt last, cuda::stream_ref stream) const +{ + return impl_->count(first, last, ref(op::count), stream); +} + template +class operator_impl< + op::count_tag, + static_map_ref> { + using base_type = static_map_ref; + using ref_type = static_map_ref; + using key_type = typename base_type::key_type; + using value_type = typename base_type::value_type; + using size_type = typename base_type::size_type; + + static constexpr auto cg_size = base_type::cg_size; + static constexpr auto window_size = base_type::window_size; + + public: + /** + * @brief Counts the occurrence of a given key contained in map + * + * @tparam ProbeKey Input type + * + * @param key The key to count for + * + * @return Number of occurrences found by the current thread + */ + template + __device__ size_type count(ProbeKey const& key) const noexcept + { + auto const& ref_ = static_cast(*this); + return ref_.impl_.count(key); + } + + /** + * @brief Counts the occurrence of a given key contained in map + * + * @tparam ProbeKey Probe key type + * + * @param group The Cooperative Group used to perform group count + * @param key The key to count for + * + * @return Number of occurrences found by the current thread + */ + template + __device__ size_type count(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key) const noexcept + { + auto const& ref_ = static_cast(*this); + return ref_.impl_.count(group, key); + } +}; } // namespace detail } // namespace cuco diff --git a/include/cuco/detail/static_set/static_set.inl b/include/cuco/detail/static_set/static_set.inl index 3e0434031..477ef650e 100644 --- a/include/cuco/detail/static_set/static_set.inl +++ b/include/cuco/detail/static_set/static_set.inl @@ -338,6 +338,21 @@ void static_set impl_->find_async(first, last, output_begin, ref(op::find), stream); } +template +template +static_set::size_type +static_set::count( + InputIt first, InputIt last, cuda::stream_ref stream) const +{ + return impl_->count(first, last, ref(op::count), stream); +} + template +class operator_impl> { + using base_type = static_set_ref; + using ref_type = static_set_ref; + using key_type = typename base_type::key_type; + using value_type = typename base_type::value_type; + using size_type = typename base_type::size_type; + + static constexpr auto cg_size = base_type::cg_size; + static constexpr auto window_size = base_type::window_size; + + public: + /** + * @brief Counts the occurrence of a given key contained in multiset + * + * @tparam ProbeKey Probe key type + * + * @param key The key to count for + * + * @return Number of occurrences found by the current thread + */ + template + __device__ size_type count(ProbeKey const& key) const noexcept + { + auto const& ref_ = static_cast(*this); + return ref_.impl_.count(key); + } + + /** + * @brief Counts the occurrence of a given key contained in multiset + * + * @tparam ProbeKey Probe key type + * + * @param group The Cooperative Group used to perform group count + * @param key The key to count for + * + * @return Number of occurrences found by the current thread + */ + template + __device__ size_type count(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key) const noexcept + { + auto const& ref_ = static_cast(*this); + return ref_.impl_.count(group, key); + } +}; } // namespace detail } // namespace cuco diff --git a/include/cuco/static_map.cuh b/include/cuco/static_map.cuh index fc7dc088d..5079f0479 100644 --- a/include/cuco/static_map.cuh +++ b/include/cuco/static_map.cuh @@ -835,6 +835,22 @@ class static_map { CallbackOp&& callback_op, cuda::stream_ref stream = {}) const noexcept; + /** + * @brief Counts the occurrences of keys in `[first, last)` contained in the map + * + * @note This function synchronizes the given stream. + * + * @tparam Input Device accessible input iterator + * + * @param first Beginning of the sequence of keys to count + * @param last End of the sequence of keys to count + * @param stream CUDA stream used for count + * + * @return The sum of total occurrences of all keys in `[first, last)` + */ + template + size_type count(InputIt first, InputIt last, cuda::stream_ref stream = {}) const; + /** * @brief Retrieves all of the keys and their associated values. * diff --git a/include/cuco/static_set.cuh b/include/cuco/static_set.cuh index eb4b4a242..8da360d75 100644 --- a/include/cuco/static_set.cuh +++ b/include/cuco/static_set.cuh @@ -590,6 +590,22 @@ class static_set { OutputIt output_begin, cuda::stream_ref stream = {}) const; + /** + * @brief Counts the occurrences of keys in `[first, last)` contained in the set + * + * @note This function synchronizes the given stream. + * + * @tparam Input Device accessible input iterator + * + * @param first Beginning of the sequence of keys to count + * @param last End of the sequence of keys to count + * @param stream CUDA stream used for count + * + * @return The sum of total occurrences of all keys in `[first, last)` + */ + template + size_type count(InputIt first, InputIt last, cuda::stream_ref stream = {}) const; + /** * @brief Retrieves the matched key in the set corresponding to all probe keys in the range * `[first, last)` diff --git a/tests/static_map/unique_sequence_test.cu b/tests/static_map/unique_sequence_test.cu index 22cfd2d4a..4ab864ab7 100644 --- a/tests/static_map/unique_sequence_test.cu +++ b/tests/static_map/unique_sequence_test.cu @@ -60,6 +60,8 @@ void test_unique_sequence(Map& map, size_type num_keys) { REQUIRE(map.size() == 0); + REQUIRE(map.count(keys_begin, keys_begin + num_keys) == 0); + map.contains(keys_begin, keys_begin + num_keys, d_contained.begin()); REQUIRE(cuco::test::none_of(d_contained.begin(), d_contained.end(), thrust::identity{})); } @@ -97,6 +99,8 @@ void test_unique_sequence(Map& map, size_type num_keys) SECTION("All inserted keys should be contained.") { + REQUIRE(map.count(keys_begin, keys_begin + num_keys) == num_keys); + map.contains(keys_begin, keys_begin + num_keys, d_contained.begin()); REQUIRE(cuco::test::all_of(d_contained.begin(), d_contained.end(), thrust::identity{})); } diff --git a/tests/static_set/retrieve_test.cu b/tests/static_set/retrieve_test.cu index 013731cbe..7ea88d418 100644 --- a/tests/static_set/retrieve_test.cu +++ b/tests/static_set/retrieve_test.cu @@ -45,6 +45,8 @@ void test_unique_sequence(Set& set, std::size_t num_keys) { REQUIRE(set.size() == 0); + REQUIRE(set.count(iter, iter + num_keys) == 0); + auto const [probe_end, matched_end] = set.retrieve(iter, iter + num_keys, keys.begin(), matched_keys.begin()); REQUIRE(std::distance(keys.begin(), probe_end) == 0); @@ -55,6 +57,8 @@ void test_unique_sequence(Set& set, std::size_t num_keys) SECTION("All inserted key/value pairs should be contained.") { + REQUIRE(set.count(iter, iter + num_keys) == num_keys); + auto const [probe_end, matched_end] = set.retrieve(iter, iter + num_keys, keys.begin(), matched_keys.begin()); thrust::sort(keys.begin(), probe_end);