diff --git a/.gitignore b/.gitignore index a6d408d..8593ffc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,5 @@ .DS_Store *.tgz -/.cargo/ /packages/core/README.md -/packages/wasm/wbg/ -/target/ -bindings/ dist/ node_modules/ diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 5557bca..78dc8eb 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -1,8 +1 @@ -{ - "recommendations": [ - "esbenp.prettier-vscode", - "rust-lang.rust-analyzer", - "tamasfe.even-better-toml", - "timonwong.shellcheck" - ] -} +{ "recommendations": ["esbenp.prettier-vscode", "timonwong.shellcheck"] } diff --git a/.vscode/settings.json b/.vscode/settings.json index 205f93a..8553edb 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -5,15 +5,11 @@ // Prettier-supported languages here: https://prettier.io/docs/en/index.html "[css]": { "editor.defaultFormatter": "esbenp.prettier-vscode" }, "[html]": { "editor.defaultFormatter": "esbenp.prettier-vscode" }, - "[javascript]": { "editor.defaultFormatter": "esbenp.prettier-vscode" }, "[json]": { "editor.defaultFormatter": "esbenp.prettier-vscode" }, "[jsonc]": { "editor.defaultFormatter": "esbenp.prettier-vscode" }, "[markdown]": { "editor.defaultFormatter": "esbenp.prettier-vscode" }, "[typescript]": { "editor.defaultFormatter": "esbenp.prettier-vscode" }, - "[typescriptreact]": { "editor.defaultFormatter": "esbenp.prettier-vscode" }, "[yaml]": { "editor.defaultFormatter": "esbenp.prettier-vscode" }, - "editor.formatOnSave": true, - "evenBetterToml.formatter.alignComments": false, - "rust-analyzer.check.command": "clippy" + "editor.formatOnSave": true } diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b3cb5e7..a90a799 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -7,7 +7,6 @@ Make sure to have these tools installed: - [Git][] - [Make][] - [Node][] -- [Rust][] ## Setup @@ -61,9 +60,8 @@ Sometimes old build artifacts can hide errors. To clean your build: make clean ``` -This doesn't clean everything; it keeps around downloaded files and Rust's -`target` directory. You should be able to run `make all` right after it without -an Internet connection. +This doesn't clean everything; it keeps around downloaded files. You should be +able to run `make all` right after it without an Internet connection. ## Site @@ -82,4 +80,3 @@ make site-deps && npm run --workspace=@rose-lang/site dev -- --host [git]: https://git-scm.com/downloads [make]: https://en.wikipedia.org/wiki/Make_(software) [node]: https://nodejs.org/en/download -[rust]: https://www.rust-lang.org/tools/install diff --git a/Cargo.lock b/Cargo.lock deleted file mode 100644 index f9bc8c5..0000000 --- a/Cargo.lock +++ /dev/null @@ -1,464 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "Inflector" -version = "0.11.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3" - -[[package]] -name = "bumpalo" -version = "3.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" - -[[package]] -name = "by_address" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf8dba2868114ed769a1f2590fc9ae5eb331175b44313b6c9b922f8f7ca813d0" - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "console_error_panic_hook" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a06aeb73f470f66dcdbf7223caeebb85984942f22f1adb2a088cf9668146bbbc" -dependencies = [ - "cfg-if", - "wasm-bindgen", -] - -[[package]] -name = "console_log" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be8aed40e4edbf4d3b4431ab260b63fdc40f5780a4766824329ea0f1eefe3c0f" -dependencies = [ - "log", - "web-sys", -] - -[[package]] -name = "darling" -version = "0.14.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" -dependencies = [ - "darling_core", - "darling_macro", -] - -[[package]] -name = "darling_core" -version = "0.14.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "darling_macro" -version = "0.14.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" -dependencies = [ - "darling_core", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "enumset" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19be8061a06ab6f3a6cf21106c873578bf01bd42ad15e0311a9c76161cb1c753" -dependencies = [ - "enumset_derive", -] - -[[package]] -name = "enumset_derive" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03e7b551eba279bf0fa88b83a46330168c1560a52a94f5126f892f0b364ab3e0" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "equivalent" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88bffebc5d80432c9b140ee17875ff173a8ab62faad5b257da912bd2f6c1c0a1" - -[[package]] -name = "fnv" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" - -[[package]] -name = "hashbrown" -version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" - -[[package]] -name = "ident_case" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" - -[[package]] -name = "indexmap" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" -dependencies = [ - "equivalent", - "hashbrown", -] - -[[package]] -name = "js-sys" -version = "0.3.61" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445dde2150c55e483f3d8416706b97ec8e8237c307e5b7b4b8dd15e6af2a0730" -dependencies = [ - "wasm-bindgen", -] - -[[package]] -name = "leb128" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "884e2677b40cc8c339eaefcb701c32ef1fd2493d71118dc0ca4b6a736c93bd67" - -[[package]] -name = "log" -version = "0.4.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "once_cell" -version = "1.17.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" - -[[package]] -name = "proc-macro2" -version = "1.0.56" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "quote" -version = "1.0.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rose" -version = "0.4.5" -dependencies = [ - "enumset", -] - -[[package]] -name = "rose-autodiff" -version = "0.4.5" -dependencies = [ - "rose", -] - -[[package]] -name = "rose-interp" -version = "0.4.5" -dependencies = [ - "enumset", - "indexmap", - "rose", - "serde", - "thiserror", - "ts-rs", -] - -[[package]] -name = "rose-transpose" -version = "0.4.5" -dependencies = [ - "enumset", - "indexmap", - "rose", -] - -[[package]] -name = "rose-wasm" -version = "0.4.5" -dependencies = [ - "by_address", - "indexmap", - "rose", - "wasm-encoder", -] - -[[package]] -name = "rose-web" -version = "0.4.5" -dependencies = [ - "by_address", - "console_error_panic_hook", - "console_log", - "enumset", - "indexmap", - "js-sys", - "rose", - "rose-autodiff", - "rose-interp", - "rose-transpose", - "rose-wasm", - "serde", - "serde-wasm-bindgen", - "wasm-bindgen", -] - -[[package]] -name = "serde" -version = "1.0.163" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2113ab51b87a539ae008b5c6c02dc020ffa39afd2d83cffcb3f4eb2722cebec2" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde-wasm-bindgen" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3b4c031cd0d9014307d82b8abf653c0290fbdaeb4c02d00c63cf52f728628bf" -dependencies = [ - "js-sys", - "serde", - "wasm-bindgen", -] - -[[package]] -name = "serde_derive" -version = "1.0.163" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.14", -] - -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "syn" -version = "2.0.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcf316d5356ed6847742d036f8a39c3b8435cac10bd528a4bd461928a6ab34d5" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "termcolor" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" -dependencies = [ - "winapi-util", -] - -[[package]] -name = "thiserror" -version = "1.0.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" -dependencies = [ - "thiserror-impl", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.14", -] - -[[package]] -name = "ts-rs" -version = "6.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4added4070a4fdf9df03457206cd2e4b12417c8560a2954d91ffcbe60177a56a" -dependencies = [ - "thiserror", - "ts-rs-macros", -] - -[[package]] -name = "ts-rs-macros" -version = "6.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f807fdb3151fee75df7485b901a89624358cd07a67a8fb1a5831bf5a07681ff" -dependencies = [ - "Inflector", - "proc-macro2", - "quote", - "syn 1.0.109", - "termcolor", -] - -[[package]] -name = "unicode-ident" -version = "1.0.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" - -[[package]] -name = "wasm-bindgen" -version = "0.2.87" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" -dependencies = [ - "cfg-if", - "wasm-bindgen-macro", -] - -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.87" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" -dependencies = [ - "bumpalo", - "log", - "once_cell", - "proc-macro2", - "quote", - "syn 2.0.14", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-macro" -version = "0.2.87" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" -dependencies = [ - "quote", - "wasm-bindgen-macro-support", -] - -[[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.87" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.14", - "wasm-bindgen-backend", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-shared" -version = "0.2.87" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" - -[[package]] -name = "wasm-encoder" -version = "0.33.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b39de0723a53d3c8f54bed106cfbc0d06b3e4d945c5c5022115a61e3b29183ae" -dependencies = [ - "leb128", -] - -[[package]] -name = "web-sys" -version = "0.3.61" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e33b99f4b23ba3eec1a53ac264e35a755f00e966e0065077d6027c0f575b0b97" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-util" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" -dependencies = [ - "winapi", -] - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" diff --git a/Cargo.toml b/Cargo.toml deleted file mode 100644 index d428e36..0000000 --- a/Cargo.toml +++ /dev/null @@ -1,13 +0,0 @@ -[workspace] -members = ["crates/*"] -resolver = "2" - -[profile.release] -codegen-units = 1 -lto = true - -# https://github.com/johnthagen/min-sized-rust -[profile.web] -inherits = "release" -opt-level = "z" -strip = true diff --git a/Makefile b/Makefile index 9abc29f..57d60e7 100644 --- a/Makefile +++ b/Makefile @@ -2,43 +2,18 @@ build: packages # run all tests -test: test-rust test-js +test: test-js # run other checks check: prettier - cargo fmt --check - cargo clippy # delete build artifacts, but not dependencies or downloaded files clean: - git clean -Xdf crates packages -e '!node_modules' + git clean -Xdf packages -e '!node_modules' # do everything all: build test check -### Rust - -# install additional Rust stuff that we need -rust: - cargo install --root=.cargo --version=0.2.87 wasm-bindgen-cli - -# export TypeScript bindings from Rust types -bindings: - cargo test export_bindings_ - -# compile Rust to WebAssembly -wbg: rust - cargo build --package=rose-web --target=wasm32-unknown-unknown --release - cargo build --package=rose-web --no-default-features -Z build-std=std,panic_abort -Z build-std-features=panic_immediate_abort --target wasm32-unknown-unknown --profile web - .cargo/bin/wasm-bindgen --target=web --out-dir=packages/wasm/wbg target/wasm32-unknown-unknown/release/rose_web.wasm - .cargo/bin/wasm-bindgen --target=web --out-dir=packages/wasm/dist/wbg target/wasm32-unknown-unknown/web/rose_web.wasm - -# run Rust tests -test-rust: - cargo test --quiet - -### JavaScript - # fetch JavaScript dependencies npm: npm i @@ -48,7 +23,7 @@ prettier: npm npx prettier --check . # build `packages/` -packages: core site wasm +packages: core site # run JavaScript tests test-js: test-core test-site @@ -56,12 +31,12 @@ test-js: test-core test-site ## `packages/core` # build -core: npm wasm +core: npm cp README.md packages/core npm run --workspace=rose build # test -test-core: npm wasm +test-core: npm npm run --workspace=rose test -- run --no-threads ## `packages/site` @@ -75,10 +50,3 @@ site: site-deps # test test-site: site-deps npm run --workspace=@rose-lang/site test -- run --no-threads - -## `packages/wasm` - -# build -wasm: npm bindings wbg - npm run --workspace=@rose-lang/wasm build - node bindings.js diff --git a/README.md b/README.md index 57aca98..eef9731 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@

icon by Aaron Weiss / CC BY 4.0

npm license Build

