diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index b2e8268aa332..349968c3fa2f 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,11 +23,11 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. +use crate::{expr::Sort, lit}; +use arrow::datatypes::DataType; use std::fmt::{self, Formatter}; use std::hash::Hash; -use crate::{expr::Sort, lit}; - use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; use sqlparser::parser::ParserError::ParserError; @@ -119,9 +119,9 @@ impl TryFrom for WindowFrame { type Error = DataFusionError; fn try_from(value: ast::WindowFrame) -> Result { - let start_bound = value.start_bound.try_into()?; + let start_bound = WindowFrameBound::try_parse(value.start_bound, &value.units)?; let end_bound = match value.end_bound { - Some(value) => value.try_into()?, + Some(bound) => WindowFrameBound::try_parse(bound, &value.units)?, None => WindowFrameBound::CurrentRow, }; @@ -138,6 +138,7 @@ impl TryFrom for WindowFrame { )? } }; + let units = value.units.into(); Ok(Self::new_bounds(units, start_bound, end_bound)) } @@ -334,17 +335,18 @@ impl WindowFrameBound { } } -impl TryFrom for WindowFrameBound { - type Error = DataFusionError; - - fn try_from(value: ast::WindowFrameBound) -> Result { +impl WindowFrameBound { + fn try_parse( + value: ast::WindowFrameBound, + units: &ast::WindowFrameUnits, + ) -> Result { Ok(match value { ast::WindowFrameBound::Preceding(Some(v)) => { - Self::Preceding(convert_frame_bound_to_scalar_value(*v)?) + Self::Preceding(convert_frame_bound_to_scalar_value(*v, units)?) } ast::WindowFrameBound::Preceding(None) => Self::Preceding(ScalarValue::Null), ast::WindowFrameBound::Following(Some(v)) => { - Self::Following(convert_frame_bound_to_scalar_value(*v)?) + Self::Following(convert_frame_bound_to_scalar_value(*v, units)?) } ast::WindowFrameBound::Following(None) => Self::Following(ScalarValue::Null), ast::WindowFrameBound::CurrentRow => Self::CurrentRow, @@ -352,33 +354,65 @@ impl TryFrom for WindowFrameBound { } } -pub fn convert_frame_bound_to_scalar_value(v: ast::Expr) -> Result { - Ok(ScalarValue::Utf8(Some(match v { - ast::Expr::Value(ast::Value::Number(value, false)) - | ast::Expr::Value(ast::Value::SingleQuotedString(value)) => value, - ast::Expr::Interval(ast::Interval { - value, - leading_field, - .. - }) => { - let result = match *value { - ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, - e => { - return sql_err!(ParserError(format!( - "INTERVAL expression cannot be {e:?}" - ))); +fn convert_frame_bound_to_scalar_value( + v: ast::Expr, + units: &ast::WindowFrameUnits, +) -> Result { + match units { + // For ROWS and GROUPS we are sure that the ScalarValue must be a non-negative integer ... + ast::WindowFrameUnits::Rows | ast::WindowFrameUnits::Groups => match v { + ast::Expr::Value(ast::Value::Number(value, false)) => { + Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?) + }, + ast::Expr::Interval(ast::Interval { + value, + leading_field: None, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }) => { + let value = match *value { + ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, + e => { + return sql_err!(ParserError(format!( + "INTERVAL expression cannot be {e:?}" + ))); + } + }; + Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?) + } + _ => plan_err!( + "Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers" + ), + }, + // ... instead for RANGE it could be anything depending on the type of the ORDER BY clause, + // so we use a ScalarValue::Utf8. + ast::WindowFrameUnits::Range => Ok(ScalarValue::Utf8(Some(match v { + ast::Expr::Value(ast::Value::Number(value, false)) => value, + ast::Expr::Interval(ast::Interval { + value, + leading_field, + .. + }) => { + let result = match *value { + ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, + e => { + return sql_err!(ParserError(format!( + "INTERVAL expression cannot be {e:?}" + ))); + } + }; + if let Some(leading_field) = leading_field { + format!("{result} {leading_field}") + } else { + result } - }; - if let Some(leading_field) = leading_field { - format!("{result} {leading_field}") - } else { - result } - } - _ => plan_err!( - "Invalid window frame: frame offsets must be non negative integers" - )?, - }))) + _ => plan_err!( + "Invalid window frame: frame offsets for RANGE must be either a numeric value, a string value or an interval" + )?, + }))), + } } impl fmt::Display for WindowFrameBound { @@ -479,8 +513,91 @@ mod tests { ast::Expr::Value(ast::Value::Number("1".to_string(), false)), )))), }; - let result = WindowFrame::try_from(window_frame); - assert!(result.is_ok()); + + let window_frame = WindowFrame::try_from(window_frame)?; + assert_eq!(window_frame.units, WindowFrameUnits::Rows); + assert_eq!( + window_frame.start_bound, + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))) + ); + assert_eq!( + window_frame.end_bound, + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))) + ); + + Ok(()) + } + + macro_rules! test_bound { + ($unit:ident, $value:expr, $expected:expr) => { + let preceding = WindowFrameBound::try_parse( + ast::WindowFrameBound::Preceding($value), + &ast::WindowFrameUnits::$unit, + )?; + assert_eq!(preceding, WindowFrameBound::Preceding($expected)); + let following = WindowFrameBound::try_parse( + ast::WindowFrameBound::Following($value), + &ast::WindowFrameUnits::$unit, + )?; + assert_eq!(following, WindowFrameBound::Following($expected)); + }; + } + + macro_rules! test_bound_err { + ($unit:ident, $value:expr, $expected:expr) => { + let err = WindowFrameBound::try_parse( + ast::WindowFrameBound::Preceding($value), + &ast::WindowFrameUnits::$unit, + ) + .unwrap_err(); + assert_eq!(err.strip_backtrace(), $expected); + let err = WindowFrameBound::try_parse( + ast::WindowFrameBound::Following($value), + &ast::WindowFrameUnits::$unit, + ) + .unwrap_err(); + assert_eq!(err.strip_backtrace(), $expected); + }; + } + + #[test] + fn test_window_frame_bound_creation() -> Result<()> { + // Unbounded + test_bound!(Rows, None, ScalarValue::Null); + test_bound!(Groups, None, ScalarValue::Null); + test_bound!(Range, None, ScalarValue::Null); + + // Number + let number = Some(Box::new(ast::Expr::Value(ast::Value::Number( + "42".to_string(), + false, + )))); + test_bound!(Rows, number.clone(), ScalarValue::UInt64(Some(42))); + test_bound!(Groups, number.clone(), ScalarValue::UInt64(Some(42))); + test_bound!( + Range, + number.clone(), + ScalarValue::Utf8(Some("42".to_string())) + ); + + // Interval + let number = Some(Box::new(ast::Expr::Interval(ast::Interval { + value: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( + "1".to_string(), + ))), + leading_field: Some(ast::DateTimeField::Day), + fractional_seconds_precision: None, + last_field: None, + leading_precision: None, + }))); + test_bound_err!(Rows, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers"); + test_bound_err!(Groups, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers"); + test_bound!( + Range, + number.clone(), + ScalarValue::Utf8(Some("1 DAY".to_string())) + ); + Ok(()) } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 36b72233b5af..33eea1a661c6 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -696,20 +696,20 @@ fn coerce_window_frame( expressions: &[Sort], ) -> Result { let mut window_frame = window_frame; - let current_types = expressions - .iter() - .map(|s| s.expr.get_type(schema)) - .collect::>>()?; let target_type = match window_frame.units { WindowFrameUnits::Range => { - if let Some(col_type) = current_types.first() { + let current_types = expressions + .first() + .map(|s| s.expr.get_type(schema)) + .transpose()?; + if let Some(col_type) = current_types { if col_type.is_numeric() - || is_utf8_or_large_utf8(col_type) + || is_utf8_or_large_utf8(&col_type) || matches!(col_type, DataType::Null) { col_type - } else if is_datetime(col_type) { - &DataType::Interval(IntervalUnit::MonthDayNano) + } else if is_datetime(&col_type) { + DataType::Interval(IntervalUnit::MonthDayNano) } else { return internal_err!( "Cannot run range queries on datatype: {col_type:?}" @@ -719,10 +719,11 @@ fn coerce_window_frame( return internal_err!("ORDER BY column cannot be empty"); } } - WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64, + WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64, }; - window_frame.start_bound = coerce_frame_bound(target_type, window_frame.start_bound)?; - window_frame.end_bound = coerce_frame_bound(target_type, window_frame.end_bound)?; + window_frame.start_bound = + coerce_frame_bound(&target_type, window_frame.start_bound)?; + window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?; Ok(window_frame) } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 9ed084eec249..8e25c1c5b1cd 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -1061,7 +1061,7 @@ fn test_aggregation_to_sql() { FROM person GROUP BY id, first_name;"#, r#"SELECT person.id, person.first_name, -sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN '5' PRECEDING AND '2' FOLLOWING) AS moving_sum, +sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, max(sum(person.id)) OVER (PARTITION BY person.first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total, rank() OVER (PARTITION BY (grouping(person.id) + grouping(person.age)), CASE WHEN (grouping(person.age) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_1, rank() OVER (PARTITION BY (grouping(person.age) + grouping(person.id)), CASE WHEN (CAST(grouping(person.age) AS BIGINT) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_2 diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 95d850795772..4a2d9e1d6864 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -2208,7 +2208,7 @@ physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2] 02)--SortExec: TopK(fetch=5), expr=[c9@2 ASC NULLS LAST], preserve_partitioning=[false] 03)----ProjectionExec: expr=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum1, sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING@4 as sum2, c9@1 as c9] -04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)), is_causal: true }], mode=[Sorted] 05)--------ProjectionExec: expr=[c1@0 as c1, c9@2 as c9, c12@3 as c12, sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING] 06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] 07)------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST], preserve_partitioning=[false] @@ -2378,17 +2378,41 @@ SELECT c9, rn1 FROM (SELECT c9, # invalid window frame. null as preceding -statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers select row_number() over (rows between null preceding and current row) from (select 1 a) x # invalid window frame. null as preceding -statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers select row_number() over (rows between null preceding and current row) from (select 1 a) x # invalid window frame. negative as following -statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers select row_number() over (rows between current row and -1 following) from (select 1 a) x +# invalid window frame. null as preceding +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers +select row_number() over (order by a groups between null preceding and current row) from (select 1 a) x + +# invalid window frame. null as preceding +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers +select row_number() over (order by a groups between null preceding and current row) from (select 1 a) x + +# invalid window frame. negative as following +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers +select row_number() over (order by a groups between current row and -1 following) from (select 1 a) x + +# interval for rows +query I +select row_number() over (rows between '1' preceding and current row) from (select 1 a) x +---- +1 + +# interval for groups +query I +select row_number() over (order by a groups between '1' preceding and current row) from (select 1 a) x +---- +1 + # This test shows that ordering satisfy considers ordering equivalences, # and can simplify (reduce expression size) multi expression requirements during normalization # For the example below, requirement rn1 ASC, c9 DESC should be simplified to the rn1 ASC. diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 4105dc1876db..4855af683b7d 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1718,98 +1718,38 @@ fn make_substrait_like_expr( } } +fn to_substrait_bound_offset(value: &ScalarValue) -> Option { + match value { + ScalarValue::UInt8(Some(v)) => Some(*v as i64), + ScalarValue::UInt16(Some(v)) => Some(*v as i64), + ScalarValue::UInt32(Some(v)) => Some(*v as i64), + ScalarValue::UInt64(Some(v)) => Some(*v as i64), + ScalarValue::Int8(Some(v)) => Some(*v as i64), + ScalarValue::Int16(Some(v)) => Some(*v as i64), + ScalarValue::Int32(Some(v)) => Some(*v as i64), + ScalarValue::Int64(Some(v)) => Some(*v), + _ => None, + } +} + fn to_substrait_bound(bound: &WindowFrameBound) -> Bound { match bound { WindowFrameBound::CurrentRow => Bound { kind: Some(BoundKind::CurrentRow(SubstraitBound::CurrentRow {})), }, - WindowFrameBound::Preceding(s) => match s { - ScalarValue::UInt8(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), + WindowFrameBound::Preceding(s) => match to_substrait_bound_offset(s) { + Some(offset) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { offset })), }, - ScalarValue::UInt16(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::UInt32(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::UInt64(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::Int8(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::Int16(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::Int32(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::Int64(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v, - })), - }, - _ => Bound { + None => Bound { kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), }, }, - WindowFrameBound::Following(s) => match s { - ScalarValue::UInt8(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::UInt16(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::UInt32(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::UInt64(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int8(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int16(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int32(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int64(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v, - })), + WindowFrameBound::Following(s) => match to_substrait_bound_offset(s) { + Some(offset) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { offset })), }, - _ => Bound { + None => Bound { kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), }, },