From 4922ca3eab0a4479b6360ec9ebe86fb0cf1b37b9 Mon Sep 17 00:00:00 2001 From: Grieve Date: Wed, 27 Nov 2024 15:22:51 +0800 Subject: [PATCH 01/30] chore: update deps version --- wren-core-py/Cargo.lock | 42 ++++++++++++++++--------------------- wren-core-py/Cargo.toml | 6 +++--- wren-core-py/poetry.lock | 36 +++++++++++++++---------------- wren-core-py/pyproject.toml | 2 +- 4 files changed, 40 insertions(+), 46 deletions(-) diff --git a/wren-core-py/Cargo.lock b/wren-core-py/Cargo.lock index 23e07dfd9..5fc11d3a7 100644 --- a/wren-core-py/Cargo.lock +++ b/wren-core-py/Cargo.lock @@ -1398,12 +1398,6 @@ version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -2168,15 +2162,15 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.21.2" +version = "0.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" +checksum = "f54b3d09cbdd1f8c20650b28e7b09e338881482f4aa908a5f61a00c98fba2690" dependencies = [ "cfg-if", "indoc", "libc", "memoffset", - "parking_lot", + "once_cell", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -2186,9 +2180,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.21.2" +version = "0.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" +checksum = "3015cf985888fe66cfb63ce0e321c603706cd541b7aec7ddd35c281390af45d8" dependencies = [ "once_cell", "python3-dll-a", @@ -2197,9 +2191,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.21.2" +version = "0.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" +checksum = "6fca7cd8fd809b5ac4eefb89c1f98f7a7651d3739dfb341ca6980090f554c270" dependencies = [ "libc", "pyo3-build-config", @@ -2207,9 +2201,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.21.2" +version = "0.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" +checksum = "34e657fa5379a79151b6ff5328d9216a84f55dc93b17b08e7c3609a969b73aa0" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -2219,11 +2213,11 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.21.2" +version = "0.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" +checksum = "295548d5ffd95fd1981d2d3cf4458831b21d60af046b729b6fd143b0ba7aee2f" dependencies = [ - "heck 0.4.1", + "heck", "proc-macro2", "pyo3-build-config", "quote", @@ -2498,7 +2492,7 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn", @@ -2564,7 +2558,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "rustversion", @@ -2620,18 +2614,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.68" +version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02dd99dc800bbb97186339685293e1cc5d9df1f8fae2d0aecd9ff1c77efea892" +checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.68" +version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7c61ec9a6f64d2793d8a45faba21efbe3ced62a886d44c36a009b2b519b4c7e" +checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" dependencies = [ "proc-macro2", "quote", diff --git a/wren-core-py/Cargo.toml b/wren-core-py/Cargo.toml index b0c6fe181..3abe94dee 100644 --- a/wren-core-py/Cargo.toml +++ b/wren-core-py/Cargo.toml @@ -9,14 +9,14 @@ name = "wren_core_py" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.21.2", features = [ +pyo3 = { version = "0.23.2", features = [ "extension-module", "generate-import-lib", ] } wren-core = { path = "../wren-core/core" } base64 = "0.22.1" serde_json = "1.0.117" -thiserror = "1.0" +thiserror = "2.0.3" csv = "1.3.0" serde = { version = "1.0.210", features = ["derive"] } env_logger = "0.11.5" @@ -24,4 +24,4 @@ log = "0.4.22" tokio = "1.40.0" [build-dependencies] -pyo3-build-config = "0.21.2" +pyo3-build-config = "0.23.2" diff --git a/wren-core-py/poetry.lock b/wren-core-py/poetry.lock index cb8adf32a..0b8dcd523 100644 --- a/wren-core-py/poetry.lock +++ b/wren-core-py/poetry.lock @@ -24,24 +24,24 @@ files = [ [[package]] name = "maturin" -version = "1.7.4" +version = "1.7.5" description = "Build and publish crates with pyo3, cffi and uniffi bindings as well as rust binaries as python packages" optional = false python-versions = ">=3.7" files = [ - {file = "maturin-1.7.4-py3-none-linux_armv6l.whl", hash = "sha256:eb7b7753b733ae302c08f80bca7b0c3fda1eea665c2b1922c58795f35a54c833"}, - {file = "maturin-1.7.4-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:0182a9638399c8835afd39d2aeacf56908e37cba3f7abb15816b9df6774fab81"}, - {file = "maturin-1.7.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:41a29c5b23f3ebdfe7633637e3de256579a1b2700c04cd68c16ed46934440c5a"}, - {file = "maturin-1.7.4-py3-none-manylinux_2_12_i686.manylinux2010_i686.musllinux_1_1_i686.whl", hash = "sha256:23fae44e345a2da5cb391ae878726fb793394826e2f97febe41710bd4099460e"}, - {file = "maturin-1.7.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.musllinux_1_1_x86_64.whl", hash = "sha256:8b441521c151f0dbe70ed06fb1feb29b855d787bda038ff4330ca962e5d56641"}, - {file = "maturin-1.7.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:7ccb66d0c5297cf06652c5f72cb398f447d3a332eccf5d1e73b3fe14dbc9498c"}, - {file = "maturin-1.7.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:71f668f19e719048605dbca6a1f4d0dc03b987c922ad9c4bf5be03b9b278e4c3"}, - {file = "maturin-1.7.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.musllinux_1_1_ppc64le.whl", hash = "sha256:c179fcb2b494f19186781b667320e43d95b3e71fcb1c98fffad9ef6bd6e276b3"}, - {file = "maturin-1.7.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fd5b4b95286f2f376437340f8a4908f4761587212170263084455be8099099a7"}, - {file = "maturin-1.7.4-py3-none-win32.whl", hash = "sha256:35487a424467d1fda4567cbb02d21f09febb10eda22f5fd647b130bc0767dc61"}, - {file = "maturin-1.7.4-py3-none-win_amd64.whl", hash = "sha256:f70c1c8ec9bd4749a53c0f3ae8fdbb326ce45be4f1c5551985ee25a6d7150328"}, - {file = "maturin-1.7.4-py3-none-win_arm64.whl", hash = "sha256:f3d38a6d0c7fd7b04bec30dd470b2173cf9bd184ab6220c1acaf49df6b48faf5"}, - {file = "maturin-1.7.4.tar.gz", hash = "sha256:2b349d742a07527d236f0b4b6cab26f53ebecad0ceabfc09ec4c6a396e3176f9"}, + {file = "maturin-1.7.5-py3-none-linux_armv6l.whl", hash = "sha256:e31c4d25b56346c7872417d58cca81e52387a37469cdb79f7225bae9ad75daf9"}, + {file = "maturin-1.7.5-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:e773ade7a1383c24eaf6b665340a91278c80ab544c18687aa69e9661b289cf48"}, + {file = "maturin-1.7.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0d2d04ab5f47c1bc2b075a5d8255d9a72921e8dceebf9f9e9884f09d67f7cdd6"}, + {file = "maturin-1.7.5-py3-none-manylinux_2_12_i686.manylinux2010_i686.musllinux_1_1_i686.whl", hash = "sha256:742cd76a50104fdd832b010a205199e9b02333879f750c0cfca6c93e9472623f"}, + {file = "maturin-1.7.5-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.musllinux_1_1_x86_64.whl", hash = "sha256:9044e5e2eb68bbf8ad86c4ffeab365b78b54bf342ba346dc93775531d3a4e647"}, + {file = "maturin-1.7.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:5563d61cfa2fcd7d1552022df6566300f229fa3aed62020c93a750fa3dca9a99"}, + {file = "maturin-1.7.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:c1002ca9a23c45123af752d353f6b221151a6eab2b5b65d57a79298b7d8ca6d4"}, + {file = "maturin-1.7.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.musllinux_1_1_ppc64le.whl", hash = "sha256:76a78284a96c24cd2d0ac3eac865315b4b0be7a443463fd5b3ebea3c6f147703"}, + {file = "maturin-1.7.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c38e585555be525ebc2602ea7189c7ef3e1c3001c94893e5bc71f934468ff124"}, + {file = "maturin-1.7.5-py3-none-win32.whl", hash = "sha256:f6c80fa7d67f58fd2cecbcdf309e2c3c5cd6f965216191de73af6cf947ef2ab8"}, + {file = "maturin-1.7.5-py3-none-win_amd64.whl", hash = "sha256:c441fe54945fe8077f17cb116834980391169cf712b63631d8380c8c3de781a1"}, + {file = "maturin-1.7.5-py3-none-win_arm64.whl", hash = "sha256:71cbcfd4a74aac3eafe99a1cd73d83af8049f572986ff4e0e5e4d8fec9c66a93"}, + {file = "maturin-1.7.5.tar.gz", hash = "sha256:f05ccbdfe96ad58d70dba9c3eed090726db8ccbaf07ec03852113ca2fec6d84b"}, ] [package.extras] @@ -50,13 +50,13 @@ zig = ["ziglang (>=0.10.0,<0.13.0)"] [[package]] name = "packaging" -version = "24.1" +version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, - {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, + {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, + {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, ] [[package]] @@ -97,4 +97,4 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "005c7c4dcf4dac1015756ca7eadaabdd0e10d9b755a30f7872c2a3939c0ab3c1" +content-hash = "8e45af48010bff32480c1523a46b1b627116738ad479336edb4cd5cdd2a5f795" diff --git a/wren-core-py/pyproject.toml b/wren-core-py/pyproject.toml index 7e998acb1..2f097d7f8 100644 --- a/wren-core-py/pyproject.toml +++ b/wren-core-py/pyproject.toml @@ -10,7 +10,7 @@ authors = ["Canner "] [tool.poetry.dependencies] python = ">=3.11,<3.12" -maturin = "1.7.4" +maturin = "1.7.5" [tool.poetry.group.dev.dependencies] pytest = "8.3.3" From 44dd1c164265a0273eed41eba074048f17972a1a Mon Sep 17 00:00:00 2001 From: Grieve Date: Wed, 27 Nov 2024 15:26:09 +0800 Subject: [PATCH 02/30] chore: remove unused import --- wren-core-py/tests/test_modeling_core.py | 29 ++++++++++++++++-------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index f94357334..4d8969182 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -1,7 +1,6 @@ import base64 import json -import wren_core from wren_core import SessionContext manifest = { @@ -25,31 +24,38 @@ manifest_str = base64.b64encode(json.dumps(manifest).encode("utf-8")).decode("utf-8") + def test_session_context(): session_context = SessionContext(manifest_str, None) sql = "SELECT * FROM my_catalog.my_schema.customer" rewritten_sql = session_context.transform_sql(sql) assert ( - rewritten_sql - == 'SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey AS c_custkey, customer.c_name AS c_name FROM main.customer) AS customer' + rewritten_sql + == "SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey AS c_custkey, customer.c_name AS c_name FROM main.customer) AS customer" ) session_context = SessionContext(manifest_str, "tests/functions.csv") sql = "SELECT add_two(c_custkey) FROM my_catalog.my_schema.customer" rewritten_sql = session_context.transform_sql(sql) assert ( - rewritten_sql - == 'SELECT add_two(customer.c_custkey) FROM (SELECT customer.c_custkey FROM (SELECT customer.c_custkey AS c_custkey FROM main.customer) AS customer) AS customer' + rewritten_sql + == "SELECT add_two(customer.c_custkey) FROM (SELECT customer.c_custkey FROM (SELECT customer.c_custkey AS c_custkey FROM main.customer) AS customer) AS customer" ) + def test_read_function_list(): path = "tests/functions.csv" session_context = SessionContext(manifest_str, path) functions = session_context.get_available_functions() assert len(functions) == 271 - rewritten_sql = session_context.transform_sql("SELECT add_two(c_custkey) FROM my_catalog.my_schema.customer") - assert rewritten_sql == 'SELECT add_two(customer.c_custkey) FROM (SELECT customer.c_custkey FROM (SELECT customer.c_custkey AS c_custkey FROM main.customer) AS customer) AS customer' + rewritten_sql = session_context.transform_sql( + "SELECT add_two(c_custkey) FROM my_catalog.my_schema.customer" + ) + assert ( + rewritten_sql + == "SELECT add_two(customer.c_custkey) FROM (SELECT customer.c_custkey FROM (SELECT customer.c_custkey AS c_custkey FROM main.customer) AS customer) AS customer" + ) session_context = SessionContext(manifest_str, None) functions = session_context.get_available_functions() @@ -59,7 +65,9 @@ def test_read_function_list(): def test_get_available_functions(): session_context = SessionContext(manifest_str, "tests/functions.csv") functions = session_context.get_available_functions() - add_two = next(filter(lambda x: x["name"] == "add_two", map(lambda x: x.to_dict(), functions))) + add_two = next( + filter(lambda x: x["name"] == "add_two", map(lambda x: x.to_dict(), functions)) + ) assert add_two["name"] == "add_two" assert add_two["function_type"] == "scalar" assert add_two["description"] == "Adds two numbers together." @@ -67,9 +75,10 @@ def test_get_available_functions(): assert add_two["param_names"] == "f1,f2" assert add_two["param_types"] == "int,int" - max_if = next(filter(lambda x: x["name"] == "max_if", map(lambda x: x.to_dict(), functions))) + max_if = next( + filter(lambda x: x["name"] == "max_if", map(lambda x: x.to_dict(), functions)) + ) assert max_if["name"] == "max_if" assert max_if["function_type"] == "window" assert max_if["param_names"] is None assert max_if["param_types"] is None - From 8d63cbdbb2beed5767c884e99e0abfe42bf814fa Mon Sep 17 00:00:00 2001 From: Grieve Date: Wed, 27 Nov 2024 16:39:32 +0800 Subject: [PATCH 03/30] chore: adjust for pyo3 0.23.2 --- wren-core-py/src/context.rs | 1 + wren-core-py/src/remote_functions.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/wren-core-py/src/context.rs b/wren-core-py/src/context.rs index 32ce7bd67..adeb3e361 100644 --- a/wren-core-py/src/context.rs +++ b/wren-core-py/src/context.rs @@ -67,6 +67,7 @@ impl PySessionContext { /// if `mdl_base64` is provided, the session context will be created with the given MDL. Otherwise, an empty MDL will be created. /// if `remote_functions_path` is provided, the session context will be created with the remote functions defined in the CSV file. #[new] + #[pyo3(signature = (mdl_base64=None, remote_functions_path=None))] pub fn new( mdl_base64: Option<&str>, remote_functions_path: Option<&str>, diff --git a/wren-core-py/src/remote_functions.rs b/wren-core-py/src/remote_functions.rs index 2d96f0fe0..f7980fca6 100644 --- a/wren-core-py/src/remote_functions.rs +++ b/wren-core-py/src/remote_functions.rs @@ -38,7 +38,7 @@ pub struct PyRemoteFunction { #[pymethods] impl PyRemoteFunction { pub fn to_dict(&self, py: Python) -> PyObject { - let dict = PyDict::new_bound(py); + let dict = PyDict::new(py); dict.set_item("function_type", self.function_type.clone()) .unwrap(); dict.set_item("name", self.name.clone()).unwrap(); From 9065118959df8c6f680e3008d5e5b5600e76200d Mon Sep 17 00:00:00 2001 From: Grieve Date: Fri, 29 Nov 2024 15:40:31 +0800 Subject: [PATCH 04/30] feat(extract-mdl): implement new function to extract used models --- wren-core-py/src/context.rs | 136 ++++++++++++ wren-core-py/src/lib.rs | 2 + wren-core-py/src/manifest.rs | 269 +++++++++++++++++++++++ wren-core-py/tests/test_modeling_core.py | 84 +++++++ wren-core/core/src/mdl/manifest.rs | 4 +- wren-core/core/src/mdl/mod.rs | 14 +- 6 files changed, 506 insertions(+), 3 deletions(-) create mode 100644 wren-core-py/src/manifest.rs diff --git a/wren-core-py/src/context.rs b/wren-core-py/src/context.rs index adeb3e361..728377f2a 100644 --- a/wren-core-py/src/context.rs +++ b/wren-core-py/src/context.rs @@ -16,6 +16,7 @@ // under the License. use crate::errors::CoreError; +use crate::manifest::PyManifest; use crate::remote_functions::PyRemoteFunction; use base64::prelude::BASE64_STANDARD; use base64::Engine; @@ -195,6 +196,39 @@ impl PySessionContext { }); Ok(builder.values().cloned().collect()) } + + /// parse the given SQL and return the list of used table name. + pub fn resolve_used_table_names(&self, sql: &str) -> Result, CoreError> { + let mdl = self.mdl.wren_mdl(); + self.ctx + .state() + .sql_to_statement(sql, "generic") + .map_err(CoreError::from) + .and_then(|stmt| { + self.ctx + .state() + .resolve_table_references(&stmt) + .map_err(CoreError::from) + }) + .map(|tables| { + tables + .iter() + .filter(|t| { + t.catalog().unwrap_or_default() == mdl.catalog() + && t.schema().unwrap_or_default() == mdl.schema() + }) + .map(|t| t.table().to_string()) + .collect() + }) + } + + /// Given a used dataset list, extract manifest by removing unused datasets. + /// If a model is related to another dataset, both datasets will be kept. + /// The relationship between of them will be kept as well. + /// A dataset could be model, view. + pub fn extract_manifest(&self, used_datasets: Vec) -> PyResult { + Ok(extractor::extract_manifest(self, &used_datasets)) + } } impl PySessionContext { @@ -244,3 +278,105 @@ impl PySessionContext { } } } + +mod extractor { + use crate::context::PySessionContext; + use crate::manifest::PyManifest; + use std::collections::HashSet; + use std::sync::Arc; + use wren_core::mdl::manifest::{Model, Relationship, View}; + use wren_core::mdl::WrenMDL; + + pub fn extract_manifest( + ctx: &PySessionContext, + used_datasets: &[String], + ) -> PyManifest { + let mdl = Arc::clone(&ctx.mdl).wren_mdl(); + let used_models = extract_models(&mdl, used_datasets); + let (used_views, models_of_views) = extract_views(&ctx, &mdl, used_datasets); + let used_relationships = extract_relationships(&mdl, used_datasets); + PyManifest { + catalog: mdl.catalog().to_string(), + schema: mdl.schema().to_string(), + models: [used_models, models_of_views].concat(), + relationships: used_relationships, + metrics: mdl.metrics().to_vec(), + views: used_views, + } + } + + fn extract_models(mdl: &Arc, used_datasets: &[String]) -> Vec> { + let mut used_set: HashSet = used_datasets.iter().cloned().collect(); + let mut stack: Vec = used_datasets.to_vec(); + while let Some(dataset_name) = stack.pop() { + if let Some(model) = mdl.get_model(&dataset_name) { + model + .columns + .iter() + .filter_map(|col| { + col.relationship + .as_ref() + .and_then(|rel_name| mdl.get_relationship(rel_name)) + }) + .flat_map(|rel| rel.models.clone()) + .filter(|related| used_set.insert(related.clone())) + .for_each(|related| stack.push(related)); + } + } + mdl.models() + .iter() + .filter(|model| used_set.contains(model.name())) + .cloned() + .collect() + } + + fn extract_views( + ctx: &PySessionContext, + mdl: &Arc, + used_datasets: &[String], + ) -> (Vec>, Vec>) { + let used_set: HashSet<&str> = used_datasets.iter().map(String::as_str).collect(); + let stack: Vec<&str> = used_datasets.iter().map(String::as_str).collect(); + let models = stack + .iter() + .filter_map(|&dataset_name| { + mdl.get_view(dataset_name).and_then(|view| { + ctx.resolve_used_table_names(&view.statement) + .ok() + .map(|used_tables| extract_models(mdl, &used_tables)) + }) + }) + .flatten() + .collect::>(); + let views = mdl + .views() + .iter() + .filter(|view| used_set.contains(view.name())) + .cloned() + .collect(); + + (views, models) + } + + fn extract_relationships( + mdl: &Arc, + used_datasets: &[String], + ) -> Vec> { + let mut used_set: HashSet = used_datasets.iter().cloned().collect(); + let mut stack: Vec = used_datasets.to_vec(); + while let Some(dataset_name) = stack.pop() { + if let Some(relationship) = mdl.get_relationship(&dataset_name) { + for model in &relationship.models { + if used_set.insert(model.clone()) { + stack.push(model.clone()); + } + } + } + } + mdl.relationships() + .iter() + .filter(|rel| rel.models.iter().any(|model| used_set.contains(model))) + .cloned() + .collect() + } +} diff --git a/wren-core-py/src/lib.rs b/wren-core-py/src/lib.rs index b7f7a207a..30834798c 100644 --- a/wren-core-py/src/lib.rs +++ b/wren-core-py/src/lib.rs @@ -4,6 +4,7 @@ use remote_functions::PyRemoteFunction; pub mod context; mod errors; +mod manifest; pub mod remote_functions; #[pymodule] @@ -12,5 +13,6 @@ fn wren_core_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> { env_logger::init(); m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/wren-core-py/src/manifest.rs b/wren-core-py/src/manifest.rs new file mode 100644 index 000000000..13771c01d --- /dev/null +++ b/wren-core-py/src/manifest.rs @@ -0,0 +1,269 @@ +use pyo3::{pyclass, pymethods, PyResult}; +use serde::{Deserialize, Serialize}; +use std::iter::Iterator; +use std::sync::Arc; +use wren_core::mdl::manifest::{ + Column, JoinType, Metric, Model, Relationship, TimeGrain, TimeUnit, View, +}; + +#[pyclass(name = "Manifest")] +#[derive(Serialize, Deserialize)] +pub struct PyManifest { + pub catalog: String, + pub schema: String, + pub models: Vec>, + pub relationships: Vec>, + pub metrics: Vec>, + pub views: Vec>, +} + +#[pymethods] +impl PyManifest { + #[getter] + fn catalog(&self) -> PyResult { + Ok(self.catalog.clone()) + } + + #[getter] + fn schema(&self) -> PyResult { + Ok(self.schema.clone()) + } + + #[getter] + fn models(&self) -> PyResult> { + Ok(self + .models + .iter() + .map(|m| PyModel::from(m.as_ref())) + .collect()) + } + + #[getter] + fn relationships(&self) -> PyResult> { + Ok(self + .relationships + .iter() + .map(|r| PyRelationship::from(r.as_ref())) + .collect()) + } + + #[getter] + fn metrics(&self) -> PyResult> { + Ok(self + .metrics + .iter() + .map(|m| PyMetric::from(m.as_ref())) + .collect()) + } + + #[getter] + fn views(&self) -> PyResult> { + Ok(self + .views + .iter() + .map(|v| PyView::from(v.as_ref())) + .collect()) + } +} + +#[pyclass(name = "Model")] +#[derive(Serialize, Deserialize)] +pub struct PyModel { + #[pyo3(get)] + pub name: String, + #[pyo3(get)] + pub ref_sql: Option, + #[pyo3(get)] + pub base_object: Option, + #[pyo3(get)] + pub table_reference: Option, + pub columns: Vec>, + #[pyo3(get)] + pub primary_key: Option, + #[pyo3(get)] + pub cached: bool, + #[pyo3(get)] + pub refresh_time: Option, +} + +#[pymethods] +impl PyModel { + #[getter] + fn columns(&self) -> PyResult> { + Ok(self + .columns + .iter() + .map(|c| PyColumn::from(c.as_ref())) + .collect()) + } +} + +impl From<&Model> for PyModel { + fn from(model: &Model) -> Self { + Self { + name: model.name.clone(), + ref_sql: model.ref_sql.clone(), + base_object: model.base_object.clone(), + table_reference: model.table_reference.clone(), + columns: model.columns.clone(), + primary_key: model.primary_key.clone(), + cached: model.cached, + refresh_time: model.refresh_time.clone(), + } + } +} + +#[pyclass(name = "Column")] +#[derive(Serialize, Deserialize)] +pub struct PyColumn { + #[pyo3(get)] + pub name: String, + #[pyo3(get)] + pub r#type: String, + #[pyo3(get)] + pub relationship: Option, + #[pyo3(get)] + pub is_calculated: bool, + #[pyo3(get)] + pub not_null: bool, + #[pyo3(get)] + pub expression: Option, + #[pyo3(get)] + pub is_hidden: bool, +} + +impl From<&Column> for PyColumn { + fn from(column: &Column) -> Self { + Self { + name: column.name.clone(), + r#type: column.r#type.clone(), + relationship: column.relationship.clone(), + is_calculated: column.is_calculated, + not_null: column.not_null, + expression: column.expression.clone(), + is_hidden: column.is_hidden, + } + } +} + +#[pyclass(name = "Relationship")] +#[derive(Serialize, Deserialize)] +pub struct PyRelationship { + #[pyo3(get)] + pub name: String, + #[pyo3(get)] + pub models: Vec, + pub join_type: JoinType, + #[pyo3(get)] + pub condition: String, +} + +#[pymethods] +impl PyRelationship { + #[getter] + fn join_type(&self) -> PyResult { + Ok(PyJoinType::from(&self.join_type)) + } +} + +impl From<&Relationship> for PyRelationship { + fn from(relationship: &Relationship) -> Self { + Self { + name: relationship.name.clone(), + models: relationship.models.clone(), + join_type: relationship.join_type, + condition: relationship.condition.clone(), + } + } +} + +#[pyclass(name = "JoinType", eq)] +#[derive(Serialize, Deserialize, PartialEq, Eq)] +pub enum PyJoinType { + #[serde(alias = "one_to_one")] + OneToOne, + #[serde(alias = "one_to_many")] + OneToMany, + #[serde(alias = "many_to_one")] + ManyToOne, + #[serde(alias = "many_to_many")] + ManyToMany, +} + +impl From<&JoinType> for PyJoinType { + fn from(join_type: &JoinType) -> Self { + match join_type { + JoinType::OneToOne => PyJoinType::OneToOne, + JoinType::OneToMany => PyJoinType::OneToMany, + JoinType::ManyToOne => PyJoinType::ManyToOne, + JoinType::ManyToMany => PyJoinType::ManyToMany, + } + } +} + +#[pyclass(name = "Metric")] +#[derive(Serialize, Deserialize)] +pub struct PyMetric { + #[pyo3(get)] + pub name: String, + #[pyo3(get)] + pub base_object: String, + pub dimension: Vec>, + pub measure: Vec>, + pub time_grain: Vec, + #[pyo3(get)] + pub cached: bool, + #[pyo3(get)] + pub refresh_time: Option, +} + +impl From<&Metric> for PyMetric { + fn from(metric: &Metric) -> Self { + Self { + name: metric.name.clone(), + base_object: metric.base_object.clone(), + dimension: metric.dimension.clone(), + measure: metric.measure.clone(), + time_grain: metric.time_grain.clone(), + cached: metric.cached, + refresh_time: metric.refresh_time.clone(), + } + } +} + +#[pyclass(name = "TimeGrain")] +#[derive(Serialize, Deserialize)] +pub struct PyTimeGrain { + pub name: String, + pub ref_column: String, + pub date_parts: Vec, +} + +#[pyclass(name = "TimeUnit", eq)] +#[derive(Serialize, Deserialize, PartialEq, Eq)] +pub enum PyTimeUnit { + Year, + Month, + Day, + Hour, + Minute, + Second, +} + +#[pyclass(name = "View")] +#[derive(Serialize, Deserialize)] +pub struct PyView { + #[pyo3(get)] + pub name: String, + #[pyo3(get)] + pub statement: String, +} + +impl From<&View> for PyView { + fn from(view: &View) -> Self { + Self { + name: view.name.clone(), + statement: view.statement.clone(), + } + } +} diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index 4d8969182..54aed8c72 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -1,6 +1,7 @@ import base64 import json +import pytest from wren_core import SessionContext manifest = { @@ -16,9 +17,61 @@ "columns": [ {"name": "c_custkey", "type": "integer"}, {"name": "c_name", "type": "varchar"}, + {"name": "orders", "type": "orders", "relationship": "orders_customer"}, ], "primaryKey": "c_custkey", }, + { + "name": "orders", + "tableReference": { + "schema": "main", + "table": "orders", + }, + "columns": [ + {"name": "o_orderkey", "type": "integer"}, + {"name": "o_custkey", "type": "integer"}, + {"name": "o_orderdate", "type": "date"}, + { + "name": "lineitems", + "type": "Lineitem", + "relationship": "orders_lineitem", + }, + ], + "primaryKey": "o_orderkey", + }, + { + "name": "lineitem", + "tableReference": { + "schema": "main", + "table": "lineitem", + }, + "columns": [ + {"name": "l_orderkey", "type": "integer"}, + {"name": "l_quantity", "type": "decimal"}, + {"name": "l_extendedprice", "type": "decimal"}, + ], + "primaryKey": "l_orderkey", + }, + ], + "relationships": [ + { + "name": "orders_customer", + "models": ["orders", "customer"], + "joinType": "MANY_TO_ONE", + "condition": "orders.custkey = customer.custkey", + }, + { + "name": "orders_lineitem", + "models": ["orders", "lineitem"], + "joinType": "ONE_TO_MANY", + "condition": "orders.orderkey = lineitem.orderkey", + }, + ], + "views": [ + { + "name": "customer_view", + "statement": "SELECT * FROM my_catalog.my_schema.customer", + }, ], } @@ -82,3 +135,34 @@ def test_get_available_functions(): assert max_if["function_type"] == "window" assert max_if["param_names"] is None assert max_if["param_types"] is None + + +@pytest.mark.parametrize( + ("sql", "expected"), + [ + ("SELECT * FROM my_catalog.my_schema.customer", ["customer"]), + ( + "SELECT * FROM my_catalog.my_schema.customer JOIN my_catalog.my_schema.orders ON customer.custkey = orders.custkey", + ["customer", "orders"], + ), + ("SELECT * FROM my_catalog.my_schema.customer_view", ["customer_view"]), + ], +) +def test_resolve_used_table_names(sql, expected): + tables = SessionContext(manifest_str, None).resolve_used_table_names(sql) + assert tables == expected + + +@pytest.mark.parametrize( + ("dataset", "expected_models"), + [ + (["customer"], ["customer", "orders", "lineitem"]), + (["customer_view"], ["customer", "orders", "lineitem"]), + (["orders"], ["orders", "lineitem"]), + (["lineitem"], ["lineitem"]), + ], +) +def test_extract_manifest(dataset, expected_models): + extracted_manifest = SessionContext(manifest_str, None).extract_manifest(dataset) + assert len(extracted_manifest.models) == len(expected_models) + assert [m.name for m in extracted_manifest.models] == expected_models diff --git a/wren-core/core/src/mdl/manifest.rs b/wren-core/core/src/mdl/manifest.rs index 4ccd89006..38d6cb221 100644 --- a/wren-core/core/src/mdl/manifest.rs +++ b/wren-core/core/src/mdl/manifest.rs @@ -229,7 +229,7 @@ impl Metric { } } -#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash)] +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)] #[serde(rename_all = "camelCase")] pub struct TimeGrain { pub name: String, @@ -237,7 +237,7 @@ pub struct TimeGrain { pub date_parts: Vec, } -#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash)] +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)] pub enum TimeUnit { Year, Month, diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 10ffd41c6..94cdd6f31 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -6,7 +6,7 @@ use crate::mdl::function::{ ByPassAggregateUDF, ByPassScalarUDF, ByPassWindowFunction, FunctionType, RemoteFunction, }; -use crate::mdl::manifest::{Column, Manifest, Model, View}; +use crate::mdl::manifest::{Column, Manifest, Metric, Model, View}; use crate::DataFusionError; use datafusion::arrow::datatypes::Field; use datafusion::common::internal_datafusion_err; @@ -269,6 +269,18 @@ impl WrenMDL { &self.manifest.models } + pub fn views(&self) -> &[Arc] { + &self.manifest.views + } + + pub fn relationships(&self) -> &[Arc] { + &self.manifest.relationships + } + + pub fn metrics(&self) -> &[Arc] { + &self.manifest.metrics + } + pub fn get_model(&self, name: &str) -> Option> { self.manifest .models From ed3a4d7ac662445aa2e7d5d32336d72cff033d19 Mon Sep 17 00:00:00 2001 From: Grieve Date: Fri, 29 Nov 2024 15:41:16 +0800 Subject: [PATCH 05/30] chore: adjust pyo3 setting --- wren-core-py/Cargo.lock | 10 ---------- wren-core-py/Cargo.toml | 9 +++++---- wren-core-py/justfile | 4 ++-- wren-core-py/pyproject.toml | 1 + 4 files changed, 8 insertions(+), 16 deletions(-) diff --git a/wren-core-py/Cargo.lock b/wren-core-py/Cargo.lock index 5fc11d3a7..962c2c12f 100644 --- a/wren-core-py/Cargo.lock +++ b/wren-core-py/Cargo.lock @@ -2185,7 +2185,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3015cf985888fe66cfb63ce0e321c603706cd541b7aec7ddd35c281390af45d8" dependencies = [ "once_cell", - "python3-dll-a", "target-lexicon", ] @@ -2224,15 +2223,6 @@ dependencies = [ "syn", ] -[[package]] -name = "python3-dll-a" -version = "0.2.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd0b78171a90d808b319acfad166c4790d9e9759bbc14ac8273fe133673dd41b" -dependencies = [ - "cc", -] - [[package]] name = "quote" version = "1.0.37" diff --git a/wren-core-py/Cargo.toml b/wren-core-py/Cargo.toml index 3abe94dee..1e77fc2fa 100644 --- a/wren-core-py/Cargo.toml +++ b/wren-core-py/Cargo.toml @@ -9,10 +9,7 @@ name = "wren_core_py" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.23.2", features = [ - "extension-module", - "generate-import-lib", -] } +pyo3 = { version = "0.23.2", features = ["extension-module"] } wren-core = { path = "../wren-core/core" } base64 = "0.22.1" serde_json = "1.0.117" @@ -25,3 +22,7 @@ tokio = "1.40.0" [build-dependencies] pyo3-build-config = "0.23.2" + +[features] +extension-module = ["pyo3/extension-module"] +default = ["extension-module"] diff --git a/wren-core-py/justfile b/wren-core-py/justfile index d10453b40..509cc05bd 100644 --- a/wren-core-py/justfile +++ b/wren-core-py/justfile @@ -5,11 +5,11 @@ install: poetry install --no-root build *args: - poetry run maturin build {{args}} + poetry run maturin build {{ args }} develop: poetry run maturin develop test: develop - cargo test + cargo test --no-default-features poetry run pytest diff --git a/wren-core-py/pyproject.toml b/wren-core-py/pyproject.toml index 2f097d7f8..c8efa218b 100644 --- a/wren-core-py/pyproject.toml +++ b/wren-core-py/pyproject.toml @@ -20,6 +20,7 @@ module-name = "wren_core" include = [{ path = "Cargo.lock", format = "sdist" }] exclude = ["tests/**", "target/**"] locked = true +features = ["pyo3/extension-module"] [build-system] requires = ["maturin>=1.0,<2.0"] From c48eecac062f323ac9586176a39f64e82ffa0404 Mon Sep 17 00:00:00 2001 From: Grieve Date: Fri, 29 Nov 2024 18:07:47 +0800 Subject: [PATCH 06/30] chore: pull out resolve_used_table_names and extract_manifest --- wren-core-py/src/context.rs | 147 +---------------------- wren-core-py/src/extractor.rs | 143 ++++++++++++++++++++++ wren-core-py/src/lib.rs | 4 + wren-core-py/src/manifest.rs | 23 +++- wren-core-py/tests/test_modeling_core.py | 22 +++- wren-core/core/src/mdl/manifest.rs | 2 +- 6 files changed, 189 insertions(+), 152 deletions(-) create mode 100644 wren-core-py/src/extractor.rs diff --git a/wren-core-py/src/context.rs b/wren-core-py/src/context.rs index 728377f2a..f17b77a81 100644 --- a/wren-core-py/src/context.rs +++ b/wren-core-py/src/context.rs @@ -16,10 +16,8 @@ // under the License. use crate::errors::CoreError; -use crate::manifest::PyManifest; +use crate::manifest::to_manifest; use crate::remote_functions::PyRemoteFunction; -use base64::prelude::BASE64_STANDARD; -use base64::Engine; use log::debug; use pyo3::{pyclass, pymethods, PyErr, PyResult}; use std::collections::hash_map::Entry; @@ -32,7 +30,6 @@ use wren_core::mdl::function::{ ByPassAggregateUDF, ByPassScalarUDF, ByPassWindowFunction, FunctionType, RemoteFunction, }; -use wren_core::mdl::manifest::Manifest; use wren_core::{mdl, AggregateUDF, AnalyzedWrenMDL, ScalarUDF, WindowUDF}; /// The Python wrapper for the Wren Core session context. @@ -90,12 +87,7 @@ impl PySessionContext { }); }; - let mdl_json_bytes = BASE64_STANDARD - .decode(mdl_base64) - .map_err(CoreError::from)?; - let mdl_json = String::from_utf8(mdl_json_bytes).map_err(CoreError::from)?; - let manifest = - serde_json::from_str::(&mdl_json).map_err(CoreError::from)?; + let manifest = to_manifest(mdl_base64)?; let Ok(analyzed_mdl) = AnalyzedWrenMDL::analyze(manifest) else { return Err(CoreError::new("Failed to analyze manifest").into()); @@ -196,39 +188,6 @@ impl PySessionContext { }); Ok(builder.values().cloned().collect()) } - - /// parse the given SQL and return the list of used table name. - pub fn resolve_used_table_names(&self, sql: &str) -> Result, CoreError> { - let mdl = self.mdl.wren_mdl(); - self.ctx - .state() - .sql_to_statement(sql, "generic") - .map_err(CoreError::from) - .and_then(|stmt| { - self.ctx - .state() - .resolve_table_references(&stmt) - .map_err(CoreError::from) - }) - .map(|tables| { - tables - .iter() - .filter(|t| { - t.catalog().unwrap_or_default() == mdl.catalog() - && t.schema().unwrap_or_default() == mdl.schema() - }) - .map(|t| t.table().to_string()) - .collect() - }) - } - - /// Given a used dataset list, extract manifest by removing unused datasets. - /// If a model is related to another dataset, both datasets will be kept. - /// The relationship between of them will be kept as well. - /// A dataset could be model, view. - pub fn extract_manifest(&self, used_datasets: Vec) -> PyResult { - Ok(extractor::extract_manifest(self, &used_datasets)) - } } impl PySessionContext { @@ -278,105 +237,3 @@ impl PySessionContext { } } } - -mod extractor { - use crate::context::PySessionContext; - use crate::manifest::PyManifest; - use std::collections::HashSet; - use std::sync::Arc; - use wren_core::mdl::manifest::{Model, Relationship, View}; - use wren_core::mdl::WrenMDL; - - pub fn extract_manifest( - ctx: &PySessionContext, - used_datasets: &[String], - ) -> PyManifest { - let mdl = Arc::clone(&ctx.mdl).wren_mdl(); - let used_models = extract_models(&mdl, used_datasets); - let (used_views, models_of_views) = extract_views(&ctx, &mdl, used_datasets); - let used_relationships = extract_relationships(&mdl, used_datasets); - PyManifest { - catalog: mdl.catalog().to_string(), - schema: mdl.schema().to_string(), - models: [used_models, models_of_views].concat(), - relationships: used_relationships, - metrics: mdl.metrics().to_vec(), - views: used_views, - } - } - - fn extract_models(mdl: &Arc, used_datasets: &[String]) -> Vec> { - let mut used_set: HashSet = used_datasets.iter().cloned().collect(); - let mut stack: Vec = used_datasets.to_vec(); - while let Some(dataset_name) = stack.pop() { - if let Some(model) = mdl.get_model(&dataset_name) { - model - .columns - .iter() - .filter_map(|col| { - col.relationship - .as_ref() - .and_then(|rel_name| mdl.get_relationship(rel_name)) - }) - .flat_map(|rel| rel.models.clone()) - .filter(|related| used_set.insert(related.clone())) - .for_each(|related| stack.push(related)); - } - } - mdl.models() - .iter() - .filter(|model| used_set.contains(model.name())) - .cloned() - .collect() - } - - fn extract_views( - ctx: &PySessionContext, - mdl: &Arc, - used_datasets: &[String], - ) -> (Vec>, Vec>) { - let used_set: HashSet<&str> = used_datasets.iter().map(String::as_str).collect(); - let stack: Vec<&str> = used_datasets.iter().map(String::as_str).collect(); - let models = stack - .iter() - .filter_map(|&dataset_name| { - mdl.get_view(dataset_name).and_then(|view| { - ctx.resolve_used_table_names(&view.statement) - .ok() - .map(|used_tables| extract_models(mdl, &used_tables)) - }) - }) - .flatten() - .collect::>(); - let views = mdl - .views() - .iter() - .filter(|view| used_set.contains(view.name())) - .cloned() - .collect(); - - (views, models) - } - - fn extract_relationships( - mdl: &Arc, - used_datasets: &[String], - ) -> Vec> { - let mut used_set: HashSet = used_datasets.iter().cloned().collect(); - let mut stack: Vec = used_datasets.to_vec(); - while let Some(dataset_name) = stack.pop() { - if let Some(relationship) = mdl.get_relationship(&dataset_name) { - for model in &relationship.models { - if used_set.insert(model.clone()) { - stack.push(model.clone()); - } - } - } - } - mdl.relationships() - .iter() - .filter(|rel| rel.models.iter().any(|model| used_set.contains(model))) - .cloned() - .collect() - } -} diff --git a/wren-core-py/src/extractor.rs b/wren-core-py/src/extractor.rs new file mode 100644 index 000000000..6c6d59ccf --- /dev/null +++ b/wren-core-py/src/extractor.rs @@ -0,0 +1,143 @@ +use crate::errors::CoreError; +use crate::manifest::{to_manifest, PyManifest}; +use pyo3::pyfunction; +use std::collections::HashSet; +use std::sync::Arc; +use wren_core::mdl::manifest::{Manifest, Model, Relationship, View}; +use wren_core::mdl::WrenMDL; + +/// parse the given SQL and return the list of used table name. +#[pyfunction] +#[pyo3(name = "resolve_used_table_names", signature = (mdl_base64, sql), text_signature = "(mdl_base64: str, sql: str)")] +pub fn py_resolve_used_table_names( + mdl_base64: &str, + sql: &str, +) -> Result, CoreError> { + let manifest = to_manifest(mdl_base64)?; + resolve_used_table_names(manifest, sql) +} + +fn resolve_used_table_names( + manifest: Manifest, + sql: &str, +) -> Result, CoreError> { + let mdl = WrenMDL::new_ref(manifest); + let ctx_state = wren_core::SessionContext::new().state(); + ctx_state + .sql_to_statement(sql, "generic") + .map_err(CoreError::from) + .and_then(|stmt| { + ctx_state + .resolve_table_references(&stmt) + .map_err(CoreError::from) + }) + .map(|tables| { + tables + .iter() + .filter(|t| { + t.catalog().unwrap_or_default() == mdl.catalog() + && t.schema().unwrap_or_default() == mdl.schema() + }) + .map(|t| t.table().to_string()) + .collect() + }) +} + +/// Given a used dataset list, extract manifest by removing unused datasets. +/// If a model is related to another dataset, both datasets will be kept. +/// The relationship between of them will be kept as well. +/// A dataset could be model, view. +#[pyfunction] +#[pyo3(signature = (mdl_base64, used_datasets), text_signature = "(mdl_base64: str, used_datasets: list[str])")] +pub fn extract_manifest( + mdl_base64: &str, + used_datasets: Vec, +) -> Result { + let manifest = to_manifest(mdl_base64)?; + let mdl = WrenMDL::new_ref(manifest); + let used_models = extract_models(&mdl, &used_datasets); + let (used_views, models_of_views) = extract_views(&mdl, &used_datasets); + let used_relationships = extract_relationships(&mdl, &used_datasets); + Ok(PyManifest { + catalog: mdl.catalog().to_string(), + schema: mdl.schema().to_string(), + models: [used_models, models_of_views].concat(), + relationships: used_relationships, + metrics: mdl.metrics().to_vec(), + views: used_views, + }) +} + +fn extract_models(mdl: &Arc, used_datasets: &[String]) -> Vec> { + let mut used_set: HashSet = used_datasets.iter().cloned().collect(); + let mut stack: Vec = used_datasets.to_vec(); + while let Some(dataset_name) = stack.pop() { + if let Some(model) = mdl.get_model(&dataset_name) { + model + .columns + .iter() + .filter_map(|col| { + col.relationship + .as_ref() + .and_then(|rel_name| mdl.get_relationship(rel_name)) + }) + .flat_map(|rel| rel.models.clone()) + .filter(|related| used_set.insert(related.clone())) + .for_each(|related| stack.push(related)); + } + } + mdl.models() + .iter() + .filter(|model| used_set.contains(model.name())) + .cloned() + .collect() +} + +fn extract_views( + mdl: &Arc, + used_datasets: &[String], +) -> (Vec>, Vec>) { + let used_set: HashSet<&str> = used_datasets.iter().map(String::as_str).collect(); + let stack: Vec<&str> = used_datasets.iter().map(String::as_str).collect(); + let models = stack + .iter() + .filter_map(|&dataset_name| { + mdl.get_view(dataset_name).and_then(|view| { + resolve_used_table_names(mdl.manifest.clone(), view.statement.as_str()) + .ok() + .map(|used_tables| extract_models(mdl, &used_tables)) + }) + }) + .flatten() + .collect::>(); + let views = mdl + .views() + .iter() + .filter(|view| used_set.contains(view.name())) + .cloned() + .collect(); + + (views, models) +} + +fn extract_relationships( + mdl: &Arc, + used_datasets: &[String], +) -> Vec> { + let mut used_set: HashSet = used_datasets.iter().cloned().collect(); + let mut stack: Vec = used_datasets.to_vec(); + while let Some(dataset_name) = stack.pop() { + if let Some(relationship) = mdl.get_relationship(&dataset_name) { + for model in &relationship.models { + if used_set.insert(model.clone()) { + stack.push(model.clone()); + } + } + } + } + mdl.relationships() + .iter() + .filter(|rel| rel.models.iter().any(|model| used_set.contains(model))) + .cloned() + .collect() +} diff --git a/wren-core-py/src/lib.rs b/wren-core-py/src/lib.rs index 30834798c..05c02c2e1 100644 --- a/wren-core-py/src/lib.rs +++ b/wren-core-py/src/lib.rs @@ -4,6 +4,7 @@ use remote_functions::PyRemoteFunction; pub mod context; mod errors; +mod extractor; mod manifest; pub mod remote_functions; @@ -14,5 +15,8 @@ fn wren_core_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_function(wrap_pyfunction!(extractor::py_resolve_used_table_names, m)?)?; + m.add_function(wrap_pyfunction!(extractor::extract_manifest, m)?)?; + m.add_function(wrap_pyfunction!(manifest::to_json_base64, m)?)?; Ok(()) } diff --git a/wren-core-py/src/manifest.rs b/wren-core-py/src/manifest.rs index 13771c01d..18625c21c 100644 --- a/wren-core-py/src/manifest.rs +++ b/wren-core-py/src/manifest.rs @@ -1,13 +1,30 @@ -use pyo3::{pyclass, pymethods, PyResult}; +use crate::errors::CoreError; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use pyo3::{pyclass, pyfunction, pymethods, PyResult}; use serde::{Deserialize, Serialize}; use std::iter::Iterator; use std::sync::Arc; use wren_core::mdl::manifest::{ - Column, JoinType, Metric, Model, Relationship, TimeGrain, TimeUnit, View, + Column, JoinType, Manifest, Metric, Model, Relationship, TimeGrain, TimeUnit, View, }; +#[pyfunction] +pub fn to_json_base64(mdl: PyManifest) -> Result { + let mdl_json = serde_json::to_string(&mdl)?; + let mdl_base64 = BASE64_STANDARD.encode(mdl_json.as_bytes()); + Ok(mdl_base64) +} + +pub fn to_manifest(mdl_base64: &str) -> Result { + let decoded_bytes = BASE64_STANDARD.decode(mdl_base64)?; + let mdl_json = String::from_utf8(decoded_bytes)?; + let manifest = serde_json::from_str::(&mdl_json)?; + Ok(manifest) +} + #[pyclass(name = "Manifest")] -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub struct PyManifest { pub catalog: String, pub schema: String, diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index 54aed8c72..047b4892a 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -1,8 +1,14 @@ import base64 import json +from contextlib import nullcontext as does_not_raise import pytest -from wren_core import SessionContext +from wren_core import ( + SessionContext, + resolve_used_table_names, + extract_manifest, + to_json_base64, +) manifest = { "catalog": "my_catalog", @@ -149,7 +155,7 @@ def test_get_available_functions(): ], ) def test_resolve_used_table_names(sql, expected): - tables = SessionContext(manifest_str, None).resolve_used_table_names(sql) + tables = resolve_used_table_names(manifest_str, sql) assert tables == expected @@ -163,6 +169,16 @@ def test_resolve_used_table_names(sql, expected): ], ) def test_extract_manifest(dataset, expected_models): - extracted_manifest = SessionContext(manifest_str, None).extract_manifest(dataset) + extracted_manifest = extract_manifest(manifest_str, dataset) assert len(extracted_manifest.models) == len(expected_models) assert [m.name for m in extracted_manifest.models] == expected_models + + +def test_to_json_base64(): + extracted_manifest = extract_manifest(manifest_str, ["customer"]) + base64_str = to_json_base64(extracted_manifest) + with does_not_raise(): + json_str = base64.b64decode(base64_str) + decoded_manifest = json.loads(json_str) + assert decoded_manifest["catalog"] == "my_catalog" + assert len(decoded_manifest["models"]) == 3 diff --git a/wren-core/core/src/mdl/manifest.rs b/wren-core/core/src/mdl/manifest.rs index 38d6cb221..1589f4a4b 100644 --- a/wren-core/core/src/mdl/manifest.rs +++ b/wren-core/core/src/mdl/manifest.rs @@ -6,7 +6,7 @@ use serde_with::serde_as; use serde_with::NoneAsEmptyString; /// This is the main struct that holds all the information about the manifest -#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash)] +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)] pub struct Manifest { pub catalog: String, pub schema: String, From 49617062725678ac23293549c2846d6680b28290 Mon Sep 17 00:00:00 2001 From: Grieve Date: Fri, 29 Nov 2024 18:08:44 +0800 Subject: [PATCH 07/30] feat: merge new rust func into ibis-server --- ibis-server/README.md | 2 +- ibis-server/app/mdl/context.py | 8 -------- ibis-server/app/mdl/core.py | 22 ++++++++++++++++++++++ ibis-server/app/mdl/rewriter.py | 18 +++++++++++++++--- ibis-server/app/routers/v3/connector.py | 5 ++--- ibis-server/docs/development.md | 4 ++-- ibis-server/tests/mdl/test_context.py | 2 +- 7 files changed, 43 insertions(+), 18 deletions(-) delete mode 100644 ibis-server/app/mdl/context.py create mode 100644 ibis-server/app/mdl/core.py diff --git a/ibis-server/README.md b/ibis-server/README.md index 1d4c6d8b3..ea7afc581 100644 --- a/ibis-server/README.md +++ b/ibis-server/README.md @@ -71,7 +71,7 @@ vim .env ``` Install the dependencies ```bash -just install +just install && just install-core ``` Run the server ```bash diff --git a/ibis-server/app/mdl/context.py b/ibis-server/app/mdl/context.py deleted file mode 100644 index 100765367..000000000 --- a/ibis-server/app/mdl/context.py +++ /dev/null @@ -1,8 +0,0 @@ -from functools import cache - - -@cache -def get_session_context(manifest_str: str, function_path: str): - from wren_core import SessionContext - - return SessionContext(manifest_str, function_path) diff --git a/ibis-server/app/mdl/core.py b/ibis-server/app/mdl/core.py new file mode 100644 index 000000000..0835fcb03 --- /dev/null +++ b/ibis-server/app/mdl/core.py @@ -0,0 +1,22 @@ +from functools import cache + +import wren_core + + +@cache +def get_session_context( + manifest_str: str | None, function_path: str +) -> wren_core.SessionContext: + return wren_core.SessionContext(manifest_str, function_path) + + +def resolve_used_table_names(manifest_str: str, sql: str) -> list[str]: + return wren_core.resolve_used_table_names(manifest_str, sql) + + +def extract_manifest(manifest_str: str, datasets: list[str]) -> dict: + return wren_core.extract_manifest(manifest_str, datasets) + + +def to_json_base64(manifest): + return wren_core.to_json_base64(manifest) diff --git a/ibis-server/app/mdl/rewriter.py b/ibis-server/app/mdl/rewriter.py index fc63f1a7a..9e6015608 100644 --- a/ibis-server/app/mdl/rewriter.py +++ b/ibis-server/app/mdl/rewriter.py @@ -6,7 +6,12 @@ from loguru import logger from app.config import get_config -from app.mdl.context import get_session_context +from app.mdl.core import ( + extract_manifest, + get_session_context, + resolve_used_table_names, + to_json_base64, +) from app.model import InternalServerError, UnprocessableEntityError from app.model.data_source import DataSource @@ -59,6 +64,9 @@ def __init__(self, manifest_str: str): def rewrite(self, sql: str) -> str: try: + tables = resolve_used_table_names(self.manifest_str, sql) + manifest = extract_manifest(self.manifest_str, tables) + manifest_str = to_json_base64(manifest) r = httpx.request( method="GET", url=f"{wren_engine_endpoint}/v2/mdl/dry-plan", @@ -66,7 +74,7 @@ def rewrite(self, sql: str) -> str: "Content-Type": "application/json", "Accept": "application/json", }, - content=orjson.dumps({"manifestStr": self.manifest_str, "sql": sql}), + content=orjson.dumps({"manifestStr": manifest_str, "sql": sql}), ) return r.raise_for_status().text.replace("\n", " ") except httpx.ConnectError as e: @@ -84,7 +92,11 @@ def __init__(self, manifest_str: str, function_path: str): def rewrite(self, sql: str) -> str: try: - session_context = get_session_context(self.manifest_str, self.function_path) + tables = resolve_used_table_names(self.manifest_str, sql) + manifest = extract_manifest(self.manifest_str, tables) + session_context = get_session_context( + to_json_base64(manifest), self.function_path + ) return session_context.transform_sql(sql) except Exception as e: raise RewriteError(str(e)) diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index d14366cc2..46bf2ba29 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -5,6 +5,7 @@ from app.config import get_config from app.dependencies import verify_query_dto +from app.mdl.core import get_session_context from app.mdl.rewriter import Rewriter from app.model import ( DryPlanDTO, @@ -60,10 +61,8 @@ def validate(data_source: DataSource, rule_name: str, dto: ValidateDTO) -> Respo @router.get("/{data_source}/functions") def functions(data_source: DataSource) -> Response: - from wren_core import SessionContext - file_path = get_config().get_remote_function_list_path(data_source) - session_context = SessionContext(None, file_path) + session_context = get_session_context(None, file_path) func_list = [f.to_dict() for f in session_context.get_available_functions()] return JSONResponse(func_list) diff --git a/ibis-server/docs/development.md b/ibis-server/docs/development.md index 3c1fe2802..71a056a39 100644 --- a/ibis-server/docs/development.md +++ b/ibis-server/docs/development.md @@ -27,8 +27,8 @@ This installs the pre-commit hooks. ## Start the server To get the application running: 1. Execute `just install` to install the dependencies -2. Create a `.env` file and fill in the required environment variables (see [Environment Variables](#Environment-Variables)) -3. If you want to use `wren_core`, you need to install the core by `just install-core`. After you modify the core, you can update it by `just update-core`. +2. Execute `just install-core` to Install the core. If you modify the core, you can update it by `just update-core`. +3. Create a `.env` file and fill in the required environment variables (see [Environment Variables](#Environment-Variables)) To start the server: - Execute `just run` to start the server diff --git a/ibis-server/tests/mdl/test_context.py b/ibis-server/tests/mdl/test_context.py index a5708dc61..cfc703495 100644 --- a/ibis-server/tests/mdl/test_context.py +++ b/ibis-server/tests/mdl/test_context.py @@ -2,7 +2,7 @@ import orjson -from app.mdl.context import get_session_context +from app.mdl.core import get_session_context from tests.conftest import file_path From 441d724c09118e0dda8f515d595ebfdc58ce3bd3 Mon Sep 17 00:00:00 2001 From: Grieve Date: Fri, 29 Nov 2024 18:25:15 +0800 Subject: [PATCH 08/30] feat: merge two rust functions be a class --- ibis-server/app/mdl/core.py | 8 +--- ibis-server/app/mdl/rewriter.py | 13 +++--- wren-core-py/src/extractor.rs | 59 +++++++++++++++--------- wren-core-py/src/lib.rs | 3 +- wren-core-py/tests/test_modeling_core.py | 9 ++-- 5 files changed, 50 insertions(+), 42 deletions(-) diff --git a/ibis-server/app/mdl/core.py b/ibis-server/app/mdl/core.py index 0835fcb03..e2329f929 100644 --- a/ibis-server/app/mdl/core.py +++ b/ibis-server/app/mdl/core.py @@ -10,12 +10,8 @@ def get_session_context( return wren_core.SessionContext(manifest_str, function_path) -def resolve_used_table_names(manifest_str: str, sql: str) -> list[str]: - return wren_core.resolve_used_table_names(manifest_str, sql) - - -def extract_manifest(manifest_str: str, datasets: list[str]) -> dict: - return wren_core.extract_manifest(manifest_str, datasets) +def get_extractor(manifest_str: str) -> wren_core.Extractor: + return wren_core.Extractor(manifest_str) def to_json_base64(manifest): diff --git a/ibis-server/app/mdl/rewriter.py b/ibis-server/app/mdl/rewriter.py index 9e6015608..b02055f8a 100644 --- a/ibis-server/app/mdl/rewriter.py +++ b/ibis-server/app/mdl/rewriter.py @@ -7,9 +7,8 @@ from app.config import get_config from app.mdl.core import ( - extract_manifest, + get_extractor, get_session_context, - resolve_used_table_names, to_json_base64, ) from app.model import InternalServerError, UnprocessableEntityError @@ -64,8 +63,9 @@ def __init__(self, manifest_str: str): def rewrite(self, sql: str) -> str: try: - tables = resolve_used_table_names(self.manifest_str, sql) - manifest = extract_manifest(self.manifest_str, tables) + extractor = get_extractor(self.manifest_str) + tables = extractor.resolve_used_table_names(sql) + manifest = extractor.extract_manifest(tables) manifest_str = to_json_base64(manifest) r = httpx.request( method="GET", @@ -92,8 +92,9 @@ def __init__(self, manifest_str: str, function_path: str): def rewrite(self, sql: str) -> str: try: - tables = resolve_used_table_names(self.manifest_str, sql) - manifest = extract_manifest(self.manifest_str, tables) + extractor = get_extractor(self.manifest_str) + tables = extractor.resolve_used_table_names(sql) + manifest = extractor.extract_manifest(tables) session_context = get_session_context( to_json_base64(manifest), self.function_path ) diff --git a/wren-core-py/src/extractor.rs b/wren-core-py/src/extractor.rs index 6c6d59ccf..f5f425ed9 100644 --- a/wren-core-py/src/extractor.rs +++ b/wren-core-py/src/extractor.rs @@ -1,27 +1,48 @@ use crate::errors::CoreError; use crate::manifest::{to_manifest, PyManifest}; -use pyo3::pyfunction; +use pyo3::{pyclass, pymethods}; use std::collections::HashSet; use std::sync::Arc; -use wren_core::mdl::manifest::{Manifest, Model, Relationship, View}; +use wren_core::mdl::manifest::{Model, Relationship, View}; use wren_core::mdl::WrenMDL; -/// parse the given SQL and return the list of used table name. -#[pyfunction] -#[pyo3(name = "resolve_used_table_names", signature = (mdl_base64, sql), text_signature = "(mdl_base64: str, sql: str)")] -pub fn py_resolve_used_table_names( - mdl_base64: &str, - sql: &str, -) -> Result, CoreError> { - let manifest = to_manifest(mdl_base64)?; - resolve_used_table_names(manifest, sql) +#[pyclass] +#[derive(Clone)] +#[pyo3(name = "Extractor")] +pub struct PyExtractor { + mdl: Arc, +} + +#[pymethods] +impl PyExtractor { + #[new] + pub fn new(mdl_base64: &str) -> Self { + let manifest = to_manifest(mdl_base64).unwrap(); + let mdl = WrenMDL::new_ref(manifest); + Self { mdl } + } + + /// parse the given SQL and return the list of used table name. + pub fn resolve_used_table_names(&self, sql: &str) -> Result, CoreError> { + resolve_used_table_names(&self.mdl, sql) + } + + /// Given a used dataset list, extract manifest by removing unused datasets. + /// If a model is related to another dataset, both datasets will be kept. + /// The relationship between of them will be kept as well. + /// A dataset could be model, view. + pub fn extract_manifest( + &self, + used_datasets: Vec, + ) -> Result { + extract_manifest(&self.mdl, used_datasets) + } } fn resolve_used_table_names( - manifest: Manifest, + mdl: &Arc, sql: &str, ) -> Result, CoreError> { - let mdl = WrenMDL::new_ref(manifest); let ctx_state = wren_core::SessionContext::new().state(); ctx_state .sql_to_statement(sql, "generic") @@ -43,18 +64,10 @@ fn resolve_used_table_names( }) } -/// Given a used dataset list, extract manifest by removing unused datasets. -/// If a model is related to another dataset, both datasets will be kept. -/// The relationship between of them will be kept as well. -/// A dataset could be model, view. -#[pyfunction] -#[pyo3(signature = (mdl_base64, used_datasets), text_signature = "(mdl_base64: str, used_datasets: list[str])")] pub fn extract_manifest( - mdl_base64: &str, + mdl: &Arc, used_datasets: Vec, ) -> Result { - let manifest = to_manifest(mdl_base64)?; - let mdl = WrenMDL::new_ref(manifest); let used_models = extract_models(&mdl, &used_datasets); let (used_views, models_of_views) = extract_views(&mdl, &used_datasets); let used_relationships = extract_relationships(&mdl, &used_datasets); @@ -103,7 +116,7 @@ fn extract_views( .iter() .filter_map(|&dataset_name| { mdl.get_view(dataset_name).and_then(|view| { - resolve_used_table_names(mdl.manifest.clone(), view.statement.as_str()) + resolve_used_table_names(mdl, view.statement.as_str()) .ok() .map(|used_tables| extract_models(mdl, &used_tables)) }) diff --git a/wren-core-py/src/lib.rs b/wren-core-py/src/lib.rs index 05c02c2e1..f8d00125e 100644 --- a/wren-core-py/src/lib.rs +++ b/wren-core-py/src/lib.rs @@ -15,8 +15,7 @@ fn wren_core_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_function(wrap_pyfunction!(extractor::py_resolve_used_table_names, m)?)?; - m.add_function(wrap_pyfunction!(extractor::extract_manifest, m)?)?; + m.add_class::()?; m.add_function(wrap_pyfunction!(manifest::to_json_base64, m)?)?; Ok(()) } diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index 047b4892a..8f808ab6b 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -5,8 +5,7 @@ import pytest from wren_core import ( SessionContext, - resolve_used_table_names, - extract_manifest, + Extractor, to_json_base64, ) @@ -155,7 +154,7 @@ def test_get_available_functions(): ], ) def test_resolve_used_table_names(sql, expected): - tables = resolve_used_table_names(manifest_str, sql) + tables = Extractor(manifest_str).resolve_used_table_names(sql) assert tables == expected @@ -169,13 +168,13 @@ def test_resolve_used_table_names(sql, expected): ], ) def test_extract_manifest(dataset, expected_models): - extracted_manifest = extract_manifest(manifest_str, dataset) + extracted_manifest = Extractor(manifest_str).extract_manifest(dataset) assert len(extracted_manifest.models) == len(expected_models) assert [m.name for m in extracted_manifest.models] == expected_models def test_to_json_base64(): - extracted_manifest = extract_manifest(manifest_str, ["customer"]) + extracted_manifest = Extractor(manifest_str).extract_manifest(["customer"]) base64_str = to_json_base64(extracted_manifest) with does_not_raise(): json_str = base64.b64decode(base64_str) From f21d1a572e1ad233857196a4426458d2a85c97e7 Mon Sep 17 00:00:00 2001 From: Grieve Date: Mon, 2 Dec 2024 18:39:30 +0800 Subject: [PATCH 09/30] fix: make table without catalog or schema not filtered --- wren-core-py/src/extractor.rs | 4 ++-- wren-core-py/tests/test_modeling_core.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/wren-core-py/src/extractor.rs b/wren-core-py/src/extractor.rs index f5f425ed9..356eae126 100644 --- a/wren-core-py/src/extractor.rs +++ b/wren-core-py/src/extractor.rs @@ -56,8 +56,8 @@ fn resolve_used_table_names( tables .iter() .filter(|t| { - t.catalog().unwrap_or_default() == mdl.catalog() - && t.schema().unwrap_or_default() == mdl.schema() + t.catalog().map_or(true, |catalog| catalog == mdl.catalog()) + && t.schema().map_or(true, |schema| schema == mdl.schema()) }) .map(|t| t.table().to_string()) .collect() diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index 8f808ab6b..5b162e6b5 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -145,6 +145,9 @@ def test_get_available_functions(): @pytest.mark.parametrize( ("sql", "expected"), [ + ("SELECT * FROM customer", ["customer"]), + ("SELECT * FROM not_my_catalog.my_schema.customer", []), + ("SELECT * FROM my_catalog.not_my_schema.customer", []), ("SELECT * FROM my_catalog.my_schema.customer", ["customer"]), ( "SELECT * FROM my_catalog.my_schema.customer JOIN my_catalog.my_schema.orders ON customer.custkey = orders.custkey", From c3c3c9cb7d7aef751540d1e739023ca2863b60e5 Mon Sep 17 00:00:00 2001 From: Grieve Date: Mon, 2 Dec 2024 18:40:05 +0800 Subject: [PATCH 10/30] chore: remove redundant '&' --- wren-core-py/src/extractor.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wren-core-py/src/extractor.rs b/wren-core-py/src/extractor.rs index 356eae126..3636ba5ee 100644 --- a/wren-core-py/src/extractor.rs +++ b/wren-core-py/src/extractor.rs @@ -68,9 +68,9 @@ pub fn extract_manifest( mdl: &Arc, used_datasets: Vec, ) -> Result { - let used_models = extract_models(&mdl, &used_datasets); - let (used_views, models_of_views) = extract_views(&mdl, &used_datasets); - let used_relationships = extract_relationships(&mdl, &used_datasets); + let used_models = extract_models(mdl, &used_datasets); + let (used_views, models_of_views) = extract_views(mdl, &used_datasets); + let used_relationships = extract_relationships(mdl, &used_datasets); Ok(PyManifest { catalog: mdl.catalog().to_string(), schema: mdl.schema().to_string(), From 7cc2f01de46c966af5c20f795b887ddc48eb3395 Mon Sep 17 00:00:00 2001 From: Grieve Date: Mon, 2 Dec 2024 18:40:46 +0800 Subject: [PATCH 11/30] chore: make struct can debug --- wren-core-py/src/manifest.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/wren-core-py/src/manifest.rs b/wren-core-py/src/manifest.rs index 18625c21c..44c0c204a 100644 --- a/wren-core-py/src/manifest.rs +++ b/wren-core-py/src/manifest.rs @@ -24,7 +24,7 @@ pub fn to_manifest(mdl_base64: &str) -> Result { } #[pyclass(name = "Manifest")] -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct PyManifest { pub catalog: String, pub schema: String, @@ -84,7 +84,7 @@ impl PyManifest { } #[pyclass(name = "Model")] -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct PyModel { #[pyo3(get)] pub name: String, @@ -121,7 +121,7 @@ impl From<&Model> for PyModel { name: model.name.clone(), ref_sql: model.ref_sql.clone(), base_object: model.base_object.clone(), - table_reference: model.table_reference.clone(), + table_reference: Some(String::from(model.table_reference())), columns: model.columns.clone(), primary_key: model.primary_key.clone(), cached: model.cached, @@ -131,7 +131,7 @@ impl From<&Model> for PyModel { } #[pyclass(name = "Column")] -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct PyColumn { #[pyo3(get)] pub name: String, @@ -164,7 +164,7 @@ impl From<&Column> for PyColumn { } #[pyclass(name = "Relationship")] -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct PyRelationship { #[pyo3(get)] pub name: String, @@ -195,7 +195,7 @@ impl From<&Relationship> for PyRelationship { } #[pyclass(name = "JoinType", eq)] -#[derive(Serialize, Deserialize, PartialEq, Eq)] +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] pub enum PyJoinType { #[serde(alias = "one_to_one")] OneToOne, @@ -219,7 +219,7 @@ impl From<&JoinType> for PyJoinType { } #[pyclass(name = "Metric")] -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct PyMetric { #[pyo3(get)] pub name: String, @@ -249,7 +249,7 @@ impl From<&Metric> for PyMetric { } #[pyclass(name = "TimeGrain")] -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct PyTimeGrain { pub name: String, pub ref_column: String, @@ -257,7 +257,7 @@ pub struct PyTimeGrain { } #[pyclass(name = "TimeUnit", eq)] -#[derive(Serialize, Deserialize, PartialEq, Eq)] +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] pub enum PyTimeUnit { Year, Month, @@ -268,7 +268,7 @@ pub enum PyTimeUnit { } #[pyclass(name = "View")] -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct PyView { #[pyo3(get)] pub name: String, From 1fafd47892fe2cc7a9c5e938c34a6d1d1b6c14f3 Mon Sep 17 00:00:00 2001 From: Grieve Date: Mon, 2 Dec 2024 18:41:20 +0800 Subject: [PATCH 12/30] feat: return null when table ref is None --- wren-core/core/src/mdl/manifest.rs | 34 +++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/wren-core/core/src/mdl/manifest.rs b/wren-core/core/src/mdl/manifest.rs index 1589f4a4b..cce4d4837 100644 --- a/wren-core/core/src/mdl/manifest.rs +++ b/wren-core/core/src/mdl/manifest.rs @@ -122,7 +122,7 @@ mod table_reference { }; table_ref.serialize(serializer) } else { - TableReference::default().serialize(serializer) + serializer.serialize_none() } } } @@ -258,3 +258,35 @@ impl View { &self.name } } + +#[cfg(test)] +mod tests { + use crate::mdl::manifest::table_reference; + use serde_json::Serializer; + + #[test] + fn test_table_reference_serialize() { + [ + ( + Some("catalog.schema.table".to_string()), + r#"{"catalog":"catalog","schema":"schema","table":"table"}"#, + ), + ( + Some("schema.table".to_string()), + r#"{"catalog":null,"schema":"schema","table":"table"}"#, + ), + ( + Some("table".to_string()), + r#"{"catalog":null,"schema":null,"table":"table"}"#, + ), + (None, "null"), + ] + .iter() + .for_each(|(table_ref, expected)| { + let mut buf = Vec::new(); + table_reference::serialize(table_ref, &mut Serializer::new(&mut buf)) + .unwrap(); + assert_eq!(String::from_utf8(buf).unwrap(), *expected); + }); + } +} From 8d54cfdcdeb3d0526a3be8e9201a7636e9a9a19d Mon Sep 17 00:00:00 2001 From: Grieve Date: Tue, 3 Dec 2024 14:32:37 +0800 Subject: [PATCH 13/30] chore: make deserialize can receive null --- wren-core/core/src/mdl/manifest.rs | 37 ++++++++++++------------------ 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/wren-core/core/src/mdl/manifest.rs b/wren-core/core/src/mdl/manifest.rs index cce4d4837..0a5ac0922 100644 --- a/wren-core/core/src/mdl/manifest.rs +++ b/wren-core/core/src/mdl/manifest.rs @@ -60,28 +60,21 @@ mod table_reference { where D: Deserializer<'de>, { - let TableReference { - catalog, - schema, - table, - } = TableReference::deserialize(deserializer)?; - let mut result = String::new(); - if let Some(catalog) = catalog.filter(|c| !c.is_empty()) { - result.push_str(&catalog); - result.push('.'); - } - if let Some(schema) = schema.filter(|s| !s.is_empty()) { - result.push_str(&schema); - result.push('.'); - } - if let Some(table) = table.filter(|t| !t.is_empty()) { - result.push_str(&table); - } - if result.is_empty() { - Ok(None) - } else { - Ok(Some(result)) - } + Ok(Option::deserialize(deserializer)? + .map( + |TableReference { + catalog, + schema, + table, + }| { + [catalog, schema, table] + .into_iter() + .filter_map(|s| s.filter(|x| !x.is_empty())) + .collect::>() + .join(".") + }, + ) + .filter(|s| !s.is_empty())) } pub fn serialize( From f181d128ddbc9d22175bc96c2cce49e709aa59f8 Mon Sep 17 00:00:00 2001 From: Grieve Date: Tue, 3 Dec 2024 11:53:46 +0800 Subject: [PATCH 14/30] fix: wren core raise CoreError instead of pyo3_runtime.PanicException --- wren-core-py/src/extractor.rs | 6 +++--- wren-core-py/tests/test_modeling_core.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/wren-core-py/src/extractor.rs b/wren-core-py/src/extractor.rs index 3636ba5ee..26f8c4baa 100644 --- a/wren-core-py/src/extractor.rs +++ b/wren-core-py/src/extractor.rs @@ -16,10 +16,10 @@ pub struct PyExtractor { #[pymethods] impl PyExtractor { #[new] - pub fn new(mdl_base64: &str) -> Self { - let manifest = to_manifest(mdl_base64).unwrap(); + pub fn new(mdl_base64: &str) -> Result { + let manifest = to_manifest(mdl_base64)?; let mdl = WrenMDL::new_ref(manifest); - Self { mdl } + Ok(Self { mdl }) } /// parse the given SQL and return the list of used table name. diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index 5b162e6b5..814a679a8 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -142,6 +142,29 @@ def test_get_available_functions(): assert max_if["param_types"] is None +@pytest.mark.parametrize( + ("value", "expected_error", "error_message"), + [ + ( + None, + TypeError, + "argument 'mdl_base64': 'NoneType' object cannot be converted to 'PyString'", + ), + ("xxx", Exception, "Base64 decode error: Invalid padding"), + ("{}", Exception, "Base64 decode error: Invalid symbol 123, offset 0."), + ( + "", + Exception, + "Serde JSON error: EOF while parsing a value at line 1 column 0", + ), + ], +) +def test_extractor_with_invalid_manifest(value, expected_error, error_message): + with pytest.raises(expected_error) as e: + Extractor(value) + assert str(e.value) == error_message + + @pytest.mark.parametrize( ("sql", "expected"), [ From 5a8c94956d6a751b89b24b709bba46959d495570 Mon Sep 17 00:00:00 2001 From: Grieve Date: Tue, 3 Dec 2024 11:56:12 +0800 Subject: [PATCH 15/30] chore: adjust style --- ibis-server/app/mdl/rewriter.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ibis-server/app/mdl/rewriter.py b/ibis-server/app/mdl/rewriter.py index b02055f8a..4edd39229 100644 --- a/ibis-server/app/mdl/rewriter.py +++ b/ibis-server/app/mdl/rewriter.py @@ -95,9 +95,8 @@ def rewrite(self, sql: str) -> str: extractor = get_extractor(self.manifest_str) tables = extractor.resolve_used_table_names(sql) manifest = extractor.extract_manifest(tables) - session_context = get_session_context( - to_json_base64(manifest), self.function_path - ) + manifest_str = to_json_base64(manifest) + session_context = get_session_context(manifest_str, self.function_path) return session_context.transform_sql(sql) except Exception as e: raise RewriteError(str(e)) From c3c6d58fddf2b9ef72766958f432a3a9ce52d24a Mon Sep 17 00:00:00 2001 From: Grieve Date: Tue, 3 Dec 2024 12:14:39 +0800 Subject: [PATCH 16/30] chore: add format --- wren-core-py/justfile | 8 +++ wren-core-py/pyproject.toml | 92 ++++++++++++++++++++++++ wren-core-py/tests/__init__.py | 0 wren-core-py/tests/test_modeling_core.py | 10 +-- 4 files changed, 103 insertions(+), 7 deletions(-) create mode 100644 wren-core-py/tests/__init__.py diff --git a/wren-core-py/justfile b/wren-core-py/justfile index 509cc05bd..16bfd0be3 100644 --- a/wren-core-py/justfile +++ b/wren-core-py/justfile @@ -13,3 +13,11 @@ develop: test: develop cargo test --no-default-features poetry run pytest + +alias fmt := format + +format: + cargo fmt + poetry run ruff format . + poetry run ruff check --fix . + taplo fmt diff --git a/wren-core-py/pyproject.toml b/wren-core-py/pyproject.toml index c8efa218b..702a82fc4 100644 --- a/wren-core-py/pyproject.toml +++ b/wren-core-py/pyproject.toml @@ -14,6 +14,7 @@ maturin = "1.7.5" [tool.poetry.group.dev.dependencies] pytest = "8.3.3" +ruff = "0.8.0" [tool.maturin] module-name = "wren_core" @@ -25,3 +26,94 @@ features = ["pyo3/extension-module"] [build-system] requires = ["maturin>=1.0,<2.0"] build-backend = "maturin" + +[tool.ruff] +line-length = 88 +target-version = "py311" +exclude = ["tools/"] + +[tool.ruff.lint] +select = [ + "C4", # comprehensions + "D", # pydocstyle + "E", # pycodestyle + "EXE", # flake8-executable + "F", # pyflakes + "FA", # flake8-future-annotations + "G", # flake8-logging-format + "FLY", # flynt (format string conversion) + "I", # isort + "ICN", # flake8-import-conventions + "INP", # flake8-no-pep420 (implicit namespace packages) + "ISC", # flake8-implicit-str-concat + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PL", # pylint + "RET", # flake8-return + "RUF", # ruff-specific rules + "SIM", # flake8-simplify + "T10", # flake8-debugger + "T20", # flake8-print + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "YTT", # flake8-2020 +] +ignore = [ + "B008", # do not perform function calls in argument defaults + "B028", # required stacklevel argument to warn + "B904", # raise from e or raise from None in exception handlers + "B905", # zip-without-explicit-strict + "C408", # dict(...) as literal + "C901", # too complex + "D100", # public module + "D101", # public class + "D102", # public method + "D103", # public function + "D104", # public package + "D105", # magic methods + "D106", # nested class + "D107", # init + "D202", # blank lines after function docstring + "D203", # blank line before class docstring + "D213", # Multi-line docstring summary should start at the second line + "D401", # Imperative mood + "D402", # First line should not be the function's signature + "D413", # Blank line required after last section + "E501", # line-too-long, this is automatically enforced by ruff format + "E731", # lambda-assignment + "ISC001", # single line implicit string concat, handled by ruff format + "PGH003", # blanket-type-ignore + "PLC0105", # covariant type parameters should have a _co suffix + "PLR0124", # name compared with self, e.g., a == a + "PLR0911", # too many return statements + "PLR0912", # too many branches + "PLR0913", # too many arguments + "PLR0915", # too many statements + "PLR2004", # forces everything to be a constant + "PLW2901", # overwriting loop variable + "RET504", # unnecessary-assign, these are useful for debugging + "RET505", # superfluous-else-return, stylistic choice + "RET506", # superfluous-else-raise, stylistic choice + "RET507", # superfluous-else-continue, stylistic choice + "RET508", # superfluous-else-break, stylistic choice + "RUF005", # splat instead of concat + "RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` + "S101", # ignore "Use of `assert` detected" + "SIM102", # nested ifs + "SIM108", # convert everything to ternary operator + "SIM114", # combine `if` branches using logical `or` operator + "SIM116", # dictionary instead of `if` statements + "SIM117", # nested with statements + "SIM118", # remove .keys() calls from dictionaries + "SIM300", # yoda conditions + "UP007", # Optional[str] -> str | None + "UP038", # non-pep604-isinstance, results in slower code + "W191", # indentation contains tabs +] +# none of these codes will be automatically fixed by ruff +unfixable = [ + "T201", # print statements + "F401", # unused imports + "RUF100", # unused noqa comments + "F841", # unused variables +] diff --git a/wren-core-py/tests/__init__.py b/wren-core-py/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index 814a679a8..880a6d290 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -4,8 +4,8 @@ import pytest from wren_core import ( - SessionContext, Extractor, + SessionContext, to_json_base64, ) @@ -123,9 +123,7 @@ def test_read_function_list(): def test_get_available_functions(): session_context = SessionContext(manifest_str, "tests/functions.csv") functions = session_context.get_available_functions() - add_two = next( - filter(lambda x: x["name"] == "add_two", map(lambda x: x.to_dict(), functions)) - ) + add_two = next(x.to_dict() for x in functions if x["name"] == "add_two") assert add_two["name"] == "add_two" assert add_two["function_type"] == "scalar" assert add_two["description"] == "Adds two numbers together." @@ -133,9 +131,7 @@ def test_get_available_functions(): assert add_two["param_names"] == "f1,f2" assert add_two["param_types"] == "int,int" - max_if = next( - filter(lambda x: x["name"] == "max_if", map(lambda x: x.to_dict(), functions)) - ) + max_if = next(x.to_dict() for x in functions if x["name"] == "max_if") assert max_if["name"] == "max_if" assert max_if["function_type"] == "window" assert max_if["param_names"] is None From e8e6ec6297be453300e6b58f42d6a2c61d66fabb Mon Sep 17 00:00:00 2001 From: Grieve Date: Tue, 3 Dec 2024 12:15:57 +0800 Subject: [PATCH 17/30] chore: lock poetry --- wren-core-py/poetry.lock | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/wren-core-py/poetry.lock b/wren-core-py/poetry.lock index 0b8dcd523..b42f47a1a 100644 --- a/wren-core-py/poetry.lock +++ b/wren-core-py/poetry.lock @@ -94,7 +94,34 @@ pluggy = ">=1.5,<2" [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "ruff" +version = "0.8.0" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.8.0-py3-none-linux_armv6l.whl", hash = "sha256:fcb1bf2cc6706adae9d79c8d86478677e3bbd4ced796ccad106fd4776d395fea"}, + {file = "ruff-0.8.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:295bb4c02d58ff2ef4378a1870c20af30723013f441c9d1637a008baaf928c8b"}, + {file = "ruff-0.8.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7b1f1c76b47c18fa92ee78b60d2d20d7e866c55ee603e7d19c1e991fad933a9a"}, + {file = "ruff-0.8.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb0d4f250a7711b67ad513fde67e8870109e5ce590a801c3722580fe98c33a99"}, + {file = "ruff-0.8.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0e55cce9aa93c5d0d4e3937e47b169035c7e91c8655b0974e61bb79cf398d49c"}, + {file = "ruff-0.8.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f4cd64916d8e732ce6b87f3f5296a8942d285bbbc161acee7fe561134af64f9"}, + {file = "ruff-0.8.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c5c1466be2a2ebdf7c5450dd5d980cc87c8ba6976fb82582fea18823da6fa362"}, + {file = "ruff-0.8.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2dabfd05b96b7b8f2da00d53c514eea842bff83e41e1cceb08ae1966254a51df"}, + {file = "ruff-0.8.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:facebdfe5a5af6b1588a1d26d170635ead6892d0e314477e80256ef4a8470cf3"}, + {file = "ruff-0.8.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87a8e86bae0dbd749c815211ca11e3a7bd559b9710746c559ed63106d382bd9c"}, + {file = "ruff-0.8.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:85e654f0ded7befe2d61eeaf3d3b1e4ef3894469cd664ffa85006c7720f1e4a2"}, + {file = "ruff-0.8.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:83a55679c4cb449fa527b8497cadf54f076603cc36779b2170b24f704171ce70"}, + {file = "ruff-0.8.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:812e2052121634cf13cd6fddf0c1871d0ead1aad40a1a258753c04c18bb71bbd"}, + {file = "ruff-0.8.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:780d5d8523c04202184405e60c98d7595bdb498c3c6abba3b6d4cdf2ca2af426"}, + {file = "ruff-0.8.0-py3-none-win32.whl", hash = "sha256:5fdb6efecc3eb60bba5819679466471fd7d13c53487df7248d6e27146e985468"}, + {file = "ruff-0.8.0-py3-none-win_amd64.whl", hash = "sha256:582891c57b96228d146725975fbb942e1f30a0c4ba19722e692ca3eb25cc9b4f"}, + {file = "ruff-0.8.0-py3-none-win_arm64.whl", hash = "sha256:ba93e6294e9a737cd726b74b09a6972e36bb511f9a102f1d9a7e1ce94dd206a6"}, + {file = "ruff-0.8.0.tar.gz", hash = "sha256:a7ccfe6331bf8c8dad715753e157457faf7351c2b69f62f32c165c2dbcbacd44"}, +] + [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "8e45af48010bff32480c1523a46b1b627116738ad479336edb4cd5cdd2a5f795" +content-hash = "f763127adce2d84b1766e5e0e5c910e434d63458aa6eda747f0da1c57e215fae" From fdec2382b2e48b15bb4439e0726ba1595462d775 Mon Sep 17 00:00:00 2001 From: Grieve Date: Tue, 3 Dec 2024 12:56:56 +0800 Subject: [PATCH 18/30] chore: let PyRemoteFunction with getter --- wren-core-py/src/remote_functions.rs | 6 ++++++ wren-core-py/tests/test_modeling_core.py | 26 ++++++++++++------------ 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/wren-core-py/src/remote_functions.rs b/wren-core-py/src/remote_functions.rs index f7980fca6..552b94242 100644 --- a/wren-core-py/src/remote_functions.rs +++ b/wren-core-py/src/remote_functions.rs @@ -25,13 +25,19 @@ use wren_core::mdl::function::FunctionType; #[pyclass(name = "RemoteFunction")] #[derive(Serialize, Deserialize, Clone)] pub struct PyRemoteFunction { + #[pyo3(get)] pub function_type: String, + #[pyo3(get)] pub name: String, + #[pyo3(get)] pub return_type: Option, /// It's a comma separated string of parameter names + #[pyo3(get)] pub param_names: Option, /// It's a comma separated string of parameter types + #[pyo3(get)] pub param_types: Option, + #[pyo3(get)] pub description: Option, } diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index 880a6d290..23b4bc21d 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -123,19 +123,19 @@ def test_read_function_list(): def test_get_available_functions(): session_context = SessionContext(manifest_str, "tests/functions.csv") functions = session_context.get_available_functions() - add_two = next(x.to_dict() for x in functions if x["name"] == "add_two") - assert add_two["name"] == "add_two" - assert add_two["function_type"] == "scalar" - assert add_two["description"] == "Adds two numbers together." - assert add_two["return_type"] == "int" - assert add_two["param_names"] == "f1,f2" - assert add_two["param_types"] == "int,int" - - max_if = next(x.to_dict() for x in functions if x["name"] == "max_if") - assert max_if["name"] == "max_if" - assert max_if["function_type"] == "window" - assert max_if["param_names"] is None - assert max_if["param_types"] is None + add_two = next(f for f in functions if f.name == "add_two") + assert add_two.name == "add_two" + assert add_two.function_type == "scalar" + assert add_two.description == "Adds two numbers together." + assert add_two.return_type == "int" + assert add_two.param_names == "f1,f2" + assert add_two.param_types == "int,int" + + max_if = next(f for f in functions if f.name == "max_if") + assert max_if.name == "max_if" + assert max_if.function_type == "window" + assert max_if.param_names is None + assert max_if.param_types is None @pytest.mark.parametrize( From f217fd84400c09c3d1a7780ba620d195fe0e6773 Mon Sep 17 00:00:00 2001 From: Grieve Date: Tue, 3 Dec 2024 17:31:55 +0800 Subject: [PATCH 19/30] test: add test in Rust --- wren-core-py/Cargo.lock | 80 ++++++++++++++++- wren-core-py/Cargo.toml | 3 + wren-core-py/src/errors.rs | 2 +- wren-core-py/src/extractor.rs | 163 +++++++++++++++++++++++++++++++++- wren-core-py/src/manifest.rs | 53 +++++++++++ 5 files changed, 295 insertions(+), 6 deletions(-) diff --git a/wren-core-py/Cargo.lock b/wren-core-py/Cargo.lock index 962c2c12f..1de75dcd3 100644 --- a/wren-core-py/Cargo.lock +++ b/wren-core-py/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -1314,6 +1314,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -2151,6 +2157,15 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "proc-macro-crate" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +dependencies = [ + "toml_edit", +] + [[package]] name = "proc-macro2" version = "1.0.89" @@ -2300,6 +2315,42 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + +[[package]] +name = "rstest" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a2c585be59b6b5dd66a9d2084aa1d8bd52fbdb806eafdeffb52791147862035" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "825ea780781b15345a146be27eaefb05085e337e869bff01b4306a4fd4a9ad5a" +dependencies = [ + "cfg-if", + "glob", + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn", + "unicode-ident", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -2719,6 +2770,23 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml_datetime" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" + +[[package]] +name = "toml_edit" +version = "0.22.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +dependencies = [ + "indexmap 2.6.0", + "toml_datetime", + "winnow", +] + [[package]] name = "tracing" version = "0.1.40" @@ -3015,6 +3083,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +dependencies = [ + "memchr", +] + [[package]] name = "wren-core" version = "0.1.0" @@ -3044,6 +3121,7 @@ dependencies = [ "log", "pyo3", "pyo3-build-config", + "rstest", "serde", "serde_json", "thiserror", diff --git a/wren-core-py/Cargo.toml b/wren-core-py/Cargo.toml index 1e77fc2fa..a5bab9c5e 100644 --- a/wren-core-py/Cargo.toml +++ b/wren-core-py/Cargo.toml @@ -20,6 +20,9 @@ env_logger = "0.11.5" log = "0.4.22" tokio = "1.40.0" +[dev-dependencies] +rstest = "0.23.0" + [build-dependencies] pyo3-build-config = "0.23.2" diff --git a/wren-core-py/src/errors.rs b/wren-core-py/src/errors.rs index 9ad301f85..56dc403eb 100644 --- a/wren-core-py/src/errors.rs +++ b/wren-core-py/src/errors.rs @@ -4,7 +4,7 @@ use pyo3::PyErr; use std::string::FromUtf8Error; use thiserror::Error; -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] #[error("{message}")] pub struct CoreError { message: String, diff --git a/wren-core-py/src/extractor.rs b/wren-core-py/src/extractor.rs index 26f8c4baa..636c8441e 100644 --- a/wren-core-py/src/extractor.rs +++ b/wren-core-py/src/extractor.rs @@ -16,10 +16,14 @@ pub struct PyExtractor { #[pymethods] impl PyExtractor { #[new] - pub fn new(mdl_base64: &str) -> Result { - let manifest = to_manifest(mdl_base64)?; - let mdl = WrenMDL::new_ref(manifest); - Ok(Self { mdl }) + #[pyo3(signature = (mdl_base64=None))] + pub fn new(mdl_base64: Option<&str>) -> Result { + mdl_base64 + .ok_or_else(|| CoreError::new("Expected a valid base64 encoded string for the model definition, but got None.")) + .and_then(to_manifest) + .map(|manifest| Self { + mdl: WrenMDL::new_ref(manifest), + }) } /// parse the given SQL and return the list of used table name. @@ -154,3 +158,154 @@ fn extract_relationships( .cloned() .collect() } + +#[cfg(test)] +mod tests { + use crate::extractor::PyExtractor; + use base64::prelude::BASE64_STANDARD; + use base64::Engine; + use rstest::{fixture, rstest}; + use std::iter::Iterator; + + #[fixture] + pub fn mdl_base64() -> String { + let mdl_json = r#" + { + "catalog": "my_catalog", + "schema": "my_schema", + "models": [ + { + "name": "customer", + "tableReference": { + "schema": "main", + "table": "customer" + }, + "columns": [ + {"name": "c_custkey", "type": "integer"}, + {"name": "orders", "type": "orders", "relationship": "orders_customer"} + ], + "primaryKey": "c_custkey" + }, + { + "name": "orders", + "tableReference": { + "schema": "main", + "table": "orders" + }, + "columns": [ + {"name": "o_orderkey", "type": "integer"}, + {"name": "o_custkey", "type": "integer"}, + { + "name": "lineitems", + "type": "Lineitem", + "relationship": "orders_lineitem" + } + ], + "primaryKey": "o_orderkey" + }, + { + "name": "lineitem", + "tableReference": { + "schema": "main", + "table": "lineitem" + }, + "columns": [ + {"name": "l_orderkey", "type": "integer"} + ], + "primaryKey": "l_orderkey" + } + ], + "relationships": [ + { + "name": "orders_customer", + "models": ["orders", "customer"], + "joinType": "MANY_TO_ONE", + "condition": "orders.custkey = customer.custkey" + }, + { + "name": "orders_lineitem", + "models": ["orders", "lineitem"], + "joinType": "ONE_TO_MANY", + "condition": "orders.orderkey = lineitem.orderkey" + } + ], + "views": [ + { + "name": "customer_view", + "statement": "SELECT * FROM my_catalog.my_schema.customer" + } + ] + }"#; + BASE64_STANDARD.encode(mdl_json.as_bytes()) + } + + #[fixture] + pub fn extractor(mdl_base64: String) -> PyExtractor { + PyExtractor::new(Option::from(mdl_base64.as_str())).unwrap() + } + + #[rstest] + #[case( + None, + "Expected a valid base64 encoded string for the model definition, but got None." + )] + #[case(Some("xxx"), "Base64 decode error: Invalid padding")] + #[case(Some("{}"), "Base64 decode error: Invalid symbol 123, offset 0.")] + #[case( + Some(""), + "Serde JSON error: EOF while parsing a value at line 1 column 0" + )] + fn test_extractor_with_invalid_manifest( + #[case] value: Option<&str>, + #[case] error_message: &str, + ) { + let result = PyExtractor::new(value); + + match result { + Err(err) => { + assert_eq!(err.to_string(), error_message); + } + Ok(_) => panic!("Expected an error but got Ok"), + } + } + + #[rstest] + #[case("SELECT * FROM customer", vec!["customer"])] + #[case("SELECT * FROM not_my_catalog.my_schema.customer", vec![])] + #[case("SELECT * FROM my_catalog.not_my_schema.customer", vec![])] + #[case("SELECT * FROM my_catalog.my_schema.customer", vec!["customer"])] + #[case("SELECT * FROM my_catalog.my_schema.customer JOIN my_catalog.my_schema.orders ON customer.custkey = orders.custkey", vec!["customer", "orders"])] + #[case("SELECT * FROM my_catalog.my_schema.customer_view", vec!["customer_view"])] + fn test_resolve_used_table_names( + extractor: PyExtractor, + #[case] sql: &str, + #[case] expected: Vec<&str>, + ) { + assert_eq!(extractor.resolve_used_table_names(sql).unwrap(), expected); + } + + #[rstest] + #[case(vec!["customer"], vec!["customer", "orders", "lineitem"])] + #[case(vec!["customer_view"], vec!["customer", "orders", "lineitem"])] + #[case(vec!["orders"], vec!["orders", "lineitem"])] + #[case(vec!["lineitem"], vec!["lineitem"])] + fn test_extract_manifest( + extractor: PyExtractor, + #[case] dataset: Vec<&str>, + #[case] expected_models: Vec<&str>, + ) { + let dataset_strings: Vec = + dataset.iter().map(|s| s.to_string()).collect(); + let extracted_manifest = extractor.extract_manifest(dataset_strings).unwrap(); + + assert_eq!(extracted_manifest.models.len(), expected_models.len()); + assert_eq!( + extracted_manifest + .models + .iter() + .map(|m| m.name.clone()) + .collect::>(), + expected_models + ); + } +} diff --git a/wren-core-py/src/manifest.rs b/wren-core-py/src/manifest.rs index 44c0c204a..b2187da20 100644 --- a/wren-core-py/src/manifest.rs +++ b/wren-core-py/src/manifest.rs @@ -284,3 +284,56 @@ impl From<&View> for PyView { } } } + +#[cfg(test)] +mod tests { + use crate::manifest::{to_json_base64, to_manifest, PyManifest}; + use rstest::rstest; + use std::sync::Arc; + use wren_core::mdl::manifest::Model; + + #[rstest] + fn test_manifest_to_json_base64() { + let py_manifest = PyManifest { + catalog: "catalog".to_string(), + schema: "schema".to_string(), + models: vec![ + Arc::from(Model { + name: "model_1".to_string(), + ref_sql: "SELECT * FROM table".to_string().into(), + base_object: None, + table_reference: None, + columns: vec![], + primary_key: None, + cached: false, + refresh_time: None, + }), + Arc::from(Model { + name: "model_2".to_string(), + ref_sql: None, + base_object: None, + table_reference: "catalog.schema.table".to_string().into(), + columns: vec![], + primary_key: None, + cached: false, + refresh_time: None, + }), + ], + relationships: vec![], + metrics: vec![], + views: vec![], + }; + let base64_str = to_json_base64(py_manifest).unwrap(); + let manifest = to_manifest(&base64_str).unwrap(); + assert_eq!(manifest.catalog, "catalog"); + assert_eq!(manifest.schema, "schema"); + assert_eq!(manifest.models.len(), 2); + assert_eq!(manifest.models[0].name, "model_1"); + assert_eq!( + manifest.models[0].ref_sql, + Some("SELECT * FROM table".to_string()) + ); + assert_eq!(manifest.models[1].name(), "model_2"); + assert_eq!(manifest.models[1].table_reference(), "catalog.schema.table"); + } +} From efae0eda5463032d9868525525ea562a51da9661 Mon Sep 17 00:00:00 2001 From: Grieve Date: Tue, 3 Dec 2024 17:42:46 +0800 Subject: [PATCH 20/30] test: update expected --- wren-core-py/tests/test_modeling_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index 23b4bc21d..f506069e7 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -143,8 +143,8 @@ def test_get_available_functions(): [ ( None, - TypeError, - "argument 'mdl_base64': 'NoneType' object cannot be converted to 'PyString'", + Exception, + "Expected a valid base64 encoded string for the model definition, but got None.", ), ("xxx", Exception, "Base64 decode error: Invalid padding"), ("{}", Exception, "Base64 decode error: Invalid symbol 123, offset 0."), From 2c7e787afbdfc4f445a030ea48db65b4bef31509 Mon Sep 17 00:00:00 2001 From: Grieve Date: Wed, 4 Dec 2024 11:47:08 +0800 Subject: [PATCH 21/30] chore: rename class from `Extractor` to `ManifestExtractor` --- ibis-server/app/mdl/core.py | 6 ++-- ibis-server/app/mdl/rewriter.py | 10 +++--- wren-core-py/src/extractor.rs | 45 +++++++++++------------- wren-core-py/src/lib.rs | 2 +- wren-core-py/tests/test_modeling_core.py | 12 +++---- 5 files changed, 35 insertions(+), 40 deletions(-) diff --git a/ibis-server/app/mdl/core.py b/ibis-server/app/mdl/core.py index e2329f929..27df4d10e 100644 --- a/ibis-server/app/mdl/core.py +++ b/ibis-server/app/mdl/core.py @@ -10,9 +10,9 @@ def get_session_context( return wren_core.SessionContext(manifest_str, function_path) -def get_extractor(manifest_str: str) -> wren_core.Extractor: - return wren_core.Extractor(manifest_str) +def get_manifest_extractor(manifest_str: str) -> wren_core.ManifestExtractor: + return wren_core.ManifestExtractor(manifest_str) -def to_json_base64(manifest): +def to_json_base64(manifest) -> str: return wren_core.to_json_base64(manifest) diff --git a/ibis-server/app/mdl/rewriter.py b/ibis-server/app/mdl/rewriter.py index 4edd39229..0a2fa073d 100644 --- a/ibis-server/app/mdl/rewriter.py +++ b/ibis-server/app/mdl/rewriter.py @@ -7,7 +7,7 @@ from app.config import get_config from app.mdl.core import ( - get_extractor, + get_manifest_extractor, get_session_context, to_json_base64, ) @@ -63,9 +63,9 @@ def __init__(self, manifest_str: str): def rewrite(self, sql: str) -> str: try: - extractor = get_extractor(self.manifest_str) + extractor = get_manifest_extractor(self.manifest_str) tables = extractor.resolve_used_table_names(sql) - manifest = extractor.extract_manifest(tables) + manifest = extractor.extract_by(tables) manifest_str = to_json_base64(manifest) r = httpx.request( method="GET", @@ -92,9 +92,9 @@ def __init__(self, manifest_str: str, function_path: str): def rewrite(self, sql: str) -> str: try: - extractor = get_extractor(self.manifest_str) + extractor = get_manifest_extractor(self.manifest_str) tables = extractor.resolve_used_table_names(sql) - manifest = extractor.extract_manifest(tables) + manifest = extractor.extract_by(tables) manifest_str = to_json_base64(manifest) session_context = get_session_context(manifest_str, self.function_path) return session_context.transform_sql(sql) diff --git a/wren-core-py/src/extractor.rs b/wren-core-py/src/extractor.rs index 636c8441e..674cce9a1 100644 --- a/wren-core-py/src/extractor.rs +++ b/wren-core-py/src/extractor.rs @@ -8,13 +8,13 @@ use wren_core::mdl::WrenMDL; #[pyclass] #[derive(Clone)] -#[pyo3(name = "Extractor")] -pub struct PyExtractor { +#[pyo3(name = "ManifestExtractor")] +pub struct PyManifestExtractor { mdl: Arc, } #[pymethods] -impl PyExtractor { +impl PyManifestExtractor { #[new] #[pyo3(signature = (mdl_base64=None))] pub fn new(mdl_base64: Option<&str>) -> Result { @@ -35,7 +35,7 @@ impl PyExtractor { /// If a model is related to another dataset, both datasets will be kept. /// The relationship between of them will be kept as well. /// A dataset could be model, view. - pub fn extract_manifest( + pub fn extract_by( &self, used_datasets: Vec, ) -> Result { @@ -161,7 +161,7 @@ fn extract_relationships( #[cfg(test)] mod tests { - use crate::extractor::PyExtractor; + use crate::extractor::PyManifestExtractor; use base64::prelude::BASE64_STANDARD; use base64::Engine; use rstest::{fixture, rstest}; @@ -240,8 +240,8 @@ mod tests { } #[fixture] - pub fn extractor(mdl_base64: String) -> PyExtractor { - PyExtractor::new(Option::from(mdl_base64.as_str())).unwrap() + pub fn extractor(mdl_base64: String) -> PyManifestExtractor { + PyManifestExtractor::new(Option::from(mdl_base64.as_str())).unwrap() } #[rstest] @@ -259,9 +259,7 @@ mod tests { #[case] value: Option<&str>, #[case] error_message: &str, ) { - let result = PyExtractor::new(value); - - match result { + match PyManifestExtractor::new(value) { Err(err) => { assert_eq!(err.to_string(), error_message); } @@ -277,7 +275,7 @@ mod tests { #[case("SELECT * FROM my_catalog.my_schema.customer JOIN my_catalog.my_schema.orders ON customer.custkey = orders.custkey", vec!["customer", "orders"])] #[case("SELECT * FROM my_catalog.my_schema.customer_view", vec!["customer_view"])] fn test_resolve_used_table_names( - extractor: PyExtractor, + extractor: PyManifestExtractor, #[case] sql: &str, #[case] expected: Vec<&str>, ) { @@ -285,25 +283,22 @@ mod tests { } #[rstest] - #[case(vec!["customer"], vec!["customer", "orders", "lineitem"])] - #[case(vec!["customer_view"], vec!["customer", "orders", "lineitem"])] - #[case(vec!["orders"], vec!["orders", "lineitem"])] - #[case(vec!["lineitem"], vec!["lineitem"])] + #[case(&["customer"], &["customer", "orders", "lineitem"])] + #[case(&["customer_view"], &["customer", "orders", "lineitem"])] + #[case(&["orders"], &["orders", "lineitem"])] + #[case(&["lineitem"], &["lineitem"])] fn test_extract_manifest( - extractor: PyExtractor, - #[case] dataset: Vec<&str>, - #[case] expected_models: Vec<&str>, + extractor: PyManifestExtractor, + #[case] dataset: &[&str], + #[case] expected_models: &[&str], ) { - let dataset_strings: Vec = - dataset.iter().map(|s| s.to_string()).collect(); - let extracted_manifest = extractor.extract_manifest(dataset_strings).unwrap(); - - assert_eq!(extracted_manifest.models.len(), expected_models.len()); assert_eq!( - extracted_manifest + extractor + .extract_by(dataset.iter().map(|s| s.to_string()).collect()) + .unwrap() .models .iter() - .map(|m| m.name.clone()) + .map(|m| m.name.as_str()) .collect::>(), expected_models ); diff --git a/wren-core-py/src/lib.rs b/wren-core-py/src/lib.rs index f8d00125e..30b81df6d 100644 --- a/wren-core-py/src/lib.rs +++ b/wren-core-py/src/lib.rs @@ -15,7 +15,7 @@ fn wren_core_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(manifest::to_json_base64, m)?)?; Ok(()) } diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index f506069e7..d8e0ac34b 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -4,7 +4,7 @@ import pytest from wren_core import ( - Extractor, + ManifestExtractor, SessionContext, to_json_base64, ) @@ -157,7 +157,7 @@ def test_get_available_functions(): ) def test_extractor_with_invalid_manifest(value, expected_error, error_message): with pytest.raises(expected_error) as e: - Extractor(value) + ManifestExtractor(value) assert str(e.value) == error_message @@ -176,7 +176,7 @@ def test_extractor_with_invalid_manifest(value, expected_error, error_message): ], ) def test_resolve_used_table_names(sql, expected): - tables = Extractor(manifest_str).resolve_used_table_names(sql) + tables = ManifestExtractor(manifest_str).resolve_used_table_names(sql) assert tables == expected @@ -189,14 +189,14 @@ def test_resolve_used_table_names(sql, expected): (["lineitem"], ["lineitem"]), ], ) -def test_extract_manifest(dataset, expected_models): - extracted_manifest = Extractor(manifest_str).extract_manifest(dataset) +def test_extract_by(dataset, expected_models): + extracted_manifest = ManifestExtractor(manifest_str).extract_by(dataset) assert len(extracted_manifest.models) == len(expected_models) assert [m.name for m in extracted_manifest.models] == expected_models def test_to_json_base64(): - extracted_manifest = Extractor(manifest_str).extract_manifest(["customer"]) + extracted_manifest = ManifestExtractor(manifest_str).extract_by(["customer"]) base64_str = to_json_base64(extracted_manifest) with does_not_raise(): json_str = base64.b64decode(base64_str) From 49240c63ee9c7b46a371a3ec530d4d0873e72db2 Mon Sep 17 00:00:00 2001 From: Grieve Date: Thu, 5 Dec 2024 15:21:12 +0800 Subject: [PATCH 22/30] chore: add comment --- wren-core-py/src/manifest.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/wren-core-py/src/manifest.rs b/wren-core-py/src/manifest.rs index b2187da20..e5d738649 100644 --- a/wren-core-py/src/manifest.rs +++ b/wren-core-py/src/manifest.rs @@ -9,6 +9,7 @@ use wren_core::mdl::manifest::{ Column, JoinType, Manifest, Metric, Model, Relationship, TimeGrain, TimeUnit, View, }; +/// Convert a manifest to a JSON string and then encode it as base64. #[pyfunction] pub fn to_json_base64(mdl: PyManifest) -> Result { let mdl_json = serde_json::to_string(&mdl)?; @@ -16,6 +17,7 @@ pub fn to_json_base64(mdl: PyManifest) -> Result { Ok(mdl_base64) } +/// Convert a base64 encoded JSON string to a manifest object. pub fn to_manifest(mdl_base64: &str) -> Result { let decoded_bytes = BASE64_STANDARD.decode(mdl_base64)?; let mdl_json = String::from_utf8(decoded_bytes)?; From 7b36d529bc12d19d172833f38b7d4333c8661dcb Mon Sep 17 00:00:00 2001 From: Grieve Date: Thu, 5 Dec 2024 15:21:36 +0800 Subject: [PATCH 23/30] chore: adjust test command --- wren-core-py/justfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wren-core-py/justfile b/wren-core-py/justfile index 16bfd0be3..84083fe53 100644 --- a/wren-core-py/justfile +++ b/wren-core-py/justfile @@ -10,8 +10,9 @@ build *args: develop: poetry run maturin develop -test: develop +test: cargo test --no-default-features + just develop poetry run pytest alias fmt := format From 06fc2dc9740739f053472e01bdcd86aa65c31bf3 Mon Sep 17 00:00:00 2001 From: Grieve Date: Thu, 5 Dec 2024 15:21:50 +0800 Subject: [PATCH 24/30] chore: address comment --- wren-core-py/src/extractor.rs | 229 ++++++++++++++++++---------------- wren-core-py/src/manifest.rs | 13 ++ 2 files changed, 136 insertions(+), 106 deletions(-) diff --git a/wren-core-py/src/extractor.rs b/wren-core-py/src/extractor.rs index 674cce9a1..cf7da5355 100644 --- a/wren-core-py/src/extractor.rs +++ b/wren-core-py/src/extractor.rs @@ -39,14 +39,11 @@ impl PyManifestExtractor { &self, used_datasets: Vec, ) -> Result { - extract_manifest(&self.mdl, used_datasets) + extract_manifest(&self.mdl, &used_datasets) } } -fn resolve_used_table_names( - mdl: &Arc, - sql: &str, -) -> Result, CoreError> { +fn resolve_used_table_names(mdl: &WrenMDL, sql: &str) -> Result, CoreError> { let ctx_state = wren_core::SessionContext::new().state(); ctx_state .sql_to_statement(sql, "generic") @@ -68,24 +65,25 @@ fn resolve_used_table_names( }) } -pub fn extract_manifest( - mdl: &Arc, - used_datasets: Vec, +fn extract_manifest( + mdl: &WrenMDL, + used_datasets: &[String], ) -> Result { - let used_models = extract_models(mdl, &used_datasets); - let (used_views, models_of_views) = extract_views(mdl, &used_datasets); - let used_relationships = extract_relationships(mdl, &used_datasets); + let extracted_models = extract_models(mdl, used_datasets); + let (used_views, models_of_views) = extract_views(mdl, used_datasets); + let used_models = [extracted_models, models_of_views].concat(); + let used_relationships = extract_relationships(mdl, &used_models); Ok(PyManifest { catalog: mdl.catalog().to_string(), schema: mdl.schema().to_string(), - models: [used_models, models_of_views].concat(), + models: used_models, relationships: used_relationships, metrics: mdl.metrics().to_vec(), views: used_views, }) } -fn extract_models(mdl: &Arc, used_datasets: &[String]) -> Vec> { +fn extract_models(mdl: &WrenMDL, used_datasets: &[String]) -> Vec> { let mut used_set: HashSet = used_datasets.iter().cloned().collect(); let mut stack: Vec = used_datasets.to_vec(); while let Some(dataset_name) = stack.pop() { @@ -111,7 +109,7 @@ fn extract_models(mdl: &Arc, used_datasets: &[String]) -> Vec, + mdl: &WrenMDL, used_datasets: &[String], ) -> (Vec>, Vec>) { let used_set: HashSet<&str> = used_datasets.iter().map(String::as_str).collect(); @@ -138,23 +136,13 @@ fn extract_views( } fn extract_relationships( - mdl: &Arc, - used_datasets: &[String], + mdl: &WrenMDL, + used_models: &[Arc], ) -> Vec> { - let mut used_set: HashSet = used_datasets.iter().cloned().collect(); - let mut stack: Vec = used_datasets.to_vec(); - while let Some(dataset_name) = stack.pop() { - if let Some(relationship) = mdl.get_relationship(&dataset_name) { - for model in &relationship.models { - if used_set.insert(model.clone()) { - stack.push(model.clone()); - } - } - } - } + let model_names: Vec<_> = used_models.iter().map(|m| m.name.as_str()).collect(); mdl.relationships() .iter() - .filter(|rel| rel.models.iter().any(|model| used_set.contains(model))) + .filter(|rel| rel.models.iter().any(|m| model_names.contains(&m.as_str()))) .cloned() .collect() } @@ -162,81 +150,65 @@ fn extract_relationships( #[cfg(test)] mod tests { use crate::extractor::PyManifestExtractor; - use base64::prelude::BASE64_STANDARD; - use base64::Engine; + use crate::manifest::{to_json_base64, PyManifest}; use rstest::{fixture, rstest}; use std::iter::Iterator; + use wren_core::mdl::builder::{ + ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder, ViewBuilder, + }; + use wren_core::mdl::manifest::JoinType; #[fixture] pub fn mdl_base64() -> String { - let mdl_json = r#" - { - "catalog": "my_catalog", - "schema": "my_schema", - "models": [ - { - "name": "customer", - "tableReference": { - "schema": "main", - "table": "customer" - }, - "columns": [ - {"name": "c_custkey", "type": "integer"}, - {"name": "orders", "type": "orders", "relationship": "orders_customer"} - ], - "primaryKey": "c_custkey" - }, - { - "name": "orders", - "tableReference": { - "schema": "main", - "table": "orders" - }, - "columns": [ - {"name": "o_orderkey", "type": "integer"}, - {"name": "o_custkey", "type": "integer"}, - { - "name": "lineitems", - "type": "Lineitem", - "relationship": "orders_lineitem" - } - ], - "primaryKey": "o_orderkey" - }, - { - "name": "lineitem", - "tableReference": { - "schema": "main", - "table": "lineitem" - }, - "columns": [ - {"name": "l_orderkey", "type": "integer"} - ], - "primaryKey": "l_orderkey" - } - ], - "relationships": [ - { - "name": "orders_customer", - "models": ["orders", "customer"], - "joinType": "MANY_TO_ONE", - "condition": "orders.custkey = customer.custkey" - }, - { - "name": "orders_lineitem", - "models": ["orders", "lineitem"], - "joinType": "ONE_TO_MANY", - "condition": "orders.orderkey = lineitem.orderkey" - } - ], - "views": [ - { - "name": "customer_view", - "statement": "SELECT * FROM my_catalog.my_schema.customer" - } - ] - }"#; - BASE64_STANDARD.encode(mdl_json.as_bytes()) + let customer = ModelBuilder::new("customer") + .table_reference("main.customer") + .column(ColumnBuilder::new("c_custkey", "integer").build()) + .column( + ColumnBuilder::new("orders", "orders") + .relationship("customer_orders") + .build(), + ) + .build(); + let orders = ModelBuilder::new("orders") + .table_reference("main.orders") + .column(ColumnBuilder::new("o_orderkey", "integer").build()) + .column(ColumnBuilder::new("o_custkey", "integer").build()) + .column( + ColumnBuilder::new("lineitems", "Lineitem") + .relationship("orders_lineitem") + .build(), + ) + .build(); + let lineitem = ModelBuilder::new("lineitem") + .table_reference("main.lineitem") + .column(ColumnBuilder::new("l_orderkey", "integer").build()) + .build(); + let c_o_relationship = RelationshipBuilder::new("customer_orders") + .model("customer") + .model("orders") + .join_type(JoinType::OneToMany) + .condition("customer.custkey = orders.custkey") + .build(); + let o_l_relationship = RelationshipBuilder::new("orders_lineitem") + .model("orders") + .model("lineitem") + .join_type(JoinType::OneToMany) + .condition("orders.orderkey = lineitem.orderkey") + .build(); + let c_view = ViewBuilder::new("customer_view") + .statement("SELECT * FROM my_catalog.my_schema.customer") + .build(); + let manifest = ManifestBuilder::new() + .catalog("my_catalog") + .schema("my_schema") + .model(customer) + .model(orders) + .model(lineitem) + .relationship(c_o_relationship) + .relationship(o_l_relationship) + .view(c_view) + .build(); + to_json_base64(PyManifest::from(&manifest)).unwrap() } #[fixture] @@ -268,16 +240,17 @@ mod tests { } #[rstest] - #[case("SELECT * FROM customer", vec!["customer"])] - #[case("SELECT * FROM not_my_catalog.my_schema.customer", vec![])] - #[case("SELECT * FROM my_catalog.not_my_schema.customer", vec![])] - #[case("SELECT * FROM my_catalog.my_schema.customer", vec!["customer"])] - #[case("SELECT * FROM my_catalog.my_schema.customer JOIN my_catalog.my_schema.orders ON customer.custkey = orders.custkey", vec!["customer", "orders"])] - #[case("SELECT * FROM my_catalog.my_schema.customer_view", vec!["customer_view"])] + #[case("SELECT * FROM customer", &["customer"])] + #[case("SELECT * FROM not_my_catalog.my_schema.customer", &[])] + #[case("SELECT * FROM my_catalog.not_my_schema.customer", &[])] + #[case("SELECT * FROM my_catalog.my_schema.customer", &["customer"])] + #[case("SELECT * FROM my_catalog.my_schema.customer JOIN my_catalog.my_schema.orders ON customer.custkey = orders.custkey", &["customer", "orders"])] + #[case("SELECT * FROM my_catalog.my_schema.customer_view", &["customer_view"])] + #[case("WITH t1 as (select * from customer) select * from t1", &["customer"])] fn test_resolve_used_table_names( extractor: PyManifestExtractor, #[case] sql: &str, - #[case] expected: Vec<&str>, + #[case] expected: &[&str], ) { assert_eq!(extractor.resolve_used_table_names(sql).unwrap(), expected); } @@ -287,7 +260,7 @@ mod tests { #[case(&["customer_view"], &["customer", "orders", "lineitem"])] #[case(&["orders"], &["orders", "lineitem"])] #[case(&["lineitem"], &["lineitem"])] - fn test_extract_manifest( + fn test_extract_manifest_for_models( extractor: PyManifestExtractor, #[case] dataset: &[&str], #[case] expected_models: &[&str], @@ -303,4 +276,48 @@ mod tests { expected_models ); } + + #[rstest] + #[case(&["customer"], &["customer_orders", "orders_lineitem"])] + #[case(&["customer_view"], &["customer_orders", "orders_lineitem"])] + #[case(&["orders"], &["customer_orders", "orders_lineitem"])] + #[case(&["lineitem"], &["orders_lineitem"])] + fn test_extract_manifest_for_relationships( + extractor: PyManifestExtractor, + #[case] dataset: &[&str], + #[case] expected_relationships: &[&str], + ) { + assert_eq!( + extractor + .extract_by(dataset.iter().map(|s| s.to_string()).collect()) + .unwrap() + .relationships + .iter() + .map(|r| r.name.as_str()) + .collect::>(), + expected_relationships + ); + } + + #[rstest] + #[case(&["customer_view"], &["customer_view"])] + #[case(&["customer"], &[])] + #[case(&["orders"], &[])] + #[case(&["lineitem"], &[])] + fn test_extract_manifest_for_view( + extractor: PyManifestExtractor, + #[case] dataset: &[&str], + #[case] expected_views: &[&str], + ) { + assert_eq!( + extractor + .extract_by(dataset.iter().map(|s| s.to_string()).collect()) + .unwrap() + .views + .iter() + .map(|v| v.name.as_str()) + .collect::>(), + expected_views + ); + } } diff --git a/wren-core-py/src/manifest.rs b/wren-core-py/src/manifest.rs index e5d738649..b272895c3 100644 --- a/wren-core-py/src/manifest.rs +++ b/wren-core-py/src/manifest.rs @@ -85,6 +85,19 @@ impl PyManifest { } } +impl From<&Manifest> for PyManifest { + fn from(manifest: &Manifest) -> Self { + Self { + catalog: manifest.catalog.clone(), + schema: manifest.schema.clone(), + models: manifest.models.clone(), + relationships: manifest.relationships.clone(), + metrics: manifest.metrics.clone(), + views: manifest.views.clone(), + } + } +} + #[pyclass(name = "Model")] #[derive(Serialize, Deserialize, Debug)] pub struct PyModel { From 754ac5915d15aefecf513cf56aeb796ff58993d1 Mon Sep 17 00:00:00 2001 From: Grieve Date: Thu, 5 Dec 2024 15:32:16 +0800 Subject: [PATCH 25/30] chore: add test case --- wren-core-py/src/extractor.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/wren-core-py/src/extractor.rs b/wren-core-py/src/extractor.rs index cf7da5355..4eede382b 100644 --- a/wren-core-py/src/extractor.rs +++ b/wren-core-py/src/extractor.rs @@ -247,6 +247,7 @@ mod tests { #[case("SELECT * FROM my_catalog.my_schema.customer JOIN my_catalog.my_schema.orders ON customer.custkey = orders.custkey", &["customer", "orders"])] #[case("SELECT * FROM my_catalog.my_schema.customer_view", &["customer_view"])] #[case("WITH t1 as (select * from customer) select * from t1", &["customer"])] + #[case("WITH customer as (select * from customer) select * from customer", &["customer"])] fn test_resolve_used_table_names( extractor: PyManifestExtractor, #[case] sql: &str, From 278f4355c15d81619139d39da281404df7d6c2d5 Mon Sep 17 00:00:00 2001 From: Grieve Date: Thu, 5 Dec 2024 15:34:05 +0800 Subject: [PATCH 26/30] chore: add test case --- wren-core-py/src/extractor.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/wren-core-py/src/extractor.rs b/wren-core-py/src/extractor.rs index 4eede382b..c3bde3689 100644 --- a/wren-core-py/src/extractor.rs +++ b/wren-core-py/src/extractor.rs @@ -248,6 +248,8 @@ mod tests { #[case("SELECT * FROM my_catalog.my_schema.customer_view", &["customer_view"])] #[case("WITH t1 as (select * from customer) select * from t1", &["customer"])] #[case("WITH customer as (select * from customer) select * from customer", &["customer"])] + #[case("SELECT * from (select * from customer) as t1", &["customer"])] + #[case("SELECT * from (select * from customer) as customer", &["customer"])] fn test_resolve_used_table_names( extractor: PyManifestExtractor, #[case] sql: &str, From cce9360b2f0b155bffde6e12345de86f561c7719 Mon Sep 17 00:00:00 2001 From: Grieve Date: Thu, 5 Dec 2024 15:43:59 +0800 Subject: [PATCH 27/30] chore: split test command --- wren-core-py/README.md | 12 +++++++++--- wren-core-py/justfile | 7 +++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/wren-core-py/README.md b/wren-core-py/README.md index c51405f72..995a5e005 100644 --- a/wren-core-py/README.md +++ b/wren-core-py/README.md @@ -1,18 +1,24 @@ # Wren Core in Python + Here is a dependency package for Python. It is a wrapper for the Rust package [wren-core](../wren-core). The Rust package is compiled to a Python package and can be used in Python. ## Developer Guide ### Environment Setup + - Install [Rust](https://www.rust-lang.org/tools/install) and [Cargo](https://doc.rust-lang.org/cargo/getting-started/installation.html) - Install [Python](https://www.python.org/downloads/) and [pipx](https://pipx.pypa.io/) - Install [poetry](https://github.com/python-poetry/poetry) - Install [casey/just](https://github.com/casey/just) ### Test and build -- Execute `just install` to create python venv and install dependencies. -- Execute `just test` to test Rust and Python. + +- Execute `just install` to create Python venv and install dependencies. +- **Important**: Before testing Python, you need to build the Rust package by running `just develop`. +- Use `just test-r` to test Rust only, and `just test-py` to test Python only. +- Use `just test` to test Rust and Python. - Execute `just build` to build the Python package. You can find the wheel in the `target/wheels/` directory. ### Coding Style -Format with rustfmt via `cargo fmt` + +Format via `just format` diff --git a/wren-core-py/justfile b/wren-core-py/justfile index 84083fe53..243987567 100644 --- a/wren-core-py/justfile +++ b/wren-core-py/justfile @@ -10,11 +10,14 @@ build *args: develop: poetry run maturin develop -test: +test-r: cargo test --no-default-features - just develop + +test-py: poetry run pytest +test: test-r test-py + alias fmt := format format: From 352f614acd40e8f0ccc63b6fdc8e4798b4888863 Mon Sep 17 00:00:00 2001 From: Grieve Date: Mon, 9 Dec 2024 10:41:58 +0800 Subject: [PATCH 28/30] chore: rename command to `test-rs` --- wren-core-py/README.md | 2 +- wren-core-py/justfile | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/wren-core-py/README.md b/wren-core-py/README.md index 995a5e005..3329fe32d 100644 --- a/wren-core-py/README.md +++ b/wren-core-py/README.md @@ -15,7 +15,7 @@ Here is a dependency package for Python. It is a wrapper for the Rust package [w - Execute `just install` to create Python venv and install dependencies. - **Important**: Before testing Python, you need to build the Rust package by running `just develop`. -- Use `just test-r` to test Rust only, and `just test-py` to test Python only. +- Use `just test-rs` to test Rust only, and `just test-py` to test Python only. - Use `just test` to test Rust and Python. - Execute `just build` to build the Python package. You can find the wheel in the `target/wheels/` directory. diff --git a/wren-core-py/justfile b/wren-core-py/justfile index 243987567..b92ce90ee 100644 --- a/wren-core-py/justfile +++ b/wren-core-py/justfile @@ -10,13 +10,13 @@ build *args: develop: poetry run maturin develop -test-r: +test-rs: cargo test --no-default-features test-py: poetry run pytest -test: test-r test-py +test: test-rs test-py alias fmt := format From afe6497ebaab4051ec2ff6d9cfd78797dbdbd8d4 Mon Sep 17 00:00:00 2001 From: Grieve Date: Mon, 9 Dec 2024 13:07:35 +0800 Subject: [PATCH 29/30] chore: use HashMap instead of HashSet --- wren-core-py/src/extractor.rs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/wren-core-py/src/extractor.rs b/wren-core-py/src/extractor.rs index c3bde3689..6a5498cc5 100644 --- a/wren-core-py/src/extractor.rs +++ b/wren-core-py/src/extractor.rs @@ -1,7 +1,8 @@ use crate::errors::CoreError; use crate::manifest::{to_manifest, PyManifest}; use pyo3::{pyclass, pymethods}; -use std::collections::HashSet; +use std::collections::hash_map::Entry; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use wren_core::mdl::manifest::{Model, Relationship, View}; use wren_core::mdl::WrenMDL; @@ -84,7 +85,8 @@ fn extract_manifest( } fn extract_models(mdl: &WrenMDL, used_datasets: &[String]) -> Vec> { - let mut used_set: HashSet = used_datasets.iter().cloned().collect(); + let mut used_set: HashMap = + used_datasets.iter().map(|s| (s.clone(), 0)).collect(); let mut stack: Vec = used_datasets.to_vec(); while let Some(dataset_name) = stack.pop() { if let Some(model) = mdl.get_model(&dataset_name) { @@ -97,13 +99,18 @@ fn extract_models(mdl: &WrenMDL, used_datasets: &[String]) -> Vec> { .and_then(|rel_name| mdl.get_relationship(rel_name)) }) .flat_map(|rel| rel.models.clone()) - .filter(|related| used_set.insert(related.clone())) - .for_each(|related| stack.push(related)); + .for_each(|related| { + if let Entry::Vacant(vacant) = used_set.entry(related) { + let key = vacant.key().clone(); + vacant.insert(0); + stack.push(key); + } + }); } } mdl.models() .iter() - .filter(|model| used_set.contains(model.name())) + .filter(|model| used_set.contains_key(model.name())) .cloned() .collect() } From e6c846717c05c6acb7d594f4a17c854741361d5f Mon Sep 17 00:00:00 2001 From: Grieve Date: Mon, 9 Dec 2024 13:07:52 +0800 Subject: [PATCH 30/30] chore: remove unused stack --- wren-core-py/src/extractor.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/wren-core-py/src/extractor.rs b/wren-core-py/src/extractor.rs index 6a5498cc5..deac71963 100644 --- a/wren-core-py/src/extractor.rs +++ b/wren-core-py/src/extractor.rs @@ -120,8 +120,7 @@ fn extract_views( used_datasets: &[String], ) -> (Vec>, Vec>) { let used_set: HashSet<&str> = used_datasets.iter().map(String::as_str).collect(); - let stack: Vec<&str> = used_datasets.iter().map(String::as_str).collect(); - let models = stack + let models = used_set .iter() .filter_map(|&dataset_name| { mdl.get_view(dataset_name).and_then(|view| {