Skip to content

Commit

Permalink
Fix tuple handling
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Oct 12, 2023
1 parent c15d326 commit eed2c54
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
29 changes: 17 additions & 12 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <thrust/distance.h>
#include <thrust/pair.h>
#include <thrust/tuple.h>

#include <cuda/atomic>

Expand Down Expand Up @@ -865,9 +866,9 @@ class open_addressing_ref_impl {
Value const& value) const noexcept
{
if constexpr (this->has_payload) {
return value.first;
return thrust::get<0>(thrust::raw_reference_cast(value));
} else {
return value;
return thrust::raw_reference_cast(value);
}
}

Expand All @@ -886,7 +887,7 @@ class open_addressing_ref_impl {
[[nodiscard]] __host__ __device__ constexpr auto const& extract_payload(
Value const& value) const noexcept
{
return value.second;
return thrust::get<1>(thrust::raw_reference_cast(value));
}

/**
Expand Down Expand Up @@ -952,10 +953,10 @@ class open_addressing_ref_impl {
auto const expected_key = expected.first;
auto const expected_payload = expected.second;

auto old_key =
compare_and_swap(&address->first, expected_key, static_cast<key_type>(desired.first));
auto old_key = compare_and_swap(
&address->first, expected_key, static_cast<key_type>(thrust::get<0>(desired)));
auto old_payload = compare_and_swap(
&address->second, expected_payload, static_cast<mapped_type>(desired.second));
&address->second, expected_payload, static_cast<mapped_type>(thrust::get<1>(desired)));

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);
auto* old_payload_ptr = reinterpret_cast<mapped_type*>(&old_payload);
Expand All @@ -964,7 +965,7 @@ class open_addressing_ref_impl {
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
old_payload = compare_and_swap(
&address->second, expected_payload, static_cast<mapped_type>(desired.second));
&address->second, expected_payload, static_cast<mapped_type>(thrust::get<1>(desired)));
}
return insert_result::SUCCESS;
} else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
Expand All @@ -973,7 +974,9 @@ class open_addressing_ref_impl {

// Our key was already present in the slot, so our key is a duplicate
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
if (this->predicate_.equal_to(*old_key_ptr,
thrust::get<0>(thrust::raw_reference_cast(desired))) ==
detail::equal_result::EQUAL) {
return insert_result::DUPLICATE;
}

Expand All @@ -999,20 +1002,22 @@ class open_addressing_ref_impl {

auto const expected_key = expected.first;

auto old_key =
compare_and_swap(&address->first, expected_key, static_cast<key_type>(desired.first));
auto old_key = compare_and_swap(
&address->first, expected_key, static_cast<key_type>(thrust::get<0>(desired)));

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);

// if key success
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
atomic_store(&address->second, static_cast<mapped_type>(desired.second));
atomic_store(&address->second, static_cast<mapped_type>(thrust::get<1>(desired)));
return insert_result::SUCCESS;
}

// Our key was already present in the slot, so our key is a duplicate
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
if (this->predicate_.equal_to(*old_key_ptr,
thrust::get<0>(thrust::raw_reference_cast(desired))) ==
detail::equal_result::EQUAL) {
return insert_result::DUPLICATE;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,11 @@ __host__ __device__ constexpr bool operator==(cuco::pair<T1, T2> const& lhs,
}

} // namespace cuco

namespace thrust {
#include <cuco/detail/pair/tuple_helpers.inl>
} // namespace thrust

namespace cuda::std {
#include <cuco/detail/pair/tuple_helpers.inl>
} // namespace cuda::std
7 changes: 4 additions & 3 deletions include/cuco/pair.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <thrust/device_reference.h>
#include <thrust/tuple.h>

#include <tuple>
#include <cuda/std/tuple>
#include <type_traits>

namespace cuco {
Expand Down Expand Up @@ -87,7 +87,8 @@ struct alignas(detail::pair_alignment<First, Second>()) pair {
*/
template <typename T, std::enable_if_t<detail::is_std_pair_like<T>::value>* = nullptr>
__host__ __device__ constexpr pair(T const& p)
: pair{std::get<0>(thrust::raw_reference_cast(p)), std::get<1>(thrust::raw_reference_cast(p))}
: pair{cuda::std::get<0>(thrust::raw_reference_cast(p)),
cuda::std::get<1>(thrust::raw_reference_cast(p))}
{
}

Expand Down Expand Up @@ -143,4 +144,4 @@ __host__ __device__ constexpr bool operator==(cuco::pair<T1, T2> const& lhs,

} // namespace cuco

#include <cuco/detail/pair.inl>
#include <cuco/detail/pair/pair.inl>

0 comments on commit eed2c54

Please sign in to comment.