From 127339370b1d45343783b889237e0173d0e343a4 Mon Sep 17 00:00:00 2001 From: Andrey Koshchiy Date: Tue, 29 Oct 2024 22:52:56 +0300 Subject: [PATCH] feat(substrait): AggregateRel grouping_expression support --- .../substrait/src/logical_plan/consumer.rs | 77 ++++++++++----- .../substrait/src/logical_plan/producer.rs | 58 ++++++++--- .../tests/cases/roundtrip_logical_plan.rs | 13 +++ ...roject_group_expression_ref.substrait.json | 98 +++++++++++++++++++ 4 files changed, 210 insertions(+), 36 deletions(-) create mode 100644 datafusion/substrait/tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 3d5d7cce5673..378ac411e112 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -33,6 +33,7 @@ use datafusion::logical_expr::{ expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, Values, }; +use substrait::proto::aggregate_rel::Grouping; use substrait::proto::expression::subquery::set_predicate::PredicateOp; use substrait::proto::expression_reference::ExprType; use url::Url; @@ -652,39 +653,48 @@ pub async fn from_substrait_rel( let input = LogicalPlanBuilder::from( from_substrait_rel(ctx, input, extensions).await?, ); - let mut group_expr = vec![]; - let mut aggr_expr = vec![]; + let mut ref_group_exprs = vec![]; + + for e in &agg.grouping_expressions { + let x = + from_substrait_rex(ctx, e, input.schema(), extensions).await?; + ref_group_exprs.push(x); + } + + let mut group_exprs = vec![]; + let mut aggr_exprs = vec![]; match agg.groupings.len() { 1 => { - for e in &agg.groupings[0].grouping_expressions { - let x = - from_substrait_rex(ctx, e, input.schema(), extensions) - .await?; - group_expr.push(x); - } + group_exprs.extend_from_slice( + &from_substrait_grouping( + ctx, + &agg.groupings[0], + &ref_group_exprs, + input.schema(), + extensions, + ) + .await?, + ); } _ => { let mut grouping_sets = vec![]; for grouping in &agg.groupings { - let mut grouping_set = vec![]; - for e in &grouping.grouping_expressions { - let x = from_substrait_rex( - ctx, - e, - input.schema(), - extensions, - ) - .await?; - grouping_set.push(x); - } + let grouping_set = from_substrait_grouping( + ctx, + grouping, + &ref_group_exprs, + input.schema(), + extensions, + ) + .await?; grouping_sets.push(grouping_set); } // Single-element grouping expression of type Expr::GroupingSet. // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when // parsed by the producer and consumer, since Substrait does not have a type dedicated // to ROLLUP. Only vector of Groupings (grouping sets) is available. - group_expr.push(Expr::GroupingSet(GroupingSet::GroupingSets( + group_exprs.push(Expr::GroupingSet(GroupingSet::GroupingSets( grouping_sets, ))); } @@ -729,9 +739,9 @@ pub async fn from_substrait_rel( "Aggregate without aggregate function is not supported" ), }; - aggr_expr.push(agg_func?.as_ref().clone()); + aggr_exprs.push(agg_func?.as_ref().clone()); } - input.aggregate(group_expr, aggr_expr)?.build() + input.aggregate(group_exprs, aggr_exprs)?.build() } else { not_impl_err!("Aggregate without an input is not valid") } @@ -2571,6 +2581,29 @@ fn from_substrait_null( } } +#[allow(deprecated)] +async fn from_substrait_grouping( + ctx: &SessionContext, + grouping: &Grouping, + expressions: &[Expr], + input_schema: &DFSchemaRef, + extensions: &Extensions, +) -> Result> { + let mut group_exprs = vec![]; + if !grouping.grouping_expressions.is_empty() { + for e in &grouping.grouping_expressions { + let expr = from_substrait_rex(ctx, e, input_schema, extensions).await?; + group_exprs.push(expr); + } + return Ok(group_exprs); + } + for idx in &grouping.expression_references { + let e = &expressions[*idx as usize]; + group_exprs.push(e.clone()); + } + Ok(group_exprs) +} + fn from_substrait_field_reference( field_ref: &FieldReference, input_schema: &DFSchema, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 4855af683b7d..07fd39e6dda4 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -361,7 +361,7 @@ pub fn to_substrait_rel( } LogicalPlan::Aggregate(agg) => { let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?; - let groupings = to_substrait_groupings( + let (grouping_expressions, groupings) = to_substrait_groupings( ctx, &agg.group_expr, agg.input.schema(), @@ -377,7 +377,7 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, input: Some(input), - grouping_expressions: vec![], + grouping_expressions, groupings, measures, advanced_extension: None, @@ -774,14 +774,20 @@ pub fn parse_flat_grouping_exprs( exprs: &[Expr], schema: &DFSchemaRef, extensions: &mut Extensions, + ref_group_exprs: &mut Vec, ) -> Result { - let grouping_expressions = exprs - .iter() - .map(|e| to_substrait_rex(ctx, e, schema, 0, extensions)) - .collect::>>()?; + let mut expression_references = vec![]; + let mut grouping_expressions = vec![]; + + for e in exprs { + let rex = to_substrait_rex(ctx, e, schema, 0, extensions)?; + grouping_expressions.push(rex.clone()); + ref_group_exprs.push(rex); + expression_references.push((ref_group_exprs.len() - 1) as u32); + } Ok(Grouping { grouping_expressions, - expression_references: vec![], + expression_references, }) } @@ -790,8 +796,9 @@ pub fn to_substrait_groupings( exprs: &[Expr], schema: &DFSchemaRef, extensions: &mut Extensions, -) -> Result> { - match exprs.len() { +) -> Result<(Vec, Vec)> { + let mut ref_group_exprs = vec![]; + let groupings = match exprs.len() { 1 => match &exprs[0] { Expr::GroupingSet(gs) => match gs { GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented( @@ -799,7 +806,15 @@ pub fn to_substrait_groupings( )), GroupingSet::GroupingSets(sets) => Ok(sets .iter() - .map(|set| parse_flat_grouping_exprs(ctx, set, schema, extensions)) + .map(|set| { + parse_flat_grouping_exprs( + ctx, + set, + schema, + extensions, + &mut ref_group_exprs, + ) + }) .collect::>>()?), GroupingSet::Rollup(set) => { let mut sets: Vec> = vec![vec![]]; @@ -810,19 +825,34 @@ pub fn to_substrait_groupings( .iter() .rev() .map(|set| { - parse_flat_grouping_exprs(ctx, set, schema, extensions) + parse_flat_grouping_exprs( + ctx, + set, + schema, + extensions, + &mut ref_group_exprs, + ) }) .collect::>>()?) } }, _ => Ok(vec![parse_flat_grouping_exprs( - ctx, exprs, schema, extensions, + ctx, + exprs, + schema, + extensions, + &mut ref_group_exprs, )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( - ctx, exprs, schema, extensions, + ctx, + exprs, + schema, + extensions, + &mut ref_group_exprs, )?]), - } + }?; + Ok((ref_group_exprs, groupings)) } #[allow(deprecated)] diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 06a047b108bd..7f6cb0fd0868 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -680,6 +680,19 @@ async fn aggregate_wo_projection_consume() -> Result<()> { .await } +#[tokio::test] +async fn aggregate_wo_projection_group_expression_ref_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json"); + + assert_expected_plan_substrait( + proto_plan, + "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\ + \n TableScan: data projection=[a]", + ) + .await +} + #[tokio::test] async fn simple_intersect_consume() -> Result<()> { let proto_plan = read_json("tests/testdata/test_plans/intersect.substrait.json"); diff --git a/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json b/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json new file mode 100644 index 000000000000..b6f14afd6fa9 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json @@ -0,0 +1,98 @@ +{ + "extensionUris": [ + { + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 185, + "name": "count:any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "aggregate": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "grouping_expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ], + "groupings": [ + { + "expression_references": [0] + } + ], + "measures": [ + { + "measure": { + "functionReference": 185, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": {} + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + }, + "names": [ + "a", + "countA" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } +} \ No newline at end of file