Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Sep 27, 2023
1 parent 359f5ae commit 134a52f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 37 deletions.
4 changes: 2 additions & 2 deletions include/cuco/detail/common_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ __global__ void insert_if_n(InputIterator first,

while (idx < n) {
if (pred(*(stencil + idx))) {
typename Ref::value_type const insert_element{*(first + idx)};
auto const insert_element{*(first + idx)};
if constexpr (CGSize == 1) {
if (ref.insert(insert_element)) { thread_num_successes++; };
} else {
Expand Down Expand Up @@ -134,7 +134,7 @@ __global__ void insert_if_n(

while (idx < n) {
if (pred(*(stencil + idx))) {
typename Ref::value_type const insert_element{*(first + idx)};
auto const insert_element{*(first + idx)};
if constexpr (CGSize == 1) {
ref.insert(insert_element);
} else {
Expand Down
14 changes: 8 additions & 6 deletions include/cuco/detail/equal_wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,16 @@ struct equal_wrapper {
/**
* @brief Equality check with the given equality callable.
*
* @tparam U Right-hand side Element type
* @tparam LHS Left-hand side Element type
* @tparam RHS Right-hand side Element type
*
* @param lhs Left-hand side element to check equality
* @param rhs Right-hand side element to check equality
*
* @return `EQUAL` if `lhs` and `rhs` are equivalent. `UNEQUAL` otherwise.
*/
template <typename U>
__device__ constexpr equal_result equal_to(T const& lhs, U const& rhs) const noexcept
template <typename LHS, typename RHS>
__device__ constexpr equal_result equal_to(LHS const& lhs, RHS const& rhs) const noexcept
{
return equal_(lhs, rhs) ? equal_result::EQUAL : equal_result::UNEQUAL;
}
Expand All @@ -75,15 +76,16 @@ struct equal_wrapper {
* first then perform a equality check with the given `equal_` callable, i.e., `equal_(lhs, rhs)`.
* @note Container (like set or map) keys MUST be always on the left-hand side.
*
* @tparam U Right-hand side Element type
* @tparam LHS Left-hand side Element type
* @tparam RHS Right-hand side Element type
*
* @param lhs Left-hand side element to check equality
* @param rhs Right-hand side element to check equality
*
* @return Three way equality comparison result
*/
template <typename U>
__device__ constexpr equal_result operator()(T const& lhs, U const& rhs) const noexcept
template <typename LHS, typename RHS>
__device__ constexpr equal_result operator()(LHS const& lhs, RHS const& rhs) const noexcept
{
return cuco::detail::bitwise_compare(lhs, empty_sentinel_) ? equal_result::EMPTY
: this->equal_to(lhs, rhs);
Expand Down
57 changes: 34 additions & 23 deletions include/cuco/detail/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,16 @@ class open_addressing_ref_impl {
* @brief Inserts an element.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param value The element to insert
* @param predicate Predicate used to compare slot content against `key`
*
* @return True if the given element is successfully inserted
*/
template <bool HasPayload, typename Predicate>
__device__ bool insert(value_type const& value, Predicate const& predicate) noexcept
template <bool HasPayload, typename Value, typename Predicate>
__device__ bool insert(Value const& value, Predicate const& predicate) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

Expand Down Expand Up @@ -202,6 +203,7 @@ class open_addressing_ref_impl {
* @brief Inserts an element.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param group The Cooperative Group used to perform group insert
Expand All @@ -210,9 +212,9 @@ class open_addressing_ref_impl {
*
* @return True if the given element is successfully inserted
*/
template <bool HasPayload, typename Predicate>
template <bool HasPayload, typename Value, typename Predicate>
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value,
Value const& value,
Predicate const& predicate) noexcept
{
auto const key = [&]() {
Expand Down Expand Up @@ -275,6 +277,7 @@ class open_addressing_ref_impl {
* not.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param value The element to insert
Expand All @@ -283,8 +286,8 @@ class open_addressing_ref_impl {
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <bool HasPayload, typename Predicate>
__device__ thrust::pair<iterator, bool> insert_and_find(value_type const& value,
template <bool HasPayload, typename Value, typename Predicate>
__device__ thrust::pair<iterator, bool> insert_and_find(Value const& value,
Predicate const& predicate) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
Expand All @@ -309,7 +312,8 @@ class open_addressing_ref_impl {
if (eq_res == detail::equal_result::EQUAL) { return {iterator{&window_ptr[i]}, false}; }
if (eq_res == detail::equal_result::EMPTY) {
switch ([&]() {
if constexpr (sizeof(value_type) <= 8) {
if constexpr ((sizeof(value_type) <= 8) and
cuda::std::is_convertible_v<Value, value_type>) {
return packed_cas<HasPayload>(window_ptr + i, value, predicate);
} else {
return cas_dependent_write(window_ptr + i, value, predicate);
Expand Down Expand Up @@ -337,6 +341,7 @@ class open_addressing_ref_impl {
* not.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param group The Cooperative Group used to perform group insert_and_find
Expand All @@ -346,10 +351,10 @@ class open_addressing_ref_impl {
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <bool HasPayload, typename Predicate>
template <bool HasPayload, typename Value, typename Predicate>
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value,
Value const& value,
Predicate const& predicate) noexcept
{
auto const key = [&]() {
Expand Down Expand Up @@ -712,6 +717,7 @@ class open_addressing_ref_impl {
* @brief Inserts the specified element with one single CAS operation.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param slot Pointer to the slot in memory
Expand All @@ -720,12 +726,12 @@ class open_addressing_ref_impl {
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <bool HasPayload, typename Predicate>
template <bool HasPayload, typename Value, typename Predicate>
[[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot,
value_type const& value,
Value const& value,
Predicate const& predicate) noexcept
{
auto old = compare_and_swap(slot, this->empty_slot_sentinel_, value);
auto old = compare_and_swap(slot, this->empty_slot_sentinel_, static_cast<value_type>(value));
auto* old_ptr = reinterpret_cast<value_type*>(&old);
auto const inserted = [&]() {
if constexpr (HasPayload) {
Expand Down Expand Up @@ -757,6 +763,7 @@ class open_addressing_ref_impl {
/**
* @brief Inserts the specified element with two back-to-back CAS operations.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param slot Pointer to the slot in memory
Expand All @@ -765,15 +772,16 @@ class open_addressing_ref_impl {
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <typename Predicate>
template <typename Value, typename Predicate>
[[nodiscard]] __device__ constexpr insert_result back_to_back_cas(
value_type* slot, value_type const& value, Predicate const& predicate) noexcept
value_type* slot, Value const& value, Predicate const& predicate) noexcept
{
auto const expected_key = this->empty_slot_sentinel_.first;
auto const expected_payload = this->empty_slot_sentinel_.second;

auto old_key = compare_and_swap(&slot->first, expected_key, value.first);
auto old_payload = compare_and_swap(&slot->second, expected_payload, value.second);
auto old_key = compare_and_swap(&slot->first, expected_key, value.first); // TODO static_cast?
auto old_payload =
compare_and_swap(&slot->second, expected_payload, value.second); // TODO static_cast?

using mapped_type = decltype(expected_payload);

Expand All @@ -783,7 +791,8 @@ class open_addressing_ref_impl {
// if key success
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
old_payload = compare_and_swap(&slot->second, expected_payload, value.second);
old_payload =
compare_and_swap(&slot->second, expected_payload, value.second); // TODO static_cast?
}
return insert_result::SUCCESS;
} else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
Expand All @@ -802,6 +811,7 @@ class open_addressing_ref_impl {
/**
* @brief Inserts the specified element with CAS-dependent write operations.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param slot Pointer to the slot in memory
Expand All @@ -810,13 +820,13 @@ class open_addressing_ref_impl {
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <typename Predicate>
template <typename Value, typename Predicate>
[[nodiscard]] __device__ constexpr insert_result cas_dependent_write(
value_type* slot, value_type const& value, Predicate const& predicate) noexcept
value_type* slot, Value const& value, Predicate const& predicate) noexcept
{
auto const expected_key = this->empty_slot_sentinel_.first;

auto old_key = compare_and_swap(&slot->first, expected_key, value.first);
auto old_key = compare_and_swap(&slot->first, expected_key, value.first); // TODO static_cast?

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);

Expand All @@ -842,6 +852,7 @@ class open_addressing_ref_impl {
* type and presence of other operator mixins.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param slot Pointer to the slot in memory
Expand All @@ -850,12 +861,12 @@ class open_addressing_ref_impl {
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <bool HasPayload, typename Predicate>
template <bool HasPayload, typename Value, typename Predicate>
[[nodiscard]] __device__ insert_result attempt_insert(value_type* slot,
value_type const& value,
Value const& value,
Predicate const& predicate) noexcept
{
if constexpr (sizeof(value_type) <= 8) {
if constexpr ((sizeof(value_type) <= 8) and cuda::std::is_convertible_v<Value, value_type>) {
return packed_cas<HasPayload>(slot, value, predicate);
} else {
#if (_CUDA_ARCH__ < 700)
Expand Down
15 changes: 9 additions & 6 deletions tests/static_set/heterogeneous_lookup_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct key_pair {
// Device equality operator is mandatory due to libcudacxx bug:
// https://github.com/NVIDIA/libcudacxx/issues/223
__device__ bool operator==(key_pair const& other) const { return a == other.a and b == other.b; }

__device__ operator T() const noexcept { return a; }
};

// probe key type
Expand Down Expand Up @@ -75,14 +77,15 @@ struct custom_key_equal {
template <typename LHS, typename RHS>
__device__ bool operator()(LHS const& lhs, RHS const& rhs) const
{
return thrust::raw_reference_cast(lhs).a == thrust::raw_reference_cast(rhs).a;
return thrust::raw_reference_cast(lhs) == thrust::raw_reference_cast(rhs).a;
}
};

TEMPLATE_TEST_CASE_SIG(
"Heterogeneous lookup", "", ((typename T, int CGSize), T, CGSize), (int32_t, 1), (int32_t, 2))
{
using Key = key_pair<T>;
using Key = T;
using InsertKey = key_pair<T>;
using ProbeKey = key_triplet<T>;
using probe_type = cuco::experimental::double_hashing<CGSize, custom_hasher, custom_hasher>;

Expand All @@ -98,15 +101,15 @@ TEMPLATE_TEST_CASE_SIG(
probe_type>{
capacity, cuco::empty_key<Key>{sentinel_key}, custom_key_equal{}, probe};

auto insert_pairs = thrust::make_transform_iterator(thrust::counting_iterator<int>(0),
[] __device__(auto i) { return Key{i}; });
auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator<int>(0),
auto insert_keys = thrust::make_transform_iterator(
thrust::counting_iterator<int>(0), [] __device__(auto i) { return InsertKey(i); });
auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator<int>(0),
[] __device__(auto i) { return ProbeKey(i); });

SECTION("All inserted keys should be contained")
{
thrust::device_vector<bool> contained(num);
my_set.insert(insert_pairs, insert_pairs + num);
my_set.insert(insert_keys, insert_keys + num);
my_set.contains(probe_keys, probe_keys + num, contained.begin());
REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), thrust::identity{}));
}
Expand Down

0 comments on commit 134a52f

Please sign in to comment.