diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 85958223ac97..6e7efaf39e3e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -281,6 +281,15 @@ pub enum LogicalPlan { RecursiveQuery(RecursiveQuery), } +impl Default for LogicalPlan { + fn default() -> Self { + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + }) + } +} + impl LogicalPlan { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 295039af2f19..68339a84649d 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -693,8 +693,9 @@ impl OptimizerRule for PushDownFilter { insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter) } LogicalPlan::Projection(projection) => { + let predicates = split_conjunction_owned(filter.predicate.clone()); let (new_projection, keep_predicate) = - rewrite_projection(filter.predicate.clone(), projection)?; + rewrite_projection(predicates, projection)?; if new_projection.transformed { match keep_predicate { None => Ok(new_projection), @@ -709,41 +710,54 @@ impl OptimizerRule for PushDownFilter { } } LogicalPlan::Unnest(mut unnest) => { - // collect all the Expr::Column in predicate recursively - let mut accum: HashSet = HashSet::new(); - expr_to_columns(&filter.predicate, &mut accum)?; + let predicates = split_conjunction_owned(filter.predicate.clone()); + let mut non_unnest_predicates = vec![]; + let mut unnest_predicates = vec![]; + for predicate in predicates { + // collect all the Expr::Column in predicate recursively + let mut accum: HashSet = HashSet::new(); + expr_to_columns(&predicate, &mut accum)?; + + if unnest.exec_columns.iter().any(|c| accum.contains(c)) { + unnest_predicates.push(predicate); + } else { + non_unnest_predicates.push(predicate); + } + } - if unnest.exec_columns.iter().any(|c| accum.contains(c)) { + // Unnest predicates should not be pushed down. + // If no non-unnest predicates exist, early return + if non_unnest_predicates.is_empty() { filter.input = Arc::new(LogicalPlan::Unnest(unnest)); return Ok(Transformed::no(LogicalPlan::Filter(filter))); } - // Unnest is built above Projection, so we only take Projection into consideration - match unwrap_arc(unnest.input) { - LogicalPlan::Projection(projection) => { - let (new_projection, keep_predicate) = - rewrite_projection(filter.predicate.clone(), projection)?; - unnest.input = Arc::new(new_projection.data); - - if new_projection.transformed { - match keep_predicate { - None => Ok(Transformed::yes(LogicalPlan::Unnest(unnest))), - Some(keep_predicate) => Ok(Transformed::yes( - LogicalPlan::Filter(Filter::try_new( - keep_predicate, - Arc::new(LogicalPlan::Unnest(unnest)), - )?), - )), - } - } else { - filter.input = Arc::new(LogicalPlan::Unnest(unnest)); - Ok(Transformed::no(LogicalPlan::Filter(filter))) - } - } - child => { - filter.input = Arc::new(child); - Ok(Transformed::no(LogicalPlan::Filter(filter))) - } + // Push down non-unnest filter predicate + // Unnest + // Unenst Input (Projection) + // -> rewritten to + // Unnest + // Filter + // Unenst Input (Projection) + + let unnest_input = std::mem::take(&mut unnest.input); + + let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new( + conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty. + unnest_input, + )?); + + // Directly assign new filter plan as the new unnest's input. + // The new filter plan will go through another rewrite pass since the rule itself + // is applied recursively to all the child from top to down + let unnest_plan = + insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?; + + match conjunction(unnest_predicates) { + None => Ok(unnest_plan), + Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter( + Filter::try_new(predicate, Arc::new(unnest_plan.data))?, + ))), } } LogicalPlan::Union(ref union) => { @@ -958,6 +972,10 @@ impl OptimizerRule for PushDownFilter { /// `plan` is a LogicalPlan for `projection` with possibly a new FilterExec below it. /// `remaining_predicate` is any part of the predicate that could not be pushed down /// +/// # Args +/// - predicates: Split predicates like `[foo=5, bar=6]` +/// - projection: The target projection plan to push down the predicates +/// /// # Example /// /// Pushing a predicate like `foo=5 AND bar=6` with an input plan like this: @@ -974,7 +992,7 @@ impl OptimizerRule for PushDownFilter { /// ... /// ``` fn rewrite_projection( - predicate: Expr, + predicates: Vec, projection: Projection, ) -> Result<(Transformed, Option)> { // A projection is filter-commutable if it do not contain volatile predicates or contain volatile @@ -994,7 +1012,7 @@ fn rewrite_projection( let mut push_predicates = vec![]; let mut keep_predicates = vec![]; - for expr in split_conjunction_owned(predicate) { + for expr in predicates { if contain(&expr, &volatile_map) { keep_predicates.push(expr); } else { diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 5029ab170a18..3ca187ddee84 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -67,17 +67,34 @@ select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where 5 2 # Could push the filter (column1 = 2) down below unnest -# https://github.com/apache/datafusion/issues/11016 query TT explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; ---- logical_plan 01)Projection: unnest(v.column2) AS uc2, v.column1 -02)--Filter: unnest(v.column2) > Int64(3) AND v.column1 = Int64(2) +02)--Filter: unnest(v.column2) > Int64(3) 03)----Unnest: lists[unnest(v.column2)] structs[] 04)------Projection: v.column2 AS unnest(v.column2), v.column1 -05)--------TableScan: v projection=[column1, column2] +05)--------Filter: v.column1 = Int64(2) +06)----------TableScan: v projection=[column1, column2] + +query II +select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; +---- +3 2 +4 2 +5 2 +# only non-unnest filter in AND clause could be pushed down +query TT +explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; +---- +logical_plan +01)Projection: unnest(v.column2) AS uc2, v.column1 +02)--Filter: unnest(v.column2) > Int64(3) OR v.column1 = Int64(2) +03)----Unnest: lists[unnest(v.column2)] structs[] +04)------Projection: v.column2 AS unnest(v.column2), v.column1 +05)--------TableScan: v projection=[column1, column2] statement ok drop table v;