From 886e8accdaa85d7b3dca45340b955437786a9b6a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 24 Jul 2024 19:54:57 -0400 Subject: [PATCH] Consistent API to set parameters of aggregate and window functions (`AggregateExt` --> `ExprFunctionExt`) (#11550) * Moving over AggregateExt to ExprFunctionExt and adding in function settings for window functions * Switch WindowFrame to only need the window function definition and arguments. Other parameters will be set via the ExprFuncBuilder * Changing null_treatment to take an option, but this is mostly for code cleanliness and not strictly required * Moving functions in ExprFuncBuilder over to be explicitly implementing ExprFunctionExt trait so we can guarantee a consistent user experience no matter which they call on the Expr and which on the builder * Apply cargo fmt * Add deprecated trait AggregateExt so that users get a warning but still builds * Window helper functions should return Expr * Update documentation to show window function example * Add license info * Update comments that are no longer applicable * Remove first_value and last_value since these are already implemented in the aggregate functions * Update to use WindowFunction::new to set additional parameters for order_by using ExprFunctionExt * Apply cargo fmt * Fix up clippy * fix doc example * fmt * doc tweaks * more doc tweaks * fix up links * fix integration test * fix anothr doc example --------- Co-authored-by: Tim Saucer Co-authored-by: Andrew Lamb --- datafusion-examples/examples/advanced_udwf.rs | 12 +- datafusion-examples/examples/expr_api.rs | 4 +- datafusion-examples/examples/simple_udwf.rs | 12 +- datafusion/core/src/dataframe/mod.rs | 13 +- datafusion/core/tests/dataframe/mod.rs | 22 +- datafusion/core/tests/expr_api/mod.rs | 2 +- datafusion/expr/src/expr.rs | 85 ++++-- datafusion/expr/src/expr_fn.rs | 279 +++++++++++++++++- datafusion/expr/src/lib.rs | 3 +- datafusion/expr/src/tree_node.rs | 17 +- datafusion/expr/src/udaf.rs | 177 +---------- datafusion/expr/src/udwf.rs | 47 ++- datafusion/expr/src/utils.rs | 89 +++--- datafusion/expr/src/window_function.rs | 99 +++++++ .../functions-aggregate/src/first_last.rs | 4 +- .../src/analyzer/count_wildcard_rule.rs | 18 +- .../optimizer/src/analyzer/type_coercion.rs | 21 +- .../optimizer/src/optimize_projections/mod.rs | 17 +- .../src/replace_distinct_aggregate.rs | 2 +- .../simplify_expressions/expr_simplifier.rs | 24 +- .../src/single_distinct_to_groupby.rs | 2 +- .../proto/src/logical_plan/from_proto.rs | 46 +-- .../tests/cases/roundtrip_logical_plan.rs | 77 ++--- datafusion/sql/src/expr/function.rs | 25 +- datafusion/sql/src/unparser/expr.rs | 2 +- docs/source/user-guide/expressions.md | 2 +- 26 files changed, 657 insertions(+), 444 deletions(-) create mode 100644 datafusion/expr/src/window_function.rs diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 11fb6f6ccc48..ec0318a561b9 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -216,12 +216,12 @@ async fn main() -> Result<()> { df.show().await?; // Now, run the function using the DataFrame API: - let window_expr = smooth_it.call( - vec![col("speed")], // smooth_it(speed) - vec![col("car")], // PARTITION BY car - vec![col("time").sort(true, true)], // ORDER BY time ASC - WindowFrame::new(None), - ); + let window_expr = smooth_it + .call(vec![col("speed")]) // smooth_it(speed) + .partition_by(vec![col("car")]) // PARTITION BY car + .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC + .window_frame(WindowFrame::new(None)) + .build()?; let df = ctx.table("cars").await?.window(vec![window_expr])?; // print the results diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index a48171c625a8..0eb823302acf 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -33,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator}; +use datafusion_expr::{ColumnarValue, ExprFunctionExt, ExprSchemable, Operator}; /// This example demonstrates the DataFusion [`Expr`] API. /// @@ -95,7 +95,7 @@ fn expr_fn_demo() -> Result<()> { let agg = first_value.call(vec![col("price")]); assert_eq!(agg.to_string(), "first_value(price)"); - // You can use the AggregateExt trait to create more complex aggregates + // You can use the ExprFunctionExt trait to create more complex aggregates // such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts ) let agg = first_value .call(vec![col("price")]) diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs index 563f02cee6a6..22dfbbbf0c3a 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/simple_udwf.rs @@ -118,12 +118,12 @@ async fn main() -> Result<()> { df.show().await?; // Now, run the function using the DataFrame API: - let window_expr = smooth_it.call( - vec![col("speed")], // smooth_it(speed) - vec![col("car")], // PARTITION BY car - vec![col("time").sort(true, true)], // ORDER BY time ASC - WindowFrame::new(None), - ); + let window_expr = smooth_it + .call(vec![col("speed")]) // smooth_it(speed) + .partition_by(vec![col("car")]) // PARTITION BY car + .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC + .window_frame(WindowFrame::new(None)) + .build()?; let df = ctx.table("cars").await?.window(vec![window_expr])?; // print the results diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index fb28b5c1ab47..ea437cc99a33 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1696,8 +1696,8 @@ mod tests { use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation, - Volatility, WindowFrame, WindowFunctionDefinition, + cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt, + ScalarFunctionImplementation, Volatility, WindowFunctionDefinition, }; use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct}; use datafusion_physical_expr::expressions::Column; @@ -1867,11 +1867,10 @@ mod tests { BuiltInWindowFunction::FirstValue, ), vec![col("aggregate_test_100.c1")], - vec![col("aggregate_test_100.c2")], - vec![], - WindowFrame::new(None), - None, - )); + )) + .partition_by(vec![col("aggregate_test_100.c2")]) + .build() + .unwrap(); let t2 = t.select(vec![col("c1"), first_row])?; let plan = t2.plan.clone(); diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index bc01ada1e04b..d83a47ceb069 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -55,8 +55,8 @@ use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, - scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunctionDefinition, + scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum}; @@ -183,15 +183,15 @@ async fn test_count_wildcard_on_window() -> Result<()> { .select(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], - vec![], - vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], - WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - ), - None, - ))])? + )) + .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build() + .unwrap()])? .explain(false, false)? .collect() .await?; diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 37d06355d2d3..051d65652633 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -21,7 +21,7 @@ use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray}; use arrow_schema::{DataType, Field}; use datafusion::prelude::*; use datafusion_common::{assert_contains, DFSchema, ScalarValue}; -use datafusion_expr::AggregateExt; +use datafusion_expr::ExprFunctionExt; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::first_last::first_value_udaf; use datafusion_functions_aggregate::sum::sum_udaf; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 452c05be34f4..68d5504eea48 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -28,8 +28,8 @@ use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::expr_to_columns; use crate::{ - aggregate_function, built_in_window_function, udaf, ExprSchemable, Operator, - Signature, + aggregate_function, built_in_window_function, udaf, BuiltInWindowFunction, + ExprSchemable, Operator, Signature, WindowFrame, WindowUDF, }; use crate::{window_frame, Volatility}; @@ -60,6 +60,10 @@ use sqlparser::ast::NullTreatment; /// use the fluent APIs in [`crate::expr_fn`] such as [`col`] and [`lit`], or /// methods such as [`Expr::alias`], [`Expr::cast_to`], and [`Expr::Like`]). /// +/// See also [`ExprFunctionExt`] for creating aggregate and window functions. +/// +/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt +/// /// # Schema Access /// /// See [`ExprSchemable::get_type`] to access the [`DataType`] and nullability @@ -283,15 +287,17 @@ pub enum Expr { /// This expression is guaranteed to have a fixed type. TryCast(TryCast), /// A sort expression, that can be used to sort values. + /// + /// See [Expr::sort] for more details Sort(Sort), /// Represents the call of a scalar function with a set of arguments. ScalarFunction(ScalarFunction), /// Calls an aggregate function with arguments, and optional /// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`. /// - /// See also [`AggregateExt`] to set these fields. + /// See also [`ExprFunctionExt`] to set these fields. /// - /// [`AggregateExt`]: crate::udaf::AggregateExt + /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. WindowFunction(WindowFunction), @@ -641,9 +647,9 @@ impl AggregateFunctionDefinition { /// Aggregate function /// -/// See also [`AggregateExt`] to set these fields on `Expr` +/// See also [`ExprFunctionExt`] to set these fields on `Expr` /// -/// [`AggregateExt`]: crate::udaf::AggregateExt +/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateFunction { /// Name of the function @@ -769,7 +775,52 @@ impl fmt::Display for WindowFunctionDefinition { } } +impl From for WindowFunctionDefinition { + fn from(value: aggregate_function::AggregateFunction) -> Self { + Self::AggregateFunction(value) + } +} + +impl From for WindowFunctionDefinition { + fn from(value: BuiltInWindowFunction) -> Self { + Self::BuiltInWindowFunction(value) + } +} + +impl From> for WindowFunctionDefinition { + fn from(value: Arc) -> Self { + Self::AggregateUDF(value) + } +} + +impl From> for WindowFunctionDefinition { + fn from(value: Arc) -> Self { + Self::WindowUDF(value) + } +} + /// Window function +/// +/// Holds the actual actual function to call [`WindowFunction`] as well as its +/// arguments (`args`) and the contents of the `OVER` clause: +/// +/// 1. `PARTITION BY` +/// 2. `ORDER BY` +/// 3. Window frame (e.g. `ROWS 1 PRECEDING AND 1 FOLLOWING`) +/// +/// # Example +/// ``` +/// # use datafusion_expr::{Expr, BuiltInWindowFunction, col, ExprFunctionExt}; +/// # use datafusion_expr::expr::WindowFunction; +/// // Create FIRST_VALUE(a) OVER (PARTITION BY b ORDER BY c) +/// let expr = Expr::WindowFunction( +/// WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![col("a")]) +/// ) +/// .partition_by(vec![col("b")]) +/// .order_by(vec![col("b").sort(true, true)]) +/// .build() +/// .unwrap(); +/// ``` #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct WindowFunction { /// Name of the function @@ -787,22 +838,16 @@ pub struct WindowFunction { } impl WindowFunction { - /// Create a new Window expression - pub fn new( - fun: WindowFunctionDefinition, - args: Vec, - partition_by: Vec, - order_by: Vec, - window_frame: window_frame::WindowFrame, - null_treatment: Option, - ) -> Self { + /// Create a new Window expression with the specified argument an + /// empty `OVER` clause + pub fn new(fun: impl Into, args: Vec) -> Self { Self { - fun, + fun: fun.into(), args, - partition_by, - order_by, - window_frame, - null_treatment, + partition_by: Vec::default(), + order_by: Vec::default(), + window_frame: WindowFrame::new(None), + null_treatment: None, } } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 9187e8352205..1f51cded2239 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -19,7 +19,7 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - Placeholder, TryCast, Unnest, + Placeholder, TryCast, Unnest, WindowFunction, }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, @@ -30,12 +30,15 @@ use crate::{ AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; -use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; +use crate::{ + AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, +}; use arrow::compute::kernels::cast_utils::{ parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, }; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{Column, Result, ScalarValue}; +use datafusion_common::{plan_err, Column, Result, ScalarValue}; +use sqlparser::ast::NullTreatment; use std::any::Any; use std::fmt::Debug; use std::ops::Not; @@ -664,6 +667,276 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) } +/// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] +/// +/// Adds methods to [`Expr`] that make it easy to set optional options +/// such as `ORDER BY`, `FILTER` and `DISTINCT` +/// +/// # Example +/// ```no_run +/// # use datafusion_common::Result; +/// # use datafusion_expr::test::function_stub::count; +/// # use sqlparser::ast::NullTreatment; +/// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col}; +/// # use datafusion_expr::window_function::percent_rank; +/// # // first_value is an aggregate function in another crate +/// # fn first_value(_arg: Expr) -> Expr { +/// unimplemented!() } +/// # fn main() -> Result<()> { +/// // Create an aggregate count, filtering on column y > 5 +/// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?; +/// +/// // Find the first value in an aggregate sorted by column y +/// // equivalent to: +/// // `FIRST_VALUE(x ORDER BY y ASC IGNORE NULLS)` +/// let sort_expr = col("y").sort(true, true); +/// let agg = first_value(col("x")) +/// .order_by(vec![sort_expr]) +/// .null_treatment(NullTreatment::IgnoreNulls) +/// .build()?; +/// +/// // Create a window expression for percent rank partitioned on column a +/// // equivalent to: +/// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)` +/// let window = percent_rank() +/// .partition_by(vec![col("a")]) +/// .order_by(vec![col("b").sort(true, true)]) +/// .null_treatment(NullTreatment::IgnoreNulls) +/// .build()?; +/// # Ok(()) +/// # } +/// ``` +pub trait ExprFunctionExt { + /// Add `ORDER BY ` + /// + /// Note: `order_by` must be [`Expr::Sort`] + fn order_by(self, order_by: Vec) -> ExprFuncBuilder; + /// Add `FILTER ` + fn filter(self, filter: Expr) -> ExprFuncBuilder; + /// Add `DISTINCT` + fn distinct(self) -> ExprFuncBuilder; + /// Add `RESPECT NULLS` or `IGNORE NULLS` + fn null_treatment( + self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder; + /// Add `PARTITION BY` + fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder; + /// Add appropriate window frame conditions + fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder; +} + +#[derive(Debug, Clone)] +pub enum ExprFuncKind { + Aggregate(AggregateFunction), + Window(WindowFunction), +} + +/// Implementation of [`ExprFunctionExt`]. +/// +/// See [`ExprFunctionExt`] for usage and examples +#[derive(Debug, Clone)] +pub struct ExprFuncBuilder { + fun: Option, + order_by: Option>, + filter: Option, + distinct: bool, + null_treatment: Option, + partition_by: Option>, + window_frame: Option, +} + +impl ExprFuncBuilder { + /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`] + fn new(fun: Option) -> Self { + Self { + fun, + order_by: None, + filter: None, + distinct: false, + null_treatment: None, + partition_by: None, + window_frame: None, + } + } + + /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] + /// + /// # Errors: + /// + /// Returns an error if this builder [`ExprFunctionExt`] was used with an + /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] + pub fn build(self) -> Result { + let Self { + fun, + order_by, + filter, + distinct, + null_treatment, + partition_by, + window_frame, + } = self; + + let Some(fun) = fun else { + return plan_err!( + "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction" + ); + }; + + if let Some(order_by) = &order_by { + for expr in order_by.iter() { + if !matches!(expr, Expr::Sort(_)) { + return plan_err!( + "ORDER BY expressions must be Expr::Sort, found {expr:?}" + ); + } + } + } + + let fun_expr = match fun { + ExprFuncKind::Aggregate(mut udaf) => { + udaf.order_by = order_by; + udaf.filter = filter.map(Box::new); + udaf.distinct = distinct; + udaf.null_treatment = null_treatment; + Expr::AggregateFunction(udaf) + } + ExprFuncKind::Window(mut udwf) => { + let has_order_by = order_by.as_ref().map(|o| !o.is_empty()); + udwf.order_by = order_by.unwrap_or_default(); + udwf.partition_by = partition_by.unwrap_or_default(); + udwf.window_frame = + window_frame.unwrap_or(WindowFrame::new(has_order_by)); + udwf.null_treatment = null_treatment; + Expr::WindowFunction(udwf) + } + }; + + Ok(fun_expr) + } +} + +impl ExprFunctionExt for ExprFuncBuilder { + /// Add `ORDER BY ` + /// + /// Note: `order_by` must be [`Expr::Sort`] + fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { + self.order_by = Some(order_by); + self + } + + /// Add `FILTER ` + fn filter(mut self, filter: Expr) -> ExprFuncBuilder { + self.filter = Some(filter); + self + } + + /// Add `DISTINCT` + fn distinct(mut self) -> ExprFuncBuilder { + self.distinct = true; + self + } + + /// Add `RESPECT NULLS` or `IGNORE NULLS` + fn null_treatment( + mut self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder { + self.null_treatment = null_treatment.into(); + self + } + + fn partition_by(mut self, partition_by: Vec) -> ExprFuncBuilder { + self.partition_by = Some(partition_by); + self + } + + fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder { + self.window_frame = Some(window_frame); + self + } +} + +impl ExprFunctionExt for Expr { + fn order_by(self, order_by: Vec) -> ExprFuncBuilder { + let mut builder = match self { + Expr::AggregateFunction(udaf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) + } + Expr::WindowFunction(udwf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + } + _ => ExprFuncBuilder::new(None), + }; + if builder.fun.is_some() { + builder.order_by = Some(order_by); + } + builder + } + fn filter(self, filter: Expr) -> ExprFuncBuilder { + match self { + Expr::AggregateFunction(udaf) => { + let mut builder = + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); + builder.filter = Some(filter); + builder + } + _ => ExprFuncBuilder::new(None), + } + } + fn distinct(self) -> ExprFuncBuilder { + match self { + Expr::AggregateFunction(udaf) => { + let mut builder = + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); + builder.distinct = true; + builder + } + _ => ExprFuncBuilder::new(None), + } + } + fn null_treatment( + self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder { + let mut builder = match self { + Expr::AggregateFunction(udaf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) + } + Expr::WindowFunction(udwf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + } + _ => ExprFuncBuilder::new(None), + }; + if builder.fun.is_some() { + builder.null_treatment = null_treatment.into(); + } + builder + } + + fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { + match self { + Expr::WindowFunction(udwf) => { + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + builder.partition_by = Some(partition_by); + builder + } + _ => ExprFuncBuilder::new(None), + } + } + + fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { + match self { + Expr::WindowFunction(udwf) => { + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + builder.window_frame = Some(window_frame); + builder + } + _ => ExprFuncBuilder::new(None), + } + } +} + #[cfg(test)] mod test { use super::*; diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index e1943c890e7c..0a5cf4653a22 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -60,6 +60,7 @@ pub mod type_coercion; pub mod utils; pub mod var_provider; pub mod window_frame; +pub mod window_function; pub mod window_state; pub use accumulator::Accumulator; @@ -86,7 +87,7 @@ pub use signature::{ }; pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateExt, AggregateUDF, AggregateUDFImpl, ReversedUDAF}; +pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f1df8609f903..a97b9f010f79 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -22,7 +22,7 @@ use crate::expr::{ Cast, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, }; -use crate::Expr; +use crate::{Expr, ExprFunctionExt}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, @@ -294,14 +294,13 @@ impl TreeNode for Expr { transform_vec(order_by, &mut f) )? .update_data(|(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new( - fun, - new_args, - new_partition_by, - new_order_by, - window_frame, - null_treatment, - )) + Expr::WindowFunction(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() }), Expr::AggregateFunction(AggregateFunction { args, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 2851ca811e0c..8867a478f790 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -24,9 +24,8 @@ use std::sync::Arc; use std::vec; use arrow::datatypes::{DataType, Field}; -use sqlparser::ast::NullTreatment; -use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; +use datafusion_common::{exec_err, not_impl_err, Result}; use crate::expr::AggregateFunction; use crate::function::{ @@ -655,177 +654,3 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { (self.accumulator)(acc_args) } } - -/// Extensions for configuring [`Expr::AggregateFunction`] -/// -/// Adds methods to [`Expr`] that make it easy to set optional aggregate options -/// such as `ORDER BY`, `FILTER` and `DISTINCT` -/// -/// # Example -/// ```no_run -/// # use datafusion_common::Result; -/// # use datafusion_expr::{AggregateUDF, col, Expr, lit}; -/// # use sqlparser::ast::NullTreatment; -/// # fn count(arg: Expr) -> Expr { todo!{} } -/// # fn first_value(arg: Expr) -> Expr { todo!{} } -/// # fn main() -> Result<()> { -/// use datafusion_expr::AggregateExt; -/// -/// // Create COUNT(x FILTER y > 5) -/// let agg = count(col("x")) -/// .filter(col("y").gt(lit(5))) -/// .build()?; -/// // Create FIRST_VALUE(x ORDER BY y IGNORE NULLS) -/// let sort_expr = col("y").sort(true, true); -/// let agg = first_value(col("x")) -/// .order_by(vec![sort_expr]) -/// .null_treatment(NullTreatment::IgnoreNulls) -/// .build()?; -/// # Ok(()) -/// # } -/// ``` -pub trait AggregateExt { - /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - fn order_by(self, order_by: Vec) -> AggregateBuilder; - /// Add `FILTER ` - fn filter(self, filter: Expr) -> AggregateBuilder; - /// Add `DISTINCT` - fn distinct(self) -> AggregateBuilder; - /// Add `RESPECT NULLS` or `IGNORE NULLS` - fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder; -} - -/// Implementation of [`AggregateExt`]. -/// -/// See [`AggregateExt`] for usage and examples -#[derive(Debug, Clone)] -pub struct AggregateBuilder { - udaf: Option, - order_by: Option>, - filter: Option, - distinct: bool, - null_treatment: Option, -} - -impl AggregateBuilder { - /// Create a new `AggregateBuilder`, see [`AggregateExt`] - - fn new(udaf: Option) -> Self { - Self { - udaf, - order_by: None, - filter: None, - distinct: false, - null_treatment: None, - } - } - - /// Updates and returns the in progress [`Expr::AggregateFunction`] - /// - /// # Errors: - /// - /// Returns an error of this builder [`AggregateExt`] was used with an - /// `Expr` variant other than [`Expr::AggregateFunction`] - pub fn build(self) -> Result { - let Self { - udaf, - order_by, - filter, - distinct, - null_treatment, - } = self; - - let Some(mut udaf) = udaf else { - return plan_err!( - "AggregateExt can only be used with Expr::AggregateFunction" - ); - }; - - if let Some(order_by) = &order_by { - for expr in order_by.iter() { - if !matches!(expr, Expr::Sort(_)) { - return plan_err!( - "ORDER BY expressions must be Expr::Sort, found {expr:?}" - ); - } - } - } - - udaf.order_by = order_by; - udaf.filter = filter.map(Box::new); - udaf.distinct = distinct; - udaf.null_treatment = null_treatment; - Ok(Expr::AggregateFunction(udaf)) - } - - /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - pub fn order_by(mut self, order_by: Vec) -> AggregateBuilder { - self.order_by = Some(order_by); - self - } - - /// Add `FILTER ` - pub fn filter(mut self, filter: Expr) -> AggregateBuilder { - self.filter = Some(filter); - self - } - - /// Add `DISTINCT` - pub fn distinct(mut self) -> AggregateBuilder { - self.distinct = true; - self - } - - /// Add `RESPECT NULLS` or `IGNORE NULLS` - pub fn null_treatment(mut self, null_treatment: NullTreatment) -> AggregateBuilder { - self.null_treatment = Some(null_treatment); - self - } -} - -impl AggregateExt for Expr { - fn order_by(self, order_by: Vec) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.order_by = Some(order_by); - builder - } - _ => AggregateBuilder::new(None), - } - } - fn filter(self, filter: Expr) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.filter = Some(filter); - builder - } - _ => AggregateBuilder::new(None), - } - } - fn distinct(self) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.distinct = true; - builder - } - _ => AggregateBuilder::new(None), - } - } - fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.null_treatment = Some(null_treatment); - builder - } - _ => AggregateBuilder::new(None), - } - } -} diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 1a6b21e3dd29..5abce013dfb6 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -28,9 +28,10 @@ use arrow::datatypes::DataType; use datafusion_common::Result; +use crate::expr::WindowFunction; use crate::{ function::WindowFunctionSimplification, Expr, PartitionEvaluator, - PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, + PartitionEvaluatorFactory, ReturnTypeFunction, Signature, }; /// Logical representation of a user-defined window function (UDWF) @@ -123,28 +124,19 @@ impl WindowUDF { Self::new_from_impl(AliasedWindowUDFImpl::new(Arc::clone(&self.inner), aliases)) } - /// creates a [`Expr`] that calls the window function given - /// the `partition_by`, `order_by`, and `window_frame` definition + /// creates a [`Expr`] that calls the window function with default + /// values for `order_by`, `partition_by`, `window_frame`. /// - /// This utility allows using the UDWF without requiring access to - /// the registry, such as with the DataFrame API. - pub fn call( - &self, - args: Vec, - partition_by: Vec, - order_by: Vec, - window_frame: WindowFrame, - ) -> Expr { + /// See [`ExprFunctionExt`] for details on setting these values. + /// + /// This utility allows using a user defined window function without + /// requiring access to the registry, such as with the DataFrame API. + /// + /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt + pub fn call(&self, args: Vec) -> Expr { let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); - Expr::WindowFunction(crate::expr::WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: None, - }) + Expr::WindowFunction(WindowFunction::new(fun, args)) } /// Returns this function's name @@ -210,7 +202,7 @@ where /// # use std::any::Any; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame}; +/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame, ExprFunctionExt}; /// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; /// #[derive(Debug, Clone)] /// struct SmoothIt { @@ -244,12 +236,13 @@ where /// let smooth_it = WindowUDF::from(SmoothIt::new()); /// /// // Call the function `add_one(col)` -/// let expr = smooth_it.call( -/// vec![col("speed")], // smooth_it(speed) -/// vec![col("car")], // PARTITION BY car -/// vec![col("time").sort(true, true)], // ORDER BY time ASC -/// WindowFrame::new(None), -/// ); +/// // smooth_it(speed) OVER (PARTITION BY car ORDER BY time ASC) +/// let expr = smooth_it.call(vec![col("speed")]) +/// .partition_by(vec![col("car")]) +/// .order_by(vec![col("time").sort(true, true)]) +/// .window_frame(WindowFrame::new(None)) +/// .build() +/// .unwrap(); /// ``` pub trait WindowUDFImpl: Debug + Send + Sync { /// Returns this object as an [`Any`] trait object diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 889aa0952e51..2ef1597abfd1 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1253,8 +1253,8 @@ mod tests { use super::*; use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, - test::function_stub::sum_udaf, AggregateFunction, Cast, WindowFrame, - WindowFunctionDefinition, + test::function_stub::sum_udaf, AggregateFunction, Cast, ExprFunctionExt, + WindowFrame, WindowFunctionDefinition, }; #[test] @@ -1270,34 +1270,18 @@ mod tests { let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], - vec![], - vec![], - WindowFrame::new(None), - None, )); let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; @@ -1317,35 +1301,32 @@ mod tests { let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], - vec![], - vec![age_asc.clone(), name_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); + )) + .order_by(vec![age_asc.clone(), name_desc.clone()]) + .build() + .unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], - vec![], - vec![age_asc.clone(), name_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); + )) + .order_by(vec![age_asc.clone(), name_desc.clone()]) + .build() + .unwrap(); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], - vec![], - vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); + )) + .order_by(vec![ + name_desc.clone(), + age_asc.clone(), + created_at_desc.clone(), + ]) + .build() + .unwrap(); // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; @@ -1373,26 +1354,26 @@ mod tests { Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], - vec![], - vec![ - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - ], - WindowFrame::new(Some(false)), - None, - )), + )) + .order_by(vec![ + Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), + Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), + ]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(), Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], - vec![], - vec![ - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ], - WindowFrame::new(Some(false)), - None, - )), + )) + .order_by(vec![ + Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), + Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), + Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), + ]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(), ]; let expected = vec![ Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs new file mode 100644 index 000000000000..5e81464d39c2 --- /dev/null +++ b/datafusion/expr/src/window_function.rs @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::ScalarValue; + +use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; + +/// Create an expression to represent the `row_number` window function +pub fn row_number() -> Expr { + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::RowNumber, + vec![], + )) +} + +/// Create an expression to represent the `rank` window function +pub fn rank() -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Rank, vec![])) +} + +/// Create an expression to represent the `dense_rank` window function +pub fn dense_rank() -> Expr { + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::DenseRank, + vec![], + )) +} + +/// Create an expression to represent the `percent_rank` window function +pub fn percent_rank() -> Expr { + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::PercentRank, + vec![], + )) +} + +/// Create an expression to represent the `cume_dist` window function +pub fn cume_dist() -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::CumeDist, vec![])) +} + +/// Create an expression to represent the `ntile` window function +pub fn ntile(arg: Expr) -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg])) +} + +/// Create an expression to represent the `lag` window function +pub fn lag( + arg: Expr, + shift_offset: Option, + default_value: Option, +) -> Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::Lag, + vec![arg, shift_offset_lit, default_lit], + )) +} + +/// Create an expression to represent the `lead` window function +pub fn lead( + arg: Expr, + shift_offset: Option, + default_value: Option, +) -> Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::Lead, + vec![arg, shift_offset_lit, default_lit], + )) +} + +/// Create an expression to represent the `nth_value` window function +pub fn nth_value(arg: Expr, n: i64) -> Expr { + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::NthValue, + vec![arg, n.lit()], + )) +} diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index ba11f7e91e07..8969937d377c 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -31,8 +31,8 @@ use datafusion_common::{ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateExt, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, - TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, ExprFunctionExt, + Signature, TypeSignature, Volatility, }; use datafusion_physical_expr_common::aggregate::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{ diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index fa8aeb86ed31..338268e299da 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -101,6 +101,7 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; + use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, @@ -223,15 +224,14 @@ mod tests { .window(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], - vec![], - vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], - WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - ), - None, - ))])? + )) + .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build()?])? .project(vec![count(wildcard())])? .build()?; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 50fb1b8193ce..75dbb4d1adcd 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -47,8 +47,9 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - type_coercion, AggregateFunction, AggregateUDF, Expr, ExprSchemable, LogicalPlan, - Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, + type_coercion, AggregateFunction, AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, + LogicalPlan, Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; use crate::analyzer::AnalyzerRule; @@ -466,14 +467,14 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { _ => args, }; - Ok(Transformed::yes(Expr::WindowFunction(WindowFunction::new( - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - )))) + Ok(Transformed::yes( + Expr::WindowFunction(WindowFunction::new(fun, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build()?, + )) } Expr::Alias(_) | Expr::Column(_) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 58c1ae297b02..16abf93f3807 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -806,7 +806,7 @@ mod tests { use datafusion_common::{ Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, }; - use datafusion_expr::AggregateExt; + use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ binary_expr, build_join_schema, builder::table_scan_with_filters, @@ -815,7 +815,7 @@ mod tests { lit, logical_plan::{builder::LogicalPlanBuilder, table_scan}, max, min, not, try_cast, when, AggregateFunction, BinaryExpr, Expr, Extension, - Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFrame, + Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition, }; @@ -1919,19 +1919,14 @@ mod tests { let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.a")], - vec![col("test.b")], - vec![], - WindowFrame::new(None), - None, - )); + )) + .partition_by(vec![col("test.b")]) + .build() + .unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.b")], - vec![], - vec![], - WindowFrame::new(None), - None, )); let col1 = col(max1.display_name()?); let col2 = col(max2.display_name()?); diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index fcd33be618f7..430517121f2a 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -23,7 +23,7 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{col, AggregateExt, LogicalPlanBuilder}; +use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 56556f387d1b..38dfbb3ed551 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -3855,15 +3855,9 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( - udwf, - vec![], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + let window_function_expr = Expr::WindowFunction( + datafusion_expr::expr::WindowFunction::new(udwf, vec![]), + ); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -3871,15 +3865,9 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( - udwf, - vec![], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + let window_function_expr = Expr::WindowFunction( + datafusion_expr::expr::WindowFunction::new(udwf, vec![]), + ); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index f2b4abdd6cbd..d776e6598cbe 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -354,7 +354,7 @@ mod tests { use super::*; use crate::test::*; use datafusion_expr::expr::{self, GroupingSet}; - use datafusion_expr::AggregateExt; + use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction, }; diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index aea8e454a31c..7b717add3311 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -25,6 +25,7 @@ use datafusion_common::{ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; +use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ expr::{self, InList, Sort, WindowFunction}, logical_plan::{PlanType, StringifiedPlan}, @@ -299,7 +300,6 @@ pub fn parse_expr( ) })?; // TODO: support proto for null treatment - let null_treatment = None; regularize_window_order_by(&window_frame, &mut order_by)?; match window_function { @@ -314,11 +314,12 @@ pub fn parse_expr( "expr", codec, )?], - partition_by, - order_by, - window_frame, - None, - ))) + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } window_expr_node::WindowFunction::BuiltInFunction(i) => { let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i) @@ -335,11 +336,12 @@ pub fn parse_expr( built_in_function, ), args, - partition_by, - order_by, - window_frame, - null_treatment, - ))) + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { @@ -354,11 +356,12 @@ pub fn parse_expr( Ok(Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, - partition_by, - order_by, - window_frame, - None, - ))) + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { @@ -373,11 +376,12 @@ pub fn parse_expr( Ok(Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, - partition_by, - order_by, - window_frame, - None, - ))) + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 25223c3731be..7a4de4f61a38 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -60,7 +60,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateExt, AggregateFunction, AggregateUDF, ColumnarValue, + Accumulator, AggregateFunction, AggregateUDF, ColumnarValue, ExprFunctionExt, ExprSchemable, Literal, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, @@ -2073,11 +2073,12 @@ fn roundtrip_window() { datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], - vec![col("col1")], - vec![col("col2")], - WindowFrame::new(Some(false)), - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, false)]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(); // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( @@ -2085,11 +2086,12 @@ fn roundtrip_window() { datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], - vec![col("col1")], - vec![col("col2")], - WindowFrame::new(Some(false)), - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(false, true)]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(); // 3. with window_frame with row numbers let range_number_frame = WindowFrame::new_bounds( @@ -2103,11 +2105,12 @@ fn roundtrip_window() { datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], - vec![col("col1")], - vec![col("col2")], - range_number_frame, - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(false, false)]) + .window_frame(range_number_frame) + .build() + .unwrap(); // 4. test with AggregateFunction let row_number_frame = WindowFrame::new_bounds( @@ -2119,11 +2122,12 @@ fn roundtrip_window() { let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, true)]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); // 5. test with AggregateUDF #[derive(Debug)] @@ -2168,11 +2172,12 @@ fn roundtrip_window() { let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, true)]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); ctx.register_udaf(dummy_agg); // 6. test with WindowUDF @@ -2244,20 +2249,20 @@ fn roundtrip_window() { let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, true)]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), vec![col("col1")], - vec![], - vec![], - row_number_frame.clone(), - None, - )); + )) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); ctx.register_udwf(dummy_window_udf); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 0c4b125e76d0..fd759c161381 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -24,7 +24,8 @@ use datafusion_common::{ use datafusion_expr::planner::PlannerResult; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition, + expr, AggregateFunction, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, + WindowFunctionDefinition, }; use datafusion_expr::{ expr::{ScalarFunction, Unnest}, @@ -329,20 +330,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(aggregate_fun), args, - partition_by, - order_by, - window_frame, - null_treatment, )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() } _ => Expr::WindowFunction(expr::WindowFunction::new( fun, self.function_args_to_expr(args, schema, planner_context)?, - partition_by, - order_by, - window_frame, - null_treatment, - )), + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap(), }; return Ok(expr); } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index f4ea44f37d78..3f7a85da276b 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1507,7 +1507,7 @@ mod tests { table_scan, try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, }; - use datafusion_expr::{interval_month_day_nano_lit, AggregateExt}; + use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 6e693a0e7087..60036e440ffb 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -308,7 +308,7 @@ select log(-1), log(0), sqrt(-1); ## Aggregate Function Builder -You can also use the `AggregateExt` trait to more easily build Aggregate arguments `Expr`. +You can also use the `ExprFunctionExt` trait to more easily build Aggregate arguments `Expr`. See `datafusion-examples/examples/expr_api.rs` for example usage.