Skip to content

Commit

Permalink
Cleanup conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Sep 29, 2023
1 parent c3f95e3 commit 6cbabd2
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 15 deletions.
6 changes: 3 additions & 3 deletions include/cuco/detail/common_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ __global__ void insert_if_n(InputIt first,

while (idx < n) {
if (pred(*(stencil + idx))) {
typename std::iterator_traits<InputIt>::value_type const insert_element{*(first + idx)};
typename std::iterator_traits<InputIt>::value_type const& insert_element{*(first + idx)};
if constexpr (CGSize == 1) {
if (ref.insert(insert_element)) { thread_num_successes++; };
} else {
Expand Down Expand Up @@ -136,7 +136,7 @@ __global__ void insert_if_n(

while (idx < n) {
if (pred(*(stencil + idx))) {
typename std::iterator_traits<InputIt>::value_type const insert_element{*(first + idx)};
typename std::iterator_traits<InputIt>::value_type const& insert_element{*(first + idx)};
if constexpr (CGSize == 1) {
ref.insert(insert_element);
} else {
Expand Down Expand Up @@ -200,7 +200,7 @@ __global__ void contains_if_n(InputIt first,
while (idx - thread_idx < n) { // the whole thread block falls into the same iteration
if constexpr (CGSize == 1) {
if (idx < n) {
auto const key = *(first + idx);
typename std::iterator_traits<InputIt>::value_type const& key = *(first + idx);
/*
* The ld.relaxed.gpu instruction causes L1 to flush more frequently, causing increased
* sector stores from L2 to global memory. By writing results to shared memory and then
Expand Down
21 changes: 11 additions & 10 deletions include/cuco/detail/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,7 @@ class open_addressing_ref_impl {
if (eq_res == detail::equal_result::EQUAL) { return {iterator{&window_ptr[i]}, false}; }
if (eq_res == detail::equal_result::EMPTY) {
switch ([&]() {
if constexpr ((sizeof(value_type) <= 8) and
cuda::std::is_convertible_v<Value, value_type>) {
if constexpr (sizeof(value_type) <= 8) {
return packed_cas<HasPayload>(window_ptr + i, value, predicate);
} else {
return cas_dependent_write(window_ptr + i, value, predicate);
Expand Down Expand Up @@ -776,14 +775,14 @@ class open_addressing_ref_impl {
[[nodiscard]] __device__ constexpr insert_result back_to_back_cas(
value_type* slot, Value const& value, Predicate const& predicate) noexcept
{
using mapped_type = decltype(this->empty_slot_sentinel_.second);

auto const expected_key = this->empty_slot_sentinel_.first;
auto const expected_payload = this->empty_slot_sentinel_.second;

auto old_key = compare_and_swap(&slot->first, expected_key, value.first); // TODO static_cast?
auto old_key = compare_and_swap(&slot->first, expected_key, static_cast<key_type>(value.first));
auto old_payload =
compare_and_swap(&slot->second, expected_payload, value.second); // TODO static_cast?

using mapped_type = decltype(expected_payload);
compare_and_swap(&slot->second, expected_payload, static_cast<mapped_type>(value.second));

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);
auto* old_payload_ptr = reinterpret_cast<mapped_type*>(&old_payload);
Expand All @@ -792,7 +791,7 @@ class open_addressing_ref_impl {
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
old_payload =
compare_and_swap(&slot->second, expected_payload, value.second); // TODO static_cast?
compare_and_swap(&slot->second, expected_payload, static_cast<mapped_type>(value.second));
}
return insert_result::SUCCESS;
} else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
Expand Down Expand Up @@ -824,15 +823,17 @@ class open_addressing_ref_impl {
[[nodiscard]] __device__ constexpr insert_result cas_dependent_write(
value_type* slot, Value const& value, Predicate const& predicate) noexcept
{
using mapped_type = decltype(this->empty_slot_sentinel_.second);

auto const expected_key = this->empty_slot_sentinel_.first;

auto old_key = compare_and_swap(&slot->first, expected_key, value.first); // TODO static_cast?
auto old_key = compare_and_swap(&slot->first, expected_key, static_cast<key_type>(value.first));

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);

// if key success
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
atomic_store(&slot->second, value.second);
atomic_store(&slot->second, static_cast<mapped_type>(value.second));
return insert_result::SUCCESS;
}

Expand Down Expand Up @@ -866,7 +867,7 @@ class open_addressing_ref_impl {
Value const& value,
Predicate const& predicate) noexcept
{
if constexpr ((sizeof(value_type) <= 8) and cuda::std::is_convertible_v<Value, value_type>) {
if constexpr (sizeof(value_type) <= 8) {
return packed_cas<HasPayload>(slot, value, predicate);
} else {
#if (_CUDA_ARCH__ < 700)
Expand Down
2 changes: 1 addition & 1 deletion include/cuco/detail/static_set/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ __global__ void find(InputIt first, cuco::detail::index_type n, OutputIt output_

while (idx - thread_idx < n) { // the whole thread block falls into the same iteration
if (idx < n) {
typename std::iterator_traits<InputIt>::value_type const key = *(first + idx);
typename std::iterator_traits<InputIt>::value_type const& key = *(first + idx);
if constexpr (CGSize == 1) {
auto const found = ref.find(key);
/*
Expand Down
2 changes: 1 addition & 1 deletion tests/static_set/heterogeneous_lookup_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct key_pair {
// https://github.com/NVIDIA/libcudacxx/issues/223
__device__ bool operator==(key_pair const& other) const { return a == other.a and b == other.b; }

__device__ operator T() const noexcept { return a; }
__device__ explicit operator T() const noexcept { return a; }
};

// probe key type
Expand Down

0 comments on commit 6cbabd2

Please sign in to comment.