Skip to content

Commit

Permalink
Fix: regularize order bys when consuming from substrait (apache#14125)
Browse files Browse the repository at this point in the history
* Fix: regularize order bys when consuming from substrait

* Add window_function_with_range_unit_and_no_order_by test

* Fix typo in comment
  • Loading branch information
gabotechs authored Jan 16, 2025
1 parent 05f4e5a commit 9403448
Showing 1 changed file with 47 additions and 8 deletions.
55 changes: 47 additions & 8 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2194,7 +2194,8 @@ pub async fn from_window_function(
)
}?;

let order_by = from_substrait_sorts(consumer, &window.sorts, input_schema).await?;
let mut order_by =
from_substrait_sorts(consumer, &window.sorts, input_schema).await?;

let bound_units = match BoundsType::try_from(window.bounds_type).map_err(|e| {
plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type)
Expand All @@ -2212,17 +2213,21 @@ pub async fn from_window_function(
}
}
};
let window_frame = datafusion::logical_expr::WindowFrame::new_bounds(
bound_units,
from_substrait_bound(&window.lower_bound, true)?,
from_substrait_bound(&window.upper_bound, false)?,
);

window_frame.regularize_order_bys(&mut order_by)?;

Ok(Expr::WindowFunction(expr::WindowFunction {
fun,
args: from_substrait_func_args(consumer, &window.arguments, input_schema).await?,
partition_by: from_substrait_rex_vec(consumer, &window.partitions, input_schema)
.await?,
order_by,
window_frame: datafusion::logical_expr::WindowFrame::new_bounds(
bound_units,
from_substrait_bound(&window.lower_bound, true)?,
from_substrait_bound(&window.upper_bound, false)?,
),
window_frame,
null_treatment: None,
}))
}
Expand Down Expand Up @@ -3271,18 +3276,21 @@ impl BuiltinExprBuilder {
mod test {
use crate::extensions::Extensions;
use crate::logical_plan::consumer::{
from_substrait_literal_without_names, DefaultSubstraitConsumer,
from_substrait_literal_without_names, from_substrait_rex,
DefaultSubstraitConsumer,
};
use arrow_buffer::IntervalMonthDayNano;
use datafusion::common::DFSchema;
use datafusion::error::Result;
use datafusion::execution::SessionState;
use datafusion::prelude::SessionContext;
use datafusion::prelude::{Expr, SessionContext};
use datafusion::scalar::ScalarValue;
use std::sync::OnceLock;
use substrait::proto::expression::literal::{
interval_day_to_second, IntervalCompound, IntervalDayToSecond,
IntervalYearToMonth, LiteralType,
};
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::expression::Literal;

static TEST_SESSION_STATE: OnceLock<SessionState> = OnceLock::new();
Expand Down Expand Up @@ -3328,4 +3336,35 @@ mod test {

Ok(())
}

#[tokio::test]
async fn window_function_with_range_unit_and_no_order_by() -> Result<()> {
let substrait = substrait::proto::Expression {
rex_type: Some(substrait::proto::expression::RexType::WindowFunction(
substrait::proto::expression::WindowFunction {
function_reference: 0,
bounds_type: BoundsType::Range as i32,
sorts: vec![],
..Default::default()
},
)),
};

let mut consumer = test_consumer();

// Just registering a single function (index 0) so that the plan
// does not throw a "function not found" error.
let mut extensions = Extensions::default();
extensions.register_function("count".to_string());
consumer.extensions = &extensions;

match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
Expr::WindowFunction(window_function) => {
assert_eq!(window_function.order_by.len(), 1)
}
_ => panic!("expr was not a WindowFunction"),
};

Ok(())
}
}

0 comments on commit 9403448

Please sign in to comment.