-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid copies in CountWildcardRule
via TreeNode API
#10066
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,23 +15,17 @@ | |
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
use std::sync::Arc; | ||
|
||
use crate::analyzer::AnalyzerRule; | ||
|
||
use crate::utils::NamePreserver; | ||
use datafusion_common::config::ConfigOptions; | ||
use datafusion_common::tree_node::{ | ||
Transformed, TransformedResult, TreeNode, TreeNodeRewriter, | ||
}; | ||
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; | ||
use datafusion_common::Result; | ||
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; | ||
use datafusion_expr::expr_rewriter::rewrite_preserving_name; | ||
use datafusion_expr::utils::COUNT_STAR_EXPANSION; | ||
use datafusion_expr::Expr::ScalarSubquery; | ||
use datafusion_expr::{ | ||
aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan, | ||
LogicalPlanBuilder, Projection, Sort, Subquery, | ||
use datafusion_expr::expr::{ | ||
AggregateFunction, AggregateFunctionDefinition, WindowFunction, | ||
}; | ||
use datafusion_expr::utils::COUNT_STAR_EXPANSION; | ||
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; | ||
|
||
/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. | ||
/// | ||
|
@@ -47,181 +41,62 @@ impl CountWildcardRule { | |
|
||
impl AnalyzerRule for CountWildcardRule { | ||
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> { | ||
plan.transform_down(&analyze_internal).data() | ||
plan.transform_down_with_subqueries(&analyze_internal) | ||
.data() | ||
} | ||
|
||
fn name(&self) -> &str { | ||
"count_wildcard_rule" | ||
} | ||
} | ||
|
||
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> { | ||
let mut rewriter = CountWildcardRewriter {}; | ||
match plan { | ||
LogicalPlan::Window(window) => { | ||
let window_expr = window | ||
.window_expr | ||
.iter() | ||
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed this clone (and the ones below) |
||
.collect::<Result<Vec<_>>>()?; | ||
|
||
Ok(Transformed::yes( | ||
LogicalPlanBuilder::from((*window.input).clone()) | ||
.window(window_expr)? | ||
.build()?, | ||
)) | ||
} | ||
LogicalPlan::Aggregate(agg) => { | ||
let aggr_expr = agg | ||
.aggr_expr | ||
.iter() | ||
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) | ||
.collect::<Result<Vec<_>>>()?; | ||
|
||
Ok(Transformed::yes(LogicalPlan::Aggregate( | ||
Aggregate::try_new(agg.input.clone(), agg.group_expr, aggr_expr)?, | ||
))) | ||
} | ||
LogicalPlan::Sort(Sort { expr, input, fetch }) => { | ||
let sort_expr = expr | ||
.iter() | ||
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) | ||
.collect::<Result<Vec<_>>>()?; | ||
Ok(Transformed::yes(LogicalPlan::Sort(Sort { | ||
expr: sort_expr, | ||
input, | ||
fetch, | ||
}))) | ||
} | ||
LogicalPlan::Projection(projection) => { | ||
let projection_expr = projection | ||
.expr | ||
.iter() | ||
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) | ||
.collect::<Result<Vec<_>>>()?; | ||
Ok(Transformed::yes(LogicalPlan::Projection( | ||
Projection::try_new(projection_expr, projection.input)?, | ||
))) | ||
} | ||
LogicalPlan::Filter(Filter { | ||
predicate, input, .. | ||
}) => { | ||
let predicate = rewrite_preserving_name(predicate, &mut rewriter)?; | ||
Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new( | ||
predicate, input, | ||
)?))) | ||
} | ||
|
||
_ => Ok(Transformed::no(plan)), | ||
} | ||
fn is_wildcard(expr: &Expr) -> bool { | ||
matches!(expr, Expr::Wildcard { qualifier: None }) | ||
} | ||
|
||
struct CountWildcardRewriter {} | ||
|
||
impl TreeNodeRewriter for CountWildcardRewriter { | ||
type Node = Expr; | ||
|
||
fn f_up(&mut self, old_expr: Expr) -> Result<Transformed<Expr>> { | ||
Ok(match old_expr.clone() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here is another clone that is avoided |
||
Expr::WindowFunction(expr::WindowFunction { | ||
fun: | ||
expr::WindowFunctionDefinition::AggregateFunction( | ||
aggregate_function::AggregateFunction::Count, | ||
), | ||
args, | ||
partition_by, | ||
order_by, | ||
window_frame, | ||
null_treatment, | ||
}) if args.len() == 1 => match args[0] { | ||
Expr::Wildcard { qualifier: None } => { | ||
Transformed::yes(Expr::WindowFunction(expr::WindowFunction { | ||
fun: expr::WindowFunctionDefinition::AggregateFunction( | ||
aggregate_function::AggregateFunction::Count, | ||
), | ||
args: vec![lit(COUNT_STAR_EXPANSION)], | ||
partition_by, | ||
order_by, | ||
window_frame, | ||
null_treatment, | ||
})) | ||
} | ||
fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { | ||
matches!( | ||
&aggregate_function.func_def, | ||
AggregateFunctionDefinition::BuiltIn( | ||
datafusion_expr::aggregate_function::AggregateFunction::Count, | ||
) | ||
) && aggregate_function.args.len() == 1 | ||
&& is_wildcard(&aggregate_function.args[0]) | ||
} | ||
|
||
_ => Transformed::no(old_expr), | ||
}, | ||
Expr::AggregateFunction(AggregateFunction { | ||
func_def: | ||
AggregateFunctionDefinition::BuiltIn( | ||
aggregate_function::AggregateFunction::Count, | ||
), | ||
args, | ||
distinct, | ||
filter, | ||
order_by, | ||
null_treatment, | ||
}) if args.len() == 1 => match args[0] { | ||
Expr::Wildcard { qualifier: None } => { | ||
Transformed::yes(Expr::AggregateFunction(AggregateFunction::new( | ||
aggregate_function::AggregateFunction::Count, | ||
vec![lit(COUNT_STAR_EXPANSION)], | ||
distinct, | ||
filter, | ||
order_by, | ||
null_treatment, | ||
))) | ||
} | ||
_ => Transformed::no(old_expr), | ||
}, | ||
fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { | ||
matches!( | ||
&window_function.fun, | ||
WindowFunctionDefinition::AggregateFunction( | ||
datafusion_expr::aggregate_function::AggregateFunction::Count, | ||
) | ||
) && window_function.args.len() == 1 | ||
&& is_wildcard(&window_function.args[0]) | ||
} | ||
|
||
ScalarSubquery(Subquery { | ||
subquery, | ||
outer_ref_columns, | ||
}) => subquery | ||
.as_ref() | ||
.clone() | ||
.transform_down(&analyze_internal)? | ||
.update_data(|new_plan| { | ||
ScalarSubquery(Subquery { | ||
subquery: Arc::new(new_plan), | ||
outer_ref_columns, | ||
}) | ||
}), | ||
Expr::InSubquery(InSubquery { | ||
expr, | ||
subquery, | ||
negated, | ||
}) => subquery | ||
.subquery | ||
.as_ref() | ||
.clone() | ||
.transform_down(&analyze_internal)? | ||
.update_data(|new_plan| { | ||
Expr::InSubquery(InSubquery::new( | ||
expr, | ||
Subquery { | ||
subquery: Arc::new(new_plan), | ||
outer_ref_columns: subquery.outer_ref_columns, | ||
}, | ||
negated, | ||
)) | ||
}), | ||
Expr::Exists(expr::Exists { subquery, negated }) => subquery | ||
.subquery | ||
.as_ref() | ||
.clone() | ||
.transform_down(&analyze_internal)? | ||
.update_data(|new_plan| { | ||
Expr::Exists(expr::Exists { | ||
subquery: Subquery { | ||
subquery: Arc::new(new_plan), | ||
outer_ref_columns: subquery.outer_ref_columns, | ||
}, | ||
negated, | ||
}) | ||
}), | ||
_ => Transformed::no(old_expr), | ||
}) | ||
} | ||
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> { | ||
let name_preserver = NamePreserver::new(&plan); | ||
plan.map_expressions(|expr| { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not only does this logic avoid the clone ( |
||
let original_name = name_preserver.save(&expr)?; | ||
let transformed_expr = expr.transform_up(&|expr| match expr { | ||
Expr::WindowFunction(mut window_function) | ||
if is_count_star_window_aggregate(&window_function) => | ||
{ | ||
window_function.args = vec![lit(COUNT_STAR_EXPANSION)]; | ||
Ok(Transformed::yes(Expr::WindowFunction(window_function))) | ||
} | ||
Expr::AggregateFunction(mut aggregate_function) | ||
if is_count_star_aggregate(&aggregate_function) => | ||
{ | ||
aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)]; | ||
Ok(Transformed::yes(Expr::AggregateFunction( | ||
aggregate_function, | ||
))) | ||
} | ||
_ => Ok(Transformed::no(expr)), | ||
})?; | ||
transformed_expr.map_data(|data| original_name.restore(data)) | ||
}) | ||
} | ||
|
||
#[cfg(test)] | ||
|
@@ -233,9 +108,10 @@ mod tests { | |
use datafusion_expr::expr::Sort; | ||
use datafusion_expr::{ | ||
col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder, | ||
max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr, | ||
max, out_ref_col, scalar_subquery, sum, wildcard, AggregateFunction, Expr, | ||
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, | ||
}; | ||
use std::sync::Arc; | ||
|
||
fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { | ||
assert_analyzed_plan_eq_display_indent( | ||
|
@@ -381,6 +257,17 @@ mod tests { | |
assert_plan_eq(&plan, expected) | ||
} | ||
|
||
#[test] | ||
fn test_count_wildcard_on_non_count_aggregate() -> Result<()> { | ||
let table_scan = test_table_scan()?; | ||
let err = LogicalPlanBuilder::from(table_scan) | ||
.aggregate(Vec::<Expr>::new(), vec![sum(wildcard())]) | ||
.unwrap_err() | ||
.to_string(); | ||
assert!(err.contains("Error during planning: No function matches the given name and argument types 'SUM(Null)'."), "{err}"); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_count_wildcard_on_nesting() -> Result<()> { | ||
let table_scan = test_table_scan()?; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,7 +64,7 @@ impl ApplyFunctionRewrites { | |
let original_name = name_preserver.save(&expr)?; | ||
|
||
// recursively transform the expression, applying the rewrites at each step | ||
let result = expr.transform_up(&|expr| { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. drive by cleanup based on @crepererum 's comment on #10038 (comment) |
||
let transformed_expr = expr.transform_up(&|expr| { | ||
let mut result = Transformed::no(expr); | ||
for rewriter in self.function_rewrites.iter() { | ||
result = result.transform_data(|expr| { | ||
|
@@ -74,7 +74,7 @@ impl ApplyFunctionRewrites { | |
Ok(result) | ||
})?; | ||
|
||
result.map_data(|expr| original_name.restore(expr)) | ||
transformed_expr.map_data(|expr| original_name.restore(expr)) | ||
}) | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By using the wonderful
transform_down_with_subqueries
from @peter-toth this rule can avoid having to recuse into subqueries directly by itself