Skip to content

Commit

Permalink
Fix heterogeneous insert for static_map
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Nov 4, 2023
1 parent 1dd1648 commit b7c114e
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 119 deletions.
84 changes: 48 additions & 36 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <cuco/probing_scheme.cuh>

#include <thrust/distance.h>
#include <thrust/pair.h>
#include <thrust/tuple.h>

#include <cuda/atomic>
Expand Down Expand Up @@ -262,7 +261,7 @@ class open_addressing_ref_impl {
/**
* @brief Inserts an element.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param value The element to insert
*
Expand Down Expand Up @@ -304,7 +303,7 @@ class open_addressing_ref_impl {
/**
* @brief Inserts an element.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
Expand Down Expand Up @@ -374,7 +373,7 @@ class open_addressing_ref_impl {
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param value The element to insert
*
Expand Down Expand Up @@ -423,7 +422,7 @@ class open_addressing_ref_impl {
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert_and_find
* @param value The element to insert
Expand Down Expand Up @@ -500,7 +499,7 @@ class open_addressing_ref_impl {
/**
* @brief Erases an element.
*
* @tparam ProbeKey Input type which is implicitly convertible to 'key_type'
* @tparam ProbeKey Input type which is convertible to 'key_type'
*
* @param value The element to erase
*
Expand Down Expand Up @@ -540,7 +539,7 @@ class open_addressing_ref_impl {
/**
* @brief Erases an element.
*
* @tparam ProbeKey Input type which is implicitly convertible to 'key_type'
* @tparam ProbeKey Input type which is convertible to 'key_type'
*
* @param group The Cooperative Group used to perform group erase
* @param value The element to erase
Expand Down Expand Up @@ -600,7 +599,7 @@ class open_addressing_ref_impl {
* @note If the probe key `key` was inserted into the container, returns true. Otherwise, returns
* false.
*
* @tparam ProbeKey Probe key type
* @tparam ProbeKey Input type which is convertible to 'key_type'
*
* @param key The key to search for
*
Expand Down Expand Up @@ -633,7 +632,7 @@ class open_addressing_ref_impl {
* @note If the probe key `key` was inserted into the container, returns true. Otherwise, returns
* false.
*
* @tparam ProbeKey Probe key type
* @tparam ProbeKey Input type which is convertible to 'key_type'
*
* @param group The Cooperative Group used to perform group contains
* @param key The key to search for
Expand Down Expand Up @@ -673,7 +672,7 @@ class open_addressing_ref_impl {
* @note Returns a un-incrementable input iterator to the element whose key is equivalent to
* `key`. If no such element exists, returns `end()`.
*
* @tparam ProbeKey Probe key type
* @tparam ProbeKey Input type which is convertible to 'key_type'
*
* @param key The key to search for
*
Expand Down Expand Up @@ -710,7 +709,7 @@ class open_addressing_ref_impl {
* @note Returns a un-incrementable input iterator to the element whose key is equivalent to
* `key`. If no such element exists, returns `end()`.
*
* @tparam ProbeKey Probe key type
* @tparam ProbeKey Input type which is convertible to 'key_type'
*
* @param group The Cooperative Group used to perform this operation
* @param key The key to search for
Expand Down Expand Up @@ -845,7 +844,7 @@ class open_addressing_ref_impl {
/**
* @brief Extracts the key from a given value type.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param value The input value
*
Expand All @@ -856,7 +855,7 @@ class open_addressing_ref_impl {
Value const& value) const noexcept
{
if constexpr (this->has_payload) {
return thrust::get<0>(thrust::raw_reference_cast(value));
return thrust::raw_reference_cast(value).first;
} else {
return thrust::raw_reference_cast(value);
}
Expand All @@ -867,7 +866,7 @@ class open_addressing_ref_impl {
*
* @note This function is only available if `this->has_payload == true`
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param value The input value
*
Expand All @@ -877,7 +876,26 @@ class open_addressing_ref_impl {
[[nodiscard]] __host__ __device__ constexpr auto const& extract_payload(
Value const& value) const noexcept
{
return thrust::get<1>(thrust::raw_reference_cast(value));
return thrust::raw_reference_cast(value).second;
}

/**
* @brief Converts the given type to the container's native `value_type`.
*
* @tparam T Input type which is convertible to 'value_type'
*
* @param value The input value
*
* @return The converted object
*/
template <typename T>
[[nodiscard]] __host__ __device__ constexpr value_type native_value(T const& value) const noexcept
{
if constexpr (this->has_payload) {
return {static_cast<key_type>(this->extract_key(value)), this->extract_payload(value)};
} else {
return static_cast<value_type>(value);
}
}

/**
Expand All @@ -897,7 +915,7 @@ class open_addressing_ref_impl {
/**
* @brief Inserts the specified element with one single CAS operation.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param address Pointer to the slot in memory
* @param expected Element to compare against
Expand All @@ -910,7 +928,7 @@ class open_addressing_ref_impl {
value_type const& expected,
Value const& desired) noexcept
{
auto old = compare_and_swap(address, expected, static_cast<value_type>(desired));
auto old = compare_and_swap(address, expected, this->native_value(desired));
auto* old_ptr = reinterpret_cast<value_type*>(&old);
if (cuco::detail::bitwise_compare(this->extract_key(*old_ptr), this->extract_key(expected))) {
return insert_result::SUCCESS;
Expand All @@ -925,7 +943,7 @@ class open_addressing_ref_impl {
/**
* @brief Inserts the specified element with two back-to-back CAS operations.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param address Pointer to the slot in memory
* @param expected Element to compare against
Expand All @@ -943,19 +961,17 @@ class open_addressing_ref_impl {
auto const expected_key = expected.first;
auto const expected_payload = expected.second;

auto old_key = compare_and_swap(
&address->first, expected_key, static_cast<key_type>(thrust::get<0>(desired)));
auto old_payload = compare_and_swap(
&address->second, expected_payload, static_cast<mapped_type>(thrust::get<1>(desired)));
auto old_key =
compare_and_swap(&address->first, expected_key, static_cast<key_type>(desired.first));
auto old_payload = compare_and_swap(&address->second, expected_payload, desired.second);

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);
auto* old_payload_ptr = reinterpret_cast<mapped_type*>(&old_payload);

// if key success
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
old_payload = compare_and_swap(
&address->second, expected_payload, static_cast<mapped_type>(thrust::get<1>(desired)));
old_payload = compare_and_swap(&address->second, expected_payload, desired.second);
}
return insert_result::SUCCESS;
} else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
Expand All @@ -964,9 +980,7 @@ class open_addressing_ref_impl {

// Our key was already present in the slot, so our key is a duplicate
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
if (this->predicate_.equal_to(*old_key_ptr,
thrust::get<0>(thrust::raw_reference_cast(desired))) ==
detail::equal_result::EQUAL) {
if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
return insert_result::DUPLICATE;
}

Expand All @@ -976,7 +990,7 @@ class open_addressing_ref_impl {
/**
* @brief Inserts the specified element with CAS-dependent write operations.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param address Pointer to the slot in memory
* @param expected Element to compare against
Expand All @@ -992,22 +1006,20 @@ class open_addressing_ref_impl {

auto const expected_key = expected.first;

auto old_key = compare_and_swap(
&address->first, expected_key, static_cast<key_type>(thrust::get<0>(desired)));
auto old_key =
compare_and_swap(&address->first, expected_key, static_cast<key_type>(desired.first));

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);

// if key success
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
atomic_store(&address->second, static_cast<mapped_type>(thrust::get<1>(desired)));
atomic_store(&address->second, desired.second);
return insert_result::SUCCESS;
}

// Our key was already present in the slot, so our key is a duplicate
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
if (this->predicate_.equal_to(*old_key_ptr,
thrust::get<0>(thrust::raw_reference_cast(desired))) ==
detail::equal_result::EQUAL) {
if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
return insert_result::DUPLICATE;
}

Expand All @@ -1020,7 +1032,7 @@ class open_addressing_ref_impl {
* @note Dispatches the correct implementation depending on the container
* type and presence of other operator mixins.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param address Pointer to the slot in memory
* @param expected Element to compare against
Expand Down Expand Up @@ -1053,7 +1065,7 @@ class open_addressing_ref_impl {
* @note `stable` indicates that the payload will only be updated once from the sentinel value to
* the desired value, meaning there can be no ABA situations.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param address Pointer to the slot in memory
* @param expected Element to compare against
Expand Down
Loading

0 comments on commit b7c114e

Please sign in to comment.