Skip to content

Commit

Permalink
Introduce Sum UDAF (apache#10651)
Browse files Browse the repository at this point in the history
* move accumulate

Signed-off-by: jayzhan211 <[email protected]>

* move prim_op

Signed-off-by: jayzhan211 <[email protected]>

* move test to slt

Signed-off-by: jayzhan211 <[email protected]>

* remove sum distinct

Signed-off-by: jayzhan211 <[email protected]>

* move sum aggregate

Signed-off-by: jayzhan211 <[email protected]>

* fix args

Signed-off-by: jayzhan211 <[email protected]>

* add sum

Signed-off-by: jayzhan211 <[email protected]>

* merge fix

Signed-off-by: jayzhan211 <[email protected]>

* fix sum sig

Signed-off-by: jayzhan211 <[email protected]>

* todo: wait ahash merge

Signed-off-by: jayzhan211 <[email protected]>

* rebase

Signed-off-by: jayzhan211 <[email protected]>

* disable ordering req by default

Signed-off-by: jayzhan211 <[email protected]>

* check arg count

Signed-off-by: jayzhan211 <[email protected]>

* rm old workflow

Signed-off-by: jayzhan211 <[email protected]>

* fmt

Signed-off-by: jayzhan211 <[email protected]>

* fix failed test

Signed-off-by: jayzhan211 <[email protected]>

* doc and fmt

Signed-off-by: jayzhan211 <[email protected]>

* check udaf first

Signed-off-by: jayzhan211 <[email protected]>

* fmt

Signed-off-by: jayzhan211 <[email protected]>

* fix ci

Signed-off-by: jayzhan211 <[email protected]>

* fix ci

Signed-off-by: jayzhan211 <[email protected]>

* fix ci

Signed-off-by: jayzhan211 <[email protected]>

* fix err msg AGAIN

Signed-off-by: jayzhan211 <[email protected]>

* rm sum in builtin test which covered in sql

Signed-off-by: jayzhan211 <[email protected]>

* proto for window with udaf

Signed-off-by: jayzhan211 <[email protected]>

* fix slt

Signed-off-by: jayzhan211 <[email protected]>

* fmt

Signed-off-by: jayzhan211 <[email protected]>

* fix err msg

Signed-off-by: jayzhan211 <[email protected]>

* fix exprfn

Signed-off-by: jayzhan211 <[email protected]>

* fix ciy

Signed-off-by: jayzhan211 <[email protected]>

* fix ci

Signed-off-by: jayzhan211 <[email protected]>

* rename first/last to lowercase

Signed-off-by: jayzhan211 <[email protected]>

* skip sum

Signed-off-by: jayzhan211 <[email protected]>

* fix firstvalue

Signed-off-by: jayzhan211 <[email protected]>

* clippy

Signed-off-by: jayzhan211 <[email protected]>

* add doc

Signed-off-by: jayzhan211 <[email protected]>

* rm has_ordering_req

Signed-off-by: jayzhan211 <[email protected]>

* default hard req

Signed-off-by: jayzhan211 <[email protected]>

* insensitive for sum

Signed-off-by: jayzhan211 <[email protected]>

* cleanup duplicate code

Signed-off-by: jayzhan211 <[email protected]>

* Re-introduce check

---------

Signed-off-by: jayzhan211 <[email protected]>
Co-authored-by: Mustafa Akur <[email protected]>
  • Loading branch information
2 people authored and findepi committed Jul 16, 2024
1 parent 042da6a commit 95d1b62
Show file tree
Hide file tree
Showing 41 changed files with 888 additions and 299 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ impl AggregateUDFImpl for GeoMeanUdaf {
true
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(GeometricMeanGroupsAccumulator::new()))
}
}
Expand Down
6 changes: 5 additions & 1 deletion datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ impl AggregateUDFImpl for BetterAvgUdaf {
true
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!("should not get here");
}

