Skip to content

Commit

Permalink
Fix tuple handling
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Oct 12, 2023
1 parent c15d326 commit 03f4a89
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 15 deletions.
29 changes: 17 additions & 12 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <thrust/distance.h>
#include <thrust/pair.h>
#include <thrust/tuple.h>

#include <cuda/atomic>

Expand Down Expand Up @@ -865,9 +866,9 @@ class open_addressing_ref_impl {
Value const& value) const noexcept
{
if constexpr (this->has_payload) {
return value.first;
return thrust::get<0>(thrust::raw_reference_cast(value));
} else {
return value;
return thrust::raw_reference_cast(value);
}
}

Expand All @@ -886,7 +887,7 @@ class open_addressing_ref_impl {
[[nodiscard]] __host__ __device__ constexpr auto const& extract_payload(
Value const& value) const noexcept
{
return value.second;
return thrust::get<1>(thrust::raw_reference_cast(value));
}

/**
Expand Down Expand Up @@ -952,10 +953,10 @@ class open_addressing_ref_impl {
auto const expected_key = expected.first;
auto const expected_payload = expected.second;

auto old_key =
compare_and_swap(&address->first, expected_key, static_cast<key_type>(desired.first));
auto old_key = compare_and_swap(
&address->first, expected_key, static_cast<key_type>(thrust::get<0>(desired)));
auto old_payload = compare_and_swap(
&address->second, expected_payload, static_cast<mapped_type>(desired.second));
&address->second, expected_payload, static_cast<mapped_type>(thrust::get<1>(desired)));

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);
auto* old_payload_ptr = reinterpret_cast<mapped_type*>(&old_payload);
Expand All @@ -964,7 +965,7 @@ class open_addressing_ref_impl {
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
old_payload = compare_and_swap(
&address->second, expected_payload, static_cast<mapped_type>(desired.second));
&address->second, expected_payload, static_cast<mapped_type>(thrust::get<1>(desired)));
}
return insert_result::SUCCESS;
} else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
Expand All @@ -973,7 +974,9 @@ class open_addressing_ref_impl {

// Our key was already present in the slot, so our key is a duplicate
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
if (this->predicate_.equal_to(*old_key_ptr,
thrust::get<0>(thrust::raw_reference_cast(desired))) ==
detail::equal_result::EQUAL) {
return insert_result::DUPLICATE;
}

Expand All @@ -999,20 +1002,22 @@ class open_addressing_ref_impl {

auto const expected_key = expected.first;

auto old_key =
compare_and_swap(&address->first, expected_key, static_cast<key_type>(desired.first));
auto old_key = compare_and_swap(
&address->first, expected_key, static_cast<key_type>(thrust::get<0>(desired)));

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);

// if key success
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
atomic_store(&address->second, static_cast<mapped_type>(desired.second));
atomic_store(&address->second, static_cast<mapped_type>(thrust::get<1>(desired)));
return insert_result::SUCCESS;
}

// Our key was already present in the slot, so our key is a duplicate
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
if (this->predicate_.equal_to(*old_key_ptr,
thrust::get<0>(thrust::raw_reference_cast(desired))) ==
detail::equal_result::EQUAL) {
return insert_result::DUPLICATE;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,11 @@ __host__ __device__ constexpr bool operator==(cuco::pair<T1, T2> const& lhs,
}

} // namespace cuco

namespace thrust {
#include <cuco/detail/pair/tuple_helpers.inl>
} // namespace thrust

namespace cuda::std {
#include <cuco/detail/pair/tuple_helpers.inl>
} // namespace cuda::std
118 changes: 118 additions & 0 deletions include/cuco/detail/pair/tuple_helpers.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

template <typename T1, typename T2>
struct tuple_size<cuco::pair<T1, T2>> : integral_constant<size_t, 2> {
};

template <typename T1, typename T2>
struct tuple_size<const cuco::pair<T1, T2>> : tuple_size<cuco::pair<T1, T2>> {
};

template <typename T1, typename T2>
struct tuple_size<volatile cuco::pair<T1, T2>> : tuple_size<cuco::pair<T1, T2>> {
};

template <typename T1, typename T2>
struct tuple_size<const volatile cuco::pair<T1, T2>> : tuple_size<cuco::pair<T1, T2>> {
};

template <std::size_t I, typename T1, typename T2>
struct tuple_element<I, cuco::pair<T1, T2>> {
using type = void;
};

template <typename T1, typename T2>
struct tuple_element<0, cuco::pair<T1, T2>> {
using type = T1;
};

template <typename T1, typename T2>
struct tuple_element<1, cuco::pair<T1, T2>> {
using type = T2;
};

template <typename T1, typename T2>
struct tuple_element<0, const cuco::pair<T1, T2>> : tuple_element<0, cuco::pair<T1, T2>> {
};

template <typename T1, typename T2>
struct tuple_element<1, const cuco::pair<T1, T2>> : tuple_element<1, cuco::pair<T1, T2>> {
};

template <typename T1, typename T2>
struct tuple_element<0, volatile cuco::pair<T1, T2>> : tuple_element<0, cuco::pair<T1, T2>> {
};

template <typename T1, typename T2>
struct tuple_element<1, volatile cuco::pair<T1, T2>> : tuple_element<1, cuco::pair<T1, T2>> {
};

template <typename T1, typename T2>
struct tuple_element<0, const volatile cuco::pair<T1, T2>> : tuple_element<0, cuco::pair<T1, T2>> {
};

template <typename T1, typename T2>
struct tuple_element<1, const volatile cuco::pair<T1, T2>> : tuple_element<1, cuco::pair<T1, T2>> {
};

template <std::size_t I, typename T1, typename T2>
__host__ __device__ constexpr auto get(cuco::pair<T1, T2>& p) ->
typename tuple_element<I, cuco::pair<T1, T2>>::type&
{
static_assert(I < 2);
if constexpr (I == 0) {
return p.first;
} else {
return p.second;
}
}

template <std::size_t I, typename T1, typename T2>
__host__ __device__ constexpr auto get(cuco::pair<T1, T2>&& p) ->
typename tuple_element<I, cuco::pair<T1, T2>>::type&&
{
static_assert(I < 2);
if constexpr (I == 0) {
return std::move(p.first);
} else {
return std::move(p.second);
}
}

template <std::size_t I, typename T1, typename T2>
__host__ __device__ constexpr auto get(cuco::pair<T1, T2> const& p) ->
typename tuple_element<I, cuco::pair<T1, T2>>::type const&
{
static_assert(I < 2);
if constexpr (I == 0) {
return p.first;
} else {
return p.second;
}
}

template <std::size_t I, typename T1, typename T2>
__host__ __device__ constexpr auto get(cuco::pair<T1, T2> const&& p) ->
typename tuple_element<I, cuco::pair<T1, T2>>::type const&&
{
static_assert(I < 2);
if constexpr (I == 0) {
return std::move(p.first);
} else {
return std::move(p.second);
}
}
7 changes: 4 additions & 3 deletions include/cuco/pair.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <thrust/device_reference.h>
#include <thrust/tuple.h>

#include <tuple>
#include <cuda/std/tuple>
#include <type_traits>

namespace cuco {
Expand Down Expand Up @@ -87,7 +87,8 @@ struct alignas(detail::pair_alignment<First, Second>()) pair {
*/
template <typename T, std::enable_if_t<detail::is_std_pair_like<T>::value>* = nullptr>
__host__ __device__ constexpr pair(T const& p)
: pair{std::get<0>(thrust::raw_reference_cast(p)), std::get<1>(thrust::raw_reference_cast(p))}
: pair{cuda::std::get<0>(thrust::raw_reference_cast(p)),
cuda::std::get<1>(thrust::raw_reference_cast(p))}
{
}

Expand Down Expand Up @@ -143,4 +144,4 @@ __host__ __device__ constexpr bool operator==(cuco::pair<T1, T2> const& lhs,

} // namespace cuco

#include <cuco/detail/pair.inl>
#include <cuco/detail/pair/pair.inl>

0 comments on commit 03f4a89

Please sign in to comment.