Skip to content

Commit

Permalink
Reordering APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Nov 7, 2024
1 parent d6dd669 commit 8d5bf12
Showing 1 changed file with 70 additions and 70 deletions.
140 changes: 70 additions & 70 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -821,76 +821,6 @@ class open_addressing_ref_impl {
}
}

/**
* @brief Counts the occurrence of a given key contained in the container
*
* @tparam ProbeKey Probe key type
*
* @param key The key to count for
*
* @return Number of occurrences found by the current thread
*/
template <typename ProbeKey>
[[nodiscard]] __device__ size_type count(ProbeKey const& key) const noexcept
{
if constexpr (not allows_duplicates) {
return static_cast<size_type>(this->contains(key));
} else {
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
size_type count = 0;

while (true) {
// TODO atomic_ref::load if insert operator is present
auto const bucket_slots = storage_ref_[*probing_iter];

for (auto& slot_content : bucket_slots) {
switch (
this->predicate_.operator()<is_insert::NO>(key, this->extract_key(slot_content))) {
case detail::equal_result::EMPTY: return count;
case detail::equal_result::EQUAL: ++count; break;
default: continue;
}
}
++probing_iter;
}
}
}

/**
* @brief Counts the occurrence of a given key contained in the container
*
* @tparam ProbeKey Probe key type
*
* @param group The Cooperative Group used to perform group count
* @param key The key to count for
*
* @return Number of occurrences found by the current thread
*/
template <typename ProbeKey>
[[nodiscard]] __device__ size_type count(
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
{
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
size_type count = 0;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];

auto const state = [&]() {
auto res = detail::equal_result::UNEQUAL;
for (auto& slot : bucket_slots) {
res = this->predicate_.operator()<is_insert::NO>(key, this->extract_key(slot));
if (res == detail::equal_result::EMPTY) { return res; }
count += static_cast<size_type>(res);
}
return res;
}();

if (group.any(state == detail::equal_result::EMPTY)) { return count; }
++probing_iter;
}
}

/**
* @brief Finds an element in the container with key equivalent to the probe key.
*
Expand Down Expand Up @@ -978,6 +908,76 @@ class open_addressing_ref_impl {
}
}

/**
* @brief Counts the occurrence of a given key contained in the container
*
* @tparam ProbeKey Probe key type
*
* @param key The key to count for
*
* @return Number of occurrences found by the current thread
*/
template <typename ProbeKey>
[[nodiscard]] __device__ size_type count(ProbeKey const& key) const noexcept
{
if constexpr (not allows_duplicates) {
return static_cast<size_type>(this->contains(key));
} else {
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
size_type count = 0;

while (true) {
// TODO atomic_ref::load if insert operator is present
auto const bucket_slots = storage_ref_[*probing_iter];

for (auto& slot_content : bucket_slots) {
switch (
this->predicate_.operator()<is_insert::NO>(key, this->extract_key(slot_content))) {
case detail::equal_result::EMPTY: return count;
case detail::equal_result::EQUAL: ++count; break;
default: continue;
}
}
++probing_iter;
}
}
}

/**
* @brief Counts the occurrence of a given key contained in the container
*
* @tparam ProbeKey Probe key type
*
* @param group The Cooperative Group used to perform group count
* @param key The key to count for
*
* @return Number of occurrences found by the current thread
*/
template <typename ProbeKey>
[[nodiscard]] __device__ size_type count(
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
{
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
size_type count = 0;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];

auto const state = [&]() {
auto res = detail::equal_result::UNEQUAL;
for (auto& slot : bucket_slots) {
res = this->predicate_.operator()<is_insert::NO>(key, this->extract_key(slot));
if (res == detail::equal_result::EMPTY) { return res; }
count += static_cast<size_type>(res);
}
return res;
}();

if (group.any(state == detail::equal_result::EMPTY)) { return count; }
++probing_iter;
}
}

/**
* @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin,
* input_probe_end)`.
Expand Down

0 comments on commit 8d5bf12

Please sign in to comment.