diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 4bf6eb89e..f1c0194ca 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -821,76 +821,6 @@ class open_addressing_ref_impl { } } - /** - * @brief Counts the occurrence of a given key contained in the container - * - * @tparam ProbeKey Probe key type - * - * @param key The key to count for - * - * @return Number of occurrences found by the current thread - */ - template - [[nodiscard]] __device__ size_type count(ProbeKey const& key) const noexcept - { - if constexpr (not allows_duplicates) { - return static_cast(this->contains(key)); - } else { - auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); - size_type count = 0; - - while (true) { - // TODO atomic_ref::load if insert operator is present - auto const bucket_slots = storage_ref_[*probing_iter]; - - for (auto& slot_content : bucket_slots) { - switch ( - this->predicate_.operator()(key, this->extract_key(slot_content))) { - case detail::equal_result::EMPTY: return count; - case detail::equal_result::EQUAL: ++count; break; - default: continue; - } - } - ++probing_iter; - } - } - } - - /** - * @brief Counts the occurrence of a given key contained in the container - * - * @tparam ProbeKey Probe key type - * - * @param group The Cooperative Group used to perform group count - * @param key The key to count for - * - * @return Number of occurrences found by the current thread - */ - template - [[nodiscard]] __device__ size_type count( - cooperative_groups::thread_block_tile const& group, ProbeKey const& key) const noexcept - { - auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent()); - size_type count = 0; - - while (true) { - auto const bucket_slots = storage_ref_[*probing_iter]; - - auto const state = [&]() { - auto res = detail::equal_result::UNEQUAL; - for (auto& slot : bucket_slots) { - res = this->predicate_.operator()(key, this->extract_key(slot)); - if (res == detail::equal_result::EMPTY) { return res; } - count += static_cast(res); - } - return res; - }(); - - if (group.any(state == detail::equal_result::EMPTY)) { return count; } - ++probing_iter; - } - } - /** * @brief Finds an element in the container with key equivalent to the probe key. * @@ -978,6 +908,76 @@ class open_addressing_ref_impl { } } + /** + * @brief Counts the occurrence of a given key contained in the container + * + * @tparam ProbeKey Probe key type + * + * @param key The key to count for + * + * @return Number of occurrences found by the current thread + */ + template + [[nodiscard]] __device__ size_type count(ProbeKey const& key) const noexcept + { + if constexpr (not allows_duplicates) { + return static_cast(this->contains(key)); + } else { + auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); + size_type count = 0; + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const bucket_slots = storage_ref_[*probing_iter]; + + for (auto& slot_content : bucket_slots) { + switch ( + this->predicate_.operator()(key, this->extract_key(slot_content))) { + case detail::equal_result::EMPTY: return count; + case detail::equal_result::EQUAL: ++count; break; + default: continue; + } + } + ++probing_iter; + } + } + } + + /** + * @brief Counts the occurrence of a given key contained in the container + * + * @tparam ProbeKey Probe key type + * + * @param group The Cooperative Group used to perform group count + * @param key The key to count for + * + * @return Number of occurrences found by the current thread + */ + template + [[nodiscard]] __device__ size_type count( + cooperative_groups::thread_block_tile const& group, ProbeKey const& key) const noexcept + { + auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent()); + size_type count = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + auto const state = [&]() { + auto res = detail::equal_result::UNEQUAL; + for (auto& slot : bucket_slots) { + res = this->predicate_.operator()(key, this->extract_key(slot)); + if (res == detail::equal_result::EMPTY) { return res; } + count += static_cast(res); + } + return res; + }(); + + if (group.any(state == detail::equal_result::EMPTY)) { return count; } + ++probing_iter; + } + } + /** * @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin, * input_probe_end)`.