Skip to content

Commit

Permalink
Fix static_set heterogeneous insert
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Sep 28, 2023
1 parent 134a52f commit c3f95e3
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 18 deletions.
20 changes: 11 additions & 9 deletions include/cuco/detail/common_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

#include <cooperative_groups.h>

#include <iterator>

namespace cuco {
namespace experimental {
namespace detail {
Expand All @@ -37,7 +39,7 @@ namespace detail {
*
* @tparam CGSize Number of threads in each CG
* @tparam BlockSize Number of threads in each block
* @tparam InputIterator Device accessible input iterator whose `value_type` is
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the `value_type` of the data structure
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
Expand All @@ -55,12 +57,12 @@ namespace detail {
*/
template <int32_t CGSize,
int32_t BlockSize,
typename InputIterator,
typename InputIt,
typename StencilIt,
typename Predicate,
typename AtomicT,
typename Ref>
__global__ void insert_if_n(InputIterator first,
__global__ void insert_if_n(InputIt first,
cuco::detail::index_type n,
StencilIt stencil,
Predicate pred,
Expand All @@ -76,7 +78,7 @@ __global__ void insert_if_n(InputIterator first,

while (idx < n) {
if (pred(*(stencil + idx))) {
auto const insert_element{*(first + idx)};
typename std::iterator_traits<InputIt>::value_type const insert_element{*(first + idx)};
if constexpr (CGSize == 1) {
if (ref.insert(insert_element)) { thread_num_successes++; };
} else {
Expand Down Expand Up @@ -106,7 +108,7 @@ __global__ void insert_if_n(InputIterator first,
*
* @tparam CGSize Number of threads in each CG
* @tparam BlockSize Number of threads in each block
* @tparam InputIterator Device accessible input iterator whose `value_type` is
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the `value_type` of the data structure
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
Expand All @@ -122,19 +124,19 @@ __global__ void insert_if_n(InputIterator first,
*/
template <int32_t CGSize,
int32_t BlockSize,
typename InputIterator,
typename InputIt,
typename StencilIt,
typename Predicate,
typename Ref>
__global__ void insert_if_n(
InputIterator first, cuco::detail::index_type n, StencilIt stencil, Predicate pred, Ref ref)
InputIt first, cuco::detail::index_type n, StencilIt stencil, Predicate pred, Ref ref)
{
auto const loop_stride = cuco::detail::grid_stride() / CGSize;
auto idx = cuco::detail::global_thread_id() / CGSize;

while (idx < n) {
if (pred(*(stencil + idx))) {
auto const insert_element{*(first + idx)};
typename std::iterator_traits<InputIt>::value_type const insert_element{*(first + idx)};
if constexpr (CGSize == 1) {
ref.insert(insert_element);
} else {
Expand Down Expand Up @@ -212,7 +214,7 @@ __global__ void contains_if_n(InputIt first,
} else {
auto const tile = cg::tiled_partition<CGSize>(cg::this_thread_block());
if (idx < n) {
auto const key = *(first + idx);
typename std::iterator_traits<InputIt>::value_type const key = *(first + idx);
auto const found = pred(*(stencil + idx)) ? ref.contains(tile, key) : false;
if (tile.thread_rank() == 0) { *(output_begin + idx) = found; }
}
Expand Down
4 changes: 3 additions & 1 deletion include/cuco/detail/static_set/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#include <cooperative_groups.h>

#include <iterator>

namespace cuco {
namespace experimental {
namespace static_set_ns {
Expand Down Expand Up @@ -62,7 +64,7 @@ __global__ void find(InputIt first, cuco::detail::index_type n, OutputIt output_

while (idx - thread_idx < n) { // the whole thread block falls into the same iteration
if (idx < n) {
auto const key = *(first + idx);
typename std::iterator_traits<InputIt>::value_type const key = *(first + idx);
if constexpr (CGSize == 1) {
auto const found = ref.find(key);
/*
Expand Down
16 changes: 12 additions & 4 deletions include/cuco/detail/static_set/static_set_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ class operator_impl<op::insert_tag,
*
* @return True if the given element is successfully inserted
*/
__device__ bool insert(value_type const& value) noexcept
template <typename Value>
__device__ bool insert(Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
Expand All @@ -147,8 +148,9 @@ class operator_impl<op::insert_tag,
*
* @return True if the given element is successfully inserted
*/
template <typename Value>
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value) noexcept
Value const& value) noexcept
{
auto& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
Expand Down Expand Up @@ -208,12 +210,15 @@ class operator_impl<op::insert_and_find_tag,
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param value The element to insert
*
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
__device__ thrust::pair<iterator, bool> insert_and_find(value_type const& value) noexcept
template <typename Value>
__device__ thrust::pair<iterator, bool> insert_and_find(Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
Expand All @@ -227,14 +232,17 @@ class operator_impl<op::insert_and_find_tag,
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert_and_find
* @param value The element to insert
*
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <typename Value>
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group, value_type const& value) noexcept
cooperative_groups::thread_block_tile<cg_size> const& group, Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
Expand Down
8 changes: 4 additions & 4 deletions tests/static_set/heterogeneous_lookup_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,16 @@ struct custom_hasher {
template <typename CustomKey>
__device__ uint32_t operator()(CustomKey const& k) const
{
return thrust::raw_reference_cast(k).a;
return k.a;
};
};

// User-defined device key equality
struct custom_key_equal {
template <typename LHS, typename RHS>
__device__ bool operator()(LHS const& lhs, RHS const& rhs) const
template <typename SlotKey, typename InputKey>
__device__ bool operator()(SlotKey const& lhs, InputKey const& rhs) const
{
return thrust::raw_reference_cast(lhs) == thrust::raw_reference_cast(rhs).a;
return lhs == rhs.a;
}
};

Expand Down

0 comments on commit c3f95e3

Please sign in to comment.