From 12941e404b5c02c06faa5a8f2d04af9744cbed4a Mon Sep 17 00:00:00 2001
From: Arttu <Blizzara@users.noreply.github.com>
Date: Tue, 11 Jun 2024 20:35:34 +0200
Subject: [PATCH] Support correct output column names and struct field names
 when consuming/producing Substrait (#10829)

* produce flattened list of names including inner struct fields

* add a (failing) test

* rename output columns (incl. inner struct fields) according to the given list of names

* fix a test

* add column names project to the new TPC-H test and fix case

(assert_eq gives nicer error messages than assert)
---
 .../substrait/src/logical_plan/consumer.rs    | 133 +++++++++++++++++-
 .../substrait/src/logical_plan/producer.rs    |   2 +-
 .../tests/cases/consumer_integration.rs       |  17 +--
 .../substrait/tests/cases/logical_plans.rs    |   2 +-
 .../tests/cases/roundtrip_logical_plan.rs     |  43 +++---
 5 files changed, 156 insertions(+), 41 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs
index 8a483db8c4d62..648a281832e10 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -17,7 +17,7 @@
 
 use async_recursion::async_recursion;
 use datafusion::arrow::datatypes::{
-    DataType, Field, Fields, IntervalUnit, Schema, TimeUnit,
+    DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
 };
 use datafusion::common::{
     not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
@@ -29,12 +29,13 @@ use url::Url;
 use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
 use datafusion::execution::FunctionRegistry;
 use datafusion::logical_expr::{
-    aggregate_function, expr::find_df_window_func, BinaryExpr, Case, EmptyRelation, Expr,
-    LogicalPlan, Operator, ScalarUDF, Values,
+    aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case,
+    EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarUDF,
+    Values,
 };
 
 use datafusion::logical_expr::{
-    expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
+    col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
     Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
 };
 use datafusion::prelude::JoinType;
@@ -225,6 +226,7 @@ pub async fn from_substrait_plan(
             None => not_impl_err!("Cannot parse empty extension"),
         })
         .collect::<Result<HashMap<_, _>>>()?;
+
     // Parse relations
     match plan.relations.len() {
         1 => {
@@ -234,7 +236,29 @@ pub async fn from_substrait_plan(
                         Ok(from_substrait_rel(ctx, rel, &function_extension).await?)
                     },
                     plan_rel::RelType::Root(root) => {
-                        Ok(from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?)
+                        let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?;
+                        if root.names.is_empty() {
+                            // Backwards compatibility for plans missing names
+                            return Ok(plan);
+                        }
+                        let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?;
+                        if renamed_schema.equivalent_names_and_types(plan.schema()) {
+                            // Nothing to do if the schema is already equivalent
+                            return Ok(plan);
+                        }
+
+                        match plan {
+                            // If the last node of the plan produces expressions, bake the renames into those expressions.
+                            // This isn't necessary for correctness, but helps with roundtrip tests.
+                            LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema)?, p.input)?)),
+                            LogicalPlan::Aggregate(a) => {
+                                let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), renamed_schema)?;
+                                Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?))
+                            },
+                            // There are probably more plans where we could bake things in, can add them later as needed.
+                            // Otherwise, add a new Project to handle the renaming.
+                            _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema)?, Arc::new(plan))?))
+                        }
                     }
                 },
                 None => plan_err!("Cannot parse plan relation: None")
@@ -284,6 +308,105 @@ pub fn extract_projection(
     }
 }
 
