Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(substrait): AggregateRel grouping_expressions support #13173

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 55 additions & 22 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use datafusion::logical_expr::{
expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr,
ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, Values,
};
use substrait::proto::aggregate_rel::Grouping;
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
use substrait::proto::expression_reference::ExprType;
use url::Url;
Expand Down Expand Up @@ -654,39 +655,48 @@ pub async fn from_substrait_rel(
let input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
);
let mut group_expr = vec![];
let mut aggr_expr = vec![];
let mut ref_group_exprs = vec![];

for e in &agg.grouping_expressions {
let x =
from_substrait_rex(ctx, e, input.schema(), extensions).await?;
ref_group_exprs.push(x);
}

let mut group_exprs = vec![];
let mut aggr_exprs = vec![];

match agg.groupings.len() {
1 => {
for e in &agg.groupings[0].grouping_expressions {
let x =
from_substrait_rex(ctx, e, input.schema(), extensions)
.await?;
group_expr.push(x);
}
group_exprs.extend_from_slice(
&from_substrait_grouping(
ctx,
&agg.groupings[0],
&ref_group_exprs,
input.schema(),
extensions,
)
.await?,
);
}
_ => {
let mut grouping_sets = vec![];
for grouping in &agg.groupings {
let mut grouping_set = vec![];
for e in &grouping.grouping_expressions {
let x = from_substrait_rex(
ctx,
e,
input.schema(),
extensions,
)
.await?;
grouping_set.push(x);
}
let grouping_set = from_substrait_grouping(
ctx,
grouping,
&ref_group_exprs,
input.schema(),
extensions,
)
.await?;
grouping_sets.push(grouping_set);
}
// Single-element grouping expression of type Expr::GroupingSet.
// Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when
// parsed by the producer and consumer, since Substrait does not have a type dedicated
// to ROLLUP. Only vector of Groupings (grouping sets) is available.
group_expr.push(Expr::GroupingSet(GroupingSet::GroupingSets(
group_exprs.push(Expr::GroupingSet(GroupingSet::GroupingSets(
grouping_sets,
)));
}
Expand Down Expand Up @@ -744,9 +754,9 @@ pub async fn from_substrait_rel(
"Aggregate without aggregate function is not supported"
),
};
aggr_expr.push(agg_func?.as_ref().clone());
aggr_exprs.push(agg_func?.as_ref().clone());
}
input.aggregate(group_expr, aggr_expr)?.build()
input.aggregate(group_exprs, aggr_exprs)?.build()
} else {
not_impl_err!("Aggregate without an input is not valid")
}
Expand Down Expand Up @@ -2618,6 +2628,29 @@ fn from_substrait_null(
}
}

#[allow(deprecated)]
async fn from_substrait_grouping(
ctx: &SessionContext,
grouping: &Grouping,
expressions: &[Expr],
input_schema: &DFSchemaRef,
extensions: &Extensions,
) -> Result<Vec<Expr>> {
let mut group_exprs = vec![];
if !grouping.grouping_expressions.is_empty() {
for e in &grouping.grouping_expressions {
let expr = from_substrait_rex(ctx, e, input_schema, extensions).await?;
group_exprs.push(expr);
}
return Ok(group_exprs);
}
for idx in &grouping.expression_references {
let e = &expressions[*idx as usize];
group_exprs.push(e.clone());
}
Ok(group_exprs)
}

