Skip to content

Commit

Permalink
Avoid copies in CountWildcardRule via TreeNode API
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Apr 12, 2024
1 parent 952c98e commit ee2003d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 179 deletions.
195 changes: 18 additions & 177 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,12 @@ use std::sync::Arc;

use crate::analyzer::AnalyzerRule;

use crate::utils::NamePreserver;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRewriter,
};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery};
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::Expr::ScalarSubquery;
use datafusion_expr::{
aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan,
LogicalPlanBuilder, Projection, Sort, Subquery,
};
use datafusion_expr::{lit, Expr, LogicalPlan};

/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
///
Expand All @@ -47,7 +40,8 @@ impl CountWildcardRule {

impl AnalyzerRule for CountWildcardRule {
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
plan.transform_down(&analyze_internal).data()
plan.transform_down_with_subqueries(&analyze_internal)
.data()
}

fn name(&self) -> &str {
Expand All @@ -56,172 +50,19 @@ impl AnalyzerRule for CountWildcardRule {
}

fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
let mut rewriter = CountWildcardRewriter {};
match plan {
LogicalPlan::Window(window) => {
let window_expr = window
.window_expr
.iter()
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
.collect::<Result<Vec<_>>>()?;

Ok(Transformed::yes(
LogicalPlanBuilder::from((*window.input).clone())
.window(window_expr)?
.build()?,
))
}
LogicalPlan::Aggregate(agg) => {
let aggr_expr = agg
.aggr_expr
.iter()
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
.collect::<Result<Vec<_>>>()?;

Ok(Transformed::yes(LogicalPlan::Aggregate(
Aggregate::try_new(agg.input.clone(), agg.group_expr, aggr_expr)?,
)))
}
LogicalPlan::Sort(Sort { expr, input, fetch }) => {
let sort_expr = expr
.iter()
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
.collect::<Result<Vec<_>>>()?;
Ok(Transformed::yes(LogicalPlan::Sort(Sort {
expr: sort_expr,
input,
fetch,
})))
}
LogicalPlan::Projection(projection) => {
let projection_expr = projection
.expr
.iter()
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
.collect::<Result<Vec<_>>>()?;
Ok(Transformed::yes(LogicalPlan::Projection(
Projection::try_new(projection_expr, projection.input)?,
)))
}
LogicalPlan::Filter(Filter {
predicate, input, ..
}) => {
let predicate = rewrite_preserving_name(predicate, &mut rewriter)?;
Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new(
predicate, input,
)?)))
}

_ => Ok(Transformed::no(plan)),
}
}

struct CountWildcardRewriter {}

impl TreeNodeRewriter for CountWildcardRewriter {
type Node = Expr;

fn f_up(&mut self, old_expr: Expr) -> Result<Transformed<Expr>> {
Ok(match old_expr.clone() {
Expr::WindowFunction(expr::WindowFunction {
fun:
expr::WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Count,
),
args,
partition_by,
order_by,
window_frame,
null_treatment,
}) if args.len() == 1 => match args[0] {
Expr::Wildcard { qualifier: None } => {
Transformed::yes(Expr::WindowFunction(expr::WindowFunction {
fun: expr::WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Count,
),
args: vec![lit(COUNT_STAR_EXPANSION)],
partition_by,
order_by,
window_frame,
null_treatment,
}))
}

_ => Transformed::no(old_expr),
},
Expr::AggregateFunction(AggregateFunction {
func_def:
AggregateFunctionDefinition::BuiltIn(
aggregate_function::AggregateFunction::Count,
),
args,
distinct,
filter,
order_by,
null_treatment,
}) if args.len() == 1 => match args[0] {
Expr::Wildcard { qualifier: None } => {
Transformed::yes(Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Count,
vec![lit(COUNT_STAR_EXPANSION)],
distinct,
filter,
order_by,
null_treatment,
)))
}
_ => Transformed::no(old_expr),
},

ScalarSubquery(Subquery {
subquery,
outer_ref_columns,
}) => subquery
.as_ref()
.clone()
.transform_down(&analyze_internal)?
.update_data(|new_plan| {
ScalarSubquery(Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns,
})
}),
Expr::InSubquery(InSubquery {
expr,
subquery,
negated,
}) => subquery
.subquery
.as_ref()
.clone()
.transform_down(&analyze_internal)?
.update_data(|new_plan| {
Expr::InSubquery(InSubquery::new(
expr,
Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns: subquery.outer_ref_columns,
},
negated,
))
}),
Expr::Exists(expr::Exists { subquery, negated }) => subquery
.subquery
.as_ref()
.clone()
.transform_down(&analyze_internal)?
.update_data(|new_plan| {
Expr::Exists(expr::Exists {
subquery: Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns: subquery.outer_ref_columns,
},
negated,
})
}),
_ => Transformed::no(old_expr),
})
}
//let mut rewriter = CountWildcardRewriter {};
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr)?;
let transformed_expr = expr.transform_up(&|expr| {
if let Expr::Wildcard { qualifier: None } = expr {
Ok(Transformed::yes(lit(COUNT_STAR_EXPANSION)))
} else {
Ok(Transformed::no(expr))
}
})?;
transformed_expr.map_data(|data| original_name.restore(data))
})
}

#[cfg(test)]
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/analyzer/function_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl ApplyFunctionRewrites {
let original_name = name_preserver.save(&expr)?;

// recursively transform the expression, applying the rewrites at each step
let result = expr.transform_up(&|expr| {
let transformed_expr = expr.transform_up(&|expr| {
let mut result = Transformed::no(expr);
for rewriter in self.function_rewrites.iter() {
result = result.transform_data(|expr| {
Expand All @@ -74,7 +74,7 @@ impl ApplyFunctionRewrites {
Ok(result)
})?;

result.map_data(|expr| original_name.restore(expr))
transformed_expr.map_data(|expr| original_name.restore(expr))
})
}
}
Expand Down

0 comments on commit ee2003d

Please sign in to comment.