diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 6e2cc0cbdbcba..3f9ee683ca184 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -15,11 +15,10 @@ //! [`PushDownFilter`] applies filters as early as possible use indexmap::IndexSet; +use itertools::Itertools; use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use itertools::Itertools; - use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; @@ -35,8 +34,8 @@ use datafusion_expr::utils::{ conjunction, expr_to_columns, split_conjunction, split_conjunction_owned, }; use datafusion_expr::{ - and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator, - Projection, TableProviderFilterPushDown, + and, build_join_schema, or, BinaryExpr, Distinct, DistinctOn, Expr, Filter, + LogicalPlanBuilder, Operator, Projection, TableProviderFilterPushDown, Volatility, }; use crate::optimizer::ApplyOrder; @@ -628,6 +627,108 @@ fn infer_join_predicates( .collect::>>() } +/// Check whether the given expression can be resolved using only the columns `col_names`. +/// This means that if this function returns true: +/// - the table provider can filter the table partition values with this expression +/// - the expression can be marked as `TableProviderFilterPushDown::Exact` once this filtering +/// was performed +fn expr_applicable_for_schema(schema: &DFSchema, expr: &Expr) -> bool { + let mut is_applicable = true; + expr.apply(|expr| match expr { + Expr::Column(column) => { + println!("schema {:?} column {:?}", schema, column); + is_applicable &= schema.has_column(&column); + if is_applicable { + Ok(TreeNodeRecursion::Jump) + } else { + Ok(TreeNodeRecursion::Stop) + } + } + Expr::Literal(_) + | Expr::Alias(_) + | Expr::OuterReferenceColumn(_, _) + | Expr::ScalarVariable(_, _) + | Expr::Not(_) + | Expr::IsNotNull(_) + | Expr::IsNull(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) + | Expr::Negative(_) + | Expr::Cast(_) + | Expr::TryCast(_) + | Expr::BinaryExpr(_) + | Expr::Between(_) + | Expr::Like(_) + | Expr::SimilarTo(_) + | Expr::InList(_) + | Expr::Exists(_) + | Expr::InSubquery(_) + | Expr::ScalarSubquery(_) + | Expr::GroupingSet(_) + | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), + + Expr::ScalarFunction(scalar_function) => { + match scalar_function.func.signature().volatility { + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(TreeNodeRecursion::Stop) + } + } + } + + // TODO other expressions are not handled yet: + // - AGGREGATE and WINDOW should not end up in filter conditions, except maybe in some edge cases + // - Can `Wildcard` be considered as a `Literal`? + // - ScalarVariable could be `applicable`, but that would require access to the context + Expr::AggregateFunction { .. } + | Expr::WindowFunction { .. } + | Expr::Wildcard { .. } + | Expr::Unnest { .. } + | Expr::Placeholder(_) => { + is_applicable = false; + Ok(TreeNodeRecursion::Stop) + } + }) + .unwrap(); + is_applicable +} + +fn check_if_expr_depends_only_on_distinct_on_columns( + distinct_on: &DistinctOn, + expr: &Expr, +) -> Result { + let distinct_on_cols: HashSet<&Column> = distinct_on + .on_expr + .iter() + .map(|e| e.column_refs()) + .flatten() + .collect(); + let distinct_on_input_schema = distinct_on.input.schema(); + let distinct_on_qualified_fields: Vec<_> = distinct_on_cols + .iter() + .map(|c| distinct_on_input_schema.qualified_field_from_column(c)) + .collect::>>()? + .into_iter() + .collect::>() + .into_iter() + .map(|(table_reference, field)| { + (table_reference.cloned(), Arc::new(field.clone())) + }) + .collect(); + let distinct_on_columns_schema = + DFSchema::new_with_metadata(distinct_on_qualified_fields, Default::default())?; + Ok(expr_applicable_for_schema( + &distinct_on_columns_schema, + &expr, + )) +} + impl OptimizerRule for PushDownFilter { fn name(&self) -> &str { "push_down_filter" @@ -656,6 +757,8 @@ impl OptimizerRule for PushDownFilter { return Ok(Transformed::no(plan)); }; + println!("Filter: {:?}", filter.input); + match Arc::unwrap_or_clone(filter.input) { LogicalPlan::Filter(child_filter) => { let parents_predicates = split_conjunction_owned(filter.predicate); @@ -709,13 +812,33 @@ impl OptimizerRule for PushDownFilter { Expr::Column(Column::new(qualifier.cloned(), field.name())), ); } - let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?; - - let new_filter = LogicalPlan::Filter(Filter::try_new( - new_predicate, - Arc::clone(&subquery_alias.input), - )?); - insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter) + let new_predicate = + replace_cols_by_name(filter.predicate.clone(), &replace_map)?; + match &subquery_alias.input.as_ref() { + LogicalPlan::Distinct(Distinct::On(distinct_on)) + // If the filter predicate uses columns that are not in the distinct on + // expressions, we can't push the filter down. This is because the filter + // might change the cardinality of the distinct on expressions. + if !check_if_expr_depends_only_on_distinct_on_columns( + distinct_on, + &new_predicate, + )? => + { + filter.input = + Arc::new(LogicalPlan::SubqueryAlias(subquery_alias)); + Ok(Transformed::no(LogicalPlan::Filter(filter))) + } + _ => { + let new_filter = LogicalPlan::Filter(Filter::try_new( + new_predicate, + Arc::clone(&subquery_alias.input), + )?); + insert_below( + LogicalPlan::SubqueryAlias(subquery_alias), + new_filter, + ) + } + } } LogicalPlan::Projection(projection) => { let predicates = split_conjunction_owned(filter.predicate.clone()); @@ -1667,6 +1790,53 @@ mod tests { assert_optimized_plan_eq(plan, expected) } + #[test] + fn distinct_on() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .distinct_on(vec![col("a")], vec![col("a"), col("b")], None)? + .filter(col("a").eq(lit(1i64)))? + .build()?; + // filter is on the same subquery as the distinct, so it should be pushed down + let expected = "\ + DistinctOn: on_expr=[[test.a]], select_expr=[[a, b]], sort_expr=[[]]\ + \n TableScan: test, full_filters=[a = Int64(1)]"; + assert_optimized_plan_eq(plan, expected) + } + + #[test] + fn subquery_distinct_on_filter_on_distinct_column() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .distinct_on(vec![col("a")], vec![col("a"), col("b")], None)? + .alias("test2")? + .filter(col("a").eq(lit(1i64)))? + .build()?; + // filter is on the distinct column, so it can be pushed down + let expected = "\ + SubqueryAlias: test2\ + \n DistinctOn: on_expr=[[test.a]], select_expr=[[a, b]], sort_expr=[[]]\ + \n TableScan: test, full_filters=[a = Int64(1)]"; + assert_optimized_plan_eq(plan, expected) + } + + #[test] + fn subquery_distinct_on_filter_not_on_distinct_column() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .distinct_on(vec![col("a")], vec![col("a"), col("b")], None)? + .alias("test2")? + .filter(col("b").eq(lit(1i64)))? + .build()?; + // filter is not on the distinct column, so it cannot be pushed down + let expected = "\ + Filter: test2.b = Int64(1)\ + \n SubqueryAlias: test2\ + \n DistinctOn: on_expr=[[test.a]], select_expr=[[a, b]], sort_expr=[[]]\ + \n TableScan: test"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn union_all() -> Result<()> { let table_scan = test_table_scan()?;