fn from_substrait_field_reference(
field_ref: &FieldReference,
input_schema: &DFSchema,
Expand Down
58 changes: 44 additions & 14 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ pub fn to_substrait_rel(
}
LogicalPlan::Aggregate(agg) => {
let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?;
let groupings = to_substrait_groupings(
let (grouping_expressions, groupings) = to_substrait_groupings(
ctx,
&agg.group_expr,
agg.input.schema(),
Expand All @@ -377,7 +377,7 @@ pub fn to_substrait_rel(
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
input: Some(input),
grouping_expressions: vec![],
grouping_expressions,
groupings,
measures,
advanced_extension: None,
Expand Down Expand Up @@ -771,14 +771,20 @@ pub fn parse_flat_grouping_exprs(
exprs: &[Expr],
schema: &DFSchemaRef,
extensions: &mut Extensions,
ref_group_exprs: &mut Vec<Expression>,
) -> Result<Grouping> {
let grouping_expressions = exprs
.iter()
.map(|e| to_substrait_rex(ctx, e, schema, 0, extensions))
.collect::<Result<Vec<_>>>()?;
let mut expression_references = vec![];
let mut grouping_expressions = vec![];

for e in exprs {
let rex = to_substrait_rex(ctx, e, schema, 0, extensions)?;
grouping_expressions.push(rex.clone());
ref_group_exprs.push(rex);
expression_references.push((ref_group_exprs.len() - 1) as u32);
}
Ok(Grouping {
grouping_expressions,
expression_references: vec![],
expression_references,
})
}

Expand All @@ -787,16 +793,25 @@ pub fn to_substrait_groupings(
exprs: &[Expr],
schema: &DFSchemaRef,
extensions: &mut Extensions,
) -> Result<Vec<Grouping>> {
match exprs.len() {
) -> Result<(Vec<Expression>, Vec<Grouping>)> {
let mut ref_group_exprs = vec![];
let groupings = match exprs.len() {
1 => match &exprs[0] {
Expr::GroupingSet(gs) => match gs {
GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented(
"GroupingSet CUBE is not yet supported".to_string(),
)),
GroupingSet::GroupingSets(sets) => Ok(sets
.iter()
.map(|set| parse_flat_grouping_exprs(ctx, set, schema, extensions))
.map(|set| {
parse_flat_grouping_exprs(
ctx,
set,
schema,
extensions,
&mut ref_group_exprs,
)
})
.collect::<Result<Vec<_>>>()?),
GroupingSet::Rollup(set) => {
let mut sets: Vec<Vec<Expr>> = vec![vec![]];
Expand All @@ -807,19 +822,34 @@ pub fn to_substrait_groupings(
.iter()
.rev()
.map(|set| {
parse_flat_grouping_exprs(ctx, set, schema, extensions)
parse_flat_grouping_exprs(
ctx,
set,
schema,
extensions,
&mut ref_group_exprs,
)
})
.collect::<Result<Vec<_>>>()?)
}
},
_ => Ok(vec![parse_flat_grouping_exprs(
ctx, exprs, schema, extensions,
ctx,
exprs,
schema,
extensions,
&mut ref_group_exprs,
)?]),
},
_ => Ok(vec![parse_flat_grouping_exprs(
ctx, exprs, schema, extensions,
ctx,
exprs,
schema,
extensions,
&mut ref_group_exprs,
)?]),
}
}?;
Ok((ref_group_exprs, groupings))
}

#[allow(deprecated)]
Expand Down
13 changes: 13 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,19 @@ async fn aggregate_wo_projection_consume() -> Result<()> {
.await
}

#[tokio::test]
async fn aggregate_wo_projection_group_expression_ref_consume() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json");

assert_expected_plan_substrait(
proto_plan,
"Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\
\n TableScan: data projection=[a]",
)
.await
}

#[tokio::test]
async fn aggregate_wo_projection_sorted_consume() -> Result<()> {
let proto_plan =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
{
"extensionUris": [
{
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml"
}
],
"extensions": [
{
"extensionFunction": {
"functionAnchor": 185,
"name": "count:any"
}
}
],
"relations": [
{
"root": {
"input": {
"aggregate": {
"input": {
"read": {
"common": {
"direct": {}
},
"baseSchema": {
"names": [
"a"
],
"struct": {
"types": [
{
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
}
],
"nullability": "NULLABILITY_NULLABLE"
}
},
"namedTable": {
"names": [
"data"
]
}
}
},
"grouping_expressions": [
{
"selection": {
"directReference": {
"structField": {}
},
"rootReference": {}
}
}
],
"groupings": [
{
"expression_references": [0]
}
],
"measures": [
{
"measure": {
"functionReference": 185,
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
"outputType": {
"i64": {}
},
"arguments": [
{
"value": {
"selection": {
"directReference": {
"structField": {}
},
"rootReference": {}
}
}
}
]
}
}
]
}
},
"names": [
"a",
"countA"
]
}
}
],
"version": {
"minorNumber": 54,
"producer": "subframe"
}
}