diff --git a/include/cuco/detail/static_multimap/static_multimap_ref.inl b/include/cuco/detail/static_multimap/static_multimap_ref.inl index 3bbf90e3a..3dc429f3b 100644 --- a/include/cuco/detail/static_multimap/static_multimap_ref.inl +++ b/include/cuco/detail/static_multimap/static_multimap_ref.inl @@ -487,6 +487,115 @@ class operator_impl< } }; +template +class operator_impl< + op::for_each_tag, + static_multimap_ref> { + using base_type = static_multimap_ref; + using ref_type = + static_multimap_ref; + + static constexpr auto cg_size = base_type::cg_size; + + public: + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Unary callback functor or device lambda + * + * @param key The key to search for + * @param callback_op Function to call on every element found + */ + template + __device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each(key, std::forward(callback_op)); + } + + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Unary callback functor or device lambda + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to call on every element found + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key, + CallbackOp&& callback_op) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each(group, key, std::forward(callback_op)); + } + + /** + * @brief Executes a callback on every element in the container with key equivalent to the probe + * key and can additionally perform work that requires synchronizing the Cooperative Group + * performing this operation. + * + * @note Passes an un-incrementable input iterator to the element whose key is equivalent to + * `key` to the callback. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @note The `sync_op` function can be used to perform work that requires synchronizing threads in + * `group` inbetween probing steps, where the number of probing steps performed between + * synchronization points is capped by `window_size * cg_size`. The functor will be called right + * after the current probing window has been traversed. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Unary callback functor or device lambda + * @tparam SyncOp Functor or device lambda which accepts the current `group` object + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to call on every element found + * @param sync_op Function that is allowed to synchronize `group` inbetween probing windows + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key, + CallbackOp&& callback_op, + SyncOp&& sync_op) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each( + group, key, std::forward(callback_op), std::forward(sync_op)); + } +}; + template + +#include +#include + +#include +#include +#include +#include + +#include +#include + +#include + +#include + +template +CUCO_KERNEL void for_each_check_scalar(Ref ref, + InputIt first, + std::size_t n, + std::size_t multiplicity, + AtomicErrorCounter* error_counter) +{ + static_assert(Ref::cg_size == 1, "Scalar test must have cg_size==1"); + auto const loop_stride = cuco::detail::grid_stride(); + auto idx = cuco::detail::global_thread_id(); + + while (idx < n) { + auto const& key = *(first + idx); + std::size_t matches = 0; + ref.for_each(key, [&] __device__(auto const slot) { + auto const [slot_key, slot_value] = slot; + if (ref.key_eq()(key, slot_key) and ref.key_eq()(slot_key, slot_value)) { matches++; } + }); + if (matches != multiplicity) { error_counter->fetch_add(1, cuda::memory_order_relaxed); } + idx += loop_stride; + } +} + +template +CUCO_KERNEL void for_each_check_cooperative(Ref ref, + InputIt first, + std::size_t n, + std::size_t multiplicity, + AtomicErrorCounter* error_counter) +{ + auto const loop_stride = cuco::detail::grid_stride() / Ref::cg_size; + auto idx = cuco::detail::global_thread_id() / Ref::cg_size; + ; + + while (idx < n) { + auto const tile = + cooperative_groups::tiled_partition(cooperative_groups::this_thread_block()); + auto const& key = *(first + idx); + std::size_t thread_matches = 0; + if constexpr (Synced) { + ref.for_each( + tile, + key, + [&] __device__(auto const slot) { + auto const [slot_key, slot_value] = slot; + if (ref.key_eq()(key, slot_key) and ref.key_eq()(slot_key, slot_value)) { + thread_matches++; + } + }, + [] __device__(auto const& group) { group.sync(); }); + } else { + ref.for_each(tile, key, [&] __device__(auto const slot) { + auto const [slot_key, slot_value] = slot; + if (ref.key_eq()(key, slot_key) and ref.key_eq()(slot_key, slot_value)) { + thread_matches++; + } + }); + } + auto const tile_matches = + cooperative_groups::reduce(tile, thread_matches, cooperative_groups::plus()); + if (tile_matches != multiplicity and tile.thread_rank() == 0) { + error_counter->fetch_add(1, cuda::memory_order_relaxed); + } + idx += loop_stride; + } +} + +TEMPLATE_TEST_CASE_SIG( + "static_multimap for_each tests", + "", + ((typename Key, cuco::test::probe_sequence Probe, int CGSize), Key, Probe, CGSize), + (int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, cuco::test::probe_sequence::linear_probing, 2)) +{ + constexpr size_t num_unique_keys{400}; + constexpr size_t key_multiplicity{5}; + constexpr size_t num_keys{num_unique_keys * key_multiplicity}; + + using probe = std::conditional_t>, + cuco::double_hashing>>; + + auto set = cuco::experimental::static_multimap{num_keys, + cuco::empty_key{-1}, + cuco::empty_value{-1}, + {}, + probe{}, + {}, + cuco::storage<2>{}}; + + auto unique_keys_begin = thrust::counting_iterator(0); + auto gen_duplicate_keys = cuda::proclaim_return_type( + [] __device__(auto const& k) { return static_cast(k % num_unique_keys); }); + auto keys_begin = thrust::make_transform_iterator(unique_keys_begin, gen_duplicate_keys); + + auto const pairs_begin = thrust::make_transform_iterator( + keys_begin, cuda::proclaim_return_type>([] __device__(auto i) { + return cuco::pair{i, i}; + })); + + set.insert(pairs_begin, pairs_begin + num_keys); + + using error_counter_type = cuda::atomic; + error_counter_type* error_counter; + CUCO_CUDA_TRY(cudaMallocHost(&error_counter, sizeof(error_counter_type))); + new (error_counter) error_counter_type{0}; + + auto const grid_size = cuco::detail::grid_size(num_unique_keys, CGSize); + auto const block_size = cuco::detail::default_block_size(); + + // test scalar for_each + if constexpr (CGSize == 1) { + for_each_check_scalar<<>>( + set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); + REQUIRE(error_counter->load() == 0); + error_counter->store(0); + } + + // test CG for_each + for_each_check_cooperative<<>>( + set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); + REQUIRE(error_counter->load() == 0); + error_counter->store(0); + + // test synchronized CG for_each + for_each_check_cooperative<<>>( + set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); + REQUIRE(error_counter->load() == 0); + + CUCO_CUDA_TRY(cudaFreeHost(error_counter)); +} \ No newline at end of file