diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 565af7aa4..683bf94b1 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -24,6 +24,7 @@ #include #include +#include #include @@ -865,9 +866,9 @@ class open_addressing_ref_impl { Value const& value) const noexcept { if constexpr (this->has_payload) { - return value.first; + return thrust::get<0>(thrust::raw_reference_cast(value)); } else { - return value; + return thrust::raw_reference_cast(value); } } @@ -886,7 +887,7 @@ class open_addressing_ref_impl { [[nodiscard]] __host__ __device__ constexpr auto const& extract_payload( Value const& value) const noexcept { - return value.second; + return thrust::get<1>(thrust::raw_reference_cast(value)); } /** @@ -952,10 +953,10 @@ class open_addressing_ref_impl { auto const expected_key = expected.first; auto const expected_payload = expected.second; - auto old_key = - compare_and_swap(&address->first, expected_key, static_cast(desired.first)); + auto old_key = compare_and_swap( + &address->first, expected_key, static_cast(thrust::get<0>(desired))); auto old_payload = compare_and_swap( - &address->second, expected_payload, static_cast(desired.second)); + &address->second, expected_payload, static_cast(thrust::get<1>(desired))); auto* old_key_ptr = reinterpret_cast(&old_key); auto* old_payload_ptr = reinterpret_cast(&old_payload); @@ -964,7 +965,7 @@ class open_addressing_ref_impl { if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { old_payload = compare_and_swap( - &address->second, expected_payload, static_cast(desired.second)); + &address->second, expected_payload, static_cast(thrust::get<1>(desired))); } return insert_result::SUCCESS; } else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { @@ -973,7 +974,9 @@ class open_addressing_ref_impl { // Our key was already present in the slot, so our key is a duplicate // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) { + if (this->predicate_.equal_to(*old_key_ptr, + thrust::get<0>(thrust::raw_reference_cast(desired))) == + detail::equal_result::EQUAL) { return insert_result::DUPLICATE; } @@ -999,20 +1002,22 @@ class open_addressing_ref_impl { auto const expected_key = expected.first; - auto old_key = - compare_and_swap(&address->first, expected_key, static_cast(desired.first)); + auto old_key = compare_and_swap( + &address->first, expected_key, static_cast(thrust::get<0>(desired))); auto* old_key_ptr = reinterpret_cast(&old_key); // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { - atomic_store(&address->second, static_cast(desired.second)); + atomic_store(&address->second, static_cast(thrust::get<1>(desired))); return insert_result::SUCCESS; } // Our key was already present in the slot, so our key is a duplicate // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) { + if (this->predicate_.equal_to(*old_key_ptr, + thrust::get<0>(thrust::raw_reference_cast(desired))) == + detail::equal_result::EQUAL) { return insert_result::DUPLICATE; } diff --git a/include/cuco/detail/pair.inl b/include/cuco/detail/pair/pair.inl similarity index 89% rename from include/cuco/detail/pair.inl rename to include/cuco/detail/pair/pair.inl index 56d16e4fb..3279a915d 100644 --- a/include/cuco/detail/pair.inl +++ b/include/cuco/detail/pair/pair.inl @@ -49,3 +49,11 @@ __host__ __device__ constexpr bool operator==(cuco::pair const& lhs, } } // namespace cuco + +namespace thrust { +#include +} // namespace thrust + +namespace cuda::std { +#include +} // namespace cuda::std diff --git a/include/cuco/pair.cuh b/include/cuco/pair.cuh index 0a804cc04..d28cae5da 100644 --- a/include/cuco/pair.cuh +++ b/include/cuco/pair.cuh @@ -22,7 +22,7 @@ #include #include -#include +#include #include namespace cuco { @@ -87,7 +87,8 @@ struct alignas(detail::pair_alignment()) pair { */ template ::value>* = nullptr> __host__ __device__ constexpr pair(T const& p) - : pair{std::get<0>(thrust::raw_reference_cast(p)), std::get<1>(thrust::raw_reference_cast(p))} + : pair{cuda::std::get<0>(thrust::raw_reference_cast(p)), + cuda::std::get<1>(thrust::raw_reference_cast(p))} { } @@ -143,4 +144,4 @@ __host__ __device__ constexpr bool operator==(cuco::pair const& lhs, } // namespace cuco -#include +#include