Skip to content

Commit

Permalink
Fix several bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Oct 9, 2023
1 parent a1b5987 commit 58300e5
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 35 deletions.
8 changes: 3 additions & 5 deletions include/cuco/detail/equal_wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace detail {
/**
* @brief Enum of equality comparison results.
*/
enum class equal_result : int32_t { UNEQUAL = 0, AVAILABLE = 1, EQUAL = 2 };
enum class equal_result : int32_t { UNEQUAL = 0, EMPTY = 1, EQUAL = 2, ERASED = 3 };

/**
* @brief Key equality wrapper.
Expand Down Expand Up @@ -92,10 +92,8 @@ struct equal_wrapper {
template <typename LHS, typename RHS>
__device__ constexpr equal_result operator()(LHS const& lhs, RHS const& rhs) const noexcept
{
return cuco::detail::bitwise_compare(lhs, empty_sentinel_) or
cuco::detail::bitwise_compare(lhs, erased_sentinel_)
? equal_result::AVAILABLE
: this->equal_to(lhs, rhs);
return cuco::detail::bitwise_compare(lhs, empty_sentinel_) ? equal_result::EMPTY
: this->equal_to(lhs, rhs);
}
};

Expand Down
73 changes: 47 additions & 26 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,9 @@ class open_addressing_ref_impl {

// If the key is already in the container, return false
if (eq_res == detail::equal_result::EQUAL) { return false; }
if (eq_res == detail::equal_result::AVAILABLE) {
if (eq_res == detail::equal_result::EMPTY or
cuco::detail::bitwise_compare(this->extract_key(slot_content),
this->erased_key_sentinel())) {
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
slot_content,
Expand Down Expand Up @@ -321,11 +323,19 @@ class open_addressing_ref_impl {
auto const [state, intra_window_index] = [&]() {
for (auto i = 0; i < window_size; ++i) {
switch (this->predicate_(this->extract_key(window_slots[i]), key)) {
case detail::equal_result::AVAILABLE:
return window_probing_results{detail::equal_result::AVAILABLE, i};
case detail::equal_result::EMPTY:
return window_probing_results{detail::equal_result::EMPTY, i};
case detail::equal_result::EQUAL:
return window_probing_results{detail::equal_result::EQUAL, i};
default: continue;
default: {
continue;
if (cuco::detail::bitwise_compare(this->extract_key(window_slots[i]),
this->erased_key_sentinel())) {
return window_probing_results{detail::equal_result::ERASED, i};
} else {
continue;
}
}
}
}
// returns dummy index `-1` for UNEQUAL
Expand All @@ -335,14 +345,14 @@ class open_addressing_ref_impl {
// If the key is already in the container, return false
if (group.any(state == detail::equal_result::EQUAL)) { return false; }

auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);

auto const group_contains_available =
group.ballot(state == detail::equal_result::EMPTY or state == detail::equal_result::ERASED);
if (group_contains_available) {
auto const src_lane = __ffs(group_contains_available) - 1;
auto const status =
(group.thread_rank() == src_lane)
? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
window_slots[src_lane],
window_slots[intra_window_index],
value)
: insert_result::CONTINUE;

Expand Down Expand Up @@ -388,7 +398,9 @@ class open_addressing_ref_impl {

// If the key is already in the container, return false
if (eq_res == detail::equal_result::EQUAL) { return {iterator{&window_ptr[i]}, false}; }
if (eq_res == detail::equal_result::AVAILABLE) {
if (eq_res == detail::equal_result::EMPTY or
cuco::detail::bitwise_compare(this->extract_key(window_slots[i]),
this->erased_key_sentinel())) {
switch ([&]() {
if constexpr (sizeof(value_type) <= 8) {
return packed_cas(window_ptr + i, window_slots[i], value);
Expand Down Expand Up @@ -438,11 +450,19 @@ class open_addressing_ref_impl {
auto const [state, intra_window_index] = [&]() {
for (auto i = 0; i < window_size; ++i) {
switch (this->predicate_(this->extract_key(window_slots[i]), key)) {
case detail::equal_result::AVAILABLE:
return window_probing_results{detail::equal_result::AVAILABLE, i};
case detail::equal_result::EMPTY:
return window_probing_results{detail::equal_result::EMPTY, i};
case detail::equal_result::EQUAL:
return window_probing_results{detail::equal_result::EQUAL, i};
default: continue;
default: {
continue;
if (cuco::detail::bitwise_compare(this->extract_key(window_slots[i]),
this->erased_key_sentinel())) {
return window_probing_results{detail::equal_result::ERASED, i};
} else {
continue;
}
}
}
}
// returns dummy index `-1` for UNEQUAL
Expand All @@ -459,16 +479,17 @@ class open_addressing_ref_impl {
return {iterator{reinterpret_cast<value_type*>(res)}, false};
}

auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
auto const group_contains_available =
group.ballot(state == detail::equal_result::EMPTY or state == detail::equal_result::ERASED);
if (group_contains_available) {
auto const src_lane = __ffs(group_contains_available) - 1;
auto const res = group.shfl(reinterpret_cast<intptr_t>(slot_ptr), src_lane);
auto const status = [&]() {
auto const status = [&, target_idx = intra_window_index]() {
if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; }
if constexpr (sizeof(value_type) <= 8) {
return packed_cas(slot_ptr, window_slots[src_lane], value);
return packed_cas(slot_ptr, window_slots[target_idx], value);
} else {
return cas_dependent_write(slot_ptr, window_slots[src_lane], value);
return cas_dependent_write(slot_ptr, window_slots[target_idx], value);
}
}();

Expand Down Expand Up @@ -510,7 +531,7 @@ class open_addressing_ref_impl {
auto const eq_res = this->predicate_(this->extract_key(slot_content), key);

// Key doesn't exist, return false
if (eq_res == detail::equal_result::AVAILABLE) { return false; }
if (eq_res == detail::equal_result::EMPTY) { return false; }
// Key exists, return true if successfully deleted
if (eq_res == detail::equal_result::EQUAL) {
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
Expand Down Expand Up @@ -549,8 +570,8 @@ class open_addressing_ref_impl {
auto const [state, intra_window_index] = [&]() {
for (auto i = 0; i < window_size; ++i) {
switch (this->predicate_(this->extract_key(window_slots[i]), key)) {
case detail::equal_result::AVAILABLE:
return window_probing_results{detail::equal_result::AVAILABLE, i};
case detail::equal_result::EMPTY:
return window_probing_results{detail::equal_result::EMPTY, i};
case detail::equal_result::EQUAL:
return window_probing_results{detail::equal_result::EQUAL, i};
default: continue;
Expand All @@ -575,7 +596,7 @@ class open_addressing_ref_impl {
case insert_result::DUPLICATE: return false;
default: continue;
}
} else if (group.any(state == detail::equal_result::AVAILABLE)) {
} else if (group.any(state == detail::equal_result::EMPTY)) {
// Key doesn't exist, return false
return false;
} else {
Expand Down Expand Up @@ -609,7 +630,7 @@ class open_addressing_ref_impl {
for (auto& slot_content : window_slots) {
switch (this->predicate_(this->extract_key(slot_content), key)) {
case detail::equal_result::UNEQUAL: continue;
case detail::equal_result::AVAILABLE: return false;
case detail::equal_result::EMPTY: return false;
case detail::equal_result::EQUAL: return true;
}
}
Expand Down Expand Up @@ -642,7 +663,7 @@ class open_addressing_ref_impl {
auto const state = [&]() {
for (auto& slot : window_slots) {
switch (this->predicate_(this->extract_key(slot), key)) {
case detail::equal_result::AVAILABLE: return detail::equal_result::AVAILABLE;
case detail::equal_result::EMPTY: return detail::equal_result::EMPTY;
case detail::equal_result::EQUAL: return detail::equal_result::EQUAL;
default: continue;
}
Expand All @@ -651,7 +672,7 @@ class open_addressing_ref_impl {
}();

if (group.any(state == detail::equal_result::EQUAL)) { return true; }
if (group.any(state == detail::equal_result::AVAILABLE)) { return false; }
if (group.any(state == detail::equal_result::EMPTY)) { return false; }

++probing_iter;
}
Expand Down Expand Up @@ -681,7 +702,7 @@ class open_addressing_ref_impl {

for (auto i = 0; i < window_size; ++i) {
switch (this->predicate_(this->extract_key(window_slots[i]), key)) {
case detail::equal_result::AVAILABLE: {
case detail::equal_result::EMPTY: {
return this->end();
}
case detail::equal_result::EQUAL: {
Expand Down Expand Up @@ -719,8 +740,8 @@ class open_addressing_ref_impl {
auto const [state, intra_window_index] = [&]() {
for (auto i = 0; i < window_size; ++i) {
switch (this->predicate_(this->extract_key(window_slots[i]), key)) {
case detail::equal_result::AVAILABLE:
return window_probing_results{detail::equal_result::AVAILABLE, i};
case detail::equal_result::EMPTY:
return window_probing_results{detail::equal_result::EMPTY, i};
case detail::equal_result::EQUAL:
return window_probing_results{detail::equal_result::EQUAL, i};
default: continue;
Expand All @@ -741,7 +762,7 @@ class open_addressing_ref_impl {
}

// Find an empty slot, meaning that the probe key isn't present in the container
if (group.any(state == detail::equal_result::AVAILABLE)) { return this->end(); }
if (group.any(state == detail::equal_result::EMPTY)) { return this->end(); }

++probing_iter;
}
Expand Down
8 changes: 4 additions & 4 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ class operator_impl<
value.second);
return;
}
if (eq_res == detail::equal_result::AVAILABLE) {
if (eq_res == detail::equal_result::EMPTY) {
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
if (attempt_insert_or_assign(
(storage_ref.data() + *probing_iter)->data() + intra_window_index, value)) {
Expand Down Expand Up @@ -293,8 +293,8 @@ class operator_impl<
auto const [state, intra_window_index] = [&]() {
for (auto i = 0; i < window_size; ++i) {
switch (ref_.impl_.predicate()(window_slots[i].first, key)) {
case detail::equal_result::AVAILABLE:
return detail::window_probing_results{detail::equal_result::AVAILABLE, i};
case detail::equal_result::EMPTY:
return detail::window_probing_results{detail::equal_result::EMPTY, i};
case detail::equal_result::EQUAL:
return detail::window_probing_results{detail::equal_result::EQUAL, i};
default: continue;
Expand All @@ -316,7 +316,7 @@ class operator_impl<
return;
}

auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
auto const group_contains_available = group.ballot(state == detail::equal_result::EMPTY);
if (group_contains_available) {
auto const src_lane = __ffs(group_contains_available) - 1;
auto const status =
Expand Down

0 comments on commit 58300e5

Please sign in to comment.