Skip to content

Commit

Permalink
Convert BuiltInWindowFunction::{Lead, Lag} to a user defined window…
Browse files Browse the repository at this point in the history
… function (apache#12857)

* Move `lead-lag` to `functions-window` package

* Builds with warnings

* Adds `PartitionEvaluatorArgs`

* Extracts `shift_offset` from input expressions

* Computes shift offset

* Get default value from input expression

* Implements `partition_evaluator`

* Fixes compiler warnings

* Comments out failing tests

* Fixes `cargo test` errors and warnings

* Minor: taplo formatting

* Delete code

* Define `lead`, `lag` user-defined window functions

* Fixes `cargo build` errors

* Export udwf and expression public APIs

* Mark result field as nullable

* Delete `return_type` tests for `lead` and `lag`

* Disables test: window function case insensitive

* Fixes: lowercase name in logical plan

* Reverts to old methods for computing `shift_offset`, `default_value`

* Implements expression reversal

* Fixes: lowercase name in logical plans

* Fixes: doc test compilation errors
Fixes: doc test build errors

* Temporarily quite clippy errors

* Fixes proto defintion

* Minor: fixes formatting

* Fixes: doc tests

* Uses macro for defining `lag_udwf()` and `leag_udwf()`

* Fixes: window fuzz test cases

* Copies doc comments verbatim from `BuiltInWindowFunction` enum

* Deletes from window function case insensitive test

* Deletes `BuiltInWindowFunction` expression APIs

* Delete from `create_built_in_window_expr`

* Deletes proto serialization

* Delete from `BuiltInWindowFunction` enum

* Deletes test for finding built-in window function

* Fixes build errors + deletes redundant code

* Deletes more code

* Delete unnecessary structs

* Refactors shift offset computation

* Passes range unit test

* Fixes: clippy::get-first error

* Rewrite unit tests for WindowUDF

* Fixes: unit test for lag with default value

* Consistent input expressions and data types in unit tests

* Minor: fixes formatting

* Restore original helper method for unit tests

* Revert "Refactors shift offset computation"

This reverts commit 000ceb7.

* Moves helper functions into `functions-window-common` package

* Uses common helper functions in `{lead, lag}`

* Minor: formatting

* Revert "Moves helper functions into `functions-window-common` package"

This reverts commit ab8a83c.

* Moves common functions to utils

* Minor: formatting fixes

* Update lowercase names in explain output

* Adds doc for `lead()` and `lag()` expression functions

* Add doc for `WindowShiftKind::shift_offset`

* Remove `arrow` dev dependency

* Minor: formatting

* Update inner doc comment

* Serialize 1 or more window function arguments

* Adds logical plan roundtrip test cases

* Refactor: readability of unit tests

* Minor: rename variable bindings

* Minor: copy edit

* Revert "Remove `arrow` dev dependency"

This reverts commit 3eb0985.

* Move null argument handling helper to utils

* Disable failing sqllogic tests for handling NULL input

* Revert "Disable failing sqllogic tests for handling NULL input"

This reverts commit 270a203.

* Fixes: incorrect NULL handling in `lead`/`lag` window function

* Adds more tests cases

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
jcsherin and alamb authored Oct 18, 2024
1 parent 700b07f commit efe5708
Show file tree
Hide file tree
Showing 24 changed files with 520 additions and 407 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 5 additions & 8 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use test_utils::add_empty_batches;

use datafusion::functions_window::row_number::row_number_udwf;
use datafusion_functions_window::lead_lag::{lag_udwf, lead_udwf};
use datafusion_functions_window::rank::{dense_rank_udwf, rank_udwf};
use hashbrown::HashMap;
use rand::distributions::Alphanumeric;
Expand Down Expand Up @@ -197,7 +198,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
// )
(
// Window function
WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lag),
WindowFunctionDefinition::WindowUDF(lag_udwf()),
// its name
"LAG",
// no argument
Expand All @@ -211,7 +212,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
// )
(
// Window function
WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lead),
WindowFunctionDefinition::WindowUDF(lead_udwf()),
// its name
"LEAD",
// no argument
Expand Down Expand Up @@ -393,9 +394,7 @@ fn get_random_function(
window_fn_map.insert(
"lead",
(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::Lead,
),
WindowFunctionDefinition::WindowUDF(lead_udwf()),
vec![
arg.clone(),
lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
Expand All @@ -406,9 +405,7 @@ fn get_random_function(
window_fn_map.insert(
"lag",
(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::Lag,
),
WindowFunctionDefinition::WindowUDF(lag_udwf()),
vec![
arg.clone(),
lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
Expand Down
32 changes: 3 additions & 29 deletions datafusion/expr/src/built_in_window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::str::FromStr;

use crate::type_coercion::functions::data_types;
use crate::utils;
use crate::{Signature, TypeSignature, Volatility};
use crate::{Signature, Volatility};
use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result};

use arrow::datatypes::DataType;
Expand All @@ -44,17 +44,7 @@ pub enum BuiltInWindowFunction {
CumeDist,
/// Integer ranging from 1 to the argument value, dividing the partition as equally as possible
Ntile,
/// Returns value evaluated at the row that is offset rows before the current row within the partition;
/// If there is no such row, instead return default (which must be of the same type as value).
/// Both offset and default are evaluated with respect to the current row.
/// If omitted, offset defaults to 1 and default to null
Lag,
/// Returns value evaluated at the row that is offset rows after the current row within the partition;
/// If there is no such row, instead return default (which must be of the same type as value).
/// Both offset and default are evaluated with respect to the current row.
/// If omitted, offset defaults to 1 and default to null
Lead,
/// Returns value evaluated at the row that is the first row of the window frame
/// returns value evaluated at the row that is the first row of the window frame
FirstValue,
/// Returns value evaluated at the row that is the last row of the window frame
LastValue,
Expand All @@ -68,8 +58,6 @@ impl BuiltInWindowFunction {
match self {
CumeDist => "CUME_DIST",
Ntile => "NTILE",
Lag => "LAG",
Lead => "LEAD",
FirstValue => "first_value",
LastValue => "last_value",
NthValue => "NTH_VALUE",
Expand All @@ -83,8 +71,6 @@ impl FromStr for BuiltInWindowFunction {
Ok(match name.to_uppercase().as_str() {
"CUME_DIST" => BuiltInWindowFunction::CumeDist,
"NTILE" => BuiltInWindowFunction::Ntile,
"LAG" => BuiltInWindowFunction::Lag,
"LEAD" => BuiltInWindowFunction::Lead,
"FIRST_VALUE" => BuiltInWindowFunction::FirstValue,
"LAST_VALUE" => BuiltInWindowFunction::LastValue,
"NTH_VALUE" => BuiltInWindowFunction::NthValue,
Expand Down Expand Up @@ -117,9 +103,7 @@ impl BuiltInWindowFunction {
match self {
BuiltInWindowFunction::Ntile => Ok(DataType::UInt64),
BuiltInWindowFunction::CumeDist => Ok(DataType::Float64),
BuiltInWindowFunction::Lag
| BuiltInWindowFunction::Lead
| BuiltInWindowFunction::FirstValue
BuiltInWindowFunction::FirstValue
| BuiltInWindowFunction::LastValue
| BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()),
}
Expand All @@ -130,16 +114,6 @@ impl BuiltInWindowFunction {
// Note: The physical expression must accept the type returned by this function or the execution panics.
match self {
BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable),
BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => {
Signature::one_of(
vec![
TypeSignature::Any(1),
TypeSignature::Any(2),
TypeSignature::Any(3),
],
Volatility::Immutable,
)
}
BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => {
Signature::any(1, Volatility::Immutable)
}
Expand Down
38 changes: 0 additions & 38 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2560,30 +2560,6 @@ mod test {
Ok(())
}

#[test]
fn test_lead_return_type() -> Result<()> {
let fun = find_df_window_func("lead").unwrap();
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
assert_eq!(DataType::Float64, observed);

Ok(())
}

#[test]
fn test_lag_return_type() -> Result<()> {
let fun = find_df_window_func("lag").unwrap();
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
assert_eq!(DataType::Float64, observed);

Ok(())
}

#[test]
fn test_nth_value_return_type() -> Result<()> {
let fun = find_df_window_func("nth_value").unwrap();
Expand Down Expand Up @@ -2621,8 +2597,6 @@ mod test {
let names = vec![
"cume_dist",
"ntile",
"lag",
"lead",
"first_value",
"last_value",
"nth_value",
Expand Down Expand Up @@ -2660,18 +2634,6 @@ mod test {
built_in_window_function::BuiltInWindowFunction::LastValue
))
);
assert_eq!(
find_df_window_func("LAG"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
built_in_window_function::BuiltInWindowFunction::Lag
))
);
assert_eq!(
find_df_window_func("LEAD"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
built_in_window_function::BuiltInWindowFunction::Lead
))
);
assert_eq!(find_df_window_func("not_exist"), None)
}

Expand Down
23 changes: 23 additions & 0 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ use crate::{
Signature,
};
use datafusion_common::{not_impl_err, Result};
use datafusion_functions_window_common::expr::ExpressionArgs;
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;

/// Logical representation of a user-defined window function (UDWF)
/// A UDWF is different from a UDF in that it is stateful across batches.
Expand Down Expand Up @@ -149,6 +151,12 @@ impl WindowUDF {
self.inner.simplify()
}

/// Expressions that are passed to the [`PartitionEvaluator`].
///
/// See [`WindowUDFImpl::expressions`] for more details.
pub fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
self.inner.expressions(expr_args)
}
/// Return a `PartitionEvaluator` for evaluating this window function
pub fn partition_evaluator_factory(
&self,
Expand Down Expand Up @@ -302,6 +310,14 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
/// types are accepted and the function's Volatility.
fn signature(&self) -> &Signature;

/// Returns the expressions that are passed to the [`PartitionEvaluator`].
fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
expr_args
.input_exprs()
.first()
.map_or(vec![], |expr| vec![Arc::clone(expr)])
}

/// Invoke the function, returning the [`PartitionEvaluator`] instance
fn partition_evaluator(
&self,
Expand Down Expand Up @@ -480,6 +496,13 @@ impl WindowUDFImpl for AliasedWindowUDFImpl {
self.inner.signature()
}

fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
expr_args
.input_exprs()
.first()
.map_or(vec![], |expr| vec![Arc::clone(expr)])
}

fn partition_evaluator(
&self,
partition_evaluator_args: PartitionEvaluatorArgs,
Expand Down
34 changes: 0 additions & 34 deletions datafusion/expr/src/window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use datafusion_common::ScalarValue;

use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal};

/// Create an expression to represent the `cume_dist` window function
Expand All @@ -29,38 +27,6 @@ pub fn ntile(arg: Expr) -> Expr {
Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg]))
}

/// Create an expression to represent the `lag` window function
pub fn lag(
arg: Expr,
shift_offset: Option<i64>,
default_value: Option<ScalarValue>,
) -> Expr {
let shift_offset_lit = shift_offset
.map(|v| v.lit())
.unwrap_or(ScalarValue::Null.lit());
let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
Expr::WindowFunction(WindowFunction::new(
BuiltInWindowFunction::Lag,
vec![arg, shift_offset_lit, default_lit],
))
}

/// Create an expression to represent the `lead` window function
pub fn lead(
arg: Expr,
shift_offset: Option<i64>,
default_value: Option<ScalarValue>,
) -> Expr {
let shift_offset_lit = shift_offset
.map(|v| v.lit())
.unwrap_or(ScalarValue::Null.lit());
let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
Expr::WindowFunction(WindowFunction::new(
BuiltInWindowFunction::Lead,
vec![arg, shift_offset_lit, default_lit],
))
}

/// Create an expression to represent the `nth_value` window function
pub fn nth_value(arg: Expr, n: i64) -> Expr {
Expr::WindowFunction(WindowFunction::new(
Expand Down
64 changes: 64 additions & 0 deletions datafusion/functions-window-common/src/expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use datafusion_common::arrow::datatypes::DataType;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use std::sync::Arc;

/// Arguments passed to user-defined window function
#[derive(Debug, Default)]
pub struct ExpressionArgs<'a> {
/// The expressions passed as arguments to the user-defined window
/// function.
input_exprs: &'a [Arc<dyn PhysicalExpr>],
/// The corresponding data types of expressions passed as arguments
/// to the user-defined window function.
input_types: &'a [DataType],
}

impl<'a> ExpressionArgs<'a> {
/// Create an instance of [`ExpressionArgs`].
///
/// # Arguments
///
/// * `input_exprs` - The expressions passed as arguments
/// to the user-defined window function.
/// * `input_types` - The data types corresponding to the
/// arguments to the user-defined window function.
///
pub fn new(
input_exprs: &'a [Arc<dyn PhysicalExpr>],
input_types: &'a [DataType],
) -> Self {
Self {
input_exprs,
input_types,
}
}

/// Returns the expressions passed as arguments to the user-defined
/// window function.
pub fn input_exprs(&self) -> &'a [Arc<dyn PhysicalExpr>] {
self.input_exprs
}

/// Returns the [`DataType`]s corresponding to the input expressions
/// to the user-defined window function.
pub fn input_types(&self) -> &'a [DataType] {
self.input_types
}
}
1 change: 1 addition & 0 deletions datafusion/functions-window-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
//! Common user-defined window functionality for [DataFusion]
//!
//! [DataFusion]: <https://crates.io/crates/datafusion>
pub mod expr;
pub mod field;
pub mod partition;
1 change: 1 addition & 0 deletions datafusion/functions-window/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ path = "src/lib.rs"
datafusion-common = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-functions-window-common = { workspace = true }
datafusion-physical-expr = { workspace = true }
datafusion-physical-expr-common = { workspace = true }
log = { workspace = true }
paste = "1.0.15"
Expand Down
Loading

0 comments on commit efe5708

Please sign in to comment.