Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use invoke_one when possible #448

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions include/cuco/detail/open_addressing/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ CUCO_KERNEL void insert_if_n(InputIt first,
AtomicT* num_successes,
Ref ref)
{
namespace cg = cooperative_groups;

using BlockReduce = cub::BlockReduce<typename Ref::size_type, BlockSize>;
__shared__ typename BlockReduce::TempStorage temp_storage;
typename Ref::size_type thread_num_successes = 0;
Expand All @@ -84,7 +86,13 @@ CUCO_KERNEL void insert_if_n(InputIt first,
} else {
auto const tile =
cooperative_groups::tiled_partition<CGSize>(cooperative_groups::this_thread_block());
#if defined(CUCO_HAS_CG_INVOKE_ONE)
if (ref.insert(tile, insert_element)) {
cg::invoke_one(tile, [&]() { thread_num_successes++; });
}
#else
if (ref.insert(tile, insert_element) && tile.thread_rank() == 0) { thread_num_successes++; }
#endif
}
}
idx += loop_stride;
Expand All @@ -93,9 +101,15 @@ CUCO_KERNEL void insert_if_n(InputIt first,
// compute number of successfully inserted elements for each block
// and atomically add to the grand total
auto const block_num_successes = BlockReduce(temp_storage).Sum(thread_num_successes);
#if defined(CUCO_HAS_CG_INVOKE_ONE)
cg::invoke_one(cg::this_thread_block(), [&]() {
num_successes->fetch_add(block_num_successes, cuda::std::memory_order_relaxed);
});
#else
if (threadIdx.x == 0) {
num_successes->fetch_add(block_num_successes, cuda::std::memory_order_relaxed);
}
#endif
}

/**
Expand Down Expand Up @@ -248,7 +262,11 @@ CUCO_KERNEL void contains_if_n(InputIt first,
if (idx < n) {
typename std::iterator_traits<InputIt>::value_type const& key = *(first + idx);
auto const found = pred(*(stencil + idx)) ? ref.contains(tile, key) : false;
#if defined(CUCO_HAS_CG_INVOKE_ONE)
cg::invoke_one(tile, [&]() { *(output_begin + idx) = found; });
#else
if (tile.thread_rank() == 0) { *(output_begin + idx) = found; }
#endif
}
}
idx += loop_stride;
Expand All @@ -270,6 +288,7 @@ CUCO_KERNEL void contains_if_n(InputIt first,
template <int32_t BlockSize, typename StorageRef, typename Predicate, typename AtomicT>
CUCO_KERNEL void size(StorageRef storage, Predicate is_filled, AtomicT* count)
{
namespace cg = cooperative_groups;
using size_type = typename StorageRef::size_type;

auto const loop_stride = cuco::detail::grid_stride();
Expand All @@ -290,7 +309,12 @@ CUCO_KERNEL void size(StorageRef storage, Predicate is_filled, AtomicT* count)
using BlockReduce = cub::BlockReduce<size_type, BlockSize>;
__shared__ typename BlockReduce::TempStorage temp_storage;
auto const block_count = BlockReduce(temp_storage).Sum(thread_count);
#if defined(CUCO_HAS_CG_INVOKE_ONE)
cg::invoke_one(cg::this_thread_block(),
[&]() { count->fetch_add(block_count, cuda::std::memory_order_relaxed); });
#else
if (threadIdx.x == 0) { count->fetch_add(block_count, cuda::std::memory_order_relaxed); }
#endif
}

template <int32_t BlockSize, typename ContainerRef, typename Predicate>
Expand All @@ -315,7 +339,11 @@ CUCO_KERNEL void rehash(typename ContainerRef::storage_ref_type storage_ref,
auto const n = storage_ref.num_windows();

while (idx - thread_rank < n) {
#if defined(CUCO_HAS_CG_INVOKE_ONE)
cg::invoke_one(block, [&]() { buffer_size = 0; });
#else
if (thread_rank == 0) { buffer_size = 0; }
#endif
block.sync();

// gather values in shmem buffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,12 @@ class open_addressing_ref_impl {
auto const num_windows = static_cast<size_type>(this->window_extent());
#if defined(CUDA_HAS_CUDA_BARRIER)
__shared__ cuda::barrier<cuda::thread_scope::thread_scope_block> barrier;

#if defined(CUCO_HAS_CG_INVOKE_ONE)
cooperative_groups::invoke_one(g, [&]() { init(&barrier, g.size()); });
#else
if (g.thread_rank() == 0) { init(&barrier, g.size()); }
#endif
g.sync();

cuda::memcpy_async(
Expand Down
6 changes: 6 additions & 0 deletions include/cuco/detail/static_map/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,15 @@ CUCO_KERNEL void find(InputIt first, cuco::detail::index_type n, OutputIt output
auto const tile = cg::tiled_partition<CGSize>(block);
auto const found = ref.find(tile, key);

#if defined(CUCO_HAS_CG_INVOKE_ONE)
cg::invoke_one(tile, [&]() {
*(output_begin + idx) = found == ref.end() ? ref.empty_value_sentinel() : (*found).second;
});
#else
if (tile.thread_rank() == 0) {
*(output_begin + idx) = found == ref.end() ? ref.empty_value_sentinel() : (*found).second;
}
#endif
}
}
idx += loop_stride;
Expand Down
6 changes: 6 additions & 0 deletions include/cuco/detail/static_set/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,15 @@ CUCO_KERNEL void find(InputIt first, cuco::detail::index_type n, OutputIt output
auto const tile = cg::tiled_partition<CGSize>(block);
auto const found = ref.find(tile, key);

#if defined(CUCO_HAS_CG_INVOKE_ONE)
cg::invoke_one(tile, [&]() {
*(output_begin + idx) = found == ref.end() ? ref.empty_key_sentinel() : *found;
});
#else
if (tile.thread_rank() == 0) {
*(output_begin + idx) = found == ref.end() ? ref.empty_key_sentinel() : *found;
}
#endif
Comment on lines 81 to +92
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For context: This is the section I've been focussing on in my investigation.

}
}
idx += loop_stride;
Expand Down
Loading