Skip to content

Commit

Permalink
Support correct output column names and struct field names when consu…
Browse files Browse the repository at this point in the history
…ming/producing Substrait (apache#10829)

* produce flattened list of names including inner struct fields

* add a (failing) test

* rename output columns (incl. inner struct fields) according to the given list of names

* fix a test

* add column names project to the new TPC-H test and fix case

(assert_eq gives nicer error messages than assert)
  • Loading branch information
Blizzara authored and findepi committed Jul 16, 2024
1 parent 31f5578 commit 12941e4
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 41 deletions.
133 changes: 128 additions & 5 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{
DataType, Field, Fields, IntervalUnit, Schema, TimeUnit,
DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
};
use datafusion::common::{
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
Expand All @@ -29,12 +29,13 @@ use url::Url;
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
aggregate_function, expr::find_df_window_func, BinaryExpr, Case, EmptyRelation, Expr,
LogicalPlan, Operator, ScalarUDF, Values,
aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case,
EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarUDF,
Values,
};

use datafusion::logical_expr::{
expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion::prelude::JoinType;
Expand Down Expand Up @@ -225,6 +226,7 @@ pub async fn from_substrait_plan(
None => not_impl_err!("Cannot parse empty extension"),
})
.collect::<Result<HashMap<_, _>>>()?;

// Parse relations
match plan.relations.len() {
1 => {
Expand All @@ -234,7 +236,29 @@ pub async fn from_substrait_plan(
Ok(from_substrait_rel(ctx, rel, &function_extension).await?)
},
plan_rel::RelType::Root(root) => {
Ok(from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?)
let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?;
if root.names.is_empty() {
// Backwards compatibility for plans missing names
return Ok(plan);
}
let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?;
if renamed_schema.equivalent_names_and_types(plan.schema()) {
// Nothing to do if the schema is already equivalent
return Ok(plan);
}

match plan {
// If the last node of the plan produces expressions, bake the renames into those expressions.
// This isn't necessary for correctness, but helps with roundtrip tests.
LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema)?, p.input)?)),
LogicalPlan::Aggregate(a) => {
let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), renamed_schema)?;
Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?))
},
// There are probably more plans where we could bake things in, can add them later as needed.
// Otherwise, add a new Project to handle the renaming.
_ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema)?, Arc::new(plan))?))
}
}
},
None => plan_err!("Cannot parse plan relation: None")
Expand Down Expand Up @@ -284,6 +308,105 @@ pub fn extract_projection(
}
}

fn rename_expressions(
exprs: impl IntoIterator<Item = Expr>,
input_schema: &DFSchema,
new_schema: DFSchemaRef,
) -> Result<Vec<Expr>> {
exprs
.into_iter()
.zip(new_schema.fields())
.map(|(old_expr, new_field)| {
if &old_expr.get_type(input_schema)? == new_field.data_type() {
// Alias column if needed
old_expr.alias_if_changed(new_field.name().into())
} else {
// Use Cast to rename inner struct fields + alias column if needed
Expr::Cast(Cast::new(
Box::new(old_expr),
new_field.data_type().to_owned(),
))
.alias_if_changed(new_field.name().into())
}
})
.collect()
}

