From 12941e404b5c02c06faa5a8f2d04af9744cbed4a Mon Sep 17 00:00:00 2001 From: Arttu <Blizzara@users.noreply.github.com> Date: Tue, 11 Jun 2024 20:35:34 +0200 Subject: [PATCH] Support correct output column names and struct field names when consuming/producing Substrait (#10829) * produce flattened list of names including inner struct fields * add a (failing) test * rename output columns (incl. inner struct fields) according to the given list of names * fix a test * add column names project to the new TPC-H test and fix case (assert_eq gives nicer error messages than assert) --- .../substrait/src/logical_plan/consumer.rs | 133 +++++++++++++++++- .../substrait/src/logical_plan/producer.rs | 2 +- .../tests/cases/consumer_integration.rs | 17 +-- .../substrait/tests/cases/logical_plans.rs | 2 +- .../tests/cases/roundtrip_logical_plan.rs | 43 +++--- 5 files changed, 156 insertions(+), 41 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 8a483db8c4d62..648a281832e10 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -17,7 +17,7 @@ use async_recursion::async_recursion; use datafusion::arrow::datatypes::{ - DataType, Field, Fields, IntervalUnit, Schema, TimeUnit, + DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; use datafusion::common::{ not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, @@ -29,12 +29,13 @@ use url::Url; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ - aggregate_function, expr::find_df_window_func, BinaryExpr, Case, EmptyRelation, Expr, - LogicalPlan, Operator, ScalarUDF, Values, + aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case, + EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarUDF, + Values, }; use datafusion::logical_expr::{ - expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, + col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion::prelude::JoinType; @@ -225,6 +226,7 @@ pub async fn from_substrait_plan( None => not_impl_err!("Cannot parse empty extension"), }) .collect::<Result<HashMap<_, _>>>()?; + // Parse relations match plan.relations.len() { 1 => { @@ -234,7 +236,29 @@ pub async fn from_substrait_plan( Ok(from_substrait_rel(ctx, rel, &function_extension).await?) }, plan_rel::RelType::Root(root) => { - Ok(from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?) + let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?; + if root.names.is_empty() { + // Backwards compatibility for plans missing names + return Ok(plan); + } + let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?; + if renamed_schema.equivalent_names_and_types(plan.schema()) { + // Nothing to do if the schema is already equivalent + return Ok(plan); + } + + match plan { + // If the last node of the plan produces expressions, bake the renames into those expressions. + // This isn't necessary for correctness, but helps with roundtrip tests. + LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema)?, p.input)?)), + LogicalPlan::Aggregate(a) => { + let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), renamed_schema)?; + Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?)) + }, + // There are probably more plans where we could bake things in, can add them later as needed. + // Otherwise, add a new Project to handle the renaming. + _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema)?, Arc::new(plan))?)) + } } }, None => plan_err!("Cannot parse plan relation: None") @@ -284,6 +308,105 @@ pub fn extract_projection( } } +fn rename_expressions( + exprs: impl IntoIterator<Item = Expr>, + input_schema: &DFSchema, + new_schema: DFSchemaRef, +) -> Result<Vec<Expr>> { + exprs + .into_iter() + .zip(new_schema.fields()) + .map(|(old_expr, new_field)| { + if &old_expr.get_type(input_schema)? == new_field.data_type() { + // Alias column if needed + old_expr.alias_if_changed(new_field.name().into()) + } else { + // Use Cast to rename inner struct fields + alias column if needed + Expr::Cast(Cast::new( + Box::new(old_expr), + new_field.data_type().to_owned(), + )) + .alias_if_changed(new_field.name().into()) + } + }) + .collect() +} + +fn make_renamed_schema( + schema: &DFSchemaRef, + dfs_names: &Vec<String>, +) -> Result<DFSchemaRef> { + fn rename_inner_fields( + dtype: &DataType, + dfs_names: &Vec<String>, + name_idx: &mut usize, + ) -> Result<DataType> { + match dtype { + DataType::Struct(fields) => { + let fields = fields + .iter() + .map(|f| { + let name = next_struct_field_name(0, dfs_names, name_idx)?; + Ok((**f).to_owned().with_name(name).with_data_type( + rename_inner_fields(f.data_type(), dfs_names, name_idx)?, + )) + }) + .collect::<Result<_>>()?; + Ok(DataType::Struct(fields)) + } + DataType::List(inner) => Ok(DataType::List(FieldRef::new( + (**inner).to_owned().with_data_type(rename_inner_fields( + inner.data_type(), + dfs_names, + name_idx, + )?), + ))), + DataType::LargeList(inner) => Ok(DataType::LargeList(FieldRef::new( + (**inner).to_owned().with_data_type(rename_inner_fields( + inner.data_type(), + dfs_names, + name_idx, + )?), + ))), + _ => Ok(dtype.to_owned()), + } + } + + let mut name_idx = 0; + + let (qualifiers, fields): (_, Vec<Field>) = schema + .iter() + .map(|(q, f)| { + let name = next_struct_field_name(0, dfs_names, &mut name_idx)?; + Ok(( + q.cloned(), + (**f) + .to_owned() + .with_name(name) + .with_data_type(rename_inner_fields( + f.data_type(), + dfs_names, + &mut name_idx, + )?), + )) + }) + .collect::<Result<Vec<_>>>()? + .into_iter() + .unzip(); + + if name_idx != dfs_names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + dfs_names.len()); + } + + Ok(Arc::new(DFSchema::from_field_specific_qualified_schema( + qualifiers, + &Arc::new(Schema::new(fields)), + )?)) +} + /// Convert Substrait Rel to DataFusion DataFrame #[async_recursion] pub async fn from_substrait_rel( diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 6c8be4aa9b121..88dc894eccd28 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -115,7 +115,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result<Box let plan_rels = vec![PlanRel { rel_type: Some(plan_rel::RelType::Root(RelRoot { input: Some(*to_substrait_rel(plan, ctx, &mut extension_info)?), - names: plan.schema().field_names(), + names: to_substrait_named_struct(plan.schema())?.names, })), }]; diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index c2ae5691134a4..e0151ecc3a4f9 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -43,14 +43,15 @@ mod tests { let plan = from_substrait_plan(&ctx, &proto).await?; - assert!( - format!("{:?}", plan).eq_ignore_ascii_case( - "Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST, FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\n \ - Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus]], aggr=[[SUM(FILENAME_PLACEHOLDER_0.l_quantity), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity), AVG(FILENAME_PLACEHOLDER_0.l_extendedprice), AVG(FILENAME_PLACEHOLDER_0.l_discount), COUNT(Int64(1))]]\n \ - Projection: FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity, FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount), FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) + FILENAME_PLACEHOLDER_0.l_tax), FILENAME_PLACEHOLDER_0.l_discount\n \ - Filter: FILENAME_PLACEHOLDER_0.l_shipdate <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120, milliseconds: 0 }\")\n \ - TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]" - ) + let plan_str = format!("{:?}", plan); + assert_eq!( + plan_str, + "Projection: FILENAME_PLACEHOLDER_0.l_returnflag AS L_RETURNFLAG, FILENAME_PLACEHOLDER_0.l_linestatus AS L_LINESTATUS, sum(FILENAME_PLACEHOLDER_0.l_quantity) AS SUM_QTY, sum(FILENAME_PLACEHOLDER_0.l_extendedprice) AS SUM_BASE_PRICE, sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount) AS SUM_DISC_PRICE, sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax) AS SUM_CHARGE, AVG(FILENAME_PLACEHOLDER_0.l_quantity) AS AVG_QTY, AVG(FILENAME_PLACEHOLDER_0.l_extendedprice) AS AVG_PRICE, AVG(FILENAME_PLACEHOLDER_0.l_discount) AS AVG_DISC, COUNT(Int64(1)) AS COUNT_ORDER\ + \n Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST, FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\ + \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus]], aggr=[[sum(FILENAME_PLACEHOLDER_0.l_quantity), sum(FILENAME_PLACEHOLDER_0.l_extendedprice), sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount), sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity), AVG(FILENAME_PLACEHOLDER_0.l_extendedprice), AVG(FILENAME_PLACEHOLDER_0.l_discount), COUNT(Int64(1))]]\ + \n Projection: FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity, FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount), FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) + FILENAME_PLACEHOLDER_0.l_tax), FILENAME_PLACEHOLDER_0.l_discount\ + \n Filter: FILENAME_PLACEHOLDER_0.l_shipdate <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120, milliseconds: 0 }\")\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]" ); Ok(()) } diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 4d485b7f1238a..994a932c30e0f 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -48,7 +48,7 @@ mod tests { assert_eq!( format!("{:?}", plan), - "Projection: NOT DATA.a\ + "Projection: NOT DATA.a AS EXPR$0\ \n TableScan: DATA projection=[a, b, c, d, e, f]" ); Ok(()) diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index ae148119ad8c1..4e4fa45a15a66 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -162,6 +162,11 @@ async fn wildcard_select() -> Result<()> { roundtrip("SELECT * FROM data").await } +#[tokio::test] +async fn select_with_alias() -> Result<()> { + roundtrip("SELECT a AS aliased_a FROM data").await +} + #[tokio::test] async fn select_with_filter() -> Result<()> { roundtrip("SELECT * FROM data WHERE a > 1").await @@ -367,9 +372,9 @@ async fn implicit_cast() -> Result<()> { async fn aggregate_case() -> Result<()> { assert_expected_plan( "SELECT sum(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data", - "Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END)]]\ + "Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END) AS sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE NULL END)]]\ \n TableScan: data projection=[a]", - false // NULL vs Int64(NULL) + true ) .await } @@ -589,32 +594,23 @@ async fn roundtrip_union_all() -> Result<()> { #[tokio::test] async fn simple_intersect() -> Result<()> { + // Substrait treats both COUNT(*) and COUNT(1) the same assert_expected_plan( "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ + "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]]\ \n Projection: \ \n LeftSemi Join: data.a = data2.a\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ \n TableScan: data projection=[a]\ \n TableScan: data2 projection=[a]", - false // COUNT(*) vs COUNT(Int64(1)) + true ) .await } #[tokio::test] async fn simple_intersect_table_reuse() -> Result<()> { - assert_expected_plan( - "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);", - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ - \n Projection: \ - \n LeftSemi Join: data.a = data.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data projection=[a]", - false // COUNT(*) vs COUNT(Int64(1)) - ) - .await + roundtrip("SELECT COUNT(1) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);").await } #[tokio::test] @@ -694,20 +690,14 @@ async fn all_type_literal() -> Result<()> { #[tokio::test] async fn roundtrip_literal_list() -> Result<()> { - assert_expected_plan( - "SELECT [[1,2,3], [], NULL, [NULL]] FROM data", - "Projection: List([[1, 2, 3], [], , []])\ - \n TableScan: data projection=[]", - false, // "List(..)" vs "make_array(..)" - ) - .await + roundtrip("SELECT [[1,2,3], [], NULL, [NULL]] FROM data").await } #[tokio::test] async fn roundtrip_literal_struct() -> Result<()> { assert_expected_plan( "SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data", - "Projection: Struct({c0:1,c1:true,c2:})\ + "Projection: Struct({c0:1,c1:true,c2:}) AS struct(Int64(1),Boolean(true),NULL)\ \n TableScan: data projection=[]", false, // "Struct(..)" vs "struct(..)" ) @@ -980,12 +970,13 @@ async fn assert_expected_plan( println!("{proto:?}"); - let plan2str = format!("{plan2:?}"); - assert_eq!(expected_plan_str, &plan2str); - if assert_schema { assert_eq!(plan.schema(), plan2.schema()); } + + let plan2str = format!("{plan2:?}"); + assert_eq!(expected_plan_str, &plan2str); + Ok(()) }