Skip to content

Commit

Permalink
Add host find_if APIs for all hash tables (#638)
Browse files Browse the repository at this point in the history
This migrates the existing OA `find` kernel to `find_if_n` and adds
`find_if(_async)` APIs for all hash tables.
  • Loading branch information
PointKernel authored Nov 12, 2024
1 parent d39d59a commit eb9319b
Show file tree
Hide file tree
Showing 14 changed files with 588 additions and 31 deletions.
30 changes: 23 additions & 7 deletions include/cuco/detail/open_addressing/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -329,19 +329,33 @@ struct find_buffer<Container, cuda::std::void_t<typename Container::mapped_type>
* @tparam CGSize Number of threads in each CG
* @tparam BlockSize The size of the thread block
* @tparam InputIt Device accessible input iterator
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
* and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
* @tparam OutputIt Device accessible output iterator
* @tparam Ref Type of non-owning device ref allowing access to storage
*
* @param first Beginning of the sequence of keys
* @param n Number of keys to query
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
* @param output_begin Beginning of the sequence of matched payloads retrieved for each key
* @param ref Non-owning container device ref used to access the slot storage
*/
template <int32_t CGSize, int32_t BlockSize, typename InputIt, typename OutputIt, typename Ref>
CUCO_KERNEL __launch_bounds__(BlockSize) void find(InputIt first,
cuco::detail::index_type n,
OutputIt output_begin,
Ref ref)
template <int32_t CGSize,
int32_t BlockSize,
typename InputIt,
typename StencilIt,
typename Predicate,
typename OutputIt,
typename Ref>
CUCO_KERNEL __launch_bounds__(BlockSize) void find_if_n(InputIt first,
cuco::detail::index_type n,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
Ref ref)
{
namespace cg = cooperative_groups;

Expand Down Expand Up @@ -382,7 +396,7 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void find(InputIt first,
* synchronizing before writing back to global, we no longer rely on L1, preventing the
* increase in sector stores from L2 to global and improving performance.
*/
output_buffer[thread_idx] = output(found);
output_buffer[thread_idx] = pred(*(stencil + idx)) ? output(found) : sentinel;
}
block.sync();
if (idx < n) { *(output_begin + idx) = output_buffer[thread_idx]; }
Expand All @@ -392,7 +406,9 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void find(InputIt first,
typename std::iterator_traits<InputIt>::value_type const& key = *(first + idx);
auto const found = ref.find(tile, key);

if (tile.thread_rank() == 0) { *(output_begin + idx) = output(found); }
if (tile.thread_rank() == 0) {
*(output_begin + idx) = pred(*(stencil + idx)) ? output(found) : sentinel;
}
}
}
idx += loop_stride;
Expand Down
48 changes: 46 additions & 2 deletions include/cuco/detail/open_addressing/open_addressing_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -565,15 +565,59 @@ class open_addressing_impl {
OutputIt output_begin,
Ref container_ref,
cuda::stream_ref stream) const noexcept
{
auto const always_true = thrust::constant_iterator<bool>{true};

this->find_if_async(
first, last, always_true, thrust::identity{}, output_begin, container_ref, stream);
}

/**
* @brief For all keys in the range `[first, last)`, asynchronously finds
* a match with its key equivalent to the query key.
*
* @note If `pred( *(stencil + i) )` is true, stores the payload of the
* matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )`
* is false, stores `empty_value_sentienl` to `(output_begin + i)`.
*
* @tparam InputIt Device accessible input iterator
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
* @tparam OutputIt Device accessible output iterator
* @tparam Ref Type of non-owning device container ref allowing access to storage
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[stencil, stencil +
* std::distance(first, last))`
* @param output_begin Beginning of the sequence of matches retrieved for each key
* @param container_ref Non-owning device container ref used to access the slot storage
* @param stream Stream used for executing the kernels
*/
template <typename InputIt,
typename StencilIt,
typename Predicate,
typename OutputIt,
typename Ref>
void find_if_async(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
Ref container_ref,
cuda::stream_ref stream) const noexcept
{
auto const num_keys = cuco::detail::distance(first, last);
if (num_keys == 0) { return; }

auto const grid_size = cuco::detail::grid_size(num_keys, cg_size);

detail::find<cg_size, cuco::detail::default_block_size()>
detail::find_if_n<cg_size, cuco::detail::default_block_size()>
<<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
first, num_keys, output_begin, container_ref);
first, num_keys, stencil, pred, output_begin, container_ref);
}

