diff --git a/include/cuco/detail/probing_scheme/probing_scheme_impl.inl b/include/cuco/detail/probing_scheme/probing_scheme_impl.inl index 047ec7987..ca39b2cb3 100644 --- a/include/cuco/detail/probing_scheme/probing_scheme_impl.inl +++ b/include/cuco/detail/probing_scheme/probing_scheme_impl.inl @@ -95,7 +95,7 @@ __host__ __device__ constexpr linear_probing::linear_probing(Hash template template -__host__ __device__ constexpr auto linear_probing::with_hash_function( +__host__ __device__ constexpr auto linear_probing::rebind_hash_function( NewHash const& hash) const noexcept { return linear_probing{hash}; @@ -143,28 +143,20 @@ __host__ __device__ constexpr double_hashing::double_hashi template __host__ __device__ constexpr double_hashing::double_hashing( - cuco::pair const& hash) + cuda::std::tuple const& hash) : hash1_{hash.first}, hash2_{hash.second} { } -template -template -__host__ __device__ constexpr auto double_hashing::with_hash_function( - NewHash1 const& hash1, NewHash2 const& hash2) const noexcept -{ - return double_hashing{hash1, hash2}; -} - template template -__host__ __device__ constexpr auto double_hashing::with_hash_function( +__host__ __device__ constexpr auto double_hashing::rebind_hash_function( NewHash const& hash) const { static_assert(cuco::is_tuple_like::value, "The given hasher must be a tuple-like object"); - auto const [hash1, hash2] = cuco::pair{hash}; + auto const [hash1, hash2] = cuda::std::tuple{hash}; using hash1_type = cuda::std::decay_t; using hash2_type = cuda::std::decay_t; return double_hashing{hash1, hash2}; diff --git a/include/cuco/detail/static_map/kernels.cuh b/include/cuco/detail/static_map/kernels.cuh index bf2aced70..4e4f396db 100644 --- a/include/cuco/detail/static_map/kernels.cuh +++ b/include/cuco/detail/static_map/kernels.cuh @@ -206,7 +206,7 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void insert_or_apply_shmem( ref.probing_scheme(), {}, storage}; - auto shared_map_ref = std::move(shared_map).with(cuco::op::insert_or_apply); + auto shared_map_ref = shared_map.rebind_operators(cuco::op::insert_or_apply); shared_map_ref.initialize(block); block.sync(); @@ -262,4 +262,4 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void insert_or_apply_shmem( } } } -} // namespace cuco::static_map_ns::detail \ No newline at end of file +} // namespace cuco::static_map_ns::detail diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 3756a641b..989880904 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -296,11 +296,17 @@ template template -auto static_map_ref::with( - NewOperators...) && noexcept +__host__ __device__ constexpr auto +static_map_ref::with_operators( + NewOperators...) const noexcept { return static_map_ref{ - std::move(*this)}; + cuco::empty_key{this->empty_key_sentinel()}, + cuco::empty_value{this->empty_value_sentinel()}, + this->key_eq(), + this->probing_scheme(), + {}, + this->storage_ref()}; } template template -__host__ __device__ auto constexpr static_map_ref::with_operators(NewOperators...) - const noexcept +__host__ __device__ constexpr auto +static_map_ref::rebind_operators( + NewOperators...) const noexcept { return static_map_ref{ cuco::empty_key{this->empty_key_sentinel()}, cuco::empty_value{this->empty_value_sentinel()}, this->key_eq(), - this->impl_.probing_scheme(), + this->probing_scheme(), {}, - this->impl_.storage_ref()}; + this->storage_ref()}; +} + +template +template +__host__ __device__ constexpr auto +static_map_ref::rebind_key_eq( + NewKeyEqual const& key_equal) const noexcept +{ + return static_map_ref{ + cuco::empty_key{this->empty_key_sentinel()}, + cuco::empty_value{this->empty_value_sentinel()}, + key_equal, + this->probing_scheme(), + {}, + this->storage_ref()}; +} + +template +template +__host__ __device__ constexpr auto +static_map_ref:: + rebind_hash_function(NewHash const& hash) const +{ + auto const probing_scheme = this->probing_scheme().rebind_hash_function(hash); + return static_map_ref, + StorageRef, + Operators...>{cuco::empty_key{this->empty_key_sentinel()}, + cuco::empty_value{this->empty_value_sentinel()}, + this->key_eq(), + probing_scheme, + {}, + this->storage_ref()}; } template cuco::empty_value{this->empty_value_sentinel()}, cuco::erased_key{this->erased_key_sentinel()}, this->key_eq(), - this->impl_.probing_scheme(), + this->probing_scheme(), scope, storage_ref_type{this->window_extent(), memory_to_use}}; } diff --git a/include/cuco/detail/static_multimap/static_multimap_ref.inl b/include/cuco/detail/static_multimap/static_multimap_ref.inl index 3bbf90e3a..bfceb75b6 100644 --- a/include/cuco/detail/static_multimap/static_multimap_ref.inl +++ b/include/cuco/detail/static_multimap/static_multimap_ref.inl @@ -21,6 +21,7 @@ #include #include +#include #include @@ -294,11 +295,22 @@ template template -auto static_multimap_ref::with( - NewOperators...) && noexcept +__host__ __device__ auto constexpr static_multimap_ref< + Key, + T, + Scope, + KeyEqual, + ProbingScheme, + StorageRef, + Operators...>::with_operators(NewOperators...) const noexcept { return static_multimap_ref{ - std::move(*this)}; + cuco::empty_key{this->empty_key_sentinel()}, + cuco::empty_value{this->empty_value_sentinel()}, + this->key_eq(), + this->probing_scheme(), + {}, + impl_.storage_ref()}; } template ::with_operators(NewOperators...) const noexcept + Operators...>::rebind_operators(NewOperators...) const noexcept { return static_multimap_ref{ cuco::empty_key{this->empty_key_sentinel()}, @@ -324,7 +336,55 @@ __host__ __device__ auto constexpr static_multimap_ref< this->key_eq(), impl_.probing_scheme(), {}, - impl_.storage_ref()}; + this->storage_ref()}; +} + +template +template +__host__ __device__ constexpr auto +static_multimap_ref:: + rebind_key_eq(NewKeyEqual const& key_equal) const noexcept +{ + return static_multimap_ref{ + cuco::empty_key{this->empty_key_sentinel()}, + cuco::empty_value{this->empty_value_sentinel()}, + key_equal, + this->probing_scheme(), + {}, + this->storage_ref()}; +} + +template +template +__host__ __device__ constexpr auto +static_multimap_ref:: + rebind_hash_function(NewHash const& hash) const +{ + auto const probing_scheme = this->probing_scheme().rebind_hash_function(hash); + return static_multimap_ref, + StorageRef, + Operators...>{cuco::empty_key{this->empty_key_sentinel()}, + cuco::empty_value{this->empty_value_sentinel()}, + this->key_eq(), + probing_scheme, + {}, + this->storage_ref()}; } template +class operator_impl< + op::for_each_tag, + static_multimap_ref> { + using base_type = static_multimap_ref; + using ref_type = + static_multimap_ref; + + static constexpr auto cg_size = base_type::cg_size; + + public: + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Unary callback functor or device lambda + * + * @param key The key to search for + * @param callback_op Function to call on every element found + */ + template + __device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each(key, cuda::std::forward(callback_op)); + } + + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Unary callback functor or device lambda + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to call on every element found + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key, + CallbackOp&& callback_op) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each(group, key, cuda::std::forward(callback_op)); + } + + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key and can additionally perform work that requires synchronizing the Cooperative Group + * performing this operation. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @note The `sync_op` function can be used to perform work that requires synchronizing threads in + * `group` inbetween probing steps, where the number of probing steps performed between + * synchronization points is capped by `window_size * cg_size`. The functor will be called right + * after the current probing window has been traversed. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Unary callback functor or device lambda + * @tparam SyncOp Functor or device lambda which accepts the current `group` object + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to call on every element found + * @param sync_op Function that is allowed to synchronize `group` inbetween probing windows + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key, + CallbackOp&& callback_op, + SyncOp&& sync_op) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each( + group, key, cuda::std::forward(callback_op), cuda::std::forward(sync_op)); + } +}; + template ProbeHash const& probe_hash, cuda::stream_ref stream) const { - return impl_->count(first, - last, - ref(op::count).with_key_eq(probe_key_equal).with_hash_function(probe_hash), - stream); + return impl_->count( + first, + last, + ref(op::count).rebind_key_eq(probe_key_equal).rebind_hash_function(probe_hash), + stream); } template return impl_->count_outer( first, last, - ref(op::count).with_key_eq(probe_key_equal).with_hash_function(probe_hash), + ref(op::count).rebind_key_eq(probe_key_equal).rebind_hash_function(probe_hash), stream); } diff --git a/include/cuco/detail/static_multiset/static_multiset_ref.inl b/include/cuco/detail/static_multiset/static_multiset_ref.inl index aa25cdf70..1cc212f14 100644 --- a/include/cuco/detail/static_multiset/static_multiset_ref.inl +++ b/include/cuco/detail/static_multiset/static_multiset_ref.inl @@ -251,11 +251,16 @@ template template -auto static_multiset_ref::with( - NewOperators...) && noexcept +__host__ __device__ constexpr auto +static_multiset_ref::with_operators( + NewOperators...) const noexcept { return static_multiset_ref{ - std::move(*this)}; + cuco::empty_key{this->empty_key_sentinel()}, + this->key_eq(), + this->probing_scheme(), + {}, + this->storage_ref()}; } template template __host__ __device__ constexpr auto -static_multiset_ref::with_operators( - NewOperators...) const noexcept +static_multiset_ref:: + rebind_operators(NewOperators...) const noexcept { return static_multiset_ref{ cuco::empty_key{this->empty_key_sentinel()}, this->key_eq(), - this->impl_.probing_scheme(), + this->probing_scheme(), {}, - this->impl_.storage_ref()}; + this->storage_ref()}; } template template __host__ __device__ constexpr auto -static_multiset_ref::with_key_eq( +static_multiset_ref::rebind_key_eq( NewKeyEqual const& key_equal) const noexcept { return static_multiset_ref{ cuco::empty_key{this->empty_key_sentinel()}, key_equal, - this->impl_.probing_scheme(), + this->probing_scheme(), {}, - this->impl_.storage_ref()}; + this->storage_ref()}; } template __host__ __device__ constexpr auto static_multiset_ref:: - with_hash_function(NewHash const& hash) const + rebind_hash_function(NewHash const& hash) const { - auto const probing_scheme = this->impl_.probing_scheme().with_hash_function(hash); + auto const probing_scheme = this->probing_scheme().rebind_hash_function(hash); return static_multiset_ref, StorageRef, Operators...>{cuco::empty_key{this->empty_key_sentinel()}, - this->impl_.key_eq(), + this->key_eq(), probing_scheme, {}, - this->impl_.storage_ref()}; + this->storage_ref()}; } namespace detail { diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index 7e2882a0a..a70df3d76 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -248,11 +248,16 @@ template template -auto static_set_ref::with( - NewOperators...) && noexcept +__host__ __device__ constexpr auto +static_set_ref::with_operators( + NewOperators...) const noexcept { return static_set_ref{ - std::move(*this)}; + cuco::empty_key{this->empty_key_sentinel()}, + this->key_eq(), + this->probing_scheme(), + {}, + this->storage_ref()}; } template template __host__ __device__ constexpr auto -static_set_ref::with_operators( +static_set_ref::rebind_operators( NewOperators...) const noexcept { return static_set_ref{ cuco::empty_key{this->empty_key_sentinel()}, this->key_eq(), - this->impl_.probing_scheme(), + this->probing_scheme(), {}, - this->impl_.storage_ref()}; + this->storage_ref()}; } template template __host__ __device__ constexpr auto -static_set_ref::with_key_eq( +static_set_ref::rebind_key_eq( NewKeyEqual const& key_equal) const noexcept { return static_set_ref{ cuco::empty_key{this->empty_key_sentinel()}, key_equal, - this->impl_.probing_scheme(), + this->probing_scheme(), {}, - this->impl_.storage_ref()}; + this->storage_ref()}; } template template __host__ __device__ constexpr auto -static_set_ref::with_hash_function( +static_set_ref::rebind_hash_function( NewHash const& hash) const { - auto const probing_scheme = this->impl_.probing_scheme().with_hash_function(hash); + auto const probing_scheme = this->probing_scheme().rebind_hash_function(hash); return static_set_ref, StorageRef, Operators...>{cuco::empty_key{this->empty_key_sentinel()}, - this->impl_.key_eq(), + this->key_eq(), probing_scheme, {}, - this->impl_.storage_ref()}; + this->storage_ref()}; } template ::m cuco::empty_key{this->empty_key_sentinel()}, cuco::erased_key{this->erased_key_sentinel()}, this->key_eq(), - this->impl_.probing_scheme(), + this->probing_scheme(), scope, storage_ref_type{this->window_extent(), memory_to_use}}; } diff --git a/include/cuco/probing_scheme.cuh b/include/cuco/probing_scheme.cuh index 4885ad63d..4032daadb 100644 --- a/include/cuco/probing_scheme.cuh +++ b/include/cuco/probing_scheme.cuh @@ -62,7 +62,7 @@ class linear_probing : private detail::probing_scheme_base { * @return Copy of the current probing method */ template - [[nodiscard]] __host__ __device__ constexpr auto with_hash_function( + [[nodiscard]] __host__ __device__ constexpr auto rebind_hash_function( NewHash const& hash) const noexcept; /** @@ -143,23 +143,7 @@ class double_hashing : private detail::probing_scheme_base { * * @param hash Hasher tuple */ - __host__ __device__ constexpr double_hashing(cuco::pair const& hash); - - /** - *@brief Makes a copy of the current probing method with the given hasher - * - * @tparam NewHash1 First new hasher type - * @tparam NewHash2 Second new hasher type - * - * @param hash1 First hasher - * @param hash2 second hasher - * - * @return Copy of the current probing method - */ - template - [[nodiscard]] __host__ __device__ constexpr auto with_hash_function(NewHash1 const& hash1, - NewHash2 const& hash2 = { - 1}) const noexcept; + __host__ __device__ constexpr double_hashing(cuda::std::tuple const& hash); /** *@brief Makes a copy of the current probing method with the given hasher @@ -174,7 +158,7 @@ class double_hashing : private detail::probing_scheme_base { */ template ::value>> - [[nodiscard]] __host__ __device__ constexpr auto with_hash_function(NewHash const& hash) const; + [[nodiscard]] __host__ __device__ constexpr auto rebind_hash_function(NewHash const& hash) const; /** * @brief Operator to return a probing iterator diff --git a/include/cuco/static_map_ref.cuh b/include/cuco/static_map_ref.cuh index 1da1e501a..e3399a93e 100644 --- a/include/cuco/static_map_ref.cuh +++ b/include/cuco/static_map_ref.cuh @@ -227,11 +227,7 @@ class static_map_ref [[nodiscard]] __host__ __device__ constexpr auto probing_scheme() const noexcept; /** - * @brief Creates a reference with new operators from the current object. - * - * @deprecated This function is deprecated. Use the new `with_operators` instead. - * - * Note that this function uses move semantics and thus invalidates the current object. + * @brief Creates a reference with new operators from the current object * * @warning Using two or more reference objects to the same container but with * a different operator set at the same time results in undefined behavior. @@ -243,24 +239,47 @@ class static_map_ref * @return `*this` with `NewOperators...` */ template - [[nodiscard]] __host__ __device__ auto with(NewOperators... ops) && noexcept; + [[nodiscard]] __host__ __device__ constexpr auto with_operators( + NewOperators... ops) const noexcept; /** - * @brief Creates a reference with new operators from the current object - * - * @warning Using two or more reference objects to the same container but with - * a different operator set at the same time results in undefined behavior. + * @brief Creates a copy of the current non-owning reference using the given operators * * @tparam NewOperators List of `cuco::op::*_tag` types * - * @param ops List of operators, e.g., `cuco::insert` + * @param ops List of operators, e.g., `cuco::op::insert` * - * @return `*this` with `NewOperators...` + * @return Copy of the current device ref */ template - [[nodiscard]] __host__ __device__ constexpr auto with_operators( + [[nodiscard]] __host__ __device__ constexpr auto rebind_operators( NewOperators... ops) const noexcept; + /** + * @brief Makes a copy of the current device reference with the given key comparator + * + * @tparam NewKeyEqual The new key equal type + * + * @param key_equal New key comparator + * + * @return Copy of the current device ref + */ + template + [[nodiscard]] __host__ __device__ constexpr auto rebind_key_eq( + NewKeyEqual const& key_equal) const noexcept; + + /** + * @brief Makes a copy of the current device reference with the given hasher + * + * @tparam NewHash The new hasher type + * + * @param hash New hasher + * + * @return Copy of the current device ref + */ + template + [[nodiscard]] __host__ __device__ constexpr auto rebind_hash_function(NewHash const& hash) const; + /** * @brief Makes a copy of the current device reference using non-owned memory * diff --git a/include/cuco/static_multimap_ref.cuh b/include/cuco/static_multimap_ref.cuh index b23925b86..ddcf77dba 100644 --- a/include/cuco/static_multimap_ref.cuh +++ b/include/cuco/static_multimap_ref.cuh @@ -226,11 +226,7 @@ class static_multimap_ref [[nodiscard]] __host__ __device__ constexpr auto probing_scheme() const noexcept; /** - * @brief Creates a reference with new operators from the current object. - * - * @deprecated This function is deprecated. Use the new `with_operators` instead. - * - * Note that this function uses move semantics and thus invalidates the current object. + * @brief Creates a reference with new operators from the current object * * @warning Using two or more reference objects to the same container but with * a different operator set at the same time results in undefined behavior. @@ -242,24 +238,47 @@ class static_multimap_ref * @return `*this` with `NewOperators...` */ template - [[nodiscard]] __host__ __device__ auto with(NewOperators... ops) && noexcept; + [[nodiscard]] __host__ __device__ constexpr auto with_operators( + NewOperators... ops) const noexcept; /** - * @brief Creates a reference with new operators from the current object - * - * @warning Using two or more reference objects to the same container but with - * a different operator set at the same time results in undefined behavior. + * @brief Creates a copy of the current non-owning reference using the given operators * * @tparam NewOperators List of `cuco::op::*_tag` types * - * @param ops List of operators, e.g., `cuco::insert` + * @param ops List of operators, e.g., `cuco::op::insert` * - * @return `*this` with `NewOperators...` + * @return Copy of the current device ref */ template - [[nodiscard]] __host__ __device__ constexpr auto with_operators( + [[nodiscard]] __host__ __device__ constexpr auto rebind_operators( NewOperators... ops) const noexcept; + /** + * @brief Makes a copy of the current device reference with the given key comparator + * + * @tparam NewKeyEqual The new key equal type + * + * @param key_equal New key comparator + * + * @return Copy of the current device ref + */ + template + [[nodiscard]] __host__ __device__ constexpr auto rebind_key_eq( + NewKeyEqual const& key_equal) const noexcept; + + /** + * @brief Makes a copy of the current device reference with the given hasher + * + * @tparam NewHash The new hasher type + * + * @param hash New hasher + * + * @return Copy of the current device ref + */ + template + [[nodiscard]] __host__ __device__ constexpr auto rebind_hash_function(NewHash const& hash) const; + /** * @brief Makes a copy of the current device reference using non-owned memory * diff --git a/include/cuco/static_multiset_ref.cuh b/include/cuco/static_multiset_ref.cuh index bf0588f2f..25c98b874 100644 --- a/include/cuco/static_multiset_ref.cuh +++ b/include/cuco/static_multiset_ref.cuh @@ -206,11 +206,7 @@ class static_multiset_ref [[nodiscard]] __host__ __device__ constexpr auto probing_scheme() const noexcept; /** - * @brief Creates a reference with new operators from the current object. - * - * @deprecated This function is deprecated. Use the new `with_operators` instead. - * - * Note that this function uses move semantics and thus invalidates the current object. + * @brief Creates a reference with new operators from the current object * * @warning Using two or more reference objects to the same container but with * a different operator set at the same time results in undefined behavior. @@ -222,26 +218,24 @@ class static_multiset_ref * @return `*this` with `NewOperators...` */ template - [[nodiscard]] __host__ __device__ auto with(NewOperators... ops) && noexcept; + [[nodiscard]] __host__ __device__ constexpr auto with_operators( + NewOperators... ops) const noexcept; /** - * @brief Creates a reference with new operators from the current object - * - * @warning Using two or more reference objects to the same container but with - * a different operator set at the same time results in undefined behavior. + * @brief Creates a copy of the current non-owning reference using the given operators * * @tparam NewOperators List of `cuco::op::*_tag` types * - * @param ops List of operators, e.g., `cuco::insert` + * @param ops List of operators, e.g., `cuco::op::insert` * - * @return `*this` with `NewOperators...` + * @return Copy of the current device ref */ template - [[nodiscard]] __host__ __device__ constexpr auto with_operators( + [[nodiscard]] __host__ __device__ constexpr auto rebind_operators( NewOperators... ops) const noexcept; /** - * @brief Makes a copy of the current device reference with given key comparator + * @brief Makes a copy of the current device reference with the given key comparator * * @tparam NewKeyEqual The new key equal type * @@ -250,11 +244,11 @@ class static_multiset_ref * @return Copy of the current device ref */ template - [[nodiscard]] __host__ __device__ constexpr auto with_key_eq( + [[nodiscard]] __host__ __device__ constexpr auto rebind_key_eq( NewKeyEqual const& key_equal) const noexcept; /** - * @brief Makes a copy of the current device reference with given hasher + * @brief Makes a copy of the current device reference with the given hasher * * @tparam NewHash The new hasher type * @@ -263,7 +257,7 @@ class static_multiset_ref * @return Copy of the current device ref */ template - [[nodiscard]] __host__ __device__ constexpr auto with_hash_function(NewHash const& hash) const; + [[nodiscard]] __host__ __device__ constexpr auto rebind_hash_function(NewHash const& hash) const; private: impl_type impl_; diff --git a/include/cuco/static_set_ref.cuh b/include/cuco/static_set_ref.cuh index 1271cb756..fa0098ebf 100644 --- a/include/cuco/static_set_ref.cuh +++ b/include/cuco/static_set_ref.cuh @@ -204,11 +204,7 @@ class static_set_ref [[nodiscard]] __host__ __device__ constexpr auto probing_scheme() const noexcept; /** - * @brief Creates a reference with new operators from the current object. - * - * @deprecated This function is deprecated. Use the new `with_operators` instead. - * - * Note that this function uses move semantics and thus invalidates the current object. + * @brief Creates a reference with new operators from the current object * * @warning Using two or more reference objects to the same container but with * a different operator set at the same time results in undefined behavior. @@ -220,26 +216,24 @@ class static_set_ref * @return `*this` with `NewOperators...` */ template - [[nodiscard]] __host__ __device__ auto with(NewOperators... ops) && noexcept; + [[nodiscard]] __host__ __device__ constexpr auto with_operators( + NewOperators... ops) const noexcept; /** - * @brief Creates a reference with new operators from the current object - * - * @warning Using two or more reference objects to the same container but with - * a different operator set at the same time results in undefined behavior. + * @brief Creates a copy of the current non-owning reference using the given operators * * @tparam NewOperators List of `cuco::op::*_tag` types * - * @param ops List of operators, e.g., `cuco::insert` + * @param ops List of operators, e.g., `cuco::op::insert` * - * @return `*this` with `NewOperators...` + * @return Copy of the current device ref */ template - [[nodiscard]] __host__ __device__ constexpr auto with_operators( + [[nodiscard]] __host__ __device__ constexpr auto rebind_operators( NewOperators... ops) const noexcept; /** - * @brief Makes a copy of the current device reference with given key comparator + * @brief Makes a copy of the current device reference with the given key comparator * * @tparam NewKeyEqual The new key equal type * @@ -248,11 +242,11 @@ class static_set_ref * @return Copy of the current device ref */ template - [[nodiscard]] __host__ __device__ constexpr auto with_key_eq( + [[nodiscard]] __host__ __device__ constexpr auto rebind_key_eq( NewKeyEqual const& key_equal) const noexcept; /** - * @brief Makes a copy of the current device reference with given hasher + * @brief Makes a copy of the current device reference with the given hasher * * @tparam NewHash The new hasher type * @@ -261,7 +255,7 @@ class static_set_ref * @return Copy of the current device ref */ template - [[nodiscard]] __host__ __device__ constexpr auto with_hash_function(NewHash const& hash) const; + [[nodiscard]] __host__ __device__ constexpr auto rebind_hash_function(NewHash const& hash) const; /** * @brief Makes a copy of the current device reference using non-owned memory diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 737ddf32e..be88c524d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -117,7 +117,8 @@ ConfigureTest(STATIC_MULTIMAP_TEST static_multimap/insert_if_test.cu static_multimap/multiplicity_test.cu static_multimap/non_match_test.cu - static_multimap/pair_function_test.cu) + static_multimap/pair_function_test.cu + static_multimap/for_each_test.cu) ################################################################################################### # - dynamic_bitset tests -------------------------------------------------------------------------- diff --git a/tests/static_multimap/for_each_test.cu b/tests/static_multimap/for_each_test.cu new file mode 100644 index 000000000..f7290707d --- /dev/null +++ b/tests/static_multimap/for_each_test.cu @@ -0,0 +1,173 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include +#include +#include + +#include +#include + +#include + +#include + +template +CUCO_KERNEL void for_each_check_scalar(Ref ref, + InputIt first, + std::size_t n, + std::size_t multiplicity, + AtomicErrorCounter* error_counter) +{ + static_assert(Ref::cg_size == 1, "Scalar test must have cg_size==1"); + auto const loop_stride = cuco::detail::grid_stride(); + auto idx = cuco::detail::global_thread_id(); + + while (idx < n) { + auto const& key = *(first + idx); + std::size_t matches = 0; + ref.for_each(key, [&] __device__(auto const slot) { + auto const [slot_key, slot_value] = slot; + if (ref.key_eq()(key, slot_key) and ref.key_eq()(slot_key, slot_value)) { matches++; } + }); + if (matches != multiplicity) { error_counter->fetch_add(1, cuda::memory_order_relaxed); } + idx += loop_stride; + } +} + +template +CUCO_KERNEL void for_each_check_cooperative(Ref ref, + InputIt first, + std::size_t n, + std::size_t multiplicity, + AtomicErrorCounter* error_counter) +{ + auto const loop_stride = cuco::detail::grid_stride() / Ref::cg_size; + auto idx = cuco::detail::global_thread_id() / Ref::cg_size; + ; + + while (idx < n) { + auto const tile = + cooperative_groups::tiled_partition(cooperative_groups::this_thread_block()); + auto const& key = *(first + idx); + std::size_t thread_matches = 0; + if constexpr (Synced) { + ref.for_each( + tile, + key, + [&] __device__(auto const slot) { + auto const [slot_key, slot_value] = slot; + if (ref.key_eq()(key, slot_key) and ref.key_eq()(slot_key, slot_value)) { + thread_matches++; + } + }, + [] __device__(auto const& group) { group.sync(); }); + } else { + ref.for_each(tile, key, [&] __device__(auto const slot) { + auto const [slot_key, slot_value] = slot; + if (ref.key_eq()(key, slot_key) and ref.key_eq()(slot_key, slot_value)) { + thread_matches++; + } + }); + } + auto const tile_matches = + cooperative_groups::reduce(tile, thread_matches, cooperative_groups::plus()); + if (tile_matches != multiplicity and tile.thread_rank() == 0) { + error_counter->fetch_add(1, cuda::memory_order_relaxed); + } + idx += loop_stride; + } +} + +TEMPLATE_TEST_CASE_SIG( + "static_multimap for_each tests", + "", + ((typename Key, cuco::test::probe_sequence Probe, int CGSize), Key, Probe, CGSize), + (int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, cuco::test::probe_sequence::linear_probing, 2)) +{ + constexpr size_t num_unique_keys{400}; + constexpr size_t key_multiplicity{5}; + constexpr size_t num_keys{num_unique_keys * key_multiplicity}; + + using probe = std::conditional_t>, + cuco::double_hashing>>; + + auto set = cuco::experimental::static_multimap{num_keys, + cuco::empty_key{-1}, + cuco::empty_value{-1}, + {}, + probe{}, + {}, + cuco::storage<2>{}}; + + auto unique_keys_begin = thrust::counting_iterator(0); + auto gen_duplicate_keys = cuda::proclaim_return_type( + [] __device__(auto const& k) { return static_cast(k % num_unique_keys); }); + auto keys_begin = thrust::make_transform_iterator(unique_keys_begin, gen_duplicate_keys); + + auto const pairs_begin = thrust::make_transform_iterator( + keys_begin, cuda::proclaim_return_type>([] __device__(auto i) { + return cuco::pair{i, i}; + })); + + set.insert(pairs_begin, pairs_begin + num_keys); + + using error_counter_type = cuda::atomic; + error_counter_type* error_counter; + CUCO_CUDA_TRY(cudaMallocHost(&error_counter, sizeof(error_counter_type))); + new (error_counter) error_counter_type{0}; + + auto const grid_size = cuco::detail::grid_size(num_unique_keys, CGSize); + auto const block_size = cuco::detail::default_block_size(); + + // test scalar for_each + if constexpr (CGSize == 1) { + for_each_check_scalar<<>>( + set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); + REQUIRE(error_counter->load() == 0); + error_counter->store(0); + } + + // test CG for_each + for_each_check_cooperative<<>>( + set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); + REQUIRE(error_counter->load() == 0); + error_counter->store(0); + + // test synchronized CG for_each + for_each_check_cooperative<<>>( + set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); + REQUIRE(error_counter->load() == 0); + + CUCO_CUDA_TRY(cudaFreeHost(error_counter)); +} \ No newline at end of file diff --git a/tests/static_multiset/custom_count_test.cu b/tests/static_multiset/custom_count_test.cu index f92b91aad..f5ade5eeb 100644 --- a/tests/static_multiset/custom_count_test.cu +++ b/tests/static_multiset/custom_count_test.cu @@ -61,21 +61,27 @@ void test_custom_count(Set& set, size_type num_keys) { using Key = typename Set::key_type; + auto const hash = []() { + if constexpr (cuco::is_double_hashing::value) { + return cuda::std::tuple{custom_hash{}, custom_hash{}}; + } else { + return custom_hash{}; + } + }(); + auto query_begin = thrust::make_transform_iterator( thrust::make_counting_iterator(0), cuda::proclaim_return_type([] __device__(auto i) { return static_cast(i * XXX); })); SECTION("Count of empty set should be zero.") { - auto const count = - set.count(query_begin, query_begin + num_keys, custom_key_eq{}, custom_hash{}); + auto const count = set.count(query_begin, query_begin + num_keys, custom_key_eq{}, hash); REQUIRE(count == 0); } SECTION("Outer count of empty set should be the same as input size.") { - auto const count = - set.count_outer(query_begin, query_begin + num_keys, custom_key_eq{}, custom_hash{}); + auto const count = set.count_outer(query_begin, query_begin + num_keys, custom_key_eq{}, hash); REQUIRE(count == num_keys); } @@ -84,15 +90,13 @@ void test_custom_count(Set& set, size_type num_keys) SECTION("Count of n unique keys should be n.") { - auto const count = - set.count(query_begin, query_begin + num_keys, custom_key_eq{}, custom_hash{}); + auto const count = set.count(query_begin, query_begin + num_keys, custom_key_eq{}, hash); REQUIRE(count == num_keys); } SECTION("Outer count of n unique keys should be n.") { - auto const count = - set.count_outer(query_begin, query_begin + num_keys, custom_key_eq{}, custom_hash{}); + auto const count = set.count_outer(query_begin, query_begin + num_keys, custom_key_eq{}, hash); REQUIRE(count == num_keys); } @@ -102,15 +106,13 @@ void test_custom_count(Set& set, size_type num_keys) SECTION("Count of a key whose multiplicity equals n should be n.") { - auto const count = - set.count(query_begin, query_begin + num_keys, custom_key_eq{}, custom_hash{}); + auto const count = set.count(query_begin, query_begin + num_keys, custom_key_eq{}, hash); REQUIRE(count == num_keys); } SECTION("Outer count of a key whose multiplicity equals n should be n + input_size - 1.") { - auto const count = - set.count_outer(query_begin, query_begin + num_keys, custom_key_eq{}, custom_hash{}); + auto const count = set.count_outer(query_begin, query_begin + num_keys, custom_key_eq{}, hash); REQUIRE(count == 2 * num_keys - 1); } }