diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index b1b510f1792de..511d4e3664f29 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -214,11 +214,34 @@ pub async fn from_substrait_plan( Ok(from_substrait_rel(ctx, rel, &extensions).await?) }, plan_rel::RelType::Root(root) => { - let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?; + let mut plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?; if root.names.is_empty() { // Backwards compatibility for plans missing names return Ok(plan); } + // if the names didn't match the root plan's schema, we try to add projections visa rebuild + if !plan.inputs().is_empty() { + let plan_schema = plan.schema(); + // only cares Projection and Aggregation which all has 1 input + let plan_input_schema = plan.inputs()[0].schema(); + let mut missed_expr: Vec = root.names.iter() + .filter_map(|name| { + if !plan_schema.has_column_with_unqualified_name(name) && + plan_input_schema.has_column_with_unqualified_name(name) { + // we can safely unwrap here + let (qualifier, field) = plan_input_schema.qualified_field_with_unqualified_name(name).unwrap(); + Some(Expr::from(Column::from((qualifier, field)))) + } else { + None + } + }).collect(); + if !missed_expr.is_empty() { + if let LogicalPlan::Projection(projection) = plan { + missed_expr.extend(projection.expr); + plan = LogicalPlan::Projection(Projection::try_new(missed_expr, projection.input)?) + } + } + } let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?; if renamed_schema.equivalent_names_and_types(plan.schema()) { // Nothing to do if the schema is already equivalent @@ -228,15 +251,36 @@ pub async fn from_substrait_plan( match plan { // If the last node of the plan produces expressions, bake the renames into those expressions. // This isn't necessary for correctness, but helps with roundtrip tests. - LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, p.input)?)), + LogicalPlan::Projection(p) => { + Ok(LogicalPlan::Projection( + Projection::try_new( + rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, + p.input + )? + )) + }, + LogicalPlan::Aggregate(a) => { let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), &renamed_schema)?; - Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?)) + Ok(LogicalPlan::Aggregate( + Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)? + )) }, + // There are probably more plans where we could bake things in, can add them later as needed. // Otherwise, add a new Project to handle the renaming. - _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), &renamed_schema)?, Arc::new(plan))?)) + _ => Ok(LogicalPlan::Projection( + Projection::try_new( + rename_expressions( + plan.schema().columns().iter().map(|c| col(c.to_owned())), + plan.schema(), + &renamed_schema + )?, + Arc::new(plan) + )? + )), } + } }, None => plan_err!("Cannot parse plan relation: None") @@ -358,12 +402,13 @@ fn make_renamed_schema( name_idx, )?), ))), - _ => Ok(dtype.to_owned()), + _ => { + Ok(dtype.to_owned()) + } } } let mut name_idx = 0; - let (qualifiers, fields): (_, Vec) = schema .iter() .map(|(q, f)| { @@ -390,7 +435,6 @@ fn make_renamed_schema( name_idx, dfs_names.len()); } - DFSchema::from_field_specific_qualified_schema( qualifiers, &Arc::new(Schema::new(fields)), diff --git a/datafusion/substrait/tests/cases/bugs_converage.rs b/datafusion/substrait/tests/cases/bugs_converage.rs new file mode 100644 index 0000000000000..da3d9825f5411 --- /dev/null +++ b/datafusion/substrait/tests/cases/bugs_converage.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for bugs in substrait + +#[cfg(test)] +mod tests { + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::Result; + use datafusion::datasource::MemTable; + use datafusion::prelude::SessionContext; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use std::fs::File; + use std::io::BufReader; + use std::sync::Arc; + use substrait::proto::Plan; + #[tokio::test] + async fn extra_projection_with_input() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Schema::new(vec![ + Field::new("user_id", DataType::Utf8, false), + Field::new("name", DataType::Utf8, false), + Field::new("paid_for_service", DataType::Boolean, false), + ]); + let memory_table = MemTable::try_new(schema.into(), vec![vec![]]).unwrap(); + ctx.register_table("users", Arc::new(memory_table))?; + let path = "tests/testdata/extra_projection_with_input.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{}", plan); + assert_eq!(plan_str, "Projection: users.user_id, users.name, users.paid_for_service, row_number() ORDER BY [users.name ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS row_number\ + \n WindowAggr: windowExpr=[[row_number() ORDER BY [users.name ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n TableScan: users projection=[user_id, name, paid_for_service]"); + Ok(()) + } +} diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index d3ea7695e4b9e..816790388660c 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +mod bugs_converage; mod consumer_integration; mod function_test; mod logical_plans; diff --git a/datafusion/substrait/tests/testdata/extra_projection_with_input.json b/datafusion/substrait/tests/testdata/extra_projection_with_input.json new file mode 100644 index 0000000000000..41b93a8f2e10f --- /dev/null +++ b/datafusion/substrait/tests/testdata/extra_projection_with_input.json @@ -0,0 +1,113 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "row_number" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "user_id", + "name", + "paid_for_service" + ], + "struct": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "users" + ] + } + } + }, + "expressions": [ + { + "windowFunction": { + "functionReference": 1, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_FIRST" + } + ], + "upperBound": { + "unbounded": {} + }, + "lowerBound": { + "unbounded": {} + }, + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL" + } + } + ] + } + }, + "names": [ + "user_id", + "name", + "paid_for_service", + "row_number" + ] + } + } + ], + "version": { + "minorNumber": 52, + "producer": "spark-substrait-gateway" + } +} \ No newline at end of file