Skip to content

Commit

Permalink
support all args for first/last expr_fn
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Apr 20, 2024
1 parent 8edad9e commit 431df45
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 38 deletions.
1 change: 1 addition & 0 deletions datafusion/core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub use datafusion_expr::{
Expr,
};
pub use datafusion_functions::expr_fn::*;
pub use datafusion_functions_aggregate::expr_fn::*;
#[cfg(feature = "array_expressions")]
pub use datafusion_functions_array::expr_fn::*;

Expand Down
21 changes: 1 addition & 20 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at
use datafusion_common::{
arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::expr::{AggregateFunction, Sort};
use datafusion_expr::expr::Sort;
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
Expand All @@ -42,36 +42,17 @@ use std::fmt::Debug;
make_udaf_function!(
FirstValue,
first_value,
value,
"Returns the first value in a group of values.",
first_value_udaf
);

make_udaf_function!(
LastValue,
last_value,
value,
"Returns the last value in a group of values.",
last_value_udaf
);

pub fn create_first_value_expr(
args: Vec<Expr>,
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Expr>>,
null_treatment: Option<NullTreatment>,
) -> Expr {
Expr::AggregateFunction(AggregateFunction::new_udf(
first_value_udaf(),
args,
distinct,
filter,
order_by,
null_treatment,
))
}

pub struct FirstValue {
signature: Signature,
aliases: Vec<String>,
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ use std::sync::Arc;
/// Fluent-style API for creating `Expr`s
pub mod expr_fn {
pub use super::first_last::first_value;
pub use super::first_last::last_value;
}

/// Registers all enabled packages with a [`FunctionRegistry`]
Expand Down
21 changes: 13 additions & 8 deletions datafusion/functions-aggregate/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,24 @@
// under the License.

macro_rules! make_udaf_function {
($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
paste::paste! {
// "fluent expr_fn" style function
#[doc = $DOC]
pub fn $EXPR_FN($($arg: Expr),*) -> Expr {
pub fn $EXPR_FN(
args: Vec<Expr>,
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Expr>>,
null_treatment: Option<NullTreatment>
) -> Expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
$AGGREGATE_UDF_FN(),
vec![$($arg),*],
// TODO: Support arguments for `expr` API
false,
None,
None,
None,
args,
distinct,
filter,
order_by,
null_treatment,
))
}

Expand Down
10 changes: 2 additions & 8 deletions datafusion/optimizer/src/replace_distinct_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use datafusion_common::{Column, Result};
use datafusion_expr::utils::expand_wildcard;
use datafusion_expr::{col, LogicalPlanBuilder};
use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};
use datafusion_functions_aggregate::first_last::create_first_value_expr;
use datafusion_functions_aggregate::expr_fn::first_value;

/// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]]
///
Expand Down Expand Up @@ -90,13 +90,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
let aggr_expr = select_expr
.iter()
.map(|e| {
create_first_value_expr(
vec![e.clone()],
false,
None,
sort_expr.clone(),
None,
)
first_value(vec![e.clone()], false, None, sort_expr.clone(), None)
})
.collect::<Vec<Expr>>();

Expand Down
4 changes: 2 additions & 2 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ use datafusion::datasource::provider::TableProviderFactory;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::functions_aggregate::expr_fn::first_value;
use datafusion::prelude::*;
use datafusion::test_util::{TestTableFactory, TestTableProvider};
use datafusion_common::config::{FormatOptions, TableOptions};
Expand Down Expand Up @@ -613,7 +612,8 @@ async fn roundtrip_expr_api() -> Result<()> {
lit(1),
),
array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)),
first_value(lit(1)),
first_value(vec![lit(1)], true, None, None, None),
last_value(vec![lit(1)], true, None, None, None),
];

// ensure expressions created with the expr api can be round tripped
Expand Down

0 comments on commit 431df45

Please sign in to comment.