Skip to content

Commit

Permalink
Add count APIs for set and map (#629)
Browse files Browse the repository at this point in the history
This PR adds host and device `count` APIs for `static_set` and
`static_map`
  • Loading branch information
PointKernel authored Oct 30, 2024
1 parent 317c273 commit 6c57bb1
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 0 deletions.
16 changes: 16 additions & 0 deletions include/cuco/detail/static_map/static_map.inl
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,22 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt>
static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::size_type
static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::count(
InputIt first, InputIt last, cuda::stream_ref stream) const
{
return impl_->count(first, last, ref(op::count), stream);
}

template <class Key,
class T,
class Extent,
Expand Down
54 changes: 54 additions & 0 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -1367,5 +1367,59 @@ class operator_impl<
}
};

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
class operator_impl<
op::count_tag,
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>> {
using base_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef>;
using ref_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
using key_type = typename base_type::key_type;
using value_type = typename base_type::value_type;
using size_type = typename base_type::size_type;

static constexpr auto cg_size = base_type::cg_size;
static constexpr auto window_size = base_type::window_size;

public:
/**
* @brief Counts the occurrence of a given key contained in map
*
* @tparam ProbeKey Input type
*
* @param key The key to count for
*
* @return Number of occurrences found by the current thread
*/
template <typename ProbeKey>
__device__ size_type count(ProbeKey const& key) const noexcept
{
auto const& ref_ = static_cast<ref_type const&>(*this);
return ref_.impl_.count(key);
}

/**
* @brief Counts the occurrence of a given key contained in map
*
* @tparam ProbeKey Probe key type
*
* @param group The Cooperative Group used to perform group count
* @param key The key to count for
*
* @return Number of occurrences found by the current thread
*/
template <typename ProbeKey>
__device__ size_type count(cooperative_groups::thread_block_tile<cg_size> const& group,
ProbeKey const& key) const noexcept
{
auto const& ref_ = static_cast<ref_type const&>(*this);
return ref_.impl_.count(group, key);
}
};
} // namespace detail
} // namespace cuco
15 changes: 15 additions & 0 deletions include/cuco/detail/static_set/static_set.inl
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,21 @@ void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
impl_->find_async(first, last, output_begin, ref(op::find), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt>
static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::size_type
static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::count(
InputIt first, InputIt last, cuda::stream_ref stream) const
{
return impl_->count(first, last, ref(op::count), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
Expand Down
52 changes: 52 additions & 0 deletions include/cuco/detail/static_set/static_set_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -629,5 +629,57 @@ class operator_impl<op::find_tag,
}
};

template <typename Key,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
class operator_impl<op::count_tag,
static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>> {
using base_type = static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef>;
using ref_type = static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
using key_type = typename base_type::key_type;
using value_type = typename base_type::value_type;
using size_type = typename base_type::size_type;

static constexpr auto cg_size = base_type::cg_size;
static constexpr auto window_size = base_type::window_size;

public:
/**
* @brief Counts the occurrence of a given key contained in set
*
* @tparam ProbeKey Probe key type
*
* @param key The key to count for
*
* @return Number of occurrences found by the current thread
*/
template <typename ProbeKey>
__device__ size_type count(ProbeKey const& key) const noexcept
{
auto const& ref_ = static_cast<ref_type const&>(*this);
return ref_.impl_.count(key);
}

/**
* @brief Counts the occurrence of a given key contained in set
*
* @tparam ProbeKey Probe key type
*
* @param group The Cooperative Group used to perform group count
* @param key The key to count for
*
* @return Number of occurrences found by the current thread
*/
template <typename ProbeKey>
__device__ size_type count(cooperative_groups::thread_block_tile<cg_size> const& group,
ProbeKey const& key) const noexcept
{
auto const& ref_ = static_cast<ref_type const&>(*this);
return ref_.impl_.count(group, key);
}
};
} // namespace detail
} // namespace cuco
16 changes: 16 additions & 0 deletions include/cuco/static_map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,22 @@ class static_map {
CallbackOp&& callback_op,
cuda::stream_ref stream = {}) const noexcept;

/**
* @brief Counts the occurrences of keys in `[first, last)` contained in the map
*
* @note This function synchronizes the given stream.
*
* @tparam Input Device accessible input iterator
*
* @param first Beginning of the sequence of keys to count
* @param last End of the sequence of keys to count
* @param stream CUDA stream used for count
*
* @return The sum of total occurrences of all keys in `[first, last)`
*/
template <typename InputIt>
size_type count(InputIt first, InputIt last, cuda::stream_ref stream = {}) const;

/**
* @brief Retrieves all of the keys and their associated values.
*
Expand Down
16 changes: 16 additions & 0 deletions include/cuco/static_set.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,22 @@ class static_set {
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief Counts the occurrences of keys in `[first, last)` contained in the set
*
* @note This function synchronizes the given stream.
*
* @tparam Input Device accessible input iterator
*
* @param first Beginning of the sequence of keys to count
* @param last End of the sequence of keys to count
* @param stream CUDA stream used for count
*
* @return The sum of total occurrences of all keys in `[first, last)`
*/
template <typename InputIt>
size_type count(InputIt first, InputIt last, cuda::stream_ref stream = {}) const;

/**
* @brief Retrieves the matched key in the set corresponding to all probe keys in the range
* `[first, last)`
Expand Down
4 changes: 4 additions & 0 deletions tests/static_map/unique_sequence_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ void test_unique_sequence(Map& map, size_type num_keys)
{
REQUIRE(map.size() == 0);

REQUIRE(map.count(keys_begin, keys_begin + num_keys) == 0);

map.contains(keys_begin, keys_begin + num_keys, d_contained.begin());
REQUIRE(cuco::test::none_of(d_contained.begin(), d_contained.end(), thrust::identity{}));
}
Expand Down Expand Up @@ -97,6 +99,8 @@ void test_unique_sequence(Map& map, size_type num_keys)

SECTION("All inserted keys should be contained.")
{
REQUIRE(map.count(keys_begin, keys_begin + num_keys) == num_keys);

map.contains(keys_begin, keys_begin + num_keys, d_contained.begin());
REQUIRE(cuco::test::all_of(d_contained.begin(), d_contained.end(), thrust::identity{}));
}
Expand Down
4 changes: 4 additions & 0 deletions tests/static_set/retrieve_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ void test_unique_sequence(Set& set, std::size_t num_keys)
{
REQUIRE(set.size() == 0);

REQUIRE(set.count(iter, iter + num_keys) == 0);

auto const [probe_end, matched_end] =
set.retrieve(iter, iter + num_keys, keys.begin(), matched_keys.begin());
REQUIRE(std::distance(keys.begin(), probe_end) == 0);
Expand All @@ -55,6 +57,8 @@ void test_unique_sequence(Set& set, std::size_t num_keys)

SECTION("All inserted key/value pairs should be contained.")
{
REQUIRE(set.count(iter, iter + num_keys) == num_keys);

auto const [probe_end, matched_end] =
set.retrieve(iter, iter + num_keys, keys.begin(), matched_keys.begin());
thrust::sort(keys.begin(), probe_end);
Expand Down

0 comments on commit 6c57bb1

Please sign in to comment.