Skip to content

Commit

Permalink
Remove AggregateFunctionDefinition (#11803)
Browse files Browse the repository at this point in the history
* Remove �[200~if udf.name() == count => {

* Apply review suggestions
  • Loading branch information
lewiszlw authored Aug 5, 2024
1 parent c8e5996 commit b4069a6
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 199 deletions.
69 changes: 32 additions & 37 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ use datafusion_common::{
};
use datafusion_expr::dml::CopyTo;
use datafusion_expr::expr::{
self, AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr,
Cast, GroupingSet, InList, Like, TryCast, WindowFunction,
self, AggregateFunction, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like,
TryCast, WindowFunction,
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::expr_vec_fmt;
Expand Down Expand Up @@ -223,18 +223,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
create_function_physical_name(&fun.to_string(), false, args, Some(order_by))
}
Expr::AggregateFunction(AggregateFunction {
func_def,
func,
distinct,
args,
filter: _,
order_by,
null_treatment: _,
}) => create_function_physical_name(
func_def.name(),
*distinct,
args,
order_by.as_ref(),
),
}) => {
create_function_physical_name(func.name(), *distinct, args, order_by.as_ref())
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Ok(format!(
"ROLLUP ({})",
Expand Down Expand Up @@ -1817,7 +1814,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
) -> Result<AggregateExprWithOptionalArgs> {
match e {
Expr::AggregateFunction(AggregateFunction {
func_def,
func,
distinct,
args,
filter,
Expand All @@ -1839,36 +1836,34 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
== NullTreatment::IgnoreNulls;

let (agg_expr, filter, order_by) = match func_def {
AggregateFunctionDefinition::UDF(fun) => {
let sort_exprs = order_by.clone().unwrap_or(vec![]);
let physical_sort_exprs = match order_by {
Some(exprs) => Some(create_physical_sort_exprs(
exprs,
logical_input_schema,
execution_props,
)?),
None => None,
};
let (agg_expr, filter, order_by) = {
let sort_exprs = order_by.clone().unwrap_or(vec![]);
let physical_sort_exprs = match order_by {
Some(exprs) => Some(create_physical_sort_exprs(
exprs,
logical_input_schema,
execution_props,
)?),
None => None,
};

let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);
let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);

let agg_expr = udaf::create_aggregate_expr_with_dfschema(
fun,
&physical_args,
args,
&sort_exprs,
&ordering_reqs,
logical_input_schema,
name,
ignore_nulls,
*distinct,
false,
)?;
let agg_expr = udaf::create_aggregate_expr_with_dfschema(
func,
&physical_args,
args,
&sort_exprs,
&ordering_reqs,
logical_input_schema,
name,
ignore_nulls,
*distinct,
false,
)?;

(agg_expr, filter, physical_sort_exprs)
}
(agg_expr, filter, physical_sort_exprs)
};

Ok((agg_expr, filter, order_by))
Expand Down
34 changes: 9 additions & 25 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,22 +627,6 @@ impl Sort {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
/// Defines which implementation of an aggregate function DataFusion should call.
pub enum AggregateFunctionDefinition {
/// Resolved to a user defined aggregate function
UDF(Arc<crate::AggregateUDF>),
}

impl AggregateFunctionDefinition {
/// Function's name for display
pub fn name(&self) -> &str {
match self {
AggregateFunctionDefinition::UDF(udf) => udf.name(),
}
}
}

/// Aggregate function
///
/// See also [`ExprFunctionExt`] to set these fields on `Expr`
Expand All @@ -651,7 +635,7 @@ impl AggregateFunctionDefinition {
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct AggregateFunction {
/// Name of the function
pub func_def: AggregateFunctionDefinition,
pub func: Arc<crate::AggregateUDF>,
/// List of expressions to feed to the functions as arguments
pub args: Vec<Expr>,
/// Whether this is a DISTINCT aggregation or not
Expand All @@ -666,15 +650,15 @@ pub struct AggregateFunction {
impl AggregateFunction {
/// Create a new AggregateFunction expression with a user-defined function (UDF)
pub fn new_udf(
udf: Arc<crate::AggregateUDF>,
func: Arc<crate::AggregateUDF>,
args: Vec<Expr>,
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Expr>>,
null_treatment: Option<NullTreatment>,
) -> Self {
Self {
func_def: AggregateFunctionDefinition::UDF(udf),
func,
args,
distinct,
filter,
Expand Down Expand Up @@ -1666,14 +1650,14 @@ impl Expr {
func.hash(hasher);
}
Expr::AggregateFunction(AggregateFunction {
func_def,
func,
args: _args,
distinct,
filter: _filter,
order_by: _order_by,
null_treatment,
}) => {
func_def.hash(hasher);
func.hash(hasher);
distinct.hash(hasher);
null_treatment.hash(hasher);
}
Expand Down Expand Up @@ -1870,15 +1854,15 @@ impl fmt::Display for Expr {
Ok(())
}
Expr::AggregateFunction(AggregateFunction {
func_def,
func,
distinct,
ref args,
filter,
order_by,
null_treatment,
..
}) => {
fmt_function(f, func_def.name(), *distinct, args, true)?;
fmt_function(f, func.name(), *distinct, args, true)?;
if let Some(nt) = null_treatment {
write!(f, " {}", nt)?;
}
Expand Down Expand Up @@ -2190,14 +2174,14 @@ fn write_name<W: Write>(w: &mut W, e: &Expr) -> Result<()> {
write!(w, "{window_frame}")?;
}
Expr::AggregateFunction(AggregateFunction {
func_def,
func,
distinct,
args,
filter,
order_by,
null_treatment,
}) => {
write_function_name(w, func_def.name(), *distinct, args)?;
write_function_name(w, func.name(), *distinct, args)?;
if let Some(fe) = filter {
write!(w, " FILTER (WHERE {fe})")?;
};
Expand Down
47 changes: 21 additions & 26 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

use super::{Between, Expr, Like};
use crate::expr::{
AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast, InList,
InSubquery, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction,
AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder,
ScalarFunction, Sort, TryCast, Unnest, WindowFunction,
};
use crate::type_coercion::binary::get_result_type;
use crate::type_coercion::functions::{
Expand Down Expand Up @@ -193,28 +193,24 @@ impl ExprSchemable for Expr {
_ => fun.return_type(&data_types, &nullability),
}
}
Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => {
Expr::AggregateFunction(AggregateFunction { func, args, .. }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
match func_def {
AggregateFunctionDefinition::UDF(fun) => {
let new_types = data_types_with_aggregate_udf(&data_types, fun)
.map_err(|err| {
plan_datafusion_err!(
"{} {}",
err,
utils::generate_signature_error_msg(
fun.name(),
fun.signature().clone(),
&data_types
)
let new_types = data_types_with_aggregate_udf(&data_types, func)
.map_err(|err| {
plan_datafusion_err!(
"{} {}",
err,
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
&data_types
)
})?;
Ok(fun.return_type(&new_types)?)
}
}
)
})?;
Ok(func.return_type(&new_types)?)
}
Expr::Not(_)
| Expr::IsNull(_)
Expand Down Expand Up @@ -329,13 +325,12 @@ impl ExprSchemable for Expr {
}
}
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
Expr::AggregateFunction(AggregateFunction { func_def, .. }) => {
match func_def {
// TODO: UDF should be able to customize nullability
AggregateFunctionDefinition::UDF(udf) if udf.name() == "count" => {
Ok(false)
}
AggregateFunctionDefinition::UDF(_) => Ok(true),
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
// TODO: UDF should be able to customize nullability
if func.name() == "count" {
Ok(false)
} else {
Ok(true)
}
}
Expr::ScalarVariable(_, _)
Expand Down
31 changes: 13 additions & 18 deletions datafusion/expr/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
//! Tree node implementation for logical expr
use crate::expr::{
AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case,
Cast, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, Sort,
TryCast, Unnest, WindowFunction,
AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList,
InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction,
};
use crate::{Expr, ExprFunctionExt};

Expand Down Expand Up @@ -304,7 +303,7 @@ impl TreeNode for Expr {
}),
Expr::AggregateFunction(AggregateFunction {
args,
func_def,
func,
distinct,
filter,
order_by,
Expand All @@ -316,20 +315,16 @@ impl TreeNode for Expr {
order_by,
transform_option_vec(order_by, &mut f)
)?
.map_data(
|(new_args, new_filter, new_order_by)| match func_def {
AggregateFunctionDefinition::UDF(fun) => {
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
fun,
new_args,
distinct,
new_filter,
new_order_by,
null_treatment,
)))
}
},
)?,
.map_data(|(new_args, new_filter, new_order_by)| {
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
func,
new_args,
distinct,
new_filter,
new_order_by,
null_treatment,
)))
})?,
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)?
.update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
Expand Down
4 changes: 1 addition & 3 deletions datafusion/functions-nested/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
use datafusion_common::{exec_err, utils::list_ndims, DFSchema, Result};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{
expr::AggregateFunctionDefinition,
planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr},
sqlparser, Expr, ExprSchemable, GetFieldAccess,
};
Expand Down Expand Up @@ -171,6 +170,5 @@ impl ExprPlanner for FieldAccessPlanner {
}

fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool {
let AggregateFunctionDefinition::UDF(udf) = &agg_func.func_def;
return udf.name() == "array_agg";
return agg_func.func.name() == "array_agg";
}
8 changes: 3 additions & 5 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ use crate::utils::NamePreserver;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use datafusion_expr::expr::{
AggregateFunction, AggregateFunctionDefinition, WindowFunction,
};
use datafusion_expr::expr::{AggregateFunction, WindowFunction};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};

Expand Down Expand Up @@ -56,10 +54,10 @@ fn is_wildcard(expr: &Expr) -> bool {
fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
matches!(aggregate_function,
AggregateFunction {
func_def: AggregateFunctionDefinition::UDF(udf),
func,
args,
..
} if udf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
} if func.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
}

fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
Expand Down
Loading

0 comments on commit b4069a6

Please sign in to comment.