// we override method, to return new expression which would substitute
// user defined function call
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ use datafusion_expr::{
avg, count, max, min, stddev, utils::COUNT_STAR_EXPANSION,
TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_expr::{case, is_null, sum};
use datafusion_expr::{case, is_null};
use datafusion_functions_aggregate::expr_fn::median;
use datafusion_functions_aggregate::expr_fn::sum;

use async_trait::async_trait;

Expand Down Expand Up @@ -1593,9 +1594,8 @@ mod tests {
use datafusion_common::{Constraint, Constraints};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::{
array_agg, cast, count_distinct, create_udf, expr, lit, sum,
BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame,
WindowFunctionDefinition,
array_agg, cast, count_distinct, create_udf, expr, lit, BuiltInWindowFunction,
ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition,
};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties};
Expand Down
5 changes: 2 additions & 3 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2257,9 +2257,8 @@ mod tests {
use datafusion_common::{assert_contains, DFSchemaRef, TableReference};
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_execution::TaskContext;
use datafusion_expr::{
col, lit, sum, LogicalPlanBuilder, UserDefinedLogicalNodeCore,
};
use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore};
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_physical_expr::EquivalenceProperties;

fn make_session_state() -> SessionState {
Expand Down
1 change: 0 additions & 1 deletion datafusion/core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ pub use datafusion_expr::{
Expr,
};
pub use datafusion_functions::expr_fn::*;
pub use datafusion_functions_aggregate::expr_fn::*;
#[cfg(feature = "array_expressions")]
pub use datafusion_functions_array::expr_fn::*;

Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max, out_ref_col,
placeholder, scalar_subquery, sum, when, wildcard, AggregateFunction, Expr,
ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
placeholder, scalar_subquery, when, wildcard, AggregateFunction, Expr, ExprSchemable,
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::sum;

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down
12 changes: 11 additions & 1 deletion datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_common::{Result, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::type_coercion::aggregates::coerce_types;
use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf;
use datafusion_expr::{
AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use test_utils::add_empty_batches;
Expand Down Expand Up @@ -341,7 +343,7 @@ fn get_random_function(
window_fn_map.insert(
"sum",
(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
WindowFunctionDefinition::AggregateUDF(sum_udaf()),
vec![arg.clone()],
),
);
Expand Down Expand Up @@ -468,6 +470,14 @@ fn get_random_function(
let coerced = coerce_types(f, &[dt], &sig).unwrap();
args[0] = cast(a, schema, coerced[0].clone()).unwrap();
}
} else if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn {
if !args.is_empty() {
// Do type coercion first argument
let a = args[0].clone();
let dt = a.data_type(schema.as_ref()).unwrap();
let coerced = data_types_with_aggregate_udf(&[dt], udf).unwrap();
args[0] = cast(a, schema, coerced[0].clone()).unwrap();
}
}

(window_fn.clone(), args, fn_name.to_string())
Expand Down
7 changes: 5 additions & 2 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async fn test_udaf_as_window_with_frame_without_retract_batch() {
let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t";
// Note if this query ever does start working
let err = execute(&ctx, sql).await.unwrap_err();
assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: AggregateUDF { inner: AggregateUDF { name: \"time_sum\", signature: Signature { type_signature: Exact([Timestamp(Nanosecond, None)]), volatility: Immutable }, fun: \"<FUNC>\" } }(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING");
assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: time_sum(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING");
}

/// Basic query for with a udaf returning a structure
Expand Down Expand Up @@ -729,7 +729,10 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
true
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(self.clone()))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,8 @@ async fn udaf_as_window_func() -> Result<()> {
context.register_udaf(my_acc);

let sql = "SELECT a, MY_ACC(b) OVER(PARTITION BY a) FROM my_table";
let expected = r#"Projection: my_table.a, AggregateUDF { inner: AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "<FUNC>" } }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
WindowAggr: windowExpr=[[AggregateUDF { inner: AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "<FUNC>" } }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
let expected = r#"Projection: my_table.a, my_acc(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
WindowAggr: windowExpr=[[my_acc(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
TableScan: my_table"#;

let dataframe = context.sql(sql).await.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/built_in_window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ impl BuiltInWindowFunction {
Ntile => "NTILE",
Lag => "LAG",
Lead => "LEAD",
FirstValue => "FIRST_VALUE",
LastValue => "LAST_VALUE",
FirstValue => "first_value",
LastValue => "last_value",
NthValue => "NTH_VALUE",
}
}
Expand Down
18 changes: 13 additions & 5 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -754,10 +754,14 @@ impl WindowFunctionDefinition {
impl fmt::Display for WindowFunctionDefinition {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
WindowFunctionDefinition::AggregateFunction(fun) => fun.fmt(f),
WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.fmt(f),
WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f),
WindowFunctionDefinition::WindowUDF(fun) => fun.fmt(f),
WindowFunctionDefinition::AggregateFunction(fun) => {
std::fmt::Display::fmt(fun, f)
}
WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
std::fmt::Display::fmt(fun, f)
}
WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Display::fmt(fun, f),
WindowFunctionDefinition::WindowUDF(fun) => std::fmt::Display::fmt(fun, f),
}
}
}
Expand Down Expand Up @@ -2263,7 +2267,11 @@ mod test {
let fun = find_df_window_func(name).unwrap();
let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap();
assert_eq!(fun, fun2);
assert_eq!(fun.to_string(), name.to_uppercase());
if fun.to_string() == "first_value" || fun.to_string() == "last_value" {
assert_eq!(fun.to_string(), name);
} else {
assert_eq!(fun.to_string(), name.to_uppercase());
}
}
Ok(())
}
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ pub fn max(expr: Expr) -> Expr {
}

/// Create an expression to represent the sum() aggregate function
///
/// TODO: Remove this function and use `sum` from `datafusion_functions_aggregate::expr_fn` instead
pub fn sum(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Sum,
Expand Down
39 changes: 35 additions & 4 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ use crate::expr::{
InSubquery, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction,
};
use crate::type_coercion::binary::get_result_type;
use crate::type_coercion::functions::data_types_with_scalar_udf;
use crate::{utils, LogicalPlan, Projection, Subquery};
use crate::type_coercion::functions::{
data_types_with_aggregate_udf, data_types_with_scalar_udf,
};
use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{
Expand Down Expand Up @@ -158,7 +160,25 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
fun.return_type(&data_types)
match fun {
WindowFunctionDefinition::AggregateUDF(udf) => {
let new_types = data_types_with_aggregate_udf(&data_types, udf).map_err(|err| {
plan_datafusion_err!(
"{} and {}",
err,
utils::generate_signature_error_msg(
fun.name(),
fun.signature().clone(),
&data_types
)
)
})?;
Ok(fun.return_type(&new_types)?)
}
_ => {
fun.return_type(&data_types)
}
}
}
Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => {
let data_types = args
Expand All @@ -170,7 +190,18 @@ impl ExprSchemable for Expr {
fun.return_type(&data_types)
}
AggregateFunctionDefinition::UDF(fun) => {
Ok(fun.return_type(&data_types)?)
let new_types = data_types_with_aggregate_udf(&data_types, fun).map_err(|err| {
plan_datafusion_err!(
"{} and {}",
err,
utils::generate_signature_error_msg(
fun.name(),
fun.signature().clone(),
&data_types
)
)
})?;
Ok(fun.return_type(&new_types)?)
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ pub struct AccumulatorArgs<'a> {
/// If no `ORDER BY` is specified, `sort_exprs`` will be empty.
pub sort_exprs: &'a [Expr],

/// The name of the aggregate expression
pub name: &'a str,

/// Whether the aggregate function is distinct.
///
/// ```sql
Expand All @@ -82,9 +85,6 @@ pub struct AccumulatorArgs<'a> {

/// The number of arguments the aggregate function takes.
pub args_num: usize,

/// The name of the expression
pub name: &'a str,
}

/// [`StateFieldsArgs`] contains information about the fields that an
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub use built_in_window_function::BuiltInWindowFunction;
pub use columnar_value::ColumnarValue;
pub use expr::{
Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet,
Like, TryCast, WindowFunctionDefinition,
Like, Sort as SortExpr, TryCast, WindowFunctionDefinition,
};
pub use expr_fn::*;
pub use expr_schema::ExprSchemable;
Expand Down
Loading

0 comments on commit 95d1b62

Please sign in to comment.