Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support duplicate column names in Joins in Substrait consumer #11049

Merged
merged 5 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 61 additions & 17 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,14 @@ pub async fn from_substrait_plan(
match plan {
// If the last node of the plan produces expressions, bake the renames into those expressions.
// This isn't necessary for correctness, but helps with roundtrip tests.
LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema)?, p.input)?)),
LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, p.input)?)),
LogicalPlan::Aggregate(a) => {
let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), renamed_schema)?;
let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), &renamed_schema)?;
Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?))
},
// There are probably more plans where we could bake things in, can add them later as needed.
// Otherwise, add a new Project to handle the renaming.
_ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema)?, Arc::new(plan))?))
_ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), &renamed_schema)?, Arc::new(plan))?))
}
}
},
Expand Down Expand Up @@ -308,34 +308,46 @@ pub fn extract_projection(
}
}

/// Ensure the expressions have the right name(s) according to the new schema.
/// This includes the top-level (column) name, which will be renamed through aliasing if needed,
/// as well as nested names (if the expression produces any struct types), which will be renamed
/// through casting if needed.
fn rename_expressions(
exprs: impl IntoIterator<Item = Expr>,
input_schema: &DFSchema,
new_schema: DFSchemaRef,
new_schema: &DFSchema,
) -> Result<Vec<Expr>> {
exprs
.into_iter()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As rename_expressions is getting more feature-full perhaps we can add some doc comments explaining what it is doing (I think it is basically renaming expressions from substrait to match DataFusion's expectations?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, added some docs in 87829ec

.zip(new_schema.fields())
.map(|(old_expr, new_field)| {
if &old_expr.get_type(input_schema)? == new_field.data_type() {
// Alias column if needed
old_expr.alias_if_changed(new_field.name().into())
} else {
// Use Cast to rename inner struct fields + alias column if needed
// Check if type (i.e. nested struct field names) match, use Cast to rename if needed
let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() {
Expr::Cast(Cast::new(
Box::new(old_expr),
new_field.data_type().to_owned(),
))
.alias_if_changed(new_field.name().into())
} else {
old_expr
};
// Alias column if needed to fix the top-level name
match &new_expr {
// If expr is a column reference, alias_if_changed would cause an aliasing if the old expr has a qualifier
Expr::Column(c) if &c.name == new_field.name() => Ok(new_expr),
_ => new_expr.alias_if_changed(new_field.name().to_owned()),
}
})
.collect()
}

/// Produce a version of the given schema with names matching the given list of names.
/// Substrait doesn't deal with column (incl. nested struct field) names within the schema,
/// but it does give us the list of expected names at the end of the plan, so we use this
/// to rename the schema to match the expected names.
fn make_renamed_schema(
schema: &DFSchemaRef,
dfs_names: &Vec<String>,
) -> Result<DFSchemaRef> {
) -> Result<DFSchema> {
fn rename_inner_fields(
dtype: &DataType,
dfs_names: &Vec<String>,
Expand Down Expand Up @@ -401,10 +413,10 @@ fn make_renamed_schema(
dfs_names.len());
}

Ok(Arc::new(DFSchema::from_field_specific_qualified_schema(
Ok(DFSchema::from_field_specific_qualified_schema(
qualifiers,
&Arc::new(Schema::new(fields)),
)?))
)?)
}

/// Convert Substrait Rel to DataFusion DataFrame
Expand Down Expand Up @@ -594,6 +606,8 @@ pub async fn from_substrait_rel(
let right = LogicalPlanBuilder::from(
from_substrait_rel(ctx, join.right.as_ref().unwrap(), extensions).await?,
);
let (left, right) = requalify_sides_if_needed(left, right)?;

let join_type = from_substrait_jointype(join.r#type)?;
// The join condition expression needs full input schema and not the output schema from join since we lose columns from
// certain join types such as semi and anti joins
Expand Down Expand Up @@ -627,13 +641,15 @@ pub async fn from_substrait_rel(
}
}
Some(RelType::Cross(cross)) => {
let left: LogicalPlanBuilder = LogicalPlanBuilder::from(
let left = LogicalPlanBuilder::from(
from_substrait_rel(ctx, cross.left.as_ref().unwrap(), extensions).await?,
);
let right =
let right = LogicalPlanBuilder::from(
from_substrait_rel(ctx, cross.right.as_ref().unwrap(), extensions)
.await?;
left.cross_join(right)?.build()
.await?,
);
let (left, right) = requalify_sides_if_needed(left, right)?;
left.cross_join(right.build()?)?.build()
}
Some(RelType::Read(read)) => match &read.as_ref().read_type {
Some(ReadType::NamedTable(nt)) => {
Expand Down Expand Up @@ -846,6 +862,34 @@ pub async fn from_substrait_rel(
}
}

/// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise
/// conflict with the columns from the other.
/// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For
/// Substrait the names don't matter since it only refers to columns by indices, however DataFusion
/// requires columns to be uniquely identifiable, in some places (see e.g. DFSchema::check_names).
fn requalify_sides_if_needed(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps you could add some of the excellent context from this PR's description as docstrings here to help future readers understand what is going on

This is not strictly necessary as I think the comments in the function explain it, but docstrings would still be nice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a short version in 87829ec

left: LogicalPlanBuilder,
right: LogicalPlanBuilder,
) -> Result<(LogicalPlanBuilder, LogicalPlanBuilder)> {
let left_cols = left.schema().columns();
let right_cols = right.schema().columns();
if left_cols.iter().any(|l| {
right_cols.iter().any(|r| {
l == r || (l.name == r.name && (l.relation == None || r.relation == None))
})
}) {
// These names have no connection to the original plan, but they'll make the columns
// (mostly) unique. There may be cases where this still causes duplicates, if either left
// or right side itself contains duplicate names with different qualifiers.
Ok((
left.alias(TableReference::bare("left"))?,
right.alias(TableReference::bare("right"))?,
))
} else {
Ok((left, right))
}
}

fn from_substrait_jointype(join_type: i32) -> Result<JoinType> {
if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) {
match substrait_join_type {
Expand Down
98 changes: 59 additions & 39 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,23 @@ async fn roundtrip_outer_join() -> Result<()> {
roundtrip("SELECT data.a FROM data FULL OUTER JOIN data2 ON data.a = data2.a").await
}

#[tokio::test]
async fn roundtrip_self_join() -> Result<()> {
// Substrait does currently NOT maintain the alias of the tables.
// Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide.
// This roundtrip works because we set aliases to what the Substrait consumer will generate.
roundtrip("SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.a = right.a").await?;
roundtrip("SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.b = right.b").await
}

#[tokio::test]
async fn roundtrip_self_implicit_cross_join() -> Result<()> {
// Substrait does currently NOT maintain the alias of the tables.
// Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide.
// This roundtrip works because we set aliases to what the Substrait consumer will generate.
roundtrip("SELECT left.a left_a, left.b, right.a right_a, right.c FROM data AS left, data AS right").await
}

#[tokio::test]
async fn roundtrip_arithmetic_ops() -> Result<()> {
roundtrip("SELECT a - a FROM data").await?;
Expand Down Expand Up @@ -610,7 +627,22 @@ async fn simple_intersect() -> Result<()> {

#[tokio::test]
async fn simple_intersect_table_reuse() -> Result<()> {
roundtrip("SELECT count(1) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);").await
// Substrait does currently NOT maintain the alias of the tables.
// Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide.
// In this case the aliasing happens at a different point in the plan, so we cannot use roundtrip.
// Schema check works because we set aliases to what the Substrait consumer will generate.
assert_expected_plan(
"SELECT count(1) FROM (SELECT left.a FROM data AS left INTERSECT SELECT right.a FROM data AS right);",
"Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\
\n Projection: \
\n LeftSemi Join: left.a = right.a\
\n SubqueryAlias: left\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n SubqueryAlias: right\
\n TableScan: data projection=[a]",
true
).await
}

#[tokio::test]
Expand All @@ -628,32 +660,6 @@ async fn qualified_catalog_schema_table_reference() -> Result<()> {
roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await
}

#[tokio::test]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe these tests are fully replaced by the stricter roundtrip_self_join tests I added above.

These tests worked before since they hit the case (1) - the tables have overlapping names but they're both qualified, so the DFSchema::check_names() doesn't mind, and they don't read those overlapping columns in the final result.

They would pass with the changes here as well, they'd just have additional SubqueryAliases.

async fn roundtrip_inner_join_table_reuse_zero_index() -> Result<()> {
assert_expected_plan(
"SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.a = d2.a",
"Projection: data.b, data.c\
\n Inner Join: data.a = data.a\
\n TableScan: data projection=[a, b]\
\n TableScan: data projection=[a, c]",
false, // "d1" vs "data" field qualifier
)
.await
}

#[tokio::test]
async fn roundtrip_inner_join_table_reuse_non_zero_index() -> Result<()> {
assert_expected_plan(
"SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b",
"Projection: data.b, data.c\
\n Inner Join: data.b = data.b\
\n TableScan: data projection=[b]\
\n TableScan: data projection=[b, c]",
false, // "d1" vs "data" field qualifier
)
.await
}

/// Construct a plan that contains several literals of types that are currently supported.
/// This case ignores:
/// - Date64, for this literal is not supported
Expand Down Expand Up @@ -707,20 +713,17 @@ async fn roundtrip_literal_struct() -> Result<()> {
#[tokio::test]
async fn roundtrip_values() -> Result<()> {
// TODO: would be nice to have a struct inside the LargeList, but arrow_cast doesn't support that currently
let values = "(\
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are unrelated, but I was annoyed about the two assert inside this test while developing and one of them was failing - so I split them out.

There was no real benefit to using a large list of types for the empty relation.

assert_expected_plan(
"VALUES \
(\
1, \
'a', \
[[-213.1, NULL, 5.5, 2.0, 1.0], []], \
arrow_cast([1,2,3], 'LargeList(Int64)'), \
STRUCT(true, 1 AS int_field, CAST(NULL AS STRING)), \
[STRUCT(STRUCT('a' AS string_field) AS struct_field)]\
)";

// Test LogicalPlan::Values
assert_expected_plan(
format!("VALUES \
{values}, \
(NULL, NULL, NULL, NULL, NULL, NULL)").as_str(),
), \
(NULL, NULL, NULL, NULL, NULL, NULL)",
"Values: \
(\
Int64(1), \
Expand All @@ -731,11 +734,28 @@ async fn roundtrip_values() -> Result<()> {
List([{struct_field: {string_field: a}}])\
), \
(Int64(NULL), Utf8(NULL), List(), LargeList(), Struct({c0:,int_field:,c2:}), List())",
true)
.await?;
true).await
}

#[tokio::test]
async fn roundtrip_values_empty_relation() -> Result<()> {
roundtrip("SELECT * FROM (VALUES ('a')) LIMIT 0").await
}

// Test LogicalPlan::EmptyRelation
roundtrip(format!("SELECT * FROM (VALUES {values}) LIMIT 0").as_str()).await
#[tokio::test]
async fn roundtrip_values_duplicate_column_join() -> Result<()> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test for the case (2)

// Substrait does currently NOT maintain the alias of the tables.
// Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide.
// This roundtrip works because we set aliases to what the Substrait consumer will generate.
roundtrip(
"SELECT left.column1 as c1, right.column1 as c2 \
FROM \
(VALUES (1)) AS left \
JOIN \
(VALUES (2)) AS right \
ON left.column1 == right.column1",
)
.await
}

/// Construct a plan that cast columns. Only those SQL types are supported for now.
Expand Down
Loading