Skip to content

Commit

Permalink
Avoid copies in CountWildcardRule via TreeNode API (#10066)
Browse files Browse the repository at this point in the history
* Avoid copies in `CountWildcardRule` via TreeNode API
  • Loading branch information
alamb authored Apr 15, 2024
1 parent 4e9f2d5 commit a165b7f
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 179 deletions.
241 changes: 64 additions & 177 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,17 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::analyzer::AnalyzerRule;

use crate::utils::NamePreserver;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRewriter,
};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery};
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::Expr::ScalarSubquery;
use datafusion_expr::{
aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan,
LogicalPlanBuilder, Projection, Sort, Subquery,
use datafusion_expr::expr::{
AggregateFunction, AggregateFunctionDefinition, WindowFunction,
};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};

/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
///
Expand All @@ -47,181 +41,62 @@ impl CountWildcardRule {

impl AnalyzerRule for CountWildcardRule {
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
plan.transform_down(&analyze_internal).data()
plan.transform_down_with_subqueries(&analyze_internal)
.data()
}

fn name(&self) -> &str {
"count_wildcard_rule"
}
}

fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
let mut rewriter = CountWildcardRewriter {};
match plan {
LogicalPlan::Window(window) => {
let window_expr = window
.window_expr
.iter()
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
.collect::<Result<Vec<_>>>()?;

Ok(Transformed::yes(
LogicalPlanBuilder::from((*window.input).clone())
.window(window_expr)?
.build()?,
))
}
LogicalPlan::Aggregate(agg) => {
let aggr_expr = agg
.aggr_expr
.iter()
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
.collect::<Result<Vec<_>>>()?;

Ok(Transformed::yes(LogicalPlan::Aggregate(
Aggregate::try_new(agg.input.clone(), agg.group_expr, aggr_expr)?,
)))
}
LogicalPlan::Sort(Sort { expr, input, fetch }) => {
let sort_expr = expr
.iter()
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
.collect::<Result<Vec<_>>>()?;
Ok(Transformed::yes(LogicalPlan::Sort(Sort {
expr: sort_expr,
input,
fetch,
})))
}
LogicalPlan::Projection(projection) => {
let projection_expr = projection
.expr
.iter()
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
.collect::<Result<Vec<_>>>()?;
Ok(Transformed::yes(LogicalPlan::Projection(
Projection::try_new(projection_expr, projection.input)?,
)))
}
LogicalPlan::Filter(Filter {
predicate, input, ..
}) => {
let predicate = rewrite_preserving_name(predicate, &mut rewriter)?;
Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new(
predicate, input,
)?)))
}

_ => Ok(Transformed::no(plan)),
}
fn is_wildcard(expr: &Expr) -> bool {
matches!(expr, Expr::Wildcard { qualifier: None })
}

struct CountWildcardRewriter {}

impl TreeNodeRewriter for CountWildcardRewriter {
type Node = Expr;

fn f_up(&mut self, old_expr: Expr) -> Result<Transformed<Expr>> {
Ok(match old_expr.clone() {
Expr::WindowFunction(expr::WindowFunction {
fun:
expr::WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Count,
),
args,
partition_by,
order_by,
window_frame,
null_treatment,
}) if args.len() == 1 => match args[0] {
Expr::Wildcard { qualifier: None } => {
Transformed::yes(Expr::WindowFunction(expr::WindowFunction {
fun: expr::WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Count,
),
args: vec![lit(COUNT_STAR_EXPANSION)],
partition_by,
order_by,
window_frame,
null_treatment,
}))
}
fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
matches!(
&aggregate_function.func_def,
AggregateFunctionDefinition::BuiltIn(
datafusion_expr::aggregate_function::AggregateFunction::Count,
)
) && aggregate_function.args.len() == 1
&& is_wildcard(&aggregate_function.args[0])
}

_ => Transformed::no(old_expr),
},
Expr::AggregateFunction(AggregateFunction {
func_def:
AggregateFunctionDefinition::BuiltIn(
aggregate_function::AggregateFunction::Count,
),
args,
distinct,
filter,
order_by,
null_treatment,
}) if args.len() == 1 => match args[0] {
Expr::Wildcard { qualifier: None } => {
Transformed::yes(Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Count,
vec![lit(COUNT_STAR_EXPANSION)],
distinct,
filter,
order_by,
null_treatment,
)))
}
_ => Transformed::no(old_expr),
},
fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
matches!(
&window_function.fun,
WindowFunctionDefinition::AggregateFunction(
datafusion_expr::aggregate_function::AggregateFunction::Count,
)
) && window_function.args.len() == 1
&& is_wildcard(&window_function.args[0])
}

