diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 273766edac34..080ec074d3c3 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -15,23 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use crate::analyzer::AnalyzerRule; +use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRewriter, -}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; -use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; -use datafusion_expr::expr_rewriter::rewrite_preserving_name; -use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::Expr::ScalarSubquery; -use datafusion_expr::{ - aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan, - LogicalPlanBuilder, Projection, Sort, Subquery, +use datafusion_expr::expr::{ + AggregateFunction, AggregateFunctionDefinition, WindowFunction, }; +use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// @@ -47,7 +41,8 @@ impl CountWildcardRule { impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_down(&analyze_internal).data() + plan.transform_down_with_subqueries(&analyze_internal) + .data() } fn name(&self) -> &str { @@ -55,173 +50,53 @@ impl AnalyzerRule for CountWildcardRule { } } -fn analyze_internal(plan: LogicalPlan) -> Result> { - let mut rewriter = CountWildcardRewriter {}; - match plan { - LogicalPlan::Window(window) => { - let window_expr = window - .window_expr - .iter() - .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) - .collect::>>()?; - - Ok(Transformed::yes( - LogicalPlanBuilder::from((*window.input).clone()) - .window(window_expr)? - .build()?, - )) - } - LogicalPlan::Aggregate(agg) => { - let aggr_expr = agg - .aggr_expr - .iter() - .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) - .collect::>>()?; - - Ok(Transformed::yes(LogicalPlan::Aggregate( - Aggregate::try_new(agg.input.clone(), agg.group_expr, aggr_expr)?, - ))) - } - LogicalPlan::Sort(Sort { expr, input, fetch }) => { - let sort_expr = expr - .iter() - .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) - .collect::>>()?; - Ok(Transformed::yes(LogicalPlan::Sort(Sort { - expr: sort_expr, - input, - fetch, - }))) - } - LogicalPlan::Projection(projection) => { - let projection_expr = projection - .expr - .iter() - .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) - .collect::>>()?; - Ok(Transformed::yes(LogicalPlan::Projection( - Projection::try_new(projection_expr, projection.input)?, - ))) - } - LogicalPlan::Filter(Filter { - predicate, input, .. - }) => { - let predicate = rewrite_preserving_name(predicate, &mut rewriter)?; - Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new( - predicate, input, - )?))) - } - - _ => Ok(Transformed::no(plan)), - } +fn is_wildcard(expr: &Expr) -> bool { + matches!(expr, Expr::Wildcard { qualifier: None }) } -struct CountWildcardRewriter {} - -impl TreeNodeRewriter for CountWildcardRewriter { - type Node = Expr; - - fn f_up(&mut self, old_expr: Expr) -> Result> { - Ok(match old_expr.clone() { - Expr::WindowFunction(expr::WindowFunction { - fun: - expr::WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ), - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) if args.len() == 1 => match args[0] { - Expr::Wildcard { qualifier: None } => { - Transformed::yes(Expr::WindowFunction(expr::WindowFunction { - fun: expr::WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ), - args: vec![lit(COUNT_STAR_EXPANSION)], - partition_by, - order_by, - window_frame, - null_treatment, - })) - } +fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { + matches!( + &aggregate_function.func_def, + AggregateFunctionDefinition::BuiltIn( + datafusion_expr::aggregate_function::AggregateFunction::Count, + ) + ) && aggregate_function.args.len() == 1 + && is_wildcard(&aggregate_function.args[0]) +} - _ => Transformed::no(old_expr), - }, - Expr::AggregateFunction(AggregateFunction { - func_def: - AggregateFunctionDefinition::BuiltIn( - aggregate_function::AggregateFunction::Count, - ), - args, - distinct, - filter, - order_by, - null_treatment, - }) if args.len() == 1 => match args[0] { - Expr::Wildcard { qualifier: None } => { - Transformed::yes(Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![lit(COUNT_STAR_EXPANSION)], - distinct, - filter, - order_by, - null_treatment, - ))) - } - _ => Transformed::no(old_expr), - }, +fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { + matches!( + &window_function.fun, + WindowFunctionDefinition::AggregateFunction( + datafusion_expr::aggregate_function::AggregateFunction::Count, + ) + ) && window_function.args.len() == 1 + && is_wildcard(&window_function.args[0]) +} - ScalarSubquery(Subquery { - subquery, - outer_ref_columns, - }) => subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)? - .update_data(|new_plan| { - ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns, - }) - }), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => subquery - .subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)? - .update_data(|new_plan| { - Expr::InSubquery(InSubquery::new( - expr, - Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - )) - }), - Expr::Exists(expr::Exists { subquery, negated }) => subquery - .subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)? - .update_data(|new_plan| { - Expr::Exists(expr::Exists { - subquery: Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - }) - }), - _ => Transformed::no(old_expr), - }) - } +fn analyze_internal(plan: LogicalPlan) -> Result> { + let name_preserver = NamePreserver::new(&plan); + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr)?; + let transformed_expr = expr.transform_up(&|expr| match expr { + Expr::WindowFunction(mut window_function) + if is_count_star_window_aggregate(&window_function) => + { + window_function.args = vec![lit(COUNT_STAR_EXPANSION)]; + Ok(Transformed::yes(Expr::WindowFunction(window_function))) + } + Expr::AggregateFunction(mut aggregate_function) + if is_count_star_aggregate(&aggregate_function) => + { + aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)]; + Ok(Transformed::yes(Expr::AggregateFunction( + aggregate_function, + ))) + } + _ => Ok(Transformed::no(expr)), + })?; + transformed_expr.map_data(|data| original_name.restore(data)) + }) } #[cfg(test)] @@ -233,9 +108,10 @@ mod tests { use datafusion_expr::expr::Sort; use datafusion_expr::{ col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder, - max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr, + max, out_ref_col, scalar_subquery, sum, wildcard, AggregateFunction, Expr, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; + use std::sync::Arc; fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { assert_analyzed_plan_eq_display_indent( @@ -381,6 +257,17 @@ mod tests { assert_plan_eq(&plan, expected) } + #[test] + fn test_count_wildcard_on_non_count_aggregate() -> Result<()> { + let table_scan = test_table_scan()?; + let err = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![sum(wildcard())]) + .unwrap_err() + .to_string(); + assert!(err.contains("Error during planning: No function matches the given name and argument types 'SUM(Null)'."), "{err}"); + Ok(()) + } + #[test] fn test_count_wildcard_on_nesting() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index deb493e09953..4dd3222a32cf 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -64,7 +64,7 @@ impl ApplyFunctionRewrites { let original_name = name_preserver.save(&expr)?; // recursively transform the expression, applying the rewrites at each step - let result = expr.transform_up(&|expr| { + let transformed_expr = expr.transform_up(&|expr| { let mut result = Transformed::no(expr); for rewriter in self.function_rewrites.iter() { result = result.transform_data(|expr| { @@ -74,7 +74,7 @@ impl ApplyFunctionRewrites { Ok(result) })?; - result.map_data(|expr| original_name.restore(expr)) + transformed_expr.map_data(|expr| original_name.restore(expr)) }) } }