Skip to content

Commit

Permalink
minor: revert parsing precedence between Aggr and UDAF (#7682)
Browse files Browse the repository at this point in the history
* minor: revert parsing precedence between Aggr and UDAF

Signed-off-by: Ruihang Xia <[email protected]>

* add unit test

Signed-off-by: Ruihang Xia <[email protected]>

---------

Signed-off-by: Ruihang Xia <[email protected]>
  • Loading branch information
waynexia authored Sep 29, 2023
1 parent 70cded6 commit 2d6e768
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 13 deletions.
37 changes: 33 additions & 4 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,37 @@ async fn test_udaf_returning_struct_subquery() {
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
}

#[tokio::test]
async fn test_udaf_shadows_builtin_fn() {
let TestContext {
mut ctx,
test_state,
} = TestContext::new();
let sql = "SELECT sum(arrow_cast(time, 'Int64')) from t";

// compute with builtin `sum` aggregator
let expected = [
"+-------------+",
"| SUM(t.time) |",
"+-------------+",
"| 19000 |",
"+-------------+",
];
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());

// Register `TimeSum` with name `sum`. This will shadow the builtin one
let sql = "SELECT sum(time) from t";
TimeSum::register(&mut ctx, test_state.clone(), "sum");
let expected = [
"+----------------------------+",
"| sum(t.time) |",
"+----------------------------+",
"| 1970-01-01T00:00:00.000019 |",
"+----------------------------+",
];
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
}

async fn execute(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
ctx.sql(sql).await?.collect().await
}
Expand Down Expand Up @@ -214,7 +245,7 @@ impl TestContext {
// Tell DataFusion about the "first" function
FirstSelector::register(&mut ctx);
// Tell DataFusion about the "time_sum" function
TimeSum::register(&mut ctx, Arc::clone(&test_state));
TimeSum::register(&mut ctx, Arc::clone(&test_state), "time_sum");

Self { ctx, test_state }
}
Expand Down Expand Up @@ -281,7 +312,7 @@ impl TimeSum {
Self { sum: 0, test_state }
}

fn register(ctx: &mut SessionContext, test_state: Arc<TestState>) {
fn register(ctx: &mut SessionContext, test_state: Arc<TestState>, name: &str) {
let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None);

// Returns the same type as its input
Expand All @@ -301,8 +332,6 @@ impl TimeSum {
let accumulator: AccumulatorFactoryFunction =
Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state)))));

let name = "time_sum";

let time_sum =
AggregateUDF::new(name, &signature, &return_type, &accumulator, &state_type);

Expand Down
18 changes: 9 additions & 9 deletions datafusion/sql/src/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
return Ok(expr);
}
} else {
// User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function
if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) {
let args =
self.function_args_to_expr(function.args, schema, planner_context)?;
return Ok(Expr::AggregateUDF(expr::AggregateUDF::new(
fm, args, None, None,
)));
}

// next, aggregate built-ins
if let Ok(fun) = AggregateFunction::from_str(&name) {
let distinct = function.distinct;
Expand All @@ -141,15 +150,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
)));
};

// User defined aggregate functions (UDAF)
if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) {
let args =
self.function_args_to_expr(function.args, schema, planner_context)?;
return Ok(Expr::AggregateUDF(expr::AggregateUDF::new(
fm, args, None, None,
)));
}

// Special case arrow_cast (as its type is dependent on its argument value)
if name == ARROW_CAST_NAME {
let args =
Expand Down

0 comments on commit 2d6e768

Please sign in to comment.