Skip to content

Commit

Permalink
test: add test in Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
grieve54706 committed Dec 3, 2024
1 parent fdec238 commit f217fd8
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 6 deletions.
80 changes: 79 additions & 1 deletion wren-core-py/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions wren-core-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ env_logger = "0.11.5"
log = "0.4.22"
tokio = "1.40.0"

[dev-dependencies]
rstest = "0.23.0"

[build-dependencies]
pyo3-build-config = "0.23.2"

Expand Down
2 changes: 1 addition & 1 deletion wren-core-py/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use pyo3::PyErr;
use std::string::FromUtf8Error;
use thiserror::Error;

#[derive(Error, Debug)]
#[derive(Error, Debug, PartialEq)]
#[error("{message}")]
pub struct CoreError {
message: String,
Expand Down
163 changes: 159 additions & 4 deletions wren-core-py/src/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ pub struct PyExtractor {
#[pymethods]
impl PyExtractor {
#[new]
pub fn new(mdl_base64: &str) -> Result<Self, CoreError> {
let manifest = to_manifest(mdl_base64)?;
let mdl = WrenMDL::new_ref(manifest);
Ok(Self { mdl })
#[pyo3(signature = (mdl_base64=None))]
pub fn new(mdl_base64: Option<&str>) -> Result<Self, CoreError> {
mdl_base64
.ok_or_else(|| CoreError::new("Expected a valid base64 encoded string for the model definition, but got None."))
.and_then(to_manifest)
.map(|manifest| Self {
mdl: WrenMDL::new_ref(manifest),
})
}

/// parse the given SQL and return the list of used table name.
Expand Down Expand Up @@ -154,3 +158,154 @@ fn extract_relationships(
.cloned()
.collect()
}

#[cfg(test)]
mod tests {
use crate::extractor::PyExtractor;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use rstest::{fixture, rstest};
use std::iter::Iterator;

#[fixture]
pub fn mdl_base64() -> String {
let mdl_json = r#"
{
"catalog": "my_catalog",
"schema": "my_schema",
"models": [
{
"name": "customer",
"tableReference": {
"schema": "main",
"table": "customer"
},
"columns": [
{"name": "c_custkey", "type": "integer"},
{"name": "orders", "type": "orders", "relationship": "orders_customer"}
],
"primaryKey": "c_custkey"
},
{
"name": "orders",
"tableReference": {
"schema": "main",
"table": "orders"
},
"columns": [
{"name": "o_orderkey", "type": "integer"},
{"name": "o_custkey", "type": "integer"},
{
"name": "lineitems",
"type": "Lineitem",
"relationship": "orders_lineitem"
}
],
"primaryKey": "o_orderkey"
},
{
"name": "lineitem",
"tableReference": {
"schema": "main",
"table": "lineitem"
},
"columns": [
{"name": "l_orderkey", "type": "integer"}
],
"primaryKey": "l_orderkey"
}
],
"relationships": [
{
"name": "orders_customer",
"models": ["orders", "customer"],
"joinType": "MANY_TO_ONE",
"condition": "orders.custkey = customer.custkey"
},
{
"name": "orders_lineitem",
"models": ["orders", "lineitem"],
"joinType": "ONE_TO_MANY",
"condition": "orders.orderkey = lineitem.orderkey"
}
],
"views": [
{
"name": "customer_view",
"statement": "SELECT * FROM my_catalog.my_schema.customer"
}
]
}"#;
BASE64_STANDARD.encode(mdl_json.as_bytes())
}

#[fixture]
pub fn extractor(mdl_base64: String) -> PyExtractor {
PyExtractor::new(Option::from(mdl_base64.as_str())).unwrap()
}

#[rstest]
#[case(
None,
"Expected a valid base64 encoded string for the model definition, but got None."
)]
#[case(Some("xxx"), "Base64 decode error: Invalid padding")]
#[case(Some("{}"), "Base64 decode error: Invalid symbol 123, offset 0.")]
#[case(
Some(""),
"Serde JSON error: EOF while parsing a value at line 1 column 0"
)]
fn test_extractor_with_invalid_manifest(
#[case] value: Option<&str>,
#[case] error_message: &str,
) {
let result = PyExtractor::new(value);

match result {
Err(err) => {
assert_eq!(err.to_string(), error_message);
}
Ok(_) => panic!("Expected an error but got Ok"),
}
}

#[rstest]
#[case("SELECT * FROM customer", vec!["customer"])]
#[case("SELECT * FROM not_my_catalog.my_schema.customer", vec![])]
#[case("SELECT * FROM my_catalog.not_my_schema.customer", vec![])]
#[case("SELECT * FROM my_catalog.my_schema.customer", vec!["customer"])]
#[case("SELECT * FROM my_catalog.my_schema.customer JOIN my_catalog.my_schema.orders ON customer.custkey = orders.custkey", vec!["customer", "orders"])]
#[case("SELECT * FROM my_catalog.my_schema.customer_view", vec!["customer_view"])]
fn test_resolve_used_table_names(
extractor: PyExtractor,
#[case] sql: &str,
#[case] expected: Vec<&str>,
) {
assert_eq!(extractor.resolve_used_table_names(sql).unwrap(), expected);
}

#[rstest]
#[case(vec!["customer"], vec!["customer", "orders", "lineitem"])]
#[case(vec!["customer_view"], vec!["customer", "orders", "lineitem"])]
#[case(vec!["orders"], vec!["orders", "lineitem"])]
#[case(vec!["lineitem"], vec!["lineitem"])]
fn test_extract_manifest(
extractor: PyExtractor,
#[case] dataset: Vec<&str>,
#[case] expected_models: Vec<&str>,
) {
let dataset_strings: Vec<String> =
dataset.iter().map(|s| s.to_string()).collect();
let extracted_manifest = extractor.extract_manifest(dataset_strings).unwrap();

assert_eq!(extracted_manifest.models.len(), expected_models.len());
assert_eq!(
extracted_manifest
.models
.iter()
.map(|m| m.name.clone())
.collect::<Vec<_>>(),
expected_models
);
}
}
53 changes: 53 additions & 0 deletions wren-core-py/src/manifest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,56 @@ impl From<&View> for PyView {
}
}
}