-Rose is an automatic differentiation engine for the web, inspired by [JAX][]. +Rose is a differentiable programming for the web, inspired by [Dex][]. ## Installation @@ -33,59 +33,7 @@ bun add rose ## Usage -This example defines custom derivatives for the builtin JavaScript logarithm and -power functions, then computes the output, gradient, and Hessian for the power -function applied with base 2 and exponent 3: - -```js -import { Dual, Real, Vec, add, compile, div, fn, mul, opaque, vjp } from "rose"; - -const log = opaque([Real], Real, Math.log); -log.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - return { re: log(x), du: div(dx, x) }; -}); - -const pow = opaque([Real, Real], Real, Math.pow); -pow.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }, { re: y, du: dy }) => { - const z = pow(x, y); - return { re: z, du: mul(add(mul(dx, div(y, x)), mul(dy, log(x))), z) }; -}); - -const Vec2 = Vec(2, Real); -const Mat2 = Vec(2, Vec2); - -const f = fn([Vec2], Real, ([x, y]) => pow(x, y)); -const g = fn([Vec2], Vec2, (v) => vjp(f)(v).grad(1)); -const h = fn([Vec2], Mat2, (v) => { - const { grad } = vjp(g)(v); - return [grad([1, 0]), grad([0, 1])]; -}); - -const funcs = await Promise.all([compile(f), compile(g), compile(h)]); -console.log(funcs.map((func) => func([2, 3]))); -``` - -### With Vite - -If you are using [Vite][] then you will need to also install the -[vite-plugin-top-level-await][] package, because Rose internally uses [top-level -`await`][], which Vite does not directly support. You must also include the -following in your Vite config: - -```js -import { defineConfig } from "vite"; -import topLevelAwait from "vite-plugin-top-level-await"; - -export default defineConfig({ - // the plugin described above - plugins: [topLevelAwait()], - - // Vite bundles external dependencies by default in development mode, but that - // process does not include assets; this option disables that particular kind - // of bundling for Rose so that it can use its internal WebAssembly module - optimizeDeps: { exclude: ["rose"] }, -}); -``` +TODO ## Contributing @@ -95,13 +43,10 @@ See [`CONTRIBUTING.md`][]. Rose is licensed under the [MIT License][]. +[Dex]: https://github.com/google-research/dex-lang [`CONTRIBUTING.md`]: https://github.com/rose-lang/rose/blob/main/CONTRIBUTING.md [Bun]: https://bun.sh/ -[JAX]: http://jax.readthedocs.io/ [MIT License]: https://github.com/rose-lang/rose/blob/main/LICENSE [npm]: https://docs.npmjs.com/downloading-and-installing-node-js-and-npm [pnpm]: https://pnpm.io/installation -[top-level `await`]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/await#top_level_await -[vite-plugin-top-level-await]: https://www.npmjs.com/package/vite-plugin-top-level-await -[Vite]: https://vitejs.dev/ [Yarn]: https://classic.yarnpkg.com/lang/en/docs/install/ diff --git a/bindings.js b/bindings.js deleted file mode 100644 index 550cf35..0000000 --- a/bindings.js +++ /dev/null @@ -1,44 +0,0 @@ -import fs from "fs/promises"; -import path from "path"; -import prettier from "prettier"; - -const start_dir = "crates"; -const dest_dir = "packages/wasm/dist/bindings"; - -const crates = await fs.readdir(start_dir, { withFileTypes: true }); -for (const crate of crates) { - if (!crate.isDirectory()) continue; - const bindings_dir = path.join(start_dir, crate.name, "bindings"); - - const stat = await fs.stat(bindings_dir).catch(() => null); - if (!stat || !stat.isDirectory()) { - continue; // bindings directory does not exist for this crate, so skip it - } - - // Create the destination directory - const dest_folder = path.join(dest_dir, crate.name); - await fs.mkdir(dest_folder, { recursive: true }); - - // Read all .ts files in the bindings directory - const files = (await fs.readdir(bindings_dir)).filter((f) => - f.endsWith(".ts"), - ); - for (const file of files) { - const tsfile = path.join(bindings_dir, file); - - // Get the base filename, without the .ts extension - const base_name = path.basename(tsfile, ".ts"); - - // Read the content of the .ts file - const ts_content = await fs.readFile(tsfile, "utf-8"); - - // Add .js to import lines - const updated_content = ts_content.replace(/^(import .*)";$/gm, '$1.js";'); - - // Write the updated content to the new location with .d.ts extension - await fs.writeFile( - path.join(dest_folder, `${base_name}.d.ts`), - await prettier.format(updated_content, { parser: "typescript" }), - ); - } -} diff --git a/crates/autodiff/Cargo.toml b/crates/autodiff/Cargo.toml deleted file mode 100644 index 84478ea..0000000 --- a/crates/autodiff/Cargo.toml +++ /dev/null @@ -1,8 +0,0 @@ -[package] -name = "rose-autodiff" -version = "0.4.5" -publish = false -edition = "2021" - -[dependencies] -rose = { path = "../core" } diff --git a/crates/autodiff/src/lib.rs b/crates/autodiff/src/lib.rs deleted file mode 100644 index 1d2a9eb..0000000 --- a/crates/autodiff/src/lib.rs +++ /dev/null @@ -1,410 +0,0 @@ -use rose::{id, Binop, Expr, Func, Instr, Ty, Unop}; - -// see docstring of `pub fn jvp` below - -const REAL: id::Ty = id::ty(0); -const DUAL: id::Ty = id::ty(1); - -// JS frontend requires field names to be alphabetized, and `"du"` comes before `"re"` -const DU: id::Member = id::member(0); -const RE: id::Member = id::member(1); - -fn map(t: id::Ty) -> id::Ty { - id::ty(t.ty() + 2) -} - -struct Autodiff<'a> { - old_types: &'a [Ty], - old_vars: &'a [id::Ty], - new_vars: &'a mut Vec, - unpacked: &'a mut [Option<(id::Var, id::Var)>], - dual_zero: id::Var, - code: Vec, -} - -impl Autodiff<'_> { - fn set(&mut self, t: id::Ty, expr: Expr) -> id::Var { - let var = id::var(self.new_vars.len()); - self.new_vars.push(t); - self.code.push(Instr { var, expr }); - var - } - - fn real(&mut self, expr: Expr) -> id::Var { - self.set(REAL, expr) - } - - fn dual(&mut self, expr: Expr) -> id::Var { - self.set(DUAL, expr) - } - - fn unpack(&mut self, var: id::Var) { - let i = var.var(); - if self.unpacked[i].is_none() { - if let Ty::F64 = self.old_types[self.old_vars[i].ty()] { - let x = self.real(Expr::Member { - tuple: var, - member: RE, - }); - let dx = self.dual(Expr::Member { - tuple: var, - member: DU, - }); - self.unpacked[i] = Some((x, dx)) - } - } - } - - fn get(&self, var: id::Var) -> (id::Var, id::Var) { - self.unpacked[var.var()].unwrap() - } - - fn pack(&mut self, var: id::Var, x: id::Var, dx: id::Var) { - self.unpacked[var.var()] = Some((x, dx)); - self.code.push(Instr { - var, - expr: Expr::Tuple { - members: [dx, x].into(), // alphabetical order - }, - }) - } - - fn child(&mut self, orig: &[Instr]) -> Box<[Instr]> { - Autodiff { - old_types: self.old_types, - old_vars: self.old_vars, - new_vars: self.new_vars, - unpacked: self.unpacked, - dual_zero: self.dual_zero, - code: vec![], - } - .block(orig) - } - - fn block(mut self, orig: &[Instr]) -> Box<[Instr]> { - for Instr { var, expr } in orig { - self.instr(*var, expr); - self.unpack(*var); - } - self.code.into() - } - - fn instr(&mut self, var: id::Var, expr: &Expr) { - match expr { - // boring cases - Expr::Unit => self.code.push(Instr { - var, - expr: Expr::Unit, - }), - &Expr::Bool { val } => self.code.push(Instr { - var, - expr: Expr::Bool { val }, - }), - &Expr::Fin { val } => self.code.push(Instr { - var, - expr: Expr::Fin { val }, - }), - Expr::Array { elems } => self.code.push(Instr { - var, - expr: Expr::Array { - elems: elems.clone(), - }, - }), - Expr::Tuple { members } => self.code.push(Instr { - var, - expr: Expr::Tuple { - members: members.clone(), - }, - }), - &Expr::Index { array, index } => self.code.push(Instr { - var, - expr: Expr::Index { array, index }, - }), - &Expr::Member { tuple, member } => self.code.push(Instr { - var, - expr: Expr::Member { tuple, member }, - }), - &Expr::Slice { array, index } => self.code.push(Instr { - var, - expr: Expr::Slice { array, index }, - }), - &Expr::Field { tuple, member } => self.code.push(Instr { - var, - expr: Expr::Field { tuple, member }, - }), - &Expr::Select { cond, then, els } => self.code.push(Instr { - var, - expr: Expr::Select { cond, then, els }, - }), - &Expr::Accum { shape } => self.code.push(Instr { - var, - expr: Expr::Accum { shape }, - }), - &Expr::Add { accum, addend } => self.code.push(Instr { - var, - expr: Expr::Add { accum, addend }, - }), - &Expr::Resolve { var: container } => self.code.push(Instr { - var, - expr: Expr::Resolve { var: container }, - }), - - // less boring cases - Expr::Call { id, generics, args } => self.code.push(Instr { - var, - expr: Expr::Call { - id: *id, - generics: generics.iter().copied().map(map).collect(), - args: args.clone(), - }, - }), - Expr::For { arg, body, ret } => { - let body = self.child(body); - self.code.push(Instr { - var, - expr: Expr::For { - arg: *arg, - body, - ret: *ret, - }, - }) - } - - // interesting cases - &Expr::F64 { val } => { - let x = self.real(Expr::F64 { val }); - let dx = self.dual_zero; - self.pack(var, x, dx) - } - &Expr::Unary { op, arg } => match op { - // boring case - Unop::Not => self.code.push(Instr { - var, - expr: Expr::Unary { op: Unop::Not, arg }, - }), - - // interesting cases - Unop::Neg => { - let (x, dx) = self.get(arg); - let y = self.real(Expr::Unary { - op: Unop::Neg, - arg: x, - }); - let dy = self.dual(Expr::Unary { - op: Unop::Neg, - arg: dx, - }); - self.pack(var, y, dy) - } - Unop::Abs => { - let (x, dx) = self.get(arg); - let y = self.real(Expr::Unary { - op: Unop::Abs, - arg: x, - }); - let sign = self.real(Expr::Unary { - op: Unop::Sign, - arg: x, - }); - let dy = self.dual(Expr::Binary { - op: Binop::Mul, - left: dx, - right: sign, - }); - self.pack(var, y, dy) - } - Unop::Sign | Unop::Ceil | Unop::Floor | Unop::Trunc => { - let (x, _) = self.get(arg); - let y = self.real(Expr::Unary { op, arg: x }); - let dy = self.dual_zero; - self.pack(var, y, dy) - } - Unop::Sqrt => { - let (x, dx) = self.get(arg); - let y = self.real(Expr::Unary { - op: Unop::Sqrt, - arg: x, - }); - let z = self.real(Expr::Binary { - op: Binop::Add, - left: y, - right: y, - }); - let dy = self.dual(Expr::Binary { - op: Binop::Div, - left: dx, - right: z, - }); - self.pack(var, y, dy) - } - }, - &Expr::Binary { op, left, right } => match op { - // boring cases - Binop::And | Binop::Or | Binop::Iff | Binop::Xor => self.code.push(Instr { - var, - expr: Expr::Binary { op, left, right }, - }), - - // less boring cases - Binop::Neq | Binop::Lt | Binop::Leq | Binop::Eq | Binop::Gt | Binop::Geq => { - let (x, _) = self.get(left); - let (y, _) = self.get(right); - self.code.push(Instr { - var, - expr: Expr::Binary { - op, - left: x, - right: y, - }, - }) - } - - // interesting cases - Binop::Add => { - let (x, dx) = self.get(left); - let (y, dy) = self.get(right); - let z = self.real(Expr::Binary { - op: Binop::Add, - left: x, - right: y, - }); - let dz = self.dual(Expr::Binary { - op: Binop::Add, - left: dx, - right: dy, - }); - self.pack(var, z, dz) - } - Binop::Sub => { - let (x, dx) = self.get(left); - let (y, dy) = self.get(right); - let z = self.real(Expr::Binary { - op: Binop::Sub, - left: x, - right: y, - }); - let dz = self.dual(Expr::Binary { - op: Binop::Sub, - left: dx, - right: dy, - }); - self.pack(var, z, dz) - } - Binop::Mul => { - let (x, dx) = self.get(left); - let (y, dy) = self.get(right); - let z = self.real(Expr::Binary { - op: Binop::Mul, - left: x, - right: y, - }); - let a = self.dual(Expr::Binary { - op: Binop::Mul, - left: dx, - right: y, - }); - let b = self.dual(Expr::Binary { - op: Binop::Mul, - left: dy, - right: x, - }); - let dz = self.dual(Expr::Binary { - op: Binop::Add, - left: a, - right: b, - }); - self.pack(var, z, dz) - } - Binop::Div => { - let (x, dx) = self.get(left); - let (y, dy) = self.get(right); - let z = self.real(Expr::Binary { - op: Binop::Div, - left: x, - right: y, - }); - let a = self.real(Expr::Binary { - op: Binop::Div, - left: z, - right: y, - }); - let b = self.dual(Expr::Binary { - op: Binop::Div, - left: dx, - right: y, - }); - let c = self.dual(Expr::Binary { - op: Binop::Mul, - left: dy, - right: a, - }); - let dz = self.dual(Expr::Binary { - op: Binop::Sub, - left: b, - right: c, - }); - self.pack(var, z, dz) - } - }, - } - } -} - -/// Return a function that computes the Jacobian-vector product of this function. -/// -/// The first two types in the new function are the nonlinear and linear `F64` types, respectively. -/// Every type from the original function is then mapped over directly in a one-to-one fashion, with -/// indices shifted by two as necessary. Instances of the `F64` type from the original function are -/// replaced with a `Tuple` type whose members are the linear and nonlinear `F64` types, -/// respectively (note that this member order does not match the order of the types themselves). -pub fn jvp(f: &Func) -> Func { - let mut types = vec![Ty::F64, Ty::F64]; - types.extend(f.types.iter().map(|ty| match ty { - // boring cases - Ty::Unit => Ty::Unit, - Ty::Bool => Ty::Bool, - &Ty::Fin { size } => Ty::Fin { size }, - &Ty::Generic { id } => Ty::Generic { id }, - - // less boring cases - &Ty::Ref { inner } => Ty::Ref { inner: map(inner) }, - &Ty::Array { index, elem } => Ty::Array { - index: map(index), - elem: map(elem), - }, - Ty::Tuple { members } => Ty::Tuple { - members: members.iter().copied().map(map).collect(), - }, - - // interesting case - Ty::F64 => Ty::Tuple { - members: [DUAL, REAL].into(), // alphabetical order - }, - })); - let mut vars: Vec<_> = f.vars.iter().copied().map(map).collect(); - let dual_zero = id::var(vars.len()); - vars.push(DUAL); - let mut ad = Autodiff { - old_types: &f.types, - old_vars: &f.vars, - new_vars: &mut vars, - unpacked: &mut vec![None; f.vars.len()], - dual_zero, - code: vec![Instr { - var: dual_zero, - expr: Expr::F64 { val: 0. }, - }], - }; - for ¶m in f.params.iter() { - ad.unpack(param); - } - let body = ad.block(&f.body); - Func { - generics: f.generics.clone(), - types: types.into(), - vars: vars.into(), - params: f.params.clone(), - ret: f.ret, - body, - } -} diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml deleted file mode 100644 index 34cba31..0000000 --- a/crates/core/Cargo.toml +++ /dev/null @@ -1,8 +0,0 @@ -[package] -name = "rose" -version = "0.4.5" -publish = false -edition = "2021" - -[dependencies] -enumset = "1" diff --git a/crates/core/src/id.rs b/crates/core/src/id.rs deleted file mode 100644 index e8cbcd0..0000000 --- a/crates/core/src/id.rs +++ /dev/null @@ -1,69 +0,0 @@ -/// Index of a member in a tuple. -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub struct Member(usize); - -pub const fn member(id: usize) -> Member { - Member(id) -} - -impl Member { - pub const fn member(self) -> usize { - self.0 - } -} - -/// Index of an uninstantiated function reference in a definition context. -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub struct Func(usize); - -pub const fn func(id: usize) -> Func { - Func(id) -} - -impl Func { - pub const fn func(self) -> usize { - self.0 - } -} - -/// Index of a generic type parameter in a definition context. -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub struct Generic(usize); - -pub const fn generic(id: usize) -> Generic { - Generic(id) -} - -impl Generic { - pub const fn generic(self) -> usize { - self.0 - } -} - -/// Index of a type in a definition context. -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub struct Ty(usize); - -pub const fn ty(id: usize) -> Ty { - Ty(id) -} - -impl Ty { - pub const fn ty(self) -> usize { - self.0 - } -} - -/// Index of a local variable in a function definition context. -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub struct Var(usize); - -pub const fn var(id: usize) -> Var { - Var(id) -} - -impl Var { - pub const fn var(self) -> usize { - self.0 - } -} diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs deleted file mode 100644 index 478123e..0000000 --- a/crates/core/src/lib.rs +++ /dev/null @@ -1,226 +0,0 @@ -pub mod id; - -use enumset::{EnumSet, EnumSetType}; - -/// A type constraint. -#[allow(clippy::derived_hash_with_manual_eq)] // `PartialEq` impl comes from enumset; should be fine -#[derive(Debug, EnumSetType, Hash)] -pub enum Constraint { - /// Not a `Ref`. - Value, - /// Can be the `index` type of an `Array`. - Index, -} - -/// A type. -#[derive(Clone, Debug, Eq, Hash, PartialEq)] -pub enum Ty { - Unit, - Bool, - F64, - /// A nonnegative integer less than `size`. Satisfies `Constraint::Index`. - Fin { - size: usize, - }, - Generic { - id: id::Generic, - }, - Ref { - inner: id::Ty, - }, - Array { - /// Must satisfy `Constraint::Index`. - index: id::Ty, - elem: id::Ty, - }, - Tuple { - members: Box<[id::Ty]>, - }, -} - -/// A function definition. -#[derive(Debug)] -pub struct Func { - /// Generic type parameters. - pub generics: Box<[EnumSet]>, - /// Types used in this function definition. - pub types: Box<[Ty]>, - /// Local variable types. - pub vars: Box<[id::Ty]>, - /// Parameter variables. - pub params: Box<[id::Var]>, - /// Return variable. - pub ret: id::Var, - /// Function body. - pub body: Box<[Instr]>, -} - -/// Resolves `id::Func`s. -pub trait Refs<'a> { - /// See `Node`. - type Opaque; - - /// Resolve `id` to a function node. - fn get(&self, id: id::Func) -> Option> - where - Self: Sized; -} - -/// A node in a graph of functions. -#[derive(Clone, Debug, Copy)] -pub enum Node<'a, O, T: Refs<'a, Opaque = O>> { - /// A function with an explicit body. - Transparent { - /// To traverse the graph by resolving functions called by this one. - refs: T, - /// The signature and definition of this function. - def: &'a Func, - }, - /// A function with an opaque body. - Opaque { - /// Generic type parameters. - generics: &'a [EnumSet], - /// Types used in this function's signature. - types: &'a [Ty], - /// Parameter types. - params: &'a [id::Ty], - /// Return type. - ret: id::Ty, - /// Definition of this function; semantics may vary. - def: O, - }, -} - -#[derive(Debug)] -pub struct Instr { - pub var: id::Var, - pub expr: Expr, -} - -#[derive(Debug)] -pub enum Expr { - Unit, - Bool { - val: bool, - }, - F64 { - val: f64, - }, - Fin { - val: usize, - }, - - Array { - elems: Box<[id::Var]>, - }, - Tuple { - members: Box<[id::Var]>, - }, - - Index { - array: id::Var, - index: id::Var, - }, - Member { - tuple: id::Var, - member: id::Member, - }, - - Slice { - /// Must actually be a `Ref` of an array, not just an array. - array: id::Var, - index: id::Var, - }, - Field { - /// Must actually be a `Ref` of a tuple, not just a tuple. - tuple: id::Var, - member: id::Member, - }, - - Unary { - op: Unop, - arg: id::Var, - }, - Binary { - op: Binop, - left: id::Var, - right: id::Var, - }, - Select { - /// Must be of type `Bool`. - cond: id::Var, - then: id::Var, - els: id::Var, - }, - - Call { - id: id::Func, - generics: Box<[id::Ty]>, - args: Box<[id::Var]>, - }, - For { - /// Type must satisfy `Constraint::Index`. - arg: id::Var, - body: Box<[Instr]>, - /// Variable from `body` holding an array element. - ret: id::Var, - }, - - /// Start a scope for an accumulator `Ref`. - Accum { - /// Topology of the `Ref`. - shape: id::Var, - }, - - /// Accumulate into an accumulator `Ref`. Returns `Unit`. - Add { - /// The `Ref`, which must be in scope. - accum: id::Var, - /// Must be of the `Ref`'s inner type. - addend: id::Var, - }, - - /// Consume a `Ref` to get its contained value. - Resolve { - /// The `Ref`, which must be in scope. - var: id::Var, - }, -} - -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] -pub enum Unop { - // `Bool` -> `Bool` - Not, - - // `F64` -> `F64` - Neg, - Abs, - Sign, - Ceil, - Floor, - Trunc, - Sqrt, -} - -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] -pub enum Binop { - // `Bool` -> `Bool` -> `Bool` - And, - Or, - Iff, - Xor, - - // `F64` -> `F64` -> `Bool` - Neq, - Lt, - Leq, - Eq, - Gt, - Geq, - - // `F64` -> `F64` -> `F64` - Add, - Sub, - Mul, - Div, -} diff --git a/crates/interp/Cargo.toml b/crates/interp/Cargo.toml deleted file mode 100644 index efbdaf5..0000000 --- a/crates/interp/Cargo.toml +++ /dev/null @@ -1,19 +0,0 @@ -[package] -name = "rose-interp" -version = "0.4.5" -publish = false -edition = "2021" - -[dependencies] -enumset = "1" -indexmap = "2" -rose = { path = "../core" } -serde = { version = "1", features = ["derive", "rc"], optional = true } -thiserror = "1" - -[dev-dependencies] -ts-rs = "6" - -[features] -default = ["serde"] -serde = ["dep:serde"] diff --git a/crates/interp/src/lib.rs b/crates/interp/src/lib.rs deleted file mode 100644 index 185c629..0000000 --- a/crates/interp/src/lib.rs +++ /dev/null @@ -1,509 +0,0 @@ -use indexmap::IndexSet; -use rose::{id, Binop, Expr, Func, Node, Refs, Ty, Unop}; -use std::{cell::Cell, convert::Infallible, rc::Rc}; - -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -#[cfg(test)] -use ts_rs::TS; - -#[cfg_attr(test, derive(TS), ts(export))] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Clone, Debug, PartialEq)] -pub enum Val { - Unit, - Bool(bool), - F64(Cell), - Fin(usize), - Ref(Rc, Option), - Array(Vals), // assume all indices are `Fin` - Tuple(Vals), -} - -pub type Vals = Rc>; // TODO: change to `Rc<[Val]>` https://github.com/rose-lang/rose/issues/63 - -pub fn vals(v: [Val; N]) -> Vals { - Rc::new(v.to_vec()) -} - -pub fn collect_vals(it: impl Iterator) -> Vals { - Rc::new(it.collect()) -} - -pub fn val_f64(x: f64) -> Val { - Val::F64(Cell::new(x)) -} - -impl Val { - fn bool(&self) -> bool { - match self { - &Val::Bool(x) => x, - _ => unreachable!(), - } - } - - fn f64(&self) -> f64 { - match self { - Val::F64(x) => x.get(), - _ => unreachable!(), - } - } - - fn fin(&self) -> usize { - match self { - &Val::Fin(i) => i, - _ => unreachable!(), - } - } - - fn get(&self, i: usize) -> &Self { - match self { - Val::Array(x) => &x[i], - Val::Tuple(x) => &x[i], - _ => unreachable!(), - } - } - - fn slice(&self, i: usize) -> Self { - match self { - Val::Ref(x, None) => Val::Ref(Rc::clone(x), Some(i)), - Val::Ref(x, Some(j)) => Val::Ref(Rc::new(x.get(*j).clone()), Some(i)), - _ => unreachable!(), - } - } - - fn inner(&self) -> &Self { - match self { - Val::Ref(x, i) => match i { - None => x.as_ref(), - &Some(j) => x.get(j), - }, - _ => unreachable!(), - } - } - - /// Return a zero value with this value's topology. - fn zero(&self) -> Self { - match self { - Self::Unit => Self::Unit, - &Self::Bool(x) => Self::Bool(x), - Self::F64(_) => Self::F64(Cell::new(0.)), - &Self::Fin(x) => Self::Fin(x), - Self::Ref(..) => unreachable!(), - Self::Array(x) => Self::Array(collect_vals(x.iter().map(|x| x.zero()))), - Self::Tuple(x) => Self::Tuple(collect_vals(x.iter().map(|x| x.zero()))), - } - } - - /// Add `x` to this value, which must represent a mutable `Ref` type. - fn add(&self, x: &Self) { - match (self, x) { - (Self::Unit, Self::Unit) - | (Self::Bool(_), Self::Bool(_)) - | (Self::Fin(_), Self::Fin(_)) => {} - (Self::F64(a), Self::F64(b)) => a.set(a.get() + b.get()), - (Self::Array(a), Self::Array(b)) => { - for (a, b) in a.iter().zip(b.iter()) { - a.add(b); - } - } - (Self::Tuple(a), Self::Tuple(b)) => { - for (a, b) in a.iter().zip(b.iter()) { - a.add(b); - } - } - _ => unreachable!(), - } - } -} - -/// Resolve `ty` via `generics` and `types`, then return its ID in `typemap`, inserting if need be. -/// -/// This is meant to be used to pull all the types from a callee into a broader context. The -/// `generics` are the IDs of all the types provided as generic type parameters for the callee. The -/// `types are the IDs of all the types that have been pulled in so far. -fn resolve(typemap: &mut IndexSet, generics: &[id::Ty], types: &[id::Ty], ty: &Ty) -> id::Ty { - let resolved = match ty { - Ty::Generic { id } => return generics[id.generic()], - - Ty::Unit => Ty::Unit, - Ty::Bool => Ty::Bool, - Ty::F64 => Ty::F64, - &Ty::Fin { size } => Ty::Fin { size }, - - Ty::Ref { inner } => Ty::Ref { - inner: types[inner.ty()], - }, - Ty::Array { index, elem } => Ty::Array { - index: types[index.ty()], - elem: types[elem.ty()], - }, - Ty::Tuple { members } => Ty::Tuple { - members: members.iter().map(|&x| types[x.ty()]).collect(), - }, - }; - let (i, _) = typemap.insert_full(resolved); - id::ty(i) -} - -/// An opaque function that can be called by the interpreter. -pub trait Opaque { - fn call(&self, types: &IndexSet, generics: &[id::Ty], args: &[Val]) -> Val; -} - -impl Opaque for Infallible { - fn call(&self, _: &IndexSet, _: &[id::Ty], _: &[Val]) -> Val { - match *self {} - } -} - -/// basically, the `'a` lifetime is for the graph of functions, and the `'b` lifetime is just for -/// this particular instance of interpretation -struct Interpreter<'a, 'b, O, T: Refs<'a, Opaque = O>> { - typemap: &'b mut IndexSet, - refs: T, - def: &'a Func, - types: Vec, - vars: Vec>, -} - -impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> { - fn new(typemap: &'b mut IndexSet, refs: T, def: &'a Func, generics: &'b [id::Ty]) -> Self { - let mut types = vec![]; - for ty in def.types.iter() { - types.push(resolve(typemap, generics, &types, ty)); - } - Self { - typemap, - refs, - def, - types, - vars: vec![None; def.vars.len()], - } - } - - fn get(&self, var: id::Var) -> &Val { - self.vars[var.var()].as_ref().unwrap() - } - - fn expr(&mut self, expr: &Expr) -> Val { - match expr { - Expr::Unit => Val::Unit, - &Expr::Bool { val } => Val::Bool(val), - &Expr::F64 { val } => val_f64(val), - &Expr::Fin { val } => Val::Fin(val), - - Expr::Array { elems } => { - Val::Array(collect_vals(elems.iter().map(|&x| self.get(x).clone()))) - } - Expr::Tuple { members } => { - Val::Tuple(collect_vals(members.iter().map(|&x| self.get(x).clone()))) - } - - &Expr::Index { array, index } => match (self.get(array), self.get(index)) { - (Val::Array(v), &Val::Fin(i)) => v[i].clone(), - _ => unreachable!(), - }, - &Expr::Member { tuple, member } => match self.get(tuple) { - Val::Tuple(x) => x[member.member()].clone(), - _ => unreachable!(), - }, - - &Expr::Slice { array, index } => self.get(array).slice(self.get(index).fin()), - &Expr::Field { tuple, member } => self.get(tuple).slice(member.member()), - - &Expr::Unary { op, arg } => { - let x = self.get(arg); - match op { - Unop::Not => Val::Bool(!x.bool()), - - Unop::Neg => val_f64(-x.f64()), - Unop::Abs => val_f64(x.f64().abs()), - Unop::Sign => val_f64(x.f64().signum()), - Unop::Ceil => val_f64(x.f64().ceil()), - Unop::Floor => val_f64(x.f64().floor()), - Unop::Trunc => val_f64(x.f64().trunc()), - Unop::Sqrt => val_f64(x.f64().sqrt()), - } - } - &Expr::Binary { op, left, right } => { - let x = self.get(left); - let y = self.get(right); - match op { - Binop::And => Val::Bool(x.bool() && y.bool()), - Binop::Or => Val::Bool(x.bool() || y.bool()), - Binop::Iff => Val::Bool(x.bool() == y.bool()), - Binop::Xor => Val::Bool(x.bool() != y.bool()), - - Binop::Neq => Val::Bool(x.f64() != y.f64()), - Binop::Lt => Val::Bool(x.f64() < y.f64()), - Binop::Leq => Val::Bool(x.f64() <= y.f64()), - Binop::Eq => Val::Bool(x.f64() == y.f64()), - Binop::Gt => Val::Bool(x.f64() > y.f64()), - Binop::Geq => Val::Bool(x.f64() >= y.f64()), - - Binop::Add => val_f64(x.f64() + y.f64()), - Binop::Sub => val_f64(x.f64() - y.f64()), - Binop::Mul => val_f64(x.f64() * y.f64()), - Binop::Div => val_f64(x.f64() / y.f64()), - } - } - &Expr::Select { cond, then, els } => { - if self.get(cond).bool() { - self.get(then).clone() - } else { - self.get(els).clone() - } - } - - Expr::Call { id, generics, args } => { - let resolved: Vec = generics.iter().map(|id| self.types[id.ty()]).collect(); - let vals = args.iter().map(|id| self.vars[id.var()].clone().unwrap()); - call(self.refs.get(*id).unwrap(), self.typemap, &resolved, vals) - } - Expr::For { arg, body, ret } => { - let n = match self.typemap[self.types[self.def.vars[arg.var()].ty()].ty()] { - Ty::Fin { size } => size, - _ => unreachable!(), - }; - Val::Array(collect_vals( - (0..n).map(|i| self.block(*arg, body, *ret, Val::Fin(i)).clone()), - )) - } - - &Expr::Accum { shape } => Val::Ref(Rc::new(self.get(shape).zero()), None), - - &Expr::Add { accum, addend } => { - self.get(accum).inner().add(self.get(addend)); - Val::Unit - } - - &Expr::Resolve { var } => self.get(var).inner().clone(), - } - } - - fn block(&mut self, param: id::Var, body: &[rose::Instr], ret: id::Var, arg: Val) -> &Val { - self.vars[param.var()] = Some(arg); - for instr in body.iter() { - self.vars[instr.var.var()] = Some(self.expr(&instr.expr)); - } - self.vars[ret.var()].as_ref().unwrap() - } -} - -/// Assumes `generics` and `arg` are valid. -fn call<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>>( - f: Node<'a, O, T>, - types: &'b mut IndexSet, - generics: &'b [id::Ty], - args: impl Iterator, -) -> Val { - match f { - Node::Transparent { refs, def } => { - let mut interp = Interpreter::new(types, refs, def, generics); - for (var, arg) in def.params.iter().zip(args) { - interp.vars[var.var()] = Some(arg.clone()); - } - for instr in def.body.iter() { - interp.vars[instr.var.var()] = Some(interp.expr(&instr.expr)); - } - interp.vars[def.ret.var()].as_ref().unwrap().clone() - } - Node::Opaque { - generics: _, - types: _, - params: _, - ret: _, - def, - } => { - let vals: Box<[Val]> = args.collect(); - def.call(types, generics, &vals) - } - } -} - -#[derive(Debug, thiserror::Error)] -pub enum Error {} - -/// Guaranteed not to panic if `f` is valid. -pub fn interp<'a, O: Opaque, T: Refs<'a, Opaque = O>>( - f: Node<'a, O, T>, - mut types: IndexSet, - generics: &'a [id::Ty], - args: impl Iterator, -) -> Result { - // TODO: check that `generics` and `arg` are valid - Ok(call(f, &mut types, generics, args)) -} - -#[cfg(test)] -mod tests { - use super::*; - use rose::{Func, Instr}; - - type CustomRef<'a> = &'a dyn Fn(&IndexSet, &[id::Ty], &[Val]) -> Val; - type CustomBox = Box, &[id::Ty], &[Val]) -> Val>; - - struct Custom<'a> { - f: CustomRef<'a>, - } - - impl Opaque for Custom<'_> { - fn call(&self, types: &IndexSet, generics: &[id::Ty], args: &[Val]) -> Val { - (self.f)(types, generics, args) - } - } - - struct FuncInSlice<'a> { - custom: &'a [CustomBox], - funcs: &'a [Func], - id: id::Func, - } - - impl<'a> Refs<'a> for FuncInSlice<'a> { - type Opaque = Custom<'a>; - - fn get(&self, id: id::Func) -> Option, Self>> { - if id.func() < self.id.func() { - node(self.custom, self.funcs, id) - } else { - None - } - } - } - - fn node<'a>( - custom: &'a [CustomBox], - funcs: &'a [Func], - id: id::Func, - ) -> Option, FuncInSlice<'a>>> { - let n = custom.len(); - let i = id.func(); - if i < n { - Some(Node::Opaque { - generics: &[], - types: &[], - params: &[], - ret: id::ty(0), - def: Custom { f: &custom[i] }, - }) - } else { - funcs.get(i - n).map(|def| Node::Transparent { - refs: FuncInSlice { custom, funcs, id }, - def, - }) - } - } - - #[test] - fn test_two_plus_two() { - let funcs = vec![Func { - generics: vec![].into(), - types: vec![Ty::F64].into(), - vars: vec![id::ty(0), id::ty(0), id::ty(0)].into(), - params: vec![id::var(0), id::var(1)].into(), - ret: id::var(2), - body: vec![Instr { - var: id::var(2), - expr: Expr::Binary { - op: Binop::Add, - left: id::var(0), - right: id::var(1), - }, - }] - .into(), - }]; - let answer = interp( - node(&[], &funcs, id::func(0)).unwrap(), - IndexSet::new(), - &[], - [val_f64(2.), val_f64(2.)].into_iter(), - ) - .unwrap(); - assert_eq!(answer, val_f64(4.)); - } - - #[test] - fn test_nested_call() { - let funcs = vec![ - Func { - generics: vec![].into(), - types: vec![Ty::F64].into(), - vars: vec![id::ty(0)].into(), - params: vec![].into(), - ret: id::var(0), - body: vec![Instr { - var: id::var(0), - expr: Expr::F64 { val: 42. }, - }] - .into(), - }, - Func { - generics: vec![].into(), - types: vec![Ty::F64].into(), - vars: vec![id::ty(0), id::ty(0)].into(), - params: vec![].into(), - ret: id::var(1), - body: vec![ - Instr { - var: id::var(0), - expr: Expr::Call { - id: id::func(0), - generics: vec![].into(), - args: vec![].into(), - }, - }, - Instr { - var: id::var(1), - expr: Expr::Binary { - op: Binop::Mul, - left: id::var(0), - right: id::var(0), - }, - }, - ] - .into(), - }, - ]; - let answer = interp( - node(&[], &funcs, id::func(1)).unwrap(), - IndexSet::new(), - &[], - [].into_iter(), - ) - .unwrap(); - assert_eq!(answer, val_f64(1764.)); - } - - #[test] - fn test_custom() { - let custom: [CustomBox; 1] = [Box::new(|_, _, args| { - Val::F64(Cell::new(args[0].f64().powf(args[1].f64()))) - })]; - let funcs = [Func { - generics: [].into(), - types: [Ty::F64].into(), - vars: [id::ty(0), id::ty(0), id::ty(0)].into(), - params: [id::var(0), id::var(1)].into(), - ret: id::var(2), - body: [Instr { - var: id::var(2), - expr: Expr::Call { - id: id::func(0), - generics: [].into(), - args: [id::var(0), id::var(1)].into(), - }, - }] - .into(), - }]; - let answer = interp( - node(&custom, &funcs, id::func(1)).unwrap(), - IndexSet::new(), - &[], - [val_f64(std::f64::consts::E), val_f64(std::f64::consts::PI)].into_iter(), - ) - .unwrap(); - assert_eq!(answer, val_f64(23.140692632779263)); - } -} diff --git a/crates/transpose/Cargo.toml b/crates/transpose/Cargo.toml deleted file mode 100644 index ad6c9af..0000000 --- a/crates/transpose/Cargo.toml +++ /dev/null @@ -1,10 +0,0 @@ -[package] -name = "rose-transpose" -version = "0.4.5" -publish = false -edition = "2021" - -[dependencies] -enumset = "1" -indexmap = "2" -rose = { path = "../core" } diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs deleted file mode 100644 index 40c3e35..0000000 --- a/crates/transpose/src/lib.rs +++ /dev/null @@ -1,1273 +0,0 @@ -use indexmap::{indexset, IndexSet}; -use rose::{id, Binop, Expr, Func, Instr, Ty, Unop}; -use std::mem::{replace, swap, take}; - -/// By convention, the first type in a function to be transposed must be the nonlinear `F64`. -const REAL: id::Ty = id::ty(0); - -/// By convention, the first type in a function to be transposed must be the linear `F64`. -const DUAL: id::Ty = id::ty(1); - -/// Return true iff `t` is the type ID of a linear type in a function to be transposed. -/// -/// In this module, "primitive" specifically means a linear or nonlinear `F64` type, and -/// specifically excludes other types that might be considered primitive, such as `Unit` or `Bool`. -fn is_primitive(t: id::Ty) -> bool { - t == REAL || t == DUAL -} - -/// By convention, the first member in a type for the dual numbers must be the linear part. -const DU: id::Member = id::member(0); - -/// By convention, the second member in a type for the dual numbers must be the nonlinear part. -const RE: id::Member = id::member(1); - -/// The source of a primitive variable or a component of a dual number variable. -/// -/// The value `None` means that this is the original source, whereas `Some` means that it is an -/// alias of a the given primitive variable or a component of the given dual number variable. -#[derive(Clone, Copy)] -struct Src(Option); - -impl Src { - /// Return the source for a variable derived from `self`. - /// - /// The variable ID `x` represents `self`, not the new source being returned. - fn derive(self, x: id::Var) -> Self { - match self.0 { - None => Self(Some(x)), - _ => self, - } - } -} - -/// Linear variables in the backward pass for a variable from the original function. -struct Lin { - /// Accumulator variable. - acc: id::Var, - - /// Resolved cotangent variable. - cot: id::Var, -} - -/// A block under construction for both the forward pass and the backward pass. -struct Block { - /// Instructions in this forward pass block, in order. - fwd: Vec, - - /// Variable IDs for intermediate values to be saved at the end of this forward pass block. - inter_mem: Vec, - - /// Variable ID for the intermediate values tuple in this backward pass block. - inter_tup: id::Var, - - /// Instructions at the beginning of this backward pass block, in order. - bwd_nonlin: Vec, - - /// Instructions at the end of this backward pass block, in reverse order. - bwd_lin: Vec, -} - -/// The forward pass and backward pass of a transposed function under construction. -struct Transpose<'a> { - /// The function being transposed, which is usually a forward-mode derivative. - f: &'a Func, - - deps: &'a [(&'a [Ty], id::Ty)], - - /// Mapped versions of `f.types`. - /// - /// The only reason this is useful is to easily check whether a type is the dual number type by - /// looking it up here to see if it's equal to `F64`. - mapped_types: Vec, - - /// Additional types, shared between the forward and backward passes. - /// - /// This starts out empty: at first we only have `mapped_types`, but later we'll add more types - /// for tuples and arrays of intermediate values that are shared between the two passes, and - /// also for new reference types that are only used in the backward pass. These `types` will - /// later all be appended onto `mapped_types`, so any type indices referencing them should be - /// offset by `mapped_types.len()`. - types: IndexSet, - - /// Type ID for `Unit`. - /// - /// We could get this every time by looking up in `types`, but it's easier to just always put it - /// in at the beginning to save ourselves the repeated hash lookups. - unit: id::Ty, - - /// Types of variables in the forward pass. - /// - /// This starts out as a clone of `f.vars`, but more variables can be added for dealing with - /// intermediate values. - fwd_vars: Vec, - - /// Types of variables in the backward pass. - /// - /// This starts out with `Some` type for every variable from `f.vars`, but more variables can be - /// added for intermediate values, accumulators, and cotangents. The only time a variable's type - /// here is `None` is for tuples of intermediate values; see the `inter_tup` field in `Block`. - bwd_vars: Vec>, - - /// A variable of type `F64` defined at the beginning of the backward pass. - /// - /// Every accumulator must be initialized using a concrete variable to dictate its topology. For - /// most variables, we keep around the original nonlinear value and use that as the shape, but - /// this doesn't work for raw linear `F64` variables, which might be part of some lower-level - /// mathematical calculation that is not clearly attached to any value at the dual number level - /// or higher. All those values have the same shape, though, so in those cases we just use this - /// one dummy variable as the shape. - real_shape: id::Var, - - /// Sources of primitive variables from the original function. - prims: Box<[Option]>, - - /// Sources for dual number variables from the original function. - /// - /// The sources are for the nonlinear part and the linear part, respectively; note that this - /// disagrees with the order for tuple members dictated by `DU` and `RE`. - duals: Box<[Option<(Src, Src)>]>, - - /// Accumulator variables for variables from the original function. - accums: Box<[Option]>, - - /// Cotangent variables for variables from the original function. - cotans: Box<[Option]>, - - /// Stack of pending unreversed instructions for the backward pass. - /// - /// In general, we keep track of reversed instructions in the `bwd_lin` field of `block`; those - /// will go after the unreversed `bwd_nonlin` instructions. When we enter a new scope via - /// `Expr::For`, we start an entirely new `block`, so even though those inner `bwd_nonlin` - /// instructions may end up interleaved with our current `bwd_lin` instructions, that's fine - /// because they're going in a separate instruction list anyway. - /// - /// But for `Expr::Accum` and `Expr::Resolve`, we're introducing a new scope without actually - /// starting a new `block`. In that case, we still need for all the instructions we put in - /// `bwd_nonlin` during this scope to go before all our `bwd_lin` instructions from the scope, - /// but we also need them to go after any `bwd_lin` instructions we add after the scope ends. - /// So, what we do is push `bwd_nonlin` onto this `stack` when we enter the scope via - /// `Expr::Accum`, and then when we exit the scope via `Expr::Resolve`, we pop it off, reverse - /// it, and append it to `bwd_lin`. Then when we finally finish the actual block, the stack - /// should be empty, so we just reverse `bwd_lin` and append it to `bwd_nonlin` as normal. - stack: Vec>, - - /// The current block under construction. - block: Block, -} - -impl<'a> Transpose<'a> { - /// Return the ID for `ty`, adding it to `types` if it isn't already there. - fn ty(&mut self, ty: Ty) -> id::Ty { - let (i, _) = self.types.insert_full(ty); - id::ty(self.f.types.len() + i) - } - - fn translate(&mut self, generics: &[id::Ty], types: &[id::Ty], ty: &rose::Ty) -> id::Ty { - self.ty(match ty { - Ty::Unit => Ty::Unit, - Ty::Bool => Ty::Bool, - Ty::F64 => Ty::F64, - &Ty::Fin { size } => Ty::Fin { size }, - Ty::Generic { id } => return generics[id.generic()], - Ty::Ref { inner } => Ty::Ref { - inner: types[inner.ty()], - }, - Ty::Array { index, elem } => Ty::Array { - index: types[index.ty()], - elem: types[elem.ty()], - }, - Ty::Tuple { members } => Ty::Tuple { - members: members.iter().map(|&member| types[member.ty()]).collect(), - }, - }) - } - - /// Return the source of a variable that is the nonlinear part of `x`. - fn re(&self, x: id::Var) -> Src { - let (src, _) = self.duals[x.var()].unwrap(); - src.derive(x) - } - - /// Return the source of a variable that is the linear part of `x`. - fn du(&self, x: id::Var) -> Src { - let (_, src) = self.duals[x.var()].unwrap(); - src.derive(x) - } - - /// Return the source variable for `x`, which has a primitive type. - fn get_prim(&self, x: id::Var) -> id::Var { - match self.prims[x.var()].unwrap() { - Src(None) => x, - Src(Some(y)) => y, - } - } - - /// Return the source variable for the nonlinear part of `x`, which has a non-primitive type. - /// - /// Every non-primitive variable whose type is not the dual numbers is considered original. - fn get_re(&self, x: id::Var) -> id::Var { - match self.duals[x.var()] { - None | Some((Src(None), _)) => x, - Some((Src(Some(y)), _)) => y, - } - } - - /// Return the source variable for the linear part of `x`, which has a non-primitive type. - /// - /// Every non-primitive variable whose type is not the dual numbers is considered original. - fn get_du(&self, x: id::Var) -> id::Var { - match self.duals[x.var()] { - None | Some((_, Src(None))) => x, - Some((_, Src(Some(y)))) => y, - } - } - - /// Return the accumulator variable for `x`, which has a primitive type. - fn get_prim_accum(&self, x: id::Var) -> id::Var { - self.accums[self.get_prim(x).var()].unwrap() - } - - /// Return the accumulator variable for the linear part of `x`, which has a non-primitive type. - fn get_accum(&self, x: id::Var) -> id::Var { - self.accums[self.get_du(x).var()].unwrap() - } - - /// Return the cotangent variable for the linear part of `x`, which has a non-primitive type. - fn get_cotan(&self, x: id::Var) -> id::Var { - self.cotans[self.get_du(x).var()].unwrap() - } - - /// Return the ID for a new variable with type ID `t` in the forward pass. - fn fwd_var(&mut self, t: id::Ty) -> id::Var { - let var = id::var(self.fwd_vars.len()); - self.fwd_vars.push(t); - var - } - - /// Return the ID for a new variable with type ID `t` in the backward pass. - /// - /// `t` should be `None` iff it is a tuple of intermediate values. - fn bwd_var(&mut self, t: Option) -> id::Var { - let var = id::var(self.bwd_vars.len()); - self.bwd_vars.push(t); - var - } - - /// Include `var` in the intermediate values tuple for the current block. - fn keep(&mut self, var: id::Var) { - self.block.bwd_nonlin.push(Instr { - var, - expr: Expr::Member { - tuple: self.block.inter_tup, - member: id::member(self.block.inter_mem.len()), - }, - }); - self.block.inter_mem.push(var); - } - - /// Create a non-primitive accumulator for `shape`; return it along with its eventual cotangent. - fn accum(&mut self, shape: id::Var) -> Lin { - let t_cot = self.f.vars[shape.var()]; - let t_acc = self.ty(Ty::Ref { inner: t_cot }); - let acc = self.bwd_var(Some(t_acc)); - let cot = self.bwd_var(Some(t_cot)); - self.block.bwd_nonlin.push(Instr { - var: acc, - expr: Expr::Accum { shape }, - }); - self.accums[shape.var()] = Some(acc); - self.cotans[shape.var()] = Some(cot); - Lin { acc, cot } - } - - /// Create a primitive accumulator for the given `tangent`, using `self.real_shape`. - fn calc(&mut self, tangent: id::Var) -> Lin { - let t_cot = self.f.vars[tangent.var()]; - let t_acc = self.ty(Ty::Ref { inner: t_cot }); - let acc = self.bwd_var(Some(t_acc)); - let cot = self.bwd_var(Some(t_cot)); - self.block.bwd_nonlin.push(Instr { - var: acc, - expr: Expr::Accum { - shape: self.real_shape, - }, - }); - self.accums[tangent.var()] = Some(acc); - self.cotans[tangent.var()] = Some(cot); - Lin { acc, cot } - } - - /// Resolve the given accumulator. - fn resolve(&mut self, lin: Lin) { - self.block.bwd_lin.push(Instr { - var: lin.cot, - expr: Expr::Resolve { var: lin.acc }, - }) - } - - /// Process `block` and return the type and forward variable for the intermediate values tuple. - fn block(&mut self, block: &[Instr]) -> (id::Ty, id::Var) { - for instr in block.iter() { - self.instr(instr.var, &instr.expr); - } - let vars = take(&mut self.block.inter_mem); - let t = self.ty(Ty::Tuple { - members: vars.iter().map(|&x| self.fwd_vars[x.var()]).collect(), - }); - let var = self.fwd_var(t); - self.block.fwd.push(Instr { - var, - expr: Expr::Tuple { - members: vars.into(), - }, - }); - self.bwd_vars[self.block.inter_tup.var()] = Some(t); - (t, var) - } - - /// Process the instruction with the given `var` and `expr`. - fn instr(&mut self, var: id::Var, expr: &Expr) { - match expr { - Expr::Unit => { - self.block.fwd.push(Instr { - var, - expr: Expr::Unit, - }); - self.block.bwd_nonlin.push(Instr { - var, - expr: Expr::Unit, - }); - let lin = self.accum(var); - self.resolve(lin); - } - &Expr::Bool { val } => { - self.block.fwd.push(Instr { - var, - expr: Expr::Bool { val }, - }); - self.block.bwd_nonlin.push(Instr { - var, - expr: Expr::Bool { val }, - }); - let lin = self.accum(var); - self.resolve(lin); - } - &Expr::F64 { val } => { - match self.f.vars[var.var()] { - DUAL => { - let lin = self.calc(var); - self.resolve(lin); - } - _ => { - self.block.fwd.push(Instr { - var, - expr: Expr::F64 { val }, - }); - self.block.bwd_nonlin.push(Instr { - var, - expr: Expr::F64 { val }, - }); - let lin = self.accum(var); - self.resolve(lin); - } - } - self.prims[var.var()] = Some(Src(None)); - } - &Expr::Fin { val } => { - self.block.fwd.push(Instr { - var, - expr: Expr::Fin { val }, - }); - self.block.bwd_nonlin.push(Instr { - var, - expr: Expr::Fin { val }, - }); - let lin = self.accum(var); - self.resolve(lin); - } - - Expr::Array { elems } => { - let t = match self.f.types[self.f.vars[var.var()].ty()] { - Ty::Array { index, elem: _ } => index, - _ => panic!(), - }; - self.block.fwd.push(Instr { - var, - expr: Expr::Array { - elems: elems.iter().map(|&elem| self.get_re(elem)).collect(), - }, - }); - self.keep(var); - let lin = self.accum(var); - for (i, &elem) in elems.iter().enumerate() { - let accum = self.get_accum(elem); - let index = self.bwd_var(Some(t)); - let addend = self.bwd_var(Some(self.f.vars[elem.var()])); - let unit = self.bwd_var(Some(self.unit)); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Add { accum, addend }, - }); - self.block.bwd_lin.push(Instr { - var: addend, - expr: Expr::Index { - array: lin.cot, - index, - }, - }); - self.block.bwd_lin.push(Instr { - var: index, - expr: Expr::Fin { val: i }, - }); - } - self.resolve(lin); - } - Expr::Tuple { members } => match self.mapped_types[self.f.vars[var.var()].ty()] { - Ty::F64 => { - let x = members[RE.member()]; - let dx = members[DU.member()]; - self.duals[var.var()] = Some(( - self.prims[x.var()].unwrap().derive(x), - self.prims[dx.var()].unwrap().derive(dx), - )); - } - _ => { - self.block.fwd.push(Instr { - var, - expr: Expr::Tuple { - members: members.iter().map(|&member| self.get_re(member)).collect(), - }, - }); - self.keep(var); - let lin = self.accum(var); - for (i, &member) in members.iter().enumerate() { - let accum = self.get_accum(member); - let addend = self.bwd_var(Some(self.f.vars[member.var()])); - let unit = self.bwd_var(Some(self.unit)); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Add { accum, addend }, - }); - self.block.bwd_lin.push(Instr { - var: addend, - expr: Expr::Member { - tuple: lin.cot, - member: id::member(i), - }, - }); - } - self.resolve(lin); - } - }, - - &Expr::Index { array, index } => { - self.block.fwd.push(Instr { - var, - expr: Expr::Index { array, index }, - }); - self.block.bwd_nonlin.push(Instr { - var, - expr: Expr::Index { array, index }, - }); - let arr_acc = self.get_accum(array); - let t_acc = self.ty(Ty::Ref { - inner: self.f.vars[var.var()], - }); - let acc = self.bwd_var(Some(t_acc)); - self.accums[var.var()] = Some(acc); - self.block.bwd_nonlin.push(Instr { - var: acc, - expr: Expr::Slice { - array: arr_acc, - index, - }, - }); - if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] { - self.duals[var.var()] = Some((Src(None), Src(None))); - } - } - &Expr::Member { tuple, member } => { - let t = self.f.vars[var.var()]; - match t { - REAL => self.prims[var.var()] = Some(self.re(tuple)), - DUAL => self.prims[var.var()] = Some(self.du(tuple)), - _ => { - self.block.fwd.push(Instr { - var, - expr: Expr::Member { tuple, member }, - }); - self.block.bwd_nonlin.push(Instr { - var, - expr: Expr::Member { tuple, member }, - }); - let tup_acc = self.get_accum(tuple); - let t_acc = self.ty(Ty::Ref { - inner: self.f.vars[var.var()], - }); - let acc = self.bwd_var(Some(t_acc)); - self.accums[var.var()] = Some(acc); - self.block.bwd_nonlin.push(Instr { - var: acc, - expr: Expr::Field { - tuple: tup_acc, - member, - }, - }); - if let Ty::F64 = self.mapped_types[t.ty()] { - self.duals[var.var()] = Some((Src(None), Src(None))); - } - } - } - } - - &Expr::Slice { array, index } => { - self.block.fwd.push(Instr { - var, - expr: Expr::Slice { array, index }, - }); - - let t_cot = match &self.f.types[self.f.vars[var.var()].ty()] { - &Ty::Ref { inner } => inner, - _ => panic!(), - }; - let cot = self.bwd_var(Some(t_cot)); - self.block.bwd_nonlin.push(Instr { - var: cot, - expr: Expr::Index { - array: self.get_cotan(array), - index, - }, - }); - self.cotans[var.var()] = Some(cot); - if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] { - self.duals[var.var()] = Some((Src(None), Src(None))); - } - } - &Expr::Field { tuple, member } => { - self.block.fwd.push(Instr { - var, - expr: Expr::Field { tuple, member }, - }); - - let t_cot = match &self.f.types[self.f.vars[var.var()].ty()] { - &Ty::Ref { inner } => inner, - _ => panic!(), - }; - let cot = self.bwd_var(Some(t_cot)); - self.block.bwd_nonlin.push(Instr { - var: cot, - expr: Expr::Member { - tuple: self.get_cotan(tuple), - member, - }, - }); - self.cotans[var.var()] = Some(cot); - if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] { - self.duals[var.var()] = Some((Src(None), Src(None))); - } - } - - &Expr::Unary { op, arg } => { - match self.f.vars[var.var()] { - DUAL => match op { - Unop::Not - | Unop::Abs - | Unop::Sign - | Unop::Ceil - | Unop::Floor - | Unop::Trunc - | Unop::Sqrt => panic!(), - Unop::Neg => { - let lin = self.calc(var); - let res = self.bwd_var(Some(DUAL)); - let unit = self.bwd_var(Some(self.unit)); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Add { - accum: self.get_prim_accum(arg), - addend: res, - }, - }); - self.block.bwd_lin.push(Instr { - var: res, - expr: Expr::Unary { - op: Unop::Neg, - arg: lin.cot, - }, - }); - self.resolve(lin); - } - }, - _ => { - let x = match op { - Unop::Not => arg, - Unop::Neg - | Unop::Abs - | Unop::Sign - | Unop::Ceil - | Unop::Floor - | Unop::Trunc - | Unop::Sqrt => self.get_prim(arg), - }; - self.block.fwd.push(Instr { - var, - expr: Expr::Unary { op, arg: x }, - }); - self.keep(var); - let lin = self.accum(var); - self.resolve(lin); - } - } - self.prims[var.var()] = Some(Src(None)); - } - &Expr::Binary { op, left, right } => { - match self.f.vars[var.var()] { - DUAL => { - let lin = self.calc(var); - match op { - Binop::And - | Binop::Or - | Binop::Iff - | Binop::Xor - | Binop::Neq - | Binop::Lt - | Binop::Leq - | Binop::Eq - | Binop::Gt - | Binop::Geq => panic!(), - Binop::Add => { - let a = self.bwd_var(Some(self.unit)); - let b = self.bwd_var(Some(self.unit)); - self.block.bwd_lin.push(Instr { - var: a, - expr: Expr::Add { - accum: self.get_prim_accum(left), - addend: lin.cot, - }, - }); - self.block.bwd_lin.push(Instr { - var: b, - expr: Expr::Add { - accum: self.get_prim_accum(right), - addend: lin.cot, - }, - }); - } - Binop::Sub => { - let res = self.bwd_var(Some(DUAL)); - let a = self.bwd_var(Some(self.unit)); - let b = self.bwd_var(Some(self.unit)); - self.block.bwd_lin.push(Instr { - var: a, - expr: Expr::Add { - accum: self.get_prim_accum(left), - addend: lin.cot, - }, - }); - self.block.bwd_lin.push(Instr { - var: b, - expr: Expr::Add { - accum: self.get_prim_accum(right), - addend: res, - }, - }); - self.block.bwd_lin.push(Instr { - var: res, - expr: Expr::Unary { - op: Unop::Neg, - arg: lin.cot, - }, - }); - } - Binop::Mul | Binop::Div => { - let res = self.bwd_var(Some(DUAL)); - let unit = self.bwd_var(Some(self.unit)); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Add { - accum: self.get_prim_accum(left), - addend: res, - }, - }); - self.block.bwd_lin.push(Instr { - var: res, - expr: Expr::Binary { - op, - left: lin.cot, - right: self.get_prim(right), - }, - }); - } - } - self.resolve(lin); - } - _ => { - let (a, b) = match op { - Binop::And | Binop::Or | Binop::Iff | Binop::Xor => (left, right), - Binop::Neq - | Binop::Lt - | Binop::Leq - | Binop::Eq - | Binop::Gt - | Binop::Geq - | Binop::Add - | Binop::Sub - | Binop::Mul - | Binop::Div => (self.get_prim(left), self.get_prim(right)), - }; - self.block.fwd.push(Instr { - var, - expr: Expr::Binary { - op, - left: a, - right: b, - }, - }); - self.keep(var); - let lin = self.accum(var); - self.resolve(lin); - } - } - self.prims[var.var()] = Some(Src(None)); - } - &Expr::Select { cond, then, els } => { - let t = self.f.vars[var.var()]; - - self.block.fwd.push(Instr { - var, - expr: Expr::Select { - cond, - then: self.get_re(then), - els: self.get_re(els), - }, - }); - - match &self.f.types[t.ty()] { - &Ty::Ref { inner } => { - let cot = self.bwd_var(Some(inner)); - self.block.bwd_nonlin.push(Instr { - var: cot, - expr: Expr::Select { - cond, - then: self.get_cotan(then), - els: self.get_cotan(els), - }, - }); - self.cotans[var.var()] = Some(cot); - } - _ => { - self.keep(var); - if t == REAL { - self.prims[var.var()] = Some(Src(None)); - } else { - let lin = self.accum(var); - let acc_then = self.get_accum(then); - let acc_els = self.get_accum(els); - let t_acc = self.ty(Ty::Ref { - inner: self.f.vars[var.var()], - }); - let acc = self.bwd_var(Some(t_acc)); - let unit = self.bwd_var(Some(self.unit)); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Add { - accum: acc, - addend: lin.cot, - }, - }); - self.block.bwd_lin.push(Instr { - var: acc, - expr: Expr::Select { - cond, - then: acc_then, - els: acc_els, - }, - }); - self.resolve(lin); - } - } - } - - if let Ty::F64 = self.mapped_types[t.ty()] { - self.duals[var.var()] = Some((Src(None), Src(None))); - } - } - - Expr::Call { id, generics, args } => match self.f.vars[var.var()] { - REAL => { - self.block.fwd.push(Instr { - var, - expr: Expr::Call { - id: *id, - generics: generics.clone(), - args: args.iter().map(|&arg| self.get_prim(arg)).collect(), - }, - }); - self.keep(var); - self.prims[var.var()] = Some(Src(None)); - } - _ => { - let (dep_types, t) = self.deps[id.func()]; - let mut types = vec![]; - for ty in dep_types { - types.push(self.translate(generics, &types, ty)); - } - let t_tup = types[t.ty()]; - - let t_bundle = self.ty(Ty::Tuple { - members: [self.f.vars[var.var()], t_tup].into(), - }); - let bundle = self.fwd_var(t_bundle); - self.block.fwd.push(Instr { - var: bundle, - expr: Expr::Call { - id: *id, - generics: generics.clone(), - args: args.iter().map(|&arg| self.get_re(arg)).collect(), - }, - }); - - self.block.fwd.push(Instr { - var, - expr: Expr::Member { - tuple: bundle, - member: id::member(0), - }, - }); - self.keep(var); - - let inter_fwd = self.fwd_var(t_tup); - let inter_bwd = self.bwd_var(Some(t_tup)); - self.block.fwd.push(Instr { - var: inter_fwd, - expr: Expr::Member { - tuple: bundle, - member: id::member(1), - }, - }); - self.block.bwd_nonlin.push(Instr { - var: inter_bwd, - expr: Expr::Member { - tuple: self.block.inter_tup, - member: id::member(self.block.inter_mem.len()), - }, - }); - self.block.inter_mem.push(inter_fwd); - - let lin = self.accum(var); - let unit = self.bwd_var(Some(self.unit)); - let mut args: Vec<_> = args - .iter() - .map(|&arg| match self.f.types[self.f.vars[arg.var()].ty()] { - Ty::Ref { .. } => self.get_cotan(arg), - _ => self.get_accum(arg), - }) - .collect(); - args.push(lin.cot); - args.push(inter_bwd); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Call { - id: *id, - generics: generics.clone(), - args: args.into(), - }, - }); - self.resolve(lin); - - if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] { - self.duals[var.var()] = Some((Src(None), Src(None))); - } - } - }, - Expr::For { arg, body, ret } => { - let t_index = self.f.vars[arg.var()]; - let t_elem = self.f.vars[ret.var()]; - - let mut block = Block { - fwd: vec![], - inter_mem: vec![], - inter_tup: self.bwd_var(None), - bwd_nonlin: vec![], - bwd_lin: vec![], - }; - swap(&mut self.block, &mut block); - let (t_inter, fwd_inter) = self.block(body); - swap(&mut self.block, &mut block); - - let t_bundle = self.ty(Ty::Tuple { - members: [t_elem, t_inter].into(), - }); - let bundle = self.fwd_var(t_bundle); - block.fwd.push(Instr { - var: bundle, - expr: Expr::Tuple { - members: [self.get_re(*ret), fwd_inter].into(), - }, - }); - let t_arr_bundle = self.ty(Ty::Array { - index: t_index, - elem: t_bundle, - }); - let arr_bundle = self.fwd_var(t_arr_bundle); - self.block.fwd.push(Instr { - var: arr_bundle, - expr: Expr::For { - arg: *arg, - body: block.fwd.into(), - ret: bundle, - }, - }); - let fst_index = self.fwd_var(t_index); - let fst_bundle = self.fwd_var(t_bundle); - let elem = self.fwd_var(t_elem); - self.block.fwd.push(Instr { - var, - expr: Expr::For { - arg: fst_index, - body: [ - Instr { - var: fst_bundle, - expr: Expr::Index { - array: arr_bundle, - index: fst_index, - }, - }, - Instr { - var: elem, - expr: Expr::Member { - tuple: fst_bundle, - member: id::member(0), - }, - }, - ] - .into(), - ret: elem, - }, - }); - self.keep(var); - let t_arr_inter = self.ty(Ty::Array { - index: t_index, - elem: t_inter, - }); - let arr_inter = self.fwd_var(t_arr_inter); - let snd_index = self.fwd_var(t_index); - let snd_bundle = self.fwd_var(t_bundle); - let inter = self.fwd_var(t_inter); - self.block.fwd.push(Instr { - var: arr_inter, - expr: Expr::For { - arg: snd_index, - body: [ - Instr { - var: snd_bundle, - expr: Expr::Index { - array: arr_bundle, - index: snd_index, - }, - }, - Instr { - var: inter, - expr: Expr::Member { - tuple: snd_bundle, - member: id::member(1), - }, - }, - ] - .into(), - ret: inter, - }, - }); - - let arr_inter_bwd = self.bwd_var(Some(t_arr_inter)); - self.block.bwd_nonlin.push(Instr { - var: arr_inter_bwd, - expr: Expr::Member { - tuple: self.block.inter_tup, - member: id::member(self.block.inter_mem.len()), - }, - }); - self.block.inter_mem.push(arr_inter); - - let lin = self.accum(var); - let bwd_acc = self.get_accum(*ret); - let bwd_cot = self.bwd_var(Some(t_elem)); - let mut bwd_body = vec![ - Instr { - var: bwd_cot, - expr: Expr::Index { - array: lin.cot, - index: *arg, - }, - }, - Instr { - var: block.inter_tup, - expr: Expr::Index { - array: arr_inter_bwd, - index: *arg, - }, - }, - ]; - bwd_body.append(&mut block.bwd_nonlin); - let unit = self.bwd_var(Some(self.unit)); - bwd_body.push(Instr { - var: unit, - expr: Expr::Add { - accum: bwd_acc, - addend: bwd_cot, - }, - }); - block.bwd_lin.reverse(); - bwd_body.append(&mut block.bwd_lin); - let bwd_ret = self.bwd_var(Some(self.unit)); - bwd_body.push(Instr { - var: bwd_ret, - expr: Expr::Unit, - }); - let t_arr_unit = self.ty(Ty::Array { - index: t_index, - elem: self.unit, - }); - let arr_unit = self.bwd_var(Some(t_arr_unit)); - self.block.bwd_lin.push(Instr { - var: arr_unit, - expr: Expr::For { - arg: *arg, - body: bwd_body.into(), - ret: bwd_ret, - }, - }); - self.resolve(lin); - } - - &Expr::Accum { shape } => { - self.block.fwd.push(Instr { - var, - expr: Expr::Accum { - shape: self.get_re(shape), - }, - }); - - let cot = self.bwd_var(Some(self.f.vars[shape.var()])); - self.cotans[var.var()] = Some(cot); - self.stack.push(take(&mut self.block.bwd_nonlin)); - } - - &Expr::Add { accum, addend } => { - self.block.fwd.push(Instr { - var, - expr: Expr::Add { - accum, - addend: self.get_re(addend), - }, - }); - - self.block.bwd_nonlin.push(Instr { - var, - expr: Expr::Unit, - }); - let lin = self.accum(var); - let unit = self.bwd_var(Some(self.unit)); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Add { - accum: self.get_accum(addend), - addend: self.get_cotan(accum), - }, - }); - self.resolve(lin); - } - - &Expr::Resolve { var: accum } => { - self.block.fwd.push(Instr { - var, - expr: Expr::Resolve { var: accum }, - }); - - let mut bwd_nonlin = replace(&mut self.block.bwd_nonlin, self.stack.pop().unwrap()); - bwd_nonlin.reverse(); - self.block.bwd_lin.append(&mut bwd_nonlin); - let acc = self.bwd_var(Some(self.f.vars[accum.var()])); - self.block.bwd_lin.push(Instr { - var: self.get_cotan(accum), - expr: Expr::Resolve { var: acc }, - }); - self.keep(var); - self.block.bwd_nonlin.push(Instr { - var: acc, - expr: Expr::Accum { shape: var }, - }); - self.accums[var.var()] = Some(acc); - if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] { - self.duals[var.var()] = Some((Src(None), Src(None))); - } - } - } - } -} - -/// Return the forward and backward pass for the transpose of `f`. -pub fn transpose(f: &Func, deps: &[(&[Ty], id::Ty)]) -> (Func, Func) { - let mut bwd_vars: Vec<_> = f.vars.iter().map(|&t| Some(t)).collect(); - let real_shape = id::var(bwd_vars.len()); - bwd_vars.push(Some(DUAL)); - let inter_tup = id::var(bwd_vars.len()); - bwd_vars.push(None); - - let mut tp = Transpose { - f, - deps, - mapped_types: f - .types - .iter() - .enumerate() - .map(|(i, ty)| match ty { - Ty::Unit => Ty::Unit, - Ty::Bool => Ty::Bool, - Ty::F64 => { - if !is_primitive(id::ty(i)) { - panic!() - } - Ty::F64 - } - &Ty::Fin { size } => Ty::Fin { size }, - &Ty::Generic { id } => Ty::Generic { id }, - &Ty::Ref { inner } => { - if is_primitive(inner) { - panic!() - } - Ty::Ref { inner } - } - &Ty::Array { index, elem } => { - if is_primitive(elem) { - panic!() - } - Ty::Array { index, elem } - } - Ty::Tuple { members } => { - if members.iter().any(|&t| is_primitive(t)) { - Ty::F64 - } else { - Ty::Tuple { - members: members.clone(), - } - } - } - }) - .collect(), - types: indexset! { Ty::Unit }, - unit: id::ty(f.types.len()), - fwd_vars: f.vars.to_vec(), - bwd_vars, - real_shape, - prims: vec![None; f.vars.len()].into(), - duals: vec![None; f.vars.len()].into(), - accums: vec![None; f.vars.len()].into(), - cotans: vec![None; f.vars.len()].into(), - stack: vec![], - block: Block { - fwd: vec![], - inter_tup, - inter_mem: vec![], - bwd_nonlin: vec![], - bwd_lin: vec![], - }, - }; - - let mut bwd_params: Vec<_> = f - .params - .iter() - .map(|¶m| { - let t = f.vars[param.var()]; - match &f.types[t.ty()] { - &Ty::Ref { inner } => { - let cot = tp.bwd_var(Some(inner)); - tp.cotans[param.var()] = Some(cot); - cot - } - _ => { - let t_acc = tp.ty(Ty::Ref { inner: t }); - tp.keep(param); - let acc = tp.bwd_var(Some(t_acc)); - if let Ty::F64 = tp.mapped_types[t.ty()] { - tp.duals[param.var()] = Some((Src(None), Src(None))); - } - tp.accums[param.var()] = Some(acc); - acc - } - } - }) - .collect(); - - let (t_intermediates, fwd_inter) = tp.block(&f.body); - let fwd_ret = tp.get_re(f.ret); - let bwd_acc = tp.get_accum(f.ret); - - let mut bwd_types = tp.mapped_types; - bwd_types.extend(tp.types.into_iter()); - - let mut fwd_types = bwd_types.clone(); - let t_bundle = id::ty(fwd_types.len()); - fwd_types.push(Ty::Tuple { - members: [f.vars[f.ret.var()], t_intermediates].into(), - }); - let mut fwd_vars = tp.fwd_vars; - let fwd_bundle = id::var(fwd_vars.len()); - fwd_vars.push(t_bundle); - let mut fwd_body = tp.block.fwd; - fwd_body.push(Instr { - var: fwd_bundle, - expr: Expr::Tuple { - members: [fwd_ret, fwd_inter].into(), - }, - }); - - let mut bwd_vars: Vec<_> = tp.bwd_vars.into_iter().map(|t| t.unwrap()).collect(); - let bwd_cot = id::var(bwd_vars.len()); - bwd_vars.push(f.vars[f.ret.var()]); - let bwd_unit = id::var(bwd_vars.len()); - bwd_vars.push(tp.unit); - bwd_params.push(bwd_cot); - bwd_params.push(tp.block.inter_tup); - let mut bwd_body = vec![Instr { - var: tp.real_shape, - expr: Expr::F64 { val: 0. }, - }]; - bwd_body.append(&mut tp.block.bwd_nonlin); - bwd_body.push(Instr { - var: bwd_unit, - expr: Expr::Add { - accum: bwd_acc, - addend: bwd_cot, - }, - }); - let mut bwd_lin = tp.block.bwd_lin; - bwd_lin.reverse(); - bwd_body.append(&mut bwd_lin); - let bwd_ret = id::var(bwd_vars.len()); // separate var, because `bwd_unit` might not be in scope - bwd_vars.push(tp.unit); - bwd_body.push(Instr { - var: bwd_ret, - expr: Expr::Unit, - }); - - ( - Func { - generics: f.generics.clone(), - types: fwd_types.into(), - vars: fwd_vars.into(), - params: f.params.clone(), - ret: fwd_bundle, - body: fwd_body.into(), - }, - Func { - generics: f.generics.clone(), - types: bwd_types.into(), - vars: bwd_vars.into(), - params: bwd_params.into(), - ret: bwd_ret, - body: bwd_body.into(), - }, - ) -} diff --git a/crates/wasm/Cargo.toml b/crates/wasm/Cargo.toml deleted file mode 100644 index 226e515..0000000 --- a/crates/wasm/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "rose-wasm" -version = "0.4.5" -publish = false -edition = "2021" - -[dependencies] -by_address = "1" -indexmap = "2" -rose = { path = "../core" } -wasm-encoder = "0.33" diff --git a/crates/wasm/src/lib.rs b/crates/wasm/src/lib.rs deleted file mode 100644 index 4107cba..0000000 --- a/crates/wasm/src/lib.rs +++ /dev/null @@ -1,1413 +0,0 @@ -use by_address::ByAddress; -use indexmap::{map::Entry, IndexMap, IndexSet}; -use rose::{id, Binop, Expr, Func, Instr, Node, Refs, Ty, Unop}; -use std::{ - hash::Hash, - mem::{replace, take}, -}; -use wasm_encoder::{ - BlockType, CodeSection, EntityType, ExportSection, Function, FunctionSection, ImportSection, - Instruction, MemArg, MemorySection, MemoryType, Module, TypeSection, ValType, -}; - -/// Resolve `ty` via `generics` and `types`, then return its ID in `typemap`, inserting if need be. -/// -/// This is meant to be used to pull all the types from a callee into a broader context. The -/// `generics` are the IDs of all the types provided as generic type parameters for the callee. The -/// `types are the IDs of all the types that have been pulled in so far. -fn resolve(typemap: &mut IndexSet, generics: &[id::Ty], types: &[id::Ty], ty: &Ty) -> id::Ty { - let resolved = match ty { - Ty::Generic { id } => return generics[id.generic()], - - Ty::Unit => Ty::Unit, - Ty::Bool => Ty::Bool, - Ty::F64 => Ty::F64, - &Ty::Fin { size } => Ty::Fin { size }, - - Ty::Ref { inner } => Ty::Ref { - inner: types[inner.ty()], - }, - Ty::Array { index, elem } => Ty::Array { - index: types[index.ty()], - elem: types[elem.ty()], - }, - Ty::Tuple { members } => Ty::Tuple { - members: members.iter().map(|&x| types[x.ty()]).collect(), - }, - }; - let (i, _) = typemap.insert_full(resolved); - id::ty(i) -} - -/// An index of opaque functions. -/// -/// Each key holds the opaque function itself followed by the generic parameters used for this -/// particular instance. The value is the resolved type signature of the function according to a -/// global type index. -type Imports = IndexMap<(O, Box<[id::Ty]>), (Box<[id::Ty]>, id::Ty)>; - -/// Liveness and aliasing analysis results for a variable. -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -enum Var { - /// The variable is dead. - Dead, - - /// The variable is live, and is not an alias of another variable. - Live, - - /// The variable is live, and is an alias of another variable. - Alias(id::Var), -} - -/// Collection of liveness and aliasing analysis results for a whole function. -struct Vars(Box<[Var]>); - -impl Vars { - /// Return the current analysis for `x`. - fn get(&self, x: id::Var) -> Var { - self.0[x.var()] - } - - /// Mark `x` as live, if it's not already marked as an alias. - fn live(&mut self, x: id::Var) { - if let Var::Dead = self.get(x) { - self.0[x.var()] = Var::Live - } - } - - /// If `y` is live or an alias then mark `x` as live. - fn follow(&mut self, x: id::Var, y: id::Var) { - match self.get(y) { - Var::Dead => {} - Var::Alias(_) | Var::Live => self.live(x), - } - } - - /// Mark `x` as an alias of `y`. - fn alias(&mut self, x: id::Var, y: id::Var) { - self.0[x.var()] = Var::Alias(y); - } -} - -/// An index of transparent functions. -/// -/// Each key holds a reference to the function itself followed by the generic parameters used for -/// this particular instance. The value holds the function's immediate callees (see `rose::Refs`) -/// followed by a mapping from the function's own type indices to resolved type indices in a -/// global type index. -type Funcs<'a, T> = IndexMap<(ByAddress<&'a Func>, Box<[id::Ty]>), (T, Box<[id::Ty]>, Vars)>; - -/// Computes a topological sort of a call graph via depth-first search. -struct Topsort<'a, O, T> { - /// All types seen so far. - types: IndexSet, - - /// All opaque functions seen so far. - imports: Imports, - - /// All transparent functions seen so far, in topological sorted order. - funcs: Funcs<'a, T>, -} - -impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Topsort<'a, O, T> { - /// Return the resolved type for the variable `x` in the function `f`. - /// - /// The `types` argument is the resolved type ID for each of `f.types` in `self.types`. - fn var_ty(&self, f: &'a Func, types: &[id::Ty], x: id::Var) -> &Ty { - &self.types[types[f.vars[x.var()].ty()].ty()] - } - - /// Search in the given `block` of `f`, using `refs` to resolve immediate function calls. - /// - /// The `types` argument is the resolved type ID for each of `f.types` in `self.types`. - fn block(&mut self, refs: &T, f: &'a Func, types: &[id::Ty], vars: &mut Vars, block: &[Instr]) { - for instr in block.iter().rev() { - match &instr.expr { - Expr::Unit | Expr::Bool { .. } | Expr::F64 { .. } | Expr::Fin { .. } => {} - Expr::Array { elems } => { - for &elem in elems.iter() { - vars.follow(elem, instr.var); - } - } - Expr::Tuple { members } => { - for &member in members.iter() { - vars.follow(member, instr.var); - } - } - &Expr::Index { array, index } => { - vars.follow(array, instr.var); - vars.follow(index, instr.var); - } - &Expr::Member { tuple, .. } => vars.follow(tuple, instr.var), - &Expr::Slice { index, .. } => { - vars.live(instr.var); - vars.follow(index, instr.var); - } - &Expr::Field { .. } => vars.live(instr.var), - &Expr::Unary { arg, .. } => vars.follow(arg, instr.var), - &Expr::Binary { left, right, .. } => { - vars.follow(left, instr.var); - vars.follow(right, instr.var); - } - &Expr::Select { cond, then, els } => { - if let Ty::Ref { .. } = self.var_ty(f, types, instr.var) { - vars.live(instr.var); // we simply consider all accumulators to be live - } - vars.follow(then, instr.var); - vars.follow(els, instr.var); - vars.follow(cond, instr.var); - } - Expr::Call { id, generics, args } => { - for arg in args.iter() { - vars.live(*arg); // args always live because callee might have side effects - } - let gens = generics.iter().map(|t| types[t.ty()]).collect(); - match refs.get(*id).unwrap() { - Node::Transparent { refs, def } => { - let key = (ByAddress(def), gens); - if !self.funcs.contains_key(&key) { - let (_, gens) = key; // get back `gens` to please the borrow checker - self.func(refs, def, gens); - } - } - Node::Opaque { def, .. } => { - let resolved = ( - args.iter().map(|x| types[f.vars[x.var()].ty()]).collect(), - types[f.vars[instr.var.var()].ty()], - ); - match self.imports.entry((def, gens)) { - Entry::Occupied(entry) => { - // we should never see the same exact opaque function with the - // same generic type parameters but multiple different type - // signatures - assert_eq!(entry.get(), &resolved); - } - Entry::Vacant(entry) => { - entry.insert(resolved); - } - } - } - } - } - Expr::For { arg, body, ret } => { - // we consider loop collections live because their bodies may have side effects - vars.live(instr.var); - vars.live(*ret); - self.block(refs, f, types, vars, body); - vars.live(*arg); // we need the index to compile the loop itself - } - &Expr::Accum { shape } => { - vars.live(instr.var); // we simply consider all accumulators to be live - match self.var_ty(f, types, shape) { - Ty::F64 => {} - Ty::Unit - | Ty::Bool - | Ty::Fin { .. } - | Ty::Array { .. } - | Ty::Tuple { .. } => { - vars.live(shape) // therefore the shape is always live too - } - Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), - } - } - &Expr::Add { addend, .. } => vars.live(addend), - &Expr::Resolve { var } => vars.alias(instr.var, var), - } - } - } - - /// Search from `def` with the given `generics`, using `refs` to resolve immediate calls. - fn func(&mut self, refs: T, def: &'a Func, generics: Box<[id::Ty]>) { - let mut types = vec![]; - for ty in def.types.iter() { - types.push(resolve(&mut self.types, &generics, &types, ty)); - } - let mut vars = Vars(vec![Var::Dead; def.vars.len()].into()); - vars.live(def.ret); - self.block(&refs, def, &types, &mut vars, &def.body); - for ¶m in def.params.iter() { - if let Ty::Ref { .. } = self.var_ty(def, &types, param) { - vars.live(param); - } - } - let prev = self - .funcs - .insert((ByAddress(def), generics), (refs, types.into(), vars)); - // we're doing depth-first search on a DAG, so even if we wait until this last moment to - // mark the node as visited, we still can't have seen it already - assert!(prev.is_none()); - } -} - -/// A WebAssembly memory offset or size. -type Size = u32; - -/// Convert a `usize` to a `Size`. -/// -/// This will always succeed if the compiler itself is running inside WebAssembly. -fn u_size(x: usize) -> Size { - x.try_into().unwrap() -} - -/// Round up `size` to the nearest multiple of `align`. -fn aligned(size: Size, align: Size) -> Size { - (size + align - 1) & !(align - 1) -} - -/// The layout of a type in memory. -#[derive(Clone, Copy)] -enum Layout { - /// The unit type. Zero-sized. - Unit, - - /// An unsigned 8-bit integer. - U8, - - /// An unsigned 16-bit integer. - U16, - - /// An unsigned 32-bit integer. - U32, - - /// A 64-bit floating-point number. - F64, - - /// `Ty::Ref` cannot be stored in memory. - Ref, -} - -impl Layout { - /// Return the size and alignment of this `Layout`, in bytes. - fn size_align(self) -> (Size, Size) { - match self { - Self::Unit => (0, 1), - Self::U8 => (1, 1), - Self::U16 => (2, 2), - Self::U32 => (4, 4), - Self::F64 => (8, 8), - Self::Ref => unreachable!(), - } - } - - /// Return the size of this `Layout`, which is always aligned. - fn size(self) -> Size { - let (size, _) = self.size_align(); - size // no need to use alignment, because every possible `Layout` size is already aligned - } - - /// Emit a load instruction for this layout with the given byte offset. - fn load(self, function: &mut Function, offset: Size) { - let offset = offset.into(); - match self { - Self::Unit => { - function.instruction(&Instruction::Drop); - function.instruction(&Instruction::I32Const(0)); - } - Self::U8 => { - function.instruction(&Instruction::I32Load8U(MemArg { - offset, - align: 0, - memory_index: 0, - })); - } - Self::U16 => { - function.instruction(&Instruction::I32Load16U(MemArg { - offset, - align: 1, - memory_index: 0, - })); - } - Self::U32 => { - function.instruction(&Instruction::I32Load(MemArg { - offset, - align: 2, - memory_index: 0, - })); - } - Self::F64 => { - function.instruction(&Instruction::F64Load(MemArg { - offset, - align: 3, - memory_index: 0, - })); - } - Self::Ref => unreachable!(), - } - } - - /// Emit a store instruction for this layout with the given byte offset. - fn store(self, function: &mut Function, offset: Size) { - let offset = offset.into(); - match self { - Self::Unit => { - function.instruction(&Instruction::Drop); - function.instruction(&Instruction::Drop); - } - Self::U8 => { - function.instruction(&Instruction::I32Store8(MemArg { - offset, - align: 0, - memory_index: 0, - })); - } - Self::U16 => { - function.instruction(&Instruction::I32Store16(MemArg { - offset, - align: 1, - memory_index: 0, - })); - } - Self::U32 => { - function.instruction(&Instruction::I32Store(MemArg { - offset, - align: 2, - memory_index: 0, - })); - } - Self::F64 => { - function.instruction(&Instruction::F64Store(MemArg { - offset, - align: 3, - memory_index: 0, - })); - } - Self::Ref => unreachable!(), - } - } -} - -/// The index of a WebAssembly local. -type Local = u32; - -/// Information about a type that has functions for accumulation. -#[derive(Clone, Copy)] -struct Accum { - /// The ID of the zero function. - zero: u32, - - /// The allocation cost of the zero function. - cost: Size, - - // The ID of the add function, which has no allocation cost. - add: u32, -} - -/// Information about a type that is necessary for code generation. -struct Meta { - /// The type. - ty: Ty, - - /// The layout of the type. - layout: Layout, - - /// Zero and add functions for accumulation, if this type is an array or tuple. - accum: Option, - - /// Offsets of each member of a tuple. - members: Option>, -} - -/// Return the WebAssembly value type used to represent a local for the type with global ID `t`. -/// -/// The second component of the pair is `true` iff a parameter with this result should be switched -/// to a result instead; the only case where this happens is if the type is a `Ref` containing an -/// `F64`. -fn val_type(metas: &[Meta], t: id::Ty) -> (ValType, bool) { - match metas[t.ty()].ty { - Ty::Unit | Ty::Bool | Ty::Fin { .. } | Ty::Array { .. } | Ty::Tuple { .. } => { - (ValType::I32, false) - } - Ty::F64 => (ValType::F64, false), - Ty::Generic { .. } => unreachable!(), - Ty::Ref { inner } => { - let (vt, _) = val_type(metas, inner); - (vt, vt == ValType::F64) - } - } -} - -/// Generates WebAssembly code for a function. -struct Codegen<'a, 'b, O, T> { - /// Metadata about all the types in the global type index. - metas: &'b [Meta], - - /// All opaque functions. - imports: &'b Imports, - - /// The number of opaque functions plus the number of accumulation functions (zeros and adds). - extras: usize, - - /// All transparent functions. - funcs: &'b Funcs<'a, T>, - - /// The allocation cost of each transparent function. - costs: &'b [Size], - - /// To resolve calls. - refs: &'b T, - - /// The definition of the particular function we're generating code for. - def: &'a Func, - - /// Mapping from this function's type indices to type indices in the global type index. - types: &'b [id::Ty], - - /// The WebAssembly local assigned to the stack pointer. - pointer: Local, - - /// The WebAssembly local assigned to each live variable in this function. - locals: &'b [Option], - - /// The amount of memory allocated so far in the current block. - /// - /// This is for the block and not the entire function, because for instance, a loop's total - /// allocation cost depends both on its block's allocation cost and on the number of iterations. - offset: Size, - - /// Stack of pending accumulator instructions to process at the end of each scope. - /// - /// The bottom element is always an empty vector, because after we process the entire function, - /// we call `resolve` again, which always pops this stack. - stack: Vec>, - - /// Pending accumulator instructions to process at the end of this scope, in reverse order. - unresolved: Vec<&'a Instr>, - - /// The WebAssembly function under construction. - wasm: Function, -} - -impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { - /// Return metadata for the type ID `t` in the current function. - /// - /// Do not use this if your type ID is already resolved to refer to the global type index. - fn meta(&self, t: id::Ty) -> &'b Meta { - &self.metas[self.types[t.ty()].ty()] - } - - /// Return true iff `x` is assigned a WebAssembly local. - fn live(&self, x: id::Var) -> bool { - self.locals[x.var()].is_some() - } - - /// Emit an instruction to push the value of `x` onto the stack. - fn get(&mut self, x: id::Var) { - self.wasm - .instruction(&Instruction::LocalGet(self.locals[x.var()].unwrap())); - } - - /// Emit an instruction to pop the top of the stack and store it in `x`. - fn set(&mut self, x: id::Var) { - self.wasm - .instruction(&Instruction::LocalSet(self.locals[x.var()].unwrap())); - } - - /// Emit an instruction to store the stack top in `x` without popping it. - fn tee(&mut self, x: id::Var) { - self.wasm - .instruction(&Instruction::LocalTee(self.locals[x.var()].unwrap())); - } - - /// Emit an instruction to push the current memory allocation pointer onto the stack. - fn pointer(&mut self) { - self.wasm.instruction(&Instruction::LocalGet(self.pointer)); - } - - /// Emit an instruction to push the constant integer value `x` onto the stack. - fn u32_const(&mut self, x: u32) { - self.wasm.instruction(&Instruction::I32Const(x as i32)); - } - - /// Emit instructions to increase the memory allocation pointer by `size` bytes. - fn bump(&mut self, size: Size) { - let aligned = aligned(size, 8); - self.pointer(); - self.u32_const(aligned); - self.wasm.instruction(&Instruction::I32Add); - self.wasm.instruction(&Instruction::LocalSet(self.pointer)); - self.offset += aligned; - } - - /// Emit instruction(s) to load a value with the given `layout` and `offset`. - fn load(&mut self, layout: Layout, offset: Size) { - layout.load(&mut self.wasm, offset) - } - - /// Emit instruction(s) to store a value with the given `layout` and `offset`. - fn store(&mut self, layout: Layout, offset: Size) { - layout.store(&mut self.wasm, offset) - } - - /// Pop the top of the scope stack and process all pending accumulator instructions. - fn resolve(&mut self) { - let unresolved = replace(&mut self.unresolved, self.stack.pop().unwrap()); - for instr in unresolved.into_iter().rev() { - match instr.expr { - Expr::Slice { array, index } => { - let layout = Layout::F64; - self.get(array); - self.get(index); - self.u32_const(layout.size()); - self.wasm.instruction(&Instruction::I32Mul); - self.wasm.instruction(&Instruction::I32Add); - self.get(array); - self.get(index); - self.u32_const(layout.size()); - // TODO: avoid recalculating the offset - self.wasm.instruction(&Instruction::I32Mul); - self.wasm.instruction(&Instruction::I32Add); - self.load(layout, 0); - self.get(instr.var); - self.wasm.instruction(&Instruction::F64Add); - self.store(layout, 0); - } - Expr::Field { tuple, member } => { - let Meta { members, .. } = - self.meta(match self.def.types[self.def.vars[tuple.var()].ty()] { - Ty::Ref { inner } => inner, - _ => unreachable!(), - }); - let offset = members.as_ref().unwrap()[member.member()]; - let layout = Layout::F64; - self.get(tuple); - self.get(tuple); - self.load(layout, offset); - self.get(instr.var); - self.wasm.instruction(&Instruction::F64Add); - self.store(layout, offset); - } - Expr::Select { cond, then, els } => { - self.get(cond); - self.wasm.instruction(&Instruction::If(BlockType::Empty)); - self.get(then); - self.get(instr.var); - self.wasm.instruction(&Instruction::F64Add); - self.set(then); - self.wasm.instruction(&Instruction::Else); - self.get(els); - self.get(instr.var); - self.wasm.instruction(&Instruction::F64Add); - self.set(els); - self.wasm.instruction(&Instruction::End); - } - _ => unreachable!(), - } - } - } - - /// Generate code for the given `block`. - fn block(&mut self, block: &'a [Instr]) { - for instr in block.iter() { - match &instr.expr { - Expr::Unit => {} - &Expr::Bool { val } => { - if self.live(instr.var) { - self.wasm.instruction(&Instruction::I32Const(val.into())); - self.set(instr.var); - } - } - &Expr::F64 { val } => { - if self.live(instr.var) { - self.wasm.instruction(&Instruction::F64Const(val)); - self.set(instr.var); - } - } - &Expr::Fin { val } => { - if self.live(instr.var) { - self.wasm - .instruction(&Instruction::I32Const(val.try_into().unwrap())); - self.set(instr.var); - } - } - Expr::Array { elems } => { - if self.live(instr.var) { - let &Meta { layout, .. } = - self.meta(match self.def.types[self.def.vars[instr.var.var()].ty()] { - Ty::Array { elem, .. } => elem, - _ => unreachable!(), - }); - let size = layout.size(); - for (i, &elem) in elems.iter().enumerate() { - self.pointer(); - self.get(elem); - self.store(layout, size * u_size(i)); - } - self.pointer(); - self.set(instr.var); - self.bump(size * u_size(elems.len())); - } - } - Expr::Tuple { members } => { - if self.live(instr.var) { - let Meta { members: mems, .. } = self.meta(self.def.vars[instr.var.var()]); - let mut size = 0; - for (&member, &offset) in members.iter().zip(mems.as_ref().unwrap().iter()) - { - let &Meta { layout, .. } = self.meta(self.def.vars[member.var()]); - self.pointer(); - self.get(member); - self.store(layout, offset); - size = size.max(offset + layout.size()); - } - self.pointer(); - self.set(instr.var); - self.bump(size); - } - } - &Expr::Index { array, index } => { - if self.live(instr.var) { - let &Meta { layout, .. } = self.meta(self.def.vars[instr.var.var()]); - let size = layout.size(); - self.get(array); - self.get(index); - self.u32_const(size); - self.wasm.instruction(&Instruction::I32Mul); - self.wasm.instruction(&Instruction::I32Add); - self.load(layout, 0); - self.set(instr.var); - } - } - &Expr::Member { tuple, member } => { - if self.live(instr.var) { - let Meta { members, .. } = self.meta(self.def.vars[tuple.var()]); - let offset = members.as_ref().unwrap()[member.member()]; - let &Meta { layout, .. } = self.meta(self.def.vars[instr.var.var()]); - self.get(tuple); - self.load(layout, offset); - self.set(instr.var); - } - } - &Expr::Slice { array, index } => { - // we simply consider all accumulators to be live - let meta = - self.meta(match self.def.types[self.def.vars[instr.var.var()].ty()] { - Ty::Ref { inner } => inner, - _ => unreachable!(), - }); - match meta.ty { - Ty::F64 => { - // need to explicitly set to zero because accumulators can be modified - // and this initialization might be inside of a loop - self.wasm.instruction(&Instruction::F64Const(0.)); - self.unresolved.push(instr); - } - _ => { - let size = meta.layout.size(); - self.get(array); - self.get(index); - self.u32_const(size); - self.wasm.instruction(&Instruction::I32Mul); - self.wasm.instruction(&Instruction::I32Add); - self.load(meta.layout, 0); - } - } - self.set(instr.var); - } - &Expr::Field { tuple, member } => { - // we simply consider all accumulators to be live - let meta = - self.meta(match self.def.types[self.def.vars[instr.var.var()].ty()] { - Ty::Ref { inner } => inner, - _ => unreachable!(), - }); - match meta.ty { - Ty::F64 => { - // need to explicitly set to zero because accumulators can be modified - // and this initialization might be inside of a loop - self.wasm.instruction(&Instruction::F64Const(0.)); - self.unresolved.push(instr); - } - _ => { - let Meta { members, .. } = - self.meta(match self.def.types[self.def.vars[tuple.var()].ty()] { - Ty::Ref { inner } => inner, - _ => unreachable!(), - }); - let offset = members.as_ref().unwrap()[member.member()]; - self.get(tuple); - self.load(meta.layout, offset); - } - } - self.set(instr.var); - } - &Expr::Unary { op, arg } => { - if self.live(instr.var) { - match op { - Unop::Not => { - self.get(arg); - self.wasm.instruction(&Instruction::I32Eqz); - } - Unop::Neg => { - self.get(arg); - self.wasm.instruction(&Instruction::F64Neg); - } - Unop::Abs => { - self.get(arg); - self.wasm.instruction(&Instruction::F64Abs); - } - Unop::Sign => { - // TODO: `f64.const` instructions are always 8 bytes, much larger - // than most instructions; maybe we should just keep this constant - // in a local - self.wasm.instruction(&Instruction::F64Const(1.)); - self.get(arg); - self.wasm.instruction(&Instruction::F64Copysign); - } - Unop::Ceil => { - self.get(arg); - self.wasm.instruction(&Instruction::F64Ceil); - } - Unop::Floor => { - self.get(arg); - self.wasm.instruction(&Instruction::F64Floor); - } - Unop::Trunc => { - self.get(arg); - self.wasm.instruction(&Instruction::F64Trunc); - } - Unop::Sqrt => { - self.get(arg); - self.wasm.instruction(&Instruction::F64Sqrt); - } - } - self.set(instr.var); - } - } - &Expr::Binary { op, left, right } => { - if self.live(instr.var) { - self.get(left); - self.get(right); - match op { - Binop::And => self.wasm.instruction(&Instruction::I32And), - Binop::Or => self.wasm.instruction(&Instruction::I32Or), - Binop::Iff => self.wasm.instruction(&Instruction::I32Eq), - Binop::Xor => self.wasm.instruction(&Instruction::I32Xor), - Binop::Neq => self.wasm.instruction(&Instruction::F64Ne), - Binop::Lt => self.wasm.instruction(&Instruction::F64Lt), - Binop::Leq => self.wasm.instruction(&Instruction::F64Le), - Binop::Eq => self.wasm.instruction(&Instruction::F64Eq), - Binop::Gt => self.wasm.instruction(&Instruction::F64Gt), - Binop::Geq => self.wasm.instruction(&Instruction::F64Ge), - Binop::Add => self.wasm.instruction(&Instruction::F64Add), - Binop::Sub => self.wasm.instruction(&Instruction::F64Sub), - Binop::Mul => self.wasm.instruction(&Instruction::F64Mul), - Binop::Div => self.wasm.instruction(&Instruction::F64Div), - }; - self.set(instr.var); - } - } - &Expr::Select { cond, then, els } => { - match self.def.types[self.def.vars[instr.var.var()].ty()] { - Ty::Ref { inner } if self.meta(inner).ty == Ty::F64 => { - // need to explicitly set to zero because accumulators can be modified - // and this initialization might be inside of a loop - self.wasm.instruction(&Instruction::F64Const(0.)); - self.set(instr.var); - self.unresolved.push(instr); - } - _ if self.live(instr.var) => { - self.get(then); - self.get(els); - self.get(cond); - self.wasm.instruction(&Instruction::Select); - self.set(instr.var); - } - _ => {} - } - } - Expr::Call { id, generics, args } => { - // we simply consider all calls to be live because they could have side effects - let gens = generics - .iter() - .map(|t| self.types[self.def.vars[t.ty()].ty()]) - .collect(); - for &arg in args.iter() { - match self.def.types[self.def.vars[arg.var()].ty()] { - // `F64` accumulators become results, not params - Ty::Ref { inner } if self.meta(inner).ty == Ty::F64 => {} - _ => self.get(arg), - } - } - let i = match self.refs.get(*id).unwrap() { - Node::Transparent { def, .. } => { - self.pointer(); - let j = self.funcs.get_index_of(&(ByAddress(def), gens)).unwrap(); - self.bump(self.costs[j]); - self.extras + j - } - Node::Opaque { def, .. } => { - self.imports.get_index_of(&(def, gens)).unwrap() - } - }; - self.wasm - .instruction(&Instruction::Call(i.try_into().unwrap())); - for &arg in args.iter().rev() { - match &self.def.types[self.def.vars[arg.var()].ty()] { - &Ty::Ref { inner } if self.meta(inner).ty == Ty::F64 => { - // `F64` accumulators became results - self.get(arg); - self.wasm.instruction(&Instruction::F64Add); - self.set(arg); - } - _ => {} - } - } - if self.live(instr.var) { - self.set(instr.var); - } else { - self.wasm.instruction(&Instruction::Drop); - } - } - Expr::For { arg, body, ret } => { - let n = u_size(match self.meta(self.def.vars[arg.var()]).ty { - Ty::Fin { size } => size, - _ => unreachable!(), - }); - let &Meta { layout, .. } = self.meta(self.def.vars[ret.var()]); - let size = layout.size(); - - // we need to set the local now rather than later, because we're going to bump - // the pointer for the array itself and possibly in the loop body, but we still - // need to know this pointer so we can use it to store each element of the array - self.pointer(); - self.set(instr.var); - - // we put the bounds check at the end of the loop, so if it's going to execute - // zero times then we need to make sure not to enter it at all; easiest way is - // to just not emit the loop instructions at all - if n > 0 { - self.bump(size * n); - let offset = take(&mut self.offset); - - self.wasm.instruction(&Instruction::I32Const(0)); - self.set(*arg); - self.wasm.instruction(&Instruction::Loop(BlockType::Empty)); - - self.stack.push(take(&mut self.unresolved)); - self.block(body); - self.resolve(); - - self.get(instr.var); - self.get(*arg); - self.u32_const(size); - self.wasm.instruction(&Instruction::I32Mul); - self.wasm.instruction(&Instruction::I32Add); - self.get(*ret); - self.store(layout, 0); - - self.get(*arg); - self.wasm.instruction(&Instruction::I32Const(1)); - self.wasm.instruction(&Instruction::I32Add); - self.tee(*arg); - self.u32_const(n); - self.wasm.instruction(&Instruction::I32LtU); - self.wasm.instruction(&Instruction::BrIf(0)); - self.wasm.instruction(&Instruction::End); - - self.offset = offset + self.offset * n; - } - } - &Expr::Accum { shape } => { - // we simply consider all accumulators to be live - let meta = self.meta(self.def.vars[shape.var()]); - match &meta.ty { - Ty::Unit | Ty::Bool | Ty::Fin { .. } => self.get(shape), - Ty::F64 => { - // need to explicitly set to zero because accumulators can be modified - // and this initialization might be inside of a loop - self.wasm.instruction(&Instruction::F64Const(0.)); - } - Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), - Ty::Array { .. } | Ty::Tuple { .. } => { - let Accum { zero, cost, .. } = meta.accum.unwrap(); - self.pointer(); - self.get(shape); - self.wasm.instruction(&Instruction::Call(zero)); - self.pointer(); - self.bump(cost); - } - } - self.set(instr.var); - self.stack.push(take(&mut self.unresolved)); - } - &Expr::Add { accum, addend } => { - let meta = self.meta(self.def.vars[addend.var()]); - match &meta.ty { - Ty::Unit | Ty::Bool | Ty::Fin { .. } => {} - Ty::F64 => { - self.get(accum); - self.get(addend); - self.wasm.instruction(&Instruction::F64Add); - self.set(accum); - } - Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), - Ty::Array { .. } | Ty::Tuple { .. } => { - self.get(accum); - self.get(addend); - self.wasm - .instruction(&Instruction::Call(meta.accum.unwrap().add)); - } - } - } - &Expr::Resolve { var } => { - self.resolve(); - assert_eq!(self.locals[instr.var.var()], self.locals[var.var()]); - } - } - } - } -} - -/// A WebAssembly module for a graph of functions. -/// -/// The module exports its memory with name `"m"` and its entrypoint function with name `"f"`. The -/// function takes one parameter in addition to its original parameters, which must be an -/// 8-byte-aligned pointer to the start of the memory region it can use for allocation. The memory -/// is the exact number of pages necessary to accommodate the function's own memory allocation as -/// well as memory allocation for all of its parameters, with each node in each parameter's memory -/// allocation tree being 8-byte aligned. That is, the function's last argument should be just large -/// enough to accommodate those allocations for all the parameters; in that case, no memory will be -/// incorrectly overwritten and no out-of-bounds memory accesses will occur. -pub struct Wasm { - /// The bytes of the WebAssembly module binary. - pub bytes: Vec, - - /// All the opaque functions that the WebAssembly module must import, in order. - /// - /// The module name for each import is the empty string, and the field name is the base-ten - /// representation of its index in this collection. - pub imports: Imports, -} - -/// Compile `f` and all its direct and indirect callees to a WebAssembly module. -pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> Wasm { - let mut topsort = Topsort { - types: IndexSet::new(), - imports: IndexMap::new(), - funcs: IndexMap::new(), - }; - match f { - Node::Transparent { refs, def } => { - topsort.func(refs, def, [].into()); - } - Node::Opaque { - types, - params, - ret, - def, - .. - } => { - // if `f` itself is an opaque function then the graph of all callees has only one node - let mut def_types = vec![]; - for ty in types.iter() { - def_types.push(resolve(&mut topsort.types, &[], &def_types, ty)); - } - topsort.imports.insert( - (def, [].into()), - ( - params.iter().map(|t| def_types[t.ty()]).collect(), - def_types[ret.ty()], - ), - ); - } - } - let Topsort { - types, - imports, - funcs, - } = topsort; - - // we add to this lazily as we generate our imports, functions, and code, after which we'll - // generate the actual function types section right near the end; it doesn't matter as long as - // the order we actually add the sections to the module itself is correct - type Signature = (Box<[ValType]>, Box<[ValType]>); - let mut func_types: IndexSet = IndexSet::new(); - let (accum_type_index, _) = - func_types.insert_full(([ValType::I32, ValType::I32].into(), [].into())); - - let mut import_section = ImportSection::new(); - for (i, (params, ret)) in imports.values().enumerate() { - // short for `ValType` - let vt = |t: id::Ty| match types[t.ty()] { - Ty::F64 => ValType::F64, - _ => unreachable!(), - }; - let (type_index, _) = - func_types.insert_full((params.iter().map(|&t| vt(t)).collect(), [vt(*ret)].into())); - // we reserve type index zero for the type with two `i32` params and no results, which we - // use for accumulation zero and add functions; we don't include that in the `func_types` - // index itself, because that index only holds function types with exactly one result - import_section.import( - "", - &i.to_string(), - EntityType::Function(type_index.try_into().unwrap()), - ); - } - - let mut function_section = FunctionSection::new(); - let mut code_section = CodeSection::new(); - - let mut metas: Vec = vec![]; - let mut extras: usize = imports.len(); - for ty in types.into_iter() { - let (layout, cost, members) = match &ty { - Ty::Unit => (Layout::Unit, None, None), - Ty::Bool => (Layout::U8, None, None), - Ty::F64 => (Layout::F64, None, None), - &Ty::Fin { size } => ( - if size <= 1 { - Layout::Unit - } else if size <= 256 { - Layout::U8 - } else if size <= 65536 { - Layout::U16 - } else { - Layout::U32 - }, - None, - None, - ), - Ty::Generic { .. } => unreachable!(), - Ty::Ref { .. } => (Layout::Ref, None, None), - Ty::Array { index, elem } => { - let n = u_size(match metas[index.ty()].ty { - Ty::Fin { size } => size, - _ => unreachable!(), - }); - let meta = &metas[elem.ty()]; - let size = meta.layout.size(); - - // for both the zero function and the add function, the first parameter is a pointer - // to the accumulator value, and the second parameter is the pointer to the other - // value (the shape for zero, or the addend for add) - - // the first local is a pointer to the end of the accumulator array, used for bounds - // checking; the second local is a memory allocation pointer, used as the - // accumulator pointer for calls to the zero function for elements if this array - // stores composite values - let mut zero = Function::new([(2, ValType::I32)]); - let bound = size * n; // use this instead of padded `total` for bounds checking - let mut total = aligned(bound, 8); - // same as zero, the local is a pointer to the end of the accumulator array, used - // for bounds checking - let mut add = Function::new([(1, ValType::I32)]); - - if n > 0 { - zero.instruction(&Instruction::LocalGet(0)); - zero.instruction(&Instruction::I32Const(bound.try_into().unwrap())); - zero.instruction(&Instruction::I32Add); - zero.instruction(&Instruction::LocalSet(2)); - zero.instruction(&Instruction::LocalGet(0)); - zero.instruction(&Instruction::I32Const(total.try_into().unwrap())); - zero.instruction(&Instruction::I32Add); - zero.instruction(&Instruction::LocalSet(3)); - zero.instruction(&Instruction::Loop(BlockType::Empty)); - - add.instruction(&Instruction::LocalGet(0)); - add.instruction(&Instruction::I32Const(bound.try_into().unwrap())); - add.instruction(&Instruction::I32Add); - add.instruction(&Instruction::LocalSet(2)); - add.instruction(&Instruction::Loop(BlockType::Empty)); - - match &meta.ty { - Ty::Unit => {} - Ty::Bool | Ty::Fin { .. } => { - zero.instruction(&Instruction::LocalGet(0)); - zero.instruction(&Instruction::LocalGet(1)); - meta.layout.load(&mut zero, 0); - meta.layout.store(&mut zero, 0); - } - Ty::F64 => { - zero.instruction(&Instruction::LocalGet(0)); - // TODO: `f64.const` instructions are always 8 bytes, much larger than - // most instructions; maybe we should just keep this constant in a local - zero.instruction(&Instruction::F64Const(0.)); - meta.layout.store(&mut zero, 0); - - add.instruction(&Instruction::LocalGet(0)); - add.instruction(&Instruction::LocalGet(0)); - meta.layout.load(&mut add, 0); - add.instruction(&Instruction::LocalGet(1)); - meta.layout.load(&mut add, 0); - add.instruction(&Instruction::F64Add); - meta.layout.store(&mut add, 0); - } - Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), - Ty::Array { .. } | Ty::Tuple { .. } => { - let accum = meta.accum.unwrap(); - let cost = accum.cost; - - zero.instruction(&Instruction::LocalGet(0)); - zero.instruction(&Instruction::LocalGet(3)); - meta.layout.store(&mut zero, 0); - zero.instruction(&Instruction::LocalGet(3)); - zero.instruction(&Instruction::LocalGet(1)); - meta.layout.load(&mut zero, 0); - zero.instruction(&Instruction::Call(accum.zero)); - zero.instruction(&Instruction::LocalGet(3)); - zero.instruction(&Instruction::I32Const(cost.try_into().unwrap())); - zero.instruction(&Instruction::I32Add); - zero.instruction(&Instruction::LocalSet(3)); - - total += cost * n; - - add.instruction(&Instruction::LocalGet(0)); - meta.layout.load(&mut add, 0); - add.instruction(&Instruction::LocalGet(1)); - meta.layout.load(&mut add, 0); - add.instruction(&Instruction::Call(accum.add)); - } - } - - zero.instruction(&Instruction::LocalGet(1)); - zero.instruction(&Instruction::I32Const(size.try_into().unwrap())); - zero.instruction(&Instruction::I32Add); - zero.instruction(&Instruction::LocalSet(1)); - zero.instruction(&Instruction::LocalGet(0)); - zero.instruction(&Instruction::I32Const(size.try_into().unwrap())); - zero.instruction(&Instruction::I32Add); - zero.instruction(&Instruction::LocalTee(0)); - zero.instruction(&Instruction::LocalGet(2)); - zero.instruction(&Instruction::I32LtU); - zero.instruction(&Instruction::BrIf(0)); - zero.instruction(&Instruction::End); - - add.instruction(&Instruction::LocalGet(1)); - add.instruction(&Instruction::I32Const(size.try_into().unwrap())); - add.instruction(&Instruction::I32Add); - add.instruction(&Instruction::LocalSet(1)); - add.instruction(&Instruction::LocalGet(0)); - add.instruction(&Instruction::I32Const(size.try_into().unwrap())); - add.instruction(&Instruction::I32Add); - add.instruction(&Instruction::LocalTee(0)); - add.instruction(&Instruction::LocalGet(2)); - add.instruction(&Instruction::I32LtU); - add.instruction(&Instruction::BrIf(0)); - add.instruction(&Instruction::End); - } - - zero.instruction(&Instruction::End); - code_section.function(&zero); - - add.instruction(&Instruction::End); - code_section.function(&add); - - (Layout::U32, Some(total), None) - } - Ty::Tuple { members } => { - let mut mems: Vec<_> = members - .iter() - .enumerate() - .map(|(i, t)| { - let Meta { layout, .. } = metas[t.ty()]; - let (size, align) = layout.size_align(); - (i, size, align) - }) - .collect(); - mems.sort_by_key(|&(_, _, align)| align); - let mut offsets = vec![0; members.len()]; - let mut offset = 0; - for (i, s, a) in mems { - offset = aligned(offset, a); - offsets[i] = offset; - offset += s; - } - - // the local is a memory allocation pointer, used as the accumulator pointer for - // calls to the zero function for composite elements of the tuple - let mut zero = Function::new([(1, ValType::I32)]); - let mut total = aligned(offset, 8); - let mut add = Function::new([]); - - for (member, &offset) in members.iter().zip(offsets.iter()) { - let meta = &metas[member.ty()]; - - match &meta.ty { - Ty::Unit => {} - Ty::Bool | Ty::Fin { .. } => { - zero.instruction(&Instruction::LocalGet(0)); - zero.instruction(&Instruction::LocalGet(1)); - meta.layout.load(&mut zero, offset); - meta.layout.store(&mut zero, offset); - } - Ty::F64 => { - zero.instruction(&Instruction::LocalGet(0)); - // TODO: `f64.const` instructions are always 8 bytes, much larger than - // most instructions; maybe we should just keep this constant in a local - zero.instruction(&Instruction::F64Const(0.)); - meta.layout.store(&mut zero, offset); - - add.instruction(&Instruction::LocalGet(0)); - add.instruction(&Instruction::LocalGet(0)); - meta.layout.load(&mut add, offset); - add.instruction(&Instruction::LocalGet(1)); - meta.layout.load(&mut add, offset); - add.instruction(&Instruction::F64Add); - meta.layout.store(&mut add, offset); - } - Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), - Ty::Array { .. } | Ty::Tuple { .. } => { - let accum = meta.accum.unwrap(); - let cost = accum.cost; - - zero.instruction(&Instruction::LocalGet(0)); - zero.instruction(&Instruction::LocalGet(0)); - zero.instruction(&Instruction::I32Const(total.try_into().unwrap())); - zero.instruction(&Instruction::I32Add); - zero.instruction(&Instruction::LocalTee(2)); - meta.layout.store(&mut zero, offset); - zero.instruction(&Instruction::LocalGet(2)); - zero.instruction(&Instruction::LocalGet(1)); - meta.layout.load(&mut zero, offset); - zero.instruction(&Instruction::Call(accum.zero)); - - total += cost; - - add.instruction(&Instruction::LocalGet(0)); - meta.layout.load(&mut add, offset); - add.instruction(&Instruction::LocalGet(1)); - meta.layout.load(&mut add, offset); - add.instruction(&Instruction::Call(accum.add)); - } - } - } - - zero.instruction(&Instruction::End); - code_section.function(&zero); - - add.instruction(&Instruction::End); - code_section.function(&add); - - (Layout::U32, Some(total), Some(offsets.into())) - } - }; - metas.push(Meta { - ty, - layout, - accum: cost.map(|cost| { - let zero = extras.try_into().unwrap(); - function_section.function(accum_type_index.try_into().unwrap()); - let add = (extras + 1).try_into().unwrap(); - function_section.function(accum_type_index.try_into().unwrap()); - extras += 2; - Accum { zero, cost, add } - }), - members, - }); - } - - let mut costs = vec![]; // allocation cost of each function, in bytes - for ((def, _), (refs, def_types, vars)) in funcs.iter() { - let vt = |t: id::Ty| val_type(&metas, def_types[t.ty()]); // short for `ValType` - let mut locals = vec![None; def.vars.len()]; - - let (ret_ty, _) = vt(def.vars[def.ret.var()]); - let mut params = vec![]; - let mut results = vec![ret_ty]; - for param in def.params.iter() { - let (val_ty, result) = vt(def.vars[param.var()]); - if result { - results.push(val_ty); - } else { - locals[param.var()] = Some(params.len().try_into().unwrap()); - params.push(val_ty); - } - } - params.push(ValType::I32); // extra pointer parameter - let num_params: u32 = params.len().try_into().unwrap(); - let (type_index, _) = func_types.insert_full((params.into(), results.into())); - function_section.function(type_index.try_into().unwrap()); - - let mut i32s = 0; - for (i, &t) in def.vars.iter().enumerate() { - if let (None, (ValType::I32, _), Var::Live) = (locals[i], vt(t), vars.get(id::var(i))) { - locals[i] = Some(num_params + i32s); - i32s += 1; - } - } - let mut f64s = 0; - for (i, &t) in def.vars.iter().enumerate() { - if let (None, (ValType::F64, _), Var::Live) = (locals[i], vt(t), vars.get(id::var(i))) { - locals[i] = Some(num_params + i32s + f64s); - f64s += 1; - } - } - for (i, var) in vars.0.iter().enumerate() { - if let Var::Alias(other) = var { - locals[i] = locals[other.var()]; - } - } - - let mut codegen = Codegen { - metas: &metas, - imports: &imports, - extras, - funcs: &funcs, - costs: &costs, - refs, - def, - types: def_types, - pointer: num_params - 1, - locals: &locals, - offset: 0, - stack: vec![vec![]], - unresolved: vec![], - wasm: Function::new([(i32s, ValType::I32), (f64s, ValType::F64)]), - }; - // accumulator result variables are automatically zero: https://stackoverflow.com/a/77170544 - codegen.block(&def.body); - codegen.resolve(); - codegen.get(def.ret); - for ¶m in def.params.iter() { - if let (_, true) = vt(def.vars[param.var()]) { - // return the accumulator variables we moved from params to results - codegen.get(param); - } - } - codegen.wasm.instruction(&Instruction::End); - code_section.function(&codegen.wasm); - costs.push(codegen.offset); - } - - let mut type_section = TypeSection::new(); - for (params, results) in func_types { - type_section.function(params.into_vec(), results.into_vec()); - } - - let mut memory_section = MemorySection::new(); - let page_size = 65536; // https://webassembly.github.io/spec/core/exec/runtime.html#page-size - let cost = funcs.last().map_or(0, |((def, _), (_, def_types, _))| { - def.params - .iter() - .filter_map(|param| metas[def_types[def.vars[param.var()].ty()].ty()].accum) - .map(|accum| accum.cost) - .sum() - }) + costs.last().unwrap_or(&0); - let pages = ((cost + page_size - 1) / page_size).into(); // round up to a whole number of pages - memory_section.memory(MemoryType { - minimum: pages, - maximum: Some(pages), - memory64: false, - shared: false, - }); - - let mut export_section = ExportSection::new(); - export_section.export( - "f", - wasm_encoder::ExportKind::Func, - (extras + funcs.len() - 1).try_into().unwrap(), - ); - export_section.export("m", wasm_encoder::ExportKind::Memory, 0); - - let mut module = Module::new(); - module.section(&type_section); - module.section(&import_section); - module.section(&function_section); - module.section(&memory_section); - module.section(&export_section); - module.section(&code_section); - Wasm { - bytes: module.finish(), - imports, - } -} diff --git a/crates/web/Cargo.toml b/crates/web/Cargo.toml deleted file mode 100644 index 3d813ec..0000000 --- a/crates/web/Cargo.toml +++ /dev/null @@ -1,28 +0,0 @@ -[package] -name = "rose-web" -version = "0.4.5" -publish = false -edition = "2021" - -[lib] -crate-type = ["cdylib"] - -[dependencies] -by_address = "1" -console_error_panic_hook = { version = "0.1", optional = true } -console_log = { version = "1", optional = true } -enumset = "1" -indexmap = "2" -js-sys = "0.3" -rose = { path = "../core" } -rose-autodiff = { path = "../autodiff" } -rose-interp = { path = "../interp", features = ["serde"] } -rose-transpose = { path = "../transpose" } -rose-wasm = { path = "../wasm" } -serde = { version = "1", features = ["derive"] } -serde-wasm-bindgen = "0.4" -wasm-bindgen = "=0.2.87" # Must be this version of wbg - -[features] -default = ["debug"] -debug = ["dep:console_error_panic_hook", "dep:console_log"] diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs deleted file mode 100644 index 34623c8..0000000 --- a/crates/web/src/lib.rs +++ /dev/null @@ -1,1600 +0,0 @@ -#[cfg(feature = "debug")] -mod pprint; - -use by_address::ByAddress; -use enumset::EnumSet; -use indexmap::{IndexMap, IndexSet}; -use rose::id; -use serde::Serialize; -use std::{ - cell::RefCell, - rc::{Rc, Weak}, -}; -use wasm_bindgen::prelude::{wasm_bindgen, JsError, JsValue}; - -#[cfg(feature = "debug")] -#[wasm_bindgen] -pub fn initialize() { - std::panic::set_hook(Box::new(console_error_panic_hook::hook)); - console_log::init().unwrap(); -} - -fn to_js_value(value: &impl Serialize) -> Result { - value.serialize(&serde_wasm_bindgen::Serializer::json_compatible()) -} - -// for regression testing purposes only -#[cfg(feature = "debug")] -#[wasm_bindgen] -pub fn layouts() -> Result { - #[derive(Serialize)] - struct Layout { - size: usize, - align: usize, - } - - fn layout() -> Layout { - Layout { - size: std::mem::size_of::(), - align: std::mem::align_of::(), - } - } - - to_js_value(&[ - ("Expr", layout::()), - ("Func", layout::()), - ("Instr", layout::()), - ("Ty", layout::()), - ("Val", layout::()), - ]) -} - -/// Clone `x` into JavaScript. -fn val_to_js(x: &rose_interp::Val) -> JsValue { - match x { - rose_interp::Val::F64(x) => JsValue::from_f64(x.get()), - _ => todo!(), - } -} - -/// Reference to an opaque function that just points to a JavaScript function as its implementation. -#[derive(Clone, Copy, Eq, Hash, PartialEq)] -struct Opaque<'a> { - f: ByAddress<&'a js_sys::Function>, -} - -impl rose_interp::Opaque for Opaque<'_> { - fn call( - &self, - _: &IndexSet, - _: &[id::Ty], - args: &[rose_interp::Val], - ) -> rose_interp::Val { - let context = &JsValue::UNDEFINED; - // we only support functions with a small number of `F64` parameters that return `F64` - rose_interp::val_f64( - match args.len() { - 0 => self.f.call0(context), - 1 => self.f.call1(context, &val_to_js(&args[0])), - 2 => self - .f - .call2(context, &val_to_js(&args[0]), &val_to_js(&args[1])), - 3 => self.f.call3( - context, - &val_to_js(&args[0]), - &val_to_js(&args[1]), - &val_to_js(&args[2]), - ), - _ => todo!(), - } - .unwrap() - .as_f64() - .unwrap(), - ) - } -} - -/// Essentially an owned version of `rose::Node`. -enum Inner { - Transparent { - deps: Box<[Func]>, - def: rose::Func, - }, - Opaque { - generics: Box<[EnumSet]>, - types: Box<[rose::Ty]>, - params: Box<[id::Ty]>, - ret: id::Ty, - def: js_sys::Function, - }, -} - -/// Reference to a slice of function nodes, representing dependencies of a function. -struct Refs<'a> { - deps: &'a [Func], -} - -impl<'a> rose::Refs<'a> for Refs<'a> { - type Opaque = Opaque<'a>; - - fn get(&self, id: id::Func) -> Option, Self>> { - self.deps.get(id.func()).map(|f| f.node()) - } -} - -struct Pointee { - inner: Inner, - - /// Indices for string keys on tuple types that represent structs. - /// - /// The actual strings are stored in JavaScript. - structs: Box<[Option>]>, - - /// Jacobian-vector product. - jvp: RefCell>>, - - /// Forward pass of the vector-Jacobian product. - fwd: RefCell>>, - - /// Backward pass of the vector-Jacobian product. - bwd: RefCell>>, -} - -/// A node in a reference-counted acyclic digraph of functions. -#[wasm_bindgen] -#[derive(Clone)] -pub struct Func { - rc: Rc, -} - -#[cfg(feature = "debug")] -impl std::fmt::Display for Func { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - pprint::write_graph(f, self.node()) - } -} - -#[wasm_bindgen] -impl Func { - /// Return an opaque function taking `params` `F64` parameters and returning `F64`. - #[wasm_bindgen(constructor)] - pub fn new(params: usize, def: js_sys::Function) -> Self { - Self { - rc: Rc::new(Pointee { - inner: Inner::Opaque { - generics: [].into(), - types: [rose::Ty::F64].into(), - params: vec![id::ty(0); params].into(), - ret: id::ty(0), - def, - }, - structs: [].into(), - jvp: RefCell::new(None), - fwd: RefCell::new(None), - bwd: RefCell::new(None), - }), - } - } - - /// Construct a function node from the data this `Func` points to. - fn node(&self) -> rose::Node { - let Pointee { inner, .. } = self.rc.as_ref(); - match inner { - Inner::Transparent { deps, def } => rose::Node::Transparent { - refs: Refs { deps }, - def, - }, - Inner::Opaque { - generics, - types, - params, - ret, - def, - } => rose::Node::Opaque { - generics, - types, - params, - ret: *ret, - def: Opaque { f: ByAddress(def) }, - }, - } - } - - #[cfg(feature = "debug")] - #[wasm_bindgen] - pub fn pprint(&self) -> String { - format!("{self}") - } - - /// Return the IDs of this function's parameter types. - #[wasm_bindgen(js_name = "paramTypes")] - pub fn param_types(&self) -> Box<[usize]> { - let Pointee { inner, .. } = self.rc.as_ref(); - match inner { - Inner::Transparent { def, .. } => { - def.params.iter().map(|p| def.vars[p.var()].ty()).collect() - } - Inner::Opaque { params, .. } => params.iter().map(|p| p.ty()).collect(), - } - } - - /// Return the ID of this function's return type. - #[wasm_bindgen(js_name = "retType")] - pub fn ret_type(&self) -> usize { - let Pointee { inner, .. } = self.rc.as_ref(); - match inner { - Inner::Transparent { def, .. } => def.vars[def.ret.var()].ty(), - Inner::Opaque { ret, .. } => ret.ty(), - } - } - - /// Return the number of types defined in this function. - #[wasm_bindgen(js_name = "numTypes")] - pub fn num_types(&self) -> usize { - match &self.rc.as_ref().inner { - Inner::Transparent { def, .. } => def.types.len(), - Inner::Opaque { types, .. } => types.len(), - } - } - - /// Return the type with ID `t`, if it exists. - fn ty(&self, t: usize) -> Option<&rose::Ty> { - match &self.rc.as_ref().inner { - Inner::Transparent { def, .. } => def.types.get(t), - Inner::Opaque { types, .. } => types.get(t), - } - } - - /// Return true iff `t` is the ID of a unit type. - #[wasm_bindgen(js_name = "isUnit")] - pub fn is_unit(&self, t: usize) -> bool { - matches!(self.ty(t), Some(rose::Ty::Unit)) - } - - /// Return true iff `t` is the ID of a boolean type. - #[wasm_bindgen(js_name = "isBool")] - pub fn is_bool(&self, t: usize) -> bool { - matches!(self.ty(t), Some(rose::Ty::Bool)) - } - - /// Return true iff `t` is the ID of a 64-bit floating-point type. - #[wasm_bindgen(js_name = "isF64")] - pub fn is_f64(&self, t: usize) -> bool { - matches!(self.ty(t), Some(rose::Ty::F64)) - } - - /// Return true iff `t` is the ID of a finite integer type. - #[wasm_bindgen(js_name = "isFin")] - pub fn is_fin(&self, t: usize) -> bool { - matches!(self.ty(t), Some(rose::Ty::Fin { .. })) - } - - /// Return true iff `t` is the ID of an array type. - #[wasm_bindgen(js_name = "isArray")] - pub fn is_array(&self, t: usize) -> bool { - matches!(self.ty(t), Some(rose::Ty::Array { .. })) - } - - /// Return true iff `t` is the ID of a struct type. - #[wasm_bindgen(js_name = "isStruct")] - pub fn is_struct(&self, t: usize) -> bool { - self.rc.as_ref().structs[t].is_some() - } - - /// Return the size of the finite integer type with ID `t`. - pub fn size(&self, t: usize) -> usize { - match self.ty(t).unwrap() { - &rose::Ty::Fin { size } => size, - _ => panic!("not a finite integer"), - } - } - - /// Return the ID of the index type for the array type with ID `t`. - pub fn index(&self, t: usize) -> usize { - match self.ty(t).unwrap() { - rose::Ty::Array { index, elem: _ } => index.ty(), - _ => panic!("not an array"), - } - } - - /// Return the ID of the element type for the array type with ID `t`. - pub fn elem(&self, t: usize) -> usize { - match self.ty(t).unwrap() { - rose::Ty::Array { index: _, elem } => elem.ty(), - _ => panic!("not an array"), - } - } - - /// Return the string IDs for the struct type with ID `t`. - pub fn keys(&self, t: usize) -> Box<[usize]> { - self.rc.as_ref().structs[t].as_ref().unwrap().clone() - } - - /// Return the member type IDs for the struct type with ID `t`. - pub fn mems(&self, t: usize) -> Box<[usize]> { - match self.ty(t).unwrap() { - rose::Ty::Tuple { members } => members.iter().map(|m| m.ty()).collect(), - _ => panic!("not a struct"), - } - } - - /// Interpret a function with no generics or parameters. - /// - /// The `args` are Serde-converted to `Vec`, and the return value is - /// Serde-converted from `rose_interp::Val`. - pub fn interp(&self, args: JsValue) -> Result { - let vals: Vec = serde_wasm_bindgen::from_value(args)?; - let ret = rose_interp::interp(self.node(), IndexSet::new(), &[], vals.into_iter())?; - Ok(to_js_value(&ret)?) - } - - /// Compile the call graph subtended by this function to WebAssembly. - pub fn compile(&self) -> Wasm { - let rose_wasm::Wasm { bytes, imports } = rose_wasm::compile(self.node()); - Wasm { - bytes: Some(bytes), - imports: Some( - imports - .into_keys() - .map(|(Opaque { f }, _)| (*f).clone()) - .collect(), - ), - } - } - - /// Set the JVP of this function to `f`. - #[wasm_bindgen(js_name = "setJvp")] - pub fn set_jvp(&self, f: &Func) { - self.rc.as_ref().jvp.replace(Some(Rc::clone(&f.rc))); - } - - /// Return a function that computes the Jacobian-vector product of this function. - /// - /// `re` must be the string ID for the string `"re"` not just in this function, but in every - /// function that this function calls, and so on, transitively. Same for `du` and `"du"`. - pub fn jvp(&self, re: usize, du: usize) -> Self { - let Pointee { - inner, - structs, - jvp, - .. - } = self.rc.as_ref(); - if let Some(rc) = jvp.borrow().as_ref().map(Rc::clone) { - return Self { rc }; - } - let rc = - match inner { - Inner::Transparent { deps, def } => { - let mut structs_jvp = vec![None, None]; - // the first two types are the two new versions of `F64`; all the other types - // are just mapped one-to-one, except that previous versions of `F64` become - // tuples, so for those we use the string IDs we have been given - structs_jvp.extend(structs.iter().enumerate().map(|(i, s)| { - match &def.types[i] { - rose::Ty::F64 => Some([du, re].into()), - _ => s.clone(), - } - })); - Rc::new(Pointee { - inner: Inner::Transparent { - deps: deps.iter().map(|f| f.jvp(re, du)).collect(), - def: rose_autodiff::jvp(def), - }, - structs: structs_jvp.into(), - jvp: RefCell::new(None), - fwd: RefCell::new(None), - bwd: RefCell::new(None), - }) - } - Inner::Opaque { .. } => panic!("no JVP provided for opaque function"), - }; - jvp.replace(Some(Rc::clone(&rc))); - Self { rc } - } - - /// Return the forward and backward pass of the transpose of this function. - fn transpose_pair(&self) -> (Self, Self) { - let Pointee { - inner, - structs, - fwd, - bwd, - .. - } = self.rc.as_ref(); - if let (Some(rc_fwd), Some(rc_bwd)) = ( - fwd.borrow().as_ref().and_then(|weak| weak.upgrade()), - bwd.borrow().as_ref().and_then(|weak| weak.upgrade()), - ) { - return (Self { rc: rc_fwd }, Self { rc: rc_bwd }); - } - let (rc_fwd, rc_bwd) = match inner { - Inner::Transparent { deps, def } => { - if let rose::Ty::F64 = def.types[def.vars[def.ret.var()].ty()] { - return (self.clone(), self.clone()); - } - let (deps_fwd, deps_bwd): (Vec<_>, Vec<_>) = - deps.iter().map(|f| f.transpose_pair()).unzip(); - let dep_types: Box<_> = deps_fwd - .iter() - .map(|f| match &f.rc.as_ref().inner { - Inner::Transparent { def, .. } => { - (def.types.as_ref(), def.vars[def.ret.var()]) - } - Inner::Opaque { types, ret, .. } => (types.as_ref(), *ret), - }) - .collect(); - let (def_fwd, def_bwd) = rose_transpose::transpose(def, &dep_types); - let structs_fwd = def_fwd - .types - .iter() - .enumerate() - .map(|(i, ty)| match ty { - rose::Ty::F64 => None, - _ => structs.get(i).cloned().flatten(), - }) - .collect(); - let structs_bwd = def_bwd - .types - .iter() - .enumerate() - .map(|(i, ty)| match ty { - rose::Ty::F64 => None, - _ => structs.get(i).cloned().flatten(), - }) - .collect(); - ( - Rc::new(Pointee { - inner: Inner::Transparent { - deps: deps_fwd.into(), - def: def_fwd, - }, - structs: structs_fwd, - jvp: RefCell::new(None), - fwd: RefCell::new(None), - bwd: RefCell::new(None), - }), - Rc::new(Pointee { - inner: Inner::Transparent { - deps: deps_bwd.into(), - def: def_bwd, - }, - structs: structs_bwd, - jvp: RefCell::new(None), - fwd: RefCell::new(None), - bwd: RefCell::new(None), - }), - ) - } - Inner::Opaque { .. } => (Rc::clone(&self.rc), (Rc::clone(&self.rc))), - }; - fwd.replace(Some(Rc::downgrade(&rc_fwd))); - bwd.replace(Some(Rc::downgrade(&rc_bwd))); - (Self { rc: rc_fwd }, Self { rc: rc_bwd }) - } - - /// Return the transpose of this function. - /// - /// Assumes that this function has already been computed as the `jvp` of another function. - pub fn transpose(&self) -> Transpose { - let (fwd, bwd) = self.transpose_pair(); - Transpose { - fwd: Some(fwd), - bwd: Some(bwd), - } - } -} - -/// A temporary object to hold a generated WebAssembly module and its imports. -#[wasm_bindgen] -pub struct Wasm { - bytes: Option>, - imports: Option>, -} - -#[wasm_bindgen] -impl Wasm { - /// Return the module binary. - pub fn bytes(&mut self) -> Option> { - self.bytes.take() - } - - /// Return the imports. - pub fn imports(&mut self) -> Option> { - self.imports.take() - } -} - -/// A temporary object to hold the two passes of a transposed function before they are destructured. -#[wasm_bindgen] -pub struct Transpose { - fwd: Option, - bwd: Option, -} - -#[wasm_bindgen] -impl Transpose { - /// Return the forward pass. - pub fn fwd(&mut self) -> Option { - self.fwd.take() - } - - /// Return the backward pass. - pub fn bwd(&mut self) -> Option { - self.bwd.take() - } -} - -/// A type, with key name information in the case of tuples (which thus become structs). -#[derive(Clone, Debug, Eq, Hash, PartialEq)] -enum Ty { - Unit, - Bool, - F64, - T64, - Fin { - size: usize, - }, - Ref { - inner: id::Ty, - }, - Array { - index: id::Ty, - elem: id::Ty, - }, - - /// A tuple type, with additional information about key names that makes it into a struct. - Struct { - /// String IDs for key names, in order; the actual strings are stored in JavaScript. - keys: Option>, - - /// Member types of the underlying tuple. Must be the same length as `keys`. - members: Box<[id::Ty]>, - }, -} - -impl Ty { - /// Split this augmented type into an actual `rose::Ty` and any additional struct information. - fn separate(self) -> (rose::Ty, Option>) { - match self { - Ty::Unit => (rose::Ty::Unit, None), - Ty::Bool => (rose::Ty::Bool, None), - Ty::F64 => (rose::Ty::F64, None), - Ty::T64 => (rose::Ty::F64, None), - Ty::Fin { size } => (rose::Ty::Fin { size }, None), - Ty::Ref { inner } => (rose::Ty::Ref { inner }, None), - Ty::Array { index, elem } => (rose::Ty::Array { index, elem }, None), - Ty::Struct { keys, members } => (rose::Ty::Tuple { members }, keys), - } - } -} - -/// Metadata about a variable while its containing function is still under construction. -enum Extra { - /// Does not depend on any of the function's parameters. - Constant, - - /// Part of the main function body; these are definitions for variables that depend on it. - Parent(Vec), - - /// Depends on a `Parent` variable; others can depend on it only indirectly through its parent. - Child(id::Var), - - /// Is no longer in scope. - Expired, -} - -struct Var { - t: id::Ty, - extra: Extra, -} - -/// A function under construction. -#[wasm_bindgen] -pub struct FuncBuilder { - /// Called functions. More can be added as the function is built. - functions: Vec, - - /// Constraints on generic type parameters. These are fixed when the `FuncBuilder` is started. - generics: Box<[EnumSet]>, - - /// Index of types, with constraints tracked for validation (e.g. array index must be `Index`). - types: IndexMap>, - - /// Variable types, scopes (expired or not?), and dependent definitions. - vars: Vec, - - /// Parameters, in order. Typically added all at once right after the `FuncBuilder` is started. - params: Vec, - - /// Definitions that don't depend on parameters (but may depend on each other), in order. - constants: Vec, -} - -#[wasm_bindgen] -impl FuncBuilder { - /// Start building a function with the given number of `generics`, all constrained as `Index`. - #[wasm_bindgen(constructor)] - pub fn new(generics: usize) -> Self { - let mut types = IndexMap::new(); - types.insert(Ty::F64, EnumSet::only(rose::Constraint::Value)); - types.insert(Ty::T64, EnumSet::only(rose::Constraint::Value)); - Self { - functions: vec![], - generics: vec![EnumSet::only(rose::Constraint::Index); generics].into(), - types, - vars: vec![], - params: vec![], - constants: vec![], - } - } - - /// Assemble this function with return variable `out` and the given `body`. - pub fn finish(mut self, out: usize, body: Block) -> Func { - // We replace `self.params` and `self.constants` with empty vec because we need to satisfy - // the borrow checker when we pass `self` to `body.finish` below; this is OK though, because - // `Block::finish` is guaranteed not to use either `self.params` or `self.constants`. - let params = std::mem::take(&mut self.params).into_boxed_slice(); - let mut code = std::mem::take(&mut self.constants); - for &x in params.iter() { - self.extra(x, &mut code); - } - body.finish(&mut self, &mut code); - let (types, structs): (Vec<_>, Vec<_>) = - self.types.into_keys().map(|ty| ty.separate()).unzip(); - Func { - rc: Rc::new(Pointee { - inner: Inner::Transparent { - deps: self.functions.into(), - def: rose::Func { - generics: self.generics, - types: types.into(), - vars: self.vars.into_iter().map(|x| x.t).collect(), - params, - ret: id::var(out), - body: code.into(), - }, - }, - structs: structs.into(), - jvp: RefCell::new(None), - fwd: RefCell::new(None), - bwd: RefCell::new(None), - }), - } - } - - /// Finalize `x`, appending its dependencies onto `code` and marking it and them as expired. - /// - /// Must not use `self.params` or `self.constants`. - fn extra(&mut self, x: id::Var, code: &mut Vec) { - match std::mem::replace(&mut self.vars[x.var()].extra, Extra::Expired) { - Extra::Parent(extra) => { - for instr in extra.iter() { - self.vars[instr.var.var()].extra = Extra::Expired; - } - code.extend(extra); - } - Extra::Constant | Extra::Child(_) | Extra::Expired => unreachable!(), - } - } - - /// Should the type with ID `t` be represented as a JavaScript `Symbol`? - /// - /// Values of index types must be symbols so that we can use the standard JavaScript indexing - /// notation (along with `Proxy`, see below) to generate array accessing code. - #[wasm_bindgen(js_name = "isSymbol")] - pub fn is_symbol(&self, t: usize) -> bool { - let (ty, _) = self.types.get_index(t).unwrap(); - matches!(ty, Ty::Fin { .. }) - } - - /// Is the type with ID `t` an array that should be represented as a JavaScript `Proxy`? - /// - /// Values of array types must be proxies so that we can use the standard JavaScript indexing - /// notation (along with `Symbol`, see above) to generate array accessing code. - #[wasm_bindgen(js_name = "isArray")] - pub fn is_array(&self, t: usize) -> bool { - let (ty, _) = self.types.get_index(t).unwrap(); - matches!(ty, Ty::Array { .. }) - } - - /// Is the type with ID `t` a struct that should be represented as a JavaScript `Proxy`? - /// - /// Values of struct types must be proxies so that we can use the standard JavaScript property - /// access notation to generate member accessing code. - #[wasm_bindgen(js_name = "isStruct")] - pub fn is_struct(&self, t: usize) -> bool { - let (ty, _) = self.types.get_index(t).unwrap(); - matches!(ty, Ty::Struct { .. }) - } - - /// Return a reference to the type with ID `t` if it exists, `Err` otherwise. - /// - /// This returns a `Result` with `JsError` because it is a helper method meant to be used by - /// `pub` methods exposed to JavaScript; don't prefer it for more Rusty stuff, because `JsError` - /// doesn't implement `Debug` so you can't easily call `Result::unwrap` here. - fn ty(&self, t: usize) -> Result<&Ty, JsError> { - match self.types.get_index(t) { - None => Err(JsError::new("type does not exist")), - Some((ty, _)) => Ok(ty), - } - } - - /// Return the ID of the index type for the array type with ID `t`. - /// - /// `Err` if `t` is out of range or does not represent an array type. - pub fn index(&self, t: usize) -> Result { - match self.ty(t)? { - &Ty::Array { index, elem: _ } => Ok(index.ty()), - _ => Err(JsError::new("type is not an array")), - } - } - - /// Return the number of elements for the array type with ID `t`. - /// - /// `Err` if `t` is out of range or does not represent an array type, or if its index type is - /// not a fixed size (e.g. if it is a generic type parameter of the function). - pub fn size(&self, t: usize) -> Result { - match self.ty(t)? { - &Ty::Array { index, elem: _ } => { - let (i, _) = self.types.get_index(index.ty()).unwrap(); - match i { - &Ty::Fin { size } => Ok(size), - _ => Err(JsError::new("index type is not a fixed size")), - } - } - _ => Err(JsError::new("type is not an array")), - } - } - - /// Return the ID of the element type for the array type with ID `t`. - /// - /// `Err` if `t` is out of range or does not represent an array type. - pub fn elem(&self, t: usize) -> Result { - match self.ty(t)? { - &Ty::Array { index: _, elem } => Ok(elem.ty()), - _ => Err(JsError::new("type is not an array")), - } - } - - /// Return the string IDs of the keys for the struct type with ID `t`. - /// - /// `Err` if `t` is out of range or does not represent a struct type. - pub fn keys(&self, t: usize) -> Result, JsError> { - match self.ty(t)? { - Ty::Struct { - keys: Some(keys), - members: _, - } => Ok(keys.clone()), - _ => Err(JsError::new("type is not a struct")), - } - } - - /// Return the type IDs of the members for the struct type with ID `t`. - /// - /// `Err` if `t` is out of range or does not represent a struct type. - pub fn members(&self, t: usize) -> Result, JsError> { - match self.ty(t)? { - Ty::Struct { keys: _, members } => Ok(members.iter().map(|t| t.ty()).collect()), - _ => Err(JsError::new("type is not a struct")), - } - } - - /// Return `x` if it exists, is in scope, and has type ID `t`; `Err` otherwise. - pub fn expect(&self, t: usize, x: usize) -> Result { - match self.vars.get(x) { - None => Err(JsError::new("variable does not exist")), - Some(var) => match var.extra { - Extra::Expired => Err(JsError::new("variable is out of scope")), - _ => { - if var.t == id::ty(t) { - Ok(x) - } else { - Err(JsError::new("variable type mismatch")) - } - } - }, - } - } - - /// Return the type ID for `ty`, creating if needed, and marking its constraints as `constrs`. - fn newtype(&mut self, ty: Ty, constrs: EnumSet) -> usize { - let (i, _) = self.types.insert_full(ty, constrs); - i - } - - /// Create a new non-constant, non-child variable with type ID `t`, and return its ID. - /// - /// This method should only be used for variables that are about to be directly defined as part - /// of a `Block`, not for constants or any literals that get attached to other variables. - fn newvar(&mut self, t: id::Ty) -> id::Var { - let id = self.vars.len(); - self.vars.push(Var { - t, - extra: Extra::Parent(vec![]), - }); - id::var(id) - } - - /// Return the ID for the unit type, creating if needed. - #[wasm_bindgen(js_name = "tyUnit")] - pub fn ty_unit(&mut self) -> usize { - self.newtype(Ty::Unit, EnumSet::only(rose::Constraint::Value)) - } - - /// Return the ID for the boolean type, creating if needed. - #[wasm_bindgen(js_name = "tyBool")] - pub fn ty_bool(&mut self) -> usize { - self.newtype(Ty::Bool, EnumSet::only(rose::Constraint::Value)) - } - - /// Return the ID for the 64-bit floating-point type, creating if needed. - #[wasm_bindgen(js_name = "tyF64")] - pub fn ty_f64(&mut self) -> usize { - 0 - } - - /// Return the ID for the 64-bit floating-point tangent type, creating if needed. - #[wasm_bindgen(js_name = "tyT64")] - pub fn ty_t64(&mut self) -> usize { - 1 - } - - /// Return the ID for the type of nonnegative integers less than `size`, creating if needed. - #[wasm_bindgen(js_name = "tyFin")] - pub fn ty_fin(&mut self, size: usize) -> usize { - self.newtype( - Ty::Fin { size }, - rose::Constraint::Value | rose::Constraint::Index, - ) - } - - #[wasm_bindgen(js_name = "tyRef")] - pub fn ty_ref(&mut self, inner: usize) -> usize { - self.newtype( - Ty::Ref { - inner: id::ty(inner), - }, - EnumSet::empty(), - ) - } - - /// Return the ID for the type of arrays with index type `index` and element type `elem`, - /// - /// Assumes `index` and `elem` are valid type IDs. - #[wasm_bindgen(js_name = "tyArray")] - pub fn ty_array(&mut self, index: usize, elem: usize) -> Result { - let (_, constrs) = self.types.get_index(index).unwrap(); - // If we support non-`Value` types then we should also check that `elem` satisfies `Value`. - if constrs.contains(rose::Constraint::Index) { - Ok(self.newtype( - Ty::Array { - index: id::ty(index), - elem: id::ty(elem), - }, - EnumSet::only(rose::Constraint::Value), - )) - } else { - Err(JsError::new("index type cannot be used as an index")) - } - } - - /// Return the ID fr the type of structs with key string IDs `keys` and member type IDs `mems`. - /// - /// Assumes `keys` are valid string IDs and `mems` are valid type IDs. - #[wasm_bindgen(js_name = "tyStruct")] - pub fn ty_struct(&mut self, keys: &[usize], mems: &[usize]) -> usize { - self.newtype( - Ty::Struct { - keys: Some(keys.into()), - members: mems.iter().map(|&t| id::ty(t)).collect(), - }, - EnumSet::only(rose::Constraint::Value), - ) - } - - /// Return the ID of a new variable with type ID `t`. - pub fn bind(&mut self, t: usize) -> usize { - self.newvar(id::ty(t)).var() - } - - /// Append a parameter with type ID `t` and return its variable ID. - pub fn param(&mut self, t: usize) -> usize { - let x = self.newvar(id::ty(t)); - self.params.push(x); - x.var() - } - - /// Append a constant with type ID `t` and definition `expr`, and return its variable ID. - fn constant(&mut self, t: usize, expr: rose::Expr) -> usize { - let x = self.vars.len(); - self.vars.push(Var { - t: id::ty(t), - extra: Extra::Constant, - }); - self.constants.push(rose::Instr { - var: id::var(x), - expr, - }); - x - } - - /// Create a constant variable with the unit type, and return its ID. - /// - /// `Err` if `t` is not the ID of the unit type. - pub fn unit(&mut self, t: usize) -> Result { - if t == self.ty_unit() { - Ok(self.constant(t, rose::Expr::Unit)) - } else { - Err(JsError::new("did not expect null")) - } - } - - /// Create a constant variable with the boolean type and value `val`, and return its ID. - /// - /// `Err` if `t` is not the ID of the boolean type. - pub fn bool(&mut self, t: usize, val: bool) -> Result { - if t == self.ty_bool() { - Ok(self.constant(t, rose::Expr::Bool { val })) - } else { - Err(JsError::new("did not expect boolean")) - } - } - - /// Return the ID of a new numeric constant variable with type `t` and value converted from `x`. - /// - /// `Err` unless `t` is the ID of either the 64-bit floating-point type or a finite nonnegative - /// integer type; or if `t` is an integer type which cannot represent the given value of `x`. - pub fn num(&mut self, t: usize, x: f64) -> Result { - match self.ty(t)? { - Ty::F64 | Ty::T64 => Ok(self.constant(t, rose::Expr::F64 { val: x })), - &Ty::Fin { size } => { - let y = x as usize; - if y as f64 != x { - Err(JsError::new("can't be represented by an unsigned integer")) - } else if y >= size { - Err(JsError::new("out of range")) - } else { - Ok(self.constant(t, rose::Expr::Fin { val: y })) - } - } - _ => Err(JsError::new("type is not numeric")), - } - } - - /// Return the ID of a new variable with type ID `t` and value `expr`, depending on `xs`. - /// - /// Assumes that `t` is a valid type ID and `xs` are all valid variable IDs. If they all have - /// `Extra::Constant` then the new variable is a constant; otherwise, it is attached to - /// whichever parent variable reachable from `xs` has the highest ID. - fn attach(&mut self, t: usize, xs: &[usize], expr: rose::Expr) -> usize { - match xs - .iter() - .filter_map(|&x| match self.vars[x].extra { - Extra::Constant => None, - Extra::Parent(_) => Some(x), - Extra::Child(y) => Some(y.var()), - Extra::Expired => unreachable!(), - }) - .max() - { - None => self.constant(t, expr), - Some(x) => { - let y = self.vars.len(); - self.vars.push(Var { - t: id::ty(t), - extra: Extra::Child(id::var(x)), - }); - match &mut self.vars[x].extra { - Extra::Parent(instrs) => instrs.push(rose::Instr { - var: id::var(y), - expr, - }), - _ => unreachable!(), - } - y - } - } - } - - /// Return the ID of a new array variable with type ID `t` and elements `xs`. - /// - /// Assumes that `t` is a valid type ID and `xs` are all valid variable IDs. If there are no - /// dependencies on parameters then the new array variable is a constant; otherwise, it is - /// attached to whichever parent variable reachable from `xs` has the highest ID. - pub fn array(&mut self, t: usize, xs: &[usize]) -> usize { - let elems = xs.iter().map(|&x| id::var(x)).collect(); - let expr = rose::Expr::Array { elems }; - self.attach(t, xs, expr) - } - - /// Return the ID of a new struct variable with type ID `t` and elements `xs`. - /// - /// Assumes that `t` is a valid type ID and `xs` are all valid variable IDs. If there are no - /// dependencies on parameters then the new struct variable is a constant; otherwise, it is - /// attached to whichever parent variable reachable from `xs` has the highest ID. - pub fn obj(&mut self, t: usize, xs: &[usize]) -> usize { - let members = xs.iter().map(|&x| id::var(x)).collect(); - let expr = rose::Expr::Tuple { members }; - self.attach(t, xs, expr) - } - - /// Resolve `ty` via `generics` and `types`, then return its ID in `typemap`, inserting if need - /// be. - /// - /// This is meant to be used to pull all the types from a callee into a broader context. The - /// `generics` are the IDs of all the types provided as generic type parameters for the callee. - /// The `types are the IDs of all the types that have been pulled in so far. - fn resolve( - &mut self, - generics: &[usize], - strings: &[usize], - structs: &[Option>], - types: &[id::Ty], - t: usize, - ty: &rose::Ty, - ) -> id::Ty { - let (deduped, constrs) = match ty { - rose::Ty::Generic { id } => return id::ty(generics[id.generic()]), - - rose::Ty::Unit => (Ty::Unit, EnumSet::only(rose::Constraint::Value)), - rose::Ty::Bool => (Ty::Bool, EnumSet::only(rose::Constraint::Value)), - rose::Ty::F64 => (Ty::F64, EnumSet::only(rose::Constraint::Value)), - &rose::Ty::Fin { size } => ( - Ty::Fin { size }, - rose::Constraint::Value | rose::Constraint::Index, - ), - - rose::Ty::Ref { inner } => ( - Ty::Ref { - inner: types[inner.ty()], - }, - EnumSet::empty(), - ), - rose::Ty::Array { index, elem } => ( - Ty::Array { - index: types[index.ty()], - elem: types[elem.ty()], - }, - EnumSet::only(rose::Constraint::Value), - ), - rose::Ty::Tuple { members } => ( - Ty::Struct { - keys: structs[t] - .as_ref() - .map(|ss| ss.iter().map(|&s| strings[s]).collect()), - members: members.iter().map(|x| types[x.ty()]).collect(), - }, - EnumSet::only(rose::Constraint::Value), - ), - }; - let (i, _) = self.types.insert_full(deduped, constrs); - id::ty(i) - } - - /// Return the parameter and return type IDs in this function for calling `f` with `generics`. - /// - /// Assumes `generics` are all valid type IDs, and have the right length and constraints. - /// - /// The returned `Vec` is always nonempty, since its last element is the return type; all the - /// other elements are the parameter types. - pub fn ingest(&mut self, f: &Func, strings: &[usize], generics: &[usize]) -> Vec { - let Pointee { inner, structs, .. } = f.rc.as_ref(); - let def = match inner { - Inner::Transparent { def, .. } => def, - Inner::Opaque { params, .. } => { - // we currently only allow opaque functions of few `F64` parameters returning `F64` - let t = self.ty_f64(); - return vec![t; params.len() + 1]; - } - }; - let mut types = vec![]; - // push a corresponding type onto our own `types` for each type in the callee - for (t, ty) in def.types.iter().enumerate() { - types.push(self.resolve(generics, strings, structs, &types, t, ty)); - } - - let mut sig: Vec<_> = def - .params - .iter() - .map(|x| types[def.vars[x.var()].ty()].ty()) - .collect(); - sig.push(types[def.vars[def.ret.var()].ty()].ty()); - sig - } - - /// Return the type IDs for `left` and `right`, checking that they are defined and in scope. - fn get_lr(&self, left: usize, right: usize) -> Result<(id::Ty, id::Ty), JsError> { - let x = self - .vars - .get(left) - .ok_or_else(|| JsError::new("left is undefined"))?; - let y = self - .vars - .get(right) - .ok_or_else(|| JsError::new("right is undefined"))?; - if let Extra::Expired = x.extra { - return Err(JsError::new("left is out of scope")); - } - if let Extra::Expired = y.extra { - return Err(JsError::new("right is out of scope")); - } - Ok((x.t, y.t)) - } -} - -/// A block under construction. -#[wasm_bindgen] -pub struct Block { - code: Vec, -} - -// just to appease Clippy -impl Default for Block { - fn default() -> Self { - Self::new() - } -} - -#[wasm_bindgen] -impl Block { - /// Start building a block. - #[wasm_bindgen(constructor)] - pub fn new() -> Self { - Self { code: vec![] } - } - - /// Pour the contents of this block (including dependent variables) into `code`. - /// - /// Must not use `f.params` or `f.constants`. Marks all variables defined in this block as - /// expired. - fn finish(self, f: &mut FuncBuilder, code: &mut Vec) { - for instr in self.code.into_iter() { - let var = instr.var; - code.push(instr); - f.extra(var, code); - } - } - - /// Define a new variable in this block with type `t` and definition `expr`, and return its ID. - fn instr(&mut self, f: &mut FuncBuilder, t: id::Ty, expr: rose::Expr) -> usize { - let x = f.newvar(t); - self.code.push(rose::Instr { var: x, expr }); - x.var() - } - - /// Add an instruction getting the element of `arr` at index `idx`, and return its variable ID. - /// - /// Assumes `arr` and `idx` are valid variable IDs, that `idx` matches up with `arr`'s `index` - /// type, and that `arr`'s `elem` type is `t`. - pub fn index(&mut self, f: &mut FuncBuilder, t: usize, arr: usize, idx: usize) -> usize { - let array = id::var(arr); - let index = id::var(idx); - self.instr(f, id::ty(t), rose::Expr::Index { array, index }) - } - - /// Add an instruction getting member `mem` of `x`, and return its variable ID. - /// - /// Assumes `x` is a valid variable ID, that `mem` is a valid member ID for `x`'s struct type, - /// and that the type of that member is `t`. - pub fn member(&mut self, f: &mut FuncBuilder, t: usize, x: usize, mem: usize) -> usize { - let tuple = id::var(x); - let member = id::member(mem); - self.instr(f, id::ty(t), rose::Expr::Member { tuple, member }) - } - - /// Return the variable ID for a new boolean negation instruction on `arg`. - /// - /// Assumes `arg` is defined, in scope, and has boolean type. - pub fn not(&mut self, f: &mut FuncBuilder, arg: usize) -> usize { - let t = id::ty(f.ty_bool()); - let expr = rose::Expr::Unary { - op: rose::Unop::Not, - arg: id::var(arg), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new absolute value instruction on `arg`. - /// - /// Assumes `arg` is defined, in scope, and has 64-bit floating point type. - pub fn abs(&mut self, f: &mut FuncBuilder, arg: usize) -> usize { - let t = id::ty(f.ty_f64()); - let expr = rose::Expr::Unary { - op: rose::Unop::Abs, - arg: id::var(arg), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new signum instruction on `arg`. - /// - /// Assumes `arg` is defined, in scope, and has 64-bit floating point type. - pub fn sign(&mut self, f: &mut FuncBuilder, arg: usize) -> usize { - let t = id::ty(f.ty_f64()); - let expr = rose::Expr::Unary { - op: rose::Unop::Sign, - arg: id::var(arg), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new ceiling instruction on `arg`. - /// - /// Assumes `arg` is defined, in scope, and has 64-bit floating point type. - pub fn ceil(&mut self, f: &mut FuncBuilder, arg: usize) -> usize { - let t = id::ty(f.ty_f64()); - let expr = rose::Expr::Unary { - op: rose::Unop::Ceil, - arg: id::var(arg), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new floor instruction on `arg`. - /// - /// Assumes `arg` is defined, in scope, and has 64-bit floating point type. - pub fn floor(&mut self, f: &mut FuncBuilder, arg: usize) -> usize { - let t = id::ty(f.ty_f64()); - let expr = rose::Expr::Unary { - op: rose::Unop::Floor, - arg: id::var(arg), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new truncate instruction on `arg`. - /// - /// Assumes `arg` is defined, in scope, and has 64-bit floating point type. - pub fn trunc(&mut self, f: &mut FuncBuilder, arg: usize) -> usize { - let t = id::ty(f.ty_f64()); - let expr = rose::Expr::Unary { - op: rose::Unop::Trunc, - arg: id::var(arg), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new square root instruction on `arg`. - /// - /// Assumes `arg` is defined, in scope, and has 64-bit floating point type. - pub fn sqrt(&mut self, f: &mut FuncBuilder, arg: usize) -> usize { - let t = id::ty(f.ty_f64()); - let expr = rose::Expr::Unary { - op: rose::Unop::Sqrt, - arg: id::var(arg), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new logical conjunction instruction on `left` and `right`. - /// - /// Assumes `left` and `right` are defined, in scope, and have boolean type. - pub fn and(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = id::ty(f.ty_bool()); - let expr = rose::Expr::Binary { - op: rose::Binop::And, - left: id::var(left), - right: id::var(right), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new logical disjunction instruction on `left` and `right`. - /// - /// Assumes `left` and `right` are defined, in scope, and have boolean type. - pub fn or(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = id::ty(f.ty_bool()); - let expr = rose::Expr::Binary { - op: rose::Binop::Or, - left: id::var(left), - right: id::var(right), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new boolean equality instruction on `left` and `right`. - /// - /// Assumes `left` and `right` are defined, in scope, and have boolean type. - pub fn iff(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = id::ty(f.ty_bool()); - let expr = rose::Expr::Binary { - op: rose::Binop::Iff, - left: id::var(left), - right: id::var(right), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new exclusive disjunction instruction on `left` and `right`. - /// - /// Assumes `left` and `right` are defined, in scope, and have boolean type. - pub fn xor(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = id::ty(f.ty_bool()); - let expr = rose::Expr::Binary { - op: rose::Binop::Xor, - left: id::var(left), - right: id::var(right), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new "not equal" instruction on `left` and `right`. - /// - /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. - pub fn neq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = id::ty(f.ty_bool()); - let expr = rose::Expr::Binary { - op: rose::Binop::Neq, - left: id::var(left), - right: id::var(right), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new "less than" instruction on `left` and `right`. - /// - /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. - pub fn lt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = id::ty(f.ty_bool()); - let expr = rose::Expr::Binary { - op: rose::Binop::Lt, - left: id::var(left), - right: id::var(right), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new "less than or equal" instruction on `left` and `right`. - /// - /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. - pub fn leq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = id::ty(f.ty_bool()); - let expr = rose::Expr::Binary { - op: rose::Binop::Leq, - left: id::var(left), - right: id::var(right), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new "equal" instruction on `left` and `right`. - /// - /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. - pub fn eq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = id::ty(f.ty_bool()); - let expr = rose::Expr::Binary { - op: rose::Binop::Eq, - left: id::var(left), - right: id::var(right), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new "greater than" instruction on `left` and `right`. - /// - /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. - pub fn gt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = id::ty(f.ty_bool()); - let expr = rose::Expr::Binary { - op: rose::Binop::Gt, - left: id::var(left), - right: id::var(right), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new "greater than or equal" instruction on `left` and `right`. - /// - /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. - pub fn geq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = id::ty(f.ty_bool()); - let expr = rose::Expr::Binary { - op: rose::Binop::Geq, - left: id::var(left), - right: id::var(right), - }; - self.instr(f, t, expr) - } - - /// Return the variable ID for a new instruction using `cond` to choose `then` or `els`. - /// - /// Assumes `cond`, `then`, and `els` are defined and in scope, that `cond` has boolean type, - /// and that `then` and `els` both have type `t`. - pub fn select( - &mut self, - f: &mut FuncBuilder, - cond: usize, - t: usize, - then: usize, - els: usize, - ) -> usize { - let expr = rose::Expr::Select { - cond: id::var(cond), - then: id::var(then), - els: id::var(els), - }; - self.instr(f, id::ty(t), expr) - } - - /// Return the variable ID for a new instruction calling `g` with `generics` and `args`. - /// - /// Assumes that `generics` are all valid type IDs, and have the right length and constraints; - /// that `args` are all valid variable IDs and match up with `g`'s parameter types; and that the - /// return type of `g` matches `t` (all in the context of the given `generics`). - pub fn call( - &mut self, - f: &mut FuncBuilder, - g: &Func, - generics: &[usize], - t: usize, - args: &[usize], - ) -> usize { - // add the function reference to the callee - let id = id::func(f.functions.len()); - f.functions.push(g.clone()); - - let expr = rose::Expr::Call { - id, - generics: generics.iter().map(|&i| id::ty(i)).collect(), - args: args.iter().map(|&x| id::var(x)).collect(), - }; - self.instr(f, id::ty(t), expr) - } - - /// Return the variable ID for a new instruction defining an array elementwise via `body`. - /// - /// Assumes `arg` is defined and in scope; this represents the index variable for the element - /// definition body, so its dependencies are prepended to the body code and it is marked as - /// expired. Also assumes `out` is defined by `body`; this represents the final variable - /// defining each array element. Finally, assumes the type of the array (not the element) - /// matches `t`. - pub fn vec( - &mut self, - f: &mut FuncBuilder, - t: usize, - arg: usize, - body: Self, - out: usize, - ) -> usize { - let arg = id::var(arg); - let mut code = vec![]; - f.extra(arg, &mut code); - body.finish(f, &mut code); - let expr = rose::Expr::For { - arg, - body: code.into(), - ret: id::var(out), - }; - self.instr(f, id::ty(t), expr) - } - - /// Return the variable ID for a new instruction defining an accumulator with the given `shape`. - /// - /// Assumes `shape` is defined and in scope, and that `t` is the ID of a reference type whose - /// inner type is the same as the type of `shape`. - pub fn accum(&mut self, f: &mut FuncBuilder, t: usize, shape: usize) -> usize { - let expr = rose::Expr::Accum { - shape: id::var(shape), - }; - self.instr(f, id::ty(t), expr) - } - - /// Return the variable ID for a new instruction resolving the given accumulator `var`. - /// - /// Assumes `var` is defined and in scope, and that `t` is the inner type of the reference type - /// for `var`. - pub fn resolve(&mut self, f: &mut FuncBuilder, t: usize, var: usize) -> usize { - let expr = rose::Expr::Resolve { var: id::var(var) }; - self.instr(f, id::ty(t), expr) - } - - /// Return the variable ID for a new floating-point negation instruction on `arg`. - /// - /// `Err` if `arg` is undefined or out of scope, or if its type is not `F64` or `T64`. - pub fn neg(&mut self, f: &mut FuncBuilder, arg: usize) -> Result { - let x = f - .vars - .get(arg) - .ok_or_else(|| JsError::new("arg is undefined"))?; - if let Extra::Expired = x.extra { - return Err(JsError::new("arg is out of scope")); - } - let t = x.t; - if !(t.ty() == f.ty_f64() || t.ty() == f.ty_t64()) { - return Err(JsError::new("arg has invalid type")); - } - let expr = rose::Expr::Unary { - op: rose::Unop::Neg, - arg: id::var(arg), - }; - Ok(self.instr(f, t, expr)) - } - - /// Return the variable ID for a new addition instruction on `left` and `right`. - /// - /// `Err` if `left` or `right` is undefined or out of scope, or if their types are not either - /// both `F64` or both `T64`. - pub fn add( - &mut self, - f: &mut FuncBuilder, - left: usize, - right: usize, - ) -> Result { - let (t1, t2) = f.get_lr(left, right)?; - if !((t1.ty() == f.ty_f64() || t1.ty() == f.ty_t64()) && t2 == t1) { - return Err(JsError::new("left and right have invalid types")); - } - let expr = rose::Expr::Binary { - op: rose::Binop::Add, - left: id::var(left), - right: id::var(right), - }; - Ok(self.instr(f, t1, expr)) - } - - /// Return the variable ID for a new subtraction instruction on `left` and `right`. - /// - /// `Err` if `left` or `right` is undefined or out of scope, or if their types are not either - /// both `F64` or both `T64`. - pub fn sub( - &mut self, - f: &mut FuncBuilder, - left: usize, - right: usize, - ) -> Result { - let (t1, t2) = f.get_lr(left, right)?; - if !((t1.ty() == f.ty_f64() || t1.ty() == f.ty_t64()) && t2 == t1) { - return Err(JsError::new("left and right have invalid types")); - } - let expr = rose::Expr::Binary { - op: rose::Binop::Sub, - left: id::var(left), - right: id::var(right), - }; - Ok(self.instr(f, t1, expr)) - } - - /// Return the variable ID for a new multiplication instruction on `left` and `right`. - /// - /// `Err` if `left` or `right` is undefined or out of scope, or if `left`'s type is not `F64` or - /// `T64`, or if `right`'s type is not `F64`. - pub fn mul( - &mut self, - f: &mut FuncBuilder, - left: usize, - right: usize, - ) -> Result { - let (t1, t2) = f.get_lr(left, right)?; - if !((t1.ty() == f.ty_f64() || t1.ty() == f.ty_t64()) && t2.ty() == f.ty_f64()) { - return Err(JsError::new("left and right have invalid types")); - } - let expr = rose::Expr::Binary { - op: rose::Binop::Mul, - left: id::var(left), - right: id::var(right), - }; - Ok(self.instr(f, t1, expr)) - } - - /// Return the variable ID for a new division instruction on `left` and `right`. - /// - /// `Err` if `left` or `right` is undefined or out of scope, or if `left`'s type is not `F64` or - /// `T64`, or if `right`'s type is not `F64`. - pub fn div( - &mut self, - f: &mut FuncBuilder, - left: usize, - right: usize, - ) -> Result { - let (t1, t2) = f.get_lr(left, right)?; - if !((t1.ty() == f.ty_f64() || t1.ty() == f.ty_t64()) && t2.ty() == f.ty_f64()) { - return Err(JsError::new("left and right have invalid types")); - } - let expr = rose::Expr::Binary { - op: rose::Binop::Div, - left: id::var(left), - right: id::var(right), - }; - Ok(self.instr(f, t1, expr)) - } -} diff --git a/crates/web/src/pprint.rs b/crates/web/src/pprint.rs deleted file mode 100644 index df38690..0000000 --- a/crates/web/src/pprint.rs +++ /dev/null @@ -1,299 +0,0 @@ -use by_address::ByAddress; -use enumset::EnumSet; -use indexmap::{IndexMap, IndexSet}; -use rose::{id, Binop, Constraint, Expr, Func, Instr, Node, Refs, Ty, Unop}; -use std::{fmt, hash::Hash}; - -fn write_constraints(f: &mut fmt::Formatter<'_>, constraints: EnumSet) -> fmt::Result { - let mut first = true; - for constraint in constraints.iter() { - if first { - first = false; - } else { - write!(f, " + ")?; - } - write!(f, "{constraint:?}")?; - } - Ok(()) -} - -fn write_generics(f: &mut fmt::Formatter<'_>, generics: &[EnumSet]) -> fmt::Result { - write!(f, "<")?; - let mut first = true; - for (i, &constraints) in generics.iter().enumerate() { - if first { - first = false; - } else { - write!(f, ", ")?; - } - write!(f, "G{i}: ")?; - write_constraints(f, constraints)?; - } - write!(f, ">") -} - -fn write_types(f: &mut fmt::Formatter<'_>, types: &[Ty]) -> fmt::Result { - for (i, ty) in types.iter().enumerate() { - write!(f, " type T{i} = ")?; - match ty { - Ty::Unit | Ty::Bool | Ty::F64 => writeln!(f, "{ty:?}")?, - Ty::Fin { size } => writeln!(f, "{size}")?, - Ty::Generic { id } => writeln!(f, "G{}", id.generic())?, - Ty::Ref { inner } => writeln!(f, "&T{}", inner.ty())?, - Ty::Array { index, elem } => writeln!(f, "[T{}]T{}", index.ty(), elem.ty())?, - Ty::Tuple { members } => { - write!(f, "(")?; - write_elems(f, 'T', members.iter().map(|member| member.ty()))?; - writeln!(f, ")")?; - } - } - } - Ok(()) -} - -fn write_opaque( - f: &mut fmt::Formatter<'_>, - generics: &[EnumSet], - types: &[Ty], - params: &[id::Ty], - ret: id::Ty, -) -> fmt::Result { - write_generics(f, generics)?; - writeln!(f, "{{")?; - write_types(f, types)?; - write!(f, " opaque: (")?; - write_elems(f, 'T', params.iter().map(|t| t.ty()))?; - writeln!(f, ") -> T{}", ret.ty())?; - writeln!(f, "}}") -} - -fn search<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>( - f: &mut fmt::Formatter<'_>, - imports: &mut IndexSet, - funcs: &mut IndexMap, T>, - refs: &T, - block: &[Instr], -) -> fmt::Result { - for instr in block.iter() { - match &instr.expr { - &Expr::Call { id, .. } => match refs.get(id).unwrap() { - Node::Transparent { refs, def } => { - let key = ByAddress(def); - if !funcs.contains_key(&key) { - search(f, imports, funcs, &refs, &def.body)?; - funcs.insert(key, refs); - } - } - Node::Opaque { - generics, - types, - params, - ret, - def, - } => { - let (i, new) = imports.insert_full(def); - if new { - write!(f, "fn f{i} = ")?; - write_opaque(f, generics, types, params, ret)?; - writeln!(f)?; - } - } - }, - Expr::For { body, .. } => search(f, imports, funcs, refs, body)?, - _ => {} - } - } - Ok(()) -} - -fn write_elems( - f: &mut fmt::Formatter<'_>, - prefix: char, - items: impl Iterator, -) -> std::fmt::Result { - let mut first = true; - for item in items { - if first { - first = false; - } else { - write!(f, ", ")?; - } - write!(f, "{}{}", prefix, item)?; - } - Ok(()) -} - -struct Function<'a, 'b, O, T> { - imports: &'b IndexSet, - funcs: &'b IndexMap, T>, - refs: &'b T, - def: &'a Func, -} - -impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Function<'a, '_, O, T> { - fn write_instr(&self, f: &mut fmt::Formatter<'_>, spaces: usize, instr: &Instr) -> fmt::Result { - for _ in 0..spaces { - write!(f, " ")?; - } - let x = instr.var.var(); - write!(f, "let x{}: T{} = ", x, self.def.vars[x].ty())?; - match &instr.expr { - Expr::Unit => writeln!(f, "unit")?, - Expr::Bool { val } => writeln!(f, "{val}")?, - Expr::F64 { val } => writeln!(f, "{val}")?, - Expr::Fin { val } => writeln!(f, "{val}")?, - Expr::Array { elems } => { - write!(f, "[")?; - write_elems(f, 'x', elems.iter().map(|elem| elem.var()))?; - writeln!(f, "]")?; - } - Expr::Tuple { members } => { - write!(f, "(")?; - write_elems(f, 'x', members.iter().map(|member| member.var()))?; - writeln!(f, ")")?; - } - Expr::Index { array, index } => writeln!(f, "x{}[x{}]", array.var(), index.var())?, - Expr::Member { tuple, member } => writeln!(f, "x{}[{}]", tuple.var(), member.member())?, - Expr::Slice { array, index } => writeln!(f, "&x{}[x{}]", array.var(), index.var())?, - Expr::Field { tuple, member } => writeln!(f, "&x{}[{}]", tuple.var(), member.member())?, - Expr::Unary { op, arg } => match op { - Unop::Not => writeln!(f, "not x{}", arg.var())?, - Unop::Neg => writeln!(f, "-x{}", arg.var())?, - Unop::Abs => writeln!(f, "|x{}|", arg.var())?, - Unop::Sign => writeln!(f, "sign(x{})", arg.var())?, - Unop::Ceil => writeln!(f, "ceil(x{})", arg.var())?, - Unop::Floor => writeln!(f, "floor(x{})", arg.var())?, - Unop::Trunc => writeln!(f, "trunc(x{})", arg.var())?, - Unop::Sqrt => writeln!(f, "sqrt(x{})", arg.var())?, - }, - Expr::Binary { op, left, right } => match op { - Binop::And => writeln!(f, "x{} and x{}", left.var(), right.var())?, - Binop::Or => writeln!(f, "x{} or x{}", left.var(), right.var())?, - Binop::Iff => writeln!(f, "x{} iff x{}", left.var(), right.var())?, - Binop::Xor => writeln!(f, "x{} xor x{}", left.var(), right.var())?, - Binop::Neq => writeln!(f, "x{} != x{}", left.var(), right.var())?, - Binop::Lt => writeln!(f, "x{} < x{}", left.var(), right.var())?, - Binop::Leq => writeln!(f, "x{} <= x{}", left.var(), right.var())?, - Binop::Eq => writeln!(f, "x{} == x{}", left.var(), right.var())?, - Binop::Gt => writeln!(f, "x{} > x{}", left.var(), right.var())?, - Binop::Geq => writeln!(f, "x{} >= x{}", left.var(), right.var())?, - Binop::Add => writeln!(f, "x{} + x{}", left.var(), right.var())?, - Binop::Sub => writeln!(f, "x{} - x{}", left.var(), right.var())?, - Binop::Mul => writeln!(f, "x{} * x{}", left.var(), right.var())?, - Binop::Div => writeln!(f, "x{} / x{}", left.var(), right.var())?, - }, - Expr::Select { cond, then, els } => { - writeln!(f, "x{} ? x{} : x{}", cond.var(), then.var(), els.var())? - } - Expr::Call { id, generics, args } => { - let i = match self.refs.get(*id).unwrap() { - Node::Transparent { def, .. } => { - self.imports.len() + self.funcs.get_index_of(&ByAddress(def)).unwrap() - } - Node::Opaque { def, .. } => self.imports.get_index_of(&def).unwrap(), - }; - write!(f, "f{i}<")?; - write_elems(f, 'T', generics.iter().map(|generic| generic.ty()))?; - write!(f, ">(")?; - write_elems(f, 'x', args.iter().map(|arg| arg.var()))?; - writeln!(f, ")")?; - } - Expr::For { arg, body, ret } => { - writeln!( - f, - "for x{}: T{} {{", - arg.var(), - self.def.vars[arg.var()].ty() - )?; - self.write_block(f, spaces + 2, body, *ret)?; - for _ in 0..spaces { - write!(f, " ")?; - } - writeln!(f, "}}")? - } - Expr::Accum { shape } => writeln!(f, "accum x{}", shape.var())?, - Expr::Add { accum, addend } => writeln!(f, "x{} += x{}", accum.var(), addend.var())?, - Expr::Resolve { var } => writeln!(f, "resolve x{}", var.var())?, - } - Ok(()) - } - - fn write_block( - &self, - f: &mut fmt::Formatter<'_>, - spaces: usize, - body: &[Instr], - ret: id::Var, - ) -> fmt::Result { - for instr in body.iter() { - self.write_instr(f, spaces, instr)?; - } - for _ in 0..spaces { - write!(f, " ")?; - } - writeln!(f, "x{}", ret.var())?; - Ok(()) - } - - fn write_func(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write_generics(f, &self.def.generics)?; - writeln!(f, "{{")?; - write_types(f, &self.def.types)?; - write!(f, " (")?; - let mut first = true; - for param in self.def.params.iter() { - if first { - first = false; - } else { - write!(f, ", ")?; - } - write!(f, "x{}: T{}", param.var(), self.def.vars[param.var()].ty())?; - } - writeln!(f, ") -> T{} {{", self.def.vars[self.def.ret.var()].ty())?; - self.write_block(f, 4, &self.def.body, self.def.ret)?; - writeln!(f, " }}")?; - writeln!(f, "}}") - } -} - -pub fn write_graph<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>( - f: &mut fmt::Formatter<'_>, - root: Node<'a, O, T>, -) -> fmt::Result { - match root { - Node::Opaque { - generics, - types, - params, - ret, - def: _, - } => { - write!(f, "fn f0 = ")?; - write_opaque(f, generics, types, params, ret) - } - Node::Transparent { refs, def } => { - let mut imports = IndexSet::new(); - let mut funcs = IndexMap::new(); - search(f, &mut imports, &mut funcs, &refs, &def.body)?; - for (i, (def, refs)) in funcs.iter().enumerate() { - write!(f, "fn f{} = ", imports.len() + i)?; - Function { - imports: &imports, - funcs: &funcs, - refs, - def, - } - .write_func(f)?; - writeln!(f)?; - } - write!(f, "fn f{} = ", imports.len() + funcs.len())?; - Function { - imports: &imports, - funcs: &funcs, - refs: &refs, - def, - } - .write_func(f) - } - } -} diff --git a/package-lock.json b/package-lock.json index 8ec0e59..10e16f2 100644 --- a/package-lock.json +++ b/package-lock.json @@ -8,12 +8,10 @@ "packages/*" ], "devDependencies": { - "esbuild": "^0.18", "prettier": "^3", "prettier-plugin-organize-imports": "^3", "typescript": "^5", "vite": "^4", - "vite-plugin-top-level-await": "^1", "vitest": "^0.33" } }, @@ -387,231 +385,16 @@ "integrity": "sha512-eF2rxCRulEKXHTRiDrDy6erMYWqNw4LPdQ8UQA4huuxaQsVeRPFl2oM8oDGxMFhJUWZf9McpLtJasDDZb/Bpeg==", "dev": true }, - "node_modules/@rollup/plugin-virtual": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@rollup/plugin-virtual/-/plugin-virtual-3.0.1.tgz", - "integrity": "sha512-fK8O0IL5+q+GrsMLuACVNk2x21g3yaw+sG2qn16SnUd3IlBsQyvWxLMGHmCmXRMecPjGRSZ/1LmZB4rjQm68og==", - "dev": true, - "engines": { - "node": ">=14.0.0" - }, - "peerDependencies": { - "rollup": "^1.20.0||^2.0.0||^3.0.0" - }, - "peerDependenciesMeta": { - "rollup": { - "optional": true - } - } - }, "node_modules/@rose-lang/site": { "resolved": "packages/site", "link": true }, - "node_modules/@rose-lang/wasm": { - "resolved": "packages/wasm", - "link": true - }, "node_modules/@sinclair/typebox": { "version": "0.27.8", "resolved": "https://registry.npmjs.org/@sinclair/typebox/-/typebox-0.27.8.tgz", "integrity": "sha512-+Fj43pSMwJs4KRrH/938Uf+uAELIgVBmQzg/q1YG10djyfA3TnrU8N8XzqCh/okZdszqBQTZf96idMfE5lnwTA==", "dev": true }, - "node_modules/@swc/core": { - "version": "1.3.69", - "resolved": "https://registry.npmjs.org/@swc/core/-/core-1.3.69.tgz", - "integrity": "sha512-Khc/DE9D5+2tYTHgAIp5DZARbs8kldWg3b0Jp6l8FQLjelcLFmlQWSwKhVZrgv4oIbgZydIp8jInsvTalMHqnQ==", - "dev": true, - "hasInstallScript": true, - "engines": { - "node": ">=10" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/swc" - }, - "optionalDependencies": { - "@swc/core-darwin-arm64": "1.3.69", - "@swc/core-darwin-x64": "1.3.69", - "@swc/core-linux-arm-gnueabihf": "1.3.69", - "@swc/core-linux-arm64-gnu": "1.3.69", - "@swc/core-linux-arm64-musl": "1.3.69", - "@swc/core-linux-x64-gnu": "1.3.69", - "@swc/core-linux-x64-musl": "1.3.69", - "@swc/core-win32-arm64-msvc": "1.3.69", - "@swc/core-win32-ia32-msvc": "1.3.69", - "@swc/core-win32-x64-msvc": "1.3.69" - }, - "peerDependencies": { - "@swc/helpers": "^0.5.0" - }, - "peerDependenciesMeta": { - "@swc/helpers": { - "optional": true - } - } - }, - "node_modules/@swc/core-darwin-arm64": { - "version": "1.3.69", - "resolved": "https://registry.npmjs.org/@swc/core-darwin-arm64/-/core-darwin-arm64-1.3.69.tgz", - "integrity": "sha512-IjZTf12zIPWkV3D7toaLDoJPSkLhQ4fDH8G6/yCJUI27cBFOI3L8LXqptYmISoN5yYdrcnNpdqdapD09JPuNJg==", - "cpu": [ - "arm64" - ], - "dev": true, - "optional": true, - "os": [ - "darwin" - ], - "engines": { - "node": ">=10" - } - }, - "node_modules/@swc/core-darwin-x64": { - "version": "1.3.69", - "resolved": "https://registry.npmjs.org/@swc/core-darwin-x64/-/core-darwin-x64-1.3.69.tgz", - "integrity": "sha512-/wBO0Rn5oS5dJI/L9kJRkPAdksVwl5H9nleW/NM3A40N98VV8T7h/i1nO051mxIjq0R6qXVGOWFbBoLrPYucJg==", - "cpu": [ - "x64" - ], - "dev": true, - "optional": true, - "os": [ - "darwin" - ], - "engines": { - "node": ">=10" - } - }, - "node_modules/@swc/core-linux-arm-gnueabihf": { - "version": "1.3.69", - "resolved": "https://registry.npmjs.org/@swc/core-linux-arm-gnueabihf/-/core-linux-arm-gnueabihf-1.3.69.tgz", - "integrity": "sha512-NShCjMv6Xn8ckMKBRqmprXvUF14+jXY0TcNKXwjYErzoIUFOnG72M36HxT4QEeAtKZ4Eg4CZFE4zlJ27fDp1gg==", - "cpu": [ - "arm" - ], - "dev": true, - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=10" - } - }, - "node_modules/@swc/core-linux-arm64-gnu": { - "version": "1.3.69", - "resolved": "https://registry.npmjs.org/@swc/core-linux-arm64-gnu/-/core-linux-arm64-gnu-1.3.69.tgz", - "integrity": "sha512-VRPOJj4idopSHIj1bOVXX0SgaB18R8yZNunb7eXS5ZcjVxAcdvqyIz3RdQX1zaJFCGzcdPLzBRP32DZWWGE8Ng==", - "cpu": [ - "arm64" - ], - "dev": true, - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=10" - } - }, - "node_modules/@swc/core-linux-arm64-musl": { - "version": "1.3.69", - "resolved": "https://registry.npmjs.org/@swc/core-linux-arm64-musl/-/core-linux-arm64-musl-1.3.69.tgz", - "integrity": "sha512-QxeSiZqo5x1X8vq8oUWLibq+IZJcxl9vy0sLUmzdjF2b/Z+qxKP3gutxnb2tzJaHqPVBbEZaILERIGy1qWdumQ==", - "cpu": [ - "arm64" - ], - "dev": true, - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=10" - } - }, - "node_modules/@swc/core-linux-x64-gnu": { - "version": "1.3.69", - "resolved": "https://registry.npmjs.org/@swc/core-linux-x64-gnu/-/core-linux-x64-gnu-1.3.69.tgz", - "integrity": "sha512-b+DUlVxYox3BwD3PyTwhLvqtu6TYZtW+S6O0FnttH11o4skHN0XyJ/cUZSI0X2biSmfDsizRDUt1PWPFM+F7SA==", - "cpu": [ - "x64" - ], - "dev": true, - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=10" - } - }, - "node_modules/@swc/core-linux-x64-musl": { - "version": "1.3.69", - "resolved": "https://registry.npmjs.org/@swc/core-linux-x64-musl/-/core-linux-x64-musl-1.3.69.tgz", - "integrity": "sha512-QXjsI+f8n9XPZHUvmGgkABpzN4M9kdSbhqBOZmv3o0AsDGNCA4uVowQqgZoPFAqlJTpwHeDmrv5sQ13HN+LOGw==", - "cpu": [ - "x64" - ], - "dev": true, - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=10" - } - }, - "node_modules/@swc/core-win32-arm64-msvc": { - "version": "1.3.69", - "resolved": "https://registry.npmjs.org/@swc/core-win32-arm64-msvc/-/core-win32-arm64-msvc-1.3.69.tgz", - "integrity": "sha512-wn7A8Ws1fyviuCUB2Vg6IotiZeuqiO1Mz3d+YDae2EYyNpj1kNHvjBip8GHkfGzZG+jVrvG6NHsDo0KO/pGb8A==", - "cpu": [ - "arm64" - ], - "dev": true, - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">=10" - } - }, - "node_modules/@swc/core-win32-ia32-msvc": { - "version": "1.3.69", - "resolved": "https://registry.npmjs.org/@swc/core-win32-ia32-msvc/-/core-win32-ia32-msvc-1.3.69.tgz", - "integrity": "sha512-LsFBXtXqxEcVaaOGEZ9X3qdMzobVoJqKv8DnksuDsWcBk+9WCeTz2u/iB+7yZ2HGuPXkCqTRqhFo6FX9aC00kQ==", - "cpu": [ - "ia32" - ], - "dev": true, - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">=10" - } - }, - "node_modules/@swc/core-win32-x64-msvc": { - "version": "1.3.69", - "resolved": "https://registry.npmjs.org/@swc/core-win32-x64-msvc/-/core-win32-x64-msvc-1.3.69.tgz", - "integrity": "sha512-ieBscU0gUgKjaseFI07tAaGqHvKyweNknPeSYEZOasVZUczhD6fK2GRnVREhv2RB2qdKC/VGFBsgRDMgzq1VLw==", - "cpu": [ - "x64" - ], - "dev": true, - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">=10" - } - }, "node_modules/@types/chai": { "version": "4.3.5", "resolved": "https://registry.npmjs.org/@types/chai/-/chai-4.3.5.tgz", @@ -1201,15 +984,6 @@ "integrity": "sha512-TrY6DsjTQQgyS3E3dBaOXf0TpPD8u9FVrVYmKVegJuFw51n/YB9XPt+U6ydzFG5ZIN7+DIjPbNmXoBj9esYhgQ==", "dev": true }, - "node_modules/uuid": { - "version": "9.0.0", - "resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.0.tgz", - "integrity": "sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==", - "dev": true, - "bin": { - "uuid": "dist/bin/uuid" - } - }, "node_modules/vite": { "version": "4.4.4", "resolved": "https://registry.npmjs.org/vite/-/vite-4.4.4.tgz", @@ -1288,20 +1062,6 @@ "url": "https://opencollective.com/vitest" } }, - "node_modules/vite-plugin-top-level-await": { - "version": "1.3.1", - "resolved": "https://registry.npmjs.org/vite-plugin-top-level-await/-/vite-plugin-top-level-await-1.3.1.tgz", - "integrity": "sha512-55M1h4NAwkrpxPNOJIBzKZFihqLUzIgnElLSmPNPMR2Fn9+JHKaNg3sVX1Fq+VgvuBksQYxiD3OnwQAUu7kaPQ==", - "dev": true, - "dependencies": { - "@rollup/plugin-virtual": "^3.0.1", - "@swc/core": "^1.3.10", - "uuid": "^9.0.0" - }, - "peerDependencies": { - "vite": ">=2.8" - } - }, "node_modules/vitest": { "version": "0.33.0", "resolved": "https://registry.npmjs.org/vitest/-/vitest-0.33.0.tgz", @@ -1409,24 +1169,16 @@ }, "packages/core": { "name": "rose", - "version": "0.4.5", - "license": "MIT", - "dependencies": { - "@rose-lang/wasm": "0.4.5" - } + "version": "0.5.0", + "license": "MIT" }, "packages/site": { "name": "@rose-lang/site", - "version": "0.4.5", + "version": "0.5.0", "dependencies": { "highlight.js": "^11", - "rose": "0.4.5" + "rose": "0.5.0" } - }, - "packages/wasm": { - "name": "@rose-lang/wasm", - "version": "0.4.5", - "license": "MIT" } } } diff --git a/package.json b/package.json index d94253e..d5dbde9 100644 --- a/package.json +++ b/package.json @@ -8,9 +8,7 @@ "prettier": "^3", "prettier-plugin-organize-imports": "^3", "typescript": "^5", - "esbuild": "^0.18", "vite": "^4", - "vite-plugin-top-level-await": "^1", "vitest": "^0.33" } } diff --git a/packages/core/package.json b/packages/core/package.json index 51c866c..29e4f46 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -1,6 +1,6 @@ { "name": "rose", - "version": "0.4.5", + "version": "0.5.0", "license": "MIT", "repository": "rose-lang/rose", "type": "module", @@ -9,9 +9,6 @@ "dist", "src" ], - "dependencies": { - "@rose-lang/wasm": "0.4.5" - }, "scripts": { "build": "tsc", "test": "vitest" diff --git a/packages/core/src/impl.test.ts b/packages/core/src/impl.test.ts deleted file mode 100644 index fe142a3..0000000 --- a/packages/core/src/impl.test.ts +++ /dev/null @@ -1,389 +0,0 @@ -import * as wasm from "@rose-lang/wasm"; -import { describe, expect, test } from "vitest"; -import { - Dual, - Fn, - Real, - Vec, - fn, - inner, - mul, - neg, - opaque, - struct, - vjp, -} from "./impl.js"; - -const pprint = (f: Fn): string => f[inner].pprint(); - -test("core IR type layouts", () => { - // these don't matter too much, but it's good to notice if sizes increase - expect(Object.fromEntries(wasm.layouts())).toEqual({ - Expr: { size: 24, align: 8 }, - Func: { size: 44, align: 4 }, - Instr: { size: 32, align: 8 }, - Ty: { size: 12, align: 4 }, - Val: { size: 16, align: 8 }, - }); -}); - -describe("pprint", () => { - test("opaque", () => { - const f = opaque([Real], Real, (x) => x); - const s = pprint(f); - expect(s).toBe( - ` -fn f0 = <>{ - type T0 = F64 - opaque: (T0) -> T0 -} -`.trimStart(), - ); - }); - - test("graph", () => { - const exp = opaque([Real], Real, Math.exp); - const sin = opaque([Real], Real, Math.sin); - const cos = opaque([Real], Real, Math.cos); - - exp.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = exp(x); - return { re: y, du: mul(dx, y) }; - }); - sin.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - return { re: sin(x), du: mul(dx, cos(x)) }; - }); - cos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - return { re: cos(x), du: mul(dx, neg(sin(x))) }; - }); - - const Complex = struct({ re: Real, im: Real }); - - const complexp = fn([Complex], Complex, (z) => { - const c = exp(z.re); - return { re: mul(c, cos(z.im)), im: mul(c, sin(z.im)) }; - }); - - const f = fn([Complex, Complex], Vec(2, Complex), (z, w) => { - const { ret, grad } = vjp(complexp)(z); - return [ret, grad(w)]; - }); - - const s = pprint(f); - expect(s).toBe( - ` -fn f0 = <>{ - type T0 = F64 - opaque: (T0) -> T0 -} - -fn f1 = <>{ - type T0 = F64 - opaque: (T0) -> T0 -} - -fn f2 = <>{ - type T0 = F64 - opaque: (T0) -> T0 -} - -fn f3 = <>{ - type T0 = F64 - type T1 = F64 - type T2 = F64 - type T3 = Unit - type T4 = &T2 - type T5 = &T1 - type T6 = (T2, T0) - type T7 = (T2, T6) - (x0: T2) -> T7 { - let x3: T0 = f0<>(x0) - let x6: T6 = (x0, x3) - let x7: T7 = (x3, x6) - x7 - } -} - -fn f4 = <>{ - type T0 = F64 - type T1 = F64 - type T2 = F64 - type T3 = Unit - type T4 = &T2 - type T5 = &T0 - type T6 = &T1 - type T7 = (T2, T0, T0, T0) - type T8 = (T2, T7) - (x0: T2) -> T8 { - let x3: T0 = f1<>(x0) - let x4: T0 = f2<>(x0) - let x5: T0 = -x4 - let x8: T7 = (x0, x3, x4, x5) - let x9: T8 = (x3, x8) - x9 - } -} - -fn f5 = <>{ - type T0 = F64 - type T1 = F64 - type T2 = F64 - type T3 = Unit - type T4 = &T2 - type T5 = &T1 - type T6 = (T2, T0, T0) - type T7 = (T2, T6) - (x0: T2) -> T7 { - let x3: T0 = f2<>(x0) - let x4: T0 = f1<>(x0) - let x7: T6 = (x0, x3, x4) - let x8: T7 = (x3, x7) - x8 - } -} - -fn f6 = <>{ - type T0 = F64 - type T1 = F64 - type T2 = F64 - type T3 = F64 - type T4 = (T2, T2) - type T5 = Unit - type T6 = &T4 - type T7 = &T1 - type T8 = &T2 - type T9 = F64 - type T10 = &T9 - type T11 = (T9, T9) - type T12 = (T9, T11) - type T13 = (T2, T12) - type T14 = (T9, T9, T9, T9) - type T15 = (T9, T14) - type T16 = (T2, T15) - type T17 = &T0 - type T18 = (T9, T9, T9) - type T19 = (T9, T18) - type T20 = (T2, T19) - type T21 = (T4, T2, T12, T2, T15, T0, T2, T19, T0, T4) - type T22 = (T4, T21) - (x0: T4) -> T22 { - let x1: T2 = x0[1] - let x31: T13 = f3<>(x1) - let x2: T2 = x31[0] - let x32: T12 = x31[1] - let x3: T2 = x0[0] - let x33: T16 = f4<>(x3) - let x4: T2 = x33[0] - let x34: T15 = x33[1] - let x19: T0 = x2 * x4 - let x6: T2 = x0[0] - let x35: T20 = f5<>(x6) - let x7: T2 = x35[0] - let x36: T19 = x35[1] - let x27: T0 = x2 * x7 - let x9: T4 = (x27, x19) - let x37: T21 = (x0, x2, x32, x4, x34, x19, x7, x36, x27, x9) - let x38: T22 = (x9, x37) - x38 - } -} - -fn f7 = <>{ - type T0 = F64 - type T1 = F64 - type T2 = F64 - type T3 = Unit - type T4 = &T2 - type T5 = &T1 - type T6 = (T2, T0, T0) - (x9: T4, x14: T2, x8: T6) -> T3 { - let x7: T1 = 0 - let x0: T2 = x8[0] - let x3: T0 = x8[1] - let x4: T0 = x8[2] - let x10: T5 = accum x7 - let x15: T3 = x10 += x14 - let x11: T1 = resolve x10 - let x12: T1 = x11 * x4 - let x13: T3 = x9 += x12 - let x16: T3 = unit - x16 - } -} - -fn f8 = <>{ - type T0 = F64 - type T1 = F64 - type T2 = F64 - type T3 = Unit - type T4 = &T2 - type T5 = &T0 - type T6 = &T1 - type T7 = (T2, T0, T0, T0) - (x10: T4, x17: T2, x9: T7) -> T3 { - let x8: T1 = 0 - let x0: T2 = x9[0] - let x3: T0 = x9[1] - let x4: T0 = x9[2] - let x5: T0 = x9[3] - let x11: T5 = accum x5 - let x13: T6 = accum x8 - let x18: T3 = x13 += x17 - let x14: T1 = resolve x13 - let x15: T1 = x14 * x5 - let x16: T3 = x10 += x15 - let x12: T0 = resolve x11 - let x19: T3 = unit - x19 - } -} - -fn f9 = <>{ - type T0 = F64 - type T1 = F64 - type T2 = F64 - type T3 = Unit - type T4 = &T2 - type T5 = &T1 - type T6 = (T2, T0) - (x8: T4, x13: T2, x7: T6) -> T3 { - let x6: T1 = 0 - let x0: T2 = x7[0] - let x3: T0 = x7[1] - let x9: T5 = accum x6 - let x14: T3 = x9 += x13 - let x10: T1 = resolve x9 - let x11: T1 = x10 * x3 - let x12: T3 = x8 += x11 - let x15: T3 = unit - x15 - } -} - -fn f10 = <>{ - type T0 = F64 - type T1 = F64 - type T2 = F64 - type T3 = F64 - type T4 = (T2, T2) - type T5 = Unit - type T6 = &T4 - type T7 = &T1 - type T8 = &T2 - type T9 = F64 - type T10 = &T9 - type T11 = (T9, T9) - type T12 = (T9, T11) - type T13 = (T2, T12) - type T14 = (T9, T9, T9, T9) - type T15 = (T9, T14) - type T16 = (T2, T15) - type T17 = &T0 - type T18 = (T9, T9, T9) - type T19 = (T9, T18) - type T20 = (T2, T19) - type T21 = (T4, T2, T12, T2, T15, T0, T2, T19, T0, T4) - (x33: T6, x85: T4, x32: T21) -> T5 { - let x31: T1 = 0 - let x0: T4 = x32[0] - let x34: T7 = accum x31 - let x1: T2 = x0[1] - let x36: T8 = &x33[1] - let x2: T2 = x32[1] - let x37: T12 = x32[2] - let x38: T8 = accum x2 - let x3: T2 = x0[0] - let x41: T8 = &x33[0] - let x4: T2 = x32[3] - let x42: T15 = x32[4] - let x43: T8 = accum x4 - let x19: T0 = x32[5] - let x46: T17 = accum x19 - let x48: T7 = accum x31 - let x52: T7 = accum x31 - let x56: T7 = accum x31 - let x6: T2 = x0[0] - let x60: T8 = &x33[0] - let x7: T2 = x32[6] - let x61: T19 = x32[7] - let x62: T8 = accum x7 - let x27: T0 = x32[8] - let x65: T17 = accum x27 - let x67: T7 = accum x31 - let x71: T7 = accum x31 - let x75: T7 = accum x31 - let x9: T4 = x32[9] - let x79: T6 = accum x9 - let x86: T5 = x79 += x85 - let x80: T4 = resolve x79 - let x83: T2 = x80[1] - let x84: T5 = x56 += x83 - let x81: T2 = x80[0] - let x82: T5 = x75 += x81 - let x76: T1 = resolve x75 - let x78: T5 = x71 += x76 - let x77: T5 = x67 += x76 - let x72: T1 = resolve x71 - let x73: T1 = x72 * x2 - let x74: T5 = x62 += x73 - let x68: T1 = resolve x67 - let x69: T1 = x68 * x7 - let x70: T5 = x38 += x69 - let x66: T0 = resolve x65 - let x63: T2 = resolve x62 - let x64: T5 = f7<>(x60, x63, x61) - let x57: T1 = resolve x56 - let x59: T5 = x52 += x57 - let x58: T5 = x48 += x57 - let x53: T1 = resolve x52 - let x54: T1 = x53 * x2 - let x55: T5 = x43 += x54 - let x49: T1 = resolve x48 - let x50: T1 = x49 * x4 - let x51: T5 = x38 += x50 - let x47: T0 = resolve x46 - let x44: T2 = resolve x43 - let x45: T5 = f8<>(x41, x44, x42) - let x39: T2 = resolve x38 - let x40: T5 = f9<>(x36, x39, x37) - let x35: T1 = resolve x34 - let x87: T5 = unit - x87 - } -} - -fn f11 = <>{ - type T0 = F64 - type T1 = F64 - type T2 = (T0, T0) - type T3 = 2 - type T4 = [T3]T2 - type T5 = Unit - type T6 = &T2 - type T7 = &T0 - type T8 = (T0, T0) - type T9 = (T0, T8) - type T10 = (T0, T9) - type T11 = (T0, T0, T0, T0) - type T12 = (T0, T11) - type T13 = (T0, T12) - type T14 = (T0, T0, T0) - type T15 = (T0, T14) - type T16 = (T0, T15) - type T17 = (T2, T0, T9, T0, T12, T0, T0, T15, T0, T2) - type T18 = (T2, T17) - (x0: T2, x1: T2) -> T4 { - let x2: T18 = f6<>(x0) - let x3: T2 = x2[0] - let x4: T17 = x2[1] - let x5: T6 = accum x0 - let x6: T5 = f10<>(x5, x1, x4) - let x7: T2 = resolve x5 - let x8: T4 = [x3, x7] - x8 - } -} -`.trimStart(), - ); - }); -}); diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts deleted file mode 100644 index 99b02e3..0000000 --- a/packages/core/src/impl.ts +++ /dev/null @@ -1,1137 +0,0 @@ -import * as wasm from "@rose-lang/wasm"; -import { Val as RawVal } from "@rose-lang/wasm/interp/Val"; - -/** - * The user-constructed functions we expose in our API hold reference-counted - * pointers to Rust objects living in WebAssembly memory. To avoid memory leaks, - * we use this finalization registry to free those objects once they are no - * longer referenced by any JavaScript objects. This registry takes a - * `wasm.Func` object, so that can't directly be what we hold onto: instead, the - * `Fn` type holds a `wasm.Func` internally, so we register that as the target - * and pass the internal value to be finalized. This does mean that if a user - * reaches in and pulls out the wrapped object, it could become invalid if they - * don't hold onto the wrapper, so that inner object must be treated as private. - */ -const funcs = new FinalizationRegistry((f: wasm.Func) => f.free()); - -/** - * Property key for the internal `wasm.Func` object held by a `Fn`. - * - * This can of course be reached via `Object.getOwnPropertySymbols`, but that - * requires an intentional effort to circumvent the encapsulation; such usage is - * discouraged. - */ -export const inner = Symbol("inner"); - -/** Property key for a `Fn`'s string array for resolving struct key names. */ -const strings = Symbol("strings"); - -interface FnBase { - [inner]: wasm.Func; - [strings]: string[]; -} - -/** An abstract function. */ -export interface Fn extends FnBase { - jvp?: Fn; -} - -/** Adds `f` to the registry, mutates it into a full `Fn`, then returns it. */ -const makeFn = (f: FnBase): Fn => { - funcs.register(f, f[inner]); - Object.defineProperty(f, "jvp", { - set(g: Fn) { - f[inner].setJvp(g[inner]); - }, - }); - return f as Fn; -}; - -/** Property key for a variable ID. */ -const variable = Symbol("variable"); - -/** - * An abstract variable. - * - * The `Context` is implicit; don't use a `Var` outside of its original context. - * - * This class is used for primitive types that can't be indices; index types are - * instead represented by `Symbol`, and non-primitive types are represented by - * `Proxy`. - */ -interface Var { - [variable]: number; -} - -/** - * Construct an abstract variable with the given `id`. - * - * This should not be used for any values that instead needs to be represented - * as a `Symbol` or a `Proxy`. - */ -const newVar = (id: number): Var => ({ [variable]: id }); - -/** An abstract null value. */ -export type Null = null | Var; - -/** An abstract boolean. */ -export type Bool = boolean | Var; - -/** An abstract 64-bit floating point number. */ -export type Real = number | Var; - -/** The zero tangent. */ -const zeroSymbol = Symbol("zero"); - -/** The zero tangent. */ -type Zero = typeof zeroSymbol; - -/** An abstract 64-bit floating point tangent number. */ -export type Tan = Zero | Var; - -/** An abstract natural number, which can be used to index into a vector. */ -type Nat = number | symbol; - -/** The portion of an abstract vector that can be directly indexed. */ -interface VecIndex { - [K: Nat]: T; -} - -/** The portion of an abstract vector that can be destructured. */ -interface VecIter { - [Symbol.iterator](): IterableIterator; -} - -/** An abstract vector, which can be indexed by a natural number. */ -export type Vec = VecIndex & VecIter; - -/** The context for an abstract function under construction. */ -interface Context { - /** Handle to Rust-side information about the function itself. */ - func: wasm.FuncBuilder; - - /** Handle to Rust-side information about the current block being built. */ - block: wasm.Block; - - /** - * Variable IDs for abstract values, indexed by type ID. - * - * For instance, the number `42` may represent a floating-point number, or it - * may represent a natural number that is being used as an index. Similarly, - * an array of numbers could be a vector of floating-point numbers or a vector - * of indices. We need to know the type first before we can interpret it. - */ - variables: Map>; - - /** The actual string represented each string ID. */ - strings: string[]; - - /** Reverse lookup from strings to IDs; also used for deduplication. */ - stringIds: Map; -} - -/** Struct field name for the nonlinear part of a dual number. */ -const re = "re"; - -/** Struct field name for the linear part of a dual number. */ -const du = "du"; - -/** String ID for `"re"`. */ -const reId = 0; - -/** String ID for `"du"`. */ -const duId = 1; - -/** Return a fresh initial string cache for constructing a new function. */ -const initStrings = (): { - strings: string[]; - stringIds: Map; -} => ({ - strings: [re, du], // order matches `reId` and `duId` - stringIds: new Map([ - [re, reId], - [du, duId], - ]), -}); - -/** Return the variable ID map for type ID `t`, creating it if necessary. */ -const typeMap = (ctx: Context, t: number): Map => { - let map = ctx.variables.get(t); - if (map === undefined) { - // the map is keyed by type, so we don't always have an entry for every - // possible type; if we're missing one, just initialize it as an empty map - map = new Map(); - ctx.variables.set(t, map); - } - return map; -}; - -/** - * Return the variable ID for the given abstract value with the given type. - * - * Throws if the value does not match the type. - */ -const valId = (ctx: Context, t: number, x: unknown): number => { - const map = typeMap(ctx, t); - let id = map.get(x); - if (id !== undefined) return ctx.func.expect(t, id); // check if out of scope - - if (typeof x === "boolean") id = ctx.func.bool(t, x); - else if (typeof x === "number") id = ctx.func.num(t, x); - else if (typeof x === "object") { - if (x === null) id = ctx.func.unit(t); - else { - id = (x as any)[variable]; - // if `x` is a `Var` or some sort of `Proxy` that we constructed then the - // `variable` property will be present and it will be a number, but - // otherwise it shouldn't be present; we just get it directly and check - // its type instead of using JavaScript's `in` operator, because - // TypeScript doesn't really like that here, and also because it's - // possible that someone manually reached in and extracted the `variable` - // symbol from somewhere else and put it somewhere it wasn't supposed to - // be; that's one of the reasons we still call `expect` here to - // double-check the type (the other reasons being that a `Var` could be - // accidentally used outside its original `Context`, or inside its - // original context but after it has already dropped out of scope) - if (typeof id === "number") id = ctx.func.expect(t, id); - else { - if (Array.isArray(x)) { - // arrays - const size = ctx.func.size(t); - const elem = ctx.func.elem(t); - if (x.length !== size) throw Error("wrong array size"); - const xs = new Uint32Array(size); - for (let i = 0; i < size; ++i) xs[i] = valId(ctx, elem, x[i]); - id = ctx.func.array(t, xs); - } else { - // structs - const keys = ctx.func.keys(t); - const mems = ctx.func.members(t); - const entries = Object.entries(x).sort(([a], [b]) => compare(a, b)); - if ( - !( - entries.length === keys.length && - // they're sorted - entries.every(([k], i) => k === ctx.strings[keys[i]]) - ) - ) - throw Error("wrong struct keys"); - const ys = new Uint32Array( - entries.map(([, y], i) => valId(ctx, mems[i], y)), - ); - id = ctx.func.obj(t, ys); - } - } - } - } else throw Error("invalid value"); - - map.set(x, id); - return id; -}; - -/** Context for the current function under construction. */ -let context: Context | undefined = undefined; - -/** Return the current context if there is one; throw otherwise. */ -const getCtx = (): Context => { - if (context === undefined) throw Error("no `fn` context found"); - return context; -}; - -/** The null type. */ -export const Null = Symbol("Null"); - -/** The boolean type. */ -export const Bool = Symbol("Bool"); - -/** The 64-bit floating-point type. */ -export const Real = Symbol("Real"); - -/** The 64-bit floating-point tangent type. */ -export const Tan = Symbol("Tan"); - -/** Representation of the null type. */ -export type Nulls = typeof Null; - -/** Representation of the boolean type. */ -export type Bools = typeof Bool; - -/** Representation of the 64-bit floating point type. */ -export type Reals = typeof Real; - -/** Representation of the 64-bit floating point tangent type. */ -export type Tans = typeof Tan; - -/** Representation of a bounded index type (it's just the upper bound). */ -export type Nats = number; - -/** Property key for the index type ID of a vector type. */ -const ind = Symbol("index"); - -/** Property key for the element type ID of a vector type. */ -const elm = Symbol("elem"); - -/** Representation of a vector type. */ -export interface Vecs { - [ind]: K; - [elm]: V; -} - -/** The type of vectors from the index type `K` to the element type `V`. */ -export const Vec = (index: K, elem: V): Vecs => { - return { [ind]: index, [elm]: elem }; -}; - -/** Create a struct type. */ -export const struct = (t: T): T => t; - -/** The 128-bit floating-point dual number type. */ -export const Dual = struct({ re: Real, du: Tan }); - -// TODO: make this locale-independent -const compare = (a: string, b: string): number => a.localeCompare(b); - -/** Return an array of the IDs for the given `strs` in `ctx`. */ -const intern = (ctx: Context, strs: string[]): Uint32Array => - new Uint32Array( - strs.map((s) => { - let i = ctx.stringIds.get(s); - if (i === undefined) { - i = ctx.strings.length; - ctx.strings.push(s); - ctx.stringIds.set(s, i); - } - return i; - }), - ); - -/** Return the type ID for `ty` in `ctx`, creating the type if needed. */ -const tyId = (ctx: Context, ty: unknown): number => { - if (ty === Null) return ctx.func.tyUnit(); - else if (ty === Bool) return ctx.func.tyBool(); - else if (ty === Real) return ctx.func.tyF64(); - else if (ty === Tan) return ctx.func.tyT64(); - else if (typeof ty === "number") return ctx.func.tyFin(ty); - else if (typeof ty === "object" && ty !== null) { - if (ind in ty && elm in ty) - // arrays - return ctx.func.tyArray(tyId(ctx, ty[ind]), tyId(ctx, ty[elm])); - else { - // structs - const entries = Object.entries(ty).sort(([a], [b]) => compare(a, b)); - const strs = intern( - ctx, - entries.map(([s]) => s), - ); - const tys = new Uint32Array(entries.map(([, t]) => tyId(ctx, t))); - return ctx.func.tyStruct(strs, tys); - } - } else throw Error("invalid type"); -}; - -/** - * Return a symbolic JS representation of the variable with the given `id`. - * - * This depends on the type `t` because if the variable is of an index type then - * it must be represented by a `Symbol`, and if it is of a vector type then it - * must be represented by a `Proxy`. - */ -const idVal = (ctx: Context, t: number, id: number): unknown => { - if (ctx.func.isSymbol(t)) { - const sym = Symbol(); - typeMap(ctx, t).set(sym, id); - return sym; - } else if (ctx.func.isArray(t)) return arrayProxy(ctx, t, id); - else if (ctx.func.isStruct(t)) return structProxy(ctx, t, id); - else return newVar(id); -}; - -/** - * Return a `Proxy` vector of type `t` for the variable ID `v`. - * - * The original variable ID `v` can be accessed via the `variable` symbol - * property key. Any string access will be parsed as a literal integer index. - * Any other symbol access will use the `Context`'s `symbols` map. - */ -const arrayProxy = (ctx: Context, t: number, v: number): Vec => { - const index = ctx.func.index(t); - const elem = ctx.func.elem(t); - return new Proxy( - { - *[Symbol.iterator]() { - // we assume the context was already checked - const n = ctx.func.size(t); - for (let i = 0; i < n; ++i) { - // but the context can change between `yield`s - if (getCtx() !== ctx) throw Error("array escaped its context"); - const x = ctx.block.index(ctx.func, elem, v, valId(ctx, index, i)); - yield idVal(ctx, elem, x); - } - }, - }, - { - get: (target: Vec, prop) => { - if (getCtx() !== ctx) throw Error("array escaped its context"); - if (prop === Symbol.iterator) return target[prop]; - if (prop === variable) return v; - const i = typeof prop === "string" ? parseInt(prop, 10) : prop; - const x = ctx.block.index(ctx.func, elem, v, valId(ctx, index, i)); - return idVal(ctx, elem, x); - }, - }, - ); -}; - -/** - * Return a `Proxy` struct of type `t` for the variable ID `x`. - * - * The original variable ID `x` can be accessed via the `variable` symbol - * property key. Any other symbol access will throw an `Error`. Any string - * access will emit a member instruction if the key is valid, throw otherwise. - */ -const structProxy = ( - ctx: Context, - t: number, - x: number, -): { [K: string]: unknown } => { - const keys = ctx.func.keys(t); - const mems = ctx.func.members(t); - const map = new Map(); - for (let i = 0; i < keys.length; ++i) map.set(ctx.strings[keys[i]], i); - return new Proxy( - {}, - { - get: (target, prop) => { - if (getCtx() !== ctx) throw Error("struct escaped its context"); - if (prop === variable) return x; - if (typeof prop === "symbol") throw Error("unexpected symbol"); - const i = map.get(prop); - if (i === undefined) throw Error("unexpected key"); - const t = mems[i]; - const y = ctx.block.member(ctx.func, t, x, i); - return idVal(ctx, t, y); - }, - }, - ); -}; - -/** Insert a call instruction to `f` in the current context. */ -const call = (f: Fn, generics: Uint32Array, args: unknown[]): unknown => { - const ctx = getCtx(); - const strs = intern(ctx, f[strings]); - // we first need to pull in the signature types from `f` so we know how to - // interpret the abstract values for its arguments - const sig = ctx.func.ingest(f[inner], strs, generics); - const ret = sig[sig.length - 1]; // the last type is the return type - // TODO: we should probably check that `args` is the right length - const vars = new Uint32Array(args.map((arg, i) => valId(ctx, sig[i], arg))); - const id = ctx.block.call(ctx.func, f[inner], generics, ret, vars); - return idVal(ctx, ret, id); -}; - -/** - * Map from a type of a type to the type of the symbolic values it represents. - * - * This should be used in any situation where the abstract value is being - * synthesized (e.g. the parameters of the body of `fn` or `vec`, or the - * returned value from a call or `select`). - */ -export type Symbolic = T extends Nulls - ? Null - : T extends Bools - ? Bool - : T extends Reals - ? Real - : T extends Tans - ? Tan - : T extends Nats - ? Nat - : T extends Vecs - ? Vec> - : { [K in keyof T]: Symbolic }; - -/** - * Map from a type of a type to the type of the abstract values it represents. - * - * This should be used in any situation where the abstract value is provided by - * the user (e.g. the returned value in the body of `fn` or `vec`, or the inputs - * to a call or `select). - */ -export type Value = T extends Nulls - ? Null - : T extends Bools - ? Bool - : T extends Reals - ? Real - : T extends Tans - ? Tan - : T extends Nats - ? Nat - : T extends Vecs - ? Vec> | Value[] - : { [K in keyof T]: Value }; - -/** Map from parameter type array to symbolic parameter value type array. */ -type SymbolicParams = { - [K in keyof T]: Symbolic; -}; - -/** Map from parameter type array to abstract parameter value type array. */ -type ValueParams = { - [K in keyof T]: Value; -}; - -/** Construct an abstract function by abstractly interpreting `f` once. */ -export const fn = ( - params: P, - ret: R, - f: (...args: SymbolicParams

) => Value, -): Fn & ((...args: ValueParams

) => Symbolic) => { - // TODO: support closures - if (context !== undefined) - throw Error("can't define a function while defining another function"); - let out: number | undefined = undefined; // function return variable ID - let { strings: strs, stringIds } = initStrings(); - const builder = new wasm.FuncBuilder(0); // TODO: support generics - const body = new wasm.Block(); - try { - const ctx = { - func: builder, - block: body, - variables: new Map(), - strings: strs, - stringIds, - }; - context = ctx; - const args = params.map((ty) => { - const t = tyId(ctx, ty); - // it's important that `map` runs eagerly an in order, because the - // ordering of these calls to `param` must match the order of `params` - return idVal(ctx, t, ctx.func.param(t)); - }) as SymbolicParams

; - const ty = tyId(ctx, ret); - const x = f(...args); - out = valId(ctx, ty, x); - } finally { - context = undefined; - if (out === undefined) { - // if we didn't reach the assignment statement defining `out` then there - // must have been an error, so we won't reach the below call to - // `builder.finish`; thus we must free these handles to avoid memory leaks - body.free(); - builder.free(); - } - } - const func = builder.finish(out, body); - const g: any = (...args: any): any => - // TODO: support generics - call(g, new Uint32Array(), args); - g[inner] = func; - g[strings] = strs; - return makeFn(g) as any; -}; - -/** Construct an opaque function whose implementation runs `f`. */ -export const opaque = ( - params: P, - ret: R, - f: (...args: JsArgs>) => ToJs>, -): Fn & ((...args: ValueParams

) => Symbolic) => { - // TODO: support more complicated signatures for opaque functions - const func = new wasm.Func(params.length, f); - const g: any = (...args: any): any => - // TODO: support generics - call(g, new Uint32Array(), args); - g[inner] = func; - g[strings] = initStrings().strings; // TODO: allow structs in opaque functions - return makeFn(g) as any; -}; - -/** A concrete value. */ -type Js = null | boolean | number | Js[] | { [K: string]: Js }; - -/** Translate from the interpreteer's raw format to a concrete value. */ -const pack = (f: Fn, t: number, x: unknown): RawVal => { - const func = f[inner]; - if (typeof x === "boolean") return { Bool: x }; - else if (typeof x === "number") - return func.isFin(t) ? { Fin: x } : { F64: x }; - else if (typeof x === "object") { - if (x === null) return "Unit"; - else if (Array.isArray(x)) - return { Array: x.map((y) => pack(f, func.elem(t), y)) }; - else { - const keys = func.keys(t); - const mems = func.mems(t); - const vals: RawVal[] = []; - for (let i = 0; i < keys.length; ++i) { - vals.push(pack(f, mems[i], (x as any)[f[strings][keys[i]]])); - } - return { Tuple: vals }; - } - } else throw Error("invalid value"); -}; - -/** Translate a concrete value from the interpreter's raw format. */ -const unpack = (f: Fn, t: number, x: RawVal): Js => { - const func = f[inner]; - if (x === "Unit") return null; - if ("Bool" in x) return x.Bool; - if ("F64" in x) return x.F64; - if ("Fin" in x) return x.Fin; - if ("Ref" in x) throw Error("Ref not supported"); - if ("Array" in x) - return x.Array.map((y: RawVal) => unpack(f, func.elem(t), y)); - else { - const keys = func.keys(t); - const mems = func.mems(t); - return Object.fromEntries( - x.Tuple.map((y: RawVal, i: number) => [ - f[strings][keys[i]], - unpack(f, mems[i], y), - ]), - ); - } -}; - -/** Map from an abstract value type to its corresponding concrete value type. */ -// https://www.typescriptlang.org/docs/handbook/2/conditional-types.html -type ToJs = [T] extends [Null] - ? null - : [T] extends [Bool] - ? boolean - : [T] extends [Real] - ? number - : [T] extends [Nat] - ? number - : { [K in keyof T]: ToJs }; - -/** Map from an abstract value type array to a concrete argument type array. */ -type JsArgs = { - [K in keyof T]: ToJs; -}; - -/** Concretize the abstract function `f` using the interpreter. */ -export const interp = - ( - f: Fn & ((...args: A) => R), - ): ((...args: JsArgs) => ToJs) => - // TODO: support interpreting functions with generics - (...args) => { - const func = f[inner]; - const params = func.paramTypes(); - const vals = args.map((x, i) => pack(f, params[i], x)); - return unpack(f, func.retType(), func.interp(vals)) as ToJs; - }; - -// https://github.com/rose-lang/rose/issues/116 - -// TODO: use something more like an enum -interface Layout { - size: number; - align: number; -} - -/** Round up `size` to the nearest multiple of `align`. */ -const aligned = ({ size, align }: Layout): number => - (size + align - 1) & ~(align - 1); - -/** An aligned `ArrayBuffer` view, or `undefined` for zero-sized types. */ -type View = undefined | Uint8Array | Uint16Array | Uint32Array | Float64Array; - -const getView = (buffer: ArrayBuffer, layout: Layout, offset: number): View => { - // this code assumes that the layout is uniquely determined by its `size` - const { size } = layout; - if (size === 0) return undefined; - else if (size === 1) return new Uint8Array(buffer, offset); - else if (size === 2) return new Uint16Array(buffer, offset); - else if (size === 4) return new Uint32Array(buffer, offset); - else if (size === 8) return new Float64Array(buffer, offset); - else throw Error("unknown layout"); -}; - -/** Memory representation for a type. */ -interface Meta { - /** Layout of an individual value of this type in memory. */ - layout: Layout; - - /** - * Return the Wasm representation of the JS value `x`. - * - * The given byte offset is only used for pointer types. - */ - encode: (x: unknown, offset: number) => number; - - /** Total memory cost of an object of this type, including sub-allocations. */ - cost: number; - - /** Return a JS value represented by the Wasm value `x`. */ - decode: (x: number) => unknown; -} - -/** - * Return enough information to encode and decode Wasm values with type ID `t`. - * - * The given function `f` must have already been compiled to WebAssembly, - * yielding the given `buffer` of sufficient size. The `metas` array should hold - * encoding/decoding information for all types with IDs less than `t`, or - * `undefined` for reference types and non-struct tuple types since those cannot - * appear in user-facing function signatures. - */ -const getMeta = ( - f: Fn, - buffer: ArrayBuffer, - metas: (Meta | undefined)[], - t: number, -): Meta | undefined => { - const func = f[inner]; - if (func.isUnit(t)) { - return { - layout: { size: 0, align: 1 }, - encode: () => 0, - cost: 0, - decode: () => null, - }; - } else if (func.isBool(t)) { - return { - layout: { size: 1, align: 1 }, - encode: (x) => (x ? 1 : 0), - cost: 0, - decode: Boolean, - }; - } else if (func.isF64(t)) { - return { - layout: { size: 8, align: 8 }, - encode: (x) => x as number, - cost: 0, - decode: (x) => x, - }; - } else if (func.isFin(t)) { - const size = func.size(t); - const layout = - size <= 1 - ? { size: 0, align: 1 } - : size <= 256 - ? { size: 1, align: 1 } - : size <= 65536 - ? { size: 2, align: 2 } - : { size: 4, align: 4 }; - return { layout, encode: (x) => x as number, cost: 0, decode: (x) => x }; - } else if (func.isArray(t)) { - const meta = metas[func.elem(t)]; - if (meta === undefined) return undefined; - const { layout, encode, cost, decode } = meta; - const n = func.size(func.index(t)); - const elem = aligned(layout); - const total = aligned({ size: n * elem, align: 8 }); - const view = getView(buffer, layout, 0); - return { - layout: { size: 4, align: 4 }, - encode: - view === undefined - ? (x, offset) => offset - : (x, offset) => { - let child = offset + total; - for (let i = 0; i < n; ++i) { - view[offset / elem + i] = encode((x as unknown[])[i], child); - child += cost; - } - return offset; - }, - cost: total + n * cost, - decode: - view === undefined - ? () => { - const arr: unknown[] = []; - // this code assumes that all values of all zero-sized types can - // be represented by zero - for (let i = 0; i < n; ++i) arr.push(decode(0)); - return arr; - } - : (x) => { - const arr: unknown[] = []; - for (let i = 0; i < n; ++i) arr.push(decode(view[x / elem + i])); - return arr; - }, - }; - } else if (func.isStruct(t)) { - const keys = func.keys(t); - const members = func.mems(t); - const n = keys.length; - const mems: { key: string; meta: Meta; view?: View; child?: number }[] = []; - for (let i = 0; i < n; ++i) { - const meta = metas[members[i]]; - if (meta === undefined) return undefined; - mems.push({ key: f[strings][keys[i]], meta }); - } - mems.sort((a, b) => a.meta.layout.align - b.meta.layout.align); - let cost = 0; - let offset = 0; - for (const mem of mems) { - const { meta } = mem; - mem.child = cost; - cost += meta.cost; - const { layout } = meta; - const { size, align } = layout; - offset = aligned({ size: offset, align }); - mem.view = getView(buffer, layout, offset); - offset += size; - } - const total = aligned({ size: offset, align: 8 }); - return { - layout: { size: 4, align: 4 }, - encode: (x, offset) => { - for (const { key, meta, view, child } of mems) { - // instead of mutating each element of `mems` above to add more data - // and then still having an `if` statement in here, it would be nicer - // to just map over `mems` above to produce an array of closures that - // can be called directly, with the condition on `view === undefined` - // being handled once rather than in every call to `encode` here - if (view !== undefined) { - view[offset / aligned(meta.layout)] = meta.encode( - (x as any)[key], - offset + total + child!, - ); - } - } - return offset; - }, - cost: total + cost, - decode: (x) => { - const obj: any = {}; - for (const { key, meta, view } of mems) { - if (view === undefined) { - // this code assumes that all values of all zero-sized types can be - // represented by zero - obj[key] = meta.decode(0); - } else { - obj[key] = meta.decode(view[x / aligned(meta.layout)]); - } - } - return obj; - }, - }; - } else return undefined; -}; - -/** Concretize the abstract function `f` using the compiler. */ -export const compile = async ( - f: Fn & ((...args: A) => R), -): Promise<(...args: JsArgs) => ToJs> => { - const func = f[inner]; - const res = func.compile(); - const bytes = res.bytes()!; - const imports = res.imports()!; - res.free(); - const instance = await WebAssembly.instantiate( - await WebAssembly.compile(bytes), - { "": Object.fromEntries(imports.map((g, i) => [i.toString(), g])) }, - ); - const { f: g, m } = instance.exports; - const metas: (Meta | undefined)[] = []; - const n = func.numTypes(); - for (let t = 0; t < n; ++t) - metas.push(getMeta(f, (m as WebAssembly.Memory).buffer, metas, t)); - let total = 0; - const params = Array.from(func.paramTypes()).map((t) => { - const { encode, cost } = metas[t]!; - const offset = total; - total += cost; - return { encode, offset }; - }); - const { decode } = metas[func.retType()]!; - return (...args): any => { - const vals = params.map(({ encode, offset }, i) => encode(args[i], offset)); - return decode((g as any)(...vals, total)); - }; -}; - -// https://www.typescriptlang.org/docs/handbook/2/conditional-types.html -type ToJvp = [T] extends [Null] - ? Null - : [T] extends [Bool] - ? Bool - : [T] extends [Real] - ? { re: Real; du: Real } - : [T] extends [Nat] - ? Nat - : { [K in keyof T]: ToJvp }; - -type JvpArgs = { - [K in keyof T]: ToJvp; -}; - -/** Construct a function that computes the Jacobian-vector product of `f`. */ -export const jvp = ( - f: Fn & ((...args: A) => R), -): Fn & ((...args: JvpArgs) => ToJvp) => { - const strs = [...f[strings]]; - const func = f[inner].jvp(reId, duId); - const g: any = (...args: any): any => - // TODO: support generics - call(g, new Uint32Array(), args); - g[inner] = func; - g[strings] = strs; - return makeFn(g) as any; -}; - -/** Construct a closure that computes the vector-Jacobian product of `f`. */ -export const vjp = ( - f: Fn & ((arg: A) => R), -): ((arg: A) => { ret: R; grad: (cot: R) => A }) => { - const g = jvp(f); - const tp = g[inner].transpose(); - const fwdFunc = tp.fwd()!; - const bwdFunc = tp.bwd()!; - tp.free(); - const fwd = makeFn({ [inner]: fwdFunc, [strings]: [...f[strings]] }); - const bwd = makeFn({ [inner]: bwdFunc, [strings]: [...f[strings]] }); - return (arg: A) => { - const ctx = getCtx(); - const strs = intern(ctx, fwd[strings]); - const generics = new Uint32Array(); // TODO: support generics - const [tArg, tBundle] = ctx.func.ingest(fwd[inner], strs, generics); - const [tRet, tInter] = ctx.func.members(tBundle); - const argId = valId(ctx, tArg, arg); - const bundleId = ctx.block.call( - ctx.func, - fwd[inner], - generics, - tBundle, - new Uint32Array([argId]), - ); - const primalId = ctx.block.member(ctx.func, tRet, bundleId, 0); - const interId = ctx.block.member(ctx.func, tInter, bundleId, 1); - const grad = (cot: R) => { - if (getCtx() !== ctx) throw Error("VJP closure escaped its context"); - const cotId = valId(ctx, tRet, cot); - const tRef = ctx.func.tyRef(tArg); - const accId = ctx.block.accum(ctx.func, tRef, argId); - ctx.block.call( - ctx.func, - bwd[inner], - new Uint32Array([]), // TODO: support generics - ctx.func.tyUnit(), - new Uint32Array([accId, cotId, interId]), - ); - return idVal(ctx, tArg, ctx.block.resolve(ctx.func, tArg, accId)) as A; - }; - return { ret: idVal(ctx, tRet, primalId) as R, grad }; - }; -}; - -/** Return the variable ID for the abstract boolean `x`. */ -const boolId = (ctx: Context, x: Bool): number => - valId(ctx, ctx.func.tyBool(), x); - -/** Return the negation of the abstract boolean `p`. */ -export const not = (p: Bool): Bool => { - const ctx = getCtx(); - return newVar(ctx.block.not(ctx.func, boolId(ctx, p))); -}; - -/** Return the conjunction of the abstract booleans `p` and `q`. */ -export const and = (p: Bool, q: Bool): Bool => { - const ctx = getCtx(); - return newVar(ctx.block.and(ctx.func, boolId(ctx, p), boolId(ctx, q))); -}; - -/** Return the disjunction of the abstract booleans `p` and `q`. */ -export const or = (p: Bool, q: Bool): Bool => { - const ctx = getCtx(); - return newVar(ctx.block.or(ctx.func, boolId(ctx, p), boolId(ctx, q))); -}; - -/** Return the biconditional of the abstract booleans `p` and `q`. */ -export const iff = (p: Bool, q: Bool): Bool => { - const ctx = getCtx(); - return newVar(ctx.block.iff(ctx.func, boolId(ctx, p), boolId(ctx, q))); -}; - -/** Return the exclusive disjunction of the abstract booleans `p` and `q`. */ -export const xor = (p: Bool, q: Bool): Bool => { - const ctx = getCtx(); - return newVar(ctx.block.xor(ctx.func, boolId(ctx, p), boolId(ctx, q))); -}; - -/** Return an abstract value selecting between `then` and `els` via `cond`. */ -export const select = ( - cond: Bool, - ty: T, - then: Value, - els: Value, -): Symbolic => { - const ctx = getCtx(); - const t = tyId(ctx, ty); - const p = boolId(ctx, cond); - const a = valId(ctx, t, then); - const b = valId(ctx, t, els); - return idVal(ctx, t, ctx.block.select(ctx.func, p, t, a, b)) as Symbolic; -}; - -/** Return the variable ID for the abstract floating point number `x`. */ -const realId = (ctx: Context, x: Real): number => - valId(ctx, ctx.func.tyF64(), x); - -/** Return the absolute value of the abstract number `x`. */ -export const abs = (x: Real): Real => { - const ctx = getCtx(); - return newVar(ctx.block.abs(ctx.func, realId(ctx, x))); -}; - -/** Return the signum of the abstract number `x`. */ -export const sign = (x: Real): Real => { - const ctx = getCtx(); - return newVar(ctx.block.sign(ctx.func, realId(ctx, x))); -}; - -/** Return the ceiling of the abstract number `x`. */ -export const ceil = (x: Real): Real => { - const ctx = getCtx(); - return newVar(ctx.block.ceil(ctx.func, realId(ctx, x))); -}; - -/** Return the floor of the abstract number `x`. */ -export const floor = (x: Real): Real => { - const ctx = getCtx(); - return newVar(ctx.block.floor(ctx.func, realId(ctx, x))); -}; - -/** Return the truncation of the abstract number `x`. */ -export const trunc = (x: Real): Real => { - const ctx = getCtx(); - return newVar(ctx.block.trunc(ctx.func, realId(ctx, x))); -}; - -/** Return the square root of the abstract number `x`. */ -export const sqrt = (x: Real): Real => { - const ctx = getCtx(); - return newVar(ctx.block.sqrt(ctx.func, realId(ctx, x))); -}; - -/** Return an abstract boolean for if `x` is not equal to `y`. */ -export const neq = (x: Real, y: Real): Bool => { - const ctx = getCtx(); - return newVar(ctx.block.neq(ctx.func, realId(ctx, x), realId(ctx, y))); -}; - -/** Return an abstract boolean for if `x` is less than `y`. */ -export const lt = (x: Real, y: Real): Bool => { - const ctx = getCtx(); - return newVar(ctx.block.lt(ctx.func, realId(ctx, x), realId(ctx, y))); -}; - -/** Return an abstract boolean for if `x` is less than or equal to `y`. */ -export const leq = (x: Real, y: Real): Bool => { - const ctx = getCtx(); - return newVar(ctx.block.leq(ctx.func, realId(ctx, x), realId(ctx, y))); -}; - -/** Return an abstract boolean for if `x` is equal to `y`. */ -export const eq = (x: Real, y: Real): Bool => { - const ctx = getCtx(); - return newVar(ctx.block.eq(ctx.func, realId(ctx, x), realId(ctx, y))); -}; - -/** Return an abstract boolean for if `x` is greater than `y`. */ -export const gt = (x: Real, y: Real): Bool => { - const ctx = getCtx(); - return newVar(ctx.block.gt(ctx.func, realId(ctx, x), realId(ctx, y))); -}; - -/** Return an abstract boolean for if `x` is greater than or equal to `y`. */ -export const geq = (x: Real, y: Real): Bool => { - const ctx = getCtx(); - return newVar(ctx.block.geq(ctx.func, realId(ctx, x), realId(ctx, y))); -}; - -/** Return an abstract vector by computing each element via `f`. */ -export const vec = ( - index: I, - elem: T, - f: (i: Symbolic) => Value, -): Vec> => { - const ctx = getCtx(); - const i = tyId(ctx, index); - const e = tyId(ctx, elem); - const t = ctx.func.tyArray(i, e); - const arg = ctx.func.bind(i); - const block = ctx.block; - let out: number | undefined = undefined; - const body = new wasm.Block(); - try { - ctx.block = body; - out = valId(ctx, e, f(idVal(ctx, i, arg) as Symbolic)); - } finally { - if (out === undefined) body.free(); - ctx.block = block; - } - const id = block.vec(ctx.func, t, arg, body, out); - return idVal(ctx, t, id) as Vec>; -}; - -/** Return the variable ID for the abstract number or tangent `x`. */ -const numId = (ctx: Context, x: Real | Tan): number => { - if (typeof x === "object") return (x as any)[variable]; - let t = x === zeroSymbol ? ctx.func.tyT64() : ctx.func.tyF64(); - const map = typeMap(ctx, t); - let id = map.get(x); - if (id !== undefined) return id; // constant, so can't be out of scope - if (!(typeof x === "number")) throw Error("invalid value"); - id = ctx.func.num(t, x); - map.set(x, id); - return id; -}; - -/** Return the zero tangent. */ -export const zero = (): Tan => { - const ctx = getCtx(); - const t = ctx.func.tyT64(); - typeMap(ctx, t).set(zeroSymbol, ctx.func.num(t, 0)); - return zeroSymbol; -}; - -/** Return the negative of the abstract number `x`. */ -export const neg: { - (x: Real): Real; - (x: Tan): Tan; -} = (x: Real | Tan): Var => { - const ctx = getCtx(); - return newVar(ctx.block.neg(ctx.func, numId(ctx, x))); -}; - -/** Return the abstract number `x` plus the abstract number `y`. */ -export const add: { - (x: Real, y: Real): Real; - (x: Tan, y: Tan): Tan; -} = (x: Real | Tan, y: Real | Tan): Var => { - const ctx = getCtx(); - return newVar(ctx.block.add(ctx.func, numId(ctx, x), numId(ctx, y))); -}; - -/** Return the abstract number `x` minus the abstract number `y`. */ -export const sub: { - (x: Real, y: Real): Real; - (x: Tan, y: Tan): Tan; -} = (x: Real | Tan, y: Real | Tan): Var => { - const ctx = getCtx(); - return newVar(ctx.block.sub(ctx.func, numId(ctx, x), numId(ctx, y))); -}; - -/** Return the abstract number `x` times the abstract number `y`. */ -export const mul: { - (x: Real, y: Real): Real; - (x: Tan, y: Real): Tan; -} = (x: Real | Tan, y: Real): Var => { - const ctx = getCtx(); - return newVar(ctx.block.mul(ctx.func, numId(ctx, x), numId(ctx, y))); -}; - -/** Return the abstract number `x` divided by the abstract number `y`. */ -export const div: { - (x: Real, y: Real): Real; - (x: Tan, y: Real): Tan; -} = (x: Real | Tan, y: Real): Var => { - const ctx = getCtx(); - return newVar(ctx.block.div(ctx.func, numId(ctx, x), numId(ctx, y))); -}; diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 4528f64..d2afc6a 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -1,945 +1,6 @@ -import { describe, expect, test } from "vitest"; -import { - Bool, - Dual, - Null, - Real, - Vec, - abs, - add, - and, - ceil, - compile, - div, - floor, - fn, - gt, - iff, - interp, - jvp, - mul, - neg, - not, - opaque, - or, - select, - sign, - sqrt, - struct, - sub, - trunc, - vec, - vjp, - xor, - zero, -} from "./index.js"; +import { expect, test } from "vitest"; +import { add } from "./index.js"; -describe("invalid", () => { - test("undefined", () => { - expect(() => fn([], Real, () => undefined as any)).toThrow("invalid value"); - }); - - test("bigint", () => { - expect(() => fn([], Real, () => 0n as any)).toThrow("invalid value"); - }); - - test("string", () => { - expect(() => fn([], Real, () => "hello" as any)).toThrow("invalid value"); - }); - - test("literal return type", () => { - expect(() => fn([], Real, () => null as any)).toThrow( - "did not expect null", - ); - }); - - test("add argument type", () => { - const two = true as any; - expect(() => fn([], Real, () => add(two, two))).toThrow("invalid value"); - }); - - test("invalid index type", () => { - expect(() => - fn([Vec(Real, Null), Real], Null, (v, i) => v[i as any]), - ).toThrow("index type cannot be used as an index"); - }); - - test("symbolic array dimension", () => { - expect(() => fn([Vec(3, Real)], Vec(2, Real), (v) => v)).toThrow( - "variable type mismatch", - ); - }); - - test("literal array dimension", () => { - expect(() => fn([], Vec(2, Real), () => [1, 2, 3])).toThrow( - "wrong array size", - ); - }); - - test("out of bounds index", () => { - expect(() => fn([Vec(2, Real)], Real, (v) => v[2])).toThrow("out of range"); - }); - - test("access index out of scope", () => { - const n = 2; - expect(() => - fn([], n, () => { - let i: number | symbol = 0; - vec(n, Real, (j) => { - i = j; - return 42; - }); - return i; - }), - ).toThrow("variable is out of scope"); - }); - - test("wrong struct member names", () => { - expect(() => - fn([{ a: Real, b: Real }], { b: Real, c: Real }, (x) => x as any), - ).toThrow("variable type mismatch"); - }); -}); - -describe("valid", () => { - test("return null", () => { - const f = fn([], Null, () => null); - const g = interp(f); - expect(g()).toBe(null); - }); - - test("interp with null", () => { - const f = fn([Null], Null, (x) => x); - const g = interp(f); - expect(g(null)).toBe(null); - }); - - test("not", () => { - const f = fn([Bool], Bool, (p) => not(p)); - const g = interp(f); - expect(g(true)).toBe(false); - expect(g(false)).toBe(true); - }); - - test("2 + 2 = 4", () => { - const f = fn([Real, Real], Real, (x, y) => add(x, y)); - const g = interp(f); - expect(g(2, 2)).toBe(4); - }); - - test("basic arithmetic", () => { - const f = fn([], Real, () => add(2, sub(mul(3, 2), div(2, 1)))); - const g = interp(f); - expect(g()).toBe(6); - }); - - test("absolute value", () => { - const f = fn([Real], Real, (x) => abs(x)); - const g = interp(f); - expect(g(-2)).toBe(2); - expect(g(-0)).toBe(0); - expect(g(0)).toBe(0); - expect(g(2)).toBe(2); - }); - - test("signum", () => { - const f = fn([Real], Real, (x) => sign(x)); - const g = interp(f); - expect(g(-2)).toBe(-1); - expect(g(-0)).toBe(-1); - expect(g(0)).toBe(1); - expect(g(2)).toBe(1); - }); - - test("ceiling", () => { - const f = fn([Real], Real, (x) => ceil(x)); - const g = interp(f); - expect(g(-1.5)).toBe(-1); - expect(g(-1)).toBe(-1); - expect(g(-0.5)).toBe(-0); - expect(g(-0)).toBe(-0); - expect(g(0)).toBe(0); - expect(g(0.5)).toBe(1); - expect(g(1)).toBe(1); - expect(g(1.5)).toBe(2); - }); - - test("floor", () => { - const f = fn([Real], Real, (x) => floor(x)); - const g = interp(f); - expect(g(-1.5)).toBe(-2); - expect(g(-1)).toBe(-1); - expect(g(-0.5)).toBe(-1); - expect(g(-0)).toBe(-0); - expect(g(0)).toBe(0); - expect(g(0.5)).toBe(0); - expect(g(1)).toBe(1); - expect(g(1.5)).toBe(1); - }); - - test("truncate", () => { - const f = fn([Real], Real, (x) => trunc(x)); - const g = interp(f); - expect(g(-1.5)).toBe(-1); - expect(g(-1)).toBe(-1); - expect(g(-0.5)).toBe(-0); - expect(g(-0)).toBe(-0); - expect(g(0)).toBe(0); - expect(g(0.5)).toBe(0); - expect(g(1)).toBe(1); - expect(g(1.5)).toBe(1); - }); - - test("square root", () => { - const f = fn([Real], Real, (x) => sqrt(x)); - const g = interp(f); - expect(g(Math.PI)).toBe(1.7724538509055159); - }); - - test("select", () => { - const f = fn([Bool], Real, (x) => select(x, Real, 1, 2)); - const g = interp(f); - expect(g(false)).toBe(2); - expect(g(true)).toBe(1); - }); - - test("call", () => { - const ifCond = fn([Bool, Real, Real], Real, (p, x, y) => - select(p, Real, x, y), - ); - const f = fn([Real], Real, (x) => ifCond(gt(x, 0), x, 0)); - const relu = interp(f); - expect(relu(-2)).toBe(0); - expect(relu(-0)).toBe(0); - expect(relu(0)).toBe(0); - expect(relu(2)).toBe(2); - }); - - test("empty boolean array", () => { - const f = fn([], Vec(0, Bool), () => []); - const g = interp(f); - expect(g()).toEqual([]); - }); - - test("empty real array", () => { - const f = fn([], Vec(0, Real), () => []); - const g = interp(f); - expect(g()).toEqual([]); - }); - - test("dot product", () => { - const R3 = Vec(3, Real); - const dot = fn([R3, R3], Real, (u, v) => { - const x = mul(u[0], v[0]); - const y = mul(u[1], v[1]); - const z = mul(u[2], v[2]); - return add(add(x, y), z); - }); - const f = interp(dot); - expect(f([1, 3, -5], [4, -2, -1])).toBe(3); - }); - - test("cross product", () => { - const R3 = Vec(3, Real); - const cross = fn([R3, R3], R3, (u, v) => { - const x = sub(mul(u[1], v[2]), mul(u[2], v[1])); - const y = sub(mul(u[2], v[0]), mul(u[0], v[2])); - const z = sub(mul(u[0], v[1]), mul(u[1], v[0])); - return [x, y, z]; - }); - const f = interp(cross); - expect(f([3, -3, 1], [4, 9, 2])).toEqual([-15, -2, 39]); - }); - - test("index array", () => { - const n = 3; - const f = fn([Vec(n, n), Vec(n, Real)], Vec(n, Real), (i, v) => - vec(n, Real, (j) => v[i[j]]), - ); - const g = fn([], Vec(n, Real), () => { - const v = [2, 0, 1]; - return f(v, v); - }); - const h = interp(g); - expect(h()).toEqual([1, 2, 0]); - }); - - test("interp with index value", () => { - const n = 1; - const f = fn([n, Vec(n, Bool)], Bool, (i, v) => v[i]); - const g = interp(f); - expect(g(0, [true])).toBe(true); - }); - - test("matrix multiplication", async () => { - const n = 6; - - const Rn = Vec(n, Real); - - const dot = fn([Rn, Rn], Real, (u, v) => { - const w = vec(n, Real, (i) => mul(u[i], v[i])); - let s = w[0]; - s = add(s, w[1]); - s = add(s, w[2]); - s = add(s, w[3]); - s = add(s, w[4]); - s = add(s, w[5]); - return s; - }); - - const m = 5; - const p = 7; - - const Rp = Vec(p, Real); - - const Rmxn = Vec(m, Rn); - const Rnxp = Vec(n, Rp); - const Rmxp = Vec(m, Rp); - - const mmul = fn([Rmxn, Rnxp], Rmxp, (a, b) => - vec(m, Rp, (i) => { - const u = a[i]; - return vec(p, Real, (j) => { - const v = vec(n, Real, (k) => b[k][j]); - return dot(u, v); - }); - }), - ); - - const f = await compile(mmul); - expect( - f( - [ - [-8, 5, 3, -1, 8, 0], - [-3, -1, 7, -7, 8, 3], - [-4, 5, 5, 5, 8, 6], - [1, -9, 5, 4, 4, 0], - [9, -3, 1, 3, -5, -5], - ], - [ - [-7, 9, 6, -8, 5, 8, -3], - [3, -6, 8, 0, 7, -4, -1], - [-4, 9, 9, 1, -8, -4, 0], - [6, -7, 6, -6, -8, -5, 0], - [8, 5, 3, 0, 6, 3, -7], - [-4, 0, -5, -9, 8, -9, -1], - ], - ), - ).toEqual([ - [117, -28, 37, 73, 27, -67, -37], - [0, 131, 4, 46, 50, -16, -49], - [93, -16, 85, -47, 31, -127, -55], - [2, 100, 15, -27, -106, 16, -22], - [-78, 62, 67, -44, -78, 95, 16], - ]); - }); - - test("singleton array from index", () => { - const One = Vec(1, 1); - const f = fn([], One, () => vec(1, One, (i) => [i])[0]); - const g = interp(f); - expect(g()).toEqual([0]); - }); - - test("struct", () => { - const Pair = struct({ x: Real, y: Real }); - const f = fn([Pair], Real, (p) => sub(p.y, p.x)); - const g = fn([Real, Real], Pair, (x, y) => ({ y, x })); - const h = interp(fn([Real, Real], Real, (x, y) => f(g(x, y)))); - expect(h(3, 5)).toBe(2); - }); - - test("return struct", () => { - const f = fn([], { p: Bool, x: Real }, () => ({ p: true, x: 42 })); - const g = interp(f); - expect(g()).toEqual({ p: true, x: 42 }); - }); - - test("select struct", () => { - const f = fn([], Real, () => { - return select(false, { x: Real }, { x: 3 }, { x: 5 }).x; - }); - const g = interp(f); - expect(g()).toBe(5); - }); - - test("array of structs", () => { - const n = 2; - const Indexed = struct({ i: n, x: Real }); - const f = fn([Vec(n, Real)], Vec(n, Indexed), (v) => - vec(n, Indexed, (i) => ({ i, x: v[i] })), - ); - const g = interp(f); - expect(g([3, 5])).toEqual([ - { i: 0, x: 3 }, - { i: 1, x: 5 }, - ]); - }); - - test("internal array of structs", () => { - const n = 2; - const f = fn([Vec(n, Real)], Vec(n, n), (v) => { - const u = vec(n, { i: n }, (i) => ({ i })); - return vec(n, n, (i) => u[i].i); - }); - const g = interp(f); - expect(g([3, 5])).toEqual([0, 1]); - }); - - test("interp struct arg", () => { - const f = fn([{ x: Real }], Real, (p) => p.x); - const g = interp(f); - expect(g({ x: 42 })).toBe(42); - }); - - test("opaque unary function", () => { - const log = opaque([Real], Real, Math.log); - const f = interp(log); - expect(f(Math.PI)).toBe(1.1447298858494002); - }); - - test("opaque binary function", () => { - const pow = opaque([Real, Real], Real, Math.pow); - const f = interp(pow); - expect(f(Math.E, Math.PI)).toBe(23.140692632779263); - }); - - test("JVP", () => { - const f = fn([Real], Real, (x) => mul(x, x)); - const g = jvp(f); - const h = interp(g); - expect(h({ re: 3, du: 1 })).toEqual({ re: 9, du: 6 }); - }); - - test("JVP with sharing in call graph", async () => { - let f = fn([Real], Real, (x) => x); - for (let i = 0; i < 20; ++i) { - f = fn([Real], Real, (x) => add(f(x), f(x))); - } - const g = await compile(jvp(f)); - expect(g({ re: 2, du: 3 })).toEqual({ re: 2097152, du: 3145728 }); - }); - - test("custom JVP", () => { - const max = fn([Real, Real], Real, (x, y) => select(gt(x, y), Real, x, y)); - const f = fn([Real], Real, (x) => sqrt(x)); - const epsilon = 1e-5; - f.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = f(x); - return { re: y, du: mul(dx, div(1 / 2, max(epsilon, y))) }; - }); - const g = fn([Real, Real], Real, (x, y) => vjp(f)(x).grad(y)); - expect(interp(g)(0, 1)).toBeCloseTo(50000); - }); - - test("custom JVP with zero tangent", () => { - const signum = opaque([Real], Real, Math.sign); - signum.jvp = fn([Dual], Dual, ({ re: x }) => ({ re: sign(x), du: zero() })); - const f = interp(jvp(signum)); - expect(f({ re: 2, du: 1 })).toEqual({ re: 1, du: 0 }); - }); - - test("VJP", () => { - const f = fn([Vec(2, Real)], Real, ([x, y]) => mul(x, y)); - const g = fn([], Vec(3, Real), () => { - const { ret: x, grad } = vjp(f)([2, 3]); - const v = grad(1); - return [x, v[0], v[1]]; - }); - expect(interp(g)()).toEqual([6, 3, 2]); - }); - - test("VJP with sharing in call graph", async () => { - const iterate = ( - n: number, - f: (x: Real) => Real, - g: (x: Real, y: Real) => Real, - ) => { - for (let i = 0; i < n; ++i) { - const f0 = f; - const g0 = g; - g = fn([Real, Real], Real, (x, y) => g0(f0(x), f0(y))); - const g1 = g; - f = fn([Real], Real, (x) => g1(x, x)); - } - return f; - }; - - const f = iterate( - 12, - (x) => sqrt(x), - (x, y) => mul(x, y), - ); - const g = vjp(fn([Real], Real, (x) => f(x))); - const h = fn([Real, Real], Vec(2, Real), (x, y) => { - const { ret, grad } = g(x); - return [ret, grad(y)]; - }); - const v = (await compile(h))(2, 3); - expect(v[0]).toBeCloseTo(2); - expect(v[1]).toBeCloseTo(3); - }); - - test("VJP with struct and select", () => { - const Stuff = struct({ a: Null, b: Bool, c: Real }); - const f = fn([Stuff], Real, ({ b, c }) => - select(or(false, not(b)), Real, c, 2), - ); - const g = fn([Bool, Real], { x: Real, stuff: Stuff }, (b, c) => { - const { ret: x, grad } = vjp(f)({ a: null, b, c }); - return { x, stuff: grad(3) }; - }); - const h = interp(g); - expect(h(true, 5)).toEqual({ x: 2, stuff: { a: null, b: true, c: 0 } }); - expect(h(false, 7)).toEqual({ x: 7, stuff: { a: null, b: false, c: 3 } }); - }); - - test("VJP with logic", () => { - const f = fn([Bool], Bool, (p) => not(p)); - const g = fn([Bool], Bool, (p) => vjp(f)(p).ret); - const h = interp(g); - expect(h(true)).toBe(false); - expect(h(false)).toBe(true); - }); - - test("VJP with select on null", () => { - const f = fn([Null], Null, () => select(true, Null, null, null)); - const g = fn([], Null, () => vjp(f)(null).ret); - const h = interp(g); - expect(h()).toBe(null); - }); - - test("VJP with select on booleans", () => { - const f = fn([Bool], Bool, (p) => select(p, Bool, false, true)); - const g = fn([Bool], Bool, (p) => vjp(f)(p).ret); - const h = interp(g); - expect(h(true)).toBe(false); - expect(h(false)).toBe(true); - }); - - test("VJP with select on indices", () => { - const n = 2; - const f = fn([Bool], n, (p) => select(p, n, 0, 1)); - const g = fn([Bool], n, (p) => vjp(f)(p).ret); - const h = interp(g); - expect(h(true)).toBe(0); - expect(h(false)).toBe(1); - }); - - test("VJP with vector comprehension", () => { - const n = 2; - const f = fn([Vec(n, Real)], Vec(n, Real), (v) => - vec(n, Real, (i) => mul(v[i], v[i])), - ); - const g = fn([Vec(n, Real), Vec(n, Real)], Vec(n, Real), (u, v) => - vjp(f)(u).grad(v), - ); - expect(interp(g)([2, 3], [5, 7])).toEqual([20, 42]); - }); - - test("VJP twice", () => { - const f = fn([Real], Real, (x) => { - const y = mul(x, x); - return mul(x, y); - }); - const g = fn([Real], Real, (x) => vjp(f)(x).grad(1)); - const h = fn([Real], Real, (x) => vjp(g)(x).grad(1)); - expect(interp(h)(10)).toBe(60); - }); - - test("Hessian", () => { - const powi = (x: Real, n: number): Real => { - if (!Number.isInteger(n)) - throw new Error(`exponent is not an integer: ${n}`); - // https://en.wikipedia.org/wiki/Exponentiation_by_squaring - if (n < 0) return powi(div(1, x), -n); - else if (n == 0) return 1; - else if (n == 1) return x; - else if (n % 2 == 0) return powi(mul(x, x), n / 2); - else return mul(x, powi(mul(x, x), (n - 1) / 2)); - }; - const f = fn([Vec(2, Real)], Real, ([x, y]) => - sub(sub(powi(x, 3), mul(2, mul(x, y))), powi(y, 6)), - ); - const g = fn([Vec(2, Real)], Vec(2, Real), (v) => vjp(f)(v).grad(1)); - const h = fn([Vec(2, Real)], Vec(2, Vec(2, Real)), (v) => { - const { grad } = vjp(g)(v); - return [grad([1, 0] as any), grad([0, 1] as any)]; - }); - expect(interp(h)([1, 2])).toEqual([ - [6, -2], - [-2, -480], - ]); - }); - - test("VJP twice with struct", () => { - const Pair = struct({ x: Real, y: Real }); - const f = fn([Pair], Real, ({ x, y }) => mul(x, y)); - const g = fn([Pair], Pair, (p) => vjp(f)(p).grad(1)); - const h = fn([Pair, Pair], Pair, (p, q) => vjp(g)(p).grad(q)); - expect(interp(h)({ x: 2, y: 3 }, { x: 5, y: 7 })).toEqual({ x: 7, y: 5 }); - }); - - test("VJP twice with select", () => { - const Stuff = struct({ p: Bool, x: Real, y: Real, z: Real }); - const f = fn([Stuff], Real, ({ p, x, y, z }) => - mul(z, select(p, Real, x, y)), - ); - const g = fn([Stuff], Stuff, (p) => vjp(f)(p).grad(1)); - const h = fn([Stuff, Stuff], Stuff, (p, q) => vjp(g)(p).grad(q)); - expect( - interp(h)( - { p: true, x: 2, y: 3, z: 5 }, - { p: false, x: 7, y: 11, z: 13 }, - ), - ).toEqual({ p: true, x: 13, y: 0, z: 7 }); - }); - - test("opaque functions with derivatives", async () => { - const grad = (f: any) => fn([Real], Real, (x) => vjp(f)(x).grad(1) as Real); - - const sin = opaque([Real], Real, Math.sin); - const cos = opaque([Real], Real, Math.cos); - - sin.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - return { re: sin(x), du: mul(dx, cos(x)) }; - }); - cos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - return { re: cos(x), du: mul(dx, neg(sin(x))) }; - }); - - let f = sin; - expect(interp(f)(1)).toBeCloseTo(Math.sin(1)); - f = grad(f); - expect((await compile(f))(1)).toBeCloseTo(Math.cos(1)); - f = grad(f); - expect(interp(f)(1)).toBeCloseTo(-Math.sin(1)); - - f = cos; - expect((await compile(f))(1)).toBeCloseTo(Math.cos(1)); - f = grad(f); - expect(interp(f)(1)).toBeCloseTo(-Math.sin(1)); - f = grad(f); - expect((await compile(f))(1)).toBeCloseTo(-Math.cos(1)); - }); - - test("compile", async () => { - const f1 = fn([Real], Real, (x) => sqrt(x)); - const f2 = fn([Real, Real], Real, (x, y) => mul(x, f1(y))); - const f3 = fn([Real, Real], Real, (x, y) => mul(f1(x), y)); - const f = fn([Real, Real], Real, (x, y) => sub(f2(x, y), f3(x, y))); - const g = await compile(f); - expect(g(2, 3)).toBeCloseTo(-0.7785390719815313); - }); - - test("compile opaque function", async () => { - const f = opaque([Real], Real, Math.sin); - const g = await compile(f); - expect(g(1)).toBeCloseTo(Math.sin(1)); - }); - - test("compile calls to multiple opaque functions", async () => { - const sin = opaque([Real], Real, Math.sin); - const cos = opaque([Real], Real, Math.cos); - const f = fn([Real], Real, (x) => sub(sin(x), cos(x))); - const g = await compile(f); - expect(g(1)).toBeCloseTo(Math.sin(1) - Math.cos(1)); - }); - - test("compile opaque and transparent calls together", async () => { - const log = opaque([Real], Real, Math.log); - const f = fn([Real], Real, (x) => add(log(x), sqrt(x))); - const g = fn([Real], Real, (x) => add(f(x), x)); - const h = await compile(g); - expect(h(1)).toBeCloseTo(Math.log(1) + Math.sqrt(1) + 1); - }); - - test("compile array", async () => { - const f = fn([Vec(2, Real)], Real, ([x, y]) => mul(x, y)); - const g = fn([Real, Real], Real, (x, y) => f([x, y])); - const h = await compile(g); - expect(h(2, 3)).toBe(6); - }); - - test("compile null array", async () => { - const f = fn([Vec(2, Null)], Null, (v) => v[1]); - const g = fn([], Real, () => { - f([null, null]); - return 42; - }); - const h = await compile(g); - expect(h()).toBe(42); - }); - - test("compile struct", async () => { - const f = fn([{ x: Real, y: Real }], Real, ({ x, y }) => mul(x, y)); - const g = fn([Real, Real], Real, (x, y) => f({ x, y })); - const h = await compile(g); - expect(h(2, 3)).toBe(6); - }); - - test("compile logic", async () => { - const f = fn([Vec(3, Bool)], Bool, ([p, q, r]) => - iff(and(or(p, not(q)), xor(r, q)), or(not(p), and(q, r))), - ); - const g = fn([Bool, Bool, Bool], Real, (p, q, r) => - select(f([p, q, r]), Real, -1, -2), - ); - const h = await compile(g); - expect(h(true, true, true)).toBe(-2); - expect(h(true, true, false)).toBe(-2); - expect(h(true, false, true)).toBe(-2); - expect(h(true, false, false)).toBe(-1); - expect(h(false, true, true)).toBe(-2); - expect(h(false, true, false)).toBe(-2); - expect(h(false, false, true)).toBe(-1); - expect(h(false, false, false)).toBe(-2); - }); - - test("compile signum", async () => { - const f = fn([Real], Real, (x) => sign(x)); - const g = await compile(f); - expect(g(-2)).toBe(-1); - expect(g(-0)).toBe(-1); - expect(g(0)).toBe(1); - expect(g(2)).toBe(1); - }); - - test("compile select", async () => { - const f = fn([Bool, Real, Real], Real, (p, x, y) => select(p, Real, x, y)); - const g = await compile(f); - expect(g(true, 2, 3)).toBe(2); - expect(g(false, 5, 7)).toBe(7); - }); - - test("compile vector comprehension", async () => { - const f = fn([Real, Vec(3, Real)], Vec(3, Real), (c, v) => - vec(3, Real, (i) => mul(c, v[i])), - ); - const g = fn([Real, Real, Real, Real], Real, (c, x, y, z) => { - const v = f(c, [x, y, z]); - return add(add(v[0], v[1]), v[2]); - }); - const h = await compile(g); - expect(h(2, 3, 5, 7)).toBe(30); - }); - - test("compile empty vector comprehension", async () => { - let i = 0; - const f = opaque([], Real, () => { - ++i; - return i; - }); - const g = fn([], Real, () => { - vec(0, Real, () => f()); - return 0; - }); - (await compile(g))(); - expect(i).toEqual(0); - }); - - test("compile VJP", async () => { - const f = fn( - [Vec(2, struct({ p: Bool, x: Real }))], - { p: Vec(2, Bool), x: Vec(2, Real) }, - (v) => ({ - p: vec(2, Bool, (i) => not(v[i].p)), - x: vec(2, Real, (i) => { - const { p, x } = v[i]; - return select(p, Real, mul(x, x), x); - }), - }), - ); - const g = fn([Bool, Real, Bool, Real], Real, (p1, x1, q1, y1) => { - const { ret, grad } = vjp(f)([ - { p: p1, x: x1 }, - { p: q1, x: y1 }, - ]); - const [x2, y2] = ret.x; - const v = grad({ p: [true, false] as any, x: [2, 3] as any }); - const [{ x: x3 }, { x: y3 }] = v; - return mul(sub(x3, y2), sub(y3, x2)); - }); - const h = await compile(g); - expect(h(true, 2, true, 3)).toBe(-14); - expect(h(true, 5, false, 7)).toBe(-286); - expect(h(false, 11, true, 13)).toBe(-11189); - expect(h(false, 17, false, 19)).toBe(238); - }); - - test("compile VJP with call", async () => { - const f = fn([Real], Real, (x) => x); - const g = fn([Real], Real, (x) => f(x)); - const h = fn([Real], Real, (x) => vjp(g)(x).ret); - expect((await compile(h))(1)).toBe(1); - }); - - test("compile VJP with opaque call", async () => { - const exp = opaque([Real], Real, Math.exp); - exp.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = exp(x); - return { re: y, du: mul(dx, y) }; - }); - const g = fn([Real], Real, (x) => exp(x)); - const h = fn([Real], Real, (x) => vjp(g)(x).ret); - expect((await compile(h))(1)).toBeCloseTo(Math.E); - }); - - test("compile nulls in signature", async () => { - const f = fn([Null], Null, (x) => x); - const g = await compile(f); - expect(g(null)).toBe(null); - }); - - test("compile booleans in signature", async () => { - const f = fn([Bool], Bool, (p) => not(p)); - const g = await compile(f); - expect(g(true)).toBe(false); - expect(g(false)).toBe(true); - }); - - test("compile null arrays in signature", async () => { - const f = fn([Vec(2, Null)], Vec(2, Null), (v) => v); - const g = await compile(f); - expect(g([null, null])).toEqual([null, null]); - }); - - test("compile byte index arrays in signature", async () => { - const n = 256; - const f = fn([Vec(3, n), Vec(3, 3)], Vec(3, n), (v, i) => - vec(3, n, (j) => v[i[j]]), - ); - const g = await compile(f); - expect(g([12, 221, 234], [1, 2, 0])).toEqual([221, 234, 12]); - }); - - test("compile structs in signature", async () => { - const Pair = struct({ x: Real, y: Real }); - const f = fn([Pair], Pair, ({ x, y }) => ({ x: y, y: x })); - const g = await compile(f); - expect(g({ x: 2, y: 3 })).toEqual({ x: 3, y: 2 }); - }); - - test("compile zero-sized struct members in signature", async () => { - const Stuff = struct({ a: Null, b: 0, c: 0, d: Null }); - const f = fn([Stuff], Stuff, ({ a, b, c, d }) => { - return { a: d, b: c, c: b, d: a }; - }); - const g = await compile(f); - const stuff = { a: null, b: 0, c: 0, d: null }; - expect(g(stuff)).toEqual(stuff); - }); - - test("compile nested structs in signature", async () => { - const Pair = struct({ x: Real, y: Real }); - const Stuff = struct({ p: Bool, q: Pair }); - const f = fn([Stuff], Stuff, ({ p, q }) => ({ - p: not(p), - q: { x: q.y, y: q.x }, - })); - const g = await compile(f); - expect(g({ p: true, q: { x: 2, y: 3 } })).toEqual({ - p: false, - q: { x: 3, y: 2 }, - }); - }); - - test("compile big structs in signature", async () => { - const M = 300; - const N = 70000; - const Stuff = struct({ - a: Real, - b: N, - c: Real, - d: Null, - e: Bool, - f: M, - g: N, - h: Null, - i: Bool, - j: Bool, - k: Real, - l: Null, - m: N, - n: N, - o: M, - p: Null, - q: Null, - r: Real, - s: Real, - t: Real, - u: Real, - }); - const f = fn( - [Stuff], - Stuff, - ({ a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u }) => { - return { - a: t, - b: m, - c: k, - d: p, - e, - f, - g: n, - h: l, - i: j, - j: i, - k: a, - l: d, - m: g, - n: b, - o, - p: h, - q, - r: s, - s: c, - t: r, - u: u, - }; - }, - ); - const x = { - a: 0, - b: 1, - c: 2, - d: null, - e: false, - f: 3, - g: 4, - h: null, - i: true, - j: false, - k: 5, - l: null, - m: 6, - n: 7, - o: 8, - p: null, - q: null, - r: 9, - s: 10, - t: 11, - u: 12, - }; - const g = interp(f); - const h = await compile(f); - expect(h(x)).toEqual(g(x)); - }); - - test("compile matrix gradient", async () => { - const T = Vec(1, Vec(1, Real)); - const f = fn([T], Null, () => null); - const g = fn([], T, () => vjp(f)([[0]]).grad(null)); - const h = await compile(g); - expect(h()).toEqual([[0]]); - }); - - test("compile gradient with dynamic index", async () => { - const T = struct({ v: Vec(1, Real), i: 1 }); - const f = fn([T], Real, ({ v, i }) => v[i]); - const g = fn([T], T, (x) => vjp(f)(x).grad(1)); - const h = await compile(g); - expect(h({ v: [2], i: 0 })).toEqual({ v: [1], i: 0 }); - }); +test("add", () => { + expect(add(2, 2)).toBe(4); }); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 589e750..bc81dd5 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,49 +1 @@ -export { - Bool, - Bools, - Dual, - Fn, - Nats, - Null, - Nulls, - Real, - Reals, - Symbolic, - Tan, - Tans, - Value, - Vec, - Vecs, - abs, - add, - and, - ceil, - compile, - div, - eq, - floor, - fn, - geq, - gt, - iff, - interp, - jvp, - leq, - lt, - mul, - neg, - neq, - not, - opaque, - or, - select, - sign, - sqrt, - struct, - sub, - trunc, - vec, - vjp, - xor, - zero, -} from "./impl.js"; +export const add = (a: number, b: number) => a + b; diff --git a/packages/site/package.json b/packages/site/package.json index c8cf33e..2c420fe 100644 --- a/packages/site/package.json +++ b/packages/site/package.json @@ -1,11 +1,11 @@ { "name": "@rose-lang/site", - "version": "0.4.5", + "version": "0.5.0", "private": true, "type": "module", "dependencies": { "highlight.js": "^11", - "rose": "0.4.5" + "rose": "0.5.0" }, "scripts": { "build": "vite build", diff --git a/packages/site/src/main.ts b/packages/site/src/main.ts index efb151d..02bf3e5 100644 --- a/packages/site/src/main.ts +++ b/packages/site/src/main.ts @@ -1,4 +1,3 @@ -import { Real, Vec, compile, fn, jvp, vec, vjp } from "rose"; import { Expr, parse } from "./parse.js"; type Vec2 = [number, number]; @@ -11,10 +10,9 @@ interface Info { type Func = (x: number, y: number) => Info; -const autodiff = async (root: Expr): Promise => { - const Vec2 = Vec(2, Real); - const f = fn([Vec2], Real, (v) => { - const emit = (e: Expr): Real => { +const autodiff = (root: Expr): Func => { + const f = (v: number[]) => { + const emit = (e: Expr): number => { switch (e.kind) { case "const": return e.val; @@ -27,29 +25,18 @@ const autodiff = async (root: Expr): Promise => { } }; return emit(root); - }); - - const Mat2 = Vec(2, Vec2); - const g = fn([Vec2], Vec2, (v) => vjp(f)(v).grad(1)); - const h = fn([Vec2], Mat2, ([x, y]) => { - const d = jvp(g); - const a = d([ - { re: x, du: 1 }, - { re: y, du: 0 }, - ]); - const b = d([ - { re: x, du: 0 }, - { re: y, du: 1 }, - ]); - return [vec(2, Real, (i) => a[i].du), vec(2, Real, (i) => b[i].du)]; - }); - - return (await compile( - fn([Real, Real], { val: Real, grad: Vec2, hess: Mat2 }, (x, y) => { - const v = [x, y]; - return { val: f(v), grad: g(v), hess: h(v) }; - }), - )) as unknown as Func; + }; + return (x, y) => { + const v = [x, y]; + return { + val: f(v), + grad: [0, 0], + hess: [ + [0, 0], + [0, 0], + ], + }; + }; }; interface Parabola { @@ -234,7 +221,7 @@ const setPoint = (newPoint: Vec2) => { }; const textbox = document.getElementById("textbox") as HTMLInputElement; -const setFunc = async () => { +const setFunc = () => { let root: Expr = { kind: "const", val: NaN }; try { root = parse(textbox.value); @@ -242,12 +229,12 @@ const setFunc = async () => { } catch (e) { textbox.classList.add("error"); } - func = await autodiff(root); + func = autodiff(root); setPoint(point); }; -await setFunc(); +setFunc(); textbox.addEventListener("input", async () => { - await setFunc(); + setFunc(); }); const roseColor = "#C33358"; diff --git a/packages/site/src/math.ts b/packages/site/src/math.ts deleted file mode 100644 index b16f904..0000000 --- a/packages/site/src/math.ts +++ /dev/null @@ -1,142 +0,0 @@ -import { Dual, Real, add, div, fn, mul, neg, opaque, sqrt, sub } from "rose"; - -export const acos = opaque([Real], Real, Math.acos); -export const acosh = opaque([Real], Real, Math.acosh); -export const asin = opaque([Real], Real, Math.asin); -export const asinh = opaque([Real], Real, Math.asinh); -export const atan = opaque([Real], Real, Math.atan); -export const atanh = opaque([Real], Real, Math.atanh); -export const cbrt = opaque([Real], Real, Math.cbrt); -export const cos = opaque([Real], Real, Math.cos); -export const cosh = opaque([Real], Real, Math.cosh); -export const exp = opaque([Real], Real, Math.exp); -export const expm1 = opaque([Real], Real, Math.expm1); -export const log = opaque([Real], Real, Math.log); -export const log10 = opaque([Real], Real, Math.log10); -export const log1p = opaque([Real], Real, Math.log1p); -export const log2 = opaque([Real], Real, Math.log2); -export const pow = opaque([Real, Real], Real, Math.pow); -export const sin = opaque([Real], Real, Math.sin); -export const sinh = opaque([Real], Real, Math.sinh); -export const tan = opaque([Real], Real, Math.tan); -export const tanh = opaque([Real], Real, Math.tanh); - -acos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = acos(x); - const dy = div(dx, neg(sqrt(sub(1, mul(x, x))))); - return { re: y, du: dy }; -}); - -acosh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = acosh(x); - const dy = div(dx, mul(sqrt(sub(x, 1)), sqrt(add(x, 1)))); - return { re: y, du: dy }; -}); - -asin.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }) => { - const y = asin(x); - const dy = div(dx, sqrt(sub(1, mul(x, x)))); - return { re: y, du: dy }; -}); - -asinh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = asinh(x); - const dy = div(dx, sqrt(add(1, mul(x, x)))); - return { re: y, du: dy }; -}); - -atan.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = atan(x); - const dy = div(dx, add(1, mul(x, x))); - return { re: y, du: dy }; -}); - -atanh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = atanh(x); - const dy = div(dx, sub(1, mul(x, x))); - return { re: y, du: dy }; -}); - -cbrt.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = cbrt(x); - const dy = mul(dx, div(1 / 3, mul(y, y))); - return { re: y, du: dy }; -}); - -cos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = cos(x); - const dy = mul(dx, neg(sin(x))); - return { re: y, du: dy }; -}); - -cosh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = cosh(x); - const dy = mul(dx, sinh(x)); - return { re: y, du: dy }; -}); - -exp.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = exp(x); - const dy = mul(dx, y); - return { re: y, du: dy }; -}); - -expm1.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = expm1(x); - const dy = mul(dx, add(y, 1)); - return { re: y, du: dy }; -}); - -log.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = log(x); - const dy = div(dx, x); - return { re: y, du: dy }; -}); - -log10.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = log10(x); - const dy = mul(dx, div(Math.LOG10E, x)); - return { re: y, du: dy }; -}); - -log1p.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = log1p(x); - const dy = div(dx, add(1, x)); - return { re: y, du: dy }; -}); - -log2.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = log2(x); - const dy = mul(dx, div(Math.LOG2E, x)); - return { re: y, du: dy }; -}); - -pow.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }, { re: y, du: dy }) => { - const z = pow(x, y); - const dz = mul(add(mul(dx, div(y, x)), mul(dy, log(x))), z); - return { re: z, du: dz }; -}); - -sin.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = sin(x); - const dy = mul(dx, cos(x)); - return { re: y, du: dy }; -}); - -sinh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = sinh(x); - const dy = mul(dx, cosh(x)); - return { re: y, du: dy }; -}); - -tan.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = tan(x); - const dy = mul(dx, add(1, mul(y, y))); - return { re: y, du: dy }; -}); - -tanh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { - const y = tanh(x); - const dy = mul(dx, sub(1, mul(y, y))); - return { re: y, du: dy }; -}); diff --git a/packages/site/src/parse.test.ts b/packages/site/src/parse.test.ts index 83923a2..cd60214 100644 --- a/packages/site/src/parse.test.ts +++ b/packages/site/src/parse.test.ts @@ -1,13 +1,12 @@ -import { add } from "rose"; import { expect, test } from "vitest"; import { Expr, parse } from "./parse.js"; test("add", () => { const expected: Expr = { kind: "binary", - f: add, + f: Math.pow, lhs: { kind: "var", idx: 0 }, rhs: { kind: "var", idx: 1 }, }; - expect(parse("x+y")).toEqual(expected); + expect(parse("x^y")).toEqual(expected); }); diff --git a/packages/site/src/parse.ts b/packages/site/src/parse.ts index f43dc43..74c0ff8 100644 --- a/packages/site/src/parse.ts +++ b/packages/site/src/parse.ts @@ -1,69 +1,32 @@ -import { - Real, - abs, - add, - ceil, - div, - floor, - mul, - neg, - sign, - sqrt, - sub, - trunc, -} from "rose"; -import { - acos, - acosh, - asin, - asinh, - atan, - atanh, - cbrt, - cos, - cosh, - exp, - expm1, - log, - log10, - log1p, - log2, - pow, - sin, - sinh, - tan, - tanh, -} from "./math.js"; - const unaries = { - abs, - acos, - acosh, - asin, - asinh, - atan, - atanh, - cbrt, - ceil, - cos, - cosh, - exp, - expm1, - floor, - log, - log10, - log1p, - log2, - sign, - sin, - sinh, - sqrt, - tan, - tanh, - trunc, + abs: Math.abs, + acos: Math.acos, + acosh: Math.acosh, + asin: Math.asin, + asinh: Math.asinh, + atan: Math.atan, + atanh: Math.atanh, + cbrt: Math.cbrt, + ceil: Math.ceil, + cos: Math.cos, + cosh: Math.cosh, + exp: Math.exp, + expm1: Math.expm1, + floor: Math.floor, + log: Math.log, + log10: Math.log10, + log1p: Math.log1p, + log2: Math.log2, + sign: Math.sign, + sin: Math.sin, + sinh: Math.sinh, + sqrt: Math.sqrt, + tan: Math.tan, + tanh: Math.tanh, + trunc: Math.trunc, }; -const unary = (name: string): ((x: Real) => Real) => { +const unary = (name: string): ((x: number) => number) => { if (name in unaries) return unaries[name as keyof typeof unaries]; throw Error(`unknown unary function: ${name}`); }; @@ -127,8 +90,13 @@ function* lex(s: string) { export type Expr = | { kind: "const"; val: number } | { kind: "var"; idx: number } - | { kind: "unary"; f: (x: Real) => Real; arg: Expr } - | { kind: "binary"; f: (x: Real, y: Real) => Real; lhs: Expr; rhs: Expr }; + | { kind: "unary"; f: (x: number) => number; arg: Expr } + | { + kind: "binary"; + f: (x: number, y: number) => number; + lhs: Expr; + rhs: Expr; + }; class Parser { tokens: Token[]; @@ -173,12 +141,12 @@ class Parser { parseFactor(): Expr { if (this.peek().kind === "-") { this.pop(); - return { kind: "unary", f: neg, arg: this.parseFactor() }; + return { kind: "unary", f: (a) => -a, arg: this.parseFactor() }; } const x = this.parseAtom(); if (this.peek().kind === "^") { this.pop(); - return { kind: "binary", f: pow, lhs: x, rhs: this.parseFactor() }; + return { kind: "binary", f: Math.pow, lhs: x, rhs: this.parseFactor() }; } return x; } @@ -188,7 +156,10 @@ class Parser { let tok = this.peek(); while (tok.kind === "*" || tok.kind === "/") { this.pop(); - const f = { "*": mul, "/": div }[tok.kind]; + const f = { + "*": (a: number, b: number) => a * b, + "/": (a: number, b: number) => a / b, + }[tok.kind]; x = { kind: "binary", f, lhs: x, rhs: this.parseFactor() }; tok = this.peek(); } @@ -200,7 +171,10 @@ class Parser { let tok = this.peek(); while (tok.kind === "+" || tok.kind === "-") { this.pop(); - const f = { "+": add, "-": sub }[tok.kind]; + const f = { + "+": (a: number, b: number) => a + b, + "-": (a: number, b: number) => a - b, + }[tok.kind]; x = { kind: "binary", f, lhs: x, rhs: this.parseTerm() }; tok = this.peek(); } diff --git a/packages/site/vite.config.ts b/packages/site/vite.config.ts deleted file mode 100644 index 4a50932..0000000 --- a/packages/site/vite.config.ts +++ /dev/null @@ -1,4 +0,0 @@ -import { defineConfig } from "vite"; -import topLevelAwait from "vite-plugin-top-level-await"; - -export default defineConfig({ plugins: [topLevelAwait()] }); diff --git a/packages/wasm/README.md b/packages/wasm/README.md deleted file mode 100644 index b7d9503..0000000 --- a/packages/wasm/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# `@rose-lang/wasm` - -This package contains the raw WebAssembly bindings for [Rose][]. You almost -certainly do not want to use this directly; check out the main package instead. - -[rose]: https://www.npmjs.com/package/rose diff --git a/packages/wasm/browser.js b/packages/wasm/browser.js deleted file mode 100644 index ae0e9f7..0000000 --- a/packages/wasm/browser.js +++ /dev/null @@ -1,5 +0,0 @@ -import init from "./wbg/rose_web.js"; - -await init(); - -export * from "./wbg/rose_web.js"; diff --git a/packages/wasm/index.js b/packages/wasm/index.js deleted file mode 100644 index 3cc7907..0000000 --- a/packages/wasm/index.js +++ /dev/null @@ -1,7 +0,0 @@ -import init, { initialize } from "./wbg/rose_web.js"; -import bytes from "./wbg/rose_web_bg.wasm"; - -await init(bytes); -initialize(); - -export * from "./wbg/rose_web.js"; diff --git a/packages/wasm/package.json b/packages/wasm/package.json deleted file mode 100644 index ba0f222..0000000 --- a/packages/wasm/package.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "name": "@rose-lang/wasm", - "version": "0.4.5", - "license": "MIT", - "repository": "rose-lang/rose", - "type": "module", - "exports": { - ".": { - "browser": { - "types": "./dist/wbg/rose_web.d.ts", - "default": "./dist/browser.js" - }, - "default": "./dist/index.js" - }, - "./*": "./dist/bindings/*.js" - }, - "files": [ - "dist", - "index.js", - "wbg" - ], - "scripts": { - "build": "cp browser.js dist/ && cp wbg/rose_web.d.ts dist/index.d.ts && esbuild index.js --outfile=dist/index.js --platform=neutral --bundle --loader:.wasm=binary --define:import.meta.url=null --sourcemap" - } -} diff --git a/rust-toolchain.toml b/rust-toolchain.toml deleted file mode 100644 index 6bcb7ea..0000000 --- a/rust-toolchain.toml +++ /dev/null @@ -1,5 +0,0 @@ -[toolchain] -channel = "nightly-2023-06-26" -components = ["rust-src"] -profile = "default" -targets = ["wasm32-unknown-unknown"] diff --git a/tsconfig.json b/tsconfig.json index 023f6d6..d4df82e 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,10 +1,7 @@ { "compilerOptions": { "forceConsistentCasingInFileNames": true, - "lib": [ - "DOM", // https://github.com/microsoft/TypeScript-DOM-lib-generator/issues/826 - "ES2021" // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry - ], + "lib": ["DOM"], // https://github.com/microsoft/TypeScript-DOM-lib-generator/issues/826 "module": "Node16", "noFallthroughCasesInSwitch": true, "skipLibCheck": true,