Skip to content

Commit

Permalink
passes lint but does not have tests
Browse files Browse the repository at this point in the history
  • Loading branch information
buraksenn committed Nov 10, 2024
1 parent 66f01c4 commit 17909a8
Show file tree
Hide file tree
Showing 16 changed files with 821 additions and 176 deletions.
30 changes: 28 additions & 2 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1946,8 +1946,9 @@ mod tests {
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::expr::WindowFunction;
use datafusion_expr::{
cast, create_udf, lit, ExprFunctionExt, ScalarFunctionImplementation, Volatility,
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
cast, create_udf, lit, BuiltInWindowFunction, ExprFunctionExt,
ScalarFunctionImplementation, Volatility, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct};
use datafusion_functions_window::expr_fn::row_number;
Expand Down Expand Up @@ -2171,6 +2172,31 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn select_with_window_exprs() -> Result<()> {
// build plan using Table API
let t = test_table().await?;
let first_row = Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::FirstValue,
),
vec![col("aggregate_test_100.c1")],
))
.partition_by(vec![col("aggregate_test_100.c2")])
.build()
.unwrap();
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();

let sql_plan = create_plan(
"select c1, first_value(c1) over (partition by c2) from aggregate_test_100",
)
.await?;

assert_same_plan(&plan, &sql_plan);
Ok(())
}

#[tokio::test]
async fn select_with_periods() -> Result<()> {
// define data with a column name that has a "." in it:
Expand Down
27 changes: 27 additions & 0 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ use test_utils::add_empty_batches;
use datafusion::functions_window::row_number::row_number_udwf;
use datafusion_common::HashMap;
use datafusion_functions_window::lead_lag::{lag_udwf, lead_udwf};
use datafusion_functions_window::nth_value::{
first_value_udwf, last_value_udwf, nth_value_udwf,
};
use datafusion_functions_window::rank::{dense_rank_udwf, rank_udwf};
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use rand::distributions::Alphanumeric;
Expand Down Expand Up @@ -414,6 +417,30 @@ fn get_random_function(
),
);
}
window_fn_map.insert(
"first_value",
(
WindowFunctionDefinition::WindowUDF(first_value_udwf()),
vec![arg.clone()],
),
);
window_fn_map.insert(
"last_value",
(
WindowFunctionDefinition::WindowUDF(last_value_udwf()),
vec![arg.clone()],
),
);
window_fn_map.insert(
"nth_value",
(
WindowFunctionDefinition::WindowUDF(nth_value_udwf()),
vec![
arg.clone(),
lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
],
),
);

let rand_fn_idx = rng.gen_range(0..window_fn_map.len());
let fn_name = window_fn_map.keys().collect::<Vec<_>>()[rand_fn_idx];
Expand Down
105 changes: 104 additions & 1 deletion datafusion/expr/src/built_in_window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,115 @@

//! Built-in functions module contains all the built-in functions definitions.
use std::fmt;
use std::str::FromStr;

use crate::type_coercion::functions::data_types;
use crate::utils;
use crate::{Signature, Volatility};
use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result};

use arrow::datatypes::DataType;

use strum_macros::EnumIter;

impl fmt::Display for BuiltInWindowFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.name())
}
}

/// A [window function] built in to DataFusion
///
/// [Window Function]: https://en.wikipedia.org/wiki/Window_function_(SQL)
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)]
pub enum BuiltInWindowFunction {
Stub,
/// returns value evaluated at the row that is the first row of the window frame
FirstValue,
/// Returns value evaluated at the row that is the last row of the window frame
LastValue,
/// Returns value evaluated at the row that is the nth row of the window frame (counting from 1); returns null if no such row
NthValue,
}

impl BuiltInWindowFunction {
pub fn name(&self) -> &str {
use BuiltInWindowFunction::*;
match self {
FirstValue => "first_value",
LastValue => "last_value",
NthValue => "NTH_VALUE",
}
}
}

