Skip to content

Commit

Permalink
Remove physical sort parameters on aggregate window functions (#12009)
Browse files Browse the repository at this point in the history
* Remove order_by on aggregate window functions since that operation is handled by the window function

* Add unit test for window functions using udaf with ordering

* Resolve clippy warning
  • Loading branch information
timsaucer authored Aug 15, 2024
1 parent 19ad53d commit 9d1cf74
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 2 deletions.
89 changes: 88 additions & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1703,13 +1703,16 @@ mod tests {
use arrow::array::{self, Int32Array};
use datafusion_common::{Constraint, Constraints, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::expr::WindowFunction;
use datafusion_expr::{
cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt,
ScalarFunctionImplementation, Volatility, WindowFunctionDefinition,
ScalarFunctionImplementation, Volatility, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties};
use sqlparser::ast::NullTreatment;

// Get string representation of the plan
async fn assert_physical_plan(df: &DataFrame, expected: Vec<&str>) {
Expand Down Expand Up @@ -2355,6 +2358,90 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn window_using_aggregates() -> Result<()> {
// build plan using DataFrame API
let df = test_table().await?.filter(col("c1").eq(lit("a")))?;
let mut aggr_expr = vec![
(
datafusion_functions_aggregate::first_last::first_value_udaf(),
"first_value",
),
(
datafusion_functions_aggregate::first_last::last_value_udaf(),
"last_val",
),
(
datafusion_functions_aggregate::approx_distinct::approx_distinct_udaf(),
"approx_distinct",
),
(
datafusion_functions_aggregate::approx_median::approx_median_udaf(),
"approx_median",
),
(
datafusion_functions_aggregate::median::median_udaf(),
"median",
),
(datafusion_functions_aggregate::min_max::max_udaf(), "max"),
(datafusion_functions_aggregate::min_max::min_udaf(), "min"),
]
.into_iter()
.map(|(func, name)| {
let w = WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(func),
vec![col("c3")],
);

Expr::WindowFunction(w)
.null_treatment(NullTreatment::IgnoreNulls)
.order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)])
.window_frame(WindowFrame::new_bounds(
WindowFrameUnits::Rows,
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
))
.build()
.unwrap()
.alias(name)
})
.collect::<Vec<_>>();
aggr_expr.extend_from_slice(&[col("c2"), col("c3")]);

let df: Vec<RecordBatch> = df.select(aggr_expr)?.collect().await?;

assert_batches_sorted_eq!(
["+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
"| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |",
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
"| | | | | | | | 1 | -85 |",
"| -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 |",
"| -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 |",
"| -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 |",
"| -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 |",
"| -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 |",
"| -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 |",
"| -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 |",
"| -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 |",
"| -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 |",
"| -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 |",
"| -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 |",
"| -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 |",
"| -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 |",
"| -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 |",
"| -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 |",
"| -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 |",
"| -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 |",
"| -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 |",
"| -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 |",
"| -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 |",
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+"],
&df
);

Ok(())
}

// Test issue: https://github.com/apache/datafusion/issues/10346
#[tokio::test]
async fn test_select_over_aggregate_schema() -> Result<()> {
Expand Down
1 change: 0 additions & 1 deletion datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ pub fn create_window_expr(
let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec())
.schema(Arc::new(input_schema.clone()))
.alias(name)
.order_by(order_by.to_vec())
.with_ignore_nulls(ignore_nulls)
.build()?;
window_expr_from_aggregate_expr(
Expand Down

0 comments on commit 9d1cf74

Please sign in to comment.