-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial setup for LiteRT (TensorFlow Lite) model tests. (#59)
Progress on #5. This contains two simple test cases for demonstration purposes, one of which is currently failing due to a regression: iree-org/iree#19402. The test suite follows the same structure as the onnx_models test suite in this repository. Some cleanup and refactoring will be more evident as this grows. We could for example share the `compile_mlir_with_iree` helper function between both test suites.
- Loading branch information
Showing
14 changed files
with
392 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright 2024 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
name: Test LiteRT Models | ||
on: | ||
push: | ||
branches: | ||
- main | ||
paths: | ||
- ".github/workflows/test_litert_models.yml" | ||
- "litert_models/**" | ||
pull_request: | ||
paths: | ||
- ".github/workflows/test_litert_models.yml" | ||
- "litert_models/**" | ||
workflow_dispatch: | ||
schedule: | ||
# Runs at 3:00 PM UTC, which is 8:00 AM PST | ||
- cron: "0 15 * * *" | ||
|
||
concurrency: | ||
# A PR number if a pull request and otherwise the commit hash. This cancels | ||
# queued and in-progress runs for the same PR (presubmit) or commit | ||
# (postsubmit). The workflow name is prepended to avoid conflicts between | ||
# different workflows. | ||
group: ${{ github.workflow }}-${{ github.event.number || github.sha }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
test-litert-models: | ||
if: ${{ github.repository_owner == 'iree-org' || github.event_name != 'schedule' }} | ||
runs-on: ubuntu-24.04 | ||
env: | ||
VENV_DIR: ${{ github.workspace }}/.venv | ||
HTML_REPORT_PATH: litert_models/litert_models_test_report_cpu_llvm_task.html | ||
steps: | ||
- name: Checkout repository | ||
uses: actions/checkout@v4 | ||
|
||
# Install Python packages. | ||
- name: Setup Python | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: "3.11" | ||
- name: Setup Python venv | ||
run: python3 -m venv ${VENV_DIR} | ||
- name: Install IREE nightly release Python packages | ||
run: | | ||
source ${VENV_DIR}/bin/activate | ||
python3 -m pip install -r litert_models/requirements-iree.txt | ||
# Run tests. | ||
- name: Run LiteRT models test suite | ||
run: | | ||
source ${VENV_DIR}/bin/activate | ||
pytest litert_models/ \ | ||
-rA \ | ||
--log-cli-level=info \ | ||
--timeout=300 \ | ||
--durations=0 \ | ||
--html=${HTML_REPORT_PATH} \ | ||
--self-contained-html | ||
- name: Upload HTML report | ||
uses: actions/upload-artifact@v4 | ||
with: | ||
name: litert_models_test_report_cpu_llvm_task.html | ||
path: ${{ env.HTML_REPORT_PATH }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# LiteRT (Formely TFLite) Model Tests | ||
|
||
This test suite exercises | ||
[LiteRT, formely known as TensorFlow Lite](https://ai.google.dev/edge/litert) | ||
models. Most pretrained models are sourced from https://www.kaggle.com/models. | ||
|
||
Testing *currently* follows several stages: | ||
|
||
```mermaid | ||
graph LR | ||
Model["Download model"] | ||
Model --> ImportMLIR["Import into MLIR"] | ||
ImportMLIR --> CompileIREE["Compile with IREE"] | ||
``` | ||
|
||
Testing *could* also test inference and compare with LiteRT: | ||
|
||
```mermaid | ||
graph LR | ||
Model --> ImportMLIR["Import into MLIR"] | ||
ImportMLIR --> CompileIREE["Compile with IREE"] | ||
CompileIREE --> RunIREE["Run with IREE"] | ||
RunIREE --> Check | ||
Model --> LoadLiteRT["Load into LiteRT"] | ||
LoadLiteRT --> RunLiteRT["Run with LiteRT"] | ||
RunLiteRT --> Check | ||
Check["Compare results"] | ||
``` | ||
|
||
## Quickstart | ||
|
||
1. Set up your virtual environment and install requirements: | ||
|
||
```bash | ||
python -m venv .venv | ||
source .venv/bin/activate | ||
python -m pip install -r requirements.txt | ||
``` | ||
|
||
* To use `iree-compile` and `iree-run-module` from Python packages: | ||
|
||
```bash | ||
python -m pip install -r requirements-iree.txt | ||
``` | ||
|
||
* To use a custom version of IREE follow the instructions for | ||
[building the IREE Python packages from source](https://iree.dev/building-from-source/getting-started/#python-bindings), | ||
including the extra steps for the TFLite importer. | ||
|
||
2. Run pytest using typical flags: | ||
|
||
```bash | ||
pytest \ | ||
-rA \ | ||
--log-cli-level=info \ | ||
--durations=0 | ||
``` | ||
|
||
See https://docs.pytest.org/en/stable/how-to/usage.html for other options. | ||
|
||
## Advanced pytest usage | ||
|
||
* The `log-cli-level` level can also be set to `debug`, `warning`, or `error`. | ||
See https://docs.pytest.org/en/stable/how-to/logging.html. | ||
* Run only tests matching a name pattern: | ||
|
||
```bash | ||
pytest -k resnet | ||
``` | ||
|
||
* Ignore xfail marks | ||
(https://docs.pytest.org/en/stable/how-to/skipping.html#ignoring-xfail): | ||
|
||
```bash | ||
pytest --runxfail | ||
``` | ||
|
||
* Run tests in parallel using https://pytest-xdist.readthedocs.io/ | ||
(note that this swallows some logging): | ||
|
||
```bash | ||
# Run with an automatic number of threads (usually one per CPU core). | ||
pytest -n auto | ||
# Run on an explicit number of threads. | ||
pytest -n 4 | ||
``` | ||
|
||
* Create an HTMl report using https://pytest-html.readthedocs.io/en/latest/index.html | ||
|
||
```bash | ||
pytest --html=report.html --self-contained-html --log-cli-level=info | ||
``` | ||
|
||
See also | ||
https://docs.pytest.org/en/latest/how-to/output.html#creating-junitxml-format-files | ||
|
||
## Test suite implementation details | ||
|
||
### Kaggle | ||
|
||
Models are downloaded using https://github.com/Kaggle/kagglehub. | ||
|
||
By default, kagglehub caches downloads at `~/.cache/kagglehub/models/`. This | ||
can be overriden by setting the `KAGGLEHUB_CACHE` environment variable. See the | ||
[`kagglehub/config.py` source](https://github.com/Kaggle/kagglehub/blob/main/src/kagglehub/config.py) | ||
for other configuration options. | ||
|
||
### Working with `.mlirbc` files | ||
|
||
The `iree-import-tflite` tool outputs MLIR bytecode (`.mlirbc`) by default. To | ||
convert to MLIR text (`.mlir`): | ||
|
||
```bash | ||
iree-ir-tool cp input.mlirbc -o output.mlir | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright 2024 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import logging | ||
import pytest | ||
|
||
from .utils import * | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def tflite_import_and_iree_compile_fn(kaggle_model_name: str): | ||
model_path = download_from_kagglehub(kaggle_model_name) | ||
logger.info(f"model_path: {model_path}") | ||
|
||
mlir_path = import_litert_model_to_mlir(model_path) | ||
logger.info(f"mlir_path: {mlir_path}") | ||
|
||
vmfb_path = compile_mlir_with_iree( | ||
mlir_path, | ||
"cpu", | ||
[ | ||
"--iree-hal-target-backends=llvm-cpu", | ||
"--iree-llvmcpu-target-cpu=host", | ||
], | ||
) | ||
logger.info(f"vmfb_path: {vmfb_path}") | ||
|
||
# TODO(#5): test iree-run-module success and numerics | ||
# * On Linux... | ||
# * Determine interface via ai-edge-litert / tflite-runtime | ||
# * Produce test inputs, save to .bin for IREE | ||
# * Produce golden test outputs, save to .bin for IREE | ||
# * Run with inputs and expected outputs | ||
|
||
|
||
@pytest.fixture | ||
def tflite_import_and_iree_compile(): | ||
return tflite_import_and_iree_compile_fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Requirements for using IREE from nightly packages. | ||
|
||
# Include base requirements. | ||
-r requirements.txt | ||
|
||
--find-links https://iree.dev/pip-release-links.html | ||
--pre | ||
iree-base-compiler | ||
iree-base-runtime | ||
iree-tools-tflite |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Baseline requirements for running the test suite. | ||
# * See requirements-iree.txt for using IREE packages. | ||
|
||
pytest | ||
pytest-html | ||
pytest-reportlog | ||
pytest-timeout | ||
pytest-xdist | ||
|
||
# Not available on Windows. Make optional? Generate test goldens on Linux? | ||
# ai-edge-litert | ||
# tflite-runtime | ||
kagglehub | ||
tensorflow |
Empty file.
Empty file.
Empty file.
22 changes: 22 additions & 0 deletions
22
litert_models/tests/kaggle/tensorflow/mobilenet_v1_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright 2024 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
# https://www.kaggle.com/models/tensorflow/mobilenet-v1/ | ||
|
||
import pytest | ||
|
||
from ....utils import * | ||
|
||
|
||
# https://www.kaggle.com/models/tensorflow/mobilenet-v1/tfLite/0-25-224 | ||
def test_mobilenet_v1_0_25_224(tflite_import_and_iree_compile): | ||
tflite_import_and_iree_compile("tensorflow/mobilenet-v1/tfLite/0-25-224") | ||
|
||
|
||
# https://www.kaggle.com/models/tensorflow/mobilenet-v1/tfLite/0-25-224-quantized/ | ||
@pytest.mark.xfail(raises=IreeCompileException) | ||
def test_mobilenet_v1_0_25_224_quantized(tflite_import_and_iree_compile): | ||
tflite_import_and_iree_compile("tensorflow/mobilenet-v1/tfLite/0-25-224-quantized") |
Oops, something went wrong.