From 8d0a0f95e09340cfe0463bf1cfc9511c5eb6b569 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Wed, 4 Sep 2024 16:07:19 +0200 Subject: [PATCH] ensure schema from table matches substrait's schema when consuming substrait --- .../substrait/src/logical_plan/consumer.rs | 226 +++++++++++------- .../tests/cases/consumer_integration.rs | 56 ++--- .../substrait/tests/cases/logical_plans.rs | 10 +- .../test_plans/select_not_bool.substrait.json | 2 +- .../test_plans/select_window.substrait.json | 12 +- .../substrait/tests/testdata/tpch/nation.csv | 2 +- .../substrait/tests/testdata/tpch/region.csv | 2 +- 7 files changed, 180 insertions(+), 130 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 21bef3c2c98e..d7057d2d399c 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -53,6 +53,7 @@ use crate::variation_const::{ }; use datafusion::arrow::array::{new_empty_array, AsArray}; use datafusion::common::scalar::ScalarStructBuilder; +use datafusion::dataframe::DataFrame; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, @@ -75,7 +76,7 @@ use substrait::proto::expression::literal::{ IntervalDayToSecond, IntervalYearToMonth, UserDefined, }; use substrait::proto::expression::subquery::SubqueryType; -use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction}; +use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; use substrait::proto::{ aggregate_function::AggregationInvocation, @@ -252,7 +253,7 @@ pub async fn from_substrait_plan( /// parse projection pub fn extract_projection( t: LogicalPlan, - projection: &::core::option::Option, + projection: &Option, ) -> Result { match projection { Some(MaskExpression { select, .. }) => match &select.as_ref() { @@ -286,6 +287,39 @@ pub fn extract_projection( } } +fn ensure_schema_match( + table: DataFrame, + expected: &NamedStruct, + extensions: &Extensions, +) -> Result { + // Substrait schema contains no qualifiers, so we need to strip them from the DF schema as well + // for the comparison to work + let df_schema = table.schema().to_owned().strip_qualifiers(); + let substrait_schema = from_substrait_named_struct(expected, extensions)?; + if df_schema.logically_equivalent_names_and_types(&substrait_schema) { + return Ok(table); + } + let selected_columns = substrait_schema + .fields() + .iter() + .map(|f| { + let df_field = df_schema.field_with_unqualified_name(f.name())?; + if DFSchema::datatype_is_logically_equal(df_field.data_type(), f.data_type()) { + Ok(col(format!("\"{}\"", df_field.name()))) + } else { + substrait_err!( + "Field {} in Substrait schema has a different type ({}) than the corresponding field in the table schema ({}).", + f.name(), + f.data_type(), + df_field.data_type() + ) + } + }) + .collect::>()?; + + table.select(selected_columns) +} + /// Ensure the expressions have the right name(s) according to the new schema. /// This includes the top-level (column) name, which will be renamed through aliasing if needed, /// as well as nested names (if the expression produces any struct types), which will be renamed @@ -638,111 +672,125 @@ pub async fn from_substrait_rel( let (left, right) = requalify_sides_if_needed(left, right)?; left.cross_join(right.build()?)?.build() } - Some(RelType::Read(read)) => match &read.as_ref().read_type { - Some(ReadType::NamedTable(nt)) => { - let table_reference = match nt.names.len() { - 0 => { - return plan_err!("No table name found in NamedTable"); - } - 1 => TableReference::Bare { - table: nt.names[0].clone().into(), - }, - 2 => TableReference::Partial { - schema: nt.names[0].clone().into(), - table: nt.names[1].clone().into(), - }, - _ => TableReference::Full { - catalog: nt.names[0].clone().into(), - schema: nt.names[1].clone().into(), - table: nt.names[2].clone().into(), - }, - }; - let t = ctx.table(table_reference).await?; - let t = t.into_optimized_plan()?; - extract_projection(t, &read.projection) - } - Some(ReadType::VirtualTable(vt)) => { - let base_schema = read.base_schema.as_ref().ok_or_else(|| { - substrait_datafusion_err!("No base schema provided for Virtual Table") - })?; + Some(RelType::Read(read)) => { + let base_schema = read.base_schema.as_ref().ok_or_else(|| { + substrait_datafusion_err!("No base schema provided for a Read") + })?; - let schema = from_substrait_named_struct(base_schema, extensions)?; + match &read.as_ref().read_type { + Some(ReadType::NamedTable(nt)) => { + let table_reference = match nt.names.len() { + 0 => { + return plan_err!("No table name found in NamedTable"); + } + 1 => TableReference::Bare { + table: nt.names[0].clone().into(), + }, + 2 => TableReference::Partial { + schema: nt.names[0].clone().into(), + table: nt.names[1].clone().into(), + }, + _ => TableReference::Full { + catalog: nt.names[0].clone().into(), + schema: nt.names[1].clone().into(), + table: nt.names[2].clone().into(), + }, + }; + let t = ctx.table(table_reference).await?; - if vt.values.is_empty() { - return Ok(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema, - })); + let t = ensure_schema_match(t, base_schema, extensions)?; + let t = t.into_optimized_plan()?; + extract_projection(t, &read.projection) } + Some(ReadType::VirtualTable(vt)) => { + let schema = from_substrait_named_struct(base_schema, extensions)?; + + if vt.values.is_empty() { + return Ok(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema, + })); + } - let values = vt - .values - .iter() - .map(|row| { - let mut name_idx = 0; - let lits = row - .fields - .iter() - .map(|lit| { - name_idx += 1; // top-level names are provided through schema - Ok(Expr::Literal(from_substrait_literal( - lit, - extensions, - &base_schema.names, - &mut name_idx, - )?)) - }) - .collect::>()?; - if name_idx != base_schema.names.len() { - return substrait_err!( + let values = vt + .values + .iter() + .map(|row| { + let mut name_idx = 0; + let lits = row + .fields + .iter() + .map(|lit| { + name_idx += 1; // top-level names are provided through schema + Ok(Expr::Literal(from_substrait_literal( + lit, + extensions, + &base_schema.names, + &mut name_idx, + )?)) + }) + .collect::>()?; + if name_idx != base_schema.names.len() { + return substrait_err!( "Names list must match exactly to nested schema, but found {} uses for {} names", name_idx, base_schema.names.len() ); - } - Ok(lits) - }) - .collect::>()?; + } + Ok(lits) + }) + .collect::>()?; - Ok(LogicalPlan::Values(Values { schema, values })) - } - Some(ReadType::LocalFiles(lf)) => { - fn extract_filename(name: &str) -> Option { - let corrected_url = - if name.starts_with("file://") && !name.starts_with("file:///") { + if read.projection.is_some() { + return substrait_err!( + "Projection is not supported for Virtual Table" + ); + } + Ok(LogicalPlan::Values(Values { schema, values })) + } + Some(ReadType::LocalFiles(lf)) => { + fn extract_filename(name: &str) -> Option { + let corrected_url = if name.starts_with("file://") + && !name.starts_with("file:///") + { name.replacen("file://", "file:///", 1) } else { name.to_string() }; - Url::parse(&corrected_url).ok().and_then(|url| { - let path = url.path(); - std::path::Path::new(path) - .file_name() - .map(|filename| filename.to_string_lossy().to_string()) - }) - } + Url::parse(&corrected_url).ok().and_then(|url| { + let path = url.path(); + std::path::Path::new(path) + .file_name() + .map(|filename| filename.to_string_lossy().to_string()) + }) + } - // we could use the file name to check the original table provider - // TODO: currently does not support multiple local files - let filename: Option = - lf.items.first().and_then(|x| match x.path_type.as_ref() { - Some(UriFile(name)) => extract_filename(name), - _ => None, - }); + // we could use the file name to check the original table provider + // TODO: currently does not support multiple local files + let filename: Option = + lf.items.first().and_then(|x| match x.path_type.as_ref() { + Some(UriFile(name)) => extract_filename(name), + _ => None, + }); - if lf.items.len() > 1 || filename.is_none() { - return not_impl_err!("Only single file reads are supported"); + if lf.items.len() > 1 || filename.is_none() { + return not_impl_err!("Only single file reads are supported"); + } + let name = filename.unwrap(); + // directly use unwrap here since we could determine it is a valid one + let table_reference = TableReference::Bare { table: name.into() }; + let t = ctx.table(table_reference).await?; + // TODO: requires renaming all the CSV fields to be capitalized + // let t = ensure_schema_match(t, base_schema, extensions)?; + let t = t.into_optimized_plan()?; + extract_projection(t, &read.projection) + } + _ => { + not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type) } - let name = filename.unwrap(); - // directly use unwrap here since we could determine it is a valid one - let table_reference = TableReference::Bare { table: name.into() }; - let t = ctx.table(table_reference).await?; - let t = t.into_optimized_plan()?; - extract_projection(t, &read.projection) } - _ => not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type), - }, + } Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) { Ok(set_op) => match set_op { set_rel::SetOp::UnionAll => { diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 0a86d27e013c..dad24559a06f 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -92,22 +92,22 @@ mod tests { let plan_str = format!("{}", plan); assert_eq!( plan_str, - "Projection: FILENAME_PLACEHOLDER_1.s_acctbal AS S_ACCTBAL, FILENAME_PLACEHOLDER_1.s_name AS S_NAME, FILENAME_PLACEHOLDER_3.n_name AS N_NAME, FILENAME_PLACEHOLDER_0.p_partkey AS P_PARTKEY, FILENAME_PLACEHOLDER_0.p_mfgr AS P_MFGR, FILENAME_PLACEHOLDER_1.s_address AS S_ADDRESS, FILENAME_PLACEHOLDER_1.s_phone AS S_PHONE, FILENAME_PLACEHOLDER_1.s_comment AS S_COMMENT\ + "Projection: FILENAME_PLACEHOLDER_1.s_acctbal AS S_ACCTBAL, FILENAME_PLACEHOLDER_1.s_name AS S_NAME, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.p_partkey AS P_PARTKEY, FILENAME_PLACEHOLDER_0.p_mfgr AS P_MFGR, FILENAME_PLACEHOLDER_1.s_address AS S_ADDRESS, FILENAME_PLACEHOLDER_1.s_phone AS S_PHONE, FILENAME_PLACEHOLDER_1.s_comment AS S_COMMENT\ \n Limit: skip=0, fetch=100\ - \n Sort: FILENAME_PLACEHOLDER_1.s_acctbal DESC NULLS FIRST, FILENAME_PLACEHOLDER_3.n_name ASC NULLS LAST, FILENAME_PLACEHOLDER_1.s_name ASC NULLS LAST, FILENAME_PLACEHOLDER_0.p_partkey ASC NULLS LAST\ - \n Projection: FILENAME_PLACEHOLDER_1.s_acctbal, FILENAME_PLACEHOLDER_1.s_name, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.p_partkey, FILENAME_PLACEHOLDER_0.p_mfgr, FILENAME_PLACEHOLDER_1.s_address, FILENAME_PLACEHOLDER_1.s_phone, FILENAME_PLACEHOLDER_1.s_comment\ - \n Filter: FILENAME_PLACEHOLDER_0.p_partkey = FILENAME_PLACEHOLDER_2.ps_partkey AND FILENAME_PLACEHOLDER_1.s_suppkey = FILENAME_PLACEHOLDER_2.ps_suppkey AND FILENAME_PLACEHOLDER_0.p_size = Int32(15) AND FILENAME_PLACEHOLDER_0.p_type LIKE CAST(Utf8(\"%BRASS\") AS Utf8) AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey AND FILENAME_PLACEHOLDER_3.n_regionkey = FILENAME_PLACEHOLDER_4.r_regionkey AND FILENAME_PLACEHOLDER_4.r_name = CAST(Utf8(\"EUROPE\") AS Utf8) AND FILENAME_PLACEHOLDER_2.ps_supplycost = ()\ + \n Sort: FILENAME_PLACEHOLDER_1.s_acctbal DESC NULLS FIRST, FILENAME_PLACEHOLDER_3.N_NAME ASC NULLS LAST, FILENAME_PLACEHOLDER_1.s_name ASC NULLS LAST, FILENAME_PLACEHOLDER_0.p_partkey ASC NULLS LAST\ + \n Projection: FILENAME_PLACEHOLDER_1.s_acctbal, FILENAME_PLACEHOLDER_1.s_name, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.p_partkey, FILENAME_PLACEHOLDER_0.p_mfgr, FILENAME_PLACEHOLDER_1.s_address, FILENAME_PLACEHOLDER_1.s_phone, FILENAME_PLACEHOLDER_1.s_comment\ + \n Filter: FILENAME_PLACEHOLDER_0.p_partkey = FILENAME_PLACEHOLDER_2.ps_partkey AND FILENAME_PLACEHOLDER_1.s_suppkey = FILENAME_PLACEHOLDER_2.ps_suppkey AND FILENAME_PLACEHOLDER_0.p_size = Int32(15) AND FILENAME_PLACEHOLDER_0.p_type LIKE CAST(Utf8(\"%BRASS\") AS Utf8) AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_3.N_NATIONKEY AND FILENAME_PLACEHOLDER_3.N_REGIONKEY = FILENAME_PLACEHOLDER_4.R_REGIONKEY AND FILENAME_PLACEHOLDER_4.R_NAME = CAST(Utf8(\"EUROPE\") AS Utf8) AND FILENAME_PLACEHOLDER_2.ps_supplycost = ()\ \n Subquery:\ \n Aggregate: groupBy=[[]], aggr=[[min(FILENAME_PLACEHOLDER_5.ps_supplycost)]]\ \n Projection: FILENAME_PLACEHOLDER_5.ps_supplycost\ - \n Filter: FILENAME_PLACEHOLDER_5.ps_partkey = FILENAME_PLACEHOLDER_5.ps_partkey AND FILENAME_PLACEHOLDER_6.s_suppkey = FILENAME_PLACEHOLDER_5.ps_suppkey AND FILENAME_PLACEHOLDER_6.s_nationkey = FILENAME_PLACEHOLDER_7.n_nationkey AND FILENAME_PLACEHOLDER_7.n_regionkey = FILENAME_PLACEHOLDER_8.r_regionkey AND FILENAME_PLACEHOLDER_8.r_name = CAST(Utf8(\"EUROPE\") AS Utf8)\ + \n Filter: FILENAME_PLACEHOLDER_5.ps_partkey = FILENAME_PLACEHOLDER_5.ps_partkey AND FILENAME_PLACEHOLDER_6.s_suppkey = FILENAME_PLACEHOLDER_5.ps_suppkey AND FILENAME_PLACEHOLDER_6.s_nationkey = FILENAME_PLACEHOLDER_7.N_NATIONKEY AND FILENAME_PLACEHOLDER_7.N_REGIONKEY = FILENAME_PLACEHOLDER_8.R_REGIONKEY AND FILENAME_PLACEHOLDER_8.R_NAME = CAST(Utf8(\"EUROPE\") AS Utf8)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_5 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ \n TableScan: FILENAME_PLACEHOLDER_6 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_7 projection=[n_nationkey, n_name, n_regionkey, n_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_8 projection=[r_regionkey, r_name, r_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_7 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ + \n TableScan: FILENAME_PLACEHOLDER_8 projection=[R_REGIONKEY, R_NAME, R_COMMENT]\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ @@ -115,8 +115,8 @@ mod tests { \n TableScan: FILENAME_PLACEHOLDER_0 projection=[p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment]\ \n TableScan: FILENAME_PLACEHOLDER_1 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ \n TableScan: FILENAME_PLACEHOLDER_2 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_4 projection=[r_regionkey, r_name, r_comment]" + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ + \n TableScan: FILENAME_PLACEHOLDER_4 projection=[R_REGIONKEY, R_NAME, R_COMMENT]" ); Ok(()) } @@ -196,11 +196,11 @@ mod tests { let plan = from_substrait_plan(&ctx, &proto).await?; let plan_str = format!("{}", plan); - assert_eq!(plan_str, "Projection: NATION.n_name AS N_NAME, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE\ + assert_eq!(plan_str, "Projection: NATION.N_NAME, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE\ \n Sort: sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) DESC NULLS FIRST\ - \n Aggregate: groupBy=[[NATION.n_name]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ - \n Projection: NATION.n_name, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ - \n Filter: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_2.l_suppkey = FILENAME_PLACEHOLDER_3.s_suppkey AND FILENAME_PLACEHOLDER_0.c_nationkey = FILENAME_PLACEHOLDER_3.s_nationkey AND FILENAME_PLACEHOLDER_3.s_nationkey = NATION.n_nationkey AND NATION.n_regionkey = REGION.r_regionkey AND REGION.r_name = CAST(Utf8(\"ASIA\") AS Utf8) AND FILENAME_PLACEHOLDER_1.o_orderdate >= CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_1.o_orderdate < CAST(Utf8(\"1995-01-01\") AS Date32)\ + \n Aggregate: groupBy=[[NATION.N_NAME]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ + \n Projection: NATION.N_NAME, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ + \n Filter: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_2.l_suppkey = FILENAME_PLACEHOLDER_3.s_suppkey AND FILENAME_PLACEHOLDER_0.c_nationkey = FILENAME_PLACEHOLDER_3.s_nationkey AND FILENAME_PLACEHOLDER_3.s_nationkey = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = CAST(Utf8(\"ASIA\") AS Utf8) AND FILENAME_PLACEHOLDER_1.o_orderdate >= CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_1.o_orderdate < CAST(Utf8(\"1995-01-01\") AS Date32)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ @@ -210,8 +210,8 @@ mod tests { \n TableScan: FILENAME_PLACEHOLDER_1 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\ \n TableScan: FILENAME_PLACEHOLDER_2 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ \n TableScan: FILENAME_PLACEHOLDER_3 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: NATION projection=[n_nationkey, n_name, n_regionkey, n_comment]\ - \n TableScan: REGION projection=[r_regionkey, r_name, r_comment]"); + \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ + \n TableScan: REGION projection=[R_REGIONKEY, R_NAME, R_COMMENT]"); Ok(()) } @@ -255,19 +255,19 @@ mod tests { let plan = from_substrait_plan(&ctx, &proto).await?; let plan_str = format!("{}", plan); - assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.c_custkey AS C_CUSTKEY, FILENAME_PLACEHOLDER_0.c_name AS C_NAME, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE, FILENAME_PLACEHOLDER_0.c_acctbal AS C_ACCTBAL, FILENAME_PLACEHOLDER_3.n_name AS N_NAME, FILENAME_PLACEHOLDER_0.c_address AS C_ADDRESS, FILENAME_PLACEHOLDER_0.c_phone AS C_PHONE, FILENAME_PLACEHOLDER_0.c_comment AS C_COMMENT\ + assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.c_custkey AS C_CUSTKEY, FILENAME_PLACEHOLDER_0.c_name AS C_NAME, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE, FILENAME_PLACEHOLDER_0.c_acctbal AS C_ACCTBAL, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.c_address AS C_ADDRESS, FILENAME_PLACEHOLDER_0.c_phone AS C_PHONE, FILENAME_PLACEHOLDER_0.c_comment AS C_COMMENT\ \n Limit: skip=0, fetch=20\ \n Sort: sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) DESC NULLS FIRST\ - \n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount), FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_0.c_comment\n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_comment]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ - \n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_comment, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ - \n Filter: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_1.o_orderdate >= CAST(Utf8(\"1993-10-01\") AS Date32) AND FILENAME_PLACEHOLDER_1.o_orderdate < CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_2.l_returnflag = Utf8(\"R\") AND FILENAME_PLACEHOLDER_0.c_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey\ + \n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount), FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_0.c_comment\n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_comment]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ + \n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_comment, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ + \n Filter: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_1.o_orderdate >= CAST(Utf8(\"1993-10-01\") AS Date32) AND FILENAME_PLACEHOLDER_1.o_orderdate < CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_2.l_returnflag = Utf8(\"R\") AND FILENAME_PLACEHOLDER_0.c_nationkey = FILENAME_PLACEHOLDER_3.N_NATIONKEY\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_0 projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment]\ \n TableScan: FILENAME_PLACEHOLDER_1 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\ \n TableScan: FILENAME_PLACEHOLDER_2 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]"); Ok(()) } @@ -297,20 +297,20 @@ mod tests { \n Projection: sum(FILENAME_PLACEHOLDER_3.ps_supplycost * FILENAME_PLACEHOLDER_3.ps_availqty) * Decimal128(Some(1000000),11,10)\ \n Aggregate: groupBy=[[]], aggr=[[sum(FILENAME_PLACEHOLDER_3.ps_supplycost * FILENAME_PLACEHOLDER_3.ps_availqty)]]\ \n Projection: FILENAME_PLACEHOLDER_3.ps_supplycost * CAST(FILENAME_PLACEHOLDER_3.ps_availqty AS Decimal128(19, 0))\ - \n Filter: FILENAME_PLACEHOLDER_3.ps_suppkey = FILENAME_PLACEHOLDER_4.s_suppkey AND FILENAME_PLACEHOLDER_4.s_nationkey = FILENAME_PLACEHOLDER_5.n_nationkey AND FILENAME_PLACEHOLDER_5.n_name = CAST(Utf8(\"JAPAN\") AS Utf8)\ + \n Filter: FILENAME_PLACEHOLDER_3.ps_suppkey = FILENAME_PLACEHOLDER_4.s_suppkey AND FILENAME_PLACEHOLDER_4.s_nationkey = FILENAME_PLACEHOLDER_5.N_NATIONKEY AND FILENAME_PLACEHOLDER_5.N_NAME = CAST(Utf8(\"JAPAN\") AS Utf8)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_3 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ \n TableScan: FILENAME_PLACEHOLDER_4 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_5 projection=[n_nationkey, n_name, n_regionkey, n_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_5 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.ps_partkey]], aggr=[[sum(FILENAME_PLACEHOLDER_0.ps_supplycost * FILENAME_PLACEHOLDER_0.ps_availqty)]]\ \n Projection: FILENAME_PLACEHOLDER_0.ps_partkey, FILENAME_PLACEHOLDER_0.ps_supplycost * CAST(FILENAME_PLACEHOLDER_0.ps_availqty AS Decimal128(19, 0))\ - \n Filter: FILENAME_PLACEHOLDER_0.ps_suppkey = FILENAME_PLACEHOLDER_1.s_suppkey AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_2.n_nationkey AND FILENAME_PLACEHOLDER_2.n_name = CAST(Utf8(\"JAPAN\") AS Utf8)\ + \n Filter: FILENAME_PLACEHOLDER_0.ps_suppkey = FILENAME_PLACEHOLDER_1.s_suppkey AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_2.N_NATIONKEY AND FILENAME_PLACEHOLDER_2.N_NAME = CAST(Utf8(\"JAPAN\") AS Utf8)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_0 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ \n TableScan: FILENAME_PLACEHOLDER_1 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_2 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); + \n TableScan: FILENAME_PLACEHOLDER_2 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]"); Ok(()) } @@ -498,7 +498,7 @@ mod tests { assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.s_name AS S_NAME, FILENAME_PLACEHOLDER_0.s_address AS S_ADDRESS\ \n Sort: FILENAME_PLACEHOLDER_0.s_name ASC NULLS LAST\ \n Projection: FILENAME_PLACEHOLDER_0.s_name, FILENAME_PLACEHOLDER_0.s_address\ - \n Filter: CAST(FILENAME_PLACEHOLDER_0.s_suppkey IN () AS Boolean) AND FILENAME_PLACEHOLDER_0.s_nationkey = FILENAME_PLACEHOLDER_1.n_nationkey AND FILENAME_PLACEHOLDER_1.n_name = CAST(Utf8(\"CANADA\") AS Utf8)\ + \n Filter: CAST(FILENAME_PLACEHOLDER_0.s_suppkey IN () AS Boolean) AND FILENAME_PLACEHOLDER_0.s_nationkey = FILENAME_PLACEHOLDER_1.N_NATIONKEY AND FILENAME_PLACEHOLDER_1.N_NAME = CAST(Utf8(\"CANADA\") AS Utf8)\ \n Subquery:\ \n Projection: FILENAME_PLACEHOLDER_2.ps_suppkey\ \n Filter: CAST(FILENAME_PLACEHOLDER_2.ps_partkey IN () AS Boolean) AND CAST(FILENAME_PLACEHOLDER_2.ps_availqty AS Decimal128(19, 1)) > ()\ @@ -515,7 +515,7 @@ mod tests { \n TableScan: FILENAME_PLACEHOLDER_2 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_0 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_1 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]"); Ok(()) } @@ -543,7 +543,7 @@ mod tests { \n Sort: count(Int64(1)) DESC NULLS FIRST, FILENAME_PLACEHOLDER_0.s_name ASC NULLS LAST\ \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.s_name]], aggr=[[count(Int64(1))]]\ \n Projection: FILENAME_PLACEHOLDER_0.s_name\ - \n Filter: FILENAME_PLACEHOLDER_0.s_suppkey = FILENAME_PLACEHOLDER_1.l_suppkey AND FILENAME_PLACEHOLDER_2.o_orderkey = FILENAME_PLACEHOLDER_1.l_orderkey AND FILENAME_PLACEHOLDER_2.o_orderstatus = Utf8(\"F\") AND FILENAME_PLACEHOLDER_1.l_receiptdate > FILENAME_PLACEHOLDER_1.l_commitdate AND EXISTS () AND NOT EXISTS () AND FILENAME_PLACEHOLDER_0.s_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey AND FILENAME_PLACEHOLDER_3.n_name = CAST(Utf8(\"SAUDI ARABIA\") AS Utf8)\ + \n Filter: FILENAME_PLACEHOLDER_0.s_suppkey = FILENAME_PLACEHOLDER_1.l_suppkey AND FILENAME_PLACEHOLDER_2.o_orderkey = FILENAME_PLACEHOLDER_1.l_orderkey AND FILENAME_PLACEHOLDER_2.o_orderstatus = Utf8(\"F\") AND FILENAME_PLACEHOLDER_1.l_receiptdate > FILENAME_PLACEHOLDER_1.l_commitdate AND EXISTS () AND NOT EXISTS () AND FILENAME_PLACEHOLDER_0.s_nationkey = FILENAME_PLACEHOLDER_3.N_NATIONKEY AND FILENAME_PLACEHOLDER_3.N_NAME = CAST(Utf8(\"SAUDI ARABIA\") AS Utf8)\ \n Subquery:\ \n Filter: FILENAME_PLACEHOLDER_4.l_orderkey = FILENAME_PLACEHOLDER_4.l_tax AND FILENAME_PLACEHOLDER_4.l_suppkey != FILENAME_PLACEHOLDER_4.l_linestatus\ \n TableScan: FILENAME_PLACEHOLDER_4 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ @@ -555,7 +555,7 @@ mod tests { \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_0 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ \n TableScan: FILENAME_PLACEHOLDER_1 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\n TableScan: FILENAME_PLACEHOLDER_2 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]"); Ok(()) } diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index f6a2b5036c80..7afcc712c024 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -39,14 +39,15 @@ mod tests { // File generated with substrait-java's Isthmus: // ./isthmus-cli/build/graal/isthmus "select not d from data" -c "create table data (d boolean)" + // and column names lowercased let proto = read_json("tests/testdata/test_plans/select_not_bool.substrait.json"); let plan = from_substrait_plan(&ctx, &proto).await?; assert_eq!( format!("{}", plan), - "Projection: NOT DATA.a AS EXPR$0\ - \n TableScan: DATA projection=[a, b, c, d, e, f]" + "Projection: NOT DATA.d AS EXPR$0\ + \n TableScan: DATA projection=[d]" ); Ok(()) } @@ -64,7 +65,8 @@ mod tests { let ctx = create_context().await?; // File generated with substrait-java's Isthmus: - // ./isthmus-cli/build/graal/isthmus "select sum(d) OVER (PARTITION BY part ORDER BY ord ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data" -c "create table data (d int, part int, ord int)" + // ./isthmus-cli/build/graal/isthmus "select sum(a) OVER (PARTITION BY b ORDER BY c ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data" -c "create table data (a bigint, b float, c date)" + // and column names lowercased let proto = read_json("tests/testdata/test_plans/select_window.substrait.json"); let plan = from_substrait_plan(&ctx, &proto).await?; @@ -73,7 +75,7 @@ mod tests { format!("{}", plan), "Projection: sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\ \n WindowAggr: windowExpr=[[sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: DATA projection=[a, b, c, d, e, f]" + \n TableScan: DATA projection=[a, b, c]" ); Ok(()) } diff --git a/datafusion/substrait/tests/testdata/test_plans/select_not_bool.substrait.json b/datafusion/substrait/tests/testdata/test_plans/select_not_bool.substrait.json index e52cf87d5028..2fc927a3a48e 100644 --- a/datafusion/substrait/tests/testdata/test_plans/select_not_bool.substrait.json +++ b/datafusion/substrait/tests/testdata/test_plans/select_not_bool.substrait.json @@ -34,7 +34,7 @@ }, "baseSchema": { "names": [ - "D" + "d" ], "struct": { "types": [ diff --git a/datafusion/substrait/tests/testdata/test_plans/select_window.substrait.json b/datafusion/substrait/tests/testdata/test_plans/select_window.substrait.json index 3082c4258f83..561827d8ee86 100644 --- a/datafusion/substrait/tests/testdata/test_plans/select_window.substrait.json +++ b/datafusion/substrait/tests/testdata/test_plans/select_window.substrait.json @@ -34,26 +34,26 @@ }, "baseSchema": { "names": [ - "D", - "PART", - "ORD" + "a", + "b", + "c" ], "struct": { "types": [ { - "i32": { + "i64": { "typeVariationReference": 0, "nullability": "NULLABILITY_NULLABLE" } }, { - "i32": { + "fp64": { "typeVariationReference": 0, "nullability": "NULLABILITY_NULLABLE" } }, { - "i32": { + "date": { "typeVariationReference": 0, "nullability": "NULLABILITY_NULLABLE" } diff --git a/datafusion/substrait/tests/testdata/tpch/nation.csv b/datafusion/substrait/tests/testdata/tpch/nation.csv index fdf7421467d3..a88d1c0d31e7 100644 --- a/datafusion/substrait/tests/testdata/tpch/nation.csv +++ b/datafusion/substrait/tests/testdata/tpch/nation.csv @@ -1,2 +1,2 @@ -n_nationkey,n_name,n_regionkey,n_comment +N_NATIONKEY,N_NAME,N_REGIONKEY,N_COMMENT 0,ALGERIA,0, haggle. carefully final deposits detect slyly agai \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/region.csv b/datafusion/substrait/tests/testdata/tpch/region.csv index 6c3fb4524355..d29c39ab8543 100644 --- a/datafusion/substrait/tests/testdata/tpch/region.csv +++ b/datafusion/substrait/tests/testdata/tpch/region.csv @@ -1,2 +1,2 @@ -r_regionkey,r_name,r_comment +R_REGIONKEY,R_NAME,R_COMMENT 0,AFRICA,lar deposits. blithely final packages cajole. regular waters are final requests. regular accounts are according to \ No newline at end of file