From a04ba65e33e15ce65072d6cc02c6525dd0d9a99e Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Wed, 23 Oct 2024 16:59:16 +0800 Subject: [PATCH] infer the source column from the manifest --- wren-modeling-rs/core/src/mdl/mod.rs | 210 ++++++++++++++++++++++++++- 1 file changed, 202 insertions(+), 8 deletions(-) diff --git a/wren-modeling-rs/core/src/mdl/mod.rs b/wren-modeling-rs/core/src/mdl/mod.rs index 65e4716af..80751a9f1 100644 --- a/wren-modeling-rs/core/src/mdl/mod.rs +++ b/wren-modeling-rs/core/src/mdl/mod.rs @@ -5,12 +5,17 @@ use crate::mdl::function::{ RemoteFunction, }; use crate::mdl::manifest::{Column, Manifest, Model, View}; +use datafusion::arrow::datatypes::Field; +use datafusion::common::internal_datafusion_err; use datafusion::datasource::TableProvider; use datafusion::error::Result; use datafusion::execution::context::SessionState; use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS; use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use datafusion::prelude::SessionContext; +use datafusion::sql::parser::DFParser; +use datafusion::sql::sqlparser::ast::{Expr, Ident}; +use datafusion::sql::sqlparser::dialect::dialect_from_str; use datafusion::sql::unparser::dialect::{Dialect, IntervalStyle}; use datafusion::sql::unparser::Unparser; use datafusion::sql::TableReference; @@ -38,7 +43,7 @@ pub struct AnalyzedWrenMDL { impl AnalyzedWrenMDL { pub fn analyze(manifest: Manifest) -> Result { - let wren_mdl = Arc::new(WrenMDL::new_and_register_table_ref(manifest)); + let wren_mdl = Arc::new(WrenMDL::infer_and_register_remote_table(manifest)); let lineage = Arc::new(lineage::Lineage::new(&wren_mdl)?); Ok(AnalyzedWrenMDL { wren_mdl, lineage }) } @@ -140,7 +145,7 @@ impl WrenMDL { /// Create a WrenMDL from a manifest and register the table reference of the model as a remote table. /// All the column without expression will be considered a column - pub fn new_and_register_table_ref(manifest: Manifest) -> Self { + pub fn infer_and_register_remote_table(manifest: Manifest) -> Self { let mut mdl = WrenMDL::new(manifest); let sources: Vec<_> = mdl .models() @@ -150,12 +155,7 @@ impl WrenMDL { let fields: Vec<_> = model .columns .iter() - .filter(|column| { - !column.is_calculated - && column.expression.is_none() - && column.relationship.is_none() - }) - .map(|column| column.to_field()) + .filter_map(|column| Self::infer_source_column(column)) .collect(); let schema = Arc::new(datafusion::arrow::datatypes::Schema::new(fields)); let datasource = WrenDataSource::new_with_schema(schema); @@ -168,6 +168,55 @@ impl WrenMDL { mdl } + /// Infer the source column from the column expression. + /// + /// If the column is calculated or has a relationship, it's not a source column. + /// If the column without expression, it's a source column. + /// If the column has an expression, it will try to infer the source column from the expression. + /// If the expression is a simple column reference, it's the source column name. + /// If the expression is a complex expression, it can't be inferred. + /// + fn infer_source_column(column: &Column) -> Option { + if column.is_calculated || column.relationship.is_some() { + return None; + } + + if let Some(expression) = column.expression() { + let expr = WrenMDL::sql_to_expr(expression).ok()?; + // if the column is a simple column reference, we can infer the column name + Self::collect_one_column(&expr).map(|name| { + Field::new( + name.value.clone(), + map_data_type(&column.r#type), + column.no_null, + ) + }) + } else { + Some(column.to_field()) + } + } + + fn sql_to_expr(sql: &str) -> Result { + let dialect = dialect_from_str("generic").ok_or_else(|| { + internal_datafusion_err!("Failed to create dialect from generic") + })?; + + let expr = DFParser::parse_sql_into_expr_with_dialect(sql, dialect.as_ref())?; + Ok(expr) + } + + /// Collect the last identifier of the expression + /// e.g. "a"."b"."c" -> c + /// e.g. "a" -> a + /// others -> None + fn collect_one_column(expr: &Expr) -> Option<&Ident> { + match expr { + Expr::CompoundIdentifier(idents) => idents.last(), + Expr::Identifier(ident) => Some(ident), + _ => None, + } + } + pub fn register_table(&mut self, name: String, table: Arc) { self.register_tables.insert(name, table); } @@ -543,6 +592,138 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_unicode_remote_column_name() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_batch("artist", artist())?; + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("artist") + .table_reference("artist") + .column(ColumnBuilder::new("名字", "string").build()) + .column( + ColumnBuilder::new("name_append", "string") + .expression(r#""名字" || "名字""#) + .build(), + ) + .column( + ColumnBuilder::new("group", "string") + .expression(r#""組別""#) + .build(), + ) + .column( + ColumnBuilder::new("subscribe", "int") + .expression(r#""訂閱數""#) + .build(), + ) + .column( + ColumnBuilder::new("subscribe_plus", "int") + .expression(r#""訂閱數" + 1"#) + .build(), + ) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let sql = r#"select * from wren.test.artist"#; + let actual = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await?; + assert_eq!(actual, + "SELECT artist.\"名字\", artist.name_append, artist.\"group\", artist.subscribe_plus, artist.subscribe FROM \ + (SELECT artist.\"名字\" AS \"名字\", artist.\"名字\" || artist.\"名字\" AS name_append, \ + artist.\"組別\" AS \"group\", CAST(artist.\"訂閱數\" AS BIGINT) + 1 AS subscribe_plus, artist.\"訂閱數\" AS subscribe FROM artist) \ + AS artist"); + ctx.sql(&actual).await?.show().await?; + + let sql = r#"select group from wren.test.artist"#; + let actual = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await?; + assert_eq!(actual, + "SELECT artist.\"group\" FROM (SELECT artist.\"group\" FROM (SELECT artist.\"組別\" AS \"group\" FROM artist) AS artist) AS artist"); + ctx.sql(&actual).await?.show().await?; + + let sql = r#"select subscribe_plus from wren.test.artist"#; + let actual = mdl::transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await?; + assert_eq!(actual, + "SELECT artist.subscribe_plus FROM (SELECT artist.subscribe_plus FROM (SELECT CAST(artist.\"訂閱數\" AS BIGINT) + 1 AS subscribe_plus FROM artist) AS artist) AS artist"); + ctx.sql(&actual).await?.show().await + } + + #[tokio::test] + async fn test_invalid_infer_remote_table() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_batch("artist", artist())?; + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("artist") + .table_reference("artist") + .column( + ColumnBuilder::new("name_append", "string") + .expression(r#""名字" || "名字""#) + .build(), + ) + .column( + ColumnBuilder::new("lower_name", "string") + .expression(r#"lower("名字")"#) + .build(), + ) + .build(), + ) + .build(); + + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let sql = r#"select name_append from wren.test.artist"#; + let _ = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await + .map_err(|e| { + assert_eq!( + e.to_string(), + "ModelAnalyzeRule\ncaused by\nSchema error: No field named \"名字\"." + ) + }); + + let sql = r#"select lower_name from wren.test.artist"#; + let _ = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await + .map_err(|e| { + assert_eq!( + e.to_string(), + "ModelAnalyzeRule\ncaused by\nSchema error: No field named \"名字\"." + ) + }); + Ok(()) + } + async fn assert_sql_valid_executable(sql: &str) -> Result<()> { let ctx = SessionContext::new(); // To roundtrip testing, we should register the mock table for the planned sql. @@ -600,4 +781,17 @@ mod test { ]) .unwrap() } + + fn artist() -> RecordBatch { + let name: ArrayRef = + Arc::new(StringArray::from_iter_values(["Ina", "Azki", "Kaela"])); + let group: ArrayRef = Arc::new(StringArray::from_iter_values(["EN", "JP", "ID"])); + let subscribe: ArrayRef = Arc::new(Int64Array::from(vec![100, 200, 300])); + RecordBatch::try_from_iter(vec![ + ("名字", name), + ("組別", group), + ("訂閱數", subscribe), + ]) + .unwrap() + } }