ScalarSubquery(Subquery {
subquery,
outer_ref_columns,
}) => subquery
.as_ref()
.clone()
.transform_down(&analyze_internal)?
.update_data(|new_plan| {
ScalarSubquery(Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns,
})
}),
Expr::InSubquery(InSubquery {
expr,
subquery,
negated,
}) => subquery
.subquery
.as_ref()
.clone()
.transform_down(&analyze_internal)?
.update_data(|new_plan| {
Expr::InSubquery(InSubquery::new(
expr,
Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns: subquery.outer_ref_columns,
},
negated,
))
}),
Expr::Exists(expr::Exists { subquery, negated }) => subquery
.subquery
.as_ref()
.clone()
.transform_down(&analyze_internal)?
.update_data(|new_plan| {
Expr::Exists(expr::Exists {
subquery: Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns: subquery.outer_ref_columns,
},
negated,
})
}),
_ => Transformed::no(old_expr),
})
}
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr)?;
let transformed_expr = expr.transform_up(&|expr| match expr {
Expr::WindowFunction(mut window_function)
if is_count_star_window_aggregate(&window_function) =>
{
window_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::WindowFunction(window_function)))
}
Expr::AggregateFunction(mut aggregate_function)
if is_count_star_aggregate(&aggregate_function) =>
{
aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::AggregateFunction(
aggregate_function,
)))
}
_ => Ok(Transformed::no(expr)),
})?;
transformed_expr.map_data(|data| original_name.restore(data))
})
}

#[cfg(test)]
Expand All @@ -233,9 +108,10 @@ mod tests {
use datafusion_expr::expr::Sort;
use datafusion_expr::{
col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder,
max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr,
max, out_ref_col, scalar_subquery, sum, wildcard, AggregateFunction, Expr,
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use std::sync::Arc;

fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
assert_analyzed_plan_eq_display_indent(
Expand Down Expand Up @@ -381,6 +257,17 @@ mod tests {
assert_plan_eq(&plan, expected)
}

#[test]
fn test_count_wildcard_on_non_count_aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
let err = LogicalPlanBuilder::from(table_scan)
.aggregate(Vec::<Expr>::new(), vec![sum(wildcard())])
.unwrap_err()
.to_string();
assert!(err.contains("Error during planning: No function matches the given name and argument types 'SUM(Null)'."), "{err}");
Ok(())
}

#[test]
fn test_count_wildcard_on_nesting() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/analyzer/function_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl ApplyFunctionRewrites {
let original_name = name_preserver.save(&expr)?;

// recursively transform the expression, applying the rewrites at each step
let result = expr.transform_up(&|expr| {
let transformed_expr = expr.transform_up(&|expr| {
let mut result = Transformed::no(expr);
for rewriter in self.function_rewrites.iter() {
result = result.transform_data(|expr| {
Expand All @@ -74,7 +74,7 @@ impl ApplyFunctionRewrites {
Ok(result)
})?;

result.map_data(|expr| original_name.restore(expr))
transformed_expr.map_data(|expr| original_name.restore(expr))
})
}
}
Expand Down

0 comments on commit a165b7f

Please sign in to comment.