Skip to content

Commit

Permalink
Refactor mixed_semi_join using cuco::static_set (rapidsai#16230)
Browse files Browse the repository at this point in the history
This PR refactors `mixed_semi_join` by replacing **cuco** legacy `static_map` with latest `static_set`. 
Contributes to rapidsai#12261.

Authors:
  - Srinivas Yadav (https://github.com/srinivasyadav18)
  - Muhammad Haseeb (https://github.com/mhaseeb123)

Approvers:
  - Yunsong Wang (https://github.com/PointKernel)
  - Nghia Truong (https://github.com/ttnghia)

URL: rapidsai#16230
  • Loading branch information
srinivasyadav18 authored Sep 18, 2024
1 parent 2a3026d commit e68f55c
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 91 deletions.
6 changes: 0 additions & 6 deletions cpp/src/join/join_common_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <cudf/table/row_operators.cuh>
#include <cudf/table/table_view.hpp>

#include <cuco/static_map.cuh>
#include <cuco/static_multimap.cuh>
#include <cuda/atomic>

Expand Down Expand Up @@ -51,11 +50,6 @@ using mixed_multimap_type =
cudf::detail::cuco_allocator<char>,
cuco::legacy::double_hashing<1, hash_type, hash_type>>;

using semi_map_type = cuco::legacy::static_map<hash_value_type,
size_type,
cuda::thread_scope_device,
cudf::detail::cuco_allocator<char>>;

using row_hash_legacy =
cudf::row_hasher<cudf::hashing::detail::default_hash, cudf::nullate::DYNAMIC>;

Expand Down
33 changes: 33 additions & 0 deletions cpp/src/join/mixed_join_common_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <rmm/device_uvector.hpp>

#include <cub/cub.cuh>
#include <cuco/static_set.cuh>

namespace cudf {
namespace detail {
Expand Down Expand Up @@ -160,6 +161,38 @@ struct pair_expression_equality : public expression_equality<has_nulls> {
}
};

/**
* @brief Equality comparator that composes two row_equality comparators.
*/
struct double_row_equality_comparator {
row_equality const equality_comparator;
row_equality const conditional_comparator;

__device__ bool operator()(size_type lhs_row_index, size_type rhs_row_index) const noexcept
{
using experimental::row::lhs_index_type;
using experimental::row::rhs_index_type;

return equality_comparator(lhs_index_type{lhs_row_index}, rhs_index_type{rhs_row_index}) &&
conditional_comparator(lhs_index_type{lhs_row_index}, rhs_index_type{rhs_row_index});
}
};

// A CUDA Cooperative Group of 4 threads for the hash set.
auto constexpr DEFAULT_MIXED_JOIN_CG_SIZE = 4;

// The hash set type used by mixed_semi_join with the build_table.
using hash_set_type = cuco::static_set<size_type,
cuco::extent<size_t>,
cuda::thread_scope_device,
double_row_equality_comparator,
cuco::linear_probing<DEFAULT_MIXED_JOIN_CG_SIZE, row_hash>,
cudf::detail::cuco_allocator<char>,
cuco::storage<1>>;

// The hash_set_ref_type used by mixed_semi_join kerenels for probing.
using hash_set_ref_type = hash_set_type::ref_type<cuco::contains_tag>;

} // namespace detail

} // namespace cudf
35 changes: 18 additions & 17 deletions cpp/src/join/mixed_join_kernels_semi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ CUDF_KERNEL void __launch_bounds__(block_size)
table_device_view right_table,
table_device_view probe,
table_device_view build,
row_hash const hash_probe,
row_equality const equality_probe,
cudf::detail::semi_map_type::device_view hash_table_view,
hash_set_ref_type set_ref,
cudf::device_span<bool> left_table_keep_mask,
cudf::ast::detail::expression_device_view device_expression_data)
{
auto constexpr cg_size = hash_set_ref_type::cg_size;

auto const tile =
cooperative_groups::tiled_partition<cg_size>(cooperative_groups::this_thread_block());

// Normally the casting of a shared memory array is used to create multiple
// arrays of different types from the shared memory buffer, but here it is
// used to circumvent conflicts between arrays of different types between
Expand All @@ -52,24 +56,24 @@ CUDF_KERNEL void __launch_bounds__(block_size)
cudf::ast::detail::IntermediateDataType<has_nulls>* intermediate_storage =
reinterpret_cast<cudf::ast::detail::IntermediateDataType<has_nulls>*>(raw_intermediate_storage);
auto thread_intermediate_storage =
&intermediate_storage[threadIdx.x * device_expression_data.num_intermediates];

cudf::size_type const left_num_rows = left_table.num_rows();
cudf::size_type const right_num_rows = right_table.num_rows();
auto const outer_num_rows = left_num_rows;
&intermediate_storage[tile.meta_group_rank() * device_expression_data.num_intermediates];

cudf::size_type outer_row_index = threadIdx.x + blockIdx.x * block_size;
cudf::size_type const outer_num_rows = left_table.num_rows();
auto const outer_row_index = cudf::detail::grid_1d::global_thread_id<block_size>() / cg_size;

auto evaluator = cudf::ast::detail::expression_evaluator<has_nulls>(
left_table, right_table, device_expression_data);

if (outer_row_index < outer_num_rows) {
// Make sure to swap_tables here as hash_set will use probe table as the left one.
auto constexpr swap_tables = true;
// Figure out the number of elements for this key.
auto equality = single_expression_equality<has_nulls>{
evaluator, thread_intermediate_storage, false, equality_probe};
evaluator, thread_intermediate_storage, swap_tables, equality_probe};

left_table_keep_mask[outer_row_index] =
hash_table_view.contains(outer_row_index, hash_probe, equality);
auto const set_ref_equality = set_ref.with_key_eq(equality);
auto const result = set_ref_equality.contains(tile, outer_row_index);
if (tile.thread_rank() == 0) left_table_keep_mask[outer_row_index] = result;
}
}

Expand All @@ -78,9 +82,8 @@ void launch_mixed_join_semi(bool has_nulls,
table_device_view right_table,
table_device_view probe,
table_device_view build,
row_hash const hash_probe,
row_equality const equality_probe,
cudf::detail::semi_map_type::device_view hash_table_view,
hash_set_ref_type set_ref,
cudf::device_span<bool> left_table_keep_mask,
cudf::ast::detail::expression_device_view device_expression_data,
detail::grid_1d const config,
Expand All @@ -94,9 +97,8 @@ void launch_mixed_join_semi(bool has_nulls,
right_table,
probe,
build,
hash_probe,
equality_probe,
hash_table_view,
set_ref,
left_table_keep_mask,
device_expression_data);
} else {
Expand All @@ -106,9 +108,8 @@ void launch_mixed_join_semi(bool has_nulls,
right_table,
probe,
build,
hash_probe,
equality_probe,
hash_table_view,
set_ref,
left_table_keep_mask,
device_expression_data);
}
Expand Down
6 changes: 2 additions & 4 deletions cpp/src/join/mixed_join_kernels_semi.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ namespace detail {
* @param[in] right_table The right table
* @param[in] probe The table with which to probe the hash table for matches.
* @param[in] build The table with which the hash table was built.
* @param[in] hash_probe The hasher used for the probe table.
* @param[in] equality_probe The equality comparator used when probing the hash table.
* @param[in] hash_table_view The hash table built from `build`.
* @param[in] set_ref The hash table device view built from `build`.
* @param[out] left_table_keep_mask The result of the join operation with "true" element indicating
* the corresponding index from left table is present in output
* @param[in] device_expression_data Container of device data required to evaluate the desired
Expand All @@ -58,9 +57,8 @@ void launch_mixed_join_semi(bool has_nulls,
table_device_view right_table,
table_device_view probe,
table_device_view build,
row_hash const hash_probe,
row_equality const equality_probe,
cudf::detail::semi_map_type::device_view hash_table_view,
hash_set_ref_type set_ref,
cudf::device_span<bool> left_table_keep_mask,
cudf::ast::detail::expression_device_view device_expression_data,
detail::grid_1d const config,
Expand Down
90 changes: 26 additions & 64 deletions cpp/src/join/mixed_join_semi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,45 +46,6 @@
namespace cudf {
namespace detail {

namespace {
/**
* @brief Device functor to create a pair of hash value and index for a given row.
*/
struct make_pair_function_semi {
__device__ __forceinline__ cudf::detail::pair_type operator()(size_type i) const noexcept
{
// The value is irrelevant since we only ever use the hash map to check for
// membership of a particular row index.
return cuco::make_pair(static_cast<hash_value_type>(i), 0);
}
};

/**
* @brief Equality comparator that composes two row_equality comparators.
*/
class double_row_equality {
public:
double_row_equality(row_equality equality_comparator, row_equality conditional_comparator)
: _equality_comparator{equality_comparator}, _conditional_comparator{conditional_comparator}
{
}

__device__ bool operator()(size_type lhs_row_index, size_type rhs_row_index) const noexcept
{
using experimental::row::lhs_index_type;
using experimental::row::rhs_index_type;

return _equality_comparator(lhs_index_type{lhs_row_index}, rhs_index_type{rhs_row_index}) &&
_conditional_comparator(lhs_index_type{lhs_row_index}, rhs_index_type{rhs_row_index});
}

private:
row_equality _equality_comparator;
row_equality _conditional_comparator;
};

} // namespace

std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
table_view const& left_equality,
table_view const& right_equality,
Expand All @@ -96,7 +57,7 @@ std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_EXPECTS((join_type != join_kind::INNER_JOIN) && (join_type != join_kind::LEFT_JOIN) &&
CUDF_EXPECTS((join_type != join_kind::INNER_JOIN) and (join_type != join_kind::LEFT_JOIN) and
(join_type != join_kind::FULL_JOIN),
"Inner, left, and full joins should use mixed_join.");

Expand Down Expand Up @@ -137,7 +98,7 @@ std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
// output column and follow the null-supporting expression evaluation code
// path.
auto const has_nulls = cudf::nullate::DYNAMIC{
cudf::has_nulls(left_equality) || cudf::has_nulls(right_equality) ||
cudf::has_nulls(left_equality) or cudf::has_nulls(right_equality) or
binary_predicate.may_evaluate_null(left_conditional, right_conditional, stream)};

auto const parser = ast::detail::expression_parser{
Expand All @@ -156,27 +117,20 @@ std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
auto right_conditional_view = table_device_view::create(right_conditional, stream);

auto const preprocessed_build =
experimental::row::equality::preprocessed_table::create(build, stream);
cudf::experimental::row::equality::preprocessed_table::create(build, stream);
auto const preprocessed_probe =
experimental::row::equality::preprocessed_table::create(probe, stream);
cudf::experimental::row::equality::preprocessed_table::create(probe, stream);
auto const row_comparator =
cudf::experimental::row::equality::two_table_comparator{preprocessed_probe, preprocessed_build};
cudf::experimental::row::equality::two_table_comparator{preprocessed_build, preprocessed_probe};
auto const equality_probe = row_comparator.equal_to<false>(has_nulls, compare_nulls);

semi_map_type hash_table{
compute_hash_table_size(build.num_rows()),
cuco::empty_key{std::numeric_limits<hash_value_type>::max()},
cuco::empty_value{cudf::detail::JoinNoneValue},
cudf::detail::cuco_allocator<char>{rmm::mr::polymorphic_allocator<char>{}, stream},
stream.value()};

// Create hash table containing all keys found in right table
// TODO: To add support for nested columns we will need to flatten in many
// places. However, this probably isn't worth adding any time soon since we
// won't be able to support AST conditions for those types anyway.
auto const build_nulls = cudf::nullate::DYNAMIC{cudf::has_nulls(build)};
auto const row_hash_build = cudf::experimental::row::hash::row_hasher{preprocessed_build};
auto const hash_build = row_hash_build.device_hasher(build_nulls);

// Since we may see multiple rows that are identical in the equality tables
// but differ in the conditional tables, the equality comparator used for
// insertion must account for both sets of tables. An alternative solution
Expand All @@ -191,39 +145,48 @@ std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
auto const equality_build_equality =
row_comparator_build.equal_to<false>(build_nulls, compare_nulls);
auto const preprocessed_build_condtional =
experimental::row::equality::preprocessed_table::create(right_conditional, stream);
cudf::experimental::row::equality::preprocessed_table::create(right_conditional, stream);
auto const row_comparator_conditional_build =
cudf::experimental::row::equality::two_table_comparator{preprocessed_build_condtional,
preprocessed_build_condtional};
auto const equality_build_conditional =
row_comparator_conditional_build.equal_to<false>(build_nulls, compare_nulls);
double_row_equality equality_build{equality_build_equality, equality_build_conditional};
make_pair_function_semi pair_func_build{};

auto iter = cudf::detail::make_counting_transform_iterator(0, pair_func_build);
hash_set_type row_set{
{compute_hash_table_size(build.num_rows())},
cuco::empty_key{JoinNoneValue},
{equality_build_equality, equality_build_conditional},
{row_hash_build.device_hasher(build_nulls)},
{},
{},
cudf::detail::cuco_allocator<char>{rmm::mr::polymorphic_allocator<char>{}, stream},
{stream.value()}};

auto iter = thrust::make_counting_iterator(0);

// skip rows that are null here.
if ((compare_nulls == null_equality::EQUAL) or (not nullable(build))) {
hash_table.insert(iter, iter + right_num_rows, hash_build, equality_build, stream.value());
row_set.insert(iter, iter + right_num_rows, stream.value());
} else {
thrust::counting_iterator<cudf::size_type> stencil(0);
auto const [row_bitmask, _] =
cudf::detail::bitmask_and(build, stream, cudf::get_current_device_resource_ref());
row_is_valid pred{static_cast<bitmask_type const*>(row_bitmask.data())};

// insert valid rows
hash_table.insert_if(
iter, iter + right_num_rows, stencil, pred, hash_build, equality_build, stream.value());
row_set.insert_if(iter, iter + right_num_rows, stencil, pred, stream.value());
}

auto hash_table_view = hash_table.get_device_view();

detail::grid_1d const config(outer_num_rows, DEFAULT_JOIN_BLOCK_SIZE);
auto const shmem_size_per_block = parser.shmem_per_thread * config.num_threads_per_block;
auto const shmem_size_per_block =
parser.shmem_per_thread *
cuco::detail::int_div_ceil(config.num_threads_per_block, hash_set_type::cg_size);

auto const row_hash = cudf::experimental::row::hash::row_hasher{preprocessed_probe};
auto const hash_probe = row_hash.device_hasher(has_nulls);

hash_set_ref_type const row_set_ref = row_set.ref(cuco::contains).with_hash_function(hash_probe);

// Vector used to indicate indices from left/probe table which are present in output
auto left_table_keep_mask = rmm::device_uvector<bool>(probe.num_rows(), stream);

Expand All @@ -232,9 +195,8 @@ std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
*right_conditional_view,
*probe_view,
*build_view,
hash_probe,
equality_probe,
hash_table_view,
row_set_ref,
cudf::device_span<bool>(left_table_keep_mask),
parser.device_expression_data,
config,
Expand Down
30 changes: 30 additions & 0 deletions cpp/tests/join/mixed_join_tests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,21 @@ TYPED_TEST(MixedLeftSemiJoinTest, BasicEquality)
{1});
}

TYPED_TEST(MixedLeftSemiJoinTest, MixedLeftSemiJoinGatherMap)
{
auto const col_ref_left_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::LEFT);
auto const col_ref_right_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::RIGHT);
auto left_one_greater_right_one =
cudf::ast::operation(cudf::ast::ast_operator::GREATER, col_ref_left_1, col_ref_right_1);

this->test({{2, 3, 9, 0, 1, 7, 4, 6, 5, 8}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}},
{{6, 5, 9, 8, 10, 32}, {0, 1, 2, 3, 4, 5}, {7, 8, 9, 0, 1, 2}},
{0},
{1},
left_one_greater_right_one,
{2, 7, 8});
}

TYPED_TEST(MixedLeftSemiJoinTest, BasicEqualityDuplicates)
{
this->test({{0, 1, 2, 1}, {3, 4, 5, 6}, {10, 20, 30, 40}},
Expand Down Expand Up @@ -900,3 +915,18 @@ TYPED_TEST(MixedLeftAntiJoinTest, AsymmetricLeftLargerEquality)
left_zero_eq_right_zero,
{0, 1, 3});
}

TYPED_TEST(MixedLeftAntiJoinTest, MixedLeftAntiJoinGatherMap)
{
auto const col_ref_left_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::LEFT);
auto const col_ref_right_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::RIGHT);
auto left_one_greater_right_one =
cudf::ast::operation(cudf::ast::ast_operator::GREATER, col_ref_left_1, col_ref_right_1);

this->test({{2, 3, 9, 0, 1, 7, 4, 6, 5, 8}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}},
{{6, 5, 9, 8, 10, 32}, {0, 1, 2, 3, 4, 5}, {7, 8, 9, 0, 1, 2}},
{0},
{1},
left_one_greater_right_one,
{0, 1, 3, 4, 5, 6, 9});
}

0 comments on commit e68f55c

Please sign in to comment.