Skip to content

Commit

Permalink
Optimize computation for single-level lists
Browse files Browse the repository at this point in the history
Signed-off-by: Yan Feng <[email protected]>
  • Loading branch information
ustcfy committed Nov 19, 2024
1 parent 058fd47 commit e7749c8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
39 changes: 31 additions & 8 deletions src/main/cpp/src/hive_hash.cu
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,11 @@ class hive_device_row_hasher {
* L2[0] L2[1] L2[2] List<int>
* | |
* i1 i2 Int
* / \ / \
* i1[0] i1[1] i2[0] i2[1] Int
*
* Note: L2、i1、i2 are all temporary columns, which would not be pushed into the stack.
* There is an optimization for the list column. If the child column is of primitive type, the
* hash value of the list column can be directly computed. Thus we can decrease the depth of
* the stack by one.
*
* @tparam T The type of the column.
* @param col The column device view.
Expand Down Expand Up @@ -389,13 +390,30 @@ class hive_device_row_hasher {
}
} else if (curr_col.type().id() == cudf::type_id::LIST) {
// Get the child column of the list column
curr_col = cudf::detail::lists_column_device_view(curr_col).get_sliced_child();
if (top.get_idx_to_process() == curr_col.size()) {
cudf::column_device_view child_col =
cudf::detail::lists_column_device_view(curr_col).get_sliced_child();
// If the child column is of primitive type, directly compute the hash value of the list
if (child_col.type().id() != cudf::type_id::LIST &&
child_col.type().id() != cudf::type_id::STRUCT) {
auto single_level_list_hash = cudf::detail::accumulate(
thrust::counting_iterator(0),
thrust::counting_iterator(child_col.size()),
HIVE_INIT_HASH,
[child_col, hasher = this->hash_functor] __device__(auto hash, auto element_index) {
auto cur_hash = cudf::type_dispatcher<cudf::experimental::dispatch_void_if_nested>(
child_col.type(), hasher, child_col, element_index);
return HIVE_HASH_FACTOR * hash + cur_hash;
});
top.update_cur_hash(single_level_list_hash);
if (--stack_size > 0) { col_stack[stack_size - 1].update_cur_hash(top.get_hash()); }
} else {
// Push the next element into the stack
col_stack[stack_size++] =
col_stack_frame(curr_col.slice(top.get_and_inc_idx_to_process(), 1));
if (top.get_idx_to_process() == child_col.size()) {
if (--stack_size > 0) { col_stack[stack_size - 1].update_cur_hash(top.get_hash()); }
} else {
// Push the next element into the stack
col_stack[stack_size++] =
col_stack_frame(child_col.slice(top.get_and_inc_idx_to_process(), 1));
}
}
} else {
// There is only one element in the column for primitive types
Expand All @@ -422,7 +440,12 @@ void check_nested_depth(cudf::table_view const& input, int max_nested_depth)

column_checker_fn_t get_nested_depth = [&](cudf::column_view const& col) {
if (col.type().id() == cudf::type_id::LIST) {
return 1 + get_nested_depth(cudf::lists_column_view(col).child());
auto const& child_col = cudf::lists_column_view(col).child();
if (child_col.type().id() != cudf::type_id::STRUCT &&
child_col.type().id() != cudf::type_id::LIST) {
return 1;
}
return 1 + get_nested_depth(child_col);
} else if (col.type().id() == cudf::type_id::STRUCT) {
int max_child_depth = 0;
for (auto child = col.child_begin(); child != col.child_end(); ++child) {
Expand Down
3 changes: 2 additions & 1 deletion src/test/java/com/nvidia/spark/rapids/jni/HashTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,8 @@ void testHiveHashNestedDepthExceedsLimit() {
ColumnView structs3 = ColumnView.makeStructView(structs2, bools);
ColumnView structs4 = ColumnView.makeStructView(structs3);
ColumnView structs5 = ColumnView.makeStructView(structs4, floats);
ColumnView nestedResult = ColumnView.makeStructView(structs5);) {
ColumnView structs6 = ColumnView.makeStructView(structs5);
ColumnView nestedResult = ColumnView.makeStructView(structs6);) {
assertThrows(CudfException.class, () -> Hash.hiveHash(new ColumnView[]{nestedResult}));
}
}
Expand Down

0 comments on commit e7749c8

Please sign in to comment.