diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 3760328934bc..4ec7d8bdb997 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -29,14 +29,20 @@ use std::{ use arrow::array::AsArray; use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; -use arrow_schema::{DataType, Field}; +use arrow_schema::{DataType, Field, Schema}; use datafusion::{assert_batches_eq, prelude::SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ - PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, + PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl, }; -use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_functions_window_common::{ + expr::ExpressionArgs, field::WindowUDFFieldArgs, +}; +use datafusion_physical_expr::{ + expressions::{col, lit}, + PhysicalExpr, +}; /// A query with a window function evaluated over the entire partition const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \ @@ -645,3 +651,120 @@ fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> arrow_array::ArrayRef { let array: Int64Array = std::iter::repeat(odd_count(arr)).take(num_rows).collect(); Arc::new(array) } + +#[derive(Debug)] +struct VariadicWindowUDF { + signature: Signature, +} + +impl VariadicWindowUDF { + fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Any(0), + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ), + } + } +} + +impl WindowUDFImpl for VariadicWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "variadic_window_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _: PartitionEvaluatorArgs, + ) -> Result> { + unimplemented!("unnecessary for testing"); + } + + fn field(&self, _: WindowUDFFieldArgs) -> Result { + unimplemented!("unnecessary for testing"); + } +} + +#[test] +// Fixes: default implementation of `WindowUDFImpl::expressions` +// returns all input expressions to the user-defined window +// function unmodified. +// +// See: https://github.com/apache/datafusion/pull/13169 +fn test_default_expressions() -> Result<()> { + let udwf = WindowUDF::from(VariadicWindowUDF::new()); + + let field_a = Field::new("a", DataType::Int32, false); + let field_b = Field::new("b", DataType::Float32, false); + let field_c = Field::new("c", DataType::Boolean, false); + let schema = Schema::new(vec![field_a, field_b, field_c]); + + let test_cases = vec![ + // + // Zero arguments + // + vec![], + // + // Single argument + // + vec![col("a", &schema)?], + vec![lit(1)], + // + // Two arguments + // + vec![col("a", &schema)?, col("b", &schema)?], + vec![col("a", &schema)?, lit(2)], + vec![lit(false), col("a", &schema)?], + // + // Three arguments + // + vec![col("a", &schema)?, col("b", &schema)?, col("c", &schema)?], + vec![col("a", &schema)?, col("b", &schema)?, lit(false)], + vec![col("a", &schema)?, lit(0.5), col("c", &schema)?], + vec![lit(3), col("b", &schema)?, col("c", &schema)?], + ]; + + for input_exprs in &test_cases { + let input_types = input_exprs + .iter() + .map(|expr: &Arc| expr.data_type(&schema).unwrap()) + .collect::>(); + let expr_args = ExpressionArgs::new(input_exprs, &input_types); + + let ret_exprs = udwf.expressions(expr_args); + + // Verify same number of input expressions are returned + assert_eq!( + input_exprs.len(), + ret_exprs.len(), + "\nInput expressions: {:?}\nReturned expressions: {:?}", + input_exprs, + ret_exprs + ); + + // Compares each returned expression with original input expressions + for (expected, actual) in input_exprs.iter().zip(&ret_exprs) { + assert_eq!( + format!("{expected:?}"), + format!("{actual:?}"), + "\nInput expressions: {:?}\nReturned expressions: {:?}", + input_exprs, + ret_exprs + ); + } + } + Ok(()) +} diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 6ab94c1e841a..124625280670 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -312,10 +312,7 @@ pub trait WindowUDFImpl: Debug + Send + Sync { /// Returns the expressions that are passed to the [`PartitionEvaluator`]. fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { - expr_args - .input_exprs() - .first() - .map_or(vec![], |expr| vec![Arc::clone(expr)]) + expr_args.input_exprs().into() } /// Invoke the function, returning the [`PartitionEvaluator`] instance