Skip to content

Commit

Permalink
Add rewrite hook to PruningPredicate
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Sep 24, 2024
1 parent 6546479 commit 93853b9
Showing 1 changed file with 141 additions and 20 deletions.
161 changes: 141 additions & 20 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,31 @@ pub struct PruningPredicate {
literal_guarantees: Vec<LiteralGuarantee>,
}

/// Hook to handle predicates that DataFusion can not handle, e.g. certain complex expressions
/// or predicates that reference columns that are not in the schema.
pub trait UnhandledPredicateHook {
/// Called when a predicate can not be handled by DataFusion's transformation rules
/// or is referencing a column that is not in the schema.
fn handle(&self, expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr>;
}

#[derive(Debug, Clone)]
struct ConstantUnhandledPredicateHook {
default: Arc<dyn PhysicalExpr>,
}

impl ConstantUnhandledPredicateHook {
fn new(default: Arc<dyn PhysicalExpr>) -> Self {
Self { default }
}
}

impl UnhandledPredicateHook for ConstantUnhandledPredicateHook {
fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
self.default.clone()
}
}

impl PruningPredicate {
/// Try to create a new instance of [`PruningPredicate`]
///
Expand All @@ -502,10 +527,33 @@ impl PruningPredicate {
/// See the struct level documentation on [`PruningPredicate`] for more
/// details.
pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: SchemaRef) -> Result<Self> {
let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::new(
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))),
));
Self::try_new_with_unhandled_hook(expr, schema, unhandled_hook)
}

/// Try to create a new instance of [`PruningPredicate`] with a custom
/// unhandled hook.
///
/// This is the same as [`PruningPredicate::try_new`] but allows for a custom
/// hook to be used when a predicate can not be handled by DataFusion's
/// transformation rules or is referencing a column that is not in the schema.
///
/// By default, a constant `true` is returned for unhandled predicates.
pub fn try_new_with_unhandled_hook(
expr: Arc<dyn PhysicalExpr>,
schema: SchemaRef,
unhandled_hook: Arc<dyn UnhandledPredicateHook>,
) -> Result<Self> {
// build predicate expression once
let mut required_columns = RequiredColumns::new();
let predicate_expr =
build_predicate_expression(&expr, schema.as_ref(), &mut required_columns);
let predicate_expr = build_predicate_expression(
&expr,
schema.as_ref(),
&mut required_columns,
&unhandled_hook,
);

let literal_guarantees = LiteralGuarantee::analyze(&expr);

Expand Down Expand Up @@ -1323,16 +1371,13 @@ fn build_predicate_expression(
expr: &Arc<dyn PhysicalExpr>,
schema: &Schema,
required_columns: &mut RequiredColumns,
unhandled_hook: &Arc<dyn UnhandledPredicateHook>,
) -> Arc<dyn PhysicalExpr> {
// Returned for unsupported expressions. Such expressions are
// converted to TRUE.
let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))));

// predicate expression can only be a binary expression
let expr_any = expr.as_any();
if let Some(is_null) = expr_any.downcast_ref::<phys_expr::IsNullExpr>() {
return build_is_null_column_expr(is_null.arg(), schema, required_columns, false)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(is_not_null) = expr_any.downcast_ref::<phys_expr::IsNotNullExpr>() {
return build_is_null_column_expr(
Expand All @@ -1341,19 +1386,19 @@ fn build_predicate_expression(
required_columns,
true,
)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(col) = expr_any.downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns, false)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(not) = expr_any.downcast_ref::<phys_expr::NotExpr>() {
// match !col (don't do so recursively)
if let Some(col) = not.arg().as_any().downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns, true)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
}
if let Some(in_list) = expr_any.downcast_ref::<phys_expr::InListExpr>() {
Expand Down Expand Up @@ -1382,9 +1427,9 @@ fn build_predicate_expression(
})
.reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _)
.unwrap();
return build_predicate_expression(&change_expr, schema, required_columns);
return build_predicate_expression(&change_expr, schema, required_columns, unhandled_hook);
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
}

Expand All @@ -1396,21 +1441,21 @@ fn build_predicate_expression(
bin_expr.right().clone(),
)
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
};