#[cfg(test)]
mod tests {
use crate::manifest::{to_json_base64, to_manifest, PyManifest};
use rstest::rstest;
use std::sync::Arc;
use wren_core::mdl::manifest::Model;

#[rstest]
fn test_manifest_to_json_base64() {
let py_manifest = PyManifest {
catalog: "catalog".to_string(),
schema: "schema".to_string(),
models: vec![
Arc::from(Model {
name: "model_1".to_string(),
ref_sql: "SELECT * FROM table".to_string().into(),
base_object: None,
table_reference: None,
columns: vec![],
primary_key: None,
cached: false,
refresh_time: None,
}),
Arc::from(Model {
name: "model_2".to_string(),
ref_sql: None,
base_object: None,
table_reference: "catalog.schema.table".to_string().into(),
columns: vec![],
primary_key: None,
cached: false,
refresh_time: None,
}),
],
relationships: vec![],
metrics: vec![],
views: vec![],
};
let base64_str = to_json_base64(py_manifest).unwrap();
let manifest = to_manifest(&base64_str).unwrap();
assert_eq!(manifest.catalog, "catalog");
assert_eq!(manifest.schema, "schema");
assert_eq!(manifest.models.len(), 2);
assert_eq!(manifest.models[0].name, "model_1");
assert_eq!(
manifest.models[0].ref_sql,
Some("SELECT * FROM table".to_string())
);
assert_eq!(manifest.models[1].name(), "model_2");
assert_eq!(manifest.models[1].table_reference(), "catalog.schema.table");
}
}

0 comments on commit f217fd8

Please sign in to comment.