diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index fb6a988b1cc6..c95ba7c2492d 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -298,8 +298,8 @@ impl AggregateExec { &original_schema, group_by.expr.len(), )); - let original_schema = Arc::new(original_schema); + let original_schema = Arc::new(original_schema); AggregateExec::try_new_with_schema( mode, group_by, @@ -312,9 +312,16 @@ impl AggregateExec { ) } - /// Create a new hash aggregate execution plan + /// Create a new hash aggregate execution plan with the given schema. + /// This constructor isn't part of the public API, it is used internally + /// by Datafusion to enforce schema consistency during when re-creating + /// `AggregateExec`s inside optimization rules. Schema field names of an + /// `AggregateExec` depends on the names of aggregate expressions. Since + /// a rule may re-write aggregate expressions (e.g. reverse them) during + /// initialization, field names may change inadvertently if one re-creates + /// the schema in such cases. #[allow(clippy::too_many_arguments)] - pub fn try_new_with_schema( + fn try_new_with_schema( mode: AggregateMode, group_by: PhysicalGroupBy, aggr_expr: Vec>, @@ -1151,7 +1158,8 @@ mod tests { lit, ApproxDistinct, Count, FirstValue, LastValue, Median, OrderSensitiveArrayAgg, }; use datafusion_physical_expr::{ - AggregateExpr, EquivalenceProperties, PhysicalExpr, PhysicalSortExpr, + reverse_order_bys, AggregateExpr, EquivalenceProperties, PhysicalExpr, + PhysicalSortExpr, }; use datafusion_execution::memory_pool::FairSpillPool; @@ -2084,4 +2092,56 @@ mod tests { assert_eq!(res, common_requirement); Ok(()) } + + #[test] + fn test_agg_exec_same_schema() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, true), + Field::new("b", DataType::Float32, true), + ])); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let sort_expr = vec![PhysicalSortExpr { + expr: col_b.clone(), + options: option_desc, + }]; + let sort_expr_reverse = reverse_order_bys(&sort_expr); + let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); + + let aggregates: Vec> = vec![ + Arc::new(FirstValue::new( + col_b.clone(), + "FIRST_VALUE(b)".to_string(), + DataType::Float64, + sort_expr_reverse.clone(), + vec![DataType::Float64], + )), + Arc::new(LastValue::new( + col_b.clone(), + "LAST_VALUE(b)".to_string(), + DataType::Float64, + sort_expr.clone(), + vec![DataType::Float64], + )), + ]; + let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups, + aggregates.clone(), + vec![None, None], + blocking_exec.clone(), + schema, + )?); + let new_agg = aggregate_exec + .clone() + .with_new_children(vec![blocking_exec])?; + assert_eq!(new_agg.schema(), aggregate_exec.schema()); + Ok(()) + } }