Skip to content

Commit

Permalink
feat(substrait): add set operations to consumer, update substrait to …
Browse files Browse the repository at this point in the history
…`0.45.0` (#12863)

* feat(substait): add set operations to consumer

* add missing intersect all test, change distinct to is_all

* upgrade substrait crate to 0.45
  • Loading branch information
tokoko authored Oct 17, 2024
1 parent 56946b4 commit e63abe7
Show file tree
Hide file tree
Showing 11 changed files with 1,136 additions and 13 deletions.
2 changes: 1 addition & 1 deletion datafusion/substrait/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object_store = { workspace = true }
pbjson-types = "0.7"
# TODO use workspace version
prost = "0.13"
substrait = { version = "0.42", features = ["serde"] }
substrait = { version = "0.45", features = ["serde"] }
url = { workspace = true }

[dev-dependencies]
Expand Down
120 changes: 109 additions & 11 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,65 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality(
(accum_join_keys, nulls_equal_nulls, join_filter)
}

async fn union_rels(
rels: &[Rel],
ctx: &SessionContext,
extensions: &Extensions,
is_all: bool,
) -> Result<LogicalPlan> {
let mut union_builder = Ok(LogicalPlanBuilder::from(
from_substrait_rel(ctx, &rels[0], extensions).await?,
));
for input in &rels[1..] {
let rel_plan = from_substrait_rel(ctx, input, extensions).await?;

union_builder = if is_all {
union_builder?.union(rel_plan)
} else {
union_builder?.union_distinct(rel_plan)
};
}
union_builder?.build()
}

async fn intersect_rels(
rels: &[Rel],
ctx: &SessionContext,
extensions: &Extensions,
is_all: bool,
) -> Result<LogicalPlan> {
let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?;

for input in &rels[1..] {
rel = LogicalPlanBuilder::intersect(
rel,
from_substrait_rel(ctx, input, extensions).await?,
is_all,
)?
}

Ok(rel)
}

async fn except_rels(
rels: &[Rel],
ctx: &SessionContext,
extensions: &Extensions,
is_all: bool,
) -> Result<LogicalPlan> {
let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?;

for input in &rels[1..] {
rel = LogicalPlanBuilder::except(
rel,
from_substrait_rel(ctx, input, extensions).await?,
is_all,
)?
}

Ok(rel)
}

/// Convert Substrait Plan to DataFusion LogicalPlan
pub async fn from_substrait_plan(
ctx: &SessionContext,
Expand Down Expand Up @@ -494,6 +553,7 @@ fn make_renamed_schema(
}

/// Convert Substrait Rel to DataFusion DataFrame
#[allow(deprecated)]
#[async_recursion]
pub async fn from_substrait_rel(
ctx: &SessionContext,
Expand Down Expand Up @@ -877,27 +937,65 @@ pub async fn from_substrait_rel(
Ok(set_op) => match set_op {
set_rel::SetOp::UnionAll => {
if !set.inputs.is_empty() {
let mut union_builder = Ok(LogicalPlanBuilder::from(
from_substrait_rel(ctx, &set.inputs[0], extensions).await?,
));
for input in &set.inputs[1..] {
union_builder = union_builder?
.union(from_substrait_rel(ctx, input, extensions).await?);
}
union_builder?.build()
union_rels(&set.inputs, ctx, extensions, true).await
} else {
not_impl_err!("Union relation requires at least one input")
}
}
set_rel::SetOp::UnionDistinct => {
if !set.inputs.is_empty() {
union_rels(&set.inputs, ctx, extensions, false).await
} else {
not_impl_err!("Union relation requires at least one input")
}
}
set_rel::SetOp::IntersectionPrimary => {
if set.inputs.len() == 2 {
if set.inputs.len() >= 2 {
LogicalPlanBuilder::intersect(
from_substrait_rel(ctx, &set.inputs[0], extensions).await?,
from_substrait_rel(ctx, &set.inputs[1], extensions).await?,
union_rels(&set.inputs[1..], ctx, extensions, true).await?,
false,
)
} else {
not_impl_err!("Primary Intersect relation with more than two inputs isn't supported")
not_impl_err!(
"Primary Intersect relation requires at least two inputs"
)
}
}
set_rel::SetOp::IntersectionMultiset => {
if set.inputs.len() >= 2 {
intersect_rels(&set.inputs, ctx, extensions, false).await
} else {
not_impl_err!(
"Multiset Intersect relation requires at least two inputs"
)
}
}
set_rel::SetOp::IntersectionMultisetAll => {
if set.inputs.len() >= 2 {
intersect_rels(&set.inputs, ctx, extensions, true).await
} else {
not_impl_err!(
"MultisetAll Intersect relation requires at least two inputs"
)
}
}
set_rel::SetOp::MinusPrimary => {
if set.inputs.len() >= 2 {
except_rels(&set.inputs, ctx, extensions, false).await
} else {
not_impl_err!(
"Primary Minus relation requires at least two inputs"
)
}
}
set_rel::SetOp::MinusPrimaryAll => {
if set.inputs.len() >= 2 {
except_rels(&set.inputs, ctx, extensions, true).await
} else {
not_impl_err!(
"PrimaryAll Minus relation requires at least two inputs"
)
}
}
_ => not_impl_err!("Unsupported set operator: {set_op:?}"),
Expand Down
12 changes: 11 additions & 1 deletion datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ pub fn to_substrait_extended_expr(
}

/// Convert DataFusion LogicalPlan to Substrait Rel
#[allow(deprecated)]
pub fn to_substrait_rel(
plan: &LogicalPlan,
ctx: &SessionContext,
Expand Down Expand Up @@ -227,6 +228,7 @@ pub fn to_substrait_rel(
advanced_extension: None,
read_type: Some(ReadType::VirtualTable(VirtualTable {
values: vec![],
expressions: vec![],
})),
}))),
}))
Expand Down Expand Up @@ -263,7 +265,10 @@ pub fn to_substrait_rel(
best_effort_filter: None,
projection: None,
advanced_extension: None,
read_type: Some(ReadType::VirtualTable(VirtualTable { values })),
read_type: Some(ReadType::VirtualTable(VirtualTable {
values,
expressions: vec![],
})),
}))),
}))
}
Expand Down Expand Up @@ -359,6 +364,7 @@ pub fn to_substrait_rel(
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
input: Some(input),
grouping_expressions: vec![],
groupings,
measures,
advanced_extension: None,
Expand All @@ -377,8 +383,10 @@ pub fn to_substrait_rel(
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
input: Some(input),
grouping_expressions: vec![],
groupings: vec![Grouping {
grouping_expressions: grouping,
expression_references: vec![],
}],
measures: vec![],
advanced_extension: None,
Expand Down Expand Up @@ -764,6 +772,7 @@ pub fn operator_to_name(op: Operator) -> &'static str {
}
}

