diff --git a/include/cuco/detail/common_kernels.cuh b/include/cuco/detail/common_kernels.cuh index 871e1f8e8..ac0334e2f 100644 --- a/include/cuco/detail/common_kernels.cuh +++ b/include/cuco/detail/common_kernels.cuh @@ -23,6 +23,8 @@ #include +#include + namespace cuco { namespace experimental { namespace detail { @@ -37,7 +39,7 @@ namespace detail { * * @tparam CGSize Number of threads in each CG * @tparam BlockSize Number of threads in each block - * @tparam InputIterator Device accessible input iterator whose `value_type` is + * @tparam InputIt Device accessible input iterator whose `value_type` is * convertible to the `value_type` of the data structure * @tparam StencilIt Device accessible random access iterator whose value_type is * convertible to Predicate's argument type @@ -55,12 +57,12 @@ namespace detail { */ template -__global__ void insert_if_n(InputIterator first, +__global__ void insert_if_n(InputIt first, cuco::detail::index_type n, StencilIt stencil, Predicate pred, @@ -76,7 +78,7 @@ __global__ void insert_if_n(InputIterator first, while (idx < n) { if (pred(*(stencil + idx))) { - auto const insert_element{*(first + idx)}; + typename std::iterator_traits::value_type const insert_element{*(first + idx)}; if constexpr (CGSize == 1) { if (ref.insert(insert_element)) { thread_num_successes++; }; } else { @@ -106,7 +108,7 @@ __global__ void insert_if_n(InputIterator first, * * @tparam CGSize Number of threads in each CG * @tparam BlockSize Number of threads in each block - * @tparam InputIterator Device accessible input iterator whose `value_type` is + * @tparam InputIt Device accessible input iterator whose `value_type` is * convertible to the `value_type` of the data structure * @tparam StencilIt Device accessible random access iterator whose value_type is * convertible to Predicate's argument type @@ -122,19 +124,19 @@ __global__ void insert_if_n(InputIterator first, */ template __global__ void insert_if_n( - InputIterator first, cuco::detail::index_type n, StencilIt stencil, Predicate pred, Ref ref) + InputIt first, cuco::detail::index_type n, StencilIt stencil, Predicate pred, Ref ref) { auto const loop_stride = cuco::detail::grid_stride() / CGSize; auto idx = cuco::detail::global_thread_id() / CGSize; while (idx < n) { if (pred(*(stencil + idx))) { - auto const insert_element{*(first + idx)}; + typename std::iterator_traits::value_type const insert_element{*(first + idx)}; if constexpr (CGSize == 1) { ref.insert(insert_element); } else { @@ -212,7 +214,7 @@ __global__ void contains_if_n(InputIt first, } else { auto const tile = cg::tiled_partition(cg::this_thread_block()); if (idx < n) { - auto const key = *(first + idx); + typename std::iterator_traits::value_type const key = *(first + idx); auto const found = pred(*(stencil + idx)) ? ref.contains(tile, key) : false; if (tile.thread_rank() == 0) { *(output_begin + idx) = found; } } diff --git a/include/cuco/detail/static_set/kernels.cuh b/include/cuco/detail/static_set/kernels.cuh index 72744f2b4..a3b0c0e9c 100644 --- a/include/cuco/detail/static_set/kernels.cuh +++ b/include/cuco/detail/static_set/kernels.cuh @@ -24,6 +24,8 @@ #include +#include + namespace cuco { namespace experimental { namespace static_set_ns { @@ -62,7 +64,7 @@ __global__ void find(InputIt first, cuco::detail::index_type n, OutputIt output_ while (idx - thread_idx < n) { // the whole thread block falls into the same iteration if (idx < n) { - auto const key = *(first + idx); + typename std::iterator_traits::value_type const key = *(first + idx); if constexpr (CGSize == 1) { auto const found = ref.find(key); /* diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index 4c3853971..8f7b1f86c 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -132,7 +132,8 @@ class operator_impl + __device__ bool insert(Value const& value) noexcept { ref_type& ref_ = static_cast(*this); auto constexpr has_payload = false; @@ -147,8 +148,9 @@ class operator_impl __device__ bool insert(cooperative_groups::thread_block_tile const& group, - value_type const& value) noexcept + Value const& value) noexcept { auto& ref_ = static_cast(*this); auto constexpr has_payload = false; @@ -208,12 +210,15 @@ class operator_impl insert_and_find(value_type const& value) noexcept + template + __device__ thrust::pair insert_and_find(Value const& value) noexcept { ref_type& ref_ = static_cast(*this); auto constexpr has_payload = false; @@ -227,14 +232,17 @@ class operator_impl __device__ thrust::pair insert_and_find( - cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept + cooperative_groups::thread_block_tile const& group, Value const& value) noexcept { ref_type& ref_ = static_cast(*this); auto constexpr has_payload = false; diff --git a/tests/static_set/heterogeneous_lookup_test.cu b/tests/static_set/heterogeneous_lookup_test.cu index e1c3f5e9e..e294203db 100644 --- a/tests/static_set/heterogeneous_lookup_test.cu +++ b/tests/static_set/heterogeneous_lookup_test.cu @@ -68,16 +68,16 @@ struct custom_hasher { template __device__ uint32_t operator()(CustomKey const& k) const { - return thrust::raw_reference_cast(k).a; + return k.a; }; }; // User-defined device key equality struct custom_key_equal { - template - __device__ bool operator()(LHS const& lhs, RHS const& rhs) const + template + __device__ bool operator()(SlotKey const& lhs, InputKey const& rhs) const { - return thrust::raw_reference_cast(lhs) == thrust::raw_reference_cast(rhs).a; + return lhs == rhs.a; } };