diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 565af7aa4..683bf94b1 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -24,6 +24,7 @@ #include #include +#include #include @@ -865,9 +866,9 @@ class open_addressing_ref_impl { Value const& value) const noexcept { if constexpr (this->has_payload) { - return value.first; + return thrust::get<0>(thrust::raw_reference_cast(value)); } else { - return value; + return thrust::raw_reference_cast(value); } } @@ -886,7 +887,7 @@ class open_addressing_ref_impl { [[nodiscard]] __host__ __device__ constexpr auto const& extract_payload( Value const& value) const noexcept { - return value.second; + return thrust::get<1>(thrust::raw_reference_cast(value)); } /** @@ -952,10 +953,10 @@ class open_addressing_ref_impl { auto const expected_key = expected.first; auto const expected_payload = expected.second; - auto old_key = - compare_and_swap(&address->first, expected_key, static_cast(desired.first)); + auto old_key = compare_and_swap( + &address->first, expected_key, static_cast(thrust::get<0>(desired))); auto old_payload = compare_and_swap( - &address->second, expected_payload, static_cast(desired.second)); + &address->second, expected_payload, static_cast(thrust::get<1>(desired))); auto* old_key_ptr = reinterpret_cast(&old_key); auto* old_payload_ptr = reinterpret_cast(&old_payload); @@ -964,7 +965,7 @@ class open_addressing_ref_impl { if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { old_payload = compare_and_swap( - &address->second, expected_payload, static_cast(desired.second)); + &address->second, expected_payload, static_cast(thrust::get<1>(desired))); } return insert_result::SUCCESS; } else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { @@ -973,7 +974,9 @@ class open_addressing_ref_impl { // Our key was already present in the slot, so our key is a duplicate // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) { + if (this->predicate_.equal_to(*old_key_ptr, + thrust::get<0>(thrust::raw_reference_cast(desired))) == + detail::equal_result::EQUAL) { return insert_result::DUPLICATE; } @@ -999,20 +1002,22 @@ class open_addressing_ref_impl { auto const expected_key = expected.first; - auto old_key = - compare_and_swap(&address->first, expected_key, static_cast(desired.first)); + auto old_key = compare_and_swap( + &address->first, expected_key, static_cast(thrust::get<0>(desired))); auto* old_key_ptr = reinterpret_cast(&old_key); // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { - atomic_store(&address->second, static_cast(desired.second)); + atomic_store(&address->second, static_cast(thrust::get<1>(desired))); return insert_result::SUCCESS; } // Our key was already present in the slot, so our key is a duplicate // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) { + if (this->predicate_.equal_to(*old_key_ptr, + thrust::get<0>(thrust::raw_reference_cast(desired))) == + detail::equal_result::EQUAL) { return insert_result::DUPLICATE; } diff --git a/include/cuco/detail/pair.inl b/include/cuco/detail/pair/pair.inl similarity index 89% rename from include/cuco/detail/pair.inl rename to include/cuco/detail/pair/pair.inl index 56d16e4fb..3279a915d 100644 --- a/include/cuco/detail/pair.inl +++ b/include/cuco/detail/pair/pair.inl @@ -49,3 +49,11 @@ __host__ __device__ constexpr bool operator==(cuco::pair const& lhs, } } // namespace cuco + +namespace thrust { +#include +} // namespace thrust + +namespace cuda::std { +#include +} // namespace cuda::std diff --git a/include/cuco/detail/pair/tuple_helpers.inl b/include/cuco/detail/pair/tuple_helpers.inl new file mode 100644 index 000000000..29be199ee --- /dev/null +++ b/include/cuco/detail/pair/tuple_helpers.inl @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +template +struct tuple_size> : integral_constant { +}; + +template +struct tuple_size> : tuple_size> { +}; + +template +struct tuple_size> : tuple_size> { +}; + +template +struct tuple_size> : tuple_size> { +}; + +template +struct tuple_element> { + using type = void; +}; + +template +struct tuple_element<0, cuco::pair> { + using type = T1; +}; + +template +struct tuple_element<1, cuco::pair> { + using type = T2; +}; + +template +struct tuple_element<0, const cuco::pair> : tuple_element<0, cuco::pair> { +}; + +template +struct tuple_element<1, const cuco::pair> : tuple_element<1, cuco::pair> { +}; + +template +struct tuple_element<0, volatile cuco::pair> : tuple_element<0, cuco::pair> { +}; + +template +struct tuple_element<1, volatile cuco::pair> : tuple_element<1, cuco::pair> { +}; + +template +struct tuple_element<0, const volatile cuco::pair> : tuple_element<0, cuco::pair> { +}; + +template +struct tuple_element<1, const volatile cuco::pair> : tuple_element<1, cuco::pair> { +}; + +template +__host__ __device__ constexpr auto get(cuco::pair& p) -> + typename tuple_element>::type& +{ + static_assert(I < 2); + if constexpr (I == 0) { + return p.first; + } else { + return p.second; + } +} + +template +__host__ __device__ constexpr auto get(cuco::pair&& p) -> + typename tuple_element>::type&& +{ + static_assert(I < 2); + if constexpr (I == 0) { + return std::move(p.first); + } else { + return std::move(p.second); + } +} + +template +__host__ __device__ constexpr auto get(cuco::pair const& p) -> + typename tuple_element>::type const& +{ + static_assert(I < 2); + if constexpr (I == 0) { + return p.first; + } else { + return p.second; + } +} + +template +__host__ __device__ constexpr auto get(cuco::pair const&& p) -> + typename tuple_element>::type const&& +{ + static_assert(I < 2); + if constexpr (I == 0) { + return std::move(p.first); + } else { + return std::move(p.second); + } +} \ No newline at end of file diff --git a/include/cuco/pair.cuh b/include/cuco/pair.cuh index 0a804cc04..d28cae5da 100644 --- a/include/cuco/pair.cuh +++ b/include/cuco/pair.cuh @@ -22,7 +22,7 @@ #include #include -#include +#include #include namespace cuco { @@ -87,7 +87,8 @@ struct alignas(detail::pair_alignment()) pair { */ template ::value>* = nullptr> __host__ __device__ constexpr pair(T const& p) - : pair{std::get<0>(thrust::raw_reference_cast(p)), std::get<1>(thrust::raw_reference_cast(p))} + : pair{cuda::std::get<0>(thrust::raw_reference_cast(p)), + cuda::std::get<1>(thrust::raw_reference_cast(p))} { } @@ -143,4 +144,4 @@ __host__ __device__ constexpr bool operator==(cuco::pair const& lhs, } // namespace cuco -#include +#include