impl FromStr for BuiltInWindowFunction {
type Err = DataFusionError;
fn from_str(name: &str) -> Result<BuiltInWindowFunction> {
Ok(match name.to_uppercase().as_str() {
"FIRST_VALUE" => BuiltInWindowFunction::FirstValue,
"LAST_VALUE" => BuiltInWindowFunction::LastValue,
"NTH_VALUE" => BuiltInWindowFunction::NthValue,
_ => return plan_err!("There is no built-in window function named {name}"),
})
}
}

/// Returns the datatype of the built-in window function
impl BuiltInWindowFunction {
pub fn return_type(&self, input_expr_types: &[DataType]) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.

// Verify that this is a valid set of data types for this function
data_types(input_expr_types, &self.signature())
// Original errors are all related to wrong function signature
// Aggregate them for better error message
.map_err(|_| {
plan_datafusion_err!(
"{}",
utils::generate_signature_error_msg(
&format!("{self}"),
self.signature(),
input_expr_types,
)
)
})?;

match self {
BuiltInWindowFunction::FirstValue
| BuiltInWindowFunction::LastValue
| BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()),
}
}

/// The signatures supported by the built-in window function `fun`.
pub fn signature(&self) -> Signature {
// Note: The physical expression must accept the type returned by this function or the execution panics.
match self {
BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => {
Signature::any(1, Volatility::Immutable)
}
BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use strum::IntoEnumIterator;
#[test]
// Test for BuiltInWindowFunction's Display and from_str() implementations.
// For each variant in BuiltInWindowFunction, it converts the variant to a string
// and then back to a variant. The test asserts that the original variant and
// the reconstructed variant are the same. This assertion is also necessary for
// function suggestion. See https://github.com/apache/datafusion/issues/8082
fn test_display_and_from_str() {
for func_original in BuiltInWindowFunction::iter() {
let func_name = func_original.to_string();
let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap();
assert_eq!(func_from_str, func_original);
}
}
}
26 changes: 21 additions & 5 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ use crate::expr_fn::binary_expr;
use crate::logical_plan::Subquery;
use crate::utils::expr_to_columns;
use crate::Volatility;
use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF};
use crate::{
udaf, BuiltInWindowFunction, ExprSchemable, Operator, Signature, WindowFrame,
WindowUDF,
};