fn make_renamed_schema(
schema: &DFSchemaRef,
dfs_names: &Vec<String>,
) -> Result<DFSchemaRef> {
fn rename_inner_fields(
dtype: &DataType,
dfs_names: &Vec<String>,
name_idx: &mut usize,
) -> Result<DataType> {
match dtype {
DataType::Struct(fields) => {
let fields = fields
.iter()
.map(|f| {
let name = next_struct_field_name(0, dfs_names, name_idx)?;
Ok((**f).to_owned().with_name(name).with_data_type(
rename_inner_fields(f.data_type(), dfs_names, name_idx)?,
))
})
.collect::<Result<_>>()?;
Ok(DataType::Struct(fields))
}
DataType::List(inner) => Ok(DataType::List(FieldRef::new(
(**inner).to_owned().with_data_type(rename_inner_fields(
inner.data_type(),
dfs_names,
name_idx,
)?),
))),
DataType::LargeList(inner) => Ok(DataType::LargeList(FieldRef::new(
(**inner).to_owned().with_data_type(rename_inner_fields(
inner.data_type(),
dfs_names,
name_idx,
)?),
))),
_ => Ok(dtype.to_owned()),
}
}

let mut name_idx = 0;

let (qualifiers, fields): (_, Vec<Field>) = schema
.iter()
.map(|(q, f)| {
let name = next_struct_field_name(0, dfs_names, &mut name_idx)?;
Ok((
q.cloned(),
(**f)
.to_owned()
.with_name(name)
.with_data_type(rename_inner_fields(
f.data_type(),
dfs_names,
&mut name_idx,
)?),
))
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.unzip();

if name_idx != dfs_names.len() {
return substrait_err!(
"Names list must match exactly to nested schema, but found {} uses for {} names",
name_idx,
dfs_names.len());
}

Ok(Arc::new(DFSchema::from_field_specific_qualified_schema(
qualifiers,
&Arc::new(Schema::new(fields)),
)?))
}

/// Convert Substrait Rel to DataFusion DataFrame
#[async_recursion]
pub async fn from_substrait_rel(
Expand Down
2 changes: 1 addition & 1 deletion datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result<Box
let plan_rels = vec![PlanRel {
rel_type: Some(plan_rel::RelType::Root(RelRoot {
input: Some(*to_substrait_rel(plan, ctx, &mut extension_info)?),
names: plan.schema().field_names(),
names: to_substrait_named_struct(plan.schema())?.names,
})),
}];

Expand Down
17 changes: 9 additions & 8 deletions datafusion/substrait/tests/cases/consumer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ mod tests {

let plan = from_substrait_plan(&ctx, &proto).await?;

assert!(
format!("{:?}", plan).eq_ignore_ascii_case(
"Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST, FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\n \
Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus]], aggr=[[SUM(FILENAME_PLACEHOLDER_0.l_quantity), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity), AVG(FILENAME_PLACEHOLDER_0.l_extendedprice), AVG(FILENAME_PLACEHOLDER_0.l_discount), COUNT(Int64(1))]]\n \
Projection: FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity, FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount), FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) + FILENAME_PLACEHOLDER_0.l_tax), FILENAME_PLACEHOLDER_0.l_discount\n \
Filter: FILENAME_PLACEHOLDER_0.l_shipdate <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120, milliseconds: 0 }\")\n \
TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]"
)
let plan_str = format!("{:?}", plan);
assert_eq!(
plan_str,
"Projection: FILENAME_PLACEHOLDER_0.l_returnflag AS L_RETURNFLAG, FILENAME_PLACEHOLDER_0.l_linestatus AS L_LINESTATUS, sum(FILENAME_PLACEHOLDER_0.l_quantity) AS SUM_QTY, sum(FILENAME_PLACEHOLDER_0.l_extendedprice) AS SUM_BASE_PRICE, sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount) AS SUM_DISC_PRICE, sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax) AS SUM_CHARGE, AVG(FILENAME_PLACEHOLDER_0.l_quantity) AS AVG_QTY, AVG(FILENAME_PLACEHOLDER_0.l_extendedprice) AS AVG_PRICE, AVG(FILENAME_PLACEHOLDER_0.l_discount) AS AVG_DISC, COUNT(Int64(1)) AS COUNT_ORDER\
\n Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST, FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\
\n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus]], aggr=[[sum(FILENAME_PLACEHOLDER_0.l_quantity), sum(FILENAME_PLACEHOLDER_0.l_extendedprice), sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount), sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity), AVG(FILENAME_PLACEHOLDER_0.l_extendedprice), AVG(FILENAME_PLACEHOLDER_0.l_discount), COUNT(Int64(1))]]\
\n Projection: FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity, FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount), FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) + FILENAME_PLACEHOLDER_0.l_tax), FILENAME_PLACEHOLDER_0.l_discount\
\n Filter: FILENAME_PLACEHOLDER_0.l_shipdate <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120, milliseconds: 0 }\")\
\n TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]"
);
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/substrait/tests/cases/logical_plans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ mod tests {

assert_eq!(
format!("{:?}", plan),
"Projection: NOT DATA.a\
"Projection: NOT DATA.a AS EXPR$0\
\n TableScan: DATA projection=[a, b, c, d, e, f]"
);
Ok(())
Expand Down
43 changes: 17 additions & 26 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ async fn wildcard_select() -> Result<()> {
roundtrip("SELECT * FROM data").await
}

#[tokio::test]
async fn select_with_alias() -> Result<()> {
roundtrip("SELECT a AS aliased_a FROM data").await
}

#[tokio::test]
async fn select_with_filter() -> Result<()> {
roundtrip("SELECT * FROM data WHERE a > 1").await
Expand Down Expand Up @@ -367,9 +372,9 @@ async fn implicit_cast() -> Result<()> {
async fn aggregate_case() -> Result<()> {
assert_expected_plan(
"SELECT sum(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data",
"Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END)]]\
"Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END) AS sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE NULL END)]]\
\n TableScan: data projection=[a]",
false // NULL vs Int64(NULL)
true
)
.await
}
Expand Down Expand Up @@ -589,32 +594,23 @@ async fn roundtrip_union_all() -> Result<()> {

#[tokio::test]
async fn simple_intersect() -> Result<()> {
// Substrait treats both COUNT(*) and COUNT(1) the same
assert_expected_plan(
"SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);",
"Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
"Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]]\
\n Projection: \
\n LeftSemi Join: data.a = data2.a\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n TableScan: data2 projection=[a]",
false // COUNT(*) vs COUNT(Int64(1))
true
)
.await
}

#[tokio::test]
async fn simple_intersect_table_reuse() -> Result<()> {
assert_expected_plan(
"SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);",
"Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
\n Projection: \
\n LeftSemi Join: data.a = data.a\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n TableScan: data projection=[a]",
false // COUNT(*) vs COUNT(Int64(1))
)
.await
roundtrip("SELECT COUNT(1) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);").await
}

#[tokio::test]
Expand Down Expand Up @@ -694,20 +690,14 @@ async fn all_type_literal() -> Result<()> {

#[tokio::test]
async fn roundtrip_literal_list() -> Result<()> {
assert_expected_plan(
"SELECT [[1,2,3], [], NULL, [NULL]] FROM data",
"Projection: List([[1, 2, 3], [], , []])\
\n TableScan: data projection=[]",
false, // "List(..)" vs "make_array(..)"
)
.await
roundtrip("SELECT [[1,2,3], [], NULL, [NULL]] FROM data").await
}

#[tokio::test]
async fn roundtrip_literal_struct() -> Result<()> {
assert_expected_plan(
"SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data",
"Projection: Struct({c0:1,c1:true,c2:})\
"Projection: Struct({c0:1,c1:true,c2:}) AS struct(Int64(1),Boolean(true),NULL)\
\n TableScan: data projection=[]",
false, // "Struct(..)" vs "struct(..)"
)
Expand Down Expand Up @@ -980,12 +970,13 @@ async fn assert_expected_plan(

println!("{proto:?}");

let plan2str = format!("{plan2:?}");
assert_eq!(expected_plan_str, &plan2str);

if assert_schema {
assert_eq!(plan.schema(), plan2.schema());
}

let plan2str = format!("{plan2:?}");
assert_eq!(expected_plan_str, &plan2str);

Ok(())
}

Expand Down

0 comments on commit 12941e4

Please sign in to comment.