#[allow(deprecated)]
pub fn parse_flat_grouping_exprs(
ctx: &SessionContext,
exprs: &[Expr],
Expand All @@ -776,6 +785,7 @@ pub fn parse_flat_grouping_exprs(
.collect::<Result<Vec<_>>>()?;
Ok(Grouping {
grouping_expressions,
expression_references: vec![],
})
}

Expand Down
66 changes: 66 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,72 @@ async fn simple_intersect_consume() -> Result<()> {
.await
}

#[tokio::test]
async fn primary_intersect_consume() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/intersect_primary.substrait.json");

assert_substrait_sql(
proto_plan,
"SELECT a FROM data INTERSECT (SELECT a FROM data2 UNION ALL SELECT a FROM data2)",
)
.await
}

#[tokio::test]
async fn multiset_intersect_consume() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/intersect_multiset.substrait.json");

assert_substrait_sql(
proto_plan,
"SELECT a FROM data INTERSECT SELECT a FROM data2 INTERSECT SELECT a FROM data2",
)
.await
}

#[tokio::test]
async fn multiset_intersect_all_consume() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/intersect_multiset_all.substrait.json");

assert_substrait_sql(
proto_plan,
"SELECT a FROM data INTERSECT ALL SELECT a FROM data2 INTERSECT ALL SELECT a FROM data2",
)
.await
}

#[tokio::test]
async fn primary_except_consume() -> Result<()> {
let proto_plan = read_json("tests/testdata/test_plans/minus_primary.substrait.json");

assert_substrait_sql(
proto_plan,
"SELECT a FROM data EXCEPT SELECT a FROM data2 EXCEPT SELECT a FROM data2",
)
.await
}

#[tokio::test]
async fn primary_except_all_consume() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/minus_primary_all.substrait.json");

assert_substrait_sql(
proto_plan,
"SELECT a FROM data EXCEPT ALL SELECT a FROM data2 EXCEPT ALL SELECT a FROM data2",
)
.await
}

#[tokio::test]
async fn union_distinct_consume() -> Result<()> {
let proto_plan = read_json("tests/testdata/test_plans/union_distinct.substrait.json");

assert_substrait_sql(proto_plan, "SELECT a FROM data UNION SELECT a FROM data2").await
}

#[tokio::test]
async fn simple_intersect_table_reuse() -> Result<()> {
// Substrait does currently NOT maintain the alias of the tables.
Expand Down
Loading

0 comments on commit e63abe7

Please sign in to comment.