Skip to content

Commit

Permalink
Fixes: incorrect NULL handling in lead/lag window function
Browse files Browse the repository at this point in the history
  • Loading branch information
jcsherin committed Oct 17, 2024
1 parent 93569ed commit 814a9b7
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 70 deletions.
23 changes: 23 additions & 0 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ use crate::{
Signature,
};
use datafusion_common::{not_impl_err, Result};
use datafusion_functions_window_common::expr::ExpressionArgs;
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;

/// Logical representation of a user-defined window function (UDWF)
/// A UDWF is different from a UDF in that it is stateful across batches.
Expand Down Expand Up @@ -149,6 +151,12 @@ impl WindowUDF {
self.inner.simplify()
}

/// Expressions that are passed to the [`PartitionEvaluator`].
///
/// See [`WindowUDFImpl::expressions`] for more details.
pub fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
self.inner.expressions(expr_args)
}
/// Return a `PartitionEvaluator` for evaluating this window function
pub fn partition_evaluator_factory(
&self,
Expand Down Expand Up @@ -302,6 +310,14 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
/// types are accepted and the function's Volatility.
fn signature(&self) -> &Signature;

/// Returns the expressions that are passed to the [`PartitionEvaluator`].
fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
expr_args
.input_exprs()
.first()
.map_or(vec![], |expr| vec![Arc::clone(expr)])
}

/// Invoke the function, returning the [`PartitionEvaluator`] instance
fn partition_evaluator(
&self,
Expand Down Expand Up @@ -480,6 +496,13 @@ impl WindowUDFImpl for AliasedWindowUDFImpl {
self.inner.signature()
}

fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
expr_args
.input_exprs()
.first()
.map_or(vec![], |expr| vec![Arc::clone(expr)])
}

fn partition_evaluator(
&self,
partition_evaluator_args: PartitionEvaluatorArgs,
Expand Down
64 changes: 64 additions & 0 deletions datafusion/functions-window-common/src/expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use datafusion_common::arrow::datatypes::DataType;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use std::sync::Arc;

/// Arguments passed to user-defined window function
#[derive(Debug, Default)]
pub struct ExpressionArgs<'a> {
/// The expressions passed as arguments to the user-defined window
/// function.
input_exprs: &'a [Arc<dyn PhysicalExpr>],
/// The corresponding data types of expressions passed as arguments
/// to the user-defined window function.
input_types: &'a [DataType],
}

impl<'a> ExpressionArgs<'a> {
/// Create an instance of [`ExpressionArgs`].
///
/// # Arguments
///
/// * `input_exprs` - The expressions passed as arguments
/// to the user-defined window function.
/// * `input_types` - The data types corresponding to the
/// arguments to the user-defined window function.
///
pub fn new(
input_exprs: &'a [Arc<dyn PhysicalExpr>],
input_types: &'a [DataType],
) -> Self {
Self {
input_exprs,
input_types,
}
}

/// Returns the expressions passed as arguments to the user-defined
/// window function.
pub fn input_exprs(&self) -> &'a [Arc<dyn PhysicalExpr>] {
self.input_exprs
}

/// Returns the [`DataType`]s corresponding to the input expressions
/// to the user-defined window function.
pub fn input_types(&self) -> &'a [DataType] {
self.input_types
}
}
1 change: 1 addition & 0 deletions datafusion/functions-window-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
//! Common user-defined window functionality for [DataFusion]
//!
//! [DataFusion]: <https://crates.io/crates/datafusion>
pub mod expr;
pub mod field;
pub mod partition;
107 changes: 87 additions & 20 deletions datafusion/functions-window/src/lead_lag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@

