Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify expression transformer in Parquet predicate pushdown with ast::tree #17587

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 24 additions & 27 deletions cpp/src/io/parquet/predicate_pushdown.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ class stats_expression_converter : public ast::detail::expression_transformer {
*/
std::reference_wrapper<ast::expression const> visit(ast::literal const& expr) override
{
_stats_expr = std::reference_wrapper<ast::expression const>(expr);
return expr;
}

Expand All @@ -278,7 +277,6 @@ class stats_expression_converter : public ast::detail::expression_transformer {
"Statistics AST supports only left table");
CUDF_EXPECTS(expr.get_column_index() < _num_columns,
"Column index cannot be more than number of columns in the table");
_stats_expr = std::reference_wrapper<ast::expression const>(expr);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These don't need to be pushed to the ast::tree since we will be filtering with the new columns in the StatsAST table.

return expr;
}

Expand Down Expand Up @@ -307,6 +305,9 @@ class stats_expression_converter : public ast::detail::expression_transformer {
CUDF_EXPECTS(dynamic_cast<ast::literal const*>(&operands[1].get()) != nullptr,
"Second operand of binary operation with column reference must be a literal");
v->accept(*this);
// Push literal into the ast::tree
auto const& literal =
_stats_expr.push(*dynamic_cast<ast::literal const*>(&operands[1].get()));
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
auto const col_index = v->get_column_index();
switch (op) {
/* transform to stats conditions. op(col, literal)
Expand All @@ -318,48 +319,46 @@ class stats_expression_converter : public ast::detail::expression_transformer {
col1 <= val --> vmin <= val
*/
case ast_operator::EQUAL: {
auto const& vmin = _col_ref.emplace_back(col_index * 2);
auto const& vmax = _col_ref.emplace_back(col_index * 2 + 1);
auto const& op1 =
_operators.emplace_back(ast_operator::LESS_EQUAL, vmin, operands[1].get());
auto const& op2 =
_operators.emplace_back(ast_operator::GREATER_EQUAL, vmax, operands[1].get());
_operators.emplace_back(ast::ast_operator::LOGICAL_AND, op1, op2);
auto const& vmin = _stats_expr.push(ast::column_reference{col_index * 2});
auto const& vmax = _stats_expr.push(ast::column_reference{col_index * 2 + 1});
_stats_expr.push(ast::operation{
ast::ast_operator::LOGICAL_AND,
_stats_expr.push(ast::operation{ast_operator::GREATER_EQUAL, vmax, literal}),
_stats_expr.push(ast::operation{ast_operator::LESS_EQUAL, vmin, literal})});
break;
}
case ast_operator::NOT_EQUAL: {
auto const& vmin = _col_ref.emplace_back(col_index * 2);
auto const& vmax = _col_ref.emplace_back(col_index * 2 + 1);
auto const& op1 = _operators.emplace_back(ast_operator::NOT_EQUAL, vmin, vmax);
auto const& op2 =
_operators.emplace_back(ast_operator::NOT_EQUAL, vmax, operands[1].get());
_operators.emplace_back(ast_operator::LOGICAL_OR, op1, op2);
auto const& vmin = _stats_expr.push(ast::column_reference{col_index * 2});
auto const& vmax = _stats_expr.push(ast::column_reference{col_index * 2 + 1});
_stats_expr.push(ast::operation{
ast_operator::LOGICAL_OR,
_stats_expr.push(ast::operation{ast_operator::NOT_EQUAL, vmin, vmax}),
_stats_expr.push(ast::operation{ast_operator::NOT_EQUAL, vmax, literal})});
break;
}
case ast_operator::LESS: [[fallthrough]];
case ast_operator::LESS_EQUAL: {
auto const& vmin = _col_ref.emplace_back(col_index * 2);
_operators.emplace_back(op, vmin, operands[1].get());
auto const& vmin = _stats_expr.push(ast::column_reference{col_index * 2});
_stats_expr.push(ast::operation{op, vmin, literal});
break;
}
case ast_operator::GREATER: [[fallthrough]];
case ast_operator::GREATER_EQUAL: {
auto const& vmax = _col_ref.emplace_back(col_index * 2 + 1);
_operators.emplace_back(op, vmax, operands[1].get());
auto const& vmax = _stats_expr.push(ast::column_reference{col_index * 2 + 1});
_stats_expr.push(ast::operation{op, vmax, literal});
break;
}
default: CUDF_FAIL("Unsupported operation in Statistics AST");
};
} else {
auto new_operands = visit_operands(operands);
if (cudf::ast::detail::ast_operator_arity(op) == 2) {
_operators.emplace_back(op, new_operands.front(), new_operands.back());
_stats_expr.push(ast::operation{op, new_operands.front(), new_operands.back()});
} else if (cudf::ast::detail::ast_operator_arity(op) == 1) {
_operators.emplace_back(op, new_operands.front());
_stats_expr.push(ast::operation{op, new_operands.front()});
}
}
_stats_expr = std::reference_wrapper<ast::expression const>(_operators.back());
return std::reference_wrapper<ast::expression const>(_operators.back());
return _stats_expr.back();
}

/**
Expand All @@ -369,7 +368,7 @@ class stats_expression_converter : public ast::detail::expression_transformer {
*/
[[nodiscard]] std::reference_wrapper<ast::expression const> get_stats_expr() const
{
return _stats_expr.value().get();
return _stats_expr.back();
}

private:
Expand All @@ -383,10 +382,8 @@ class stats_expression_converter : public ast::detail::expression_transformer {
}
return transformed_operands;
}
std::optional<std::reference_wrapper<ast::expression const>> _stats_expr;
ast::tree _stats_expr;
size_type _num_columns;
std::list<ast::column_reference> _col_ref;
std::list<ast::operation> _operators;
};
} // namespace

Expand Down
Loading