Skip to content

Commit

Permalink
ensure schema from table matches substrait's schema when consuming su…
Browse files Browse the repository at this point in the history
…bstrait
  • Loading branch information
Blizzara committed Sep 4, 2024
1 parent 7c7f485 commit 8d0a0f9
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 130 deletions.
226 changes: 137 additions & 89 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ use crate::variation_const::{
};
use datafusion::arrow::array::{new_empty_array, AsArray};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::dataframe::DataFrame;
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{
col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
Expand All @@ -75,7 +76,7 @@ use substrait::proto::expression::literal::{
IntervalDayToSecond, IntervalYearToMonth, UserDefined,
};
use substrait::proto::expression::subquery::SubqueryType;
use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction};
use substrait::proto::expression::{FieldReference, Literal, ScalarFunction};
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
use substrait::proto::{
aggregate_function::AggregationInvocation,
Expand Down Expand Up @@ -252,7 +253,7 @@ pub async fn from_substrait_plan(
/// parse projection
pub fn extract_projection(
t: LogicalPlan,
projection: &::core::option::Option<expression::MaskExpression>,
projection: &Option<MaskExpression>,
) -> Result<LogicalPlan> {
match projection {
Some(MaskExpression { select, .. }) => match &select.as_ref() {
Expand Down Expand Up @@ -286,6 +287,39 @@ pub fn extract_projection(
}
}

fn ensure_schema_match(
table: DataFrame,
expected: &NamedStruct,
extensions: &Extensions,
) -> Result<DataFrame> {
// Substrait schema contains no qualifiers, so we need to strip them from the DF schema as well
// for the comparison to work
let df_schema = table.schema().to_owned().strip_qualifiers();
let substrait_schema = from_substrait_named_struct(expected, extensions)?;
if df_schema.logically_equivalent_names_and_types(&substrait_schema) {
return Ok(table);
}
let selected_columns = substrait_schema
.fields()
.iter()
.map(|f| {
let df_field = df_schema.field_with_unqualified_name(f.name())?;
if DFSchema::datatype_is_logically_equal(df_field.data_type(), f.data_type()) {
Ok(col(format!("\"{}\"", df_field.name())))
} else {
substrait_err!(
"Field {} in Substrait schema has a different type ({}) than the corresponding field in the table schema ({}).",
f.name(),
f.data_type(),
df_field.data_type()
)
}
})
.collect::<Result<_>>()?;

table.select(selected_columns)
}

/// Ensure the expressions have the right name(s) according to the new schema.
/// This includes the top-level (column) name, which will be renamed through aliasing if needed,
/// as well as nested names (if the expression produces any struct types), which will be renamed
Expand Down Expand Up @@ -638,111 +672,125 @@ pub async fn from_substrait_rel(
let (left, right) = requalify_sides_if_needed(left, right)?;
left.cross_join(right.build()?)?.build()
}
Some(RelType::Read(read)) => match &read.as_ref().read_type {
Some(ReadType::NamedTable(nt)) => {
let table_reference = match nt.names.len() {
0 => {
return plan_err!("No table name found in NamedTable");
}
1 => TableReference::Bare {
table: nt.names[0].clone().into(),
},
2 => TableReference::Partial {
schema: nt.names[0].clone().into(),
table: nt.names[1].clone().into(),
},
_ => TableReference::Full {
catalog: nt.names[0].clone().into(),
schema: nt.names[1].clone().into(),
table: nt.names[2].clone().into(),
},
};
let t = ctx.table(table_reference).await?;
let t = t.into_optimized_plan()?;
extract_projection(t, &read.projection)
}
Some(ReadType::VirtualTable(vt)) => {
let base_schema = read.base_schema.as_ref().ok_or_else(|| {
substrait_datafusion_err!("No base schema provided for Virtual Table")
})?;
Some(RelType::Read(read)) => {
let base_schema = read.base_schema.as_ref().ok_or_else(|| {
substrait_datafusion_err!("No base schema provided for a Read")
})?;

let schema = from_substrait_named_struct(base_schema, extensions)?;
match &read.as_ref().read_type {
Some(ReadType::NamedTable(nt)) => {
let table_reference = match nt.names.len() {
0 => {
return plan_err!("No table name found in NamedTable");
}
1 => TableReference::Bare {
table: nt.names[0].clone().into(),
},
2 => TableReference::Partial {
schema: nt.names[0].clone().into(),
table: nt.names[1].clone().into(),
},
_ => TableReference::Full {
catalog: nt.names[0].clone().into(),
schema: nt.names[1].clone().into(),
table: nt.names[2].clone().into(),
},
};
let t = ctx.table(table_reference).await?;

if vt.values.is_empty() {
return Ok(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema,
}));
let t = ensure_schema_match(t, base_schema, extensions)?;
let t = t.into_optimized_plan()?;
extract_projection(t, &read.projection)
}
Some(ReadType::VirtualTable(vt)) => {
let schema = from_substrait_named_struct(base_schema, extensions)?;

if vt.values.is_empty() {
return Ok(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema,
}));
}

let values = vt
.values
.iter()
.map(|row| {
let mut name_idx = 0;
let lits = row
.fields
.iter()
.map(|lit| {
name_idx += 1; // top-level names are provided through schema
Ok(Expr::Literal(from_substrait_literal(
lit,
extensions,
&base_schema.names,
&mut name_idx,
)?))
})
.collect::<Result<_>>()?;
if name_idx != base_schema.names.len() {
return substrait_err!(
let values = vt
.values
.iter()
.map(|row| {
let mut name_idx = 0;
let lits = row
.fields
.iter()
.map(|lit| {
name_idx += 1; // top-level names are provided through schema
Ok(Expr::Literal(from_substrait_literal(
lit,
extensions,
&base_schema.names,
&mut name_idx,
)?))
})
.collect::<Result<_>>()?;
if name_idx != base_schema.names.len() {
return substrait_err!(
"Names list must match exactly to nested schema, but found {} uses for {} names",
name_idx,
base_schema.names.len()
);
}
Ok(lits)
})
.collect::<Result<_>>()?;
}
Ok(lits)
})
.collect::<Result<_>>()?;

Ok(LogicalPlan::Values(Values { schema, values }))
}
Some(ReadType::LocalFiles(lf)) => {
fn extract_filename(name: &str) -> Option<String> {
let corrected_url =
if name.starts_with("file://") && !name.starts_with("file:///") {
if read.projection.is_some() {
return substrait_err!(
"Projection is not supported for Virtual Table"
);
}
Ok(LogicalPlan::Values(Values { schema, values }))
}
Some(ReadType::LocalFiles(lf)) => {
fn extract_filename(name: &str) -> Option<String> {
let corrected_url = if name.starts_with("file://")
&& !name.starts_with("file:///")
{
name.replacen("file://", "file:///", 1)
} else {
name.to_string()
};

Url::parse(&corrected_url).ok().and_then(|url| {
let path = url.path();
std::path::Path::new(path)
.file_name()
.map(|filename| filename.to_string_lossy().to_string())
})
}
Url::parse(&corrected_url).ok().and_then(|url| {
let path = url.path();
std::path::Path::new(path)
.file_name()
.map(|filename| filename.to_string_lossy().to_string())
})
}

// we could use the file name to check the original table provider
// TODO: currently does not support multiple local files
let filename: Option<String> =
lf.items.first().and_then(|x| match x.path_type.as_ref() {
Some(UriFile(name)) => extract_filename(name),
_ => None,
});
// we could use the file name to check the original table provider
// TODO: currently does not support multiple local files
let filename: Option<String> =
lf.items.first().and_then(|x| match x.path_type.as_ref() {
Some(UriFile(name)) => extract_filename(name),
_ => None,
});

if lf.items.len() > 1 || filename.is_none() {
return not_impl_err!("Only single file reads are supported");
if lf.items.len() > 1 || filename.is_none() {
return not_impl_err!("Only single file reads are supported");
}
let name = filename.unwrap();
// directly use unwrap here since we could determine it is a valid one
let table_reference = TableReference::Bare { table: name.into() };
let t = ctx.table(table_reference).await?;
// TODO: requires renaming all the CSV fields to be capitalized
// let t = ensure_schema_match(t, base_schema, extensions)?;
let t = t.into_optimized_plan()?;
extract_projection(t, &read.projection)
}
_ => {
not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type)
}
let name = filename.unwrap();
// directly use unwrap here since we could determine it is a valid one
let table_reference = TableReference::Bare { table: name.into() };
let t = ctx.table(table_reference).await?;
let t = t.into_optimized_plan()?;
extract_projection(t, &read.projection)
}
_ => not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type),
},
}
Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) {
Ok(set_op) => match set_op {
set_rel::SetOp::UnionAll => {
Expand Down
Loading

0 comments on commit 8d0a0f9

Please sign in to comment.