Skip to content

Commit

Permalink
infer the source column from the manifest
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal committed Oct 23, 2024
1 parent 5b92139 commit a04ba65
Showing 1 changed file with 202 additions and 8 deletions.
210 changes: 202 additions & 8 deletions wren-modeling-rs/core/src/mdl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@ use crate::mdl::function::{
RemoteFunction,
};
use crate::mdl::manifest::{Column, Manifest, Model, View};
use datafusion::arrow::datatypes::Field;
use datafusion::common::internal_datafusion_err;
use datafusion::datasource::TableProvider;
use datafusion::error::Result;
use datafusion::execution::context::SessionState;
use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
use datafusion::prelude::SessionContext;
use datafusion::sql::parser::DFParser;
use datafusion::sql::sqlparser::ast::{Expr, Ident};
use datafusion::sql::sqlparser::dialect::dialect_from_str;
use datafusion::sql::unparser::dialect::{Dialect, IntervalStyle};
use datafusion::sql::unparser::Unparser;
use datafusion::sql::TableReference;
Expand Down Expand Up @@ -38,7 +43,7 @@ pub struct AnalyzedWrenMDL {

impl AnalyzedWrenMDL {
pub fn analyze(manifest: Manifest) -> Result<Self> {
let wren_mdl = Arc::new(WrenMDL::new_and_register_table_ref(manifest));
let wren_mdl = Arc::new(WrenMDL::infer_and_register_remote_table(manifest));
let lineage = Arc::new(lineage::Lineage::new(&wren_mdl)?);
Ok(AnalyzedWrenMDL { wren_mdl, lineage })
}
Expand Down Expand Up @@ -140,7 +145,7 @@ impl WrenMDL {

/// Create a WrenMDL from a manifest and register the table reference of the model as a remote table.
/// All the column without expression will be considered a column
pub fn new_and_register_table_ref(manifest: Manifest) -> Self {
pub fn infer_and_register_remote_table(manifest: Manifest) -> Self {
let mut mdl = WrenMDL::new(manifest);
let sources: Vec<_> = mdl
.models()
Expand All @@ -150,12 +155,7 @@ impl WrenMDL {
let fields: Vec<_> = model
.columns
.iter()
.filter(|column| {
!column.is_calculated
&& column.expression.is_none()
&& column.relationship.is_none()
})
.map(|column| column.to_field())
.filter_map(|column| Self::infer_source_column(column))
.collect();
let schema = Arc::new(datafusion::arrow::datatypes::Schema::new(fields));
let datasource = WrenDataSource::new_with_schema(schema);
Expand All @@ -168,6 +168,55 @@ impl WrenMDL {
mdl
}

/// Infer the source column from the column expression.
///
/// If the column is calculated or has a relationship, it's not a source column.
/// If the column without expression, it's a source column.
/// If the column has an expression, it will try to infer the source column from the expression.
/// If the expression is a simple column reference, it's the source column name.
/// If the expression is a complex expression, it can't be inferred.
///
fn infer_source_column(column: &Column) -> Option<Field> {
if column.is_calculated || column.relationship.is_some() {
return None;
}

if let Some(expression) = column.expression() {
let expr = WrenMDL::sql_to_expr(expression).ok()?;
// if the column is a simple column reference, we can infer the column name
Self::collect_one_column(&expr).map(|name| {
Field::new(
name.value.clone(),
map_data_type(&column.r#type),
column.no_null,
)
})
} else {
Some(column.to_field())
}
}

fn sql_to_expr(sql: &str) -> Result<Expr> {
let dialect = dialect_from_str("generic").ok_or_else(|| {
internal_datafusion_err!("Failed to create dialect from generic")
})?;

let expr = DFParser::parse_sql_into_expr_with_dialect(sql, dialect.as_ref())?;
Ok(expr)
}

/// Collect the last identifier of the expression
/// e.g. "a"."b"."c" -> c
/// e.g. "a" -> a
/// others -> None
fn collect_one_column(expr: &Expr) -> Option<&Ident> {
match expr {
Expr::CompoundIdentifier(idents) => idents.last(),
Expr::Identifier(ident) => Some(ident),
_ => None,
}
}

pub fn register_table(&mut self, name: String, table: Arc<dyn TableProvider>) {
self.register_tables.insert(name, table);
}
Expand Down Expand Up @@ -543,6 +592,138 @@ mod test {
Ok(())
}

#[tokio::test]
async fn test_unicode_remote_column_name() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_batch("artist", artist())?;
let manifest = ManifestBuilder::new()
.catalog("wren")
.schema("test")
.model(
ModelBuilder::new("artist")
.table_reference("artist")
.column(ColumnBuilder::new("名字", "string").build())
.column(
ColumnBuilder::new("name_append", "string")
.expression(r#""名字" || "名字""#)
.build(),
)
.column(
ColumnBuilder::new("group", "string")
.expression(r#""組別""#)
.build(),
)
.column(
ColumnBuilder::new("subscribe", "int")
.expression(r#""訂閱數""#)
.build(),
)
.column(
ColumnBuilder::new("subscribe_plus", "int")
.expression(r#""訂閱數" + 1"#)
.build(),
)
.build(),
)
.build();
let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?);
let sql = r#"select * from wren.test.artist"#;
let actual = transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
&[],
sql,
)
.await?;
assert_eq!(actual,
"SELECT artist.\"名字\", artist.name_append, artist.\"group\", artist.subscribe_plus, artist.subscribe FROM \
(SELECT artist.\"名字\" AS \"名字\", artist.\"名字\" || artist.\"名字\" AS name_append, \
artist.\"組別\" AS \"group\", CAST(artist.\"訂閱數\" AS BIGINT) + 1 AS subscribe_plus, artist.\"訂閱數\" AS subscribe FROM artist) \
AS artist");
ctx.sql(&actual).await?.show().await?;

let sql = r#"select group from wren.test.artist"#;
let actual = transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
&[],
sql,
)
.await?;
assert_eq!(actual,
"SELECT artist.\"group\" FROM (SELECT artist.\"group\" FROM (SELECT artist.\"組別\" AS \"group\" FROM artist) AS artist) AS artist");
ctx.sql(&actual).await?.show().await?;

let sql = r#"select subscribe_plus from wren.test.artist"#;
let actual = mdl::transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
&[],
sql,
)
.await?;
assert_eq!(actual,
"SELECT artist.subscribe_plus FROM (SELECT artist.subscribe_plus FROM (SELECT CAST(artist.\"訂閱數\" AS BIGINT) + 1 AS subscribe_plus FROM artist) AS artist) AS artist");
ctx.sql(&actual).await?.show().await
}

#[tokio::test]
async fn test_invalid_infer_remote_table() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_batch("artist", artist())?;
let manifest = ManifestBuilder::new()
.catalog("wren")
.schema("test")
.model(
ModelBuilder::new("artist")
.table_reference("artist")
.column(
ColumnBuilder::new("name_append", "string")
.expression(r#""名字" || "名字""#)
.build(),
)
.column(
ColumnBuilder::new("lower_name", "string")
.expression(r#"lower("名字")"#)
.build(),
)
.build(),
)
.build();

let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?);
let sql = r#"select name_append from wren.test.artist"#;
let _ = transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
&[],
sql,
)
.await
.map_err(|e| {
assert_eq!(
e.to_string(),
"ModelAnalyzeRule\ncaused by\nSchema error: No field named \"名字\"."
)
});

let sql = r#"select lower_name from wren.test.artist"#;
let _ = transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
&[],
sql,
)
.await
.map_err(|e| {
assert_eq!(
e.to_string(),
"ModelAnalyzeRule\ncaused by\nSchema error: No field named \"名字\"."
)
});
Ok(())
}

async fn assert_sql_valid_executable(sql: &str) -> Result<()> {
let ctx = SessionContext::new();
// To roundtrip testing, we should register the mock table for the planned sql.
Expand Down Expand Up @@ -600,4 +781,17 @@ mod test {
])
.unwrap()
}

fn artist() -> RecordBatch {
let name: ArrayRef =
Arc::new(StringArray::from_iter_values(["Ina", "Azki", "Kaela"]));
let group: ArrayRef = Arc::new(StringArray::from_iter_values(["EN", "JP", "ID"]));
let subscribe: ArrayRef = Arc::new(Int64Array::from(vec![100, 200, 300]));
RecordBatch::try_from_iter(vec![
("名字", name),
("組別", group),
("訂閱數", subscribe),
])
.unwrap()
}
}

0 comments on commit a04ba65

Please sign in to comment.