Skip to content

Commit

Permalink
fix union serialisation order in proto (#13709)
Browse files Browse the repository at this point in the history
* fix union serialisation order in proto

* clippy

* address comments
  • Loading branch information
onursatici authored Dec 12, 2024
1 parent 36a1361 commit 1c73a23
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 12 deletions.
19 changes: 7 additions & 12 deletions datafusion/proto/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -737,23 +737,18 @@ impl AsLogicalPlan for LogicalPlanNode {
builder.build()
}
LogicalPlanType::Union(union) => {
let mut input_plans: Vec<LogicalPlan> = union
.inputs
.iter()
.map(|i| i.try_into_logical_plan(ctx, extension_codec))
.collect::<Result<_>>()?;

if input_plans.len() < 2 {
if union.inputs.len() < 2 {
return Err( DataFusionError::Internal(String::from(
"Protobuf deserialization error, Union was require at least two input.",
)));
}
let (first, rest) = union.inputs.split_first().unwrap();
let mut builder = LogicalPlanBuilder::from(
first.try_into_logical_plan(ctx, extension_codec)?,
);

let first = input_plans.pop().ok_or_else(|| DataFusionError::Internal(String::from(
"Protobuf deserialization error, Union was require at least two input.",
)))?;
let mut builder = LogicalPlanBuilder::from(first);
for plan in input_plans {
for i in rest {
let plan = i.try_into_logical_plan(ctx, extension_codec)?;
builder = builder.union(plan)?;
}
builder.build()
Expand Down
34 changes: 34 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ use arrow::datatypes::{
};
use arrow::util::pretty::pretty_format_batches;
use datafusion::datasource::file_format::json::JsonFormatFactory;
use datafusion::optimizer::eliminate_nested_union::EliminateNestedUnion;
use datafusion::optimizer::Optimizer;
use datafusion_common::parsers::CompressionTypeVariant;
use prost::Message;
use std::any::Any;
Expand Down Expand Up @@ -2555,3 +2557,35 @@ async fn roundtrip_recursive_query() {
format!("{}", pretty_format_batches(&output_round_trip).unwrap())
);
}

#[tokio::test]
async fn roundtrip_union_query() -> Result<()> {
let query = "SELECT a FROM t1
UNION (SELECT a from t1 UNION SELECT a from t2)";

let ctx = SessionContext::new();
ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default())
.await?;
ctx.register_csv("t2", "tests/testdata/test.csv", CsvReadOptions::default())
.await?;
let dataframe = ctx.sql(query).await?;
let plan = dataframe.into_optimized_plan()?;

let bytes = logical_plan_to_bytes(&plan)?;

let ctx = SessionContext::new();
ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default())
.await?;
ctx.register_csv("t2", "tests/testdata/test.csv", CsvReadOptions::default())
.await?;
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?;
// proto deserialisation only supports 2-way union, hence this plan has nested unions
// apply the flatten unions optimizer rule to be able to compare
let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateNestedUnion::new())]);
let unnested = optimizer.optimize(logical_round_trip, &(ctx.state()), |_x, _y| {})?;
assert_eq!(
format!("{}", plan.display_indent_schema()),
format!("{}", unnested.display_indent_schema()),
);
Ok(())
}

0 comments on commit 1c73a23

Please sign in to comment.