From 9f160e43edb3da0208597c7922c8a4ecd00f6cd8 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 20 Apr 2024 15:26:03 +0800 Subject: [PATCH] support more args for udaf Signed-off-by: jayzhan211 --- datafusion/functions-aggregate/Cargo.toml | 1 + .../functions-aggregate/src/first_last.rs | 2 +- datafusion/functions-aggregate/src/macros.rs | 23 +++++++++++-------- .../tests/cases/roundtrip_logical_plan.rs | 2 +- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index be354acb4851..f97647565364 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -45,3 +45,4 @@ datafusion-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = "1.0.14" +sqlparser = { workspace = true } diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index d5367ad34163..76827f7e6716 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -39,11 +39,11 @@ use datafusion_physical_expr_common::utils::reverse_order_bys; use std::any::Any; use std::fmt::Debug; use std::sync::Arc; +use sqlparser::ast::NullTreatment; make_udaf_function!( FirstValue, first_value, - value, "Returns the first value in a group of values.", first_value_udaf ); diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index d24c60f93270..c806392daa98 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -16,19 +16,24 @@ // under the License. macro_rules! make_udaf_function { - ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + ($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { paste::paste! { // "fluent expr_fn" style function #[doc = $DOC] - pub fn $EXPR_FN($($arg: Expr),*) -> Expr { + pub fn $EXPR_FN( + args: Vec, + distinct: bool, + filter: Option>, + order_by: Option>, + null_treatment: Option + ) -> Expr { Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( $AGGREGATE_UDF_FN(), - vec![$($arg),*], - // TODO: Support arguments for `expr` API - false, - None, - None, - None, + args, + distinct, + filter, + order_by, + null_treatment, )) } @@ -50,4 +55,4 @@ macro_rules! make_udaf_function { } } } -} +} \ No newline at end of file diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index eee15008fbbb..f97559e03af2 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -613,7 +613,7 @@ async fn roundtrip_expr_api() -> Result<()> { lit(1), ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), - first_value(lit(1)), + first_value(vec![lit(1)], false, None, None, None), ]; // ensure expressions created with the expr api can be round tripped