From ee2003d4d4720507d3e6db5565e702b24037d016 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 12 Apr 2024 14:53:27 -0400 Subject: [PATCH] Avoid copies in `CountWildcardRule` via TreeNode API --- .../src/analyzer/count_wildcard_rule.rs | 195 ++---------------- .../src/analyzer/function_rewrite.rs | 4 +- 2 files changed, 20 insertions(+), 179 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 273766edac34b..7b641517c0b3a 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -19,19 +19,12 @@ use std::sync::Arc; use crate::analyzer::AnalyzerRule; +use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRewriter, -}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; -use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; -use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::Expr::ScalarSubquery; -use datafusion_expr::{ - aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan, - LogicalPlanBuilder, Projection, Sort, Subquery, -}; +use datafusion_expr::{lit, Expr, LogicalPlan}; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// @@ -47,7 +40,8 @@ impl CountWildcardRule { impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_down(&analyze_internal).data() + plan.transform_down_with_subqueries(&analyze_internal) + .data() } fn name(&self) -> &str { @@ -56,172 +50,19 @@ impl AnalyzerRule for CountWildcardRule { } fn analyze_internal(plan: LogicalPlan) -> Result> { - let mut rewriter = CountWildcardRewriter {}; - match plan { - LogicalPlan::Window(window) => { - let window_expr = window - .window_expr - .iter() - .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) - .collect::>>()?; - - Ok(Transformed::yes( - LogicalPlanBuilder::from((*window.input).clone()) - .window(window_expr)? - .build()?, - )) - } - LogicalPlan::Aggregate(agg) => { - let aggr_expr = agg - .aggr_expr - .iter() - .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) - .collect::>>()?; - - Ok(Transformed::yes(LogicalPlan::Aggregate( - Aggregate::try_new(agg.input.clone(), agg.group_expr, aggr_expr)?, - ))) - } - LogicalPlan::Sort(Sort { expr, input, fetch }) => { - let sort_expr = expr - .iter() - .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) - .collect::>>()?; - Ok(Transformed::yes(LogicalPlan::Sort(Sort { - expr: sort_expr, - input, - fetch, - }))) - } - LogicalPlan::Projection(projection) => { - let projection_expr = projection - .expr - .iter() - .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) - .collect::>>()?; - Ok(Transformed::yes(LogicalPlan::Projection( - Projection::try_new(projection_expr, projection.input)?, - ))) - } - LogicalPlan::Filter(Filter { - predicate, input, .. - }) => { - let predicate = rewrite_preserving_name(predicate, &mut rewriter)?; - Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new( - predicate, input, - )?))) - } - - _ => Ok(Transformed::no(plan)), - } -} - -struct CountWildcardRewriter {} - -impl TreeNodeRewriter for CountWildcardRewriter { - type Node = Expr; - - fn f_up(&mut self, old_expr: Expr) -> Result> { - Ok(match old_expr.clone() { - Expr::WindowFunction(expr::WindowFunction { - fun: - expr::WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ), - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) if args.len() == 1 => match args[0] { - Expr::Wildcard { qualifier: None } => { - Transformed::yes(Expr::WindowFunction(expr::WindowFunction { - fun: expr::WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ), - args: vec![lit(COUNT_STAR_EXPANSION)], - partition_by, - order_by, - window_frame, - null_treatment, - })) - } - - _ => Transformed::no(old_expr), - }, - Expr::AggregateFunction(AggregateFunction { - func_def: - AggregateFunctionDefinition::BuiltIn( - aggregate_function::AggregateFunction::Count, - ), - args, - distinct, - filter, - order_by, - null_treatment, - }) if args.len() == 1 => match args[0] { - Expr::Wildcard { qualifier: None } => { - Transformed::yes(Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![lit(COUNT_STAR_EXPANSION)], - distinct, - filter, - order_by, - null_treatment, - ))) - } - _ => Transformed::no(old_expr), - }, - - ScalarSubquery(Subquery { - subquery, - outer_ref_columns, - }) => subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)? - .update_data(|new_plan| { - ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns, - }) - }), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => subquery - .subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)? - .update_data(|new_plan| { - Expr::InSubquery(InSubquery::new( - expr, - Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - )) - }), - Expr::Exists(expr::Exists { subquery, negated }) => subquery - .subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)? - .update_data(|new_plan| { - Expr::Exists(expr::Exists { - subquery: Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - }) - }), - _ => Transformed::no(old_expr), - }) - } + //let mut rewriter = CountWildcardRewriter {}; + let name_preserver = NamePreserver::new(&plan); + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr)?; + let transformed_expr = expr.transform_up(&|expr| { + if let Expr::Wildcard { qualifier: None } = expr { + Ok(Transformed::yes(lit(COUNT_STAR_EXPANSION))) + } else { + Ok(Transformed::no(expr)) + } + })?; + transformed_expr.map_data(|data| original_name.restore(data)) + }) } #[cfg(test)] diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index deb493e09953c..4dd3222a32cfe 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -64,7 +64,7 @@ impl ApplyFunctionRewrites { let original_name = name_preserver.save(&expr)?; // recursively transform the expression, applying the rewrites at each step - let result = expr.transform_up(&|expr| { + let transformed_expr = expr.transform_up(&|expr| { let mut result = Transformed::no(expr); for rewriter in self.function_rewrites.iter() { result = result.transform_data(|expr| { @@ -74,7 +74,7 @@ impl ApplyFunctionRewrites { Ok(result) })?; - result.map_data(|expr| original_name.restore(expr)) + transformed_expr.map_data(|expr| original_name.restore(expr)) }) } }