From 3abd1c2303660e65a4dd5d36ea5c0b97f43aa246 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 28 Nov 2024 09:22:52 +0100 Subject: [PATCH 1/2] Fix `LogicalPlan::transform_..._with_subqueries` methods --- datafusion/expr/src/logical_plan/plan.rs | 120 +++++++++++++++++- datafusion/expr/src/logical_plan/tree_node.rs | 38 +++--- 2 files changed, 139 insertions(+), 19 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index e9f4f1f80972..f963685bdfbe 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -3496,7 +3496,9 @@ mod tests { use crate::logical_plan::table_scan; use crate::{col, exists, in_subquery, lit, placeholder, GroupingSet}; - use datafusion_common::tree_node::{TransformedResult, TreeNodeVisitor}; + use datafusion_common::tree_node::{ + TransformedResult, TreeNodeRewriter, TreeNodeVisitor, + }; use datafusion_common::{not_impl_err, Constraint, ScalarValue}; use crate::test::function_stub::count; @@ -4157,4 +4159,120 @@ digraph { .unwrap(); assert_eq!(limit, new_limit); } + + #[test] + fn test_with_subqueries_jump() { + // The plan contains a `Project` node above a `Filter` node so returning + // `TreeNodeRecursion::Jump` on `Project` should cause not visiting `Filter`. + let plan = test_plan(); + + let mut filter_found = false; + plan.apply_with_subqueries(|plan| { + match plan { + LogicalPlan::Projection(..) => return Ok(TreeNodeRecursion::Jump), + LogicalPlan::Filter(..) => filter_found = true, + _ => {} + } + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + assert!(!filter_found); + + struct ProjectJumpVisitor { + filter_found: bool, + } + + impl ProjectJumpVisitor { + fn new() -> Self { + Self { + filter_found: false, + } + } + } + + impl<'n> TreeNodeVisitor<'n> for ProjectJumpVisitor { + type Node = LogicalPlan; + + fn f_down(&mut self, node: &'n Self::Node) -> Result { + match node { + LogicalPlan::Projection(..) => return Ok(TreeNodeRecursion::Jump), + LogicalPlan::Filter(..) => self.filter_found = true, + _ => {} + } + Ok(TreeNodeRecursion::Continue) + } + } + + let mut visitor = ProjectJumpVisitor::new(); + plan.visit_with_subqueries(&mut visitor).unwrap(); + assert!(!visitor.filter_found); + + let mut filter_found = false; + plan.clone() + .transform_down_with_subqueries(|plan| { + match plan { + LogicalPlan::Projection(..) => { + return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) + } + LogicalPlan::Filter(..) => filter_found = true, + _ => {} + } + Ok(Transformed::no(plan)) + }) + .unwrap(); + assert!(!filter_found); + + let mut filter_found = false; + plan.clone() + .transform_down_up_with_subqueries( + |plan| { + match plan { + LogicalPlan::Projection(..) => { + return Ok(Transformed::new( + plan, + false, + TreeNodeRecursion::Jump, + )) + } + LogicalPlan::Filter(..) => filter_found = true, + _ => {} + } + Ok(Transformed::no(plan)) + }, + |plan| Ok(Transformed::no(plan)), + ) + .unwrap(); + assert!(!filter_found); + + struct ProjectJumpRewriter { + filter_found: bool, + } + + impl ProjectJumpRewriter { + fn new() -> Self { + Self { + filter_found: false, + } + } + } + + impl TreeNodeRewriter for ProjectJumpRewriter { + type Node = LogicalPlan; + + fn f_down(&mut self, node: Self::Node) -> Result> { + match node { + LogicalPlan::Projection(..) => { + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)) + } + LogicalPlan::Filter(..) => self.filter_found = true, + _ => {} + } + Ok(Transformed::no(node)) + } + } + + let mut rewriter = ProjectJumpRewriter::new(); + plan.rewrite_with_subqueries(&mut rewriter).unwrap(); + assert!(!rewriter.filter_found); + } } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 6850c30f4f81..1539b69b4007 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -385,8 +385,10 @@ fn rewrite_extension_inputs Result {{ $F_DOWN? - .transform_children(|n| n.map_subqueries($F_CHILD))? - .transform_sibling(|n| n.map_children($F_CHILD))? + .transform_children(|n| { + n.map_subqueries($F_CHILD)? + .transform_sibling(|n| n.map_children($F_CHILD)) + })? .transform_parent($F_UP) }}; } @@ -675,9 +677,11 @@ impl LogicalPlan { visitor .f_down(self)? .visit_children(|| { - self.apply_subqueries(|c| c.visit_with_subqueries(visitor)) + self.apply_subqueries(|c| c.visit_with_subqueries(visitor))? + .visit_sibling(|| { + self.apply_children(|c| c.visit_with_subqueries(visitor)) + }) })? - .visit_sibling(|| self.apply_children(|c| c.visit_with_subqueries(visitor)))? .visit_parent(|| visitor.f_up(self)) } @@ -710,13 +714,12 @@ impl LogicalPlan { node: &LogicalPlan, f: &mut F, ) -> Result { - f(node)? - .visit_children(|| { - node.apply_subqueries(|c| apply_with_subqueries_impl(c, f)) - })? - .visit_sibling(|| { - node.apply_children(|c| apply_with_subqueries_impl(c, f)) - }) + f(node)?.visit_children(|| { + node.apply_subqueries(|c| apply_with_subqueries_impl(c, f))? + .visit_sibling(|| { + node.apply_children(|c| apply_with_subqueries_impl(c, f)) + }) + }) } apply_with_subqueries_impl(self, &mut f) @@ -746,13 +749,12 @@ impl LogicalPlan { node: LogicalPlan, f: &mut F, ) -> Result> { - f(node)? - .transform_children(|n| { - n.map_subqueries(|c| transform_down_with_subqueries_impl(c, f)) - })? - .transform_sibling(|n| { - n.map_children(|c| transform_down_with_subqueries_impl(c, f)) - }) + f(node)?.transform_children(|n| { + n.map_subqueries(|c| transform_down_with_subqueries_impl(c, f))? + .transform_sibling(|n| { + n.map_children(|c| transform_down_with_subqueries_impl(c, f)) + }) + }) } transform_down_with_subqueries_impl(self, &mut f) From 542c104a9e665422c601b26357f338ab546a8483 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 29 Nov 2024 16:08:03 +0100 Subject: [PATCH 2/2] add subquery tests --- datafusion/expr/src/logical_plan/plan.rs | 32 +++++++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index f963685bdfbe..7c7ad3ac2845 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -3494,7 +3494,9 @@ mod tests { use super::*; use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; - use crate::{col, exists, in_subquery, lit, placeholder, GroupingSet}; + use crate::{ + col, exists, in_subquery, lit, placeholder, scalar_subquery, GroupingSet, + }; use datafusion_common::tree_node::{ TransformedResult, TreeNodeRewriter, TreeNodeVisitor, @@ -4162,9 +4164,31 @@ digraph { #[test] fn test_with_subqueries_jump() { - // The plan contains a `Project` node above a `Filter` node so returning - // `TreeNodeRecursion::Jump` on `Project` should cause not visiting `Filter`. - let plan = test_plan(); + // The test plan contains a `Project` node above a `Filter` node, and the + // `Project` node contains a subquery plan with a `Filter` root node, so returning + // `TreeNodeRecursion::Jump` on `Project` should cause not visiting any of the + // `Filter`s. + let subquery_schema = + Schema::new(vec![Field::new("sub_id", DataType::Int32, false)]); + + let subquery_plan = + table_scan(TableReference::none(), &subquery_schema, Some(vec![0])) + .unwrap() + .filter(col("sub_id").eq(lit(0))) + .unwrap() + .build() + .unwrap(); + + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + + let plan = table_scan(TableReference::none(), &schema, Some(vec![0])) + .unwrap() + .filter(col("id").eq(lit(0))) + .unwrap() + .project(vec![col("id"), scalar_subquery(Arc::new(subquery_plan))]) + .unwrap() + .build() + .unwrap(); let mut filter_found = false; plan.apply_with_subqueries(|plan| {