From 55e66cd08fa1063ca65aa1fdfcf625225c7b7cdd Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 11:33:13 -0700 Subject: [PATCH 01/26] Skeleton ONNX model tests. --- README.md | 6 +++ onnx_models/README.md | 55 +++++++++++++++++++++++ onnx_models/artifacts/.gitignore | 2 + onnx_models/basic_test.py | 75 +++++++++++++++++++++++++++++++ onnx_models/requirements-iree.txt | 8 ++++ onnx_models/requirements.txt | 9 ++++ 6 files changed, 155 insertions(+) create mode 100644 onnx_models/README.md create mode 100644 onnx_models/artifacts/.gitignore create mode 100644 onnx_models/basic_test.py create mode 100644 onnx_models/requirements-iree.txt create mode 100644 onnx_models/requirements.txt diff --git a/README.md b/README.md index 8d8cc65..40f6121 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,12 @@ See https://groups.google.com/g/iree-discuss/c/GIWyj8hmP0k/ for context. * Built with [cmake](https://cmake.org/) and run via [ctest](https://cmake.org/cmake/help/latest/manual/ctest.1.html) (for now?). +### [onnx_models/](onnx_models/) : Open Neural Network Exchange models + +[![Test ONNX Models](https://github.com/iree-org/iree-test-suites/actions/workflows/test_onnx_models.yml/badge.svg?branch=main)](https://github.com/iree-org/iree-test-suites/actions/workflows/test_onnx_models.yml?query=branch%3Amain) + +TODO: overview / details + ### [onnx_ops/](onnx_ops/) : Open Neural Network Exchange operations [![Test ONNX Ops](https://github.com/iree-org/iree-test-suites/actions/workflows/test_onnx_ops.yml/badge.svg?branch=main)](https://github.com/iree-org/iree-test-suites/actions/workflows/test_onnx_ops.yml?query=branch%3Amain) diff --git a/onnx_models/README.md b/onnx_models/README.md new file mode 100644 index 0000000..92c2873 --- /dev/null +++ b/onnx_models/README.md @@ -0,0 +1,55 @@ +# ONNX Model Tests + +This test suite exercises ONNX (Open Neural Network Exchange: https://onnx.ai/) +models. Most pretrained models are sourced from https://github.com/onnx/models. + +Testing follows several stages: + +```mermaid +graph LR + Model --> ImportMLIR["Import into MLIR"] + ImportMLIR --> CompileIREE["Compile with IREE"] + CompileIREE --> RunIREE["Run with IREE"] + RunIREE --> Check + + Model --> LoadONNX["Load into ORT"] + LoadONNX --> RunONNX["Run with ORT"] + RunONNX --> Check + + Check["Compare results"] +``` + +## Quickstart + +1. Set up your virtual environment and install requirements: + + ```bash + python -m venv .venv + source .venv/bin/activate + python -m pip install -r requirements.txt + ``` + + * To use `iree-compile` and `iree-run-module` from Python packages: + + ```bash + python -m pip install -r requirements-iree.txt + ``` + + * To use local versions of `iree-compile` and `iree-run-module`, put them on + your `$PATH` ahead of your `.venv/Scripts` directory: + + ```bash + export PATH=path/to/iree-build:$PATH + ``` + +2. Run pytest using typical flags: + + ```bash + pytest \ + -n auto \ + -rA \ + --timeout=30 \ + --durations=20 \ + ``` + + See https://docs.pytest.org/en/stable/how-to/usage.html for other options. diff --git a/onnx_models/artifacts/.gitignore b/onnx_models/artifacts/.gitignore new file mode 100644 index 0000000..89f5b7b --- /dev/null +++ b/onnx_models/artifacts/.gitignore @@ -0,0 +1,2 @@ +*.mlir +*.onnx diff --git a/onnx_models/basic_test.py b/onnx_models/basic_test.py new file mode 100644 index 0000000..5e2ed56 --- /dev/null +++ b/onnx_models/basic_test.py @@ -0,0 +1,75 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import onnx +import pytest +import subprocess +import urllib.request +from pathlib import Path + +logger = logging.getLogger(__name__) + +THIS_DIR = Path(__file__).parent +ARTIFACTS_DIR = THIS_DIR / "artifacts" + +ONNX_CONVERTER_OUTPUT_MIN_VERSION = 17 + + +# TODO(#18289): use real frontend API, import model in-memory? +def upgrade_onnx_model(original_path: Path): + original_model = onnx.load_model(original_path) + converted_model = onnx.version_converter.convert_version( + original_model, ONNX_CONVERTER_OUTPUT_MIN_VERSION + ) + upgraded_path = original_path.with_name( + original_path.stem + f"_version{ONNX_CONVERTER_OUTPUT_MIN_VERSION}.onnx" + ) + logging.info( + f"Upgrading '{original_path.relative_to(THIS_DIR)}' to '{upgraded_path.relative_to(THIS_DIR)}'" + ) + onnx.save(converted_model, upgraded_path) + return upgraded_path + + +# TODO(#18289): use real frontend API, import model in-memory? +def import_onnx_model(onnx_path: Path): + imported_mlir_path = onnx_path.with_suffix(".mlir") + logging.info( + f"Importing '{onnx_path.relative_to(THIS_DIR)}' to '{imported_mlir_path.relative_to(THIS_DIR)}'" + ) + exec_args = [ + "iree-import-onnx", + str(onnx_path), + "-o", + str(imported_mlir_path), + ] + ret = subprocess.run(exec_args, capture_output=True) + if ret.returncode != 0: + logger.error(f"Import of '{onnx_path.name}' failed!\niree-import-onnx stdout:") + logger.error(ret.stdout.decode("utf-8")) + logger.error("iree-import-onnx stderr:") + logger.error(ret.stderr.decode("utf-8")) + raise RuntimeError(f" '{onnx_path.name}' import failed") + return imported_mlir_path + + +def test_basic(): + print("test_basic") + + # TODO(scotttodd): move to fixture with cache / download on demand + onnx_url = "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx" + original_path = ARTIFACTS_DIR / "mobilenetv2-12.onnx" + # urllib.request.urlretrieve(onnx_url, original_path) + + upgraded_path = upgrade_onnx_model(original_path) + imported_mlir_path = import_onnx_model(upgraded_path) + # TODO(scotttodd): Load input data + # TODO(scotttodd): Compile with IREE + # TODO(scotttodd): Run with IREE + # TODO(scotttodd): Load into ONNX Runtime + # TODO(scotttodd): Run with ONNX Runtime + # TODO(scotttodd): Compare results diff --git a/onnx_models/requirements-iree.txt b/onnx_models/requirements-iree.txt new file mode 100644 index 0000000..87b6e82 --- /dev/null +++ b/onnx_models/requirements-iree.txt @@ -0,0 +1,8 @@ +# Requirements for using IREE from nightly packages. + +# Include base requirements. +-r requirements.txt + +--find-links https://iree.dev/pip-release-links.html +iree-compiler +iree-runtime diff --git a/onnx_models/requirements.txt b/onnx_models/requirements.txt new file mode 100644 index 0000000..1b87572 --- /dev/null +++ b/onnx_models/requirements.txt @@ -0,0 +1,9 @@ +# Baseline requirements for running the test suite. +# * See requirements-iree.txt for using IREE packages. + +onnx +pyjson5 +pytest +pytest-reportlog +pytest-timeout +pytest-xdist From 8e82c2da912613d54879a59c2404183196f96857 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 11:44:27 -0700 Subject: [PATCH 02/26] Compile with iree-compile. --- onnx_models/artifacts/.gitignore | 1 + onnx_models/basic_test.py | 28 +++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/onnx_models/artifacts/.gitignore b/onnx_models/artifacts/.gitignore index 89f5b7b..c025930 100644 --- a/onnx_models/artifacts/.gitignore +++ b/onnx_models/artifacts/.gitignore @@ -1,2 +1,3 @@ *.mlir *.onnx +*.vmfb diff --git a/onnx_models/basic_test.py b/onnx_models/basic_test.py index 5e2ed56..94ab0cb 100644 --- a/onnx_models/basic_test.py +++ b/onnx_models/basic_test.py @@ -10,6 +10,7 @@ import subprocess import urllib.request from pathlib import Path +from typing import List logger = logging.getLogger(__name__) @@ -57,6 +58,29 @@ def import_onnx_model(onnx_path: Path): return imported_mlir_path +def compile_model(input_program: Path, config_name: str, compile_flags: List[str]): + cwd = THIS_DIR + compiled_module_path = input_program.with_name( + input_program.stem + f"_{config_name}.vmfb" + ) + compile_args = ["iree-compile", input_program.relative_to(cwd)] + compile_args.extend(compile_flags) + compile_args.extend(["-o", compiled_module_path.relative_to(cwd)]) + compile_cmd = subprocess.list2cmdline(compile_args) + logging.getLogger().info( + f"Launching compile command:\n" # + f" cd {cwd} && {compile_cmd}" + ) + ret = subprocess.run(compile_cmd, shell=True, capture_output=True, cwd=cwd) + if ret.returncode != 0: + logging.getLogger().error(f"Compilation of '{compiled_module_path}' failed") + logger.error(ret.stdout.decode("utf-8")) + logger.error("iree-import-onnx stderr:") + logger.error(ret.stderr.decode("utf-8")) + raise RuntimeError(f" '{compiled_module_path.name}' compile failed") + return compiled_module_path + + def test_basic(): print("test_basic") @@ -67,8 +91,10 @@ def test_basic(): upgraded_path = upgrade_onnx_model(original_path) imported_mlir_path = import_onnx_model(upgraded_path) + compiled_module_path = compile_model( + imported_mlir_path, "cpu", ["--iree-hal-target-backends=llvm-cpu"] + ) # TODO(scotttodd): Load input data - # TODO(scotttodd): Compile with IREE # TODO(scotttodd): Run with IREE # TODO(scotttodd): Load into ONNX Runtime # TODO(scotttodd): Run with ONNX Runtime From e0726d40f71fe79b45a302ab17af66b8bad00c41 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 12:09:38 -0700 Subject: [PATCH 03/26] Generate random test input data and run the compiled program. --- onnx_models/.gitignore | 1 + onnx_models/artifacts/.gitignore | 3 - onnx_models/basic_test.py | 115 +++++++++++++++++++++++++------ 3 files changed, 96 insertions(+), 23 deletions(-) create mode 100644 onnx_models/.gitignore delete mode 100644 onnx_models/artifacts/.gitignore diff --git a/onnx_models/.gitignore b/onnx_models/.gitignore new file mode 100644 index 0000000..62dd1dd --- /dev/null +++ b/onnx_models/.gitignore @@ -0,0 +1 @@ +artifacts/* diff --git a/onnx_models/artifacts/.gitignore b/onnx_models/artifacts/.gitignore deleted file mode 100644 index c025930..0000000 --- a/onnx_models/artifacts/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -*.mlir -*.onnx -*.vmfb diff --git a/onnx_models/basic_test.py b/onnx_models/basic_test.py index 94ab0cb..cd63135 100644 --- a/onnx_models/basic_test.py +++ b/onnx_models/basic_test.py @@ -5,7 +5,9 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import logging +import numpy as np import onnx +import struct import pytest import subprocess import urllib.request @@ -13,6 +15,7 @@ from typing import List logger = logging.getLogger(__name__) +rng = np.random.default_rng(0) THIS_DIR = Path(__file__).parent ARTIFACTS_DIR = THIS_DIR / "artifacts" @@ -21,7 +24,7 @@ # TODO(#18289): use real frontend API, import model in-memory? -def upgrade_onnx_model(original_path: Path): +def upgrade_onnx_model_version(original_path: Path): original_model = onnx.load_model(original_path) converted_model = onnx.version_converter.convert_version( original_model, ONNX_CONVERTER_OUTPUT_MIN_VERSION @@ -37,7 +40,7 @@ def upgrade_onnx_model(original_path: Path): # TODO(#18289): use real frontend API, import model in-memory? -def import_onnx_model(onnx_path: Path): +def import_onnx_model_to_mlir(onnx_path: Path): imported_mlir_path = onnx_path.with_suffix(".mlir") logging.info( f"Importing '{onnx_path.relative_to(THIS_DIR)}' to '{imported_mlir_path.relative_to(THIS_DIR)}'" @@ -50,7 +53,8 @@ def import_onnx_model(onnx_path: Path): ] ret = subprocess.run(exec_args, capture_output=True) if ret.returncode != 0: - logger.error(f"Import of '{onnx_path.name}' failed!\niree-import-onnx stdout:") + logger.error(f"Import of '{onnx_path.name}' failed!") + logger.error("iree-import-onnx stdout:") logger.error(ret.stdout.decode("utf-8")) logger.error("iree-import-onnx stderr:") logger.error(ret.stderr.decode("utf-8")) @@ -58,44 +62,115 @@ def import_onnx_model(onnx_path: Path): return imported_mlir_path -def compile_model(input_program: Path, config_name: str, compile_flags: List[str]): +def compile_mlir_with_iree(mlir_path: Path, config_name: str, compile_flags: List[str]): cwd = THIS_DIR - compiled_module_path = input_program.with_name( - input_program.stem + f"_{config_name}.vmfb" - ) - compile_args = ["iree-compile", input_program.relative_to(cwd)] + iree_module_path = mlir_path.with_name(mlir_path.stem + f"_{config_name}.vmfb") + compile_args = ["iree-compile", mlir_path.relative_to(cwd)] compile_args.extend(compile_flags) - compile_args.extend(["-o", compiled_module_path.relative_to(cwd)]) + compile_args.extend(["-o", iree_module_path.relative_to(cwd)]) compile_cmd = subprocess.list2cmdline(compile_args) - logging.getLogger().info( + logger.info( f"Launching compile command:\n" # f" cd {cwd} && {compile_cmd}" ) ret = subprocess.run(compile_cmd, shell=True, capture_output=True, cwd=cwd) if ret.returncode != 0: - logging.getLogger().error(f"Compilation of '{compiled_module_path}' failed") + logger.error(f"Compilation of '{iree_module_path}' failed") + logger.error("iree-compile stdout:") logger.error(ret.stdout.decode("utf-8")) - logger.error("iree-import-onnx stderr:") + logger.error("iree-compile stderr:") + logger.error(ret.stderr.decode("utf-8")) + raise RuntimeError(f" '{iree_module_path.name}' compile failed") + return iree_module_path + + +def run_iree_module(iree_module_path: Path, run_flags: List[str]): + cwd = THIS_DIR + run_args = ["iree-run-module", f"--module={iree_module_path.relative_to(cwd)}"] + run_args.extend(run_flags) + run_cmd = subprocess.list2cmdline(run_args) + logger.info( + f"Launching run command:\n" # + f" cd {cwd} && {run_cmd}" + ) + ret = subprocess.run(run_cmd, shell=True, capture_output=True, cwd=cwd) + if ret.returncode != 0: + logger.error(f"Run of '{iree_module_path}' failed") + logger.error("iree-run-module stdout:") + logger.error(ret.stdout.decode("utf-8")) + logger.error("iree-run-module stderr:") logger.error(ret.stderr.decode("utf-8")) - raise RuntimeError(f" '{compiled_module_path.name}' compile failed") - return compiled_module_path + raise RuntimeError(f" '{iree_module_path.name}' run failed") + logger.info(f"Run of '{iree_module_path}' succeeded") + logger.info("iree-run-module stdout:") + logger.info(ret.stdout.decode("utf-8")) + logger.info("iree-run-module stderr:") + logger.info(ret.stderr.decode("utf-8")) + + +# map numpy dtype -> (iree dtype, struct.pack format str) +numpy_to_iree_dtype_map = { + np.dtype("int64"): ("si64", "q"), + np.dtype("uint64"): ("ui64", "Q"), + np.dtype("int32"): ("si32", "i"), + np.dtype("uint32"): ("ui32", "I"), + np.dtype("int16"): ("si16", "h"), + np.dtype("uint16"): ("ui16", "H"), + np.dtype("int8"): ("si8", "b"), + np.dtype("uint8"): ("ui8", "B"), + np.dtype("float64"): ("f64", "d"), + np.dtype("float32"): ("f32", "f"), + np.dtype("float16"): ("f16", "e"), + np.dtype("bool"): ("i1", "?"), +} + + +def pack_ndarray_to_binary(ndarr: np.ndarray): + mylist = ndarr.flatten().tolist() + dtype = ndarr.dtype + bytearr = b"" + if dtype in numpy_to_iree_dtype_map: + iree_dtype = numpy_to_iree_dtype_map[dtype][1] + bytearr = struct.pack(f"{len(mylist)}{iree_dtype}", *mylist) + else: + raise NotImplementedError( + f"Unsupported data type in pack_ndarray_to_binary(): '{dtype}'" + ) + return bytearr + + +def write_binary_to_file(ndarr: np.ndarray, filename: Path): + with open(filename, "wb") as f: + bytearr = pack_ndarray_to_binary(ndarr) + f.write(bytearr) def test_basic(): - print("test_basic") + if not ARTIFACTS_DIR.is_dir(): + ARTIFACTS_DIR.mkdir(parents=True) # TODO(scotttodd): move to fixture with cache / download on demand onnx_url = "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx" original_path = ARTIFACTS_DIR / "mobilenetv2-12.onnx" # urllib.request.urlretrieve(onnx_url, original_path) - upgraded_path = upgrade_onnx_model(original_path) - imported_mlir_path = import_onnx_model(upgraded_path) - compiled_module_path = compile_model( + upgraded_path = upgrade_onnx_model_version(original_path) + imported_mlir_path = import_onnx_model_to_mlir(upgraded_path) + iree_module_path = compile_mlir_with_iree( imported_mlir_path, "cpu", ["--iree-hal-target-backends=llvm-cpu"] ) - # TODO(scotttodd): Load input data - # TODO(scotttodd): Run with IREE + + # TODO(scotttodd): prepare_input helper function + random_data = rng.random((1, 3, 224, 224), dtype=np.float32) + random_data_path = original_path.with_name(original_path.stem + "_input_0.bin") + write_binary_to_file(random_data, random_data_path) + # logger.info(dummy_data) + + run_iree_module( + iree_module_path, + ["--device=local-task", f"--input=1x3x224x224xf32=@{random_data_path}"], + ) + # TODO(scotttodd): Load into ONNX Runtime # TODO(scotttodd): Run with ONNX Runtime # TODO(scotttodd): Compare results From 64c4c6f2def557cefe82d2f8feffddd3e96d84b8 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 13:04:22 -0700 Subject: [PATCH 04/26] Run through onnxruntime. No results comparison yet. --- onnx_models/basic_test.py | 38 +++++++++++++++++++++--------------- onnx_models/requirements.txt | 1 + 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/onnx_models/basic_test.py b/onnx_models/basic_test.py index cd63135..05cbb26 100644 --- a/onnx_models/basic_test.py +++ b/onnx_models/basic_test.py @@ -11,6 +11,7 @@ import pytest import subprocess import urllib.request +from onnxruntime import InferenceSession from pathlib import Path from typing import List @@ -24,19 +25,19 @@ # TODO(#18289): use real frontend API, import model in-memory? -def upgrade_onnx_model_version(original_path: Path): - original_model = onnx.load_model(original_path) +def upgrade_onnx_model_version(original_onnx_path: Path): + original_model = onnx.load_model(original_onnx_path) converted_model = onnx.version_converter.convert_version( original_model, ONNX_CONVERTER_OUTPUT_MIN_VERSION ) - upgraded_path = original_path.with_name( - original_path.stem + f"_version{ONNX_CONVERTER_OUTPUT_MIN_VERSION}.onnx" + upgraded_onnx_path = original_onnx_path.with_name( + original_onnx_path.stem + f"_version{ONNX_CONVERTER_OUTPUT_MIN_VERSION}.onnx" ) logging.info( - f"Upgrading '{original_path.relative_to(THIS_DIR)}' to '{upgraded_path.relative_to(THIS_DIR)}'" + f"Upgrading '{original_onnx_path.relative_to(THIS_DIR)}' to '{upgraded_onnx_path.relative_to(THIS_DIR)}'" ) - onnx.save(converted_model, upgraded_path) - return upgraded_path + onnx.save(converted_model, upgraded_onnx_path) + return upgraded_onnx_path # TODO(#18289): use real frontend API, import model in-memory? @@ -101,6 +102,7 @@ def run_iree_module(iree_module_path: Path, run_flags: List[str]): logger.error("iree-run-module stderr:") logger.error(ret.stderr.decode("utf-8")) raise RuntimeError(f" '{iree_module_path.name}' run failed") + # TODO(scotttodd): write outputs to files, or use --expected_output logger.info(f"Run of '{iree_module_path}' succeeded") logger.info("iree-run-module stdout:") logger.info(ret.stdout.decode("utf-8")) @@ -151,26 +153,30 @@ def test_basic(): # TODO(scotttodd): move to fixture with cache / download on demand onnx_url = "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx" - original_path = ARTIFACTS_DIR / "mobilenetv2-12.onnx" - # urllib.request.urlretrieve(onnx_url, original_path) + original_onnx_path = ARTIFACTS_DIR / "mobilenetv2-12.onnx" + # urllib.request.urlretrieve(onnx_url, original_onnx_path) - upgraded_path = upgrade_onnx_model_version(original_path) - imported_mlir_path = import_onnx_model_to_mlir(upgraded_path) + upgraded_onnx_path = upgrade_onnx_model_version(original_onnx_path) + imported_mlir_path = import_onnx_model_to_mlir(upgraded_onnx_path) iree_module_path = compile_mlir_with_iree( imported_mlir_path, "cpu", ["--iree-hal-target-backends=llvm-cpu"] ) - # TODO(scotttodd): prepare_input helper function + # # TODO(scotttodd): prepare_input helper function random_data = rng.random((1, 3, 224, 224), dtype=np.float32) - random_data_path = original_path.with_name(original_path.stem + "_input_0.bin") + random_data_path = original_onnx_path.with_name( + original_onnx_path.stem + "_input_0.bin" + ) write_binary_to_file(random_data, random_data_path) - # logger.info(dummy_data) + # logger.info(random_data) run_iree_module( iree_module_path, ["--device=local-task", f"--input=1x3x224x224xf32=@{random_data_path}"], ) - # TODO(scotttodd): Load into ONNX Runtime - # TODO(scotttodd): Run with ONNX Runtime + onnx_session = InferenceSession(upgraded_onnx_path) + onnx_results = onnx_session.run(["output"], {"input": random_data}) + # logger.info(onnx_results) + # TODO(scotttodd): Compare results diff --git a/onnx_models/requirements.txt b/onnx_models/requirements.txt index 1b87572..24feafa 100644 --- a/onnx_models/requirements.txt +++ b/onnx_models/requirements.txt @@ -2,6 +2,7 @@ # * See requirements-iree.txt for using IREE packages. onnx +onnxruntime pyjson5 pytest pytest-reportlog From afc0abf577bc680b3ae602da80b2bde8010acb07 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 13:12:55 -0700 Subject: [PATCH 05/26] Try resnet50. --- onnx_models/basic_test.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/onnx_models/basic_test.py b/onnx_models/basic_test.py index 05cbb26..a8edb8d 100644 --- a/onnx_models/basic_test.py +++ b/onnx_models/basic_test.py @@ -152,8 +152,10 @@ def test_basic(): ARTIFACTS_DIR.mkdir(parents=True) # TODO(scotttodd): move to fixture with cache / download on demand - onnx_url = "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx" - original_onnx_path = ARTIFACTS_DIR / "mobilenetv2-12.onnx" + # onnx_url = "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx" + # original_onnx_path = ARTIFACTS_DIR / "mobilenetv2-12.onnx" + onnx_url = "https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v1-12.onnx" + original_onnx_path = ARTIFACTS_DIR / "resnet50-v1-12.onnx" # urllib.request.urlretrieve(onnx_url, original_onnx_path) upgraded_onnx_path = upgrade_onnx_model_version(original_onnx_path) @@ -176,7 +178,20 @@ def test_basic(): ) onnx_session = InferenceSession(upgraded_onnx_path) - onnx_results = onnx_session.run(["output"], {"input": random_data}) - # logger.info(onnx_results) + # onnx_results = onnx_session.run(["output"], {"input": random_data}) + onnx_results = onnx_session.run(["resnetv17_dense0_fwd"], {"data": random_data}) + logger.info(onnx_results) # TODO(scotttodd): Compare results + + +# What varies between each test: +# Model URL +# Model name +# Function signature +# Number of inputs +# Names of inputs +# Shapes of inputs +# Number of outputs +# Names of outputs +# Shapes of outputs From 3e984810c98a6ef26ed41194bbe7afb341a56516 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 13:27:39 -0700 Subject: [PATCH 06/26] Compare outputs between ONNX Runtime and IREE. --- onnx_models/basic_test.py | 52 +++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/onnx_models/basic_test.py b/onnx_models/basic_test.py index a8edb8d..2afdd0b 100644 --- a/onnx_models/basic_test.py +++ b/onnx_models/basic_test.py @@ -103,11 +103,11 @@ def run_iree_module(iree_module_path: Path, run_flags: List[str]): logger.error(ret.stderr.decode("utf-8")) raise RuntimeError(f" '{iree_module_path.name}' run failed") # TODO(scotttodd): write outputs to files, or use --expected_output - logger.info(f"Run of '{iree_module_path}' succeeded") - logger.info("iree-run-module stdout:") - logger.info(ret.stdout.decode("utf-8")) - logger.info("iree-run-module stderr:") - logger.info(ret.stderr.decode("utf-8")) + # logger.info(f"Run of '{iree_module_path}' succeeded") + # logger.info("iree-run-module stdout:") + # logger.info(ret.stdout.decode("utf-8")) + # logger.info("iree-run-module stderr:") + # logger.info(ret.stderr.decode("utf-8")) # map numpy dtype -> (iree dtype, struct.pack format str) @@ -159,31 +159,40 @@ def test_basic(): # urllib.request.urlretrieve(onnx_url, original_onnx_path) upgraded_onnx_path = upgrade_onnx_model_version(original_onnx_path) - imported_mlir_path = import_onnx_model_to_mlir(upgraded_onnx_path) - iree_module_path = compile_mlir_with_iree( - imported_mlir_path, "cpu", ["--iree-hal-target-backends=llvm-cpu"] - ) # # TODO(scotttodd): prepare_input helper function - random_data = rng.random((1, 3, 224, 224), dtype=np.float32) - random_data_path = original_onnx_path.with_name( + input_data = rng.random((1, 3, 224, 224), dtype=np.float32) + input_data_path = original_onnx_path.with_name( original_onnx_path.stem + "_input_0.bin" ) - write_binary_to_file(random_data, random_data_path) - # logger.info(random_data) + write_binary_to_file(input_data, input_data_path) + # logger.info(input_data) + + # Run through ONNX Runtime. + onnx_session = InferenceSession(upgraded_onnx_path) + # onnx_results = onnx_session.run(["output"], {"input": input_data}) + onnx_results = onnx_session.run(["resnetv17_dense0_fwd"], {"data": input_data}) + # logger.info(np.array(onnx_results[0])) + reference_output_data_path = original_onnx_path.with_name( + original_onnx_path.stem + "_output_0.bin" + ) + write_binary_to_file(onnx_results[0], reference_output_data_path) + # Import, compile, then run with IREE. + imported_mlir_path = import_onnx_model_to_mlir(upgraded_onnx_path) + iree_module_path = compile_mlir_with_iree( + imported_mlir_path, "cpu", ["--iree-hal-target-backends=llvm-cpu"] + ) + # Note: could load the output into memory here and compare using numpy. run_iree_module( iree_module_path, - ["--device=local-task", f"--input=1x3x224x224xf32=@{random_data_path}"], + [ + "--device=local-task", + f"--input=1x3x224x224xf32=@{input_data_path}", + f"--expected_output=1x1000xf32=@{reference_output_data_path}", + ], ) - onnx_session = InferenceSession(upgraded_onnx_path) - # onnx_results = onnx_session.run(["output"], {"input": random_data}) - onnx_results = onnx_session.run(["resnetv17_dense0_fwd"], {"data": random_data}) - logger.info(onnx_results) - - # TODO(scotttodd): Compare results - # What varies between each test: # Model URL @@ -195,3 +204,4 @@ def test_basic(): # Number of outputs # Names of outputs # Shapes of outputs +# Can get the function signature from the loaded ONNX model From f6c1f28467a13e63ffd0dc5d3658eb2edc22d55c Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 13:49:57 -0700 Subject: [PATCH 07/26] Extract common code into a helper, test resnet and mobilenet. --- onnx_models/basic_test.py | 2 + onnx_models/conftest.py | 221 ++++++++++++++++++++++++++++++++++ onnx_models/helper_fn_test.py | 31 +++++ 3 files changed, 254 insertions(+) create mode 100644 onnx_models/conftest.py create mode 100644 onnx_models/helper_fn_test.py diff --git a/onnx_models/basic_test.py b/onnx_models/basic_test.py index 2afdd0b..6a0fe75 100644 --- a/onnx_models/basic_test.py +++ b/onnx_models/basic_test.py @@ -151,6 +151,8 @@ def test_basic(): if not ARTIFACTS_DIR.is_dir(): ARTIFACTS_DIR.mkdir(parents=True) + # TODO(scotttodd): group model artifacts into subfolders + # TODO(scotttodd): move to fixture with cache / download on demand # onnx_url = "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx" # original_onnx_path = ARTIFACTS_DIR / "mobilenetv2-12.onnx" diff --git a/onnx_models/conftest.py b/onnx_models/conftest.py new file mode 100644 index 0000000..6f19c31 --- /dev/null +++ b/onnx_models/conftest.py @@ -0,0 +1,221 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import numpy as np +import onnx +import struct +import pytest +import subprocess +import urllib.request +from onnxruntime import InferenceSession +from pathlib import Path +from typing import List + +logger = logging.getLogger(__name__) +rng = np.random.default_rng(0) + +THIS_DIR = Path(__file__).parent +ARTIFACTS_DIR = THIS_DIR / "artifacts" + +ONNX_CONVERTER_OUTPUT_MIN_VERSION = 17 + + +# TODO(#18289): use real frontend API, import model in-memory? +def upgrade_onnx_model_version(original_onnx_path: Path): + original_model = onnx.load_model(original_onnx_path) + converted_model = onnx.version_converter.convert_version( + original_model, ONNX_CONVERTER_OUTPUT_MIN_VERSION + ) + upgraded_onnx_path = original_onnx_path.with_name( + original_onnx_path.stem + f"_version{ONNX_CONVERTER_OUTPUT_MIN_VERSION}.onnx" + ) + logging.info( + f"Upgrading '{original_onnx_path.relative_to(THIS_DIR)}' to '{upgraded_onnx_path.relative_to(THIS_DIR)}'" + ) + onnx.save(converted_model, upgraded_onnx_path) + return upgraded_onnx_path + + +# TODO(#18289): use real frontend API, import model in-memory? +def import_onnx_model_to_mlir(onnx_path: Path): + imported_mlir_path = onnx_path.with_suffix(".mlir") + logging.info( + f"Importing '{onnx_path.relative_to(THIS_DIR)}' to '{imported_mlir_path.relative_to(THIS_DIR)}'" + ) + exec_args = [ + "iree-import-onnx", + str(onnx_path), + "-o", + str(imported_mlir_path), + ] + ret = subprocess.run(exec_args, capture_output=True) + if ret.returncode != 0: + logger.error(f"Import of '{onnx_path.name}' failed!") + logger.error("iree-import-onnx stdout:") + logger.error(ret.stdout.decode("utf-8")) + logger.error("iree-import-onnx stderr:") + logger.error(ret.stderr.decode("utf-8")) + raise RuntimeError(f" '{onnx_path.name}' import failed") + return imported_mlir_path + + +def compile_mlir_with_iree(mlir_path: Path, config_name: str, compile_flags: List[str]): + cwd = THIS_DIR + iree_module_path = mlir_path.with_name(mlir_path.stem + f"_{config_name}.vmfb") + compile_args = ["iree-compile", mlir_path.relative_to(cwd)] + compile_args.extend(compile_flags) + compile_args.extend(["-o", iree_module_path.relative_to(cwd)]) + compile_cmd = subprocess.list2cmdline(compile_args) + logger.info( + f"Launching compile command:\n" # + f" cd {cwd} && {compile_cmd}" + ) + ret = subprocess.run(compile_cmd, shell=True, capture_output=True, cwd=cwd) + if ret.returncode != 0: + logger.error(f"Compilation of '{iree_module_path}' failed") + logger.error("iree-compile stdout:") + logger.error(ret.stdout.decode("utf-8")) + logger.error("iree-compile stderr:") + logger.error(ret.stderr.decode("utf-8")) + raise RuntimeError(f" '{iree_module_path.name}' compile failed") + return iree_module_path + + +def run_iree_module(iree_module_path: Path, run_flags: List[str]): + cwd = THIS_DIR + run_args = ["iree-run-module", f"--module={iree_module_path.relative_to(cwd)}"] + run_args.extend(run_flags) + run_cmd = subprocess.list2cmdline(run_args) + logger.info( + f"Launching run command:\n" # + f" cd {cwd} && {run_cmd}" + ) + ret = subprocess.run(run_cmd, shell=True, capture_output=True, cwd=cwd) + if ret.returncode != 0: + logger.error(f"Run of '{iree_module_path}' failed") + logger.error("iree-run-module stdout:") + logger.error(ret.stdout.decode("utf-8")) + logger.error("iree-run-module stderr:") + logger.error(ret.stderr.decode("utf-8")) + raise RuntimeError(f" '{iree_module_path.name}' run failed") + # TODO(scotttodd): write outputs to files, or use --expected_output + # logger.info(f"Run of '{iree_module_path}' succeeded") + # logger.info("iree-run-module stdout:") + # logger.info(ret.stdout.decode("utf-8")) + # logger.info("iree-run-module stderr:") + # logger.info(ret.stderr.decode("utf-8")) + + +# map numpy dtype -> (iree dtype, struct.pack format str) +numpy_to_iree_dtype_map = { + np.dtype("int64"): ("si64", "q"), + np.dtype("uint64"): ("ui64", "Q"), + np.dtype("int32"): ("si32", "i"), + np.dtype("uint32"): ("ui32", "I"), + np.dtype("int16"): ("si16", "h"), + np.dtype("uint16"): ("ui16", "H"), + np.dtype("int8"): ("si8", "b"), + np.dtype("uint8"): ("ui8", "B"), + np.dtype("float64"): ("f64", "d"), + np.dtype("float32"): ("f32", "f"), + np.dtype("float16"): ("f16", "e"), + np.dtype("bool"): ("i1", "?"), +} + + +def pack_ndarray_to_binary(ndarr: np.ndarray): + mylist = ndarr.flatten().tolist() + dtype = ndarr.dtype + bytearr = b"" + if dtype in numpy_to_iree_dtype_map: + iree_dtype = numpy_to_iree_dtype_map[dtype][1] + bytearr = struct.pack(f"{len(mylist)}{iree_dtype}", *mylist) + else: + raise NotImplementedError( + f"Unsupported data type in pack_ndarray_to_binary(): '{dtype}'" + ) + return bytearr + + +def write_binary_to_file(ndarr: np.ndarray, filename: Path): + with open(filename, "wb") as f: + bytearr = pack_ndarray_to_binary(ndarr) + f.write(bytearr) + + +# What varies between each test: +# Model URL +# Model name +# Function signature +# Number of inputs +# Names of inputs +# Shapes of inputs +# Number of outputs +# Names of outputs +# Shapes of outputs +# Can get the function signature from the loaded ONNX model + + +@pytest.fixture +def compare_between_iree_and_onnxruntime(): + def fn( + model_name: str, + model_url: str, + input_name: str, + input_shape: tuple[int, ...], + input_type: str, + output_name: str, + output_shape: tuple[int, ...], + output_type: str, + ): + if not ARTIFACTS_DIR.is_dir(): + ARTIFACTS_DIR.mkdir(parents=True) + # TODO(scotttodd): group model artifacts into subfolders + + # TODO(scotttodd): move to fixture with cache / download on demand + # TODO(scotttodd): extract name from URL? + original_onnx_path = ARTIFACTS_DIR / f"{model_name}.onnx" + urllib.request.urlretrieve(model_url, original_onnx_path) + + upgraded_onnx_path = upgrade_onnx_model_version(original_onnx_path) + + # TODO(scotttodd): prepare_input helper function, multiple inputs + # TODO(scotttodd): dtype from input_shape (or ONNX model reflection) + input_data = rng.random(input_shape, dtype=np.float32) + input_data_path = original_onnx_path.with_name( + original_onnx_path.stem + "_input_0.bin" + ) + write_binary_to_file(input_data, input_data_path) + # logger.info(input_data) + + # Run through ONNX Runtime. + onnx_session = InferenceSession(upgraded_onnx_path) + # TODO(scotttodd): multiple inputs/outputs + onnx_results = onnx_session.run([output_name], {input_name: input_data}) + # logger.info(np.array(onnx_results[0])) + reference_output_data_path = original_onnx_path.with_name( + original_onnx_path.stem + "_output_0.bin" + ) + write_binary_to_file(onnx_results[0], reference_output_data_path) + + # Import, compile, then run with IREE. + imported_mlir_path = import_onnx_model_to_mlir(upgraded_onnx_path) + iree_module_path = compile_mlir_with_iree( + imported_mlir_path, "cpu", ["--iree-hal-target-backends=llvm-cpu"] + ) + # Note: could load the output into memory here and compare using numpy. + # TODO(scotttodd): signature conversions from onnx/numpy to IREE + run_iree_module( + iree_module_path, + [ + "--device=local-task", + f"--input=1x3x224x224xf32=@{input_data_path}", + f"--expected_output=1x1000xf32=@{reference_output_data_path}", + ], + ) + + return fn diff --git a/onnx_models/helper_fn_test.py b/onnx_models/helper_fn_test.py new file mode 100644 index 0000000..1f90afd --- /dev/null +++ b/onnx_models/helper_fn_test.py @@ -0,0 +1,31 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +def test_mobilenetv2_12(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_name="mobilenetv2-12", + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", + input_name="input", + input_shape=(1, 3, 224, 224), + input_type="", + output_name="output", + output_shape=(), + output_type="", + ) + + +def test_resnet50_v1_12(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_name="resnet50-v1-12", + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v1-12.onnx", + input_name="data", + input_shape=(1, 3, 224, 224), + input_type="", + output_name="resnetv17_dense0_fwd", + output_shape=(), + output_type="", + ) From 1f7d28900638cad6edd2914170af2a4a6318561a Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 15:13:39 -0700 Subject: [PATCH 08/26] Progress on extracting metadata from onnx protos. --- onnx_models/conftest.py | 245 ++++++++++++++++++++++++++++------ onnx_models/helper_fn_test.py | 7 + 2 files changed, 212 insertions(+), 40 deletions(-) diff --git a/onnx_models/conftest.py b/onnx_models/conftest.py index 6f19c31..6be4b44 100644 --- a/onnx_models/conftest.py +++ b/onnx_models/conftest.py @@ -21,6 +21,51 @@ THIS_DIR = Path(__file__).parent ARTIFACTS_DIR = THIS_DIR / "artifacts" +############################################################################### +# General utilities +############################################################################### + +# map numpy dtype -> (iree dtype, struct.pack format str) +numpy_to_iree_dtype_map = { + np.dtype("int64"): ("si64", "q"), + np.dtype("uint64"): ("ui64", "Q"), + np.dtype("int32"): ("si32", "i"), + np.dtype("uint32"): ("ui32", "I"), + np.dtype("int16"): ("si16", "h"), + np.dtype("uint16"): ("ui16", "H"), + np.dtype("int8"): ("si8", "b"), + np.dtype("uint8"): ("ui8", "B"), + np.dtype("float64"): ("f64", "d"), + np.dtype("float32"): ("f32", "f"), + np.dtype("float16"): ("f16", "e"), + np.dtype("bool"): ("i1", "?"), +} + + +def pack_ndarray_to_binary(ndarr: np.ndarray): + mylist = ndarr.flatten().tolist() + dtype = ndarr.dtype + bytearr = b"" + if dtype in numpy_to_iree_dtype_map: + iree_dtype = numpy_to_iree_dtype_map[dtype][1] + bytearr = struct.pack(f"{len(mylist)}{iree_dtype}", *mylist) + else: + raise NotImplementedError( + f"Unsupported data type in pack_ndarray_to_binary(): '{dtype}'" + ) + return bytearr + + +def write_binary_to_file(ndarr: np.ndarray, filename: Path): + with open(filename, "wb") as f: + bytearr = pack_ndarray_to_binary(ndarr) + f.write(bytearr) + + +############################################################################### +# ONNX loading, running, import, etc. +############################################################################### + ONNX_CONVERTER_OUTPUT_MIN_VERSION = 17 @@ -63,6 +108,139 @@ def import_onnx_model_to_mlir(onnx_path: Path): return imported_mlir_path +def convert_proto_elem_type_to_iree_dtype(etype): + if etype == onnx.TensorProto.FLOAT: + return "f32" + if etype == onnx.TensorProto.UINT8: + return "i8" + if etype == onnx.TensorProto.INT8: + return "i8" + if etype == onnx.TensorProto.UINT16: + return "i16" + if etype == onnx.TensorProto.INT16: + return "i16" + if etype == onnx.TensorProto.INT32: + return "i32" + if etype == onnx.TensorProto.INT64: + return "i64" + if etype == onnx.TensorProto.BOOL: + return "i1" + if etype == onnx.TensorProto.FLOAT16: + return "f16" + if etype == onnx.TensorProto.DOUBLE: + return "f64" + if etype == onnx.TensorProto.UINT32: + return "i32" + if etype == onnx.TensorProto.UINT64: + return "i64" + if etype == onnx.TensorProto.COMPLEX64: + return "complex" + if etype == onnx.TensorProto.COMPLEX128: + return "complex" + if etype == onnx.TensorProto.BFLOAT16: + return "bf16" + if etype == onnx.TensorProto.FLOAT8E4M3FN: + return "f8e4m3fn" + if etype == onnx.TensorProto.FLOAT8E4M3FNUZ: + return "f8e4m3fnuz" + if etype == onnx.TensorProto.FLOAT8E5M2: + return "f8e5m2" + if etype == onnx.TensorProto.FLOAT8E5M2FNUZ: + return "f8e5m2fnuz" + if etype == onnx.TensorProto.UINT4: + return "i4" + if etype == onnx.TensorProto.INT4: + return "i4" + return "" + + +def convert_onnx_type_proto_to_numpy_dimensions( + type_proto: onnx.onnx_ml_pb2.TypeProto, +) -> str: + if type_proto.HasField("tensor_type"): + # Note: turning dynamic dimensions into just 1 here, since we need + # a concrete (static) shape buffer of input data in the tests. + return tuple( + d.dim_value if d.HasField("dim_value") else 1 + for d in type_proto.tensor_type.shape.dim + ) + else: + raise NotImplementedError(f"Unsupported proto type: {type_proto}") + + +def convert_onnx_type_proto_to_iree_type_string( + type_proto: onnx.onnx_ml_pb2.TypeProto, +) -> str: + if type_proto.HasField("tensor_type"): + tensor_type = type_proto.tensor_type + shape = tensor_type.shape + # Note: turning dynamic dimensions into just "1" here, since we need + # a concrete (static) shape buffer of input data in the tests. + shape = "x".join( + [str(d.dim_value) if d.HasField("dim_value") else "1" for d in shape.dim] + ) + dtype = convert_proto_elem_type_to_iree_dtype(tensor_type.elem_type) + if shape == "": + return dtype + return f"{shape}x{dtype}" + else: + raise NotImplementedError(f"Unsupported proto type: {type_proto}") + + +def get_onnx_model_metadata(onnx_path: Path): + logger.info(f"Getting model metadata for '{onnx_path.relative_to(THIS_DIR)}'") + model = onnx.load(onnx_path) + + inputs = [] + outputs = [] + + # input_data = rng.random(input_shape, dtype=np.float32) + # input_data_path = original_onnx_path.with_name( + # original_onnx_path.stem + "_input_0.bin" + # ) + # write_binary_to_file(input_data, input_data_path) + # logger.debug(input_data) + + # help(model.graph.input) + # print(model.graph.input) + for graph_input in model.graph.input: + numpy_dimensions = convert_onnx_type_proto_to_numpy_dimensions(graph_input.type) + iree_type = convert_onnx_type_proto_to_iree_type_string(graph_input.type) + inputs.append( + { + "name": graph_input.name, + "numpy_dimensions": numpy_dimensions, + "iree_type": iree_type, + } + ) + for graph_output in model.graph.output: + numpy_dimensions = convert_onnx_type_proto_to_numpy_dimensions( + graph_output.type + ) + iree_type = convert_onnx_type_proto_to_iree_type_string(graph_output.type) + outputs.append( + { + "name": graph_output.name, + "numpy_dimensions": numpy_dimensions, + "iree_type": iree_type, + } + ) + + # Concrete shape to generate test data with + # (N, 3, 224, 224), with dynamic dim "N" should turn into + # (1, 3, 224, 224) + + return { + "inputs": inputs, + "outputs": outputs, + } + + +############################################################################### +# IREE compilation and running +############################################################################### + + def compile_mlir_with_iree(mlir_path: Path, config_name: str, compile_flags: List[str]): cwd = THIS_DIR iree_module_path = mlir_path.with_name(mlir_path.stem + f"_{config_name}.vmfb") @@ -110,43 +288,6 @@ def run_iree_module(iree_module_path: Path, run_flags: List[str]): # logger.info(ret.stderr.decode("utf-8")) -# map numpy dtype -> (iree dtype, struct.pack format str) -numpy_to_iree_dtype_map = { - np.dtype("int64"): ("si64", "q"), - np.dtype("uint64"): ("ui64", "Q"), - np.dtype("int32"): ("si32", "i"), - np.dtype("uint32"): ("ui32", "I"), - np.dtype("int16"): ("si16", "h"), - np.dtype("uint16"): ("ui16", "H"), - np.dtype("int8"): ("si8", "b"), - np.dtype("uint8"): ("ui8", "B"), - np.dtype("float64"): ("f64", "d"), - np.dtype("float32"): ("f32", "f"), - np.dtype("float16"): ("f16", "e"), - np.dtype("bool"): ("i1", "?"), -} - - -def pack_ndarray_to_binary(ndarr: np.ndarray): - mylist = ndarr.flatten().tolist() - dtype = ndarr.dtype - bytearr = b"" - if dtype in numpy_to_iree_dtype_map: - iree_dtype = numpy_to_iree_dtype_map[dtype][1] - bytearr = struct.pack(f"{len(mylist)}{iree_dtype}", *mylist) - else: - raise NotImplementedError( - f"Unsupported data type in pack_ndarray_to_binary(): '{dtype}'" - ) - return bytearr - - -def write_binary_to_file(ndarr: np.ndarray, filename: Path): - with open(filename, "wb") as f: - bytearr = pack_ndarray_to_binary(ndarr) - f.write(bytearr) - - # What varies between each test: # Model URL # Model name @@ -178,11 +319,17 @@ def fn( # TODO(scotttodd): move to fixture with cache / download on demand # TODO(scotttodd): extract name from URL? + # TODO(scotttodd): overwrite if already existing? check SHA? original_onnx_path = ARTIFACTS_DIR / f"{model_name}.onnx" - urllib.request.urlretrieve(model_url, original_onnx_path) + if not original_onnx_path.exists(): + urllib.request.urlretrieve(model_url, original_onnx_path) upgraded_onnx_path = upgrade_onnx_model_version(original_onnx_path) + onnx_model_metadata = get_onnx_model_metadata(upgraded_onnx_path) + logger.info(onnx_model_metadata) + return + # TODO(scotttodd): prepare_input helper function, multiple inputs # TODO(scotttodd): dtype from input_shape (or ONNX model reflection) input_data = rng.random(input_shape, dtype=np.float32) @@ -190,13 +337,31 @@ def fn( original_onnx_path.stem + "_input_0.bin" ) write_binary_to_file(input_data, input_data_path) - # logger.info(input_data) + logger.debug(input_data) # Run through ONNX Runtime. onnx_session = InferenceSession(upgraded_onnx_path) + + # We can either + # A) List all metadata explicitly + # B) Get metadata on demand from the .onnx protobuf using 'onnx' + # C) Get metadata on demand from the InferenceSession using 'onnxruntime' + inputs = onnx_session.get_inputs() + logger.info("inputs") + for input in inputs: + logger.info(f"{input.name}, {input.shape}, {input.type}") + # if input.is_tensor(): + # logger.info(f" input element type: {input.element_type}") + outputs = onnx_session.get_outputs() + logger.info("outputs") + for output in outputs: + logger.info(f"{output.name}, {output.shape}, {output.type}") + # input[0] : data, ['N', 3, 224, 224], tensor(float) + # output[0]: resnetv17_dense0_fwd, ['N', 1000], tensor(float) + # TODO(scotttodd): multiple inputs/outputs onnx_results = onnx_session.run([output_name], {input_name: input_data}) - # logger.info(np.array(onnx_results[0])) + logger.debug(np.array(onnx_results[0])) reference_output_data_path = original_onnx_path.with_name( original_onnx_path.stem + "_output_0.bin" ) diff --git a/onnx_models/helper_fn_test.py b/onnx_models/helper_fn_test.py index 1f90afd..b890026 100644 --- a/onnx_models/helper_fn_test.py +++ b/onnx_models/helper_fn_test.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet def test_mobilenetv2_12(compare_between_iree_and_onnxruntime): compare_between_iree_and_onnxruntime( model_name="mobilenetv2-12", @@ -18,6 +19,7 @@ def test_mobilenetv2_12(compare_between_iree_and_onnxruntime): ) +# https://github.com/onnx/models/tree/main/validated/vision/classification/resnet def test_resnet50_v1_12(compare_between_iree_and_onnxruntime): compare_between_iree_and_onnxruntime( model_name="resnet50-v1-12", @@ -29,3 +31,8 @@ def test_resnet50_v1_12(compare_between_iree_and_onnxruntime): output_shape=(), output_type="", ) + + +# TODO(scotttodd): add annotations: +# xfail (with Exception subclass / reason) +# marks (size of test, hardware required, etc.) From 43c77f5b47633c6c5a2e12e33a809b298ca8ee6b Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 15:52:14 -0700 Subject: [PATCH 09/26] Extract metadata and generate test inputs automatically. --- onnx_models/conftest.py | 202 ++++++++++++++-------------------- onnx_models/helper_fn_test.py | 12 -- 2 files changed, 82 insertions(+), 132 deletions(-) diff --git a/onnx_models/conftest.py b/onnx_models/conftest.py index 6be4b44..1308f86 100644 --- a/onnx_models/conftest.py +++ b/onnx_models/conftest.py @@ -11,6 +11,7 @@ import pytest import subprocess import urllib.request +from onnx.mapping import TENSOR_TYPE_MAP from onnxruntime import InferenceSession from pathlib import Path from typing import List @@ -154,82 +155,81 @@ def convert_proto_elem_type_to_iree_dtype(etype): return "" -def convert_onnx_type_proto_to_numpy_dimensions( - type_proto: onnx.onnx_ml_pb2.TypeProto, +def convert_onnx_tensor_proto_to_numpy_dimensions( + tensor_proto: onnx.onnx_ml_pb2.TensorProto, ) -> str: - if type_proto.HasField("tensor_type"): - # Note: turning dynamic dimensions into just 1 here, since we need - # a concrete (static) shape buffer of input data in the tests. - return tuple( - d.dim_value if d.HasField("dim_value") else 1 - for d in type_proto.tensor_type.shape.dim - ) - else: - raise NotImplementedError(f"Unsupported proto type: {type_proto}") + # Note: turning dynamic dimensions into just 1 here, since we need + # a concrete (static) shape buffer of input data in the tests. + return tuple( + d.dim_value if d.HasField("dim_value") else 1 for d in tensor_proto.shape.dim + ) -def convert_onnx_type_proto_to_iree_type_string( - type_proto: onnx.onnx_ml_pb2.TypeProto, +def convert_onnx_tensor_proto_to_iree_type_string( + tensor_proto: onnx.onnx_ml_pb2.TensorProto, ) -> str: - if type_proto.HasField("tensor_type"): - tensor_type = type_proto.tensor_type - shape = tensor_type.shape - # Note: turning dynamic dimensions into just "1" here, since we need - # a concrete (static) shape buffer of input data in the tests. - shape = "x".join( - [str(d.dim_value) if d.HasField("dim_value") else "1" for d in shape.dim] - ) - dtype = convert_proto_elem_type_to_iree_dtype(tensor_type.elem_type) - if shape == "": - return dtype - return f"{shape}x{dtype}" - else: - raise NotImplementedError(f"Unsupported proto type: {type_proto}") + shape = tensor_proto.shape + # Note: turning dynamic dimensions into just "1" here, since we need + # a concrete (static) shape buffer of input data in the tests. + shape = "x".join( + [str(d.dim_value) if d.HasField("dim_value") else "1" for d in shape.dim] + ) + dtype = convert_proto_elem_type_to_iree_dtype(tensor_proto.elem_type) + if shape == "": + return dtype + return f"{shape}x{dtype}" def get_onnx_model_metadata(onnx_path: Path): + # We can either + # A) List all metadata explicitly + # B) Get metadata on demand from the .onnx protobuf using 'onnx' + # C) Get metadata on demand from the InferenceSession using 'onnxruntime' + # This is option B. + logger.info(f"Getting model metadata for '{onnx_path.relative_to(THIS_DIR)}'") model = onnx.load(onnx_path) inputs = [] - outputs = [] + for idx, graph_input in enumerate(model.graph.input): + type_proto = graph_input.type + if not type_proto.HasField("tensor_type"): + raise NotImplementedError(f"Unsupported proto type: {type_proto}") + tensor_type = type_proto.tensor_type + + # Create a numpy tensor with some random data for the input. + numpy_dimensions = convert_onnx_tensor_proto_to_numpy_dimensions(tensor_type) + numpy_dtype = TENSOR_TYPE_MAP[tensor_type.elem_type].np_dtype + input_data = rng.random(numpy_dimensions, dtype=numpy_dtype) + logger.debug(input_data) + input_data_path = onnx_path.with_name(onnx_path.stem + f"_input_{idx}.bin") + write_binary_to_file(input_data, input_data_path) - # input_data = rng.random(input_shape, dtype=np.float32) - # input_data_path = original_onnx_path.with_name( - # original_onnx_path.stem + "_input_0.bin" - # ) - # write_binary_to_file(input_data, input_data_path) - # logger.debug(input_data) - - # help(model.graph.input) - # print(model.graph.input) - for graph_input in model.graph.input: - numpy_dimensions = convert_onnx_type_proto_to_numpy_dimensions(graph_input.type) - iree_type = convert_onnx_type_proto_to_iree_type_string(graph_input.type) + iree_type = convert_onnx_tensor_proto_to_iree_type_string(tensor_type) inputs.append( { "name": graph_input.name, - "numpy_dimensions": numpy_dimensions, "iree_type": iree_type, + "input_data": input_data, + "input_data_path": input_data_path, } ) + + outputs = [] for graph_output in model.graph.output: - numpy_dimensions = convert_onnx_type_proto_to_numpy_dimensions( - graph_output.type - ) - iree_type = convert_onnx_type_proto_to_iree_type_string(graph_output.type) + type_proto = graph_output.type + if not type_proto.HasField("tensor_type"): + raise NotImplementedError(f"Unsupported proto type: {type_proto}") + tensor_type = type_proto.tensor_type + + iree_type = convert_onnx_tensor_proto_to_iree_type_string(tensor_type) outputs.append( { "name": graph_output.name, - "numpy_dimensions": numpy_dimensions, "iree_type": iree_type, } ) - # Concrete shape to generate test data with - # (N, 3, 224, 224), with dynamic dim "N" should turn into - # (1, 3, 224, 224) - return { "inputs": inputs, "outputs": outputs, @@ -280,25 +280,6 @@ def run_iree_module(iree_module_path: Path, run_flags: List[str]): logger.error("iree-run-module stderr:") logger.error(ret.stderr.decode("utf-8")) raise RuntimeError(f" '{iree_module_path.name}' run failed") - # TODO(scotttodd): write outputs to files, or use --expected_output - # logger.info(f"Run of '{iree_module_path}' succeeded") - # logger.info("iree-run-module stdout:") - # logger.info(ret.stdout.decode("utf-8")) - # logger.info("iree-run-module stderr:") - # logger.info(ret.stderr.decode("utf-8")) - - -# What varies between each test: -# Model URL -# Model name -# Function signature -# Number of inputs -# Names of inputs -# Shapes of inputs -# Number of outputs -# Names of outputs -# Shapes of outputs -# Can get the function signature from the loaded ONNX model @pytest.fixture @@ -306,12 +287,6 @@ def compare_between_iree_and_onnxruntime(): def fn( model_name: str, model_url: str, - input_name: str, - input_shape: tuple[int, ...], - input_type: str, - output_name: str, - output_shape: tuple[int, ...], - output_type: str, ): if not ARTIFACTS_DIR.is_dir(): ARTIFACTS_DIR.mkdir(parents=True) @@ -324,48 +299,41 @@ def fn( if not original_onnx_path.exists(): urllib.request.urlretrieve(model_url, original_onnx_path) + # TODO(scotttodd): cache ONNX metadata and runtime results upgraded_onnx_path = upgrade_onnx_model_version(original_onnx_path) onnx_model_metadata = get_onnx_model_metadata(upgraded_onnx_path) - logger.info(onnx_model_metadata) - return - - # TODO(scotttodd): prepare_input helper function, multiple inputs - # TODO(scotttodd): dtype from input_shape (or ONNX model reflection) - input_data = rng.random(input_shape, dtype=np.float32) - input_data_path = original_onnx_path.with_name( - original_onnx_path.stem + "_input_0.bin" - ) - write_binary_to_file(input_data, input_data_path) - logger.debug(input_data) + logger.debug("ONNX model metadata:") + logger.debug(onnx_model_metadata) # Run through ONNX Runtime. onnx_session = InferenceSession(upgraded_onnx_path) - - # We can either - # A) List all metadata explicitly - # B) Get metadata on demand from the .onnx protobuf using 'onnx' - # C) Get metadata on demand from the InferenceSession using 'onnxruntime' - inputs = onnx_session.get_inputs() - logger.info("inputs") - for input in inputs: - logger.info(f"{input.name}, {input.shape}, {input.type}") - # if input.is_tensor(): - # logger.info(f" input element type: {input.element_type}") - outputs = onnx_session.get_outputs() - logger.info("outputs") - for output in outputs: - logger.info(f"{output.name}, {output.shape}, {output.type}") - # input[0] : data, ['N', 3, 224, 224], tensor(float) - # output[0]: resnetv17_dense0_fwd, ['N', 1000], tensor(float) - - # TODO(scotttodd): multiple inputs/outputs - onnx_results = onnx_session.run([output_name], {input_name: input_data}) - logger.debug(np.array(onnx_results[0])) - reference_output_data_path = original_onnx_path.with_name( - original_onnx_path.stem + "_output_0.bin" - ) - write_binary_to_file(onnx_results[0], reference_output_data_path) + output_names = [output["name"] for output in onnx_model_metadata["outputs"]] + inputs = {} + for input in onnx_model_metadata["inputs"]: + inputs[input["name"]] = input["input_data"] + onnx_results = onnx_session.run(output_names, inputs) + + # Prepare inputs and expected outputs for running through IREE. + run_module_args = [] + for input in onnx_model_metadata["inputs"]: + input_type = input["iree_type"] + input_data_path = input["input_data_path"] + run_module_args.append(f"--input={input_type}=@{input_data_path}") + + assert len(onnx_model_metadata["outputs"]) == len(onnx_results) + for idx in range(len(onnx_results)): + output = onnx_model_metadata["outputs"][idx] + output_type = output["iree_type"] + onnx_result = onnx_results[idx] + logger.debug(np.array(onnx_result)) + reference_output_data_path = original_onnx_path.with_name( + original_onnx_path.stem + f"_output_{idx}.bin" + ) + write_binary_to_file(onnx_result, reference_output_data_path) + run_module_args.append( + f"--expected_output={output_type}=@{reference_output_data_path}" + ) # Import, compile, then run with IREE. imported_mlir_path = import_onnx_model_to_mlir(upgraded_onnx_path) @@ -373,14 +341,8 @@ def fn( imported_mlir_path, "cpu", ["--iree-hal-target-backends=llvm-cpu"] ) # Note: could load the output into memory here and compare using numpy. - # TODO(scotttodd): signature conversions from onnx/numpy to IREE - run_iree_module( - iree_module_path, - [ - "--device=local-task", - f"--input=1x3x224x224xf32=@{input_data_path}", - f"--expected_output=1x1000xf32=@{reference_output_data_path}", - ], - ) + run_flags = ["--device=local-task"] + run_flags.extend(run_module_args) + run_iree_module(iree_module_path, run_flags) return fn diff --git a/onnx_models/helper_fn_test.py b/onnx_models/helper_fn_test.py index b890026..4057e56 100644 --- a/onnx_models/helper_fn_test.py +++ b/onnx_models/helper_fn_test.py @@ -10,12 +10,6 @@ def test_mobilenetv2_12(compare_between_iree_and_onnxruntime): compare_between_iree_and_onnxruntime( model_name="mobilenetv2-12", model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", - input_name="input", - input_shape=(1, 3, 224, 224), - input_type="", - output_name="output", - output_shape=(), - output_type="", ) @@ -24,12 +18,6 @@ def test_resnet50_v1_12(compare_between_iree_and_onnxruntime): compare_between_iree_and_onnxruntime( model_name="resnet50-v1-12", model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v1-12.onnx", - input_name="data", - input_shape=(1, 3, 224, 224), - input_type="", - output_name="resnetv17_dense0_fwd", - output_shape=(), - output_type="", ) From 498d516d5454ad44c3a4db75d61a698b5426ec48 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 16:14:45 -0700 Subject: [PATCH 10/26] Cleanup, initial xfail testing support. --- onnx_models/__init__.py | 0 onnx_models/basic_test.py | 209 ------------------ onnx_models/conftest.py | 15 +- onnx_models/utils.py | 17 ++ ...elper_fn_test.py => vision_models_test.py} | 20 +- 5 files changed, 43 insertions(+), 218 deletions(-) create mode 100644 onnx_models/__init__.py delete mode 100644 onnx_models/basic_test.py create mode 100644 onnx_models/utils.py rename onnx_models/{helper_fn_test.py => vision_models_test.py} (63%) diff --git a/onnx_models/__init__.py b/onnx_models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/onnx_models/basic_test.py b/onnx_models/basic_test.py deleted file mode 100644 index 6a0fe75..0000000 --- a/onnx_models/basic_test.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright 2024 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import numpy as np -import onnx -import struct -import pytest -import subprocess -import urllib.request -from onnxruntime import InferenceSession -from pathlib import Path -from typing import List - -logger = logging.getLogger(__name__) -rng = np.random.default_rng(0) - -THIS_DIR = Path(__file__).parent -ARTIFACTS_DIR = THIS_DIR / "artifacts" - -ONNX_CONVERTER_OUTPUT_MIN_VERSION = 17 - - -# TODO(#18289): use real frontend API, import model in-memory? -def upgrade_onnx_model_version(original_onnx_path: Path): - original_model = onnx.load_model(original_onnx_path) - converted_model = onnx.version_converter.convert_version( - original_model, ONNX_CONVERTER_OUTPUT_MIN_VERSION - ) - upgraded_onnx_path = original_onnx_path.with_name( - original_onnx_path.stem + f"_version{ONNX_CONVERTER_OUTPUT_MIN_VERSION}.onnx" - ) - logging.info( - f"Upgrading '{original_onnx_path.relative_to(THIS_DIR)}' to '{upgraded_onnx_path.relative_to(THIS_DIR)}'" - ) - onnx.save(converted_model, upgraded_onnx_path) - return upgraded_onnx_path - - -# TODO(#18289): use real frontend API, import model in-memory? -def import_onnx_model_to_mlir(onnx_path: Path): - imported_mlir_path = onnx_path.with_suffix(".mlir") - logging.info( - f"Importing '{onnx_path.relative_to(THIS_DIR)}' to '{imported_mlir_path.relative_to(THIS_DIR)}'" - ) - exec_args = [ - "iree-import-onnx", - str(onnx_path), - "-o", - str(imported_mlir_path), - ] - ret = subprocess.run(exec_args, capture_output=True) - if ret.returncode != 0: - logger.error(f"Import of '{onnx_path.name}' failed!") - logger.error("iree-import-onnx stdout:") - logger.error(ret.stdout.decode("utf-8")) - logger.error("iree-import-onnx stderr:") - logger.error(ret.stderr.decode("utf-8")) - raise RuntimeError(f" '{onnx_path.name}' import failed") - return imported_mlir_path - - -def compile_mlir_with_iree(mlir_path: Path, config_name: str, compile_flags: List[str]): - cwd = THIS_DIR - iree_module_path = mlir_path.with_name(mlir_path.stem + f"_{config_name}.vmfb") - compile_args = ["iree-compile", mlir_path.relative_to(cwd)] - compile_args.extend(compile_flags) - compile_args.extend(["-o", iree_module_path.relative_to(cwd)]) - compile_cmd = subprocess.list2cmdline(compile_args) - logger.info( - f"Launching compile command:\n" # - f" cd {cwd} && {compile_cmd}" - ) - ret = subprocess.run(compile_cmd, shell=True, capture_output=True, cwd=cwd) - if ret.returncode != 0: - logger.error(f"Compilation of '{iree_module_path}' failed") - logger.error("iree-compile stdout:") - logger.error(ret.stdout.decode("utf-8")) - logger.error("iree-compile stderr:") - logger.error(ret.stderr.decode("utf-8")) - raise RuntimeError(f" '{iree_module_path.name}' compile failed") - return iree_module_path - - -def run_iree_module(iree_module_path: Path, run_flags: List[str]): - cwd = THIS_DIR - run_args = ["iree-run-module", f"--module={iree_module_path.relative_to(cwd)}"] - run_args.extend(run_flags) - run_cmd = subprocess.list2cmdline(run_args) - logger.info( - f"Launching run command:\n" # - f" cd {cwd} && {run_cmd}" - ) - ret = subprocess.run(run_cmd, shell=True, capture_output=True, cwd=cwd) - if ret.returncode != 0: - logger.error(f"Run of '{iree_module_path}' failed") - logger.error("iree-run-module stdout:") - logger.error(ret.stdout.decode("utf-8")) - logger.error("iree-run-module stderr:") - logger.error(ret.stderr.decode("utf-8")) - raise RuntimeError(f" '{iree_module_path.name}' run failed") - # TODO(scotttodd): write outputs to files, or use --expected_output - # logger.info(f"Run of '{iree_module_path}' succeeded") - # logger.info("iree-run-module stdout:") - # logger.info(ret.stdout.decode("utf-8")) - # logger.info("iree-run-module stderr:") - # logger.info(ret.stderr.decode("utf-8")) - - -# map numpy dtype -> (iree dtype, struct.pack format str) -numpy_to_iree_dtype_map = { - np.dtype("int64"): ("si64", "q"), - np.dtype("uint64"): ("ui64", "Q"), - np.dtype("int32"): ("si32", "i"), - np.dtype("uint32"): ("ui32", "I"), - np.dtype("int16"): ("si16", "h"), - np.dtype("uint16"): ("ui16", "H"), - np.dtype("int8"): ("si8", "b"), - np.dtype("uint8"): ("ui8", "B"), - np.dtype("float64"): ("f64", "d"), - np.dtype("float32"): ("f32", "f"), - np.dtype("float16"): ("f16", "e"), - np.dtype("bool"): ("i1", "?"), -} - - -def pack_ndarray_to_binary(ndarr: np.ndarray): - mylist = ndarr.flatten().tolist() - dtype = ndarr.dtype - bytearr = b"" - if dtype in numpy_to_iree_dtype_map: - iree_dtype = numpy_to_iree_dtype_map[dtype][1] - bytearr = struct.pack(f"{len(mylist)}{iree_dtype}", *mylist) - else: - raise NotImplementedError( - f"Unsupported data type in pack_ndarray_to_binary(): '{dtype}'" - ) - return bytearr - - -def write_binary_to_file(ndarr: np.ndarray, filename: Path): - with open(filename, "wb") as f: - bytearr = pack_ndarray_to_binary(ndarr) - f.write(bytearr) - - -def test_basic(): - if not ARTIFACTS_DIR.is_dir(): - ARTIFACTS_DIR.mkdir(parents=True) - - # TODO(scotttodd): group model artifacts into subfolders - - # TODO(scotttodd): move to fixture with cache / download on demand - # onnx_url = "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx" - # original_onnx_path = ARTIFACTS_DIR / "mobilenetv2-12.onnx" - onnx_url = "https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v1-12.onnx" - original_onnx_path = ARTIFACTS_DIR / "resnet50-v1-12.onnx" - # urllib.request.urlretrieve(onnx_url, original_onnx_path) - - upgraded_onnx_path = upgrade_onnx_model_version(original_onnx_path) - - # # TODO(scotttodd): prepare_input helper function - input_data = rng.random((1, 3, 224, 224), dtype=np.float32) - input_data_path = original_onnx_path.with_name( - original_onnx_path.stem + "_input_0.bin" - ) - write_binary_to_file(input_data, input_data_path) - # logger.info(input_data) - - # Run through ONNX Runtime. - onnx_session = InferenceSession(upgraded_onnx_path) - # onnx_results = onnx_session.run(["output"], {"input": input_data}) - onnx_results = onnx_session.run(["resnetv17_dense0_fwd"], {"data": input_data}) - # logger.info(np.array(onnx_results[0])) - reference_output_data_path = original_onnx_path.with_name( - original_onnx_path.stem + "_output_0.bin" - ) - write_binary_to_file(onnx_results[0], reference_output_data_path) - - # Import, compile, then run with IREE. - imported_mlir_path = import_onnx_model_to_mlir(upgraded_onnx_path) - iree_module_path = compile_mlir_with_iree( - imported_mlir_path, "cpu", ["--iree-hal-target-backends=llvm-cpu"] - ) - # Note: could load the output into memory here and compare using numpy. - run_iree_module( - iree_module_path, - [ - "--device=local-task", - f"--input=1x3x224x224xf32=@{input_data_path}", - f"--expected_output=1x1000xf32=@{reference_output_data_path}", - ], - ) - - -# What varies between each test: -# Model URL -# Model name -# Function signature -# Number of inputs -# Names of inputs -# Shapes of inputs -# Number of outputs -# Names of outputs -# Shapes of outputs -# Can get the function signature from the loaded ONNX model diff --git a/onnx_models/conftest.py b/onnx_models/conftest.py index 1308f86..4ce14d0 100644 --- a/onnx_models/conftest.py +++ b/onnx_models/conftest.py @@ -16,6 +16,8 @@ from pathlib import Path from typing import List +from .utils import * + logger = logging.getLogger(__name__) rng = np.random.default_rng(0) @@ -105,7 +107,7 @@ def import_onnx_model_to_mlir(onnx_path: Path): logger.error(ret.stdout.decode("utf-8")) logger.error("iree-import-onnx stderr:") logger.error(ret.stderr.decode("utf-8")) - raise RuntimeError(f" '{onnx_path.name}' import failed") + raise IreeImportOnnxException(f" '{onnx_path.name}' import failed") return imported_mlir_path @@ -200,7 +202,12 @@ def get_onnx_model_metadata(onnx_path: Path): # Create a numpy tensor with some random data for the input. numpy_dimensions = convert_onnx_tensor_proto_to_numpy_dimensions(tensor_type) numpy_dtype = TENSOR_TYPE_MAP[tensor_type.elem_type].np_dtype - input_data = rng.random(numpy_dimensions, dtype=numpy_dtype) + if numpy_dtype == np.float32 or numpy_dtype == np.float64: + input_data = rng.random(numpy_dimensions, dtype=numpy_dtype) + elif numpy_dtype == np.int32 or numpy_dtype == np.int64: + input_data = rng.integers(numpy_dimensions, dtype=numpy_dtype) + else: + raise NotImplementedError(f"Unsupported numpy type: {numpy_dtype}") logger.debug(input_data) input_data_path = onnx_path.with_name(onnx_path.stem + f"_input_{idx}.bin") write_binary_to_file(input_data, input_data_path) @@ -259,7 +266,7 @@ def compile_mlir_with_iree(mlir_path: Path, config_name: str, compile_flags: Lis logger.error(ret.stdout.decode("utf-8")) logger.error("iree-compile stderr:") logger.error(ret.stderr.decode("utf-8")) - raise RuntimeError(f" '{iree_module_path.name}' compile failed") + raise IreeCompileException(f" '{iree_module_path.name}' compile failed") return iree_module_path @@ -279,7 +286,7 @@ def run_iree_module(iree_module_path: Path, run_flags: List[str]): logger.error(ret.stdout.decode("utf-8")) logger.error("iree-run-module stderr:") logger.error(ret.stderr.decode("utf-8")) - raise RuntimeError(f" '{iree_module_path.name}' run failed") + raise IreeRunException(f" '{iree_module_path.name}' run failed") @pytest.fixture diff --git a/onnx_models/utils.py b/onnx_models/utils.py new file mode 100644 index 0000000..75bc712 --- /dev/null +++ b/onnx_models/utils.py @@ -0,0 +1,17 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +class IreeImportOnnxException(RuntimeError): + pass + + +class IreeCompileException(RuntimeError): + pass + + +class IreeRunException(RuntimeError): + pass diff --git a/onnx_models/helper_fn_test.py b/onnx_models/vision_models_test.py similarity index 63% rename from onnx_models/helper_fn_test.py rename to onnx_models/vision_models_test.py index 4057e56..929f466 100644 --- a/onnx_models/helper_fn_test.py +++ b/onnx_models/vision_models_test.py @@ -4,8 +4,23 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import pytest + +from .utils import * + + +# https://github.com/onnx/models/tree/main/validated/vision/classification/mnist +# TODO(scotttodd): fix test runner (only use "Input3" input?) +# @pytest.mark.xfail(raises=IreeCompileException) +# def test_mnist_7(compare_between_iree_and_onnxruntime): +# compare_between_iree_and_onnxruntime( +# model_name="mnist-7", +# model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/mnist/model/mnist-7.onnx", +# ) + # https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet +@pytest.mark.xfail(raises=IreeRunException) def test_mobilenetv2_12(compare_between_iree_and_onnxruntime): compare_between_iree_and_onnxruntime( model_name="mobilenetv2-12", @@ -19,8 +34,3 @@ def test_resnet50_v1_12(compare_between_iree_and_onnxruntime): model_name="resnet50-v1-12", model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v1-12.onnx", ) - - -# TODO(scotttodd): add annotations: -# xfail (with Exception subclass / reason) -# marks (size of test, hardware required, etc.) From ad2340ade4c637af027338b0c916c7d693c67d96 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 16:16:14 -0700 Subject: [PATCH 11/26] Add test workflow. --- .github/workflows/test_onnx_models.yml | 62 ++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 .github/workflows/test_onnx_models.yml diff --git a/.github/workflows/test_onnx_models.yml b/.github/workflows/test_onnx_models.yml new file mode 100644 index 0000000..258e0e5 --- /dev/null +++ b/.github/workflows/test_onnx_models.yml @@ -0,0 +1,62 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: Test ONNX Models +on: + push: + branches: + - main + paths: + - ".github/workflows/test_onnx_models.yml" + - "onnx_models/**" + pull_request: + paths: + - ".github/workflows/test_onnx_models.yml" + - "onnx_models/**" + workflow_dispatch: + schedule: + # Runs at 3:00 PM UTC, which is 8:00 AM PST + - cron: "0 15 * * *" + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + test-onnx-models: + runs-on: ubuntu-24.04 + env: + VENV_DIR: ${{ github.workspace }}/.venv + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + # Install Python packages. + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Setup Python venv + run: python3 -m venv ${VENV_DIR} + - name: Install IREE nightly release Python packages + run: | + source ${VENV_DIR}/bin/activate + python3 -m pip install -r onnx_models/requirements-iree.txt + + # Run tests. + - name: Run ONNX models test suite + run: | + source ${VENV_DIR}/bin/activate + pytest onnx_models/ \ + -n auto \ + -rA \ + --log-cli-level=info \ + --timeout=60 \ + --durations=0 From e601d9314f624388cbfb4b75c6d50b1a31cf5407 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 16:18:36 -0700 Subject: [PATCH 12/26] Update docs. --- README.md | 6 ++++-- onnx_models/README.md | 3 +-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 40f6121..cb98a96 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,9 @@ See https://groups.google.com/g/iree-discuss/c/GIWyj8hmP0k/ for context. [![Test ONNX Models](https://github.com/iree-org/iree-test-suites/actions/workflows/test_onnx_models.yml/badge.svg?branch=main)](https://github.com/iree-org/iree-test-suites/actions/workflows/test_onnx_models.yml?query=branch%3Amain) -TODO: overview / details +* Tests that import, compile, and run ONNX models through IREE then compare + the outputs against a reference (ONNX Runtime). +* Runnable via [pytest](https://docs.pytest.org/). ### [onnx_ops/](onnx_ops/) : Open Neural Network Exchange operations @@ -30,5 +32,5 @@ TODO: overview / details * 1250+ tests for [ONNX](https://onnx.ai/) framework [operators](https://onnx.ai/onnx/operators/). -* Runnable via [pytest](https://docs.pytest.org/en/stable/) using a +* Runnable via [pytest](https://docs.pytest.org/) using a configurable set of flags to `iree-compile` and `iree-run-module`. diff --git a/onnx_models/README.md b/onnx_models/README.md index 92c2873..25a4913 100644 --- a/onnx_models/README.md +++ b/onnx_models/README.md @@ -48,8 +48,7 @@ graph LR pytest \ -n auto \ -rA \ - --timeout=30 \ - --durations=20 \ + --durations=0 \ ``` See https://docs.pytest.org/en/stable/how-to/usage.html for other options. From 18073c704db62bc921453984971dc3b13a45c338 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 16:24:48 -0700 Subject: [PATCH 13/26] Trim dep. --- onnx_models/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/onnx_models/requirements.txt b/onnx_models/requirements.txt index 24feafa..9ded4d3 100644 --- a/onnx_models/requirements.txt +++ b/onnx_models/requirements.txt @@ -3,7 +3,6 @@ onnx onnxruntime -pyjson5 pytest pytest-reportlog pytest-timeout From 628b1eb00dd87b5e956a252cdb809cb58d4c808f Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 16:43:54 -0700 Subject: [PATCH 14/26] Adjust flags for more complete log output. --- .github/workflows/test_onnx_models.yml | 1 - onnx_models/README.md | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test_onnx_models.yml b/.github/workflows/test_onnx_models.yml index 258e0e5..0fe5ddf 100644 --- a/.github/workflows/test_onnx_models.yml +++ b/.github/workflows/test_onnx_models.yml @@ -55,7 +55,6 @@ jobs: run: | source ${VENV_DIR}/bin/activate pytest onnx_models/ \ - -n auto \ -rA \ --log-cli-level=info \ --timeout=60 \ diff --git a/onnx_models/README.md b/onnx_models/README.md index 25a4913..d80260a 100644 --- a/onnx_models/README.md +++ b/onnx_models/README.md @@ -46,8 +46,8 @@ graph LR ```bash pytest \ - -n auto \ -rA \ + --log-cli-level=info --durations=0 \ ``` From c688c8dbd07d97f7ef89b3581a13b40809ebb74c Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 6 Sep 2024 16:46:31 -0700 Subject: [PATCH 15/26] Use logger consistently. --- onnx_models/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnx_models/conftest.py b/onnx_models/conftest.py index 4ce14d0..873b99e 100644 --- a/onnx_models/conftest.py +++ b/onnx_models/conftest.py @@ -81,7 +81,7 @@ def upgrade_onnx_model_version(original_onnx_path: Path): upgraded_onnx_path = original_onnx_path.with_name( original_onnx_path.stem + f"_version{ONNX_CONVERTER_OUTPUT_MIN_VERSION}.onnx" ) - logging.info( + logger.info( f"Upgrading '{original_onnx_path.relative_to(THIS_DIR)}' to '{upgraded_onnx_path.relative_to(THIS_DIR)}'" ) onnx.save(converted_model, upgraded_onnx_path) @@ -91,7 +91,7 @@ def upgrade_onnx_model_version(original_onnx_path: Path): # TODO(#18289): use real frontend API, import model in-memory? def import_onnx_model_to_mlir(onnx_path: Path): imported_mlir_path = onnx_path.with_suffix(".mlir") - logging.info( + logger.info( f"Importing '{onnx_path.relative_to(THIS_DIR)}' to '{imported_mlir_path.relative_to(THIS_DIR)}'" ) exec_args = [ From 70dd0734484b25f6f8d8ec2924d1a69ad6b55963 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 17 Sep 2024 08:28:05 -0700 Subject: [PATCH 16/26] Move more util code from conftest.py to utils.py. --- onnx_models/conftest.py | 132 +----------------------------------- onnx_models/utils.py | 144 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 130 deletions(-) diff --git a/onnx_models/conftest.py b/onnx_models/conftest.py index 873b99e..5f78bb8 100644 --- a/onnx_models/conftest.py +++ b/onnx_models/conftest.py @@ -7,7 +7,6 @@ import logging import numpy as np import onnx -import struct import pytest import subprocess import urllib.request @@ -24,138 +23,11 @@ THIS_DIR = Path(__file__).parent ARTIFACTS_DIR = THIS_DIR / "artifacts" -############################################################################### -# General utilities -############################################################################### - -# map numpy dtype -> (iree dtype, struct.pack format str) -numpy_to_iree_dtype_map = { - np.dtype("int64"): ("si64", "q"), - np.dtype("uint64"): ("ui64", "Q"), - np.dtype("int32"): ("si32", "i"), - np.dtype("uint32"): ("ui32", "I"), - np.dtype("int16"): ("si16", "h"), - np.dtype("uint16"): ("ui16", "H"), - np.dtype("int8"): ("si8", "b"), - np.dtype("uint8"): ("ui8", "B"), - np.dtype("float64"): ("f64", "d"), - np.dtype("float32"): ("f32", "f"), - np.dtype("float16"): ("f16", "e"), - np.dtype("bool"): ("i1", "?"), -} - - -def pack_ndarray_to_binary(ndarr: np.ndarray): - mylist = ndarr.flatten().tolist() - dtype = ndarr.dtype - bytearr = b"" - if dtype in numpy_to_iree_dtype_map: - iree_dtype = numpy_to_iree_dtype_map[dtype][1] - bytearr = struct.pack(f"{len(mylist)}{iree_dtype}", *mylist) - else: - raise NotImplementedError( - f"Unsupported data type in pack_ndarray_to_binary(): '{dtype}'" - ) - return bytearr - - -def write_binary_to_file(ndarr: np.ndarray, filename: Path): - with open(filename, "wb") as f: - bytearr = pack_ndarray_to_binary(ndarr) - f.write(bytearr) - ############################################################################### # ONNX loading, running, import, etc. ############################################################################### -ONNX_CONVERTER_OUTPUT_MIN_VERSION = 17 - - -# TODO(#18289): use real frontend API, import model in-memory? -def upgrade_onnx_model_version(original_onnx_path: Path): - original_model = onnx.load_model(original_onnx_path) - converted_model = onnx.version_converter.convert_version( - original_model, ONNX_CONVERTER_OUTPUT_MIN_VERSION - ) - upgraded_onnx_path = original_onnx_path.with_name( - original_onnx_path.stem + f"_version{ONNX_CONVERTER_OUTPUT_MIN_VERSION}.onnx" - ) - logger.info( - f"Upgrading '{original_onnx_path.relative_to(THIS_DIR)}' to '{upgraded_onnx_path.relative_to(THIS_DIR)}'" - ) - onnx.save(converted_model, upgraded_onnx_path) - return upgraded_onnx_path - - -# TODO(#18289): use real frontend API, import model in-memory? -def import_onnx_model_to_mlir(onnx_path: Path): - imported_mlir_path = onnx_path.with_suffix(".mlir") - logger.info( - f"Importing '{onnx_path.relative_to(THIS_DIR)}' to '{imported_mlir_path.relative_to(THIS_DIR)}'" - ) - exec_args = [ - "iree-import-onnx", - str(onnx_path), - "-o", - str(imported_mlir_path), - ] - ret = subprocess.run(exec_args, capture_output=True) - if ret.returncode != 0: - logger.error(f"Import of '{onnx_path.name}' failed!") - logger.error("iree-import-onnx stdout:") - logger.error(ret.stdout.decode("utf-8")) - logger.error("iree-import-onnx stderr:") - logger.error(ret.stderr.decode("utf-8")) - raise IreeImportOnnxException(f" '{onnx_path.name}' import failed") - return imported_mlir_path - - -def convert_proto_elem_type_to_iree_dtype(etype): - if etype == onnx.TensorProto.FLOAT: - return "f32" - if etype == onnx.TensorProto.UINT8: - return "i8" - if etype == onnx.TensorProto.INT8: - return "i8" - if etype == onnx.TensorProto.UINT16: - return "i16" - if etype == onnx.TensorProto.INT16: - return "i16" - if etype == onnx.TensorProto.INT32: - return "i32" - if etype == onnx.TensorProto.INT64: - return "i64" - if etype == onnx.TensorProto.BOOL: - return "i1" - if etype == onnx.TensorProto.FLOAT16: - return "f16" - if etype == onnx.TensorProto.DOUBLE: - return "f64" - if etype == onnx.TensorProto.UINT32: - return "i32" - if etype == onnx.TensorProto.UINT64: - return "i64" - if etype == onnx.TensorProto.COMPLEX64: - return "complex" - if etype == onnx.TensorProto.COMPLEX128: - return "complex" - if etype == onnx.TensorProto.BFLOAT16: - return "bf16" - if etype == onnx.TensorProto.FLOAT8E4M3FN: - return "f8e4m3fn" - if etype == onnx.TensorProto.FLOAT8E4M3FNUZ: - return "f8e4m3fnuz" - if etype == onnx.TensorProto.FLOAT8E5M2: - return "f8e5m2" - if etype == onnx.TensorProto.FLOAT8E5M2FNUZ: - return "f8e5m2fnuz" - if etype == onnx.TensorProto.UINT4: - return "i4" - if etype == onnx.TensorProto.INT4: - return "i4" - return "" - def convert_onnx_tensor_proto_to_numpy_dimensions( tensor_proto: onnx.onnx_ml_pb2.TensorProto, @@ -210,7 +82,7 @@ def get_onnx_model_metadata(onnx_path: Path): raise NotImplementedError(f"Unsupported numpy type: {numpy_dtype}") logger.debug(input_data) input_data_path = onnx_path.with_name(onnx_path.stem + f"_input_{idx}.bin") - write_binary_to_file(input_data, input_data_path) + write_ndarray_to_binary_file(input_data, input_data_path) iree_type = convert_onnx_tensor_proto_to_iree_type_string(tensor_type) inputs.append( @@ -337,7 +209,7 @@ def fn( reference_output_data_path = original_onnx_path.with_name( original_onnx_path.stem + f"_output_{idx}.bin" ) - write_binary_to_file(onnx_result, reference_output_data_path) + write_ndarray_to_binary_file(onnx_result, reference_output_data_path) run_module_args.append( f"--expected_output={output_type}=@{reference_output_data_path}" ) diff --git a/onnx_models/utils.py b/onnx_models/utils.py index 75bc712..91ebe32 100644 --- a/onnx_models/utils.py +++ b/onnx_models/utils.py @@ -4,6 +4,21 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import logging +import numpy as np +import onnx +import struct +import subprocess +from pathlib import Path + +logger = logging.getLogger(__name__) + +THIS_DIR = Path(__file__).parent + +############################################################################### +# Exception types +############################################################################### + class IreeImportOnnxException(RuntimeError): pass @@ -15,3 +30,132 @@ class IreeCompileException(RuntimeError): class IreeRunException(RuntimeError): pass + + +############################################################################### +# Numpy utilities +############################################################################### + + +def write_ndarray_to_binary_file(ndarr: np.ndarray, filename: Path): + with open(filename, "wb") as f: + bytearr = pack_ndarray_to_binary(ndarr) + f.write(bytearr) + + +# map numpy dtype -> (iree dtype, struct.pack format str) +numpy_to_iree_dtype_map = { + np.dtype("int64"): ("si64", "q"), + np.dtype("uint64"): ("ui64", "Q"), + np.dtype("int32"): ("si32", "i"), + np.dtype("uint32"): ("ui32", "I"), + np.dtype("int16"): ("si16", "h"), + np.dtype("uint16"): ("ui16", "H"), + np.dtype("int8"): ("si8", "b"), + np.dtype("uint8"): ("ui8", "B"), + np.dtype("float64"): ("f64", "d"), + np.dtype("float32"): ("f32", "f"), + np.dtype("float16"): ("f16", "e"), + np.dtype("bool"): ("i1", "?"), +} + + +def pack_ndarray_to_binary(ndarr: np.ndarray): + mylist = ndarr.flatten().tolist() + dtype = ndarr.dtype + bytearr = b"" + if dtype in numpy_to_iree_dtype_map: + iree_dtype = numpy_to_iree_dtype_map[dtype][1] + bytearr = struct.pack(f"{len(mylist)}{iree_dtype}", *mylist) + else: + raise NotImplementedError( + f"Unsupported data type in pack_ndarray_to_binary(): '{dtype}'" + ) + return bytearr + + +############################################################################### +# ONNX utilities +############################################################################### + + +def convert_proto_elem_type_to_iree_dtype(etype) -> str: + if etype == onnx.TensorProto.BOOL: + return "i1" + if etype == onnx.TensorProto.INT4 or etype == onnx.TensorProto.UINT4: + return "i4" + if etype == onnx.TensorProto.INT8 or etype == onnx.TensorProto.UINT8: + return "i8" + if etype == onnx.TensorProto.INT16 or etype == onnx.TensorProto.UINT16: + return "i16" + if etype == onnx.TensorProto.INT32 or etype == onnx.TensorProto.UINT32: + return "i32" + if etype == onnx.TensorProto.INT64 or etype == onnx.TensorProto.UINT64: + return "i64" + if etype == onnx.TensorProto.FLOAT16: + return "f16" + if etype == onnx.TensorProto.FLOAT: + return "f32" + if etype == onnx.TensorProto.DOUBLE: + return "f64" + if etype == onnx.TensorProto.COMPLEX64: + return "complex" + if etype == onnx.TensorProto.COMPLEX128: + return "complex" + if etype == onnx.TensorProto.BFLOAT16: + return "bf16" + if etype == onnx.TensorProto.FLOAT8E4M3FN: + return "f8e4m3fn" + if etype == onnx.TensorProto.FLOAT8E4M3FNUZ: + return "f8e4m3fnuz" + if etype == onnx.TensorProto.FLOAT8E5M2: + return "f8e5m2" + if etype == onnx.TensorProto.FLOAT8E5M2FNUZ: + return "f8e5m2fnuz" + return "" + + +# TODO(#18289): use real frontend API, import model in-memory? +def upgrade_onnx_model_version(original_onnx_path: Path, min_version=17): + original_model = onnx.load_model(original_onnx_path) + original_version = original_model.opset_import[0].version + if original_version >= min_version: + logger.debug( + f"ONNX model at {original_onnx_path.relative_to(THIS_DIR)} version {original_version} >= {min_version}, skipping upgrade" + ) + return original_onnx_path + + converted_model = onnx.version_converter.convert_version( + original_model, min_version + ) + upgraded_onnx_path = original_onnx_path.with_name( + original_onnx_path.stem + f"_version{min_version}.onnx" + ) + logger.info( + f"Upgrading '{original_onnx_path.relative_to(THIS_DIR)}' to '{upgraded_onnx_path.relative_to(THIS_DIR)}'" + ) + onnx.save(converted_model, upgraded_onnx_path) + return upgraded_onnx_path + + +# TODO(#18289): use real frontend API, import model in-memory? +def import_onnx_model_to_mlir(onnx_path: Path): + imported_mlir_path = onnx_path.with_suffix(".mlir") + logger.info( + f"Importing '{onnx_path.relative_to(THIS_DIR)}' to '{imported_mlir_path.relative_to(THIS_DIR)}'" + ) + exec_args = [ + "iree-import-onnx", + str(onnx_path), + "-o", + str(imported_mlir_path), + ] + ret = subprocess.run(exec_args, capture_output=True) + if ret.returncode != 0: + logger.error(f"Import of '{onnx_path.name}' failed!") + logger.error("iree-import-onnx stdout:") + logger.error(ret.stdout.decode("utf-8")) + logger.error("iree-import-onnx stderr:") + logger.error(ret.stderr.decode("utf-8")) + raise IreeImportOnnxException(f" '{onnx_path.name}' import failed") + return imported_mlir_path From 812b3023c54cbe97abac80b12a224f8ef041dc6f Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 17 Sep 2024 10:24:27 -0700 Subject: [PATCH 17/26] Extract model names from source URLs. --- onnx_models/conftest.py | 7 +++++-- onnx_models/vision_models_test.py | 3 --- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnx_models/conftest.py b/onnx_models/conftest.py index 5f78bb8..56eb644 100644 --- a/onnx_models/conftest.py +++ b/onnx_models/conftest.py @@ -164,15 +164,18 @@ def run_iree_module(iree_module_path: Path, run_flags: List[str]): @pytest.fixture def compare_between_iree_and_onnxruntime(): def fn( - model_name: str, model_url: str, ): + # "https://github.com/.../mobilenetv2-12.onnx" --> "mobilenetv2-12.onnx" + model_file_name = model_url.rsplit("/", 1)[-1] + # "mobilenetv2-12.onnx" --> "mobilenetv2-12" + model_name = model_file_name.rsplit(".", 1)[0] + if not ARTIFACTS_DIR.is_dir(): ARTIFACTS_DIR.mkdir(parents=True) # TODO(scotttodd): group model artifacts into subfolders # TODO(scotttodd): move to fixture with cache / download on demand - # TODO(scotttodd): extract name from URL? # TODO(scotttodd): overwrite if already existing? check SHA? original_onnx_path = ARTIFACTS_DIR / f"{model_name}.onnx" if not original_onnx_path.exists(): diff --git a/onnx_models/vision_models_test.py b/onnx_models/vision_models_test.py index 929f466..a35178f 100644 --- a/onnx_models/vision_models_test.py +++ b/onnx_models/vision_models_test.py @@ -14,7 +14,6 @@ # @pytest.mark.xfail(raises=IreeCompileException) # def test_mnist_7(compare_between_iree_and_onnxruntime): # compare_between_iree_and_onnxruntime( -# model_name="mnist-7", # model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/mnist/model/mnist-7.onnx", # ) @@ -23,7 +22,6 @@ @pytest.mark.xfail(raises=IreeRunException) def test_mobilenetv2_12(compare_between_iree_and_onnxruntime): compare_between_iree_and_onnxruntime( - model_name="mobilenetv2-12", model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", ) @@ -31,6 +29,5 @@ def test_mobilenetv2_12(compare_between_iree_and_onnxruntime): # https://github.com/onnx/models/tree/main/validated/vision/classification/resnet def test_resnet50_v1_12(compare_between_iree_and_onnxruntime): compare_between_iree_and_onnxruntime( - model_name="resnet50-v1-12", model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v1-12.onnx", ) From 6963989fa0fcda4da783f63e2ec7f6a0ee503e38 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 17 Sep 2024 11:19:57 -0700 Subject: [PATCH 18/26] Rework metadata extraction. --- onnx_models/conftest.py | 111 +++++++++++++++++------------- onnx_models/utils.py | 26 ++++++- onnx_models/vision_models_test.py | 11 ++- 3 files changed, 95 insertions(+), 53 deletions(-) diff --git a/onnx_models/conftest.py b/onnx_models/conftest.py index 56eb644..b05b418 100644 --- a/onnx_models/conftest.py +++ b/onnx_models/conftest.py @@ -11,9 +11,8 @@ import subprocess import urllib.request from onnx.mapping import TENSOR_TYPE_MAP -from onnxruntime import InferenceSession +from onnxruntime import InferenceSession, NodeArg from pathlib import Path -from typing import List from .utils import * @@ -31,7 +30,7 @@ def convert_onnx_tensor_proto_to_numpy_dimensions( tensor_proto: onnx.onnx_ml_pb2.TensorProto, -) -> str: +) -> tuple[int]: # Note: turning dynamic dimensions into just 1 here, since we need # a concrete (static) shape buffer of input data in the tests. return tuple( @@ -39,6 +38,14 @@ def convert_onnx_tensor_proto_to_numpy_dimensions( ) +def convert_onnxruntime_node_arg_to_numpy_dimensions( + node_arg: NodeArg, +) -> tuple[int]: + # Note: turning dynamic dimensions into just 1 here, since we need + # a concrete (static) shape buffer of input data in the tests. + return tuple(x if isinstance(x, int) else 1 for x in node_arg.shape) + + def convert_onnx_tensor_proto_to_iree_type_string( tensor_proto: onnx.onnx_ml_pb2.TensorProto, ) -> str: @@ -54,61 +61,89 @@ def convert_onnx_tensor_proto_to_iree_type_string( return f"{shape}x{dtype}" +def convert_onnxruntime_shape_to_iree_type_string( + node_arg: NodeArg, +) -> str: + # Note: turning dynamic dimensions into just "1" here, since we need + # a concrete (static) shape buffer of input data in the tests. + shape = "x".join([str(x) if isinstance(x, int) else "1" for x in node_arg.shape]) + dtype = convert_node_arg_type_to_iree_dtype(node_arg.type) + if shape == "": + return dtype + return f"{shape}x{dtype}" + + def get_onnx_model_metadata(onnx_path: Path): # We can either # A) List all metadata explicitly # B) Get metadata on demand from the .onnx protobuf using 'onnx' # C) Get metadata on demand from the InferenceSession using 'onnxruntime' - # This is option B. + # This is option C. + onnx_session = InferenceSession(onnx_path) logger.info(f"Getting model metadata for '{onnx_path.relative_to(THIS_DIR)}'") - model = onnx.load(onnx_path) - inputs = [] - for idx, graph_input in enumerate(model.graph.input): - type_proto = graph_input.type - if not type_proto.HasField("tensor_type"): - raise NotImplementedError(f"Unsupported proto type: {type_proto}") - tensor_type = type_proto.tensor_type + onnx_inputs = {} + for idx, input in enumerate(onnx_session.get_inputs()): + logger.debug(f"Session input [{idx}]") + logger.debug(f" name: '{input.name}'") + numpy_dimensions = convert_onnxruntime_node_arg_to_numpy_dimensions(input) + iree_type = convert_onnxruntime_shape_to_iree_type_string(input) + logger.debug(f" shape: {input.shape}") + logger.debug(f" numpy shape: {numpy_dimensions}") + logger.debug(f" type: '{input.type}'") + logger.debug(f" iree parameter: {iree_type}") # Create a numpy tensor with some random data for the input. - numpy_dimensions = convert_onnx_tensor_proto_to_numpy_dimensions(tensor_type) - numpy_dtype = TENSOR_TYPE_MAP[tensor_type.elem_type].np_dtype + numpy_dtype = convert_node_arg_type_to_numpy_dtype(input.type) if numpy_dtype == np.float32 or numpy_dtype == np.float64: input_data = rng.random(numpy_dimensions, dtype=numpy_dtype) elif numpy_dtype == np.int32 or numpy_dtype == np.int64: input_data = rng.integers(numpy_dimensions, dtype=numpy_dtype) else: raise NotImplementedError(f"Unsupported numpy type: {numpy_dtype}") - logger.debug(input_data) input_data_path = onnx_path.with_name(onnx_path.stem + f"_input_{idx}.bin") write_ndarray_to_binary_file(input_data, input_data_path) - iree_type = convert_onnx_tensor_proto_to_iree_type_string(tensor_type) inputs.append( { - "name": graph_input.name, + "name": input.name, "iree_type": iree_type, - "input_data": input_data, "input_data_path": input_data_path, } ) + onnx_inputs[input.name] = input_data + output_names = [output.name for output in onnx_session.get_outputs()] + onnx_results = onnx_session.run(output_names, onnx_inputs) + + assert len(onnx_session.get_outputs()) == len(onnx_results) outputs = [] - for graph_output in model.graph.output: - type_proto = graph_output.type - if not type_proto.HasField("tensor_type"): - raise NotImplementedError(f"Unsupported proto type: {type_proto}") - tensor_type = type_proto.tensor_type + for i in range(len(onnx_results)): + output = onnx_session.get_outputs()[i] + result = onnx_results[i] + logger.debug(f"Session output [{idx}]") + logger.debug(f" name: '{output.name}'") + logger.debug(f" shape (actual): {result.shape}") + logger.debug(f" type (numpy): '{result.dtype}'") + iree_type = convert_numpy_to_iree_type_string(result) + output_data_path = onnx_path.with_name(onnx_path.stem + f"_output_{idx}.bin") + write_ndarray_to_binary_file(result, output_data_path) - iree_type = convert_onnx_tensor_proto_to_iree_type_string(tensor_type) outputs.append( { - "name": graph_output.name, + "name": output.name, "iree_type": iree_type, + "output_data_path": output_data_path, } ) + outputs = [] + for idx, output in enumerate(onnx_session.get_outputs()): + logger.debug( + f"Session output [{idx}] name: '{output.name}', shape: {output.shape}, type: {output.type}" + ) + return { "inputs": inputs, "outputs": outputs, @@ -120,7 +155,7 @@ def get_onnx_model_metadata(onnx_path: Path): ############################################################################### -def compile_mlir_with_iree(mlir_path: Path, config_name: str, compile_flags: List[str]): +def compile_mlir_with_iree(mlir_path: Path, config_name: str, compile_flags: list[str]): cwd = THIS_DIR iree_module_path = mlir_path.with_name(mlir_path.stem + f"_{config_name}.vmfb") compile_args = ["iree-compile", mlir_path.relative_to(cwd)] @@ -142,7 +177,7 @@ def compile_mlir_with_iree(mlir_path: Path, config_name: str, compile_flags: Lis return iree_module_path -def run_iree_module(iree_module_path: Path, run_flags: List[str]): +def run_iree_module(iree_module_path: Path, run_flags: list[str]): cwd = THIS_DIR run_args = ["iree-run-module", f"--module={iree_module_path.relative_to(cwd)}"] run_args.extend(run_flags) @@ -185,36 +220,20 @@ def fn( upgraded_onnx_path = upgrade_onnx_model_version(original_onnx_path) onnx_model_metadata = get_onnx_model_metadata(upgraded_onnx_path) - logger.debug("ONNX model metadata:") + logger.debug("ONNX model metadata2:") logger.debug(onnx_model_metadata) - # Run through ONNX Runtime. - onnx_session = InferenceSession(upgraded_onnx_path) - output_names = [output["name"] for output in onnx_model_metadata["outputs"]] - inputs = {} - for input in onnx_model_metadata["inputs"]: - inputs[input["name"]] = input["input_data"] - onnx_results = onnx_session.run(output_names, inputs) - # Prepare inputs and expected outputs for running through IREE. run_module_args = [] for input in onnx_model_metadata["inputs"]: input_type = input["iree_type"] input_data_path = input["input_data_path"] run_module_args.append(f"--input={input_type}=@{input_data_path}") - - assert len(onnx_model_metadata["outputs"]) == len(onnx_results) - for idx in range(len(onnx_results)): - output = onnx_model_metadata["outputs"][idx] + for output in onnx_model_metadata["outputs"]: output_type = output["iree_type"] - onnx_result = onnx_results[idx] - logger.debug(np.array(onnx_result)) - reference_output_data_path = original_onnx_path.with_name( - original_onnx_path.stem + f"_output_{idx}.bin" - ) - write_ndarray_to_binary_file(onnx_result, reference_output_data_path) + output_data_path = output["output_data_path"] run_module_args.append( - f"--expected_output={output_type}=@{reference_output_data_path}" + f"--expected_output={output_type}=@{output_data_path}" ) # Import, compile, then run with IREE. diff --git a/onnx_models/utils.py b/onnx_models/utils.py index 91ebe32..37146df 100644 --- a/onnx_models/utils.py +++ b/onnx_models/utils.py @@ -60,6 +60,14 @@ def write_ndarray_to_binary_file(ndarr: np.ndarray, filename: Path): } +def convert_numpy_to_iree_type_string(ndarr: np.ndarray): + shape = "x".join(str(x) for x in ndarr.shape) + dtype = numpy_to_iree_dtype_map[ndarr.dtype] + if shape == "": + return dtype + return f"{shape}x{dtype}" + + def pack_ndarray_to_binary(ndarr: np.ndarray): mylist = ndarr.flatten().tolist() dtype = ndarr.dtype @@ -112,7 +120,23 @@ def convert_proto_elem_type_to_iree_dtype(etype) -> str: return "f8e5m2" if etype == onnx.TensorProto.FLOAT8E5M2FNUZ: return "f8e5m2fnuz" - return "" + raise NotImplementedError( + f"type conversion for '{etype}' enum value not implemented" + ) + + +def convert_node_arg_type_to_numpy_dtype(type: str): + # TODO(scotttodd): use onnx.TensorProto instead? enums > strings + if type == "tensor(float)": + return np.float32 + raise NotImplementedError(f"type conversion for '{type}' not implemented") + + +def convert_node_arg_type_to_iree_dtype(type: str) -> str: + # TODO(scotttodd): use onnx.TensorProto instead? enums > strings + if type == "tensor(float)": + return "f32" + raise NotImplementedError(f"type conversion for '{type}' not implemented") # TODO(#18289): use real frontend API, import model in-memory? diff --git a/onnx_models/vision_models_test.py b/onnx_models/vision_models_test.py index a35178f..19d6168 100644 --- a/onnx_models/vision_models_test.py +++ b/onnx_models/vision_models_test.py @@ -10,12 +10,11 @@ # https://github.com/onnx/models/tree/main/validated/vision/classification/mnist -# TODO(scotttodd): fix test runner (only use "Input3" input?) -# @pytest.mark.xfail(raises=IreeCompileException) -# def test_mnist_7(compare_between_iree_and_onnxruntime): -# compare_between_iree_and_onnxruntime( -# model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/mnist/model/mnist-7.onnx", -# ) +@pytest.mark.xfail(raises=IreeCompileException) +def test_mnist_7(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/mnist/model/mnist-7.onnx", + ) # https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet From 0a77c0c3aee9f6944664934f1c5b00db4b54da08 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 17 Sep 2024 11:57:37 -0700 Subject: [PATCH 19/26] Add alexnet test (just trying more models). --- onnx_models/vision_models_test.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnx_models/vision_models_test.py b/onnx_models/vision_models_test.py index 19d6168..436e0c8 100644 --- a/onnx_models/vision_models_test.py +++ b/onnx_models/vision_models_test.py @@ -30,3 +30,10 @@ def test_resnet50_v1_12(compare_between_iree_and_onnxruntime): compare_between_iree_and_onnxruntime( model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v1-12.onnx", ) + + +# https://github.com/onnx/models/tree/main/validated/vision/classification/alexnet +def test_alexnet_9(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/alexnet/model/bvlcalexnet-9.onnx", + ) From 1723696fc0264c9e3737a7faa318033a9730607f Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 17 Sep 2024 13:01:31 -0700 Subject: [PATCH 20/26] Enable xfail_strict for the entire directory. --- onnx_models/pytest.ini | 2 ++ onnx_models/vision_models_test.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) create mode 100644 onnx_models/pytest.ini diff --git a/onnx_models/pytest.ini b/onnx_models/pytest.ini new file mode 100644 index 0000000..d61d029 --- /dev/null +++ b/onnx_models/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +xfail_strict=true diff --git a/onnx_models/vision_models_test.py b/onnx_models/vision_models_test.py index 436e0c8..58e642e 100644 --- a/onnx_models/vision_models_test.py +++ b/onnx_models/vision_models_test.py @@ -8,9 +8,12 @@ from .utils import * +# Note: can mark tests as expected to fail at a specific stage with: +# @pytest.mark.xfail(raises=IreeCompileException) +# @pytest.mark.xfail(raises=IreeRunException) + # https://github.com/onnx/models/tree/main/validated/vision/classification/mnist -@pytest.mark.xfail(raises=IreeCompileException) def test_mnist_7(compare_between_iree_and_onnxruntime): compare_between_iree_and_onnxruntime( model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/mnist/model/mnist-7.onnx", @@ -18,7 +21,6 @@ def test_mnist_7(compare_between_iree_and_onnxruntime): # https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet -@pytest.mark.xfail(raises=IreeRunException) def test_mobilenetv2_12(compare_between_iree_and_onnxruntime): compare_between_iree_and_onnxruntime( model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", From 256f5faa8691da1ade5920f42c93bf108b000fab Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 17 Sep 2024 13:33:01 -0700 Subject: [PATCH 21/26] Cleanup: add dataclasses, prune unused code, update comments. --- onnx_models/conftest.py | 115 ++++++++++++++---------------- onnx_models/utils.py | 44 +----------- onnx_models/vision_models_test.py | 1 + 3 files changed, 58 insertions(+), 102 deletions(-) diff --git a/onnx_models/conftest.py b/onnx_models/conftest.py index b05b418..e218e4c 100644 --- a/onnx_models/conftest.py +++ b/onnx_models/conftest.py @@ -6,11 +6,10 @@ import logging import numpy as np -import onnx import pytest import subprocess import urllib.request -from onnx.mapping import TENSOR_TYPE_MAP +from dataclasses import dataclass from onnxruntime import InferenceSession, NodeArg from pathlib import Path @@ -28,14 +27,32 @@ ############################################################################### -def convert_onnx_tensor_proto_to_numpy_dimensions( - tensor_proto: onnx.onnx_ml_pb2.TensorProto, -) -> tuple[int]: - # Note: turning dynamic dimensions into just 1 here, since we need - # a concrete (static) shape buffer of input data in the tests. - return tuple( - d.dim_value if d.HasField("dim_value") else 1 for d in tensor_proto.shape.dim - ) +@dataclass(frozen=True) +class IreeModelParameterMetadata: + """Metadata for a single input or output used with iree-run-module tooling. + + Args: + name: The name of the parameter. + type: The type of the parameter as expected by the tools, e.g. "2x2xi32". + data_file: Path to either the input or expected output binary file for this parameter. + """ + + name: str + type: str + data_file: Path + + +@dataclass(frozen=True) +class OnnxModelMetadata: + """Metadata for an ONNX model. + + Args: + inputs: One parameter metadata per input. + outputs: One parameter metadata per output. + """ + + inputs: list[IreeModelParameterMetadata] + outputs: list[IreeModelParameterMetadata] def convert_onnxruntime_node_arg_to_numpy_dimensions( @@ -46,21 +63,6 @@ def convert_onnxruntime_node_arg_to_numpy_dimensions( return tuple(x if isinstance(x, int) else 1 for x in node_arg.shape) -def convert_onnx_tensor_proto_to_iree_type_string( - tensor_proto: onnx.onnx_ml_pb2.TensorProto, -) -> str: - shape = tensor_proto.shape - # Note: turning dynamic dimensions into just "1" here, since we need - # a concrete (static) shape buffer of input data in the tests. - shape = "x".join( - [str(d.dim_value) if d.HasField("dim_value") else "1" for d in shape.dim] - ) - dtype = convert_proto_elem_type_to_iree_dtype(tensor_proto.elem_type) - if shape == "": - return dtype - return f"{shape}x{dtype}" - - def convert_onnxruntime_shape_to_iree_type_string( node_arg: NodeArg, ) -> str: @@ -73,7 +75,7 @@ def convert_onnxruntime_shape_to_iree_type_string( return f"{shape}x{dtype}" -def get_onnx_model_metadata(onnx_path: Path): +def get_onnx_model_metadata(onnx_path: Path) -> OnnxModelMetadata: # We can either # A) List all metadata explicitly # B) Get metadata on demand from the .onnx protobuf using 'onnx' @@ -106,14 +108,15 @@ def get_onnx_model_metadata(onnx_path: Path): write_ndarray_to_binary_file(input_data, input_data_path) inputs.append( - { - "name": input.name, - "iree_type": iree_type, - "input_data_path": input_data_path, - } + IreeModelParameterMetadata( + name=input.name, + type=iree_type, + data_file=input_data_path, + ) ) onnx_inputs[input.name] = input_data + # Run through onnxruntime and then save the output results. output_names = [output.name for output in onnx_session.get_outputs()] onnx_results = onnx_session.run(output_names, onnx_inputs) @@ -131,23 +134,14 @@ def get_onnx_model_metadata(onnx_path: Path): write_ndarray_to_binary_file(result, output_data_path) outputs.append( - { - "name": output.name, - "iree_type": iree_type, - "output_data_path": output_data_path, - } - ) - - outputs = [] - for idx, output in enumerate(onnx_session.get_outputs()): - logger.debug( - f"Session output [{idx}] name: '{output.name}', shape: {output.shape}, type: {output.type}" + IreeModelParameterMetadata( + name=output.name, + type=iree_type, + data_file=output_data_path, + ) ) - return { - "inputs": inputs, - "outputs": outputs, - } + return OnnxModelMetadata(inputs=inputs, outputs=outputs) ############################################################################### @@ -201,39 +195,37 @@ def compare_between_iree_and_onnxruntime(): def fn( model_url: str, ): + if not ARTIFACTS_DIR.is_dir(): + ARTIFACTS_DIR.mkdir(parents=True) + # TODO(scotttodd): group model artifacts into subfolders + + # Extract path and file components from the model URL. # "https://github.com/.../mobilenetv2-12.onnx" --> "mobilenetv2-12.onnx" model_file_name = model_url.rsplit("/", 1)[-1] # "mobilenetv2-12.onnx" --> "mobilenetv2-12" model_name = model_file_name.rsplit(".", 1)[0] - if not ARTIFACTS_DIR.is_dir(): - ARTIFACTS_DIR.mkdir(parents=True) - # TODO(scotttodd): group model artifacts into subfolders - + # Download the model as needed. # TODO(scotttodd): move to fixture with cache / download on demand # TODO(scotttodd): overwrite if already existing? check SHA? original_onnx_path = ARTIFACTS_DIR / f"{model_name}.onnx" if not original_onnx_path.exists(): urllib.request.urlretrieve(model_url, original_onnx_path) - # TODO(scotttodd): cache ONNX metadata and runtime results + # TODO(scotttodd): cache ONNX metadata and runtime results (pickle?) upgraded_onnx_path = upgrade_onnx_model_version(original_onnx_path) onnx_model_metadata = get_onnx_model_metadata(upgraded_onnx_path) - logger.debug("ONNX model metadata2:") + logger.debug("ONNX model metadata:") logger.debug(onnx_model_metadata) # Prepare inputs and expected outputs for running through IREE. run_module_args = [] - for input in onnx_model_metadata["inputs"]: - input_type = input["iree_type"] - input_data_path = input["input_data_path"] - run_module_args.append(f"--input={input_type}=@{input_data_path}") - for output in onnx_model_metadata["outputs"]: - output_type = output["iree_type"] - output_data_path = output["output_data_path"] + for input in onnx_model_metadata.inputs: + run_module_args.append(f"--input={input.type}=@{input.data_file}") + for output in onnx_model_metadata.outputs: run_module_args.append( - f"--expected_output={output_type}=@{output_data_path}" + f"--expected_output={output.type}=@{output.data_file}" ) # Import, compile, then run with IREE. @@ -241,7 +233,8 @@ def fn( iree_module_path = compile_mlir_with_iree( imported_mlir_path, "cpu", ["--iree-hal-target-backends=llvm-cpu"] ) - # Note: could load the output into memory here and compare using numpy. + # Note: could load the output into memory here and compare using numpy + # if the pass/fail criteria is difficult to model in the native tooling. run_flags = ["--device=local-task"] run_flags.extend(run_module_args) run_iree_module(iree_module_path, run_flags) diff --git a/onnx_models/utils.py b/onnx_models/utils.py index 37146df..66e8c8f 100644 --- a/onnx_models/utils.py +++ b/onnx_models/utils.py @@ -62,7 +62,7 @@ def write_ndarray_to_binary_file(ndarr: np.ndarray, filename: Path): def convert_numpy_to_iree_type_string(ndarr: np.ndarray): shape = "x".join(str(x) for x in ndarr.shape) - dtype = numpy_to_iree_dtype_map[ndarr.dtype] + dtype = numpy_to_iree_dtype_map[ndarr.dtype][0] if shape == "": return dtype return f"{shape}x{dtype}" @@ -87,53 +87,15 @@ def pack_ndarray_to_binary(ndarr: np.ndarray): ############################################################################### -def convert_proto_elem_type_to_iree_dtype(etype) -> str: - if etype == onnx.TensorProto.BOOL: - return "i1" - if etype == onnx.TensorProto.INT4 or etype == onnx.TensorProto.UINT4: - return "i4" - if etype == onnx.TensorProto.INT8 or etype == onnx.TensorProto.UINT8: - return "i8" - if etype == onnx.TensorProto.INT16 or etype == onnx.TensorProto.UINT16: - return "i16" - if etype == onnx.TensorProto.INT32 or etype == onnx.TensorProto.UINT32: - return "i32" - if etype == onnx.TensorProto.INT64 or etype == onnx.TensorProto.UINT64: - return "i64" - if etype == onnx.TensorProto.FLOAT16: - return "f16" - if etype == onnx.TensorProto.FLOAT: - return "f32" - if etype == onnx.TensorProto.DOUBLE: - return "f64" - if etype == onnx.TensorProto.COMPLEX64: - return "complex" - if etype == onnx.TensorProto.COMPLEX128: - return "complex" - if etype == onnx.TensorProto.BFLOAT16: - return "bf16" - if etype == onnx.TensorProto.FLOAT8E4M3FN: - return "f8e4m3fn" - if etype == onnx.TensorProto.FLOAT8E4M3FNUZ: - return "f8e4m3fnuz" - if etype == onnx.TensorProto.FLOAT8E5M2: - return "f8e5m2" - if etype == onnx.TensorProto.FLOAT8E5M2FNUZ: - return "f8e5m2fnuz" - raise NotImplementedError( - f"type conversion for '{etype}' enum value not implemented" - ) - - def convert_node_arg_type_to_numpy_dtype(type: str): - # TODO(scotttodd): use onnx.TensorProto instead? enums > strings + # TODO(scotttodd): use onnx.TensorProto instead? prefer enums over strings if type == "tensor(float)": return np.float32 raise NotImplementedError(f"type conversion for '{type}' not implemented") def convert_node_arg_type_to_iree_dtype(type: str) -> str: - # TODO(scotttodd): use onnx.TensorProto instead? enums > strings + # TODO(scotttodd): use onnx.TensorProto instead? prefer enums over strings if type == "tensor(float)": return "f32" raise NotImplementedError(f"type conversion for '{type}' not implemented") diff --git a/onnx_models/vision_models_test.py b/onnx_models/vision_models_test.py index 58e642e..554bea0 100644 --- a/onnx_models/vision_models_test.py +++ b/onnx_models/vision_models_test.py @@ -21,6 +21,7 @@ def test_mnist_7(compare_between_iree_and_onnxruntime): # https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet +@pytest.mark.xfail(raises=IreeRunException) def test_mobilenetv2_12(compare_between_iree_and_onnxruntime): compare_between_iree_and_onnxruntime( model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", From e39121ed995a0538eb5452e5980f500edb283edf Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 17 Sep 2024 13:49:00 -0700 Subject: [PATCH 22/26] Save on indentation by using a helper function. --- onnx_models/conftest.py | 93 ++++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 48 deletions(-) diff --git a/onnx_models/conftest.py b/onnx_models/conftest.py index e218e4c..986a5d8 100644 --- a/onnx_models/conftest.py +++ b/onnx_models/conftest.py @@ -190,53 +190,50 @@ def run_iree_module(iree_module_path: Path, run_flags: list[str]): raise IreeRunException(f" '{iree_module_path.name}' run failed") -@pytest.fixture -def compare_between_iree_and_onnxruntime(): - def fn( - model_url: str, - ): - if not ARTIFACTS_DIR.is_dir(): - ARTIFACTS_DIR.mkdir(parents=True) - # TODO(scotttodd): group model artifacts into subfolders - - # Extract path and file components from the model URL. - # "https://github.com/.../mobilenetv2-12.onnx" --> "mobilenetv2-12.onnx" - model_file_name = model_url.rsplit("/", 1)[-1] - # "mobilenetv2-12.onnx" --> "mobilenetv2-12" - model_name = model_file_name.rsplit(".", 1)[0] - - # Download the model as needed. - # TODO(scotttodd): move to fixture with cache / download on demand - # TODO(scotttodd): overwrite if already existing? check SHA? - original_onnx_path = ARTIFACTS_DIR / f"{model_name}.onnx" - if not original_onnx_path.exists(): - urllib.request.urlretrieve(model_url, original_onnx_path) - - # TODO(scotttodd): cache ONNX metadata and runtime results (pickle?) - upgraded_onnx_path = upgrade_onnx_model_version(original_onnx_path) - - onnx_model_metadata = get_onnx_model_metadata(upgraded_onnx_path) - logger.debug("ONNX model metadata:") - logger.debug(onnx_model_metadata) - - # Prepare inputs and expected outputs for running through IREE. - run_module_args = [] - for input in onnx_model_metadata.inputs: - run_module_args.append(f"--input={input.type}=@{input.data_file}") - for output in onnx_model_metadata.outputs: - run_module_args.append( - f"--expected_output={output.type}=@{output.data_file}" - ) +def compare_between_iree_and_onnxruntime_fn(model_url: str): + if not ARTIFACTS_DIR.is_dir(): + ARTIFACTS_DIR.mkdir(parents=True) + # TODO(scotttodd): group model artifacts into subfolders + + # Extract path and file components from the model URL. + # "https://github.com/.../mobilenetv2-12.onnx" --> "mobilenetv2-12.onnx" + model_file_name = model_url.rsplit("/", 1)[-1] + # "mobilenetv2-12.onnx" --> "mobilenetv2-12" + model_name = model_file_name.rsplit(".", 1)[0] + + # Download the model as needed. + # TODO(scotttodd): move to fixture with cache / download on demand + # TODO(scotttodd): overwrite if already existing? check SHA? + original_onnx_path = ARTIFACTS_DIR / f"{model_name}.onnx" + if not original_onnx_path.exists(): + urllib.request.urlretrieve(model_url, original_onnx_path) + + # TODO(scotttodd): cache ONNX metadata and runtime results (pickle?) + upgraded_onnx_path = upgrade_onnx_model_version(original_onnx_path) + + onnx_model_metadata = get_onnx_model_metadata(upgraded_onnx_path) + logger.debug("ONNX model metadata:") + logger.debug(onnx_model_metadata) + + # Prepare inputs and expected outputs for running through IREE. + run_module_args = [] + for input in onnx_model_metadata.inputs: + run_module_args.append(f"--input={input.type}=@{input.data_file}") + for output in onnx_model_metadata.outputs: + run_module_args.append(f"--expected_output={output.type}=@{output.data_file}") + + # Import, compile, then run with IREE. + imported_mlir_path = import_onnx_model_to_mlir(upgraded_onnx_path) + iree_module_path = compile_mlir_with_iree( + imported_mlir_path, "cpu", ["--iree-hal-target-backends=llvm-cpu"] + ) + # Note: could load the output into memory here and compare using numpy + # if the pass/fail criteria is difficult to model in the native tooling. + run_flags = ["--device=local-task"] + run_flags.extend(run_module_args) + run_iree_module(iree_module_path, run_flags) - # Import, compile, then run with IREE. - imported_mlir_path = import_onnx_model_to_mlir(upgraded_onnx_path) - iree_module_path = compile_mlir_with_iree( - imported_mlir_path, "cpu", ["--iree-hal-target-backends=llvm-cpu"] - ) - # Note: could load the output into memory here and compare using numpy - # if the pass/fail criteria is difficult to model in the native tooling. - run_flags = ["--device=local-task"] - run_flags.extend(run_module_args) - run_iree_module(iree_module_path, run_flags) - return fn +@pytest.fixture +def compare_between_iree_and_onnxruntime(): + return compare_between_iree_and_onnxruntime_fn From cb1d83312aca669df3a6d683fbbc9f9852a7415c Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 17 Sep 2024 14:43:49 -0700 Subject: [PATCH 23/26] Move into subdir, add size marks, import more tests, add to docs. --- .github/workflows/test_onnx_models.yml | 3 +- onnx_models/README.md | 34 +++++ onnx_models/conftest.py | 12 +- onnx_models/pytest.ini | 2 + onnx_models/tests/__init__.py | 0 onnx_models/tests/vision/__init__.py | 0 .../vision/classification_models_test.py | 137 ++++++++++++++++++ onnx_models/utils.py | 5 + onnx_models/vision_models_test.py | 42 ------ 9 files changed, 186 insertions(+), 49 deletions(-) create mode 100644 onnx_models/tests/__init__.py create mode 100644 onnx_models/tests/vision/__init__.py create mode 100644 onnx_models/tests/vision/classification_models_test.py delete mode 100644 onnx_models/vision_models_test.py diff --git a/.github/workflows/test_onnx_models.yml b/.github/workflows/test_onnx_models.yml index 0fe5ddf..9d5bdec 100644 --- a/.github/workflows/test_onnx_models.yml +++ b/.github/workflows/test_onnx_models.yml @@ -56,6 +56,7 @@ jobs: source ${VENV_DIR}/bin/activate pytest onnx_models/ \ -rA \ + -n 4 \ --log-cli-level=info \ - --timeout=60 \ + --timeout=120 \ --durations=0 diff --git a/onnx_models/README.md b/onnx_models/README.md index d80260a..3c5c5d9 100644 --- a/onnx_models/README.md +++ b/onnx_models/README.md @@ -52,3 +52,37 @@ graph LR ``` See https://docs.pytest.org/en/stable/how-to/usage.html for other options. + +## Advanced pytest usage + +* The `log-cli-level` level can also be set to `debug`, `warning`, or `error`. + See https://docs.pytest.org/en/stable/how-to/logging.html. +* Run only tests matching a name pattern: + + ```bash + pytest -k resnet + ``` + +* Skip "medium" sized tests using custom markers + (https://docs.pytest.org/en/stable/example/markers.html): + + ```bash + pytest -m "not size_medium" + ``` + +* Ignore xfail marks + (https://docs.pytest.org/en/stable/how-to/skipping.html#ignoring-xfail): + + ```bash + pytest --runxfail + ``` + +* Run tests in parallel using https://pytest-xdist.readthedocs.io/en/stable/: + + ```bash + # Run with an automatic number of threads (usually one per CPU core). + pytest -n auto + + # Run on an explicit number of threads. + pytest -n 4 + ``` diff --git a/onnx_models/conftest.py b/onnx_models/conftest.py index 986a5d8..b6db1e9 100644 --- a/onnx_models/conftest.py +++ b/onnx_models/conftest.py @@ -19,7 +19,7 @@ rng = np.random.default_rng(0) THIS_DIR = Path(__file__).parent -ARTIFACTS_DIR = THIS_DIR / "artifacts" +ARTIFACTS_ROOT = THIS_DIR / "artifacts" ############################################################################### @@ -190,10 +190,10 @@ def run_iree_module(iree_module_path: Path, run_flags: list[str]): raise IreeRunException(f" '{iree_module_path.name}' run failed") -def compare_between_iree_and_onnxruntime_fn(model_url: str): - if not ARTIFACTS_DIR.is_dir(): - ARTIFACTS_DIR.mkdir(parents=True) - # TODO(scotttodd): group model artifacts into subfolders +def compare_between_iree_and_onnxruntime_fn(model_url: str, artifacts_subdir=""): + test_artifacts_dir = ARTIFACTS_ROOT / artifacts_subdir + if not test_artifacts_dir.is_dir(): + test_artifacts_dir.mkdir(parents=True) # Extract path and file components from the model URL. # "https://github.com/.../mobilenetv2-12.onnx" --> "mobilenetv2-12.onnx" @@ -204,7 +204,7 @@ def compare_between_iree_and_onnxruntime_fn(model_url: str): # Download the model as needed. # TODO(scotttodd): move to fixture with cache / download on demand # TODO(scotttodd): overwrite if already existing? check SHA? - original_onnx_path = ARTIFACTS_DIR / f"{model_name}.onnx" + original_onnx_path = test_artifacts_dir / f"{model_name}.onnx" if not original_onnx_path.exists(): urllib.request.urlretrieve(model_url, original_onnx_path) diff --git a/onnx_models/pytest.ini b/onnx_models/pytest.ini index d61d029..a54d927 100644 --- a/onnx_models/pytest.ini +++ b/onnx_models/pytest.ini @@ -1,2 +1,4 @@ [pytest] xfail_strict=true +markers = + size_medium: mark tests as being "medium" size (500MB+ data, 30 seconds+ runtime) diff --git a/onnx_models/tests/__init__.py b/onnx_models/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/onnx_models/tests/vision/__init__.py b/onnx_models/tests/vision/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/onnx_models/tests/vision/classification_models_test.py b/onnx_models/tests/vision/classification_models_test.py new file mode 100644 index 0000000..87edac9 --- /dev/null +++ b/onnx_models/tests/vision/classification_models_test.py @@ -0,0 +1,137 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# https://github.com/onnx/models/tree/main/validated/vision/classification/ + +import pytest + +from ...utils import * + + +def test_alexnet(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/alexnet/model/bvlcalexnet-12.onnx", + artifacts_subdir="vision/classification", + ) + + +def test_caffenet(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/caffenet/model/caffenet-12.onnx", + artifacts_subdir="vision/classification", + ) + + +def test_densenet_121(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/densenet-121/model/densenet-12.onnx", + artifacts_subdir="vision/classification", + ) + + +@pytest.mark.xfail(raises=IreeCompileException) +def test_efficientnet_lite4(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx", + artifacts_subdir="vision/classification", + ) + + +def test_googlenet(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/inception_and_googlenet/googlenet/model/googlenet-12.onnx", + artifacts_subdir="vision/classification", + ) + + +@pytest.mark.xfail(raises=IreeCompileException) +def test_inception_v1(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-12.onnx", + artifacts_subdir="vision/classification", + ) + + +def test_inception_v2(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-9.onnx", + artifacts_subdir="vision/classification", + ) + + +def test_mnist(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/mnist/model/mnist-12.onnx", + artifacts_subdir="vision/classification", + ) + + +@pytest.mark.xfail(raises=IreeRunException) +def test_mobilenet(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", + artifacts_subdir="vision/classification", + ) + + +@pytest.mark.xfail(raises=IreeCompileException) +def test_rcnn_ilsvrc13(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/rcnn_ilsvrc13/model/rcnn-ilsvrc13-9.onnx", + artifacts_subdir="vision/classification", + ) + + +def test_resnet50_v1(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v1-12.onnx", + artifacts_subdir="vision/classification", + ) + + +def test_resnet50_v2(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v2-7.onnx", + artifacts_subdir="vision/classification", + ) + + +def test_shufflenet(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/shufflenet/model/shufflenet-9.onnx", + artifacts_subdir="vision/classification", + ) + + +def test_shufflenet_v2(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/shufflenet/model/shufflenet-v2-12.onnx", + artifacts_subdir="vision/classification", + ) + + +def test_squeezenet(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/squeezenet/model/squeezenet1.0-9.onnx", + artifacts_subdir="vision/classification", + ) + + +@pytest.mark.size_medium +def test_vgg19(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/vgg/model/vgg19-7.onnx", + artifacts_subdir="vision/classification", + ) + + +@pytest.mark.size_medium +@pytest.mark.xfail(raises=IreeCompileException) +def test_zfnet_512(compare_between_iree_and_onnxruntime): + compare_between_iree_and_onnxruntime( + model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/zfnet-512/model/zfnet512-12.onnx", + artifacts_subdir="vision/classification", + ) diff --git a/onnx_models/utils.py b/onnx_models/utils.py index 66e8c8f..c7eefaf 100644 --- a/onnx_models/utils.py +++ b/onnx_models/utils.py @@ -19,6 +19,11 @@ # Exception types ############################################################################### +# Note: can mark tests as expected to fail at a specific stage with: +# @pytest.mark.xfail(raises=IreeImportOnnxException) +# @pytest.mark.xfail(raises=IreeCompileException) +# @pytest.mark.xfail(raises=IreeRunException) + class IreeImportOnnxException(RuntimeError): pass diff --git a/onnx_models/vision_models_test.py b/onnx_models/vision_models_test.py deleted file mode 100644 index 554bea0..0000000 --- a/onnx_models/vision_models_test.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2024 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import pytest - -from .utils import * - -# Note: can mark tests as expected to fail at a specific stage with: -# @pytest.mark.xfail(raises=IreeCompileException) -# @pytest.mark.xfail(raises=IreeRunException) - - -# https://github.com/onnx/models/tree/main/validated/vision/classification/mnist -def test_mnist_7(compare_between_iree_and_onnxruntime): - compare_between_iree_and_onnxruntime( - model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/mnist/model/mnist-7.onnx", - ) - - -# https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet -@pytest.mark.xfail(raises=IreeRunException) -def test_mobilenetv2_12(compare_between_iree_and_onnxruntime): - compare_between_iree_and_onnxruntime( - model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", - ) - - -# https://github.com/onnx/models/tree/main/validated/vision/classification/resnet -def test_resnet50_v1_12(compare_between_iree_and_onnxruntime): - compare_between_iree_and_onnxruntime( - model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v1-12.onnx", - ) - - -# https://github.com/onnx/models/tree/main/validated/vision/classification/alexnet -def test_alexnet_9(compare_between_iree_and_onnxruntime): - compare_between_iree_and_onnxruntime( - model_url="https://github.com/onnx/models/raw/main/validated/vision/classification/alexnet/model/bvlcalexnet-9.onnx", - ) From 2700a76cca3a10d868ee06f71691d4b453b15aad Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 17 Sep 2024 15:11:10 -0700 Subject: [PATCH 24/26] Iterate on docs and logging. --- .github/workflows/test_onnx_models.yml | 1 - onnx_models/README.md | 76 +++++++++++++++++++++++++- onnx_models/conftest.py | 12 ++-- 3 files changed, 83 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test_onnx_models.yml b/.github/workflows/test_onnx_models.yml index 9d5bdec..3d2d878 100644 --- a/.github/workflows/test_onnx_models.yml +++ b/.github/workflows/test_onnx_models.yml @@ -56,7 +56,6 @@ jobs: source ${VENV_DIR}/bin/activate pytest onnx_models/ \ -rA \ - -n 4 \ --log-cli-level=info \ --timeout=120 \ --durations=0 diff --git a/onnx_models/README.md b/onnx_models/README.md index 3c5c5d9..c5a53a3 100644 --- a/onnx_models/README.md +++ b/onnx_models/README.md @@ -77,7 +77,8 @@ graph LR pytest --runxfail ``` -* Run tests in parallel using https://pytest-xdist.readthedocs.io/en/stable/: +* Run tests in parallel using https://pytest-xdist.readthedocs.io/ + (note that this swallows some logging): ```bash # Run with an automatic number of threads (usually one per CPU core). @@ -86,3 +87,76 @@ graph LR # Run on an explicit number of threads. pytest -n 4 ``` + +## Debugging tests outside of pytest + +Each test generates some files as it runs: + +```text +├── artifacts +│ └── vision +│ └── classification +│ ├── mnist-12_version17_cpu.vmfb (Program compiled using IREE's llvm-cpu target) +│ ├── mnist-12_version17_input_0.bin (Random input generated using numpy) +│ ├── mnist-12_version17_output_0.bin (Reference output from onnxruntime) +│ ├── mnist-12_version17.mlir (The model imported to MLIR) +│ ├── mnist-12_version17.onnx (The model upgraded to a minimum supported version) +│ └── mnist-12.onnx (The downloaded ONNX model) +``` + +Running a test with logging enabled will show what the test is doing: + +```console +pytest --log-cli-level=debug -k mnist + +======================================= test session starts ======================================= +platform win32 -- Python 3.11.2, pytest-8.3.3, pluggy-1.5.0 +rootdir: D:\dev\projects\iree-test-suites\onnx_models +configfile: pytest.ini +plugins: reportlog-0.4.0, timeout-2.3.1, xdist-3.6.1 +collected 17 items / 16 deselected / 1 selected + +tests/vision/classification_models_test.py::test_mnist +------------------------------------------ live log call ------------------------------------------ +INFO onnx_models.utils:utils.py:125 Upgrading 'artifacts\vision\classification\mnist-12.onnx' to 'artifacts\vision\classification\mnist-12_version17.onnx' +DEBUG onnx_models.conftest:conftest.py:90 Session input [0] +DEBUG onnx_models.conftest:conftest.py:91 name: 'Input3' +DEBUG onnx_models.conftest:conftest.py:94 shape: [1, 1, 28, 28] +DEBUG onnx_models.conftest:conftest.py:95 numpy shape: (1, 1, 28, 28) +DEBUG onnx_models.conftest:conftest.py:96 type: 'tensor(float)' +DEBUG onnx_models.conftest:conftest.py:97 iree parameter: 1x1x28x28xf32 +DEBUG onnx_models.conftest:conftest.py:129 Session output [0] +DEBUG onnx_models.conftest:conftest.py:130 name: 'Plus214_Output_0' +DEBUG onnx_models.conftest:conftest.py:131 shape (actual): (1, 10) +DEBUG onnx_models.conftest:conftest.py:132 type (numpy): 'float32' +DEBUG onnx_models.conftest:conftest.py:133 iree parameter: 1x10xf32 +DEBUG onnx_models.conftest:conftest.py:217 OnnxModelMetadata(inputs=[IreeModelParameterMetadata(name='Input3', type='1x1x28x28xf32', data_file=WindowsPath('D:/dev/projects/iree-test-suites/onnx_models/artifacts/vision/classification/mnist-12_version17_input_0.bin'))], outputs=[IreeModelParameterMetadata(name='Plus214_Output_0', type='1x10xf32', data_file=WindowsPath('D:/dev/projects/iree-test-suites/onnx_models/artifacts/vision/classification/mnist-12_version17_output_0.bin'))]) +INFO onnx_models.utils:utils.py:135 Importing 'artifacts\vision\classification\mnist-12_version17.onnx' to 'artifacts\vision\classification\mnist-12_version17.mlir' +INFO onnx_models.conftest:conftest.py:160 Launching compile command: + cd D:\dev\projects\iree-test-suites\onnx_models && iree-compile artifacts\vision\classification\mnist-12_version17.mlir --iree-hal-target-backends=llvm-cpu -o artifacts\vision\classification\mnist-12_version17_cpu.vmfb +INFO onnx_models.conftest:conftest.py:180 Launching run command: + cd D:\dev\projects\iree-test-suites\onnx_models && iree-run-module --module=artifacts\vision\classification\mnist-12_version17_cpu.vmfb --device=local-task --input=1x1x28x28xf32=@artifacts\vision\classification\mnist-12_version17_input_0.bin --expected_output=1x10xf32=@artifacts\vision\classification\mnist-12_version17_output_0.bin +PASSED [100%] + +================================ 1 passed, 16 deselected in 1.81s ================================= +``` + +For this test case there is one input with shape/type `1x1x28x28xf32` stored at +`artifacts/vision/classification/mnist-12_version17_input_0.bin` and one output +with shape/type `1x10xf32` stored at +`artifacts/vision/classification/mnist-12_version17_output_0.bin`. + +We can reproduce the compile and run commands with: + +```bash +iree-compile \ + artifacts/vision/classification/mnist-12_version17.mlir \ + --iree-hal-target-backends=llvm-cpu \ + -o artifacts/vision/classification/mnist-12_version17_cpu.vmfb + +iree-run-module \ + --module=artifacts/vision/classification/mnist-12_version17_cpu.vmfb \ + --device=local-task \ + --input=1x1x28x28xf32=@artifacts/vision/classification/mnist-12_version17_input_0.bin \ + --expected_output=1x10xf32=@artifacts/vision/classification/mnist-12_version17_output_0.bin +``` diff --git a/onnx_models/conftest.py b/onnx_models/conftest.py index b6db1e9..b009784 100644 --- a/onnx_models/conftest.py +++ b/onnx_models/conftest.py @@ -125,11 +125,12 @@ def get_onnx_model_metadata(onnx_path: Path) -> OnnxModelMetadata: for i in range(len(onnx_results)): output = onnx_session.get_outputs()[i] result = onnx_results[i] + iree_type = convert_numpy_to_iree_type_string(result) logger.debug(f"Session output [{idx}]") logger.debug(f" name: '{output.name}'") logger.debug(f" shape (actual): {result.shape}") logger.debug(f" type (numpy): '{result.dtype}'") - iree_type = convert_numpy_to_iree_type_string(result) + logger.debug(f" iree parameter: {iree_type}") output_data_path = onnx_path.with_name(onnx_path.stem + f"_output_{idx}.bin") write_ndarray_to_binary_file(result, output_data_path) @@ -212,15 +213,18 @@ def compare_between_iree_and_onnxruntime_fn(model_url: str, artifacts_subdir="") upgraded_onnx_path = upgrade_onnx_model_version(original_onnx_path) onnx_model_metadata = get_onnx_model_metadata(upgraded_onnx_path) - logger.debug("ONNX model metadata:") logger.debug(onnx_model_metadata) # Prepare inputs and expected outputs for running through IREE. run_module_args = [] for input in onnx_model_metadata.inputs: - run_module_args.append(f"--input={input.type}=@{input.data_file}") + run_module_args.append( + f"--input={input.type}=@{input.data_file.relative_to(THIS_DIR)}" + ) for output in onnx_model_metadata.outputs: - run_module_args.append(f"--expected_output={output.type}=@{output.data_file}") + run_module_args.append( + f"--expected_output={output.type}=@{output.data_file.relative_to(THIS_DIR)}" + ) # Import, compile, then run with IREE. imported_mlir_path = import_onnx_model_to_mlir(upgraded_onnx_path) From 05e1c078f437db050edc89d3187c68da3089e9d8 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 17 Sep 2024 15:29:49 -0700 Subject: [PATCH 25/26] Add pytest-html package to deps. --- onnx_models/README.md | 9 +++++++++ onnx_models/requirements.txt | 1 + 2 files changed, 10 insertions(+) diff --git a/onnx_models/README.md b/onnx_models/README.md index c5a53a3..f607027 100644 --- a/onnx_models/README.md +++ b/onnx_models/README.md @@ -88,6 +88,15 @@ graph LR pytest -n 4 ``` +* Create an HTMl report using https://pytest-html.readthedocs.io/en/latest/index.html + + ```bash + pytest --html=report.html --self-contained-html --log-cli-level=info + ``` + + See also + https://docs.pytest.org/en/latest/how-to/output.html#creating-junitxml-format-files + ## Debugging tests outside of pytest Each test generates some files as it runs: diff --git a/onnx_models/requirements.txt b/onnx_models/requirements.txt index 9ded4d3..2cbade5 100644 --- a/onnx_models/requirements.txt +++ b/onnx_models/requirements.txt @@ -4,6 +4,7 @@ onnx onnxruntime pytest +pytest-html pytest-reportlog pytest-timeout pytest-xdist From 5e967114f98b206cbf75771f99b465522e993529 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Wed, 18 Sep 2024 12:09:00 -0700 Subject: [PATCH 26/26] Refactor to support more dtypes. --- onnx_models/conftest.py | 36 ++------------------- onnx_models/utils.py | 70 +++++++++++++++++++++++++++++++++++------ 2 files changed, 63 insertions(+), 43 deletions(-) diff --git a/onnx_models/conftest.py b/onnx_models/conftest.py index b009784..ac2f215 100644 --- a/onnx_models/conftest.py +++ b/onnx_models/conftest.py @@ -5,18 +5,16 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import logging -import numpy as np import pytest import subprocess import urllib.request from dataclasses import dataclass -from onnxruntime import InferenceSession, NodeArg +from onnxruntime import InferenceSession from pathlib import Path from .utils import * logger = logging.getLogger(__name__) -rng = np.random.default_rng(0) THIS_DIR = Path(__file__).parent ARTIFACTS_ROOT = THIS_DIR / "artifacts" @@ -55,26 +53,6 @@ class OnnxModelMetadata: outputs: list[IreeModelParameterMetadata] -def convert_onnxruntime_node_arg_to_numpy_dimensions( - node_arg: NodeArg, -) -> tuple[int]: - # Note: turning dynamic dimensions into just 1 here, since we need - # a concrete (static) shape buffer of input data in the tests. - return tuple(x if isinstance(x, int) else 1 for x in node_arg.shape) - - -def convert_onnxruntime_shape_to_iree_type_string( - node_arg: NodeArg, -) -> str: - # Note: turning dynamic dimensions into just "1" here, since we need - # a concrete (static) shape buffer of input data in the tests. - shape = "x".join([str(x) if isinstance(x, int) else "1" for x in node_arg.shape]) - dtype = convert_node_arg_type_to_iree_dtype(node_arg.type) - if shape == "": - return dtype - return f"{shape}x{dtype}" - - def get_onnx_model_metadata(onnx_path: Path) -> OnnxModelMetadata: # We can either # A) List all metadata explicitly @@ -89,21 +67,13 @@ def get_onnx_model_metadata(onnx_path: Path) -> OnnxModelMetadata: for idx, input in enumerate(onnx_session.get_inputs()): logger.debug(f"Session input [{idx}]") logger.debug(f" name: '{input.name}'") - numpy_dimensions = convert_onnxruntime_node_arg_to_numpy_dimensions(input) - iree_type = convert_onnxruntime_shape_to_iree_type_string(input) + iree_type = convert_ort_to_iree_type(input) logger.debug(f" shape: {input.shape}") - logger.debug(f" numpy shape: {numpy_dimensions}") logger.debug(f" type: '{input.type}'") logger.debug(f" iree parameter: {iree_type}") # Create a numpy tensor with some random data for the input. - numpy_dtype = convert_node_arg_type_to_numpy_dtype(input.type) - if numpy_dtype == np.float32 or numpy_dtype == np.float64: - input_data = rng.random(numpy_dimensions, dtype=numpy_dtype) - elif numpy_dtype == np.int32 or numpy_dtype == np.int64: - input_data = rng.integers(numpy_dimensions, dtype=numpy_dtype) - else: - raise NotImplementedError(f"Unsupported numpy type: {numpy_dtype}") + input_data = generate_numpy_input_for_ort_node_arg(input) input_data_path = onnx_path.with_name(onnx_path.stem + f"_input_{idx}.bin") write_ndarray_to_binary_file(input_data, input_data_path) diff --git a/onnx_models/utils.py b/onnx_models/utils.py index c7eefaf..bdfc0b3 100644 --- a/onnx_models/utils.py +++ b/onnx_models/utils.py @@ -9,9 +9,11 @@ import onnx import struct import subprocess +from onnxruntime import NodeArg from pathlib import Path logger = logging.getLogger(__name__) +rng = np.random.default_rng(0) THIS_DIR = Path(__file__).parent @@ -92,18 +94,66 @@ def pack_ndarray_to_binary(ndarr: np.ndarray): ############################################################################### -def convert_node_arg_type_to_numpy_dtype(type: str): - # TODO(scotttodd): use onnx.TensorProto instead? prefer enums over strings - if type == "tensor(float)": - return np.float32 - raise NotImplementedError(f"type conversion for '{type}' not implemented") +def convert_ort_shape_to_numpy_dimensions( + node_arg: NodeArg, +) -> tuple[int]: + # Note: turning dynamic dimensions into just 1 here, since we need + # a concrete (static) shape buffer of input data in the tests. + # TODO(scotttodd): allow this to be overriden as needed + return tuple(x if isinstance(x, int) else 1 for x in node_arg.shape) + + +def convert_ort_type_to_numpy_dtype(node_arg: NodeArg): + type_str = node_arg.type + if type_str[0:6] != "tensor": + raise TypeError(f"node: {node_arg} has unhandled non-tensor type '{type_str}'") + dtype_str = type_str[7:-1] + if dtype_str == "float": + return np.dtype("float32") + if dtype_str == "int" or dtype_str == "int32": + return np.dtype("int32") + if dtype_str == "int64": + return np.dtype("int64") + if dtype_str == "int8": + return np.dtype("int8") + if dtype_str == "uint8": + return np.dtype("uint8") + if dtype_str == "bool": + return np.dtype("bool") + raise NotImplementedError(f"type conversion for '{type_str}' not implemented") + + +def convert_ort_type_to_iree_dtype(node_arg: NodeArg) -> str: + numpy_dtype = convert_ort_type_to_numpy_dtype(node_arg) + return numpy_to_iree_dtype_map[numpy_dtype][0] + + +def convert_ort_to_iree_type( + node_arg: NodeArg, +) -> str: + # Note: turning dynamic dimensions into just "1" here, since we need + # a concrete (static) shape buffer of input data in the tests. + # TODO(scotttodd): allow this to be overriden as needed + shape = "x".join([str(x) if isinstance(x, int) else "1" for x in node_arg.shape]) + dtype = convert_ort_type_to_iree_dtype(node_arg) + if shape == "": + return dtype + return f"{shape}x{dtype}" + + +def generate_numpy_input_for_ort_node_arg(node_arg: NodeArg): + numpy_dimensions = convert_ort_shape_to_numpy_dimensions(node_arg) + numpy_type = convert_ort_type_to_numpy_dtype(node_arg).type + if numpy_type == np.float32 or numpy_type == np.float64: + return rng.random(numpy_dimensions, dtype=numpy_type) + if numpy_type == np.int32 or numpy_type == np.int64: + return rng.integers(numpy_dimensions, dtype=numpy_type) + # TODO(scotttodd): test i8, bool, and other dtypes + # if numpy_type == np.int8: + # return rng.integers(-127, 128, size=numpy_dimensions, dtype=numpy_type) -def convert_node_arg_type_to_iree_dtype(type: str) -> str: - # TODO(scotttodd): use onnx.TensorProto instead? prefer enums over strings - if type == "tensor(float)": - return "f32" - raise NotImplementedError(f"type conversion for '{type}' not implemented") + raise NotImplementedError(f"Unsupported numpy type: {numpy_type}") # TODO(#18289): use real frontend API, import model in-memory?