/**
Expand Down
41 changes: 41 additions & 0 deletions include/cuco/detail/static_map/static_map.inl
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,47 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
impl_->find_async(first, last, output_begin, ref(op::find), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if(
InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
cuda::stream_ref stream) const
{
this->find_if_async(first, last, stencil, pred, output_begin, stream);
stream.wait();
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if_async(
InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
cuda::stream_ref stream) const
{
impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream);
}

template <class Key,
class T,
class Extent,
Expand Down
41 changes: 41 additions & 0 deletions include/cuco/detail/static_multimap/static_multimap.inl
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,47 @@ void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator,
impl_->find_async(first, last, output_begin, ref(op::find), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if(
InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
cuda::stream_ref stream) const
{
this->find_if_async(first, last, stencil, pred, output_begin, stream);
stream.wait();
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
find_if_async(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
cuda::stream_ref stream) const
{
impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream);
}

template <class Key,
class T,
class Extent,
Expand Down
39 changes: 39 additions & 0 deletions include/cuco/detail/static_multiset/static_multiset.inl
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,45 @@ void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Sto
impl_->find_async(first, last, output_begin, ref(op::find), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if(
InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
cuda::stream_ref stream) const
{
this->find_if_async(first, last, stencil, pred, output_begin, stream);
stream.wait();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
find_if_async(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
cuda::stream_ref stream) const
{
impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
Expand Down
39 changes: 39 additions & 0 deletions include/cuco/detail/static_set/static_set.inl
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,45 @@ void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
impl_->find_async(first, last, output_begin, ref(op::find), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if(
InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
cuda::stream_ref stream) const
{
this->find_if_async(first, last, stencil, pred, output_begin, stream);
stream.wait();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if_async(
InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
cuda::stream_ref stream) const
{
impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
Expand Down
64 changes: 64 additions & 0 deletions include/cuco/static_map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,70 @@ class static_map {
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief For all keys in the range `[first, last)`, finds a match with its key equivalent to the
* query key.
*
* @note If `pred( *(stencil + i) )` is true, stores the payload of the
* matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )`
* is false, always stores the `empty_value_sentienl` to `(output_begin + i)`.
* @note This function synchronizes the given stream. For asynchronous execution use
* `find_if_async`.
*
* @tparam InputIt Device accessible input iterator
* @tparam StencilIt Device accessible random access iterator whose `value_type` is convertible to
* Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
* @tparam OutputIt Device accessible output iterator
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[stencil, stencil +
* std::distance(first, last))`
* @param output_begin Beginning of the sequence of matches retrieved for each key
* @param stream Stream used for executing the kernels
*/
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
void find_if(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief For all keys in the range `[first, last)`, asynchronously finds
* a match with its key equivalent to the query key.
*
* @note If `pred( *(stencil + i) )` is true, stores the payload of the
* matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )`
* is false, always stores the `empty_value_sentienl` to `(output_begin + i)`.
*
* @tparam InputIt Device accessible input iterator
* @tparam StencilIt Device accessible random access iterator whose `value_type` is convertible to
* Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
* @tparam OutputIt Device accessible output iterator
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[stencil, stencil +
* std::distance(first, last))`
* @param output_begin Beginning of the sequence of matches retrieved for each key
* @param stream Stream used for executing the kernels
*/
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
void find_if_async(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief Applies the given function object `callback_op` to the copy of every filled slot in the
* container
Expand Down
Loading

0 comments on commit eb9319b

Please sign in to comment.