diff --git a/README.md b/README.md index a1a95f7f..fc4af18d 100644 --- a/README.md +++ b/README.md @@ -206,6 +206,50 @@ running python -c "from tvm.runtime import Module" ``` +### Step 2-tris: Install ONNX-MLIR Compiler +Since `ONNX-MLIR` compiler need to be installed from the source code, it build +the specific commit of [llvm-project](https://github.com/llvm/llvm-project) from +source, further builds [ONNX-MLIR](https://github.com/onnx/onnx-mlir) from source. +Its installation can take several minutes (or even hours) for being performed. +For this reason we decided to not include it in the default auto-installer. However, +if you want to squeeze out the maximum of the performance from your model on your machine, +we highly recommend installing ONNX-MLIR as well. With `nebullvm` it is super-easy! Just run + +``` +python -c "from nebullvm.installers.installers import install_onnx_mlir; install_onnx_mlir()" +``` +and wait for the compiler to be installed! You can check that everything worked by running +``` +python -c "from nebullvm.inference_learners import onnx_mlir" +``` +The default installation location of `ONNX-MLIR` is in home directory, +if you want to install it in different directory, installation directory can be +changed by specifying the custom installation directory in `install_onnx_mlir` function +``` +python -c "from nebullvm.installers.installers import install_onnx_mlir; +install_onnx_mlir(absolute_path_of_installation_directory)" +``` +and wait for the compiler to be installed! +In the subsequent run, specify ONNX-MLIR installation directory by exporting +the environment variable `MLIR_INSTALLATION_ROOT="absolute_path_of_installation_directory"` by running + +``` +export MLIR_INSTALLATION_ROOT="absolute_path_of_installation_directory" +``` + +from your command line or adding + +``` +import os +os.environ["MLIR_INSTALLATION_ROOT"] = "absolute_path_of_installation_directory" +``` +in your python code before importing `nebullvm` for the first time. +You can check that everything worked by running +``` +export MLIR_INSTALLATION_ROOT="absolute_path_of_installation_directory" +python -c "from nebullvm.inference_learners import onnx_mlir" +``` +
diff --git a/nebullvm/base.py b/nebullvm/base.py index 6d8746f0..870c80b6 100644 --- a/nebullvm/base.py +++ b/nebullvm/base.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Tuple, List, Dict, Union +from typing import Dict, List, Tuple, Union class DataType(str, Enum): @@ -106,6 +106,7 @@ class ModelCompiler(Enum): OPENVINO = "openvino" APACHE_TVM = "tvm" ONNX_RUNTIME = "onnxruntime" + ONNX_MLIR = "onnx_mlir" class QuantizationType(Enum): diff --git a/nebullvm/config.py b/nebullvm/config.py index 05135099..511f8cd8 100644 --- a/nebullvm/config.py +++ b/nebullvm/config.py @@ -37,3 +37,7 @@ "description_file": "description.xml", "weights": "weights.bin", } + +ONNX_MLIR_FILENAMES = { + "model_name": "mlir_model.so", +} diff --git a/nebullvm/inference_learners/onnx_mlir.py b/nebullvm/inference_learners/onnx_mlir.py new file mode 100644 index 00000000..14b0de9a --- /dev/null +++ b/nebullvm/inference_learners/onnx_mlir.py @@ -0,0 +1,220 @@ +import copy +import os +import shutil +import sys +import warnings +from abc import ABC +from pathlib import Path +from typing import Dict, Generator, List, Tuple, Type, Union + +import cpuinfo +import numpy as np +import tensorflow as tf +import torch +from nebullvm.base import DeepLearningFramework, ModelParams +from nebullvm.config import ONNX_MLIR_FILENAMES +from nebullvm.inference_learners.base import ( + BaseInferenceLearner, + LearnerMetadata, + PytorchBaseInferenceLearner, + TensorflowBaseInferenceLearner, +) + +try: + # Set the ONNX_MLIR_HOME as the environment variable and append in the path, + # directory path where the MLIR is built + + # retrieve the ONNX-MLIR installation directory from environment variable + # if exists otherwise set to home directory + MLIR_INSTALLATION_ROOT = os.environ.get( + "MLIR_INSTALLATION_ROOT", Path.home() + ) + + os.environ["ONNX_MLIR_HOME"] = os.path.join( + MLIR_INSTALLATION_ROOT, + "onnx-mlir", + "build", + "Debug", + ) + + sys.path.append( + os.path.join( + os.environ.get("ONNX_MLIR_HOME", ""), + "lib", + ) + ) + import PyRuntime +except ImportError: + # Disable the ONNX-MLIR auto-installer for the time being as it takes long to + # install, can be installed by explicitly running the install_onnx_mlir function + + # TODO: Remove the False flag for allowing onnx-mlir to be installed by + # the Auto-Installer. + if False and not NO_COMPILER_INSTALLATION: + warnings.warn( + "No valid onnx-mlir installation found. Trying to install it..." + ) + from nebullvm.installers.installers import install_onnx_mlir + + install_onnx_mlir( + working_dir=MLIR_INSTALLATION_ROOT, + ) + import PyRuntime + else: + warnings.warn( + "Not found any valid onnx-mlir installation. " + "ONNX-MLIR will not be available in the following steps. " + "ONNX-MLIR should be explicitly installed using install_onnx_mlir." + ) + + +class ONNXMlirInferenceLearner(BaseInferenceLearner, ABC): + """Model converted from ONNX to Shared Object file using ONNX-MLIR dialect + and run with ONNX-MLIR's PyRuntime + created at onnx-mlir/build/Debug/lib/PyRuntime.cpython-.so. + + Attributes: + onnx_mlir_model_path (str or Path): Path to the shared object mlir model. + network_parameters (ModelParams): The model parameters as batch + size, input and output sizes. + """ + + def __init__( + self, + onnx_mlir_model_path: Union[str, Path], + **kwargs, + ): + super().__init__(**kwargs) + self.onnx_mlir_model_path = onnx_mlir_model_path + self._session = PyRuntime.ExecutionSession( + os.path.abspath(str(self.onnx_mlir_model_path)), + ) + + def save(self, path: Union[str, Path], **kwargs): + """Save the model. + + Args: + path (Path or str): Path to the directory where the model will + be stored. + kwargs (Dict): Dictionary of key-value pairs that will be saved in + the model metadata file. + """ + metadata = LearnerMetadata.from_model(self, **kwargs) + metadata.save(path) + + shutil.copy( + self.onnx_mlir_model_path, + os.path.join(str(path), ONNX_MLIR_FILENAMES["model_name"]), + ) + + @classmethod + def load(cls, path: Union[Path, str], **kwargs): + """Load the model. + + Args: + path (Path or str): Path to the directory where the model is + stored. + kwargs (Dict): Dictionary of additional arguments for consistency + with other Learners. + + Returns: + ONNXInferenceLearner: The optimized model. + """ + if len(kwargs) > 0: + warnings.warn( + f"No extra keywords expected for the load method. " + f"Got {kwargs}." + ) + onnx_mlir_model_path = os.path.join( + str(path), ONNX_MLIR_FILENAMES["model_name"] + ) + metadata = LearnerMetadata.read(path) + + return cls( + network_parameters=ModelParams(**metadata.network_parameters), + onnx_mlir_model_path=onnx_mlir_model_path, + ) + + def _predict_arrays(self, input_arrays: Generator[np.ndarray, None, None]): + outputs = self._session.run(list(input_arrays)) + return outputs + + +class PytorchONNXMlirInferenceLearner( + ONNXMlirInferenceLearner, PytorchBaseInferenceLearner +): + """Model run with ONNX-MLIR's PyRuntime using a Pytorch interface. + + Attributes: + onnx_mlir_model_path (str or Path): Path to the shared object mlir model. + network_parameters (ModelParams): The model parameters as batch + size, input and output sizes. + """ + + def predict(self, *input_tensors: torch.Tensor) -> Tuple[torch.Tensor]: + """Predict on the input tensors. + + Note that the input tensors must be on the same batch. If a sequence + of tensors is given when the model is expecting a single input tensor + (with batch size >= 1) an error is raised. + + Args: + input_tensors (Tuple[Tensor]): Input tensors belonging to the same + batch. The tensors are expected having dimensions + (batch_size, dim1, dim2, ...). + + Returns: + Tuple[Tensor]: Output tensors. Note that the output tensors does + not correspond to the prediction on the input tensors with a + 1 to 1 mapping. In fact the output tensors are produced as the + multiple-output of the model given a (multi-) tensor input. + """ + input_arrays = ( + input_tensor.cpu().detach().numpy() + for input_tensor in input_tensors + ) + outputs = self._predict_arrays(input_arrays) + return tuple(torch.from_numpy(output) for output in outputs) + + +class TensorflowONNXMlirInferenceLearner( + ONNXMlirInferenceLearner, TensorflowBaseInferenceLearner +): + """Model run with ONNX-MLIR's PyRuntime using a tensorflow interface. + + Attributes: + onnx_mlir_model_path (str or Path): Path to the shared object mlir model. + network_parameters (ModelParams): The model parameters as batch + size, input and output sizes. + """ + + def predict(self, *input_tensors: tf.Tensor) -> Tuple[tf.Tensor]: + """Predict on the input tensors. + + Note that the input tensors must be on the same batch. If a sequence + of tensors is given when the model is expecting a single input tensor + (with batch size >= 1) an error is raised. + + Args: + input_tensors (Tuple[Tensor]): Input tensors belonging to the same + batch. The tensors are expected having dimensions + (batch_size, dim1, dim2, ...). + + Returns: + Tuple[Tensor]: Output tensors. Note that the output tensors does + not correspond to the prediction on the input tensors with a + 1 to 1 mapping. In fact the output tensors are produced as the + multiple-output of the model given a (multi-) tensor input. + """ + input_arrays = (input_tensor.numpy() for input_tensor in input_tensors) + outputs = self._predict_arrays(input_arrays) + # noinspection PyTypeChecker + return tuple(tf.convert_to_tensor(output) for output in outputs) + + +ONNX_MLIR_INFERENCE_LEARNERS: Dict[ + DeepLearningFramework, Type[ONNXMlirInferenceLearner] +] = { + DeepLearningFramework.PYTORCH: PytorchONNXMlirInferenceLearner, + DeepLearningFramework.TENSORFLOW: TensorflowONNXMlirInferenceLearner, +} diff --git a/nebullvm/installers/install_onnx_mlir.sh b/nebullvm/installers/install_onnx_mlir.sh new file mode 100644 index 00000000..9353b4a7 --- /dev/null +++ b/nebullvm/installers/install_onnx_mlir.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +# Installation steps to build the ONNX-MLIR from source + +# Set non interactive mode for apt-get +export DEBIAN_FRONTEND=noninteractive + +# Build ONNX-MLIR + +if [ ! -d "onnx-mlir" ] +then + + git clone --recursive https://github.com/onnx/onnx-mlir.git onnx-mlir +fi + + +if [ -z "$NPROC" ] +then + NPROC=4 +fi + + +# Export environment variables pointing to LLVM-Projects. +export MLIR_DIR=$(pwd)/llvm-project/build/lib/cmake/mlir + +# Get the python interpreter path +export PYTHON_LOCATION=$(which python3) + +mkdir onnx-mlir/build && cd onnx-mlir/build + +if [[ -z "$PYTHON_LOCATION" ]]; then + cmake -G Ninja \ + -DCMAKE_CXX_COMPILER=/usr/bin/c++ \ + -DMLIR_DIR=${MLIR_DIR} \ + .. +else + echo "Using python path " $PYTHON_LOCATION + echo "Using MLIR_DIR " $MLIR_DIR + + cmake -G Ninja \ + -DCMAKE_CXX_COMPILER=/usr/bin/c++ \ + -DPython3_ROOT_DIR=${PYTHON_LOCATION} \ + -DPython3_EXECUTABLE=${PYTHON_LOCATION} \ + -DMLIR_DIR=${MLIR_DIR} \ + .. + +fi + +cmake --build . --parallel $NPROC + +# Run lit tests: +export LIT_OPTS=-v +cmake --build . --parallel $NPROC --target check-onnx-lit diff --git a/nebullvm/installers/install_onnx_mlir_prerequisites.sh b/nebullvm/installers/install_onnx_mlir_prerequisites.sh new file mode 100644 index 00000000..4c4c42ff --- /dev/null +++ b/nebullvm/installers/install_onnx_mlir_prerequisites.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +# Installation steps to build and install the llvm-project from source + +# Set non interactive mode for apt-get +export DEBIAN_FRONTEND=noninteractive + +if [ -z "$NPROC" ] +then + export NPROC=4 +fi + +# Install the OS dependent required packages +if [[ $OSTYPE == "darwin"* ]] +then + brew install gcc git cmake ninja pybind11 +elif [[ "$(grep '^ID_LIKE' /etc/os-release)" == *"centos"* ]] +then + sudo yum update -q -y && \ + sudo yum install -q -y \ + autoconf automake ca-certificates cmake diffutils \ + file java-11-openjdk-devel java-11-openjdk-headless \ + gcc gcc-c++ git libtool make ncurses-devel \ + zlib-devel && \ + # Install ninja + git clone -b v1.10.2 https://github.com/ninja-build/ninja.git && \ + cd ninja && mkdir -p build && cd build && \ + cmake .. && \ + make -j$NPROC install && \ + cd ../.. && rm -rf ninja; +else + sudo apt-get update && sudo apt-get install -y --no-install-recommends \ + autoconf automake ca-certificates cmake curl \ + default-jdk-headless gcc g++ git libncurses-dev \ + libtool make maven ninja-build openjdk-11-jdk-headless \ + zlib1g-dev + +fi + +# Install protobuf +PROTOBUF_VERSION=3.14.0 +git clone -b v$PROTOBUF_VERSION --recursive https://github.com/google/protobuf.git \ + && cd protobuf && ./autogen.sh \ + && ./configure --enable-static=no \ + && make -j$NPROC install && ldconfig \ + && cd python && python setup.py install \ + && cd ../.. && rm -rf protobuf + +# Install jsoniter +JSONITER_VERSION=0.9.23 +JSONITER_URL=https://repo1.maven.org/maven2/com/jsoniter/jsoniter/$JSONITER_VERSION \ + && JSONITER_FILE=jsoniter-$JSONITER_VERSION.jar \ + && curl -s $JSONITER_URL/$JSONITER_FILE -o /usr/share/java/$JSONITER_FILE + + +# ONNX-MLIR needs the llvm-project build from the source + +# Firstly, install MLIR (as a part of LLVM-Project): +git clone -n https://github.com/llvm/llvm-project.git + + +# Check out a specific branch that is known to work with ONNX-MLIR. +# TBD: Option to set the commit hash dynamically +cd llvm-project && git checkout a7ac120a9ad784998a5527fc0a71b2d0fd55eccb && cd .. + +mkdir llvm-project/build +cd llvm-project/build + +cmake -G Ninja ../llvm \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_TARGETS_TO_BUILD="host" \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_RTTI=ON + +cmake --build . --parallel $NPROC -- ${MAKEFLAGS} +cmake --build . --parallel $NPROC --target check-mlir diff --git a/nebullvm/installers/installers.py b/nebullvm/installers/installers.py index 54f13068..cc0de8d6 100644 --- a/nebullvm/installers/installers.py +++ b/nebullvm/installers/installers.py @@ -104,3 +104,32 @@ def install_onnxruntime(): # install requirements for onnxruntime.transformers cmd = ["pip3", "install", "coloredlogs", "sympy"] subprocess.run(cmd) + + +def install_onnx_mlir(working_dir: str = None): + """Helper function for installing Onnx-MLIR. + + This function build llvm-project from source that is a prerequisite for + ONNX-MLIR followed by the building of ONNX-MLIR dialect from source. + + Args: + working_dir (str, optional): The directory where the llvm-project and onnx-mlir + repo will be cloned and installed. + """ + path = Path(__file__).parent + + # install pre-requisites, exclusively depends on the llvm-project + installation_file_prerequisites = str( + path / "install_onnx_mlir_prerequisites.sh" + ) + subprocess.run( + ["bash", installation_file_prerequisites], + cwd=working_dir or Path.home(), + ) + + # build and install onnx-mlir + installation_file = str(path / "install_onnx_mlir.sh") + subprocess.run( + ["bash", installation_file], + cwd=working_dir or Path.home(), + ) diff --git a/nebullvm/optimizers/__init__.py b/nebullvm/optimizers/__init__.py index 3d67fdbc..9e6e113f 100644 --- a/nebullvm/optimizers/__init__.py +++ b/nebullvm/optimizers/__init__.py @@ -3,5 +3,6 @@ from nebullvm.optimizers.openvino import OpenVinoOptimizer # noqa F401 from nebullvm.optimizers.tensor_rt import TensorRTOptimizer # noqa F401 from nebullvm.optimizers.tvm import ApacheTVMOptimizer # noqa F401 +from nebullvm.optimizers.onnx_mlir import ONNXMlirOptimizer # noqa F401 -__all__ = [k for k in globals().keys() if not k.startswith("_")] +__all__ = [k for k in globals().keys() if not k.startswith("_")] \ No newline at end of file diff --git a/nebullvm/optimizers/multi_compiler.py b/nebullvm/optimizers/multi_compiler.py index 13d2abc6..decf27bc 100644 --- a/nebullvm/optimizers/multi_compiler.py +++ b/nebullvm/optimizers/multi_compiler.py @@ -1,4 +1,6 @@ import json +import os +import sys import warnings from logging import Logger from pathlib import Path @@ -25,6 +27,7 @@ ApacheTVMOptimizer, OpenVinoOptimizer, ONNXOptimizer, + ONNXMlirOptimizer, ) from nebullvm.transformations.base import MultiStageTransformation from nebullvm.utils.data import DataManager @@ -34,9 +37,42 @@ ModelCompiler.OPENVINO: OpenVinoOptimizer, ModelCompiler.TENSOR_RT: TensorRTOptimizer, ModelCompiler.ONNX_RUNTIME: ONNXOptimizer, + ModelCompiler.ONNX_MLIR: ONNXMlirOptimizer, } +def _onnx_mlir_pyruntime_is_available() -> bool: + try: + # Set the ONNX_MLIR_HOME as the environment variable and append in the path, + # directory path where the MLIR is built + + # retrieve the ONNX-MLIR installation directory from environment variable + # if exists otherwise set to home directory + MLIR_INSTALLATION_ROOT = os.environ.get( + "MLIR_INSTALLATION_ROOT", Path.home() + ) + + os.environ["ONNX_MLIR_HOME"] = os.path.join( + MLIR_INSTALLATION_ROOT, + "onnx-mlir", + "build", + "Debug", + ) + + sys.path.append( + os.path.join( + os.environ.get("ONNX_MLIR_HOME", ""), + "lib", + ) + ) + + import PyRuntime # noqa F401 + + return True + except ImportError: + return False + + def _tvm_is_available() -> bool: try: import tvm # noqa F401 @@ -47,7 +83,13 @@ def _tvm_is_available() -> bool: def select_compilers_from_hardware(): - compilers = [ModelCompiler.ONNX_RUNTIME] + compilers = [ + ModelCompiler.ONNX_RUNTIME, + ] + + if _onnx_mlir_pyruntime_is_available(): + compilers.append(ModelCompiler.ONNX_MLIR) + if _tvm_is_available(): compilers.append(ModelCompiler.APACHE_TVM) if torch.cuda.is_available(): diff --git a/nebullvm/optimizers/onnx_mlir.py b/nebullvm/optimizers/onnx_mlir.py new file mode 100644 index 00000000..63ec77e5 --- /dev/null +++ b/nebullvm/optimizers/onnx_mlir.py @@ -0,0 +1,70 @@ +import os +import subprocess +from pathlib import Path + +from nebullvm.base import DeepLearningFramework, ModelParams +from nebullvm.inference_learners.onnx_mlir import ( + ONNX_MLIR_INFERENCE_LEARNERS, + ONNXMlirInferenceLearner, +) +from nebullvm.optimizers.base import BaseOptimizer + + +class ONNXMlirOptimizer(BaseOptimizer): + """Class for compiling the AI models from ONNX format to equivalent MLIR dialect.""" + + def optimize( + self, + onnx_model: str, + output_library: DeepLearningFramework, + model_params: ModelParams, + ) -> ONNXMlirInferenceLearner: + """Optimize the onnx model to MLIR Compiler Infrastructure. + + Args: + onnx_model (str): Path to the saved onnx model. + output_library (str): DL Framework the optimized model will be + compatible with. + model_params (ModelParams): Model parameters. + + Returns: + ONNXMlirInferenceLearner: Model optimized with ONNX-MLIR. The model + will have an interface in the DL library specified in + `output_library`. + """ + inputs = list(model_params.input_sizes) + + shape_info = "--shapeInformation=" + for input_index in range(len(inputs)): + shape_info += ( + f"{input_index}:{model_params.batch_size}x" + + f"x".join(map(str, inputs[input_index])) + + "," + ) + shape_info = shape_info[:-1] + + command = [ + "./onnx-mlir", + "--EmitLib", + "--O3", + shape_info, + onnx_model, + ] + process = subprocess.Popen( + command, + cwd=os.path.join( + os.environ.get("ONNX_MLIR_HOME", ""), + "bin", + ), + ) + process.wait() + + base_path = Path(onnx_model).parent + onnx_mlir_model_path = base_path / f"{Path(onnx_model).stem}.so" + + model = ONNX_MLIR_INFERENCE_LEARNERS[output_library]( + onnx_mlir_model_path=str(onnx_mlir_model_path), + network_parameters=model_params, + ) + + return model diff --git a/nebullvm/optimizers/tests/test_onnx_mlir.py b/nebullvm/optimizers/tests/test_onnx_mlir.py new file mode 100644 index 00000000..8e7bd576 --- /dev/null +++ b/nebullvm/optimizers/tests/test_onnx_mlir.py @@ -0,0 +1,34 @@ +from tempfile import TemporaryDirectory + +import pytest + +from nebullvm.base import DeepLearningFramework +from nebullvm.inference_learners.onnx_mlir import ONNX_MLIR_INFERENCE_LEARNERS +from nebullvm.optimizers.onnx_mlir import ONNXMlirOptimizer +from nebullvm.optimizers.tests.utils import get_onnx_model + + +@pytest.mark.parametrize( + ("output_library", "dynamic"), + [ + (DeepLearningFramework.PYTORCH, True), + (DeepLearningFramework.PYTORCH, False), + ], +) +def test_onnxruntime(output_library: DeepLearningFramework, dynamic: bool): + with TemporaryDirectory() as tmp_dir: + model_path, model_params = get_onnx_model(tmp_dir, dynamic) + optimizer = ONNXMlirOptimizer() + model = optimizer.optimize(model_path, output_library, model_params) + assert isinstance(model, ONNX_MLIR_INFERENCE_LEARNERS[output_library]) + + inputs_example = list(model.get_inputs_example()) + res = model.predict(*inputs_example) + assert res is not None + + if dynamic: # Check also with a smaller bath_size + inputs_example = [ + input_[: len(input_) // 2] for input_ in inputs_example + ] + res = model.predict(*inputs_example) + assert res is not None