From 814a9b7ec3d5460bf927ac80e6c5cddb61b5aea1 Mon Sep 17 00:00:00 2001 From: jcsherin Date: Thu, 17 Oct 2024 21:47:25 +0530 Subject: [PATCH] Fixes: incorrect NULL handling in `lead`/`lag` window function --- datafusion/expr/src/udwf.rs | 23 ++++ .../functions-window-common/src/expr.rs | 64 +++++++++++ datafusion/functions-window-common/src/lib.rs | 1 + datafusion/functions-window/src/lead_lag.rs | 107 ++++++++++++++---- datafusion/functions-window/src/utils.rs | 52 +-------- datafusion/physical-plan/src/windows/mod.rs | 4 +- 6 files changed, 181 insertions(+), 70 deletions(-) create mode 100644 datafusion/functions-window-common/src/expr.rs diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 6d8f2be97e02..6ab94c1e841a 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -34,8 +34,10 @@ use crate::{ Signature, }; use datafusion_common::{not_impl_err, Result}; +use datafusion_functions_window_common::expr::ExpressionArgs; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. @@ -149,6 +151,12 @@ impl WindowUDF { self.inner.simplify() } + /// Expressions that are passed to the [`PartitionEvaluator`]. + /// + /// See [`WindowUDFImpl::expressions`] for more details. + pub fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + self.inner.expressions(expr_args) + } /// Return a `PartitionEvaluator` for evaluating this window function pub fn partition_evaluator_factory( &self, @@ -302,6 +310,14 @@ pub trait WindowUDFImpl: Debug + Send + Sync { /// types are accepted and the function's Volatility. fn signature(&self) -> &Signature; + /// Returns the expressions that are passed to the [`PartitionEvaluator`]. + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + expr_args + .input_exprs() + .first() + .map_or(vec![], |expr| vec![Arc::clone(expr)]) + } + /// Invoke the function, returning the [`PartitionEvaluator`] instance fn partition_evaluator( &self, @@ -480,6 +496,13 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { self.inner.signature() } + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + expr_args + .input_exprs() + .first() + .map_or(vec![], |expr| vec![Arc::clone(expr)]) + } + fn partition_evaluator( &self, partition_evaluator_args: PartitionEvaluatorArgs, diff --git a/datafusion/functions-window-common/src/expr.rs b/datafusion/functions-window-common/src/expr.rs new file mode 100644 index 000000000000..1d99fe7acf15 --- /dev/null +++ b/datafusion/functions-window-common/src/expr.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +/// Arguments passed to user-defined window function +#[derive(Debug, Default)] +pub struct ExpressionArgs<'a> { + /// The expressions passed as arguments to the user-defined window + /// function. + input_exprs: &'a [Arc], + /// The corresponding data types of expressions passed as arguments + /// to the user-defined window function. + input_types: &'a [DataType], +} + +impl<'a> ExpressionArgs<'a> { + /// Create an instance of [`ExpressionArgs`]. + /// + /// # Arguments + /// + /// * `input_exprs` - The expressions passed as arguments + /// to the user-defined window function. + /// * `input_types` - The data types corresponding to the + /// arguments to the user-defined window function. + /// + pub fn new( + input_exprs: &'a [Arc], + input_types: &'a [DataType], + ) -> Self { + Self { + input_exprs, + input_types, + } + } + + /// Returns the expressions passed as arguments to the user-defined + /// window function. + pub fn input_exprs(&self) -> &'a [Arc] { + self.input_exprs + } + + /// Returns the [`DataType`]s corresponding to the input expressions + /// to the user-defined window function. + pub fn input_types(&self) -> &'a [DataType] { + self.input_types + } +} diff --git a/datafusion/functions-window-common/src/lib.rs b/datafusion/functions-window-common/src/lib.rs index 53f9eb1c9ac6..da8d096da562 100644 --- a/datafusion/functions-window-common/src/lib.rs +++ b/datafusion/functions-window-common/src/lib.rs @@ -18,5 +18,6 @@ //! Common user-defined window functionality for [DataFusion] //! //! [DataFusion]: +pub mod expr; pub mod field; pub mod partition; diff --git a/datafusion/functions-window/src/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs index 9e9bb3f6231c..f81521099751 100644 --- a/datafusion/functions-window/src/lead_lag.rs +++ b/datafusion/functions-window/src/lead_lag.rs @@ -17,10 +17,7 @@ //! `lead` and `lag` window function implementations -use crate::utils::{ - get_casted_value, get_scalar_value_from_args, get_signed_integer, - rewrite_null_expr_and_data_type, -}; +use crate::utils::{get_scalar_value_from_args, get_signed_integer}; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; @@ -29,8 +26,10 @@ use datafusion_expr::{ Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature, Volatility, WindowUDFImpl, }; +use datafusion_functions_window_common::expr::ExpressionArgs; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::any::Any; use std::cmp::min; use std::collections::VecDeque; @@ -161,6 +160,17 @@ impl WindowUDFImpl for WindowShift { &self.signature } + /// Handles the case where `NULL` expression is passed as an + /// argument to `lead`/`lag`. The type is refined depending + /// on the default value argument. + /// + /// For more details see: + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + parse_expr(expr_args.input_exprs(), expr_args.input_types()) + .into_iter() + .collect::>() + } + fn partition_evaluator( &self, partition_evaluator_args: PartitionEvaluatorArgs, @@ -177,18 +187,9 @@ impl WindowUDFImpl for WindowShift { offset } })?; - let return_type = partition_evaluator_args - .input_types() - .first() - .unwrap_or(&DataType::Null); - // See https://github.com/apache/datafusion/pull/12811 - let (_expr, return_type) = rewrite_null_expr_and_data_type( + let default_value = parse_default_value( partition_evaluator_args.input_exprs(), - return_type, - )?; - let default_value = get_casted_value( - get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 2)?, - &return_type, + partition_evaluator_args.input_types(), )?; Ok(Box::new(WindowShiftEvaluator { @@ -200,11 +201,9 @@ impl WindowUDFImpl for WindowShift { } fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new( - field_args.name(), - field_args.get_input_type(0).unwrap(), - true, - )) + let return_type = parse_expr_type(field_args.input_types())?; + + Ok(Field::new(field_args.name(), return_type, true)) } fn reverse_expr(&self) -> ReversedUDWF { @@ -215,6 +214,74 @@ impl WindowUDFImpl for WindowShift { } } +/// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to +/// refine it by matching it with the type of the default value. +/// +/// For e.g. in `lead(NULL, 1, false)` the generic `ScalarValue::Null` +/// is refined into `ScalarValue::Boolean(None)`. Only the type is +/// refined, the expression value remains `NULL`. +/// +/// When the window function is evaluated with `NULL` expression +/// this guarantees that the type matches with that of the default +/// value. +/// +/// For more details see: +fn parse_expr( + input_exprs: &[Arc], + input_types: &[DataType], +) -> Result> { + assert!(!input_exprs.is_empty()); + assert!(!input_types.is_empty()); + + let expr = Arc::clone(input_exprs.first().unwrap()); + let expr_type = input_types.first().unwrap(); + + // Handles the most common case where NULL is unexpected + if !expr_type.is_null() { + return Ok(expr); + } + + let default_value = get_scalar_value_from_args(input_exprs, 2)?; + default_value.map_or(Ok(expr), |value| { + ScalarValue::try_from(&value.data_type()).map(|v| { + Arc::new(datafusion_physical_expr::expressions::Literal::new(v)) + as Arc + }) + }) +} + +/// Returns the data type of the default value(if provided) when the +/// expression is `NULL`. +/// +/// Otherwise, returns the expression type unchanged. +fn parse_expr_type(input_types: &[DataType]) -> Result { + assert!(!input_types.is_empty()); + let expr_type = input_types.first().unwrap_or(&DataType::Null); + + // Handles the most common case where NULL is unexpected + if !expr_type.is_null() { + return Ok(expr_type.clone()); + } + + let default_value_type = input_types.get(2).unwrap_or(&DataType::Null); + Ok(default_value_type.clone()) +} + +/// Handles type coercion and null value refinement for default value +/// argument depending on the data type of the input expression. +fn parse_default_value( + input_exprs: &[Arc], + input_types: &[DataType], +) -> Result { + let expr_type = parse_expr_type(input_types)?; + let unparsed = get_scalar_value_from_args(input_exprs, 2)?; + + unparsed + .filter(|v| !v.data_type().is_null()) + .map(|v| v.cast_to(&expr_type)) + .unwrap_or(ScalarValue::try_from(expr_type)) +} + #[derive(Debug)] struct WindowShiftEvaluator { shift_offset: i64, diff --git a/datafusion/functions-window/src/utils.rs b/datafusion/functions-window/src/utils.rs index f470a32fbafe..69f68aa78f2c 100644 --- a/datafusion/functions-window/src/utils.rs +++ b/datafusion/functions-window/src/utils.rs @@ -16,12 +16,12 @@ // under the License. use datafusion_common::arrow::datatypes::DataType; -use datafusion_common::{exec_err, DataFusionError, ScalarValue}; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; -pub(crate) fn get_signed_integer(value: ScalarValue) -> datafusion_common::Result { +pub(crate) fn get_signed_integer(value: ScalarValue) -> Result { if value.is_null() { return Ok(0); } @@ -33,56 +33,10 @@ pub(crate) fn get_signed_integer(value: ScalarValue) -> datafusion_common::Resul value.cast_to(&DataType::Int64)?.try_into() } -pub(crate) fn get_casted_value( - default_value: Option, - dtype: &DataType, -) -> datafusion_common::Result { - match default_value { - Some(v) if !v.data_type().is_null() => v.cast_to(dtype), - // If None or Null datatype - _ => ScalarValue::try_from(dtype), - } -} - -/// Rewrites the NULL expression (1st argument) with an expression -/// which is the same data type as the default value (3rd argument). -/// Also rewrites the return type with the same data type as the -/// default value. -/// -/// If a default value is not provided, or it is NULL the original -/// expression (1st argument) and return type is returned without -/// any modifications. -pub(crate) fn rewrite_null_expr_and_data_type( - args: &[Arc], - expr_type: &DataType, -) -> datafusion_common::Result<(Arc, DataType)> { - assert!(!args.is_empty()); - let expr = Arc::clone(&args[0]); - - // The input expression and the return is type is unchanged - // when the input expression is not NULL. - if !expr_type.is_null() { - return Ok((expr, expr_type.clone())); - } - - get_scalar_value_from_args(args, 2)? - .and_then(|value| { - ScalarValue::try_from(value.data_type().clone()) - .map(|sv| { - Ok(( - Arc::new(Literal::new(sv)) as Arc, - value.data_type().clone(), - )) - }) - .ok() - }) - .unwrap_or(Ok((expr, expr_type.clone()))) -} - pub(crate) fn get_scalar_value_from_args( args: &[Arc], index: usize, -) -> datafusion_common::Result> { +) -> Result> { Ok(if let Some(field) = args.get(index) { let tmp = field .as_any() diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index df8a25b03377..adf61f27bc6f 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -48,6 +48,7 @@ mod utils; mod window_agg_exec; pub use bounded_window_agg_exec::BoundedWindowAggExec; +use datafusion_functions_window_common::expr::ExpressionArgs; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_physical_expr::expressions::Column; @@ -333,7 +334,8 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } fn expressions(&self) -> Vec> { - self.args.clone() + self.fun + .expressions(ExpressionArgs::new(&self.args, &self.input_types)) } fn create_evaluator(&self) -> Result> {