From b978cf8236436038a106ed94fb0d7eaa6ba99962 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Tue, 22 Oct 2024 04:57:04 +0200 Subject: [PATCH] Support filter in cross join elimination (#13025) * Support filter in cross join elimination * Support filter in cross join elimination * Support filter in cross join elimination * Support filter in cross join elimination --- .../optimizer/src/eliminate_cross_join.rs | 61 +++++++++++-------- datafusion/sqllogictest/test_files/join.slt | 2 +- 2 files changed, 38 insertions(+), 25 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index bce5c77ca674..8a365fb389be 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -22,13 +22,13 @@ use crate::{OptimizerConfig, OptimizerRule}; use crate::join_key_set::JoinKeySet; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::Result; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, }; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; -use datafusion_expr::{build_join_schema, ExprSchemable, Operator}; +use datafusion_expr::{and, build_join_schema, ExprSchemable, Operator}; #[derive(Default, Debug)] pub struct EliminateCrossJoin; @@ -88,6 +88,7 @@ impl OptimizerRule for EliminateCrossJoin { let plan_schema = Arc::clone(plan.schema()); let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec = vec![]; + let mut all_filters: Vec = vec![]; let parent_predicate = if let LogicalPlan::Filter(filter) = plan { // if input isn't a join that can potentially be rewritten @@ -116,6 +117,7 @@ impl OptimizerRule for EliminateCrossJoin { Arc::unwrap_or_clone(input), &mut possible_join_keys, &mut all_inputs, + &mut all_filters, )?; extract_possible_join_keys(&predicate, &mut possible_join_keys); @@ -130,7 +132,12 @@ impl OptimizerRule for EliminateCrossJoin { if !can_flatten_join_inputs(&plan) { return Ok(Transformed::no(plan)); } - flatten_join_inputs(plan, &mut possible_join_keys, &mut all_inputs)?; + flatten_join_inputs( + plan, + &mut possible_join_keys, + &mut all_inputs, + &mut all_filters, + )?; None } else { // recursively try to rewrite children @@ -158,6 +165,13 @@ impl OptimizerRule for EliminateCrossJoin { )); } + if !all_filters.is_empty() { + // Add any filters on top - PushDownFilter can push filters down to applicable join + let first = all_filters.swap_remove(0); + let predicate = all_filters.into_iter().fold(first, and); + left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?); + } + let Some(predicate) = parent_predicate else { return Ok(Transformed::yes(left)); }; @@ -206,25 +220,25 @@ fn flatten_join_inputs( plan: LogicalPlan, possible_join_keys: &mut JoinKeySet, all_inputs: &mut Vec, + all_filters: &mut Vec, ) -> Result<()> { match plan { LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { - // checked in can_flatten_join_inputs - if join.filter.is_some() { - return internal_err!( - "should not have filter in inner join in flatten_join_inputs" - ); + if let Some(filter) = join.filter { + all_filters.push(filter); } possible_join_keys.insert_all_owned(join.on); flatten_join_inputs( Arc::unwrap_or_clone(join.left), possible_join_keys, all_inputs, + all_filters, )?; flatten_join_inputs( Arc::unwrap_or_clone(join.right), possible_join_keys, all_inputs, + all_filters, )?; } LogicalPlan::CrossJoin(join) => { @@ -232,11 +246,13 @@ fn flatten_join_inputs( Arc::unwrap_or_clone(join.left), possible_join_keys, all_inputs, + all_filters, )?; flatten_join_inputs( Arc::unwrap_or_clone(join.right), possible_join_keys, all_inputs, + all_filters, )?; } _ => { @@ -253,13 +269,7 @@ fn flatten_join_inputs( fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { // can only flatten inner / cross joins match plan { - LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { - // The filter of inner join will lost, skip this rule. - // issue: https://github.com/apache/datafusion/issues/4844 - if join.filter.is_some() { - return false; - } - } + LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {} LogicalPlan::CrossJoin(_) => {} _ => return false, }; @@ -467,12 +477,6 @@ mod tests { assert_eq!(&starting_schema, optimized_plan.schema()) } - fn assert_optimization_rule_fails(plan: LogicalPlan) { - let rule = EliminateCrossJoin::new(); - let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); - assert!(!transformed_plan.transformed) - } - #[test] fn eliminate_cross_with_simple_and() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -642,8 +646,7 @@ mod tests { } #[test] - /// See https://github.com/apache/datafusion/issues/7530 - fn eliminate_cross_not_possible_nested_inner_join_with_filter() -> Result<()> { + fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; let t3 = test_table_scan_with_name("t3")?; @@ -660,7 +663,17 @@ mod tests { .filter(col("t1.a").gt(lit(15u32)))? .build()?; - assert_optimization_rule_fails(plan); + let expected = vec![ + "Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]" + ]; + + assert_optimized_plan_eq(plan, expected); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index fe9ceaa7907a..39f903a58714 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -1152,7 +1152,7 @@ logical_plan 01)Projection: t1.v0, t1.v1, t5.v2, t5.v3, t5.v4, t0.v0, t0.v1 02)--Inner Join: CAST(t1.v0 AS Float64) = t0.v1 Filter: t0.v1 + CAST(t5.v0 AS Float64) > Float64(0) 03)----Projection: t1.v0, t1.v1, t5.v0, t5.v2, t5.v3, t5.v4 -04)------Inner Join: Using t1.v0 = t5.v0, t1.v1 = t5.v1 +04)------Inner Join: t1.v0 = t5.v0, t1.v1 = t5.v1 05)--------TableScan: t1 projection=[v0, v1] 06)--------TableScan: t5 projection=[v0, v1, v2, v3, v4] 07)----TableScan: t0 projection=[v0, v1]