From 8431723f33b5db42cc7199486b2ef27c8b35981a Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Mon, 24 Jun 2024 14:45:41 +0800 Subject: [PATCH] Introduce the calculation for TO_MANY relationship (#626) * upgrade to datafusio v39.0.0 * fix the 39.0.0 change * introduce to many calculation * add an example * stop panic and use DataFusionError instead * change the sqllogictests flow * stop panic and unwrap directly * fix count(*) aggregation for model and registered physical table * fmt and clippy * add one-one-many test case * fmt and clippy * fix test * update cargo lock * fmt and fix test --- wren-modeling-py/Cargo.lock | 495 ++++-------- wren-modeling-py/src/errors.rs | 2 +- wren-modeling-py/src/lib.rs | 26 +- wren-modeling-py/tests/test_modeling_core.py | 2 +- wren-modeling-rs/Cargo.toml | 3 +- wren-modeling-rs/core/Cargo.toml | 1 - .../core/src/logical_plan/analyze/plan.rs | 741 +++++++++++++----- .../core/src/logical_plan/analyze/rule.rs | 82 +- .../core/src/logical_plan/context_provider.rs | 87 +- .../core/src/logical_plan/utils.rs | 32 +- wren-modeling-rs/core/src/mdl/builder.rs | 4 +- wren-modeling-rs/core/src/mdl/lineage.rs | 326 ++++---- wren-modeling-rs/core/src/mdl/manifest.rs | 19 +- wren-modeling-rs/core/src/mdl/mod.rs | 57 +- wren-modeling-rs/core/src/mdl/utils.rs | 77 +- wren-modeling-rs/core/tests/data/mdl.json | 44 ++ .../sqllogictest/bin/sqllogictests.rs | 7 +- .../sqllogictest/src/engine/runner.rs | 10 +- wren-modeling-rs/sqllogictest/src/lib.rs | 2 +- .../sqllogictest/src/test_context.rs | 74 +- .../sqllogictest/test_sql_files/model.slt | 5 + wren-modeling-rs/wren-example/Cargo.toml | 1 - .../wren-example/examples/datafusion-apply.rs | 21 +- .../wren-example/examples/plan-sql.rs | 2 +- .../examples/to-many-calculation.rs | 248 ++++++ 25 files changed, 1440 insertions(+), 928 deletions(-) create mode 100644 wren-modeling-rs/wren-example/examples/to-many-calculation.rs diff --git a/wren-modeling-py/Cargo.lock b/wren-modeling-py/Cargo.lock index ef2e139d9..c2c3a18bf 100644 --- a/wren-modeling-py/Cargo.lock +++ b/wren-modeling-py/Cargo.lock @@ -139,9 +139,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "219d05930b81663fd3b32e3bde8ce5bff3c4d23052a99f11a8fa50a3b47b2658" +checksum = "7ae9728f104939be6d8d9b368a354b4929b0569160ea1641f0721b55a861ce38" dependencies = [ "arrow-arith", "arrow-array", @@ -160,9 +160,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0272150200c07a86a390be651abdd320a2d12e84535f0837566ca87ecd8f95e0" +checksum = "a7029a5b3efbeafbf4a12d12dc16b8f9e9bff20a410b8c25c5d28acc089e1043" dependencies = [ "arrow-array", "arrow-buffer", @@ -175,9 +175,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8010572cf8c745e242d1b632bd97bd6d4f40fefed5ed1290a8f433abaa686fea" +checksum = "d33238427c60271710695f17742f45b1a5dc5bcfc5c15331c25ddfe7abf70d97" dependencies = [ "ahash", "arrow-buffer", @@ -192,9 +192,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d0a2432f0cba5692bf4cb757469c66791394bac9ec7ce63c1afe74744c37b27" +checksum = "fe9b95e825ae838efaf77e366c00d3fc8cca78134c9db497d6bda425f2e7b7c1" dependencies = [ "bytes", "half", @@ -203,9 +203,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9abc10cd7995e83505cc290df9384d6e5412b207b79ce6bdff89a10505ed2cba" +checksum = "87cf8385a9d5b5fcde771661dd07652b79b9139fea66193eda6a88664400ccab" dependencies = [ "arrow-array", "arrow-buffer", @@ -224,9 +224,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95cbcba196b862270bf2a5edb75927380a7f3a163622c61d40cbba416a6305f2" +checksum = "cea5068bef430a86690059665e40034625ec323ffa4dd21972048eebb0127adc" dependencies = [ "arrow-array", "arrow-buffer", @@ -243,9 +243,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2742ac1f6650696ab08c88f6dd3f0eb68ce10f8c253958a18c943a68cd04aec5" +checksum = "cb29be98f987bcf217b070512bb7afba2f65180858bca462edf4a39d84a23e10" dependencies = [ "arrow-buffer", "arrow-schema", @@ -255,9 +255,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a42ea853130f7e78b9b9d178cb4cd01dee0f78e64d96c2949dc0a915d6d9e19d" +checksum = "ffc68f6523970aa6f7ce1dc9a33a7d9284cfb9af77d4ad3e617dbe5d79cc6ec8" dependencies = [ "arrow-array", "arrow-buffer", @@ -270,9 +270,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaafb5714d4e59feae964714d724f880511500e3569cc2a94d02456b403a2a49" +checksum = "2041380f94bd6437ab648e6c2085a045e45a0c44f91a1b9a4fe3fed3d379bfb1" dependencies = [ "arrow-array", "arrow-buffer", @@ -290,9 +290,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3e6b61e3dc468f503181dccc2fc705bdcc5f2f146755fa5b56d0a6c5943f412" +checksum = "fcb56ed1547004e12203652f12fe12e824161ff9d1e5cf2a7dc4ff02ba94f413" dependencies = [ "arrow-array", "arrow-buffer", @@ -305,9 +305,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "848ee52bb92eb459b811fb471175ea3afcf620157674c8794f539838920f9228" +checksum = "575b42f1fc588f2da6977b94a5ca565459f5ab07b60545e17243fb9a7ed6d43e" dependencies = [ "ahash", "arrow-array", @@ -320,15 +320,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02d9483aaabe910c4781153ae1b6ae0393f72d9ef757d38d09d450070cf2e528" +checksum = "32aae6a60458a2389c0da89c9de0b7932427776127da1a738e2efc21d32f3393" [[package]] name = "arrow-select" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "849524fa70e0e3c5ab58394c770cb8f514d0122d20de08475f7b472ed8075830" +checksum = "de36abaef8767b4220d7b4a8c2fe5ffc78b47db81b03d77e2136091c3ba39102" dependencies = [ "ahash", "arrow-array", @@ -340,9 +340,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9373cb5a021aee58863498c37eb484998ef13377f69989c6c5ccfbd258236cdb" +checksum = "e435ada8409bcafc910bc3e0077f532a4daa20e99060a496685c0e3e53cc2597" dependencies = [ "arrow-array", "arrow-buffer", @@ -381,7 +381,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -465,9 +465,9 @@ dependencies = [ [[package]] name = "brotli" -version = "3.5.0" +version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" +checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -476,9 +476,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "2.5.1" +version = "4.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" +checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -554,9 +554,9 @@ dependencies = [ [[package]] name = "chrono-tz" -version = "0.8.6" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d59ae0466b83e838b81a54256c39d5d7c20b9d7daa10510a242d9b75abd5936e" +checksum = "93698b29de5e97ad0ae26447b344c482a7284c737d9ddc5f9e52b74a336671bb" dependencies = [ "chrono", "chrono-tz-build", @@ -565,9 +565,9 @@ dependencies = [ [[package]] name = "chrono-tz-build" -version = "0.2.1" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433e39f13c9a060046954e0592a8d0a4bcb1040125cbf91cb8ee58964cfb350f" +checksum = "0c088aee841df9c3041febbb73934cfc39708749bf96dc827e3359cd39ef11b1" dependencies = [ "parse-zoneinfo", "phf", @@ -693,9 +693,9 @@ dependencies = [ [[package]] name = "datafusion" -version = "38.0.0" +version = "39.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05fb4eeeb7109393a0739ac5b8fd892f95ccef691421491c85544f7997366f68" +checksum = "2f92d2d7a9cba4580900b32b009848d9eb35f1028ac84cdd6ddcf97612cd0068" dependencies = [ "ahash", "arrow", @@ -717,6 +717,7 @@ dependencies = [ "datafusion-functions-array", "datafusion-optimizer", "datafusion-physical-expr", + "datafusion-physical-expr-common", "datafusion-physical-plan", "datafusion-sql", "flate2", @@ -731,6 +732,7 @@ dependencies = [ "object_store", "parking_lot", "parquet", + "paste", "pin-project-lite", "rand", "sqlparser", @@ -745,9 +747,9 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "38.0.0" +version = "39.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "741aeac15c82f239f2fc17deccaab19873abbd62987be20023689b15fa72fa09" +checksum = "effed030d2c1667eb1e11df5372d4981eaf5d11a521be32220b3985ae5ba6971" dependencies = [ "ahash", "arrow", @@ -756,6 +758,7 @@ dependencies = [ "arrow-schema", "chrono", "half", + "hashbrown", "instant", "libc", "num_cpus", @@ -766,18 +769,18 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "38.0.0" +version = "39.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e8ddfb8d8cb51646a30da0122ecfffb81ca16919ae9a3495a9e7468bdcd52b8" +checksum = "d0091318129dad1359f08e4c6c71f855163c35bba05d1dbf983196f727857894" dependencies = [ "tokio", ] [[package]] name = "datafusion-execution" -version = "38.0.0" +version = "39.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282122f90b20e8f98ebfa101e4bf20e718fd2684cf81bef4e8c6366571c64404" +checksum = "8385aba84fc4a06d3ebccfbcbf9b4f985e80c762fac634b49079f7cc14933fb1" dependencies = [ "arrow", "chrono", @@ -796,13 +799,14 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "38.0.0" +version = "39.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5478588f733df0dfd87a62671c7478f590952c95fa2fa5c137e3ff2929491e22" +checksum = "ebb192f0055d2ce64e38ac100abc18e4e6ae9734d3c28eee522bbbd6a32108a3" dependencies = [ "ahash", "arrow", "arrow-array", + "arrow-buffer", "chrono", "datafusion-common", "paste", @@ -814,9 +818,9 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "38.0.0" +version = "39.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4afd261cea6ac9c3ca1192fd5e9f940596d8e9208c5b1333f4961405db53185" +checksum = "27c081ae5b7edd712b92767fb8ed5c0e32755682f8075707666cd70835807c0b" dependencies = [ "arrow", "base64", @@ -841,11 +845,13 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "38.0.0" +version = "39.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b36a6c4838ab94b5bf8f7a96ce6ce059d805c5d1dcaa6ace49e034eb65cd999" +checksum = "feb28a4ea52c28a26990646986a27c4052829a2a2572386258679e19263f8b78" dependencies = [ + "ahash", "arrow", + "arrow-schema", "datafusion-common", "datafusion-execution", "datafusion-expr", @@ -857,9 +863,9 @@ dependencies = [ [[package]] name = "datafusion-functions-array" -version = "38.0.0" +version = "39.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5fdd200a6233f48d3362e7ccb784f926f759100e44ae2137a5e2dcb986a59c4" +checksum = "89b17c02a74cdc87380a56758ec27e7d417356bf806f33062700908929aedb8a" dependencies = [ "arrow", "arrow-array", @@ -877,9 +883,9 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "38.0.0" +version = "39.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54f2820938810e8a2d71228fd6f59f33396aebc5f5f687fcbf14de5aab6a7e1a" +checksum = "12172f2a6c9eb4992a51e62d709eeba5dedaa3b5369cce37ff6c2260e100ba76" dependencies = [ "arrow", "async-trait", @@ -896,9 +902,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "38.0.0" +version = "39.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9adf8eb12716f52ddf01e09eb6c94d3c9b291e062c05c91b839a448bddba2ff8" +checksum = "7a3fce531b623e94180f6cd33d620ef01530405751b6ddd2fd96250cdbd78e2e" dependencies = [ "ahash", "arrow", @@ -927,20 +933,21 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "38.0.0" +version = "39.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d5472c3230584c150197b3f2c23f2392b9dc54dbfb62ad41e7e36447cfce4be" +checksum = "046400b6a2cc3ed57a7c576f5ae6aecc77804ac8e0186926b278b189305b2a77" dependencies = [ "arrow", "datafusion-common", "datafusion-expr", + "rand", ] [[package]] name = "datafusion-physical-plan" -version = "38.0.0" +version = "39.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18ae750c38389685a8b62e5b899bbbec488950755ad6d218f3662d35b800c4fe" +checksum = "4aed47f5a2ad8766260befb375b201592e86a08b260256e168ae4311426a2bff" dependencies = [ "ahash", "arrow", @@ -972,9 +979,9 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "38.0.0" +version = "39.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "befc67a3cdfbfa76853f43b10ac27337821bb98e519ab6baf431fcc0bcfcafdb" +checksum = "7fa92bb1fd15e46ce5fb6f1c85f3ac054592560f294429a28e392b5f9cd4255e" dependencies = [ "arrow", "arrow-array", @@ -982,6 +989,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "log", + "regex", "sqlparser", "strum", ] @@ -997,17 +1005,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "displaydoc" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "487585f4d0c6655fe74905e2504d8ad6908e4db67f744eb140876906c2f3175d" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.66", -] - [[package]] name = "doc-comment" version = "0.3.3" @@ -1073,9 +1070,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flatbuffers" -version = "23.5.26" +version = "24.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dac53e22462d78c16d64a1cd22371b54cc3fe94aa15e7886a2fa6e5d1ab8640" +checksum = "8add37afff2d4ffa83bc748a70b4b1370984f6980768554182424ef71447c35f" dependencies = [ "bitflags 1.3.2", "rustc_version", @@ -1156,7 +1153,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -1296,134 +1293,14 @@ dependencies = [ "cc", ] -[[package]] -name = "icu_collections" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" -dependencies = [ - "displaydoc", - "yoke", - "zerofrom", - "zerovec", -] - -[[package]] -name = "icu_locid" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" -dependencies = [ - "displaydoc", - "litemap", - "tinystr", - "writeable", - "zerovec", -] - -[[package]] -name = "icu_locid_transform" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_locid_transform_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_locid_transform_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" - -[[package]] -name = "icu_normalizer" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" -dependencies = [ - "displaydoc", - "icu_collections", - "icu_normalizer_data", - "icu_properties", - "icu_provider", - "smallvec", - "utf16_iter", - "utf8_iter", - "write16", - "zerovec", -] - -[[package]] -name = "icu_normalizer_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" - -[[package]] -name = "icu_properties" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f8ac670d7422d7f76b32e17a5db556510825b29ec9154f235977c9caba61036" -dependencies = [ - "displaydoc", - "icu_collections", - "icu_locid_transform", - "icu_properties_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_properties_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" - -[[package]] -name = "icu_provider" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_provider_macros", - "stable_deref_trait", - "tinystr", - "writeable", - "yoke", - "zerofrom", - "zerovec", -] - -[[package]] -name = "icu_provider_macros" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.66", -] - [[package]] name = "idna" -version = "1.0.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4716a3a0933a1d01c2f72450e89596eb51dd34ef3c211ccd875acdf1f8fe47ed" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" dependencies = [ - "icu_normalizer", - "icu_properties", - "smallvec", - "utf8_iter", + "unicode-bidi", + "unicode-normalization", ] [[package]] @@ -1587,12 +1464,6 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" -[[package]] -name = "litemap" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" - [[package]] name = "lock_api" version = "0.4.12" @@ -1656,9 +1527,9 @@ dependencies = [ [[package]] name = "miniz_oxide" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87dfd01fe195c66b572b37921ad8803d010623c0aca821bea2302239d155cdae" +checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" dependencies = [ "adler", ] @@ -1758,9 +1629,9 @@ dependencies = [ [[package]] name = "object_store" -version = "0.9.1" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8718f8b65fdf67a45108d1548347d4af7d71fb81ce727bbf9e3b2535e079db3" +checksum = "fbebfd32c213ba1907fa7a9c9138015a8de2b43e30c5aa45b18f7deb46786ad6" dependencies = [ "async-trait", "bytes", @@ -1817,9 +1688,9 @@ dependencies = [ [[package]] name = "parquet" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "096795d4f47f65fd3ee1ec5a98b77ab26d602f2cc785b0e4be5443add17ecc32" +checksum = "29c3b5322cc1bbf67f11c079c42be41a55949099b78732f7dba9e15edde40eab" dependencies = [ "ahash", "arrow-array", @@ -1848,6 +1719,7 @@ dependencies = [ "tokio", "twox-hash", "zstd", + "zstd-sys", ] [[package]] @@ -1961,9 +1833,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.85" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" dependencies = [ "unicode-ident", ] @@ -2016,7 +1888,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -2029,7 +1901,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -2082,9 +1954,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" +checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" dependencies = [ "bitflags 2.5.0", ] @@ -2202,7 +2074,7 @@ checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -2278,9 +2150,9 @@ checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "sqlparser" -version = "0.45.0" +version = "0.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7bbffee862a796d67959a89859d6b1046bb5016d63e23835ad0da182777bbe0" +checksum = "295e9930cd7a97e58ca2a070541a3ca502b17f5d1fa7157376d0fabd85324f25" dependencies = [ "log", "sqlparser_derive", @@ -2294,15 +2166,9 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] -[[package]] -name = "stable_deref_trait" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" - [[package]] name = "static_assertions" version = "1.1.0" @@ -2328,14 +2194,14 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] name = "subtle" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +checksum = "0d0208408ba0c3df17ed26eb06992cb1a1268d41b2c0e12e65203fbe3972cee5" [[package]] name = "syn" @@ -2350,26 +2216,15 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.66" +version = "2.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" +checksum = "ff8655ed1d86f3af4ee3fd3263786bc14245ad17c4c7e85ba7187fb3ae028c90" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] -[[package]] -name = "synstructure" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.66", -] - [[package]] name = "target-lexicon" version = "0.12.14" @@ -2405,7 +2260,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -2429,15 +2284,20 @@ dependencies = [ ] [[package]] -name = "tinystr" -version = "0.7.6" +name = "tinyvec" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" dependencies = [ - "displaydoc", - "zerovec", + "tinyvec_macros", ] +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.38.0" @@ -2459,7 +2319,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -2494,7 +2354,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -2522,12 +2382,27 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "unicode-bidi" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" + [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-normalization" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +dependencies = [ + "tinyvec", +] + [[package]] name = "unicode-segmentation" version = "1.11.0" @@ -2548,27 +2423,15 @@ checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" [[package]] name = "url" -version = "2.5.1" +version = "2.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7c25da092f0a868cdf09e8674cd3b7ef3a7d92a24253e663a2fb85e2496de56" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" dependencies = [ "form_urlencoded", "idna", "percent-encoding", ] -[[package]] -name = "utf16_iter" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" - -[[package]] -name = "utf8_iter" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" - [[package]] name = "utf8parse" version = "0.2.2" @@ -2627,7 +2490,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", "wasm-bindgen-shared", ] @@ -2649,7 +2512,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2765,7 +2628,6 @@ checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" name = "wren-core" version = "0.1.0" dependencies = [ - "arrow-schema", "async-trait", "datafusion", "env_logger", @@ -2789,18 +2651,6 @@ dependencies = [ "wren-core", ] -[[package]] -name = "write16" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" - -[[package]] -name = "writeable" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" - [[package]] name = "xz2" version = "0.1.7" @@ -2810,30 +2660,6 @@ dependencies = [ "lzma-sys", ] -[[package]] -name = "yoke" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" -dependencies = [ - "serde", - "stable_deref_trait", - "yoke-derive", - "zerofrom", -] - -[[package]] -name = "yoke-derive" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.66", - "synstructure", -] - [[package]] name = "zerocopy" version = "0.7.34" @@ -2851,75 +2677,32 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", -] - -[[package]] -name = "zerofrom" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" -dependencies = [ - "zerofrom-derive", -] - -[[package]] -name = "zerofrom-derive" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.66", - "synstructure", -] - -[[package]] -name = "zerovec" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb2cc8827d6c0994478a15c53f374f46fbd41bea663d809b14744bc42e6b109c" -dependencies = [ - "yoke", - "zerofrom", - "zerovec-derive", -] - -[[package]] -name = "zerovec-derive" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97cf56601ee5052b4417d90c8755c6683473c926039908196cf35d99f893ebe7" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] name = "zstd" -version = "0.13.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d789b1514203a1120ad2429eae43a7bd32b90976a7bb8a05f7ec02fa88cc23a" +checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" dependencies = [ "zstd-safe", ] [[package]] name = "zstd-safe" -version = "7.1.0" +version = "7.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cd99b45c6bc03a018c8b8a86025678c87e55526064e38f9df301989dce7ec0a" +checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.10+zstd.1.5.6" +version = "2.0.9+zstd.1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" +checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" dependencies = [ "cc", "pkg-config", diff --git a/wren-modeling-py/src/errors.rs b/wren-modeling-py/src/errors.rs index b8b86b26c..2adfe37c5 100644 --- a/wren-modeling-py/src/errors.rs +++ b/wren-modeling-py/src/errors.rs @@ -1,7 +1,7 @@ -use std::string::FromUtf8Error; use base64::DecodeError; use pyo3::exceptions::PyException; use pyo3::PyErr; +use std::string::FromUtf8Error; use thiserror::Error; #[derive(Error, Debug)] diff --git a/wren-modeling-py/src/lib.rs b/wren-modeling-py/src/lib.rs index 0b0c5e108..71d43b873 100644 --- a/wren-modeling-py/src/lib.rs +++ b/wren-modeling-py/src/lib.rs @@ -4,8 +4,8 @@ use base64::prelude::*; use pyo3::prelude::*; use wren_core::mdl; -use wren_core::mdl::AnalyzedWrenMDL; use wren_core::mdl::manifest::Manifest; +use wren_core::mdl::AnalyzedWrenMDL; use crate::errors::CoreError; @@ -13,10 +13,15 @@ mod errors; #[pyfunction] fn transform_sql(mdl_base64: &str, sql: &str) -> Result { - let mdl_json_bytes = BASE64_STANDARD.decode(mdl_base64).map_err(CoreError::from)?; + let mdl_json_bytes = BASE64_STANDARD + .decode(mdl_base64) + .map_err(CoreError::from)?; let mdl_json = String::from_utf8(mdl_json_bytes).map_err(CoreError::from)?; let manifest = serde_json::from_str::(&mdl_json)?; - let analyzed_mdl = AnalyzedWrenMDL::analyze(manifest); + + let Ok(analyzed_mdl) = AnalyzedWrenMDL::analyze(manifest) else { + return Err(CoreError::new("Failed to analyze manifest")); + }; match mdl::transform_sql(Arc::new(analyzed_mdl), sql) { Ok(transformed_sql) => Ok(transformed_sql), Err(e) => Err(CoreError::new(&e.to_string())), @@ -32,8 +37,8 @@ fn wren_core_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> { #[cfg(test)] mod tests { - use base64::Engine; use base64::prelude::BASE64_STANDARD; + use base64::Engine; use serde_json::Value; use crate::transform_sql; @@ -58,11 +63,12 @@ mod tests { }"#; let v: Value = serde_json::from_str(data).unwrap(); let mdl_base64: String = BASE64_STANDARD.encode(v.to_string().as_bytes()); - let transformed_sql = transform_sql( - &mdl_base64, - "SELECT * FROM my_catalog.my_schema.customer", - ) - .unwrap(); - assert_eq!(transformed_sql, r##"SELECT "customer"."custkey", "customer"."name" FROM (SELECT "customer"."custkey", "customer"."name" FROM (SELECT "main"."customer"."c_custkey" AS "custkey", "main"."customer"."c_name" AS "name" FROM "main"."customer") AS "customer") AS "customer""##); + let transformed_sql = + transform_sql(&mdl_base64, "SELECT * FROM my_catalog.my_schema.customer") + .unwrap(); + assert_eq!( + transformed_sql, + r#"SELECT customer.custkey, customer."name" FROM (SELECT customer.custkey, customer."name" FROM (SELECT main.customer.c_custkey AS custkey, main.customer.c_name AS "name" FROM main.customer) AS customer) AS customer"# + ); } } diff --git a/wren-modeling-py/tests/test_modeling_core.py b/wren-modeling-py/tests/test_modeling_core.py index be6318a6c..b72587c74 100644 --- a/wren-modeling-py/tests/test_modeling_core.py +++ b/wren-modeling-py/tests/test_modeling_core.py @@ -27,5 +27,5 @@ def test_transform_sql(): rewritten_sql = wren_core.transform_sql(manifest_str, sql) assert ( rewritten_sql - == 'SELECT "customer"."custkey", "customer"."name" FROM (SELECT "customer"."custkey", "customer"."name" FROM (SELECT "main"."customer"."c_custkey" AS "custkey", "main"."customer"."c_name" AS "name" FROM "main"."customer") AS "customer") AS "customer"' + == 'SELECT customer.custkey, customer."name" FROM (SELECT customer.custkey, customer."name" FROM (SELECT main.customer.c_custkey AS custkey, main.customer.c_name AS "name" FROM main.customer) AS customer) AS customer' ) diff --git a/wren-modeling-rs/Cargo.toml b/wren-modeling-rs/Cargo.toml index 49cefedd8..281f874c0 100644 --- a/wren-modeling-rs/Cargo.toml +++ b/wren-modeling-rs/Cargo.toml @@ -13,9 +13,8 @@ rust-version = "1.78" version = "0.1.0" [workspace.dependencies] -arrow-schema = { version = "51.0.0", default-features = false } async-trait = "0.1.80" -datafusion = { version = "38.0.0" } +datafusion = { version = "39.0.0" } env_logger = "0.11.3" log = { version = "0.4.14" } petgraph = "0.6.5" diff --git a/wren-modeling-rs/core/Cargo.toml b/wren-modeling-rs/core/Cargo.toml index 570fe795b..35b044e20 100644 --- a/wren-modeling-rs/core/Cargo.toml +++ b/wren-modeling-rs/core/Cargo.toml @@ -13,7 +13,6 @@ name = "wren_core" path = "src/lib.rs" [dependencies] -arrow-schema = { workspace = true } async-trait = { workspace = true } datafusion = { workspace = true } env_logger = { workspace = true } diff --git a/wren-modeling-rs/core/src/logical_plan/analyze/plan.rs b/wren-modeling-rs/core/src/logical_plan/analyze/plan.rs index 36207635f..973f525ad 100644 --- a/wren-modeling-rs/core/src/logical_plan/analyze/plan.rs +++ b/wren-modeling-rs/core/src/logical_plan/analyze/plan.rs @@ -1,15 +1,20 @@ +use datafusion::arrow::datatypes::Field; use std::cmp::Ordering; -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeSet, HashMap, VecDeque}; use std::fmt; use std::fmt::{Debug, Formatter}; use std::hash::Hash; use std::sync::Arc; -use arrow_schema::Field; -use datafusion::common::{Column, DFSchema, DFSchemaRef, TableReference}; +use datafusion::common::{ + internal_err, not_impl_err, plan_err, Column, DFSchema, DFSchemaRef, TableReference, +}; +use datafusion::error::Result; +use datafusion::logical_expr::utils::find_aggregate_exprs; use datafusion::logical_expr::{ col, Expr, Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNodeCore, }; +use petgraph::graph::NodeIndex; use petgraph::Graph; use crate::logical_plan::analyze::plan::RelationChain::Start; @@ -19,7 +24,8 @@ use crate::mdl; use crate::mdl::lineage::DatasetLink; use crate::mdl::manifest::{JoinType, Model}; use crate::mdl::utils::{ - create_remote_expr_for_model, create_wren_expr_for_model, is_dag, + create_remote_expr_for_model, create_wren_calculated_field_expr, + create_wren_expr_for_model, is_dag, }; use crate::mdl::{AnalyzedWrenMDL, ColumnReference, Dataset}; @@ -41,20 +47,21 @@ impl ModelPlanNode { required_fields: Vec, original_table_scan: Option, analyzed_wren_mdl: Arc, - ) -> Self { + ) -> Result { let mut required_exprs_buffer = BTreeSet::new(); let mut directed_graph: Graph = Graph::new(); let mut model_required_fields: HashMap> = HashMap::new(); + let mut required_calculation: Vec = vec![]; + let mut fields = VecDeque::new(); let model_ref = TableReference::full( analyzed_wren_mdl.wren_mdl().catalog(), analyzed_wren_mdl.wren_mdl().schema(), model.name(), ); - let fields = model - .get_physical_columns() - .iter() - .filter(|column| { + + let required_columns = + model.get_physical_columns().into_iter().filter(|column| { required_fields.iter().any(|expr| { if let Expr::Column(column_expr) = expr { column_expr.name.as_str() == column.name() @@ -62,172 +69,213 @@ impl ModelPlanNode { false } }) - }) - .map(|column| { - if column.is_calculated { - if column.expression.is_some() { - let column_rf = analyzed_wren_mdl - .wren_mdl() - .get_column_reference(&from_qualified_name( - &analyzed_wren_mdl.wren_mdl(), - model.name(), - column.name(), - )); - let expr = mdl::utils::create_wren_calculated_field_expr( - column_rf.clone(), - Arc::clone(&analyzed_wren_mdl), - ); - let expr_plan = expr.alias(column.name()); - required_exprs_buffer.insert(OrdExpr::new(expr_plan)); - } - else { - panic!("Only support calculated field with expression") + }); + + for column in required_columns { + if column.is_calculated { + let expr = if column.expression.is_some() { + let column_rf = analyzed_wren_mdl.wren_mdl().get_column_reference( + &from_qualified_name( + &analyzed_wren_mdl.wren_mdl(), + model.name(), + column.name(), + ), + ); + let Some(column_rf) = column_rf else { + return plan_err!("Column reference not found for {:?}", column); }; + let expr = create_wren_calculated_field_expr( + column_rf, + Arc::clone(&analyzed_wren_mdl), + )?; + let expr_plan = expr.alias(column.name()); + expr_plan + } else { + return plan_err!("Only support calculated field with expression"); + }; + + let qualified_column = from_qualified_name( + &analyzed_wren_mdl.wren_mdl(), + model.name(), + column.name(), + ); + + let Some(column_graph) = analyzed_wren_mdl + .lineage() + .required_dataset_topo + .get(&qualified_column) + else { + return plan_err!( + "Required dataset not found for {}", + qualified_column + ); + }; - let qualified_column = from_qualified_name( - &analyzed_wren_mdl.wren_mdl(), - model.name(), + if !find_aggregate_exprs(&[expr.clone()]).is_empty() { + // The calculation column is provided by the CalculationPlanNode. + required_exprs_buffer.insert(OrdExpr::new(col(format!( + "{}.{}", column.name(), + column.name() + )))); + + let column_rf = analyzed_wren_mdl + .wren_mdl() + .get_column_reference(&qualified_column); + let mut partial_model_required_fields = HashMap::new(); + let _ = collect_model_required_fields( + qualified_column, + Arc::clone(&analyzed_wren_mdl), + &mut partial_model_required_fields, ); - match analyzed_wren_mdl - .lineage() - .required_dataset_topo - .get(&qualified_column) - { - Some(column_graph) => { - merge_graph(&mut directed_graph, column_graph); - } - None => { - panic!("Column {} not found in the lineage", qualified_column) - } - } - analyzed_wren_mdl - .lineage() - .required_fields_map - .get(&qualified_column) - .unwrap() - .iter() - .for_each(|c| { - let Some(relation_ref) = &c.relation else { - panic!("Source dataset not found for {}", c) - }; - let ColumnReference {dataset, column} = analyzed_wren_mdl.wren_mdl().get_column_reference(c); - if let Dataset::Model(m) = dataset { - if column.is_calculated { - let expr_plan = if let Some(expression) = &column.expression { - create_wren_expr_for_model( - expression, - Arc::clone(&m), - Arc::clone(&analyzed_wren_mdl)) - } else { - panic!("Only support calculated field with expression") - }.alias(column.name.clone()); - model_required_fields - .entry(relation_ref.clone()) - .or_default() - .insert(OrdExpr::new(expr_plan)); - } - else { - let expr_plan = get_remote_column_exp( - &column, - Arc::clone(&m), - Arc::clone(&analyzed_wren_mdl)); - model_required_fields - .entry(relation_ref.clone()) - .or_default() - .insert(OrdExpr::new(expr_plan)); - } - } - else { - panic!("Only support model as source dataset") - }; - }); + let mut iter = column_graph.node_indices(); + + let start = iter.next().unwrap(); + let source_required_fields = partial_model_required_fields + .get(&model_ref) + .map(|c| c.iter().cloned().map(|c| c.expr).collect()) + .unwrap_or_default(); + let source = column_graph.node_weight(start).unwrap(); + + let source_chain = RelationChain::source( + source, + source_required_fields, + Arc::clone(&analyzed_wren_mdl), + )?; + + let partial_chain = RelationChain::with_chain( + source_chain, + start, + iter, + column_graph.clone(), + &partial_model_required_fields, + Arc::clone(&analyzed_wren_mdl), + )?; + + let Some(column_rf) = column_rf else { + return plan_err!("Column reference not found for {:?}", column); + }; + let calculation = CalculationPlanNode::new( + column_rf, + expr, + partial_chain, + Arc::clone(&analyzed_wren_mdl), + )?; + required_calculation.push(calculation); } else { - let expr_plan = get_remote_column_exp(column, Arc::clone(&model), Arc::clone(&analyzed_wren_mdl)); - model_required_fields - .entry(model_ref.clone()) - .or_default() - .insert(OrdExpr::new(expr_plan.clone())); - let expr_plan = Expr::Column(Column::from_qualified_name(format!("{}.{}", model_ref.table(), column.name()))); - required_exprs_buffer.insert(OrdExpr::new(expr_plan.clone())); + required_exprs_buffer.insert(OrdExpr::new(expr.clone())); + merge_graph(&mut directed_graph, column_graph)?; + let _ = collect_model_required_fields( + qualified_column, + Arc::clone(&analyzed_wren_mdl), + &mut model_required_fields, + ); } - ( - Some(TableReference::bare(model.name())), - Arc::new(Field::new( - column.name(), - map_data_type(&column.r#type), - column.no_null, - )), - ) - }) - .collect(); + } else { + let expr_plan = get_remote_column_exp( + &column, + Arc::clone(&model), + Arc::clone(&analyzed_wren_mdl), + )?; + model_required_fields + .entry(model_ref.clone()) + .or_default() + .insert(OrdExpr::new(expr_plan.clone())); + let expr_plan = Expr::Column(Column::from_qualified_name(format!( + "{}.{}", + model_ref.table(), + column.name() + ))); + required_exprs_buffer.insert(OrdExpr::new(expr_plan.clone())); + } + fields.push_front(( + Some(TableReference::bare(model.name())), + Arc::new(Field::new( + column.name(), + map_data_type(&column.r#type)?, + column.no_null, + )), + )); + } directed_graph.add_node(Dataset::Model(Arc::clone(&model))); if !is_dag(&directed_graph) { - panic!("cyclic dependency detected: {}", model.name()); + return plan_err!("cyclic dependency detected: {}", model.name()); } let schema_ref = DFSchemaRef::new( - DFSchema::new_with_metadata(fields, HashMap::new()) + DFSchema::new_with_metadata(fields.into_iter().collect(), HashMap::new()) .expect("create schema failed"), ); let mut iter = directed_graph.node_indices(); - let mut start = iter.next().unwrap(); - let source = directed_graph.node_weight(start).unwrap(); - let source_required_fields: Vec = model_required_fields + let Some(start) = iter.next() else { + return internal_err!("Model not found"); + }; + let Some(source) = directed_graph.node_weight(start) else { + return internal_err!("Dataset not found"); + }; + let mut source_required_fields: Vec = model_required_fields .get(&model_ref) .map(|c| c.iter().cloned().map(|c| c.expr).collect()) .unwrap_or_default(); - let mut relation_chain = RelationChain::source( - source, - source_required_fields, - Arc::clone(&analyzed_wren_mdl), - ); + let mut calculate_iter = required_calculation.iter(); - for next in iter { - let target = directed_graph.node_weight(next).unwrap(); - let Some(link_index) = directed_graph.find_edge(start, next) else { - break; + let source_chain = + if !source_required_fields.is_empty() || required_fields.is_empty() { + if required_fields.is_empty() { + source_required_fields.insert(0, Expr::Wildcard { qualifier: None }); + } + RelationChain::source( + source, + source_required_fields, + Arc::clone(&analyzed_wren_mdl), + )? + } else { + let Some(first_calculation) = calculate_iter.next() else { + return plan_err!("Calculation not found and no any required field"); + }; + Start(LogicalPlan::Extension(Extension { + node: Arc::new(first_calculation.clone()), + })) }; - let link = directed_graph.edge_weight(link_index).unwrap(); - let target_ref = TableReference::full( - analyzed_wren_mdl.wren_mdl().catalog(), - analyzed_wren_mdl.wren_mdl().schema(), - target.name(), + + let mut relation_chain = RelationChain::with_chain( + source_chain, + start, + iter, + directed_graph, + &model_required_fields, + Arc::clone(&analyzed_wren_mdl), + )?; + + for calculation_plan in calculate_iter { + let target_ref = + TableReference::bare(calculation_plan.calculation.column.name()); + let Some(join_key) = model.primary_key() else { + return plan_err!( + "Model {} should have primary key for calculation", + model.name() + ); + }; + relation_chain = RelationChain::Chain( + LogicalPlan::Extension(Extension { + node: Arc::new(calculation_plan.clone()), + }), + JoinType::OneToOne, + format!( + "{}.{} = {}.{}", + model_ref.table(), + join_key, + target_ref.table(), + join_key, + ), + Box::new(relation_chain), ); - match target { - Dataset::Model(target_model) => { - relation_chain = RelationChain::Chain( - LogicalPlan::Extension(Extension { - node: Arc::new(ModelSourceNode::new( - Arc::clone(target_model), - model_required_fields - .get(&target_ref) - .unwrap() - .iter() - .cloned() - .map(|c| c.expr) - .collect(), - Arc::clone(&analyzed_wren_mdl), - None, - )), - }), - link.join_type, - link.condition.clone(), - Box::new(relation_chain), - ); - } - _ => { - unimplemented!("Only support model as source dataset") - } - } - start = next; } - - Self { + Ok(Self { model_name: model.name.clone(), required_exprs: required_exprs_buffer .into_iter() @@ -236,25 +284,81 @@ impl ModelPlanNode { relation_chain: Box::new(relation_chain), schema_ref, original_table_scan, - } + }) + } +} + +fn collect_model_required_fields( + qualified_column: Column, + analyzed_wren_mdl: Arc, + model_required_fields: &mut HashMap>, +) -> Result<()> { + let Some(set) = analyzed_wren_mdl + .lineage() + .required_fields_map + .get(&qualified_column) + else { + return plan_err!("Required fields not found for {}", qualified_column); + }; + + for c in set { + let Some(relation_ref) = &c.relation else { + return plan_err!("Source dataset not found for {}", c); + }; + let Some(ColumnReference { dataset, column }) = + analyzed_wren_mdl.wren_mdl().get_column_reference(c) + else { + return plan_err!("Column reference not found for {}", c); + }; + if let Dataset::Model(m) = dataset { + if column.is_calculated { + let expr_plan = if let Some(expression) = &column.expression { + create_wren_expr_for_model( + expression, + Arc::clone(&m), + Arc::clone(&analyzed_wren_mdl), + )? + } else { + return plan_err!("Only support calculated field with expression"); + } + .alias(column.name.clone()); + model_required_fields + .entry(relation_ref.clone()) + .or_default() + .insert(OrdExpr::new(expr_plan)); + } else { + let expr_plan = get_remote_column_exp( + &column, + Arc::clone(&m), + Arc::clone(&analyzed_wren_mdl), + )?; + model_required_fields + .entry(relation_ref.clone()) + .or_default() + .insert(OrdExpr::new(expr_plan)); + } + } else { + return plan_err!("Only support model as source dataset"); + }; } + Ok(()) } fn get_remote_column_exp( column: &mdl::manifest::Column, model: Arc, analyzed_wren_mdl: Arc, -) -> Expr { +) -> Result { let expr = if let Some(expression) = &column.expression { - create_remote_expr_for_model(expression, model, analyzed_wren_mdl) + create_remote_expr_for_model(expression, model, analyzed_wren_mdl)? } else { - create_remote_expr_for_model(&column.name, model, analyzed_wren_mdl) + create_remote_expr_for_model(&column.name, model, analyzed_wren_mdl)? }; - expr.alias(column.name.clone()) + Ok(expr.alias(column.name.clone())) } #[derive(Eq, PartialEq, Debug, Hash, Clone)] -struct OrdExpr { +pub struct OrdExpr { expr: Expr, } @@ -285,7 +389,7 @@ impl From for Expr { fn merge_graph( graph: &mut Graph, new_graph: &Graph, -) { +) -> Result<()> { let mut node_map = HashMap::new(); for node in new_graph.node_indices() { let new_node = graph.add_node(new_graph[node].clone()); @@ -293,11 +397,14 @@ fn merge_graph( } for edge in new_graph.edge_indices() { - let (source, target) = new_graph.edge_endpoints(edge).unwrap(); + let Some((source, target)) = new_graph.edge_endpoints(edge) else { + return internal_err!("Edge not found"); + }; let source = node_map.get(&source).unwrap(); let target = node_map.get(&target).unwrap(); graph.add_edge(*source, *target, new_graph[edge].clone()); } + Ok(()) } /// RelationChain is a chain of models that are connected by the relationship. @@ -305,7 +412,7 @@ fn merge_graph( /// The physical layout will be looked like: /// (((Model3, Model2), Model1), Nil) #[derive(Eq, PartialEq, Debug, Hash, Clone)] -pub(crate) enum RelationChain { +pub enum RelationChain { Chain(LogicalPlan, JoinType, String, Box), Start(LogicalPlan), } @@ -315,23 +422,78 @@ impl RelationChain { dataset: &Dataset, required_fields: Vec, analyzed_wren_mdl: Arc, - ) -> Self { + ) -> Result { match dataset { - Dataset::Model(source_model) => Start(LogicalPlan::Extension(Extension { - node: Arc::new(ModelSourceNode::new( - Arc::clone(source_model), - required_fields, - analyzed_wren_mdl, - None, - )), - })), + Dataset::Model(source_model) => { + Ok(Start(LogicalPlan::Extension(Extension { + node: Arc::new(ModelSourceNode::new( + Arc::clone(source_model), + required_fields, + analyzed_wren_mdl, + None, + )?), + }))) + } _ => { - unimplemented!("Only support model as source dataset") + not_impl_err!("Only support model as source dataset") + } + } + } + + pub fn with_chain( + source: Self, + mut start: NodeIndex, + iter: impl Iterator, + directed_graph: Graph, + model_required_fields: &HashMap>, + analyzed_wren_mdl: Arc, + ) -> Result { + let mut relation_chain = source; + + for next in iter { + let target = directed_graph.node_weight(next).unwrap(); + let Some(link_index) = directed_graph.find_edge(start, next) else { + break; + }; + let link = directed_graph.edge_weight(link_index).unwrap(); + let target_ref = TableReference::full( + analyzed_wren_mdl.wren_mdl().catalog(), + analyzed_wren_mdl.wren_mdl().schema(), + target.name(), + ); + match target { + Dataset::Model(target_model) => { + relation_chain = RelationChain::Chain( + LogicalPlan::Extension(Extension { + node: Arc::new(ModelSourceNode::new( + Arc::clone(target_model), + model_required_fields + .get(&target_ref) + .unwrap() + .iter() + .cloned() + .map(|c| c.expr) + .collect(), + Arc::clone(&analyzed_wren_mdl), + None, + )?), + }), + link.join_type, + link.condition.clone(), + Box::new(relation_chain), + ); + } + _ => return plan_err!("Only support model as source dataset"), } + start = next; } + Ok(relation_chain) } - pub(crate) fn plan(&mut self, rule: ModelGenerationRule) -> Option { + pub(crate) fn plan( + &mut self, + rule: ModelGenerationRule, + ) -> Result> { match self { RelationChain::Chain(plan, _, condition, ref mut next) => { let left = rule @@ -344,10 +506,9 @@ impl RelationChain { .map(|c| col(c.flat_name())) .collect(); let join_condition = join_keys[0].clone().eq(join_keys[1].clone()); - let Some(right) = next.plan(rule) else { - panic!("Nil relation chain") + let Some(right) = next.plan(rule)? else { + return plan_err!("Nil relation chain"); }; - let mut required_exprs = BTreeSet::new(); // collect the output calculated fields match plan { @@ -384,16 +545,32 @@ impl RelationChain { .for_each(|c| { required_exprs.insert(OrdExpr::new(c)); }); + } else if let Some(calculation_plan) = + plan.node.as_any().downcast_ref::() + { + UserDefinedLogicalNodeCore::schema(calculation_plan) + .fields() + .iter() + .map(|field| { + col(format!( + "{}.{}", + calculation_plan.calculation.column.name(), + field.name() + )) + }) + .for_each(|c| { + required_exprs.insert(OrdExpr::new(c)); + }); } else { - panic!("Invalid extension plan node") + return plan_err!("Invalid extension plan node"); } } - _ => panic!(""), + _ => return internal_err!("Invalid plan node"), }; // collect the column of the left table for index in 0..left.schema().fields().len() { let (Some(table_rf), f) = left.schema().qualified_field(index) else { - panic!("Field not found") + return plan_err!("Field not found"); }; let qualified_name = format!("{}.{}", table_rf, f.name()); required_exprs.insert(OrdExpr::new(col(qualified_name))); @@ -403,7 +580,7 @@ impl RelationChain { for index in 0..right.schema().fields().len() { let (Some(table_rf), f) = right.schema().qualified_field(index) else { - panic!("Field not found") + return plan_err!("Field not found"); }; let qualified_name = format!("{}.{}", table_rf, f.name()); required_exprs.insert(OrdExpr::new(col(qualified_name))); @@ -414,7 +591,7 @@ impl RelationChain { .map(|expr| expr.expr.clone()) .collect(); - Some( + Ok(Some( LogicalPlanBuilder::from(left) .join_on( right, @@ -426,13 +603,13 @@ impl RelationChain { .unwrap() .build() .unwrap(), - ) + )) } - Start(plan) => Some( + Start(plan) => Ok(Some( rule.generate_model_internal(plan.clone()) .expect("Failed to generate model plan") .data, - ), + )), } } } @@ -462,14 +639,18 @@ impl UserDefinedLogicalNodeCore for ModelPlanNode { write!(f, "Model: name={}", self.model_name) } - fn from_template(&self, _: &[Expr], _: &[LogicalPlan]) -> Self { - ModelPlanNode { + fn with_exprs_and_inputs( + &self, + _: Vec, + _: Vec, + ) -> datafusion::common::Result { + Ok(ModelPlanNode { model_name: self.model_name.clone(), required_exprs: self.required_exprs.clone(), relation_chain: self.relation_chain.clone(), schema_ref: self.schema_ref.clone(), original_table_scan: self.original_table_scan.clone(), - } + }) } } @@ -490,53 +671,89 @@ impl ModelSourceNode { required_exprs: Vec, analyzed_wren_mdl: Arc, original_table_scan: Option, - ) -> Self { + ) -> Result { let mut required_exprs_buffer = BTreeSet::new(); - let fields = required_exprs - .iter() - .map(|field| { - let column = model - .get_physical_columns() - .into_iter() - .find(|column| match field { - Expr::Column(c) => c.name.as_str() == column.name(), - Expr::Alias(alias) => alias.name.as_str() == column.name(), - _ => panic!("Invalid field expression"), - }) - .unwrap_or_else(|| panic!("Field not found {}", field)); + let mut fields_buffer = BTreeSet::new(); + for expr in required_exprs.iter() { + if let Expr::Wildcard { qualifier } = expr { + let model = if let Some(model) = qualifier { + let Some(model) = analyzed_wren_mdl.wren_mdl.get_model(model) else { + return plan_err!("Model not found {}", &model); + }; + model + } else { + Arc::clone(&model) + }; + for column in model.get_physical_columns().into_iter() { + // skip the calculated field + if column.is_calculated { + continue; + } + fields_buffer.insert(( + Some(TableReference::bare(model.name())), + Arc::new(Field::new( + column.name(), + map_data_type(&column.r#type)?, + column.no_null, + )), + )); + required_exprs_buffer.insert(OrdExpr::new(get_remote_column_exp( + &column, + Arc::clone(&model), + Arc::clone(&analyzed_wren_mdl), + )?)); + } + } else { + let Some(column) = + model + .get_physical_columns() + .into_iter() + .find(|column| match expr { + Expr::Column(c) => c.name.as_str() == column.name(), + Expr::Alias(alias) => alias.name.as_str() == column.name(), + _ => false, + }) + else { + return plan_err!("Field not found {}", expr); + }; if column.is_calculated { - panic!("should not use calculated field in source plan") + return plan_err!("should not use calculated field in source plan"); } else { let expr_plan = get_remote_column_exp( &column, Arc::clone(&model), Arc::clone(&analyzed_wren_mdl), - ); + )?; required_exprs_buffer.insert(OrdExpr::new(expr_plan.clone())); } - ( + + fields_buffer.insert(( Some(TableReference::bare(model.name())), Arc::new(Field::new( column.name(), - map_data_type(&column.r#type), + map_data_type(&column.r#type)?, column.no_null, )), - ) - }) - .collect(); + )); + } + } + let fields = fields_buffer.into_iter().collect::>(); let schema_ref = DFSchemaRef::new( DFSchema::new_with_metadata(fields, HashMap::new()) .expect("create schema failed"), ); - - ModelSourceNode { + let required_exprs = required_exprs_buffer + .into_iter() + .map(|e| e.expr) + .collect::>(); + Ok(ModelSourceNode { model_name: model.name().to_string(), required_exprs, schema_ref, original_table_scan, - } + }) } } @@ -565,12 +782,116 @@ impl UserDefinedLogicalNodeCore for ModelSourceNode { write!(f, "ModelSource: name={}", self.model_name) } - fn from_template(&self, _: &[Expr], _: &[LogicalPlan]) -> Self { - ModelSourceNode { + fn with_exprs_and_inputs( + &self, + _: Vec, + _: Vec, + ) -> datafusion::common::Result { + Ok(ModelSourceNode { model_name: self.model_name.clone(), required_exprs: self.required_exprs.clone(), schema_ref: self.schema_ref.clone(), original_table_scan: self.original_table_scan.clone(), - } + }) + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub struct CalculationPlanNode { + pub calculation: ColumnReference, + pub relation_chain: RelationChain, + pub dimensions: Vec, + pub measures: Vec, + schema_ref: DFSchemaRef, +} + +impl CalculationPlanNode { + pub fn new( + calculation: ColumnReference, + calculation_expr: Expr, + relation_chain: RelationChain, + analyzed_wren_mdl: Arc, + ) -> Result { + let Some(model) = calculation.dataset.try_as_model() else { + return plan_err!("Only support model as source dataset"); + }; + let Some(pk_column) = model.primary_key().and_then(|pk| model.get_column(pk)) + else { + return plan_err!("Primary key not found"); + }; + + // include calculation column and join key (pk) + let output_field = vec![ + Arc::new(Field::new( + calculation.column.name(), + map_data_type(&calculation.column.r#type)?, + calculation.column.no_null, + )), + Arc::new(Field::new( + pk_column.name(), + map_data_type(&pk_column.r#type)?, + pk_column.no_null, + )), + ] + .into_iter() + .map(|f| (Some(TableReference::bare(model.name())), f)) + .collect(); + let dimensions = vec![create_wren_expr_for_model( + &pk_column.name, + Arc::clone(&model), + Arc::clone(&analyzed_wren_mdl), + )? + .alias(pk_column.name())]; + let schema_ref = DFSchemaRef::new( + DFSchema::new_with_metadata(output_field, HashMap::new()) + .expect("create schema failed"), + ); + Ok(Self { + calculation, + relation_chain, + dimensions, + measures: vec![calculation_expr], + schema_ref, + }) + } +} + +impl UserDefinedLogicalNodeCore for CalculationPlanNode { + fn name(&self) -> &str { + "Calculation" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema_ref + } + + fn expressions(&self) -> Vec { + self.schema_ref + .fields() + .iter() + .map(|field| col(field.name())) + .collect() + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "Calculation: name={}", self.calculation.column.name) + } + + fn with_exprs_and_inputs( + &self, + _: Vec, + _: Vec, + ) -> datafusion::common::Result { + Ok(CalculationPlanNode { + calculation: self.calculation.clone(), + relation_chain: self.relation_chain.clone(), + dimensions: self.dimensions.clone(), + measures: self.measures.clone(), + schema_ref: self.schema_ref.clone(), + }) } } diff --git a/wren-modeling-rs/core/src/logical_plan/analyze/rule.rs b/wren-modeling-rs/core/src/logical_plan/analyze/rule.rs index d7a7f45de..040cfc394 100644 --- a/wren-modeling-rs/core/src/logical_plan/analyze/rule.rs +++ b/wren-modeling-rs/core/src/logical_plan/analyze/rule.rs @@ -4,14 +4,16 @@ use std::sync::Arc; use datafusion::common::config::ConfigOptions; use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion::common::Result; +use datafusion::common::{plan_err, Result}; use datafusion::logical_expr::logical_plan::tree_node::unwrap_arc; -use datafusion::logical_expr::{col, utils, Extension}; +use datafusion::logical_expr::{col, ident, utils, Extension}; use datafusion::logical_expr::{Expr, Join, LogicalPlan, LogicalPlanBuilder, TableScan}; use datafusion::optimizer::analyzer::AnalyzerRule; use datafusion::sql::TableReference; -use crate::logical_plan::analyze::plan::{ModelPlanNode, ModelSourceNode}; +use crate::logical_plan::analyze::plan::{ + CalculationPlanNode, ModelPlanNode, ModelSourceNode, +}; use crate::logical_plan::utils::create_remote_table_source; use crate::mdl::manifest::Model; use crate::mdl::{AnalyzedWrenMDL, WrenMDL}; @@ -85,7 +87,7 @@ impl ModelAnalyzeRule { field, Some(LogicalPlan::TableScan(table_scan.clone())), Arc::clone(&self.analyzed_wren_mdl), - )), + )?), }); used_columns.borrow_mut().clear(); Ok(Transformed::yes(model)) @@ -112,7 +114,7 @@ impl ModelAnalyzeRule { Arc::clone(&self.analyzed_wren_mdl), table_scan, buffer.iter().cloned().collect(), - ), + )?, ignore => ignore, }; @@ -121,7 +123,7 @@ impl ModelAnalyzeRule { Arc::clone(&self.analyzed_wren_mdl), table_scan, buffer.iter().cloned().collect(), - ), + )?, ignore => ignore, }; buffer.clear(); @@ -155,22 +157,22 @@ fn analyze_table_scan( analyzed_wren_mdl: Arc, table_scan: TableScan, required_field: Vec, -) -> LogicalPlan { +) -> Result { if belong_to_mdl(&analyzed_wren_mdl.wren_mdl(), table_scan.table_name.clone()) { - LogicalPlan::TableScan(table_scan) + Ok(LogicalPlan::TableScan(table_scan)) } else { let table_name = table_scan.table_name.table(); if let Some(model) = analyzed_wren_mdl.wren_mdl.get_model(table_name) { - LogicalPlan::Extension(Extension { + Ok(LogicalPlan::Extension(Extension { node: Arc::new(ModelPlanNode::new( model, required_field, Some(LogicalPlan::TableScan(table_scan.clone())), Arc::clone(&analyzed_wren_mdl), - )), - }) + )?), + })) } else { - LogicalPlan::TableScan(table_scan) + Ok(LogicalPlan::TableScan(table_scan)) } } } @@ -210,13 +212,9 @@ impl ModelGenerationRule { if let Some(model_plan) = extension.node.as_any().downcast_ref::() { - let source_plan = - model_plan - .relation_chain - .clone() - .plan(ModelGenerationRule::new(Arc::clone( - &self.analyzed_wren_mdl, - ))); + let source_plan = model_plan.relation_chain.clone().plan( + ModelGenerationRule::new(Arc::clone(&self.analyzed_wren_mdl)), + )?; let model: Arc = Arc::clone( &self @@ -225,13 +223,12 @@ impl ModelGenerationRule { .get_model(&model_plan.model_name) .expect("Model not found"), ); - let result = match source_plan { Some(plan) => LogicalPlanBuilder::from(plan) .project(model_plan.required_exprs.clone())? .build()?, _ => { - panic!("Failed to generate source plan") + return plan_err!("Failed to generate source plan"); } }; // calculated field scope @@ -288,6 +285,39 @@ impl ModelGenerationRule { .alias(model.name.clone())? .build()?; Ok(Transformed::yes(result)) + } else if let Some(calculation_plan) = extension + .node + .as_any() + .downcast_ref::( + ) { + let source_plan = calculation_plan.relation_chain.clone().plan( + ModelGenerationRule::new(Arc::clone(&self.analyzed_wren_mdl)), + )?; + + if let Expr::Alias(alias) = calculation_plan.measures[0].clone() { + let measure: Expr = *alias.expr.clone(); + let name = alias.name.clone(); + let ident = ident(measure.to_string()).alias(name); + let project = vec![calculation_plan.dimensions[0].clone(), ident]; + let result = match source_plan { + Some(plan) => LogicalPlanBuilder::from(plan) + .aggregate( + calculation_plan.dimensions.clone(), + vec![measure], + )? + .project(project)? + .build()?, + _ => { + return plan_err!("Failed to generate source plan"); + } + }; + let alias = LogicalPlanBuilder::from(result) + .alias(calculation_plan.calculation.column.name())? + .build()?; + Ok(Transformed::yes(alias)) + } else { + return plan_err!("measures should have an alias"); + } } else { Ok(Transformed::no(LogicalPlan::Extension(extension))) } @@ -403,20 +433,20 @@ mod test { .build(), ) .build(); - let analyzed_wren_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)); + let analyzed_wren_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); // [RemoveWrenPrefixRule] only remove the prefix of identifiers, so that the table name in // the expected result will have the schema prefix. let tests = vec![ ("select wrenai.default.a.c1, wrenai.default.a.c2 from wrenai.default.a", - r#"SELECT "a"."c1", "a"."c2" FROM "default"."a""#), + r#"SELECT a.c1, a.c2 FROM wrenai."default".a"#), ("select wrenai.default.a.c1, wrenai.default.a.c2 from wrenai.default.a where wrenai.default.a.c1 = 1", - r#"SELECT "a"."c1", "a"."c2" FROM "default"."a" WHERE ("a"."c1" = 1)"#), + r#"SELECT a.c1, a.c2 FROM wrenai."default".a WHERE (a.c1 = 1)"#), ("select wrenai.default.a.c1 + 1 from wrenai.default.a", - r#"SELECT ("a"."c1" + 1) FROM "default"."a""#) + r#"SELECT (a.c1 + 1) FROM wrenai."default".a"#) ]; - let context_provider = WrenContextProvider::new(&analyzed_wren_mdl.wren_mdl); + let context_provider = WrenContextProvider::new(&analyzed_wren_mdl.wren_mdl)?; let sql_to_rel = SqlToRel::new(&context_provider); let dialect = GenericDialect {}; let analyzer = Analyzer::with_rules(vec![Arc::new(RemoveWrenPrefixRule::new( diff --git a/wren-modeling-rs/core/src/logical_plan/context_provider.rs b/wren-modeling-rs/core/src/logical_plan/context_provider.rs index 24d0d3c57..09350505d 100644 --- a/wren-modeling-rs/core/src/logical_plan/context_provider.rs +++ b/wren-modeling-rs/core/src/logical_plan/context_provider.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, sync::Arc}; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::DefaultTableSource; use datafusion::logical_expr::builder::LogicalTableSource; use datafusion::{ @@ -23,29 +23,37 @@ pub struct WrenContextProvider { } impl WrenContextProvider { - pub fn new(mdl: &WrenMDL) -> Self { + pub fn new(mdl: &WrenMDL) -> Result { let mut tables = HashMap::new(); - mdl.manifest.models.iter().for_each(|model| { + // register model table + for model in mdl.manifest.models.iter() { tables.insert( format!("{}.{}.{}", mdl.catalog(), mdl.schema(), model.name()), - create_table_source(model), + create_table_source(model)?, ); - }); - Self { + } + // register physical table + for (name, table) in mdl.register_tables.iter() { + tables.insert( + name.clone(), + Arc::new(DefaultTableSource::new(table.clone())), + ); + } + Ok(Self { tables, options: Default::default(), - } + }) } - pub fn new_bare(mdl: &WrenMDL) -> Self { + pub fn new_bare(mdl: &WrenMDL) -> Result { let mut tables = HashMap::new(); - mdl.manifest.models.iter().for_each(|model| { - tables.insert(model.name().to_string(), create_table_source(model)); - }); - Self { + for model in mdl.manifest.models.iter() { + tables.insert(model.name().to_string(), create_table_source(model)?); + } + Ok(Self { tables, options: Default::default(), - } + }) } } @@ -78,15 +86,15 @@ impl ContextProvider for WrenContextProvider { &self.options } - fn udfs_names(&self) -> Vec { + fn udf_names(&self) -> Vec { Vec::new() } - fn udafs_names(&self) -> Vec { + fn udaf_names(&self) -> Vec { Vec::new() } - fn udwfs_names(&self) -> Vec { + fn udwf_names(&self) -> Vec { Vec::new() } } @@ -99,7 +107,7 @@ pub struct RemoteContextProvider { } impl RemoteContextProvider { - pub fn new(mdl: &WrenMDL) -> Self { + pub fn new(mdl: &WrenMDL) -> Result { let tables = mdl .manifest .models @@ -109,15 +117,15 @@ impl RemoteContextProvider { let datasource = if let Some(table_provider) = remove_provider { Arc::new(DefaultTableSource::new(table_provider)) } else { - create_remote_table_source(model, mdl) + create_remote_table_source(model, mdl)? }; - (model.table_reference.clone(), datasource) + Ok((model.table_reference.clone(), datasource)) }) - .collect::>(); - Self { + .collect::>>()?; + Ok(Self { tables, options: Default::default(), - } + }) } } @@ -149,35 +157,43 @@ impl ContextProvider for RemoteContextProvider { &self.options } - fn udfs_names(&self) -> Vec { + fn udf_names(&self) -> Vec { Vec::new() } - fn udafs_names(&self) -> Vec { + fn udaf_names(&self) -> Vec { Vec::new() } - fn udwfs_names(&self) -> Vec { + fn udwf_names(&self) -> Vec { Vec::new() } } -fn create_remote_table_source(model: &Model, wren_mdl: &WrenMDL) -> Arc { +fn create_remote_table_source( + model: &Model, + wren_mdl: &WrenMDL, +) -> Result> { if let Some(table_provider) = wren_mdl.get_table(&model.table_reference) { - Arc::new(DefaultTableSource::new(table_provider)) + Ok(Arc::new(DefaultTableSource::new(table_provider))) } else { let schema = create_schema(model.get_physical_columns()); - Arc::new(LogicalTableSource::new(schema)) + Ok(Arc::new(LogicalTableSource::new(schema?))) } } -fn create_schema(columns: Vec>) -> SchemaRef { +fn create_schema(columns: Vec>) -> Result { let fields: Vec = columns .iter() .filter(|c| !c.is_calculated) .flat_map(|column| { if column.expression.is_none() { - let data_type = map_data_type(&column.r#type); + let data_type = if let Ok(data_type) = map_data_type(&column.r#type) { + data_type + } else { + // TODO optimize to use Datafusion's error type + unimplemented!("Unsupported data type: {}", column.r#type) + }; vec![Field::new(&column.name, data_type, column.no_null)] } else { utils::collect_identifiers(column.expression.as_ref().unwrap()) @@ -191,7 +207,10 @@ fn create_schema(columns: Vec>) -> SchemaRef { } }) .collect(); - SchemaRef::new(Schema::new_with_metadata(fields, HashMap::new())) + Ok(SchemaRef::new(Schema::new_with_metadata( + fields, + HashMap::new(), + ))) } pub(crate) struct DynamicContextProvider { @@ -229,15 +248,15 @@ impl ContextProvider for DynamicContextProvider { self.delegate.options() } - fn udfs_names(&self) -> Vec { + fn udf_names(&self) -> Vec { Vec::new() } - fn udafs_names(&self) -> Vec { + fn udaf_names(&self) -> Vec { Vec::new() } - fn udwfs_names(&self) -> Vec { + fn udwf_names(&self) -> Vec { Vec::new() } } diff --git a/wren-modeling-rs/core/src/logical_plan/utils.rs b/wren-modeling-rs/core/src/logical_plan/utils.rs index c485d109a..ae9eb4045 100644 --- a/wren-modeling-rs/core/src/logical_plan/utils.rs +++ b/wren-modeling-rs/core/src/logical_plan/utils.rs @@ -1,7 +1,9 @@ +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use datafusion::common::not_impl_err; use std::{collections::HashMap, sync::Arc}; -use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit}; use datafusion::datasource::DefaultTableSource; +use datafusion::error::Result; use datafusion::logical_expr::{builder::LogicalTableSource, TableSource}; use petgraph::dot::{Config, Dot}; use petgraph::Graph; @@ -12,32 +14,36 @@ use crate::mdl::{ Dataset, WrenMDL, }; -pub fn map_data_type(data_type: &str) -> DataType { - match data_type { +pub fn map_data_type(data_type: &str) -> Result { + let result = match data_type { "integer" => DataType::Int32, "bigint" => DataType::Int64, "varchar" => DataType::Utf8, "double" => DataType::Float64, "timestamp" => DataType::Timestamp(TimeUnit::Nanosecond, None), "date" => DataType::Date32, - _ => unimplemented!("{}", &data_type), - } + _ => return not_impl_err!("Unsupported data type: {}", &data_type), + }; + Ok(result) } -pub fn create_table_source(model: &Model) -> Arc { - let schema = create_schema(model.get_physical_columns()); - Arc::new(LogicalTableSource::new(schema)) +pub fn create_table_source(model: &Model) -> Result> { + let schema = create_schema(model.get_physical_columns())?; + Ok(Arc::new(LogicalTableSource::new(schema))) } -pub fn create_schema(columns: Vec>) -> SchemaRef { +pub fn create_schema(columns: Vec>) -> Result { let fields: Vec = columns .iter() .map(|column| { - let data_type = map_data_type(&column.r#type); - Field::new(&column.name, data_type, column.no_null) + let data_type = map_data_type(&column.r#type)?; + Ok(Field::new(&column.name, data_type, column.no_null)) }) - .collect(); - SchemaRef::new(Schema::new_with_metadata(fields, HashMap::new())) + .collect::>>()?; + Ok(SchemaRef::new(Schema::new_with_metadata( + fields, + HashMap::new(), + ))) } pub fn create_remote_table_source(model: &Model, mdl: &WrenMDL) -> Arc { diff --git a/wren-modeling-rs/core/src/mdl/builder.rs b/wren-modeling-rs/core/src/mdl/builder.rs index 8e0f338ff..1a2a4f03f 100644 --- a/wren-modeling-rs/core/src/mdl/builder.rs +++ b/wren-modeling-rs/core/src/mdl/builder.rs @@ -78,7 +78,7 @@ impl ModelBuilder { base_object: "".to_string(), table_reference: "".to_string(), columns: vec![], - primary_key: "".to_string(), + primary_key: None, cached: false, refresh_time: "".to_string(), properties: vec![], @@ -107,7 +107,7 @@ impl ModelBuilder { } pub fn primary_key(mut self, primary_key: &str) -> Self { - self.model.primary_key = primary_key.to_string(); + self.model.primary_key = Some(primary_key.to_string()); self } diff --git a/wren-modeling-rs/core/src/mdl/lineage.rs b/wren-modeling-rs/core/src/mdl/lineage.rs index df4e959be..88a6d100f 100644 --- a/wren-modeling-rs/core/src/mdl/lineage.rs +++ b/wren-modeling-rs/core/src/mdl/lineage.rs @@ -1,9 +1,9 @@ -use core::panic; use std::collections::{HashMap, HashSet}; use std::fmt::Display; use std::sync::Arc; -use datafusion::common::Column; +use datafusion::common::{internal_err, plan_err, Column}; +use datafusion::error::Result; use datafusion::sql::TableReference; use petgraph::Graph; @@ -21,17 +21,17 @@ pub struct Lineage { } impl Lineage { - pub fn new(mdl: &WrenMDL) -> Self { + pub fn new(mdl: &WrenMDL) -> Result { let source_columns_map = Lineage::collect_source_columns(mdl); let RequiredInfo { required_fields_map, required_dataset_topo, - } = Lineage::collect_required_fields(mdl, &source_columns_map); - Lineage { + } = Lineage::collect_required_fields(mdl, &source_columns_map)?; + Ok(Lineage { source_columns_map, required_fields_map, required_dataset_topo, - } + }) } fn collect_source_columns(mdl: &WrenMDL) -> HashMap> { @@ -73,168 +73,160 @@ impl Lineage { fn collect_required_fields( mdl: &WrenMDL, source_colums_map: &HashMap>, - ) -> RequiredInfo { + ) -> Result { let mut required_fields_map: HashMap> = HashMap::new(); let mut required_dataset_topo: HashMap> = HashMap::new(); - source_colums_map - .iter() - .for_each(|(column, source_columns)| { - let Some(relation) = column.clone().relation else { - return; - }; - let mut current_relation = match relation { - TableReference::Bare { table } => { - TableReference::full(mdl.catalog(), mdl.schema(), table) - } - TableReference::Partial { schema, table } => { - TableReference::full(mdl.catalog(), schema, table) - } - TableReference::Full { - catalog, - schema, - table, - } => TableReference::full(catalog, schema, table), - }; - - let column_ref = mdl.get_column_reference(column); - if !column_ref.column.is_calculated - || column_ref.column.relationship.is_some() - { - return; + for (column, source_columns) in source_colums_map.iter() { + let Some(relation) = column.clone().relation else { + return internal_err!("relation not found: {}", column); + }; + let mut current_relation = match relation { + TableReference::Bare { table } => { + TableReference::full(mdl.catalog(), mdl.schema(), table) } + TableReference::Partial { schema, table } => { + TableReference::full(mdl.catalog(), schema, table) + } + TableReference::Full { + catalog, + schema, + table, + } => TableReference::full(catalog, schema, table), + }; - let mut directed_graph: Graph = Graph::new(); - let mut node_index_map = HashMap::new(); - let mut left_vertex = *node_index_map - .entry(column_ref.dataset.clone()) - .or_insert_with(|| { - directed_graph.add_node(column_ref.dataset.clone()) - }); + let Some(column_ref) = mdl.get_column_reference(column) else { + return internal_err!("column not found: {}", column); + }; - source_columns.iter().for_each(|source_column| { - let mut expr_parts = to_expr_queue(source_column.clone()); - while !expr_parts.is_empty() { - let ident = expr_parts.pop_front().unwrap(); - let source_column_ref = mdl.get_column_reference(&Column::new( - Some(current_relation.clone()), - ident.clone(), - )); - match source_column_ref.dataset { - Dataset::Model(_) => { - match source_column_ref.column.relationship.clone() { - Some(rs) => { - if let Some(rs_rf) = mdl.get_relationship(&rs) { - let related_model_name = rs_rf - .models - .iter() - .find(|m| m != ¤t_relation.table()) - .cloned() - .unwrap(); - if related_model_name - != source_column_ref.column.r#type - { - panic!( - "invalid relationship type: {}", - source_column - ); - } - - utils::collect_identifiers(&rs_rf.condition) - .iter() - .cloned() - .for_each(|ident| { - required_fields_map - .entry(column.clone()) - .or_default() - .insert( - Column::from_qualified_name( - format!( - "{}.{}.{}", - mdl.catalog(), - mdl.schema(), - ident.flat_name() - ), - ), - ); - }); - - let related_model = mdl - .get_model(&related_model_name) - .unwrap(); - - let right_vertex = *node_index_map - .entry(Dataset::Model(Arc::clone( - &related_model, - ))) - .or_insert_with(|| { - directed_graph.add_node( - Dataset::Model(Arc::clone( - &related_model, - )), - ) - }); - directed_graph.add_edge( - left_vertex, - right_vertex, - get_dataset_link_revers_if_need( - source_column_ref.dataset.clone(), - rs_rf, - ), - ); - - current_relation = TableReference::full( - mdl.catalog(), - mdl.schema(), - related_model_name, - ); - - left_vertex = right_vertex; - } else { - panic!( - "relationship not found: {}", - source_column - ); - } - } - None => { - if !expr_parts.is_empty() { - panic!( - "invalid relationship chain: {}", - source_column - ); - } - let value = Column::new( - Some(current_relation.clone()), - source_column_ref.column.name().to_string(), + // Only analyze the calculated field and the relationship field + if !column_ref.column.is_calculated + || column_ref.column.relationship.is_some() + { + continue; + } + + let mut directed_graph: Graph = Graph::new(); + let mut node_index_map = HashMap::new(); + let mut left_vertex = *node_index_map + .entry(column_ref.dataset.clone()) + .or_insert_with(|| directed_graph.add_node(column_ref.dataset.clone())); + + for source_column in source_columns.iter() { + let mut expr_parts = to_expr_queue(source_column.clone()); + while !expr_parts.is_empty() { + let ident = expr_parts.pop_front().unwrap(); + let Some(source_column_ref) = mdl.get_column_reference(&Column::new( + Some(current_relation.clone()), + ident.clone(), + )) else { + return plan_err!("source column not found: {}", ident); + }; + match source_column_ref.dataset { + Dataset::Model(_) => { + if let Some(rs) = + source_column_ref.column.relationship.clone() + { + if let Some(rs_rf) = mdl.get_relationship(&rs) { + let related_model_name = rs_rf + .models + .iter() + .find(|m| m != ¤t_relation.table()) + .cloned() + .unwrap(); + if related_model_name + != source_column_ref.column.r#type + { + return plan_err!( + "invalid relationship type: {}", + source_column ); - if source_column_ref.column.is_calculated { - todo!( - "calculated source column not supported" - ) - } - required_fields_map - .entry(column.clone()) - .or_default() - .insert(value); } + + utils::collect_identifiers(&rs_rf.condition) + .iter() + .cloned() + .for_each(|ident| { + required_fields_map + .entry(column.clone()) + .or_default() + .insert(Column::from_qualified_name( + format!( + "{}.{}.{}", + mdl.catalog(), + mdl.schema(), + ident.flat_name() + ), + )); + }); + + let related_model = + mdl.get_model(&related_model_name).unwrap(); + + let right_vertex = *node_index_map + .entry(Dataset::Model(Arc::clone(&related_model))) + .or_insert_with(|| { + directed_graph.add_node(Dataset::Model( + Arc::clone(&related_model), + )) + }); + directed_graph.add_edge( + left_vertex, + right_vertex, + get_dataset_link_revers_if_need( + source_column_ref.dataset.clone(), + rs_rf, + ), + ); + + current_relation = TableReference::full( + mdl.catalog(), + mdl.schema(), + related_model_name, + ); + + left_vertex = right_vertex; + } else { + return plan_err!( + "relationship not found: {}", + source_column + ); } - } - Dataset::Metric(_) => { - todo!("Metric dataset not supported"); + } else { + if !expr_parts.is_empty() { + return plan_err!( + "invalid relationship chain: {}", + source_column + ); + } + let value = Column::new( + Some(current_relation.clone()), + source_column_ref.column.name().to_string(), + ); + if source_column_ref.column.is_calculated { + todo!("calculated source column not supported") + } + required_fields_map + .entry(column.clone()) + .or_default() + .insert(value); } } + Dataset::Metric(_) => { + todo!("Metric dataset not supported"); + } } - }); - if !utils::is_dag(&directed_graph) { - panic!("cyclic dependency detected: {}", column); } - required_dataset_topo.insert(column.clone(), directed_graph); - }); - RequiredInfo { + } + if !utils::is_dag(&directed_graph) { + return plan_err!("cyclic dependency detected: {}", column); + } + required_dataset_topo.insert(column.clone(), directed_graph); + } + Ok(RequiredInfo { required_fields_map, required_dataset_topo, - } + }) } } @@ -289,6 +281,7 @@ mod test { }; use datafusion::common::Column; + use datafusion::error::Result; use datafusion::sql::TableReference; use crate::mdl::builder::{ @@ -298,18 +291,17 @@ mod test { use crate::mdl::{Dataset, WrenMDL}; #[test] - fn test_collect_source_columns() { + fn test_collect_source_columns() -> Result<()> { let test_data: PathBuf = [env!("CARGO_MANIFEST_DIR"), "tests", "data", "mdl.json"] .iter() .collect(); let mdl_json = fs::read_to_string(path::Path::new(test_data.as_path())) .expect("Unable to read file"); - let manifest = - serde_json::from_str::(&mdl_json).unwrap(); + let manifest = serde_json::from_str::(&mdl_json).unwrap(); let wren_mdl = WrenMDL::new(manifest); - let lineage = crate::mdl::lineage::Lineage::new(&wren_mdl); - assert_eq!(lineage.source_columns_map.len(), 9); + let lineage = crate::mdl::lineage::Lineage::new(&wren_mdl)?; + assert_eq!(lineage.source_columns_map.len(), 13); assert_eq!( lineage .source_columns_map @@ -357,21 +349,21 @@ mod test { name: "customer.name".to_string() } ); + Ok(()) } #[test] - fn test_collect_required_fields() { + fn test_collect_required_fields() -> Result<()> { let test_data: PathBuf = [env!("CARGO_MANIFEST_DIR"), "tests", "data", "mdl.json"] .iter() .collect(); let mdl_json = fs::read_to_string(path::Path::new(test_data.as_path())) .expect("Unable to read file"); - let manifest = - serde_json::from_str::(&mdl_json).unwrap(); + let manifest = serde_json::from_str::(&mdl_json).unwrap(); let wren_mdl = WrenMDL::new(manifest); - let lineage = crate::mdl::lineage::Lineage::new(&wren_mdl); - assert_eq!(lineage.required_fields_map.len(), 4); + let lineage = crate::mdl::lineage::Lineage::new(&wren_mdl)?; + assert_eq!(lineage.required_fields_map.len(), 5); assert_eq!( lineage .required_fields_map @@ -416,14 +408,15 @@ mod test { ]); assert_eq!(customer_name.len(), 3); - assert_eq!(*customer_name, expected,); + assert_eq!(*customer_name, expected); + Ok(()) } #[test] - fn test_required_dataset_topo() { + fn test_required_dataset_topo() -> Result<()> { let manifest = testing_manifest(); let wren_mdl = WrenMDL::new(manifest); - let lineage = crate::mdl::lineage::Lineage::new(&wren_mdl); + let lineage = crate::mdl::lineage::Lineage::new(&wren_mdl)?; assert_eq!(lineage.required_dataset_topo.len(), 2); let customer_name = lineage .required_dataset_topo @@ -451,6 +444,7 @@ mod test { let edge = customer_name.edge_weight(second_edge).unwrap(); assert_eq!(edge.join_type, JoinType::OneToOne); assert_eq!(edge.condition, "b.b1 = c.b1"); + Ok(()) } fn testing_manifest() -> Manifest { diff --git a/wren-modeling-rs/core/src/mdl/manifest.rs b/wren-modeling-rs/core/src/mdl/manifest.rs index 72e787197..0ee29075c 100644 --- a/wren-modeling-rs/core/src/mdl/manifest.rs +++ b/wren-modeling-rs/core/src/mdl/manifest.rs @@ -30,7 +30,7 @@ pub struct Model { pub table_reference: String, pub columns: Vec>, #[serde(default)] - pub primary_key: String, + pub primary_key: Option, #[serde(default)] pub cached: bool, #[serde(default)] @@ -53,6 +53,17 @@ impl Model { pub fn name(&self) -> &str { &self.name } + + pub fn get_column(&self, column_name: &str) -> Option> { + self.columns + .iter() + .find(|c| c.name == column_name) + .map(Arc::clone) + } + + pub fn primary_key(&self) -> Option<&str> { + self.primary_key.as_deref() + } } #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash)] @@ -99,6 +110,12 @@ pub enum JoinType { ManyToMany, } +impl JoinType { + pub fn is_to_one(&self) -> bool { + matches!(self, JoinType::OneToOne | JoinType::ManyToOne) + } +} + impl Display for JoinType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/wren-modeling-rs/core/src/mdl/mod.rs b/wren-modeling-rs/core/src/mdl/mod.rs index 41e11cb09..98d4ac003 100644 --- a/wren-modeling-rs/core/src/mdl/mod.rs +++ b/wren-modeling-rs/core/src/mdl/mod.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc}; use datafusion::{ config::ConfigOptions, - error::DataFusionError, + error::Result, optimizer::analyzer::Analyzer, sql::{ planner::SqlToRel, @@ -36,33 +36,33 @@ pub struct AnalyzedWrenMDL { } impl AnalyzedWrenMDL { - pub fn analyze(manifest: Manifest) -> Self { + pub fn analyze(manifest: Manifest) -> Result { let wren_mdl = Arc::new(WrenMDL::new(manifest)); - let lineage = Arc::new(lineage::Lineage::new(&wren_mdl)); - AnalyzedWrenMDL { wren_mdl, lineage } + let lineage = Arc::new(lineage::Lineage::new(&wren_mdl)?); + Ok(AnalyzedWrenMDL { wren_mdl, lineage }) } pub fn analyze_with_tables( manifest: Manifest, register_tables: HashMap>, - ) -> Self { + ) -> Result { let mut wren_mdl = WrenMDL::new(manifest); for (name, table) in register_tables { wren_mdl.register_table(name, table); } - let lineage = lineage::Lineage::new(&wren_mdl); - AnalyzedWrenMDL { + let lineage = lineage::Lineage::new(&wren_mdl)?; + Ok(AnalyzedWrenMDL { wren_mdl: Arc::new(wren_mdl), lineage: Arc::new(lineage), - } + }) } pub fn wren_mdl(&self) -> Arc { Arc::clone(&self.wren_mdl) } - pub fn lineage(&self) -> Arc { - Arc::clone(&self.lineage) + pub fn lineage(&self) -> &lineage::Lineage { + &self.lineage } } @@ -182,18 +182,12 @@ impl WrenMDL { pub fn get_column_reference( &self, column: &datafusion::common::Column, - ) -> ColumnReference { - self.qualified_references - .get(column) - .unwrap_or_else(|| panic!("column {} not found", column)) - .clone() + ) -> Option { + self.qualified_references.get(column).cloned() } } /// Transform the SQL based on the MDL -pub fn transform_sql( - analyzed_mdl: Arc, - sql: &str, -) -> Result { +pub fn transform_sql(analyzed_mdl: Arc, sql: &str) -> Result { info!("wren-core received SQL: {}", sql); // parse the SQL @@ -202,7 +196,7 @@ pub fn transform_sql( let statement = &ast[0]; // create a logical query plan - let context_provider = WrenContextProvider::new(&analyzed_mdl.wren_mdl); + let context_provider = WrenContextProvider::new(&analyzed_mdl.wren_mdl)?; let sql_to_rel = SqlToRel::new(&context_provider); let plan = match sql_to_rel.sql_statement_to_plan(statement.clone()) { Ok(plan) => plan, @@ -244,7 +238,7 @@ pub fn transform_sql( pub fn decision_point_analyze(_wren_mdl: Arc, _sql: &str) {} /// Cheap clone of the ColumnReference -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, Hash, PartialEq)] pub struct ColumnReference { pub dataset: Dataset, pub column: Arc, @@ -255,10 +249,6 @@ impl ColumnReference { ColumnReference { dataset, column } } - pub fn get_column(&self) -> Arc { - Arc::clone(&self.column) - } - pub fn get_qualified_name(&self) -> String { format!("{}.{}", self.dataset.name(), self.column.name) } @@ -277,6 +267,13 @@ impl Dataset { Dataset::Metric(metric) => metric.name(), } } + + pub fn try_as_model(&self) -> Option> { + match self { + Dataset::Model(model) => Some(Arc::clone(model)), + _ => None, + } + } } impl Display for Dataset { @@ -313,7 +310,7 @@ mod test { .collect(); let mdl_json = fs::read_to_string(test_data.as_path())?; let mdl = serde_json::from_str::(&mdl_json)?; - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); let tests: Vec<&str> = vec![ "select orderkey + orderkey from test.test.orders", @@ -321,14 +318,16 @@ mod test { "select orders.orderkey from test.test.orders left join test.test.customer on (orders.custkey = customer.custkey) where orders.totalprice > 10", "select orderkey, sum(totalprice) from test.test.orders group by 1", "select orderkey, count(*) from test.test.orders where orders.totalprice > 10 group by 1", + "select totalcost from test.test.profile", // TODO: support calculated without relationship // "select orderkey_plus_custkey from orders", ]; for sql in tests { - println!("{}", sql); + println!("Original: {}", sql); let actual = mdl::transform_sql(Arc::clone(&analyzed_mdl), sql)?; - plan_sql(&actual, Arc::clone(&analyzed_mdl))?; + let after_roundtrip = plan_sql(&actual, Arc::clone(&analyzed_mdl))?; + println!("After roundtrip: {}", after_roundtrip); } Ok(()) @@ -339,7 +338,7 @@ mod test { let ast = Parser::parse_sql(&dialect, sql).unwrap(); let statement = &ast[0]; - let context_provider = RemoteContextProvider::new(&analyzed_mdl.wren_mdl()); + let context_provider = RemoteContextProvider::new(&analyzed_mdl.wren_mdl())?; let sql_to_rel = SqlToRel::new(&context_provider); let rels = sql_to_rel.sql_statement_to_plan(statement.clone())?; // show the planned sql diff --git a/wren-modeling-rs/core/src/mdl/utils.rs b/wren-modeling-rs/core/src/mdl/utils.rs index e0411b798..edf806222 100644 --- a/wren-modeling-rs/core/src/mdl/utils.rs +++ b/wren-modeling-rs/core/src/mdl/utils.rs @@ -2,7 +2,9 @@ use std::collections::{BTreeSet, VecDeque}; use std::ops::ControlFlow; use std::sync::Arc; -use datafusion::common::Column; +use datafusion::common::{internal_err, plan_err, Column}; +use datafusion::error::Result; +use datafusion::logical_expr::logical_plan::tree_node::unwrap_arc; use datafusion::logical_expr::{Expr, LogicalPlan}; use datafusion::sql::planner::SqlToRel; use datafusion::sql::sqlparser::ast::Expr::{CompoundIdentifier, Identifier}; @@ -61,20 +63,22 @@ pub fn collect_identifiers(expr: &String) -> BTreeSet { pub fn create_wren_calculated_field_expr( column_rf: ColumnReference, analyzed_wren_mdl: Arc, -) -> Expr { +) -> Result { if !column_rf.column.is_calculated { - panic!("Column is not calculated: {}", column_rf.column.name) + return plan_err!("Column is not calculated: {}", column_rf.column.name); } let qualified_col = from_qualified_name( &analyzed_wren_mdl.wren_mdl, column_rf.dataset.name(), column_rf.column.name(), ); - let required_fields = analyzed_wren_mdl + let Some(required_fields) = analyzed_wren_mdl .lineage .required_fields_map .get(&qualified_col) - .unwrap_or_else(|| panic!("Required fields not found for {}", qualified_col)); + else { + return plan_err!("Required fields not found for {}", qualified_col); + }; // collect all required models. let models = required_fields @@ -105,17 +109,24 @@ pub fn create_wren_calculated_field_expr( }); debug!("Statement: {:?}", statement.to_string()); // Create the expression only has the table prefix. We don't need the catalog and schema prefix when planning. - let context_provider = WrenContextProvider::new_bare(&analyzed_wren_mdl.wren_mdl); + let context_provider = WrenContextProvider::new_bare(&analyzed_wren_mdl.wren_mdl)?; let sql_to_rel = SqlToRel::new(&context_provider); let plan = match sql_to_rel.sql_statement_to_plan(statement.clone()) { Ok(plan) => plan, - Err(e) => panic!("Error creating plan: {}", e), + Err(e) => return plan_err!("Error creating plan: {}", e), }; - match plan { - LogicalPlan::Projection(projection) => projection.expr[0].clone(), - _ => unreachable!("Unexpected plan type: {:?}", plan), - } + let result = match plan { + LogicalPlan::Projection(projection) => { + if let LogicalPlan::Aggregate(aggregation) = unwrap_arc(projection.input) { + aggregation.aggr_expr[0].clone() + } else { + projection.expr[0].clone() + } + } + _ => return internal_err!("Unexpected plan type: {:?}", plan), + }; + Ok(result) } /// Create the Logical Expr for the remote column. @@ -124,8 +135,8 @@ pub(crate) fn create_remote_expr_for_model( expr: &String, model: Arc, analyzed_wren_mdl: Arc, -) -> Expr { - let context_provider = RemoteContextProvider::new(&analyzed_wren_mdl.wren_mdl); +) -> Result { + let context_provider = RemoteContextProvider::new(&analyzed_wren_mdl.wren_mdl)?; create_expr_for_model( expr, model, @@ -139,8 +150,8 @@ pub(crate) fn create_wren_expr_for_model( expr: &String, model: Arc, analyzed_wren_mdl: Arc, -) -> Expr { - let context_provider = WrenContextProvider::new(&analyzed_wren_mdl.wren_mdl); +) -> Result { + let context_provider = WrenContextProvider::new(&analyzed_wren_mdl.wren_mdl)?; let wrapped = format!( "select {} from {}.{}.{}", expr, @@ -153,14 +164,11 @@ pub(crate) fn create_wren_expr_for_model( debug!("Statement: {:?}", statement.to_string()); let sql_to_rel = SqlToRel::new(&context_provider); - let plan = match sql_to_rel.sql_statement_to_plan(statement.clone()) { - Ok(plan) => plan, - Err(e) => panic!("Error creating plan: {}", e), - }; + let plan = sql_to_rel.sql_statement_to_plan(statement.clone())?; match plan { - LogicalPlan::Projection(projection) => projection.expr[0].clone(), - _ => unreachable!("Unexpected plan type: {:?}", plan), + LogicalPlan::Projection(projection) => Ok(projection.expr[0].clone()), + _ => internal_err!("Unexpected plan type: {:?}", plan), } } @@ -169,20 +177,16 @@ pub(crate) fn create_expr_for_model( expr: &String, model: Arc, context_provider: DynamicContextProvider, -) -> Expr { +) -> Result { let wrapped = format!("select {} from {}", expr, &model.table_reference); let parsed = Parser::parse_sql(&GenericDialect {}, &wrapped).unwrap(); let statement = &parsed[0]; let sql_to_rel = SqlToRel::new(&context_provider); - let plan = match sql_to_rel.sql_statement_to_plan(statement.clone()) { - Ok(plan) => plan, - Err(e) => panic!("Error creating plan: {}", e), - }; - + let plan = sql_to_rel.sql_statement_to_plan(statement.clone())?; match plan { - LogicalPlan::Projection(projection) => projection.expr[0].clone(), - _ => unreachable!("Unexpected plan type: {:?}", plan), + LogicalPlan::Projection(projection) => Ok(projection.expr[0].clone()), + _ => internal_err!("Unexpected plan type: {:?}", plan), } } @@ -195,16 +199,17 @@ mod tests { use crate::logical_plan::utils::from_qualified_name; use crate::mdl::manifest::Manifest; use crate::mdl::AnalyzedWrenMDL; + use datafusion::error::Result; #[test] - fn test_create_wren_expr() { + fn test_create_wren_expr() -> Result<()> { let test_data: PathBuf = [env!("CARGO_MANIFEST_DIR"), "tests", "data", "mdl.json"] .iter() .collect(); let mdl_json = fs::read_to_string(test_data.as_path()).unwrap(); let mdl = serde_json::from_str::(&mdl_json).unwrap(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); let column_rf = analyzed_mdl .wren_mdl @@ -218,19 +223,20 @@ mod tests { let expr = super::create_wren_calculated_field_expr( column_rf.clone(), analyzed_mdl.clone(), - ); + )?; assert_eq!(expr.to_string(), "customer.name"); + Ok(()) } #[test] - fn test_create_wren_expr_non_relationship() { + fn test_create_wren_expr_non_relationship() -> Result<()> { let test_data: PathBuf = [env!("CARGO_MANIFEST_DIR"), "tests", "data", "mdl.json"] .iter() .collect(); let mdl_json = fs::read_to_string(test_data.as_path()).unwrap(); let mdl = serde_json::from_str::(&mdl_json).unwrap(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); let column_rf = analyzed_mdl .wren_mdl @@ -244,7 +250,8 @@ mod tests { let expr = super::create_wren_calculated_field_expr( column_rf.clone(), analyzed_mdl.clone(), - ); + )?; assert_eq!(expr.to_string(), "orders.orderkey + orders.custkey"); + Ok(()) } } diff --git a/wren-modeling-rs/core/tests/data/mdl.json b/wren-modeling-rs/core/tests/data/mdl.json index 280e124f0..38180cfdc 100644 --- a/wren-modeling-rs/core/tests/data/mdl.json +++ b/wren-modeling-rs/core/tests/data/mdl.json @@ -21,6 +21,44 @@ "type": "integer", "expression": "custkey + 1", "isCalculated": true + }, + { + "name": "orders", + "type": "orders", + "relationship": "CustomerOrders" + } + ], + "primaryKey": "custkey" + }, + { + "name": "profile", + "tableReference": "profile", + "columns": [ + { + "name": "custkey", + "type": "integer", + "expression": "p_custkey" + }, + { + "name": "phone", + "type": "varchar", + "expression": "p_phone" + }, + { + "name": "sex", + "type": "varchar", + "expression": "p_sex" + }, + { + "name": "customer", + "type": "customer", + "relationship": "CustomerProfile" + }, + { + "name": "totalcost", + "type": "integer", + "isCalculated": true, + "expression": "sum(customer.orders.totalprice)" } ], "primaryKey": "custkey" @@ -77,6 +115,12 @@ "models": ["customer", "orders"], "joinType": "one_to_many", "condition": "customer.custkey = orders.custkey" + }, + { + "name" : "CustomerProfile", + "models": ["customer", "profile"], + "joinType": "one_to_one", + "condition": "customer.custkey = profile.custkey" } ] } diff --git a/wren-modeling-rs/sqllogictest/bin/sqllogictests.rs b/wren-modeling-rs/sqllogictest/bin/sqllogictests.rs index 14d7ce8f9..904f4d40d 100644 --- a/wren-modeling-rs/sqllogictest/bin/sqllogictests.rs +++ b/wren-modeling-rs/sqllogictest/bin/sqllogictests.rs @@ -18,6 +18,7 @@ use std::ffi::OsStr; use std::fs; use std::path::{Path, PathBuf}; +use std::sync::Arc; #[cfg(target_family = "windows")] use std::thread; @@ -140,10 +141,11 @@ async fn run_test_file(test_file: TestFile) -> Result<()> { info!("Skipping: {}", path.display()); return Ok(()); }; + let test_ctx = Arc::new(test_ctx); setup_scratch_dir(&relative_path)?; let mut runner = sqllogictest::Runner::new(|| async { Ok(DataFusion::new( - test_ctx.session_ctx().clone(), + Arc::clone(&test_ctx), relative_path.clone(), )) }); @@ -166,10 +168,11 @@ async fn run_complete_file(test_file: TestFile) -> Result<()> { info!("Skipping: {}", path.display()); return Ok(()); }; + let test_ctx = Arc::new(test_ctx); setup_scratch_dir(&relative_path)?; let mut runner = sqllogictest::Runner::new(|| async { Ok(DataFusion::new( - test_ctx.session_ctx().clone(), + Arc::clone(&test_ctx), relative_path.clone(), )) }); diff --git a/wren-modeling-rs/sqllogictest/src/engine/runner.rs b/wren-modeling-rs/sqllogictest/src/engine/runner.rs index c3470f755..84c82a825 100644 --- a/wren-modeling-rs/sqllogictest/src/engine/runner.rs +++ b/wren-modeling-rs/sqllogictest/src/engine/runner.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; use std::{path::PathBuf, time::Duration}; +use crate::TestContext; use async_trait::async_trait; use datafusion::arrow::record_batch::RecordBatch; use datafusion::prelude::SessionContext; use log::info; use sqllogictest::DBOutput; +use wren_core::mdl::transform_sql; use super::{ error::Result, @@ -31,12 +34,12 @@ use super::{ }; pub struct DataFusion { - ctx: SessionContext, + ctx: Arc, relative_path: PathBuf, } impl DataFusion { - pub fn new(ctx: SessionContext, relative_path: PathBuf) -> Self { + pub fn new(ctx: Arc, relative_path: PathBuf) -> Self { Self { ctx, relative_path } } } @@ -52,7 +55,8 @@ impl sqllogictest::AsyncDB for DataFusion { self.relative_path.display(), sql ); - run_query(&self.ctx, sql).await + let sql = transform_sql(Arc::clone(self.ctx.analyzed_wren_mdl()), sql)?; + run_query(self.ctx.session_ctx(), sql).await } /// Engine name of current database. diff --git a/wren-modeling-rs/sqllogictest/src/lib.rs b/wren-modeling-rs/sqllogictest/src/lib.rs index b9728a820..641b11945 100644 --- a/wren-modeling-rs/sqllogictest/src/lib.rs +++ b/wren-modeling-rs/sqllogictest/src/lib.rs @@ -1,4 +1,4 @@ pub mod engine; -mod test_context; +pub mod test_context; pub use test_context::TestContext; diff --git a/wren-modeling-rs/sqllogictest/src/test_context.rs b/wren-modeling-rs/sqllogictest/src/test_context.rs index 6273d6b6d..91cd1c512 100644 --- a/wren-modeling-rs/sqllogictest/src/test_context.rs +++ b/wren-modeling-rs/sqllogictest/src/test_context.rs @@ -34,7 +34,6 @@ use datafusion::{ }; use log::info; use tempfile::TempDir; -use wren_core::logical_plan::analyze::rule::{ModelAnalyzeRule, ModelGenerationRule}; use wren_core::logical_plan::utils::create_schema; use wren_core::mdl::builder::{ @@ -51,14 +50,16 @@ const TEST_RESOURCES: &str = "tests/resources"; pub struct TestContext { /// Context for running queries ctx: SessionContext, + analyzed_wren_mdl: Arc, /// Temporary directory created and cleared at the end of the test test_dir: Option, } impl TestContext { - pub fn new(ctx: SessionContext) -> Self { + pub fn new(ctx: SessionContext, analyzed_wren_mdl: Arc) -> Self { Self { ctx, + analyzed_wren_mdl, test_dir: None, } } @@ -106,6 +107,10 @@ impl TestContext { pub fn session_ctx(&self) -> &SessionContext { &self.ctx } + + pub fn analyzed_wren_mdl(&self) -> &Arc { + &self.analyzed_wren_mdl + } } pub async fn register_ecommerce_table(ctx: &SessionContext) -> Result { @@ -119,11 +124,13 @@ pub async fn register_ecommerce_table(ctx: &SessionContext) -> Result Result { +async fn register_ecommerce_mdl( + ctx: &SessionContext, +) -> Result<(SessionContext, Arc)> { let manifest = ManifestBuilder::new() .model( ModelBuilder::new("customers") @@ -178,6 +185,18 @@ async fn register_ecommerce_mdl(ctx: &SessionContext) -> Result .expression("customers.state") .build(), ) + .column( + ColumnBuilder::new("order_items", "order_items") + .relationship("orders_order_items") + .build(), + ) + .column( + ColumnBuilder::new("totalprice", "double") + .expression("sum(order_items.price)") + .calculated(true) + .build(), + ) + .primary_key("order_id") .build(), ) .relationship( @@ -231,22 +250,27 @@ async fn register_ecommerce_mdl(ctx: &SessionContext) -> Result let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze_with_tables( manifest, register_tables, - )); - let new_state = ctx - .state() - .add_analyzer_rule(Arc::new(ModelAnalyzeRule::new(Arc::clone(&analyzed_mdl)))) - .add_analyzer_rule(Arc::new(ModelGenerationRule::new(Arc::clone( - &analyzed_mdl, - )))) - // TODO: disable optimize_projections rule - // There are some conflict with the optimize rule, [datafusion::optimizer::optimize_projections::OptimizeProjections] - .with_optimizer_rules(vec![]); - let ctx = SessionContext::new_with_state(new_state); - register_table_with_mdl(&ctx, Arc::clone(&analyzed_mdl.wren_mdl)).await; - Ok(ctx) + )?); + // let new_state = ctx + // .state() + // .add_analyzer_rule(Arc::new(ModelAnalyzeRule:: + // + // new(Arc::clone(&analyzed_mdl)))) + // .add_analyzer_rule(Arc::new(ModelGenerationRule::new(Arc::clone( + // &analyzed_mdl, + // )))) + // // TODO: disable optimize_projections rule + // // There are some conflict with the optimize rule, [datafusion::optimizer::optimize_projections::OptimizeProjections] + // .with_optimizer_rules(vec![]); + // let ctx = SessionContext::new_with_state(new_state); + let _ = register_table_with_mdl(ctx, Arc::clone(&analyzed_mdl.wren_mdl)).await; + Ok((ctx.to_owned(), analyzed_mdl)) } -pub async fn register_table_with_mdl(ctx: &SessionContext, wren_mdl: Arc) { +pub async fn register_table_with_mdl( + ctx: &SessionContext, + wren_mdl: Arc, +) -> Result<()> { let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); @@ -256,16 +280,16 @@ pub async fn register_table_with_mdl(ctx: &SessionContext, wren_mdl: Arc) -> Self { - let schema = create_schema(model.get_physical_columns().clone()); - Self { schema } + pub fn new(model: Arc) -> Result { + let schema = create_schema(model.get_physical_columns().clone())?; + Ok(Self { schema }) } } diff --git a/wren-modeling-rs/sqllogictest/test_sql_files/model.slt b/wren-modeling-rs/sqllogictest/test_sql_files/model.slt index e823149cd..74f15666f 100644 --- a/wren-modeling-rs/sqllogictest/test_sql_files/model.slt +++ b/wren-modeling-rs/sqllogictest/test_sql_files/model.slt @@ -20,4 +20,9 @@ select count(*) from wrenai.default.order_items; query B select cnt1 = cnt2 from (select count(*) as cnt1 from (select customer_state from wrenai.default.order_items)), (select count(*) as cnt2 from datafusion.public.order_items) limit 1; ---- +true + +query B +select actual = expected from (select totalprice as actual from wrenai.default.orders where order_id = '76754c0e642c8f99a8c3fcb8a14ac700'), (select sum(price) as expected from datafusion.public.order_items where order_id = '76754c0e642c8f99a8c3fcb8a14ac700') limit 1; +---- true \ No newline at end of file diff --git a/wren-modeling-rs/wren-example/Cargo.toml b/wren-modeling-rs/wren-example/Cargo.toml index cab95f0be..3e32eafdb 100644 --- a/wren-modeling-rs/wren-example/Cargo.toml +++ b/wren-modeling-rs/wren-example/Cargo.toml @@ -11,7 +11,6 @@ version.workspace = true publish = false [dev-dependencies] -arrow-schema = { workspace = true } async-trait = { workspace = true } datafusion = { workspace = true } env_logger = { workspace = true } diff --git a/wren-modeling-rs/wren-example/examples/datafusion-apply.rs b/wren-modeling-rs/wren-example/examples/datafusion-apply.rs index 3846d438a..b4097e810 100644 --- a/wren-modeling-rs/wren-example/examples/datafusion-apply.rs +++ b/wren-modeling-rs/wren-example/examples/datafusion-apply.rs @@ -2,8 +2,8 @@ use std::any::Any; use std::collections::HashMap; use std::sync::Arc; -use arrow_schema::SchemaRef; use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef; use datafusion::catalog::schema::MemorySchemaProvider; use datafusion::catalog::{CatalogProvider, MemoryCatalogProvider}; use datafusion::datasource::{TableProvider, TableType}; @@ -86,7 +86,8 @@ async fn main() -> Result<()> { order_items_provider, ), ]); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, register)); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, register)?); let new_state = ctx .state() @@ -99,7 +100,7 @@ async fn main() -> Result<()> { )))) // There is some conflict with optimize_projections. Disable optimization rule temporarily, .with_optimizer_rules(vec![]); - register_table_with_mdl(&ctx, Arc::clone(&analyzed_mdl.wren_mdl)).await; + register_table_with_mdl(&ctx, Arc::clone(&analyzed_mdl.wren_mdl)).await?; let new_ctx = SessionContext::new_with_state(new_state); let sql = "select * from wrenai.default.order_items"; // create a plan to run a SQL query @@ -184,7 +185,10 @@ fn init_manifest() -> Manifest { .build() } -pub async fn register_table_with_mdl(ctx: &SessionContext, wren_mdl: Arc) { +pub async fn register_table_with_mdl( + ctx: &SessionContext, + wren_mdl: Arc, +) -> Result<()> { let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); @@ -194,7 +198,7 @@ pub async fn register_table_with_mdl(ctx: &SessionContext, wren_mdl: Arc) -> Self { - let schema = create_schema(model.get_physical_columns()); - Self { schema } + pub fn new(model: Arc) -> Result { + let schema = create_schema(model.get_physical_columns())?; + Ok(Self { schema }) } } diff --git a/wren-modeling-rs/wren-example/examples/plan-sql.rs b/wren-modeling-rs/wren-example/examples/plan-sql.rs index cf5dca8ed..0bc3a6ffa 100644 --- a/wren-modeling-rs/wren-example/examples/plan-sql.rs +++ b/wren-modeling-rs/wren-example/examples/plan-sql.rs @@ -8,7 +8,7 @@ use wren_core::mdl::{transform_sql, AnalyzedWrenMDL}; #[tokio::main] async fn main() -> datafusion::common::Result<()> { let manifest = init_manifest(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); let sql = "select * from wrenai.default.order_items_model"; println!("Original SQL: \n{}", sql); diff --git a/wren-modeling-rs/wren-example/examples/to-many-calculation.rs b/wren-modeling-rs/wren-example/examples/to-many-calculation.rs new file mode 100644 index 000000000..39fc26016 --- /dev/null +++ b/wren-modeling-rs/wren-example/examples/to-many-calculation.rs @@ -0,0 +1,248 @@ +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::catalog::schema::MemorySchemaProvider; +use datafusion::catalog::{CatalogProvider, MemoryCatalogProvider}; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::error::Result; +use datafusion::execution::context::SessionState; +use datafusion::logical_expr::Expr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::{CsvReadOptions, SessionContext}; + +use wren_core::logical_plan::utils::create_schema; +use wren_core::mdl::builder::{ + ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder, +}; +use wren_core::mdl::manifest::{JoinType, Manifest, Model}; +use wren_core::mdl::{transform_sql, AnalyzedWrenMDL, WrenMDL}; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + let manifest = init_manifest(); + + // register the table + let ctx = SessionContext::new(); + ctx.register_csv( + "orders", + "sqllogictest/tests/resources/ecommerce/orders.csv", + CsvReadOptions::new(), + ) + .await?; + let provider = ctx + .catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("orders") + .await? + .unwrap(); + + ctx.register_csv( + "customers", + "sqllogictest/tests/resources/ecommerce/customers.csv", + CsvReadOptions::new(), + ) + .await?; + let customers_provider = ctx + .catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("customers") + .await? + .unwrap(); + + ctx.register_csv( + "order_items", + "sqllogictest/tests/resources/ecommerce/order_items.csv", + CsvReadOptions::new(), + ) + .await?; + let order_items_provider = ctx + .catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("order_items") + .await? + .unwrap(); + + let register = HashMap::from([ + ("datafusion.public.orders".to_string(), provider), + ( + "datafusion.public.customers".to_string(), + customers_provider, + ), + ( + "datafusion.public.order_items".to_string(), + order_items_provider, + ), + ]); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, register)?); + + let transformed = transform_sql( + Arc::clone(&analyzed_mdl), + "select totalprice from wrenai.default.orders", + ) + .unwrap(); + register_table_with_mdl(&ctx, Arc::clone(&analyzed_mdl.wren_mdl)).await?; + let df = ctx.sql(&transformed).await?; + df.show().await?; + Ok(()) +} + +fn init_manifest() -> Manifest { + ManifestBuilder::new() + .model( + ModelBuilder::new("customers") + .table_reference("datafusion.public.customers") + .column(ColumnBuilder::new("city", "varchar").build()) + .column(ColumnBuilder::new("id", "varchar").build()) + .column(ColumnBuilder::new("state", "varchar").build()) + .primary_key("id") + .build(), + ) + .model( + ModelBuilder::new("order_items") + .table_reference("datafusion.public.order_items") + .column(ColumnBuilder::new("freight_value", "double").build()) + .column(ColumnBuilder::new("id", "bigint").build()) + .column(ColumnBuilder::new("item_number", "bigint").build()) + .column(ColumnBuilder::new("order_id", "varchar").build()) + .column(ColumnBuilder::new("price", "double").build()) + .column(ColumnBuilder::new("product_id", "varchar").build()) + .column(ColumnBuilder::new("shipping_limit_date", "varchar").build()) + .column( + ColumnBuilder::new("orders", "orders") + .relationship("orders_order_items") + .build(), + ) + .column( + ColumnBuilder::new("customer_state", "varchar") + .calculated(true) + .expression("orders.customers.state") + .build(), + ) + .primary_key("id") + .build(), + ) + .model( + ModelBuilder::new("orders") + .table_reference("datafusion.public.orders") + .column(ColumnBuilder::new("approved_timestamp", "varchar").build()) + .column(ColumnBuilder::new("customer_id", "varchar").build()) + .column(ColumnBuilder::new("delivered_carrier_date", "varchar").build()) + .column(ColumnBuilder::new("estimated_delivery_date", "varchar").build()) + .column(ColumnBuilder::new("order_id", "varchar").build()) + .column(ColumnBuilder::new("purchase_timestamp", "varchar").build()) + .column( + ColumnBuilder::new("order_items", "order_items") + .relationship("orders_order_items") + .build(), + ) + .column( + ColumnBuilder::new("totalprice", "double") + .expression("sum(order_items.price)") + .calculated(true) + .build(), + ) + .primary_key("order_id") + .column( + ColumnBuilder::new("customers", "customers") + .relationship("orders_customer") + .build(), + ) + .column( + ColumnBuilder::new("customer_state", "varchar") + .calculated(true) + .expression("customers.state") + .build(), + ) + .build(), + ) + .relationship( + RelationshipBuilder::new("orders_customer") + .model("orders") + .model("customers") + .join_type(JoinType::ManyToOne) + .condition("orders.customer_id = customers.id") + .build(), + ) + .relationship( + RelationshipBuilder::new("orders_order_items") + .model("orders") + .model("order_items") + .join_type(JoinType::ManyToOne) + .condition("orders.order_id = order_items.order_id") + .build(), + ) + .build() +} + +pub async fn register_table_with_mdl( + ctx: &SessionContext, + wren_mdl: Arc, +) -> Result<()> { + let catalog = MemoryCatalogProvider::new(); + let schema = MemorySchemaProvider::new(); + + catalog + .register_schema(&wren_mdl.manifest.schema, Arc::new(schema)) + .unwrap(); + ctx.register_catalog(&wren_mdl.manifest.catalog, Arc::new(catalog)); + + for model in wren_mdl.manifest.models.iter() { + let table = WrenDataSource::new(Arc::clone(model))?; + ctx.register_table( + format!( + "{}.{}.{}", + &wren_mdl.manifest.catalog, &wren_mdl.manifest.schema, &model.name + ), + Arc::new(table), + )?; + } + Ok(()) +} + +struct WrenDataSource { + schema: SchemaRef, +} + +impl WrenDataSource { + pub fn new(model: Arc) -> Result { + let schema = create_schema(model.get_physical_columns())?; + Ok(Self { schema }) + } +} + +#[async_trait] +impl TableProvider for WrenDataSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::View + } + + async fn scan( + &self, + _state: &SessionState, + _projection: Option<&Vec>, + // filters and limit can be used here to inject some push-down operations if needed + _filters: &[Expr], + _limit: Option, + ) -> Result> { + unreachable!("WrenDataSource should be replaced before physical planning") + } +}