From b0850246351fd2483225a2a8db4534a75441f000 Mon Sep 17 00:00:00 2001 From: Michael-J-Ward Date: Tue, 29 Oct 2024 11:20:13 -0500 Subject: [PATCH 1/5] fix: default UDWFImpl::expressions returns all expressions --- datafusion/expr/src/udwf.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 6ab94c1e841a..124625280670 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -312,10 +312,7 @@ pub trait WindowUDFImpl: Debug + Send + Sync { /// Returns the expressions that are passed to the [`PartitionEvaluator`]. fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { - expr_args - .input_exprs() - .first() - .map_or(vec![], |expr| vec![Arc::clone(expr)]) + expr_args.input_exprs().into() } /// Invoke the function, returning the [`PartitionEvaluator`] instance From 3541c349f6308c01f5f0649c86964302459d432d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 29 Oct 2024 20:16:17 -0400 Subject: [PATCH 2/5] Add unit test to check for window function inputs --- .../user_defined_window_functions.rs | 73 ++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 3760328934bc..20dbc69abce9 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -35,8 +35,11 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, }; -use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_functions_window_common::{ + expr::ExpressionArgs, field::WindowUDFFieldArgs, +}; +use datafusion_physical_expr::expressions::lit; /// A query with a window function evaluated over the entire partition const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \ @@ -645,3 +648,71 @@ fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> arrow_array::ArrayRef { let array: Int64Array = std::iter::repeat(odd_count(arr)).take(num_rows).collect(); Arc::new(array) } + +#[derive(Debug)] +struct ThreeArgWindowUDF { + signature: Signature, +} + +impl ThreeArgWindowUDF { + fn new() -> Self { + Self { + signature: Signature::uniform( + 3, + vec![DataType::Int32, DataType::Boolean, DataType::Float32], + Volatility::Immutable, + ), + } + } +} + +impl WindowUDFImpl for ThreeArgWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "three_arg_window_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _: PartitionEvaluatorArgs, + ) -> Result> { + todo!() + } + + fn field(&self, _: WindowUDFFieldArgs) -> Result { + todo!() + } +} + +#[test] +fn test_input_expressions() -> Result<()> { + let udwf = WindowUDF::from(ThreeArgWindowUDF::new()); + + let input_exprs = vec![lit(1), lit(false), lit(0.5)]; // Vec> + let input_types = [DataType::Int32, DataType::Boolean, DataType::Float32]; // Vec + let actual = udwf.expressions(ExpressionArgs::new(&input_exprs, &input_types)); + + assert_eq!(actual.len(), 3); + + assert_eq!( + format!("{:?}", actual.first().unwrap()), + format!("{:?}", input_exprs.first().unwrap()), + ); + assert_eq!( + format!("{:?}", actual.get(1).unwrap()), + format!("{:?}", input_exprs.get(1).unwrap()) + ); + assert_eq!( + format!("{:?}", actual.get(2).unwrap()), + format!("{:?}", input_exprs.get(2).unwrap()) + ); + + Ok(()) +} From f8fa38c6a12f0db226aaab6ded68d56d18caa385 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 30 Oct 2024 12:49:33 -0400 Subject: [PATCH 3/5] Add unit test to catch errors in udwf with multiple column arguments --- .../user_defined_window_functions.rs | 118 +++++++++++++----- 1 file changed, 86 insertions(+), 32 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 20dbc69abce9..83368f5921b0 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -29,17 +29,20 @@ use std::{ use arrow::array::AsArray; use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; -use arrow_schema::{DataType, Field}; +use arrow_schema::{DataType, Field, Schema}; use datafusion::{assert_batches_eq, prelude::SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ - PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, + PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl, }; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_functions_window_common::{ expr::ExpressionArgs, field::WindowUDFFieldArgs, }; -use datafusion_physical_expr::expressions::lit; +use datafusion_physical_expr::{ + expressions::{col, lit}, + PhysicalExpr, +}; /// A query with a window function evaluated over the entire partition const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \ @@ -650,29 +653,33 @@ fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> arrow_array::ArrayRef { } #[derive(Debug)] -struct ThreeArgWindowUDF { +struct VariadicWindowUDF { signature: Signature, } -impl ThreeArgWindowUDF { +impl VariadicWindowUDF { fn new() -> Self { Self { - signature: Signature::uniform( - 3, - vec![DataType::Int32, DataType::Boolean, DataType::Float32], + signature: Signature::one_of( + vec![ + TypeSignature::Any(0), + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], Volatility::Immutable, ), } } } -impl WindowUDFImpl for ThreeArgWindowUDF { +impl WindowUDFImpl for VariadicWindowUDF { fn as_any(&self) -> &dyn Any { self } fn name(&self) -> &str { - "three_arg_window_udf" + "variadic_window_udf" } fn signature(&self) -> &Signature { @@ -683,36 +690,83 @@ impl WindowUDFImpl for ThreeArgWindowUDF { &self, _: PartitionEvaluatorArgs, ) -> Result> { - todo!() + unimplemented!("unnecessary for testing"); } fn field(&self, _: WindowUDFFieldArgs) -> Result { - todo!() + unimplemented!("unnecessary for testing"); } } #[test] -fn test_input_expressions() -> Result<()> { - let udwf = WindowUDF::from(ThreeArgWindowUDF::new()); - - let input_exprs = vec![lit(1), lit(false), lit(0.5)]; // Vec> - let input_types = [DataType::Int32, DataType::Boolean, DataType::Float32]; // Vec - let actual = udwf.expressions(ExpressionArgs::new(&input_exprs, &input_types)); - - assert_eq!(actual.len(), 3); +// Fixes: default implementation of `WindowUDFImpl::expressions` +// returns all input expressions to the user-defined window +// function unmodified. +// +// See: https://github.com/apache/datafusion/pull/13169 +fn test_default_expressions() -> Result<()> { + let udwf = WindowUDF::from(VariadicWindowUDF::new()); + + let field_a = Field::new("a", DataType::Int32, false); + let field_b = Field::new("b", DataType::Float32, false); + let field_c = Field::new("c", DataType::Boolean, false); + let schema = Schema::new(vec![field_a, field_b, field_c]); + + let test_cases = vec![ + // + // Zero arguments + // + vec![], + // + // Single argument + // + vec![col("a", &schema)?], + vec![lit(1)], + // + // Two arguments + // + vec![col("a", &schema)?, col("b", &schema)?], + vec![col("a", &schema)?, lit(2)], + vec![lit(false), col("a", &schema)?], + // + // Three arguments + // + vec![col("a", &schema)?, col("b", &schema)?, col("c", &schema)?], + vec![col("a", &schema)?, col("b", &schema)?, lit(false)], + vec![col("a", &schema)?, lit(0.5), col("c", &schema)?], + vec![lit(3), col("b", &schema)?, col("c", &schema)?], + ]; - assert_eq!( - format!("{:?}", actual.first().unwrap()), - format!("{:?}", input_exprs.first().unwrap()), - ); - assert_eq!( - format!("{:?}", actual.get(1).unwrap()), - format!("{:?}", input_exprs.get(1).unwrap()) - ); - assert_eq!( - format!("{:?}", actual.get(2).unwrap()), - format!("{:?}", input_exprs.get(2).unwrap()) - ); + for input_exprs in &test_cases { + let input_types = input_exprs + .iter() + .map(|expr: &std::sync::Arc| { + expr.data_type(&schema).unwrap() + }) + .collect::>(); + let expr_args = ExpressionArgs::new(input_exprs, &input_types); + + let ret_exprs = udwf.expressions(expr_args); + + // Verify same number of input expressions are returned + assert_eq!( + input_exprs.len(), + ret_exprs.len(), + "\nInput expressions: {:?}\nReturned expressions: {:?}", + input_exprs, + ret_exprs + ); + // Compares each returned expression with original input expressions + for (expected, actual) in input_exprs.iter().zip(&ret_exprs) { + assert_eq!( + format!("{expected:?}"), + format!("{actual:?}"), + "\nInput expressions: {:?}\nReturned expressions: {:?}", + input_exprs, + ret_exprs + ); + } + } Ok(()) } From acc3c9d6069d8ba5b68043798a27ba23029c7d8a Mon Sep 17 00:00:00 2001 From: Michael-J-Ward Date: Wed, 30 Oct 2024 12:37:01 -0500 Subject: [PATCH 4/5] remove unnecessary qualification from user defined window test --- .../core/tests/user_defined/user_defined_window_functions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 83368f5921b0..f3ebfb4ec296 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -740,7 +740,7 @@ fn test_default_expressions() -> Result<()> { for input_exprs in &test_cases { let input_types = input_exprs .iter() - .map(|expr: &std::sync::Arc| { + .map(|expr: &Arc| { expr.data_type(&schema).unwrap() }) .collect::>(); From 41de745087caadcfa68778fff29bff5a6eb44069 Mon Sep 17 00:00:00 2001 From: Michael-J-Ward Date: Wed, 30 Oct 2024 12:40:01 -0500 Subject: [PATCH 5/5] cargo fmt --- .../core/tests/user_defined/user_defined_window_functions.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index f3ebfb4ec296..4ec7d8bdb997 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -740,9 +740,7 @@ fn test_default_expressions() -> Result<()> { for input_exprs in &test_cases { let input_types = input_exprs .iter() - .map(|expr: &Arc| { - expr.data_type(&schema).unwrap() - }) + .map(|expr: &Arc| expr.data_type(&schema).unwrap()) .collect::>(); let expr_args = ExpressionArgs::new(input_exprs, &input_types);