From 431df45e1a368a5db1c0844fb48b78f538765aa0 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 20 Apr 2024 08:31:44 +0800 Subject: [PATCH] support all args for first/last expr_fn Signed-off-by: jayzhan211 --- datafusion/core/src/prelude.rs | 1 + .../functions-aggregate/src/first_last.rs | 21 +------------------ datafusion/functions-aggregate/src/lib.rs | 1 + datafusion/functions-aggregate/src/macros.rs | 21 ++++++++++++------- .../src/replace_distinct_aggregate.rs | 10 ++------- .../tests/cases/roundtrip_logical_plan.rs | 4 ++-- 6 files changed, 20 insertions(+), 38 deletions(-) diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index d82a5a2cc1a1..0d8d06f49bc3 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -39,6 +39,7 @@ pub use datafusion_expr::{ Expr, }; pub use datafusion_functions::expr_fn::*; +pub use datafusion_functions_aggregate::expr_fn::*; #[cfg(feature = "array_expressions")] pub use datafusion_functions_array::expr_fn::*; diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 54db16c8ac2e..4cca1a5e5787 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -24,7 +24,7 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::{AggregateFunction, Sort}; +use datafusion_expr::expr::Sort; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; @@ -42,7 +42,6 @@ use std::fmt::Debug; make_udaf_function!( FirstValue, first_value, - value, "Returns the first value in a group of values.", first_value_udaf ); @@ -50,28 +49,10 @@ make_udaf_function!( make_udaf_function!( LastValue, last_value, - value, "Returns the last value in a group of values.", last_value_udaf ); -pub fn create_first_value_expr( - args: Vec, - distinct: bool, - filter: Option>, - order_by: Option>, - null_treatment: Option, -) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( - first_value_udaf(), - args, - distinct, - filter, - order_by, - null_treatment, - )) -} - pub struct FirstValue { signature: Signature, aliases: Vec, diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index b4dae934f3dc..0b35a2a1575f 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -66,6 +66,7 @@ use std::sync::Arc; /// Fluent-style API for creating `Expr`s pub mod expr_fn { pub use super::first_last::first_value; + pub use super::first_last::last_value; } /// Registers all enabled packages with a [`FunctionRegistry`] diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index d24c60f93270..04f9fecb8b19 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -16,19 +16,24 @@ // under the License. macro_rules! make_udaf_function { - ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + ($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { paste::paste! { // "fluent expr_fn" style function #[doc = $DOC] - pub fn $EXPR_FN($($arg: Expr),*) -> Expr { + pub fn $EXPR_FN( + args: Vec, + distinct: bool, + filter: Option>, + order_by: Option>, + null_treatment: Option + ) -> Expr { Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( $AGGREGATE_UDF_FN(), - vec![$($arg),*], - // TODO: Support arguments for `expr` API - false, - None, - None, - None, + args, + distinct, + filter, + order_by, + null_treatment, )) } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index cb7ecf0669ae..dddc0ea371b0 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -23,7 +23,7 @@ use datafusion_common::{Column, Result}; use datafusion_expr::utils::expand_wildcard; use datafusion_expr::{col, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; -use datafusion_functions_aggregate::first_last::create_first_value_expr; +use datafusion_functions_aggregate::expr_fn::first_value; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] /// @@ -90,13 +90,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { let aggr_expr = select_expr .iter() .map(|e| { - create_first_value_expr( - vec![e.clone()], - false, - None, - sort_expr.clone(), - None, - ) + first_value(vec![e.clone()], false, None, sort_expr.clone(), None) }) .collect::>(); diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index eee15008fbbb..6f11f90ded7a 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -30,7 +30,6 @@ use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use datafusion::functions_aggregate::expr_fn::first_value; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; @@ -613,7 +612,8 @@ async fn roundtrip_expr_api() -> Result<()> { lit(1), ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), - first_value(lit(1)), + first_value(vec![lit(1)], true, None, None, None), + last_value(vec![lit(1)], true, None, None, None), ]; // ensure expressions created with the expr api can be round tripped