diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..ae64c41 Binary files /dev/null and b/.DS_Store differ diff --git a/Cargo.lock b/Cargo.lock index 79d3338..bd6dbf5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -36,10 +36,19 @@ dependencies = [ ] [[package]] -name = "aligned-vec" -version = "0.5.0" +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] [[package]] name = "anyhow" @@ -54,27 +63,114 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d301b3b94cb4b2f23d7917810addbbaff90738e0ca2be692bd027e70d7e0330c" [[package]] -name = "arbitrary" -version = "1.3.2" +name = "ark-ec" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" +checksum = "defd9a439d56ac24968cca0571f598a61bc8c55f71d50a89cda591cb750670ba" +dependencies = [ + "ark-ff", + "ark-poly", + "ark-serialize", + "ark-std", + "derivative", + "hashbrown 0.13.2", + "itertools 0.10.5", + "num-traits", + "rayon", + "zeroize", +] [[package]] -name = "arg_enum_proc_macro" -version = "0.3.4" +name = "ark-ff" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" +checksum = "ec847af850f44ad29048935519032c33da8aa03340876d351dfab5660d2966ba" dependencies = [ + "ark-ff-asm", + "ark-ff-macros", + "ark-serialize", + "ark-std", + "derivative", + "digest", + "itertools 0.10.5", + "num-bigint", + "num-traits", + "paste", + "rayon", + "rustc_version", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ed4aa4fe255d0bc6d79373f7e31d2ea147bcf486cba1be5ba7ea85abdb92348" +dependencies = [ + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-ff-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7abe79b0e4288889c4574159ab790824d0033b9fdcb2a112a3182fac2e514565" +dependencies = [ + "num-bigint", + "num-traits", "proc-macro2", "quote", - "syn 2.0.82", + "syn 1.0.109", ] [[package]] -name = "arrayvec" -version = "0.7.6" +name = "ark-poly" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +checksum = "d320bfc44ee185d899ccbadfa8bc31aab923ce1558716e1997a1e74057fe86bf" +dependencies = [ + "ark-ff", + "ark-serialize", + "ark-std", + "derivative", + "hashbrown 0.13.2", + "rayon", +] + +[[package]] +name = "ark-serialize" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb7b85a02b83d2f22f89bd5cac66c9c89474240cb6207cb1efc16d098e822a5" +dependencies = [ + "ark-serialize-derive", + "ark-std", + "digest", + "num-bigint", +] + +[[package]] +name = "ark-serialize-derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae3281bc6d0fd7e549af32b52511e1302185bd688fd3359fa36423346ff682ea" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-std" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94893f1e0c6eeab764ade8dc4c0db24caf4fe7cbbaafc0eba0a9030f447b5185" +dependencies = [ + "num-traits", + "rand", + "rayon", +] [[package]] name = "autocfg" @@ -83,26 +179,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] -name = "av1-grain" -version = "0.2.3" +name = "base64" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6678909d8c5d46a42abcf571271e15fdbc0a225e3646cf23762cd415046c78bf" -dependencies = [ - "anyhow", - "arrayvec", - "log", - "nom", - "num-rational", - "v_frame", -] +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] -name = "avif-serialize" -version = "0.8.2" +name = "bcs" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e335041290c43101ca215eed6f43ec437eb5a42125573f600fc3fa42b9bddd62" +checksum = "85b6598a2f5d564fb7855dc6b06fd1c38cff5a72bd8b863a4d021938497b440a" dependencies = [ - "arrayvec", + "serde", + "thiserror", ] [[package]] @@ -148,10 +237,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] -name = "bitstream-io" -version = "2.5.3" +name = "blake2" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] [[package]] name = "block-buffer" @@ -162,12 +254,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "built" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c360505aed52b7ec96a3636c3f039d99103c37d1d9b4f7a8c743d3ea9ffcd03b" - [[package]] name = "bumpalo" version = "3.16.0" @@ -186,12 +272,6 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" -[[package]] -name = "byteorder-lite" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" - [[package]] name = "bytes" version = "1.8.0" @@ -204,33 +284,42 @@ version = "1.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" dependencies = [ - "jobserver", - "libc", "shlex", ] -[[package]] -name = "cfg-expr" -version = "0.15.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d067ad48b8650848b989a59a86c6c36a995d02d2bf778d45c3c5d57bc2718f02" -dependencies = [ - "smallvec", - "target-lexicon", -] - [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-targets", +] + [[package]] name = "color_quant" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + [[package]] name = "cpufeatures" version = "0.2.14" @@ -290,6 +379,41 @@ dependencies = [ "typenum", ] +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.82", +] + +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.82", +] + [[package]] name = "deranged" version = "0.3.11" @@ -297,6 +421,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ "powerfmt", + "serde", +] + +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", ] [[package]] @@ -310,6 +446,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + [[package]] name = "digest" version = "0.10.7" @@ -318,6 +460,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -422,6 +565,107 @@ dependencies = [ "spin", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.82", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -453,6 +697,22 @@ dependencies = [ "weezl", ] +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + +[[package]] +name = "groupmap" +version = "0.1.0" +source = "git+https://github.com/o1-labs/proof-systems#df2415ab32d2157df6b61660512f0140d7f52203" +dependencies = [ + "ark-ec", + "ark-ff", + "rand", +] + [[package]] name = "half" version = "2.4.1" @@ -464,6 +724,21 @@ dependencies = [ "num-traits", ] +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "hashbrown" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" +dependencies = [ + "ahash", +] + [[package]] name = "hashbrown" version = "0.14.5" @@ -485,45 +760,73 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +dependencies = [ + "serde", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "image" -version = "0.25.4" +version = "0.24.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc144d44a31d753b02ce64093d532f55ff8dc4ebf2ffb8a63c0dda691385acae" +checksum = "5690139d2f55868e080017335e4b94cb7414274c74f1669c84fb5feba2c9f69d" dependencies = [ "bytemuck", - "byteorder-lite", + "byteorder", "color_quant", "exr", "gif", - "image-webp", + "jpeg-decoder", "num-traits", "png", "qoi", - "ravif", - "rayon", - "rgb", "tiff", - "zune-core", - "zune-jpeg", ] [[package]] -name = "image-webp" -version = "0.2.0" +name = "indexmap" +version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e031e8e3d94711a9ccb5d6ea357439ef3dcbed361798bd4071dc4d9793fbe22f" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ - "byteorder-lite", - "quick-error", + "autocfg", + "hashbrown 0.12.3", + "serde", ] -[[package]] -name = "imgref" -version = "1.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408" - [[package]] name = "indexmap" version = "2.6.0" @@ -532,6 +835,7 @@ checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", "hashbrown 0.15.0", + "serde", ] [[package]] @@ -544,15 +848,9 @@ dependencies = [ ] [[package]] -name = "interpolate_name" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.82", -] +name = "internal-tracing" +version = "0.1.0" +source = "git+https://github.com/o1-labs/proof-systems#df2415ab32d2157df6b61660512f0140d7f52203" [[package]] name = "itertools" @@ -588,19 +886,58 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] -name = "jobserver" -version = "0.1.32" +name = "jpeg-decoder" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" dependencies = [ - "libc", + "rayon", ] [[package]] -name = "jpeg-decoder" -version = "0.3.1" +name = "js-sys" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "kimchi" +version = "0.1.0" +source = "git+https://github.com/o1-labs/proof-systems#df2415ab32d2157df6b61660512f0140d7f52203" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-poly", + "ark-serialize", + "blake2", + "groupmap", + "hex", + "internal-tracing", + "itertools 0.12.1", + "log", + "mina-curves", + "mina-poseidon", + "num-bigint", + "num-derive", + "num-integer", + "num-traits", + "o1-utils", + "once_cell", + "poly-commitment", + "rand", + "rand_core", + "rayon", + "rmp-serde", + "serde", + "serde_with", + "strum", + "strum_macros", + "thiserror", + "turshi", +] [[package]] name = "kstring" @@ -630,17 +967,6 @@ version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" -[[package]] -name = "libfuzzer-sys" -version = "0.4.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a96cfd5557eb82f2b83fed4955246c988d331975a002961b07c81584d107e7f7" -dependencies = [ - "arbitrary", - "cc", - "once_cell", -] - [[package]] name = "libm" version = "0.2.8" @@ -737,15 +1063,6 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" -[[package]] -name = "loop9" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fae87c125b03c1d2c0150c90365d7d6bcc53fb73a9acaef207d2d065860f062" -dependencies = [ - "imgref", -] - [[package]] name = "maplit" version = "1.0.2" @@ -762,15 +1079,6 @@ dependencies = [ "rawpointer", ] -[[package]] -name = "maybe-rayon" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ea1f30cedd69f0a2954655f7188c6a834246d2bcf1e315e2ac40c4b24dc9519" -dependencies = [ - "cfg-if", -] - [[package]] name = "memchr" version = "2.7.4" @@ -786,6 +1094,64 @@ dependencies = [ "libc", ] +[[package]] +name = "mina-curves" +version = "0.1.0" +source = "git+https://github.com/o1-labs/proof-systems#df2415ab32d2157df6b61660512f0140d7f52203" +dependencies = [ + "ark-ec", + "ark-ff", + "num-bigint", +] + +[[package]] +name = "mina-poseidon" +version = "0.1.0" +source = "git+https://github.com/o1-labs/proof-systems#df2415ab32d2157df6b61660512f0140d7f52203" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-poly", + "ark-serialize", + "mina-curves", + "o1-utils", + "once_cell", + "rand", + "rayon", + "serde", + "serde_with", +] + +[[package]] +name = "mina-zkml" +version = "0.1.0" +dependencies = [ + "anyhow", + "ark-ec", + "ark-ff", + "ark-poly", + "bincode", + "chrono", + "groupmap", + "image", + "instant", + "kimchi", + "log", + "mina-curves", + "mina-poseidon", + "ndarray", + "poly-commitment", + "pretty_assertions", + "rand", + "rstest", + "scale", + "serde", + "serde_json", + "test-case", + "thiserror", + "tract-onnx", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -824,12 +1190,6 @@ dependencies = [ "rawpointer", ] -[[package]] -name = "new_debug_unreachable" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" - [[package]] name = "nom" version = "7.1.3" @@ -840,12 +1200,6 @@ dependencies = [ "minimal-lexical", ] -[[package]] -name = "noop_proc_macro" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" - [[package]] name = "num-bigint" version = "0.4.6" @@ -854,6 +1208,8 @@ checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", + "rand", + "serde", ] [[package]] @@ -891,17 +1247,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-rational" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" -dependencies = [ - "num-bigint", - "num-integer", - "num-traits", -] - [[package]] name = "num-traits" version = "0.2.19" @@ -912,29 +1257,37 @@ dependencies = [ "libm", ] +[[package]] +name = "o1-utils" +version = "0.1.0" +source = "git+https://github.com/o1-labs/proof-systems#df2415ab32d2157df6b61660512f0140d7f52203" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-poly", + "ark-serialize", + "bcs", + "hex", + "num-bigint", + "num-integer", + "num-traits", + "rand", + "rand_core", + "rayon", + "rmp-serde", + "secp256k1", + "serde", + "serde_with", + "sha2", + "thiserror", +] + [[package]] name = "once_cell" version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" -[[package]] -name = "onnx-parser" -version = "0.1.0" -dependencies = [ - "anyhow", - "bincode", - "image", - "instant", - "log", - "ndarray", - "scale", - "serde", - "serde_json", - "thiserror", - "tract-onnx", -] - [[package]] name = "paste" version = "1.0.15" @@ -993,10 +1346,16 @@ dependencies = [ ] [[package]] -name = "pkg-config" -version = "0.3.31" +name = "pin-project-lite" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" + +[[package]] +name = "pin-utils" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "png" @@ -1011,6 +1370,31 @@ dependencies = [ "miniz_oxide 0.8.0", ] +[[package]] +name = "poly-commitment" +version = "0.1.0" +source = "git+https://github.com/o1-labs/proof-systems#df2415ab32d2157df6b61660512f0140d7f52203" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-poly", + "ark-serialize", + "blake2", + "groupmap", + "itertools 0.12.1", + "mina-curves", + "mina-poseidon", + "o1-utils", + "once_cell", + "rand", + "rand_core", + "rayon", + "rmp-serde", + "serde", + "serde_with", + "thiserror", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -1026,6 +1410,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "pretty_assertions" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" +dependencies = [ + "diff", + "yansi", +] + [[package]] name = "primal-check" version = "0.3.4" @@ -1044,25 +1438,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "profiling" -version = "1.0.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afbdc74edc00b6f6a218ca6a5364d6226a259d4b8ea1af4a0ea063f27e179f4d" -dependencies = [ - "profiling-procmacros", -] - -[[package]] -name = "profiling-procmacros" -version = "1.0.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30" -dependencies = [ - "quote", - "syn 2.0.82", -] - [[package]] name = "prost" version = "0.11.9" @@ -1095,12 +1470,6 @@ dependencies = [ "bytemuck", ] -[[package]] -name = "quick-error" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" - [[package]] name = "quote" version = "1.0.37" @@ -1150,55 +1519,6 @@ dependencies = [ "rand", ] -[[package]] -name = "rav1e" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd87ce80a7665b1cce111f8a16c1f3929f6547ce91ade6addf4ec86a8dda5ce9" -dependencies = [ - "arbitrary", - "arg_enum_proc_macro", - "arrayvec", - "av1-grain", - "bitstream-io", - "built", - "cfg-if", - "interpolate_name", - "itertools 0.12.1", - "libc", - "libfuzzer-sys", - "log", - "maybe-rayon", - "new_debug_unreachable", - "noop_proc_macro", - "num-derive", - "num-traits", - "once_cell", - "paste", - "profiling", - "rand", - "rand_chacha", - "simd_helpers", - "system-deps", - "thiserror", - "v_frame", - "wasm-bindgen", -] - -[[package]] -name = "ravif" -version = "0.11.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2413fd96bd0ea5cdeeb37eaf446a22e6ed7b981d792828721e74ded1980a45c6" -dependencies = [ - "avif-serialize", - "imgref", - "loop9", - "quick-error", - "rav1e", - "rgb", -] - [[package]] name = "rawpointer" version = "0.2.1" @@ -1264,10 +1584,70 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] -name = "rgb" -version = "0.8.50" +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + +[[package]] +name = "rmp" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmp-serde" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + +[[package]] +name = "rstest" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97eeab2f3c0a199bc4be135c36c924b6590b88c377d416494288c14f2db30199" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d428f8247852f894ee1be110b375111b586d4fa431f6c46e64ba5a0dcccbe605" +dependencies = [ + "cfg-if", + "glob", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn 2.0.82", + "unicode-ident", +] + +[[package]] +name = "rustc_version" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] [[package]] name = "rustfft" @@ -1297,6 +1677,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustversion" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" + [[package]] name = "ryu" version = "1.0.18" @@ -1333,6 +1719,30 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "secp256k1" +version = "0.28.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d24b59d129cdadea20aea4fb2352fa053712e5d713eee47d700cd4b2bc002f10" +dependencies = [ + "secp256k1-sys", +] + +[[package]] +name = "secp256k1-sys" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5d1746aae42c19d583c3c1a8c646bfad910498e2051c551a7f2e3c0c9fbb7eb" +dependencies = [ + "cc", +] + +[[package]] +name = "semver" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" + [[package]] name = "serde" version = "1.0.213" @@ -1366,12 +1776,33 @@ dependencies = [ ] [[package]] -name = "serde_spanned" -version = "0.6.8" +name = "serde_with" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" +checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" dependencies = [ + "base64", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.6.0", "serde", + "serde_derive", + "serde_json", + "serde_with_macros", + "time", +] + +[[package]] +name = "serde_with_macros" +version = "3.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d846214a9854ef724f3da161b426242d8de7c1fc7de2f89bb1efcb154dca79d" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.82", ] [[package]] @@ -1398,12 +1829,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" [[package]] -name = "simd_helpers" -version = "0.1.0" +name = "slab" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95890f873bec569a0362c235787f3aca6e1e887302ba4840839bcc6459c42da6" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" dependencies = [ - "quote", + "autocfg", ] [[package]] @@ -1444,6 +1875,37 @@ dependencies = [ "serde", ] +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.82", +] + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "1.0.109" @@ -1466,19 +1928,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "system-deps" -version = "6.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349" -dependencies = [ - "cfg-expr", - "heck", - "pkg-config", - "toml", - "version-compare", -] - [[package]] name = "tar" version = "0.4.42" @@ -1491,10 +1940,37 @@ dependencies = [ ] [[package]] -name = "target-lexicon" -version = "0.12.16" +name = "test-case" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb2550dd13afcd286853192af8601920d959b14c401fcece38071d53bf0768a8" +dependencies = [ + "test-case-macros", +] + +[[package]] +name = "test-case-core" +version = "3.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +checksum = "adcb7fd841cd518e279be3d5a3eb0636409487998a4aff22f3de87b81e88384f" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.82", +] + +[[package]] +name = "test-case-macros" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c89e72a01ed4c579669add59014b9a524d609c0c88c6a585ce37485879f6ffb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.82", + "test-case-core", +] [[package]] name = "thiserror" @@ -1573,40 +2049,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" -[[package]] -name = "toml" -version = "0.8.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" -dependencies = [ - "serde", - "serde_spanned", - "toml_datetime", - "toml_edit", -] - -[[package]] -name = "toml_datetime" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" -dependencies = [ - "serde", -] - -[[package]] -name = "toml_edit" -version = "0.22.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" -dependencies = [ - "indexmap", - "serde", - "serde_spanned", - "toml_datetime", - "winnow", -] - [[package]] name = "tract-core" version = "0.21.6-pre" @@ -1744,6 +2186,16 @@ dependencies = [ "strength_reduce", ] +[[package]] +name = "turshi" +version = "0.1.0" +source = "git+https://github.com/o1-labs/proof-systems#df2415ab32d2157df6b61660512f0140d7f52203" +dependencies = [ + "ark-ff", + "hex", + "o1-utils", +] + [[package]] name = "typenum" version = "1.17.0" @@ -1777,23 +2229,6 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" -[[package]] -name = "v_frame" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6f32aaa24bacd11e488aa9ba66369c7cd514885742c9fe08cfe85884db3e92b" -dependencies = [ - "aligned-vec", - "num-traits", - "wasm-bindgen", -] - -[[package]] -name = "version-compare" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b" - [[package]] name = "version_check" version = "0.9.5" @@ -1886,6 +2321,15 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -1968,15 +2412,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" -[[package]] -name = "winnow" -version = "0.6.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" -dependencies = [ - "memchr", -] - [[package]] name = "xattr" version = "1.3.1" @@ -1988,6 +2423,12 @@ dependencies = [ "rustix", ] +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "zerocopy" version = "0.7.35" @@ -2010,25 +2451,30 @@ dependencies = [ ] [[package]] -name = "zune-core" -version = "0.4.12" +name = "zeroize" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f423a2c17029964870cfaabb1f13dfab7d092a62a29a89264f4d36990ca414a" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] [[package]] -name = "zune-inflate" -version = "0.2.54" +name = "zeroize_derive" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73ab332fe2f6680068f3582b16a24f90ad7096d5d39b974d1c0aff0125116f02" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ - "simd-adler32", + "proc-macro2", + "quote", + "syn 2.0.82", ] [[package]] -name = "zune-jpeg" -version = "0.4.13" +name = "zune-inflate" +version = "0.2.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16099418600b4d8f028622f73ff6e3deaabdff330fb9a2a131dea781ee8b0768" +checksum = "73ab332fe2f6680068f3582b16a24f90ad7096d5d39b974d1c0aff0125116f02" dependencies = [ - "zune-core", + "simd-adler32", ] diff --git a/Cargo.toml b/Cargo.toml index b17de97..1e1c2ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,16 +1,28 @@ [package] -name = "onnx-parser" +name = "mina-zkml" version = "0.1.0" edition = "2021" [lib] -name = "kimchi" +name = "mina_zkml" crate-type = ["cdylib", "rlib"] +[[example]] +name = "perceptron" +path = "examples/perceptron.rs" + +[[example]] +name = "mnist_inference" +path = "examples/mnist_inference.rs" + +[[example]] +name = "zk_inference" +path = "examples/zk_inference.rs" + [dependencies] anyhow = "1.0.90" bincode = "1.3" -image = "0.25.4" +image = "0.24.7" instant = "0.1.13" log = "0.4.22" ndarray = "0.15.4" @@ -19,3 +31,22 @@ serde = { version = "1.0.210", features = ["derive"] } serde_json = "1.0" thiserror = "1.0.64" tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default-features = false } +kimchi = { git = "https://github.com/o1-labs/proof-systems", package = "kimchi" } +ark-ff = "0.4.0" +ark-poly = "0.4.0" +ark-ec = "0.4.0" +mina-curves = { git = "https://github.com/o1-labs/proof-systems" } +chrono = "0.4.38" +rand = "0.8.5" +groupmap = { git = "https://github.com/o1-labs/proof-systems" } +poly-commitment = { git = "https://github.com/o1-labs/proof-systems" } +mina-poseidon = { git = "https://github.com/o1-labs/proof-systems" } + +[dev-dependencies] +pretty_assertions = "1.4.0" +test-case = "3.3.1" +rstest = "0.18.2" + +[features] +default = [] +test-utils = [] diff --git a/examples/mnist_inference.rs b/examples/mnist_inference.rs new file mode 100644 index 0000000..2faa988 --- /dev/null +++ b/examples/mnist_inference.rs @@ -0,0 +1,98 @@ +use mina_zkml::graph::model::{Model, RunArgs, VarVisibility, Visibility}; +use std::collections::HashMap; + +fn preprocess_image(img_path: &str) -> Result, Box> { + // Load and convert image to grayscale + let img = image::open(img_path)?.into_luma8(); + + // Ensure image is 28x28 + let resized = image::imageops::resize(&img, 28, 28, image::imageops::FilterType::Lanczos3); + + // Convert to f32 and normalize to [0, 1] + let pixels: Vec = resized.into_raw().into_iter().map(|x| x as f32).collect(); + + //Apply normalization + let pixels: Vec = pixels + .into_iter() + .map(|x| (x / 255.0 - 0.1307) / 0.3081) + .collect(); + + // Create a batch dimension by wrapping the flattened pixels + let mut input = Vec::with_capacity(28 * 28); + input.extend_from_slice(&pixels); + Ok(input) +} + +fn main() -> Result<(), Box> { + // Create run args with batch size + let mut variables = HashMap::new(); + variables.insert("batch_size".to_string(), 1); + let run_args = RunArgs { variables }; + + // Create visibility settings + let visibility = VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }; + + // Load the MNIST model + println!("Loading MNIST model..."); + let model = Model::new("models/mnist_mlp.onnx", &run_args, &visibility).map_err(|e| { + println!("Error loading model: {:?}", e); + e + })?; + + // Print model structure + println!("\nModel structure:"); + println!("Number of nodes: {}", model.graph.nodes.len()); + println!("Input nodes: {:?}", model.graph.inputs); + println!("Output nodes: {:?}", model.graph.outputs); + + // Load and preprocess the image + println!("\nLoading and preprocessing image..."); + let input = preprocess_image("models/data/1052.png")?; + + // Execute the model + println!("\nRunning inference..."); + let result = model.graph.execute(&[input])?; + + //Result + println!("Result: {:?}", result); + + // Print the output probabilities + println!("\nOutput probabilities for digits 0-9:"); + if let Some(probabilities) = result.first() { + // The model outputs logits, so we need to apply softmax + let max_logit = probabilities + .iter() + .take(10) + .fold(f32::NEG_INFINITY, |a, &b| a.max(b)); + let exp_sum: f32 = probabilities + .iter() + .take(10) + .map(|&x| (x - max_logit).exp()) + .sum(); + + let softmax: Vec = probabilities + .iter() + .take(10) + .map(|&x| ((x - max_logit).exp()) / exp_sum) + .collect(); + + for (digit, &prob) in softmax.iter().enumerate() { + println!("Digit {}: {:.4}", digit, prob); + } + + // Find the predicted digit + let predicted_digit = softmax + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .map(|(digit, _)| digit) + .unwrap(); + + println!("\nPredicted digit: {}", predicted_digit); + } + + Ok(()) +} diff --git a/examples/perceptron.rs b/examples/perceptron.rs new file mode 100644 index 0000000..6b3a01f --- /dev/null +++ b/examples/perceptron.rs @@ -0,0 +1,55 @@ +use mina_zkml::graph::model::{Model, RunArgs, VarVisibility, Visibility}; +use std::collections::HashMap; + +fn main() -> Result<(), Box> { + // Create run args with batch size + let mut variables = HashMap::new(); + variables.insert("batch_size".to_string(), 1); + let run_args = RunArgs { variables }; + + // Create visibility settings + let visibility = VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }; + + // Load the perceptron model + println!("Loading perceptron model..."); + let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?; + + // Print model structure + println!("\nModel structure:"); + println!("Number of nodes: {}", model.graph.nodes.len()); + println!("Input nodes: {:?}", model.graph.inputs); + println!("Output nodes: {:?}", model.graph.outputs); + + // Print node connections + println!("\nNode connections:"); + for (id, node) in &model.graph.nodes { + match node { + mina_zkml::graph::model::NodeType::Node(n) => { + println!("Node {}: {:?} inputs: {:?}", id, n.op_type, n.inputs); + println!("Output dimensions: {:?}", n.out_dims); + println!("Weight Tensor: {:?}", n.weights); + println!("Bias Tensor: {:?}", n.bias); + } + mina_zkml::graph::model::NodeType::SubGraph { .. } => { + println!("Node {}: SubGraph", id); + } + } + } + + // Create a sample input vector of size 10 + let input = vec![1.0, 0.5, -0.3, 0.8, -0.2, 0.7, 0.1, -0.4, 0.9, 0.6]; + println!("\nInput vector (size 10):"); + println!("{:?}", input); + + // Execute the model + let result = model.graph.execute(&[input])?; + + // Print the output + println!("\nOutput vector (size 3, after ReLU):"); + println!("{:?}", result[0]); + + Ok(()) +} diff --git a/examples/zk_inference.rs b/examples/zk_inference.rs new file mode 100644 index 0000000..aae1908 --- /dev/null +++ b/examples/zk_inference.rs @@ -0,0 +1,64 @@ +use mina_zkml::{ + graph::model::{Model, RunArgs, VarVisibility, Visibility}, + zk::proof::ProofSystem, +}; +use std::collections::HashMap; + +fn main() -> Result<(), Box> { + // 1. Load the model + println!("Loading model..."); + let mut variables = HashMap::new(); + variables.insert("batch_size".to_string(), 1); + let run_args = RunArgs { variables }; + + let visibility = VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }; + + let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?; + + // 2. Create proof system + println!("Creating proof system..."); + let proof_system = ProofSystem::new(&model); + + // 3. Create sample input (with proper padding to size 10) + let input = vec![vec![ + 1.0, 0.5, -0.3, 0.8, -0.2, // Original values + 0.0, 0.0, 0.0, 0.0, 0.0, // Padding to reach size 10 + ]]; + + // 4. Generate output and proof + println!("Generating output and proof..."); + let prover_output = proof_system.prove(&input)?; + println!("Model output: {:?}", prover_output.output); + + // 5. Verify the proof with output and proof + println!("Verifying proof..."); + let is_valid = proof_system.verify(&prover_output.output, &prover_output.proof)?; + + println!("\nResults:"); + println!("Model execution successful: ✓"); + println!("Proof creation successful: ✓"); + println!( + "Proof verification: {}", + if is_valid { "✓ Valid" } else { "✗ Invalid" } + ); + + // 6. Demonstrate invalid verification with modified output + println!("\nTesting invalid case with modified output..."); + let mut modified_output = prover_output.output.clone(); + modified_output[0][0] += 1.0; // Modify first output value + + let is_valid_modified = proof_system.verify(&modified_output, &prover_output.proof)?; + println!( + "Modified output verification: {}", + if !is_valid_modified { + "✗ Invalid (Expected)" + } else { + "✓ Valid (Unexpected!)" + } + ); + + Ok(()) +} diff --git a/examples/zk_inference_fail.rs b/examples/zk_inference_fail.rs new file mode 100644 index 0000000..d8ff624 --- /dev/null +++ b/examples/zk_inference_fail.rs @@ -0,0 +1,57 @@ +use mina_zkml::{ + graph::model::{Model, RunArgs, VarVisibility, Visibility}, + zk::proof::ProofSystem, +}; +use std::collections::HashMap; + +fn main() -> Result<(), Box> { + // 1. Load the model + println!("Loading model..."); + let mut variables = HashMap::new(); + variables.insert("batch_size".to_string(), 1); + let run_args = RunArgs { variables }; + + let visibility = VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }; + + let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility)?; + + // 2. Create proof system + println!("Creating proof system..."); + let proof_system = ProofSystem::new(&model); + + // 3. Create sample input (with proper padding to size 10) + let input = vec![vec![ + 1.0, 0.5, -0.3, 0.8, -0.2, // Original values + 0.0, 0.0, 0.0, 0.0, 0.0, // Padding to reach size 10 + ]]; + + // 4. Generate output and proof + println!("Generating output and proof..."); + let prover_output = proof_system.prove(&input)?; + println!("Model output: {:?}", prover_output.output); + + // 5. Create modified output (simulating malicious behavior) + let mut modified_output = prover_output.output.clone(); + modified_output[0][0] += 1.0; // Modify first output value + + // 6. Try to verify with modified output (should fail) + println!("Verifying proof with modified output..."); + let is_valid = proof_system.verify(&modified_output, &prover_output.proof)?; + + println!("\nResults:"); + println!("Model execution successful: ✓"); + println!("Proof creation successful: ✓"); + println!( + "Modified output verification: {}", + if !is_valid { + "✗ Invalid (Expected)" + } else { + "✓ Valid (Unexpected!)" + } + ); + + Ok(()) +} diff --git a/examples/zk_mnist.rs b/examples/zk_mnist.rs new file mode 100644 index 0000000..fbdde67 --- /dev/null +++ b/examples/zk_mnist.rs @@ -0,0 +1,209 @@ +use mina_zkml::graph::model::{Model, RunArgs, VarVisibility, Visibility}; +use mina_zkml::zk::proof::ProofSystem; +use std::collections::HashMap; + +fn preprocess_image(img_path: &str) -> Result, Box> { + // Load and convert image to grayscale + let img = image::open(img_path)?.into_luma8(); + + // Ensure image is 28x28 + let resized = image::imageops::resize(&img, 28, 28, image::imageops::FilterType::Lanczos3); + + // Convert to f32 and normalize to [0, 1] + let pixels: Vec = resized.into_raw().into_iter().map(|x| x as f32).collect(); + + //Apply normalization + let pixels: Vec = pixels + .into_iter() + .map(|x| (x / 255.0 - 0.1307) / 0.3081) + .collect(); + + // Create a batch dimension by wrapping the flattened pixels + let mut input = Vec::with_capacity(28 * 28); + input.extend_from_slice(&pixels); + Ok(input) +} + +fn get_predicted_digit(logits: &[f32]) -> usize { + // Apply softmax and find max probability digit + let max_logit = logits + .iter() + .take(10) + .fold(f32::NEG_INFINITY, |a, &b| a.max(b)); + let exp_sum: f32 = logits.iter().take(10).map(|&x| (x - max_logit).exp()).sum(); + + let softmax: Vec = logits + .iter() + .take(10) + .map(|&x| ((x - max_logit).exp()) / exp_sum) + .collect(); + + softmax + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .map(|(digit, _)| digit) + .unwrap() +} + +fn print_prediction_info(logits: &[f32]) { + let max_logit = logits + .iter() + .take(10) + .fold(f32::NEG_INFINITY, |a, &b| a.max(b)); + let exp_sum: f32 = logits.iter().take(10).map(|&x| (x - max_logit).exp()).sum(); + + let softmax: Vec = logits + .iter() + .take(10) + .map(|&x| ((x - max_logit).exp()) / exp_sum) + .collect(); + + println!("Probabilities for each digit:"); + for (digit, prob) in softmax.iter().enumerate() { + println!("Digit {}: {:.4}", digit, prob); + } + println!("Predicted digit: {}", get_predicted_digit(logits)); +} + +fn main() -> Result<(), Box> { + // 1. Setup model + println!("Loading MNIST model..."); + let mut variables = HashMap::new(); + variables.insert("batch_size".to_string(), 1); + let run_args = RunArgs { variables }; + + let visibility = VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }; + + let model = Model::new("models/mnist_mlp.onnx", &run_args, &visibility)?; + + // 2. Create proof system + println!("Creating proof system..."); + let proof_system = ProofSystem::new(&model); + + println!("\n=== Test Case 1: Valid Proof for First Image ==="); + // Load first image + let input1 = preprocess_image("models/data/1052.png")?; + let input_vec1 = vec![input1]; + + // Generate output and proof for first image + let prover_output1 = proof_system.prove(&input_vec1)?; + println!("First image prediction:"); + print_prediction_info(&prover_output1.output[0]); + + // Verify proof for first image + let is_valid1 = proof_system.verify(&prover_output1.output, &prover_output1.proof)?; + println!( + "Verification result: {}", + if is_valid1 { + "✓ Valid" + } else { + "✗ Invalid" + } + ); + + println!("\n=== Test Case 2: Valid Proof for Second Image ==="); + // Load second image + let input2 = preprocess_image("models/data/1085.png")?; + let input_vec2 = vec![input2]; + + // Generate output and proof for second image + let prover_output2 = proof_system.prove(&input_vec2)?; + println!("Second image prediction:"); + print_prediction_info(&prover_output2.output[0]); + + // Verify proof for second image + let is_valid2 = proof_system.verify(&prover_output2.output, &prover_output2.proof)?; + println!( + "Verification result: {}", + if is_valid2 { + "✓ Valid" + } else { + "✗ Invalid" + } + ); + + println!("\n=== Test Case 3: Invalid Proof - Completely Wrong Outputs ==="); + // Create fake output with opposite predictions + let mut fake_output1 = prover_output1.output.clone(); + for i in 0..10 { + fake_output1[0][i] = -fake_output1[0][i]; // Invert all logits + } + println!("Attempted fake prediction:"); + print_prediction_info(&fake_output1[0]); + + // Try to verify with wrong outputs + let is_valid3 = proof_system.verify(&fake_output1, &prover_output1.proof)?; + println!( + "Verification result: {}", + if is_valid3 { + "✓ Valid (UNEXPECTED!)" + } else { + "✗ Invalid (Expected)" + } + ); + + println!("\n=== Test Case 4: Invalid Proof - Slightly Modified Outputs ==="); + // Create fake output with small perturbations + let mut fake_output2 = prover_output2.output.clone(); + for i in 0..10 { + fake_output2[0][i] += 0.1; // Add small perturbation to each logit + } + println!("Attempted fake prediction (with small perturbations):"); + print_prediction_info(&fake_output2[0]); + + // Try to verify with slightly modified outputs + let is_valid4 = proof_system.verify(&fake_output2, &prover_output2.proof)?; + println!( + "Verification result: {}", + if is_valid4 { + "✓ Valid (UNEXPECTED!)" + } else { + "✗ Invalid (Expected)" + } + ); + + println!("\n=== Summary ==="); + println!( + "1. First valid case (1052.png): {}", + if is_valid1 { + "✓ Valid" + } else { + "✗ Invalid" + } + ); + println!( + "2. Second valid case (1085.png): {}", + if is_valid2 { + "✓ Valid" + } else { + "✗ Invalid" + } + ); + println!( + "3. Invalid case (inverted logits): {}", + if !is_valid3 { + "✓ Failed as expected" + } else { + "✗ Unexpectedly passed" + } + ); + println!( + "4. Invalid case (small perturbations): {}", + if !is_valid4 { + "✓ Failed as expected" + } else { + "✗ Unexpectedly passed" + } + ); + + println!("\nThis demonstrates that the zero-knowledge proof system:"); + println!("- Successfully verifies correct model executions"); + println!("- Detects both large and small output manipulations"); + println!("- Works consistently across different input images"); + + Ok(()) +} diff --git a/models/data/1052.png b/models/data/1052.png new file mode 100644 index 0000000..9e39911 Binary files /dev/null and b/models/data/1052.png differ diff --git a/models/data/1085.png b/models/data/1085.png new file mode 100644 index 0000000..93f90f6 Binary files /dev/null and b/models/data/1085.png differ diff --git a/models/mnist_mlp.onnx b/models/mnist_mlp.onnx new file mode 100644 index 0000000..5d3d6ac Binary files /dev/null and b/models/mnist_mlp.onnx differ diff --git a/models/simple_perceptron.onnx b/models/simple_perceptron.onnx new file mode 100644 index 0000000..6c00bc1 Binary files /dev/null and b/models/simple_perceptron.onnx differ diff --git a/src/graph/errors.rs b/src/graph/errors.rs index 09a0b79..b4da7fb 100644 --- a/src/graph/errors.rs +++ b/src/graph/errors.rs @@ -1,27 +1,27 @@ use thiserror::Error; -/// circuit related errors. -#[derive(Debug, Error)] +#[derive(Error, Debug, Clone, Copy, PartialEq)] pub enum GraphError { - /// Missing Batch Size - #[error("unknown dimension batch_size in model inputs, set batch_size in variables")] - MissingBatchSize, - // Unable to Read ONNX Model - #[error("unable to read onnx model")] + #[error("Unable to read model")] UnableToReadModel, - //Missing Node - #[error("missing node in model")] + #[error("Unable to save model")] + UnableToSaveModel, + #[error("Missing batch size")] + MissingBatchSize, + #[error("Missing node {0}")] MissingNode(usize), - //Invalid Shape - #[error("invalid input shape")] + #[error("Invalid input shape")] InvalidInputShape, - //Invalid Operation - #[error("invalid operation")] - InvalidOperation, - //Missing Parameter - #[error("missing parameter")] - MissingParameter, - // Unable to Save Model - #[error("unable to save model to file")] - UnableToSaveModel, + #[error("Invalid input slot {0}")] + InvalidInputSlot(usize), + #[error("Invalid output slot {0}")] + InvalidOutputSlot(usize), + #[error("Cyclic dependency detected")] + CyclicDependency, + #[error("Unsupported operation")] + UnsupportedOperation, + #[error("Invalid Output Shape")] + InvalidOutputShape, + #[error("Invalid parameter")] + InvalidParams, } diff --git a/src/graph/model.rs b/src/graph/model.rs index 16f218c..5f7e7a5 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -1,17 +1,37 @@ use super::errors::GraphError; +use chrono::Local; use instant; use log::debug; use serde::{Deserialize, Serialize}; +use std::fs::OpenOptions; +use std::io::Write; use std::{ collections::{BTreeMap, HashMap}, path::Path, }; -use tract_onnx::{prelude::*, tract_hir::ops::scan::Scan}; +use tract_onnx::{prelude::*, tract_hir::ops::konst::Const, tract_hir::ops::scan::Scan}; + +use crate::zk::operations::identify_tract_operation; + +/// Type alias for the graph loading result +pub type GraphLoadResult = (Graph>, SymbolValues); /// Represents a node output connection as (node_index, output_slot) pub type Outlet = (usize, usize); -/// Result type for tract operations containing the graph and symbol values -type TractResult = (Graph>, SymbolValues); + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum OperationType { + Input, + MatMul, + Relu, + Sigmoid, + Add, + EinSum, + Max, + Const, + RmAxis, + Reshape, +} /// Serializable version of OutletId #[derive(Clone, Debug, Serialize, Deserialize)] @@ -45,6 +65,23 @@ pub struct Model { pub visibility: VarVisibility, } +/// Represents different types of nodes in the graph +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum NodeType { + /// A regular computation node + Node(SerializableNode), + /// A subgraph node (typically used for control flow operations like loops) + SubGraph { + model: Box, + inputs: Vec, + idx: usize, + out_dims: Vec>, + out_scales: Vec, + output_mappings: Vec>, + input_mappings: Vec, + }, +} + /// Represents the parsed neural network graph structure #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ParsedNodes { @@ -62,6 +99,120 @@ impl ParsedNodes { self.inputs.len() } + pub fn log_weights_and_biases(&self) -> Result<(), GraphError> { + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open("weights_biases_log.txt") + .map_err(|_| GraphError::UnableToSaveModel)?; + + let timestamp = Local::now().format("%Y-%m-%d %H:%M:%S%.3f"); + + writeln!(file, "\n[{}] Weights and Biases Analysis", timestamp) + .map_err(|_| GraphError::UnableToSaveModel)?; + + writeln!(file, "----------------------------------------") + .map_err(|_| GraphError::UnableToSaveModel)?; + + // Build connection map + let mut const_connections: HashMap> = HashMap::new(); + for (node_idx, node_type) in &self.nodes { + if let NodeType::Node(node) = node_type { + for (input_idx, _slot) in &node.inputs { + if let Some(NodeType::Node(input_node)) = self.nodes.get(input_idx) { + if matches!(input_node.op_type, OperationType::Const) { + const_connections + .entry(*input_idx) + .or_default() + .push((*node_idx, node.op_type.clone())); + } + } + } + } + } + + // Create a sorted list of nodes for consistent output + let mut node_indices: Vec<_> = self.nodes.keys().collect(); + node_indices.sort(); + + for &node_idx in &node_indices { + if let Some(NodeType::Node(node)) = self.nodes.get(node_idx) { + if matches!(node.op_type, OperationType::Const) { + // Node header + writeln!(file, "\nConst Node {}", node_idx) + .map_err(|_| GraphError::UnableToSaveModel)?; + + // Dimensions + writeln!(file, "Dimensions: {:?}", node.out_dims) + .map_err(|_| GraphError::UnableToSaveModel)?; + + // Consumers + if let Some(consumers) = const_connections.get(node_idx) { + writeln!(file, "Used by:").map_err(|_| GraphError::UnableToSaveModel)?; + for (consumer_idx, op_type) in consumers { + writeln!(file, " - Node {} ({:?})", consumer_idx, op_type) + .map_err(|_| GraphError::UnableToSaveModel)?; + } + } + + // Values + if let Some(weights) = &node.weights { + writeln!(file, "\nAll Values:") + .map_err(|_| GraphError::UnableToSaveModel)?; + writeln!(file, "Total elements: {}", weights.len()) + .map_err(|_| GraphError::UnableToSaveModel)?; + + // Write all values as a comma-separated list within brackets + write!(file, "[").map_err(|_| GraphError::UnableToSaveModel)?; + for (i, &value) in weights.iter().enumerate() { + if i > 0 { + write!(file, ", ").map_err(|_| GraphError::UnableToSaveModel)?; + } + write!(file, "{:.6}", value) + .map_err(|_| GraphError::UnableToSaveModel)?; + } + writeln!(file, "]").map_err(|_| GraphError::UnableToSaveModel)?; + + // Statistics + let min = weights.iter().fold(f32::INFINITY, |a, &b| a.min(b)); + let max = weights.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)); + let sum: f32 = weights.iter().sum(); + let mean = sum / weights.len() as f32; + + // Count non-zero elements + let non_zero_count = weights.iter().filter(|&&x| x != 0.0).count(); + + writeln!(file, "\nStatistics:") + .map_err(|_| GraphError::UnableToSaveModel)?; + writeln!(file, " Total elements: {}", weights.len()) + .map_err(|_| GraphError::UnableToSaveModel)?; + writeln!(file, " Non-zero elements: {}", non_zero_count) + .map_err(|_| GraphError::UnableToSaveModel)?; + writeln!(file, " Zero elements: {}", weights.len() - non_zero_count) + .map_err(|_| GraphError::UnableToSaveModel)?; + writeln!( + file, + " Sparsity: {:.2}%", + (weights.len() - non_zero_count) as f32 / weights.len() as f32 * 100.0 + ) + .map_err(|_| GraphError::UnableToSaveModel)?; + writeln!(file, " Min: {:.6}", min) + .map_err(|_| GraphError::UnableToSaveModel)?; + writeln!(file, " Max: {:.6}", max) + .map_err(|_| GraphError::UnableToSaveModel)?; + writeln!(file, " Mean: {:.6}", mean) + .map_err(|_| GraphError::UnableToSaveModel)?; + } + + writeln!(file, "----------------------------------------") + .map_err(|_| GraphError::UnableToSaveModel)?; + } + } + } + + Ok(()) + } + /// Returns a vector of output scales for all output nodes pub fn get_output_scales(&self) -> Result, GraphError> { self.outputs @@ -70,68 +221,273 @@ impl ParsedNodes { self.nodes .get(&node) .ok_or(GraphError::MissingNode(node)) - .map(|n| n.out_scales()[slot]) + .map(|n| match n { + NodeType::Node(node) => node.out_scale, + NodeType::SubGraph { out_scales, .. } => out_scales[slot], + }) }) .collect() } -} -/// Represents different types of nodes in the graph -#[derive(Clone, Debug, Serialize, Deserialize)] -pub enum NodeType { - /// A regular computation node - Node(SerializableNode), - /// A subgraph node (typically used for control flow operations like loops) - SubGraph { - model: Box, - inputs: Vec, - idx: usize, - out_dims: Vec>, - out_scales: Vec, - output_mappings: Vec>, - input_mappings: Vec, - }, -} + /// Execute the graph with given inputs + pub fn execute(&self, inputs: &[Vec]) -> Result>, GraphError> { + let mut node_outputs: HashMap>> = HashMap::new(); + + // Store input values + for (&node_idx, input) in self.inputs.iter().zip(inputs.iter()) { + // Get the input node to check its dimensions + if let Some(NodeType::Node(node)) = self.nodes.get(&node_idx) { + if node.out_dims.len() > 1 { + // If input node expects a tensor, reshape the input + node_outputs.insert(node_idx, vec![input.clone()]); + } else { + node_outputs.insert(node_idx, vec![input.clone()]); + } + } else { + node_outputs.insert(node_idx, vec![input.clone()]); + } + } -impl NodeType { - /// Returns the output scales for the node - pub fn out_scales(&self) -> &[i32] { - match self { - NodeType::Node(node) => std::slice::from_ref(&node.out_scale), - NodeType::SubGraph { out_scales, .. } => out_scales, + // Topologically sort nodes for execution + let sorted_nodes = self.topological_sort()?; + + // Execute nodes in order + for &node_idx in sorted_nodes.iter() { + if let Some(node_type) = self.nodes.get(&node_idx) { + match node_type { + NodeType::Node(node) => { + // Handle Const nodes + if matches!(node.op_type, OperationType::Const) { + if let Some(weights) = &node.weights { + node_outputs.insert(node_idx, vec![weights.clone()]); + } + continue; + } + + // Skip input nodes as they're already processed + if matches!(node.op_type, OperationType::Input) { + continue; + } + + // Get input values + let mut input_values = Vec::new(); + for &(input_node, slot) in &node.inputs { + if let Some(outputs) = node_outputs.get(&input_node) { + if slot < outputs.len() { + input_values.push(outputs[slot].clone()); + } else { + return Err(GraphError::InvalidInputSlot(slot)); + } + } else { + return Err(GraphError::MissingNode(input_node)); + } + } + + // Execute operation + let output = self.execute_operation(node, &input_values)?; + node_outputs.insert(node_idx, output); + } + NodeType::SubGraph { .. } => { + return Err(GraphError::UnsupportedOperation); + } + } + } } - } - /// Returns the input connections for the node - pub fn inputs(&self) -> Vec { - match self { - NodeType::Node(node) => node.inputs.clone(), - NodeType::SubGraph { inputs, .. } => inputs.iter().map(|i| (i.node, i.slot)).collect(), + // Collect outputs + let mut outputs = Vec::new(); + for &(node, slot) in &self.outputs { + if let Some(node_output) = node_outputs.get(&node) { + if slot < node_output.len() { + outputs.push(node_output[slot].clone()); + } else { + return Err(GraphError::InvalidOutputSlot(slot)); + } + } else { + return Err(GraphError::MissingNode(node)); + } } + + Ok(outputs) } - /// Returns the output dimensions for the node - pub fn out_dims(&self) -> Vec> { - match self { - NodeType::Node(node) => vec![node.out_dims.clone()], - NodeType::SubGraph { out_dims, .. } => out_dims.clone(), - } + /// Execute a single operation + fn execute_operation( + &self, + node: &SerializableNode, + inputs: &[Vec], + ) -> Result>, GraphError> { + let result = match node.op_type { + OperationType::Input => Ok(inputs.to_vec()), + OperationType::Const => { + if let Some(weights) = &node.weights { + Ok(vec![weights.clone()]) + } else { + Err(GraphError::InvalidInputShape) + } + } + OperationType::MatMul | OperationType::EinSum => { + if inputs.is_empty() { + return Err(GraphError::InvalidInputShape); + } + + let input = &inputs[0]; // Shape: [784] + let weights = if inputs.len() > 1 { + &inputs[1] + } else if let Some(weights) = &node.weights { + weights + } else { + return Err(GraphError::InvalidInputShape); + }; + + let input_dim = input.len(); // 784 + let output_dim = node.out_dims.iter().product(); // 512 + let weight_rows = output_dim; // 512 (PyTorch convention) + let weight_cols = input_dim; // 784 (PyTorch convention) + + // Verify dimensions match + if weights.len() != weight_rows * weight_cols { + return Err(GraphError::InvalidInputShape); + } + + let mut output = vec![0.0; output_dim]; + + // Using iterators instead of range loops + output.iter_mut().enumerate().for_each(|(i, out)| { + *out = input.iter().enumerate().fold(0.0, |sum, (j, &input_val)| { + let weight_idx = i * input_dim + j; + sum + input_val * weights[weight_idx] + }); + }); + + Ok(vec![output]) + } + OperationType::Add => { + let a = &inputs[0]; + let b = if inputs.len() > 1 { + &inputs[1] + } else if let Some(bias) = &node.bias { + bias + } else { + return Err(GraphError::InvalidInputShape); + }; + + if a.len() != b.len() { + return Err(GraphError::InvalidInputShape); + } + Ok(vec![a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect()]) + } + OperationType::Relu | OperationType::Max => { + if inputs.is_empty() { + return Err(GraphError::InvalidInputShape); + } + + let result = inputs[0].iter().map(|&x| x.max(0.0)).collect(); + Ok(vec![result]) + } + OperationType::Sigmoid => { + if inputs.is_empty() { + return Err(GraphError::InvalidInputShape); + } + + let expected_size: usize = node.out_dims.iter().product(); + if inputs[0].len() != expected_size { + return Err(GraphError::InvalidInputShape); + } + + Ok(vec![inputs[0] + .iter() + .map(|&x| { + if x > 20.0 { + 1.0 + } else if x < -20.0 { + 0.0 + } else { + 1.0 / (1.0 + (-x).exp()) + } + }) + .collect()]) + } + OperationType::RmAxis => { + if inputs.is_empty() { + return Err(GraphError::InvalidInputShape); + } + + let expected_size: usize = node.out_dims.iter().product(); + let input = &inputs[0]; + + if input.len() != expected_size { + return Err(GraphError::InvalidInputShape); + } + + Ok(vec![input.clone()]) + } + OperationType::Reshape => { + if inputs.is_empty() { + return Err(GraphError::InvalidInputShape); + } + Ok(vec![inputs[0].clone()]) + } + }; + + // Log the tensor values after each operation + // if let Ok(outputs) = &result { + // if let Err(e) = self.log_tensor_values(node.id, &node.op_type, outputs) { + // println!("Warning: Failed to log tensor values: {:?}", e); + // } + // } + + result } -} -/// Represents a regular computation node in the graph -#[derive(Clone, Debug)] -pub struct Node { - /// The operation to be performed by this node - pub op: Box, - /// Input connections to this node - pub inputs: Vec, - /// Output dimensions - pub out_dims: Vec, - /// Output scale factor - pub out_scale: i32, - /// Unique identifier for the node - pub id: usize, + /// Perform topological sort of nodes + fn topological_sort(&self) -> Result, GraphError> { + let mut visited = HashMap::new(); + let mut sorted = Vec::new(); + + fn visit( + node: usize, + visited: &mut HashMap, + sorted: &mut Vec, + nodes: &BTreeMap, + ) -> Result<(), GraphError> { + if let Some(&in_progress) = visited.get(&node) { + if in_progress { + return Err(GraphError::CyclicDependency); + } + return Ok(()); + } + + visited.insert(node, true); + + if let Some(node_type) = nodes.get(&node) { + match node_type { + NodeType::Node(node) => { + if !matches!(node.op_type, OperationType::Const) { + for &(input_node, _) in &node.inputs { + visit(input_node, visited, sorted, nodes)?; + } + } + } + NodeType::SubGraph { inputs, .. } => { + for input in inputs { + visit(input.node, visited, sorted, nodes)?; + } + } + } + } + + visited.insert(node, false); + sorted.push(node); + Ok(()) + } + + for &node in self.nodes.keys() { + visit(node, &mut visited, &mut sorted, &self.nodes)?; + } + + Ok(sorted) + } } /// Serializable version of Node that excludes TypedOp @@ -145,26 +501,67 @@ pub struct SerializableNode { pub out_scale: i32, /// Unique identifier for the node pub id: usize, + /// Operation type + pub op_type: OperationType, + pub weights: Option>, + pub bias: Option>, } -impl From for SerializableNode { - fn from(node: Node) -> Self { - SerializableNode { - inputs: node.inputs, - out_dims: node.out_dims, - out_scale: node.out_scale, - id: node.id, - } - } -} +impl From<&Node>> for SerializableNode { + fn from(node: &Node>) -> Self { + let op_name = node.op.name(); + let op_type = if op_name == "Const" { + println!("Found Const operation"); + OperationType::Const + } else if node.inputs.is_empty() { + println!("Found Input operation"); + OperationType::Input + } else if op_name.starts_with("Rm(") { + println!("Found RmAxis operation"); + OperationType::RmAxis + } else if let Some(op_type) = identify_tract_operation(node) { + op_type + } else { + println!("Unknown operation: {}", op_name); + OperationType::RmAxis // Default to RmAxis for unknown operations + }; + + println!("Node From : {:?}", node); + println!("Node op_type: {:?}", op_name.as_ref()); + + // Extract weights and biases based on node type + let (weights, bias) = match op_name.as_ref() { + "Const" => { + if let Some(const_op) = node.op.downcast_ref::() { + if let Ok(tensor_data) = const_op.0.as_slice::() { + (Some(tensor_data.to_vec()), None) + } else { + (None, None) + } + } else { + (None, None) + } + } + _ => (None, None), + }; -impl From<&Node> for SerializableNode { - fn from(node: &Node) -> Self { SerializableNode { - inputs: node.inputs.clone(), - out_dims: node.out_dims.clone(), - out_scale: node.out_scale, + inputs: node.inputs.iter().map(|o| (o.node, o.slot)).collect(), + out_dims: node.outputs[0] + .fact + .shape + .iter() + .map(|d| d.to_i64().unwrap() as usize) + .collect(), + out_scale: node.outputs[0] + .fact + .konst + .as_ref() + .map_or(1, |k| *k.to_scalar::().unwrap_or(&1)), id: node.id, + op_type, + weights, + bias, } } } @@ -234,10 +631,55 @@ pub enum Visibility { } impl Model { + pub fn new( + path: &str, + run_args: &RunArgs, + visibility: &VarVisibility, + ) -> Result { + let parsed_nodes = Self::load_onnx_model(path, run_args, visibility)?; + Ok(Model { + graph: parsed_nodes, + visibility: visibility.clone(), + }) + } + + pub fn load_onnx_model( + path: &str, + run_args: &RunArgs, + visibility: &VarVisibility, + ) -> Result { + let start = instant::Instant::now(); + let (model, symbol_values) = Self::load_onnx_using_tract(path, run_args)?; + let nodes = Self::nodes_from_graph(&model, visibility.clone(), symbol_values)?; + println!("Model loaded in {:?}", start.elapsed()); + + // Collect all input nodes (nodes with OperationType::Input) + let inputs: Vec = nodes + .iter() + .filter_map(|(idx, node)| match node { + NodeType::Node(n) => { + if matches!(n.op_type, OperationType::Input) { + Some(*idx) + } else { + None + } + } + _ => None, + }) + .collect(); + + let parsed_nodes = ParsedNodes { + nodes, + inputs, + outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(), + }; + Ok(parsed_nodes) + } + pub fn load_onnx_using_tract>( path: P, run_args: &RunArgs, - ) -> Result { + ) -> Result { debug!("Starting load_onnx_using_tract"); use tract_onnx::tract_hir::internal::GenericFactoid; @@ -285,7 +727,6 @@ impl Model { debug!("set {} to {}", symbol, value); } - // Note: do not optimize the model, as the layout will depend on underlying hardware let typed_model = model .into_typed() .map_err(|_| GraphError::UnableToReadModel)? @@ -298,51 +739,6 @@ impl Model { Ok((typed_model, symbol_values)) } - /// Loads and parses an ONNX model into the internal graph representation - pub fn load_onnx_model( - path: &str, - run_args: &RunArgs, - visibility: &VarVisibility, - ) -> Result { - let start = instant::Instant::now(); - let (model, symbol_values) = Self::load_onnx_using_tract(path, run_args)?; - let nodes = Self::nodes_from_graph(&model, visibility.clone(), symbol_values)?; - println!("Model loaded in {:?}", start.elapsed()); - let parsed_nodes = ParsedNodes { - nodes, - inputs: model.inputs.iter().map(|o| o.node).collect(), - outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(), - }; - Ok(parsed_nodes) - } - - /// Creates a new Model instance from an ONNX file - pub fn new( - path: &str, - run_args: &RunArgs, - visibility: &VarVisibility, - ) -> Result { - let parsed_nodes = Self::load_onnx_model(path, run_args, visibility)?; - Ok(Model { - graph: parsed_nodes, - visibility: visibility.clone(), - }) - } - - /// Saves the model to a binary file - pub fn save>(&self, path: P) -> Result<(), GraphError> { - let encoded: Vec = - bincode::serialize(self).map_err(|_| GraphError::UnableToSaveModel)?; - std::fs::write(path, encoded).map_err(|_| GraphError::UnableToSaveModel) - } - - /// Loads a model from a binary file - pub fn load>(path: P) -> Result { - let bytes = std::fs::read(path).map_err(|_| GraphError::UnableToReadModel)?; - bincode::deserialize(&bytes).map_err(|_| GraphError::UnableToReadModel) - } - - /// Converts a tract graph into the internal node representation pub fn nodes_from_graph( graph: &Graph>, visibility: VarVisibility, @@ -351,8 +747,9 @@ impl Model { use super::utilities::node_output_shapes; let mut nodes = BTreeMap::new(); - // First pass: Create all nodes + // Process all nodes for (idx, node) in graph.nodes.iter().enumerate() { + println!("Node: {:?}", node); match node.op().downcast_ref::() { Some(scan_op) => { debug!("Processing scan node {}", idx); @@ -436,20 +833,10 @@ impl Model { } None => { debug!("Processing regular node {}", idx); - // Create regular node - let out_dims = node_output_shapes(node, &symbol_values)? - .pop() - .unwrap_or_default(); - - let regular_node = Node { - op: node.op.clone(), - inputs: node.inputs.iter().map(|i| (i.node, i.slot)).collect(), - out_dims, - out_scale: 1, - id: idx, - }; - nodes.insert(idx, NodeType::Node(regular_node.into())); + // Create the node with proper operation type and weights/biases + let serializable_node = SerializableNode::from(node); + nodes.insert(idx, NodeType::Node(serializable_node)); } } } @@ -457,23 +844,22 @@ impl Model { // Verify all required nodes exist let mut missing_nodes = Vec::new(); - // Check inputs + // Check inputs for non-Const nodes for node in nodes.values() { match node { NodeType::Node(n) => { + if matches!(n.op_type, OperationType::Const) { + continue; + } for &(input_node, _) in &n.inputs { - if !nodes.contains_key(&input_node) - && !graph.inputs.iter().any(|x| x.node == input_node) - { + if !nodes.contains_key(&input_node) { missing_nodes.push(input_node); } } } NodeType::SubGraph { inputs, .. } => { for input in inputs { - if !nodes.contains_key(&input.node) - && !graph.inputs.iter().any(|x| x.node == input.node) - { + if !nodes.contains_key(&input.node) { missing_nodes.push(input.node); } } diff --git a/src/graph/tests/mod.rs b/src/graph/tests/mod.rs index 4c5d913..c71c1f2 100644 --- a/src/graph/tests/mod.rs +++ b/src/graph/tests/mod.rs @@ -1,6 +1,6 @@ +use super::*; + +mod model_advanced_tests; mod model_tests; mod scales_tests; mod utilities_tests; - -#[cfg(test)] -use super::*; diff --git a/src/graph/tests/model_advanced_tests.rs b/src/graph/tests/model_advanced_tests.rs new file mode 100644 index 0000000..6e26563 --- /dev/null +++ b/src/graph/tests/model_advanced_tests.rs @@ -0,0 +1,257 @@ +use crate::graph::{ + errors::GraphError, + model::{ + Model, NodeType, OperationType, ParsedNodes, SerializableNode, VarVisibility, Visibility, + }, +}; +use std::collections::BTreeMap; + +#[test] +fn test_matrix_dimension_mismatch() { + let mut nodes = BTreeMap::new(); + + // Input nodes (id: 0, 1) + let input_node1 = SerializableNode { + inputs: vec![], + out_dims: vec![2, 3], // 2x3 matrix + out_scale: 1, + id: 0, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(0, NodeType::Node(input_node1)); + + let input_node2 = SerializableNode { + inputs: vec![], + out_dims: vec![4, 2], // 4x2 matrix (incompatible dimensions) + out_scale: 1, + id: 1, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(1, NodeType::Node(input_node2)); + + // MatMul node (id: 2) + let matmul_node = SerializableNode { + inputs: vec![(0, 0), (1, 0)], + out_dims: vec![2, 2], // Result should be 2x2 + out_scale: 1, + id: 2, + op_type: OperationType::MatMul, + weights: None, + bias: None, + }; + nodes.insert(2, NodeType::Node(matmul_node)); + + let graph = ParsedNodes { + nodes, + inputs: vec![0, 1], + outputs: vec![(2, 0)], + }; + + let model = Model { + graph, + visibility: VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }, + }; + + // Test with incompatible matrix dimensions + let input1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3 matrix + let input2 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; // 4x2 matrix + let result = model.graph.execute(&[input1, input2]); + assert!(matches!(result, Err(GraphError::InvalidInputShape))); +} + +#[test] +fn test_relu_edge_cases() { + let mut nodes = BTreeMap::new(); + + // Input node (id: 0) + let input_node = SerializableNode { + inputs: vec![], + out_dims: vec![4], // 1D vector of size 4 + out_scale: 1, + id: 0, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(0, NodeType::Node(input_node)); + + // ReLU node (id: 1) + let relu_node = SerializableNode { + inputs: vec![(0, 0)], + out_dims: vec![4], + out_scale: 1, + id: 1, + op_type: OperationType::Relu, + weights: None, + bias: None, + }; + nodes.insert(1, NodeType::Node(relu_node)); + + let graph = ParsedNodes { + nodes, + inputs: vec![0], + outputs: vec![(1, 0)], + }; + + let model = Model { + graph, + visibility: VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }, + }; + + // Test with very large negative numbers + let input = vec![-1e10, -1e5, 1e5, 1e10]; + let result = model.graph.execute(&[input]).unwrap(); + assert_eq!(result[0], vec![0.0, 0.0, 1e5, 1e10]); + + // Test with zeros and small numbers + let input = vec![-1e-10, 0.0, 1e-10, 1.0]; + let result = model.graph.execute(&[input]).unwrap(); + assert_eq!(result[0], vec![0.0, 0.0, 1e-10, 1.0]); + + // Test with special values + let input = vec![f32::NEG_INFINITY, -0.0, 0.0, f32::INFINITY]; + let result = model.graph.execute(&[input]).unwrap(); + assert_eq!(result[0], vec![0.0, 0.0, 0.0, f32::INFINITY]); +} + +#[test] +fn test_sigmoid_edge_cases() { + let mut nodes = BTreeMap::new(); + + // Input node (id: 0) + let input_node = SerializableNode { + inputs: vec![], + out_dims: vec![4], + out_scale: 1, + id: 0, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(0, NodeType::Node(input_node)); + + // Sigmoid node (id: 1) + let sigmoid_node = SerializableNode { + inputs: vec![(0, 0)], + out_dims: vec![4], + out_scale: 1, + id: 1, + op_type: OperationType::Sigmoid, + weights: None, + bias: None, + }; + nodes.insert(1, NodeType::Node(sigmoid_node)); + + let graph = ParsedNodes { + nodes, + inputs: vec![0], + outputs: vec![(1, 0)], + }; + + let model = Model { + graph, + visibility: VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }, + }; + + // Test with extreme values + let input = vec![-1000.0, -20.0, 20.0, 1000.0]; + let result = model.graph.execute(&[input]).unwrap(); + assert!( + (result[0][0] - 0.0).abs() < 1e-6, + "Failed on extreme negative value" + ); + assert!( + (result[0][1] - 0.0).abs() < 1e-6, + "Failed on large negative value" + ); + assert!( + (result[0][2] - 1.0).abs() < 1e-6, + "Failed on large positive value" + ); + assert!( + (result[0][3] - 1.0).abs() < 1e-6, + "Failed on extreme positive value" + ); + + // Test with zeros and small numbers + let input = vec![-1e-10, 0.0, 1e-10, 1.0]; + let result = model.graph.execute(&[input]).unwrap(); + // For very small numbers (< 1e-7), we expect 0.0 due to numerical stability + assert_eq!(result[0][0], 0.5, "Failed on small negative number"); + assert_eq!(result[0][1], 0.5, "Failed on zero"); + assert_eq!(result[0][2], 0.5, "Failed on small positive number"); + assert!( + (result[0][3] - 0.7310586).abs() < 1e-6, + "Failed on regular number" + ); + + // Test with special values + let input = vec![f32::NEG_INFINITY, -0.0, 0.0, f32::INFINITY]; + let result = model.graph.execute(&[input]).unwrap(); + assert_eq!( + result[0], + vec![0.0, 0.5, 0.5, 1.0], + "Failed on special values" + ); +} + +#[test] +fn test_reshape_edge_cases() { + let mut nodes = BTreeMap::new(); + + // Input node (id: 0) + let input_node = SerializableNode { + inputs: vec![], + out_dims: vec![6], + out_scale: 1, + id: 0, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(0, NodeType::Node(input_node)); + + // Reshape node (id: 1) - reshape to 2x3 + let reshape_node = SerializableNode { + inputs: vec![(0, 0)], + out_dims: vec![2, 3], + out_scale: 1, + id: 1, + op_type: OperationType::Reshape, + weights: None, + bias: None, + }; + nodes.insert(1, NodeType::Node(reshape_node)); + + let graph = ParsedNodes { + nodes, + inputs: vec![0], + outputs: vec![(1, 0)], + }; + + let model = Model { + graph, + visibility: VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }, + }; + + // Test with exact size match + let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let result = model.graph.execute(&[input]).unwrap(); + assert_eq!(result[0], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); +} diff --git a/src/graph/tests/model_tests.rs b/src/graph/tests/model_tests.rs index 91db10e..861daab 100644 --- a/src/graph/tests/model_tests.rs +++ b/src/graph/tests/model_tests.rs @@ -1,213 +1,533 @@ -#[cfg(test)] -mod tests { - use super::super::errors::GraphError; - use super::super::model::Node; - use super::super::model::*; - use std::collections::{BTreeMap, HashMap}; - use std::env; - use std::path::Path; - use tract_data::internal::tract_smallvec::SmallVec; - use tract_onnx::tract_hir::ops::cnn::{PaddingSpec, PoolSpec}; - use tract_onnx::tract_hir::ops::nn::DataFormat; - use tract_onnx::{prelude::*, tract_core}; - - #[test] - fn test_model_load_invalid_path() { - let run_args = RunArgs { - variables: std::collections::HashMap::from([("batch_size".to_string(), 1)]), - }; - - let visibility = VarVisibility { +use crate::graph::{ + errors::GraphError, + model::{ + Model, NodeType, OperationType, ParsedNodes, SerializableNode, VarVisibility, Visibility, + }, +}; +use std::collections::BTreeMap; + +#[test] +fn test_matmul_operation() { + let mut nodes = BTreeMap::new(); + + // Input nodes (id: 0, 1) + let input_node1 = SerializableNode { + inputs: vec![], + out_dims: vec![2], // 2 elements vector + out_scale: 1, + id: 0, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(0, NodeType::Node(input_node1)); + + let input_node2 = SerializableNode { + inputs: vec![], + out_dims: vec![2, 2], // 2x2 matrix + out_scale: 1, + id: 1, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(1, NodeType::Node(input_node2)); + + // MatMul node (id: 2) + let matmul_node = SerializableNode { + inputs: vec![(0, 0), (1, 0)], + out_dims: vec![2], // Result is 2 elements + out_scale: 1, + id: 2, + op_type: OperationType::MatMul, + weights: None, + bias: None, + }; + nodes.insert(2, NodeType::Node(matmul_node)); + + let graph = ParsedNodes { + nodes, + inputs: vec![0, 1], + outputs: vec![(2, 0)], + }; + + let model = Model { + graph, + visibility: VarVisibility { input: Visibility::Public, output: Visibility::Public, - }; - - let model_path = "nonexistent.onnx"; - let result = Model::new(model_path, &run_args, &visibility); - assert!(matches!(result, Err(GraphError::UnableToReadModel))); - } - - #[test] - fn test_model_load_success() -> Result<(), Box> { - let run_args = RunArgs { - variables: HashMap::from([ - ("N".to_string(), 1), - ("C".to_string(), 3), - ("H".to_string(), 224), - ("W".to_string(), 224), - ("batch_size".to_string(), 1), - ("sequence_length".to_string(), 128), - ]), - }; - - let visibility = VarVisibility { + }, + }; + + // Test execution with vector-matrix multiplication + // Vector: [1, 2] + let input1 = vec![1.0, 2.0]; + // Matrix: [[5, 6], [7, 8]] + let input2 = vec![5.0, 6.0, 7.0, 8.0]; + + let result = model + .graph + .execute(&[input1.clone(), input2.clone()]) + .unwrap(); + + // Expected result: [17, 23] + // First element: 1 * 5 + 2 * 6 = 17 (Transpose the vector for multiplication, same as pytorch) + // Second element: 1 * 7 + 2 * 8 = 23 + assert_eq!(result.len(), 1); + assert_eq!(result[0], vec![17.0, 23.0]); + + // Test invalid input dimensions + let invalid_input1 = vec![1.0, 2.0, 3.0]; // 3 elements instead of 2 + let result = model.graph.execute(&[invalid_input1, input2]); + assert!(matches!(result, Err(GraphError::InvalidInputShape))); +} + +#[test] +fn test_relu_operation() { + let mut nodes = BTreeMap::new(); + + // Input node (id: 0) + let input_node = SerializableNode { + inputs: vec![], + out_dims: vec![4], + out_scale: 1, + id: 0, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(0, NodeType::Node(input_node)); + + // ReLU node (id: 1) + let relu_node = SerializableNode { + inputs: vec![(0, 0)], + out_dims: vec![4], + out_scale: 1, + id: 1, + op_type: OperationType::Relu, + weights: None, + bias: None, + }; + nodes.insert(1, NodeType::Node(relu_node)); + + let graph = ParsedNodes { + nodes, + inputs: vec![0], + outputs: vec![(1, 0)], + }; + + let model = Model { + graph, + visibility: VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }, + }; + + // Test execution with various inputs + let input = vec![-1.0, 0.0, 1.0, 2.0]; + let result = model.graph.execute(&[input]).unwrap(); + + // Expected result: [0.0, 0.0, 1.0, 2.0] + assert_eq!(result.len(), 1); + assert_eq!(result[0], vec![0.0, 0.0, 1.0, 2.0]); +} + +#[test] +fn test_sigmoid_operation() { + let mut nodes = BTreeMap::new(); + + // Input node (id: 0) + let input_node = SerializableNode { + inputs: vec![], + out_dims: vec![3], + out_scale: 1, + id: 0, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(0, NodeType::Node(input_node)); + + // Sigmoid node (id: 1) + let sigmoid_node = SerializableNode { + inputs: vec![(0, 0)], + out_dims: vec![3], + out_scale: 1, + id: 1, + op_type: OperationType::Sigmoid, + weights: None, + bias: None, + }; + nodes.insert(1, NodeType::Node(sigmoid_node)); + + let graph = ParsedNodes { + nodes, + inputs: vec![0], + outputs: vec![(1, 0)], + }; + + let model = Model { + graph, + visibility: VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }, + }; + + // Test execution with various inputs + let input = vec![-2.0, 0.0, 2.0]; + let result = model.graph.execute(&[input]).unwrap(); + + // Expected result: sigmoid values for [-2.0, 0.0, 2.0] + // sigmoid(x) = 1 / (1 + e^(-x)) + assert_eq!(result.len(), 1); + assert!((result[0][0] - 0.119).abs() < 0.001); // sigmoid(-2) ≈ 0.119 + assert!((result[0][1] - 0.5).abs() < 0.001); // sigmoid(0) = 0.5 + assert!((result[0][2] - 0.881).abs() < 0.001); // sigmoid(2) ≈ 0.881 + + // Test with invalid input shape + let invalid_input = vec![-2.0, 0.0]; // 2 elements instead of 3 + let result = model.graph.execute(&[invalid_input]); + assert!(matches!(result, Err(GraphError::InvalidInputShape))); +} + +#[test] +fn test_add_operation() { + let mut nodes = BTreeMap::new(); + + // Input nodes (id: 0, 1) + let input_node1 = SerializableNode { + inputs: vec![], + out_dims: vec![3], + out_scale: 1, + id: 0, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(0, NodeType::Node(input_node1)); + + let input_node2 = SerializableNode { + inputs: vec![], + out_dims: vec![3], + out_scale: 1, + id: 1, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(1, NodeType::Node(input_node2)); + + // Add node (id: 2) + let add_node = SerializableNode { + inputs: vec![(0, 0), (1, 0)], + out_dims: vec![3], + out_scale: 1, + id: 2, + op_type: OperationType::Add, + weights: None, + bias: None, + }; + nodes.insert(2, NodeType::Node(add_node)); + + let graph = ParsedNodes { + nodes, + inputs: vec![0, 1], + outputs: vec![(2, 0)], + }; + + let model = Model { + graph, + visibility: VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }, + }; + + // Test normal addition + let input1 = vec![1.0, 2.0, 3.0]; + let input2 = vec![4.0, 5.0, 6.0]; + let result = model.graph.execute(&[input1, input2]).unwrap(); + assert_eq!(result[0], vec![5.0, 7.0, 9.0]); + + // Test with mismatched dimensions + let input1 = vec![1.0, 2.0, 3.0]; + let input2 = vec![4.0, 5.0]; + let result = model.graph.execute(&[input1, input2]); + assert!(matches!(result, Err(GraphError::InvalidInputShape))); +} + +#[test] +fn test_einsum_operation() { + let mut nodes = BTreeMap::new(); + + // Input nodes (id: 0, 1) + let input_node1 = SerializableNode { + inputs: vec![], + out_dims: vec![2], // 2 elements vector + out_scale: 1, + id: 0, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(0, NodeType::Node(input_node1)); + + let input_node2 = SerializableNode { + inputs: vec![], + out_dims: vec![2, 2], // 2x2 matrix + out_scale: 1, + id: 1, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(1, NodeType::Node(input_node2)); + + // EinSum node (id: 2) + let einsum_node = SerializableNode { + inputs: vec![(0, 0), (1, 0)], + out_dims: vec![2], // Result is 2 elements + out_scale: 1, + id: 2, + op_type: OperationType::EinSum, + weights: None, + bias: None, + }; + nodes.insert(2, NodeType::Node(einsum_node)); + + let graph = ParsedNodes { + nodes, + inputs: vec![0, 1], + outputs: vec![(2, 0)], + }; + + let model = Model { + graph, + visibility: VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }, + }; + + // Test EinSum operation (vector-matrix multiplication in this case) + let input1 = vec![1.0, 2.0]; // 2-element vector + let input2 = vec![5.0, 6.0, 7.0, 8.0]; // 2x2 matrix + let result = model.graph.execute(&[input1, input2]).unwrap(); + + // Expected result: [17.0, 23.0] + assert_eq!(result[0], vec![17.0, 23.0]); +} + +#[test] +fn test_reshape_operation() { + let mut nodes = BTreeMap::new(); + + // Input node (id: 0) + let input_node = SerializableNode { + inputs: vec![], + out_dims: vec![6], + out_scale: 1, + id: 0, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(0, NodeType::Node(input_node)); + + // Reshape node (id: 1) + let reshape_node = SerializableNode { + inputs: vec![(0, 0)], + out_dims: vec![2, 3], // Reshape 6 elements to 2x3 matrix + out_scale: 1, + id: 1, + op_type: OperationType::Reshape, + weights: None, + bias: None, + }; + nodes.insert(1, NodeType::Node(reshape_node)); + + let graph = ParsedNodes { + nodes, + inputs: vec![0], + outputs: vec![(1, 0)], + }; + + let model = Model { + graph, + visibility: VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }, + }; + + // Test reshaping + let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let result = model.graph.execute(&[input]).unwrap(); + assert_eq!(result[0], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); +} + +#[test] +fn test_const_operation() { + let mut nodes = BTreeMap::new(); + + // Const node (id: 0) + let const_node = SerializableNode { + inputs: vec![], + out_dims: vec![3], + out_scale: 1, + id: 0, + op_type: OperationType::Const, + weights: Some(vec![1.0, 2.0, 3.0]), + bias: None, + }; + nodes.insert(0, NodeType::Node(const_node)); + + let graph = ParsedNodes { + nodes, + inputs: vec![], + outputs: vec![(0, 0)], + }; + + let model = Model { + graph, + visibility: VarVisibility { input: Visibility::Public, output: Visibility::Public, - }; - - // Get current directory path - let current_dir = env::current_dir()?; - println!("current directory: {:?}", current_dir.display()); - - let model_path = current_dir.join("models/resnet101-v1-7.onnx"); - println!("Model path: {:?}", model_path); - - // Verify model file exists - if !model_path.exists() { - println!("Model file not found at expected path"); - return Ok(()); - } - - let model_path_str = model_path.to_str().ok_or("Invalid model path")?; - - let result = Model::new(model_path_str, &run_args, &visibility); - assert!(result.is_ok()); - println!("result: {:?}", result); - - let model = result.unwrap(); - assert!(model.graph.num_inputs() > 0); - Ok(()) - } - - #[test] - fn test_parsed_nodes_output_scales() { - let mut nodes: BTreeMap = BTreeMap::new(); - let inputs: Vec<(usize, usize)> = vec![]; - let pool_spec: PoolSpec = PoolSpec::new( - DataFormat::NHWC, - SmallVec::from_buf([2, 2, 2, 2]), - PaddingSpec::Valid, - None, - None, - 1, - 2, - ); - - // Create a Node first - let node = Node { - op: Box::new(tract_core::ops::cnn::MaxPool::new(pool_spec, None)), - inputs: inputs.clone(), - out_dims: vec![], - out_scale: 1, - id: 0, - }; - - // Convert Node to SerializableNode - let serializable_node = SerializableNode { - inputs: node.inputs, - out_dims: node.out_dims, - out_scale: node.out_scale, - id: node.id, - }; - - nodes.insert(0, NodeType::Node(serializable_node)); - - let parsed_nodes = ParsedNodes { - nodes, - inputs: vec![], - outputs: vec![(0, 0)], - }; - - let scales = parsed_nodes.get_output_scales().unwrap(); - assert_eq!(scales, vec![1]); - } - - #[test] - fn test_model_serialization() -> Result<(), Box> { - let run_args = RunArgs { - variables: HashMap::from([ - ("N".to_string(), 1), - ("C".to_string(), 3), - ("H".to_string(), 224), - ("W".to_string(), 224), - ("batch_size".to_string(), 1), - ("sequence_length".to_string(), 128), - ]), - }; - - let visibility = VarVisibility { + }, + }; + + // Test execution - Const nodes should output their weights + let result = model.graph.execute(&[]).unwrap(); + assert_eq!(result[0], vec![1.0, 2.0, 3.0]); +} + +#[test] +fn test_cyclic_dependency() { + let mut nodes = BTreeMap::new(); + + // Create a cycle: node 0 -> node 1 -> node 2 -> node 0 + let node0 = SerializableNode { + inputs: vec![(2, 0)], // Creates cycle + out_dims: vec![1], + out_scale: 1, + id: 0, + op_type: OperationType::Add, + weights: None, + bias: None, + }; + nodes.insert(0, NodeType::Node(node0)); + + let node1 = SerializableNode { + inputs: vec![(0, 0)], + out_dims: vec![1], + out_scale: 1, + id: 1, + op_type: OperationType::Add, + weights: None, + bias: None, + }; + nodes.insert(1, NodeType::Node(node1)); + + let node2 = SerializableNode { + inputs: vec![(1, 0)], + out_dims: vec![1], + out_scale: 1, + id: 2, + op_type: OperationType::Add, + weights: None, + bias: None, + }; + nodes.insert(2, NodeType::Node(node2)); + + let graph = ParsedNodes { + nodes, + inputs: vec![], + outputs: vec![(2, 0)], + }; + + // Execute should fail due to cyclic dependency + let result = graph.execute(&[]); + assert!(matches!(result, Err(GraphError::CyclicDependency))); +} + +#[test] +fn test_invalid_output_slot() { + let mut nodes = BTreeMap::new(); + + // Input node (id: 0) + let input_node = SerializableNode { + inputs: vec![], + out_dims: vec![1], + out_scale: 1, + id: 0, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(0, NodeType::Node(input_node)); + + let graph = ParsedNodes { + nodes, + inputs: vec![0], + outputs: vec![(0, 1)], // Invalid output slot (node 0 only has slot 0) + }; + + let model = Model { + graph, + visibility: VarVisibility { input: Visibility::Public, output: Visibility::Public, - }; - - let current_dir = env::current_dir()?; - let model_path = current_dir.join("models/resnet101-v1-7.onnx"); - - // Skip test if model file doesn't exist - if !model_path.exists() { - println!("Model file not found, skipping test"); - return Ok(()); - } - - let model_path_str = model_path.to_str().ok_or("Invalid model path")?; - let model = Model::new(model_path_str, &run_args, &visibility)?; - - // Test saving - let save_path = "test_model.bin"; - assert!(model.save(save_path).is_ok()); - - // Test loading - let loaded_model = Model::load(save_path); - assert!(loaded_model.is_ok()); - - // Clean up - std::fs::remove_file(save_path)?; - - // Compare original and loaded models - let loaded_model = loaded_model.unwrap(); - assert_eq!(model.visibility, loaded_model.visibility); - assert_eq!(model.graph.inputs.len(), loaded_model.graph.inputs.len()); - assert_eq!(model.graph.outputs.len(), loaded_model.graph.outputs.len()); - Ok(()) - } - - #[test] - fn test_model_save_error() -> Result<(), Box> { - let run_args = RunArgs { - variables: HashMap::from([ - ("N".to_string(), 1), - ("C".to_string(), 3), - ("H".to_string(), 224), - ("W".to_string(), 224), - ("batch_size".to_string(), 1), - ("sequence_length".to_string(), 128), - ]), - }; - - let visibility = VarVisibility { + }, + }; + + let result = model.graph.execute(&[vec![1.0]]); + assert!(matches!(result, Err(GraphError::InvalidOutputSlot(1)))); +} + +#[test] +fn test_missing_node() { + let mut nodes = BTreeMap::new(); + + // Input node (id: 0) + let input_node = SerializableNode { + inputs: vec![], + out_dims: vec![1], + out_scale: 1, + id: 0, + op_type: OperationType::Input, + weights: None, + bias: None, + }; + nodes.insert(0, NodeType::Node(input_node)); + + // Add node referencing non-existent node (id: 1) + let add_node = SerializableNode { + inputs: vec![(0, 0), (2, 0)], // Node 2 doesn't exist + out_dims: vec![1], + out_scale: 1, + id: 1, + op_type: OperationType::Add, + weights: None, + bias: None, + }; + nodes.insert(1, NodeType::Node(add_node)); + + let graph = ParsedNodes { + nodes, + inputs: vec![0], + outputs: vec![(1, 0)], + }; + + let model = Model { + graph, + visibility: VarVisibility { input: Visibility::Public, output: Visibility::Public, - }; - - let current_dir = env::current_dir()?; - let model_path = current_dir.join("models/resnet101-v1-7.onnx"); - - // Skip test if model file doesn't exist - if !model_path.exists() { - println!("Model file not found, skipping test"); - return Ok(()); - } - - let model_path_str = model_path.to_str().ok_or("Invalid model path")?; - let model = Model::new(model_path_str, &run_args, &visibility)?; - - // Test saving to an invalid path - let result = model.save("/invalid/path/model.bin"); - assert!(matches!(result, Err(GraphError::UnableToSaveModel))); - Ok(()) - } - - #[test] - fn test_model_load_binary_error() { - // Test loading from a non-existent file - let result = Model::load("non_existent_model.bin"); - assert!(matches!(result, Err(GraphError::UnableToReadModel))); - - // Test loading from an invalid binary file - let invalid_bin_path = "invalid_model.bin"; - std::fs::write(invalid_bin_path, "invalid binary content").unwrap(); - let result = Model::load(invalid_bin_path); - assert!(matches!(result, Err(GraphError::UnableToReadModel))); - std::fs::remove_file(invalid_bin_path).unwrap(); - } + }, + }; + + let result = model.graph.execute(&[vec![1.0]]); + assert!(matches!(result, Err(GraphError::MissingNode(2)))); } diff --git a/src/lib.rs b/src/lib.rs index 5ae6242..6f62865 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod graph; +pub mod zk; #[cfg(test)] mod tests {} diff --git a/src/zk/mod.rs b/src/zk/mod.rs new file mode 100644 index 0000000..fc277a7 --- /dev/null +++ b/src/zk/mod.rs @@ -0,0 +1,10 @@ +pub mod operations; +pub mod proof; +pub mod wiring; + +use mina_curves::pasta::Vesta; +use poly_commitment::ipa::OpeningProof; + +pub type ZkOpeningProof = OpeningProof; + +pub use wiring::ModelCircuitBuilder; diff --git a/src/zk/operations.rs b/src/zk/operations.rs new file mode 100644 index 0000000..a419a9b --- /dev/null +++ b/src/zk/operations.rs @@ -0,0 +1,316 @@ +use kimchi::circuits::{ + gate::{CircuitGate, GateType}, + wires::Wire, +}; +use mina_curves::pasta::Fp; +use tract_onnx::prelude::*; + +use crate::graph::model::{OperationType, SerializableNode}; +use anyhow::Result; + +/// Maps ONNX operations to Kimchi circuit gates +#[derive(Debug)] +pub enum OnnxOperation { + /// Matrix multiplication (Gemm/MatMul) + MatMul { + m: usize, // Number of rows in first matrix + n: usize, // Number of columns in second matrix + k: usize, // Number of columns in first matrix/rows in second matrix + }, + /// ReLU activation + Relu, + /// Sigmoid activation + Sigmoid, + /// Addition operation + Add, + /// EinSum operation (used for matrix operations) + EinSum, + /// Max operation (used in ReLU) + Max, + /// Constant value + Const, + /// Remove axis operation (used in flattening) + RmAxis, + /// Reshape operation + Reshape, +} + +impl OnnxOperation { + /// Convert ONNX operation to Kimchi circuit gates + pub fn to_circuit_gates(&self, start_row: usize) -> Result>> { + match self { + OnnxOperation::MatMul { m, n, k } => { + let mut gates = Vec::new(); + let mut current_row = start_row; + + // For each output element (m x n matrix) + for _i in 0..*m { + for _j in 0..*n { + // For each element in the dot product (k elements) + for _l in 0..*k { + // Multiplication gate + let mul_gate = CircuitGate::new( + GateType::ForeignFieldMul, + [Wire::new(current_row, 0); 7], + vec![], + ); + gates.push(mul_gate); + current_row += 1; + + // Addition gate (except for the first element) + if _l > 0 { + let add_gate = CircuitGate::new( + GateType::ForeignFieldAdd, + [Wire::new(current_row, 0); 7], + vec![], + ); + gates.push(add_gate); + current_row += 1; + } + } + } + } + Ok(gates) + } + + OnnxOperation::Relu | OnnxOperation::Max => { + // ReLU implemented using range check and generic gates + let mut gates = Vec::new(); + + // Range check for input + let range_check = + CircuitGate::new(GateType::RangeCheck0, [Wire::new(start_row, 0); 7], vec![]); + gates.push(range_check); + + // Generic gate for max(0, x) logic + let generic = + CircuitGate::new(GateType::Generic, [Wire::new(start_row + 1, 0); 7], vec![]); + gates.push(generic); + + Ok(gates) + } + + OnnxOperation::Sigmoid => { + // Sigmoid implemented using generic gates for the sigmoid function + let mut gates = Vec::new(); + + // Generic gate for sigmoid computation + let generic = + CircuitGate::new(GateType::Generic, [Wire::new(start_row, 0); 7], vec![]); + gates.push(generic); + + Ok(gates) + } + + OnnxOperation::Add | OnnxOperation::EinSum => { + // Addition operation + let mut gates = Vec::new(); + + // Generic gate for addition + let add_gate = CircuitGate::new( + GateType::ForeignFieldAdd, + [Wire::new(start_row, 0); 7], + vec![], + ); + gates.push(add_gate); + + Ok(gates) + } + + OnnxOperation::Const | OnnxOperation::RmAxis | OnnxOperation::Reshape => { + // These operations don't need any gates as they're just shape operations + Ok(vec![]) + } + } + } +} + +/// Attempts to identify the ONNX operation type from a serialized node +pub fn identify_operation(node: &SerializableNode) -> Option { + match node.op_type { + OperationType::Input => None, + OperationType::Const => Some(OnnxOperation::Const), + OperationType::MatMul => { + if node.inputs.len() == 2 { + let m = node.out_dims[0]; + let n = node.out_dims[1]; + let k = if node.inputs.len() == 2 { + // For MatMul, k is the inner dimension + node.out_dims[1] // This should be derived from input dimensions + } else { + 0 + }; + Some(OnnxOperation::MatMul { m, n, k }) + } else { + None + } + } + OperationType::Relu => Some(OnnxOperation::Relu), + OperationType::Sigmoid => Some(OnnxOperation::Sigmoid), + OperationType::Add => Some(OnnxOperation::Add), + OperationType::EinSum => Some(OnnxOperation::EinSum), + OperationType::Max => Some(OnnxOperation::Max), + OperationType::RmAxis => Some(OnnxOperation::RmAxis), + OperationType::Reshape => Some(OnnxOperation::Reshape), + } +} + +/// Identifies the operation type from a tract node +pub fn identify_tract_operation(node: &TypedNode) -> Option { + // Check operation type based on the node's operation name + let op_name = node.op.name(); + match op_name { + name if name == *"Const" => { + println!("Found Const operation"); + Some(OperationType::Const) + } + name if name == *"MatMul" || name == *"Gemm" => { + println!("Found matrix operation: {}", name); + Some(OperationType::MatMul) + } + name if name == *"EinSum" => { + println!("Found matrix operation: {}", name); + Some(OperationType::EinSum) + } + name if name == *"Relu" || name == *"Max" => { + println!("Found ReLU/Max operation: {}", name); + if name == *"Max" { + Some(OperationType::Max) + } else { + Some(OperationType::Relu) + } + } + name if name == *"Sigmoid" => { + println!("Found Sigmoid operation"); + Some(OperationType::Sigmoid) + } + name if name == *"Add" => { + println!("Found Add operation: {}", name); + Some(OperationType::Add) + } + name if name == *"Reshape" => { + println!("Found Reshape operation"); + Some(OperationType::Reshape) + } + name if name.starts_with("Rm(") => { + println!("Found RmAxis operation"); + Some(OperationType::RmAxis) + } + name if name == *"Source" => { + println!("Found Input operation"); + Some(OperationType::Input) + } + name => { + println!("Unknown operation: {}", name); + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_matmul_gate_generation() { + let op = OnnxOperation::MatMul { m: 2, n: 2, k: 2 }; + let gates = op.to_circuit_gates(0).unwrap(); + assert!(!gates.is_empty()); + } + + #[test] + fn test_relu_gate_generation() { + let op = OnnxOperation::Relu; + let gates = op.to_circuit_gates(0).unwrap(); + assert_eq!(gates.len(), 2); + assert_eq!(gates[0].typ, GateType::RangeCheck0); + assert_eq!(gates[1].typ, GateType::Generic); + } + + #[test] + fn test_sigmoid_gate_generation() { + let op = OnnxOperation::Sigmoid; + let gates = op.to_circuit_gates(0).unwrap(); + assert_eq!(gates.len(), 1); + assert_eq!(gates[0].typ, GateType::Generic); + } + + #[test] + fn test_operation_identification() { + // Test MatMul identification + let matmul_node = SerializableNode { + op_type: OperationType::MatMul, + inputs: vec![(0, 0), (1, 0)], + out_dims: vec![2, 2], + out_scale: 1, + id: 0, + weights: None, + bias: None, + }; + match identify_operation(&matmul_node) { + Some(OnnxOperation::MatMul { m, n, k }) => { + assert_eq!(m, 2); + assert_eq!(n, 2); + assert_eq!(k, 2); + } + _ => panic!("Expected MatMul operation"), + } + + // Test ReLU identification + let relu_node = SerializableNode { + op_type: OperationType::Relu, + inputs: vec![(0, 0)], + out_dims: vec![4], + out_scale: 1, + id: 0, + weights: None, + bias: None, + }; + match identify_operation(&relu_node) { + Some(OnnxOperation::Relu) => (), + _ => panic!("Expected ReLU operation"), + } + + // Test Sigmoid identification + let sigmoid_node = SerializableNode { + op_type: OperationType::Sigmoid, + inputs: vec![(0, 0)], + out_dims: vec![4], + out_scale: 1, + id: 0, + weights: None, + bias: None, + }; + match identify_operation(&sigmoid_node) { + Some(OnnxOperation::Sigmoid) => (), + _ => panic!("Expected Sigmoid operation"), + } + + // Test Input node (should return None) + let input_node = SerializableNode { + op_type: OperationType::Input, + inputs: vec![], + out_dims: vec![4], + out_scale: 1, + id: 0, + weights: None, + bias: None, + }; + assert!(identify_operation(&input_node).is_none()); + + // Test Const node + let const_node = SerializableNode { + op_type: OperationType::Const, + inputs: vec![], + out_dims: vec![4], + out_scale: 1, + id: 0, + weights: Some(vec![1.0, 2.0, 3.0, 4.0]), + bias: None, + }; + match identify_operation(&const_node) { + Some(OnnxOperation::Const) => (), + _ => panic!("Expected Const operation"), + } + } +} diff --git a/src/zk/proof.rs b/src/zk/proof.rs new file mode 100644 index 0000000..d8e4fac --- /dev/null +++ b/src/zk/proof.rs @@ -0,0 +1,383 @@ +use ark_ff::{UniformRand, Zero}; +use ark_poly::EvaluationDomain; +use groupmap::GroupMap; +use kimchi::{ + circuits::{constraints::ConstraintSystem, wires::COLUMNS}, + proof::ProverProof, + prover_index::ProverIndex, + verifier_index::VerifierIndex, +}; +use mina_curves::pasta::{Fp, Vesta, VestaParameters}; +use mina_poseidon::{ + constants::PlonkSpongeConstantsKimchi, + sponge::{DefaultFqSponge, DefaultFrSponge}, +}; +use poly_commitment::{commitment::CommitmentCurve, ipa::SRS, SRS as _}; +use rand::{rngs::ThreadRng, thread_rng}; +use std::{array, sync::Arc}; + +use super::wiring::ModelCircuitBuilder; +use super::ZkOpeningProof; +use crate::graph::model::Model; + +type SpongeParams = PlonkSpongeConstantsKimchi; +type BaseSponge = DefaultFqSponge; +type ScalarSponge = DefaultFrSponge; + +/// Result type containing model output and its proof +#[derive(Clone)] +pub struct ProverOutput { + pub output: Vec>, + pub proof: ProverProof, +} + +/// Creates prover and verifier indices for a model +pub struct ProofSystem { + pub prover_index: ProverIndex, + pub verifier_index: VerifierIndex, + model: Model, + domain_size: usize, + zk_rows: usize, +} + +type WitnessOutput = ([Vec; COLUMNS], Vec>); + +impl ProofSystem { + /// Create a new proof system from a model + pub fn new(model: &Model) -> Self { + // Convert model to circuit gates + let mut builder = ModelCircuitBuilder::new(); + let (gates, domain_size, zk_rows) = builder.build_circuit(model); + + // Calculate total number of public inputs and outputs + let num_public_inputs = model + .graph + .inputs + .iter() + .map(|&idx| { + if let crate::graph::model::NodeType::Node(node) = &model.graph.nodes[&idx] { + node.out_dims.iter().product::() + } else { + 0usize + } + }) + .sum::(); + + let num_public_outputs = model + .graph + .outputs + .iter() + .map(|&(node, _)| { + if let crate::graph::model::NodeType::Node(node) = &model.graph.nodes[&node] { + node.out_dims.iter().product::() + } else { + 0usize + } + }) + .sum::(); + + let total_public = num_public_outputs; // Only outputs are public + + println!("Number of public inputs: {}", num_public_inputs); + println!("Number of public outputs: {}", num_public_outputs); + println!("Required domain size: {}", domain_size); + + // Create constraint system with our domain size + let cs = ConstraintSystem::create(gates) + .public(total_public) // Only outputs are public + .max_poly_size(Some(domain_size)) + .build() + .expect("Failed to create constraint system"); + + println!("Constraint system domain size: {}", cs.domain.d1.size()); + + // Create SRS with our domain size + println!("Using SRS size: {}", domain_size); + let srs = SRS::::create(domain_size); + let srs = Arc::new(srs); + + // Create prover index + let prover_index = ProverIndex::create(cs.clone(), Fp::zero(), srs); + + // Create verifier index + let verifier_index = prover_index.verifier_index(); + + Self { + prover_index, + verifier_index, + model: model.clone(), + domain_size, + zk_rows, + } + } + + /// Convert f32 to field element + fn f32_to_field(value: f32) -> Fp { + if value < 0.0 { + -Fp::from((-value * 1000.0) as u64) + } else { + Fp::from((value * 1000.0) as u64) + } + } + + /// Create witness for the circuit + fn create_witness(&self, inputs: &[Vec]) -> Result { + // First execute the model to get outputs + let outputs = self + .model + .graph + .execute(inputs) + .map_err(|e| format!("Failed to execute model: {:?}", e))?; + + // Calculate initial witness size (without padding) + let mut witness_size = 0; + + let public_outputs: Vec = outputs + .iter() + .flat_map(|output| output.iter().map(|&x| Self::f32_to_field(x))) + .collect(); + + // Total public values is outputs only + witness_size += public_outputs.len(); + + // Add space for intermediate computations + for node in self.model.graph.nodes.values() { + if let crate::graph::model::NodeType::Node(node) = node { + match node.op_type { + crate::graph::model::OperationType::MatMul => { + witness_size += node.out_dims.iter().product::(); + } + crate::graph::model::OperationType::Relu => { + witness_size += node.out_dims.iter().product::(); + } + crate::graph::model::OperationType::Add => { + witness_size += node.out_dims.iter().product::(); + } + _ => {} + } + } + } + + // Ensure witness size is strictly less than domain_size - zk_rows + assert!( + witness_size < self.domain_size - self.zk_rows, + "Witness size {} must be strictly less than domain size {} minus zk_rows {}", + witness_size, + self.domain_size, + self.zk_rows + ); + + // Create witness arrays + let mut witness = array::from_fn(|_| vec![Fp::zero(); self.domain_size]); + + // Place public outputs at the start + for (i, &value) in public_outputs.iter().enumerate() { + for item in witness.iter_mut().take(COLUMNS) { + item[i] = value; + } + } + + // Process each node in topological order starting after public values + let mut current_row = public_outputs.len(); + let mut intermediate_values = std::collections::HashMap::new(); + + for (idx, node) in &self.model.graph.nodes { + if let crate::graph::model::NodeType::Node(node) = node { + match node.op_type { + crate::graph::model::OperationType::MatMul => { + let input_size = node.inputs[0].1; + let output_size = node.out_dims.iter().product(); + + // Get input values + let input_values = if let Some((input_idx, _)) = node.inputs.first() { + intermediate_values + .get(input_idx) + .map(|&row| (0..input_size).map(|i| witness[0][row + i]).collect()) + .unwrap_or_else(|| (0..input_size).map(|i| witness[0][i]).collect()) + } else { + vec![Fp::zero(); input_size] + }; + + // Compute matrix multiplication + if let Some(weights) = &node.weights { + for i in 0..output_size { + let mut sum = Fp::zero(); + for j in 0..input_size { + let weight = Self::f32_to_field(weights[i * input_size + j]); + sum += weight * input_values[j]; + } + // Set the result in all columns + for item in witness.iter_mut().take(COLUMNS) { + item[current_row + i] = sum; + } + } + intermediate_values.insert(*idx, current_row); + current_row += output_size; + } + } + crate::graph::model::OperationType::Add => { + if let (Some((left_idx, _)), Some((right_idx, _))) = + (node.inputs.first(), node.inputs.get(1)) + { + if let (Some(&left_row), Some(&right_row)) = ( + intermediate_values.get(left_idx), + intermediate_values.get(right_idx), + ) { + let size = node.out_dims.iter().product(); + for i in 0..size { + let result = + witness[0][left_row + i] + witness[0][right_row + i]; + // Set the result in all columns + for item in witness.iter_mut().take(COLUMNS) { + item[current_row + i] = result; + } + } + intermediate_values.insert(*idx, current_row); + current_row += size; + } + } + } + crate::graph::model::OperationType::Relu + | crate::graph::model::OperationType::Max => { + if let Some((input_idx, _)) = node.inputs.first() { + if let Some(&input_row) = intermediate_values.get(input_idx) { + let size = node.out_dims.iter().product(); + for i in 0..size { + let x = witness[0][input_row + i]; + let result = if x == Fp::zero() { Fp::zero() } else { x }; + // Set the result in all columns + for item in witness.iter_mut().take(COLUMNS) { + item[current_row + i] = result; + } + } + intermediate_values.insert(*idx, current_row); + current_row += size; + } + } + } + _ => {} + } + } + } + + // Add random values for zero-knowledge rows at the end + let mut rng = thread_rng(); + for item in witness.iter_mut().take(COLUMNS) { + for i in item + .iter_mut() + .take(self.domain_size) + .skip(self.domain_size - self.zk_rows) + { + *i = ::rand(&mut rng); + } + } + + // Pad remaining rows with zeros + for item in witness.iter_mut().take(COLUMNS) { + for i in item + .iter_mut() + .take(self.domain_size - self.zk_rows) + .skip(current_row) + { + *i = Fp::zero(); + } + } + + Ok((witness, outputs)) + } + + /// Generate model output and create a proof + pub fn prove(&self, inputs: &[Vec]) -> Result { + // Create witness and get outputs + let (witness, outputs) = self.create_witness(inputs)?; + + // Setup group map + let group_map = ::Map::setup(); + + // Create proof + let mut rng = thread_rng(); + let proof = ProverProof::create::( + &group_map, + witness, + &[], + &self.prover_index, + &mut rng, + ) + .map_err(|e| format!("Failed to create proof: {:?}", e))?; + + Ok(ProverOutput { + output: outputs, + proof, + }) + } + + /// Verify a proof given output and proof + pub fn verify( + &self, + output: &[Vec], + proof: &ProverProof, + ) -> Result { + // Convert output to field elements + let public_values: Vec = output + .iter() + .flat_map(|output| output.iter().map(|&x| Self::f32_to_field(x))) + .collect(); + + println!("Verifying proof with {} output values", public_values.len()); + + // Setup group map + let group_map = ::Map::setup(); + + // Verify proof with outputs only + kimchi::verifier::verify::( + &group_map, + &self.verifier_index, + proof, + &public_values, + ) + .map(|_| true) + .map_err(|e| format!("Failed to verify proof: {:?}", e)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::graph::model::{RunArgs, VarVisibility, Visibility}; + use std::collections::HashMap; + + #[test] + fn test_proof_system() { + // Create a simple model (perceptron) + let mut variables = HashMap::new(); + variables.insert("batch_size".to_string(), 1); + let run_args = RunArgs { variables }; + + let visibility = VarVisibility { + input: Visibility::Public, + output: Visibility::Public, + }; + + let model = Model::new("models/simple_perceptron.onnx", &run_args, &visibility) + .expect("Failed to load model"); + + // Create proof system + let proof_system = ProofSystem::new(&model); + + // Create sample input - pad to match expected size [1, 10] + let input = vec![vec![ + 1.0, 0.5, -0.3, 0.8, -0.2, // Original values + 0.0, 0.0, 0.0, 0.0, 0.0, // Padding to reach size 10 + ]]; + + // Generate output and proof + let prover_output = proof_system.prove(&input).expect("Failed to create proof"); + + // Verify the proof with just output and proof + let result = proof_system + .verify(&prover_output.output, &prover_output.proof) + .expect("Failed to verify proof"); + + assert!(result); + } +} diff --git a/src/zk/wiring.rs b/src/zk/wiring.rs new file mode 100644 index 0000000..33aca25 --- /dev/null +++ b/src/zk/wiring.rs @@ -0,0 +1,247 @@ +use ark_ff::Zero; +use ark_poly::{EvaluationDomain, Radix2EvaluationDomain}; +use kimchi::circuits::{ + gate::{CircuitGate, GateType}, + wires::Wire, +}; +use mina_curves::pasta::Fp; + +use crate::graph::model::{Model, NodeType, OperationType}; + +// Constants from o1js/proof-systems +pub const COLUMNS: usize = 15; // Total number of columns +pub const PERMUTS: usize = 7; // Number of permutable columns +pub const WIRES: [usize; COLUMNS] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; + +// Minimum domain size from o1js/proof-systems +pub const MIN_DOMAIN_SIZE: usize = 4096; + +pub struct ModelCircuitBuilder { + current_row: usize, +} + +impl Default for ModelCircuitBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ModelCircuitBuilder { + pub fn new() -> Self { + Self { current_row: 0 } + } + + /// Calculate the strict lower bound for zk_rows + fn zk_rows_strict_lower_bound(num_chunks: usize) -> usize { + (2 * (PERMUTS + 1) * num_chunks - 2) / PERMUTS + } + + /// Calculate next power of 2 + fn next_power_of_two(n: usize) -> usize { + let mut v = n; + v -= 1; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v |= v >> 32; + v += 1; + v + } + + /// Create wires for a row + fn create_wires(row: usize) -> [Wire; PERMUTS] { + // Create wires for the first PERMUTS columns (which can be wired) + // Each wire references itself by default + [ + Wire { row, col: 0 }, // Current row, main wire + Wire { row, col: 1 }, // Current row, auxiliary wire 1 + Wire { row, col: 2 }, // Current row, auxiliary wire 2 + Wire { row, col: 3 }, // Current row, auxiliary wire 3 + Wire { row, col: 4 }, // Current row, auxiliary wire 4 + Wire { row, col: 5 }, // Current row, auxiliary wire 5 + Wire { row, col: 6 }, // Current row, auxiliary wire 6 + ] + } + + /// Calculate domain size and zk_rows for a given circuit size + fn calculate_domain_params(circuit_size: usize) -> (usize, usize) { + let lookup_domain_size = 0; // We don't use lookup tables yet + let circuit_lower_bound = std::cmp::max(circuit_size, lookup_domain_size + 1); + let get_domain_size_lower_bound = |zk_rows: usize| circuit_lower_bound + zk_rows; + + // Start with minimum values + let zk_rows = 3; + let domain_size_lower_bound = get_domain_size_lower_bound(zk_rows); + + // Calculate initial domain size + let mut domain_size = match Radix2EvaluationDomain::::new(domain_size_lower_bound) { + Some(domain) => std::cmp::max(MIN_DOMAIN_SIZE, domain.size()), + None => std::cmp::max( + MIN_DOMAIN_SIZE, + Self::next_power_of_two(domain_size_lower_bound), + ), + }; + + // Calculate number of chunks and required zk_rows + let num_chunks = domain_size.div_ceil(MIN_DOMAIN_SIZE); + let min_zk_rows = Self::zk_rows_strict_lower_bound(num_chunks) + 1; + let zk_rows = std::cmp::max(min_zk_rows, (16 * num_chunks + 5) / 7); + + // Ensure domain size is large enough + let domain_size_lower_bound = get_domain_size_lower_bound(zk_rows); + if domain_size < domain_size_lower_bound { + domain_size = match Radix2EvaluationDomain::::new(domain_size_lower_bound) { + Some(domain) => std::cmp::max(MIN_DOMAIN_SIZE, domain.size()), + None => std::cmp::max( + MIN_DOMAIN_SIZE, + Self::next_power_of_two(domain_size_lower_bound), + ), + }; + } + + (domain_size, zk_rows) + } + + pub fn build_circuit(&mut self, model: &Model) -> (Vec>, usize, usize) { + let mut gates = Vec::new(); + + // Calculate total number of public inputs + let num_public: usize = model + .graph + .inputs + .iter() + .map(|&idx| { + if let NodeType::Node(node) = &model.graph.nodes[&idx] { + node.out_dims.iter().product::() + } else { + 0 + } + }) + .sum::(); + + // Calculate initial circuit size (without padding) + let mut circuit_size = num_public; + + // Calculate space needed for operations + for node in model.graph.nodes.values() { + if let NodeType::Node(node) = node { + match node.op_type { + OperationType::MatMul => { + let output_size = node.out_dims.iter().product::(); + circuit_size += output_size; + } + OperationType::Relu => { + let output_size = node.out_dims.iter().product::(); + circuit_size += output_size; + } + OperationType::Add => { + let output_size = node.out_dims.iter().product::(); + circuit_size += output_size; + } + _ => {} + } + } + } + + // Calculate domain size and zk_rows + let (domain_size, zk_rows) = Self::calculate_domain_params(circuit_size); + + // Add gates for public inputs + for i in 0..num_public { + gates.push(CircuitGate { + typ: GateType::Generic, + wires: Self::create_wires(i), + coeffs: vec![Fp::from(1u64)], + }); + self.current_row += 1; + } + + // Process each node in topological order + let mut intermediate_rows = std::collections::HashMap::new(); + for (idx, node) in &model.graph.nodes { + if let NodeType::Node(node) = node { + match node.op_type { + OperationType::MatMul => { + let output_size: usize = node.out_dims.iter().product(); + + // Add computation gates + for i in 0..output_size { + gates.push(CircuitGate { + typ: GateType::Generic, + wires: Self::create_wires(self.current_row + i), + coeffs: vec![Fp::from(1u64)], + }); + } + + intermediate_rows.insert(*idx, self.current_row); + self.current_row += output_size; + } + OperationType::Relu => { + let output_size: usize = node.out_dims.iter().product(); + + // Add computation gates + for i in 0..output_size { + gates.push(CircuitGate { + typ: GateType::Generic, + wires: Self::create_wires(self.current_row + i), + coeffs: vec![Fp::from(1u64)], + }); + } + + intermediate_rows.insert(*idx, self.current_row); + self.current_row += output_size; + } + OperationType::Add => { + let output_size: usize = node.out_dims.iter().product(); + + // Add computation gates + for i in 0..output_size { + gates.push(CircuitGate { + typ: GateType::Generic, + wires: Self::create_wires(self.current_row + i), + coeffs: vec![Fp::from(1u64)], + }); + } + + intermediate_rows.insert(*idx, self.current_row); + self.current_row += output_size; + } + _ => {} + } + } + } + + // Add padding gates until we reach domain_size - zk_rows + while gates.len() < domain_size - zk_rows { + gates.push(CircuitGate { + typ: GateType::Zero, + wires: Self::create_wires(self.current_row), + coeffs: vec![Fp::zero()], + }); + self.current_row += 1; + } + + // Add zero-knowledge rows + for i in 0..zk_rows { + gates.push(CircuitGate { + typ: GateType::Zero, + wires: Self::create_wires(self.current_row + i), + coeffs: vec![Fp::zero()], + }); + } + + // Ensure we have at least 2 gates (required by o1js/proof-systems) + while gates.len() < 2 { + gates.push(CircuitGate { + typ: GateType::Zero, + wires: Self::create_wires(self.current_row), + coeffs: vec![Fp::zero()], + }); + self.current_row += 1; + } + + (gates, domain_size, zk_rows) + } +} diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 7d07d00..7f250eb 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -1,4 +1,4 @@ -use kimchi::graph::{errors::GraphError, model::*, scales::*}; +use mina_zkml::graph::{errors::GraphError, model::*, scales::*}; use std::collections::{BTreeMap, HashMap}; #[test] @@ -54,7 +54,7 @@ fn test_model_load_and_scale_integration() { let test_value = 10.0; let rebased = var_scales.rebase(test_value); let unrebased = var_scales.unrebase(rebased); - assert!((test_value - unrebased).abs() < std::f64::EPSILON); + assert!((test_value - unrebased).abs() < f64::EPSILON); } NodeType::SubGraph { out_scales, .. } => { // Test subgraph scales @@ -148,18 +148,11 @@ fn test_model_graph_traversal() { #[test] fn test_error_handling_integration() { // Test missing batch size - let run_args = RunArgs { - variables: HashMap::new(), - }; - let visibility = VarVisibility { input: Visibility::Public, output: Visibility::Public, }; - let result = Model::new("models/resnet101-v1-7.onnx", &run_args, &visibility); - assert!(matches!(result, Err(GraphError::InvalidInputShape))); - // Test invalid model path let run_args = RunArgs { variables: HashMap::from([("batch_size".to_string(), 1)]),