Skip to content

Commit

Permalink
Enable heterogeneous insert for static_set (#375)
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack authored Oct 2, 2023
1 parent ee9c48a commit fd23a3d
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 54 deletions.
22 changes: 12 additions & 10 deletions include/cuco/detail/common_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

#include <cooperative_groups.h>

#include <iterator>

namespace cuco {
namespace experimental {
namespace detail {
Expand All @@ -37,7 +39,7 @@ namespace detail {
*
* @tparam CGSize Number of threads in each CG
* @tparam BlockSize Number of threads in each block
* @tparam InputIterator Device accessible input iterator whose `value_type` is
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the `value_type` of the data structure
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
Expand All @@ -55,12 +57,12 @@ namespace detail {
*/
template <int32_t CGSize,
int32_t BlockSize,
typename InputIterator,
typename InputIt,
typename StencilIt,
typename Predicate,
typename AtomicT,
typename Ref>
__global__ void insert_if_n(InputIterator first,
__global__ void insert_if_n(InputIt first,
cuco::detail::index_type n,
StencilIt stencil,
Predicate pred,
Expand All @@ -76,7 +78,7 @@ __global__ void insert_if_n(InputIterator first,

while (idx < n) {
if (pred(*(stencil + idx))) {
typename Ref::value_type const insert_element{*(first + idx)};
typename std::iterator_traits<InputIt>::value_type const& insert_element{*(first + idx)};
if constexpr (CGSize == 1) {
if (ref.insert(insert_element)) { thread_num_successes++; };
} else {
Expand Down Expand Up @@ -106,7 +108,7 @@ __global__ void insert_if_n(InputIterator first,
*
* @tparam CGSize Number of threads in each CG
* @tparam BlockSize Number of threads in each block
* @tparam InputIterator Device accessible input iterator whose `value_type` is
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the `value_type` of the data structure
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
Expand All @@ -122,19 +124,19 @@ __global__ void insert_if_n(InputIterator first,
*/
template <int32_t CGSize,
int32_t BlockSize,
typename InputIterator,
typename InputIt,
typename StencilIt,
typename Predicate,
typename Ref>
__global__ void insert_if_n(
InputIterator first, cuco::detail::index_type n, StencilIt stencil, Predicate pred, Ref ref)
InputIt first, cuco::detail::index_type n, StencilIt stencil, Predicate pred, Ref ref)
{
auto const loop_stride = cuco::detail::grid_stride() / CGSize;
auto idx = cuco::detail::global_thread_id() / CGSize;

while (idx < n) {
if (pred(*(stencil + idx))) {
typename Ref::value_type const insert_element{*(first + idx)};
typename std::iterator_traits<InputIt>::value_type const& insert_element{*(first + idx)};
if constexpr (CGSize == 1) {
ref.insert(insert_element);
} else {
Expand Down Expand Up @@ -198,7 +200,7 @@ __global__ void contains_if_n(InputIt first,
while (idx - thread_idx < n) { // the whole thread block falls into the same iteration
if constexpr (CGSize == 1) {
if (idx < n) {
auto const key = *(first + idx);
typename std::iterator_traits<InputIt>::value_type const& key = *(first + idx);
/*
* The ld.relaxed.gpu instruction causes L1 to flush more frequently, causing increased
* sector stores from L2 to global memory. By writing results to shared memory and then
Expand All @@ -212,7 +214,7 @@ __global__ void contains_if_n(InputIt first,
} else {
auto const tile = cg::tiled_partition<CGSize>(cg::this_thread_block());
if (idx < n) {
auto const key = *(first + idx);
typename std::iterator_traits<InputIt>::value_type const& key = *(first + idx);
auto const found = pred(*(stencil + idx)) ? ref.contains(tile, key) : false;
if (tile.thread_rank() == 0) { *(output_begin + idx) = found; }
}
Expand Down
14 changes: 8 additions & 6 deletions include/cuco/detail/equal_wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,16 @@ struct equal_wrapper {
/**
* @brief Equality check with the given equality callable.
*
* @tparam U Right-hand side Element type
* @tparam LHS Left-hand side Element type
* @tparam RHS Right-hand side Element type
*
* @param lhs Left-hand side element to check equality
* @param rhs Right-hand side element to check equality
*
* @return `EQUAL` if `lhs` and `rhs` are equivalent. `UNEQUAL` otherwise.
*/
template <typename U>
__device__ constexpr equal_result equal_to(T const& lhs, U const& rhs) const noexcept
template <typename LHS, typename RHS>
__device__ constexpr equal_result equal_to(LHS const& lhs, RHS const& rhs) const noexcept
{
return equal_(lhs, rhs) ? equal_result::EQUAL : equal_result::UNEQUAL;
}
Expand All @@ -75,15 +76,16 @@ struct equal_wrapper {
* first then perform a equality check with the given `equal_` callable, i.e., `equal_(lhs, rhs)`.
* @note Container (like set or map) keys MUST be always on the left-hand side.
*
* @tparam U Right-hand side Element type
* @tparam LHS Left-hand side Element type
* @tparam RHS Right-hand side Element type
*
* @param lhs Left-hand side element to check equality
* @param rhs Right-hand side element to check equality
*
* @return Three way equality comparison result
*/
template <typename U>
__device__ constexpr equal_result operator()(T const& lhs, U const& rhs) const noexcept
template <typename LHS, typename RHS>
__device__ constexpr equal_result operator()(LHS const& lhs, RHS const& rhs) const noexcept
{
return cuco::detail::bitwise_compare(lhs, empty_sentinel_) ? equal_result::EMPTY
: this->equal_to(lhs, rhs);
Expand Down
60 changes: 36 additions & 24 deletions include/cuco/detail/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,16 @@ class open_addressing_ref_impl {
* @brief Inserts an element.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param value The element to insert
* @param predicate Predicate used to compare slot content against `key`
*
* @return True if the given element is successfully inserted
*/
template <bool HasPayload, typename Predicate>
__device__ bool insert(value_type const& value, Predicate const& predicate) noexcept
template <bool HasPayload, typename Value, typename Predicate>
__device__ bool insert(Value const& value, Predicate const& predicate) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

Expand Down Expand Up @@ -202,6 +203,7 @@ class open_addressing_ref_impl {
* @brief Inserts an element.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param group The Cooperative Group used to perform group insert
Expand All @@ -210,9 +212,9 @@ class open_addressing_ref_impl {
*
* @return True if the given element is successfully inserted
*/
template <bool HasPayload, typename Predicate>
template <bool HasPayload, typename Value, typename Predicate>
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value,
Value const& value,
Predicate const& predicate) noexcept
{
auto const key = [&]() {
Expand Down Expand Up @@ -275,6 +277,7 @@ class open_addressing_ref_impl {
* not.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param value The element to insert
Expand All @@ -283,8 +286,8 @@ class open_addressing_ref_impl {
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <bool HasPayload, typename Predicate>
__device__ thrust::pair<iterator, bool> insert_and_find(value_type const& value,
template <bool HasPayload, typename Value, typename Predicate>
__device__ thrust::pair<iterator, bool> insert_and_find(Value const& value,
Predicate const& predicate) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
Expand Down Expand Up @@ -337,6 +340,7 @@ class open_addressing_ref_impl {
* not.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param group The Cooperative Group used to perform group insert_and_find
Expand All @@ -346,10 +350,10 @@ class open_addressing_ref_impl {
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <bool HasPayload, typename Predicate>
template <bool HasPayload, typename Value, typename Predicate>
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value,
Value const& value,
Predicate const& predicate) noexcept
{
auto const key = [&]() {
Expand Down Expand Up @@ -712,6 +716,7 @@ class open_addressing_ref_impl {
* @brief Inserts the specified element with one single CAS operation.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param slot Pointer to the slot in memory
Expand All @@ -720,12 +725,12 @@ class open_addressing_ref_impl {
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <bool HasPayload, typename Predicate>
template <bool HasPayload, typename Value, typename Predicate>
[[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot,
value_type const& value,
Value const& value,
Predicate const& predicate) noexcept
{
auto old = compare_and_swap(slot, this->empty_slot_sentinel_, value);
auto old = compare_and_swap(slot, this->empty_slot_sentinel_, static_cast<value_type>(value));
auto* old_ptr = reinterpret_cast<value_type*>(&old);
auto const inserted = [&]() {
if constexpr (HasPayload) {
Expand Down Expand Up @@ -757,6 +762,7 @@ class open_addressing_ref_impl {
/**
* @brief Inserts the specified element with two back-to-back CAS operations.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param slot Pointer to the slot in memory
Expand All @@ -765,25 +771,27 @@ class open_addressing_ref_impl {
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <typename Predicate>
template <typename Value, typename Predicate>
[[nodiscard]] __device__ constexpr insert_result back_to_back_cas(
value_type* slot, value_type const& value, Predicate const& predicate) noexcept
value_type* slot, Value const& value, Predicate const& predicate) noexcept
{
using mapped_type = decltype(this->empty_slot_sentinel_.second);

auto const expected_key = this->empty_slot_sentinel_.first;
auto const expected_payload = this->empty_slot_sentinel_.second;

auto old_key = compare_and_swap(&slot->first, expected_key, value.first);
auto old_payload = compare_and_swap(&slot->second, expected_payload, value.second);

using mapped_type = decltype(expected_payload);
auto old_key = compare_and_swap(&slot->first, expected_key, static_cast<key_type>(value.first));
auto old_payload =
compare_and_swap(&slot->second, expected_payload, static_cast<mapped_type>(value.second));

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);
auto* old_payload_ptr = reinterpret_cast<mapped_type*>(&old_payload);

// if key success
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
old_payload = compare_and_swap(&slot->second, expected_payload, value.second);
old_payload =
compare_and_swap(&slot->second, expected_payload, static_cast<mapped_type>(value.second));
}
return insert_result::SUCCESS;
} else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
Expand All @@ -802,6 +810,7 @@ class open_addressing_ref_impl {
/**
* @brief Inserts the specified element with CAS-dependent write operations.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param slot Pointer to the slot in memory
Expand All @@ -810,19 +819,21 @@ class open_addressing_ref_impl {
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <typename Predicate>
template <typename Value, typename Predicate>
[[nodiscard]] __device__ constexpr insert_result cas_dependent_write(
value_type* slot, value_type const& value, Predicate const& predicate) noexcept
value_type* slot, Value const& value, Predicate const& predicate) noexcept
{
using mapped_type = decltype(this->empty_slot_sentinel_.second);

auto const expected_key = this->empty_slot_sentinel_.first;

auto old_key = compare_and_swap(&slot->first, expected_key, value.first);
auto old_key = compare_and_swap(&slot->first, expected_key, static_cast<key_type>(value.first));

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);

// if key success
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
atomic_store(&slot->second, value.second);
atomic_store(&slot->second, static_cast<mapped_type>(value.second));
return insert_result::SUCCESS;
}

Expand All @@ -842,6 +853,7 @@ class open_addressing_ref_impl {
* type and presence of other operator mixins.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
* @param slot Pointer to the slot in memory
Expand All @@ -850,9 +862,9 @@ class open_addressing_ref_impl {
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <bool HasPayload, typename Predicate>
template <bool HasPayload, typename Value, typename Predicate>
[[nodiscard]] __device__ insert_result attempt_insert(value_type* slot,
value_type const& value,
Value const& value,
Predicate const& predicate) noexcept
{
if constexpr (sizeof(value_type) <= 8) {
Expand Down
4 changes: 3 additions & 1 deletion include/cuco/detail/static_set/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#include <cooperative_groups.h>

#include <iterator>

namespace cuco {
namespace experimental {
namespace static_set_ns {
Expand Down Expand Up @@ -62,7 +64,7 @@ __global__ void find(InputIt first, cuco::detail::index_type n, OutputIt output_

while (idx - thread_idx < n) { // the whole thread block falls into the same iteration
if (idx < n) {
auto const key = *(first + idx);
typename std::iterator_traits<InputIt>::value_type const& key = *(first + idx);
if constexpr (CGSize == 1) {
auto const found = ref.find(key);
/*
Expand Down
Loading

0 comments on commit fd23a3d

Please sign in to comment.