diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 8e088e7a0b56..42514537e28d 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -25,10 +25,9 @@ use crate::Result; /// These macros are used to determine continuation during transforming traversals. macro_rules! handle_transform_recursion { ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ - #[allow(clippy::redundant_closure_call)] $F_DOWN? .transform_children(|n| n.map_children($F_CHILD))? - .transform_parent(|n| $F_UP(n)) + .transform_parent($F_UP) }}; } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index f15c1c218db6..9e48c7b8a6f2 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -67,7 +67,7 @@ use datafusion_common::{ alias::AliasGenerator, config::{ConfigExtension, TableOptions}, exec_err, not_impl_err, plan_datafusion_err, plan_err, - tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, + tree_node::{TreeNodeRecursion, TreeNodeVisitor}, SchemaReference, TableReference, }; use datafusion_execution::registry::SerializerRegistry; @@ -2298,7 +2298,7 @@ impl SQLOptions { /// Return an error if the [`LogicalPlan`] has any nodes that are /// incompatible with this [`SQLOptions`]. pub fn verify_plan(&self, plan: &LogicalPlan) -> Result<()> { - plan.visit(&mut BadPlanVisitor::new(self))?; + plan.visit_with_subqueries(&mut BadPlanVisitor::new(self))?; Ok(()) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 3d40dcae0e4b..4f55bbfe3f4d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -34,8 +34,7 @@ use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, - grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, - split_conjunction, + grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction, }; use crate::{ build_join_schema, expr_vec_fmt, BinaryExpr, BuiltInWindowFunction, @@ -45,16 +44,19 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor, + Transformed, TransformedResult, TreeNode, TreeNodeIterator, TreeNodeRecursion, + TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{ - aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, - DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, - FunctionalDependencies, ParamValues, Result, TableReference, UnnestOptions, + aggregate_functional_dependencies, internal_err, map_until_stop_and_collect, + plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, + FunctionalDependence, FunctionalDependencies, ParamValues, Result, TableReference, + UnnestOptions, }; // backwards compatibility use crate::display::PgJsonVisitor; +use crate::tree_node::expr::transform_option_vec; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -248,9 +250,9 @@ impl LogicalPlan { /// DataFusion's optimizer attempts to optimize them away. pub fn expressions(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.inspect_expressions(|e| { + self.apply_expressions(|e| { exprs.push(e.clone()); - Ok(()) as Result<()> + Ok(TreeNodeRecursion::Continue) }) // closure always returns OK .unwrap(); @@ -261,13 +263,13 @@ impl LogicalPlan { /// logical plan nodes and all its descendant nodes. pub fn all_out_ref_exprs(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.inspect_expressions(|e| { + self.apply_expressions(|e| { find_out_reference_exprs(e).into_iter().for_each(|e| { if !exprs.contains(&e) { exprs.push(e) } }); - Ok(()) as Result<(), DataFusionError> + Ok(TreeNodeRecursion::Continue) }) // closure always returns OK .unwrap(); @@ -282,60 +284,81 @@ impl LogicalPlan { exprs } - /// Calls `f` on all expressions (non-recursively) in the current - /// logical plan node. This does not include expressions in any - /// children. + #[deprecated(since = "37.0.0", note = "Use `apply_expressions` instead")] pub fn inspect_expressions(self: &LogicalPlan, mut f: F) -> Result<(), E> where F: FnMut(&Expr) -> Result<(), E>, { + let mut err = Ok(()); + self.apply_expressions(|e| { + if let Err(e) = f(e) { + // save the error for later (it may not be a DataFusionError + err = Err(e); + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + // The closure always returns OK, so this will always too + .expect("no way to return error during recursion"); + + err + } + + /// Calls `f` on all expressions (non-recursively) in the current + /// logical plan node. This does not include expressions in any + /// children. + pub fn apply_expressions Result>( + &self, + mut f: F, + ) -> Result { match self { LogicalPlan::Projection(Projection { expr, .. }) => { - expr.iter().try_for_each(f) - } - LogicalPlan::Values(Values { values, .. }) => { - values.iter().flatten().try_for_each(f) + expr.iter().apply_until_stop(f) } + LogicalPlan::Values(Values { values, .. }) => values + .iter() + .apply_until_stop(|value| value.iter().apply_until_stop(&mut f)), LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), LogicalPlan::Repartition(Repartition { partitioning_scheme, .. }) => match partitioning_scheme { - Partitioning::Hash(expr, _) => expr.iter().try_for_each(f), - Partitioning::DistributeBy(expr) => expr.iter().try_for_each(f), - Partitioning::RoundRobinBatch(_) => Ok(()), + Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { + expr.iter().apply_until_stop(f) + } + Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), }, LogicalPlan::Window(Window { window_expr, .. }) => { - window_expr.iter().try_for_each(f) + window_expr.iter().apply_until_stop(f) } LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. - }) => group_expr.iter().chain(aggr_expr.iter()).try_for_each(f), + }) => group_expr + .iter() + .chain(aggr_expr.iter()) + .apply_until_stop(f), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). LogicalPlan::Join(Join { on, filter, .. }) => { on.iter() + // TODO: why we need to create an `Expr::eq`? Cloning `Expr` is costly... // it not ideal to create an expr here to analyze them, but could cache it on the Join itself .map(|(l, r)| Expr::eq(l.clone(), r.clone())) - .try_for_each(|e| f(&e))?; - - if let Some(filter) = filter.as_ref() { - f(filter) - } else { - Ok(()) - } + .apply_until_stop(|e| f(&e))? + .visit_sibling(|| filter.iter().apply_until_stop(f)) } - LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().try_for_each(f), + LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().apply_until_stop(f), LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - extension.node.expressions().iter().try_for_each(f) + extension.node.expressions().iter().apply_until_stop(f) } LogicalPlan::TableScan(TableScan { filters, .. }) => { - filters.iter().try_for_each(f) + filters.iter().apply_until_stop(f) } LogicalPlan::Unnest(Unnest { column, .. }) => { f(&Expr::Column(column.clone())) @@ -348,8 +371,8 @@ impl LogicalPlan { })) => on_expr .iter() .chain(select_expr.iter()) - .chain(sort_expr.clone().unwrap_or(vec![]).iter()) - .try_for_each(f), + .chain(sort_expr.iter().flatten()) + .apply_until_stop(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) @@ -366,10 +389,225 @@ impl LogicalPlan { | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) - | LogicalPlan::Prepare(_) => Ok(()), + | LogicalPlan::Prepare(_) => Ok(TreeNodeRecursion::Continue), } } + pub fn map_expressions Result>>( + self, + mut f: F, + ) -> Result> { + Ok(match self { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) + }), + LogicalPlan::Values(Values { schema, values }) => values + .into_iter() + .map_until_stop_and_collect(|value| { + value.into_iter().map_until_stop_and_collect(&mut f) + })? + .update_data(|values| LogicalPlan::Values(Values { schema, values })), + LogicalPlan::Filter(Filter { predicate, input }) => f(predicate)? + .update_data(|predicate| { + LogicalPlan::Filter(Filter { predicate, input }) + }), + LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + }) => match partitioning_scheme { + Partitioning::Hash(expr, usize) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| Partitioning::Hash(expr, usize)), + Partitioning::DistributeBy(expr) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(Partitioning::DistributeBy), + Partitioning::RoundRobinBatch(_) => Transformed::no(partitioning_scheme), + } + .update_data(|partitioning_scheme| { + LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + }) + }), + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) => window_expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|window_expr| { + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) + }), + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) => map_until_stop_and_collect!( + group_expr.into_iter().map_until_stop_and_collect(&mut f), + aggr_expr, + aggr_expr.into_iter().map_until_stop_and_collect(&mut f) + )? + .update_data(|(group_expr, aggr_expr)| { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) + }), + + // There are two part of expression for join, equijoin(on) and non-equijoin(filter). + // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. + // 2. the second part is non-equijoin(filter). + LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }) => map_until_stop_and_collect!( + on.into_iter().map_until_stop_and_collect( + |on| map_until_stop_and_collect!(f(on.0), on.1, f(on.1)) + ), + filter, + filter.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { + Ok(f(e)?.update_data(Some)) + }) + )? + .update_data(|(on, filter)| { + LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }) + }), + LogicalPlan::Sort(Sort { expr, input, fetch }) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), + LogicalPlan::Extension(Extension { node }) => { + // would be nice to avoid this copy -- maybe can + // update extension to just observer Exprs + node.expressions() + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|exprs| { + LogicalPlan::Extension(Extension { + node: UserDefinedLogicalNode::from_template( + node.as_ref(), + exprs.as_slice(), + node.inputs() + .into_iter() + .cloned() + .collect::>() + .as_slice(), + ), + }) + }) + } + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) => filters + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|filters| { + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) + }), + LogicalPlan::Unnest(Unnest { + input, + column, + schema, + options, + }) => f(Expr::Column(column))?.map_data(|column| match column { + Expr::Column(column) => Ok(LogicalPlan::Unnest(Unnest { + input, + column, + schema, + options, + })), + _ => internal_err!("Transformation should return Column"), + })?, + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + })) => map_until_stop_and_collect!( + on_expr.into_iter().map_until_stop_and_collect(&mut f), + select_expr, + select_expr.into_iter().map_until_stop_and_collect(&mut f), + sort_expr, + transform_option_vec(sort_expr, &mut f) + )? + .update_data(|(on_expr, select_expr, sort_expr)| { + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + })) + }), + // plans without expressions + LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Statement(_) + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Union(_) + | LogicalPlan::Distinct(Distinct::All(_)) + | LogicalPlan::Dml(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Prepare(_) => Transformed::no(self), + }) + } + /// returns all inputs of this `LogicalPlan` node. Does not /// include inputs to inputs, or subqueries. pub fn inputs(&self) -> Vec<&LogicalPlan> { @@ -417,7 +655,7 @@ impl LogicalPlan { pub fn using_columns(&self) -> Result>, DataFusionError> { let mut using_columns: Vec> = vec![]; - self.apply(&mut |plan| { + self.apply_with_subqueries(&mut |plan| { if let LogicalPlan::Join(Join { join_constraint: JoinConstraint::Using, on, @@ -1079,57 +1317,178 @@ impl LogicalPlan { } } +/// This macro is used to determine continuation during combined transforming +/// traversals. +macro_rules! handle_transform_recursion { + ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ + $F_DOWN? + .transform_children(|n| n.map_subqueries($F_CHILD))? + .transform_sibling(|n| n.map_children($F_CHILD))? + .transform_parent($F_UP) + }}; +} + +macro_rules! handle_transform_recursion_down { + ($F_DOWN:expr, $F_CHILD:expr) => {{ + $F_DOWN? + .transform_children(|n| n.map_subqueries($F_CHILD))? + .transform_sibling(|n| n.map_children($F_CHILD)) + }}; +} + +macro_rules! handle_transform_recursion_up { + ($SELF:expr, $F_CHILD:expr, $F_UP:expr) => {{ + $SELF + .map_subqueries($F_CHILD)? + .transform_sibling(|n| n.map_children($F_CHILD))? + .transform_parent(|n| $F_UP(n)) + }}; +} + impl LogicalPlan { - /// applies `op` to any subqueries in the plan - pub(crate) fn apply_subqueries(&self, op: &mut F) -> Result<()> - where - F: FnMut(&Self) -> Result, - { - self.inspect_expressions(|expr| { - // recursively look for subqueries - inspect_expr_pre(expr, |expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the collector sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.apply(op)?; - } - _ => {} + pub fn visit_with_subqueries>( + &self, + visitor: &mut V, + ) -> Result { + visitor + .f_down(self)? + .visit_children(|| { + self.apply_subqueries(|c| c.visit_with_subqueries(visitor)) + })? + .visit_sibling(|| self.apply_children(|c| c.visit_with_subqueries(visitor)))? + .visit_parent(|| visitor.f_up(self)) + } + + pub fn rewrite_with_subqueries>( + self, + rewriter: &mut R, + ) -> Result> { + handle_transform_recursion!( + rewriter.f_down(self), + |c| c.rewrite_with_subqueries(rewriter), + |n| rewriter.f_up(n) + ) + } + + pub fn apply_with_subqueries Result>( + &self, + f: &mut F, + ) -> Result { + f(self)? + .visit_children(|| self.apply_subqueries(|c| c.apply_with_subqueries(f)))? + .visit_sibling(|| self.apply_children(|c| c.apply_with_subqueries(f))) + } + + pub fn transform_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + self.transform_up_with_subqueries(f) + } + + pub fn transform_down_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + handle_transform_recursion_down!(f(self), |c| c.transform_down_with_subqueries(f)) + } + + pub fn transform_down_mut_with_subqueries< + F: FnMut(Self) -> Result>, + >( + self, + f: &mut F, + ) -> Result> { + handle_transform_recursion_down!(f(self), |c| c + .transform_down_mut_with_subqueries(f)) + } + + pub fn transform_up_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + handle_transform_recursion_up!(self, |c| c.transform_up_with_subqueries(f), f) + } + + pub fn transform_up_mut_with_subqueries< + F: FnMut(Self) -> Result>, + >( + self, + f: &mut F, + ) -> Result> { + handle_transform_recursion_up!(self, |c| c.transform_up_mut_with_subqueries(f), f) + } + + pub fn transform_down_up_with_subqueries< + FD: FnMut(Self) -> Result>, + FU: FnMut(Self) -> Result>, + >( + self, + f_down: &mut FD, + f_up: &mut FU, + ) -> Result> { + handle_transform_recursion!( + f_down(self), + |c| c.transform_down_up_with_subqueries(f_down, f_up), + f_up + ) + } + + fn apply_subqueries Result>( + &self, + mut f: F, + ) -> Result { + self.apply_expressions(|expr| { + expr.apply(&mut |expr| match expr { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + // use a synthetic plan so the collector sees a + // LogicalPlan::Subquery (even though it is + // actually a Subquery alias) + f(&LogicalPlan::Subquery(subquery.clone())) } - Ok::<(), DataFusionError>(()) + _ => Ok(TreeNodeRecursion::Continue), }) - })?; - Ok(()) + }) } - /// applies visitor to any subqueries in the plan - pub(crate) fn visit_subqueries(&self, v: &mut V) -> Result<()> - where - V: TreeNodeVisitor, - { - self.inspect_expressions(|expr| { - // recursively look for subqueries - inspect_expr_pre(expr, |expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the visitor sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.visit(v)?; - } - _ => {} + fn map_subqueries Result>>( + self, + mut f: F, + ) -> Result> { + self.map_expressions(|expr| { + expr.transform_down_mut(&mut |expr| match expr { + Expr::Exists(Exists { subquery, negated }) => { + f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::Exists(Exists { subquery, negated })) + } + _ => internal_err!("Transformation should return Subquery"), + }) } - Ok::<(), DataFusionError>(()) + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => Ok(Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + })), + _ => internal_err!("Transformation should return Subquery"), + }), + Expr::ScalarSubquery(subquery) => f(LogicalPlan::Subquery(subquery))? + .map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::ScalarSubquery(subquery)) + } + _ => internal_err!("Transformation should return Subquery"), + }), + _ => Ok(Transformed::no(expr)), }) - })?; - Ok(()) + }) } /// Return a `LogicalPlan` with all placeholders (e.g $1 $2, @@ -1165,8 +1524,8 @@ impl LogicalPlan { ) -> Result>, DataFusionError> { let mut param_types: HashMap> = HashMap::new(); - self.apply(&mut |plan| { - plan.inspect_expressions(|expr| { + self.apply_with_subqueries(&mut |plan| { + plan.apply_expressions(|expr| { expr.apply(&mut |expr| { if let Expr::Placeholder(Placeholder { id, data_type }) = expr { let prev = param_types.get(id); @@ -1183,13 +1542,10 @@ impl LogicalPlan { } } Ok(TreeNodeRecursion::Continue) - })?; - Ok::<(), DataFusionError>(()) - })?; - Ok(TreeNodeRecursion::Continue) - })?; - - Ok(param_types) + }) + }) + }) + .map(|_| param_types) } /// Return an Expr with all placeholders replaced with their @@ -1257,7 +1613,7 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let with_schema = false; let mut visitor = IndentVisitor::new(f, with_schema); - match self.0.visit(&mut visitor) { + match self.0.visit_with_subqueries(&mut visitor) { Ok(_) => Ok(()), Err(_) => Err(fmt::Error), } @@ -1300,7 +1656,7 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let with_schema = true; let mut visitor = IndentVisitor::new(f, with_schema); - match self.0.visit(&mut visitor) { + match self.0.visit_with_subqueries(&mut visitor) { Ok(_) => Ok(()), Err(_) => Err(fmt::Error), } @@ -1320,7 +1676,7 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut visitor = PgJsonVisitor::new(f); visitor.with_schema(true); - match self.0.visit(&mut visitor) { + match self.0.visit_with_subqueries(&mut visitor) { Ok(_) => Ok(()), Err(_) => Err(fmt::Error), } @@ -1369,12 +1725,16 @@ impl LogicalPlan { visitor.start_graph()?; visitor.pre_visit_plan("LogicalPlan")?; - self.0.visit(&mut visitor).map_err(|_| fmt::Error)?; + self.0 + .visit_with_subqueries(&mut visitor) + .map_err(|_| fmt::Error)?; visitor.post_visit_plan()?; visitor.set_with_schema(true); visitor.pre_visit_plan("Detailed LogicalPlan")?; - self.0.visit(&mut visitor).map_err(|_| fmt::Error)?; + self.0 + .visit_with_subqueries(&mut visitor) + .map_err(|_| fmt::Error)?; visitor.post_visit_plan()?; visitor.end_graph()?; @@ -2908,7 +3268,7 @@ digraph { fn visit_order() { let mut visitor = OkVisitor::default(); let plan = test_plan(); - let res = plan.visit(&mut visitor); + let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -2984,7 +3344,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor); + let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -3000,7 +3360,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor); + let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -3051,7 +3411,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor).unwrap_err(); + let res = plan.visit_with_subqueries(&mut visitor).unwrap_err(); assert_eq!( "This feature is not implemented: Error in pre_visit", res.strip_backtrace() @@ -3069,7 +3429,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor).unwrap_err(); + let res = plan.visit_with_subqueries(&mut visitor).unwrap_err(); assert_eq!( "This feature is not implemented: Error in post_visit", res.strip_backtrace() diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 97331720ce7d..85097f6249e1 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -412,7 +412,7 @@ where } /// &mut transform a Option<`Vec` of `Expr`s> -fn transform_option_vec( +pub fn transform_option_vec( ove: Option>, f: &mut F, ) -> Result>>> diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 7a6b1005fede..482fc96b519b 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -20,58 +20,11 @@ use crate::LogicalPlan; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeVisitor, + Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, }; use datafusion_common::Result; impl TreeNode for LogicalPlan { - fn apply Result>( - &self, - f: &mut F, - ) -> Result { - // Compared to the default implementation, we need to invoke - // [`Self::apply_subqueries`] before visiting its children - f(self)?.visit_children(|| { - self.apply_subqueries(f)?; - self.apply_children(|n| n.apply(f)) - }) - } - - /// To use, define a struct that implements the trait [`TreeNodeVisitor`] and then invoke - /// [`LogicalPlan::visit`]. - /// - /// For example, for a logical plan like: - /// - /// ```text - /// Projection: id - /// Filter: state Eq Utf8(\"CO\")\ - /// CsvScan: employee.csv projection=Some([0, 3])"; - /// ``` - /// - /// The sequence of visit operations would be: - /// ```text - /// visitor.pre_visit(Projection) - /// visitor.pre_visit(Filter) - /// visitor.pre_visit(CsvScan) - /// visitor.post_visit(CsvScan) - /// visitor.post_visit(Filter) - /// visitor.post_visit(Projection) - /// ``` - fn visit>( - &self, - visitor: &mut V, - ) -> Result { - // Compared to the default implementation, we need to invoke - // [`Self::visit_subqueries`] before visiting its children - visitor - .f_down(self)? - .visit_children(|| { - self.visit_subqueries(visitor)?; - self.apply_children(|n| n.visit(visitor)) - })? - .visit_parent(|| visitor.f_up(self)) - } - fn apply_children Result>( &self, f: F, @@ -85,8 +38,8 @@ impl TreeNode for LogicalPlan { ) -> Result> { let new_children = self .inputs() - .iter() - .map(|&c| c.clone()) + .into_iter() + .cloned() .map_until_stop_and_collect(f)?; // Propagate up `new_children.transformed` and `new_children.tnr` // along with the node containing transformed children. diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index b446fe2f320e..d0b83d24299b 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -155,8 +155,8 @@ impl Analyzer { /// Do necessary check and fail the invalid plan fn check_plan(plan: &LogicalPlan) -> Result<()> { - plan.apply(&mut |plan: &LogicalPlan| { - plan.inspect_expressions(|expr| { + plan.apply_with_subqueries(&mut |plan: &LogicalPlan| { + plan.apply_expressions(|expr| { // recursively look for subqueries expr.apply(&mut |expr| { match expr { @@ -168,11 +168,8 @@ fn check_plan(plan: &LogicalPlan) -> Result<()> { _ => {} }; Ok(TreeNodeRecursion::Continue) - })?; - Ok::<(), DataFusionError>(()) - })?; - Ok(TreeNodeRecursion::Continue) - })?; - - Ok(()) + }) + }) + }) + .map(|_| ()) } diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 038361c3ee8c..79375e52da1f 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -283,7 +283,7 @@ fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan { fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { let mut exprs = vec![]; - inner_plan.apply(&mut |plan| { + inner_plan.apply_with_subqueries(&mut |plan| { if let LogicalPlan::Filter(Filter { predicate, .. }) = plan { let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) .into_iter() diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 4143d52a053e..a8e323ff429f 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -21,7 +21,7 @@ use std::{ num::NonZeroUsize, }; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_expr::LogicalPlan; /// Non-unique identifier of a [`LogicalPlan`]. @@ -73,7 +73,7 @@ impl LogicalPlanSignature { /// Get total number of [`LogicalPlan`]s in the plan. fn get_node_number(plan: &LogicalPlan) -> NonZeroUsize { let mut node_number = 0; - plan.apply(&mut |_plan| { + plan.apply_with_subqueries(&mut |_plan| { node_number += 1; Ok(TreeNodeRecursion::Continue) })