Skip to content

Commit

Permalink
Avoid copies in CountWildcardRule via TreeNode API
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Apr 12, 2024
1 parent 952c98e commit 692fc11
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 181 deletions.
211 changes: 32 additions & 179 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::analyzer::AnalyzerRule;

use crate::utils::NamePreserver;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRewriter,
};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery};
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::Expr::ScalarSubquery;
use datafusion_expr::{
aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan,
LogicalPlanBuilder, Projection, Sort, Subquery,
};
use datafusion_expr::{lit, Expr, LogicalPlan};

/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
///
Expand All @@ -47,181 +38,43 @@ impl CountWildcardRule {

impl AnalyzerRule for CountWildcardRule {
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
plan.transform_down(&analyze_internal).data()
plan.transform_down_with_subqueries(&analyze_internal)
.data()
}

fn name(&self) -> &str {
"count_wildcard_rule"
}
}

fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
let mut rewriter = CountWildcardRewriter {};
match plan {
LogicalPlan::Window(window) => {
let window_expr = window
.window_expr
.iter()
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
.collect::<Result<Vec<_>>>()?;

Ok(Transformed::yes(
LogicalPlanBuilder::from((*window.input).clone())
.window(window_expr)?
.build()?,
))
}
LogicalPlan::Aggregate(agg) => {
let aggr_expr = agg
.aggr_expr
.iter()
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
.collect::<Result<Vec<_>>>()?;

Ok(Transformed::yes(LogicalPlan::Aggregate(
Aggregate::try_new(agg.input.clone(), agg.group_expr, aggr_expr)?,
)))
}
LogicalPlan::Sort(Sort { expr, input, fetch }) => {
let sort_expr = expr
.iter()
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
.collect::<Result<Vec<_>>>()?;
Ok(Transformed::yes(LogicalPlan::Sort(Sort {
expr: sort_expr,
input,
fetch,
})))
}
LogicalPlan::Projection(projection) => {
let projection_expr = projection
.expr
.iter()
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
.collect::<Result<Vec<_>>>()?;
Ok(Transformed::yes(LogicalPlan::Projection(
Projection::try_new(projection_expr, projection.input)?,
)))
}
LogicalPlan::Filter(Filter {
predicate, input, ..
}) => {
let predicate = rewrite_preserving_name(predicate, &mut rewriter)?;
Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new(
predicate, input,
)?)))
}

_ => Ok(Transformed::no(plan)),
}
fn is_wildcard(expr: &Expr) -> bool {
matches!(expr, Expr::Wildcard { qualifier: None })
}

struct CountWildcardRewriter {}

impl TreeNodeRewriter for CountWildcardRewriter {
type Node = Expr;

fn f_up(&mut self, old_expr: Expr) -> Result<Transformed<Expr>> {
Ok(match old_expr.clone() {
Expr::WindowFunction(expr::WindowFunction {
fun:
expr::WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Count,
),
args,
partition_by,
order_by,
window_frame,
null_treatment,
}) if args.len() == 1 => match args[0] {
Expr::Wildcard { qualifier: None } => {
Transformed::yes(Expr::WindowFunction(expr::WindowFunction {
fun: expr::WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Count,
),
args: vec![lit(COUNT_STAR_EXPANSION)],
partition_by,
order_by,
window_frame,
null_treatment,
}))
}

_ => Transformed::no(old_expr),
},
Expr::AggregateFunction(AggregateFunction {
func_def:
AggregateFunctionDefinition::BuiltIn(
aggregate_function::AggregateFunction::Count,
),
args,
distinct,
filter,
order_by,
null_treatment,
}) if args.len() == 1 => match args[0] {
Expr::Wildcard { qualifier: None } => {
Transformed::yes(Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Count,
vec![lit(COUNT_STAR_EXPANSION)],
distinct,
filter,
order_by,
null_treatment,
)))
}
_ => Transformed::no(old_expr),
},

ScalarSubquery(Subquery {
subquery,
outer_ref_columns,
}) => subquery
.as_ref()
.clone()
.transform_down(&analyze_internal)?
.update_data(|new_plan| {
ScalarSubquery(Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns,
})
}),
Expr::InSubquery(InSubquery {
expr,
subquery,
negated,
}) => subquery
.subquery
.as_ref()
.clone()
.transform_down(&analyze_internal)?
.update_data(|new_plan| {
Expr::InSubquery(InSubquery::new(
expr,
Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns: subquery.outer_ref_columns,
},
negated,
))
}),
Expr::Exists(expr::Exists { subquery, negated }) => subquery
.subquery
.as_ref()
.clone()
.transform_down(&analyze_internal)?
.update_data(|new_plan| {
Expr::Exists(expr::Exists {
subquery: Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns: subquery.outer_ref_columns,
},
negated,
})
}),
_ => Transformed::no(old_expr),
})
}
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr)?;
let transformed_expr = expr.transform_up(&|expr| match expr {
Expr::WindowFunction(mut window_function)
if window_function.args.len() == 1
&& is_wildcard(&window_function.args[0]) =>
{
window_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::WindowFunction(window_function)))
}
Expr::AggregateFunction(mut aggregate_function)
if aggregate_function.args.len() == 1
&& is_wildcard(&aggregate_function.args[0]) =>
{
aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::AggregateFunction(
aggregate_function,
)))
}
_ => Ok(Transformed::no(expr)),
})?;
transformed_expr.map_data(|data| original_name.restore(data))
})
}

#[cfg(test)]
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/analyzer/function_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl ApplyFunctionRewrites {
let original_name = name_preserver.save(&expr)?;

// recursively transform the expression, applying the rewrites at each step
let result = expr.transform_up(&|expr| {
let transformed_expr = expr.transform_up(&|expr| {
let mut result = Transformed::no(expr);
for rewriter in self.function_rewrites.iter() {
result = result.transform_data(|expr| {
Expand All @@ -74,7 +74,7 @@ impl ApplyFunctionRewrites {
Ok(result)
})?;

result.map_data(|expr| original_name.restore(expr))
transformed_expr.map_data(|expr| original_name.restore(expr))
})
}
}
Expand Down

0 comments on commit 692fc11

Please sign in to comment.