From d39d59ae8b6be942e125b4103ee2500114bb8534 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Mon, 11 Nov 2024 17:31:42 -0800 Subject: [PATCH] Add `retrieve_all` for multiset and multimap (#635) This PR adds `retrieve_all` for multiset and multimap --- .../static_multimap/static_multimap.inl | 20 +++++++++++++++ .../static_multiset/static_multiset.inl | 15 +++++++++++ include/cuco/static_map.cuh | 2 +- include/cuco/static_multimap.cuh | 25 +++++++++++++++++++ include/cuco/static_multiset.cuh | 20 +++++++++++++++ 5 files changed, 81 insertions(+), 1 deletion(-) diff --git a/include/cuco/detail/static_multimap/static_multimap.inl b/include/cuco/detail/static_multimap/static_multimap.inl index 7236d3175..649cbc749 100644 --- a/include/cuco/detail/static_multimap/static_multimap.inl +++ b/include/cuco/detail/static_multimap/static_multimap.inl @@ -398,6 +398,26 @@ static_multimapcount(first, last, ref(op::count), stream); } +template +template +std::pair +static_multimap::retrieve_all( + KeyOut keys_out, ValueOut values_out, cuda::stream_ref stream) const +{ + auto const zipped_out_begin = thrust::make_zip_iterator(thrust::make_tuple(keys_out, values_out)); + auto const zipped_out_end = impl_->retrieve_all(zipped_out_begin, stream); + auto const num = std::distance(zipped_out_begin, zipped_out_end); + + return std::make_pair(keys_out + num, values_out + num); +} + template return impl_->retrieve_outer(first, last, output_probe, output_match, probe_ref, stream); } +template +template +OutputIt +static_multiset::retrieve_all( + OutputIt output_begin, cuda::stream_ref stream) const +{ + return impl_->retrieve_all(output_begin, stream); +} + template size_type count(InputIt first, InputIt last, cuda::stream_ref stream = {}) const; + /** + * @brief Retrieves all of the keys and their associated values contained in the multimap + * + * @note This API synchronizes the given stream. + * @note The order in which keys are returned is implementation defined and not guaranteed to be + * consistent between subsequent calls to `retrieve_all`. + * @note Behavior is undefined if the range beginning at `keys_out` or `values_out` is smaller + * than the return value of `size()`. + * + * @tparam KeyOut Device accessible random access output iterator whose `value_type` is + * convertible from `key_type`. + * @tparam ValueOut Device accesible random access output iterator whose `value_type` is + * convertible from `mapped_type`. + * + * @param keys_out Beginning output iterator for keys + * @param values_out Beginning output iterator for associated values + * @param stream CUDA stream used for this operation + * + * @return Pair of iterators indicating the last elements in the output + */ + template + std::pair retrieve_all(KeyOut keys_out, + ValueOut values_out, + cuda::stream_ref stream = {}) const; + /** * @brief Regenerates the container. * diff --git a/include/cuco/static_multiset.cuh b/include/cuco/static_multiset.cuh index 9ecbde9b7..044f5d8be 100644 --- a/include/cuco/static_multiset.cuh +++ b/include/cuco/static_multiset.cuh @@ -738,6 +738,26 @@ class static_multiset { OutputMatchIt output_match, cuda::stream_ref stream = {}) const; + /** + * @brief Retrieves all keys contained in the multiset + * + * @note This API synchronizes the given stream. + * @note The order in which keys are returned is implementation defined and not guaranteed to be + * consistent between subsequent calls to `retrieve_all`. + * @note Behavior is undefined if the range beginning at `output_begin` is smaller than the return + * value of `size()`. + * + * @tparam OutputIt Device accessible random access output iterator whose `value_type` is + * convertible from the container's `key_type`. + * + * @param output_begin Beginning output iterator for keys + * @param stream CUDA stream used for this operation + * + * @return Iterator indicating the end of the output + */ + template + OutputIt retrieve_all(OutputIt output_begin, cuda::stream_ref stream = {}) const; + /** * @brief Regenerates the container. *