diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 2c71cb80d755..bcf803573cdf 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1946,12 +1946,12 @@ mod tests { use datafusion_common_runtime::SpawnedTask; use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ - cast, create_udf, lit, BuiltInWindowFunction, ExprFunctionExt, - ScalarFunctionImplementation, Volatility, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunctionDefinition, + cast, create_udf, lit, ExprFunctionExt, ScalarFunctionImplementation, Volatility, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct}; use datafusion_functions_window::expr_fn::row_number; + use datafusion_functions_window::nth_value::first_value_udwf; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; use sqlparser::ast::NullTreatment; @@ -2177,9 +2177,7 @@ mod tests { // build plan using Table API let t = test_table().await?; let first_row = Expr::WindowFunction(WindowFunction::new( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::FirstValue, - ), + WindowFunctionDefinition::WindowUDF(first_value_udwf()), vec![col("aggregate_test_100.c1")], )) .partition_by(vec![col("aggregate_test_100.c2")]) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 701b2768531b..3faa8192f3eb 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -65,7 +65,6 @@ pub mod type_coercion; pub mod utils; pub mod var_provider; pub mod window_frame; -pub mod window_function; pub mod window_state; pub use built_in_window_function::BuiltInWindowFunction; diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs deleted file mode 100644 index be2b6575e2e9..000000000000 --- a/datafusion/expr/src/window_function.rs +++ /dev/null @@ -1,26 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; - -/// Create an expression to represent the `nth_value` window function -pub fn nth_value(arg: Expr, n: i64) -> Expr { - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::NthValue, - vec![arg, n.lit()], - )) -} diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs index de6e25bd454f..9f8e54a0423b 100644 --- a/datafusion/functions-window/src/lib.rs +++ b/datafusion/functions-window/src/lib.rs @@ -23,8 +23,6 @@ //! [DataFusion]: https://crates.io/crates/datafusion //! -extern crate core; - use std::sync::Arc; use log::debug; diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index a86714aaf93f..7f3d1cb07bca 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -27,7 +27,7 @@ use std::sync::OnceLock; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; use datafusion_expr::window_state::WindowAggState; use datafusion_expr::{ @@ -215,7 +215,11 @@ impl WindowUDFImpl for NthValue { } let n = - match get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)? + match get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1) + .map_err(|_e| { + exec_datafusion_err!( + "Expected a signed integer literal for the second argument of nth_value") + })? .map(get_signed_integer) { Some(Ok(n)) => { diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs deleted file mode 100644 index 6ec3a23fc586..000000000000 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ /dev/null @@ -1,415 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions for `FIRST_VALUE`, `LAST_VALUE`, and `NTH_VALUE` -//! functions that can be evaluated at run time during query execution. - -use std::any::Any; -use std::cmp::Ordering; -use std::ops::Range; -use std::sync::Arc; - -use crate::window::window_expr::{NthValueKind, NthValueState}; -use crate::window::BuiltInWindowFunctionExpr; -use crate::PhysicalExpr; - -use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::window_state::WindowAggState; -use datafusion_expr::PartitionEvaluator; - -/// nth_value expression -#[derive(Debug)] -pub struct NthValue { - name: String, - expr: Arc, - /// Output data type - data_type: DataType, - kind: NthValueKind, - ignore_nulls: bool, -} - -impl NthValue { - /// Create a new FIRST_VALUE window aggregate function - pub fn first( - name: impl Into, - expr: Arc, - data_type: DataType, - ignore_nulls: bool, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - kind: NthValueKind::First, - ignore_nulls, - } - } - - /// Create a new LAST_VALUE window aggregate function - pub fn last( - name: impl Into, - expr: Arc, - data_type: DataType, - ignore_nulls: bool, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - kind: NthValueKind::Last, - ignore_nulls, - } - } - - /// Create a new NTH_VALUE window aggregate function - pub fn nth( - name: impl Into, - expr: Arc, - data_type: DataType, - n: i64, - ignore_nulls: bool, - ) -> Result { - Ok(Self { - name: name.into(), - expr, - data_type, - kind: NthValueKind::Nth(n), - ignore_nulls, - }) - } - - /// Get the NTH_VALUE kind - pub fn get_kind(&self) -> NthValueKind { - self.kind - } -} - -impl BuiltInWindowFunctionExpr for NthValue { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = true; - Ok(Field::new(&self.name, self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - - fn name(&self) -> &str { - &self.name - } - - fn create_evaluator(&self) -> Result> { - let state = NthValueState { - finalized_result: None, - kind: self.kind, - }; - Ok(Box::new(NthValueEvaluator { - state, - ignore_nulls: self.ignore_nulls, - })) - } - - fn reverse_expr(&self) -> Option> { - let reversed_kind = match self.kind { - NthValueKind::First => NthValueKind::Last, - NthValueKind::Last => NthValueKind::First, - NthValueKind::Nth(idx) => NthValueKind::Nth(-idx), - }; - Some(Arc::new(Self { - name: self.name.clone(), - expr: Arc::clone(&self.expr), - data_type: self.data_type.clone(), - kind: reversed_kind, - ignore_nulls: self.ignore_nulls, - })) - } -} - -/// Value evaluator for nth_value functions -#[derive(Debug)] -pub(crate) struct NthValueEvaluator { - state: NthValueState, - ignore_nulls: bool, -} - -impl PartitionEvaluator for NthValueEvaluator { - /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING), - /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we - /// can memoize the result. Once result is calculated, it will always stay - /// same. Hence, we do not need to keep past data as we process the entire - /// dataset. - fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> { - let out = &state.out_col; - let size = out.len(); - let mut buffer_size = 1; - // Decide if we arrived at a final result yet: - let (is_prunable, is_reverse_direction) = match self.state.kind { - NthValueKind::First => { - let n_range = - state.window_frame_range.end - state.window_frame_range.start; - (n_range > 0 && size > 0, false) - } - NthValueKind::Last => (true, true), - NthValueKind::Nth(n) => { - let n_range = - state.window_frame_range.end - state.window_frame_range.start; - match n.cmp(&0) { - Ordering::Greater => { - (n_range >= (n as usize) && size > (n as usize), false) - } - Ordering::Less => { - let reverse_index = (-n) as usize; - buffer_size = reverse_index; - // Negative index represents reverse direction. - (n_range >= reverse_index, true) - } - Ordering::Equal => (false, false), - } - } - }; - // Do not memoize results when nulls are ignored. - if is_prunable && !self.ignore_nulls { - if self.state.finalized_result.is_none() && !is_reverse_direction { - let result = ScalarValue::try_from_array(out, size - 1)?; - self.state.finalized_result = Some(result); - } - state.window_frame_range.start = - state.window_frame_range.end.saturating_sub(buffer_size); - } - Ok(()) - } - - fn evaluate( - &mut self, - values: &[ArrayRef], - range: &Range, - ) -> Result { - if let Some(ref result) = self.state.finalized_result { - Ok(result.clone()) - } else { - // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1. - let arr = &values[0]; - let n_range = range.end - range.start; - if n_range == 0 { - // We produce None if the window is empty. - return ScalarValue::try_from(arr.data_type()); - } - - // Extract valid indices if ignoring nulls. - let valid_indices = if self.ignore_nulls { - // Calculate valid indices, inside the window frame boundaries - let slice = arr.slice(range.start, n_range); - let valid_indices = slice - .nulls() - .map(|nulls| { - nulls - .valid_indices() - // Add offset `range.start` to valid indices, to point correct index in the original arr. - .map(|idx| idx + range.start) - .collect::>() - }) - .unwrap_or_default(); - if valid_indices.is_empty() { - return ScalarValue::try_from(arr.data_type()); - } - Some(valid_indices) - } else { - None - }; - match self.state.kind { - NthValueKind::First => { - if let Some(valid_indices) = &valid_indices { - ScalarValue::try_from_array(arr, valid_indices[0]) - } else { - ScalarValue::try_from_array(arr, range.start) - } - } - NthValueKind::Last => { - if let Some(valid_indices) = &valid_indices { - ScalarValue::try_from_array( - arr, - valid_indices[valid_indices.len() - 1], - ) - } else { - ScalarValue::try_from_array(arr, range.end - 1) - } - } - NthValueKind::Nth(n) => { - match n.cmp(&0) { - Ordering::Greater => { - // SQL indices are not 0-based. - let index = (n as usize) - 1; - if index >= n_range { - // Outside the range, return NULL: - ScalarValue::try_from(arr.data_type()) - } else if let Some(valid_indices) = valid_indices { - if index >= valid_indices.len() { - return ScalarValue::try_from(arr.data_type()); - } - ScalarValue::try_from_array(&arr, valid_indices[index]) - } else { - ScalarValue::try_from_array(arr, range.start + index) - } - } - Ordering::Less => { - let reverse_index = (-n) as usize; - if n_range < reverse_index { - // Outside the range, return NULL: - ScalarValue::try_from(arr.data_type()) - } else if let Some(valid_indices) = valid_indices { - if reverse_index > valid_indices.len() { - return ScalarValue::try_from(arr.data_type()); - } - let new_index = - valid_indices[valid_indices.len() - reverse_index]; - ScalarValue::try_from_array(&arr, new_index) - } else { - ScalarValue::try_from_array( - arr, - range.start + n_range - reverse_index, - ) - } - } - Ordering::Equal => ScalarValue::try_from(arr.data_type()), - } - } - } - } - } - - fn supports_bounded_execution(&self) -> bool { - true - } - - fn uses_window_frame(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::Column; - use arrow::{array::*, datatypes::*}; - use datafusion_common::cast::as_int32_array; - - fn test_i32_result(expr: NthValue, expected: Int32Array) -> Result<()> { - let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); - let values = vec![arr]; - let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); - let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; - let mut ranges: Vec> = vec![]; - for i in 0..8 { - ranges.push(Range { - start: 0, - end: i + 1, - }) - } - let mut evaluator = expr.create_evaluator()?; - let values = expr.evaluate_args(&batch)?; - let result = ranges - .iter() - .map(|range| evaluator.evaluate(&values, range)) - .collect::>>()?; - let result = ScalarValue::iter_to_array(result.into_iter())?; - let result = as_int32_array(&result)?; - assert_eq!(expected, *result); - Ok(()) - } - - #[test] - fn first_value() -> Result<()> { - let first_value = NthValue::first( - "first_value".to_owned(), - Arc::new(Column::new("arr", 0)), - DataType::Int32, - false, - ); - test_i32_result(first_value, Int32Array::from(vec![1; 8]))?; - Ok(()) - } - - #[test] - fn last_value() -> Result<()> { - let last_value = NthValue::last( - "last_value".to_owned(), - Arc::new(Column::new("arr", 0)), - DataType::Int32, - false, - ); - test_i32_result( - last_value, - Int32Array::from(vec![ - Some(1), - Some(-2), - Some(3), - Some(-4), - Some(5), - Some(-6), - Some(7), - Some(8), - ]), - )?; - Ok(()) - } - - #[test] - fn nth_value_1() -> Result<()> { - let nth_value = NthValue::nth( - "nth_value".to_owned(), - Arc::new(Column::new("arr", 0)), - DataType::Int32, - 1, - false, - )?; - test_i32_result(nth_value, Int32Array::from(vec![1; 8]))?; - Ok(()) - } - - #[test] - fn nth_value_2() -> Result<()> { - let nth_value = NthValue::nth( - "nth_value".to_owned(), - Arc::new(Column::new("arr", 0)), - DataType::Int32, - 2, - false, - )?; - test_i32_result( - nth_value, - Int32Array::from(vec![ - None, - Some(-2), - Some(-2), - Some(-2), - Some(-2), - Some(-2), - Some(-2), - Some(-2), - ]), - )?; - Ok(()) - } -} diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index a9f9b22fafda..64fd0f49a233 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -70,6 +70,7 @@ tokio = { workspace = true } [dev-dependencies] criterion = { version = "0.5", features = ["async_futures"] } datafusion-functions-aggregate = { workspace = true } +datafusion-functions-window = { workspace = true } rstest = { workspace = true } rstest_reuse = "0.7.0" tokio = { workspace = true, features = [ diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 61c5e9584b8d..0a898bd852bb 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1180,9 +1180,13 @@ mod tests { WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_physical_expr::expressions::{col, Column}; + use datafusion_functions_window::nth_value::first_value_udwf; + use datafusion_functions_window::nth_value::last_value_udwf; + use datafusion_functions_window::nth_value::nth_value_udwf; + use datafusion_physical_expr::expressions::{col, lit, Column}; use datafusion_physical_expr::{LexOrdering, PhysicalExpr}; + use crate::memory::MemoryExec; use futures::future::Shared; use futures::{pin_mut, ready, FutureExt, Stream, StreamExt}; use itertools::Itertools; @@ -1501,124 +1505,123 @@ mod tests { // - feed BoundedWindowAggExec with batch stream data. // - Window frame should contain UNBOUNDED PRECEDING. // It hard to ensure these conditions are met, from the sql query. - // #[tokio::test] - // async fn test_window_nth_value_bounded_memoize() -> Result<()> { - // let config = SessionConfig::new().with_target_partitions(1); - // let task_ctx = Arc::new(TaskContext::default().with_session_config(config)); - // - // let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - // // Create a new batch of data to insert into the table - // let batch = RecordBatch::try_new( - // Arc::clone(&schema), - // vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))], - // )?; - // - // let memory_exec = MemoryExec::try_new( - // &[vec![batch.clone(), batch.clone(), batch.clone()]], - // Arc::clone(&schema), - // None, - // ) - // .map(|e| Arc::new(e) as Arc)?; - // let col_a = col("a", &schema)?; - // // let nth_value_func1 = WindowFunctionDefinition::WindowUDF(nth_value_udwf()) - // // // NthValue::nth( - // // // "nth_value(-1)", - // // // Arc::clone(&col_a), - // // // DataType::Int32, - // // // 1, - // // // false, - // // // )? - // // // .reverse_expr() - // // .unwrap(); - // // let nth_value_func2 = NthValue::nth( - // // "nth_value(-2)", - // // Arc::clone(&col_a), - // // DataType::Int32, - // // 2, - // // false, - // // )? - // // .reverse_expr() - // // .unwrap(); - // // let last_value_func = Arc::new(NthValue::last( - // // "last", - // // Arc::clone(&col_a), - // // DataType::Int32, - // // false, - // // )) as _; - // // let window_exprs = vec![ - // // LAST_VALUE(a) - // // Arc::new(BuiltInWindowExpr::new( - // // last_value_func, - // // &[], - // // &LexOrdering::default(), - // // Arc::new(WindowFrame::new_bounds( - // // WindowFrameUnits::Rows, - // // WindowFrameBound::Preceding(ScalarValue::UInt64(None)), - // // WindowFrameBound::CurrentRow, - // // )), - // // )) as _, - // // // NTH_VALUE(a, -1) - // // Arc::new(BuiltInWindowExpr::new( - // // nth_value_func1, - // // &[], - // // &LexOrdering::default(), - // // Arc::new(WindowFrame::new_bounds( - // // WindowFrameUnits::Rows, - // // WindowFrameBound::Preceding(ScalarValue::UInt64(None)), - // // WindowFrameBound::CurrentRow, - // // )), - // // )) as _, - // // // NTH_VALUE(a, -2) - // // Arc::new(BuiltInWindowExpr::new( - // // nth_value_func2, - // // &[], - // // &LexOrdering::default(), - // // Arc::new(WindowFrame::new_bounds( - // // WindowFrameUnits::Rows, - // // WindowFrameBound::Preceding(ScalarValue::UInt64(None)), - // // WindowFrameBound::CurrentRow, - // // )), - // // )) as _, - // // ]; - // let physical_plan = BoundedWindowAggExec::try_new( - // window_exprs, - // memory_exec, - // vec![], - // InputOrderMode::Sorted, - // ) - // .map(|e| Arc::new(e) as Arc)?; - // - // let batches = collect(physical_plan.execute(0, task_ctx)?).await?; - // - // let expected = vec![ - // "BoundedWindowAggExec: wdw=[last: Ok(Field { name: \"last\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }, nth_value(-1): Ok(Field { name: \"nth_value(-1)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }, nth_value(-2): Ok(Field { name: \"nth_value(-2)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - // " MemoryExec: partitions=1, partition_sizes=[3]", - // ]; - // // Get string representation of the plan - // let actual = get_plan_string(&physical_plan); - // assert_eq!( - // expected, actual, - // "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - // ); - // - // let expected = [ - // "+---+------+---------------+---------------+", - // "| a | last | nth_value(-1) | nth_value(-2) |", - // "+---+------+---------------+---------------+", - // "| 1 | 1 | 1 | |", - // "| 2 | 2 | 2 | 1 |", - // "| 3 | 3 | 3 | 2 |", - // "| 1 | 1 | 1 | 3 |", - // "| 2 | 2 | 2 | 1 |", - // "| 3 | 3 | 3 | 2 |", - // "| 1 | 1 | 1 | 3 |", - // "| 2 | 2 | 2 | 1 |", - // "| 3 | 3 | 3 | 2 |", - // "+---+------+---------------+---------------+", - // ]; - // assert_batches_eq!(expected, &batches); - // Ok(()) - // } + #[tokio::test] + async fn test_window_nth_value_bounded_memoize() -> Result<()> { + let config = SessionConfig::new().with_target_partitions(1); + let task_ctx = Arc::new(TaskContext::default().with_session_config(config)); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + // Create a new batch of data to insert into the table + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))], + )?; + + let memory_exec = MemoryExec::try_new( + &[vec![batch.clone(), batch.clone(), batch.clone()]], + Arc::clone(&schema), + None, + ) + .map(|e| Arc::new(e) as Arc)?; + let col_a = col("a", &schema)?; + let nth_value_func1 = NthValue::nth( + "nth_value(-1)", + Arc::clone(&col_a), + DataType::Int32, + 1, + false, + )? + .reverse_expr() + .unwrap(); + let nth_value_func2 = NthValue::nth( + "nth_value(-2)", + Arc::clone(&col_a), + DataType::Int32, + 2, + false, + )? + .reverse_expr() + .unwrap(); + let last_value_func = Arc::new(NthValue::last( + "last", + Arc::clone(&col_a), + DataType::Int32, + false, + )) as _; + let window_exprs = vec![ + // LAST_VALUE(a) + Arc::new(BuiltInWindowExpr::new( + last_value_func, + &[], + &LexOrdering::default(), + Arc::new(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::CurrentRow, + )), + )) as _, + // NTH_VALUE(a, -1) + Arc::new(BuiltInWindowExpr::new( + nth_value_func1, + &[], + &LexOrdering::default(), + Arc::new(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::CurrentRow, + )), + )) as _, + // NTH_VALUE(a, -2) + Arc::new(BuiltInWindowExpr::new( + nth_value_func2, + &[], + &LexOrdering::default(), + Arc::new(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::CurrentRow, + )), + )) as _, + ]; + let physical_plan = BoundedWindowAggExec::try_new( + window_exprs, + memory_exec, + vec![], + InputOrderMode::Sorted, + ) + .map(|e| Arc::new(e) as Arc)?; + + let batches = collect(physical_plan.execute(0, task_ctx)?).await?; + + let expected = vec![ + "BoundedWindowAggExec: wdw=[last: Ok(Field { name: \"last\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }, nth_value(-1): Ok(Field { name: \"nth_value(-1)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }, nth_value(-2): Ok(Field { name: \"nth_value(-2)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " MemoryExec: partitions=1, partition_sizes=[3]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = [ + "+---+------+---------------+---------------+", + "| a | last | nth_value(-1) | nth_value(-2) |", + "+---+------+---------------+---------------+", + "| 1 | 1 | 1 | |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "| 1 | 1 | 1 | 3 |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "| 1 | 1 | 1 | 3 |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "+---+------+---------------+---------------+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } // This test, tests whether most recent row guarantee by the input batch of the `BoundedWindowAggExec` // helps `BoundedWindowAggExec` to generate low latency result in the `Linear` mode. diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index aab63dd8bd66..88939b5bccf4 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -269,6 +269,35 @@ fn roundtrip_window() -> Result<()> { let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let _window_frame = WindowFrame::new_bounds( + datafusion_expr::WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::Int64(None)), + WindowFrameBound::CurrentRow, + ); + + // let udwf_window_expr = Expr::WindowFunction(WindowFunction::new( + // WindowFunctionDefinition::WindowUDF(first_value_udwf()), + + // let builtin_window_expr = Arc::new(BuiltInWindowExpr::new( + // Arc::new(NthValue::first( + // "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", + // col("a", &schema)?, + // DataType::Int64, + // false, + // )), + // &[col("b", &schema)?], + // &LexOrdering{ + // inner: vec![PhysicalSortExpr { + // expr: col("a", &schema)?, + // options: SortOptions { + // descending: false, + // nulls_first: false, + // }, + // }] + // }, + // Arc::new(window_frame), + // )); + let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( AggregateExprBuilder::new( avg_udaf(), diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index db180d2cf14d..4911a18dc019 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -136,6 +136,21 @@ from aggregate_test_100 order by c9 +statement error DataFusion error: Error during planning: Invalid function 'nth_vlue'.\nDid you mean 'nth_value'? +SELECT + NTH_VLUE(c4, 2) OVER() + FROM aggregate_test_100 + ORDER BY c9 + LIMIT 5; + +statement error DataFusion error: Error during planning: Invalid function 'frst_value'.\nDid you mean 'first_value'? +SELECT + FRST_VALUE(c4, 2) OVER() + FROM aggregate_test_100 + ORDER BY c9 + LIMIT 5; + + query error DataFusion error: Arrow error: Cast error: Cannot cast string 'foo' to value of Int64 type create table foo as values (1), ('foo'); diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 5bbe5cfc172a..8e3559a32684 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -4892,7 +4892,7 @@ DROP TABLE t1; statement ok CREATE TABLE t1(v1 BIGINT); -query error DataFusion error: This feature is not implemented: There is only support Literal types for field at idx: 1 in Window Function +query error DataFusion error: Execution error: Expected a signed integer literal for the second argument of nth_value SELECT NTH_VALUE('+Inf'::Double, v1) OVER (PARTITION BY v1) FROM t1; statement ok