diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ad0a1f60..5dd86610 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -7,7 +7,9 @@ on: paths: - '.github/workflows/test.yml' - 'src/**/*.rs' + - 'examples/**/*' - 'ort-sys/**/*.rs' + - 'ort-sys/**/dist.txt' - 'build.rs' - 'Cargo.toml' - '.cargo/**/*' @@ -16,7 +18,9 @@ on: paths: - '.github/workflows/test.yml' - 'src/**/*.rs' + - 'examples/**/*' - 'ort-sys/**/*.rs' + - 'ort-sys/**/dist.txt' - 'build.rs' - 'Cargo.toml' - '.cargo/**/*' diff --git a/.gitignore b/.gitignore index bf1af90f..6ab71818 100644 --- a/.gitignore +++ b/.gitignore @@ -186,6 +186,7 @@ WixTools/ # ONNX Runtime downloaded models **/*.onnx **/*.ort +**/*.pbseq !examples/webassembly/**/*.ort !tests/data/*.onnx !tests/data/*.ort @@ -195,3 +196,9 @@ WixTools/ # Glassbench results /glassbench*.db + +# Python virtual environment +.venv* + +# Training checkpoints +tools/train-data/**/checkpoint diff --git a/Cargo.toml b/Cargo.toml index e7b3785e..a3b3efcb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ 'examples/model-info', 'examples/yolov8', 'examples/modnet', + 'examples/training', 'examples/webassembly' ] default-members = [ @@ -22,8 +23,8 @@ exclude = [ 'examples/cudarc' ] [package] name = "voicevox-ort" -description = "A safe Rust wrapper for ONNX Runtime 1.17 - Optimize and Accelerate Machine Learning Inferencing" -version = "2.0.0-rc.2" +description = "A safe Rust wrapper for ONNX Runtime 1.18 - Optimize and accelerate machine learning inference & training" +version = "2.0.0-rc.4" edition = "2021" rust-version = "1.70" license = "MIT OR Apache-2.0" @@ -45,7 +46,7 @@ strip = true codegen-units = 1 [package.metadata.docs.rs] -features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs", "__init-for-voicevox" ] +features = [ "ndarray", "half", "training", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs", "__init-for-voicevox" ] targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"] rustdoc-args = [ "--cfg", "docsrs" ] @@ -55,6 +56,8 @@ name = "ort" [features] default = [ "ndarray", "half", "download-binaries", "copy-dylibs" ] +training = [ "voicevox-ort-sys/training" ] + operator-libraries = [ "libc", "winapi" ] fetch-models = [ "ureq" ] @@ -90,7 +93,7 @@ anyhow = "1.0" ndarray = { version = "0.15", optional = true } thiserror = "1.0" once_cell = "1.19.0" -voicevox-ort-sys = { version = "2.0.0-rc.2", path = "ort-sys" } +voicevox-ort-sys = { version = "2.0.0-rc.4", path = "ort-sys" } libloading = { version = "0.8", optional = true } ureq = { version = "2.1", optional = true, default-features = false, features = [ "tls" ] } diff --git a/README.md b/README.md index 262f3dbd..e3bdcc7d 100644 --- a/README.md +++ b/README.md @@ -3,10 +3,10 @@
Coverage Results Crates.io Open Collective backers and sponsors
- Crates.io ONNX Runtime + Crates.io ONNX Runtime -`ort` is an (unofficial) [ONNX Runtime](https://onnxruntime.ai/) 1.17 wrapper for Rust based on the now inactive [`onnxruntime-rs`](https://github.com/nbigaouette/onnxruntime-rs). ONNX Runtime accelerates ML inference on both CPU & GPU. +`ort` is an (unofficial) [ONNX Runtime](https://onnxruntime.ai/) 1.18 wrapper for Rust based on the now inactive [`onnxruntime-rs`](https://github.com/nbigaouette/onnxruntime-rs). ONNX Runtime accelerates ML inference and training on both CPU & GPU. ## 📖 Documentation - [Guide](https://ort.pyke.io/) @@ -24,11 +24,11 @@ - **[Twitter](https://twitter.com/)** uses `ort` to serve homepage recommendations to hundreds of millions of users. - **[Bloop](https://bloop.ai/)** uses `ort` to power their semantic code search feature. -- **[pyke Diffusers](https://github.com/pykeio/diffusers)** uses `ort` for efficient Stable Diffusion image generation on both CPUs & GPUs. - **[edge-transformers](https://github.com/npc-engine/edge-transformers)** uses `ort` for accelerated transformer model inference at the edge. - **[Ortex](https://github.com/relaypro-open/ortex)** uses `ort` for safe ONNX Runtime bindings in Elixir. - **[Supabase](https://supabase.com/)** uses `ort` to remove cold starts for their edge functions. - **[Lantern](https://github.com/lanterndata/lantern_extras)** uses `ort` to provide embedding model inference inside Postgres. +- **[Magika](https://github.com/google/magika)** uses `ort` for content type detection. ## 🌠 Sponsor `ort` diff --git a/docs/mint.json b/docs/mint.json deleted file mode 100644 index 8a6237ba..00000000 --- a/docs/mint.json +++ /dev/null @@ -1,102 +0,0 @@ -{ - "$schema": "https://mintlify.com/schema.json", - "name": "ort", - "logo": { - "dark": "/assets/banner.png", - "light": "/assets/banner.png" - }, - "favicon": "/assets/icon.png", - "colors": { - "primary": "#F74C00", - "light": "#F74C00", - "background": { - "light": "#FFFFFF", - "dark": "#000000" - }, - "dark": "#F74C00", - "anchors": { - "from": "#F74C00", - "to": "#eb8e65" - } - }, - "tabs": [ - { - "name": "API Reference", - "url": "https://docs.rs/ort/2.0.0-rc.2/ort/" - } - ], - "anchors": [ - { - "name": "Sponsor", - "icon": "hand-holding-heart", - "url": "https://opencollective.com/pyke-osai" - }, - { - "name": "Crates.io", - "icon": "rust", - "url": "https://crates.io/crates/ort" - }, - { - "name": "GitHub", - "icon": "github", - "url": "https://github.com/pykeio/ort" - }, - { - "name": "Discord", - "icon": "discord", - "url": "https://discord.gg/uQtsNu2xMa" - } - ], - "navigation": [ - { - "group": "Get Started", - "pages": [ - "introduction" - ] - }, - { - "group": "Setup", - "pages": [ - "setup/platforms", - "setup/webassembly", - "setup/linking", - "setup/cargo-features" - ] - }, - { - "group": "Fundamentals", - "pages": [ - "fundamentals/environment", - "fundamentals/session", - "fundamentals/value" - ] - }, - { - "group": "Performance", - "pages": [ - "perf/execution-providers", - "perf/io-binding" - ] - }, - { - "group": "Troubleshooting", - "pages": [ - "troubleshooting/precision", - "troubleshooting/performance", - "troubleshooting/compiling" - ] - }, - { - "group": "Migration & versioning", - "pages": [ - "migrating/version-mapping", - "migrating/v2" - ] - } - ], - "footerSocials": { - "website": "https://pyke.io/", - "github": "https://github.com/pykeio/ort", - "discord": "https://discord.gg/uQtsNu2xMa" - } -} diff --git a/docs/next-env.d.ts b/docs/next-env.d.ts new file mode 100644 index 00000000..4f11a03d --- /dev/null +++ b/docs/next-env.d.ts @@ -0,0 +1,5 @@ +/// +/// + +// NOTE: This file should not be edited +// see https://nextjs.org/docs/basic-features/typescript for more information. diff --git a/docs/next.config.mjs b/docs/next.config.mjs new file mode 100644 index 00000000..47e0f5ea --- /dev/null +++ b/docs/next.config.mjs @@ -0,0 +1,11 @@ +import nextra from 'nextra'; + +export default nextra({ + theme: 'nextra-theme-docs', + themeConfig: './theme.config.jsx' +})({ + output: 'export', + images: { + unoptimized: true + } +}); diff --git a/docs/package.json b/docs/package.json new file mode 100644 index 00000000..36fdb98c --- /dev/null +++ b/docs/package.json @@ -0,0 +1,23 @@ +{ + "private": true, + "name": "ort-docs", + "version": "0.0.0", + "scripts": { + "dev": "next dev", + "build": "next build", + "start": "next start" + }, + "dependencies": { + "next": "^14.2.3", + "nextra": "^2.13.4", + "nextra-theme-docs": "^2.13.4", + "react": "^18.3.1", + "react-dom": "^18.3.1" + }, + "devDependencies": { + "@types/node": "20.14.2", + "@types/react": "^18.3.3", + "@types/react-dom": "^18.3.0", + "typescript": "^5.4.5" + } +} diff --git a/docs/pages/_app.mdx b/docs/pages/_app.mdx new file mode 100644 index 00000000..c466f982 --- /dev/null +++ b/docs/pages/_app.mdx @@ -0,0 +1,5 @@ +import font from 'next/font/google'; + +export default function App({ Component, pageProps }) { + return ; +} diff --git a/docs/pages/_meta.json b/docs/pages/_meta.json new file mode 100644 index 00000000..a58afd92 --- /dev/null +++ b/docs/pages/_meta.json @@ -0,0 +1,37 @@ +{ + "-- Links": { + "type": "separator", + "title": "Links" + }, + "link-oc": { + "title": "Sponsor ↗", + "href": "https://opencollective.com/pyke-osai", + "newWindow": true + }, + "link-api": { + "title": "API Reference ↗", + "href": "https://docs.rs/ort/2.0.0-rc.4/ort" + }, + "link-crates": { + "title": "Crates.io ↗", + "href": "https://crates.io/crates/ort", + "newWindow": true + }, + "-- Docs": { + "type": "separator", + "title": "Docs" + }, + "index": "Introduction", + "setup": { + "title": "Setup" + }, + "perf": { + "title": "Performance" + }, + "troubleshooting": { + "title": "Troubleshooting" + }, + "migrating": { + "title": "Migration & versioning" + } +} diff --git a/docs/introduction.mdx b/docs/pages/index.mdx similarity index 55% rename from docs/introduction.mdx rename to docs/pages/index.mdx index 0ff676ba..8034d46c 100644 --- a/docs/introduction.mdx +++ b/docs/pages/index.mdx @@ -2,14 +2,17 @@ title: Introduction --- +import Image from 'next/image'; +import { Callout, Card, Cards, Steps } from 'nextra/components'; +

ort is an open-source Rust binding for ONNX Runtime.

- - These docs are for the latest alpha version of `ort`, `2.0.0-rc.2`. This version is production-ready (just not API stable) and we recommend new & existing projects use it. - + + These docs are for the latest alpha version of `ort`, `2.0.0-rc.4`. This version is production-ready (just not API stable) and we recommend new & existing projects use it. + `ort` makes it easy to deploy your machine learning models to production via [ONNX Runtime](https://onnxruntime.ai/), a hardware-accelerated inference engine. With `ort` + ONNX Runtime, you can run almost any ML model (including ResNet, YOLOv8, BERT, LLaMA) on almost any hardware, often far faster than PyTorch, and with the added bonus of Rust's efficiency. @@ -29,52 +32,54 @@ Converting a neural network to a graph representation like ONNX opens the door t # Getting started - - If you have a [supported platform](/setup/platforms) (and you probably do), installing `ort` couldn't be any simpler! Just add it to your Cargo dependencies: - ```toml - [dependencies] - ort = "2.0.0-rc.2" - ``` - - - Your model will need to be converted to an ONNX graph before you can use it. - - The awesome folks at Hugging Face have [a guide](https://huggingface.co/docs/transformers/serialization) to export 🤗 Transformers models to ONNX with 🤗 Optimum. - - For any PyTorch model: [`torch.onnx`](https://pytorch.org/docs/stable/onnx.html) - - For `scikit-learn` models: [`sklearn-onnx`](https://onnx.ai/sklearn-onnx/) - - For TensorFlow, Keras, TFlite, & TensorFlow.js: [`tf2onnx`](https://github.com/onnx/tensorflow-onnx) - - For PaddlePaddle: [`Paddle2ONNX`](https://github.com/PaddlePaddle/Paddle2ONNX) - - - Once you've got a model, load it via `ort` by creating a [`Session`](/fundamentals/session): - - ```rust - use ort::{GraphOptimizationLevel, Session}; - - let model = Session::builder()? - .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_intra_threads(4)? - .commit_from_file("yolov8m.onnx")?; - ``` - - - Preprocess your inputs, then `run()` the session to perform inference. - - ```rust - let outputs = model.run(ort::inputs!["image" => image]?)?; - let predictions = outputs["output0"].try_extract_tensor::()?; - ... - ``` - - There are some more useful examples [in the `ort` repo](https://github.com/pykeio/ort/tree/main/examples)! - + +### Add ort to your Cargo.toml +If you have a [supported platform](/setup/platforms) (and you probably do), installing `ort` couldn't be any simpler! Just add it to your Cargo dependencies: +```toml +[dependencies] +ort = "2.0.0-rc.4" +``` + +### Convert your model +Your model will need to be converted to an ONNX graph before you can use it. +- The awesome folks at Hugging Face have [a guide](https://huggingface.co/docs/transformers/serialization) to export 🤗 Transformers models to ONNX with 🤗 Optimum. +- For any PyTorch model: [`torch.onnx`](https://pytorch.org/docs/stable/onnx.html) +- For `scikit-learn` models: [`sklearn-onnx`](https://onnx.ai/sklearn-onnx/) +- For TensorFlow, Keras, TFlite, & TensorFlow.js: [`tf2onnx`](https://github.com/onnx/tensorflow-onnx) +- For PaddlePaddle: [`Paddle2ONNX`](https://github.com/PaddlePaddle/Paddle2ONNX) + +### Load your model +Once you've got a model, load it via `ort` by creating a [`Session`](/fundamentals/session): + +```rust +use ort::{GraphOptimizationLevel, Session}; + +let model = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(4)? + .commit_from_file("yolov8m.onnx")?; +``` + +### Perform inference +Preprocess your inputs, then `run()` the session to perform inference. + +```rust +let outputs = model.run(ort::inputs!["image" => image]?)?; +let predictions = outputs["output0"].try_extract_tensor::()?; +... +``` + +There are some more useful examples [in the `ort` repo](https://github.com/pykeio/ort/tree/main/examples)! + # Next steps - - Use [execution providers](/perf/execution-providers) to enable hardware acceleration in your app and unlock the full power of your GPU or NPU. - - - We'd love to see what you've made with `ort`! Show off your project in [GitHub Discussions](https://github.com/pykeio/ort/discussions/categories/show-and-tell) or on our [Discord](https://discord.gg/uQtsNu2xMa). - + +### Unlock more performance with EPs +Use [execution providers](/perf/execution-providers) to enable hardware acceleration in your app and unlock the full power of your GPU or NPU. + +### Show off your project! +We'd love to see what you've made with `ort`! Show off your project in [GitHub Discussions](https://github.com/pykeio/ort/discussions/categories/show-and-tell) or on our [Discord](https://discord.gg/uQtsNu2xMa). + diff --git a/docs/migrating/opsets.mdx b/docs/pages/migrating/opsets.mdx similarity index 100% rename from docs/migrating/opsets.mdx rename to docs/pages/migrating/opsets.mdx diff --git a/docs/migrating/v2.mdx b/docs/pages/migrating/v2.mdx similarity index 97% rename from docs/migrating/v2.mdx rename to docs/pages/migrating/v2.mdx index a2f20201..f1a6af3f 100644 --- a/docs/migrating/v2.mdx +++ b/docs/pages/migrating/v2.mdx @@ -141,13 +141,6 @@ let noise_pred = unet.run(ort::inputs![ ]?)?; ``` -You can also supply `ort::inputs!` your `IoBinding` by specifying `bind =`: -```rust -let binding = model.create_binding()?; -... -let outputs = model.run(ort::inputs![bind = binding]?)?; -``` - ### Tensor creation no longer requires the session's allocator In previous versions, `Value::from_array` took an allocator parameter. The allocator was only used because the string data contained in string tensors had to be cloned into ONNX Runtime-managed memory. However, 99% of users only ever use primitive tensors, so the extra parameter served little purpose. The new `Tensor::from_array` function now takes only an array, and the logic for converting string arrays has been moved to a new function, `DynTensor::from_string_array`. @@ -180,7 +173,7 @@ let l = outputs["latents"].try_extract_tensor::()?; ``` ## Execution providers -Execution provider structs with public fields have been replaced with builder pattern structs. See the [API reference](https://docs.rs/ort/2.0.0-rc.2/ort/index.html?search=ExecutionProvider) and the [execution providers reference](/perf/execution-providers) for more information. +Execution provider structs with public fields have been replaced with builder pattern structs. See the [API reference](https://docs.rs/ort/2.0.0-rc.4/ort/index.html?search=ExecutionProvider) and the [execution providers reference](/perf/execution-providers) for more information. ```diff -// v1.x diff --git a/docs/migrating/version-mapping.mdx b/docs/pages/migrating/version-mapping.mdx similarity index 92% rename from docs/migrating/version-mapping.mdx rename to docs/pages/migrating/version-mapping.mdx index d2861ce0..c4ac5d43 100644 --- a/docs/migrating/version-mapping.mdx +++ b/docs/pages/migrating/version-mapping.mdx @@ -6,7 +6,7 @@ description: Information about `ort`'s versioning and relation to ONNX Runtime v ## A note on SemVer `ort` versions pre-2.0 were not SemVer compatible. From v2.0 onwards, breaking API changes are accompanied by a **major version update**. -Updates to the version of ONNX Runtime used by `ort` may occur on **minor** version updates, i.e. 2.0 ships with ONNX Runtime 1.17.3, but 2.1 may ship with 1.18.0. ONNX Runtime is generally forward compatible, but in case you require a specific version of ONNX Runtime, you should pin the minor version in your `Cargo.toml` using a [tilde requirement](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#tilde-requirements): +Updates to the version of ONNX Runtime used by `ort` may occur on **minor** version updates, i.e. 2.0 ships with ONNX Runtime 1.18.1, but 2.1 may ship with 1.19.0. ONNX Runtime is generally forward compatible, but in case you require a specific version of ONNX Runtime, you should pin the minor version in your `Cargo.toml` using a [tilde requirement](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#tilde-requirements): ```toml [dependencies] ort = { version = "~2.0", ... } @@ -16,7 +16,7 @@ ort = { version = "~2.0", ... } | **ort** | **ONNX Runtime** | | -------- | ----------------:| -| v2.0.0+ | v1.17.3 | +| v2.0.0+ | v1.18.1 | | v1.16.0-v1.16.2 | v1.16.0 | | v1.15.0-v1.15.5 | v1.15.1 | | v1.14.2-v1.14.8 | v1.14.1 | diff --git a/docs/perf/execution-providers.mdx b/docs/pages/perf/execution-providers.mdx similarity index 84% rename from docs/perf/execution-providers.mdx rename to docs/pages/perf/execution-providers.mdx index 03447b1e..f7f20aaa 100644 --- a/docs/perf/execution-providers.mdx +++ b/docs/pages/perf/execution-providers.mdx @@ -3,6 +3,8 @@ title: Execution providers description: Learn how to enable execution providers to leverage hardware acceleration. --- +import { Callout, Tabs } from 'nextra/components'; + Execution providers (EPs) enable ONNX Runtime to execute ONNX graphs with hardware acceleration. If you have specialized hardware like a GPU or NPU, execution providers can provide a massive performance boost to your `ort` applications. For more information on the intricacies of execution providers, see the [ONNX Runtime docs](https://onnxruntime.ai/docs/execution-providers/). ONNX Runtime must be compiled with support for each execution provider. pyke provides precompiled binaries for some of the most common EPs, so you won't need to compile ONNX Runtime from source. Below is a table showing available EPs, their support in `ort`, and their binary availability status. @@ -28,12 +30,12 @@ ONNX Runtime must be compiled with support for each execution provider. pyke pro | Microsoft Azure | ❌ | ❌ | ❓ | | Rockchip RKNPU | ❌ | ❌ | ❓ | - + Some EPs supported by ONNX Runtime are not supported by `ort` due to a lack of hardware for testing. If your preferred EP is missing support and you've got the hardware, please [open an issue](https://github.com/pykeio/ort/issues/new)! - + ## Registering execution providers - + To use an execution provider with `ort`, you'll need to enable its respective Cargo feature, e.g. the `cuda` feature to use CUDA, or the `coreml` feature to use CoreML. ```toml Cargo.toml @@ -42,7 +44,7 @@ ONNX Runtime must be compiled with support for each execution provider. pyke pro ``` See [Cargo features](/setup/cargo-features) for the full list of features. - + In order to configure sessions to use certain execution providers, you must **register** them when creating an environment or session. You can do this via the `SessionBuilder::with_execution_providers` method. For example, to register the CUDA execution provider for a session: @@ -81,7 +83,7 @@ fn main() -> anyhow::Result<()> { ``` ## Configuring EPs -EPs have configuration options to control behavior or increase performance. Each `XXXExecutionProvider` struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.2/ort/index.html?search=ExecutionProvider) for the EP structs for more information on which options are supported and what they do. +EPs have configuration options to control behavior or increase performance. Each `XXXExecutionProvider` struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.4/ort/index.html?search=ExecutionProvider) for the EP structs for more information on which options are supported and what they do. ```rust use ort::{CoreMLExecutionProvider, Session}; @@ -105,7 +107,22 @@ fn main() -> anyhow::Result<()> { ## Fallback behavior `ort` will silently fail and fall back to executing on the CPU if all execution providers fail to register. In many cases, though, you'll want to show the user an error message when an EP fails to register, or outright abort the process. -To receive these registration errors, instead use `ExecutionProvider::register` to register an execution provider: +You can configure an EP to return an error on failure by adding `.error_on_failure()` after you `.build()` it. In this example, if CUDA doesn't register successfully, the program will exit with an error at `with_execution_providers`: +```rust +use ort::{CoreMLExecutionProvider, Session}; + +fn main() -> anyhow::Result<()> { + let session = Session::builder()? + .with_execution_providers([ + CUDAExecutionProvider::default().build().error_on_failure() + ])? + .commit_from_file("model.onnx")?; + + Ok(()) +} +``` + +If you require more complex error handling, you can also manually register execution providers via the `ExecutionProvider::register` method: ```rust use ort::{CUDAExecutionProvider, ExecutionProvider, Session}; @@ -167,9 +184,9 @@ fn main() -> anyhow::Result<()> { } ``` - + `ort::init` must come before you create any sessions, otherwise the configuration will not take effect! - + Sessions configured with their own execution providers will *extend* the execution provider defaults, rather than overriding them. @@ -178,35 +195,40 @@ If it seems like the execution provider is not registering properly, or you are ## Notes +### CUDA +`ort` provides binaries for both CUDA 11 and CUDA 12; `ort` will automatically choose which binary to install based on whether CUDA 12 is installed. + +CUDA 11 requires cuDNN 8.x. CUDA 12 requires cuDNN 9.x. Make sure the correct version of cuDNN is installed and available on the `PATH`. + ### CoreML Statically linking to CoreML (the default behavior when using downloaded binaries + the `coreml` Cargo feature) requires an additional Rust flag in order to link properly. You'll need to provide the flag `-C link-arg=-fapple-link-rtlib` to `rustc`. You can do this via an entry in [`.cargo/config.toml`](https://doc.rust-lang.org/cargo/reference/config.html#hierarchical-structure), in a build script, or in an environment variable. - - + + See [Configuration: Hierarchical structure](https://doc.rust-lang.org/cargo/reference/config.html#hierarchical-structure) for more information on where the configuration file can be placed. - ```toml .cargo/config.toml + ```toml filename=".cargo/config.toml" copy [target.aarch64-apple-darwin] rustflags = ["-Clink-arg=-fapple-link-rtlib"] [target.x86_64-apple-darwin] rustflags = ["-Clink-arg=-fapple-link-rtlib"] ``` - - + + Add the following to the `build.rs` script of any **binary** crate that uses `ort`. - ```rust build.rs + ```rust filename="build.rs" copy fn main() { println!("cargo:rustc-link-arg=-fapple-link-rtlib"); } ``` Library crates do not need this flag, and the usage of it in a library crate will not transitively apply to any binary crates dependent on it. - - - ```shell + + + ```shell copy $ RUSTFLAGS="-Clink-arg=-fapple-link-rtlib" cargo build ``` - + diff --git a/docs/perf/io-binding.mdx b/docs/pages/perf/io-binding.mdx similarity index 100% rename from docs/perf/io-binding.mdx rename to docs/pages/perf/io-binding.mdx diff --git a/docs/setup/cargo-features.mdx b/docs/pages/setup/cargo-features.mdx similarity index 95% rename from docs/setup/cargo-features.mdx rename to docs/pages/setup/cargo-features.mdx index 9fc92416..ad495f72 100644 --- a/docs/setup/cargo-features.mdx +++ b/docs/pages/setup/cargo-features.mdx @@ -9,8 +9,8 @@ title: Cargo features - ✅ **`half`**: Enables support for float16 & bfloat16 tensors via the [`half`](https://crates.io/crates/half) crate. ONNX models that are converted to 16-bit precision will typically convert to/from 32-bit floats at the input/output, so you will likely never actually need to interact with a 16-bit tensor on the Rust side. Though, `half` isn't a heavy enough crate to worry about it affecting compile times. - ✅ **`copy-dylibs`**: In case dynamic libraries are used (like with the CUDA execution provider), creates a symlink to them in the relevant places in the `target` folder to make [compile-time dynamic linking](/setup/linking#compile-time-dynamic-linking) work. - ⚒️ **`load-dynamic`**: Enables [runtime dynamic linking](/setup/linking#runtime-loading-with-load-dynamic), which alleviates many of the troubles with compile-time dynamic linking and offers greater flexibility. -- ⚒️ **`fetch-models`**: Enables the [`SessionBuilder::commit_from_url`](https://docs.rs/ort/2.0.0-rc.2/ort/struct.SessionBuilder.html#method.commit_from_url) method, allowing you to quickly download & run a model from a URL. This should only be used for quick testing. -- ⚒️ **`operator-libraries`**: Allows for sessions to load custom operators from dynamic C++ libraries via [`SessionBuilder::with_operator_library`](https://docs.rs/ort/2.0.0-rc.2/ort/struct.SessionBuilder.html#method.with_operator_library). If possible, we recommend [writing your custom ops in Rust instead](/perf/custom-operators). +- ⚒️ **`fetch-models`**: Enables the [`SessionBuilder::commit_from_url`](https://docs.rs/ort/2.0.0-rc.4/ort/struct.SessionBuilder.html#method.commit_from_url) method, allowing you to quickly download & run a model from a URL. This should only be used for quick testing. +- ⚒️ **`operator-libraries`**: Allows for sessions to load custom operators from dynamic C++ libraries via [`SessionBuilder::with_operator_library`](https://docs.rs/ort/2.0.0-rc.4/ort/struct.SessionBuilder.html#method.with_operator_library). If possible, we recommend [writing your custom ops in Rust instead](/perf/custom-operators). ## Execution providers Each [execution provider](/perf/execution-providers) is also gated behind a Cargo feature. diff --git a/docs/pages/setup/linking.mdx b/docs/pages/setup/linking.mdx new file mode 100644 index 00000000..bdb49234 --- /dev/null +++ b/docs/pages/setup/linking.mdx @@ -0,0 +1,109 @@ +--- +title: Linking +description: Here's how `ort` links to ONNX Runtime, and how to configure its behavior. +--- + +import { Callout, Tabs, Steps } from 'nextra/components'; + +`ort` provides its own builds of ONNX Runtime to make your experience as painless as possible, but in some cases, you'll want to use a custom build of ONNX Runtime with `ort`. Luckily, we make this very easy by handling all of the linking configuration automagically. Just point `ort` to the output of ONNX Runtime's build pipeline and it'll Just Work™. + +## Static linking +Most ONNX Runtime compile configurations will support static linking - just run `build.sh` without the `--build_shared_lib` argument. You should prefer static linking if your execution providers support it, as it avoids many issues and follows de facto Rust practices. If you compile both static libraries and dynamic libraries, `ort` will prefer linking to the static libraries. + +To direct `ort` to your statically built binaries, use the `ORT_LIB_LOCATION` environment variable when running `cargo build`. Point it to the location where the static libraries (`.a`/`.lib` files) are compiled to. This will typically be `onnxruntime/build/`. For example: +```shell +$ ORT_LIB_LOCATION=~/onnxruntime/build/Linux cargo build +``` + +For iOS (or for other platforms if you are compiling multiple profiles at once), you'll need to manually specify the profile with the `ORT_LIB_PROFILE` environment variable. If not specified, `ort` will prefer `Release` over `RelWithDebInfo` over `MinSizeRel` over `Debug`. + +## Dynamic linking +Some execution providers unfortunately only support dynamic linking. Dynamic linking doesn't play well with the Rust ecosystem, though `ort` tries to alleviate the pain as much as possible. + +When it comes to dynamic linking, there are two options: `load-dynamic`, or standard compile-time dynamic linking. We recommend `load-dynamic` as it gives more control and is often far less troublesome to work with. + +### Runtime loading with `load-dynamic` +The `load-dynamic` Cargo feature solves a few of the issues with dynamic linking by **loading the library at runtime** rather than **linking at compile time**. This means that the path to the ONNX Runtime library can be configured at runtime, and the executable will not just completely fail to start if the binary couldn't be found. + +To use `load-dynamic`: + + +#### Enable the feature in Cargo.toml +```toml filename="Cargo.toml" +[dependencies] +ort = { version = "2", features = [ "load-dynamic" ] } +``` + +### Point ort to the dylib + + + ```rust main.rs + fn main() -> anyhow::Result<()> { + // Find our custom ONNX Runtime dylib path somehow + // (i.e. resolving it from the root of our program's install folder) + let dylib_path = crate::internal::find_onnxruntime_dylib()?; + // The path should point to the `libonnxruntime` binary, which looks like: + // - on Unix: /etc/.../libonnxruntime.so + // - on Windows: C:\Program Files\...\onnxruntime.dll + + // Initialize ort with the path to the dylib. This **must** be called before any usage of `ort`! + // `init_from` returns an `EnvironmentBuilder` which you can use to further configure the environment + // before `.commit()`ing; see the Environment docs for more information on what you can configure. + ort::init_from(dylib_path).commit()?; + + Ok(()) + } + ``` + + + Set the `ORT_DYLIB_PATH` environment variable to the path to `libonnxruntime.so`/`onnxruntime.dll`. + + ```shell + $ ORT_DYLIB_PATH=../onnxruntime-build/linux-x64/libonnxruntime.so ./mirai + ``` + + + + + +`ORT_DYLIB_PATH` is relative to the executable. Cargo examples and tests are compiled to a different directory than binary crates: `target//examples` and `target//deps` respectively. Keep this in mind if you're going to use relative paths. + +### Compile-time dynamic linking +For compile-time dynamic linking, you'll need to configure your environment in the exact same way as if you were [statically linking](#static-linking). + +Note that the dylibs then have to be placed in a certain location for them to be found by the executable. For Windows, this is either somewhere on the `PATH`, or in the same folder as the executable. On macOS and Linux, they have to be placed somewhere in the `LD_LIBRARY_PATH`, or you can use rpath to configure the executable to search for dylibs in its parent folder. We've had the least issues with rpath, but YMMV. + +To configure rpath, you'll need to: + +#### Enable rpath in Cargo.toml +```toml filename="Cargo.toml" copy +[profile.dev] +rpath = true + +[profile.release] +rpath = true + +# do this for any other profiles +``` + +### Configure the path in the linker args in .cargo/config.toml to be relative to the executable + + + ```toml filename="~/.cargo/config.toml" copy + [target.x86_64-unknown-linux-gnu] + rustflags = [ "-Clink-args=-Wl,-rpath,\\$ORIGIN" ] + + # do this for any other Linux targets as well + ``` + + + ```toml filename="~/.cargo/config.toml" copy + [target.x86_64-apple-darwin] + rustflags = [ "-Clink-args=-Wl,-rpath,@loader_path" ] + + # do this for any other macOS targets as well + ``` + + + + diff --git a/docs/setup/platforms.mdx b/docs/pages/setup/platforms.mdx similarity index 62% rename from docs/setup/platforms.mdx rename to docs/pages/setup/platforms.mdx index 02443452..c7a55c9d 100644 --- a/docs/setup/platforms.mdx +++ b/docs/pages/setup/platforms.mdx @@ -3,7 +3,9 @@ title: Platform support description: ONNX Runtime, and by extension `ort`, supports a wide variety of platforms. For most desktop users, pre-built binaries are available, so setting up `ort` is as simple as adding it to your `Cargo.toml`! --- -Here are the supported platforms and binary availability status, as of v2.0.0-rc.2. +import { Callout } from 'nextra/components'; + +Here are the supported platforms and binary availability status, as of v2.0.0-rc.4. * 🟢 - Supported. Dynamic & static binaries provided by pyke. * 🔷 - Supported. Static binaries provided by pyke. @@ -19,14 +21,18 @@ Here are the supported platforms and binary availability status, as of v2.0.0-rc | **Android** | ❌ | ❌ | ⭕ | ⭕ | ❌ | | **Web** | ❌ | ❌ | ❌ | ❌ | 🔷¶ | -\* Recent version of Windows 10/11 required for pyke binaries.
-† glibc ≥ 2.31 (Ubuntu ≥ 20.04) required for pyke binaries.
-‡ glibc ≥ 2.35 (Ubuntu ≥ 22.04) required for pyke binaries.
-§ macOS ≥ 10.15 required.
-¶ WASM supports a limited subset of ONNX Runtime features. For more info, see [the docs on WebAssembly support](/setup/webassembly). +
+

\* Recent version of Windows 10/11 required for pyke binaries.

+

† glibc ≥ 2.31 (Ubuntu ≥ 20.04) required for pyke binaries.

+

‡ glibc ≥ 2.35 (Ubuntu ≥ 22.04) required for pyke binaries.

+

§ macOS ≥ 10.15 required.

+

¶ WASM supports a limited subset of ONNX Runtime features. For more info, see [the docs on WebAssembly support](/setup/webassembly).

+
If your platform is marked as 🟢 or 🔷, you're in luck! Almost no setup will be required to get `ort` up and running. For platforms marked as ⭕, you'll need to [compile ONNX Runtime from source](https://onnxruntime.ai/docs/build/) and then [link `ort` to your custom binaries](/setup/linking) (but don't worry, we made this setup as simple as possible!) -Certain execution providers may not have binaries available. You can check EP binary support in the [Execution providers](/perf/execution-providers) documentation. + + Certain execution providers may not have binaries available. You can check EP binary support in the [Execution providers](/perf/execution-providers) documentation. + diff --git a/docs/setup/webassembly.mdx b/docs/pages/setup/webassembly.mdx similarity index 83% rename from docs/setup/webassembly.mdx rename to docs/pages/setup/webassembly.mdx index 87b1ccf9..3e5e5cc0 100644 --- a/docs/setup/webassembly.mdx +++ b/docs/pages/setup/webassembly.mdx @@ -5,19 +5,13 @@ description: Deploy ONNX models to the web WebAssembly support in `ort` is currently experimental. If you experience any issues using `ort` in WebAssembly, please [open an issue](https://github.com/pykeio/ort/issues/new). -Development of WASM support is done in a separate branch for now, so you'll have to add `ort` as a Git dependency: -```toml Cargo.toml -[dependencies] -ort = { git = "https://github.com/pykeio/ort.git", branch = "wasm32-unknown-unknown" } -``` - By nature, some features of ONNX Runtime are not available in the web. These include: - **Support for `.onnx` models.** You instead need to [convert `.onnx` models to the `.ort` format](https://onnxruntime.ai/docs/performance/model-optimizations/ort-format-models.html). - **Runtime graph optimizations**, aka `SessionBuilder::with_optimization_level`. You can statically optimize the graph using the `.ort` conversion tool, though. - **Loading models with `commit_from_file`/`commit_from_url`.** You can create models from a slice of bytes in memory with `SessionBuilder::commit_from_memory` or `SessionBuilder::commit_from_memory_directly`. Additionally, you'll need to call `ort::wasm::initialize()` at the earliest possible point in your code, before you use any `ort` APIs: -```rust main.rs +```rust filename="main.rs" copy use ort::Session; static MODEL_BYTES: &[u8] = include_bytes!("../model.ort"); diff --git a/docs/troubleshooting/compiling.mdx b/docs/pages/troubleshooting/compiling.mdx similarity index 100% rename from docs/troubleshooting/compiling.mdx rename to docs/pages/troubleshooting/compiling.mdx diff --git a/docs/troubleshooting/performance.mdx b/docs/pages/troubleshooting/performance.mdx similarity index 55% rename from docs/troubleshooting/performance.mdx rename to docs/pages/troubleshooting/performance.mdx index 6bf41128..e407895e 100644 --- a/docs/troubleshooting/performance.mdx +++ b/docs/pages/troubleshooting/performance.mdx @@ -2,53 +2,56 @@ title: 'Troubleshoot: Performance' --- +import { Callout, Tabs, Steps } from 'nextra/components'; + ## Execution providers don't seem to register `ort` is designed to fail gracefully when an execution provider is not available. It logs failure events through [`tracing`](https://crates.io/crates/tracing), thus you'll need a library that subscribes to `tracing` events to see the logs. The simplest way to do this is to use [`tracing-subscriber`](https://crates.io/crates/tracing-subscriber). - - ```toml Cargo.toml - [dependencies] - tracing-subscriber = { version = "0.3", features = [ "env-filter", "fmt" ] } + +### Add tracing-subscriber to your dependencies +```toml Cargo.toml +[dependencies] +tracing-subscriber = { version = "0.3", features = [ "env-filter", "fmt" ] } +``` + +### Initialize the subscriber in the main function +```rust main.rs +fn main() { + tracing_subscriber::fmt::init(); +} +``` + +### Show debug messages from ort +Set the environment variable `RUST_LOG` to `ort=debug` to see all debug messages from `ort`. + + + ```powershell + $env:RUST_LOG = 'ort=debug'; + cargo run ``` - - - ```rust main.rs - fn main() { - tracing_subscriber::fmt::init(); - } + + + ```cmd + set RUST_LOG=ort=debug + cargo run ``` - - - Set the environment variable `RUST_LOG` to `ort=debug` to see all debug messages from `ort`. - - - ```powershell - $env:RUST_LOG = 'ort=debug'; - cargo run - ``` - - - ```cmd - set RUST_LOG=ort=debug - cargo run - ``` - - - ```shell - RUST_LOG="ort=debug" cargo run - ``` - - - ```shell - RUST_LOG="ort=debug" cargo run - ``` - - - + + + ```shell + RUST_LOG="ort=debug" cargo run + ``` + + + ```shell + RUST_LOG="ort=debug" cargo run + ``` + + + -You can also detect EP regsitration failures programmatically. See [Execution providers: Fallback behavior](/perf/execution-providers#fallback-behavior) for more info. +You can also detect EP regsitration failures programmatically. See [Execution providers: Fallback behavior](/perf/execution-providers#fallback-behavior) for more info. ## Inference is slower than expected There are a few things you could try to improve performance: diff --git a/docs/pnpm-lock.yaml b/docs/pnpm-lock.yaml new file mode 100644 index 00000000..4bebc4f8 --- /dev/null +++ b/docs/pnpm-lock.yaml @@ -0,0 +1,3200 @@ +lockfileVersion: '9.0' + +settings: + autoInstallPeers: true + excludeLinksFromLockfile: false + +importers: + + .: + dependencies: + next: + specifier: ^14.2.3 + version: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + nextra: + specifier: ^2.13.4 + version: 2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + nextra-theme-docs: + specifier: ^2.13.4 + version: 2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(nextra@2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: + specifier: ^18.3.1 + version: 18.3.1 + react-dom: + specifier: ^18.3.1 + version: 18.3.1(react@18.3.1) + devDependencies: + '@types/node': + specifier: 20.14.2 + version: 20.14.2 + '@types/react': + specifier: ^18.3.3 + version: 18.3.3 + '@types/react-dom': + specifier: ^18.3.0 + version: 18.3.0 + typescript: + specifier: ^5.4.5 + version: 5.4.5 + +packages: + + '@babel/runtime@7.24.7': + resolution: {integrity: sha512-UwgBRMjJP+xv857DCngvqXI3Iq6J4v0wXmwc6sapg+zyhbwmQX67LUEFrkK5tbyJ30jGuG3ZvWpBiB9LCy1kWw==} + engines: {node: '>=6.9.0'} + + '@braintree/sanitize-url@6.0.4': + resolution: {integrity: sha512-s3jaWicZd0pkP0jf5ysyHUI/RE7MHos6qlToFcGWXVp+ykHOy77OUMrfbgJ9it2C5bow7OIQwYYaHjk9XlBQ2A==} + + '@headlessui/react@1.7.19': + resolution: {integrity: sha512-Ll+8q3OlMJfJbAKM/+/Y2q6PPYbryqNTXDbryx7SXLIDamkF6iQFbriYHga0dY44PvDhvvBWCx1Xj4U5+G4hOw==} + engines: {node: '>=10'} + peerDependencies: + react: ^16 || ^17 || ^18 + react-dom: ^16 || ^17 || ^18 + + '@mdx-js/mdx@2.3.0': + resolution: {integrity: sha512-jLuwRlz8DQfQNiUCJR50Y09CGPq3fLtmtUQfVrj79E0JWu3dvsVcxVIcfhR5h0iXu+/z++zDrYeiJqifRynJkA==} + + '@mdx-js/react@2.3.0': + resolution: {integrity: sha512-zQH//gdOmuu7nt2oJR29vFhDv88oGPmVw6BggmrHeMI+xgEkp1B2dX9/bMBSYtK0dyLX/aOmesKS09g222K1/g==} + peerDependencies: + react: '>=16' + + '@napi-rs/simple-git-android-arm-eabi@0.1.16': + resolution: {integrity: sha512-dbrCL0Pl5KZG7x7tXdtVsA5CO6At5ohDX3myf5xIYn9kN4jDFxsocl8bNt6Vb/hZQoJd8fI+k5VlJt+rFhbdVw==} + engines: {node: '>= 10'} + cpu: [arm] + os: [android] + + '@napi-rs/simple-git-android-arm64@0.1.16': + resolution: {integrity: sha512-xYz+TW5J09iK8SuTAKK2D5MMIsBUXVSs8nYp7HcMi8q6FCRO7yJj96YfP9PvKsc/k64hOyqGmL5DhCzY9Cu1FQ==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [android] + + '@napi-rs/simple-git-darwin-arm64@0.1.16': + resolution: {integrity: sha512-XfgsYqxhUE022MJobeiX563TJqyQyX4FmYCnqrtJwAfivESVeAJiH6bQIum8dDEYMHXCsG7nL8Ok0Dp8k2m42g==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [darwin] + + '@napi-rs/simple-git-darwin-x64@0.1.16': + resolution: {integrity: sha512-tkEVBhD6vgRCbeWsaAQqM3bTfpIVGeitamPPRVSbsq8qgzJ5Dx6ZedH27R7KSsA/uao7mZ3dsrNLXbu1Wy5MzA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [darwin] + + '@napi-rs/simple-git-linux-arm-gnueabihf@0.1.16': + resolution: {integrity: sha512-R6VAyNnp/yRaT7DV1Ao3r67SqTWDa+fNq2LrNy0Z8gXk2wB9ZKlrxFtLPE1WSpWknWtyRDLpRlsorh7Evk7+7w==} + engines: {node: '>= 10'} + cpu: [arm] + os: [linux] + + '@napi-rs/simple-git-linux-arm64-gnu@0.1.16': + resolution: {integrity: sha512-LAGI0opFKw/HBMCV2qIBK3uWSEW9h4xd2ireZKLJy8DBPymX6NrWIamuxYNyCuACnFdPRxR4LaRFy4J5ZwuMdw==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@napi-rs/simple-git-linux-arm64-musl@0.1.16': + resolution: {integrity: sha512-I57Ph0F0Yn2KW93ep+V1EzKhACqX0x49vvSiapqIsdDA2PifdEWLc1LJarBolmK7NKoPqKmf6lAKKO9lhiZzkg==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@napi-rs/simple-git-linux-x64-gnu@0.1.16': + resolution: {integrity: sha512-AZYYFY2V7hlcQASPEOWyOa3e1skzTct9QPzz0LiDM3f/hCFY/wBaU2M6NC5iG3d2Kr38heuyFS/+JqxLm5WaKA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@napi-rs/simple-git-linux-x64-musl@0.1.16': + resolution: {integrity: sha512-9TyMcYSBJwjT8jwjY9m24BZbu7ozyWTjsmYBYNtK3B0Um1Ov6jthSNneLVvouQ6x+k3Ow+00TiFh6bvmT00r8g==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@napi-rs/simple-git-win32-arm64-msvc@0.1.16': + resolution: {integrity: sha512-uslJ1WuAHCYJWui6xjsyT47SjX6KOHDtClmNO8hqKz1pmDSNY7AjyUY8HxvD1lK9bDnWwc4JYhikS9cxCqHybw==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [win32] + + '@napi-rs/simple-git-win32-x64-msvc@0.1.16': + resolution: {integrity: sha512-SoEaVeCZCDF1MP+M9bMSXsZWgEjk4On9GWADO5JOulvzR1bKjk0s9PMHwe/YztR9F0sJzrCxwtvBZowhSJsQPg==} + engines: {node: '>= 10'} + cpu: [x64] + os: [win32] + + '@napi-rs/simple-git@0.1.16': + resolution: {integrity: sha512-C5wRPw9waqL2jk3jEDeJv+f7ScuO3N0a39HVdyFLkwKxHH4Sya4ZbzZsu2JLi6eEqe7RuHipHL6mC7B2OfYZZw==} + engines: {node: '>= 10'} + + '@next/env@14.2.3': + resolution: {integrity: sha512-W7fd7IbkfmeeY2gXrzJYDx8D2lWKbVoTIj1o1ScPHNzvp30s1AuoEFSdr39bC5sjxJaxTtq3OTCZboNp0lNWHA==} + + '@next/swc-darwin-arm64@14.2.3': + resolution: {integrity: sha512-3pEYo/RaGqPP0YzwnlmPN2puaF2WMLM3apt5jLW2fFdXD9+pqcoTzRk+iZsf8ta7+quAe4Q6Ms0nR0SFGFdS1A==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [darwin] + + '@next/swc-darwin-x64@14.2.3': + resolution: {integrity: sha512-6adp7waE6P1TYFSXpY366xwsOnEXM+y1kgRpjSRVI2CBDOcbRjsJ67Z6EgKIqWIue52d2q/Mx8g9MszARj8IEA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [darwin] + + '@next/swc-linux-arm64-gnu@14.2.3': + resolution: {integrity: sha512-cuzCE/1G0ZSnTAHJPUT1rPgQx1w5tzSX7POXSLaS7w2nIUJUD+e25QoXD/hMfxbsT9rslEXugWypJMILBj/QsA==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@next/swc-linux-arm64-musl@14.2.3': + resolution: {integrity: sha512-0D4/oMM2Y9Ta3nGuCcQN8jjJjmDPYpHX9OJzqk42NZGJocU2MqhBq5tWkJrUQOQY9N+In9xOdymzapM09GeiZw==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@next/swc-linux-x64-gnu@14.2.3': + resolution: {integrity: sha512-ENPiNnBNDInBLyUU5ii8PMQh+4XLr4pG51tOp6aJ9xqFQ2iRI6IH0Ds2yJkAzNV1CfyagcyzPfROMViS2wOZ9w==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@next/swc-linux-x64-musl@14.2.3': + resolution: {integrity: sha512-BTAbq0LnCbF5MtoM7I/9UeUu/8ZBY0i8SFjUMCbPDOLv+un67e2JgyN4pmgfXBwy/I+RHu8q+k+MCkDN6P9ViQ==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@next/swc-win32-arm64-msvc@14.2.3': + resolution: {integrity: sha512-AEHIw/dhAMLNFJFJIJIyOFDzrzI5bAjI9J26gbO5xhAKHYTZ9Or04BesFPXiAYXDNdrwTP2dQceYA4dL1geu8A==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [win32] + + '@next/swc-win32-ia32-msvc@14.2.3': + resolution: {integrity: sha512-vga40n1q6aYb0CLrM+eEmisfKCR45ixQYXuBXxOOmmoV8sYST9k7E3US32FsY+CkkF7NtzdcebiFT4CHuMSyZw==} + engines: {node: '>= 10'} + cpu: [ia32] + os: [win32] + + '@next/swc-win32-x64-msvc@14.2.3': + resolution: {integrity: sha512-Q1/zm43RWynxrO7lW4ehciQVj+5ePBhOK+/K2P7pLFX3JaJ/IZVC69SHidrmZSOkqz7ECIOhhy7XhAFG4JYyHA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [win32] + + '@popperjs/core@2.11.8': + resolution: {integrity: sha512-P1st0aksCrn9sGZhp8GMYwBnQsbvAWsZAX44oXNNvLHGqAOcoVxmjZiohstwQ7SqKnbR47akdNi+uleWD8+g6A==} + + '@swc/counter@0.1.3': + resolution: {integrity: sha512-e2BR4lsJkkRlKZ/qCHPw9ZaSxc0MVUd7gtbtaB7aMvHeJVYe8sOB8DBZkP2DtISHGSku9sCK6T6cnY0CtXrOCQ==} + + '@swc/helpers@0.5.5': + resolution: {integrity: sha512-KGYxvIOXcceOAbEk4bi/dVLEK9z8sZ0uBB3Il5b1rhfClSpcX0yfRO0KmTkqR2cnQDymwLB+25ZyMzICg/cm/A==} + + '@tanstack/react-virtual@3.5.1': + resolution: {integrity: sha512-jIsuhfgy8GqA67PdWqg73ZB2LFE+HD9hjWL1L6ifEIZVyZVAKpYmgUG4WsKQ005aEyImJmbuimPiEvc57IY0Aw==} + peerDependencies: + react: ^16.8.0 || ^17.0.0 || ^18.0.0 + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 + + '@tanstack/virtual-core@3.5.1': + resolution: {integrity: sha512-046+AUSiDru/V9pajE1du8WayvBKeCvJ2NmKPy/mR8/SbKKrqmSbj7LJBfXE+nSq4f5TBXvnCzu0kcYebI9WdQ==} + + '@theguild/remark-mermaid@0.0.5': + resolution: {integrity: sha512-e+ZIyJkEv9jabI4m7q29wZtZv+2iwPGsXJ2d46Zi7e+QcFudiyuqhLhHG/3gX3ZEB+hxTch+fpItyMS8jwbIcw==} + peerDependencies: + react: ^18.2.0 + + '@theguild/remark-npm2yarn@0.2.1': + resolution: {integrity: sha512-jUTFWwDxtLEFtGZh/TW/w30ySaDJ8atKWH8dq2/IiQF61dPrGfETpl0WxD0VdBfuLOeU14/kop466oBSRO/5CA==} + + '@types/acorn@4.0.6': + resolution: {integrity: sha512-veQTnWP+1D/xbxVrPC3zHnCZRjSrKfhbMUlEA43iMZLu7EsnTtkJklIuwrCPbOi8YkvDQAiW05VQQFvvz9oieQ==} + + '@types/d3-scale-chromatic@3.0.3': + resolution: {integrity: sha512-laXM4+1o5ImZv3RpFAsTRn3TEkzqkytiOY0Dz0sq5cnd1dtNlk6sHLon4OvqaiJb28T0S/TdsBI3Sjsy+keJrw==} + + '@types/d3-scale@4.0.8': + resolution: {integrity: sha512-gkK1VVTr5iNiYJ7vWDI+yUFFlszhNMtVeneJ6lUTKPjprsvLLI9/tgEGiXJOnlINJA8FyA88gfnQsHbybVZrYQ==} + + '@types/d3-time@3.0.3': + resolution: {integrity: sha512-2p6olUZ4w3s+07q3Tm2dbiMZy5pCDfYwtLXXHUnVzXgQlZ/OyPtUz6OL382BkOuGlLXqfT+wqv8Fw2v8/0geBw==} + + '@types/debug@4.1.12': + resolution: {integrity: sha512-vIChWdVG3LG1SMxEvI/AK+FWJthlrqlTu7fbrlywTkkaONwk/UAGaULXRlf8vkzFBLVm0zkMdCquhL5aOjhXPQ==} + + '@types/estree-jsx@1.0.5': + resolution: {integrity: sha512-52CcUVNFyfb1A2ALocQw/Dd1BQFNmSdkuC3BkZ6iqhdMfQz7JWOFRuJFloOzjk+6WijU56m9oKXFAXc7o3Towg==} + + '@types/estree@1.0.5': + resolution: {integrity: sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==} + + '@types/hast@2.3.10': + resolution: {integrity: sha512-McWspRw8xx8J9HurkVBfYj0xKoE25tOFlHGdx4MJ5xORQrMGZNqJhVQWaIbm6Oyla5kYOXtDiopzKRJzEOkwJw==} + + '@types/hast@3.0.4': + resolution: {integrity: sha512-WPs+bbQw5aCj+x6laNGWLH3wviHtoCv/P3+otBhbOhJgG8qtpdAMlTCxLtsTWA7LH1Oh/bFCHsBn0TPS5m30EQ==} + + '@types/js-yaml@4.0.9': + resolution: {integrity: sha512-k4MGaQl5TGo/iipqb2UDG2UwjXziSWkh0uysQelTlJpX1qGlpUZYm8PnO4DxG1qBomtJUdYJ6qR6xdIah10JLg==} + + '@types/katex@0.16.7': + resolution: {integrity: sha512-HMwFiRujE5PjrgwHQ25+bsLJgowjGjm5Z8FVSf0N6PwgJrwxH0QxzHYDcKsTfV3wva0vzrpqMTJS2jXPr5BMEQ==} + + '@types/mdast@3.0.15': + resolution: {integrity: sha512-LnwD+mUEfxWMa1QpDraczIn6k0Ee3SMicuYSSzS6ZYl2gKS09EClnJYGd8Du6rfc5r/GZEk5o1mRb8TaTj03sQ==} + + '@types/mdast@4.0.4': + resolution: {integrity: sha512-kGaNbPh1k7AFzgpud/gMdvIm5xuECykRR+JnWKQno9TAXVa6WIVCGTPvYGekIDL4uwCZQSYbUxNBSb1aUo79oA==} + + '@types/mdx@2.0.13': + resolution: {integrity: sha512-+OWZQfAYyio6YkJb3HLxDrvnx6SWWDbC0zVPfBRzUk0/nqoDyf6dNxQi3eArPe8rJ473nobTMQ/8Zk+LxJ+Yuw==} + + '@types/ms@0.7.34': + resolution: {integrity: sha512-nG96G3Wp6acyAgJqGasjODb+acrI7KltPiRxzHPXnP3NgI28bpQDRv53olbqGXbfcgF5aiiHmO3xpwEpS5Ld9g==} + + '@types/node@20.14.2': + resolution: {integrity: sha512-xyu6WAMVwv6AKFLB+e/7ySZVr/0zLCzOa7rSpq6jNwpqOrUbcACDWC+53d4n2QHOnDou0fbIsg8wZu/sxrnI4Q==} + + '@types/prop-types@15.7.12': + resolution: {integrity: sha512-5zvhXYtRNRluoE/jAp4GVsSduVUzNWKkOZrCDBWYtE7biZywwdC2AcEzg+cSMLFRfVgeAFqpfNabiPjxFddV1Q==} + + '@types/react-dom@18.3.0': + resolution: {integrity: sha512-EhwApuTmMBmXuFOikhQLIBUn6uFg81SwLMOAUgodJF14SOBOCMdU04gDoYi0WOJJHD144TL32z4yDqCW3dnkQg==} + + '@types/react@18.3.3': + resolution: {integrity: sha512-hti/R0pS0q1/xx+TsI73XIqk26eBsISZ2R0wUijXIngRK9R/e7Xw/cXVxQK7R5JjW+SV4zGcn5hXjudkN/pLIw==} + + '@types/unist@2.0.10': + resolution: {integrity: sha512-IfYcSBWE3hLpBg8+X2SEa8LVkJdJEkT2Ese2aaLs3ptGdVtABxndrMaxuFlQ1qdFf9Q5rDvDpxI3WwgvKFAsQA==} + + '@types/unist@3.0.2': + resolution: {integrity: sha512-dqId9J8K/vGi5Zr7oo212BGii5m3q5Hxlkwy3WpYuKPklmBEvsbMYYyLxAQpSffdLl/gdW0XUpKWFvYmyoWCoQ==} + + '@ungap/structured-clone@1.2.0': + resolution: {integrity: sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ==} + + acorn-jsx@5.3.2: + resolution: {integrity: sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==} + peerDependencies: + acorn: ^6.0.0 || ^7.0.0 || ^8.0.0 + + acorn@8.11.3: + resolution: {integrity: sha512-Y9rRfJG5jcKOE0CLisYbojUjIrIEE7AGMzA/Sm4BslANhbS+cDMpgBdcPT91oJ7OuJ9hYJBx59RjbhxVnrF8Xg==} + engines: {node: '>=0.4.0'} + hasBin: true + + ansi-sequence-parser@1.1.1: + resolution: {integrity: sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==} + + ansi-styles@3.2.1: + resolution: {integrity: sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==} + engines: {node: '>=4'} + + arch@2.2.0: + resolution: {integrity: sha512-Of/R0wqp83cgHozfIYLbBMnej79U/SVGOOyuB3VVFv1NRM/PSFMK12x9KVtiYzJqmnU5WR2qp0Z5rHb7sWGnFQ==} + + arg@1.0.0: + resolution: {integrity: sha512-Wk7TEzl1KqvTGs/uyhmHO/3XLd3t1UeU4IstvPXVzGPM522cTjqjNZ99esCkcL52sjqjo8e8CTBcWhkxvGzoAw==} + + argparse@1.0.10: + resolution: {integrity: sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg==} + + argparse@2.0.1: + resolution: {integrity: sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==} + + astring@1.8.6: + resolution: {integrity: sha512-ISvCdHdlTDlH5IpxQJIex7BWBywFWgjJSVdwst+/iQCoEYnyOaQ95+X1JGshuBjGp6nxKUy1jMgE3zPqN7fQdg==} + hasBin: true + + bail@2.0.2: + resolution: {integrity: sha512-0xO6mYd7JB2YesxDKplafRpsiOzPt9V02ddPCLbY1xYGPOX24NTyN50qnUxgCPcSoYMhKpAuBTjQoRZCAkUDRw==} + + busboy@1.6.0: + resolution: {integrity: sha512-8SFQbg/0hQ9xy3UNTB0YEnsNBbWfhf7RtnzpL7TkBiTBRfrQ9Fxcnz7VJsleJpyp6rVLvXiuORqjlHi5q+PYuA==} + engines: {node: '>=10.16.0'} + + caniuse-lite@1.0.30001629: + resolution: {integrity: sha512-c3dl911slnQhmxUIT4HhYzT7wnBK/XYpGnYLOj4nJBaRiw52Ibe7YxlDaAeRECvA786zCuExhxIUJ2K7nHMrBw==} + + ccount@2.0.1: + resolution: {integrity: sha512-eyrF0jiFpY+3drT6383f1qhkbGsLSifNAjA61IUjZjmLCWjItY6LB9ft9YhoDgwfmclB2zhu51Lc7+95b8NRAg==} + + chalk@2.3.0: + resolution: {integrity: sha512-Az5zJR2CBujap2rqXGaJKaPHyJ0IrUimvYNX+ncCy8PJP4ltOGTrHUIo097ZaL2zMeKYpiCdqDvS6zdrTFok3Q==} + engines: {node: '>=4'} + + character-entities-html4@2.1.0: + resolution: {integrity: sha512-1v7fgQRj6hnSwFpq1Eu0ynr/CDEw0rXo2B61qXrLNdHZmPKgb7fqS1a2JwF0rISo9q77jDI8VMEHoApn8qDoZA==} + + character-entities-legacy@3.0.0: + resolution: {integrity: sha512-RpPp0asT/6ufRm//AJVwpViZbGM/MkjQFxJccQRHmISF/22NBtsHqAWmL+/pmkPWoIUJdWyeVleTl1wydHATVQ==} + + character-entities@2.0.2: + resolution: {integrity: sha512-shx7oQ0Awen/BRIdkjkvz54PnEEI/EjwXDSIZp86/KKdbafHh1Df/RYGBhn4hbe2+uKC9FnT5UCEdyPz3ai9hQ==} + + character-reference-invalid@2.0.1: + resolution: {integrity: sha512-iBZ4F4wRbyORVsu0jPV7gXkOsGYjGHPmAyv+HiHG8gi5PtC9KI2j1+v8/tlibRvjoWX027ypmG/n0HtO5t7unw==} + + client-only@0.0.1: + resolution: {integrity: sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA==} + + clipboardy@1.2.2: + resolution: {integrity: sha512-16KrBOV7bHmHdxcQiCvfUFYVFyEah4FI8vYT1Fr7CGSA4G+xBWMEfUEQJS1hxeHGtI9ju1Bzs9uXSbj5HZKArw==} + engines: {node: '>=4'} + + clsx@2.1.1: + resolution: {integrity: sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==} + engines: {node: '>=6'} + + color-convert@1.9.3: + resolution: {integrity: sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==} + + color-name@1.1.3: + resolution: {integrity: sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw==} + + comma-separated-tokens@2.0.3: + resolution: {integrity: sha512-Fu4hJdvzeylCfQPp9SGWidpzrMs7tTrlu6Vb8XGaRGck8QSNZJJp538Wrb60Lax4fPwR64ViY468OIUTbRlGZg==} + + commander@7.2.0: + resolution: {integrity: sha512-QrWXB+ZQSVPmIWIhtEO9H+gwHaMGYiF5ChvoJ+K9ZGHG/sVsa6yiesAD1GC/x46sET00Xlwo1u49RVVVzvcSkw==} + engines: {node: '>= 10'} + + commander@8.3.0: + resolution: {integrity: sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==} + engines: {node: '>= 12'} + + compute-scroll-into-view@3.1.0: + resolution: {integrity: sha512-rj8l8pD4bJ1nx+dAkMhV1xB5RuZEyVysfxJqB1pRchh1KVvwOv9b7CGB8ZfjTImVv2oF+sYMUkMZq6Na5Ftmbg==} + + cose-base@1.0.3: + resolution: {integrity: sha512-s9whTXInMSgAp/NVXVNuVxVKzGH2qck3aQlVHxDCdAEPgtMKwc4Wq6/QKhgdEdgbLSi9rBTAcPoRa6JpiG4ksg==} + + cross-spawn@5.1.0: + resolution: {integrity: sha512-pTgQJ5KC0d2hcY8eyL1IzlBPYjTkyH72XRZPnLyKus2mBfNjQs3klqbJU2VILqZryAZUt9JOb3h/mWMy23/f5A==} + + csstype@3.1.3: + resolution: {integrity: sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==} + + cytoscape-cose-bilkent@4.1.0: + resolution: {integrity: sha512-wgQlVIUJF13Quxiv5e1gstZ08rnZj2XaLHGoFMYXz7SkNfCDOOteKBE6SYRfA9WxxI/iBc3ajfDoc6hb/MRAHQ==} + peerDependencies: + cytoscape: ^3.2.0 + + cytoscape@3.29.2: + resolution: {integrity: sha512-2G1ycU28Nh7OHT9rkXRLpCDP30MKH1dXJORZuBhtEhEW7pKwgPi77ImqlCWinouyE1PNepIOGZBOrE84DG7LyQ==} + engines: {node: '>=0.10'} + + d3-array@2.12.1: + resolution: {integrity: sha512-B0ErZK/66mHtEsR1TkPEEkwdy+WDesimkM5gpZr5Dsg54BiTA5RXtYW5qTLIAcekaS9xfZrzBLF/OAkB3Qn1YQ==} + + d3-array@3.2.4: + resolution: {integrity: sha512-tdQAmyA18i4J7wprpYq8ClcxZy3SC31QMeByyCFyRt7BVHdREQZ5lpzoe5mFEYZUWe+oq8HBvk9JjpibyEV4Jg==} + engines: {node: '>=12'} + + d3-axis@3.0.0: + resolution: {integrity: sha512-IH5tgjV4jE/GhHkRV0HiVYPDtvfjHQlQfJHs0usq7M30XcSBvOotpmH1IgkcXsO/5gEQZD43B//fc7SRT5S+xw==} + engines: {node: '>=12'} + + d3-brush@3.0.0: + resolution: {integrity: sha512-ALnjWlVYkXsVIGlOsuWH1+3udkYFI48Ljihfnh8FZPF2QS9o+PzGLBslO0PjzVoHLZ2KCVgAM8NVkXPJB2aNnQ==} + engines: {node: '>=12'} + + d3-chord@3.0.1: + resolution: {integrity: sha512-VE5S6TNa+j8msksl7HwjxMHDM2yNK3XCkusIlpX5kwauBfXuyLAtNg9jCp/iHH61tgI4sb6R/EIMWCqEIdjT/g==} + engines: {node: '>=12'} + + d3-color@3.1.0: + resolution: {integrity: sha512-zg/chbXyeBtMQ1LbD/WSoW2DpC3I0mpmPdW+ynRTj/x2DAWYrIY7qeZIHidozwV24m4iavr15lNwIwLxRmOxhA==} + engines: {node: '>=12'} + + d3-contour@4.0.2: + resolution: {integrity: sha512-4EzFTRIikzs47RGmdxbeUvLWtGedDUNkTcmzoeyg4sP/dvCexO47AaQL7VKy/gul85TOxw+IBgA8US2xwbToNA==} + engines: {node: '>=12'} + + d3-delaunay@6.0.4: + resolution: {integrity: sha512-mdjtIZ1XLAM8bm/hx3WwjfHt6Sggek7qH043O8KEjDXN40xi3vx/6pYSVTwLjEgiXQTbvaouWKynLBiUZ6SK6A==} + engines: {node: '>=12'} + + d3-dispatch@3.0.1: + resolution: {integrity: sha512-rzUyPU/S7rwUflMyLc1ETDeBj0NRuHKKAcvukozwhshr6g6c5d8zh4c2gQjY2bZ0dXeGLWc1PF174P2tVvKhfg==} + engines: {node: '>=12'} + + d3-drag@3.0.0: + resolution: {integrity: sha512-pWbUJLdETVA8lQNJecMxoXfH6x+mO2UQo8rSmZ+QqxcbyA3hfeprFgIT//HW2nlHChWeIIMwS2Fq+gEARkhTkg==} + engines: {node: '>=12'} + + d3-dsv@3.0.1: + resolution: {integrity: sha512-UG6OvdI5afDIFP9w4G0mNq50dSOsXHJaRE8arAS5o9ApWnIElp8GZw1Dun8vP8OyHOZ/QJUKUJwxiiCCnUwm+Q==} + engines: {node: '>=12'} + hasBin: true + + d3-ease@3.0.1: + resolution: {integrity: sha512-wR/XK3D3XcLIZwpbvQwQ5fK+8Ykds1ip7A2Txe0yxncXSdq1L9skcG7blcedkOX+ZcgxGAmLX1FrRGbADwzi0w==} + engines: {node: '>=12'} + + d3-fetch@3.0.1: + resolution: {integrity: sha512-kpkQIM20n3oLVBKGg6oHrUchHM3xODkTzjMoj7aWQFq5QEM+R6E4WkzT5+tojDY7yjez8KgCBRoj4aEr99Fdqw==} + engines: {node: '>=12'} + + d3-force@3.0.0: + resolution: {integrity: sha512-zxV/SsA+U4yte8051P4ECydjD/S+qeYtnaIyAs9tgHCqfguma/aAQDjo85A9Z6EKhBirHRJHXIgJUlffT4wdLg==} + engines: {node: '>=12'} + + d3-format@3.1.0: + resolution: {integrity: sha512-YyUI6AEuY/Wpt8KWLgZHsIU86atmikuoOmCfommt0LYHiQSPjvX2AcFc38PX0CBpr2RCyZhjex+NS/LPOv6YqA==} + engines: {node: '>=12'} + + d3-geo@3.1.1: + resolution: {integrity: sha512-637ln3gXKXOwhalDzinUgY83KzNWZRKbYubaG+fGVuc/dxO64RRljtCTnf5ecMyE1RIdtqpkVcq0IbtU2S8j2Q==} + engines: {node: '>=12'} + + d3-hierarchy@3.1.2: + resolution: {integrity: sha512-FX/9frcub54beBdugHjDCdikxThEqjnR93Qt7PvQTOHxyiNCAlvMrHhclk3cD5VeAaq9fxmfRp+CnWw9rEMBuA==} + engines: {node: '>=12'} + + d3-interpolate@3.0.1: + resolution: {integrity: sha512-3bYs1rOD33uo8aqJfKP3JWPAibgw8Zm2+L9vBKEHJ2Rg+viTR7o5Mmv5mZcieN+FRYaAOWX5SJATX6k1PWz72g==} + engines: {node: '>=12'} + + d3-path@1.0.9: + resolution: {integrity: sha512-VLaYcn81dtHVTjEHd8B+pbe9yHWpXKZUC87PzoFmsFrJqgFwDe/qxfp5MlfsfM1V5E/iVt0MmEbWQ7FVIXh/bg==} + + d3-path@3.1.0: + resolution: {integrity: sha512-p3KP5HCf/bvjBSSKuXid6Zqijx7wIfNW+J/maPs+iwR35at5JCbLUT0LzF1cnjbCHWhqzQTIN2Jpe8pRebIEFQ==} + engines: {node: '>=12'} + + d3-polygon@3.0.1: + resolution: {integrity: sha512-3vbA7vXYwfe1SYhED++fPUQlWSYTTGmFmQiany/gdbiWgU/iEyQzyymwL9SkJjFFuCS4902BSzewVGsHHmHtXg==} + engines: {node: '>=12'} + + d3-quadtree@3.0.1: + resolution: {integrity: sha512-04xDrxQTDTCFwP5H6hRhsRcb9xxv2RzkcsygFzmkSIOJy3PeRJP7sNk3VRIbKXcog561P9oU0/rVH6vDROAgUw==} + engines: {node: '>=12'} + + d3-random@3.0.1: + resolution: {integrity: sha512-FXMe9GfxTxqd5D6jFsQ+DJ8BJS4E/fT5mqqdjovykEB2oFbTMDVdg1MGFxfQW+FBOGoB++k8swBrgwSHT1cUXQ==} + engines: {node: '>=12'} + + d3-sankey@0.12.3: + resolution: {integrity: sha512-nQhsBRmM19Ax5xEIPLMY9ZmJ/cDvd1BG3UVvt5h3WRxKg5zGRbvnteTyWAbzeSvlh3tW7ZEmq4VwR5mB3tutmQ==} + + d3-scale-chromatic@3.1.0: + resolution: {integrity: sha512-A3s5PWiZ9YCXFye1o246KoscMWqf8BsD9eRiJ3He7C9OBaxKhAd5TFCdEx/7VbKtxxTsu//1mMJFrEt572cEyQ==} + engines: {node: '>=12'} + + d3-scale@4.0.2: + resolution: {integrity: sha512-GZW464g1SH7ag3Y7hXjf8RoUuAFIqklOAq3MRl4OaWabTFJY9PN/E1YklhXLh+OQ3fM9yS2nOkCoS+WLZ6kvxQ==} + engines: {node: '>=12'} + + d3-selection@3.0.0: + resolution: {integrity: sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==} + engines: {node: '>=12'} + + d3-shape@1.3.7: + resolution: {integrity: sha512-EUkvKjqPFUAZyOlhY5gzCxCeI0Aep04LwIRpsZ/mLFelJiUfnK56jo5JMDSE7yyP2kLSb6LtF+S5chMk7uqPqw==} + + d3-shape@3.2.0: + resolution: {integrity: sha512-SaLBuwGm3MOViRq2ABk3eLoxwZELpH6zhl3FbAoJ7Vm1gofKx6El1Ib5z23NUEhF9AsGl7y+dzLe5Cw2AArGTA==} + engines: {node: '>=12'} + + d3-time-format@4.1.0: + resolution: {integrity: sha512-dJxPBlzC7NugB2PDLwo9Q8JiTR3M3e4/XANkreKSUxF8vvXKqm1Yfq4Q5dl8budlunRVlUUaDUgFt7eA8D6NLg==} + engines: {node: '>=12'} + + d3-time@3.1.0: + resolution: {integrity: sha512-VqKjzBLejbSMT4IgbmVgDjpkYrNWUYJnbCGo874u7MMKIWsILRX+OpX/gTk8MqjpT1A/c6HY2dCA77ZN0lkQ2Q==} + engines: {node: '>=12'} + + d3-timer@3.0.1: + resolution: {integrity: sha512-ndfJ/JxxMd3nw31uyKoY2naivF+r29V+Lc0svZxe1JvvIRmi8hUsrMvdOwgS1o6uBHmiz91geQ0ylPP0aj1VUA==} + engines: {node: '>=12'} + + d3-transition@3.0.1: + resolution: {integrity: sha512-ApKvfjsSR6tg06xrL434C0WydLr7JewBB3V+/39RMHsaXTOG0zmt/OAXeng5M5LBm0ojmxJrpomQVZ1aPvBL4w==} + engines: {node: '>=12'} + peerDependencies: + d3-selection: 2 - 3 + + d3-zoom@3.0.0: + resolution: {integrity: sha512-b8AmV3kfQaqWAuacbPuNbL6vahnOJflOhexLzMMNLga62+/nh0JzvJ0aO/5a5MVgUFGS7Hu1P9P03o3fJkDCyw==} + engines: {node: '>=12'} + + d3@7.9.0: + resolution: {integrity: sha512-e1U46jVP+w7Iut8Jt8ri1YsPOvFpg46k+K8TpCb0P+zjCkjkPnV7WzfDJzMHy1LnA+wj5pLT1wjO901gLXeEhA==} + engines: {node: '>=12'} + + dagre-d3-es@7.0.10: + resolution: {integrity: sha512-qTCQmEhcynucuaZgY5/+ti3X/rnszKZhEQH/ZdWdtP1tA/y3VoHJzcVrO9pjjJCNpigfscAtoUB5ONcd2wNn0A==} + + dayjs@1.11.11: + resolution: {integrity: sha512-okzr3f11N6WuqYtZSvm+F776mB41wRZMhKP+hc34YdW+KmtYYK9iqvHSwo2k9FEH3fhGXvOPV6yz2IcSrfRUDg==} + + debug@4.3.5: + resolution: {integrity: sha512-pt0bNEmneDIvdL1Xsd9oDQ/wrQRkXDT4AUWlNZNPKvW5x/jyO9VFXkJUP07vQ2upmw5PlaITaPKc31jK13V+jg==} + engines: {node: '>=6.0'} + peerDependencies: + supports-color: '*' + peerDependenciesMeta: + supports-color: + optional: true + + decode-named-character-reference@1.0.2: + resolution: {integrity: sha512-O8x12RzrUF8xyVcY0KJowWsmaJxQbmy0/EtnNtHRpsOcT7dFk5W598coHqBVpmWo1oQQfsCqfCmkZN5DJrZVdg==} + + delaunator@5.0.1: + resolution: {integrity: sha512-8nvh+XBe96aCESrGOqMp/84b13H9cdKbG5P2ejQCh4d4sK9RL4371qou9drQjMhvnPmhWl5hnmqbEE0fXr9Xnw==} + + dequal@2.0.3: + resolution: {integrity: sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==} + engines: {node: '>=6'} + + devlop@1.1.0: + resolution: {integrity: sha512-RWmIqhcFf1lRYBvNmr7qTNuyCt/7/ns2jbpp1+PalgE/rDQcBT0fioSMUpJ93irlUhC5hrg4cYqe6U+0ImW0rA==} + + diff@5.2.0: + resolution: {integrity: sha512-uIFDxqpRZGZ6ThOk84hEfqWoHx2devRFvpTZcTHur85vImfaxUbTW9Ryh4CpCuDnToOP1CEtXKIgytHBPVff5A==} + engines: {node: '>=0.3.1'} + + dompurify@3.1.5: + resolution: {integrity: sha512-lwG+n5h8QNpxtyrJW/gJWckL+1/DQiYMX8f7t8Z2AZTPw1esVrqjI63i7Zc2Gz0aKzLVMYC1V1PL/ky+aY/NgA==} + + elkjs@0.9.3: + resolution: {integrity: sha512-f/ZeWvW/BCXbhGEf1Ujp29EASo/lk1FDnETgNKwJrsVvGZhUWCZyg3xLJjAsxfOmt8KjswHmI5EwCQcPMpOYhQ==} + + entities@4.5.0: + resolution: {integrity: sha512-V0hjH4dGPh9Ao5p0MoRY6BVqtwCjhz6vI5LT8AJ55H+4g9/4vbHx1I54fS0XuclLhDHArPQCiMjDxjaL8fPxhw==} + engines: {node: '>=0.12'} + + escape-string-regexp@1.0.5: + resolution: {integrity: sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg==} + engines: {node: '>=0.8.0'} + + escape-string-regexp@5.0.0: + resolution: {integrity: sha512-/veY75JbMK4j1yjvuUxuVsiS/hr/4iHs9FTT6cgTexxdE0Ly/glccBAkloH/DofkjRbZU3bnoj38mOmhkZ0lHw==} + engines: {node: '>=12'} + + esprima@4.0.1: + resolution: {integrity: sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==} + engines: {node: '>=4'} + hasBin: true + + estree-util-attach-comments@2.1.1: + resolution: {integrity: sha512-+5Ba/xGGS6mnwFbXIuQiDPTbuTxuMCooq3arVv7gPZtYpjp+VXH/NkHAP35OOefPhNG/UGqU3vt/LTABwcHX0w==} + + estree-util-build-jsx@2.2.2: + resolution: {integrity: sha512-m56vOXcOBuaF+Igpb9OPAy7f9w9OIkb5yhjsZuaPm7HoGi4oTOQi0h2+yZ+AtKklYFZ+rPC4n0wYCJCEU1ONqg==} + + estree-util-is-identifier-name@2.1.0: + resolution: {integrity: sha512-bEN9VHRyXAUOjkKVQVvArFym08BTWB0aJPppZZr0UNyAqWsLaVfAqP7hbaTJjzHifmB5ebnR8Wm7r7yGN/HonQ==} + + estree-util-to-js@1.2.0: + resolution: {integrity: sha512-IzU74r1PK5IMMGZXUVZbmiu4A1uhiPgW5hm1GjcOfr4ZzHaMPpLNJjR7HjXiIOzi25nZDrgFTobHTkV5Q6ITjA==} + + estree-util-value-to-estree@1.3.0: + resolution: {integrity: sha512-Y+ughcF9jSUJvncXwqRageavjrNPAI+1M/L3BI3PyLp1nmgYTGUXU6t5z1Y7OWuThoDdhPME07bQU+d5LxdJqw==} + engines: {node: '>=12.0.0'} + + estree-util-visit@1.2.1: + resolution: {integrity: sha512-xbgqcrkIVbIG+lI/gzbvd9SGTJL4zqJKBFttUl5pP27KhAjtMKbX/mQXJ7qgyXpMgVy/zvpm0xoQQaGL8OloOw==} + + estree-walker@3.0.3: + resolution: {integrity: sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==} + + execa@0.8.0: + resolution: {integrity: sha512-zDWS+Rb1E8BlqqhALSt9kUhss8Qq4nN3iof3gsOdyINksElaPyNBtKUMTR62qhvgVWR0CqCX7sdnKe4MnUbFEA==} + engines: {node: '>=4'} + + extend-shallow@2.0.1: + resolution: {integrity: sha512-zCnTtlxNoAiDc3gqY2aYAWFx7XWWiasuF2K8Me5WbN8otHKTUKBwjPtNpRs/rbUZm7KxWAaNj7P1a/p52GbVug==} + engines: {node: '>=0.10.0'} + + extend@3.0.2: + resolution: {integrity: sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g==} + + flexsearch@0.7.43: + resolution: {integrity: sha512-c5o/+Um8aqCSOXGcZoqZOm+NqtVwNsvVpWv6lfmSclU954O3wvQKxxK8zj74fPaSJbXpSLTs4PRhh+wnoCXnKg==} + + focus-visible@5.2.0: + resolution: {integrity: sha512-Rwix9pBtC1Nuy5wysTmKy+UjbDJpIfg8eHjw0rjZ1mX4GNLz1Bmd16uDpI3Gk1i70Fgcs8Csg2lPm8HULFg9DQ==} + + get-stream@3.0.0: + resolution: {integrity: sha512-GlhdIUuVakc8SJ6kK0zAFbiGzRFzNnY4jUuEbV9UROo4Y+0Ny4fjvcZFVTeDA4odpFyOQzaw6hXukJSq/f28sQ==} + engines: {node: '>=4'} + + git-up@7.0.0: + resolution: {integrity: sha512-ONdIrbBCFusq1Oy0sC71F5azx8bVkvtZtMJAsv+a6lz5YAmbNnLD6HAB4gptHZVLPR8S2/kVN6Gab7lryq5+lQ==} + + git-url-parse@13.1.1: + resolution: {integrity: sha512-PCFJyeSSdtnbfhSNRw9Wk96dDCNx+sogTe4YNXeXSJxt7xz5hvXekuRn9JX7m+Mf4OscCu8h+mtAl3+h5Fo8lQ==} + + github-slugger@2.0.0: + resolution: {integrity: sha512-IaOQ9puYtjrkq7Y0Ygl9KDZnrf/aiUJYUpVf89y8kyaxbRG7Y1SrX/jaumrv81vc61+kiMempujsM3Yw7w5qcw==} + + graceful-fs@4.2.11: + resolution: {integrity: sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==} + + gray-matter@4.0.3: + resolution: {integrity: sha512-5v6yZd4JK3eMI3FqqCouswVqwugaA9r4dNZB1wwcmrD02QkV5H0y7XBQW8QwQqEaZY1pM9aqORSORhJRdNK44Q==} + engines: {node: '>=6.0'} + + has-flag@2.0.0: + resolution: {integrity: sha512-P+1n3MnwjR/Epg9BBo1KT8qbye2g2Ou4sFumihwt6I4tsUX7jnLcX4BTOSKg/B1ZrIYMN9FcEnG4x5a7NB8Eng==} + engines: {node: '>=0.10.0'} + + hash-obj@4.0.0: + resolution: {integrity: sha512-FwO1BUVWkyHasWDW4S8o0ssQXjvyghLV2rfVhnN36b2bbcj45eGiuzdn9XOvOpjV3TKQD7Gm2BWNXdE9V4KKYg==} + engines: {node: '>=12'} + + hast-util-from-dom@5.0.0: + resolution: {integrity: sha512-d6235voAp/XR3Hh5uy7aGLbM3S4KamdW0WEgOaU1YoewnuYw4HXb5eRtv9g65m/RFGEfUY1Mw4UqCc5Y8L4Stg==} + + hast-util-from-html-isomorphic@2.0.0: + resolution: {integrity: sha512-zJfpXq44yff2hmE0XmwEOzdWin5xwH+QIhMLOScpX91e/NSGPsAzNCvLQDIEPyO2TXi+lBmU6hjLIhV8MwP2kw==} + + hast-util-from-html@2.0.1: + resolution: {integrity: sha512-RXQBLMl9kjKVNkJTIO6bZyb2n+cUH8LFaSSzo82jiLT6Tfc+Pt7VQCS+/h3YwG4jaNE2TA2sdJisGWR+aJrp0g==} + + hast-util-from-parse5@8.0.1: + resolution: {integrity: sha512-Er/Iixbc7IEa7r/XLtuG52zoqn/b3Xng/w6aZQ0xGVxzhw5xUFxcRqdPzP6yFi/4HBYRaifaI5fQ1RH8n0ZeOQ==} + + hast-util-is-element@3.0.0: + resolution: {integrity: sha512-Val9mnv2IWpLbNPqc/pUem+a7Ipj2aHacCwgNfTiK0vJKl0LF+4Ba4+v1oPHFpf3bLYmreq0/l3Gud9S5OH42g==} + + hast-util-parse-selector@4.0.0: + resolution: {integrity: sha512-wkQCkSYoOGCRKERFWcxMVMOcYE2K1AaNLU8DXS9arxnLOUEWbOXKXiJUNzEpqZ3JOKpnha3jkFrumEjVliDe7A==} + + hast-util-raw@9.0.3: + resolution: {integrity: sha512-ICWvVOF2fq4+7CMmtCPD5CM4QKjPbHpPotE6+8tDooV0ZuyJVUzHsrNX+O5NaRbieTf0F7FfeBOMAwi6Td0+yQ==} + + hast-util-to-estree@2.3.3: + resolution: {integrity: sha512-ihhPIUPxN0v0w6M5+IiAZZrn0LH2uZomeWwhn7uP7avZC6TE7lIiEh2yBMPr5+zi1aUCXq6VoYRgs2Bw9xmycQ==} + + hast-util-to-parse5@8.0.0: + resolution: {integrity: sha512-3KKrV5ZVI8if87DVSi1vDeByYrkGzg4mEfeu4alwgmmIeARiBLKCZS2uw5Gb6nU9x9Yufyj3iudm6i7nl52PFw==} + + hast-util-to-text@4.0.2: + resolution: {integrity: sha512-KK6y/BN8lbaq654j7JgBydev7wuNMcID54lkRav1P0CaE1e47P72AWWPiGKXTJU271ooYzcvTAn/Zt0REnvc7A==} + + hast-util-whitespace@2.0.1: + resolution: {integrity: sha512-nAxA0v8+vXSBDt3AnRUNjyRIQ0rD+ntpbAp4LnPkumc5M9yUbSMa4XDU9Q6etY4f1Wp4bNgvc1yjiZtsTTrSng==} + + hastscript@8.0.0: + resolution: {integrity: sha512-dMOtzCEd3ABUeSIISmrETiKuyydk1w0pa+gE/uormcTpSYuaNJPbX1NU3JLyscSLjwAQM8bWMhhIlnCqnRvDTw==} + + html-void-elements@3.0.0: + resolution: {integrity: sha512-bEqo66MRXsUGxWHV5IP0PUiAWwoEjba4VCzg0LjFJBpchPaTfyfCKTG6bc5F8ucKec3q5y6qOdGyYTSBEvhCrg==} + + iconv-lite@0.6.3: + resolution: {integrity: sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==} + engines: {node: '>=0.10.0'} + + inline-style-parser@0.1.1: + resolution: {integrity: sha512-7NXolsK4CAS5+xvdj5OMMbI962hU/wvwoxk+LWR9Ek9bVtyuuYScDN6eS0rUm6TxApFpw7CX1o4uJzcd4AyD3Q==} + + internmap@1.0.1: + resolution: {integrity: sha512-lDB5YccMydFBtasVtxnZ3MRBHuaoE8GKsppq+EchKL2U4nK/DmEpPHNH8MZe5HkMtpSiTSOZwfN0tzYjO/lJEw==} + + internmap@2.0.3: + resolution: {integrity: sha512-5Hh7Y1wQbvY5ooGgPbDaL5iYLAPzMTUrjMulskHLH6wnv/A+1q5rgEaiuqEjB+oxGXIVZs1FF+R/KPN3ZSQYYg==} + engines: {node: '>=12'} + + intersection-observer@0.12.2: + resolution: {integrity: sha512-7m1vEcPCxXYI8HqnL8CKI6siDyD+eIWSwgB3DZA+ZTogxk9I4CDnj4wilt9x/+/QbHI4YG5YZNmC6458/e9Ktg==} + + is-alphabetical@2.0.1: + resolution: {integrity: sha512-FWyyY60MeTNyeSRpkM2Iry0G9hpr7/9kD40mD/cGQEuilcZYS4okz8SN2Q6rLCJ8gbCt6fN+rC+6tMGS99LaxQ==} + + is-alphanumerical@2.0.1: + resolution: {integrity: sha512-hmbYhX/9MUMF5uh7tOXyK/n0ZvWpad5caBA17GsC6vyuCqaWliRG5K1qS9inmUhEMaOBIW7/whAnSwveW/LtZw==} + + is-buffer@2.0.5: + resolution: {integrity: sha512-i2R6zNFDwgEHJyQUtJEk0XFi1i0dPFn/oqjK3/vPCcDeJvW5NQ83V8QbicfF1SupOaB0h8ntgBC2YiE7dfyctQ==} + engines: {node: '>=4'} + + is-decimal@2.0.1: + resolution: {integrity: sha512-AAB9hiomQs5DXWcRB1rqsxGUstbRroFOPPVAomNk/3XHR5JyEZChOyTWe2oayKnsSsr/kcGqF+z6yuH6HHpN0A==} + + is-extendable@0.1.1: + resolution: {integrity: sha512-5BMULNob1vgFX6EjQw5izWDxrecWK9AM72rugNr0TFldMOi0fj6Jk+zeKIt0xGj4cEfQIJth4w3OKWOJ4f+AFw==} + engines: {node: '>=0.10.0'} + + is-hexadecimal@2.0.1: + resolution: {integrity: sha512-DgZQp241c8oO6cA1SbTEWiXeoxV42vlcJxgH+B3hi1AiqqKruZR3ZGF8In3fj4+/y/7rHvlOZLZtgJ/4ttYGZg==} + + is-obj@3.0.0: + resolution: {integrity: sha512-IlsXEHOjtKhpN8r/tRFj2nDyTmHvcfNeu/nrRIcXE17ROeatXchkojffa1SpdqW4cr/Fj6QkEf/Gn4zf6KKvEQ==} + engines: {node: '>=12'} + + is-plain-obj@3.0.0: + resolution: {integrity: sha512-gwsOE28k+23GP1B6vFl1oVh/WOzmawBrKwo5Ev6wMKzPkaXaCDIQKzLnvsA42DRlbVTWorkgTKIviAKCWkfUwA==} + engines: {node: '>=10'} + + is-plain-obj@4.1.0: + resolution: {integrity: sha512-+Pgi+vMuUNkJyExiMBt5IlFoMyKnr5zhJ4Uspz58WOhBF5QoIZkFyNHIbBAtHwzVAgk5RtndVNsDRN61/mmDqg==} + engines: {node: '>=12'} + + is-reference@3.0.2: + resolution: {integrity: sha512-v3rht/LgVcsdZa3O2Nqs+NMowLOxeOm7Ay9+/ARQ2F+qEoANRcqrjAZKGN0v8ymUetZGgkp26LTnGT7H0Qo9Pg==} + + is-ssh@1.4.0: + resolution: {integrity: sha512-x7+VxdxOdlV3CYpjvRLBv5Lo9OJerlYanjwFrPR9fuGPjCiNiCzFgAWpiLAohSbsnH4ZAys3SBh+hq5rJosxUQ==} + + is-stream@1.1.0: + resolution: {integrity: sha512-uQPm8kcs47jx38atAcWTVxyltQYoPT68y9aWYdV6yWXSyW8mzSat0TL6CiWdZeCdF3KrAvpVtnHbTv4RN+rqdQ==} + engines: {node: '>=0.10.0'} + + isexe@2.0.0: + resolution: {integrity: sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==} + + js-tokens@4.0.0: + resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==} + + js-yaml@3.14.1: + resolution: {integrity: sha512-okMH7OXXJ7YrN9Ok3/SXrnu4iX9yOk+25nqX4imS2npuvTYDmo/QEZoqwZkYaIDk3jVvBOTOIEgEhaLOynBS9g==} + hasBin: true + + js-yaml@4.1.0: + resolution: {integrity: sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==} + hasBin: true + + jsonc-parser@3.2.1: + resolution: {integrity: sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==} + + katex@0.16.10: + resolution: {integrity: sha512-ZiqaC04tp2O5utMsl2TEZTXxa6WSC4yo0fv5ML++D3QZv/vx2Mct0mTlRx3O+uUkjfuAgOkzsCmq5MiUEsDDdA==} + hasBin: true + + khroma@2.1.0: + resolution: {integrity: sha512-Ls993zuzfayK269Svk9hzpeGUKob/sIgZzyHYdjQoAdQetRKpOLj+k/QQQ/6Qi0Yz65mlROrfd+Ev+1+7dz9Kw==} + + kind-of@6.0.3: + resolution: {integrity: sha512-dcS1ul+9tmeD95T+x28/ehLgd9mENa3LsvDTtzm3vyBEO7RPptvAD+t44WVXaUjTBRcrpFeFlC8WCruUR456hw==} + engines: {node: '>=0.10.0'} + + kleur@4.1.5: + resolution: {integrity: sha512-o+NO+8WrRiQEE4/7nwRJhN1HWpVmJm511pBHUxPLtp0BUISzlBplORYSmTclCnJvQq2tKu/sgl3xVpkc7ZWuQQ==} + engines: {node: '>=6'} + + layout-base@1.0.2: + resolution: {integrity: sha512-8h2oVEZNktL4BH2JCOI90iD1yXwL6iNW7KcCKT2QZgQJR2vbqDsldCTPRU9NifTCqHZci57XvQQ15YTu+sTYPg==} + + lodash-es@4.17.21: + resolution: {integrity: sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw==} + + lodash.get@4.4.2: + resolution: {integrity: sha512-z+Uw/vLuy6gQe8cfaFWD7p0wVv8fJl3mbzXh33RS+0oW2wvUqiRXiQ69gLWSLpgB5/6sU+r6BlQR0MBILadqTQ==} + + longest-streak@3.1.0: + resolution: {integrity: sha512-9Ri+o0JYgehTaVBBDoMqIl8GXtbWg711O3srftcHhZ0dqnETqLaoIK0x17fUw9rFSlK/0NlsKe0Ahhyl5pXE2g==} + + loose-envify@1.4.0: + resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} + hasBin: true + + lru-cache@4.1.5: + resolution: {integrity: sha512-sWZlbEP2OsHNkXrMl5GYk/jKk70MBng6UU4YI/qGDYbgf6YbP4EvmqISbXCoJiRKs+1bSpFHVgQxvJ17F2li5g==} + + markdown-extensions@1.1.1: + resolution: {integrity: sha512-WWC0ZuMzCyDHYCasEGs4IPvLyTGftYwh6wIEOULOF0HXcqZlhwRzrK0w2VUlxWA98xnvb/jszw4ZSkJ6ADpM6Q==} + engines: {node: '>=0.10.0'} + + markdown-table@3.0.3: + resolution: {integrity: sha512-Z1NL3Tb1M9wH4XESsCDEksWoKTdlUafKc4pt0GRwjUyXaCFZ+dc3g2erqB6zm3szA2IUSi7VnPI+o/9jnxh9hw==} + + match-sorter@6.3.4: + resolution: {integrity: sha512-jfZW7cWS5y/1xswZo8VBOdudUiSd9nifYRWphc9M5D/ee4w4AoXLgBEdRbgVaxbMuagBPeUC5y2Hi8DO6o9aDg==} + + mdast-util-definitions@5.1.2: + resolution: {integrity: sha512-8SVPMuHqlPME/z3gqVwWY4zVXn8lqKv/pAhC57FuJ40ImXyBpmO5ukh98zB2v7Blql2FiHjHv9LVztSIqjY+MA==} + + mdast-util-find-and-replace@2.2.2: + resolution: {integrity: sha512-MTtdFRz/eMDHXzeK6W3dO7mXUlF82Gom4y0oOgvHhh/HXZAGvIQDUvQ0SuUx+j2tv44b8xTHOm8K/9OoRFnXKw==} + + mdast-util-from-markdown@1.3.1: + resolution: {integrity: sha512-4xTO/M8c82qBcnQc1tgpNtubGUW/Y1tBQ1B0i5CtSoelOLKFYlElIr3bvgREYYO5iRqbMY1YuqZng0GVOI8Qww==} + + mdast-util-gfm-autolink-literal@1.0.3: + resolution: {integrity: sha512-My8KJ57FYEy2W2LyNom4n3E7hKTuQk/0SES0u16tjA9Z3oFkF4RrC/hPAPgjlSpezsOvI8ObcXcElo92wn5IGA==} + + mdast-util-gfm-footnote@1.0.2: + resolution: {integrity: sha512-56D19KOGbE00uKVj3sgIykpwKL179QsVFwx/DCW0u/0+URsryacI4MAdNJl0dh+u2PSsD9FtxPFbHCzJ78qJFQ==} + + mdast-util-gfm-strikethrough@1.0.3: + resolution: {integrity: sha512-DAPhYzTYrRcXdMjUtUjKvW9z/FNAMTdU0ORyMcbmkwYNbKocDpdk+PX1L1dQgOID/+vVs1uBQ7ElrBQfZ0cuiQ==} + + mdast-util-gfm-table@1.0.7: + resolution: {integrity: sha512-jjcpmNnQvrmN5Vx7y7lEc2iIOEytYv7rTvu+MeyAsSHTASGCCRA79Igg2uKssgOs1i1po8s3plW0sTu1wkkLGg==} + + mdast-util-gfm-task-list-item@1.0.2: + resolution: {integrity: sha512-PFTA1gzfp1B1UaiJVyhJZA1rm0+Tzn690frc/L8vNX1Jop4STZgOE6bxUhnzdVSB+vm2GU1tIsuQcA9bxTQpMQ==} + + mdast-util-gfm@2.0.2: + resolution: {integrity: sha512-qvZ608nBppZ4icQlhQQIAdc6S3Ffj9RGmzwUKUWuEICFnd1LVkN3EktF7ZHAgfcEdvZB5owU9tQgt99e2TlLjg==} + + mdast-util-math@2.0.2: + resolution: {integrity: sha512-8gmkKVp9v6+Tgjtq6SYx9kGPpTf6FVYRa53/DLh479aldR9AyP48qeVOgNZ5X7QUK7nOy4yw7vg6mbiGcs9jWQ==} + + mdast-util-mdx-expression@1.3.2: + resolution: {integrity: sha512-xIPmR5ReJDu/DHH1OoIT1HkuybIfRGYRywC+gJtI7qHjCJp/M9jrmBEJW22O8lskDWm562BX2W8TiAwRTb0rKA==} + + mdast-util-mdx-jsx@2.1.4: + resolution: {integrity: sha512-DtMn9CmVhVzZx3f+optVDF8yFgQVt7FghCRNdlIaS3X5Bnym3hZwPbg/XW86vdpKjlc1PVj26SpnLGeJBXD3JA==} + + mdast-util-mdx@2.0.1: + resolution: {integrity: sha512-38w5y+r8nyKlGvNjSEqWrhG0w5PmnRA+wnBvm+ulYCct7nsGYhFVb0lljS9bQav4psDAS1eGkP2LMVcZBi/aqw==} + + mdast-util-mdxjs-esm@1.3.1: + resolution: {integrity: sha512-SXqglS0HrEvSdUEfoXFtcg7DRl7S2cwOXc7jkuusG472Mmjag34DUDeOJUZtl+BVnyeO1frIgVpHlNRWc2gk/w==} + + mdast-util-phrasing@3.0.1: + resolution: {integrity: sha512-WmI1gTXUBJo4/ZmSk79Wcb2HcjPJBzM1nlI/OUWA8yk2X9ik3ffNbBGsU+09BFmXaL1IBb9fiuvq6/KMiNycSg==} + + mdast-util-to-hast@12.3.0: + resolution: {integrity: sha512-pits93r8PhnIoU4Vy9bjW39M2jJ6/tdHyja9rrot9uujkN7UTU9SDnE6WNJz/IGyQk3XHX6yNNtrBH6cQzm8Hw==} + + mdast-util-to-hast@13.1.0: + resolution: {integrity: sha512-/e2l/6+OdGp/FB+ctrJ9Avz71AN/GRH3oi/3KAx/kMnoUsD6q0woXlDT8lLEeViVKE7oZxE7RXzvO3T8kF2/sA==} + + mdast-util-to-markdown@1.5.0: + resolution: {integrity: sha512-bbv7TPv/WC49thZPg3jXuqzuvI45IL2EVAr/KxF0BSdHsU0ceFHOmwQn6evxAh1GaoK/6GQ1wp4R4oW2+LFL/A==} + + mdast-util-to-string@3.2.0: + resolution: {integrity: sha512-V4Zn/ncyN1QNSqSBxTrMOLpjr+IKdHl2v3KVLoWmDPscP4r9GcCi71gjgvUV1SFSKh92AjAG4peFuBl2/YgCJg==} + + mermaid@10.9.1: + resolution: {integrity: sha512-Mx45Obds5W1UkW1nv/7dHRsbfMM1aOKA2+Pxs/IGHNonygDHwmng8xTHyS9z4KWVi0rbko8gjiBmuwwXQ7tiNA==} + + micromark-core-commonmark@1.1.0: + resolution: {integrity: sha512-BgHO1aRbolh2hcrzL2d1La37V0Aoz73ymF8rAcKnohLy93titmv62E0gP8Hrx9PKcKrqCZ1BbLGbP3bEhoXYlw==} + + micromark-extension-gfm-autolink-literal@1.0.5: + resolution: {integrity: sha512-z3wJSLrDf8kRDOh2qBtoTRD53vJ+CWIyo7uyZuxf/JAbNJjiHsOpG1y5wxk8drtv3ETAHutCu6N3thkOOgueWg==} + + micromark-extension-gfm-footnote@1.1.2: + resolution: {integrity: sha512-Yxn7z7SxgyGWRNa4wzf8AhYYWNrwl5q1Z8ii+CSTTIqVkmGZF1CElX2JI8g5yGoM3GAman9/PVCUFUSJ0kB/8Q==} + + micromark-extension-gfm-strikethrough@1.0.7: + resolution: {integrity: sha512-sX0FawVE1o3abGk3vRjOH50L5TTLr3b5XMqnP9YDRb34M0v5OoZhG+OHFz1OffZ9dlwgpTBKaT4XW/AsUVnSDw==} + + micromark-extension-gfm-table@1.0.7: + resolution: {integrity: sha512-3ZORTHtcSnMQEKtAOsBQ9/oHp9096pI/UvdPtN7ehKvrmZZ2+bbWhi0ln+I9drmwXMt5boocn6OlwQzNXeVeqw==} + + micromark-extension-gfm-tagfilter@1.0.2: + resolution: {integrity: sha512-5XWB9GbAUSHTn8VPU8/1DBXMuKYT5uOgEjJb8gN3mW0PNW5OPHpSdojoqf+iq1xo7vWzw/P8bAHY0n6ijpXF7g==} + + micromark-extension-gfm-task-list-item@1.0.5: + resolution: {integrity: sha512-RMFXl2uQ0pNQy6Lun2YBYT9g9INXtWJULgbt01D/x8/6yJ2qpKyzdZD3pi6UIkzF++Da49xAelVKUeUMqd5eIQ==} + + micromark-extension-gfm@2.0.3: + resolution: {integrity: sha512-vb9OoHqrhCmbRidQv/2+Bc6pkP0FrtlhurxZofvOEy5o8RtuuvTq+RQ1Vw5ZDNrVraQZu3HixESqbG+0iKk/MQ==} + + micromark-extension-math@2.1.2: + resolution: {integrity: sha512-es0CcOV89VNS9wFmyn+wyFTKweXGW4CEvdaAca6SWRWPyYCbBisnjaHLjWO4Nszuiud84jCpkHsqAJoa768Pvg==} + + micromark-extension-mdx-expression@1.0.8: + resolution: {integrity: sha512-zZpeQtc5wfWKdzDsHRBY003H2Smg+PUi2REhqgIhdzAa5xonhP03FcXxqFSerFiNUr5AWmHpaNPQTBVOS4lrXw==} + + micromark-extension-mdx-jsx@1.0.5: + resolution: {integrity: sha512-gPH+9ZdmDflbu19Xkb8+gheqEDqkSpdCEubQyxuz/Hn8DOXiXvrXeikOoBA71+e8Pfi0/UYmU3wW3H58kr7akA==} + + micromark-extension-mdx-md@1.0.1: + resolution: {integrity: sha512-7MSuj2S7xjOQXAjjkbjBsHkMtb+mDGVW6uI2dBL9snOBCbZmoNgDAeZ0nSn9j3T42UE/g2xVNMn18PJxZvkBEA==} + + micromark-extension-mdxjs-esm@1.0.5: + resolution: {integrity: sha512-xNRBw4aoURcyz/S69B19WnZAkWJMxHMT5hE36GtDAyhoyn/8TuAeqjFJQlwk+MKQsUD7b3l7kFX+vlfVWgcX1w==} + + micromark-extension-mdxjs@1.0.1: + resolution: {integrity: sha512-7YA7hF6i5eKOfFUzZ+0z6avRG52GpWR8DL+kN47y3f2KhxbBZMhmxe7auOeaTBrW2DenbbZTf1ea9tA2hDpC2Q==} + + micromark-factory-destination@1.1.0: + resolution: {integrity: sha512-XaNDROBgx9SgSChd69pjiGKbV+nfHGDPVYFs5dOoDd7ZnMAE+Cuu91BCpsY8RT2NP9vo/B8pds2VQNCLiu0zhg==} + + micromark-factory-label@1.1.0: + resolution: {integrity: sha512-OLtyez4vZo/1NjxGhcpDSbHQ+m0IIGnT8BoPamh+7jVlzLJBH98zzuCoUeMxvM6WsNeh8wx8cKvqLiPHEACn0w==} + + micromark-factory-mdx-expression@1.0.9: + resolution: {integrity: sha512-jGIWzSmNfdnkJq05c7b0+Wv0Kfz3NJ3N4cBjnbO4zjXIlxJr+f8lk+5ZmwFvqdAbUy2q6B5rCY//g0QAAaXDWA==} + + micromark-factory-space@1.1.0: + resolution: {integrity: sha512-cRzEj7c0OL4Mw2v6nwzttyOZe8XY/Z8G0rzmWQZTBi/jjwyw/U4uqKtUORXQrR5bAZZnbTI/feRV/R7hc4jQYQ==} + + micromark-factory-title@1.1.0: + resolution: {integrity: sha512-J7n9R3vMmgjDOCY8NPw55jiyaQnH5kBdV2/UXCtZIpnHH3P6nHUKaH7XXEYuWwx/xUJcawa8plLBEjMPU24HzQ==} + + micromark-factory-whitespace@1.1.0: + resolution: {integrity: sha512-v2WlmiymVSp5oMg+1Q0N1Lxmt6pMhIHD457whWM7/GUlEks1hI9xj5w3zbc4uuMKXGisksZk8DzP2UyGbGqNsQ==} + + micromark-util-character@1.2.0: + resolution: {integrity: sha512-lXraTwcX3yH/vMDaFWCQJP1uIszLVebzUa3ZHdrgxr7KEU/9mL4mVgCpGbyhvNLNlauROiNUq7WN5u7ndbY6xg==} + + micromark-util-character@2.1.0: + resolution: {integrity: sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==} + + micromark-util-chunked@1.1.0: + resolution: {integrity: sha512-Ye01HXpkZPNcV6FiyoW2fGZDUw4Yc7vT0E9Sad83+bEDiCJ1uXu0S3mr8WLpsz3HaG3x2q0HM6CTuPdcZcluFQ==} + + micromark-util-classify-character@1.1.0: + resolution: {integrity: sha512-SL0wLxtKSnklKSUplok1WQFoGhUdWYKggKUiqhX+Swala+BtptGCu5iPRc+xvzJ4PXE/hwM3FNXsfEVgoZsWbw==} + + micromark-util-combine-extensions@1.1.0: + resolution: {integrity: sha512-Q20sp4mfNf9yEqDL50WwuWZHUrCO4fEyeDCnMGmG5Pr0Cz15Uo7KBs6jq+dq0EgX4DPwwrh9m0X+zPV1ypFvUA==} + + micromark-util-decode-numeric-character-reference@1.1.0: + resolution: {integrity: sha512-m9V0ExGv0jB1OT21mrWcuf4QhP46pH1KkfWy9ZEezqHKAxkj4mPCy3nIH1rkbdMlChLHX531eOrymlwyZIf2iw==} + + micromark-util-decode-string@1.1.0: + resolution: {integrity: sha512-YphLGCK8gM1tG1bd54azwyrQRjCFcmgj2S2GoJDNnh4vYtnL38JS8M4gpxzOPNyHdNEpheyWXCTnnTDY3N+NVQ==} + + micromark-util-encode@1.1.0: + resolution: {integrity: sha512-EuEzTWSTAj9PA5GOAs992GzNh2dGQO52UvAbtSOMvXTxv3Criqb6IOzJUBCmEqrrXSblJIJBbFFv6zPxpreiJw==} + + micromark-util-encode@2.0.0: + resolution: {integrity: sha512-pS+ROfCXAGLWCOc8egcBvT0kf27GoWMqtdarNfDcjb6YLuV5cM3ioG45Ys2qOVqeqSbjaKg72vU+Wby3eddPsA==} + + micromark-util-events-to-acorn@1.2.3: + resolution: {integrity: sha512-ij4X7Wuc4fED6UoLWkmo0xJQhsktfNh1J0m8g4PbIMPlx+ek/4YdW5mvbye8z/aZvAPUoxgXHrwVlXAPKMRp1w==} + + micromark-util-html-tag-name@1.2.0: + resolution: {integrity: sha512-VTQzcuQgFUD7yYztuQFKXT49KghjtETQ+Wv/zUjGSGBioZnkA4P1XXZPT1FHeJA6RwRXSF47yvJ1tsJdoxwO+Q==} + + micromark-util-normalize-identifier@1.1.0: + resolution: {integrity: sha512-N+w5vhqrBihhjdpM8+5Xsxy71QWqGn7HYNUvch71iV2PM7+E3uWGox1Qp90loa1ephtCxG2ftRV/Conitc6P2Q==} + + micromark-util-resolve-all@1.1.0: + resolution: {integrity: sha512-b/G6BTMSg+bX+xVCshPTPyAu2tmA0E4X98NSR7eIbeC6ycCqCeE7wjfDIgzEbkzdEVJXRtOG4FbEm/uGbCRouA==} + + micromark-util-sanitize-uri@1.2.0: + resolution: {integrity: sha512-QO4GXv0XZfWey4pYFndLUKEAktKkG5kZTdUNaTAkzbuJxn2tNBOr+QtxR2XpWaMhbImT2dPzyLrPXLlPhph34A==} + + micromark-util-sanitize-uri@2.0.0: + resolution: {integrity: sha512-WhYv5UEcZrbAtlsnPuChHUAsu/iBPOVaEVsntLBIdpibO0ddy8OzavZz3iL2xVvBZOpolujSliP65Kq0/7KIYw==} + + micromark-util-subtokenize@1.1.0: + resolution: {integrity: sha512-kUQHyzRoxvZO2PuLzMt2P/dwVsTiivCK8icYTeR+3WgbuPqfHgPPy7nFKbeqRivBvn/3N3GBiNC+JRTMSxEC7A==} + + micromark-util-symbol@1.1.0: + resolution: {integrity: sha512-uEjpEYY6KMs1g7QfJ2eX1SQEV+ZT4rUD3UcF6l57acZvLNK7PBZL+ty82Z1qhK1/yXIY4bdx04FKMgR0g4IAag==} + + micromark-util-symbol@2.0.0: + resolution: {integrity: sha512-8JZt9ElZ5kyTnO94muPxIGS8oyElRJaiJO8EzV6ZSyGQ1Is8xwl4Q45qU5UOg+bGH4AikWziz0iN4sFLWs8PGw==} + + micromark-util-types@1.1.0: + resolution: {integrity: sha512-ukRBgie8TIAcacscVHSiddHjO4k/q3pnedmzMQ4iwDcK0FtFCohKOlFbaOL/mPgfnPsL3C1ZyxJa4sbWrBl3jg==} + + micromark-util-types@2.0.0: + resolution: {integrity: sha512-oNh6S2WMHWRZrmutsRmDDfkzKtxF+bc2VxLC9dvtrDIRFln627VsFP6fLMgTryGDljgLPjkrzQSDcPrjPyDJ5w==} + + micromark@3.2.0: + resolution: {integrity: sha512-uD66tJj54JLYq0De10AhWycZWGQNUvDI55xPgk2sQM5kn1JYlhbCMTtEeT27+vAhW2FBQxLlOmS3pmA7/2z4aA==} + + mri@1.2.0: + resolution: {integrity: sha512-tzzskb3bG8LvYGFF/mDTpq3jpI6Q9wc3LEmBaghu+DdCssd1FakN7Bc0hVNmEyGq1bq3RgfkCb3cmQLpNPOroA==} + engines: {node: '>=4'} + + ms@2.1.2: + resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==} + + nanoid@3.3.7: + resolution: {integrity: sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g==} + engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} + hasBin: true + + next-mdx-remote@4.4.1: + resolution: {integrity: sha512-1BvyXaIou6xy3XoNF4yaMZUCb6vD2GTAa5ciOa6WoO+gAUTYsb1K4rI/HSC2ogAWLrb/7VSV52skz07vOzmqIQ==} + engines: {node: '>=14', npm: '>=7'} + peerDependencies: + react: '>=16.x <=18.x' + react-dom: '>=16.x <=18.x' + + next-seo@6.5.0: + resolution: {integrity: sha512-MfzUeWTN/x/rsKp/1n0213eojO97lIl0unxqbeCY+6pAucViHDA8GSLRRcXpgjsSmBxfCFdfpu7LXbt4ANQoNQ==} + peerDependencies: + next: ^8.1.1-canary.54 || >=9.0.0 + react: '>=16.0.0' + react-dom: '>=16.0.0' + + next-themes@0.2.1: + resolution: {integrity: sha512-B+AKNfYNIzh0vqQQKqQItTS8evEouKD7H5Hj3kmuPERwddR2TxvDSFZuTj6T7Jfn1oyeUyJMydPl1Bkxkh0W7A==} + peerDependencies: + next: '*' + react: '*' + react-dom: '*' + + next@14.2.3: + resolution: {integrity: sha512-dowFkFTR8v79NPJO4QsBUtxv0g9BrS/phluVpMAt2ku7H+cbcBJlopXjkWlwxrk/xGqMemr7JkGPGemPrLLX7A==} + engines: {node: '>=18.17.0'} + hasBin: true + peerDependencies: + '@opentelemetry/api': ^1.1.0 + '@playwright/test': ^1.41.2 + react: ^18.2.0 + react-dom: ^18.2.0 + sass: ^1.3.0 + peerDependenciesMeta: + '@opentelemetry/api': + optional: true + '@playwright/test': + optional: true + sass: + optional: true + + nextra-theme-docs@2.13.4: + resolution: {integrity: sha512-2XOoMfwBCTYBt8ds4ZHftt9Wyf2XsykiNo02eir/XEYB+sGeUoE77kzqfidjEOKCSzOHYbK9BDMcg2+B/2vYRw==} + peerDependencies: + next: '>=9.5.3' + nextra: 2.13.4 + react: '>=16.13.1' + react-dom: '>=16.13.1' + + nextra@2.13.4: + resolution: {integrity: sha512-7of2rSBxuUa3+lbMmZwG9cqgftcoNOVQLTT6Rxf3EhBR9t1EI7b43dted8YoqSNaigdE3j1CoyNkX8N/ZzlEpw==} + engines: {node: '>=16'} + peerDependencies: + next: '>=9.5.3' + react: '>=16.13.1' + react-dom: '>=16.13.1' + + non-layered-tidy-tree-layout@2.0.2: + resolution: {integrity: sha512-gkXMxRzUH+PB0ax9dUN0yYF0S25BqeAYqhgMaLUFmpXLEk7Fcu8f4emJuOAY0V8kjDICxROIKsTAKsV/v355xw==} + + npm-run-path@2.0.2: + resolution: {integrity: sha512-lJxZYlT4DW/bRUtFh1MQIWqmLwQfAxnqWG4HhEdjMlkrJYnJn0Jrr2u3mgxqaWsdiBc76TYkTG/mhrnYTuzfHw==} + engines: {node: '>=4'} + + npm-to-yarn@2.2.1: + resolution: {integrity: sha512-O/j/ROyX0KGLG7O6Ieut/seQ0oiTpHF2tXAcFbpdTLQFiaNtkyTXXocM1fwpaa60dg1qpWj0nHlbNhx6qwuENQ==} + engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} + + p-finally@1.0.0: + resolution: {integrity: sha512-LICb2p9CB7FS+0eR1oqWnHhp0FljGLZCWBE9aix0Uye9W8LTQPwMTYVGWQWIw9RdQiDg4+epXQODwIYJtSJaow==} + engines: {node: '>=4'} + + p-limit@3.1.0: + resolution: {integrity: sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==} + engines: {node: '>=10'} + + parse-entities@4.0.1: + resolution: {integrity: sha512-SWzvYcSJh4d/SGLIOQfZ/CoNv6BTlI6YEQ7Nj82oDVnRpwe/Z/F1EMx42x3JAOwGBlCjeCH0BRJQbQ/opHL17w==} + + parse-numeric-range@1.3.0: + resolution: {integrity: sha512-twN+njEipszzlMJd4ONUYgSfZPDxgHhT9Ahed5uTigpQn90FggW4SA/AIPq/6a149fTbE9qBEcSwE3FAEp6wQQ==} + + parse-path@7.0.0: + resolution: {integrity: sha512-Euf9GG8WT9CdqwuWJGdf3RkUcTBArppHABkO7Lm8IzRQp0e2r/kkFnmhu4TSK30Wcu5rVAZLmfPKSBBi9tWFog==} + + parse-url@8.1.0: + resolution: {integrity: sha512-xDvOoLU5XRrcOZvnI6b8zA6n9O9ejNk/GExuz1yBuWUGn9KA97GI6HTs6u02wKara1CeVmZhH+0TZFdWScR89w==} + + parse5@7.1.2: + resolution: {integrity: sha512-Czj1WaSVpaoj0wbhMzLmWD69anp2WH7FXMB9n1Sy8/ZFF9jolSQVMu1Ij5WIyGmcBmhk7EOndpO4mIpihVqAXw==} + + path-key@2.0.1: + resolution: {integrity: sha512-fEHGKCSmUSDPv4uoj8AlD+joPlq3peND+HRYyxFz4KPw4z926S/b8rIuFs2FYJg3BwsxJf6A9/3eIdLaYC+9Dw==} + engines: {node: '>=4'} + + periscopic@3.1.0: + resolution: {integrity: sha512-vKiQ8RRtkl9P+r/+oefh25C3fhybptkHKCZSPlcXiJux2tJF55GnEj3BVn4A5gKfq9NWWXXrxkHBwVPUfH0opw==} + + picocolors@1.0.1: + resolution: {integrity: sha512-anP1Z8qwhkbmu7MFP5iTt+wQKXgwzf7zTyGlcdzabySa9vd0Xt392U0rVmz9poOaBj0uHJKyyo9/upk0HrEQew==} + + postcss@8.4.31: + resolution: {integrity: sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==} + engines: {node: ^10 || ^12 || >=14} + + property-information@6.5.0: + resolution: {integrity: sha512-PgTgs/BlvHxOu8QuEN7wi5A0OmXaBcHpmCSTehcs6Uuu9IkDIEo13Hy7n898RHfrQ49vKCoGeWZSaAK01nwVig==} + + protocols@2.0.1: + resolution: {integrity: sha512-/XJ368cyBJ7fzLMwLKv1e4vLxOju2MNAIokcr7meSaNcVbWz/CPcW22cP04mwxOErdA5mwjA8Q6w/cdAQxVn7Q==} + + pseudomap@1.0.2: + resolution: {integrity: sha512-b/YwNhb8lk1Zz2+bXXpS/LK9OisiZZ1SNsSLxN1x2OXVEhW2Ckr/7mWE5vrC1ZTiJlD9g19jWszTmJsB+oEpFQ==} + + react-dom@18.3.1: + resolution: {integrity: sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==} + peerDependencies: + react: ^18.3.1 + + react@18.3.1: + resolution: {integrity: sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==} + engines: {node: '>=0.10.0'} + + reading-time@1.5.0: + resolution: {integrity: sha512-onYyVhBNr4CmAxFsKS7bz+uTLRakypIe4R+5A824vBSkQy/hB3fZepoVEf8OVAxzLvK+H/jm9TzpI3ETSm64Kg==} + + regenerator-runtime@0.14.1: + resolution: {integrity: sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==} + + rehype-katex@7.0.0: + resolution: {integrity: sha512-h8FPkGE00r2XKU+/acgqwWUlyzve1IiOKwsEkg4pDL3k48PiE0Pt+/uLtVHDVkN1yA4iurZN6UES8ivHVEQV6Q==} + + rehype-pretty-code@0.9.11: + resolution: {integrity: sha512-Eq90eCYXQJISktfRZ8PPtwc5SUyH6fJcxS8XOMnHPUQZBtC6RYo67gGlley9X2nR8vlniPj0/7oCDEYHKQa/oA==} + engines: {node: '>=16'} + peerDependencies: + shiki: '*' + + rehype-raw@7.0.0: + resolution: {integrity: sha512-/aE8hCfKlQeA8LmyeyQvQF3eBiLRGNlfBJEvWH7ivp9sBqs7TNqBL5X3v157rM4IFETqDnIOO+z5M/biZbo9Ww==} + + remark-gfm@3.0.1: + resolution: {integrity: sha512-lEFDoi2PICJyNrACFOfDD3JlLkuSbOa5Wd8EPt06HUdptv8Gn0bxYTdbU/XXQ3swAPkEaGxxPN9cbnMHvVu1Ig==} + + remark-math@5.1.1: + resolution: {integrity: sha512-cE5T2R/xLVtfFI4cCePtiRn+e6jKMtFDR3P8V3qpv8wpKjwvHoBA4eJzvX+nVrnlNy0911bdGmuspCSwetfYHw==} + + remark-mdx@2.3.0: + resolution: {integrity: sha512-g53hMkpM0I98MU266IzDFMrTD980gNF3BJnkyFcmN+dD873mQeD5rdMO3Y2X+x8umQfbSE0PcoEDl7ledSA+2g==} + + remark-parse@10.0.2: + resolution: {integrity: sha512-3ydxgHa/ZQzG8LvC7jTXccARYDcRld3VfcgIIFs7bI6vbRSxJJmzgLEIIoYKyrfhaY+ujuWaf/PJiMZXoiCXgw==} + + remark-reading-time@2.0.1: + resolution: {integrity: sha512-fy4BKy9SRhtYbEHvp6AItbRTnrhiDGbqLQTSYVbQPGuRCncU1ubSsh9p/W5QZSxtYcUXv8KGL0xBgPLyNJA1xw==} + + remark-rehype@10.1.0: + resolution: {integrity: sha512-EFmR5zppdBp0WQeDVZ/b66CWJipB2q2VLNFMabzDSGR66Z2fQii83G5gTBbgGEnEEA0QRussvrFHxk1HWGJskw==} + + remove-accents@0.5.0: + resolution: {integrity: sha512-8g3/Otx1eJaVD12e31UbJj1YzdtVvzH85HV7t+9MJYk/u3XmkOUJ5Ys9wQrf9PCPK8+xn4ymzqYCiZl6QWKn+A==} + + robust-predicates@3.0.2: + resolution: {integrity: sha512-IXgzBWvWQwE6PrDI05OvmXUIruQTcoMDzRsOd5CDvHCVLcLHMTSYvOK5Cm46kWqlV3yAbuSpBZdJ5oP5OUoStg==} + + rw@1.3.3: + resolution: {integrity: sha512-PdhdWy89SiZogBLaw42zdeqtRJ//zFd2PgQavcICDUgJT5oW10QCRKbJ6bg4r0/UY2M6BWd5tkxuGFRvCkgfHQ==} + + sade@1.8.1: + resolution: {integrity: sha512-xal3CZX1Xlo/k4ApwCFrHVACi9fBqJ7V+mwhBsuf/1IOKbBy098Fex+Wa/5QMubw09pSZ/u8EY8PWgevJsXp1A==} + engines: {node: '>=6'} + + safer-buffer@2.1.2: + resolution: {integrity: sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==} + + scheduler@0.23.2: + resolution: {integrity: sha512-UOShsPwz7NrMUqhR6t0hWjFduvOzbtv7toDH1/hIrfRNIDBnnBWd0CwJTGvTpngVlmwGCdP9/Zl/tVrDqcuYzQ==} + + scroll-into-view-if-needed@3.1.0: + resolution: {integrity: sha512-49oNpRjWRvnU8NyGVmUaYG4jtTkNonFZI86MmGRDqBphEK2EXT9gdEUoQPZhuBM8yWHxCWbobltqYO5M4XrUvQ==} + + section-matter@1.0.0: + resolution: {integrity: sha512-vfD3pmTzGpufjScBh50YHKzEu2lxBWhVEHsNGoEXmCmn2hKGfeNLYMzCJpe8cD7gqX7TJluOVpBkAequ6dgMmA==} + engines: {node: '>=4'} + + shebang-command@1.2.0: + resolution: {integrity: sha512-EV3L1+UQWGor21OmnvojK36mhg+TyIKDh3iFBKBohr5xeXIhNBcx8oWdgkTEEQ+BEFFYdLRuqMfd5L84N1V5Vg==} + engines: {node: '>=0.10.0'} + + shebang-regex@1.0.0: + resolution: {integrity: sha512-wpoSFAxys6b2a2wHZ1XpDSgD7N9iVjg29Ph9uV/uaP9Ex/KXlkTZTeddxDPSYQpgvzKLGJke2UU0AzoGCjNIvQ==} + engines: {node: '>=0.10.0'} + + shiki@0.14.7: + resolution: {integrity: sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==} + + signal-exit@3.0.7: + resolution: {integrity: sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==} + + slash@3.0.0: + resolution: {integrity: sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==} + engines: {node: '>=8'} + + sort-keys@5.0.0: + resolution: {integrity: sha512-Pdz01AvCAottHTPQGzndktFNdbRA75BgOfeT1hH+AMnJFv8lynkPi42rfeEhpx1saTEI3YNMWxfqu0sFD1G8pw==} + engines: {node: '>=12'} + + source-map-js@1.2.0: + resolution: {integrity: sha512-itJW8lvSA0TXEphiRoawsCksnlf8SyvmFzIhltqAHluXd88pkCd+cXJVHTDwdCr0IzwptSm035IHQktUu1QUMg==} + engines: {node: '>=0.10.0'} + + source-map@0.7.4: + resolution: {integrity: sha512-l3BikUxvPOcn5E74dZiq5BGsTb5yEwhaTSzccU6t4sDOH8NWJCstKO5QT2CvtFoK6F0saL7p9xHAqHOlCPJygA==} + engines: {node: '>= 8'} + + space-separated-tokens@2.0.2: + resolution: {integrity: sha512-PEGlAwrG8yXGXRjW32fGbg66JAlOAwbObuqVoJpv/mRgoWDQfgH1wDPvtzWyUSNAXBGSk8h755YDbbcEy3SH2Q==} + + sprintf-js@1.0.3: + resolution: {integrity: sha512-D9cPgkvLlV3t3IzL0D0YLvGA9Ahk4PcvVwUbN0dSGr1aP0Nrt4AEnTUbuGvquEC0mA64Gqt1fzirlRs5ibXx8g==} + + streamsearch@1.1.0: + resolution: {integrity: sha512-Mcc5wHehp9aXz1ax6bZUyY5afg9u2rv5cqQI3mRrYkGC8rW2hM02jWuwjtL++LS5qinSyhj2QfLyNsuc+VsExg==} + engines: {node: '>=10.0.0'} + + stringify-entities@4.0.4: + resolution: {integrity: sha512-IwfBptatlO+QCJUo19AqvrPNqlVMpW9YEL2LIVY+Rpv2qsjCGxaDLNRgeGsQWJhfItebuJhsGSLjaBbNSQ+ieg==} + + strip-bom-string@1.0.0: + resolution: {integrity: sha512-uCC2VHvQRYu+lMh4My/sFNmF2klFymLX1wHJeXnbEJERpV/ZsVuonzerjfrGpIGF7LBVa1O7i9kjiWvJiFck8g==} + engines: {node: '>=0.10.0'} + + strip-eof@1.0.0: + resolution: {integrity: sha512-7FCwGGmx8mD5xQd3RPUvnSpUXHM3BWuzjtpD4TXsfcZ9EL4azvVVUscFYwD9nx8Kh+uCBC00XBtAykoMHwTh8Q==} + engines: {node: '>=0.10.0'} + + style-to-object@0.4.4: + resolution: {integrity: sha512-HYNoHZa2GorYNyqiCaBgsxvcJIn7OHq6inEga+E6Ke3m5JkoqpQbnFssk4jwe+K7AhGa2fcha4wSOf1Kn01dMg==} + + styled-jsx@5.1.1: + resolution: {integrity: sha512-pW7uC1l4mBZ8ugbiZrcIsiIvVx1UmTfw7UkC3Um2tmfUq9Bhk8IiyEIPl6F8agHgjzku6j0xQEZbfA5uSgSaCw==} + engines: {node: '>= 12.0.0'} + peerDependencies: + '@babel/core': '*' + babel-plugin-macros: '*' + react: '>= 16.8.0 || 17.x.x || ^18.0.0-0' + peerDependenciesMeta: + '@babel/core': + optional: true + babel-plugin-macros: + optional: true + + stylis@4.3.2: + resolution: {integrity: sha512-bhtUjWd/z6ltJiQwg0dUfxEJ+W+jdqQd8TbWLWyeIJHlnsqmGLRFFd8e5mA0AZi/zx90smXRlN66YMTcaSFifg==} + + supports-color@4.5.0: + resolution: {integrity: sha512-ycQR/UbvI9xIlEdQT1TQqwoXtEldExbCEAJgRo5YXlmSKjv6ThHnP9/vwGa1gr19Gfw+LkFd7KqYMhzrRC5JYw==} + engines: {node: '>=4'} + + title@3.5.3: + resolution: {integrity: sha512-20JyowYglSEeCvZv3EZ0nZ046vLarO37prvV0mbtQV7C8DJPGgN967r8SJkqd3XK3K3lD3/Iyfp3avjfil8Q2Q==} + hasBin: true + + titleize@1.0.0: + resolution: {integrity: sha512-TARUb7z1pGvlLxgPk++7wJ6aycXF3GJ0sNSBTAsTuJrQG5QuZlkUQP+zl+nbjAh4gMX9yDw9ZYklMd7vAfJKEw==} + engines: {node: '>=0.10.0'} + + trim-lines@3.0.1: + resolution: {integrity: sha512-kRj8B+YHZCc9kQYdWfJB2/oUl9rA99qbowYYBtr4ui4mZyAQ2JpvVBd/6U2YloATfqBhBTSMhTpgBHtU0Mf3Rg==} + + trough@2.2.0: + resolution: {integrity: sha512-tmMpK00BjZiUyVyvrBK7knerNgmgvcV/KLVyuma/SC+TQN167GrMRciANTz09+k3zW8L8t60jWO1GpfkZdjTaw==} + + ts-dedent@2.2.0: + resolution: {integrity: sha512-q5W7tVM71e2xjHZTlgfTDoPF/SmqKG5hddq9SzR49CH2hayqRKJtQ4mtRlSxKaJlR/+9rEM+mnBHf7I2/BQcpQ==} + engines: {node: '>=6.10'} + + tslib@2.6.3: + resolution: {integrity: sha512-xNvxJEOUiWPGhUuUdQgAJPKOOJfGnIyKySOc09XkKsgdUV/3E2zvwZYdejjmRgPCgcym1juLH3226yA7sEFJKQ==} + + type-fest@1.4.0: + resolution: {integrity: sha512-yGSza74xk0UG8k+pLh5oeoYirvIiWo5t0/o3zHHAO2tRDiZcxWP7fywNlXhqb6/r6sWvwi+RsyQMWhVLe4BVuA==} + engines: {node: '>=10'} + + typescript@5.4.5: + resolution: {integrity: sha512-vcI4UpRgg81oIRUFwR0WSIHKt11nJ7SAVlYNIu+QpqeyXP+gpQJy/Z4+F0aGxSE4MqwjyXvW/TzgkLAx2AGHwQ==} + engines: {node: '>=14.17'} + hasBin: true + + undici-types@5.26.5: + resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} + + unified@10.1.2: + resolution: {integrity: sha512-pUSWAi/RAnVy1Pif2kAoeWNBa3JVrx0MId2LASj8G+7AiHWoKZNTomq6LG326T68U7/e263X6fTdcXIy7XnF7Q==} + + unist-util-find-after@5.0.0: + resolution: {integrity: sha512-amQa0Ep2m6hE2g72AugUItjbuM8X8cGQnFoHk0pGfrFeT9GZhzN5SW8nRsiGKK7Aif4CrACPENkA6P/Lw6fHGQ==} + + unist-util-generated@2.0.1: + resolution: {integrity: sha512-qF72kLmPxAw0oN2fwpWIqbXAVyEqUzDHMsbtPvOudIlUzXYFIeQIuxXQCRCFh22B7cixvU0MG7m3MW8FTq/S+A==} + + unist-util-is@5.2.1: + resolution: {integrity: sha512-u9njyyfEh43npf1M+yGKDGVPbY/JWEemg5nH05ncKPfi+kBbKBJoTdsogMu33uhytuLlv9y0O7GH7fEdwLdLQw==} + + unist-util-is@6.0.0: + resolution: {integrity: sha512-2qCTHimwdxLfz+YzdGfkqNlH0tLi9xjTnHddPmJwtIG9MGsdbutfTc4P+haPD7l7Cjxf/WZj+we5qfVPvvxfYw==} + + unist-util-position-from-estree@1.1.2: + resolution: {integrity: sha512-poZa0eXpS+/XpoQwGwl79UUdea4ol2ZuCYguVaJS4qzIOMDzbqz8a3erUCOmubSZkaOuGamb3tX790iwOIROww==} + + unist-util-position@4.0.4: + resolution: {integrity: sha512-kUBE91efOWfIVBo8xzh/uZQ7p9ffYRtUbMRZBNFYwf0RK8koUMx6dGUfwylLOKmaT2cs4wSW96QoYUSXAyEtpg==} + + unist-util-position@5.0.0: + resolution: {integrity: sha512-fucsC7HjXvkB5R3kTCO7kUjRdrS0BJt3M/FPxmHMBOm8JQi2BsHAHFsy27E0EolP8rp0NzXsJ+jNPyDWvOJZPA==} + + unist-util-remove-position@4.0.2: + resolution: {integrity: sha512-TkBb0HABNmxzAcfLf4qsIbFbaPDvMO6wa3b3j4VcEzFVaw1LBKwnW4/sRJ/atSLSzoIg41JWEdnE7N6DIhGDGQ==} + + unist-util-remove-position@5.0.0: + resolution: {integrity: sha512-Hp5Kh3wLxv0PHj9m2yZhhLt58KzPtEYKQQ4yxfYFEO7EvHwzyDYnduhHnY1mDxoqr7VUwVuHXk9RXKIiYS1N8Q==} + + unist-util-remove@4.0.0: + resolution: {integrity: sha512-b4gokeGId57UVRX/eVKej5gXqGlc9+trkORhFJpu9raqZkZhU0zm8Doi05+HaiBsMEIJowL+2WtQ5ItjsngPXg==} + + unist-util-stringify-position@3.0.3: + resolution: {integrity: sha512-k5GzIBZ/QatR8N5X2y+drfpWG8IDBzdnVj6OInRNWm1oXrzydiaAT2OQiA8DPRRZyAKb9b6I2a6PxYklZD0gKg==} + + unist-util-stringify-position@4.0.0: + resolution: {integrity: sha512-0ASV06AAoKCDkS2+xw5RXJywruurpbC4JZSm7nr7MOt1ojAzvyyaO+UxZf18j8FCF6kmzCZKcAgN/yu2gm2XgQ==} + + unist-util-visit-parents@4.1.1: + resolution: {integrity: sha512-1xAFJXAKpnnJl8G7K5KgU7FY55y3GcLIXqkzUj5QF/QVP7biUm0K0O2oqVkYsdjzJKifYeWn9+o6piAK2hGSHw==} + + unist-util-visit-parents@5.1.3: + resolution: {integrity: sha512-x6+y8g7wWMyQhL1iZfhIPhDAs7Xwbn9nRosDXl7qoPTSCy0yNxnKc+hWokFifWQIDGi154rdUqKvbCa4+1kLhg==} + + unist-util-visit-parents@6.0.1: + resolution: {integrity: sha512-L/PqWzfTP9lzzEa6CKs0k2nARxTdZduw3zyh8d2NVBnsyvHjSX4TWse388YrrQKbvI8w20fGjGlhgT96WwKykw==} + + unist-util-visit@3.1.0: + resolution: {integrity: sha512-Szoh+R/Ll68QWAyQyZZpQzZQm2UPbxibDvaY8Xc9SUtYgPsDzx5AWSk++UUt2hJuow8mvwR+rG+LQLw+KsuAKA==} + + unist-util-visit@4.1.2: + resolution: {integrity: sha512-MSd8OUGISqHdVvfY9TPhyK2VdUrPgxkUtWSuMHF6XAAFuL4LokseigBnZtPnJMu+FbynTkFNnFlyjxpVKujMRg==} + + unist-util-visit@5.0.0: + resolution: {integrity: sha512-MR04uvD+07cwl/yhVuVWAtw+3GOR/knlL55Nd/wAdblk27GCVt3lqpTivy/tkJcZoNPzTwS1Y+KMojlLDhoTzg==} + + uuid@9.0.1: + resolution: {integrity: sha512-b+1eJOlsR9K8HJpow9Ok3fiWOWSIcIzXodvv0rQjVoOVNpWMpxf1wZNpt4y9h10odCNrqnYp1OBzRktckBe3sA==} + hasBin: true + + uvu@0.5.6: + resolution: {integrity: sha512-+g8ENReyr8YsOc6fv/NVJs2vFdHBnBNdfE49rshrTzDWOlUx4Gq7KOS2GD8eqhy2j+Ejq29+SbKH8yjkAqXqoA==} + engines: {node: '>=8'} + hasBin: true + + vfile-location@5.0.2: + resolution: {integrity: sha512-NXPYyxyBSH7zB5U6+3uDdd6Nybz6o6/od9rk8bp9H8GR3L+cm/fC0uUTbqBmUTnMCUDslAGBOIKNfvvb+gGlDg==} + + vfile-matter@3.0.1: + resolution: {integrity: sha512-CAAIDwnh6ZdtrqAuxdElUqQRQDQgbbIrYtDYI8gCjXS1qQ+1XdLoK8FIZWxJwn0/I+BkSSZpar3SOgjemQz4fg==} + + vfile-message@3.1.4: + resolution: {integrity: sha512-fa0Z6P8HUrQN4BZaX05SIVXic+7kE3b05PWAtPuYP9QLHsLKYR7/AlLW3NtOrpXRLeawpDLMsVkmk5DG0NXgWw==} + + vfile-message@4.0.2: + resolution: {integrity: sha512-jRDZ1IMLttGj41KcZvlrYAaI3CfqpLpfpf+Mfig13viT6NKvRzWZ+lXz0Y5D60w6uJIBAOGq9mSHf0gktF0duw==} + + vfile@5.3.7: + resolution: {integrity: sha512-r7qlzkgErKjobAmyNIkkSpizsFPYiUPuJb5pNW1RB4JcYVZhs4lIbVqk8XPk033CV/1z8ss5pkax8SuhGpcG8g==} + + vfile@6.0.1: + resolution: {integrity: sha512-1bYqc7pt6NIADBJ98UiG0Bn/CHIVOoZ/IyEkqIruLg0mE1BKzkOXY2D6CSqQIcKqgadppE5lrxgWXJmXd7zZJw==} + + vscode-oniguruma@1.7.0: + resolution: {integrity: sha512-L9WMGRfrjOhgHSdOYgCt/yRMsXzLDJSL7BPrOZt73gU0iWO4mpqzqQzOz5srxqTvMBaR0XZTSrVWo4j55Rc6cA==} + + vscode-textmate@8.0.0: + resolution: {integrity: sha512-AFbieoL7a5LMqcnOF04ji+rpXadgOXnZsxQr//r83kLPr7biP7am3g9zbaZIaBGwBRWeSvoMD4mgPdX3e4NWBg==} + + web-namespaces@2.0.1: + resolution: {integrity: sha512-bKr1DkiNa2krS7qxNtdrtHAmzuYGFQLiQ13TsorsdT6ULTkPLKuu5+GsFpDlg6JFjUTwX2DyhMPG2be8uPrqsQ==} + + web-worker@1.3.0: + resolution: {integrity: sha512-BSR9wyRsy/KOValMgd5kMyr3JzpdeoR9KVId8u5GVlTTAtNChlsE4yTxeY7zMdNSyOmoKBv8NH2qeRY9Tg+IaA==} + + which@1.3.1: + resolution: {integrity: sha512-HxJdYWq1MTIQbJ3nw0cqssHoTNU267KlrDuGZ1WYlxDStUtKUhOaJmh112/TZmHxxUfuJqPXSOm7tDyas0OSIQ==} + hasBin: true + + yallist@2.1.2: + resolution: {integrity: sha512-ncTzHV7NvsQZkYe1DW7cbDLm0YpzHmZF5r/iyP3ZnQtMiJ+pjzisCiMNI+Sj+xQF5pXhSHxSB3uDbsBTzY/c2A==} + + yocto-queue@0.1.0: + resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==} + engines: {node: '>=10'} + + zod@3.23.8: + resolution: {integrity: sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==} + + zwitch@2.0.4: + resolution: {integrity: sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A==} + +snapshots: + + '@babel/runtime@7.24.7': + dependencies: + regenerator-runtime: 0.14.1 + + '@braintree/sanitize-url@6.0.4': {} + + '@headlessui/react@1.7.19(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@tanstack/react-virtual': 3.5.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + client-only: 0.0.1 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + '@mdx-js/mdx@2.3.0': + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/mdx': 2.0.13 + estree-util-build-jsx: 2.2.2 + estree-util-is-identifier-name: 2.1.0 + estree-util-to-js: 1.2.0 + estree-walker: 3.0.3 + hast-util-to-estree: 2.3.3 + markdown-extensions: 1.1.1 + periscopic: 3.1.0 + remark-mdx: 2.3.0 + remark-parse: 10.0.2 + remark-rehype: 10.1.0 + unified: 10.1.2 + unist-util-position-from-estree: 1.1.2 + unist-util-stringify-position: 3.0.3 + unist-util-visit: 4.1.2 + vfile: 5.3.7 + transitivePeerDependencies: + - supports-color + + '@mdx-js/react@2.3.0(react@18.3.1)': + dependencies: + '@types/mdx': 2.0.13 + '@types/react': 18.3.3 + react: 18.3.1 + + '@napi-rs/simple-git-android-arm-eabi@0.1.16': + optional: true + + '@napi-rs/simple-git-android-arm64@0.1.16': + optional: true + + '@napi-rs/simple-git-darwin-arm64@0.1.16': + optional: true + + '@napi-rs/simple-git-darwin-x64@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-arm-gnueabihf@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-arm64-gnu@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-arm64-musl@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-x64-gnu@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-x64-musl@0.1.16': + optional: true + + '@napi-rs/simple-git-win32-arm64-msvc@0.1.16': + optional: true + + '@napi-rs/simple-git-win32-x64-msvc@0.1.16': + optional: true + + '@napi-rs/simple-git@0.1.16': + optionalDependencies: + '@napi-rs/simple-git-android-arm-eabi': 0.1.16 + '@napi-rs/simple-git-android-arm64': 0.1.16 + '@napi-rs/simple-git-darwin-arm64': 0.1.16 + '@napi-rs/simple-git-darwin-x64': 0.1.16 + '@napi-rs/simple-git-linux-arm-gnueabihf': 0.1.16 + '@napi-rs/simple-git-linux-arm64-gnu': 0.1.16 + '@napi-rs/simple-git-linux-arm64-musl': 0.1.16 + '@napi-rs/simple-git-linux-x64-gnu': 0.1.16 + '@napi-rs/simple-git-linux-x64-musl': 0.1.16 + '@napi-rs/simple-git-win32-arm64-msvc': 0.1.16 + '@napi-rs/simple-git-win32-x64-msvc': 0.1.16 + + '@next/env@14.2.3': {} + + '@next/swc-darwin-arm64@14.2.3': + optional: true + + '@next/swc-darwin-x64@14.2.3': + optional: true + + '@next/swc-linux-arm64-gnu@14.2.3': + optional: true + + '@next/swc-linux-arm64-musl@14.2.3': + optional: true + + '@next/swc-linux-x64-gnu@14.2.3': + optional: true + + '@next/swc-linux-x64-musl@14.2.3': + optional: true + + '@next/swc-win32-arm64-msvc@14.2.3': + optional: true + + '@next/swc-win32-ia32-msvc@14.2.3': + optional: true + + '@next/swc-win32-x64-msvc@14.2.3': + optional: true + + '@popperjs/core@2.11.8': {} + + '@swc/counter@0.1.3': {} + + '@swc/helpers@0.5.5': + dependencies: + '@swc/counter': 0.1.3 + tslib: 2.6.3 + + '@tanstack/react-virtual@3.5.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@tanstack/virtual-core': 3.5.1 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + '@tanstack/virtual-core@3.5.1': {} + + '@theguild/remark-mermaid@0.0.5(react@18.3.1)': + dependencies: + mermaid: 10.9.1 + react: 18.3.1 + unist-util-visit: 5.0.0 + transitivePeerDependencies: + - supports-color + + '@theguild/remark-npm2yarn@0.2.1': + dependencies: + npm-to-yarn: 2.2.1 + unist-util-visit: 5.0.0 + + '@types/acorn@4.0.6': + dependencies: + '@types/estree': 1.0.5 + + '@types/d3-scale-chromatic@3.0.3': {} + + '@types/d3-scale@4.0.8': + dependencies: + '@types/d3-time': 3.0.3 + + '@types/d3-time@3.0.3': {} + + '@types/debug@4.1.12': + dependencies: + '@types/ms': 0.7.34 + + '@types/estree-jsx@1.0.5': + dependencies: + '@types/estree': 1.0.5 + + '@types/estree@1.0.5': {} + + '@types/hast@2.3.10': + dependencies: + '@types/unist': 2.0.10 + + '@types/hast@3.0.4': + dependencies: + '@types/unist': 3.0.2 + + '@types/js-yaml@4.0.9': {} + + '@types/katex@0.16.7': {} + + '@types/mdast@3.0.15': + dependencies: + '@types/unist': 2.0.10 + + '@types/mdast@4.0.4': + dependencies: + '@types/unist': 3.0.2 + + '@types/mdx@2.0.13': {} + + '@types/ms@0.7.34': {} + + '@types/node@20.14.2': + dependencies: + undici-types: 5.26.5 + + '@types/prop-types@15.7.12': {} + + '@types/react-dom@18.3.0': + dependencies: + '@types/react': 18.3.3 + + '@types/react@18.3.3': + dependencies: + '@types/prop-types': 15.7.12 + csstype: 3.1.3 + + '@types/unist@2.0.10': {} + + '@types/unist@3.0.2': {} + + '@ungap/structured-clone@1.2.0': {} + + acorn-jsx@5.3.2(acorn@8.11.3): + dependencies: + acorn: 8.11.3 + + acorn@8.11.3: {} + + ansi-sequence-parser@1.1.1: {} + + ansi-styles@3.2.1: + dependencies: + color-convert: 1.9.3 + + arch@2.2.0: {} + + arg@1.0.0: {} + + argparse@1.0.10: + dependencies: + sprintf-js: 1.0.3 + + argparse@2.0.1: {} + + astring@1.8.6: {} + + bail@2.0.2: {} + + busboy@1.6.0: + dependencies: + streamsearch: 1.1.0 + + caniuse-lite@1.0.30001629: {} + + ccount@2.0.1: {} + + chalk@2.3.0: + dependencies: + ansi-styles: 3.2.1 + escape-string-regexp: 1.0.5 + supports-color: 4.5.0 + + character-entities-html4@2.1.0: {} + + character-entities-legacy@3.0.0: {} + + character-entities@2.0.2: {} + + character-reference-invalid@2.0.1: {} + + client-only@0.0.1: {} + + clipboardy@1.2.2: + dependencies: + arch: 2.2.0 + execa: 0.8.0 + + clsx@2.1.1: {} + + color-convert@1.9.3: + dependencies: + color-name: 1.1.3 + + color-name@1.1.3: {} + + comma-separated-tokens@2.0.3: {} + + commander@7.2.0: {} + + commander@8.3.0: {} + + compute-scroll-into-view@3.1.0: {} + + cose-base@1.0.3: + dependencies: + layout-base: 1.0.2 + + cross-spawn@5.1.0: + dependencies: + lru-cache: 4.1.5 + shebang-command: 1.2.0 + which: 1.3.1 + + csstype@3.1.3: {} + + cytoscape-cose-bilkent@4.1.0(cytoscape@3.29.2): + dependencies: + cose-base: 1.0.3 + cytoscape: 3.29.2 + + cytoscape@3.29.2: {} + + d3-array@2.12.1: + dependencies: + internmap: 1.0.1 + + d3-array@3.2.4: + dependencies: + internmap: 2.0.3 + + d3-axis@3.0.0: {} + + d3-brush@3.0.0: + dependencies: + d3-dispatch: 3.0.1 + d3-drag: 3.0.0 + d3-interpolate: 3.0.1 + d3-selection: 3.0.0 + d3-transition: 3.0.1(d3-selection@3.0.0) + + d3-chord@3.0.1: + dependencies: + d3-path: 3.1.0 + + d3-color@3.1.0: {} + + d3-contour@4.0.2: + dependencies: + d3-array: 3.2.4 + + d3-delaunay@6.0.4: + dependencies: + delaunator: 5.0.1 + + d3-dispatch@3.0.1: {} + + d3-drag@3.0.0: + dependencies: + d3-dispatch: 3.0.1 + d3-selection: 3.0.0 + + d3-dsv@3.0.1: + dependencies: + commander: 7.2.0 + iconv-lite: 0.6.3 + rw: 1.3.3 + + d3-ease@3.0.1: {} + + d3-fetch@3.0.1: + dependencies: + d3-dsv: 3.0.1 + + d3-force@3.0.0: + dependencies: + d3-dispatch: 3.0.1 + d3-quadtree: 3.0.1 + d3-timer: 3.0.1 + + d3-format@3.1.0: {} + + d3-geo@3.1.1: + dependencies: + d3-array: 3.2.4 + + d3-hierarchy@3.1.2: {} + + d3-interpolate@3.0.1: + dependencies: + d3-color: 3.1.0 + + d3-path@1.0.9: {} + + d3-path@3.1.0: {} + + d3-polygon@3.0.1: {} + + d3-quadtree@3.0.1: {} + + d3-random@3.0.1: {} + + d3-sankey@0.12.3: + dependencies: + d3-array: 2.12.1 + d3-shape: 1.3.7 + + d3-scale-chromatic@3.1.0: + dependencies: + d3-color: 3.1.0 + d3-interpolate: 3.0.1 + + d3-scale@4.0.2: + dependencies: + d3-array: 3.2.4 + d3-format: 3.1.0 + d3-interpolate: 3.0.1 + d3-time: 3.1.0 + d3-time-format: 4.1.0 + + d3-selection@3.0.0: {} + + d3-shape@1.3.7: + dependencies: + d3-path: 1.0.9 + + d3-shape@3.2.0: + dependencies: + d3-path: 3.1.0 + + d3-time-format@4.1.0: + dependencies: + d3-time: 3.1.0 + + d3-time@3.1.0: + dependencies: + d3-array: 3.2.4 + + d3-timer@3.0.1: {} + + d3-transition@3.0.1(d3-selection@3.0.0): + dependencies: + d3-color: 3.1.0 + d3-dispatch: 3.0.1 + d3-ease: 3.0.1 + d3-interpolate: 3.0.1 + d3-selection: 3.0.0 + d3-timer: 3.0.1 + + d3-zoom@3.0.0: + dependencies: + d3-dispatch: 3.0.1 + d3-drag: 3.0.0 + d3-interpolate: 3.0.1 + d3-selection: 3.0.0 + d3-transition: 3.0.1(d3-selection@3.0.0) + + d3@7.9.0: + dependencies: + d3-array: 3.2.4 + d3-axis: 3.0.0 + d3-brush: 3.0.0 + d3-chord: 3.0.1 + d3-color: 3.1.0 + d3-contour: 4.0.2 + d3-delaunay: 6.0.4 + d3-dispatch: 3.0.1 + d3-drag: 3.0.0 + d3-dsv: 3.0.1 + d3-ease: 3.0.1 + d3-fetch: 3.0.1 + d3-force: 3.0.0 + d3-format: 3.1.0 + d3-geo: 3.1.1 + d3-hierarchy: 3.1.2 + d3-interpolate: 3.0.1 + d3-path: 3.1.0 + d3-polygon: 3.0.1 + d3-quadtree: 3.0.1 + d3-random: 3.0.1 + d3-scale: 4.0.2 + d3-scale-chromatic: 3.1.0 + d3-selection: 3.0.0 + d3-shape: 3.2.0 + d3-time: 3.1.0 + d3-time-format: 4.1.0 + d3-timer: 3.0.1 + d3-transition: 3.0.1(d3-selection@3.0.0) + d3-zoom: 3.0.0 + + dagre-d3-es@7.0.10: + dependencies: + d3: 7.9.0 + lodash-es: 4.17.21 + + dayjs@1.11.11: {} + + debug@4.3.5: + dependencies: + ms: 2.1.2 + + decode-named-character-reference@1.0.2: + dependencies: + character-entities: 2.0.2 + + delaunator@5.0.1: + dependencies: + robust-predicates: 3.0.2 + + dequal@2.0.3: {} + + devlop@1.1.0: + dependencies: + dequal: 2.0.3 + + diff@5.2.0: {} + + dompurify@3.1.5: {} + + elkjs@0.9.3: {} + + entities@4.5.0: {} + + escape-string-regexp@1.0.5: {} + + escape-string-regexp@5.0.0: {} + + esprima@4.0.1: {} + + estree-util-attach-comments@2.1.1: + dependencies: + '@types/estree': 1.0.5 + + estree-util-build-jsx@2.2.2: + dependencies: + '@types/estree-jsx': 1.0.5 + estree-util-is-identifier-name: 2.1.0 + estree-walker: 3.0.3 + + estree-util-is-identifier-name@2.1.0: {} + + estree-util-to-js@1.2.0: + dependencies: + '@types/estree-jsx': 1.0.5 + astring: 1.8.6 + source-map: 0.7.4 + + estree-util-value-to-estree@1.3.0: + dependencies: + is-plain-obj: 3.0.0 + + estree-util-visit@1.2.1: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/unist': 2.0.10 + + estree-walker@3.0.3: + dependencies: + '@types/estree': 1.0.5 + + execa@0.8.0: + dependencies: + cross-spawn: 5.1.0 + get-stream: 3.0.0 + is-stream: 1.1.0 + npm-run-path: 2.0.2 + p-finally: 1.0.0 + signal-exit: 3.0.7 + strip-eof: 1.0.0 + + extend-shallow@2.0.1: + dependencies: + is-extendable: 0.1.1 + + extend@3.0.2: {} + + flexsearch@0.7.43: {} + + focus-visible@5.2.0: {} + + get-stream@3.0.0: {} + + git-up@7.0.0: + dependencies: + is-ssh: 1.4.0 + parse-url: 8.1.0 + + git-url-parse@13.1.1: + dependencies: + git-up: 7.0.0 + + github-slugger@2.0.0: {} + + graceful-fs@4.2.11: {} + + gray-matter@4.0.3: + dependencies: + js-yaml: 3.14.1 + kind-of: 6.0.3 + section-matter: 1.0.0 + strip-bom-string: 1.0.0 + + has-flag@2.0.0: {} + + hash-obj@4.0.0: + dependencies: + is-obj: 3.0.0 + sort-keys: 5.0.0 + type-fest: 1.4.0 + + hast-util-from-dom@5.0.0: + dependencies: + '@types/hast': 3.0.4 + hastscript: 8.0.0 + web-namespaces: 2.0.1 + + hast-util-from-html-isomorphic@2.0.0: + dependencies: + '@types/hast': 3.0.4 + hast-util-from-dom: 5.0.0 + hast-util-from-html: 2.0.1 + unist-util-remove-position: 5.0.0 + + hast-util-from-html@2.0.1: + dependencies: + '@types/hast': 3.0.4 + devlop: 1.1.0 + hast-util-from-parse5: 8.0.1 + parse5: 7.1.2 + vfile: 6.0.1 + vfile-message: 4.0.2 + + hast-util-from-parse5@8.0.1: + dependencies: + '@types/hast': 3.0.4 + '@types/unist': 3.0.2 + devlop: 1.1.0 + hastscript: 8.0.0 + property-information: 6.5.0 + vfile: 6.0.1 + vfile-location: 5.0.2 + web-namespaces: 2.0.1 + + hast-util-is-element@3.0.0: + dependencies: + '@types/hast': 3.0.4 + + hast-util-parse-selector@4.0.0: + dependencies: + '@types/hast': 3.0.4 + + hast-util-raw@9.0.3: + dependencies: + '@types/hast': 3.0.4 + '@types/unist': 3.0.2 + '@ungap/structured-clone': 1.2.0 + hast-util-from-parse5: 8.0.1 + hast-util-to-parse5: 8.0.0 + html-void-elements: 3.0.0 + mdast-util-to-hast: 13.1.0 + parse5: 7.1.2 + unist-util-position: 5.0.0 + unist-util-visit: 5.0.0 + vfile: 6.0.1 + web-namespaces: 2.0.1 + zwitch: 2.0.4 + + hast-util-to-estree@2.3.3: + dependencies: + '@types/estree': 1.0.5 + '@types/estree-jsx': 1.0.5 + '@types/hast': 2.3.10 + '@types/unist': 2.0.10 + comma-separated-tokens: 2.0.3 + estree-util-attach-comments: 2.1.1 + estree-util-is-identifier-name: 2.1.0 + hast-util-whitespace: 2.0.1 + mdast-util-mdx-expression: 1.3.2 + mdast-util-mdxjs-esm: 1.3.1 + property-information: 6.5.0 + space-separated-tokens: 2.0.2 + style-to-object: 0.4.4 + unist-util-position: 4.0.4 + zwitch: 2.0.4 + transitivePeerDependencies: + - supports-color + + hast-util-to-parse5@8.0.0: + dependencies: + '@types/hast': 3.0.4 + comma-separated-tokens: 2.0.3 + devlop: 1.1.0 + property-information: 6.5.0 + space-separated-tokens: 2.0.2 + web-namespaces: 2.0.1 + zwitch: 2.0.4 + + hast-util-to-text@4.0.2: + dependencies: + '@types/hast': 3.0.4 + '@types/unist': 3.0.2 + hast-util-is-element: 3.0.0 + unist-util-find-after: 5.0.0 + + hast-util-whitespace@2.0.1: {} + + hastscript@8.0.0: + dependencies: + '@types/hast': 3.0.4 + comma-separated-tokens: 2.0.3 + hast-util-parse-selector: 4.0.0 + property-information: 6.5.0 + space-separated-tokens: 2.0.2 + + html-void-elements@3.0.0: {} + + iconv-lite@0.6.3: + dependencies: + safer-buffer: 2.1.2 + + inline-style-parser@0.1.1: {} + + internmap@1.0.1: {} + + internmap@2.0.3: {} + + intersection-observer@0.12.2: {} + + is-alphabetical@2.0.1: {} + + is-alphanumerical@2.0.1: + dependencies: + is-alphabetical: 2.0.1 + is-decimal: 2.0.1 + + is-buffer@2.0.5: {} + + is-decimal@2.0.1: {} + + is-extendable@0.1.1: {} + + is-hexadecimal@2.0.1: {} + + is-obj@3.0.0: {} + + is-plain-obj@3.0.0: {} + + is-plain-obj@4.1.0: {} + + is-reference@3.0.2: + dependencies: + '@types/estree': 1.0.5 + + is-ssh@1.4.0: + dependencies: + protocols: 2.0.1 + + is-stream@1.1.0: {} + + isexe@2.0.0: {} + + js-tokens@4.0.0: {} + + js-yaml@3.14.1: + dependencies: + argparse: 1.0.10 + esprima: 4.0.1 + + js-yaml@4.1.0: + dependencies: + argparse: 2.0.1 + + jsonc-parser@3.2.1: {} + + katex@0.16.10: + dependencies: + commander: 8.3.0 + + khroma@2.1.0: {} + + kind-of@6.0.3: {} + + kleur@4.1.5: {} + + layout-base@1.0.2: {} + + lodash-es@4.17.21: {} + + lodash.get@4.4.2: {} + + longest-streak@3.1.0: {} + + loose-envify@1.4.0: + dependencies: + js-tokens: 4.0.0 + + lru-cache@4.1.5: + dependencies: + pseudomap: 1.0.2 + yallist: 2.1.2 + + markdown-extensions@1.1.1: {} + + markdown-table@3.0.3: {} + + match-sorter@6.3.4: + dependencies: + '@babel/runtime': 7.24.7 + remove-accents: 0.5.0 + + mdast-util-definitions@5.1.2: + dependencies: + '@types/mdast': 3.0.15 + '@types/unist': 2.0.10 + unist-util-visit: 4.1.2 + + mdast-util-find-and-replace@2.2.2: + dependencies: + '@types/mdast': 3.0.15 + escape-string-regexp: 5.0.0 + unist-util-is: 5.2.1 + unist-util-visit-parents: 5.1.3 + + mdast-util-from-markdown@1.3.1: + dependencies: + '@types/mdast': 3.0.15 + '@types/unist': 2.0.10 + decode-named-character-reference: 1.0.2 + mdast-util-to-string: 3.2.0 + micromark: 3.2.0 + micromark-util-decode-numeric-character-reference: 1.1.0 + micromark-util-decode-string: 1.1.0 + micromark-util-normalize-identifier: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + unist-util-stringify-position: 3.0.3 + uvu: 0.5.6 + transitivePeerDependencies: + - supports-color + + mdast-util-gfm-autolink-literal@1.0.3: + dependencies: + '@types/mdast': 3.0.15 + ccount: 2.0.1 + mdast-util-find-and-replace: 2.2.2 + micromark-util-character: 1.2.0 + + mdast-util-gfm-footnote@1.0.2: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-to-markdown: 1.5.0 + micromark-util-normalize-identifier: 1.1.0 + + mdast-util-gfm-strikethrough@1.0.3: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-to-markdown: 1.5.0 + + mdast-util-gfm-table@1.0.7: + dependencies: + '@types/mdast': 3.0.15 + markdown-table: 3.0.3 + mdast-util-from-markdown: 1.3.1 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-gfm-task-list-item@1.0.2: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-to-markdown: 1.5.0 + + mdast-util-gfm@2.0.2: + dependencies: + mdast-util-from-markdown: 1.3.1 + mdast-util-gfm-autolink-literal: 1.0.3 + mdast-util-gfm-footnote: 1.0.2 + mdast-util-gfm-strikethrough: 1.0.3 + mdast-util-gfm-table: 1.0.7 + mdast-util-gfm-task-list-item: 1.0.2 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-math@2.0.2: + dependencies: + '@types/mdast': 3.0.15 + longest-streak: 3.1.0 + mdast-util-to-markdown: 1.5.0 + + mdast-util-mdx-expression@1.3.2: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + mdast-util-from-markdown: 1.3.1 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-mdx-jsx@2.1.4: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + '@types/unist': 2.0.10 + ccount: 2.0.1 + mdast-util-from-markdown: 1.3.1 + mdast-util-to-markdown: 1.5.0 + parse-entities: 4.0.1 + stringify-entities: 4.0.4 + unist-util-remove-position: 4.0.2 + unist-util-stringify-position: 3.0.3 + vfile-message: 3.1.4 + transitivePeerDependencies: + - supports-color + + mdast-util-mdx@2.0.1: + dependencies: + mdast-util-from-markdown: 1.3.1 + mdast-util-mdx-expression: 1.3.2 + mdast-util-mdx-jsx: 2.1.4 + mdast-util-mdxjs-esm: 1.3.1 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-mdxjs-esm@1.3.1: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + mdast-util-from-markdown: 1.3.1 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-phrasing@3.0.1: + dependencies: + '@types/mdast': 3.0.15 + unist-util-is: 5.2.1 + + mdast-util-to-hast@12.3.0: + dependencies: + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + mdast-util-definitions: 5.1.2 + micromark-util-sanitize-uri: 1.2.0 + trim-lines: 3.0.1 + unist-util-generated: 2.0.1 + unist-util-position: 4.0.4 + unist-util-visit: 4.1.2 + + mdast-util-to-hast@13.1.0: + dependencies: + '@types/hast': 3.0.4 + '@types/mdast': 4.0.4 + '@ungap/structured-clone': 1.2.0 + devlop: 1.1.0 + micromark-util-sanitize-uri: 2.0.0 + trim-lines: 3.0.1 + unist-util-position: 5.0.0 + unist-util-visit: 5.0.0 + vfile: 6.0.1 + + mdast-util-to-markdown@1.5.0: + dependencies: + '@types/mdast': 3.0.15 + '@types/unist': 2.0.10 + longest-streak: 3.1.0 + mdast-util-phrasing: 3.0.1 + mdast-util-to-string: 3.2.0 + micromark-util-decode-string: 1.1.0 + unist-util-visit: 4.1.2 + zwitch: 2.0.4 + + mdast-util-to-string@3.2.0: + dependencies: + '@types/mdast': 3.0.15 + + mermaid@10.9.1: + dependencies: + '@braintree/sanitize-url': 6.0.4 + '@types/d3-scale': 4.0.8 + '@types/d3-scale-chromatic': 3.0.3 + cytoscape: 3.29.2 + cytoscape-cose-bilkent: 4.1.0(cytoscape@3.29.2) + d3: 7.9.0 + d3-sankey: 0.12.3 + dagre-d3-es: 7.0.10 + dayjs: 1.11.11 + dompurify: 3.1.5 + elkjs: 0.9.3 + katex: 0.16.10 + khroma: 2.1.0 + lodash-es: 4.17.21 + mdast-util-from-markdown: 1.3.1 + non-layered-tidy-tree-layout: 2.0.2 + stylis: 4.3.2 + ts-dedent: 2.2.0 + uuid: 9.0.1 + web-worker: 1.3.0 + transitivePeerDependencies: + - supports-color + + micromark-core-commonmark@1.1.0: + dependencies: + decode-named-character-reference: 1.0.2 + micromark-factory-destination: 1.1.0 + micromark-factory-label: 1.1.0 + micromark-factory-space: 1.1.0 + micromark-factory-title: 1.1.0 + micromark-factory-whitespace: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-chunked: 1.1.0 + micromark-util-classify-character: 1.1.0 + micromark-util-html-tag-name: 1.2.0 + micromark-util-normalize-identifier: 1.1.0 + micromark-util-resolve-all: 1.1.0 + micromark-util-subtokenize: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm-autolink-literal@1.0.5: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-sanitize-uri: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-extension-gfm-footnote@1.1.2: + dependencies: + micromark-core-commonmark: 1.1.0 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-normalize-identifier: 1.1.0 + micromark-util-sanitize-uri: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm-strikethrough@1.0.7: + dependencies: + micromark-util-chunked: 1.1.0 + micromark-util-classify-character: 1.1.0 + micromark-util-resolve-all: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm-table@1.0.7: + dependencies: + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm-tagfilter@1.0.2: + dependencies: + micromark-util-types: 1.1.0 + + micromark-extension-gfm-task-list-item@1.0.5: + dependencies: + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm@2.0.3: + dependencies: + micromark-extension-gfm-autolink-literal: 1.0.5 + micromark-extension-gfm-footnote: 1.1.2 + micromark-extension-gfm-strikethrough: 1.0.7 + micromark-extension-gfm-table: 1.0.7 + micromark-extension-gfm-tagfilter: 1.0.2 + micromark-extension-gfm-task-list-item: 1.0.5 + micromark-util-combine-extensions: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-extension-math@2.1.2: + dependencies: + '@types/katex': 0.16.7 + katex: 0.16.10 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-mdx-expression@1.0.8: + dependencies: + '@types/estree': 1.0.5 + micromark-factory-mdx-expression: 1.0.9 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-events-to-acorn: 1.2.3 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-mdx-jsx@1.0.5: + dependencies: + '@types/acorn': 4.0.6 + '@types/estree': 1.0.5 + estree-util-is-identifier-name: 2.1.0 + micromark-factory-mdx-expression: 1.0.9 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + vfile-message: 3.1.4 + + micromark-extension-mdx-md@1.0.1: + dependencies: + micromark-util-types: 1.1.0 + + micromark-extension-mdxjs-esm@1.0.5: + dependencies: + '@types/estree': 1.0.5 + micromark-core-commonmark: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-events-to-acorn: 1.2.3 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + unist-util-position-from-estree: 1.1.2 + uvu: 0.5.6 + vfile-message: 3.1.4 + + micromark-extension-mdxjs@1.0.1: + dependencies: + acorn: 8.11.3 + acorn-jsx: 5.3.2(acorn@8.11.3) + micromark-extension-mdx-expression: 1.0.8 + micromark-extension-mdx-jsx: 1.0.5 + micromark-extension-mdx-md: 1.0.1 + micromark-extension-mdxjs-esm: 1.0.5 + micromark-util-combine-extensions: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-factory-destination@1.1.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-factory-label@1.1.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-factory-mdx-expression@1.0.9: + dependencies: + '@types/estree': 1.0.5 + micromark-util-character: 1.2.0 + micromark-util-events-to-acorn: 1.2.3 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + unist-util-position-from-estree: 1.1.2 + uvu: 0.5.6 + vfile-message: 3.1.4 + + micromark-factory-space@1.1.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-types: 1.1.0 + + micromark-factory-title@1.1.0: + dependencies: + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-factory-whitespace@1.1.0: + dependencies: + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-util-character@1.2.0: + dependencies: + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-util-character@2.1.0: + dependencies: + micromark-util-symbol: 2.0.0 + micromark-util-types: 2.0.0 + + micromark-util-chunked@1.1.0: + dependencies: + micromark-util-symbol: 1.1.0 + + micromark-util-classify-character@1.1.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-util-combine-extensions@1.1.0: + dependencies: + micromark-util-chunked: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-util-decode-numeric-character-reference@1.1.0: + dependencies: + micromark-util-symbol: 1.1.0 + + micromark-util-decode-string@1.1.0: + dependencies: + decode-named-character-reference: 1.0.2 + micromark-util-character: 1.2.0 + micromark-util-decode-numeric-character-reference: 1.1.0 + micromark-util-symbol: 1.1.0 + + micromark-util-encode@1.1.0: {} + + micromark-util-encode@2.0.0: {} + + micromark-util-events-to-acorn@1.2.3: + dependencies: + '@types/acorn': 4.0.6 + '@types/estree': 1.0.5 + '@types/unist': 2.0.10 + estree-util-visit: 1.2.1 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + vfile-message: 3.1.4 + + micromark-util-html-tag-name@1.2.0: {} + + micromark-util-normalize-identifier@1.1.0: + dependencies: + micromark-util-symbol: 1.1.0 + + micromark-util-resolve-all@1.1.0: + dependencies: + micromark-util-types: 1.1.0 + + micromark-util-sanitize-uri@1.2.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-encode: 1.1.0 + micromark-util-symbol: 1.1.0 + + micromark-util-sanitize-uri@2.0.0: + dependencies: + micromark-util-character: 2.1.0 + micromark-util-encode: 2.0.0 + micromark-util-symbol: 2.0.0 + + micromark-util-subtokenize@1.1.0: + dependencies: + micromark-util-chunked: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-util-symbol@1.1.0: {} + + micromark-util-symbol@2.0.0: {} + + micromark-util-types@1.1.0: {} + + micromark-util-types@2.0.0: {} + + micromark@3.2.0: + dependencies: + '@types/debug': 4.1.12 + debug: 4.3.5 + decode-named-character-reference: 1.0.2 + micromark-core-commonmark: 1.1.0 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-chunked: 1.1.0 + micromark-util-combine-extensions: 1.1.0 + micromark-util-decode-numeric-character-reference: 1.1.0 + micromark-util-encode: 1.1.0 + micromark-util-normalize-identifier: 1.1.0 + micromark-util-resolve-all: 1.1.0 + micromark-util-sanitize-uri: 1.2.0 + micromark-util-subtokenize: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + transitivePeerDependencies: + - supports-color + + mri@1.2.0: {} + + ms@2.1.2: {} + + nanoid@3.3.7: {} + + next-mdx-remote@4.4.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + '@mdx-js/mdx': 2.3.0 + '@mdx-js/react': 2.3.0(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + vfile: 5.3.7 + vfile-matter: 3.0.1 + transitivePeerDependencies: + - supports-color + + next-seo@6.5.0(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + next: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + next-themes@0.2.1(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + next: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + '@next/env': 14.2.3 + '@swc/helpers': 0.5.5 + busboy: 1.6.0 + caniuse-lite: 1.0.30001629 + graceful-fs: 4.2.11 + postcss: 8.4.31 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + styled-jsx: 5.1.1(react@18.3.1) + optionalDependencies: + '@next/swc-darwin-arm64': 14.2.3 + '@next/swc-darwin-x64': 14.2.3 + '@next/swc-linux-arm64-gnu': 14.2.3 + '@next/swc-linux-arm64-musl': 14.2.3 + '@next/swc-linux-x64-gnu': 14.2.3 + '@next/swc-linux-x64-musl': 14.2.3 + '@next/swc-win32-arm64-msvc': 14.2.3 + '@next/swc-win32-ia32-msvc': 14.2.3 + '@next/swc-win32-x64-msvc': 14.2.3 + transitivePeerDependencies: + - '@babel/core' + - babel-plugin-macros + + nextra-theme-docs@2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(nextra@2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + '@headlessui/react': 1.7.19(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@popperjs/core': 2.11.8 + clsx: 2.1.1 + escape-string-regexp: 5.0.0 + flexsearch: 0.7.43 + focus-visible: 5.2.0 + git-url-parse: 13.1.1 + intersection-observer: 0.12.2 + match-sorter: 6.3.4 + next: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next-seo: 6.5.0(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next-themes: 0.2.1(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + nextra: 2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + scroll-into-view-if-needed: 3.1.0 + zod: 3.23.8 + + nextra@2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + '@headlessui/react': 1.7.19(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@mdx-js/mdx': 2.3.0 + '@mdx-js/react': 2.3.0(react@18.3.1) + '@napi-rs/simple-git': 0.1.16 + '@theguild/remark-mermaid': 0.0.5(react@18.3.1) + '@theguild/remark-npm2yarn': 0.2.1 + clsx: 2.1.1 + github-slugger: 2.0.0 + graceful-fs: 4.2.11 + gray-matter: 4.0.3 + katex: 0.16.10 + lodash.get: 4.4.2 + next: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next-mdx-remote: 4.4.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + p-limit: 3.1.0 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + rehype-katex: 7.0.0 + rehype-pretty-code: 0.9.11(shiki@0.14.7) + rehype-raw: 7.0.0 + remark-gfm: 3.0.1 + remark-math: 5.1.1 + remark-reading-time: 2.0.1 + shiki: 0.14.7 + slash: 3.0.0 + title: 3.5.3 + unist-util-remove: 4.0.0 + unist-util-visit: 5.0.0 + zod: 3.23.8 + transitivePeerDependencies: + - supports-color + + non-layered-tidy-tree-layout@2.0.2: {} + + npm-run-path@2.0.2: + dependencies: + path-key: 2.0.1 + + npm-to-yarn@2.2.1: {} + + p-finally@1.0.0: {} + + p-limit@3.1.0: + dependencies: + yocto-queue: 0.1.0 + + parse-entities@4.0.1: + dependencies: + '@types/unist': 2.0.10 + character-entities: 2.0.2 + character-entities-legacy: 3.0.0 + character-reference-invalid: 2.0.1 + decode-named-character-reference: 1.0.2 + is-alphanumerical: 2.0.1 + is-decimal: 2.0.1 + is-hexadecimal: 2.0.1 + + parse-numeric-range@1.3.0: {} + + parse-path@7.0.0: + dependencies: + protocols: 2.0.1 + + parse-url@8.1.0: + dependencies: + parse-path: 7.0.0 + + parse5@7.1.2: + dependencies: + entities: 4.5.0 + + path-key@2.0.1: {} + + periscopic@3.1.0: + dependencies: + '@types/estree': 1.0.5 + estree-walker: 3.0.3 + is-reference: 3.0.2 + + picocolors@1.0.1: {} + + postcss@8.4.31: + dependencies: + nanoid: 3.3.7 + picocolors: 1.0.1 + source-map-js: 1.2.0 + + property-information@6.5.0: {} + + protocols@2.0.1: {} + + pseudomap@1.0.2: {} + + react-dom@18.3.1(react@18.3.1): + dependencies: + loose-envify: 1.4.0 + react: 18.3.1 + scheduler: 0.23.2 + + react@18.3.1: + dependencies: + loose-envify: 1.4.0 + + reading-time@1.5.0: {} + + regenerator-runtime@0.14.1: {} + + rehype-katex@7.0.0: + dependencies: + '@types/hast': 3.0.4 + '@types/katex': 0.16.7 + hast-util-from-html-isomorphic: 2.0.0 + hast-util-to-text: 4.0.2 + katex: 0.16.10 + unist-util-visit-parents: 6.0.1 + vfile: 6.0.1 + + rehype-pretty-code@0.9.11(shiki@0.14.7): + dependencies: + '@types/hast': 2.3.10 + hash-obj: 4.0.0 + parse-numeric-range: 1.3.0 + shiki: 0.14.7 + + rehype-raw@7.0.0: + dependencies: + '@types/hast': 3.0.4 + hast-util-raw: 9.0.3 + vfile: 6.0.1 + + remark-gfm@3.0.1: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-gfm: 2.0.2 + micromark-extension-gfm: 2.0.3 + unified: 10.1.2 + transitivePeerDependencies: + - supports-color + + remark-math@5.1.1: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-math: 2.0.2 + micromark-extension-math: 2.1.2 + unified: 10.1.2 + + remark-mdx@2.3.0: + dependencies: + mdast-util-mdx: 2.0.1 + micromark-extension-mdxjs: 1.0.1 + transitivePeerDependencies: + - supports-color + + remark-parse@10.0.2: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-from-markdown: 1.3.1 + unified: 10.1.2 + transitivePeerDependencies: + - supports-color + + remark-reading-time@2.0.1: + dependencies: + estree-util-is-identifier-name: 2.1.0 + estree-util-value-to-estree: 1.3.0 + reading-time: 1.5.0 + unist-util-visit: 3.1.0 + + remark-rehype@10.1.0: + dependencies: + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + mdast-util-to-hast: 12.3.0 + unified: 10.1.2 + + remove-accents@0.5.0: {} + + robust-predicates@3.0.2: {} + + rw@1.3.3: {} + + sade@1.8.1: + dependencies: + mri: 1.2.0 + + safer-buffer@2.1.2: {} + + scheduler@0.23.2: + dependencies: + loose-envify: 1.4.0 + + scroll-into-view-if-needed@3.1.0: + dependencies: + compute-scroll-into-view: 3.1.0 + + section-matter@1.0.0: + dependencies: + extend-shallow: 2.0.1 + kind-of: 6.0.3 + + shebang-command@1.2.0: + dependencies: + shebang-regex: 1.0.0 + + shebang-regex@1.0.0: {} + + shiki@0.14.7: + dependencies: + ansi-sequence-parser: 1.1.1 + jsonc-parser: 3.2.1 + vscode-oniguruma: 1.7.0 + vscode-textmate: 8.0.0 + + signal-exit@3.0.7: {} + + slash@3.0.0: {} + + sort-keys@5.0.0: + dependencies: + is-plain-obj: 4.1.0 + + source-map-js@1.2.0: {} + + source-map@0.7.4: {} + + space-separated-tokens@2.0.2: {} + + sprintf-js@1.0.3: {} + + streamsearch@1.1.0: {} + + stringify-entities@4.0.4: + dependencies: + character-entities-html4: 2.1.0 + character-entities-legacy: 3.0.0 + + strip-bom-string@1.0.0: {} + + strip-eof@1.0.0: {} + + style-to-object@0.4.4: + dependencies: + inline-style-parser: 0.1.1 + + styled-jsx@5.1.1(react@18.3.1): + dependencies: + client-only: 0.0.1 + react: 18.3.1 + + stylis@4.3.2: {} + + supports-color@4.5.0: + dependencies: + has-flag: 2.0.0 + + title@3.5.3: + dependencies: + arg: 1.0.0 + chalk: 2.3.0 + clipboardy: 1.2.2 + titleize: 1.0.0 + + titleize@1.0.0: {} + + trim-lines@3.0.1: {} + + trough@2.2.0: {} + + ts-dedent@2.2.0: {} + + tslib@2.6.3: {} + + type-fest@1.4.0: {} + + typescript@5.4.5: {} + + undici-types@5.26.5: {} + + unified@10.1.2: + dependencies: + '@types/unist': 2.0.10 + bail: 2.0.2 + extend: 3.0.2 + is-buffer: 2.0.5 + is-plain-obj: 4.1.0 + trough: 2.2.0 + vfile: 5.3.7 + + unist-util-find-after@5.0.0: + dependencies: + '@types/unist': 3.0.2 + unist-util-is: 6.0.0 + + unist-util-generated@2.0.1: {} + + unist-util-is@5.2.1: + dependencies: + '@types/unist': 2.0.10 + + unist-util-is@6.0.0: + dependencies: + '@types/unist': 3.0.2 + + unist-util-position-from-estree@1.1.2: + dependencies: + '@types/unist': 2.0.10 + + unist-util-position@4.0.4: + dependencies: + '@types/unist': 2.0.10 + + unist-util-position@5.0.0: + dependencies: + '@types/unist': 3.0.2 + + unist-util-remove-position@4.0.2: + dependencies: + '@types/unist': 2.0.10 + unist-util-visit: 4.1.2 + + unist-util-remove-position@5.0.0: + dependencies: + '@types/unist': 3.0.2 + unist-util-visit: 5.0.0 + + unist-util-remove@4.0.0: + dependencies: + '@types/unist': 3.0.2 + unist-util-is: 6.0.0 + unist-util-visit-parents: 6.0.1 + + unist-util-stringify-position@3.0.3: + dependencies: + '@types/unist': 2.0.10 + + unist-util-stringify-position@4.0.0: + dependencies: + '@types/unist': 3.0.2 + + unist-util-visit-parents@4.1.1: + dependencies: + '@types/unist': 2.0.10 + unist-util-is: 5.2.1 + + unist-util-visit-parents@5.1.3: + dependencies: + '@types/unist': 2.0.10 + unist-util-is: 5.2.1 + + unist-util-visit-parents@6.0.1: + dependencies: + '@types/unist': 3.0.2 + unist-util-is: 6.0.0 + + unist-util-visit@3.1.0: + dependencies: + '@types/unist': 2.0.10 + unist-util-is: 5.2.1 + unist-util-visit-parents: 4.1.1 + + unist-util-visit@4.1.2: + dependencies: + '@types/unist': 2.0.10 + unist-util-is: 5.2.1 + unist-util-visit-parents: 5.1.3 + + unist-util-visit@5.0.0: + dependencies: + '@types/unist': 3.0.2 + unist-util-is: 6.0.0 + unist-util-visit-parents: 6.0.1 + + uuid@9.0.1: {} + + uvu@0.5.6: + dependencies: + dequal: 2.0.3 + diff: 5.2.0 + kleur: 4.1.5 + sade: 1.8.1 + + vfile-location@5.0.2: + dependencies: + '@types/unist': 3.0.2 + vfile: 6.0.1 + + vfile-matter@3.0.1: + dependencies: + '@types/js-yaml': 4.0.9 + is-buffer: 2.0.5 + js-yaml: 4.1.0 + + vfile-message@3.1.4: + dependencies: + '@types/unist': 2.0.10 + unist-util-stringify-position: 3.0.3 + + vfile-message@4.0.2: + dependencies: + '@types/unist': 3.0.2 + unist-util-stringify-position: 4.0.0 + + vfile@5.3.7: + dependencies: + '@types/unist': 2.0.10 + is-buffer: 2.0.5 + unist-util-stringify-position: 3.0.3 + vfile-message: 3.1.4 + + vfile@6.0.1: + dependencies: + '@types/unist': 3.0.2 + unist-util-stringify-position: 4.0.0 + vfile-message: 4.0.2 + + vscode-oniguruma@1.7.0: {} + + vscode-textmate@8.0.0: {} + + web-namespaces@2.0.1: {} + + web-worker@1.3.0: {} + + which@1.3.1: + dependencies: + isexe: 2.0.0 + + yallist@2.1.2: {} + + yocto-queue@0.1.0: {} + + zod@3.23.8: {} + + zwitch@2.0.4: {} diff --git a/docs/assets/banner.png b/docs/public/assets/banner.png similarity index 100% rename from docs/assets/banner.png rename to docs/public/assets/banner.png diff --git a/docs/assets/icon.png b/docs/public/assets/icon.png similarity index 100% rename from docs/assets/icon.png rename to docs/public/assets/icon.png diff --git a/docs/assets/sample-onnx-graph.png b/docs/public/assets/sample-onnx-graph.png similarity index 100% rename from docs/assets/sample-onnx-graph.png rename to docs/public/assets/sample-onnx-graph.png diff --git a/docs/assets/trend-banner.png b/docs/public/assets/trend-banner.png similarity index 100% rename from docs/assets/trend-banner.png rename to docs/public/assets/trend-banner.png diff --git a/docs/setup/linking.mdx b/docs/setup/linking.mdx deleted file mode 100644 index ecba449f..00000000 --- a/docs/setup/linking.mdx +++ /dev/null @@ -1,106 +0,0 @@ ---- -title: Linking -description: Here's how `ort` links to ONNX Runtime, and how to configure its behavior. ---- - -In some cases, you'll want to use a custom build of ONNX Runtime with `ort`. Luckily, we make this very easy by handling all of the linking configuration automagically. Just point `ort` to the output of ONNX Runtime's build pipeline and it'll Just Work™. - -## Static linking -Most ONNX Runtime compile configurations will support static linking - just run `build.sh` without the `--build_shared_lib` argument. You should prefer static linking if your execution providers support it, as it avoids many issues and follows de facto Rust practices. If you compile both static libraries and dynamic libraries, `ort` will prefer linking to the static libraries. - -To direct `ort` to your statically built binaries, use the `ORT_LIB_LOCATION` environment variable when running `cargo build`. Point it to the location where the static libraries (`.a`/`.lib` files) are compiled to. This will typically be `onnxruntime/build/`. For example: -```shell -$ ORT_LIB_LOCATION=~/onnxruntime/build/Linux cargo build -``` - -For iOS (or for other platforms if you are compiling multiple profiles at once), you'll need to manually specify the profile with the `ORT_LIB_PROFILE` environment variable. If not specified, `ort` will prefer `Release` over `RelWithDebInfo` over `MinSizeRel` over `Debug`. - -## Dynamic linking -Some execution providers unfortunately only support dynamic linking. Dynamic linking doesn't play well with the Rust ecosystem, though `ort` tries to alleviate the pain as much as possible. - -When it comes to dynamic linking, there are two options: `load-dynamic`, or standard compile-time dynamic linking. We recommend `load-dynamic` as it gives more control and is often far less troublesome to work with. - -### Runtime loading with `load-dynamic` -The `load-dynamic` Cargo feature solves a few of the issues with dynamic linking by **loading the library at runtime** rather than **linking at compile time**. This means that the path to the ONNX Runtime library can be configured at runtime, and the executable will not just completely fail to start if the binary couldn't be found. - -To use `load-dynamic`: - - - ```toml Cargo.toml - [dependencies] - ort = { version = "2", features = [ "load-dynamic" ] } - ``` - - - - - ```rust main.rs - fn main() -> anyhow::Result<()> { - // Find our custom ONNX Runtime dylib path somehow - // (i.e. resolving it from the root of our program's install folder) - let dylib_path = crate::internal::find_onnxruntime_dylib()?; - // The path should point to the `libonnxruntime` binary, which looks like: - // - on Unix: /etc/.../libonnxruntime.so - // - on Windows: C:\Program Files\...\onnxruntime.dll - - // Initialize ort with the path to the dylib. This **must** be called before any usage of `ort`! - // `init_from` returns an `EnvironmentBuilder` which you can use to further configure the environment - // before `.commit()`ing; see the Environment docs for more information on what you can configure. - ort::init_from(dylib_path).commit()?; - - Ok(()) - } - ``` - - - Set the `ORT_DYLIB_PATH` environment variable to the path to `libonnxruntime.so`/`onnxruntime.dll`. - - ```shell - $ ORT_DYLIB_PATH=../onnxruntime-build/linux-x64/libonnxruntime.so ./mirai - ``` - - - - - -`ORT_DYLIB_PATH` is relative to the executable. Cargo examples and tests are compiled to a different directory than binary crates: `target//examples` and `target//deps` respectively. Keep this in mind if you're going to use relative paths. - -### Compile-time dynamic linking -For compile-time dynamic linking, you'll need to configure your environment in the exact same way as if you were [statically linking](#static-linking). - -Note that the dylibs then have to be placed in a certain location for them to be found by the executable. For Windows, this is either somewhere on the `PATH`, or in the same folder as the executable. On macOS and Linux, they have to be placed somewhere in the `LD_LIBRARY_PATH`, or you can use rpath to configure the executable to search for dylibs in its parent folder. We've had the least issues with rpath, but YMMV. - -To configure rpath, you'll need to: - - - ```toml - [profile.dev] - rpath = true - - [profile.release] - rpath = true - - # do this for any other profiles - ``` - - - - - ```toml - [target.x86_64-unknown-linux-gnu] - rustflags = [ "-Clink-args=-Wl,-rpath,\\$ORIGIN" ] - - # do this for any other Linux targets as well - ``` - - - ```toml - [target.x86_64-apple-darwin] - rustflags = [ "-Clink-args=-Wl,-rpath,@loader_path" ] - - # do this for any other macOS targets as well - ``` - - - - diff --git a/docs/theme.config.jsx b/docs/theme.config.jsx new file mode 100644 index 00000000..ef71c4eb --- /dev/null +++ b/docs/theme.config.jsx @@ -0,0 +1,33 @@ +import Image from 'next/image'; + +/** @type {import('nextra-theme-docs').DocsThemeConfig} */ +const config = { + project: { + link: 'https://github.com/pykeio/ort' + }, + chat: { + link: 'https://discord.gg/uQtsNu2xMa' + }, + docsRepositoryBase: 'https://github.com/pykeio/ort/blob/main/docs', + useNextSeoProps() { + return { + titleTemplate: '%s | ort' + } + }, + logo: , + darkMode: true, + nextThemes: { + defaultTheme: 'system' + }, + footer: { + text:
+

made with 💜 by pykesponsor

+
+ }, + primaryHue: 20, + primarySaturation: 100, + toc: { + float: true + } +}; +export default config; diff --git a/docs/tsconfig.json b/docs/tsconfig.json new file mode 100644 index 00000000..19deeffc --- /dev/null +++ b/docs/tsconfig.json @@ -0,0 +1,28 @@ +{ + "compilerOptions": { + "lib": [ + "dom", + "dom.iterable", + "esnext" + ], + "allowJs": true, + "skipLibCheck": true, + "strict": false, + "noEmit": true, + "incremental": true, + "module": "esnext", + "esModuleInterop": true, + "moduleResolution": "node", + "resolveJsonModule": true, + "isolatedModules": true, + "jsx": "preserve" + }, + "include": [ + "next-env.d.ts", + "**/*.ts", + "**/*.tsx" +, "pages/_app.mdx" ], + "exclude": [ + "node_modules" + ] +} diff --git a/examples/cudarc/src/main.rs b/examples/cudarc/src/main.rs index 1ffc01f0..20013a9f 100644 --- a/examples/cudarc/src/main.rs +++ b/examples/cudarc/src/main.rs @@ -11,7 +11,7 @@ fn main() -> anyhow::Result<()> { tracing_subscriber::fmt::init(); ort::init() - .with_execution_providers([CUDAExecutionProvider::default().build()]) + .with_execution_providers([CUDAExecutionProvider::default().build().error_on_failure()]) .commit()?; let model = diff --git a/examples/custom-ops/examples/custom-ops.rs b/examples/custom-ops/examples/custom-ops.rs index 2d590f0c..1206860c 100644 --- a/examples/custom-ops/examples/custom-ops.rs +++ b/examples/custom-ops/examples/custom-ops.rs @@ -82,7 +82,7 @@ impl Kernel for CustomOpTwoKernel { fn main() -> ort::Result<()> { let session = Session::builder()? - .with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)? + .with_operators(OperatorDomain::new("test.customop")?.add::()?.add::()?)? .commit_from_file("tests/data/custom_op_test.onnx")?; let values = session.run(ort::inputs![Array2::::zeros((3, 5)), Array2::::ones((3, 5))]?)?; diff --git a/examples/training/Cargo.toml b/examples/training/Cargo.toml new file mode 100644 index 00000000..69679924 --- /dev/null +++ b/examples/training/Cargo.toml @@ -0,0 +1,18 @@ +[package] +publish = false +name = "example-training" +version = "0.0.0" +edition = "2021" + +[dependencies] +voicevox-ort = { path = "../../", features = [ "training" ] } +ndarray = "0.15" +tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] } +rand = "0.8" +simd-json = "0.13" +kdam = "0.5" +tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] } + +[features] +load-dynamic = [ "voicevox-ort/load-dynamic" ] +cuda = [ "voicevox-ort/cuda" ] diff --git a/examples/training/README.md b/examples/training/README.md new file mode 100644 index 00000000..7c99d643 --- /dev/null +++ b/examples/training/README.md @@ -0,0 +1,26 @@ +# Training Examples + +## `train-clm` +This example trains a tiny causal language model on a small subset of pyke's [**OshiChats v2**](https://huggingface.co/datasets/pykeio/oshichats-v2), a dataset of live text chat messages collected from various [VTuber](https://en.wikipedia.org/wiki/VTuber) live streams. The model is not particularly useful or interesting (due to both the low-quality dataset and small model size), but it showcases that entire language models can be trained from scratch entirely in Rust on (almost) any device. + +To get started, create a Python virtual environment and install the following packages: +``` +pip install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT/pypi/simple/ onnxruntime-training-cpu==1.18.0 onnx~=1.17 torch~=2.3 +``` + +We're installing the CPU version of the `onnxruntime-training` & `torch` packages because we only need to use Python to *create* the initial graph which will be used for training. Run `python tools/train-data/mini-clm.py` from the root directory of the `ort` repo to create the training artifacts. + +Next, we need to convert our dataset into tokens to feed the model. This can be achieved by downloading the `oshicats-v2.jsonl` file from the OshiChats v2 dataset and running `cargo run -p example-training --example pretokenize -- ~/oshichats-v2.jsonl`, or if you (rightfully) don't wish to waste 30 GB worth of disk space and bandwidth on brainrot, you may download a [1 MB pre-tokenized subset of the dataset](https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_data/dataset.bin). Make sure `dataset.bin` is in the root of the `ort` repo. + +Finally, we can train our model! Run `cargo run -p example-training --example train-clm` to start training. If you have an NVIDIA GPU, add `--features cuda` to enable CUDA, though it's not required and you can train directly on CPU instead. **This will use ~8 GB of (V)RAM!** You can lower the memory usage by adjusting the `BATCH_SIZE` and `SEQUENCE_LENGTH` constants in `train-clm.rs`, though note that changing the batch size may require adjustments to the learning rate. + +While training, the progress bar will show the cross-entropy loss at each training step. At the end of training, the final trained model will be saved to `trained-clm.onnx`, and the program will use the model to generate a small snippet of text: +``` +100%|██████████████████████████████████████| 5000/5000 [06:29<00:00, 12.83it/s, loss=3.611] +I'm so much better than the game<|endoftext|>I think you can't see it<|endoftext|>I think you can't see it<|endoftext|>I think so it's a new game<|endoftext|>I think I'm sure you can't see what you can't see it<|endoftext|> +``` + +Not bad, considering the model & dataset size! This example can easily be scaled up to pre-train or fine-tune (both full-parameter and PEFT) larger language models like Llama/Phi, so long as you have enough compute. + +## `train-clm-simple` +This example is functionally identical to `train-clm`, except it uses ort's "simple" Trainer API instead of implementing a manual training loop. The simple API is more akin to 🤗 Transformer's [`Trainer`](https://huggingface.co/docs/transformers/en/main_classes/trainer) API or [PyTorch Lightning](https://lightning.ai/pytorch-lightning). With the simple API, all you have to do is pass a data loader & parameters, and let `ort` handle training for you! diff --git a/examples/training/build.rs b/examples/training/build.rs new file mode 100644 index 00000000..79d3a0bb --- /dev/null +++ b/examples/training/build.rs @@ -0,0 +1,5 @@ +fn main() { + // Need this for CoreML. See: https://ort.pyke.io/perf/execution-providers#coreml + #[cfg(target_os = "macos")] + println!("cargo:rustc-link-arg=-fapple-link-rtlib"); +} diff --git a/examples/training/examples/pretokenize.rs b/examples/training/examples/pretokenize.rs new file mode 100644 index 00000000..79eee195 --- /dev/null +++ b/examples/training/examples/pretokenize.rs @@ -0,0 +1,44 @@ +use std::{ + env, + fs::File, + io::{BufRead, BufReader, BufWriter, Write}, + path::Path +}; + +use simd_json::derived::ValueObjectAccessAsScalar; +use tokenizers::Tokenizer; + +const MAX_TOKENS: usize = 500_000; + +fn main() { + let input = env::args().nth(1).expect("provide input jsonl"); + let output = env::args().nth(2).unwrap_or_else(|| "dataset.bin".into()); + + let input = BufReader::new(File::open(input).unwrap()); + let mut output = BufWriter::new(File::create(output).unwrap()); + + let tokenizer = Tokenizer::from_file( + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("gpt2") + .join("data") + .join("tokenizer.json") + ) + .unwrap(); + let mut bytes_written = 0; + + for line in input.lines() { + let line: simd_json::OwnedValue = unsafe { simd_json::from_str(&mut line.unwrap()).unwrap() }; + let tokenized = tokenizer + .encode(format!("<|endoftext|>{}", line.get_str("message").unwrap()), false) + .unwrap(); + let id_bytes: Vec = tokenized.get_ids().iter().flat_map(|c| (*c as u16).to_le_bytes()).collect(); + output.write_all(&id_bytes).unwrap(); + bytes_written += id_bytes.len(); + if bytes_written >= MAX_TOKENS * 2 { + output.flush().unwrap(); + break; + } + } +} diff --git a/examples/training/examples/train-clm-simple.rs b/examples/training/examples/train-clm-simple.rs new file mode 100644 index 00000000..0c3ac326 --- /dev/null +++ b/examples/training/examples/train-clm-simple.rs @@ -0,0 +1,118 @@ +use std::{ + fs::File, + io::{Read, Seek, SeekFrom, Write}, + path::Path +}; + +use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis}; +use ort::{Allocator, CUDAExecutionProvider, CheckpointStrategy, Session, SessionBuilder, Trainer, TrainingArguments}; +use rand::RngCore; +use tokenizers::Tokenizer; + +const BATCH_SIZE: usize = 16; +const SEQUENCE_LENGTH: usize = 256; + +fn main() -> ort::Result<()> { + tracing_subscriber::fmt::init(); + + ort::init().commit()?; + + let trainer = Trainer::new_from_artifacts( + SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?, + Allocator::default(), + "tools/train-data/mini-clm", + None + )?; + + let tokenizer = Tokenizer::from_file( + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("gpt2") + .join("data") + .join("tokenizer.json") + ) + .unwrap(); + + let mut dataset = File::open("dataset.bin").unwrap(); + let file_size = dataset.metadata().unwrap().len(); + let num_tokens = (file_size / 2) as usize; // 16-bit tokens + let mut rng = rand::thread_rng(); + let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let dataloader = move |_: usize| { + for batch in 0..BATCH_SIZE { + let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64; + dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + } + + Ok(( + ort::inputs![Array2::::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap()]?, + ort::inputs![Array1::::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap()]? + )) + }; + + trainer.train( + TrainingArguments::new(dataloader) + .with_lr(7e-5) + .with_max_steps(5000) + .with_ckpt_strategy(CheckpointStrategy::Steps(500)) + )?; + + trainer.export("trained-clm.onnx", ["probs"])?; + + let session = Session::builder()?.commit_from_file("trained-clm.onnx")?; + + let mut stdout = std::io::stdout(); + + let tokens = tokenizer.encode("<|endoftext|>", false).unwrap(); + let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); + + let mut tokens = Array1::from_iter(tokens.iter().cloned()); + + for _ in 0..50 { + let array = tokens.view().insert_axis(Axis(0)); + let outputs = session.run(ort::inputs![array]?)?; + let generated_tokens: ArrayViewD = outputs["probs"].try_extract_tensor()?; + + let probabilities = &mut generated_tokens + .slice(s![-1, ..]) + .to_owned() + .iter() + .cloned() + .enumerate() + .collect::>(); + probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less)); + + let token = probabilities[0].0; + tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]]; + + let token_str = tokenizer.decode(&[token as _], false).unwrap(); + print!("{}", token_str); + stdout.flush().unwrap(); + } + + println!(); + Ok(()) +} diff --git a/examples/training/examples/train-clm.rs b/examples/training/examples/train-clm.rs new file mode 100644 index 00000000..9e46bf44 --- /dev/null +++ b/examples/training/examples/train-clm.rs @@ -0,0 +1,133 @@ +use std::{ + fs::File, + io::{Read, Seek, SeekFrom, Write}, + path::Path +}; + +use kdam::BarExt; +use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis}; +use ort::{Allocator, CUDAExecutionProvider, Checkpoint, Session, SessionBuilder, Trainer}; +use rand::RngCore; +use tokenizers::Tokenizer; + +const BATCH_SIZE: usize = 16; +const SEQUENCE_LENGTH: usize = 256; + +fn main() -> ort::Result<()> { + tracing_subscriber::fmt::init(); + + ort::init().commit()?; + + kdam::term::init(true); + let _ = kdam::term::hide_cursor(); + + let trainer = Trainer::new( + SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?, + Allocator::default(), + Checkpoint::load("tools/train-data/mini-clm/checkpoint")?, + "tools/train-data/mini-clm/training_model.onnx", + "tools/train-data/mini-clm/eval_model.onnx", + "tools/train-data/mini-clm/optimizer_model.onnx" + )?; + + let tokenizer = Tokenizer::from_file( + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("gpt2") + .join("data") + .join("tokenizer.json") + ) + .unwrap(); + + let optimizer = trainer.optimizer(); + optimizer.set_lr(7e-5)?; + + let mut dataset = File::open("dataset.bin").unwrap(); + let file_size = dataset.metadata().unwrap().len(); + let num_tokens = (file_size / 2) as usize; // 16-bit tokens + let mut rng = rand::thread_rng(); + + let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let mut pb = kdam::tqdm!(total = 5000); + for _ in 0..5000 { + for batch in 0..BATCH_SIZE { + let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64; + dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + } + + let inputs = Array2::::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap(); + let labels = Array1::::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap(); + + let outputs = trainer.step(ort::inputs![inputs.view()]?, ort::inputs![labels.view()]?)?; + let loss = outputs[0].try_extract_scalar::()?; + pb.set_postfix(format!("loss={loss:.3}")); + pb.update(1).unwrap(); + if loss.is_nan() { + return Ok(()); + } + optimizer.step()?; + optimizer.reset_grad()?; + } + + eprintln!(); + let _ = kdam::term::show_cursor(); + + trainer.export("trained-clm.onnx", ["probs"])?; + + let session = Session::builder()?.commit_from_file("trained-clm.onnx")?; + + let mut stdout = std::io::stdout(); + + let tokens = tokenizer.encode("<|endoftext|>", false).unwrap(); + let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); + + let mut tokens = Array1::from_iter(tokens.iter().cloned()); + + for _ in 0..50 { + let array = tokens.view().insert_axis(Axis(0)); + let outputs = session.run(ort::inputs![array]?)?; + let generated_tokens: ArrayViewD = outputs["probs"].try_extract_tensor()?; + + let probabilities = &mut generated_tokens + .slice(s![-1, ..]) + .to_owned() + .iter() + .cloned() + .enumerate() + .collect::>(); + probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less)); + + let token = probabilities[0].0; + tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]]; + + let token_str = tokenizer.decode(&[token as _], false).unwrap(); + print!("{}", token_str); + stdout.flush().unwrap(); + } + + println!(); + Ok(()) +} diff --git a/examples/webassembly/Cargo.toml b/examples/webassembly/Cargo.toml index de8e675f..75692990 100644 --- a/examples/webassembly/Cargo.toml +++ b/examples/webassembly/Cargo.toml @@ -16,6 +16,7 @@ web-sys = "0.3" tracing = "0.1" tracing-subscriber = "0.3" tracing-subscriber-wasm = "0.1" +image = { version = "0.25", default-features = false, features = [ "jpeg" ]} [dev-dependencies] wasm-bindgen-test = "0.3" diff --git a/examples/webassembly/src/lib.rs b/examples/webassembly/src/lib.rs index 589bc11d..dbc0d73a 100644 --- a/examples/webassembly/src/lib.rs +++ b/examples/webassembly/src/lib.rs @@ -1,22 +1,42 @@ -use ndarray::{Array4, ArrayViewD}; -use ort::Session; +use image::{ImageBuffer, Luma, Pixel}; +use ort::{ArrayExtensions, Session}; use wasm_bindgen::prelude::*; -static MODEL_BYTES: &[u8] = include_bytes!("upsample.ort"); +static IMAGE_BYTES: &[u8] = include_bytes!("../../../tests/data/mnist_5.jpg"); +static MODEL_BYTES: &[u8] = include_bytes!("mnist.ort"); pub fn upsample_inner() -> ort::Result<()> { let session = Session::builder()? .commit_from_memory_directly(MODEL_BYTES) .expect("Could not read model from memory"); - let array = Array4::::zeros((1, 224, 224, 3)); + // NOTE: An earlier nightly version of Rust 1.78 includes a patch required for ONNX Runtime to link properly, but a + // later version enables debug assertions in `dlmalloc`, which surfaces an allocation bug in the `image` crate: + // https://github.com/rustwasm/wasm-pack/issues/1389 Because of this, using `image::load_from_memory` crashes. + // For demonstration purposes, we're replacing the image loading code shown below with zeros(). In a real application, + // you can get the image from another source, like an HTML canvas. + // + // let image_buffer: ImageBuffer, Vec> = image::load_from_memory(IMAGE_BYTES).unwrap().to_luma8(); + // let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { + // let pixel = image_buffer.get_pixel(i as u32, j as u32); + // let channels = pixel.channels(); + // (channels[c] as f32) / 255.0 + // }); + let array = ndarray::Array4::::zeros((1, 1, 28, 28)); let outputs = session.run(ort::inputs![array]?)?; - assert_eq!(outputs.len(), 1); - let output: ArrayViewD = outputs[0].try_extract_tensor()?; + let mut probabilities: Vec<(usize, f32)> = outputs[0] + .try_extract_tensor()? + .softmax(ndarray::Axis(1)) + .iter() + .copied() + .enumerate() + .collect::>(); - assert_eq!(output.shape(), [1, 448, 448, 3]); + probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + + assert_eq!(probabilities[0].0, 5, "Expecting class '5' (got {})", probabilities[0].0); Ok(()) } diff --git a/examples/webassembly/src/mnist.ort b/examples/webassembly/src/mnist.ort new file mode 100644 index 00000000..18458997 Binary files /dev/null and b/examples/webassembly/src/mnist.ort differ diff --git a/examples/webassembly/src/upsample.ort b/examples/webassembly/src/upsample.ort deleted file mode 100644 index b3e43d00..00000000 Binary files a/examples/webassembly/src/upsample.ort and /dev/null differ diff --git a/ort-sys/Cargo.toml b/ort-sys/Cargo.toml index 36e7fc28..a7358228 100644 --- a/ort-sys/Cargo.toml +++ b/ort-sys/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "voicevox-ort-sys" -description = "Unsafe Rust bindings for ONNX Runtime 1.17 - Optimize and Accelerate Machine Learning Inferencing" -version = "2.0.0-rc.2" +description = "Unsafe Rust bindings for ONNX Runtime 1.18 - Optimize and Accelerate Machine Learning Inferencing" +version = "2.0.0-rc.4" edition = "2021" rust-version = "1.70" license = "MIT OR Apache-2.0" @@ -19,6 +19,7 @@ name = "ort_sys" [features] default = [] +training = [] download-binaries = [ "ureq", "tar", "flate2", "sha2" ] load-dynamic = [] copy-dylibs = [] @@ -41,9 +42,9 @@ vitis = [] cann = [] qnn = [] - [build-dependencies] -ureq = { version = "2.1", optional = true, default-features = false, features = [ "tls" ] } +ureq = { version = "2.1", optional = true, default-features = false, features = [ "tls", "socks-proxy" ] } tar = { version = "0.4", optional = true } flate2 = { version = "1.0", optional = true } sha2 = { version = "0.10", optional = true } +pkg-config = "0.3.30" diff --git a/ort-sys/VERSION_NUMBER b/ort-sys/VERSION_NUMBER index b9a05a6d..ec6d649b 100644 --- a/ort-sys/VERSION_NUMBER +++ b/ort-sys/VERSION_NUMBER @@ -1 +1 @@ -1.17.3 +1.18.1 diff --git a/ort-sys/build.rs b/ort-sys/build.rs index 8ceefd79..e5ce7fae 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -1,8 +1,12 @@ use std::{ env, fs, - path::{Path, PathBuf} + path::{Path, PathBuf}, + process::Command }; +#[allow(unused)] +const ONNXRUNTIME_VERSION: &str = "1.18.1"; + const ORT_ENV_SYSTEM_LIB_LOCATION: &str = "ORT_LIB_LOCATION"; const ORT_ENV_SYSTEM_LIB_PROFILE: &str = "ORT_LIB_PROFILE"; @@ -35,38 +39,15 @@ fn fetch_file(source_url: &str) -> Vec { buffer } -fn find_dist(target: &str, designator: &str) -> Option<(&'static str, &'static str)> { +fn find_dist(target: &str, feature_set: &str) -> Option<(&'static str, &'static str)> { DIST_TABLE .split('\n') .filter(|c| !c.is_empty() && !c.starts_with('#')) .map(|c| c.split('\t').collect::>()) - .find(|c| c[0] == designator && c[1] == target) + .find(|c| c[0] == feature_set && c[1] == target) .map(|c| (c[2], c[3])) } -fn lib_exists(name: &str) -> bool { - #[cfg(any(target_family = "windows", unix))] - let lib_str = std::ffi::CString::new(name).unwrap(); - // note that we're not performing any cleanup here because this is a short lived build script; the OS will clean it up - // for us when we finish - #[cfg(target_family = "windows")] - return unsafe { - extern "C" { - fn LoadLibraryA(lplibfilename: *const std::ffi::c_char) -> isize; - } - LoadLibraryA(lib_str.as_ptr()) != 0 - }; - #[cfg(unix)] - return unsafe { - extern "C" { - fn dlopen(file: *const std::ffi::c_char, mode: std::ffi::c_int) -> *const std::ffi::c_void; - } - !dlopen(lib_str.as_ptr(), 1).is_null() - }; - #[cfg(not(any(target_family = "windows", unix)))] - return false; -} - #[cfg(feature = "download-binaries")] fn hex_str_to_bytes(c: impl AsRef<[u8]>) -> Vec { fn nibble(c: u8) -> u8 { @@ -134,7 +115,7 @@ fn copy_libraries(lib_dir: &Path, out_dir: &Path) { #[cfg(target_os = "linux")] { let main_dy = lib_dir.join("libonnxruntime.so"); - let versioned_dy = out_dir.join("libonnxruntime.so.1.17.3"); + let versioned_dy = out_dir.join(format!("libonnxruntime.so.{}", ONNXRUNTIME_VERSION)); if main_dy.exists() && !versioned_dy.exists() { if versioned_dy.is_symlink() { fs::remove_file(&versioned_dy).unwrap(); @@ -255,8 +236,12 @@ fn prepare_libort_dir() -> (PathBuf, bool) { println!("cargo:rustc-link-lib=static=onnx"); println!("cargo:rustc-link-lib=static=onnx_proto"); - add_search_dir(transform_dep(external_lib_dir.join("google_nsync-build"), &profile)); - println!("cargo:rustc-link-lib=static=nsync_cpp"); + let nsync_path = transform_dep(external_lib_dir.join("google_nsync-build"), &profile); + // some builds of ONNX Runtime, particularly the default no-EP windows build, don't require nsync + if nsync_path.exists() { + add_search_dir(nsync_path); + println!("cargo:rustc-link-lib=static=nsync_cpp"); + } if target_arch != "wasm32" { add_search_dir(transform_dep(external_lib_dir.join("pytorch_cpuinfo-build"), &profile)); @@ -275,6 +260,9 @@ fn prepare_libort_dir() -> (PathBuf, bool) { add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("base"), &profile)); println!("cargo:rustc-link-lib=static=absl_base"); + println!("cargo:rustc-link-lib=static=absl_spinlock_wait"); + println!("cargo:rustc-link-lib=static=absl_malloc_internal"); + println!("cargo:rustc-link-lib=static=absl_raw_logging_internal"); println!("cargo:rustc-link-lib=static=absl_throw_delegate"); add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("hash"), &profile)); println!("cargo:rustc-link-lib=static=absl_hash"); @@ -282,6 +270,23 @@ fn prepare_libort_dir() -> (PathBuf, bool) { println!("cargo:rustc-link-lib=static=absl_low_level_hash"); add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("container"), &profile)); println!("cargo:rustc-link-lib=static=absl_raw_hash_set"); + add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("synchronization"), &profile)); + println!("cargo:rustc-link-lib=static=absl_kernel_timeout_internal"); + println!("cargo:rustc-link-lib=static=absl_graphcycles_internal"); + println!("cargo:rustc-link-lib=static=absl_synchronization"); + add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("time"), &profile)); + println!("cargo:rustc-link-lib=static=absl_time_zone"); + println!("cargo:rustc-link-lib=static=absl_time"); + add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("numeric"), &profile)); + println!("cargo:rustc-link-lib=static=absl_int128"); + add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("strings"), &profile)); + println!("cargo:rustc-link-lib=static=absl_str_format_internal"); + println!("cargo:rustc-link-lib=static=absl_strings"); + println!("cargo:rustc-link-lib=static=absl_string_view"); + println!("cargo:rustc-link-lib=static=absl_strings_internal"); + add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("debugging"), &profile)); + println!("cargo:rustc-link-lib=static=absl_symbolize"); + println!("cargo:rustc-link-lib=static=absl_stacktrace"); if cfg!(feature = "coreml") && (target_os == "macos" || target_os == "ios") { println!("cargo:rustc-link-lib=framework=CoreML"); @@ -334,31 +339,75 @@ fn prepare_libort_dir() -> (PathBuf, bool) { compile_error!("unsupported EP"); let target = env::var("TARGET").unwrap().to_string(); - let designator = if cfg!(any(feature = "cuda", feature = "tensorrt")) { - if lib_exists("cudart64_12.dll") || lib_exists("libcudart.so.12") { "cu12" } else { "cu11" } + + let mut feature_set = Vec::new(); + if cfg!(feature = "training") { + feature_set.push("train"); + } + if cfg!(any(feature = "cuda", feature = "tensorrt")) { + // pytorch's CUDA docker images set `NV_CUDNN_VERSION` + let cu12_tag = match env::var("NV_CUDNN_VERSION").or_else(|_| env::var("ORT_CUDNN_VERSION")).as_deref() { + Ok(v) => { + if v.starts_with("8") { + "cu12+cudnn8" + } else { + "cu12" + } + } + Err(_) => "cu12" + }; + + match env::var("ORT_DFBIN_FORCE_CUDA_VERSION").as_deref() { + Ok("11") => feature_set.push("cu11"), + Ok("12") => feature_set.push("cu12"), + _ => { + let mut success = false; + if let Ok(nvcc_output) = Command::new("nvcc").arg("--version").output() { + if nvcc_output.status.success() { + let stdout = String::from_utf8_lossy(&nvcc_output.stdout); + let version_line = stdout.lines().nth(3).unwrap(); + let release_section = version_line.split(", ").nth(1).unwrap(); + let version_number = release_section.split(' ').nth(1).unwrap(); + if version_number.starts_with("12") { + feature_set.push(cu12_tag); + } else { + feature_set.push("cu11"); + } + success = true; + } + } + + if !success { + println!("cargo:warning=nvcc call did not succeed. falling back to CUDA 12"); + // fallback to CUDA 12. + feature_set.push(cu12_tag); + } + } + } } else if cfg!(feature = "rocm") { - "rocm" - } else { - "none" - }; - let _ = designator; // 上記のものを無視する - let designator = if cfg!(feature = "directml") { + feature_set.push("rocm"); + } + let feature_set = if !feature_set.is_empty() { feature_set.join(",") } else { "none".to_owned() }; + let _ = feature_set; // 上記のものを無視する + let feature_set = if cfg!(feature = "directml") { "directml" } else if cfg!(feature = "cuda") { - "cu12" // ビルド環境に何がインストールされていようが、常にCUDA 12を使う + "cu12" // ビルド環境に何がインストールされていようが、常にCUDA 12とcuDNN 9を使う } else { "none" - }; - let mut dist = find_dist(&target, designator); - if dist.is_none() && designator != "none" { + } + .to_owned(); + println!("selected feature set: {feature_set}"); + let mut dist = find_dist(&target, &feature_set); + if dist.is_none() && feature_set != "none" { dist = find_dist(&target, "none"); } if dist.is_none() { panic!( "downloaded binaries not available for target {target}{}\nyou may have to compile ONNX Runtime from source", - if designator != "none" { - format!(" (note: also requested `{designator}`)") + if feature_set != "none" { + format!(" (note: also requested features `{feature_set}`)") } else { String::new() } @@ -406,6 +455,36 @@ fn prepare_libort_dir() -> (PathBuf, bool) { } } +fn try_setup_with_pkg_config() -> bool { + match pkg_config::Config::new().probe("libonnxruntime") { + Ok(lib) => { + let expected_minor = ONNXRUNTIME_VERSION.split('.').nth(1).unwrap().parse::().unwrap(); + let got_minor = lib.version.split('.').nth(1).unwrap().parse::().unwrap(); + if got_minor < expected_minor { + println!("libonnxruntime provided by pkg-config is out of date, so it will be ignored - expected {}, got {}", ONNXRUNTIME_VERSION, lib.version); + return false; + } + + // Setting the link paths + for path in lib.link_paths { + println!("cargo:rustc-link-search=native={}", path.display()); + } + + // Setting the libraries to link against + for lib in lib.libs { + println!("cargo:rustc-link-lib={}", lib); + } + + println!("Using onnxruntime found by pkg-config."); + true + } + Err(_) => { + println!("onnxruntime not found using pkg-config, falling back to manual setup."); + false + } + } +} + fn real_main(link: bool) { println!("cargo:rerun-if-env-changed={}", ORT_ENV_SYSTEM_LIB_LOCATION); println!("cargo:rerun-if-env-changed={}", ORT_ENV_SYSTEM_LIB_PROFILE); @@ -416,7 +495,15 @@ fn real_main(link: bool) { if link { if needs_link { - println!("cargo:rustc-link-lib=onnxruntime"); + let target_os = env::var("CARGO_CFG_TARGET_OS").unwrap(); + let static_lib_file_name = if target_os.contains("windows") { "onnxruntime.lib" } else { "libonnxruntime.a" }; + + let static_lib_path = lib_dir.join(static_lib_file_name); + if static_lib_path.exists() { + println!("cargo:rustc-link-lib=static=onnxruntime"); + } else { + println!("cargo:rustc-link-lib=onnxruntime"); + } println!("cargo:rustc-link-search=native={}", lib_dir.display()); } @@ -436,14 +523,20 @@ fn main() { } if cfg!(feature = "load-dynamic") { - // we only need to execute the real main step if we are using the download strategy... - if cfg!(feature = "download-binaries") && env::var(ORT_ENV_SYSTEM_LIB_LOCATION).is_err() { - // but we don't need to link to the binaries we download (so all we are doing is downloading them and placing them in - // the output directory) - real_main(false); + if !try_setup_with_pkg_config() { + // Only execute the real main step if pkg-config fails and if we are using the download + // strategy + if cfg!(feature = "download-binaries") && env::var(ORT_ENV_SYSTEM_LIB_LOCATION).is_err() { + // but we don't need to link to the binaries we download (so all we are doing is + // downloading them and placing them in the output directory) + real_main(false); // but we don't need to link to the binaries we download + } } } else { - // if we are not using the load-dynamic feature then we need to link to dylibs. - real_main(true); + // If pkg-config setup was successful, we don't need further action + // Otherwise, if we are not using the load-dynamic feature, we need to link to the dylibs. + if !try_setup_with_pkg_config() { + real_main(true); + } } } diff --git a/ort-sys/dist.txt b/ort-sys/dist.txt index 7530f6fe..ff7f30c3 100644 --- a/ort-sys/dist.txt +++ b/ort-sys/dist.txt @@ -1,30 +1,15 @@ -none aarch64-apple-darwin https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-osx-arm64-1.17.3.tgz A11F02FB263783C4202E11E41199D88C510C66F375C092259BC73FA8F25DEB7D -#none aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_static-v1.17.3-aarch64-pc-windows-msvc.tgz 7BCECBBC15F64C631051C894C5044FB01658F171D7B5D4E9635D7D7F464B8B3E -none aarch64-unknown-linux-gnu https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-linux-arm64-1.17.3.tgz 5176F952D3C694825D2408BE37C0D13463631D2B3E956F6BFB84EAEE931726AC - -none armv7-unknown-linux-gnueabihf https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-linux-armhf-1.17.3.tgz F1A0A14FB47B6904EEB38BEE7D550B858788784765684CF7928AA11136354FC7 - -none x86_64-apple-darwin https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-osx-x86_64-1.17.3.tgz 14A29904F236B8EF90835422FEA985E6892A94BA8B0E660325025F1DF045C870 -none x86_64-pc-windows-msvc https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-win-x64-1.17.3.tgz EB40E58F3BA0BA8D256847429B4511CDAE4C81805891C94F0C58220E3E17CDA8 -none x86_64-unknown-linux-gnu https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-linux-x64-1.17.3.tgz 65BBCEC997F9854A721B982CD89D11F92D48F122270B9068BCD919E6C13606B5 - -none i686-pc-windows-msvc https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-win-x86-1.17.3.tgz 5B4CD789F5FA4936E28A640753EE46E86C4F65792BE7D074E8AC900410D94757 - -directml x86_64-pc-windows-msvc https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-win-x64-gpu-1.17.3.tgz 85345B4D7371C9A07C38ABDE853694FB0311EA2E41953DF3CCB052C564BAC8C2 -#cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_dylib_cuda11-v1.17.3-x86_64-pc-windows-msvc.tgz 9AE21DECB9BE1270CD850276AAC1AB9C2E2AE978B563B733804F5DF7DACC3BE5 -cu12 x86_64-pc-windows-msvc https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-win-x64-gpu-1.17.3.tgz 85345B4D7371C9A07C38ABDE853694FB0311EA2E41953DF3CCB052C564BAC8C2 -#cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_dylib_cuda11-v1.17.3-x86_64-unknown-linux-gnu.tgz D4C264EA6805790D4C8B51D166EF6BD407FB3ECC641B33AEFE77FCD5BF0C6ECA -cu12 x86_64-unknown-linux-gnu https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-linux-x64-gpu-1.17.3.tgz 6212F2BD130BAD7F59027DB24DC4DD8510E01FAFE875E04E7863A99B9D8BE400 -#rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_dylib_rocm-v1.17.3-x86_64-unknown-linux-gnu.tgz 50E39B38484A0676B3D24365149CE9E7760F658B552EFE5A9382AB503D73D7E7 - -# todo: update WASM build to 1.17.3 -#none wasm32-wasi https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b2-v1.17.1-wasm32-unknown-unknown.tgz 41A5713B37EEE40A0D7608B9E77AEB3E1A5DCE6845496A5F5E65F89A13E45089 -#none wasm32-wasi-preview1 https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b2-v1.17.1-wasm32-unknown-unknown.tgz 41A5713B37EEE40A0D7608B9E77AEB3E1A5DCE6845496A5F5E65F89A13E45089 -#none wasm32-wasi-preview2 https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b2-v1.17.1-wasm32-unknown-unknown.tgz 41A5713B37EEE40A0D7608B9E77AEB3E1A5DCE6845496A5F5E65F89A13E45089 -#none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b2-v1.17.1-wasm32-unknown-unknown.tgz 41A5713B37EEE40A0D7608B9E77AEB3E1A5DCE6845496A5F5E65F89A13E45089 - -none aarch64-linux-android https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-android-arm64-1.17.3.tgz A8EB02DF2153EF85BDE5C115F6802508077A98348F20FCFE12EC6D1A80DA1B53 -none x86_64-linux-android https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-android-x64-1.17.3.tgz DEB8D2687E5B8CFD0D80AF4D1C572248B999C62C6BCBF13108D5F22B4BBC78CE -none aarch64-apple-ios https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-ios-arm64-1.17.3.tgz 8687C8E358B65D84D5D2AD4BE9C577227B7D20CD866DF27133185DD0575B5B36 -none aarch64-apple-ios-sim https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-ios-sim-arm64-1.17.3.tgz 0561A9A1CA0BF3071C5C5CFCA69A1208338C8DDFE924A24A1A26DA618651B4AB -none x86_64-apple-ios https://github.com/VOICEVOX/onnxruntime-builder/releases/download/1.17.3/onnxruntime-ios-sim-x86_64-1.17.3.tgz B3C253258292E71B7D136E4C0DD0E4598BDCDE98285958020BE86B7014C0E1BA +none aarch64-linux-android https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-android-arm64-1.18.1.tgz 8F20DCFF66D693BDE5D0C10927A3F435372F3F3BC92928C6A5475256DB7453F3 +none x86_64-linux-android https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-android-x64-1.18.1.tgz F82A7226CE1A13283E14053BA167A85A34246DD179728CC120942AC21409B5FB +none aarch64-apple-ios https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-ios-arm64-1.18.1.tgz 1BBB7389A38D70C929136B5868E9A9691FB46C4FC229DD252F0629B0487C1273 +none aarch64-apple-ios-sim https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-ios-sim-arm64-1.18.1.tgz 4CE643E05251CFBED0C17D36BAF626FE70220CB8443B610D1B568966032E27B6 +none x86_64-apple-ios https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-ios-sim-x86_64-1.18.1.tg CC298BF4C21BD308BA5E2983C0AEB44689E5080DEAA14F3E78C3FDAF154EFE1Cz +none aarch64-unknown-linux-gnu https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-linux-arm64-1.18.1.tgz 90DEA27F55554025FB15D05204FFD86115DA460AE47DA1FA3ED7925CE7E2F2B3 +none armv7-unknown-linux-gnueabihf https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-linux-armhf-1.18.1.tgz F20A91E6578EE49547FB56AF8193A51CB50E20791ED979EAA91AFC25B4C8D4FC +none x86_64-unknown-linux-gnu https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-linux-x64-1.18.1.tgz 64D5651853DC53338A46533D6EA9D0B3D21E46E6FB73C4896ECFC35BE5695D72 +cu12 x86_64-unknown-linux-gnu https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-linux-x64-gpu-1.18.1.tgz A40210A28A4D5624B114F5011348FDA58854C2DBE8319A3F93B43C11E37F059F +none aarch64-apple-darwin https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-osx-arm64-1.18.1.tgz DDFFA06D0D7C51283822A8783316E62EFB4F9B58AEA70E9625211EC5B9813FDF +none x86_64-apple-darwin https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-osx-x86_64-1.18.1.tgz BAAF71F6CFB5405F85112C5A6683C8F6094C86CC61C7CD7AF2E7A075D08AA5F0 +none x86_64-pc-windows-msvc https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-win-x64-1.18.1.tgz AB1A4FC9AC0D818303AE54755F9E04507E2C673B2AD00A6DA4EB57C364DF53EB +cu12 x86_64-pc-windows-msvc https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-win-x64-cuda-1.18.1.tgz B6C2145B147BE7B5FE2EA8D3BDCE76F59884D5CE816CB4C6A6F10FE6071C5B89 +directml x86_64-pc-windows-msvc https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-win-x64-dml-1.18.1.tgz 02657C8B0D860AC3C25C355A741A01329E38C8D057A388BBB0A773FF883D869B +none i686-pc-windows-msvc https://github.com/qryxip/onnxruntime-builder/releases/download/1.18.1/onnxruntime-win-x86-1.18.1.tgz A1365AB4B33629EC47D28F403C57EE4ACAE1F3D76360857FA2AB1EBA2990E201 diff --git a/ort-sys/src/lib.rs b/ort-sys/src/lib.rs index ee08e634..f7cb853e 100644 --- a/ort-sys/src/lib.rs +++ b/ort-sys/src/lib.rs @@ -10,7 +10,7 @@ #[doc(hidden)] pub mod internal; -pub const ORT_API_VERSION: u32 = 17; +pub const ORT_API_VERSION: u32 = 18; pub use std::ffi::{c_char, c_int, c_ulong, c_ulonglong, c_ushort, c_void}; @@ -298,7 +298,8 @@ pub struct OrtAllocator { #[doc = "< Free a block of memory previously allocated with OrtAllocator::Alloc"] pub Free: ::std::option::Option<_system!(unsafe fn(this_: *mut OrtAllocator, p: *mut ::std::os::raw::c_void))>, #[doc = "< Return a pointer to an ::OrtMemoryInfo that describes this allocator"] - pub Info: ::std::option::Option<_system!(unsafe fn(this_: *const OrtAllocator) -> *const OrtMemoryInfo)> + pub Info: ::std::option::Option<_system!(unsafe fn(this_: *const OrtAllocator) -> *const OrtMemoryInfo)>, + pub Reserve: ::std::option::Option<_system!(unsafe fn(this_: *const OrtAllocator, size: size_t) -> *mut ::std::os::raw::c_void)> } #[test] fn bindgen_test_layout_OrtAllocator() { @@ -522,6 +523,7 @@ pub struct OrtROCMProviderOptions { pub user_compute_stream: *mut ::std::os::raw::c_void, #[doc = " \\brief ROCM memory arena configuration parameters"] pub default_memory_arena_cfg: *mut OrtArenaCfg, + pub enable_hip_graph: ::std::os::raw::c_int, #[doc = " \\brief Enable TunableOp for using.\n Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default.\n This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE."] pub tunable_op_enable: ::std::os::raw::c_int, #[doc = " \\brief Enable TunableOp for tuning.\n Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default.\n This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_TUNING_ENABLE."] @@ -821,9 +823,117 @@ fn bindgen_test_layout_OrtOpenVINOProviderOptions() { } #[repr(C)] #[derive(Debug, Copy, Clone)] -pub struct OrtTrainingApi { +pub struct OrtTrainingSession { + _unused: [u8; 0] +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct OrtCheckpointState { _unused: [u8; 0] } +#[repr(i32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum OrtPropertyType { + OrtIntProperty = 0, + OrtFloatProperty = 1, + OrtStringProperty = 2 +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct OrtTrainingApi { + pub LoadCheckpoint: + ::std::option::Option<_system!(unsafe fn(checkpoint_path: *const ortchar, checkpoint_state: *mut *mut OrtCheckpointState) -> OrtStatusPtr)>, + pub SaveCheckpoint: ::std::option::Option< + _system!(unsafe fn(checkpoint_state: *mut OrtCheckpointState, checkpoint_path: *const ortchar, include_optimizer_state: bool) -> OrtStatusPtr) + >, + pub CreateTrainingSession: ::std::option::Option< + _system!( + unsafe fn( + env: *const OrtEnv, + options: *const OrtSessionOptions, + checkpoint_state: *mut OrtCheckpointState, + train_model_path: *const ortchar, + eval_model_path: *const ortchar, + optimizer_model_path: *const ortchar, + out: *mut *mut OrtTrainingSession + ) -> OrtStatusPtr + ) + >, + pub CreateTrainingSessionFromBuffer: ::std::option::Option< + _system!( + unsafe fn( + env: *const OrtEnv, + options: *const OrtSessionOptions, + checkpoint_state: *mut OrtCheckpointState, + train_model_data: *const (), + train_data_length: size_t, + eval_model_data: *const (), + eval_data_length: size_t, + optimizer_model_data: *const (), + optimizer_data_length: size_t, + out: *mut *mut OrtTrainingSession + ) -> OrtStatusPtr + ) + >, + pub TrainingSessionGetTrainingModelOutputCount: + ::std::option::Option<_system!(unsafe fn(sess: *const OrtTrainingSession, out: *mut size_t) -> OrtStatusPtr)>, + pub TrainingSessionGetEvalModelOutputCount: ::std::option::Option<_system!(unsafe fn(sess: *const OrtTrainingSession, out: *mut size_t) -> OrtStatusPtr)>, + pub TrainingSessionGetTrainingModelOutputName: ::std::option::Option< + _system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut c_char) -> OrtStatusPtr) + >, + pub TrainingSessionGetEvalModelOutputName: ::std::option::Option< + _system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut c_char) -> OrtStatusPtr) + >, + pub LazyResetGrad: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession) -> OrtStatusPtr)>, + pub TrainStep: ::std::option::Option< + _system!( + unsafe fn( + session: *mut OrtTrainingSession, + run_options: *const OrtRunOptions, + inputs_len: size_t, + inputs: *const *const OrtValue, + outputs_len: size_t, + outputs: *mut *mut OrtValue + ) -> OrtStatusPtr + ) + >, + pub EvalStep: ::std::option::Option< + _system!( + unsafe fn( + session: *mut OrtTrainingSession, + run_options: *const OrtRunOptions, + inputs_len: size_t, + inputs: *const *const OrtValue, + outputs_len: size_t, + outputs: *mut *mut OrtValue + ) -> OrtStatusPtr + ) + >, + pub SetLearningRate: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, learning_rate: f32) -> OrtStatusPtr)>, + pub GetLearningRate: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, learning_rate: *mut f32) -> OrtStatusPtr)>, + pub OptimizerStep: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, run_options: *const OrtRunOptions) -> OrtStatusPtr)>, + pub RegisterLinearLRScheduler: ::std::option::Option< + _system!(unsafe fn(session: *mut OrtTrainingSession, warmup_step_count: i64, total_step_count: i64, initial_lr: f32) -> OrtStatusPtr) + >, + pub SchedulerStep: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession) -> OrtStatusPtr)>, + pub GetParametersSize: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, out: *mut size_t, trainable_only: bool) -> OrtStatusPtr)>, + pub CopyParametersToBuffer: + ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, parameters_buffer: *mut OrtValue, trainable_only: bool) -> OrtStatusPtr)>, + pub CopyBufferToParameters: + ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, parameters_buffer: *mut OrtValue, trainable_only: bool) -> OrtStatusPtr)>, + pub ReleaseTrainingSession: ::std::option::Option<_system!(unsafe fn(input: *mut OrtTrainingSession))>, + pub ReleaseCheckpointState: ::std::option::Option<_system!(unsafe fn(input: *mut OrtCheckpointState))>, + pub ExportModelForInferencing: ::std::option::Option< + _system!( + unsafe fn( + session: *mut OrtTrainingSession, + inference_model_path: *const ortchar, + graph_outputs_len: usize, + graph_output_names: *const *const c_char + ) -> OrtStatusPtr + ) + > +} #[doc = " \\brief The helper interface to get the right version of OrtApi\n\n Get a pointer to this structure through ::OrtGetApiBase"] #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -1831,6 +1941,39 @@ pub struct OrtApi { num_keys: size_t ) -> OrtStatusPtr ) + >, + pub SessionOptionsAppendExecutionProvider_VitisAI: ::std::option::Option< + _system!( + unsafe fn( + options: *mut OrtSessionOptions, + provider_options_keys: *const *const ::std::os::raw::c_char, + provider_options_values: *const *const ::std::os::raw::c_char, + num_keys: size_t + ) -> OrtStatusPtr + ) + >, + pub KernelContext_GetScratchBuffer: ::std::option::Option< + _system!( + unsafe fn( + context: *const OrtKernelContext, + mem_info: *const OrtMemoryInfo, + count_or_bytes: size_t, + out: *mut *mut ::std::os::raw::c_void + ) -> OrtStatusPtr + ) + >, + pub KernelInfoGetAllocator: + ::std::option::Option<_system!(unsafe fn(info: *const OrtKernelInfo, mem_type: OrtMemType, out: *mut *mut OrtAllocator) -> OrtStatusPtr)>, + pub AddExternalInitializersFromMemory: ::std::option::Option< + _system!( + unsafe fn( + options: *mut OrtSessionOptions, + external_initializer_file_names: *const *const ortchar, + external_initializer_file_buffer_array: *const *mut ::std::os::raw::c_char, + external_initializer_file_lengths: *const size_t, + num_external_initializer_files: size_t + ) -> OrtStatusPtr + ) > } #[test] @@ -3254,7 +3397,13 @@ pub struct OrtCustomOp { pub KernelComputeV2: ::std::option::Option<_system!(unsafe fn(op_kernel: *mut ::std::os::raw::c_void, context: *mut OrtKernelContext) -> OrtStatusPtr)>, pub InferOutputShapeFn: ::std::option::Option<_system!(unsafe fn(op: *const OrtCustomOp, arg1: *mut OrtShapeInferContext) -> OrtStatusPtr)>, pub GetStartVersion: ::std::option::Option<_system!(unsafe fn(op: *const OrtCustomOp) -> ::std::os::raw::c_int)>, - pub GetEndVersion: ::std::option::Option<_system!(unsafe fn(op: *const OrtCustomOp) -> ::std::os::raw::c_int)> + pub GetEndVersion: ::std::option::Option<_system!(unsafe fn(op: *const OrtCustomOp) -> ::std::os::raw::c_int)>, + pub GetMayInplace: + ::std::option::Option<_system!(unsafe fn(input_index: *mut *mut ::std::os::raw::c_int, output_index: *mut *mut ::std::os::raw::c_int) -> size_t)>, + pub ReleaseMayInplace: ::std::option::Option<_system!(unsafe fn(input_index: *mut ::std::os::raw::c_int, output_index: *mut *mut ::std::os::raw::c_int))>, + pub GetAliasMap: + ::std::option::Option<_system!(unsafe fn(input_index: *mut *mut ::std::os::raw::c_int, output_index: *mut *mut ::std::os::raw::c_int) -> size_t)>, + pub ReleaseAliasMap: ::std::option::Option<_system!(unsafe fn(input_index: *mut ::std::os::raw::c_int, output_index: *mut *mut ::std::os::raw::c_int))> } #[test] fn bindgen_test_layout_OrtCustomOp() { diff --git a/src/environment.rs b/src/environment.rs index 58d10681..eb56862a 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -15,7 +15,8 @@ use tracing::{debug, Level}; use crate::G_ORT_DYLIB_PATH; use crate::{ error::{Error, Result}, - extern_system_fn, ortsys, ExecutionProviderDispatch + execution_providers::ExecutionProviderDispatch, + extern_system_fn, ortsys }; struct EnvironmentSingleton { @@ -34,7 +35,7 @@ pub struct Environment { } impl Environment { - /// Loads the underlying [`ort_sys::OrtEnv`] pointer. + /// Returns the underlying [`ort_sys::OrtEnv`] pointer. pub fn ptr(&self) -> *mut ort_sys::OrtEnv { self.env_ptr.load(Ordering::Relaxed) } @@ -52,13 +53,14 @@ impl Drop for Environment { } } -/// Gets a reference to the global environment, creating one if an environment has been committed yet. +/// Gets a reference to the global environment, creating one if an environment has not been +/// [`commit`](EnvironmentBuilder::commit)ted yet. pub fn get_environment() -> Result<&'static Arc> { if let Some(c) = unsafe { &*G_ENV.cell.get() } { Ok(c) } else { debug!("Environment not yet initialized, creating a new one"); - EnvironmentBuilder::default().commit()?; + EnvironmentBuilder::new().commit()?; Ok(unsafe { (*G_ENV.cell.get()).as_ref().unwrap_unchecked() }) } @@ -72,26 +74,26 @@ pub struct EnvironmentGlobalThreadPoolOptions { pub intra_op_thread_affinity: Option } -/// Struct used to build an `Environment`. +/// Struct used to build an [`Environment`]; see [`crate::init`]. pub struct EnvironmentBuilder { name: String, + telemetry: bool, execution_providers: Vec, global_thread_pool_options: Option } -impl Default for EnvironmentBuilder { - fn default() -> Self { +impl EnvironmentBuilder { + pub(crate) fn new() -> Self { EnvironmentBuilder { name: "default".to_string(), + telemetry: true, execution_providers: vec![], global_thread_pool_options: None } } -} -impl EnvironmentBuilder { /// Configure the environment with a given name for logging purposes. - #[must_use] + #[must_use = "commit() must be called in order for the environment to take effect"] pub fn with_name(mut self, name: S) -> Self where S: Into @@ -100,6 +102,22 @@ impl EnvironmentBuilder { self } + /// Enable or disable sending telemetry events to Microsoft. + /// + /// Typically, only Windows builds of ONNX Runtime provided by Microsoft will have telemetry enabled. + /// Pre-built binaries provided by pyke, or binaries compiled from source, won't have telemetry enabled. + /// + /// The exact kind of telemetry data sent can be found [here](https://github.com/microsoft/onnxruntime/blob/v1.18.1/onnxruntime/core/platform/windows/telemetry.cc). + /// Currently, this includes (but is not limited to): ONNX graph version, model producer name & version, whether or + /// not FP16 is used, operator domains & versions, model graph name & custom metadata, execution provider names, + /// error messages, and the total number & time of session inference runs. The ONNX Runtime team uses this data to + /// better understand how customers use ONNX Runtime and where performance can be improved. + #[must_use = "commit() must be called in order for the environment to take effect"] + pub fn with_telemetry(mut self, enable: bool) -> Self { + self.telemetry = enable; + self + } + /// Sets a list of execution providers which all sessions created in this environment will register. /// /// If a session is created in this environment with [`crate::SessionBuilder::with_execution_providers`], those EPs @@ -108,14 +126,14 @@ impl EnvironmentBuilder { /// Execution providers will only work if the corresponding Cargo feature is enabled and ONNX Runtime was built /// with support for the corresponding execution provider. Execution providers that do not have their corresponding /// feature enabled will emit a warning. - #[must_use] + #[must_use = "commit() must be called in order for the environment to take effect"] pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Self { self.execution_providers = execution_providers.as_ref().to_vec(); self } /// Enables the global thread pool for this environment. - #[must_use] + #[must_use = "commit() must be called in order for the environment to take effect"] pub fn with_global_thread_pool(mut self, options: EnvironmentGlobalThreadPoolOptions) -> Self { self.global_thread_pool_options = Some(options); self @@ -156,14 +174,17 @@ impl EnvironmentBuilder { ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr()) -> Error::CreateEnvironment]; } - ortsys![unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools( + ortsys![ + unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools( logging_function, logger_param, ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, cname.as_ptr(), thread_options, &mut env_ptr - ) -> Error::CreateEnvironment; nonNull(env_ptr)]; + ) -> Error::CreateEnvironment; + nonNull(env_ptr) + ]; ortsys![unsafe ReleaseThreadingOptions(thread_options)]; (env_ptr, true) } else { @@ -172,17 +193,26 @@ impl EnvironmentBuilder { // FIXME: What should go here? let logger_param: *mut std::ffi::c_void = std::ptr::null_mut(); let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!()); - ortsys![unsafe CreateEnvWithCustomLogger( + ortsys![ + unsafe CreateEnvWithCustomLogger( logging_function, logger_param, ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, cname.as_ptr(), &mut env_ptr - ) -> Error::CreateEnvironment; nonNull(env_ptr)]; + ) -> Error::CreateEnvironment; + nonNull(env_ptr) + ]; (env_ptr, false) }; debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created"); + if self.telemetry { + ortsys![unsafe EnableTelemetryEvents(env_ptr) -> Error::CreateEnvironment]; + } else { + ortsys![unsafe DisableTelemetryEvents(env_ptr) -> Error::CreateEnvironment]; + } + unsafe { *G_ENV.cell.get() = Some(Arc::new(Environment { execution_providers: self.execution_providers, @@ -197,15 +227,25 @@ impl EnvironmentBuilder { /// Creates an ONNX Runtime environment. /// +/// ``` +/// # use ort::CUDAExecutionProvider; +/// # fn main() -> ort::Result<()> { +/// ort::init() +/// .with_execution_providers([CUDAExecutionProvider::default().build()]) +/// .commit()?; +/// # Ok(()) +/// # } +/// ``` +/// /// # Notes /// - It is not required to call this function. If this is not called by the time any other `ort` APIs are used, a /// default environment will be created. -/// - Library crates that use `ort` shouldn't create their own environment. Let downstream applications create it. +/// - **Library crates that use `ort` shouldn't create their own environment.** Let downstream applications create it. /// - In order for environment settings to apply, this must be called **before** you use other APIs like /// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function. -#[must_use] +#[must_use = "commit() must be called in order for the environment to take effect"] pub fn init() -> EnvironmentBuilder { - EnvironmentBuilder::default() + EnvironmentBuilder::new() } /// Creates an ONNX Runtime environment, dynamically loading ONNX Runtime from the library file (`.dll`/`.so`/`.dylib`) @@ -213,56 +253,43 @@ pub fn init() -> EnvironmentBuilder { /// /// This must be called before any other `ort` APIs are used in order for the correct dynamic library to be loaded. /// +/// ```no_run +/// # use ort::CUDAExecutionProvider; +/// # fn main() -> ort::Result<()> { +/// let lib_path = std::env::current_exe().unwrap().parent().unwrap().join("lib"); +/// ort::init_from(lib_path.join("onnxruntime.dll")) +/// .with_execution_providers([CUDAExecutionProvider::default().build()]) +/// .commit()?; +/// # Ok(()) +/// # } +/// ``` +/// /// # Notes /// - In order for environment settings to apply, this must be called **before** you use other APIs like /// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function. #[cfg(feature = "load-dynamic")] #[cfg_attr(docsrs, doc(cfg(feature = "load-dynamic")))] -#[must_use] +#[must_use = "commit() must be called in order for the environment to take effect"] pub fn init_from(path: impl ToString) -> EnvironmentBuilder { let _ = G_ORT_DYLIB_PATH.set(Arc::new(path.to_string())); - EnvironmentBuilder::default() -} - -/// ONNX's logger sends the code location where the log occurred, which will be parsed into this struct. -#[derive(Debug)] -struct CodeLocation<'a> { - file: &'a str, - line: &'a str, - function: &'a str -} - -impl<'a> From<&'a str> for CodeLocation<'a> { - fn from(code_location: &'a str) -> Self { - let mut splitter = code_location.split(' '); - let file_and_line = splitter.next().unwrap_or(":"); - let function = splitter.next().unwrap_or(""); - let mut file_and_line_splitter = file_and_line.split(':'); - let file = file_and_line_splitter.next().unwrap_or(""); - let line = file_and_line_splitter.next().unwrap_or(""); - - CodeLocation { file, line, function } - } + EnvironmentBuilder::new() } extern_system_fn! { /// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate. - pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: ort_sys::OrtLoggingLevel, category: *const c_char, _: *const c_char, code_location: *const c_char, message: *const c_char) { - assert_ne!(category, ptr::null()); - let category = unsafe { CStr::from_ptr(category) }.to_str().unwrap_or(""); + pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: ort_sys::OrtLoggingLevel, _: *const c_char, id: *const c_char, code_location: *const c_char, message: *const c_char) { assert_ne!(code_location, ptr::null()); - let code_location_str = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or(""); + let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or(""); assert_ne!(message, ptr::null()); let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or(""); + assert_ne!(id, ptr::null()); + let id = unsafe { CStr::from_ptr(id) }.to_str().unwrap_or(""); - let code_location = CodeLocation::from(code_location_str); let span = tracing::span!( Level::TRACE, "ort", - category = category, - file = code_location.file, - line = code_location.line, - function = code_location.function + id = id, + location = code_location ); match severity { @@ -318,7 +345,7 @@ mod tests { assert!(!is_env_initialized()); assert_eq!(env_ptr(), None); - EnvironmentBuilder::default().with_name("env_is_initialized").commit()?; + EnvironmentBuilder::new().with_name("env_is_initialized").commit()?; assert!(is_env_initialized()); assert_ne!(env_ptr(), None); Ok(()) diff --git a/src/error.rs b/src/error.rs index 07375207..2e84580a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,7 +4,7 @@ use std::{convert::Infallible, ffi::CString, io, path::PathBuf, ptr, string}; use thiserror::Error; -use super::{char_p_to_string, ortsys, tensor::TensorElementType, ValueType}; +use crate::{char_p_to_string, ortsys, tensor::TensorElementType, value::ValueType}; /// Type alias for the Result type returned by ORT functions. pub type Result = std::result::Result; @@ -121,9 +121,6 @@ pub enum Error { /// Error occurred when filling a tensor with string data #[error("Failed to fill string tensor: {0}")] FillStringTensor(ErrorInternal), - /// Error occurred when checking if a value is a tensor - #[error("Failed to check if value is a tensor: {0}")] - FailedTensorCheck(ErrorInternal), /// Error occurred when getting tensor type and shape #[error("Failed to get tensor type and shape: {0}")] GetTensorTypeAndShape(ErrorInternal), @@ -159,12 +156,6 @@ pub enum Error { /// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models). #[error("Failed to download ONNX model: {0}")] DownloadError(#[from] FetchModelError), - /// Type of input data and the ONNX model do not match. - #[error("Data types do not match: expected {model:?}, got {input:?}")] - NonMatchingDataTypes { input: TensorElementType, model: TensorElementType }, - /// Dimensions of input data and the ONNX model do not match. - #[error("Dimensions do not match: {0:?}")] - NonMatchingDimensions(NonMatchingDimensionsError), /// File does not exist #[error("File `{filename:?}` does not exist")] FileDoesNotExist { @@ -180,19 +171,12 @@ pub enum Error { /// Attempt to build a Rust `CString` when the original string contains a null character. #[error("Failed to build CString when original contains null: {0}")] FfiStringNull(#[from] std::ffi::NulError), - /// Attempt to build a `WideCString` when the original string contains a null character. - #[cfg(all(windows, feature = "profiling"))] - #[error("Failed to build CString when original contains null: {0}")] - WideFfiStringNull(#[from] widestring::error::ContainsNul), #[error("`{0}` should be a null pointer")] /// ORT pointer should have been null PointerShouldBeNull(&'static str), /// ORT pointer should not have been null #[error("`{0}` should not be a null pointer")] PointerShouldNotBeNull(&'static str), - /// The runtime type was undefined. - #[error("Undefined tensor element type")] - UndefinedTensorElementType, /// Could not retrieve model metadata. #[error("Failed to retrieve model metadata: {0}")] GetModelMetadata(ErrorInternal), @@ -212,8 +196,8 @@ pub enum Error { ExecutionProviderNotRegistered(&'static str), #[error("Expected tensor to be on CPU in order to get data, but had allocation device `{0}`.")] TensorNotOnCpu(&'static str), - #[error("String tensors require the session's allocator to be provided through `Value::from_array`.")] - StringTensorRequiresAllocator, + #[error("Cannot extract scalar value from a {0}-dimensional tensor")] + TensorNot0Dimensional(usize), #[error("Failed to create memory info: {0}")] CreateMemoryInfo(ErrorInternal), #[error("Could not get allocation device from `MemoryInfo`: {0}")] @@ -226,10 +210,10 @@ pub enum Error { BindInput(ErrorInternal), #[error("Error when binding output: {0}")] BindOutput(ErrorInternal), - #[error("Failed to clear IO binding: {0}")] - ClearBinding(ErrorInternal), #[error("Error when retrieving session outputs from `IoBinding`: {0}")] GetBoundOutputs(ErrorInternal), + #[error("Cannot use `extract_tensor` on a value that is {0:?}")] + NotTensor(ValueType), #[error("Cannot use `extract_sequence` on a value that is {0:?}")] NotSequence(ValueType), #[error("Cannot use `extract_map` on a value that is {0:?}")] @@ -256,6 +240,16 @@ pub enum Error { GetOperatorInput(ErrorInternal), #[error("Failed to get operator output: {0}")] GetOperatorOutput(ErrorInternal), + #[error("Failed to get operator node name: {0}")] + GetOperatorNodeName(ErrorInternal), + #[error("Failed to retrieve GPU compute stream from kernel context: {0}")] + GetKernelGPUComputeStream(ErrorInternal), + #[error("Failed to retrieve EP resource from kernel context: {0}")] + GetKernelResource(ErrorInternal), + #[error("Failed to create allocator in kernel context: {0}")] + GetKernelAllocator(ErrorInternal), + #[error("Failed to allocate temporary buffer in kernel context: {0}")] + GetKernelBuffer(ErrorInternal), #[error("{0}")] CustomError(#[from] Box), #[error("String tensors cannot be borrowed as mutable")] @@ -267,40 +261,25 @@ pub enum Error { #[error("Could't get `AllocatorType` from memory info: {0}")] GetAllocatorType(ErrorInternal), #[error("Could't get device ID from memory info: {0}")] - GetDeviceId(ErrorInternal) + GetDeviceId(ErrorInternal), + #[error("Training API is not enabled in this build of ONNX Runtime.")] + TrainingNotEnabled } -impl From for Error { - fn from(_: Infallible) -> Self { - Error::Infallible +impl Error { + /// Wrap a custom, user-provided error in an [`ort::Error`](Error). The resulting error will be the + /// [`Error::CustomError`] variant. + /// + /// This can be used to return custom errors from e.g. training dataloaders or custom operators if a non-`ort` + /// related operation fails. + pub fn wrap(err: T) -> Self { + Error::CustomError(Box::new(err) as Box) } } -/// Error used when the input dimensions defined in the model and passed from an inference call do not match. -#[non_exhaustive] -#[derive(Error, Debug)] -pub enum NonMatchingDimensionsError { - /// Number of inputs from model does not match the number of inputs from inference call. - #[error( - "Non-matching number of inputs: {inference_input_count:?} provided vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})" - )] - InputsCount { - /// Number of input dimensions used by inference call - inference_input_count: usize, - /// Number of input dimensions defined in model - model_input_count: usize, - /// Input dimensions used by inference call - inference_input: Vec>, - /// Input dimensions defined in model - model_input: Vec>> - }, - /// Inputs length from model does not match the expected input from inference call - #[error("Different input lengths; expected input: {model_input:?}, received input: {inference_input:?}")] - InputsLength { - /// Input dimensions used by inference call - inference_input: Vec>, - /// Input dimensions defined in model - model_input: Vec>> +impl From for Error { + fn from(_: Infallible) -> Self { + Error::Infallible } } @@ -311,8 +290,8 @@ pub enum ErrorInternal { /// Details about the error. #[error("{0}")] Msg(String), - /// Converting the ONNX error message to UTF-8 failed. - #[error("an error occurred, but ort failed to convert the error message to UTF-8")] + /// Converting an FFI string to UTF-8 failed. + #[error("failed to convert string to UTF-8: {0}")] IntoStringError(std::ffi::IntoStringError) } diff --git a/src/execution_providers/acl.rs b/src/execution_providers/acl.rs index a8e3bdb3..ddce6299 100644 --- a/src/execution_providers/acl.rs +++ b/src/execution_providers/acl.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "acl"))] extern "C" { @@ -26,7 +29,7 @@ impl ACLExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: ACLExecutionProvider) -> Self { - ExecutionProviderDispatch::ACL(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/armnn.rs b/src/execution_providers/armnn.rs index 53a38795..c428feb8 100644 --- a/src/execution_providers/armnn.rs +++ b/src/execution_providers/armnn.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "armnn"))] extern "C" { @@ -26,7 +29,7 @@ impl ArmNNExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: ArmNNExecutionProvider) -> Self { - ExecutionProviderDispatch::ArmNN(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/cann.rs b/src/execution_providers/cann.rs index c43e8e06..91895681 100644 --- a/src/execution_providers/cann.rs +++ b/src/execution_providers/cann.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{ArenaExtendStrategy, Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Clone, PartialEq, Eq)] #[non_exhaustive] @@ -109,7 +112,7 @@ impl CANNExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: CANNExecutionProvider) -> Self { - ExecutionProviderDispatch::CANN(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/coreml.rs b/src/execution_providers/coreml.rs index 94971e8a..2fa4aa77 100644 --- a/src/execution_providers/coreml.rs +++ b/src/execution_providers/coreml.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "coreml"))] extern "C" { @@ -46,7 +49,7 @@ impl CoreMLExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: CoreMLExecutionProvider) -> Self { - ExecutionProviderDispatch::CoreML(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/cpu.rs b/src/execution_providers/cpu.rs index 2f98095d..06e031b8 100644 --- a/src/execution_providers/cpu.rs +++ b/src/execution_providers/cpu.rs @@ -1,5 +1,9 @@ -use super::ExecutionProvider; -use crate::{error::status_to_result, ortsys, Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{status_to_result, Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + ortsys, + session::SessionBuilder +}; #[derive(Debug, Default, Clone)] pub struct CPUExecutionProvider { @@ -21,7 +25,7 @@ impl CPUExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: CPUExecutionProvider) -> Self { - ExecutionProviderDispatch::CPU(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/cuda.rs b/src/execution_providers/cuda.rs index 6881068b..67cad84c 100644 --- a/src/execution_providers/cuda.rs +++ b/src/execution_providers/cuda.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{ArenaExtendStrategy, Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; /// The type of search done for cuDNN convolution algorithms. #[derive(Debug, Clone)] @@ -161,7 +164,7 @@ impl CUDAExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: CUDAExecutionProvider) -> Self { - ExecutionProviderDispatch::CUDA(value) + ExecutionProviderDispatch::new(value) } } @@ -171,7 +174,7 @@ impl ExecutionProvider for CUDAExecutionProvider { } fn supported_by_platform(&self) -> bool { - cfg!(any(all(target_os = "linux", any(target_os = "aarch64", target_os = "x86_64")), all(target_os = "windows", target_arch = "x86_64"))) + cfg!(any(all(target_os = "linux", any(target_arch = "aarch64", target_arch = "x86_64")), all(target_os = "windows", target_arch = "x86_64"))) } #[allow(unused, unreachable_code)] diff --git a/src/execution_providers/directml.rs b/src/execution_providers/directml.rs index 71802553..085e68f0 100644 --- a/src/execution_providers/directml.rs +++ b/src/execution_providers/directml.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "directml"))] extern "C" { @@ -26,7 +29,7 @@ impl DirectMLExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: DirectMLExecutionProvider) -> Self { - ExecutionProviderDispatch::DirectML(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/migraphx.rs b/src/execution_providers/migraphx.rs new file mode 100644 index 00000000..d3cc62aa --- /dev/null +++ b/src/execution_providers/migraphx.rs @@ -0,0 +1,86 @@ +use std::ffi::CString; + +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; + +#[derive(Debug, Default, Clone)] +pub struct MIGraphXExecutionProvider { + device_id: i32, + enable_fp16: bool, + enable_int8: bool, + use_native_calibration_table: bool, + int8_calibration_table_name: Option +} + +impl MIGraphXExecutionProvider { + #[must_use] + pub fn with_device_id(mut self, device_id: i32) -> Self { + self.device_id = device_id; + self + } + + #[must_use] + pub fn with_fp16(mut self) -> Self { + self.enable_fp16 = true; + self + } + + #[must_use] + pub fn with_int8(mut self) -> Self { + self.enable_int8 = true; + self + } + + #[must_use] + pub fn with_native_calibration_table(mut self, table_name: Option>) -> Self { + self.use_native_calibration_table = true; + self.int8_calibration_table_name = table_name.map(|c| CString::new(c.as_ref()).expect("invalid string")); + self + } + + #[must_use] + pub fn build(self) -> ExecutionProviderDispatch { + self.into() + } +} + +impl From for ExecutionProviderDispatch { + fn from(value: MIGraphXExecutionProvider) -> Self { + ExecutionProviderDispatch::new(value) + } +} + +impl ExecutionProvider for MIGraphXExecutionProvider { + fn as_str(&self) -> &'static str { + "MIGraphXExecutionProvider" + } + + fn supported_by_platform(&self) -> bool { + cfg!(any(all(target_os = "linux", target_arch = "x86_64"), all(target_os = "windows", target_arch = "x86_64"))) + } + + #[allow(unused, unreachable_code)] + fn register(&self, session_builder: &SessionBuilder) -> Result<()> { + #[cfg(any(feature = "load-dynamic", feature = "migraphx"))] + { + let options = ort_sys::OrtMIGraphXProviderOptions { + device_id: self.device_id, + migraphx_fp16_enable: self.enable_fp16.into(), + migraphx_int8_enable: self.enable_int8.into(), + migraphx_use_native_calibration_table: self.use_native_calibration_table.into(), + migraphx_int8_calibration_table_name: self + .int8_calibration_table_name + .as_ref() + .map(|c| c.as_ptr()) + .unwrap_or_else(std::ptr::null) + }; + crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.session_options_ptr.as_ptr(), &options) -> Error::ExecutionProvider]; + return Ok(()); + } + + Err(Error::ExecutionProviderNotRegistered(self.as_str())) + } +} diff --git a/src/execution_providers/mod.rs b/src/execution_providers/mod.rs index 24ec6acf..52e452da 100644 --- a/src/execution_providers/mod.rs +++ b/src/execution_providers/mod.rs @@ -1,6 +1,11 @@ -use std::{fmt::Debug, os::raw::c_char}; +use std::{fmt::Debug, os::raw::c_char, sync::Arc}; -use crate::{char_p_to_string, ortsys, Error, Result, SessionBuilder}; +use crate::{ + char_p_to_string, + error::{Error, Result}, + ortsys, + session::SessionBuilder +}; mod cpu; pub use self::cpu::CPUExecutionProvider; @@ -32,6 +37,8 @@ mod xnnpack; pub use self::xnnpack::XNNPACKExecutionProvider; mod armnn; pub use self::armnn::ArmNNExecutionProvider; +mod migraphx; +pub use self::migraphx::MIGraphXExecutionProvider; /// ONNX Runtime works with different hardware acceleration libraries through its extensible **Execution Providers** /// (EP) framework to optimally execute the ONNX models on the hardware platform. This interface enables flexibility for @@ -39,7 +46,7 @@ pub use self::armnn::ArmNNExecutionProvider; /// optimize the execution by taking advantage of the compute capabilities of the platform. /// /// ![](https://www.onnxruntime.ai/images/ONNX_Runtime_EP1.png) -pub trait ExecutionProvider { +pub trait ExecutionProvider: Sync + Send { /// Returns the identifier of this execution provider used internally by ONNX Runtime. /// /// This is the same as what's used in ONNX Runtime's Python API to register this execution provider, i.e. @@ -60,16 +67,17 @@ pub trait ExecutionProvider { true } - /// Returns `Ok(true)` if ONNX Runtime was compiled with support for this execution provider, and `Ok(false)` + /// Returns `Ok(true)` if ONNX Runtime was *compiled with support* for this execution provider, and `Ok(false)` /// otherwise. /// /// An `Err` may be returned if a serious internal error occurs, in which case your application should probably /// just abort. /// - /// Note that this does not always mean the execution provider is *usable* for a specific model. A model may use - /// operators not supported by an execution provider, or the EP may encounter an error while attempting to load a - /// dynamic library during registration. In most cases (i.e. showing the user an error message if CUDA could not be - /// enabled), you'll instead want to detect and handle errors from [`ExecutionProvider::register`]. + /// **Note that this does not always mean the execution provider is *usable* for a specific session.** A model may + /// use operators not supported by an execution provider, or the EP may encounter an error while attempting to load + /// dependencies during session creation. In most cases (i.e. showing the user an error message if CUDA could not be + /// enabled), you'll instead want to manually register this EP via [`ExecutionProvider::register`] and detect + /// and handle any errors returned by that function. fn is_available(&self) -> Result { let mut providers: *mut *mut c_char = std::ptr::null_mut(); let mut num_providers = 0; @@ -110,56 +118,50 @@ pub enum ArenaExtendStrategy { SameAsRequested } -/// Execution provider container. See [the ONNX Runtime docs](https://onnxruntime.ai/docs/execution-providers/) for more -/// info on execution providers. Execution providers are actually registered via the functions [`crate::SessionBuilder`] -/// (per-session) or [`EnvironmentBuilder`](crate::environment::EnvironmentBuilder) (default for all sessions in an -/// environment). -#[derive(Debug, Clone)] +/// Dynamic execution provider container, used to provide a list of multiple types of execution providers when +/// configuring execution providers for a [`SessionBuilder`](crate::SessionBuilder) or +/// [`EnvironmentBuilder`](crate::environment::EnvironmentBuilder). +/// +/// See [`ExecutionProvider`] for more info on execution providers. +#[derive(Clone)] #[allow(missing_docs)] #[non_exhaustive] -pub enum ExecutionProviderDispatch { - CPU(CPUExecutionProvider), - CUDA(CUDAExecutionProvider), - TensorRT(TensorRTExecutionProvider), - OpenVINO(OpenVINOExecutionProvider), - ACL(ACLExecutionProvider), - OneDNN(OneDNNExecutionProvider), - CoreML(CoreMLExecutionProvider), - DirectML(DirectMLExecutionProvider), - ROCm(ROCmExecutionProvider), - NNAPI(NNAPIExecutionProvider), - QNN(QNNExecutionProvider), - TVM(TVMExecutionProvider), - CANN(CANNExecutionProvider), - XNNPACK(XNNPACKExecutionProvider), - ArmNN(ArmNNExecutionProvider) +pub struct ExecutionProviderDispatch { + pub(crate) inner: Arc, + error_on_failure: bool } -macro_rules! impl_dispatch { - ($($variant:ident),*) => { - impl ExecutionProvider for ExecutionProviderDispatch { - fn as_str(&self) -> &'static str { - match self { - $(Self::$variant(inner) => inner.as_str(),)* - } - } +impl ExecutionProviderDispatch { + pub(crate) fn new(ep: E) -> Self { + ExecutionProviderDispatch { + inner: Arc::new(ep) as Arc, + error_on_failure: false + } + } - fn is_available(&self) -> $crate::Result { - match self { - $(Self::$variant(inner) => inner.is_available(),)* - } - } + /// Configures this execution provider to silently log an error if registration of the EP fails. + /// This is the default behavior; it can be overridden with [`ExecutionProviderDispatch::error_on_failure`]. + pub fn fail_silently(mut self) -> Self { + self.error_on_failure = false; + self + } - fn register(&self, session_builder: &$crate::SessionBuilder) -> $crate::Result<()> { - match self { - $(Self::$variant(inner) => inner.register(session_builder),)* - } - } - } - }; + /// Configures this execution provider to return an error upon EP registration if registration of this EP fails. + /// The default behavior is to silently fail and fall back to the next execution provider, or the CPU provider if no + /// registrations succeed. + pub fn error_on_failure(mut self) -> Self { + self.error_on_failure = true; + self + } } -impl_dispatch!(CPU, CUDA, TensorRT, ACL, OneDNN, OpenVINO, CoreML, CANN, ROCm, DirectML, TVM, NNAPI, QNN, XNNPACK, ArmNN); +impl Debug for ExecutionProviderDispatch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct(self.inner.as_str()) + .field("error_on_failure", &self.error_on_failure) + .finish() + } +} #[allow(unused)] macro_rules! map_keys { @@ -207,26 +209,31 @@ macro_rules! get_ep_register { pub(crate) use get_ep_register; #[tracing::instrument(skip_all)] -pub(crate) fn apply_execution_providers(session_builder: &SessionBuilder, execution_providers: impl Iterator) { +pub(crate) fn apply_execution_providers(session_builder: &SessionBuilder, execution_providers: impl Iterator) -> Result<()> { let execution_providers: Vec<_> = execution_providers.collect(); let mut fallback_to_cpu = !execution_providers.is_empty(); for ex in execution_providers { - if let Err(e) = ex.register(session_builder) { + if let Err(e) = ex.inner.register(session_builder) { + if ex.error_on_failure { + return Err(e); + } + if let &Error::ExecutionProviderNotRegistered(ep_name) = &e { - if ex.supported_by_platform() { + if ex.inner.supported_by_platform() { tracing::warn!("{e}"); } else { - tracing::debug!("{e} (additionally, `{ep_name}` is not supported on this platform)"); + tracing::debug!("{e} (note: additionally, `{ep_name}` is not supported on this platform)"); } } else { - tracing::warn!("An error occurred when attempting to register `{}`: {e}", ex.as_str()); + tracing::error!("An error occurred when attempting to register `{}`: {e}", ex.inner.as_str()); } } else { - tracing::info!("Successfully registered `{}`", ex.as_str()); + tracing::info!("Successfully registered `{}`", ex.inner.as_str()); fallback_to_cpu = false; } } if fallback_to_cpu { tracing::warn!("No execution providers registered successfully. Falling back to CPU."); } + Ok(()) } diff --git a/src/execution_providers/nnapi.rs b/src/execution_providers/nnapi.rs index 472db339..68d275af 100644 --- a/src/execution_providers/nnapi.rs +++ b/src/execution_providers/nnapi.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "nnapi"))] extern "C" { @@ -59,7 +62,7 @@ impl NNAPIExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: NNAPIExecutionProvider) -> Self { - ExecutionProviderDispatch::NNAPI(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/onednn.rs b/src/execution_providers/onednn.rs index 04166757..45dec270 100644 --- a/src/execution_providers/onednn.rs +++ b/src/execution_providers/onednn.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "onednn"))] extern "C" { @@ -29,7 +32,7 @@ impl OneDNNExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: OneDNNExecutionProvider) -> Self { - ExecutionProviderDispatch::OneDNN(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/openvino.rs b/src/execution_providers/openvino.rs index fb8f932b..61924c53 100644 --- a/src/execution_providers/openvino.rs +++ b/src/execution_providers/openvino.rs @@ -1,7 +1,10 @@ use std::os::raw::c_void; -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Clone)] pub struct OpenVINOExecutionProvider { @@ -103,7 +106,7 @@ impl OpenVINOExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: OpenVINOExecutionProvider) -> Self { - ExecutionProviderDispatch::OpenVINO(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/qnn.rs b/src/execution_providers/qnn.rs index 6262aac3..54ee71b4 100644 --- a/src/execution_providers/qnn.rs +++ b/src/execution_providers/qnn.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Clone)] pub enum QNNExecutionProviderPerformanceMode { @@ -110,7 +113,7 @@ impl QNNExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: QNNExecutionProvider) -> Self { - ExecutionProviderDispatch::QNN(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/rocm.rs b/src/execution_providers/rocm.rs index 50e7de54..3c3553be 100644 --- a/src/execution_providers/rocm.rs +++ b/src/execution_providers/rocm.rs @@ -1,7 +1,10 @@ use std::os::raw::c_void; -use super::ExecutionProvider; -use crate::{ArenaExtendStrategy, Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Clone)] pub struct ROCmExecutionProvider { @@ -12,6 +15,7 @@ pub struct ROCmExecutionProvider { do_copy_in_default_stream: bool, user_compute_stream: Option<*mut c_void>, default_memory_arena_cfg: Option<*mut ort_sys::OrtArenaCfg>, + enable_hip_graph: bool, tunable_op_enable: bool, tunable_op_tuning_enable: bool, tunable_op_max_tuning_duration_ms: i32 @@ -30,6 +34,7 @@ impl Default for ROCmExecutionProvider { do_copy_in_default_stream: true, user_compute_stream: None, default_memory_arena_cfg: None, + enable_hip_graph: false, tunable_op_enable: false, tunable_op_tuning_enable: false, tunable_op_max_tuning_duration_ms: 0 @@ -80,6 +85,12 @@ impl ROCmExecutionProvider { self } + #[must_use] + pub fn with_hip_graph(mut self, enable: bool) -> Self { + self.enable_hip_graph = enable; + self + } + #[must_use] pub fn with_tunable_op(mut self, enable: bool) -> Self { self.tunable_op_enable = enable; @@ -106,7 +117,7 @@ impl ROCmExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: ROCmExecutionProvider) -> Self { - ExecutionProviderDispatch::ROCm(value) + ExecutionProviderDispatch::new(value) } } @@ -135,6 +146,7 @@ impl ExecutionProvider for ROCmExecutionProvider { has_user_compute_stream: self.user_compute_stream.is_some().into(), user_compute_stream: self.user_compute_stream.unwrap_or_else(std::ptr::null_mut), default_memory_arena_cfg: self.default_memory_arena_cfg.unwrap_or_else(std::ptr::null_mut), + enable_hip_graph: self.enable_hip_graph.into(), tunable_op_enable: self.tunable_op_enable.into(), tunable_op_tuning_enable: self.tunable_op_tuning_enable.into(), tunable_op_max_tuning_duration_ms: self.tunable_op_max_tuning_duration_ms diff --git a/src/execution_providers/tensorrt.rs b/src/execution_providers/tensorrt.rs index e1df3ad0..599eaf09 100644 --- a/src/execution_providers/tensorrt.rs +++ b/src/execution_providers/tensorrt.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Default, Clone)] pub struct TensorRTExecutionProvider { @@ -210,7 +213,7 @@ impl TensorRTExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: TensorRTExecutionProvider) -> Self { - ExecutionProviderDispatch::TensorRT(value) + ExecutionProviderDispatch::new(value) } } @@ -220,13 +223,18 @@ impl ExecutionProvider for TensorRTExecutionProvider { } fn supported_by_platform(&self) -> bool { - cfg!(any(all(target_os = "linux", any(target_os = "aarch64", target_os = "x86_64")), all(target_os = "windows", target_arch = "x86_64"))) + cfg!(any(all(target_os = "linux", any(target_arch = "aarch64", target_arch = "x86_64")), all(target_os = "windows", target_arch = "x86_64"))) } #[allow(unused, unreachable_code)] fn register(&self, session_builder: &SessionBuilder) -> Result<()> { #[cfg(any(feature = "load-dynamic", feature = "tensorrt"))] { + // The TensorRT execution provider specifically is pretty picky about requiring an environment to be initialized by the + // time we register it. This isn't always the case in `ort`, so if we get to this point, let's make sure we have an + // environment initialized. + let _ = crate::get_environment(); + let mut trt_options: *mut ort_sys::OrtTensorRTProviderOptionsV2 = std::ptr::null_mut(); crate::error::status_to_result(crate::ortsys![unsafe CreateTensorRTProviderOptions(&mut trt_options)]).map_err(Error::ExecutionProvider)?; let (key_ptrs, value_ptrs, len, keys, values) = super::map_keys! { diff --git a/src/execution_providers/tvm.rs b/src/execution_providers/tvm.rs index a054a704..6e43601f 100644 --- a/src/execution_providers/tvm.rs +++ b/src/execution_providers/tvm.rs @@ -1,9 +1,12 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "tvm"))] extern "C" { - fn OrtSessionOptionsAppendExecutionProvider_Tvm(options: *mut ort_sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> ort_sys::OrtStatusPtr; + fn OrtSessionOptionsAppendExecutionProvider_Tvm(options: *mut ort_sys::OrtSessionOptions, opt_str: *const std::os::raw::c_char) -> ort_sys::OrtStatusPtr; } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -54,7 +57,7 @@ impl TVMExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: TVMExecutionProvider) -> Self { - ExecutionProviderDispatch::TVM(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/xnnpack.rs b/src/execution_providers/xnnpack.rs index b344cc3b..bd3763e0 100644 --- a/src/execution_providers/xnnpack.rs +++ b/src/execution_providers/xnnpack.rs @@ -1,7 +1,10 @@ use std::num::NonZeroUsize; -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Default, Clone)] pub struct XNNPACKExecutionProvider { @@ -23,7 +26,7 @@ impl XNNPACKExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: XNNPACKExecutionProvider) -> Self { - ExecutionProviderDispatch::XNNPACK(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/io_binding.rs b/src/io_binding.rs index a94f2913..e3b9b76e 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -1,32 +1,97 @@ use std::{ + collections::HashMap, ffi::CString, fmt::Debug, + marker::PhantomData, ptr::{self, NonNull}, sync::Arc }; use crate::{ + error::{Error, Result}, memory::MemoryInfo, ortsys, - session::{output::SessionOutputs, RunOptions}, - value::{Value, ValueRefMut}, - Error, Result, Session, ValueTypeMarker + session::{output::SessionOutputs, NoSelectedOutputs, RunOptions, Session}, + value::{DynValue, Value, ValueInner, ValueTypeMarker} }; /// Enables binding of session inputs and/or outputs to pre-allocated memory. /// -/// Note that this arrangement is designed to minimize data copies, and to that effect, your memory allocations must -/// match what is expected by the model, whether you run on CPU or GPU. Data will still be copied if the -/// pre-allocated memory location does not match the one expected by the model. However, copies with `IoBinding`s are -/// only done once, at the time of the binding, not at run time. This means, that if your input data required a copy, -/// your further input modifications would not be seen by ONNX Runtime unless you rebind it, even if it is the same -/// buffer. If your scenario requires that the data is copied, `IoBinding` may not be the best match for your use case. -/// The fact that data copy is not made during runtime may also have performance implications. +/// [`IoBinding`] minimizes copies between a device (like a GPU) and the host (CPU) by allowing the user to bind a +/// certain input/output to a pre-allocated value on a specific device. +/// +/// [`IoBinding`] is most suitable for: +/// - An ensemble of models in which the output from one model is the input of another and does not need to pass through +/// the CPU to perform additional processing. +/// - Situations where the output should stay on a device (e.g. to perform additional processing with CUDA). +/// - Diffusion models, for instance, that accept an unchanging embedding for conditioning. +/// +/// [`IoBinding`] will not provide any meaningful benefit for: +/// - Models where every input changes with each invocation, such as a causal language model or object recognition +/// model. +/// - Pipelines that go straight from CPU -> GPU -> CPU. +/// +/// # Example +/// A diffusion model which takes a text condition input. +/// +/// ```no_run +/// # use ort::{Allocator, AllocatorType, AllocationDevice, CUDAExecutionProvider, MemoryInfo, MemoryType, Session, Tensor, IoBinding}; +/// # fn main() -> ort::Result<()> { +/// let text_encoder = Session::builder()? +/// .with_execution_providers([CUDAExecutionProvider::default().build()])? +/// .commit_from_file("text_encoder.onnx")?; +/// let unet = Session::builder()? +/// .with_execution_providers([CUDAExecutionProvider::default().build()])? +/// .commit_from_file("unet.onnx")?; +/// +/// let text_condition = text_encoder +/// .run(ort::inputs![Tensor::::from_array(( +/// vec![27], +/// vec![ +/// 23763, 15460, 473, 68, 312, 265, 17463, 4098, 304, 1077, 283, 198, 7676, 5976, 272, 285, 3609, 435, +/// 21680, 321, 265, 300, 1689, 64, 285, 4763, 64 +/// ] +/// ))?]?)? +/// .remove("output0") +/// .unwrap(); +/// +/// let input_allocator = Allocator::new( +/// &unet, +/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)? +/// )?; +/// let mut latents = Tensor::::new(&input_allocator, [1, 4, 64, 64])?; +/// +/// let mut io_binding = unet.create_binding()?; +/// io_binding.bind_input("condition", &text_condition)?; +/// +/// let output_allocator = Allocator::new( +/// &unet, +/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUOutput)? +/// )?; +/// io_binding.bind_output("noise_pred", Tensor::::new(&output_allocator, [1, 4, 64, 64])?)?; +/// +/// for _ in 0..20 { +/// io_binding.bind_input("latents", &latents)?; +/// let noise_pred = io_binding.run()?.remove("noise_pred").unwrap(); +/// +/// let mut latents = latents.extract_tensor_mut(); +/// latents += &noise_pred.try_extract_tensor::()?; +/// } +/// # Ok(()) +/// # } +/// ``` +/// +/// [`IoBinding`] may provide a decent speedup in this example since the `condition` tensor is unchanging between runs. +/// If we were to use normal session inference, the `condition` tensor would be needlessly copied with each invocation +/// of `unet.run()`, and this copying can come with significant latency & overhead. With [`IoBinding`], the `condition` +/// tensor is only copied to the device once instead of 20 times. #[derive(Debug)] pub struct IoBinding<'s> { pub(crate) ptr: NonNull, session: &'s Session, - output_names: Vec + held_inputs: HashMap>, + output_names: Vec, + output_values: HashMap } impl<'s> IoBinding<'s> { @@ -36,25 +101,47 @@ impl<'s> IoBinding<'s> { Ok(Self { ptr: unsafe { NonNull::new_unchecked(ptr) }, session, - output_names: Vec::new() + held_inputs: HashMap::new(), + output_names: Vec::new(), + output_values: HashMap::new() }) } /// Bind a [`Value`] to a session input. - pub fn bind_input<'i: 's, T: ValueTypeMarker, S: AsRef>(&mut self, name: S, ort_value: &'i mut Value) -> Result> { + /// + /// Upon invocation, the value's data will be copied to the device the session is allocated on. The copied data will + /// be used as an input (specified by `name`) in all future invocations of [`IoBinding::run`] until the input is + /// overridden (by calling [`IoBinding::bind_input`] again) or until all inputs are cleared (via + /// [`IoBinding::clear_inputs`] or [`IoBinding::clear`]). + /// + /// The data is only copied **once**, immediately upon invocation of this function. Any changes to the given + /// value afterwards will not affect the data seen by the session until the value is re-bound. Subsequent re-binds + /// will still copy data, hence why [`IoBinding`] is really only suitable when one or more inputs do not change + /// between runs. + pub fn bind_input>(&mut self, name: S, ort_value: &Value) -> Result<()> { let name = name.as_ref(); let cname = CString::new(name)?; ortsys![unsafe BindInput(self.ptr.as_ptr(), cname.as_ptr(), ort_value.ptr()) -> Error::BindInput]; - Ok(ort_value.view_mut()) + self.held_inputs.insert(name.to_string(), Arc::clone(&ort_value.inner)); + Ok(()) } /// Bind a session output to a pre-allocated [`Value`]. - pub fn bind_output<'o: 's, T: ValueTypeMarker, S: AsRef>(&mut self, name: S, ort_value: &'o mut Value) -> Result> { + /// + /// This allows for the pre-allocation and reuse of memory in the session output (see [`crate::Tensor::new`]). Any + /// subsequent runs via [`IoBinding::run`] will reuse the same tensor to store the output instead of creating a new + /// one each time. + /// + /// The output will be accessible in the value returned by [`IoBinding::run`], under the name specified by `name`. + pub fn bind_output>(&mut self, name: S, ort_value: Value) -> Result<()> { let name = name.as_ref(); let cname = CString::new(name)?; ortsys![unsafe BindOutput(self.ptr.as_ptr(), cname.as_ptr(), ort_value.ptr()) -> Error::BindOutput]; self.output_names.push(name.to_string()); - Ok(ort_value.view_mut()) + // Clear the old bound output if we have any. + drop(self.output_values.remove(name)); + self.output_values.insert(name.to_string(), ort_value.into_dyn()); + Ok(()) } /// Bind a session output to a device which is specified by `mem_info`. @@ -66,15 +153,35 @@ impl<'s> IoBinding<'s> { Ok(()) } - pub fn run<'i: 's>(&'i self) -> Result> { + /// Clears all bound inputs specified by [`IoBinding::bind_input`]. + pub fn clear_inputs(&mut self) { + ortsys![unsafe ClearBoundInputs(self.ptr.as_ptr())]; + drop(self.held_inputs.drain()); + } + /// Clears all bound outputs specified by [`IoBinding::bind_output`] or [`IoBinding::bind_output_to_device`]. + pub fn clear_outputs(&mut self) { + ortsys![unsafe ClearBoundOutputs(self.ptr.as_ptr())]; + drop(self.output_names.drain(..)); + drop(self.output_values.drain()); + } + /// Clears both the bound inputs & outputs; equivalent to [`IoBinding::clear_inputs`] followed by + /// [`IoBinding::clear_outputs`]. + pub fn clear(&mut self) { + self.clear_inputs(); + self.clear_outputs(); + } + + /// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`]. + pub fn run(&mut self) -> Result> { self.run_inner(None) } - pub fn run_with_options<'i: 's>(&'i self, run_options: Arc) -> Result> { + /// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`]. + pub fn run_with_options(&mut self, run_options: &RunOptions) -> Result> { self.run_inner(Some(run_options)) } - fn run_inner<'i: 's>(&'i self, run_options: Option>) -> Result> { + fn run_inner(&mut self, run_options: Option<&RunOptions>) -> Result> { let run_options_ptr = if let Some(run_options) = run_options { run_options.run_options_ptr.as_ptr() } else { @@ -82,6 +189,7 @@ impl<'s> IoBinding<'s> { }; ortsys![unsafe RunWithBinding(self.session.inner.session_ptr.as_ptr(), run_options_ptr, self.ptr.as_ptr()) -> Error::SessionRunWithIoBinding]; + let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Arc> = self.output_values.values().map(|c| (c.ptr(), &c.inner)).collect(); let mut count = self.output_names.len() as ort_sys::size_t; if count > 0 { let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut(); @@ -91,10 +199,17 @@ impl<'s> IoBinding<'s> { let output_values = unsafe { std::slice::from_raw_parts(output_values_ptr, count as _).to_vec() } .into_iter() .map(|v| unsafe { - Value::from_ptr( - NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"), - Some(Arc::clone(&self.session.inner)) - ) + if let Some(inner) = owned_ptrs.get(&v) { + DynValue { + inner: Arc::clone(*inner), + _markers: PhantomData + } + } else { + DynValue::from_ptr( + NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"), + Some(Arc::clone(&self.session.inner)) + ) + } }); // output values will be freed when the `Value`s in `SessionOutputs` drop diff --git a/src/lib.rs b/src/lib.rs index bd8f560b..8e83d6d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,9 @@ pub(crate) mod metadata; pub(crate) mod operator; pub(crate) mod session; pub(crate) mod tensor; +#[cfg(feature = "training")] +pub(crate) mod training; +pub(crate) mod util; pub(crate) mod value; #[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))] #[cfg(target_arch = "wasm32")] @@ -61,18 +64,21 @@ pub use self::operator::{ InferShapeFn, Operator, OperatorDomain }; pub use self::session::{ - GraphOptimizationLevel, InMemorySession, Input, Output, RunOptions, Session, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, - SharedSessionInner + GraphOptimizationLevel, HasSelectedOutputs, InMemorySession, InferenceFut, Input, NoSelectedOutputs, Output, OutputSelector, RunOptions, + SelectedOutputMarker, Session, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, SharedSessionInner }; #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub use self::tensor::ArrayExtensions; -pub use self::tensor::{IntoTensorElementType, TensorElementType}; +pub use self::tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data}; +#[cfg(feature = "training")] +#[cfg_attr(docsrs, doc(cfg(feature = "training")))] +pub use self::training::*; pub use self::value::{ DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence, - SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker, Value, ValueRef, - ValueRefMut, ValueType, ValueTypeMarker + SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker, Value, + ValueRef, ValueRefMut, ValueType, ValueTypeMarker }; /// このクレートのフィーチャが指定された状態になっていなければコンパイルエラー。 @@ -294,7 +300,7 @@ fn create_env(api: NonNull, tp_options: Option> = OnceLock::ne /// May panic if: /// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime. /// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled. +/// +/// # Examples +/// The primary (public-facing) use case for this function is accessing APIs that do not have a corresponding safe +/// implementation in `ort`. For example, [`GetBuildInfoString`](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a0a7dba37b0017c0ef3a0ab4e266a967d): +/// +/// ``` +/// # use std::ffi::CStr; +/// # fn main() -> ort::Result<()> { +/// let api = ort::api().as_ptr(); +/// let build_info = unsafe { CStr::from_ptr((*api).GetBuildInfoString.unwrap()()) }; +/// println!("{}", build_info.to_string_lossy()); +/// // ORT Build Info: git-branch=HEAD, git-commit-id=4573740, build type=Release, cmake cxx flags: /DWIN32 /D_WINDOWS /EHsc /EHsc /wd26812 -DEIGEN_HAS_C99_MATH -DCPUINFO_SUPPORTED +/// # Ok(()) +/// # } +/// ``` +/// +/// For the full list of ONNX Runtime APIs, consult the [`ort_sys::OrtApi`] struct and the [ONNX Runtime C API](https://onnxruntime.ai/docs/api/c/struct_ort_api.html). pub fn api() -> NonNull { #[cfg(feature = "__init-for-voicevox")] if true { @@ -436,6 +459,26 @@ pub(crate) fn char_p_to_string(raw: *const c_char) -> Result { .map_err(Error::FfiStringConversion) } +pub(crate) struct PrivateTraitMarker; + +macro_rules! private_trait { + () => { + #[doc(hidden)] + #[allow(private_interfaces)] + fn _private() -> crate::PrivateTraitMarker; + }; +} +macro_rules! private_impl { + () => { + #[allow(private_interfaces)] + fn _private() -> crate::PrivateTraitMarker { + crate::PrivateTraitMarker + } + }; +} +pub(crate) use private_impl; +pub(crate) use private_trait; + #[cfg(test)] mod test { use std::ffi::CString; diff --git a/src/memory.rs b/src/memory.rs index 00464f27..74eb7703 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -1,20 +1,76 @@ use std::{ ffi::{c_char, c_int, CString}, - ptr::NonNull + ptr::NonNull, + sync::Arc }; -use super::{ - error::{Error, Result}, - ortsys +use crate::{ + char_p_to_string, + error::{status_to_result, Error, Result}, + ortsys, + session::{Session, SharedSessionInner} }; -use crate::{char_p_to_string, error::status_to_result, Session}; -/// An ONNX Runtime allocator, used to manage the allocation of [`crate::Value`]s. +/// A device allocator used to manage the allocation of [`crate::Value`]s. +/// +/// # Direct allocation +/// [`Allocator`] can be used to directly allocate device memory. This can be useful if you have a +/// postprocessing step that runs on the GPU. +/// ```no_run +/// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; +/// # fn main() -> ort::Result<()> { +/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// let allocator = Allocator::new( +/// &session, +/// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)? +/// )?; +/// +/// let mut tensor = Tensor::::new(&allocator, [1, 3, 224, 224])?; +/// // Here, `data_ptr` is a pointer to **device memory** inaccessible to the CPU; we'll need another crate, like +/// // `cudarc`, to access it. +/// let data_ptr = tensor.data_ptr_mut()?; +/// # Ok(()) +/// # } +/// ``` +/// +/// Note that `ort` does not facilitate the transfer of data between host & device outside of session inputs & +/// outputs; you'll need to use a separate crate for that, like [`cudarc`](https://crates.io/crates/cudarc) for CUDA. +/// +/// # Pinned allocation +/// Memory allocated on the host CPU is often *pageable* and may reside on the disk (swap memory). Transferring +/// pageable memory to another device is slow because the device has to go through the CPU to access the +/// memory. Many execution providers thus provide a "pinned" allocator type, which allocates *unpaged* CPU memory +/// that the device is able to access directly, bypassing the CPU and allowing for faster host-to-device data +/// transfer. +/// +/// If you create a session with a device allocator that supports pinned memory, like CUDA or ROCm, you can create +/// an allocator for that session, and use it to allocate tensors with faster pinned memory: +/// ```no_run +/// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; +/// # fn main() -> ort::Result<()> { +/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// let allocator = Allocator::new( +/// &session, +/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)? +/// )?; +/// +/// // Create a tensor with our pinned allocator. +/// let mut tensor = Tensor::::new(&allocator, [1, 3, 224, 224])?; +/// let data = tensor.extract_tensor_mut(); +/// // ...fill `data` with data... +/// # Ok(()) +/// # } +/// ``` #[derive(Debug)] pub struct Allocator { pub(crate) ptr: NonNull, + /// The 'default' CPU allocator, provided by `GetAllocatorWithDefaultOptions` and implemented by + /// [`Allocator::default`], should **not** be released, so this field marks whether or not we should call + /// `ReleaseAllocator` on drop. is_default: bool, - _info: Option + _info: Option, + /// Hold a reference to the session if this allocator is tied to one. + _session_inner: Option> } impl Allocator { @@ -22,47 +78,46 @@ impl Allocator { Allocator { ptr: NonNull::new_unchecked(ptr), is_default: false, + // currently, this function is only ever used in session creation, where we call `CreateAllocator` manually and store the allocator resulting from + // this function in the `SharedSessionInner` - we don't need to hold onto the session, because the session is holding onto us. + _session_inner: None, _info: None } } + /// Frees an object allocated by this allocator, given the object's C pointer. pub(crate) unsafe fn free(&self, ptr: *mut T) { self.ptr.as_ref().Free.unwrap_or_else(|| unreachable!("Allocator method `Free` is null"))(self.ptr.as_ptr(), ptr.cast()); } /// Creates a new [`Allocator`] for the given session, to allocate memory on the device described in the /// [`MemoryInfo`]. - /// - /// For example, to create an allocator to allocate pinned memory for CUDA: - /// ```no_run - /// # use ort::{Allocator, Session, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; - /// # fn main() -> ort::Result<()> { - /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; - /// let allocator = Allocator::new( - /// &session, - /// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)? - /// )?; - /// # Ok(()) - /// # } - /// ``` pub fn new(session: &Session, memory_info: MemoryInfo) -> Result { let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); ortsys![unsafe CreateAllocator(session.ptr(), memory_info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)]; Ok(Self { ptr: unsafe { NonNull::new_unchecked(allocator_ptr) }, is_default: false, + _session_inner: Some(session.inner()), _info: Some(memory_info) }) } } impl Default for Allocator { + /// Returns the default CPU allocator; equivalent to `MemoryInfo::new(AllocationDevice::CPU, 0, + /// AllocatorType::Device, MemoryType::Default)`. + /// + /// The allocator returned by this function is actually shared across all invocations (though this behavior is + /// transparent to the user). fn default() -> Self { let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); status_to_result(ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr); nonNull(allocator_ptr)]).expect("Failed to get default allocator"); Self { ptr: unsafe { NonNull::new_unchecked(allocator_ptr) }, is_default: true, + // The default allocator isn't tied to a session. + _session_inner: None, _info: None } } @@ -70,8 +125,6 @@ impl Default for Allocator { impl Drop for Allocator { fn drop(&mut self) { - // per GetAllocatorWithDefaultOptions docs: Returned value should NOT be freed - // https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a8dec797ae52ee1a681e4f88be1fb4bb3 if !self.is_default { ortsys![unsafe ReleaseAllocator(self.ptr.as_ptr())]; } @@ -81,7 +134,8 @@ impl Drop for Allocator { /// Represents possible devices that have their own device allocator. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AllocationDevice { - // https://github.com/microsoft/onnxruntime/blob/v1.17.0/include/onnxruntime/core/framework/allocator.h#L43-L53 + // https://github.com/microsoft/onnxruntime/blob/v1.18.1/include/onnxruntime/core/framework/allocator.h#L43-L53 + // ort will likely never support WebGPU, so I think it's best to leave `WebGPU_Buffer` out entirely to reduce confusion CPU, CUDA, CUDAPinned, @@ -91,12 +145,10 @@ pub enum AllocationDevice { HIP, HIPPinned, OpenVINOCPU, - OpenVINOGPU, - WebGPUBuffer + OpenVINOGPU } impl AllocationDevice { - #[must_use] pub fn as_str(&self) -> &'static str { match self { Self::CPU => "Cpu", @@ -108,10 +160,15 @@ impl AllocationDevice { Self::HIP => "Hip", Self::HIPPinned => "HipPinned", Self::OpenVINOCPU => "OpenVINO_CPU", - Self::OpenVINOGPU => "OpenVINO_GPU", - Self::WebGPUBuffer => "WebGPU_Buffer" + Self::OpenVINOGPU => "OpenVINO_GPU" } } + + /// Returns `true` if this memory is accessible by the CPU; meaning that, if a value were allocated on this device, + /// it could be extracted to an `ndarray` or slice. + pub fn is_cpu_accessible(&self) -> bool { + matches!(self, Self::CPU | Self::CUDAPinned | Self::CANNPinned | Self::HIPPinned | Self::OpenVINOCPU) + } } impl TryFrom for AllocationDevice { @@ -129,14 +186,13 @@ impl TryFrom for AllocationDevice { "HipPinned" => Ok(AllocationDevice::HIPPinned), "OpenVINO_CPU" => Ok(AllocationDevice::OpenVINOCPU), "OpenVINO_GPU" => Ok(AllocationDevice::OpenVINOGPU), - "WebGPUBuffer" => Ok(AllocationDevice::WebGPUBuffer), _ => Err(value) } } } /// Execution provider allocator type. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AllocatorType { /// Default device-specific allocator. Device, @@ -154,11 +210,11 @@ impl From for ort_sys::OrtAllocatorType { } /// Memory types for allocated memory. -#[derive(Default, Debug, Copy, Clone)] +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] pub enum MemoryType { /// Any CPU memory used by non-CPU execution provider. CPUInput, - /// CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED. + /// CPU-accessible memory output by a non-CPU execution provider, i.e. [`AllocatorDevice::CUDAPinned`]. CPUOutput, /// The default allocator for an execution provider. #[default] @@ -190,6 +246,12 @@ impl From for MemoryType { } } +/// Structure describing a memory location - the device on which the memory resides, the type of allocator (device +/// default, or arena) used, and the type of memory allocated (device-only, or CPU accessible). +/// +/// `MemoryInfo` is used in the creation of [`Session`]s, [`Allocator`]s, and [`crate::Value`]s to describe on which +/// device value data should reside, and how that data should be accessible with regard to the CPU (if a non-CPU device +/// is requested). #[derive(Debug)] pub struct MemoryInfo { pub(crate) ptr: NonNull, @@ -197,24 +259,24 @@ pub struct MemoryInfo { } impl MemoryInfo { - pub(crate) fn from_raw(ptr: NonNull, should_release: bool) -> Self { - MemoryInfo { ptr, should_release } - } - - #[tracing::instrument] - pub fn new_cpu(allocator: AllocatorType, memory_type: MemoryType) -> Result { - let mut memory_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut(); - ortsys![ - unsafe CreateCpuMemoryInfo(allocator.into(), memory_type.into(), &mut memory_info_ptr) -> Error::CreateMemoryInfo; - nonNull(memory_info_ptr) - ]; - Ok(Self { - ptr: unsafe { NonNull::new_unchecked(memory_info_ptr) }, - should_release: true - }) - } - - #[tracing::instrument] + /// Creates a [`MemoryInfo`], describing a memory location on a device allocator. + /// + /// # Examples + /// `MemoryInfo` can be used to specify the device & memory type used by an [`Allocator`] to allocate tensors. + /// See [`Allocator`] for more information & potential applications. + /// ```no_run + /// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let allocator = Allocator::new( + /// &session, + /// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)? + /// )?; + /// + /// let mut tensor = Tensor::::new(&allocator, [1, 3, 224, 224])?; + /// # Ok(()) + /// # } + /// ``` pub fn new(allocation_device: AllocationDevice, device_id: c_int, allocator_type: AllocatorType, memory_type: MemoryType) -> Result { let mut memory_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut(); let allocator_name = CString::new(allocation_device.as_str()).unwrap_or_else(|_| unreachable!()); @@ -229,7 +291,19 @@ impl MemoryInfo { }) } + pub(crate) fn from_raw(ptr: NonNull, should_release: bool) -> Self { + MemoryInfo { ptr, should_release } + } + /// Returns the [`MemoryType`] described by this struct. + /// ``` + /// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?; + /// assert_eq!(mem.memory_type()?, MemoryType::Default); + /// # Ok(()) + /// # } + /// ``` pub fn memory_type(&self) -> Result { let mut raw_type: ort_sys::OrtMemType = ort_sys::OrtMemType::OrtMemTypeDefault; ortsys![unsafe MemoryInfoGetMemType(self.ptr.as_ptr(), &mut raw_type) -> Error::GetMemoryType]; @@ -237,6 +311,14 @@ impl MemoryInfo { } /// Returns the [`AllocatorType`] described by this struct. + /// ``` + /// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?; + /// assert_eq!(mem.allocator_type()?, AllocatorType::Device); + /// # Ok(()) + /// # } + /// ``` pub fn allocator_type(&self) -> Result { let mut raw_type: ort_sys::OrtAllocatorType = ort_sys::OrtAllocatorType::OrtInvalidAllocator; ortsys![unsafe MemoryInfoGetType(self.ptr.as_ptr(), &mut raw_type) -> Error::GetAllocatorType]; @@ -248,6 +330,14 @@ impl MemoryInfo { } /// Returns the [`AllocationDevice`] this struct was created with. + /// ``` + /// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?; + /// assert_eq!(mem.allocation_device()?, AllocationDevice::CPU); + /// # Ok(()) + /// # } + /// ``` pub fn allocation_device(&self) -> Result { let mut name_ptr: *const c_char = std::ptr::null_mut(); ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr) -> Error::GetAllocationDevice; nonNull(name_ptr)]; @@ -258,6 +348,14 @@ impl MemoryInfo { } /// Returns the ID of the [`AllocationDevice`] described by this struct. + /// ``` + /// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?; + /// assert_eq!(mem.device_id()?, 0); + /// # Ok(()) + /// # } + /// ``` pub fn device_id(&self) -> Result { let mut raw: ort_sys::c_int = 0; ortsys![unsafe MemoryInfoGetId(self.ptr.as_ptr(), &mut raw) -> Error::GetDeviceId]; @@ -266,24 +364,9 @@ impl MemoryInfo { } impl Drop for MemoryInfo { - #[tracing::instrument] fn drop(&mut self) { if self.should_release { ortsys![unsafe ReleaseMemoryInfo(self.ptr.as_ptr())]; } } } - -#[cfg(test)] -mod tests { - use test_log::test; - - use super::*; - - #[test] - fn create_memory_info() -> crate::Result<()> { - let memory_info = MemoryInfo::new_cpu(AllocatorType::Device, MemoryType::Default)?; - std::mem::drop(memory_info); - Ok(()) - } -} diff --git a/src/metadata.rs b/src/metadata.rs index 5464e5f9..84fc69e2 100644 --- a/src/metadata.rs +++ b/src/metadata.rs @@ -1,7 +1,11 @@ use std::{ffi::CString, os::raw::c_char, ptr::NonNull}; -use super::{char_p_to_string, error::Result, ortsys, Error}; -use crate::Allocator; +use crate::{ + char_p_to_string, + error::{Error, Result}, + memory::Allocator, + ortsys +}; /// Container for model metadata, including name & producer information. pub struct ModelMetadata<'s> { diff --git a/src/operator/bound.rs b/src/operator/bound.rs index 30a7ccbb..452736fa 100644 --- a/src/operator/bound.rs +++ b/src/operator/bound.rs @@ -11,8 +11,7 @@ use super::{ }; use crate::{error::IntoStatus, extern_system_fn}; -#[repr(C)] -#[derive(Clone)] +#[repr(C)] // <- important! a defined layout allows us to store extra data after the `OrtCustomOp` that we can retrieve later pub(crate) struct BoundOperator { implementation: ort_sys::OrtCustomOp, name: CString, @@ -43,6 +42,10 @@ impl BoundOperator { GetVariadicInputMinArity: Some(BoundOperator::::GetVariadicInputMinArity), GetVariadicOutputHomogeneity: Some(BoundOperator::::GetVariadicOutputHomogeneity), GetVariadicOutputMinArity: Some(BoundOperator::::GetVariadicOutputMinArity), + GetAliasMap: None, + ReleaseAliasMap: None, + GetMayInplace: None, + ReleaseMayInplace: None, InferOutputShapeFn: if O::get_infer_shape_function().is_some() { Some(BoundOperator::::InferOutputShapeFn) } else { @@ -212,7 +215,10 @@ unsafe impl Send for ErasedBoundOperator {} impl ErasedBoundOperator { pub(crate) fn new(bound: BoundOperator) -> Self { - ErasedBoundOperator(NonNull::from(unsafe { &mut *(Box::leak(Box::new(bound)) as *mut _ as *mut ()) })) + ErasedBoundOperator(NonNull::from(unsafe { + // horrible horrible horrible horrible horrible horrible horrible horrible horrible + &mut *(Box::leak(Box::new(bound)) as *mut _ as *mut ()) + })) } pub(crate) fn op_ptr(&self) -> *mut ort_sys::OrtCustomOp { diff --git a/src/operator/io.rs b/src/operator/io.rs index 5a7507a8..16d0e93e 100644 --- a/src/operator/io.rs +++ b/src/operator/io.rs @@ -1,4 +1,4 @@ -use crate::{MemoryType, TensorElementType}; +use crate::{memory::MemoryType, tensor::TensorElementType}; #[repr(i32)] #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index 733d6158..15c8a9f8 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -1,9 +1,16 @@ use std::{ - ffi::CString, + ffi::{c_char, CString}, + ops::{Deref, DerefMut}, ptr::{self, NonNull} }; -use crate::{error::status_to_result, ortsys, value::ValueRefMut, Error, Result, Value, ValueRef}; +use crate::{ + error::{status_to_result, Error, ErrorInternal, Result}, + memory::{Allocator, MemoryInfo}, + ortsys, + session::{Input, Output}, + value::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut, ValueType} +}; pub trait Kernel { fn compute(&mut self, ctx: &KernelContext) -> crate::Result<()>; @@ -25,19 +32,74 @@ impl KernelAttributes { } #[allow(private_bounds)] - pub fn get(&self, name: impl AsRef) -> Option { + pub fn get<'s, T: GetKernelAttribute<'s>>(&'s self, name: impl AsRef) -> Option { let name = CString::new(name.as_ref()).ok()?; T::get_from(self.0.as_ptr(), name.as_ptr()) } + + pub fn inputs(&self) -> Result> { + let mut num_inputs: ort_sys::size_t = 0; + ortsys![unsafe KernelInfo_GetInputCount(self.0.as_ptr(), &mut num_inputs) -> Error::GetOperatorInput]; + + let mut inputs = Vec::with_capacity(num_inputs as _); + for idx in 0..num_inputs as usize { + let mut name_len: ort_sys::size_t = 0; + ortsys![unsafe KernelInfo_GetInputName(self.0.as_ptr(), idx as _, ptr::null_mut(), &mut name_len) -> Error::GetOperatorInput]; + let mut name = vec![0u8; name_len as _]; + ortsys![unsafe KernelInfo_GetInputName(self.0.as_ptr(), idx as _, name.as_mut_ptr().cast::(), &mut name_len) -> Error::GetOperatorInput]; + let name = CString::from_vec_with_nul(name) + .map_err(|e| Error::FfiStringConversion(ErrorInternal::Msg(e.to_string())))? + .into_string() + .map_err(|e| Error::FfiStringConversion(ErrorInternal::IntoStringError(e)))?; + let mut type_info = ptr::null_mut(); + ortsys![unsafe KernelInfo_GetInputTypeInfo(self.0.as_ptr(), idx as _, &mut type_info) -> Error::GetOperatorInput; nonNull(type_info)]; + let input_type = ValueType::from_type_info(type_info)?; + inputs.push(Input { name, input_type }) + } + Ok(inputs) + } + + pub fn outputs(&self) -> Result> { + let mut num_outputs: ort_sys::size_t = 0; + ortsys![unsafe KernelInfo_GetOutputCount(self.0.as_ptr(), &mut num_outputs) -> Error::GetOperatorOutput]; + + let mut outputs = Vec::with_capacity(num_outputs as _); + for idx in 0..num_outputs as usize { + let mut name_len: ort_sys::size_t = 0; + ortsys![unsafe KernelInfo_GetOutputName(self.0.as_ptr(), idx as _, ptr::null_mut(), &mut name_len) -> Error::GetOperatorOutput]; + let mut name = vec![0u8; name_len as _]; + ortsys![unsafe KernelInfo_GetOutputName(self.0.as_ptr(), idx as _, name.as_mut_ptr().cast::(), &mut name_len) -> Error::GetOperatorOutput]; + let name = CString::from_vec_with_nul(name) + .map_err(|e| Error::FfiStringConversion(ErrorInternal::Msg(e.to_string())))? + .into_string() + .map_err(|e| Error::FfiStringConversion(ErrorInternal::IntoStringError(e)))?; + let mut type_info = ptr::null_mut(); + ortsys![unsafe KernelInfo_GetOutputTypeInfo(self.0.as_ptr(), idx as _, &mut type_info) -> Error::GetOperatorOutput; nonNull(type_info)]; + let output_type = ValueType::from_type_info(type_info)?; + outputs.push(Output { name, output_type }) + } + Ok(outputs) + } + + pub fn node_name(&self) -> Result { + let mut name_len: ort_sys::size_t = 0; + ortsys![unsafe KernelInfo_GetNodeName(self.0.as_ptr(), ptr::null_mut(), &mut name_len) -> Error::GetOperatorNodeName]; + let mut name = vec![0u8; name_len as _]; + ortsys![unsafe KernelInfo_GetNodeName(self.0.as_ptr(), name.as_mut_ptr().cast::(), &mut name_len) -> Error::GetOperatorNodeName]; + CString::from_vec_with_nul(name) + .map_err(|e| Error::FfiStringConversion(ErrorInternal::Msg(e.to_string())))? + .into_string() + .map_err(|e| Error::FfiStringConversion(ErrorInternal::IntoStringError(e))) + } } -pub trait GetKernelAttribute { +pub trait GetKernelAttribute<'s> { fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option where Self: Sized; } -impl GetKernelAttribute for f32 { +impl<'s> GetKernelAttribute<'s> for f32 { fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option where Self: Sized @@ -48,6 +110,100 @@ impl GetKernelAttribute for f32 { } } +impl<'s> GetKernelAttribute<'s> for i64 { + fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option + where + Self: Sized + { + let mut value = Self::default(); + status_to_result(ortsys![unsafe KernelInfoGetAttribute_int64(info, name, &mut value)]).ok()?; + Some(value) + } +} + +impl<'s> GetKernelAttribute<'s> for String { + fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option + where + Self: Sized + { + let mut size = ort_sys::size_t::default(); + status_to_result(ortsys![unsafe KernelInfoGetAttribute_string(info, name, ptr::null_mut(), &mut size)]).ok()?; + let mut out = vec![0u8; size as _]; + status_to_result(ortsys![unsafe KernelInfoGetAttribute_string(info, name, out.as_mut_ptr().cast::(), &mut size)]).ok()?; + CString::from_vec_with_nul(out).ok().and_then(|c| c.into_string().ok()) + } +} + +impl<'s> GetKernelAttribute<'s> for Vec { + fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option + where + Self: Sized + { + let mut size = ort_sys::size_t::default(); + status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_float(info, name, ptr::null_mut(), &mut size)]).ok()?; + let mut out = vec![0f32; size as _]; + status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_float(info, name, out.as_mut_ptr(), &mut size)]).ok()?; + Some(out) + } +} + +impl<'s> GetKernelAttribute<'s> for Vec { + fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option + where + Self: Sized + { + let mut size = ort_sys::size_t::default(); + status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_int64(info, name, ptr::null_mut(), &mut size)]).ok()?; + let mut out = vec![0i64; size as _]; + status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_int64(info, name, out.as_mut_ptr(), &mut size)]).ok()?; + Some(out) + } +} + +impl<'s, T: DowncastableTarget> GetKernelAttribute<'s> for ValueRef<'s, T> { + fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option + where + Self: Sized + { + // TODO: This should probably be customizable - docs say the allocator is required for "internal tensor state", but it's + // not clear if this also includes tensor data (and thus it should instead be allocated on an appropriate device). + let allocator = Allocator::default(); + + let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); + status_to_result(ortsys![unsafe KernelInfoGetAttribute_tensor(info, name, allocator.ptr.as_ptr(), &mut value_ptr)]).ok()?; + unsafe { ValueRef::new(DynValue::from_ptr(NonNull::new(value_ptr)?, None)) } + .downcast() + .ok() + } +} + +pub struct ScratchBuffer { + allocator: Allocator, + buffer: *mut T, + size: usize +} + +impl Deref for ScratchBuffer { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + unsafe { std::slice::from_raw_parts(self.buffer.cast_const(), self.size) } + } +} +impl DerefMut for ScratchBuffer { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { std::slice::from_raw_parts_mut(self.buffer, self.size) } + } +} + +impl Drop for ScratchBuffer { + fn drop(&mut self) { + unsafe { + self.allocator.free(self.buffer); + } + } +} + pub struct KernelContext { ptr: NonNull } @@ -71,4 +227,60 @@ impl KernelContext { ortsys![unsafe KernelContext_GetOutput(self.ptr.as_ptr(), idx as ort_sys::size_t, shape.as_ptr(), shape.len() as _, &mut value_ptr) -> Error::GetOperatorOutput]; Ok(NonNull::new(value_ptr).map(|c| ValueRefMut::new(unsafe { Value::from_ptr_nodrop(c, None) }))) } + + pub fn num_inputs(&self) -> Result { + let mut num: ort_sys::size_t = 0; + ortsys![unsafe KernelContext_GetInputCount(self.ptr.as_ptr(), &mut num) -> Error::GetOperatorInput]; + Ok(num as _) + } + + pub fn num_outputs(&self) -> Result { + let mut num: ort_sys::size_t = 0; + ortsys![unsafe KernelContext_GetOutputCount(self.ptr.as_ptr(), &mut num) -> Error::GetOperatorOutput]; + Ok(num as _) + } + + pub fn allocator(&self, memory_info: &MemoryInfo) -> Result { + let mut allocator_ptr = ptr::null_mut(); + ortsys![unsafe KernelContext_GetAllocator(self.ptr.as_ptr(), memory_info.ptr.as_ptr(), &mut allocator_ptr) -> Error::GetKernelAllocator]; + println!("allocator ptr {allocator_ptr:?}"); + Ok(unsafe { Allocator::from_raw_unchecked(allocator_ptr) }) + } + + pub fn get_resource(&self, id: ort_sys::c_int, version: ort_sys::c_int) -> Result>> { + let mut resource_ptr: *mut ort_sys::c_void = ptr::null_mut(); + ortsys![unsafe KernelContext_GetResource(self.ptr.as_ptr(), version, id, &mut resource_ptr) -> Error::GetKernelResource]; + Ok(NonNull::new(resource_ptr)) + } + + // TODO: STATUS_ACCESS_VIOLATION inside `KernelContext_GetScratchBuffer`. gonna assume this one is just an internal ONNX + // Runtime bug. + // + // pub fn allocate(&self, memory_info: &MemoryInfo, len: usize) -> Result> { + // let mut buffer = ptr::null_mut(); + // let allocator = self.allocator(memory_info)?; + // ortsys![ + // unsafe KernelContext_GetScratchBuffer( + // self.ptr.as_ptr(), + // memory_info.ptr.as_ptr(), + // (len * std::mem::size_of::()) as ort_sys::size_t, + // &mut buffer + // ) -> Error::GetKernelBuffer; + // nonNull(buffer) + // ]; + // Ok(ScratchBuffer { + // allocator, + // buffer: buffer.cast::(), + // size: len + // }) + // } + + /// Returns a pointer to the GPU compute stream (i.e. `cudaStream_t`) used by the execution provider, if this + /// kernel's operator was configured to use said execution provider (see + /// [`super::Operator::execution_provider_type`]). + pub fn compute_stream(&self) -> Result>> { + let mut stream_ptr: *mut ort_sys::c_void = ptr::null_mut(); + ortsys![unsafe KernelContext_GetGPUComputeStream(self.ptr.as_ptr(), &mut stream_ptr) -> Error::GetKernelGPUComputeStream]; + Ok(NonNull::new(stream_ptr)) + } } diff --git a/src/operator/mod.rs b/src/operator/mod.rs index 7e873516..207a74d9 100644 --- a/src/operator/mod.rs +++ b/src/operator/mod.rs @@ -8,19 +8,37 @@ pub(crate) mod io; pub(crate) mod kernel; use self::{ - bound::ErasedBoundOperator, + bound::{BoundOperator, ErasedBoundOperator}, io::{OperatorInput, OperatorOutput}, kernel::{DummyKernel, Kernel, KernelAttributes} }; -use crate::{operator::bound::BoundOperator, ortsys, Error, Result}; +use crate::{ + error::{Error, Result}, + ortsys +}; pub type InferShapeFn = dyn FnMut(*mut ort_sys::OrtShapeInferContext) -> crate::Result<()>; +/// A custom operator descriptor, which describes the expected inputs & outputs of a graph operator. +/// +/// [`Operator`]s are bound to [`OperatorDomain`]s. Multiple operators can have the same name as long as they have +/// different input/output types, in which case the exact operator will be picked depending on the input/output +/// types. If you want to, for example, define a `Sort` operator that can accept either a single `f32` or `i64` tensor +/// input, you'll need to define 2 separate operators (which can be done via a macro); but both of these +/// [`Operator`] structs can return the same name in [`Operator::name`] so that they are usable as simply +/// `my.domain:Sort` in the graph. pub trait Operator: Send { type Kernel: Kernel; + /// Returns the name of the operator. fn name() -> &'static str; + /// Returns the execution provider this operator runs on, e.g. `CUDAExecutionProvider`. + /// + /// If the returned type is not `None`, and the execution provider used by the session matches this operator's + /// EP type, the value will not be copied to the CPU and you may use functions like [`crate::Tensor::data_ptr`] to + /// access the underlying device memory, and [`super::KernelContext::compute_stream`] to access the GPU compute + /// stream. fn execution_provider_type() -> Option<&'static str> { None } @@ -42,6 +60,7 @@ pub trait Operator: Send { } } +/// Dummy type implementing [`Operator`] used by [`ErasedBoundOperator`] to cheat the type system. struct DummyOperator; impl Operator for DummyOperator { @@ -84,7 +103,7 @@ impl OperatorDomain { } #[allow(clippy::should_implement_trait)] - pub fn add(mut self, _operator: O) -> Result { + pub fn add(mut self) -> Result { let name = O::name(); let bound = BoundOperator::::new(CString::new(name)?, O::execution_provider_type().map(CString::new).transpose()?); diff --git a/src/session/async.rs b/src/session/async.rs index 104618b1..c02a8eb2 100644 --- a/src/session/async.rs +++ b/src/session/async.rs @@ -2,61 +2,41 @@ use std::{ cell::UnsafeCell, ffi::{c_char, CString}, future::Future, - mem::MaybeUninit, + ops::Deref, pin::Pin, ptr::NonNull, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, Mutex - }, + sync::{Arc, Mutex}, task::{Context, Poll, Waker} }; use ort_sys::{c_void, OrtStatus}; -use crate::{error::assert_non_null_pointer, Error, Result, RunOptions, SessionInputValue, SessionOutputs, SharedSessionInner, Value}; - -pub(crate) enum InnerValue { - Present(T), - Pending, - Closed -} - -const VALUE_PRESENT: usize = 1 << 0; -const CHANNEL_CLOSED: usize = 1 << 1; +use crate::{ + error::{assert_non_null_pointer, Error, Result}, + session::{RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner}, + value::Value +}; #[derive(Debug)] -pub(crate) struct InferenceFutInner<'s> { - presence: AtomicUsize, - value: UnsafeCell>>>, +pub(crate) struct InferenceFutInner<'r, 's> { + value: UnsafeCell>>>, waker: Mutex> } -impl<'s> InferenceFutInner<'s> { +impl<'r, 's> InferenceFutInner<'r, 's> { pub(crate) fn new() -> Self { InferenceFutInner { - presence: AtomicUsize::new(0), waker: Mutex::new(None), - value: UnsafeCell::new(MaybeUninit::uninit()) + value: UnsafeCell::new(None) } } - pub(crate) fn try_take(&self) -> InnerValue>> { - let state_snapshot = self.presence.fetch_and(!VALUE_PRESENT, Ordering::Acquire); - if state_snapshot & VALUE_PRESENT == 0 { - if self.presence.load(Ordering::Acquire) & CHANNEL_CLOSED != 0 { - InnerValue::Closed - } else { - InnerValue::Pending - } - } else { - InnerValue::Present(unsafe { (*self.value.get()).assume_init_read() }) - } + pub(crate) fn try_take(&self) -> Option>> { + unsafe { &mut *self.value.get() }.take() } - pub(crate) fn emplace_value(&self, value: Result>) { - unsafe { (*self.value.get()).write(value) }; - self.presence.fetch_or(VALUE_PRESENT, Ordering::Release); + pub(crate) fn emplace_value(&self, value: Result>) { + unsafe { &mut *self.value.get() }.replace(value); } pub(crate) fn set_waker(&self, waker: Option<&Waker>) { @@ -68,31 +48,47 @@ impl<'s> InferenceFutInner<'s> { waker.wake(); } } +} + +unsafe impl<'r, 's> Send for InferenceFutInner<'r, 's> {} +unsafe impl<'r, 's> Sync for InferenceFutInner<'r, 's> {} + +pub enum RunOptionsRef<'r, O: SelectedOutputMarker> { + Arc(Arc>), + Ref(&'r RunOptions) +} - pub(crate) fn close(&self) -> bool { - self.presence.fetch_or(CHANNEL_CLOSED, Ordering::Acquire) & CHANNEL_CLOSED == 0 +impl<'r, O: SelectedOutputMarker> From<&Arc>> for RunOptionsRef<'r, O> { + fn from(value: &Arc>) -> Self { + Self::Arc(Arc::clone(value)) } } -impl<'s> Drop for InferenceFutInner<'s> { - fn drop(&mut self) { - if self.presence.load(Ordering::Acquire) & VALUE_PRESENT != 0 { - unsafe { (*self.value.get()).assume_init_drop() }; - } +impl<'r, O: SelectedOutputMarker> From<&'r RunOptions> for RunOptionsRef<'r, O> { + fn from(value: &'r RunOptions) -> Self { + Self::Ref(value) } } -unsafe impl<'s> Send for InferenceFutInner<'s> {} -unsafe impl<'s> Sync for InferenceFutInner<'s> {} +impl<'r, O: SelectedOutputMarker> Deref for RunOptionsRef<'r, O> { + type Target = RunOptions; + + fn deref(&self) -> &Self::Target { + match self { + Self::Arc(r) => r, + Self::Ref(r) => r + } + } +} -pub struct InferenceFut<'s> { - inner: Arc>, - run_options: Arc, +pub struct InferenceFut<'s, 'r, O: SelectedOutputMarker> { + inner: Arc>, + run_options: RunOptionsRef<'r, O>, did_receive: bool } -impl<'s> InferenceFut<'s> { - pub(crate) fn new(inner: Arc>, run_options: Arc) -> Self { +impl<'s, 'r, O: SelectedOutputMarker> InferenceFut<'s, 'r, O> { + pub(crate) fn new(inner: Arc>, run_options: RunOptionsRef<'r, O>) -> Self { Self { inner, run_options, @@ -101,38 +97,33 @@ impl<'s> InferenceFut<'s> { } } -impl<'s> Future for InferenceFut<'s> { - type Output = Result>; +impl<'s, 'r, O: SelectedOutputMarker> Future for InferenceFut<'s, 'r, O> { + type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = Pin::into_inner(self); - match this.inner.try_take() { - InnerValue::Present(v) => { - this.did_receive = true; - return Poll::Ready(v); - } - InnerValue::Pending => {} - InnerValue::Closed => panic!() - }; + if let Some(v) = this.inner.try_take() { + this.did_receive = true; + return Poll::Ready(v); + } this.inner.set_waker(Some(cx.waker())); - Poll::Pending } } -impl<'s> Drop for InferenceFut<'s> { +impl<'s, 'r, O: SelectedOutputMarker> Drop for InferenceFut<'s, 'r, O> { fn drop(&mut self) { - if !self.did_receive && self.inner.close() { + if !self.did_receive { let _ = self.run_options.terminate(); self.inner.set_waker(None); } } } -pub(crate) struct AsyncInferenceContext<'s> { - pub(crate) inner: Arc>, +pub(crate) struct AsyncInferenceContext<'r, 's> { + pub(crate) inner: Arc>, pub(crate) _input_values: Vec>, pub(crate) input_ort_values: Vec<*const ort_sys::OrtValue>, pub(crate) input_name_ptrs: Vec<*const c_char>, @@ -144,7 +135,7 @@ pub(crate) struct AsyncInferenceContext<'s> { crate::extern_system_fn! { pub(crate) fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: ort_sys::size_t, status: *mut OrtStatus) { - let ctx = unsafe { Box::from_raw(user_data.cast::>()) }; + let ctx = unsafe { Box::from_raw(user_data.cast::>()) }; // Reconvert name ptrs to CString so drop impl is called and memory is freed drop( diff --git a/src/session/builder.rs b/src/session/builder.rs index 60f716e0..458c6ade 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -1,7 +1,7 @@ -#[cfg(unix)] -use std::os::unix::ffi::OsStrExt; -#[cfg(target_family = "windows")] -use std::os::windows::ffi::OsStrExt; +#[cfg(any(feature = "operator-libraries", not(windows)))] +use std::ffi::CString; +#[cfg(not(target_arch = "wasm32"))] +use std::path::Path; #[cfg(feature = "fetch-models")] use std::path::PathBuf; use std::{ @@ -11,8 +11,6 @@ use std::{ rc::Rc, sync::{atomic::Ordering, Arc} }; -#[cfg(not(target_arch = "wasm32"))] -use std::{ffi::CString, path::Path}; use super::{dangerous, InMemorySession, Input, Output, Session, SharedSessionInner}; #[cfg(feature = "fetch-models")] @@ -21,8 +19,9 @@ use crate::{ environment::get_environment, error::{assert_non_null_pointer, status_to_result, Error, Result}, execution_providers::{apply_execution_providers, ExecutionProviderDispatch}, - memory::Allocator, - ortsys, MemoryInfo, OperatorDomain + memory::{Allocator, MemoryInfo}, + operator::OperatorDomain, + ortsys }; /// Creates a session using the builder pattern. @@ -112,7 +111,7 @@ impl SessionBuilder { /// `CUDAExecutionProvider`) **is discouraged** unless you allow the user to configure the execution providers by /// providing a `Vec` of [`ExecutionProviderDispatch`]es. pub fn with_execution_providers(self, execution_providers: impl IntoIterator) -> Result { - apply_execution_providers(&self, execution_providers.into_iter()); + apply_execution_providers(&self, execution_providers.into_iter())?; Ok(self) } @@ -313,23 +312,10 @@ impl SessionBuilder { }); } - // Build an OsString, then a vector of bytes to pass to C - let model_path = std::ffi::OsString::from(model_filepath); - #[cfg(target_family = "windows")] - let model_path: Vec = model_path - .encode_wide() - .chain(std::iter::once(0)) // Make sure we have a null terminated string - .collect(); - #[cfg(not(target_family = "windows"))] - let model_path: Vec = model_path - .as_encoded_bytes() - .iter() - .chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string - .map(|b| *b as std::os::raw::c_char) - .collect(); + let model_path = crate::util::path_to_os_char(model_filepath); let env = get_environment()?; - apply_execution_providers(&self, env.execution_providers.iter().cloned()); + apply_execution_providers(&self, env.execution_providers.iter().cloned())?; if env.has_global_threadpool { ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions]; @@ -406,7 +392,7 @@ impl SessionBuilder { let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); let env = get_environment()?; - apply_execution_providers(&self, env.execution_providers.iter().cloned()); + apply_execution_providers(&self, env.execution_providers.iter().cloned())?; if env.has_global_threadpool { ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions]; diff --git a/src/session/input.rs b/src/session/input.rs index ce33f006..31a1433f 100644 --- a/src/session/input.rs +++ b/src/session/input.rs @@ -1,9 +1,6 @@ use std::{borrow::Cow, collections::HashMap, ops::Deref}; -use crate::{ - value::{DynValueTypeMarker, ValueTypeMarker}, - Value, ValueRef, ValueRefMut -}; +use crate::value::{DynValueTypeMarker, Value, ValueRef, ValueRefMut, ValueTypeMarker}; pub enum SessionInputValue<'v> { ViewMut(ValueRefMut<'v, DynValueTypeMarker>), @@ -92,16 +89,15 @@ impl<'i, 'v, const N: usize> From<[SessionInputValue<'v>; N]> for SessionInputs< /// # } /// ``` /// -/// Note that string tensors must be created manually with [`Value::from_string_array`]. +/// Note that string tensors must be created manually with [`crate::Tensor::from_string_array`]. /// /// ```no_run /// # use std::{error::Error, sync::Arc}; /// # use ndarray::Array1; -/// # use ort::{GraphOptimizationLevel, Session, Value}; +/// # use ort::{GraphOptimizationLevel, Session, Tensor}; /// # fn main() -> Result<(), Box> { /// # let mut session = Session::builder()?.commit_from_file("model.onnx")?; -/// let _ = session -/// .run(ort::inputs![Value::from_string_array(session.allocator(), Array1::from_vec(vec!["hello", "world"]))?]?); +/// let _ = session.run(ort::inputs![Tensor::from_string_array(Array1::from_vec(vec!["hello", "world"]))?]?); /// # Ok(()) /// # } /// ``` @@ -141,7 +137,8 @@ macro_rules! inputs { mod tests { use std::{collections::HashMap, sync::Arc}; - use crate::{DynTensor, SessionInputs}; + use super::SessionInputs; + use crate::value::DynTensor; #[test] fn test_hashmap_static_keys() -> crate::Result<()> { diff --git a/src/session/mod.rs b/src/session/mod.rs index 7e91a908..e31793e3 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -2,7 +2,7 @@ use std::{any::Any, ffi::CString, marker::PhantomData, ops::Deref, os::raw::c_char, ptr::NonNull, sync::Arc}; -use super::{ +use crate::{ char_p_to_string, environment::Environment, error::{assert_non_null_pointer, assert_null_pointer, status_to_result, Error, ErrorInternal, Result}, @@ -18,12 +18,14 @@ mod r#async; pub(crate) mod builder; pub(crate) mod input; pub(crate) mod output; -use self::r#async::{AsyncInferenceContext, InferenceFutInner}; +mod run_options; +use self::r#async::{AsyncInferenceContext, InferenceFutInner, RunOptionsRef}; pub use self::{ r#async::InferenceFut, builder::{GraphOptimizationLevel, SessionBuilder}, input::{SessionInputValue, SessionInputs}, - output::SessionOutputs + output::SessionOutputs, + run_options::{HasSelectedOutputs, NoSelectedOutputs, OutputSelector, RunOptions, SelectedOutputMarker} }; /// Holds onto an [`ort_sys::OrtSession`] pointer and its associated allocator. @@ -112,101 +114,6 @@ pub struct Output { pub output_type: ValueType } -/// A structure which can be passed to [`Session::run_with_options`] to allow terminating/unterminating a session -/// inference run from a different thread. -#[derive(Debug)] -pub struct RunOptions { - pub(crate) run_options_ptr: NonNull -} - -// https://onnxruntime.ai/docs/api/c/struct_ort_api.html#ac2a08cac0a657604bd5899e0d1a13675 -unsafe impl Send for RunOptions {} -unsafe impl Sync for RunOptions {} - -impl RunOptions { - /// Creates a new [`RunOptions`] struct. - pub fn new() -> Result { - let mut run_options_ptr: *mut ort_sys::OrtRunOptions = std::ptr::null_mut(); - ortsys![unsafe CreateRunOptions(&mut run_options_ptr) -> Error::CreateRunOptions; nonNull(run_options_ptr)]; - Ok(Self { - run_options_ptr: unsafe { NonNull::new_unchecked(run_options_ptr) } - }) - } - - /// Sets a tag to identify this run in logs. - pub fn set_tag(&mut self, tag: impl AsRef) -> Result<()> { - let tag = CString::new(tag.as_ref())?; - ortsys![unsafe RunOptionsSetRunTag(self.run_options_ptr.as_ptr(), tag.as_ptr()) -> Error::RunOptionsSetTag]; - Ok(()) - } - - /// Sets the termination flag for the runs associated with this [`RunOptions`]. - /// - /// This function returns immediately (it does not wait for the session run to terminate). The run will terminate as - /// soon as it is able to. - /// - /// ```no_run - /// # // no_run because upsample.onnx is too simple of a model for the termination signal to be reliable enough - /// # use std::sync::Arc; - /// # use ort::{Session, RunOptions, Value, ValueType, TensorElementType}; - /// # fn main() -> ort::Result<()> { - /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; - /// # let input = Value::from_array(ndarray::Array4::::zeros((1, 64, 64, 3)))?; - /// let run_options = Arc::new(RunOptions::new()?); - /// - /// let run_options_ = Arc::clone(&run_options); - /// std::thread::spawn(move || { - /// let _ = run_options_.terminate(); - /// }); - /// - /// let res = session.run_with_options(ort::inputs![input]?, run_options); - /// // upon termination, the session will return an `Error::SessionRun` error.` - /// assert_eq!( - /// &res.unwrap_err().to_string(), - /// "Failed to run inference on model: Exiting due to terminate flag being set to true." - /// ); - /// # Ok(()) - /// # } - /// ``` - pub fn terminate(&self) -> Result<()> { - ortsys![unsafe RunOptionsSetTerminate(self.run_options_ptr.as_ptr()) -> Error::RunOptionsSetTerminate]; - Ok(()) - } - - /// Resets the termination flag for the runs associated with [`RunOptions`]. - /// - /// ```no_run - /// # use std::sync::Arc; - /// # use ort::{Session, RunOptions, Value, ValueType, TensorElementType}; - /// # fn main() -> ort::Result<()> { - /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; - /// # let input = Value::from_array(ndarray::Array4::::zeros((1, 64, 64, 3)))?; - /// let run_options = Arc::new(RunOptions::new()?); - /// - /// let run_options_ = Arc::clone(&run_options); - /// std::thread::spawn(move || { - /// let _ = run_options_.terminate(); - /// // ...oops, didn't mean to do that - /// let _ = run_options_.unterminate(); - /// }); - /// - /// let res = session.run_with_options(ort::inputs![input]?, run_options); - /// assert!(res.is_ok()); - /// # Ok(()) - /// # } - /// ``` - pub fn unterminate(&self) -> Result<()> { - ortsys![unsafe RunOptionsUnsetTerminate(self.run_options_ptr.as_ptr()) -> Error::RunOptionsUnsetTerminate]; - Ok(()) - } -} - -impl Drop for RunOptions { - fn drop(&mut self) { - ortsys![unsafe ReleaseRunOptions(self.run_options_ptr.as_ptr())]; - } -} - impl Session { /// Creates a new [`SessionBuilder`]. pub fn builder() -> Result { @@ -252,17 +159,19 @@ impl Session { /// # Ok(()) /// # } /// ``` - pub fn run<'s, 'i, 'v: 'i, const N: usize>(&'s self, input_values: impl Into>) -> Result> { + pub fn run<'s, 'i, 'v: 'i, const N: usize>(&'s self, input_values: impl Into>) -> Result> { match input_values.into() { SessionInputs::ValueSlice(input_values) => { - self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) + self.run_inner::(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) } SessionInputs::ValueArray(input_values) => { - self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) - } - SessionInputs::ValueMap(input_values) => { - self.run_inner(&input_values.iter().map(|(k, _)| k.as_ref()).collect::>(), input_values.iter().map(|(_, v)| v), None) + self.run_inner::(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) } + SessionInputs::ValueMap(input_values) => self.run_inner::( + &input_values.iter().map(|(k, _)| k.as_ref()).collect::>(), + input_values.iter().map(|(_, v)| v), + None + ) } } @@ -283,7 +192,7 @@ impl Session { /// let _ = run_options_.terminate(); /// }); /// - /// let res = session.run_with_options(ort::inputs![input]?, run_options); + /// let res = session.run_with_options(ort::inputs![input]?, &*run_options); /// // upon termination, the session will return an `Error::SessionRun` error.` /// assert_eq!( /// &res.unwrap_err().to_string(), @@ -292,11 +201,11 @@ impl Session { /// # Ok(()) /// # } /// ``` - pub fn run_with_options<'s, 'i, 'v: 'i, const N: usize>( + pub fn run_with_options<'r, 's: 'r, 'i, 'v: 'i, O: SelectedOutputMarker, const N: usize>( &'s self, input_values: impl Into>, - run_options: Arc - ) -> Result> { + run_options: &'r RunOptions + ) -> Result> { match input_values.into() { SessionInputs::ValueSlice(input_values) => { self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), Some(run_options)) @@ -310,25 +219,34 @@ impl Session { } } - fn run_inner<'i, 'v: 'i>( - &self, + fn run_inner<'i, 'r, 's: 'r, 'v: 'i, O: SelectedOutputMarker>( + &'s self, input_names: &[&str], input_values: impl Iterator>, - run_options: Option> - ) -> Result> { + run_options: Option<&'r RunOptions> + ) -> Result> { let input_names_ptr: Vec<*const c_char> = input_names .iter() .map(|n| CString::new(n.as_bytes()).unwrap_or_else(|_| unreachable!())) .map(|n| n.into_raw().cast_const()) .collect(); - let output_names_ptr: Vec<*const c_char> = self - .outputs + + let (output_names, output_tensors) = match run_options { + Some(r) => r.outputs.resolve_outputs(&self.outputs), + None => (self.outputs.iter().map(|o| o.name.as_str()).collect(), std::iter::repeat_with(|| None).take(self.outputs.len()).collect()) + }; + let output_names_ptr: Vec<*const c_char> = output_names .iter() - .map(|output| CString::new(output.name.as_str()).unwrap_or_else(|_| unreachable!())) + .map(|n| CString::new(*n).unwrap_or_else(|_| unreachable!())) .map(|n| n.into_raw().cast_const()) .collect(); - - let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()]; + let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = output_tensors + .iter() + .map(|c| match c { + Some(v) => v.ptr(), + None => std::ptr::null_mut() + }) + .collect(); // The C API expects pointers for the arrays (pointers to C-arrays) let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); @@ -352,10 +270,17 @@ impl Session { ) -> Error::SessionRun ]; - let outputs: Vec = output_tensor_ptrs + let outputs: Vec = output_tensors .into_iter() - .map(|tensor_ptr| unsafe { - Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), Some(Arc::clone(&self.inner))) + .enumerate() + .map(|(i, v)| match v { + Some(value) => value, + None => unsafe { + Value::from_ptr( + NonNull::new(output_tensor_ptrs[i]).expect("OrtValue ptr returned from session Run should not be null"), + Some(Arc::clone(&self.inner)) + ) + } }) .collect(); @@ -371,7 +296,7 @@ impl Session { .collect::>>()? ); - Ok(SessionOutputs::new(self.outputs.iter().map(|o| o.name.as_str()), outputs)) + Ok(SessionOutputs::new(output_names.into_iter(), outputs)) } /// Asynchronously run input data through the ONNX graph, performing inference. @@ -393,22 +318,56 @@ impl Session { /// # Ok(()) /// # }) } /// ``` - pub fn run_async<'s, 'i, 'v: 'i + 's, const N: usize>(&'s self, input_values: impl Into> + 'static) -> Result> { + pub fn run_async<'s, 'i, 'v: 'i + 's, const N: usize>( + &'s self, + input_values: impl Into> + 'static + ) -> Result> { match input_values.into() { SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"), SessionInputs::ValueArray(input_values) => { - self.run_inner_async(&self.inputs.iter().map(|input| input.name.to_string()).collect::>(), input_values.into_iter()) + self.run_inner_async(&self.inputs.iter().map(|input| input.name.to_string()).collect::>(), input_values.into_iter(), None) } SessionInputs::ValueMap(input_values) => { - self.run_inner_async(&input_values.iter().map(|(k, _)| k.to_string()).collect::>(), input_values.into_iter().map(|(_, v)| v)) + self.run_inner_async(&input_values.iter().map(|(k, _)| k.to_string()).collect::>(), input_values.into_iter().map(|(_, v)| v), None) } } } - fn run_inner_async<'s, 'v: 's>(&'s self, input_names: &[String], input_values: impl Iterator>) -> Result> { - // create a `RunOptions` to pass to the future so that when it drops, it terminates inference - crucial - // (performance-wise) for routines involving `tokio::select!` or timeouts - let run_options = Arc::new(RunOptions::new()?); + /// Asynchronously run input data through the ONNX graph, performing inference, with the given [`RunOptions`]. + /// See [`Session::run_with_options`] and [`Session::run_async`] for more details. + pub fn run_async_with_options<'s, 'i, 'v: 'i + 's, 'r, O: SelectedOutputMarker, const N: usize>( + &'s self, + input_values: impl Into> + 'static, + run_options: &'r RunOptions + ) -> Result> { + match input_values.into() { + SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"), + SessionInputs::ValueArray(input_values) => { + self.run_inner_async(&self.inputs.iter().map(|input| input.name.to_string()).collect::>(), input_values.into_iter(), Some(run_options)) + } + SessionInputs::ValueMap(input_values) => self.run_inner_async( + &input_values.iter().map(|(k, _)| k.to_string()).collect::>(), + input_values.into_iter().map(|(_, v)| v), + Some(run_options) + ) + } + } + + fn run_inner_async<'s, 'v: 's, 'r, O: SelectedOutputMarker>( + &'s self, + input_names: &[String], + input_values: impl Iterator>, + run_options: Option<&'r RunOptions> + ) -> Result> { + let run_options = match run_options { + Some(r) => RunOptionsRef::Ref(r), + // create a `RunOptions` to pass to the future so that when it drops, it terminates inference - crucial + // (performance-wise) for routines involving `tokio::select!` or timeouts + None => RunOptionsRef::Arc(Arc::new(unsafe { + // SAFETY: transmuting from `RunOptions` to `RunOptions`; safe because its just a marker + std::mem::transmute(RunOptions::new()?) + })) + }; let input_name_ptrs: Vec<*const c_char> = input_names .iter() @@ -489,7 +448,6 @@ unsafe impl Sync for Session {} mod dangerous { use super::*; - use crate::value::{extract_data_type_from_map_info, extract_data_type_from_sequence_info, extract_data_type_from_tensor_info}; pub(super) fn extract_inputs_count(session_ptr: NonNull) -> Result { let f = ortsys![unsafe SessionGetInputCount]; @@ -586,29 +544,6 @@ mod dangerous { status_to_result(status).map_err(Error::GetTypeInfo)?; assert_non_null_pointer(typeinfo_ptr, "TypeInfo")?; - let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; - let status = ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty)]; - status_to_result(status).map_err(Error::GetOnnxTypeFromTypeInfo)?; - let io_type = match ty { - ort_sys::ONNXType::ONNX_TYPE_TENSOR | ort_sys::ONNXType::ONNX_TYPE_SPARSETENSOR => { - let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToTensorInfo; nonNull(info_ptr)]; - unsafe { extract_data_type_from_tensor_info(info_ptr)? } - } - ort_sys::ONNXType::ONNX_TYPE_SEQUENCE => { - let mut info_ptr: *const ort_sys::OrtSequenceTypeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToSequenceTypeInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToSequenceTypeInfo; nonNull(info_ptr)]; - unsafe { extract_data_type_from_sequence_info(info_ptr)? } - } - ort_sys::ONNXType::ONNX_TYPE_MAP => { - let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToMapTypeInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToMapTypeInfo; nonNull(info_ptr)]; - unsafe { extract_data_type_from_map_info(info_ptr)? } - } - _ => unreachable!() - }; - - ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)]; - Ok(io_type) + ValueType::from_type_info(typeinfo_ptr) } } diff --git a/src/session/output.rs b/src/session/output.rs index c0fed437..2409d6b6 100644 --- a/src/session/output.rs +++ b/src/session/output.rs @@ -4,7 +4,7 @@ use std::{ ops::{Deref, DerefMut, Index} }; -use crate::{Allocator, DynValue}; +use crate::{memory::Allocator, value::DynValue}; /// The outputs returned by a [`crate::Session`] inference call. /// @@ -25,16 +25,16 @@ use crate::{Allocator, DynValue}; /// # } /// ``` #[derive(Debug)] -pub struct SessionOutputs<'s> { - map: BTreeMap<&'s str, DynValue>, - idxs: Vec<&'s str>, +pub struct SessionOutputs<'r, 's> { + map: BTreeMap<&'r str, DynValue>, + idxs: Vec<&'r str>, backing_ptr: Option<(&'s Allocator, *mut c_void)> } -unsafe impl<'s> Send for SessionOutputs<'s> {} +unsafe impl<'r, 's> Send for SessionOutputs<'r, 's> {} -impl<'s> SessionOutputs<'s> { - pub(crate) fn new(output_names: impl Iterator + Clone, output_values: impl IntoIterator) -> Self { +impl<'r, 's> SessionOutputs<'r, 's> { + pub(crate) fn new(output_names: impl Iterator + Clone, output_values: impl IntoIterator) -> Self { let map = output_names.clone().zip(output_values).collect(); Self { map, @@ -44,7 +44,7 @@ impl<'s> SessionOutputs<'s> { } pub(crate) fn new_backed( - output_names: impl Iterator + Clone, + output_names: impl Iterator + Clone, output_values: impl IntoIterator, allocator: &'s Allocator, backing_ptr: *mut c_void @@ -66,7 +66,7 @@ impl<'s> SessionOutputs<'s> { } } -impl<'s> Drop for SessionOutputs<'s> { +impl<'r, 's> Drop for SessionOutputs<'r, 's> { fn drop(&mut self) { if let Some((allocator, ptr)) = self.backing_ptr { unsafe { allocator.free(ptr) }; @@ -74,35 +74,35 @@ impl<'s> Drop for SessionOutputs<'s> { } } -impl<'s> Deref for SessionOutputs<'s> { - type Target = BTreeMap<&'s str, DynValue>; +impl<'r, 's> Deref for SessionOutputs<'r, 's> { + type Target = BTreeMap<&'r str, DynValue>; fn deref(&self) -> &Self::Target { &self.map } } -impl<'s> DerefMut for SessionOutputs<'s> { +impl<'r, 's> DerefMut for SessionOutputs<'r, 's> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.map } } -impl<'s> Index<&str> for SessionOutputs<'s> { +impl<'r, 's> Index<&str> for SessionOutputs<'r, 's> { type Output = DynValue; fn index(&self, index: &str) -> &Self::Output { self.map.get(index).expect("no entry found for key") } } -impl<'s> Index for SessionOutputs<'s> { +impl<'r, 's> Index for SessionOutputs<'r, 's> { type Output = DynValue; fn index(&self, index: String) -> &Self::Output { self.map.get(index.as_str()).expect("no entry found for key") } } -impl<'s> Index for SessionOutputs<'s> { +impl<'r, 's> Index for SessionOutputs<'r, 's> { type Output = DynValue; fn index(&self, index: usize) -> &Self::Output { self.map.get(&self.idxs[index]).expect("no entry found for key") diff --git a/src/session/run_options.rs b/src/session/run_options.rs new file mode 100644 index 00000000..42368923 --- /dev/null +++ b/src/session/run_options.rs @@ -0,0 +1,292 @@ +use std::{collections::HashMap, ffi::CString, marker::PhantomData, ptr::NonNull, sync::Arc}; + +use crate::{ + error::{Error, Result}, + ortsys, + session::Output, + value::{DynValue, Value, ValueTypeMarker} +}; + +/// Allows selecting/deselecting/preallocating the outputs of a [`crate::Session`] inference call. +/// +/// ``` +/// # use std::sync::Arc; +/// # use ort::{Session, Allocator, RunOptions, OutputSelector, Tensor}; +/// # fn main() -> ort::Result<()> { +/// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// let input = Tensor::::new(&Allocator::default(), [1, 64, 64, 3])?; +/// +/// let output0 = session.outputs[0].name.as_str(); +/// let options = RunOptions::new()?.with_outputs( +/// // Disable all outputs... +/// OutputSelector::no_default() +/// // except for the first one... +/// .with(output0) +/// // and since this is a 2x upsampler model, pre-allocate the output to be twice as large. +/// .preallocate(output0, Tensor::::new(&Allocator::default(), [1, 128, 128, 3])?) +/// ); +/// +/// // `outputs[0]` will be the tensor we just pre-allocated. +/// let outputs = session.run_with_options(ort::inputs![input]?, &options)?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +pub struct OutputSelector { + use_defaults: bool, + default_blocklist: Vec, + allowlist: Vec, + preallocated_outputs: HashMap +} + +impl Default for OutputSelector { + /// Creates an [`OutputSelector`] that enables all outputs by default. Use [`OutputSelector::without`] to disable a + /// specific output. + fn default() -> Self { + Self { + use_defaults: true, + allowlist: Vec::new(), + default_blocklist: Vec::new(), + preallocated_outputs: HashMap::new() + } + } +} + +impl OutputSelector { + /// Creates an [`OutputSelector`] that does not enable any outputs. Use [`OutputSelector::with`] to enable a + /// specific output. + pub fn no_default() -> Self { + Self { + use_defaults: false, + ..Default::default() + } + } + + /// Mark the output specified by the `name` for inclusion. + pub fn with(mut self, name: impl Into) -> Self { + self.allowlist.push(name.into()); + self + } + + /// Mark the output specified by `name` to be **excluded**. ONNX Runtime may prune some of the output node's + /// ancestor nodes. + pub fn without(mut self, name: impl Into) -> Self { + self.default_blocklist.push(name.into()); + self + } + + /// Pre-allocates an output. Assuming the type & shape of the value matches what is expected by the model, the + /// output value corresponding to `name` returned by the inference call will be the exact same value as the + /// pre-allocated value. + /// + /// **The same value will be reused as long as this [`OutputSelector`] and its parent [`RunOptions`] is used**, so + /// if you use the same `RunOptions` across multiple runs with a preallocated value, the preallocated value will be + /// overwritten upon each run. + /// + /// This can improve performance if the size and type of the output is known, and does not change between runs, i.e. + /// for an ODE or embeddings model. + /// + /// ``` + /// # use std::sync::Arc; + /// # use ort::{Session, Allocator, RunOptions, OutputSelector, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let input = Tensor::::new(&Allocator::default(), [1, 64, 64, 3])?; + /// + /// let output0 = session.outputs[0].name.as_str(); + /// let options = RunOptions::new()?.with_outputs( + /// OutputSelector::default().preallocate(output0, Tensor::::new(&Allocator::default(), [1, 128, 128, 3])?) + /// ); + /// + /// let outputs = session.run_with_options(ort::inputs![input]?, &options)?; + /// # Ok(()) + /// # } + /// ``` + pub fn preallocate(mut self, name: impl Into, value: Value) -> Self { + self.preallocated_outputs.insert(name.into(), value.into_dyn()); + self + } + + pub(crate) fn resolve_outputs<'a, 's: 'a>(&'a self, outputs: &'s [Output]) -> (Vec<&'a str>, Vec>) { + if self.use_defaults { outputs.iter() } else { [].iter() } + .map(|o| &o.name) + .filter(|n| !self.default_blocklist.contains(n)) + .chain(self.allowlist.iter()) + .map(|n| { + ( + n.as_str(), + self.preallocated_outputs.get(n).map(|v| DynValue { + inner: Arc::clone(&v.inner), + _markers: PhantomData + }) + ) + }) + .unzip() + } +} + +/// Types that specify whether a [`RunOptions`] was configured with an [`OutputSelector`]. +pub trait SelectedOutputMarker {} +/// Marks that a [`RunOptions`] was not configured with an [`OutputSelector`]. +pub struct NoSelectedOutputs; +impl SelectedOutputMarker for NoSelectedOutputs {} +/// Marks that a [`RunOptions`] was configured with an [`OutputSelector`]. +pub struct HasSelectedOutputs; +impl SelectedOutputMarker for HasSelectedOutputs {} + +/// Allows for finer control over session inference. +/// +/// [`RunOptions`] provides three main features: +/// - **Run tagging**: Each individual session run can have a uniquely identifiable tag attached with +/// [`RunOptions::set_tag`], which will show up in logs. This can be especially useful for debugging +/// performance/errors in inference servers. +/// - **Termination**: Allows for terminating an inference call from another thread; when [`RunOptions::terminate`] is +/// called, any sessions currently running under that [`RunOptions`] instance will halt graph execution as soon as the +/// termination signal is received. This allows for [`crate::Session::run_async`]'s cancel-safety. +/// - **Output specification**: Certain session outputs can be [disabled](`OutputSelector::without`) or +/// [pre-allocated](`OutputSelector::preallocate`). Disabling an output might mean ONNX Runtime will not execute parts +/// of the graph that are only used by that output. Pre-allocation can reduce expensive re-allocations by allowing you +/// to use the same memory across runs. +/// +/// [`RunOptions`] can be passed to most places where a session can be inferred, e.g. +/// [`crate::Session::run_with_options`], [`crate::Session::run_async_with_options`], +/// [`crate::IoBinding::run_with_options`]. Some of these patterns (notably `IoBinding`) do not accept +/// [`OutputSelector`], hence [`RunOptions`] contains an additional type parameter that marks whether or not outputs +/// have been selected. +#[derive(Debug)] +pub struct RunOptions { + pub(crate) run_options_ptr: NonNull, + pub(crate) outputs: OutputSelector, + _marker: PhantomData +} + +// https://onnxruntime.ai/docs/api/c/struct_ort_api.html#ac2a08cac0a657604bd5899e0d1a13675 +unsafe impl Send for RunOptions {} +// Only allow `Sync` if we don't have (potentially pre-allocated) outputs selected. +// Allowing `Sync` here would mean a single pre-allocated `Value` could be mutated simultaneously in different threads - +// a brazen crime against crabkind. +unsafe impl Sync for RunOptions {} + +impl RunOptions { + /// Creates a new [`RunOptions`] struct. + pub fn new() -> Result> { + let mut run_options_ptr: *mut ort_sys::OrtRunOptions = std::ptr::null_mut(); + ortsys![unsafe CreateRunOptions(&mut run_options_ptr) -> Error::CreateRunOptions; nonNull(run_options_ptr)]; + Ok(RunOptions { + run_options_ptr: unsafe { NonNull::new_unchecked(run_options_ptr) }, + outputs: OutputSelector::default(), + _marker: PhantomData + }) + } +} + +impl RunOptions { + /// Select/deselect/preallocate outputs for this run. + /// + /// See [`OutputSelector`] for more details. + /// + /// ``` + /// # use std::sync::Arc; + /// # use ort::{Session, Allocator, RunOptions, OutputSelector, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let input = Tensor::::new(&Allocator::default(), [1, 64, 64, 3])?; + /// + /// let output0 = session.outputs[0].name.as_str(); + /// let options = RunOptions::new()?.with_outputs( + /// // Disable all outputs... + /// OutputSelector::no_default() + /// // except for the first one... + /// .with(output0) + /// // and since this is a 2x upsampler model, pre-allocate the output to be twice as large. + /// .preallocate(output0, Tensor::::new(&Allocator::default(), [1, 128, 128, 3])?) + /// ); + /// + /// // `outputs[0]` will be the tensor we just pre-allocated. + /// let outputs = session.run_with_options(ort::inputs![input]?, &options)?; + /// # Ok(()) + /// # } + /// ``` + pub fn with_outputs(mut self, outputs: OutputSelector) -> RunOptions { + self.outputs = outputs; + unsafe { std::mem::transmute(self) } + } + + /// Sets a tag to identify this run in logs. + pub fn with_tag(mut self, tag: impl AsRef) -> Result { + self.set_tag(tag).map(|_| self) + } + + /// Sets a tag to identify this run in logs. + pub fn set_tag(&mut self, tag: impl AsRef) -> Result<()> { + let tag = CString::new(tag.as_ref())?; + ortsys![unsafe RunOptionsSetRunTag(self.run_options_ptr.as_ptr(), tag.as_ptr()) -> Error::RunOptionsSetTag]; + Ok(()) + } + + /// Sets the termination flag for the runs associated with this [`RunOptions`]. + /// + /// This function returns immediately (it does not wait for the session run to terminate). The run will terminate as + /// soon as it is able to. + /// + /// ```no_run + /// # // no_run because upsample.onnx is too simple of a model for the termination signal to be reliable enough + /// # use std::sync::Arc; + /// # use ort::{Session, RunOptions, Value, ValueType, TensorElementType}; + /// # fn main() -> ort::Result<()> { + /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// # let input = Value::from_array(ndarray::Array4::::zeros((1, 64, 64, 3)))?; + /// let run_options = Arc::new(RunOptions::new()?); + /// + /// let run_options_ = Arc::clone(&run_options); + /// std::thread::spawn(move || { + /// let _ = run_options_.terminate(); + /// }); + /// + /// let res = session.run_with_options(ort::inputs![input]?, &*run_options); + /// // upon termination, the session will return an `Error::SessionRun` error.` + /// assert_eq!( + /// &res.unwrap_err().to_string(), + /// "Failed to run inference on model: Exiting due to terminate flag being set to true." + /// ); + /// # Ok(()) + /// # } + /// ``` + pub fn terminate(&self) -> Result<()> { + ortsys![unsafe RunOptionsSetTerminate(self.run_options_ptr.as_ptr()) -> Error::RunOptionsSetTerminate]; + Ok(()) + } + + /// Resets the termination flag for the runs associated with [`RunOptions`]. + /// + /// ```no_run + /// # use std::sync::Arc; + /// # use ort::{Session, RunOptions, Value, ValueType, TensorElementType}; + /// # fn main() -> ort::Result<()> { + /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// # let input = Value::from_array(ndarray::Array4::::zeros((1, 64, 64, 3)))?; + /// let run_options = Arc::new(RunOptions::new()?); + /// + /// let run_options_ = Arc::clone(&run_options); + /// std::thread::spawn(move || { + /// let _ = run_options_.terminate(); + /// // ...oops, didn't mean to do that + /// let _ = run_options_.unterminate(); + /// }); + /// + /// let res = session.run_with_options(ort::inputs![input]?, &*run_options); + /// assert!(res.is_ok()); + /// # Ok(()) + /// # } + /// ``` + pub fn unterminate(&self) -> Result<()> { + ortsys![unsafe RunOptionsUnsetTerminate(self.run_options_ptr.as_ptr()) -> Error::RunOptionsUnsetTerminate]; + Ok(()) + } +} + +impl Drop for RunOptions { + fn drop(&mut self) { + ortsys![unsafe ReleaseRunOptions(self.run_options_ptr.as_ptr())]; + } +} diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 98cf2ae2..a1a5440e 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -22,4 +22,4 @@ mod types; pub use self::ndarray::ArrayExtensions; #[cfg(feature = "ndarray")] pub(crate) use self::types::{extract_primitive_array, extract_primitive_array_mut}; -pub use self::types::{IntoTensorElementType, TensorElementType, Utf8Data}; +pub use self::types::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data}; diff --git a/src/tensor/types.rs b/src/tensor/types.rs index 08a57d8f..f5f0a1ab 100644 --- a/src/tensor/types.rs +++ b/src/tensor/types.rs @@ -2,7 +2,10 @@ use std::ptr; #[cfg(feature = "ndarray")] -use crate::{ortsys, Error, Result}; +use crate::{ + error::{Error, Result}, + ortsys +}; /// Enum mapping ONNX Runtime's supported tensor data types. #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -91,6 +94,12 @@ impl From for TensorElementType { pub trait IntoTensorElementType { /// Returns the ONNX tensor element data type corresponding to the given Rust type. fn into_tensor_element_type() -> TensorElementType; + + crate::private_trait!(); +} + +pub trait PrimitiveTensorElementType: IntoTensorElementType { + crate::private_trait!(); } macro_rules! impl_type_trait { @@ -99,6 +108,12 @@ macro_rules! impl_type_trait { fn into_tensor_element_type() -> TensorElementType { TensorElementType::$variant } + + crate::private_impl!(); + } + + impl PrimitiveTensorElementType for $type_ { + crate::private_impl!(); } }; } @@ -121,6 +136,14 @@ impl_type_trait!(u64, Uint64); #[cfg_attr(docsrs, doc(cfg(feature = "half")))] impl_type_trait!(half::bf16, Bfloat16); +impl IntoTensorElementType for String { + fn into_tensor_element_type() -> TensorElementType { + TensorElementType::String + } + + crate::private_impl!(); +} + /// Adapter for common Rust string types to ONNX strings. pub trait Utf8Data { /// Returns the contents of this value as a slice of UTF-8 bytes. diff --git a/src/training/mod.rs b/src/training/mod.rs new file mode 100644 index 00000000..d66db11d --- /dev/null +++ b/src/training/mod.rs @@ -0,0 +1,142 @@ +use std::{ + path::Path, + ptr::{self, NonNull}, + sync::{ + atomic::{AtomicPtr, Ordering}, + OnceLock + } +}; + +use crate::{ortsys, Error, Result, RunOptions}; + +mod simple; +mod trainer; + +pub use self::{ + simple::{iterable_data_loader, CheckpointStrategy, DataLoader, EvaluationStrategy, IterableDataLoader, TrainingArguments}, + trainer::Trainer +}; + +pub(crate) static TRAINING_API: OnceLock> = OnceLock::new(); + +/// Returns a pointer to the global [`ort_sys::OrtTrainingApi`] object, or errors if the Training API is not enabled. +/// +/// # Panics +/// May panic if: +/// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime. +/// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled. +pub fn training_api() -> Result> { + NonNull::new( + TRAINING_API + .get_or_init(|| { + let training_api = ortsys![unsafe GetTrainingApi(ort_sys::ORT_API_VERSION)]; + AtomicPtr::new(training_api.cast_mut()) + }) + .load(Ordering::Relaxed) + ) + .ok_or(Error::TrainingNotEnabled) +} + +macro_rules! trainsys { + ($method:ident) => { + $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null"))) + }; + (unsafe $method:ident) => { + unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null"))) } + }; + ($method:ident($($n:expr),+ $(,)?)) => { + $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) + }; + (unsafe $method:ident($($n:expr),+ $(,)?)) => { + unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) } + }; + ($method:ident($($n:expr),+ $(,)?).expect($e:expr)) => { + $crate::error::status_to_result($crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).expect($e) + }; + (unsafe $method:ident($($n:expr),+ $(,)?).expect($e:expr)) => { + $crate::error::status_to_result(unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).expect($e) + }; + ($method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => { + $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+); + $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ + }; + (unsafe $method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {{ + let _x = unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }; + $($crate::error::assert_non_null_pointer($check, stringify!($method)).unwrap();)+ + _x + }}; + ($method:ident($($n:expr),+ $(,)?) -> $err:expr$(;)?) => { + $crate::error::status_to_result($crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).map_err($err)?; + }; + (unsafe $method:ident($($n:expr),+ $(,)?) -> $err:expr$(;)?) => { + $crate::error::status_to_result(unsafe { $crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).map_err($err)?; + }; + ($method:ident($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => { + $crate::error::status_to_result($crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).map_err($err)?; + $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ + }; + (unsafe $method:ident($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {{ + $crate::error::status_to_result(unsafe { $crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).map_err($err)?; + $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ + }}; +} +pub(crate) use trainsys; + +#[derive(Debug)] +pub struct Checkpoint { + pub(crate) ptr: NonNull +} + +impl Checkpoint { + pub fn load(path: impl AsRef) -> Result { + let path = crate::util::path_to_os_char(path); + let mut ptr: *mut ort_sys::OrtCheckpointState = ptr::null_mut(); + trainsys![unsafe LoadCheckpoint(path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)]; + Ok(Checkpoint { + ptr: unsafe { NonNull::new_unchecked(ptr) } + }) + } + + pub fn save(&self, path: impl AsRef, include_optimizer_state: bool) -> Result<()> { + let path = crate::util::path_to_os_char(path); + trainsys![unsafe SaveCheckpoint(self.ptr.as_ptr(), path.as_ptr(), include_optimizer_state) -> Error::CreateSession]; + Ok(()) + } +} + +impl Drop for Checkpoint { + fn drop(&mut self) { + tracing::trace!("dropping checkpoint"); + trainsys![unsafe ReleaseCheckpointState(self.ptr.as_ptr())]; + } +} + +#[derive(Debug)] +pub struct Optimizer(NonNull); + +impl Optimizer { + pub fn reset_grad(&self) -> Result<()> { + trainsys![unsafe LazyResetGrad(self.0.as_ptr()) -> Error::CreateSession]; + Ok(()) + } + + pub fn lr(&self) -> Result { + let mut lr = f32::NAN; + trainsys![unsafe GetLearningRate(self.0.as_ptr(), &mut lr) -> Error::CreateSession]; + Ok(lr) + } + + pub fn set_lr(&self, lr: f32) -> Result<()> { + trainsys![unsafe SetLearningRate(self.0.as_ptr(), lr) -> Error::CreateSession]; + Ok(()) + } + + pub fn step(&self) -> Result<()> { + self.step_with_options(RunOptions::new()?) + } + + pub fn step_with_options(&self, options: RunOptions) -> Result<()> { + trainsys![unsafe OptimizerStep(self.0.as_ptr(), options.run_options_ptr.as_ptr()) -> Error::CreateSession]; + Ok(()) + } +} diff --git a/src/training/simple.rs b/src/training/simple.rs new file mode 100644 index 00000000..267f3c64 --- /dev/null +++ b/src/training/simple.rs @@ -0,0 +1,240 @@ +use std::{collections::VecDeque, fs, path::PathBuf}; + +use crate::{Result, SessionInputs}; + +#[allow(clippy::len_without_is_empty)] +pub trait DataLoader { + fn load(&mut self, idx: usize) -> Result<(I, L)>; + + fn len(&self) -> Option { + None + } +} + +pub struct IterableDataLoader Result<(I, L)>> { + items: Box<[T]>, + collator: C +} + +impl Result<(I, L)>> DataLoader for IterableDataLoader { + fn load(&mut self, idx: usize) -> Result<(I, L)> { + (self.collator)(&self.items[idx]) + } + + fn len(&self) -> Option { + Some(self.items.len()) + } +} + +pub fn iterable_data_loader Result<(I, L)>>(iterable: impl Iterator, collator: C) -> IterableDataLoader { + IterableDataLoader { + items: iterable.collect::>().into_boxed_slice(), + collator + } +} + +impl Result<(I, L)>> DataLoader for F { + fn load(&mut self, idx: usize) -> Result<(I, L)> { + (self)(idx) + } + + fn len(&self) -> Option { + None + } +} + +pub enum EvaluationStrategy { + None, + Steps(usize), + Epochs(usize) +} + +impl EvaluationStrategy { + pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option) -> bool { + match self { + Self::None => false, + Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0, + Self::Epochs(epochs) => { + if let Some(dataloader_size) = dataloader_size { + iter_step > 0 && iter_step % (dataloader_size * epochs) == 0 + } else { + false + } + } + } + } +} + +pub enum CheckpointStrategy { + None, + Steps(usize), + Epochs(usize) +} + +impl CheckpointStrategy { + pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option) -> bool { + match self { + Self::None => false, + Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0, + Self::Epochs(epochs) => { + if let Some(dataloader_size) = dataloader_size { + iter_step > 0 && iter_step % (dataloader_size * epochs) == 0 + } else { + false + } + } + } + } +} + +pub struct TrainingArguments>, L: Into>, const NI: usize, const NL: usize> { + loader: Box>, + eval_loader: Option>>, + eval_strategy: EvaluationStrategy, + ckpt_strategy: CheckpointStrategy, + ckpt_path: PathBuf, + lr: f32, + max_saved_ckpts: usize, + gradient_accumulation_steps: usize, + max_steps: usize, + max_eval_steps: usize +} + +impl>, L: Into>, const NI: usize, const NL: usize> + TrainingArguments +{ + pub fn new + 'static>(train_loader: D) -> Self { + Self { + loader: Box::new(train_loader), + eval_loader: None, + eval_strategy: EvaluationStrategy::None, + ckpt_strategy: CheckpointStrategy::Epochs(1), + ckpt_path: PathBuf::from("checkpoints"), + lr: 1e-4, + gradient_accumulation_steps: 1, + max_saved_ckpts: 1, + max_steps: usize::MAX, + max_eval_steps: usize::MAX + } + } + + pub fn with_lr(mut self, lr: f32) -> Self { + self.lr = lr; + self + } + + pub fn with_max_steps(mut self, steps: usize) -> Self { + self.max_steps = steps; + self + } + + pub fn with_max_eval_steps(mut self, steps: usize) -> Self { + self.max_eval_steps = steps; + self + } + + pub fn with_gradient_accumulation(mut self, steps: usize) -> Self { + self.gradient_accumulation_steps = steps; + self + } + + pub fn with_ckpt_path(mut self, path: impl Into) -> Self { + self.ckpt_path = path.into(); + self + } + + pub fn with_ckpt_strategy(mut self, strategy: CheckpointStrategy) -> Self { + self.ckpt_strategy = strategy; + self + } + + pub fn with_max_saved_ckpts(mut self, max_ckpts: usize) -> Self { + self.max_saved_ckpts = max_ckpts; + self + } + + pub fn with_eval_loader + 'static>(mut self, eval_loader: D) -> Self { + self.eval_loader = Some(Box::new(eval_loader)); + self + } + + pub fn with_eval_strategy(mut self, strategy: EvaluationStrategy) -> Self { + self.eval_strategy = strategy; + self + } +} + +impl super::Trainer { + pub fn train>, L: Into>, const NI: usize, const NL: usize>( + &self, + mut args: TrainingArguments + ) -> crate::Result<()> { + let optimizer = self.optimizer(); + optimizer.set_lr(args.lr)?; + + let mut saved_ckpts = VecDeque::new(); + let mut global_step = 0; + for (iter_step, _) in (0..args.max_steps).enumerate() { + let epoch = iter_step / args.loader.len().unwrap_or(usize::MAX); + let (inputs, labels) = args.loader.load(iter_step)?; + let (inputs, labels) = (inputs.into(), labels.into()); + + let outputs = self.step(inputs, labels)?; + let loss = outputs[0].try_extract_scalar::()?; + println!("epoch={epoch} step={global_step} loss={loss}"); + + if iter_step % args.gradient_accumulation_steps == 0 { + optimizer.step()?; + optimizer.reset_grad()?; + global_step += 1; + } + + if args.ckpt_strategy.should_fire(global_step, iter_step, args.loader.len()) { + if !args.ckpt_path.exists() { + let _ = fs::create_dir_all(&args.ckpt_path); + } + + let ckpt_path = args.ckpt_path.join(format!("epoch={epoch},step={global_step}.ortckpt")); + self.checkpoint().save(&ckpt_path, true)?; + + saved_ckpts.push_front(ckpt_path.clone()); + while saved_ckpts.len() > args.max_saved_ckpts { + let Some(old_ckpt) = saved_ckpts.pop_back() else { + break; + }; + let _ = fs::remove_file(old_ckpt); + } + } + + if args + .eval_strategy + .should_fire(global_step, iter_step, args.eval_loader.as_ref().and_then(|d| d.len())) + { + let eval_loss = self.eval_inner(&mut args)?; + println!("eval_loss={eval_loss}"); + } + } + Ok(()) + } + + pub(crate) fn eval_inner>, L: Into>, const NI: usize, const NL: usize>( + &self, + args: &mut TrainingArguments + ) -> crate::Result { + let Some(eval_loader) = &mut args.eval_loader else { + return Ok(0.0); + }; + + let mut total_loss = 0.0; + for step in 0..args.max_eval_steps.min(eval_loader.len().unwrap_or(usize::MAX)) { + let (inputs, labels) = eval_loader.load(step)?; + let (inputs, labels) = (inputs.into(), labels.into()); + + let outputs = self.eval_step(inputs, labels)?; + let loss = outputs[0].try_extract_scalar::()?; + total_loss = (total_loss * (step as f32) + loss) / (step as f32 + 1.); + } + + Ok(total_loss) + } +} diff --git a/src/training/trainer.rs b/src/training/trainer.rs new file mode 100644 index 00000000..f7c7cb38 --- /dev/null +++ b/src/training/trainer.rs @@ -0,0 +1,235 @@ +use std::{ + ffi::CString, + path::Path, + ptr::{self, NonNull}, + sync::Arc +}; + +use ort_sys::c_char; + +use super::{trainsys, Checkpoint, Optimizer}; +use crate::{ + char_p_to_string, + error::{assert_non_null_pointer, status_to_result}, + Allocator, Error, Result, RunOptions, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, Value +}; + +#[derive(Debug)] +pub struct Trainer { + pub(crate) ptr: NonNull, + train_output_names: Vec, + optimizer: Optimizer, + ckpt: Checkpoint, + _allocator: Allocator +} + +impl Trainer { + pub fn new( + session_options: SessionBuilder, + allocator: Allocator, + ckpt: Checkpoint, + training_model_path: impl AsRef, + eval_model_path: impl AsRef, + optimizer_model_path: impl AsRef + ) -> Result { + let training_model_path = crate::util::path_to_os_char(training_model_path); + let eval_model_path = crate::util::path_to_os_char(eval_model_path); + let optimizer_model_path = crate::util::path_to_os_char(optimizer_model_path); + + let env = crate::get_environment()?; + + let mut ptr: *mut ort_sys::OrtTrainingSession = ptr::null_mut(); + trainsys![unsafe CreateTrainingSession(env.ptr(), session_options.session_options_ptr.as_ptr(), ckpt.ptr.as_ptr(), training_model_path.as_ptr(), eval_model_path.as_ptr(), optimizer_model_path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)]; + + let ptr = unsafe { NonNull::new_unchecked(ptr) }; + + let mut train_output_len = 0; + trainsys![unsafe TrainingSessionGetTrainingModelOutputCount(ptr.as_ptr(), &mut train_output_len) -> Error::CreateSession]; + let train_output_names = (0..train_output_len) + .map(|i| { + let mut name_bytes: *mut c_char = std::ptr::null_mut(); + trainsys![unsafe TrainingSessionGetTrainingModelOutputName(ptr.as_ptr(), i, allocator.ptr.as_ptr(), &mut name_bytes) -> Error::CreateSession]; + let name = match char_p_to_string(name_bytes) { + Ok(name) => name, + Err(e) => { + unsafe { allocator.free(name_bytes) }; + return Err(e); + } + }; + unsafe { allocator.free(name_bytes) }; + Ok(name) + }) + .collect::>>()?; + + Ok(Self { + ptr, + _allocator: allocator, + train_output_names, + optimizer: Optimizer(ptr), + ckpt + }) + } + + pub fn new_from_artifacts( + session_options: SessionBuilder, + allocator: Allocator, + base_dir: impl AsRef, + override_ckpt: Option + ) -> Result { + let base_dir = base_dir.as_ref(); + let ckpt = if let Some(ckpt) = override_ckpt { + ckpt + } else { + Checkpoint::load(base_dir.join("checkpoint"))? + }; + Self::new( + session_options, + allocator, + ckpt, + base_dir.join("training_model.onnx"), + base_dir.join("eval_model.onnx"), + base_dir.join("optimizer_model.onnx") + ) + } + + pub fn step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>( + &'s self, + inputs: impl Into>, + labels: impl Into> + ) -> Result> { + match inputs.into() { + SessionInputs::ValueSlice(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueArray(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + } + } + + fn step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( + &'s self, + input_values: impl Iterator>, + run_options: Option<&'r RunOptions> + ) -> Result> { + let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; + + let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); + + let run_options_ptr = if let Some(run_options) = &run_options { + run_options.run_options_ptr.as_ptr() + } else { + std::ptr::null_mut() + }; + + trainsys![unsafe TrainStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun]; + + let outputs: Vec = output_tensor_ptrs + .into_iter() + .map(|tensor_ptr| unsafe { + // TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`. + // but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣 + Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None) + }) + .collect(); + + Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) + } + + pub fn eval_step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>( + &'s self, + inputs: impl Into>, + labels: impl Into> + ) -> Result> { + match inputs.into() { + SessionInputs::ValueSlice(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueArray(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + } + } + + fn eval_step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( + &'s self, + input_values: impl Iterator>, + run_options: Option<&'r RunOptions> + ) -> Result> { + let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; + + let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); + + let run_options_ptr = if let Some(run_options) = &run_options { + run_options.run_options_ptr.as_ptr() + } else { + std::ptr::null_mut() + }; + + trainsys![unsafe EvalStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun]; + + let outputs: Vec = output_tensor_ptrs + .into_iter() + .map(|tensor_ptr| unsafe { + // TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`. + // but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣 + Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None) + }) + .collect(); + + Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) + } + + pub fn export>(&self, out_path: impl AsRef, output_names: impl AsRef<[O]>) -> Result<()> { + let out_path = crate::util::path_to_os_char(out_path); + + let output_names_ptr: Vec<*const c_char> = output_names + .as_ref() + .iter() + .map(|output| CString::new(output.as_ref()).unwrap_or_else(|_| unreachable!())) + .map(|n| n.into_raw().cast_const()) + .collect(); + + let res = trainsys![unsafe ExportModelForInferencing(self.ptr.as_ptr(), out_path.as_ptr(), output_names_ptr.len(), output_names_ptr.as_ptr())]; + + // Reconvert name ptrs to CString so drop impl is called and memory is freed + drop( + output_names_ptr + .into_iter() + .map(|p| { + assert_non_null_pointer(p, "c_char for CString")?; + unsafe { Ok(CString::from_raw(p.cast_mut().cast())) } + }) + .collect::>>()? + ); + + status_to_result(res).map_err(Error::CreateSession)?; + + Ok(()) + } + + pub fn optimizer(&self) -> &Optimizer { + &self.optimizer + } + + pub fn checkpoint(&self) -> &Checkpoint { + &self.ckpt + } +} + +impl Drop for Trainer { + fn drop(&mut self) { + tracing::trace!("dropping trainer"); + trainsys![unsafe ReleaseTrainingSession(self.ptr.as_ptr())]; + } +} diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 00000000..bfa11d98 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,26 @@ +#[cfg(not(target_family = "windows"))] +use std::os::raw::c_char; +#[cfg(unix)] +use std::os::unix::ffi::OsStrExt; +#[cfg(target_family = "windows")] +use std::os::windows::ffi::OsStrExt; +use std::{ffi::OsString, path::Path}; + +#[cfg(target_family = "windows")] +type OsCharArray = Vec; +#[cfg(not(target_family = "windows"))] +type OsCharArray = Vec; + +pub fn path_to_os_char(path: impl AsRef) -> OsCharArray { + let model_path = OsString::from(path.as_ref()); + #[cfg(target_family = "windows")] + let model_path: Vec = model_path.encode_wide().chain(std::iter::once(0)).collect(); + #[cfg(not(target_family = "windows"))] + let model_path: Vec = model_path + .as_encoded_bytes() + .iter() + .chain(std::iter::once(&b'\0')) + .map(|b| *b as c_char) + .collect(); + model_path +} diff --git a/src/value/impl_map.rs b/src/value/impl_map.rs index 1421e5a5..87d653f4 100644 --- a/src/value/impl_map.rs +++ b/src/value/impl_map.rs @@ -3,25 +3,42 @@ use std::{ fmt::Debug, hash::Hash, marker::PhantomData, - ptr::{self, NonNull} + ptr::{self, NonNull}, + sync::Arc }; -use super::{ValueInner, ValueTypeMarker}; +use super::{ + impl_tensor::{calculate_tensor_size, DynTensor, Tensor}, + DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker +}; use crate::{ - memory::Allocator, ortsys, value::impl_tensor::DynTensor, DynValue, Error, IntoTensorElementType, Result, Tensor, Value, ValueRef, ValueRefMut, ValueType + error::{Error, Result}, + memory::Allocator, + ortsys, + tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType} }; -pub trait MapValueTypeMarker: ValueTypeMarker {} +pub trait MapValueTypeMarker: ValueTypeMarker { + crate::private_trait!(); +} #[derive(Debug)] pub struct DynMapValueType; -impl ValueTypeMarker for DynMapValueType {} -impl MapValueTypeMarker for DynMapValueType {} +impl ValueTypeMarker for DynMapValueType { + crate::private_impl!(); +} +impl MapValueTypeMarker for DynMapValueType { + crate::private_impl!(); +} #[derive(Debug)] pub struct MapValueType(PhantomData<(K, V)>); -impl ValueTypeMarker for MapValueType {} -impl MapValueTypeMarker for MapValueType {} +impl ValueTypeMarker for MapValueType { + crate::private_impl!(); +} +impl MapValueTypeMarker for MapValueType { + crate::private_impl!(); +} pub type DynMap = Value; pub type Map = Value>; @@ -32,10 +49,7 @@ pub type MapRef<'v, K, V> = ValueRef<'v, MapValueType>; pub type MapRefMut<'v, K, V> = ValueRefMut<'v, MapValueType>; impl Value { - pub fn try_extract_map( - &self, - allocator: &Allocator - ) -> Result> { + pub fn try_extract_map(&self) -> Result> { match self.dtype()? { ValueType::Map { key, value } => { let k_type = K::into_tensor_element_type(); @@ -47,47 +61,95 @@ impl Value { return Err(Error::InvalidMapValueType { expected: v_type, actual: value }); } + let allocator = Allocator::default(); + let mut key_tensor_ptr = ptr::null_mut(); ortsys![unsafe GetValue(self.ptr(), 0, allocator.ptr.as_ptr(), &mut key_tensor_ptr) -> Error::ExtractMap; nonNull(key_tensor_ptr)]; let key_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(key_tensor_ptr), None) }; - let (key_tensor_shape, key_tensor) = key_value.try_extract_raw_tensor::()?; + if K::into_tensor_element_type() != TensorElementType::String { + let dtype = key_value.dtype()?; + let (key_tensor_shape, key_tensor) = match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = key_value.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if ty == K::into_tensor_element_type() { + let mut output_array_ptr: *mut K = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut K = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); + ortsys![unsafe GetTensorMutableData(key_tensor_ptr, output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; + + let len = calculate_tensor_size(&dimensions); + (dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) }) + } else { + return Err(Error::DataTypeMismatch { + actual: ty, + requested: K::into_tensor_element_type() + }); + } + } + _ => unreachable!() + }; - let mut value_tensor_ptr = ptr::null_mut(); - ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)]; - let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) }; - let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::()?; + let mut value_tensor_ptr = ptr::null_mut(); + ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)]; + let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) }; + let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::()?; - assert_eq!(key_tensor_shape.len(), 1); - assert_eq!(value_tensor_shape.len(), 1); - assert_eq!(key_tensor_shape[0], value_tensor_shape[0]); + assert_eq!(key_tensor_shape.len(), 1); + assert_eq!(value_tensor_shape.len(), 1); + assert_eq!(key_tensor_shape[0], value_tensor_shape[0]); - let mut vec = Vec::with_capacity(key_tensor_shape[0] as _); - for i in 0..key_tensor_shape[0] as usize { - vec.push((key_tensor[i].clone(), value_tensor[i].clone())); + let mut vec = Vec::with_capacity(key_tensor_shape[0] as _); + for i in 0..key_tensor_shape[0] as usize { + vec.push((key_tensor[i].clone(), value_tensor[i].clone())); + } + Ok(vec.into_iter().collect()) + } else { + let (key_tensor_shape, key_tensor) = key_value.try_extract_raw_string_tensor()?; + // SAFETY: `IntoTensorElementType` is a private trait, and we only map the `String` type to `TensorElementType::String`, + // so at this point, `K` is **always** the `String` type, and this transmute really does nothing but please the type + // checker. + let key_tensor: Vec = unsafe { std::mem::transmute(key_tensor) }; + + let mut value_tensor_ptr = ptr::null_mut(); + ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)]; + let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) }; + let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::()?; + + assert_eq!(key_tensor_shape.len(), 1); + assert_eq!(value_tensor_shape.len(), 1); + assert_eq!(key_tensor_shape[0], value_tensor_shape[0]); + + let mut vec = Vec::with_capacity(key_tensor_shape[0] as _); + for i in 0..key_tensor_shape[0] as usize { + vec.push((key_tensor[i].clone(), value_tensor[i].clone())); + } + Ok(vec.into_iter().collect()) } - Ok(vec.into_iter().collect()) } t => Err(Error::NotMap(t)) } } } -impl Value> { +impl Value> { /// Creates a [`Map`] from an iterable emitting `K` and `V`. /// /// ``` /// # use std::collections::HashMap; - /// # use ort::{Allocator, Map}; + /// # use ort::Map; /// # fn main() -> ort::Result<()> { - /// # let allocator = Allocator::default(); /// let mut map = HashMap::::new(); /// map.insert(0, 1.0); /// map.insert(1, 2.0); /// map.insert(2, 3.0); /// - /// let value = Map::new(map)?; + /// let value = Map::::new(map)?; /// - /// assert_eq!(*value.extract_map(&allocator).get(&0).unwrap(), 1.0); + /// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0); /// # Ok(()) /// # } /// ``` @@ -95,20 +157,45 @@ impl, Vec) = data.into_iter().unzip(); Self::new_kv(Tensor::from_array((vec![keys.len()], keys))?, Tensor::from_array((vec![values.len()], values))?) } +} +impl Value> { + /// Creates a [`Map`] from an iterable emitting `K` and `V`. + /// + /// ``` + /// # use std::collections::HashMap; + /// # use ort::Map; + /// # fn main() -> ort::Result<()> { + /// let mut map = HashMap::::new(); + /// map.insert(0, 1.0); + /// map.insert(1, 2.0); + /// map.insert(2, 3.0); + /// + /// let value = Map::::new(map)?; + /// + /// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0); + /// # Ok(()) + /// # } + /// ``` + pub fn new(data: impl IntoIterator) -> Result { + let (keys, values): (Vec, Vec) = data.into_iter().unzip(); + Self::new_kv(Tensor::from_string_array((vec![keys.len()], keys))?, Tensor::from_array((vec![values.len()], values))?) + } +} + +impl Value> { /// Creates a [`Map`] from two tensors of keys & values respectively. /// /// ``` /// # use std::collections::HashMap; - /// # use ort::{Allocator, Map, Tensor}; + /// # use ort::{Map, Tensor}; /// # fn main() -> ort::Result<()> { - /// # let allocator = Allocator::default(); /// let keys = Tensor::::from_array(([4], vec![0, 1, 2, 3]))?; /// let values = Tensor::::from_array(([4], vec![1., 2., 3., 4.]))?; /// /// let value = Map::new_kv(keys, values)?; /// - /// assert_eq!(*value.extract_map(&allocator).get(&0).unwrap(), 1.0); + /// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0); /// # Ok(()) /// # } /// ``` @@ -122,21 +209,23 @@ impl Value> { - pub fn extract_map(&self, allocator: &Allocator) -> HashMap { - self.try_extract_map(allocator).expect("Failed to extract map") +impl Value> { + pub fn extract_map(&self) -> HashMap { + self.try_extract_map().expect("Failed to extract map") } +} +impl Value> { /// Converts from a strongly-typed [`Map`] to a type-erased [`DynMap`]. #[inline] pub fn upcast(self) -> DynMap { @@ -149,7 +238,7 @@ impl(PhantomData); -impl ValueTypeMarker for SequenceValueType {} -impl SequenceValueTypeMarker for SequenceValueType {} +impl ValueTypeMarker for SequenceValueType { + crate::private_impl!(); +} +impl SequenceValueTypeMarker for SequenceValueType { + crate::private_impl!(); +} pub type DynSequence = Value; pub type Sequence = Value>; @@ -89,11 +104,11 @@ impl Value Value Value { + /// Construct a [`DynTensor`] from an array of strings. /// - /// Just like numeric tensors, string tensor `Value`s can be created from: + /// Just like numeric tensors, string tensors can be created from: /// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`); /// - (with feature `ndarray`) a mutable/exclusive reference to an [`ndarray::ArcArray`] (`&mut ArcArray`); /// - (with feature `ndarray`) an owned [`ndarray::Array`]; @@ -36,26 +35,19 @@ impl DynTensor { /// ``` /// # use ort::{Session, Value}; /// # fn main() -> ort::Result<()> { - /// # let session = Session::builder()?.commit_from_file("tests/data/vectorizer.onnx")?; - /// // You'll need to obtain an `Allocator` from a session in order to create string tensors. - /// let allocator = session.allocator(); - /// /// // Create a string tensor from a raw data vector /// let data = vec!["hello", "world"]; - /// let value = Value::from_string_array(allocator, ([data.len()], data.into_boxed_slice()))?; + /// let value = Value::from_string_array(([data.len()], data.into_boxed_slice()))?; /// /// // Create a string tensor from an `ndarray::Array` /// #[cfg(feature = "ndarray")] - /// let value = Value::from_string_array( - /// allocator, - /// ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap() - /// )?; + /// let value = Value::from_string_array(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap())?; /// # Ok(()) /// # } /// ``` /// /// Note that string data will *always* be copied, no matter what form the data is provided in. - pub fn from_string_array(allocator: &Allocator, input: impl IntoValueTensor) -> Result { + pub fn from_string_array(input: impl IntoValueTensor) -> Result> { let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); let (shape, data) = input.ref_parts()?; @@ -64,7 +56,7 @@ impl DynTensor { // create tensor without data -- data is filled in later ortsys![ - unsafe CreateTensorAsOrtValue(allocator.ptr.as_ptr(), shape_ptr, shape_len as _, TensorElementType::String.into(), &mut value_ptr) + unsafe CreateTensorAsOrtValue(Allocator::default().ptr.as_ptr(), shape_ptr, shape_len as _, TensorElementType::String.into(), &mut value_ptr) -> Error::CreateTensor; nonNull(value_ptr) ]; @@ -84,18 +76,18 @@ impl DynTensor { ortsys![unsafe FillStringTensor(value_ptr, string_pointers.as_ptr(), string_pointers.len() as _) -> Error::FillStringTensor]; Ok(Value { - inner: ValueInner::RustOwned { + inner: Arc::new(ValueInner::RustOwned { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, _array: Box::new(()), _memory_info: None - }, + }), _markers: PhantomData }) } } -impl Tensor { - /// Construct a tensor [`Value`] in a given allocator with a given shape and datatype. The data contained in the +impl Tensor { + /// Construct a tensor in a given allocator with a given shape and datatype. The data contained in the /// value will be zero-allocated on the allocation device. /// /// This can be used to create a tensor with data on a certain device. For example, to create a tensor with pinned @@ -132,18 +124,18 @@ impl Tensor { ]; Ok(Value { - inner: ValueInner::RustOwned { + inner: Arc::new(ValueInner::RustOwned { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, _array: Box::new(()), _memory_info: None - }, + }), _markers: PhantomData }) } - /// Construct a tensor [`Value`] from an array of data. + /// Construct a tensor from an array of data. /// - /// Tensor `Value`s can be created from: + /// Tensors can be created from: /// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`); /// - (with feature `ndarray`) a mutable/exclusive reference to an [`ndarray::ArcArray`] (`&mut ArcArray`); /// - (with feature `ndarray`) an owned [`ndarray::Array`]; @@ -154,19 +146,19 @@ impl Tensor { /// * and `data` is one of `Vec`, `Box<[T]>`, `Arc>`, or `&[T]`. /// /// ``` - /// # use ort::Value; + /// # use ort::Tensor; /// # fn main() -> ort::Result<()> { /// // Create a tensor from a raw data vector - /// let value = Value::from_array(([1usize, 2, 3], vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0].into_boxed_slice()))?; + /// let tensor = Tensor::from_array(([1usize, 2, 3], vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0].into_boxed_slice()))?; /// /// // Create a tensor from an `ndarray::Array` /// #[cfg(feature = "ndarray")] - /// let value = Value::from_array(ndarray::Array4::::zeros((1, 16, 16, 3)))?; + /// let tensor = Tensor::from_array(ndarray::Array4::::zeros((1, 16, 16, 3)))?; /// # Ok(()) /// # } /// ``` /// - /// Creating string tensors requires a separate method; see [`Value::from_string_array`]. + /// Creating string tensors requires a separate method; see [`DynTensor::from_string_array`]. /// /// Note that data provided in an `ndarray` may be copied in some circumstances: /// - `&CowArray<'_, T, D>` will always be copied regardless of whether it is uniquely owned or borrowed. @@ -177,7 +169,7 @@ impl Tensor { /// Raw data provided as a `Arc>`, `Box<[T]>`, or `Vec` will never be copied. Raw data is expected to be /// in standard, contigous layout. pub fn from_array(input: impl IntoValueTensor) -> Result> { - let memory_info = MemoryInfo::new_cpu(AllocatorType::Arena, MemoryType::Default)?; + let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::Default)?; let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); @@ -203,17 +195,17 @@ impl Tensor { ]; Ok(Value { - inner: ValueInner::RustOwned { + inner: Arc::new(ValueInner::RustOwned { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, _array: guard, _memory_info: Some(memory_info) - }, + }), _markers: PhantomData }) } } -impl<'a, T: IntoTensorElementType + Debug> TensorRefMut<'a, T> { +impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { /// Create a mutable tensor view from a raw pointer and shape. /// /// The length of data is determined by `T` and the given shape, so the given buffer must be at least @@ -260,11 +252,11 @@ impl<'a, T: IntoTensorElementType + Debug> TensorRefMut<'a, T> { ]; Ok(TensorRefMut::new(Value { - inner: ValueInner::CppOwned { + inner: Arc::new(ValueInner::CppOwned { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, drop: true, _session: None - }, + }), _markers: PhantomData })) } @@ -290,7 +282,7 @@ macro_rules! impl_to_dimensions { .enumerate() .map(|(i, c)| if *c >= 1 { Ok(*c as i64) } else { Err(Error::InvalidDimension(i)) }) .collect::>()?; - let sum = v.iter().product::() as usize; + let sum = calculate_tensor_size(&v); if let Some(expected_size) = expected_size { if sum != expected_size { Err(Error::TensorShapeMismatch { @@ -318,6 +310,14 @@ macro_rules! impl_to_dimensions { }; } +impl ToDimensions for () { + fn to_dimensions(&self, expected_size: Option) -> Result> { + match expected_size { + Some(1) | None => Ok(vec![]), + Some(x) => Err(Error::TensorShapeMismatch { input: vec![], total: 1, expected: x }) + } + } +} impl_to_dimensions!(for &[usize], for &[i32], for &[i64], for Vec, for Vec, for Vec); impl_to_dimensions!( for [usize; N], for [i32; N], for [i64; N]); @@ -500,7 +500,7 @@ impl IntoValueTensor for (D, Arc TryFrom<&'i CowArray<'v, T, D>> for Tensor +impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for Tensor where 'i: 'v { @@ -512,7 +512,7 @@ where #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for Tensor { +impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for Tensor { type Error = Error; fn try_from(arr: ArrayView<'v, T, D>) -> Result { Tensor::from_array(arr) @@ -521,7 +521,7 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynTensor +impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynTensor where 'i: 'v { @@ -533,7 +533,7 @@ where #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynTensor { +impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynTensor { type Error = Error; fn try_from(arr: ArrayView<'v, T, D>) -> Result { Tensor::from_array(arr).map(|c| c.upcast()) @@ -542,7 +542,7 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynValue +impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynValue where 'i: 'v { @@ -554,7 +554,7 @@ where #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynValue { +impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynValue { type Error = Error; fn try_from(arr: ArrayView<'v, T, D>) -> Result { Tensor::from_array(arr).map(|c| c.into_dyn()) @@ -564,19 +564,19 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta macro_rules! impl_try_from { (@T,I $($t:ty),+) => { $( - impl TryFrom<$t> for Tensor { + impl TryFrom<$t> for Tensor { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value) } } - impl TryFrom<$t> for DynTensor { + impl TryFrom<$t> for DynTensor { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value).map(|c| c.upcast()) } } - impl TryFrom<$t> for crate::DynValue { + impl TryFrom<$t> for crate::DynValue { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value).map(|c| c.into_dyn()) @@ -587,21 +587,21 @@ macro_rules! impl_try_from { (@T,D $($t:ty),+) => { $( #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - impl TryFrom<$t> for Tensor { + impl TryFrom<$t> for Tensor { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value) } } #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - impl TryFrom<$t> for DynTensor { + impl TryFrom<$t> for DynTensor { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value).map(|c| c.upcast()) } } #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - impl TryFrom<$t> for crate::DynValue { + impl TryFrom<$t> for crate::DynValue { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value).map(|c| c.into_dyn()) diff --git a/src/value/impl_tensor/extract.rs b/src/value/impl_tensor/extract.rs index a4859e73..1b27d8ab 100644 --- a/src/value/impl_tensor/extract.rs +++ b/src/value/impl_tensor/extract.rs @@ -1,15 +1,16 @@ -use std::{fmt::Debug, os::raw::c_char, ptr, string::FromUtf8Error}; +use std::{fmt::Debug, ptr, string::FromUtf8Error}; #[cfg(feature = "ndarray")] use ndarray::IxDyn; -use super::TensorValueTypeMarker; +use super::{calculate_tensor_size, Tensor, TensorValueTypeMarker}; #[cfg(feature = "ndarray")] use crate::tensor::{extract_primitive_array, extract_primitive_array_mut}; use crate::{ + error::{Error, Result}, ortsys, - tensor::{IntoTensorElementType, TensorElementType}, - Error, Result, Tensor, Value + tensor::{PrimitiveTensorElementType, TensorElementType}, + value::{Value, ValueType} }; impl Value { @@ -38,38 +39,81 @@ impl Value { /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the /// infallible [`Tensor::extract_tensor`] instead)* /// - The provided type `T` does not match the tensor's element type. + /// - The tensor's data is not allocated in CPU memory. #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - pub fn try_extract_tensor(&self) -> Result> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; + pub fn try_extract_tensor(&self) -> Result> { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if ty == T::into_tensor_element_type() { + Ok(extract_primitive_array(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), self.ptr())?) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } + } + t => Err(Error::NotTensor(t)) + } + } - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == T::into_tensor_element_type() { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; - - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::>()); - - let mut len = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - - Ok(extract_primitive_array(shape, self.ptr())?) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: T::into_tensor_element_type() - }) + /// Attempt to extract the scalar from a tensor of type `T`. + /// + /// ``` + /// # use std::sync::Arc; + /// # use ort::{Session, Value}; + /// # fn main() -> ort::Result<()> { + /// let value = Value::from_array(((), vec![3.14_f32]))?; + /// + /// let extracted = value.try_extract_scalar::()?; + /// assert_eq!(extracted, 3.14); + /// # Ok(()) + /// # } + /// ``` + /// + /// # Errors + /// May return an error if: + /// - The tensor is not 0-dimensional. + /// - The provided type `T` does not match the tensor's element type. + /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the + /// infallible [`Tensor::extract_tensor`] instead)* + /// - The tensor's data is not allocated in CPU memory. + pub fn try_extract_scalar(&self) -> Result { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if !dimensions.is_empty() { + return Err(Error::TensorNot0Dimensional(dimensions.len())); + } + + if ty == T::into_tensor_element_type() { + let mut output_array_ptr: *mut T = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); + ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; + + Ok(unsafe { *output_array_ptr }) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying data of type `T` into a mutable read-only [`ndarray::ArrayViewMut`]. @@ -101,36 +145,26 @@ impl Value { /// - The provided type `T` does not match the tensor's element type. #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - pub fn try_extract_tensor_mut(&mut self) -> Result> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; - - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == T::into_tensor_element_type() { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; - - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::>()); - - let mut len = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - - Ok(extract_primitive_array_mut(shape, self.ptr())?) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: T::into_tensor_element_type() - }) + pub fn try_extract_tensor_mut(&mut self) -> Result> { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if ty == T::into_tensor_element_type() { + Ok(extract_primitive_array_mut(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), self.ptr())?) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and an @@ -159,40 +193,32 @@ impl Value { /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the /// infallible [`Tensor::extract_raw_tensor`] instead)* /// - The provided type `T` does not match the tensor's element type. - pub fn try_extract_raw_tensor(&self) -> Result<(Vec, &[T])> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; - - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == T::into_tensor_element_type() { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; - - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - - let mut output_array_ptr: *mut T = ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; - - let mut len = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - - Ok((node_dims, unsafe { std::slice::from_raw_parts(output_array_ptr, len as _) })) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: T::into_tensor_element_type() - }) + pub fn try_extract_raw_tensor(&self) -> Result<(Vec, &[T])> { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if ty == T::into_tensor_element_type() { + let mut output_array_ptr: *mut T = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); + ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; + + let len = calculate_tensor_size(&dimensions); + Ok((dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) })) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a @@ -218,50 +244,41 @@ impl Value { /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the /// infallible [`Tensor::extract_raw_tensor_mut`] instead)* /// - The provided type `T` does not match the tensor's element type. - pub fn try_extract_raw_tensor_mut(&mut self) -> Result<(Vec, &mut [T])> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; - - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == T::into_tensor_element_type() { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; - - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - - let mut output_array_ptr: *mut T = ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; - - let mut len = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - - Ok((node_dims, unsafe { std::slice::from_raw_parts_mut(output_array_ptr, len as _) })) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: T::into_tensor_element_type() - }) + pub fn try_extract_raw_tensor_mut(&mut self) -> Result<(Vec, &mut [T])> { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if ty == T::into_tensor_element_type() { + let mut output_array_ptr: *mut T = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); + ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; + + let len = calculate_tensor_size(&dimensions); + Ok((dimensions, unsafe { std::slice::from_raw_parts_mut(output_array_ptr, len) })) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying data into a Rust `ndarray`. /// /// ``` - /// # use ort::{Allocator, Session, DynTensor, TensorElementType}; + /// # use ort::{Session, Tensor, TensorElementType}; /// # fn main() -> ort::Result<()> { - /// # let allocator = Allocator::default(); /// let array = ndarray::Array1::from_vec(vec!["hello", "world"]); - /// let tensor = DynTensor::from_string_array(&allocator, array.clone())?; + /// let tensor = Tensor::from_string_array(array.clone())?; /// /// let extracted = tensor.try_extract_string_tensor()?; /// assert_eq!(array.into_dyn(), extracted); @@ -271,78 +288,68 @@ impl Value { #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub fn try_extract_string_tensor(&self) -> Result> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; - - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == TensorElementType::String { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; - - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::>()); - - let mut len: ort_sys::size_t = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - - // Total length of string data, not including \0 suffix - let mut total_length: ort_sys::size_t = 0; - ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength]; - - // In the JNI impl of this, tensor_element_len was included in addition to total_length, - // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes - // don't seem to be written to in practice either. - // If the string data actually did go farther, it would panic below when using the offset - // data to get slices for each string. - let mut string_contents = vec![0u8; total_length as _]; - // one extra slot so that the total length can go in the last one, making all per-string - // length calculations easy - let mut offsets = vec![0; (len + 1) as _]; - - ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent]; - - // final offset = overall length so that per-string length calculations work for the last string - debug_assert_eq!(0, offsets[len as usize]); - offsets[len as usize] = total_length; - - let strings = offsets - // offsets has 1 extra offset past the end so that all windows work - .windows(2) - .map(|w| { - let slice = &string_contents[w[0] as _..w[1] as _]; - String::from_utf8(slice.into()) + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if ty == TensorElementType::String { + let len = calculate_tensor_size(&dimensions); + + // Total length of string data, not including \0 suffix + let mut total_length: ort_sys::size_t = 0; + ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength]; + + // In the JNI impl of this, tensor_element_len was included in addition to total_length, + // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes + // don't seem to be written to in practice either. + // If the string data actually did go farther, it would panic below when using the offset + // data to get slices for each string. + let mut string_contents = vec![0u8; total_length as _]; + // one extra slot so that the total length can go in the last one, making all per-string + // length calculations easy + let mut offsets = vec![0; (len + 1) as _]; + + ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len as _) -> Error::GetStringTensorContent]; + + // final offset = overall length so that per-string length calculations work for the last string + debug_assert_eq!(0, offsets[len]); + offsets[len] = total_length; + + let strings = offsets + // offsets has 1 extra offset past the end so that all windows work + .windows(2) + .map(|w| { + let slice = &string_contents[w[0] as _..w[1] as _]; + String::from_utf8(slice.into()) + }) + .collect::, FromUtf8Error>>() + .map_err(Error::StringFromUtf8Error)?; + + Ok(ndarray::Array::from_shape_vec(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), strings) + .expect("Shape extracted from tensor didn't match tensor contents")) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: TensorElementType::String }) - .collect::, FromUtf8Error>>() - .map_err(Error::StringFromUtf8Error)?; - - Ok(ndarray::Array::from_shape_vec(shape, strings) - .expect("Shape extracted from tensor didn't match tensor contents") - .into_dyn()) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: TensorElementType::String - }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying string data into a "raw" data tuple, consisting of the tensor's dimensions and /// an owned `Vec` of its data. /// /// ``` - /// # use ort::{Allocator, Session, DynTensor, TensorElementType}; + /// # use ort::{Session, Tensor, TensorElementType}; /// # fn main() -> ort::Result<()> { - /// # let allocator = Allocator::default(); /// let array = vec!["hello", "world"]; - /// let tensor = DynTensor::from_string_array(&allocator, ([array.len()], array.clone().into_boxed_slice()))?; + /// let tensor = Tensor::from_string_array(([array.len()], array.clone().into_boxed_slice()))?; /// /// let (extracted_shape, extracted_data) = tensor.try_extract_raw_string_tensor()?; /// assert_eq!(extracted_data, array); @@ -351,68 +358,57 @@ impl Value { /// # } /// ``` pub fn try_extract_raw_string_tensor(&self) -> Result<(Vec, Vec)> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; - - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == TensorElementType::String { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; - - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - - let mut output_array_ptr: *mut c_char = ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut c_char = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; - - let mut len: ort_sys::size_t = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - // Total length of string data, not including \0 suffix - let mut total_length = 0; - ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength]; - - // In the JNI impl of this, tensor_element_len was included in addition to total_length, - // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes - // don't seem to be written to in practice either. - // If the string data actually did go farther, it would panic below when using the offset - // data to get slices for each string. - let mut string_contents = vec![0u8; total_length as _]; - // one extra slot so that the total length can go in the last one, making all per-string - // length calculations easy - let mut offsets = vec![0; len as usize + 1]; - - ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length as _, offsets.as_mut_ptr(), len as _) -> Error::GetStringTensorContent]; - - // final offset = overall length so that per-string length calculations work for the last string - debug_assert_eq!(0, offsets[len as usize]); - offsets[len as usize] = total_length; - - let strings = offsets - // offsets has 1 extra offset past the end so that all windows work - .windows(2) - .map(|w| { - let slice = &string_contents[w[0] as _..w[1] as _]; - String::from_utf8(slice.into()) + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if ty == TensorElementType::String { + let len = calculate_tensor_size(&dimensions); + + // Total length of string data, not including \0 suffix + let mut total_length: ort_sys::size_t = 0; + ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength]; + + // In the JNI impl of this, tensor_element_len was included in addition to total_length, + // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes + // don't seem to be written to in practice either. + // If the string data actually did go farther, it would panic below when using the offset + // data to get slices for each string. + let mut string_contents = vec![0u8; total_length as _]; + // one extra slot so that the total length can go in the last one, making all per-string + // length calculations easy + let mut offsets = vec![0; (len + 1) as _]; + + ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len as _) -> Error::GetStringTensorContent]; + + // final offset = overall length so that per-string length calculations work for the last string + debug_assert_eq!(0, offsets[len]); + offsets[len] = total_length; + + let strings = offsets + // offsets has 1 extra offset past the end so that all windows work + .windows(2) + .map(|w| { + let slice = &string_contents[w[0] as _..w[1] as _]; + String::from_utf8(slice.into()) + }) + .collect::, FromUtf8Error>>() + .map_err(Error::StringFromUtf8Error)?; + + Ok((dimensions, strings)) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: TensorElementType::String }) - .collect::, FromUtf8Error>>() - .map_err(Error::StringFromUtf8Error)?; - - Ok((node_dims, strings)) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: TensorElementType::String - }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Returns the shape of the tensor. @@ -445,7 +441,7 @@ impl Value { } } -impl Tensor { +impl Tensor { /// Extracts the underlying data into a read-only [`ndarray::ArrayView`]. /// /// ``` diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index d7f1db3b..92a08c96 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -8,44 +8,104 @@ use std::{ ptr::NonNull }; -use super::{DowncastableTarget, Value, ValueInner, ValueTypeMarker}; -use crate::{ortsys, DynValue, Error, IntoTensorElementType, MemoryInfo, Result, ValueRef, ValueRefMut, ValueType}; +use super::{DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker}; +use crate::{ + error::{Error, Result}, + memory::MemoryInfo, + ortsys, + tensor::IntoTensorElementType +}; -pub trait TensorValueTypeMarker: ValueTypeMarker {} +pub trait TensorValueTypeMarker: ValueTypeMarker { + crate::private_trait!(); +} #[derive(Debug)] pub struct DynTensorValueType; -impl ValueTypeMarker for DynTensorValueType {} -impl TensorValueTypeMarker for DynTensorValueType {} +impl ValueTypeMarker for DynTensorValueType { + crate::private_impl!(); +} +impl TensorValueTypeMarker for DynTensorValueType { + crate::private_impl!(); +} #[derive(Debug)] pub struct TensorValueType(PhantomData); -impl ValueTypeMarker for TensorValueType {} -impl TensorValueTypeMarker for TensorValueType {} +impl ValueTypeMarker for TensorValueType { + crate::private_impl!(); +} +impl TensorValueTypeMarker for TensorValueType { + crate::private_impl!(); +} +/// A tensor [`Value`] whose data type is unknown. pub type DynTensor = Value; +/// A strongly-typed tensor [`Value`]. pub type Tensor = Value>; +/// A reference to a tensor [`Value`] whose data type is unknown. pub type DynTensorRef<'v> = ValueRef<'v, DynTensorValueType>; +/// A mutable reference to a tensor [`Value`] whose data type is unknown. pub type DynTensorRefMut<'v> = ValueRefMut<'v, DynTensorValueType>; +/// A reference to a strongly-typed tensor [`Value`]. pub type TensorRef<'v, T> = ValueRef<'v, TensorValueType>; +/// A mutable reference to a strongly-typed tensor [`Value`]. pub type TensorRefMut<'v, T> = ValueRefMut<'v, TensorValueType>; impl DowncastableTarget for DynTensorValueType { fn can_downcast(dtype: &ValueType) -> bool { matches!(dtype, ValueType::Tensor { .. }) } + + crate::private_impl!(); } impl Value { /// Returns a mutable pointer to the tensor's data. + /// + /// It's important to note that the resulting pointer may not point to CPU-accessible memory. In the case of a + /// tensor created on a different EP device, e.g. via [`Tensor::new`], the pointer returned by this function may be + /// a CUDA pointer, which would require a separate crate (like [`cudarc`](https://crates.io/crates/cudarc)) to access. + /// Use [`Tensor::memory_info`] & [`MemoryInfo::allocation_device`] to check which device the data resides on before + /// accessing it. + /// + /// ``` + /// # use ort::{Allocator, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let mut tensor = Tensor::::from_array((vec![5], vec![0, 1, 2, 3, 4]))?; + /// let ptr = tensor.data_ptr_mut()?.cast::(); + /// unsafe { + /// *ptr.add(3) = 42; + /// }; + /// + /// let (_, extracted) = tensor.extract_raw_tensor(); + /// assert_eq!(&extracted, &[0, 1, 2, 42, 4]); + /// # Ok(()) + /// # } + /// ``` pub fn data_ptr_mut(&mut self) -> Result<*mut ort_sys::c_void> { let mut buffer_ptr: *mut ort_sys::c_void = std::ptr::null_mut(); ortsys![unsafe GetTensorMutableData(self.ptr(), &mut buffer_ptr) -> Error::GetTensorMutableData; nonNull(buffer_ptr)]; Ok(buffer_ptr) } - /// Returns a pointer to the tensor's data. + /// Returns an immutable pointer to the tensor's underlying data. + /// + /// It's important to note that the resulting pointer may not point to CPU-accessible memory. In the case of a + /// tensor created on a different EP device, e.g. via [`Tensor::new`], the pointer returned by this function may be + /// a CUDA pointer, which would require a separate crate (like [`cudarc`](https://crates.io/crates/cudarc)) to access. + /// Use [`Tensor::memory_info`] & [`MemoryInfo::allocation_device`] to check which device the data resides on before + /// accessing it. + /// + /// ``` + /// # use ort::{Allocator, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let tensor = Tensor::::from_array((vec![5], vec![0, 1, 2, 3, 4]))?; + /// let ptr = tensor.data_ptr()?.cast::(); + /// assert_eq!(unsafe { *ptr.add(3) }, 3); + /// # Ok(()) + /// # } + /// ``` pub fn data_ptr(&self) -> Result<*const ort_sys::c_void> { let mut buffer_ptr: *mut ort_sys::c_void = std::ptr::null_mut(); ortsys![unsafe GetTensorMutableData(self.ptr(), &mut buffer_ptr) -> Error::GetTensorMutableData; nonNull(buffer_ptr)]; @@ -53,6 +113,26 @@ impl Value { } /// Returns information about the device this tensor is allocated on. + /// + /// ``` + /// # use ort::{Allocator, AllocatorType, AllocationDevice, MemoryInfo, MemoryType, Session, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let tensor = Tensor::::new(&Allocator::default(), [1, 3, 224, 224])?; + /// // Tensors are allocated on CPU by default. + /// assert_eq!(tensor.memory_info()?.allocation_device()?, AllocationDevice::CPU); + /// + /// # if false { + /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let cuda_allocator = Allocator::new( + /// &session, + /// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)? + /// )?; + /// let tensor = Tensor::::new(&cuda_allocator, [1, 3, 224, 224])?; + /// assert_eq!(tensor.memory_info()?.allocation_device()?, AllocationDevice::CUDA); + /// # } + /// # Ok(()) + /// # } + /// ``` pub fn memory_info(&self) -> Result { let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = std::ptr::null_mut(); ortsys![unsafe GetTensorMemoryInfo(self.ptr(), &mut memory_info_ptr) -> Error::GetTensorMemoryInfo; nonNull(memory_info_ptr)]; @@ -62,29 +142,68 @@ impl Value { impl Tensor { /// Converts from a strongly-typed [`Tensor`] to a type-erased [`DynTensor`]. + /// + /// ``` + /// # use ort::{Allocator, DynTensor, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let tensor = Tensor::::new(&Allocator::default(), [1, 3, 224, 224])?; + /// let tensor_dyn = tensor.upcast(); + /// assert!(tensor_dyn.try_extract_raw_tensor::().is_ok()); + /// assert!(tensor_dyn.try_extract_raw_tensor::().is_err()); + /// # Ok(()) + /// # } + /// ``` #[inline] pub fn upcast(self) -> DynTensor { unsafe { std::mem::transmute(self) } } - /// Converts from a strongly-typed [`Tensor`] to a reference to a type-erased [`DynTensor`]. + /// Creates a type-erased [`DynTensorRef`] from a strongly-typed [`Tensor`]. + /// + /// ``` + /// # use ort::{Allocator, DynTensor, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let tensor = Tensor::::new(&Allocator::default(), [1, 3, 224, 224])?; + /// let tensor_dyn = tensor.upcast_ref(); + /// + /// let (_, original_extract) = tensor.extract_raw_tensor(); + /// let (_, ref_extract) = tensor_dyn.try_extract_raw_tensor::()?; + /// assert_eq!(original_extract, ref_extract); + /// # Ok(()) + /// # } + /// ``` #[inline] pub fn upcast_ref(&self) -> DynTensorRef { DynTensorRef::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) }) } /// Converts from a strongly-typed [`Tensor`] to a mutable reference to a type-erased [`DynTensor`]. + /// + /// ``` + /// # use ort::{Allocator, DynTensor, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let mut tensor = Tensor::::from_array((vec![5], vec![1, 2, 3, 4, 5]))?; + /// let mut tensor_dyn = tensor.upcast_mut(); + /// + /// let (_, mut_view) = tensor_dyn.try_extract_raw_tensor_mut::()?; + /// mut_view[3] = 0; + /// + /// let (_, original_view) = tensor.extract_raw_tensor(); + /// assert_eq!(original_view, &[1, 2, 3, 0, 5]); + /// # Ok(()) + /// # } + /// ``` #[inline] pub fn upcast_mut(&mut self) -> DynTensorRefMut { DynTensorRefMut::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) }) } @@ -97,6 +216,8 @@ impl DowncastableTarget for TensorValueType _ => false } } + + crate::private_impl!(); } impl From>> for DynValue { @@ -113,6 +234,17 @@ impl From> for DynValue { impl Index<[i64; N]> for Tensor { type Output = T; fn index(&self, index: [i64; N]) -> &Self::Output { + // Interestingly, the `TensorAt` API doesn't check if the tensor is on CPU, so we have to perform the check ourselves. + if !self + .memory_info() + .expect("could not retrieve tensor memory info") + .allocation_device() + .expect("could not retrieve tensor allocation device") + .is_cpu_accessible() + { + panic!("Cannot directly index a tensor which is not allocated on the CPU."); + } + let mut out: *mut ort_sys::c_void = std::ptr::null_mut(); ortsys![unsafe TensorAt(self.ptr(), index.as_ptr(), N as _, &mut out).expect("Failed to index tensor")]; unsafe { &*out.cast::() } @@ -120,25 +252,46 @@ impl Index<[i64; N]> f } impl IndexMut<[i64; N]> for Tensor { fn index_mut(&mut self, index: [i64; N]) -> &mut Self::Output { + if !self + .memory_info() + .expect("could not retrieve tensor memory info") + .allocation_device() + .expect("could not retrieve tensor allocation device") + .is_cpu_accessible() + { + panic!("Cannot directly index a tensor which is not allocated on the CPU."); + } + let mut out: *mut ort_sys::c_void = std::ptr::null_mut(); ortsys![unsafe TensorAt(self.ptr(), index.as_ptr(), N as _, &mut out).expect("Failed to index tensor")]; unsafe { &mut *out.cast::() } } } +pub(crate) fn calculate_tensor_size(shape: &[i64]) -> usize { + let mut size = 1usize; + for dim in shape { + if *dim < 0 { + return 0; + } + size *= *dim as usize; + } + size +} + #[cfg(test)] mod tests { use std::sync::Arc; use ndarray::{ArcArray1, Array1, CowArray}; - use crate::{Allocator, DynTensor, TensorElementType, Value, ValueType}; + use crate::{Tensor, TensorElementType, ValueType}; #[test] #[cfg(feature = "ndarray")] fn test_tensor_value() -> crate::Result<()> { let v: Vec = vec![1., 2., 3., 4., 5.]; - let value = Value::from_array(Array1::from_vec(v.clone()))?; + let value = Tensor::from_array(Array1::from_vec(v.clone()))?; assert!(value.is_tensor()?); assert_eq!(value.dtype()?.tensor_type(), Some(TensorElementType::Float32)); assert_eq!( @@ -163,17 +316,17 @@ mod tests { let arc1 = ArcArray1::from_vec(v.clone()); let mut arc2 = ArcArray1::clone(&arc1); - let value = Value::from_array(&mut arc2)?; + let value = Tensor::from_array(&mut arc2)?; drop((arc1, arc2)); assert_eq!(value.extract_raw_tensor().1, &v); let cow = CowArray::from(Array1::from_vec(v.clone())); - let value = Value::from_array(&cow)?; + let value = Tensor::from_array(&cow)?; assert_eq!(value.extract_raw_tensor().1, &v); let owned = Array1::from_vec(v.clone()); - let value = Value::from_array(owned.view())?; + let value = Tensor::from_array(owned.view())?; drop(owned); assert_eq!(value.extract_raw_tensor().1, &v); @@ -186,7 +339,7 @@ mod tests { let arc = Arc::new(v.clone().into_boxed_slice()); let shape = vec![v.len() as i64]; - let value = Value::from_array((shape, Arc::clone(&arc)))?; + let value = Tensor::from_array((shape, Arc::clone(&arc)))?; drop(arc); assert_eq!(value.try_extract_raw_tensor::()?.1, &v); @@ -196,10 +349,9 @@ mod tests { #[test] #[cfg(feature = "ndarray")] fn test_string_tensor_ndarray() -> crate::Result<()> { - let allocator = Allocator::default(); let v = Array1::from_vec(vec!["hello world".to_string(), "こんにちは世界".to_string()]); - let value = DynTensor::from_string_array(&allocator, v.view())?; + let value = Tensor::from_string_array(v.view())?; let extracted = value.try_extract_string_tensor()?; assert_eq!(extracted, v.into_dyn()); @@ -208,10 +360,9 @@ mod tests { #[test] fn test_string_tensor_raw() -> crate::Result<()> { - let allocator = Allocator::default(); let v = vec!["hello world".to_string(), "こんにちは世界".to_string()]; - let value = DynTensor::from_string_array(&allocator, (vec![v.len() as i64], v.clone().into_boxed_slice()))?; + let value = Tensor::from_string_array((vec![v.len() as i64], v.clone().into_boxed_slice()))?; let (extracted_shape, extracted_view) = value.try_extract_raw_string_tensor()?; assert_eq!(extracted_shape, [v.len() as i64]); assert_eq!(extracted_view, v); @@ -224,10 +375,10 @@ mod tests { let v: Vec = vec![1., 2., 3., 4., 5.]; let shape = [v.len()]; - let value_arc_box = Value::from_array((shape, Arc::new(v.clone().into_boxed_slice())))?; - let value_box = Value::from_array((shape, v.clone().into_boxed_slice()))?; - let value_vec = Value::from_array((shape, v.clone()))?; - let value_slice = Value::from_array((shape, &v[..]))?; + let value_arc_box = Tensor::from_array((shape, Arc::new(v.clone().into_boxed_slice())))?; + let value_box = Tensor::from_array((shape, v.clone().into_boxed_slice()))?; + let value_vec = Tensor::from_array((shape, v.clone()))?; + let value_slice = Tensor::from_array((shape, &v[..]))?; assert_eq!(value_arc_box.extract_raw_tensor().1, &v); assert_eq!(value_box.extract_raw_tensor().1, &v); diff --git a/src/value/mod.rs b/src/value/mod.rs index 61b49d18..2dad7ffb 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -16,9 +16,15 @@ pub use self::{ impl_sequence::{ DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, Sequence, SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker }, - impl_tensor::{DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker} + impl_tensor::{DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker} +}; +use crate::{ + error::{status_to_result, Error, Result}, + memory::MemoryInfo, + ortsys, + session::SharedSessionInner, + tensor::TensorElementType }; -use crate::{error::status_to_result, memory::MemoryInfo, ortsys, session::SharedSessionInner, tensor::TensorElementType, Error, Result}; /// The type of a [`Value`], or a session input/output. /// @@ -85,6 +91,30 @@ pub enum ValueType { } impl ValueType { + pub(crate) fn from_type_info(typeinfo_ptr: *mut ort_sys::OrtTypeInfo) -> Result { + let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; + ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty) -> Error::GetOnnxTypeFromTypeInfo]; + let io_type = match ty { + ort_sys::ONNXType::ONNX_TYPE_TENSOR | ort_sys::ONNXType::ONNX_TYPE_SPARSETENSOR => { + let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToTensorInfo; nonNull(info_ptr)]; + unsafe { extract_data_type_from_tensor_info(info_ptr)? } + } + ort_sys::ONNXType::ONNX_TYPE_SEQUENCE => { + let mut info_ptr: *const ort_sys::OrtSequenceTypeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToSequenceTypeInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToSequenceTypeInfo; nonNull(info_ptr)]; + unsafe { extract_data_type_from_sequence_info(info_ptr)? } + } + ort_sys::ONNXType::ONNX_TYPE_MAP => { + let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToMapTypeInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToMapTypeInfo; nonNull(info_ptr)]; + unsafe { extract_data_type_from_map_info(info_ptr)? } + } + _ => unreachable!() + }; + ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)]; + Ok(io_type) + } /// Returns the dimensions of this value type if it is a tensor, or `None` if it is a sequence or map. /// /// ``` @@ -166,6 +196,14 @@ pub(crate) enum ValueInner { } } +impl ValueInner { + pub(crate) fn ptr(&self) -> *mut ort_sys::OrtValue { + match self { + ValueInner::CppOwned { ptr, .. } | ValueInner::RustOwned { ptr, .. } => ptr.as_ptr() + } + } +} + /// A temporary version of a [`Value`] with a lifetime specifier. #[derive(Debug)] pub struct ValueRef<'v, Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> { @@ -178,6 +216,14 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> { ValueRef { inner, lifetime: PhantomData } } + /// Attempts to downcast a temporary dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed + /// variant, like [`TensorRef`]. + #[inline] + pub fn downcast(self) -> Result> { + let dt = self.dtype()?; + if OtherType::can_downcast(&dt) { Ok(unsafe { std::mem::transmute(self) }) } else { panic!() } + } + pub fn into_dyn(self) -> ValueRef<'v, DynValueTypeMarker> { unsafe { std::mem::transmute(self) } } @@ -203,6 +249,14 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> { ValueRefMut { inner, lifetime: PhantomData } } + /// Attempts to downcast a temporary mutable dynamic value (like [`DynValue`] or [`DynTensor`]) to a more + /// strongly typed variant, like [`TensorRefMut`]. + #[inline] + pub fn downcast(self) -> Result> { + let dt = self.dtype()?; + if OtherType::can_downcast(&dt) { Ok(unsafe { std::mem::transmute(self) }) } else { panic!() } + } + pub fn into_dyn(self) -> ValueRefMut<'v, DynValueTypeMarker> { unsafe { std::mem::transmute(self) } } @@ -261,8 +315,8 @@ impl<'v, Type: ValueTypeMarker + ?Sized> DerefMut for ValueRefMut<'v, Type> { /// - [`Tensor::extract_tensor`], [`Tensor::extract_raw_tensor`] #[derive(Debug)] pub struct Value { - inner: ValueInner, - _markers: PhantomData + pub(crate) inner: Arc, + pub(crate) _markers: PhantomData } /// A dynamic value, which could be a [`Tensor`], [`Sequence`], or [`Map`]. @@ -275,11 +329,15 @@ pub type DynValue = Value; /// /// For example, [`Tensor::try_extract_tensor`] can only be used on [`Value`]s with the [`TensorValueTypeMarker`] (which /// inherits this trait), i.e. [`Tensor`]s, [`DynTensor`]s, and [`DynValue`]s. -pub trait ValueTypeMarker: Debug {} +pub trait ValueTypeMarker: Debug { + crate::private_trait!(); +} /// Represents a type that a [`DynValue`] can be downcast to. pub trait DowncastableTarget: ValueTypeMarker { fn can_downcast(dtype: &ValueType) -> bool; + + crate::private_trait!(); } // this implementation is used in case we want to extract `DynValue`s from a [`Sequence`]; see `try_extract_sequence` @@ -287,15 +345,25 @@ impl DowncastableTarget for DynValueTypeMarker { fn can_downcast(_: &ValueType) -> bool { true } + + crate::private_impl!(); } /// The dynamic type marker, used for values which can be of any type. #[derive(Debug)] pub struct DynValueTypeMarker; -impl ValueTypeMarker for DynValueTypeMarker {} -impl MapValueTypeMarker for DynValueTypeMarker {} -impl SequenceValueTypeMarker for DynValueTypeMarker {} -impl TensorValueTypeMarker for DynValueTypeMarker {} +impl ValueTypeMarker for DynValueTypeMarker { + crate::private_impl!(); +} +impl MapValueTypeMarker for DynValueTypeMarker { + crate::private_impl!(); +} +impl SequenceValueTypeMarker for DynValueTypeMarker { + crate::private_impl!(); +} +impl TensorValueTypeMarker for DynValueTypeMarker { + crate::private_impl!(); +} unsafe impl Send for Value {} @@ -304,37 +372,14 @@ impl Value { pub fn dtype(&self) -> Result { let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); ortsys![unsafe GetTypeInfo(self.ptr(), &mut typeinfo_ptr) -> Error::GetTypeInfo; nonNull(typeinfo_ptr)]; - - let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; - ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty) -> Error::GetOnnxTypeFromTypeInfo]; - let io_type = match ty { - ort_sys::ONNXType::ONNX_TYPE_TENSOR | ort_sys::ONNXType::ONNX_TYPE_SPARSETENSOR => { - let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToTensorInfo; nonNull(info_ptr)]; - unsafe { extract_data_type_from_tensor_info(info_ptr)? } - } - ort_sys::ONNXType::ONNX_TYPE_SEQUENCE => { - let mut info_ptr: *const ort_sys::OrtSequenceTypeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToSequenceTypeInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToSequenceTypeInfo; nonNull(info_ptr)]; - unsafe { extract_data_type_from_sequence_info(info_ptr)? } - } - ort_sys::ONNXType::ONNX_TYPE_MAP => { - let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToMapTypeInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToMapTypeInfo; nonNull(info_ptr)]; - unsafe { extract_data_type_from_map_info(info_ptr)? } - } - _ => unreachable!() - }; - - ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)]; - Ok(io_type) + ValueType::from_type_info(typeinfo_ptr) } /// Construct a [`Value`] from a C++ [`ort_sys::OrtValue`] pointer. /// /// If the value belongs to a session (i.e. if it is returned from [`crate::Session::run`] or /// [`crate::IoBinding::run`]), you must provide the [`SharedSessionInner`] (acquired from - /// [`crate::Session::inner`]). This ensures the session is not dropped until the value is. + /// [`crate::Session::inner`]). This ensures the session is not dropped until any values owned by it is. /// /// # Safety /// @@ -343,7 +388,7 @@ impl Value { #[must_use] pub unsafe fn from_ptr(ptr: NonNull, session: Option>) -> Value { Value { - inner: ValueInner::CppOwned { ptr, drop: true, _session: session }, + inner: Arc::new(ValueInner::CppOwned { ptr, drop: true, _session: session }), _markers: PhantomData } } @@ -353,16 +398,14 @@ impl Value { #[must_use] pub(crate) unsafe fn from_ptr_nodrop(ptr: NonNull, session: Option>) -> Value { Value { - inner: ValueInner::CppOwned { ptr, drop: false, _session: session }, + inner: Arc::new(ValueInner::CppOwned { ptr, drop: false, _session: session }), _markers: PhantomData } } /// Returns the underlying [`ort_sys::OrtValue`] pointer. pub fn ptr(&self) -> *mut ort_sys::OrtValue { - match &self.inner { - ValueInner::CppOwned { ptr, .. } | ValueInner::RustOwned { ptr, .. } => ptr.as_ptr() - } + self.inner.ptr() } /// Create a view of this value's data. @@ -370,7 +413,7 @@ impl Value { ValueRef::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) }) } @@ -380,7 +423,7 @@ impl Value { ValueRefMut::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) }) } @@ -426,7 +469,7 @@ impl Value { Ok(ValueRef::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) })) } else { @@ -434,7 +477,7 @@ impl Value { } } - /// Attempts to upcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed + /// Attempts to downcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed /// mutable-reference variant, like [`TensorRefMut`]. #[inline] pub fn downcast_mut(&mut self) -> Result> { @@ -443,7 +486,7 @@ impl Value { Ok(ValueRefMut::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) })) } else { @@ -452,17 +495,17 @@ impl Value { } } -impl Drop for Value { +impl Drop for ValueInner { fn drop(&mut self) { let ptr = self.ptr(); tracing::trace!( "dropping {} value at {ptr:p}", - match &self.inner { + match self { ValueInner::RustOwned { .. } => "rust-owned", ValueInner::CppOwned { .. } => "cpp-owned" } ); - if !matches!(&self.inner, ValueInner::CppOwned { drop: false, .. }) { + if !matches!(self, ValueInner::CppOwned { drop: false, .. }) { ortsys![unsafe ReleaseValue(ptr)]; } } diff --git a/src/wasm.rs b/src/wasm.rs index fe8bdf6d..235a2198 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -1,6 +1,6 @@ //! Utilities for using `ort` in WebAssembly. //! -//! You **must** call `ort::wasm::initialize()` before using any `ort` APIs: +//! You **must** call `ort::wasm::initialize()` before using any `ort` APIs in WASM: //! ``` //! # use ort::Session; //! # static MODEL_BYTES: &[u8] = include_bytes!("../tests/data/upsample.ort"); @@ -212,17 +212,23 @@ mod emscripten_shims { let c = str::from_utf8_unchecked(slice::from_raw_parts(str, len)); tracing::error!("Emscripten error: {c}"); } + + // despite disabling exceptions literally everywhere when compiling, we still have to stub this... + #[no_mangle] + pub unsafe extern "C" fn __cxa_throw(_ptr: *const (), _type: *const (), _destructor: *const ()) -> ! { + std::process::abort(); + } } #[no_mangle] #[export_name = "_initialize"] pub fn initialize() { - // No idea what the hell this does, but the presence of an `_initialize` function prevents the linker from calling - // `__wasm_call_ctors` at the top of every function - including the functions `wasm-bindgen` interprets to generate - // JS glue code. The `__wasm_call_ctors` call was calling complex functions that the interpreter isn't equipped to - // handle, which was preventing wbg from outputting anything. I don't know what specific constructors this is calling, - // and most basic ONNX Runtime APIs *do* work without calling this, but we encourage the user to perform this - // initialization at program start anyways to be safe. + // The presence of an `_initialize` function prevents the linker from calling `__wasm_call_ctors` at the top of every + // function - including the functions `wasm-bindgen` interprets to generate JS glue code. `__wasm_call_ctors` calls + // complex functions that wbg's interpreter isn't equipped to handle, which was preventing wbg from outputting + // anything. + // I'm not entirely sure what `__wasm_call_ctors` is initializing, but it seems to have something to do with C++ + // vtables, and it's crucial for proper operation. extern "C" { fn __wasm_call_ctors(); } diff --git a/tests/vectorizer.rs b/tests/vectorizer.rs index bef1a572..f3af20a7 100644 --- a/tests/vectorizer.rs +++ b/tests/vectorizer.rs @@ -3,7 +3,7 @@ use std::path::Path; use ndarray::{ArrayD, IxDyn}; -use ort::{inputs, DynTensor, GraphOptimizationLevel, Session}; +use ort::{inputs, GraphOptimizationLevel, Session, Tensor}; use test_log::test; #[test] @@ -22,7 +22,7 @@ fn vectorizer() -> ort::Result<()> { let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap()); // Just one input - let input_tensor_values = inputs![DynTensor::from_string_array(session.allocator(), &array)?]?; + let input_tensor_values = inputs![Tensor::from_string_array(&array)?]?; // Perform the inference let outputs = session.run(input_tensor_values)?; diff --git a/tools/requirements.txt b/tools/requirements.txt new file mode 100644 index 00000000..d49cd910 --- /dev/null +++ b/tools/requirements.txt @@ -0,0 +1,4 @@ +torch~=2.3 +torch-ort~=1.17 +onnx~=1.16 +--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT/pypi/simple/ onnxruntime-training-cpu==1.18.0 diff --git a/tools/train-data/mini-clm.py b/tools/train-data/mini-clm.py new file mode 100644 index 00000000..6f06a70f --- /dev/null +++ b/tools/train-data/mini-clm.py @@ -0,0 +1,140 @@ +import math + +import onnx +from onnxruntime.training import artifacts +import torch +from torch import nn, Tensor +from torch.nn import functional as F + +class RMSNorm(nn.Module): + def __init__(self, dim: int, *, eps: float = 1e-6): + super().__init__() + + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + if x.dtype != torch.float32: + xf = x.to(dtype=torch.float32) + else: + xf = x + output = (xf * torch.sqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)) + if x.dtype != torch.float32: + output = output.to(dtype=x.dtype) + return output * self.weight + +class RoPE(nn.Module): + def __init__(self, embedding_dim: int, *, max_seq_length: int = 2048, base: float = 10000.0): + super().__init__() + + pe = torch.zeros(max_seq_length, embedding_dim) + position = torch.arange(0, max_seq_length, dtype=torch.float32).unsqueeze(1) + div_term = torch.exp(torch.arange(0, embedding_dim, step=2).float() * (-math.log(base) / embedding_dim)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe, persistent=False) + + @torch.no_grad() + def forward(self, x: Tensor) -> Tensor: + return x + self.pe[:, :x.shape[1], :] + +class Attention(nn.Module): + def __init__(self, embedding_dim: int, *, rope: RoPE, max_seq_length: int = 2048, n_heads: int = 4): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_heads = n_heads + self.qkv = nn.Linear(embedding_dim, embedding_dim * 3, bias=False) + self.proj = nn.Linear(embedding_dim, embedding_dim, bias=False) + self.rope = rope + self.register_buffer('bias', torch.tril(torch.ones(max_seq_length, max_seq_length))[None, None, :, :], persistent=False) + + def forward(self, x: Tensor) -> Tensor: + b, t, c = x.size() + + x = self.rope(x) + + q, k, v = self.qkv(x).split(self.embedding_dim, dim=2) + q = q.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2) + k = k.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2) + v = v.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2) + + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :t, :t] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + y = att @ v + y = y.transpose(1, 2).contiguous().view(b, t, c) + + return self.proj(y) + +class FFN(nn.Module): + def __init__(self, embedding_dim: int, intermediate_dim: int | None = None): + super().__init__() + + intermediate_dim = intermediate_dim or embedding_dim * 4 + + self.w1 = nn.Linear(embedding_dim, intermediate_dim * 2, bias=False) + self.w2 = nn.Linear(intermediate_dim, embedding_dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + x, gate = self.w1(x).chunk(2, dim=-1) + return self.w2(F.gelu(gate) * x) + +class Layer(nn.Module): + def __init__(self, embedding_dim: int, rope: RoPE): + super().__init__() + + self.attn = Attention(embedding_dim, rope=rope) + self.norm1 = RMSNorm(embedding_dim) + self.ffn = FFN(embedding_dim) + self.norm2 = RMSNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tensor: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + return x + +class CLM(nn.Module): + def __init__(self, embedding_dim: int, n_layers: int, *, vocab_size: int): + super().__init__() + + rope = RoPE(embedding_dim) + self.layers = nn.ModuleList([Layer(embedding_dim, rope=rope) for _ in range(n_layers)]) + self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) + self.norm = RMSNorm(embedding_dim) + self.lm_head = nn.Linear(embedding_dim, vocab_size, bias=False) + + def forward(self, x: Tensor) -> Tensor: + x = self.word_embeddings(x) + for layer in self.layers: + x = layer(x) + logits = self.lm_head(self.norm(x)) + return logits.view(-1, logits.size(-1)) + +lm = CLM(256, 4, vocab_size=50257) +torch.onnx.export( + lm, + torch.randint(0, 50256, (1, 64)), + f'tools/train-data/mini-clm/model.onnx', + input_names=['input_ids'], + output_names=['probs'], + dynamic_axes={ + 'input_ids': {0: 'batch', 1: 'seq'}, + 'probs': {0: 'batch_seq'} + }, + opset_version=14 +) + +onnx_model = onnx.load('tools/train-data/mini-clm/model.onnx') +requires_grad = [param.name for param in onnx_model.graph.initializer] + +artifacts.generate_artifacts( + onnx_model, + requires_grad=requires_grad, + frozen_params=[], + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory='tools/train-data/mini-clm' +)