+fn rename_expressions(
+    exprs: impl IntoIterator<Item = Expr>,
+    input_schema: &DFSchema,
+    new_schema: DFSchemaRef,
+) -> Result<Vec<Expr>> {
+    exprs
+        .into_iter()
+        .zip(new_schema.fields())
+        .map(|(old_expr, new_field)| {
+            if &old_expr.get_type(input_schema)? == new_field.data_type() {
+                // Alias column if needed
+                old_expr.alias_if_changed(new_field.name().into())
+            } else {
+                // Use Cast to rename inner struct fields + alias column if needed
+                Expr::Cast(Cast::new(
+                    Box::new(old_expr),
+                    new_field.data_type().to_owned(),
+                ))
+                .alias_if_changed(new_field.name().into())
+            }
+        })
+        .collect()
+}
+
+fn make_renamed_schema(
+    schema: &DFSchemaRef,
+    dfs_names: &Vec<String>,
+) -> Result<DFSchemaRef> {
+    fn rename_inner_fields(
+        dtype: &DataType,
+        dfs_names: &Vec<String>,
+        name_idx: &mut usize,
+    ) -> Result<DataType> {
+        match dtype {
+            DataType::Struct(fields) => {
+                let fields = fields
+                    .iter()
+                    .map(|f| {
+                        let name = next_struct_field_name(0, dfs_names, name_idx)?;
+                        Ok((**f).to_owned().with_name(name).with_data_type(
+                            rename_inner_fields(f.data_type(), dfs_names, name_idx)?,
+                        ))
+                    })
+                    .collect::<Result<_>>()?;
+                Ok(DataType::Struct(fields))
+            }
+            DataType::List(inner) => Ok(DataType::List(FieldRef::new(
+                (**inner).to_owned().with_data_type(rename_inner_fields(
+                    inner.data_type(),
+                    dfs_names,
+                    name_idx,
+                )?),
+            ))),
+            DataType::LargeList(inner) => Ok(DataType::LargeList(FieldRef::new(
+                (**inner).to_owned().with_data_type(rename_inner_fields(
+                    inner.data_type(),
+                    dfs_names,
+                    name_idx,
+                )?),
+            ))),
+            _ => Ok(dtype.to_owned()),
+        }
+    }
+
+    let mut name_idx = 0;
+
+    let (qualifiers, fields): (_, Vec<Field>) = schema
+        .iter()
+        .map(|(q, f)| {
+            let name = next_struct_field_name(0, dfs_names, &mut name_idx)?;
+            Ok((
+                q.cloned(),
+                (**f)
+                    .to_owned()
+                    .with_name(name)
+                    .with_data_type(rename_inner_fields(
+                        f.data_type(),
+                        dfs_names,
+                        &mut name_idx,
+                    )?),
+            ))
+        })
+        .collect::<Result<Vec<_>>>()?
+        .into_iter()
+        .unzip();
+
+    if name_idx != dfs_names.len() {
+        return substrait_err!(
+            "Names list must match exactly to nested schema, but found {} uses for {} names",
+            name_idx,
+            dfs_names.len());
+    }
+
+    Ok(Arc::new(DFSchema::from_field_specific_qualified_schema(
+        qualifiers,
+        &Arc::new(Schema::new(fields)),
+    )?))
+}
+
 /// Convert Substrait Rel to DataFusion DataFrame
 #[async_recursion]
 pub async fn from_substrait_rel(
diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs
index 6c8be4aa9b121..88dc894eccd28 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -115,7 +115,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result<Box
     let plan_rels = vec![PlanRel {
         rel_type: Some(plan_rel::RelType::Root(RelRoot {
             input: Some(*to_substrait_rel(plan, ctx, &mut extension_info)?),
-            names: plan.schema().field_names(),
+            names: to_substrait_named_struct(plan.schema())?.names,
         })),
     }];
 
diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs
index c2ae5691134a4..e0151ecc3a4f9 100644
--- a/datafusion/substrait/tests/cases/consumer_integration.rs
+++ b/datafusion/substrait/tests/cases/consumer_integration.rs
@@ -43,14 +43,15 @@ mod tests {
 
         let plan = from_substrait_plan(&ctx, &proto).await?;
 
-        assert!(
-            format!("{:?}", plan).eq_ignore_ascii_case(
-                "Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST, FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\n  \
-                 Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus]], aggr=[[SUM(FILENAME_PLACEHOLDER_0.l_quantity), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity), AVG(FILENAME_PLACEHOLDER_0.l_extendedprice), AVG(FILENAME_PLACEHOLDER_0.l_discount), COUNT(Int64(1))]]\n    \
-                 Projection: FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity, FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount), FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) + FILENAME_PLACEHOLDER_0.l_tax), FILENAME_PLACEHOLDER_0.l_discount\n      \
-                 Filter: FILENAME_PLACEHOLDER_0.l_shipdate <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120, milliseconds: 0 }\")\n        \
-                 TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]"
-            )
+        let plan_str = format!("{:?}", plan);
+        assert_eq!(
+            plan_str,
+            "Projection: FILENAME_PLACEHOLDER_0.l_returnflag AS L_RETURNFLAG, FILENAME_PLACEHOLDER_0.l_linestatus AS L_LINESTATUS, sum(FILENAME_PLACEHOLDER_0.l_quantity) AS SUM_QTY, sum(FILENAME_PLACEHOLDER_0.l_extendedprice) AS SUM_BASE_PRICE, sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount) AS SUM_DISC_PRICE, sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax) AS SUM_CHARGE, AVG(FILENAME_PLACEHOLDER_0.l_quantity) AS AVG_QTY, AVG(FILENAME_PLACEHOLDER_0.l_extendedprice) AS AVG_PRICE, AVG(FILENAME_PLACEHOLDER_0.l_discount) AS AVG_DISC, COUNT(Int64(1)) AS COUNT_ORDER\
+             \n  Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST, FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\
+             \n    Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus]], aggr=[[sum(FILENAME_PLACEHOLDER_0.l_quantity), sum(FILENAME_PLACEHOLDER_0.l_extendedprice), sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount), sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity), AVG(FILENAME_PLACEHOLDER_0.l_extendedprice), AVG(FILENAME_PLACEHOLDER_0.l_discount), COUNT(Int64(1))]]\
+             \n      Projection: FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity, FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount), FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) + FILENAME_PLACEHOLDER_0.l_tax), FILENAME_PLACEHOLDER_0.l_discount\
+             \n        Filter: FILENAME_PLACEHOLDER_0.l_shipdate <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120, milliseconds: 0 }\")\
+             \n          TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]"
         );
         Ok(())
     }
diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs
index 4d485b7f1238a..994a932c30e0f 100644
--- a/datafusion/substrait/tests/cases/logical_plans.rs
+++ b/datafusion/substrait/tests/cases/logical_plans.rs
@@ -48,7 +48,7 @@ mod tests {
 
         assert_eq!(
             format!("{:?}", plan),
-            "Projection: NOT DATA.a\
+            "Projection: NOT DATA.a AS EXPR$0\
             \n  TableScan: DATA projection=[a, b, c, d, e, f]"
         );
         Ok(())
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index ae148119ad8c1..4e4fa45a15a66 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -162,6 +162,11 @@ async fn wildcard_select() -> Result<()> {
     roundtrip("SELECT * FROM data").await
 }
 
+#[tokio::test]
+async fn select_with_alias() -> Result<()> {
+    roundtrip("SELECT a AS aliased_a FROM data").await
+}
+
 #[tokio::test]
 async fn select_with_filter() -> Result<()> {
     roundtrip("SELECT * FROM data WHERE a > 1").await
@@ -367,9 +372,9 @@ async fn implicit_cast() -> Result<()> {
 async fn aggregate_case() -> Result<()> {
     assert_expected_plan(
         "SELECT sum(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data",
-        "Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END)]]\
+        "Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END) AS sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE NULL END)]]\
          \n  TableScan: data projection=[a]",
-        false // NULL vs Int64(NULL)
+        true
     )
         .await
 }
@@ -589,32 +594,23 @@ async fn roundtrip_union_all() -> Result<()> {
 
 #[tokio::test]
 async fn simple_intersect() -> Result<()> {
+    // Substrait treats both COUNT(*) and COUNT(1) the same
     assert_expected_plan(
         "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);",
-        "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
+        "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]]\
          \n  Projection: \
          \n    LeftSemi Join: data.a = data2.a\
          \n      Aggregate: groupBy=[[data.a]], aggr=[[]]\
          \n        TableScan: data projection=[a]\
          \n      TableScan: data2 projection=[a]",
-        false // COUNT(*) vs COUNT(Int64(1))
+        true
     )
         .await
 }
 
 #[tokio::test]
 async fn simple_intersect_table_reuse() -> Result<()> {
-    assert_expected_plan(
-        "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);",
-        "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
-         \n  Projection: \
-         \n    LeftSemi Join: data.a = data.a\
-         \n      Aggregate: groupBy=[[data.a]], aggr=[[]]\
-         \n        TableScan: data projection=[a]\
-         \n      TableScan: data projection=[a]",
-        false // COUNT(*) vs COUNT(Int64(1))
-    )
-        .await
+    roundtrip("SELECT COUNT(1) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);").await
 }
 
 #[tokio::test]
@@ -694,20 +690,14 @@ async fn all_type_literal() -> Result<()> {
 
 #[tokio::test]
 async fn roundtrip_literal_list() -> Result<()> {
-    assert_expected_plan(
-        "SELECT [[1,2,3], [], NULL, [NULL]] FROM data",
-        "Projection: List([[1, 2, 3], [], , []])\
-        \n  TableScan: data projection=[]",
-        false, // "List(..)" vs "make_array(..)"
-    )
-    .await
+    roundtrip("SELECT [[1,2,3], [], NULL, [NULL]] FROM data").await
 }
 
 #[tokio::test]
 async fn roundtrip_literal_struct() -> Result<()> {
     assert_expected_plan(
         "SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data",
-        "Projection: Struct({c0:1,c1:true,c2:})\
+        "Projection: Struct({c0:1,c1:true,c2:}) AS struct(Int64(1),Boolean(true),NULL)\
         \n  TableScan: data projection=[]",
         false, // "Struct(..)" vs "struct(..)"
     )
@@ -980,12 +970,13 @@ async fn assert_expected_plan(
 
     println!("{proto:?}");
 
-    let plan2str = format!("{plan2:?}");
-    assert_eq!(expected_plan_str, &plan2str);
-
     if assert_schema {
         assert_eq!(plan.schema(), plan2.schema());
     }
+
+    let plan2str = format!("{plan2:?}");
+    assert_eq!(expected_plan_str, &plan2str);
+
     Ok(())
 }