//! `lead` and `lag` window function implementations
use crate::utils::{
get_casted_value, get_scalar_value_from_args, get_signed_integer,
rewrite_null_expr_and_data_type,
};
use crate::utils::{get_scalar_value_from_args, get_signed_integer};
use datafusion_common::arrow::array::ArrayRef;
use datafusion_common::arrow::datatypes::DataType;
use datafusion_common::arrow::datatypes::Field;
Expand All @@ -29,8 +26,10 @@ use datafusion_expr::{
Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature, Volatility,
WindowUDFImpl,
};
use datafusion_functions_window_common::expr::ExpressionArgs;
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use std::any::Any;
use std::cmp::min;
use std::collections::VecDeque;
Expand Down Expand Up @@ -161,6 +160,17 @@ impl WindowUDFImpl for WindowShift {
&self.signature
}

/// Handles the case where `NULL` expression is passed as an
/// argument to `lead`/`lag`. The type is refined depending
/// on the default value argument.
///
/// For more details see: <https://github.com/apache/datafusion/issues/12717>
fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
parse_expr(expr_args.input_exprs(), expr_args.input_types())
.into_iter()
.collect::<Vec<_>>()
}

fn partition_evaluator(
&self,
partition_evaluator_args: PartitionEvaluatorArgs,
Expand All @@ -177,18 +187,9 @@ impl WindowUDFImpl for WindowShift {
offset
}
})?;
let return_type = partition_evaluator_args
.input_types()
.first()
.unwrap_or(&DataType::Null);
// See https://github.com/apache/datafusion/pull/12811
let (_expr, return_type) = rewrite_null_expr_and_data_type(
let default_value = parse_default_value(
partition_evaluator_args.input_exprs(),
return_type,
)?;
let default_value = get_casted_value(
get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 2)?,
&return_type,
partition_evaluator_args.input_types(),
)?;