if op == Operator::And || op == Operator::Or {
let left_expr = build_predicate_expression(&left, schema, required_columns);
let right_expr = build_predicate_expression(&right, schema, required_columns);
let left_expr = build_predicate_expression(&left, schema, required_columns, unhandled_hook);
let right_expr = build_predicate_expression(&right, schema, required_columns, unhandled_hook);
// simplify boolean expression if applicable
let expr = match (&left_expr, op, &right_expr) {
(left, Operator::And, _) if is_always_true(left) => right_expr,
(_, Operator::And, right) if is_always_true(right) => left_expr,
(left, Operator::Or, right)
if is_always_true(left) || is_always_true(right) =>
{
unhandled
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))))
}
_ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)),
};
Expand All @@ -1424,11 +1469,11 @@ fn build_predicate_expression(
// allow partial failure in predicate expression generation
// this can still produce a useful predicate when multiple conditions are joined using AND
Err(_) => {
return unhandled;
return unhandled_hook.handle(expr)
}
};

build_statistics_expr(&mut expr_builder).unwrap_or(unhandled)
build_statistics_expr(&mut expr_builder).unwrap_or_else(|_| unhandled_hook.handle(expr))
}

fn build_statistics_expr(
Expand Down Expand Up @@ -1583,6 +1628,7 @@ mod tests {
use datafusion_expr::expr::InList;
use datafusion_expr::{cast, is_null, try_cast, Expr};
use datafusion_physical_expr::planner::logical2physical;
use datafusion_physical_expr::expressions as phys_expr;

#[derive(Debug, Default)]
/// Mock statistic provider for tests
Expand Down Expand Up @@ -3397,6 +3443,77 @@ mod tests {
// TODO: add test for other case and op
}

#[test]
fn test_rewrite_expr_to_prunable_custom_unhandled_hook() {
struct CustomUnhandledHook;

impl UnhandledPredicateHook for CustomUnhandledHook {
/// This handles an arbitrary case of a column that doesn't exist in the schema
/// by renaming it to yet another column that doesn't exist in the schema
/// (the transformation is arbitrary, the point is that it can do whatever it wants)
fn handle(&self, expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
if let Some(expr) = expr.as_any().downcast_ref::<phys_expr::BinaryExpr>() {
let left = expr.left();
let right = expr.right();
if let Some(column) = left.as_any().downcast_ref::<phys_expr::Column>() {
if column.name() == "b" {
if let Some(_) = right.as_any().downcast_ref::<phys_expr::Literal>() {
let new_column = Arc::new(phys_expr::Column::new("c", column.index())) as _;
return Arc::new(phys_expr::BinaryExpr::new(
new_column,
expr.op().clone(),
right.clone(),
));
}
}
}
}

Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))) as Arc<dyn PhysicalExpr>
}
}

let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);

let expr = Arc::new(
phys_expr::BinaryExpr::new(
Arc::new(phys_expr::Column::new("b", 1)),
Operator::Eq,
Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(12)))),
)
) as Arc<dyn PhysicalExpr>;

let expected_expr = Arc::new(
phys_expr::BinaryExpr::new(
Arc::new(phys_expr::Column::new("c", 1)),
Operator::Eq,
Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(12)))),
)
) as Arc<dyn PhysicalExpr>;

let handler = Arc::new(CustomUnhandledHook{}) as _;
let actual_expr = build_predicate_expression(&expr, &schema, &mut RequiredColumns::new(), &handler);

assert_eq!(actual_expr.to_string(), expected_expr.to_string());

// but other cases do end up as `true`

let expr = Arc::new(
phys_expr::BinaryExpr::new(
Arc::new(phys_expr::Column::new("d", 1)),
Operator::Eq,
Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(12)))),
)
) as Arc<dyn PhysicalExpr>;

let expected_expr = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))) as Arc<dyn PhysicalExpr>;

let handler = Arc::new(CustomUnhandledHook{}) as _;
let actual_expr = build_predicate_expression(&expr, &schema, &mut RequiredColumns::new(), &handler);

assert_eq!(actual_expr.to_string(), expected_expr.to_string());
}

#[test]
fn test_rewrite_expr_to_prunable_error() {
// cast string value to numeric value
Expand Down Expand Up @@ -3886,6 +4003,10 @@ mod tests {
required_columns: &mut RequiredColumns,
) -> Arc<dyn PhysicalExpr> {
let expr = logical2physical(expr, schema);
build_predicate_expression(&expr, schema, required_columns)
// return literal true
let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::new(
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))),
)) as _;
build_predicate_expression(&expr, schema, required_columns, &unhandled_hook)
}
}

0 comments on commit 93853b9

Please sign in to comment.