From 723c565f7a03e3e9a842526cd4cc94bcf6f582e5 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Fri, 17 Nov 2023 17:37:47 -0800 Subject: [PATCH] Fix intermediate type checking in expression parsing (#14445) When parsing expressions, device data references are reused if there are multiple that are identical. Equality is determined by comparing the fields of the reference, but previously the data type was omitted. For column and literal references, this is OK because the `data_index` uniquely identifies the reference. For intermediates, however, the index is not sufficient to disambiguate because an expression could reuse a given location even if the operation produces a different data type. Therefore, the data type must be part of the equality operator. Resolves #14409 Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - David Wendt (https://github.com/davidwendt) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/14445 --- .../cudf/ast/detail/expression_parser.hpp | 4 +-- cpp/tests/ast/transform_tests.cpp | 27 +++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/cpp/include/cudf/ast/detail/expression_parser.hpp b/cpp/include/cudf/ast/detail/expression_parser.hpp index db0abe435b0..a36a831a7aa 100644 --- a/cpp/include/cudf/ast/detail/expression_parser.hpp +++ b/cpp/include/cudf/ast/detail/expression_parser.hpp @@ -67,8 +67,8 @@ struct alignas(8) device_data_reference { bool operator==(device_data_reference const& rhs) const { - return std::tie(data_index, reference_type, table_source) == - std::tie(rhs.data_index, rhs.reference_type, rhs.table_source); + return std::tie(data_index, data_type, reference_type, table_source) == + std::tie(rhs.data_index, rhs.data_type, rhs.reference_type, rhs.table_source); } }; diff --git a/cpp/tests/ast/transform_tests.cpp b/cpp/tests/ast/transform_tests.cpp index c0109a40cec..624a781c5b9 100644 --- a/cpp/tests/ast/transform_tests.cpp +++ b/cpp/tests/ast/transform_tests.cpp @@ -316,6 +316,33 @@ TEST_F(TransformTest, ImbalancedTreeArithmetic) CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view(), verbosity); } +TEST_F(TransformTest, ImbalancedTreeArithmeticDeep) +{ + auto c_0 = column_wrapper{4, 5, 6}; + auto table = cudf::table_view{{c_0}}; + + auto col_ref_0 = cudf::ast::column_reference(0); + + // expression: (c0 < c0) == (c0 < (c0 + c0)) + // {false, false, false} == (c0 < {8, 10, 12}) + // {false, false, false} == {true, true, true} + // {false, false, false} + auto expression_left_subtree = + cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref_0, col_ref_0); + auto expression_right_inner_subtree = + cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0, col_ref_0); + auto expression_right_subtree = + cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref_0, expression_right_inner_subtree); + + auto expression_tree = cudf::ast::operation( + cudf::ast::ast_operator::EQUAL, expression_left_subtree, expression_right_subtree); + + auto result = cudf::compute_column(table, expression_tree); + auto expected = column_wrapper{false, false, false}; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view(), verbosity); +} + TEST_F(TransformTest, MultiLevelTreeComparator) { auto c_0 = column_wrapper{3, 20, 1, 50};