Ok(Box::new(WindowShiftEvaluator {
Expand All @@ -200,11 +201,9 @@ impl WindowUDFImpl for WindowShift {
}

fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
Ok(Field::new(
field_args.name(),
field_args.get_input_type(0).unwrap(),
true,
))
let return_type = parse_expr_type(field_args.input_types())?;

Ok(Field::new(field_args.name(), return_type, true))
}

fn reverse_expr(&self) -> ReversedUDWF {
Expand All @@ -215,6 +214,74 @@ impl WindowUDFImpl for WindowShift {
}
}

/// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to
/// refine it by matching it with the type of the default value.
///
/// For e.g. in `lead(NULL, 1, false)` the generic `ScalarValue::Null`
/// is refined into `ScalarValue::Boolean(None)`. Only the type is
/// refined, the expression value remains `NULL`.
///
/// When the window function is evaluated with `NULL` expression
/// this guarantees that the type matches with that of the default
/// value.
///
/// For more details see: <https://github.com/apache/datafusion/issues/12717>
fn parse_expr(
input_exprs: &[Arc<dyn PhysicalExpr>],
input_types: &[DataType],
) -> Result<Arc<dyn PhysicalExpr>> {
assert!(!input_exprs.is_empty());
assert!(!input_types.is_empty());

let expr = Arc::clone(input_exprs.first().unwrap());
let expr_type = input_types.first().unwrap();

// Handles the most common case where NULL is unexpected
if !expr_type.is_null() {
return Ok(expr);
}

let default_value = get_scalar_value_from_args(input_exprs, 2)?;
default_value.map_or(Ok(expr), |value| {
ScalarValue::try_from(&value.data_type()).map(|v| {
Arc::new(datafusion_physical_expr::expressions::Literal::new(v))
as Arc<dyn PhysicalExpr>
})
})
}

/// Returns the data type of the default value(if provided) when the
/// expression is `NULL`.
///
/// Otherwise, returns the expression type unchanged.
fn parse_expr_type(input_types: &[DataType]) -> Result<DataType> {
assert!(!input_types.is_empty());
let expr_type = input_types.first().unwrap_or(&DataType::Null);

// Handles the most common case where NULL is unexpected
if !expr_type.is_null() {
return Ok(expr_type.clone());
}

let default_value_type = input_types.get(2).unwrap_or(&DataType::Null);
Ok(default_value_type.clone())
}

/// Handles type coercion and null value refinement for default value
/// argument depending on the data type of the input expression.
fn parse_default_value(
input_exprs: &[Arc<dyn PhysicalExpr>],
input_types: &[DataType],
) -> Result<ScalarValue> {
let expr_type = parse_expr_type(input_types)?;
let unparsed = get_scalar_value_from_args(input_exprs, 2)?;

unparsed
.filter(|v| !v.data_type().is_null())
.map(|v| v.cast_to(&expr_type))
.unwrap_or(ScalarValue::try_from(expr_type))
}

#[derive(Debug)]
struct WindowShiftEvaluator {
shift_offset: i64,
Expand Down
52 changes: 3 additions & 49 deletions datafusion/functions-window/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
// under the License.

use datafusion_common::arrow::datatypes::DataType;
use datafusion_common::{exec_err, DataFusionError, ScalarValue};
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
use datafusion_physical_expr::expressions::Literal;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use std::sync::Arc;

pub(crate) fn get_signed_integer(value: ScalarValue) -> datafusion_common::Result<i64> {
pub(crate) fn get_signed_integer(value: ScalarValue) -> Result<i64> {
if value.is_null() {
return Ok(0);
}
Expand All @@ -33,56 +33,10 @@ pub(crate) fn get_signed_integer(value: ScalarValue) -> datafusion_common::Resul
value.cast_to(&DataType::Int64)?.try_into()
}

pub(crate) fn get_casted_value(
default_value: Option<ScalarValue>,
dtype: &DataType,
) -> datafusion_common::Result<ScalarValue> {
match default_value {
Some(v) if !v.data_type().is_null() => v.cast_to(dtype),
// If None or Null datatype
_ => ScalarValue::try_from(dtype),
}
}

/// Rewrites the NULL expression (1st argument) with an expression
/// which is the same data type as the default value (3rd argument).
/// Also rewrites the return type with the same data type as the
/// default value.
///
/// If a default value is not provided, or it is NULL the original
/// expression (1st argument) and return type is returned without
/// any modifications.
pub(crate) fn rewrite_null_expr_and_data_type(
args: &[Arc<dyn PhysicalExpr>],
expr_type: &DataType,
) -> datafusion_common::Result<(Arc<dyn PhysicalExpr>, DataType)> {
assert!(!args.is_empty());
let expr = Arc::clone(&args[0]);

// The input expression and the return is type is unchanged
// when the input expression is not NULL.
if !expr_type.is_null() {
return Ok((expr, expr_type.clone()));
}

get_scalar_value_from_args(args, 2)?
.and_then(|value| {
ScalarValue::try_from(value.data_type().clone())
.map(|sv| {
Ok((
Arc::new(Literal::new(sv)) as Arc<dyn PhysicalExpr>,
value.data_type().clone(),
))
})
.ok()
})
.unwrap_or(Ok((expr, expr_type.clone())))
}

pub(crate) fn get_scalar_value_from_args(
args: &[Arc<dyn PhysicalExpr>],
index: usize,
) -> datafusion_common::Result<Option<ScalarValue>> {
) -> Result<Option<ScalarValue>> {
Ok(if let Some(field) = args.get(index) {
let tmp = field
.as_any()
Expand Down
4 changes: 3 additions & 1 deletion datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ mod utils;
mod window_agg_exec;

pub use bounded_window_agg_exec::BoundedWindowAggExec;
use datafusion_functions_window_common::expr::ExpressionArgs;
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
use datafusion_physical_expr::expressions::Column;
Expand Down Expand Up @@ -333,7 +334,8 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr {
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
self.args.clone()
self.fun
.expressions(ExpressionArgs::new(&self.args, &self.input_types))
}

fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
Expand Down

0 comments on commit 814a9b7

Please sign in to comment.