From 94034487e79c0867f228b6049d7b8e49a55f5ec5 Mon Sep 17 00:00:00 2001 From: Gabriel <45515538+gabotechs@users.noreply.github.com> Date: Thu, 16 Jan 2025 21:18:53 +0100 Subject: [PATCH] Fix: regularize order bys when consuming from substrait (#14125) * Fix: regularize order bys when consuming from substrait * Add window_function_with_range_unit_and_no_order_by test * Fix typo in comment --- .../substrait/src/logical_plan/consumer.rs | 55 ++++++++++++++++--- 1 file changed, 47 insertions(+), 8 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 9623f12c88dd..5a7d70c5e765 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -2194,7 +2194,8 @@ pub async fn from_window_function( ) }?; - let order_by = from_substrait_sorts(consumer, &window.sorts, input_schema).await?; + let mut order_by = + from_substrait_sorts(consumer, &window.sorts, input_schema).await?; let bound_units = match BoundsType::try_from(window.bounds_type).map_err(|e| { plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) @@ -2212,17 +2213,21 @@ pub async fn from_window_function( } } }; + let window_frame = datafusion::logical_expr::WindowFrame::new_bounds( + bound_units, + from_substrait_bound(&window.lower_bound, true)?, + from_substrait_bound(&window.upper_bound, false)?, + ); + + window_frame.regularize_order_bys(&mut order_by)?; + Ok(Expr::WindowFunction(expr::WindowFunction { fun, args: from_substrait_func_args(consumer, &window.arguments, input_schema).await?, partition_by: from_substrait_rex_vec(consumer, &window.partitions, input_schema) .await?, order_by, - window_frame: datafusion::logical_expr::WindowFrame::new_bounds( - bound_units, - from_substrait_bound(&window.lower_bound, true)?, - from_substrait_bound(&window.upper_bound, false)?, - ), + window_frame, null_treatment: None, })) } @@ -3271,18 +3276,21 @@ impl BuiltinExprBuilder { mod test { use crate::extensions::Extensions; use crate::logical_plan::consumer::{ - from_substrait_literal_without_names, DefaultSubstraitConsumer, + from_substrait_literal_without_names, from_substrait_rex, + DefaultSubstraitConsumer, }; use arrow_buffer::IntervalMonthDayNano; + use datafusion::common::DFSchema; use datafusion::error::Result; use datafusion::execution::SessionState; - use datafusion::prelude::SessionContext; + use datafusion::prelude::{Expr, SessionContext}; use datafusion::scalar::ScalarValue; use std::sync::OnceLock; use substrait::proto::expression::literal::{ interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, LiteralType, }; + use substrait::proto::expression::window_function::BoundsType; use substrait::proto::expression::Literal; static TEST_SESSION_STATE: OnceLock = OnceLock::new(); @@ -3328,4 +3336,35 @@ mod test { Ok(()) } + + #[tokio::test] + async fn window_function_with_range_unit_and_no_order_by() -> Result<()> { + let substrait = substrait::proto::Expression { + rex_type: Some(substrait::proto::expression::RexType::WindowFunction( + substrait::proto::expression::WindowFunction { + function_reference: 0, + bounds_type: BoundsType::Range as i32, + sorts: vec![], + ..Default::default() + }, + )), + }; + + let mut consumer = test_consumer(); + + // Just registering a single function (index 0) so that the plan + // does not throw a "function not found" error. + let mut extensions = Extensions::default(); + extensions.register_function("count".to_string()); + consumer.extensions = &extensions; + + match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { + Expr::WindowFunction(window_function) => { + assert_eq!(window_function.order_by.len(), 1) + } + _ => panic!("expr was not a WindowFunction"), + }; + + Ok(()) + } }