use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::cse::HashNode;
Expand Down Expand Up @@ -698,6 +701,9 @@ impl AggregateFunction {
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
/// Defines which implementation of an aggregate function DataFusion should call.
pub enum WindowFunctionDefinition {
/// A built in aggregate function that leverages an aggregate function
/// A a built-in window function
BuiltInWindowFunction(BuiltInWindowFunction),
/// A user defined aggregate function
AggregateUDF(Arc<crate::AggregateUDF>),
/// A user defined aggregate function
Expand All @@ -713,6 +719,9 @@ impl WindowFunctionDefinition {
display_name: &str,
) -> Result<DataType> {
match self {
WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
fun.return_type(input_expr_types)
}
WindowFunctionDefinition::AggregateUDF(fun) => {
fun.return_type(input_expr_types)
}
Expand All @@ -725,6 +734,7 @@ impl WindowFunctionDefinition {
/// The signatures supported by the function `fun`.
pub fn signature(&self) -> Signature {
match self {
WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(),
WindowFunctionDefinition::AggregateUDF(fun) => fun.signature().clone(),
WindowFunctionDefinition::WindowUDF(fun) => fun.signature().clone(),
}
Expand All @@ -733,6 +743,7 @@ impl WindowFunctionDefinition {
/// Function's name for display
pub fn name(&self) -> &str {
match self {
WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.name(),
WindowFunctionDefinition::WindowUDF(fun) => fun.name(),
WindowFunctionDefinition::AggregateUDF(fun) => fun.name(),
}
Expand All @@ -742,12 +753,19 @@ impl WindowFunctionDefinition {
impl Display for WindowFunctionDefinition {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
WindowFunctionDefinition::BuiltInWindowFunction(fun) => Display::fmt(fun, f),
WindowFunctionDefinition::AggregateUDF(fun) => Display::fmt(fun, f),
WindowFunctionDefinition::WindowUDF(fun) => Display::fmt(fun, f),
}
}
}

impl From<BuiltInWindowFunction> for WindowFunctionDefinition {
fn from(value: BuiltInWindowFunction) -> Self {
Self::BuiltInWindowFunction(value)
}
}

impl From<Arc<crate::AggregateUDF>> for WindowFunctionDefinition {
fn from(value: Arc<crate::AggregateUDF>) -> Self {
Self::AggregateUDF(value)
Expand All @@ -773,11 +791,9 @@ impl From<Arc<WindowUDF>> for WindowFunctionDefinition {
/// ```
/// # use datafusion_expr::{Expr, BuiltInWindowFunction, col, ExprFunctionExt};
/// # use datafusion_expr::expr::WindowFunction;
/// use datafusion_expr::test::function_stub::count_udaf;
/// use datafusion_expr::WindowFunctionDefinition::{AggregateUDF};
/// // Create COUNT(a) OVER (PARTITION BY b ORDER BY c)
/// // Create FIRST_VALUE(a) OVER (PARTITION BY b ORDER BY c)
/// let expr = Expr::WindowFunction(
/// WindowFunction::new(AggregateUDF(count_udaf()), vec![col("a")])
/// WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![col("a")])
/// )
/// .partition_by(vec![col("b")])
/// .order_by(vec![col("b").sort(true, true)])
Expand Down
6 changes: 6 additions & 0 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@ impl Expr {
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
match fun {
WindowFunctionDefinition::BuiltInWindowFunction(window_fun) => {
let return_type = window_fun.return_type(&data_types)?;
let nullable =
!["RANK", "NTILE", "CUME_DIST"].contains(&window_fun.name());
Ok((return_type, nullable))
}
WindowFunctionDefinition::AggregateUDF(udaf) => {
let new_types = data_types_with_aggregate_udf(&data_types, udaf)
.map_err(|err| {
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub mod type_coercion;
pub mod utils;
pub mod var_provider;
pub mod window_frame;
pub mod window_function;
pub mod window_state;

pub use built_in_window_function::BuiltInWindowFunction;
Expand Down
26 changes: 26 additions & 0 deletions datafusion/expr/src/window_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal};

/// Create an expression to represent the `nth_value` window function
pub fn nth_value(arg: Expr, n: i64) -> Expr {
Expr::WindowFunction(WindowFunction::new(
BuiltInWindowFunction::NthValue,
vec![arg, n.lit()],
))
}
27 changes: 23 additions & 4 deletions datafusion/functions-window/src/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,51 @@ use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
use datafusion_expr::window_state::WindowAggState;
use datafusion_expr::{
Documentation, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
Volatility, WindowUDFImpl,
};
use datafusion_functions_window_common::field;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
use field::WindowUDFFieldArgs;

define_udwf_and_expr!(
get_or_init_udwf!(
First,
first_value,
"returns the first value in the window frame",
NthValue::first
);
define_udwf_and_expr!(
get_or_init_udwf!(
Last,
last_value,
"returns the last value in the window frame",
NthValue::last
);
define_udwf_and_expr!(
get_or_init_udwf!(
NthValue,
nth_value,
"returns the nth value in the window frame",
NthValue::nth
);

/// Create an expression to represent the `first_value` window function
///
pub fn first_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
first_value_udwf().call(vec![arg])
}

/// Create an expression to represent the `last_value` window function
///
pub fn last_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
last_value_udwf().call(vec![arg])
}

/// Create an expression to represent the `nth_value` window function
///
pub fn nth_value(arg: datafusion_expr::Expr, n: Option<i64>) -> datafusion_expr::Expr {
let n_lit = n.map(|v| v.lit()).unwrap_or(ScalarValue::Null.lit());
nth_value_udwf().call(vec![arg, n_lit])
}

/// Tag to differentiate special use cases of the NTH_VALUE built-in window function.
#[derive(Debug, Copy, Clone)]
pub enum NthValueKind {
Expand Down
Loading

0 comments on commit 17909a8

Please sign in to comment.