diff --git a/src/main/cpp/src/xxhash64.cu b/src/main/cpp/src/xxhash64.cu index 375e4a19b..4ce47d869 100644 --- a/src/main/cpp/src/xxhash64.cu +++ b/src/main/cpp/src/xxhash64.cu @@ -34,6 +34,8 @@ namespace { using hash_value_type = int64_t; using half_size_type = int32_t; +constexpr int MAX_NESTED_DEPTH = 8; + constexpr __device__ inline int64_t rotate_bits_left_signed(hash_value_type h, int8_t r) { return (h << r) | (h >> (64 - r)) & ~(-1 << r); @@ -295,9 +297,8 @@ hash_value_type __device__ inline XXHash_64::operator()( * handling, refer to the SparkXXHash64 functor. * * @tparam Nullate A cudf::nullate type describing whether to check for nulls. - * @tparam MAX_NESTED_DEPTH The maximum nested depth allowed in the input table. */ -template +template class device_row_hasher { public: device_row_hasher(Nullate nulls, cudf::table_device_view const& t, hash_value_type seed) @@ -367,26 +368,27 @@ class device_row_hasher { return hasher.template operator()(col, row_index); } - class col_stack_element { + struct col_stack_frame { private: - cudf::column_device_view column; // current column - int child_idx; // index of the child column to process next, initialized as 0 + cudf::column_device_view _column; // the column to process + int _idx_to_process; // the index of child or element to process next public: - __device__ col_stack_element() = + __device__ col_stack_frame() = delete; // Because the default constructor of `cudf::column_device_view` is deleted - __device__ col_stack_element(cudf::column_device_view col) : column(col), child_idx(0) {} + __device__ col_stack_frame(cudf::column_device_view col) + : _column(std::move(col)), _idx_to_process(0) + { + } - __device__ int get_and_inc_child_idx() { return child_idx++; } + __device__ int get_and_inc_idx_to_process() { return _idx_to_process++; } - __device__ int cur_child_idx() { return child_idx; } + __device__ int get_idx_to_process() { return _idx_to_process; } - __device__ cudf::column_device_view get_column() { return column; } + __device__ cudf::column_device_view get_column() { return _column; } }; - typedef col_stack_element* col_stack_element_ptr; - /** * @brief Functor to compute hash value for nested columns. * @@ -403,9 +405,9 @@ class device_row_hasher { * next struct element would be pushed into the stack. * - If the current column is a primitive column, it computes the hash value. * - * For example, consider the following nested column: `List>`. - * list_of_struct_column = [ [(1, 2.0), (3, 4.0)] ], which has only one item - * [(1, 2.0), (3, 4.0)]. + * For example, consider that the input column is of type `List>`. + * Assume that the element at `row_index` is: [(1, 2.0), (3, 4.0)]. + * The sliced column is noted as L1 here. * * L1 List> * | @@ -446,44 +448,44 @@ class device_row_hasher { { hash_value_type ret = _seed; cudf::column_device_view curr_col = col.slice(row_index, 1); - // The default constructor of `col_stack_element` is deleted, so it can not allocate an array - // of `col_stack_element` directly. - // Instead leverage the byte array to create the col_stack_element array. - uint8_t stack_wrapper[MAX_NESTED_DEPTH * sizeof(col_stack_element)]; - col_stack_element_ptr col_stack = reinterpret_cast(stack_wrapper); - int stack_size = 0; + // The default constructor of `col_stack_frame` is deleted, so it can not allocate an array + // of `col_stack_frame` directly. + // Instead leverage the byte array to create the col_stack_frame array. + alignas(col_stack_frame) char stack_wrapper[sizeof(col_stack_frame) * MAX_NESTED_DEPTH]; + auto col_stack = reinterpret_cast(stack_wrapper); + int stack_size = 0; - col_stack[stack_size++] = col_stack_element(curr_col); + col_stack[stack_size++] = col_stack_frame(curr_col); while (stack_size > 0) { - col_stack_element& element = col_stack[stack_size - 1]; - curr_col = element.get_column(); + col_stack_frame& top = col_stack[stack_size - 1]; + curr_col = top.get_column(); // Replace list column with its most inner non-list child if (curr_col.type().id() == cudf::type_id::LIST) { do { curr_col = cudf::detail::lists_column_device_view(curr_col).get_sliced_child(); } while (curr_col.type().id() == cudf::type_id::LIST); - col_stack[stack_size - 1] = col_stack_element(curr_col); + col_stack[stack_size - 1] = col_stack_frame(curr_col); continue; } if (curr_col.type().id() == cudf::type_id::STRUCT) { if (curr_col.size() <= 1) { // struct element // All child columns processed, pop the element - if (element.cur_child_idx() == curr_col.num_child_columns()) { + if (top.get_idx_to_process() == curr_col.num_child_columns()) { --stack_size; } else { // Push the next child column into the stack - col_stack[stack_size++] = col_stack_element( - cudf::detail::structs_column_device_view(curr_col).get_sliced_child( - element.get_and_inc_child_idx())); + col_stack[stack_size++] = + col_stack_frame(cudf::detail::structs_column_device_view(curr_col).get_sliced_child( + top.get_and_inc_idx_to_process())); } } else { // struct column - if (element.cur_child_idx() == curr_col.size()) { + if (top.get_idx_to_process() == curr_col.size()) { --stack_size; } else { col_stack[stack_size++] = - col_stack_element(curr_col.slice(element.get_and_inc_child_idx(), 1)); + col_stack_frame(curr_col.slice(top.get_and_inc_idx_to_process(), 1)); } } } else { // Primitive column @@ -507,7 +509,7 @@ class device_row_hasher { hash_value_type const _seed; }; -void check_nested_depth(cudf::table_view const& input, int max_nested_depth) +void check_nested_depth(cudf::table_view const& input) { using column_checker_fn_t = std::function; @@ -528,11 +530,11 @@ void check_nested_depth(cudf::table_view const& input, int max_nested_depth) for (auto i = 0; i < input.num_columns(); i++) { cudf::column_view const& col = input.column(i); - CUDF_EXPECTS(get_nested_depth(col) <= max_nested_depth, + CUDF_EXPECTS(get_nested_depth(col) <= MAX_NESTED_DEPTH, "The " + std::to_string(i) + "-th column exceeds the maximum allowed nested depth. " + "Current depth: " + std::to_string(get_nested_depth(col)) + ", " + - "Maximum allowed depth: " + std::to_string(max_nested_depth)); + "Maximum allowed depth: " + std::to_string(MAX_NESTED_DEPTH)); } } @@ -553,9 +555,8 @@ std::unique_ptr xxhash64(cudf::table_view const& input, // Return early if there's nothing to hash if (input.num_columns() == 0 || input.num_rows() == 0) { return output; } - // Nested depth cannot exceed 8 - constexpr int max_nested_depth = 8; - check_nested_depth(input, max_nested_depth); + + check_nested_depth(input); bool const nullable = has_nested_nulls(input); auto const input_view = cudf::table_device_view::create(input, stream); @@ -565,7 +566,7 @@ std::unique_ptr xxhash64(cudf::table_view const& input, thrust::tabulate(rmm::exec_policy(stream), output_view.begin(), output_view.end(), - device_row_hasher(nullable, *input_view, seed)); + device_row_hasher(nullable, *input_view, seed)); return output; } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java b/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java index ae12c25e0..874cb84b5 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java @@ -533,6 +533,36 @@ void testXXHash64ListOfStruct() { } } + @Test + void testXXHash64NestedDepthExceedsLimit() { + try (ColumnVector nestedIntListCV = ColumnVector.fromLists( + new ListType(true, new ListType(true, new BasicType(true, DType.INT32))), + Arrays.asList(Arrays.asList(null, null), null), + Arrays.asList(Collections.singletonList(0), Collections.singletonList(-2), Collections.singletonList(3)), + Arrays.asList(null, Collections.singletonList(Integer.MAX_VALUE)), + Arrays.asList(Collections.singletonList(5), Arrays.asList(-6, null)), + Arrays.asList(Collections.singletonList(Integer.MIN_VALUE), null), + null); + ColumnVector integers = ColumnVector.fromBoxedInts( + 0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null); + ColumnVector doubles = ColumnVector.fromBoxedDoubles(0.0, 100.0, -100.0, + POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null); + ColumnVector floats = ColumnVector.fromBoxedFloats(0f, 100f, -100f, + NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); + ColumnVector bools = ColumnVector.fromBoxedBooleans( + true, false, null, false, true, null); + ColumnView structs1 = ColumnView.makeStructView(nestedIntListCV, integers); + ColumnView structs2 = ColumnView.makeStructView(structs1, doubles); + ColumnView structs3 = ColumnView.makeStructView(structs2, bools); + ColumnView structs4 = ColumnView.makeStructView(structs3); + ColumnView structs5 = ColumnView.makeStructView(structs4, floats); + ColumnView structs6 = ColumnView.makeStructView(structs5); + ColumnView structs7 = ColumnView.makeStructView(structs6); + ColumnView nestedResult = ColumnView.makeStructView(structs7);) { + assertThrows(CudfException.class, () -> Hash.xxhash64(new ColumnView[]{nestedResult})); + } + } + @Test void testHiveHashBools() { try (ColumnVector v0 = ColumnVector.fromBoxedBooleans(true, false, null);