diff --git a/include/cuco/detail/open_addressing/kernels.cuh b/include/cuco/detail/open_addressing/kernels.cuh index b0457c071..67314462a 100644 --- a/include/cuco/detail/open_addressing/kernels.cuh +++ b/include/cuco/detail/open_addressing/kernels.cuh @@ -329,19 +329,33 @@ struct find_buffer * @tparam CGSize Number of threads in each CG * @tparam BlockSize The size of the thread block * @tparam InputIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose value_type is + * convertible to Predicate's argument type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` + * and argument type is convertible from `std::iterator_traits::value_type` * @tparam OutputIt Device accessible output iterator * @tparam Ref Type of non-owning device ref allowing access to storage * * @param first Beginning of the sequence of keys * @param n Number of keys to query + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + n)` * @param output_begin Beginning of the sequence of matched payloads retrieved for each key * @param ref Non-owning container device ref used to access the slot storage */ -template -CUCO_KERNEL __launch_bounds__(BlockSize) void find(InputIt first, - cuco::detail::index_type n, - OutputIt output_begin, - Ref ref) +template +CUCO_KERNEL __launch_bounds__(BlockSize) void find_if_n(InputIt first, + cuco::detail::index_type n, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + Ref ref) { namespace cg = cooperative_groups; @@ -382,7 +396,7 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void find(InputIt first, * synchronizing before writing back to global, we no longer rely on L1, preventing the * increase in sector stores from L2 to global and improving performance. */ - output_buffer[thread_idx] = output(found); + output_buffer[thread_idx] = pred(*(stencil + idx)) ? output(found) : sentinel; } block.sync(); if (idx < n) { *(output_begin + idx) = output_buffer[thread_idx]; } @@ -392,7 +406,9 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void find(InputIt first, typename std::iterator_traits::value_type const& key = *(first + idx); auto const found = ref.find(tile, key); - if (tile.thread_rank() == 0) { *(output_begin + idx) = output(found); } + if (tile.thread_rank() == 0) { + *(output_begin + idx) = pred(*(stencil + idx)) ? output(found) : sentinel; + } } } idx += loop_stride; diff --git a/include/cuco/detail/open_addressing/open_addressing_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_impl.cuh index a6fd9b3c1..f8d36b556 100644 --- a/include/cuco/detail/open_addressing/open_addressing_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_impl.cuh @@ -565,15 +565,59 @@ class open_addressing_impl { OutputIt output_begin, Ref container_ref, cuda::stream_ref stream) const noexcept + { + auto const always_true = thrust::constant_iterator{true}; + + this->find_if_async( + first, last, always_true, thrust::identity{}, output_begin, container_ref, stream); + } + + /** + * @brief For all keys in the range `[first, last)`, asynchronously finds + * a match with its key equivalent to the query key. + * + * @note If `pred( *(stencil + i) )` is true, stores the payload of the + * matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )` + * is false, stores `empty_value_sentienl` to `(output_begin + i)`. + * + * @tparam InputIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose value_type is + * convertible to Predicate's argument type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and + * argument type is convertible from std::iterator_traits::value_type + * @tparam OutputIt Device accessible output iterator + * @tparam Ref Type of non-owning device container ref allowing access to storage + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + + * std::distance(first, last))` + * @param output_begin Beginning of the sequence of matches retrieved for each key + * @param container_ref Non-owning device container ref used to access the slot storage + * @param stream Stream used for executing the kernels + */ + template + void find_if_async(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + Ref container_ref, + cuda::stream_ref stream) const noexcept { auto const num_keys = cuco::detail::distance(first, last); if (num_keys == 0) { return; } auto const grid_size = cuco::detail::grid_size(num_keys, cg_size); - detail::find + detail::find_if_n <<>>( - first, num_keys, output_begin, container_ref); + first, num_keys, stencil, pred, output_begin, container_ref); } /** diff --git a/include/cuco/detail/static_map/static_map.inl b/include/cuco/detail/static_map/static_map.inl index 8b49c35d4..7c69263d2 100644 --- a/include/cuco/detail/static_map/static_map.inl +++ b/include/cuco/detail/static_map/static_map.inl @@ -491,6 +491,47 @@ void static_mapfind_async(first, last, output_begin, ref(op::find), stream); } +template +template +void static_map::find_if( + InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream) const +{ + this->find_if_async(first, last, stencil, pred, output_begin, stream); + stream.wait(); +} + +template +template +void static_map::find_if_async( + InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream) const +{ + impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream); +} + template find_async(first, last, output_begin, ref(op::find), stream); } +template +template +void static_multimap::find_if( + InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream) const +{ + this->find_if_async(first, last, stencil, pred, output_begin, stream); + stream.wait(); +} + +template +template +void static_multimap:: + find_if_async(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream) const +{ + impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream); +} + template count(first, last, ref(op::count), stream); } +template +template +std::pair +static_multimap::retrieve_all( + KeyOut keys_out, ValueOut values_out, cuda::stream_ref stream) const +{ + auto const zipped_out_begin = thrust::make_zip_iterator(thrust::make_tuple(keys_out, values_out)); + auto const zipped_out_end = impl_->retrieve_all(zipped_out_begin, stream); + auto const num = std::distance(zipped_out_begin, zipped_out_end); + + return std::make_pair(keys_out + num, values_out + num); +} + template find_async(first, last, output_begin, ref(op::find), stream); } +template +template +void static_multiset::find_if( + InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream) const +{ + this->find_if_async(first, last, stencil, pred, output_begin, stream); + stream.wait(); +} + +template +template +void static_multiset:: + find_if_async(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream) const +{ + impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream); +} + template return impl_->retrieve_outer(first, last, output_probe, output_match, probe_ref, stream); } +template +template +OutputIt +static_multiset::retrieve_all( + OutputIt output_begin, cuda::stream_ref stream) const +{ + return impl_->retrieve_all(output_begin, stream); +} + template impl_->find_async(first, last, output_begin, ref(op::find), stream); } +template +template +void static_set::find_if( + InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream) const +{ + this->find_if_async(first, last, stencil, pred, output_begin, stream); + stream.wait(); +} + +template +template +void static_set::find_if_async( + InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream) const +{ + impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream); +} + template std::iterator_traits::value_type + * @tparam OutputIt Device accessible output iterator + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + + * std::distance(first, last))` + * @param output_begin Beginning of the sequence of matches retrieved for each key + * @param stream Stream used for executing the kernels + */ + template + void find_if(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream = {}) const; + + /** + * @brief For all keys in the range `[first, last)`, asynchronously finds + * a match with its key equivalent to the query key. + * + * @note If `pred( *(stencil + i) )` is true, stores the payload of the + * matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )` + * is false, always stores the `empty_value_sentienl` to `(output_begin + i)`. + * + * @tparam InputIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose `value_type` is convertible to + * Predicate's argument type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and + * argument type is convertible from std::iterator_traits::value_type + * @tparam OutputIt Device accessible output iterator + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + + * std::distance(first, last))` + * @param output_begin Beginning of the sequence of matches retrieved for each key + * @param stream Stream used for executing the kernels + */ + template + void find_if_async(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream = {}) const; + /** * @brief Applies the given function object `callback_op` to the copy of every filled slot in the * container @@ -852,7 +916,7 @@ class static_map { size_type count(InputIt first, InputIt last, cuda::stream_ref stream = {}) const; /** - * @brief Retrieves all of the keys and their associated values. + * @brief Retrieves all of the keys and their associated values contained in the map * * @note This API synchronizes the given stream. * @note The order in which keys are returned is implementation defined and not guaranteed to be diff --git a/include/cuco/static_multimap.cuh b/include/cuco/static_multimap.cuh index bf3e5998f..21454dc44 100644 --- a/include/cuco/static_multimap.cuh +++ b/include/cuco/static_multimap.cuh @@ -518,6 +518,70 @@ class static_multimap { OutputIt output_begin, cuda::stream_ref stream = {}) const; + /** + * @brief For all keys in the range `[first, last)`, finds a match with its key equivalent to the + * query key. + * + * @note If `pred( *(stencil + i) )` is true, stores the payload of the + * matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )` + * is false, always stores the `empty_value_sentienl` to `(output_begin + i)`. + * @note This function synchronizes the given stream. For asynchronous execution use + * `find_if_async`. + * + * @tparam InputIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose `value_type` is convertible to + * Predicate's argument type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and + * argument type is convertible from std::iterator_traits::value_type + * @tparam OutputIt Device accessible output iterator + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + + * std::distance(first, last))` + * @param output_begin Beginning of the sequence of matches retrieved for each key + * @param stream Stream used for executing the kernels + */ + template + void find_if(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream = {}) const; + + /** + * @brief For all keys in the range `[first, last)`, asynchronously finds + * a match with its key equivalent to the query key. + * + * @note If `pred( *(stencil + i) )` is true, stores the payload of the + * matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )` + * is false, always stores the `empty_value_sentienl` to `(output_begin + i)`. + * + * @tparam InputIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose `value_type` is convertible to + * Predicate's argument type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and + * argument type is convertible from std::iterator_traits::value_type + * @tparam OutputIt Device accessible output iterator + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + + * std::distance(first, last))` + * @param output_begin Beginning of the sequence of matches retrieved for each key + * @param stream Stream used for executing the kernels + */ + template + void find_if_async(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream = {}) const; + /** * @brief Applies the given function object `callback_op` to the copy of every filled slot in the * container @@ -602,6 +666,31 @@ class static_multimap { template size_type count(InputIt first, InputIt last, cuda::stream_ref stream = {}) const; + /** + * @brief Retrieves all of the keys and their associated values contained in the multimap + * + * @note This API synchronizes the given stream. + * @note The order in which keys are returned is implementation defined and not guaranteed to be + * consistent between subsequent calls to `retrieve_all`. + * @note Behavior is undefined if the range beginning at `keys_out` or `values_out` is smaller + * than the return value of `size()`. + * + * @tparam KeyOut Device accessible random access output iterator whose `value_type` is + * convertible from `key_type`. + * @tparam ValueOut Device accesible random access output iterator whose `value_type` is + * convertible from `mapped_type`. + * + * @param keys_out Beginning output iterator for keys + * @param values_out Beginning output iterator for associated values + * @param stream CUDA stream used for this operation + * + * @return Pair of iterators indicating the last elements in the output + */ + template + std::pair retrieve_all(KeyOut keys_out, + ValueOut values_out, + cuda::stream_ref stream = {}) const; + /** * @brief Regenerates the container. * diff --git a/include/cuco/static_multiset.cuh b/include/cuco/static_multiset.cuh index 9ecbde9b7..f7fd961fa 100644 --- a/include/cuco/static_multiset.cuh +++ b/include/cuco/static_multiset.cuh @@ -482,6 +482,70 @@ class static_multiset { OutputIt output_begin, cuda::stream_ref stream = {}) const; + /** + * @brief For all keys in the range `[first, last)`, finds a match with its key equivalent to the + * query key. + * + * @note If `pred( *(stencil + i) )` is true, stores the payload of the + * matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )` + * is false, always stores the `empty_value_sentienl` to `(output_begin + i)`. + * @note This function synchronizes the given stream. For asynchronous execution use + * `find_if_async`. + * + * @tparam InputIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose `value_type` is convertible to + * Predicate's argument type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and + * argument type is convertible from std::iterator_traits::value_type + * @tparam OutputIt Device accessible output iterator + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + + * std::distance(first, last))` + * @param output_begin Beginning of the sequence of matches retrieved for each key + * @param stream Stream used for executing the kernels + */ + template + void find_if(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream = {}) const; + + /** + * @brief For all keys in the range `[first, last)`, asynchronously finds + * a match with its key equivalent to the query key. + * + * @note If `pred( *(stencil + i) )` is true, stores the payload of the + * matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )` + * is false, always stores the `empty_value_sentienl` to `(output_begin + i)`. + * + * @tparam InputIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose `value_type` is convertible to + * Predicate's argument type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and + * argument type is convertible from std::iterator_traits::value_type + * @tparam OutputIt Device accessible output iterator + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + + * std::distance(first, last))` + * @param output_begin Beginning of the sequence of matches retrieved for each key + * @param stream Stream used for executing the kernels + */ + template + void find_if_async(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream = {}) const; + /** * @brief Applies the given function object `callback_op` to the copy of every filled slot in the * container @@ -738,6 +802,26 @@ class static_multiset { OutputMatchIt output_match, cuda::stream_ref stream = {}) const; + /** + * @brief Retrieves all keys contained in the multiset + * + * @note This API synchronizes the given stream. + * @note The order in which keys are returned is implementation defined and not guaranteed to be + * consistent between subsequent calls to `retrieve_all`. + * @note Behavior is undefined if the range beginning at `output_begin` is smaller than the return + * value of `size()`. + * + * @tparam OutputIt Device accessible random access output iterator whose `value_type` is + * convertible from the container's `key_type`. + * + * @param output_begin Beginning output iterator for keys + * @param stream CUDA stream used for this operation + * + * @return Iterator indicating the end of the output + */ + template + OutputIt retrieve_all(OutputIt output_begin, cuda::stream_ref stream = {}) const; + /** * @brief Regenerates the container. * diff --git a/include/cuco/static_set.cuh b/include/cuco/static_set.cuh index ce2f799b0..827f42068 100644 --- a/include/cuco/static_set.cuh +++ b/include/cuco/static_set.cuh @@ -590,6 +590,70 @@ class static_set { OutputIt output_begin, cuda::stream_ref stream = {}) const; + /** + * @brief For all keys in the range `[first, last)`, finds a match with its key equivalent to the + * query key. + * + * @note If `pred( *(stencil + i) )` is true, stores the payload of the + * matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )` + * is false, always stores the `empty_value_sentienl` to `(output_begin + i)`. + * @note This function synchronizes the given stream. For asynchronous execution use + * `find_if_async`. + * + * @tparam InputIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose `value_type` is convertible to + * Predicate's argument type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and + * argument type is convertible from std::iterator_traits::value_type + * @tparam OutputIt Device accessible output iterator + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + + * std::distance(first, last))` + * @param output_begin Beginning of the sequence of matches retrieved for each key + * @param stream Stream used for executing the kernels + */ + template + void find_if(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream = {}) const; + + /** + * @brief For all keys in the range `[first, last)`, asynchronously finds + * a match with its key equivalent to the query key. + * + * @note If `pred( *(stencil + i) )` is true, stores the payload of the + * matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )` + * is false, always stores the `empty_value_sentienl` to `(output_begin + i)`. + * + * @tparam InputIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose `value_type` is convertible to + * Predicate's argument type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and + * argument type is convertible from std::iterator_traits::value_type + * @tparam OutputIt Device accessible output iterator + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + + * std::distance(first, last))` + * @param output_begin Beginning of the sequence of matches retrieved for each key + * @param stream Stream used for executing the kernels + */ + template + void find_if_async(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream = {}) const; + /** * @brief Applies the given function object `callback_op` to the copy of every filled slot in the * container diff --git a/tests/static_map/unique_sequence_test.cu b/tests/static_map/unique_sequence_test.cu index 4ab864ab7..8e8d6c9ad 100644 --- a/tests/static_map/unique_sequence_test.cu +++ b/tests/static_map/unique_sequence_test.cu @@ -26,7 +26,6 @@ #include #include #include -#include #include #include @@ -34,17 +33,15 @@ using size_type = int32_t; +int32_t constexpr SENTINEL = -1; + template void test_unique_sequence(Map& map, size_type num_keys) { using Key = typename Map::key_type; using Value = typename Map::mapped_type; - thrust::device_vector d_keys(num_keys); - - thrust::sequence(thrust::device, d_keys.begin(), d_keys.end()); - - auto keys_begin = d_keys.begin(); + auto keys_begin = thrust::counting_iterator{0}; auto pairs_begin = thrust::make_transform_iterator( thrust::make_counting_iterator(0), cuda::proclaim_return_type>( @@ -128,6 +125,27 @@ void test_unique_sequence(Map& map, size_type num_keys) REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal)); } + SECTION("Conditional find should return valid values on even inputs.") + { + auto found_results = thrust::device_vector(num_keys); + auto gold_fn = cuda::proclaim_return_type([] __device__(auto const& i) { + return i % 2 == 0 ? static_cast(i) : Value{SENTINEL}; + }); + + map.find_if(keys_begin, + keys_begin + num_keys, + thrust::counting_iterator{0}, + is_even, + found_results.begin()); + + REQUIRE(cuco::test::equal( + found_results.begin(), + found_results.end(), + thrust::make_transform_iterator(thrust::counting_iterator{0}, gold_fn), + cuda::proclaim_return_type( + [] __device__(auto const& found, auto const& gold) { return found == gold; }))); + } + SECTION("All inserted key-values should be properly retrieved") { thrust::device_vector d_values(num_keys); @@ -188,7 +206,7 @@ TEMPLATE_TEST_CASE_SIG( probe, cuco::cuda_allocator, cuco::storage<2>>{ - extent_type{}, cuco::empty_key{-1}, cuco::empty_value{-1}}; + extent_type{}, cuco::empty_key{SENTINEL}, cuco::empty_value{SENTINEL}}; REQUIRE(map.capacity() == gold_capacity); diff --git a/tests/static_multimap/find_test.cu b/tests/static_multimap/find_test.cu index 51456b088..3fe0ae8bc 100644 --- a/tests/static_multimap/find_test.cu +++ b/tests/static_multimap/find_test.cu @@ -28,6 +28,9 @@ using size_type = int32_t; +int32_t constexpr KEY_SENTINEL = -1; +int32_t constexpr VAL_SENTINEL = -2; + template void test_multimap_find(Map& map, size_type num_keys) { @@ -70,6 +73,29 @@ void test_multimap_find(Map& map, size_type num_keys) REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal)); } + + SECTION("Conditional find should return valid values on even inputs.") + { + auto found_results = thrust::device_vector(num_keys); + auto is_even = + cuda::proclaim_return_type([] __device__(auto const& i) { return i % 2 == 0; }); + auto gold_fn = cuda::proclaim_return_type([] __device__(auto const& i) { + return i % 2 == 0 ? static_cast(i) * 2 : Value{VAL_SENTINEL}; + }); + + map.find_if(keys_begin, + keys_begin + num_keys, + thrust::counting_iterator{0}, + is_even, + found_results.begin()); + + REQUIRE(cuco::test::equal( + found_results.begin(), + found_results.end(), + thrust::make_transform_iterator(thrust::counting_iterator{0}, gold_fn), + cuda::proclaim_return_type( + [] __device__(auto const& found, auto const& gold) { return found == gold; }))); + } } TEMPLATE_TEST_CASE_SIG( @@ -100,7 +126,7 @@ TEMPLATE_TEST_CASE_SIG( probe, cuco::cuda_allocator, cuco::storage<2>>{ - num_keys, cuco::empty_key{-1}, cuco::empty_value{-2}}; + num_keys, cuco::empty_key{KEY_SENTINEL}, cuco::empty_value{VAL_SENTINEL}}; test_multimap_find(map, num_keys); } diff --git a/tests/static_multiset/find_test.cu b/tests/static_multiset/find_test.cu index b0945ab90..6379b60fb 100644 --- a/tests/static_multiset/find_test.cu +++ b/tests/static_multiset/find_test.cu @@ -29,16 +29,14 @@ using size_type = int32_t; +int32_t constexpr SENTINEL = -1; + template void test_unique_sequence(Set& set, size_type num_keys) { using Key = typename Set::key_type; - thrust::device_vector d_keys(num_keys); - - thrust::sequence(d_keys.begin(), d_keys.end()); - - auto keys_begin = d_keys.begin(); + auto keys_begin = thrust::counting_iterator{0}; thrust::device_vector d_contained(num_keys); auto zip_equal = cuda::proclaim_return_type( @@ -66,6 +64,28 @@ void test_unique_sequence(Set& set, size_type num_keys) REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal)); } + + SECTION("Conditional find should return valid values on even inputs.") + { + auto found_results = thrust::device_vector(num_keys); + auto is_even = + cuda::proclaim_return_type([] __device__(auto const& i) { return i % 2 == 0; }); + auto gold_fn = cuda::proclaim_return_type( + [] __device__(auto const& i) { return i % 2 == 0 ? static_cast(i) : Key{SENTINEL}; }); + + set.find_if(keys_begin, + keys_begin + num_keys, + thrust::counting_iterator{0}, + is_even, + found_results.begin()); + + REQUIRE(cuco::test::equal( + found_results.begin(), + found_results.end(), + thrust::make_transform_iterator(thrust::counting_iterator{0}, gold_fn), + cuda::proclaim_return_type( + [] __device__(auto const& found, auto const& gold) { return found == gold; }))); + } } TEMPLATE_TEST_CASE_SIG( @@ -87,8 +107,8 @@ TEMPLATE_TEST_CASE_SIG( cuco::linear_probing>, cuco::double_hashing>>; - auto set = - cuco::static_multiset{num_keys, cuco::empty_key{-1}, {}, probe{}, {}, cuco::storage<2>{}}; + auto set = cuco::static_multiset{ + num_keys, cuco::empty_key{SENTINEL}, {}, probe{}, {}, cuco::storage<2>{}}; test_unique_sequence(set, num_keys); } diff --git a/tests/static_set/unique_sequence_test.cu b/tests/static_set/unique_sequence_test.cu index 199c7cbff..0cd2924a9 100644 --- a/tests/static_set/unique_sequence_test.cu +++ b/tests/static_set/unique_sequence_test.cu @@ -24,7 +24,6 @@ #include #include #include -#include #include #include @@ -32,16 +31,14 @@ using size_type = int32_t; +int32_t constexpr SENTINEL = -1; + template void test_unique_sequence(Set& set, size_type num_keys) { using Key = typename Set::key_type; - thrust::device_vector d_keys(num_keys); - - thrust::sequence(thrust::device, d_keys.begin(), d_keys.end()); - - auto keys_begin = d_keys.begin(); + auto keys_begin = thrust::counting_iterator{0}; thrust::device_vector d_contained(num_keys); auto zip_equal = cuda::proclaim_return_type( @@ -116,6 +113,26 @@ void test_unique_sequence(Set& set, size_type num_keys) REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal)); } + + SECTION("Conditional find should return valid values on even inputs.") + { + auto found_results = thrust::device_vector(num_keys); + auto gold_fn = cuda::proclaim_return_type( + [] __device__(auto const& i) { return i % 2 == 0 ? static_cast(i) : Key{SENTINEL}; }); + + set.find_if(keys_begin, + keys_begin + num_keys, + thrust::counting_iterator{0}, + is_even, + found_results.begin()); + + REQUIRE(cuco::test::equal( + found_results.begin(), + found_results.end(), + thrust::make_transform_iterator(thrust::counting_iterator{0}, gold_fn), + cuda::proclaim_return_type( + [] __device__(auto const& found, auto const& gold) { return found == gold; }))); + } } TEMPLATE_TEST_CASE_SIG( @@ -141,7 +158,7 @@ TEMPLATE_TEST_CASE_SIG( cuco::double_hashing>>; auto set = - cuco::static_set{num_keys, cuco::empty_key{-1}, {}, probe{}, {}, cuco::storage<2>{}}; + cuco::static_set{num_keys, cuco::empty_key{SENTINEL}, {}, probe{}, {}, cuco::storage<2>{}}; REQUIRE(set.capacity() == gold_capacity);