diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..7abcfd78 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitignore b/.gitignore index 08f63a8e..6f2df537 100644 --- a/.gitignore +++ b/.gitignore @@ -128,3 +128,6 @@ dmypy.json # Pyre type checker .pyre/ + +# MacOS DS_Store +.DS_Store diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..648154f6 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,46 @@ +ARG STARTING_IMAGE=nvidia/cuda:11.2.0-runtime-ubuntu20.04 +FROM ${STARTING_IMAGE} + +# Set frontend as non-interactive +ARG DEBIAN_FRONTEND=noninteractive + +RUN apt-get update + +# Install python and pip +RUN apt-get install -y python3-opencv python3-pip && \ + python3 -m pip install --upgrade pip && \ + apt-get -y install git + +# Install nebullvm +ARG NEBULLVM_VERSION=latest +RUN if [ "$NEBULLVM_VERSION" = "latest" ] ; then \ + # pip install nebullvm ; \ + git clone https://github.com/nebuly-ai/nebullvm.git ; \ + cd nebullvm ; \ + pip install . ;\ + else \ + pip install nebullvm==${NEBULLVM_VERSION} ; \ + fi + +# Install required python modules +RUN pip install scipy==1.5.4 && \ + pip install cmake + +# Install default deep learning compilers +ARG COMPILER=all +ENV NO_COMPILER_INSTALLATION=1 +RUN if [ "$COMPILER" = "all" ] ; then \ + python3 -c "import os; os.environ['NO_COMPILER_INSTALLATION'] = '0'; import nebullvm" ; \ + elif [ "$COMPILER" = "tensorrt" ] ; then \ + python3 -c "from nebullvm.installers.installers import install_tensor_rt; install_tensor_rt()" ; \ + elif [ "$COMPILER" = "openvino" ] ; then \ + python3 -c "from nebullvm.installers.installers import install_openvino; install_openvino()" ; \ + elif [ "$COMPILER" = "onnxruntime" ] ; then \ + python3 -c "from nebullvm.installers.installers import install_onnxruntime; install_onnxruntime()" ; \ + fi + +# Install TVM +RUN if [ "$COMPILER" = "all" ] || [ "$COMPILER" = "tvm" ] ; then \ + python3 -c "from nebullvm.installers.installers import install_tvm; install_tvm()" ; \ + python3 -c "from tvm.runtime import Module" ; \ + fi diff --git a/docker_build.sh b/docker_build.sh new file mode 100644 index 00000000..10274702 --- /dev/null +++ b/docker_build.sh @@ -0,0 +1,8 @@ +# Create image with all compilers installed +docker build -t nebullvm-all . + +# Create an image for each compiler installed +docker build -t nebullvm-onnxruntime . --build-arg COMPILER="onnxruntime" +docker build -t nebullvm-openvino . --build-arg COMPILER="openvino" +docker build -t nebullvm-tvm . --build-arg COMPILER="tvm" +docker build -t nebullvm-tensorrt . --build-arg COMPILER="tensorrt" diff --git a/nebullvm/api/frontend/onnx.py b/nebullvm/api/frontend/onnx.py index d2530697..4002b50b 100644 --- a/nebullvm/api/frontend/onnx.py +++ b/nebullvm/api/frontend/onnx.py @@ -225,7 +225,7 @@ def optimize_onnx_model( ) if model_optimizer.usable: model_optimized = model_optimizer.optimize( - onnx_model=str(onnx_path), + model=str(onnx_path), output_library=dl_library, model_params=model_params, input_tfms=input_tfms, diff --git a/nebullvm/api/frontend/tf.py b/nebullvm/api/frontend/tf.py index 24910b30..0aeb2a5b 100644 --- a/nebullvm/api/frontend/tf.py +++ b/nebullvm/api/frontend/tf.py @@ -1,9 +1,13 @@ +import logging import os +import warnings from pathlib import Path from tempfile import TemporaryDirectory from typing import List, Tuple, Union, Dict, Optional, Callable, Any +import numpy as np import tensorflow as tf +from tqdm import tqdm from nebullvm.api.frontend.utils import ( ifnone, @@ -15,9 +19,13 @@ ModelParams, InputInfo, ModelCompiler, + QuantizationType, ) from nebullvm.converters import ONNXConverter +from nebullvm.inference_learners import TensorflowBaseInferenceLearner +from nebullvm.measure import compute_optimized_running_time from nebullvm.optimizers import BaseOptimizer +from nebullvm.optimizers.tensorflow import TensorflowBackendOptimizer from nebullvm.transformations.base import MultiStageTransformation from nebullvm.utils.data import DataManager from nebullvm.utils.tf import ( @@ -27,6 +35,12 @@ ) from nebullvm.optimizers.multi_compiler import MultiCompilerOptimizer +logging.basicConfig( + format="%(asctime)s %(message)s", datefmt="%d/%m/%Y %I:%M:%S %p" +) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + def _extract_dynamic_axis( tf_model: tf.Module, @@ -216,13 +230,36 @@ def optimize_tf_model( ignore_compilers=ignore_compilers, extra_optimizers=custom_optimizers, debug_mode=int(os.environ.get("DEBUG_MODE", "0")) > 0, + logger=logger, ) with TemporaryDirectory() as tmp_dir: + logger.info("Running Optimization using tensorflow interface (1/3)") + if perf_loss_ths is not None: + q_types = [ + None, + QuantizationType.DYNAMIC, + QuantizationType.HALF, + ] + if dataset is not None: + q_types.append(QuantizationType.STATIC) + else: + q_types = [None] + torch_res = [ + _torch_api_optimization( + model, model_params, perf_loss_ths, q_type, False, input_data + ) + for q_type in tqdm(q_types) + ] + (tf_api_model, tf_api_latency, used_compilers,) = sorted( + torch_res, key=lambda x: x[1] + )[0] + ignore_compilers.extend(used_compilers) + logger.info("Running Optimization using ONNX interface (2/3)") onnx_path = model_converter.convert( model, model_params.input_sizes, Path(tmp_dir) ) model_optimized = model_optimizer.optimize( - onnx_model=str(onnx_path), + model=str(onnx_path), output_library=dl_library, model_params=model_params, input_tfms=input_tfms, @@ -230,5 +267,86 @@ def optimize_tf_model( perf_metric=perf_metric, input_data=input_data, ) + logger.info("Running comparison between optimized models (3/3).") + model_optimized = _compare_optimized_models( + model_optimized, tf_api_model, tf_api_latency + ) + if model_optimized is None: + raise RuntimeError( + "No valid compiled model has been produced. " + "Look at the logs for further information about the failure." + ) model_optimized.save(save_dir) return model_optimized.load(save_dir) + + +def _compare_optimized_models( + new_model: TensorflowBaseInferenceLearner, + previous_best_model: TensorflowBaseInferenceLearner, + previous_latency: float, +) -> TensorflowBaseInferenceLearner: + if new_model is not None: + new_latency = compute_optimized_running_time(new_model) + if new_latency < previous_latency: + return new_model + return previous_best_model + + +def _get_optimizers_supporting_tf_api(use_extra_compilers: bool): + if use_extra_compilers: + logger.warning( + "No compiler found supporting the tensorflow interface." + ) + return [(ModelCompiler.TFLITE, TensorflowBackendOptimizer(logger=logger))] + + +def _torch_api_optimization( + model: tf.Module, + model_params: ModelParams, + quantization_ths: float, + quantization_type: QuantizationType, + use_extra_compilers: bool, + input_data: DataManager, +) -> Tuple[Optional[TensorflowBaseInferenceLearner], float, List]: + used_compilers = [] + best_tf_opt_model = None + best_latency = np.inf + for compiler, optimizer in tqdm( + _get_optimizers_supporting_tf_api(use_extra_compilers) + ): + try: + if hasattr(optimizer, "optimize_from_tf"): + candidate_model = optimizer.optimize_from_tf( + torch_model=model, + model_params=model_params, + perf_loss_ths=quantization_ths + if quantization_type is not None + else None, + quantization_type=quantization_type, + input_data=input_data, + ) + else: + candidate_model = optimizer.optimize( + model=model, + output_library=DeepLearningFramework.PYTORCH, + model_params=model_params, + perf_loss_ths=quantization_ths + if quantization_type is not None + else None, + quantization_type=quantization_type, + input_data=input_data, + ) + candidate_latency = compute_optimized_running_time(candidate_model) + if candidate_latency < best_latency: + best_latency = candidate_latency + best_tf_opt_model = candidate_model + used_compilers.append(compiler) + except Exception as ex: + warnings.warn( + f"Compilation failed with torch interface of {compiler}. " + f"Got error {ex}. If possible the compilation will be " + f"re-scheduled with the ONNX interface. Please consult the " + f"documentation for further info or open an issue on GitHub " + f"for receiving assistance." + ) + return best_tf_opt_model, best_latency, used_compilers diff --git a/nebullvm/api/frontend/torch.py b/nebullvm/api/frontend/torch.py index f776e626..32629ac3 100644 --- a/nebullvm/api/frontend/torch.py +++ b/nebullvm/api/frontend/torch.py @@ -1,3 +1,4 @@ +import logging import os import warnings from pathlib import Path @@ -7,6 +8,7 @@ import numpy as np import torch from torch.utils.data import DataLoader +from tqdm import tqdm from nebullvm.api.frontend.utils import ( check_inputs, @@ -22,6 +24,7 @@ QuantizationType, ) from nebullvm.converters import ONNXConverter +from nebullvm.optimizers.pytorch import PytorchBackendOptimizer from nebullvm.transformations.base import MultiStageTransformation from nebullvm.utils.data import DataManager from nebullvm.utils.torch import ( @@ -34,6 +37,12 @@ from nebullvm.optimizers import ApacheTVMOptimizer, BaseOptimizer from nebullvm.optimizers.multi_compiler import MultiCompilerOptimizer +logging.basicConfig( + format="%(asctime)s %(message)s", datefmt="%d/%m/%Y %I:%M:%S %p" +) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + def _extract_dynamic_axis( torch_model: torch.nn.Module, @@ -246,39 +255,46 @@ def optimize_torch_model( ) input_tfms = MultiStageTransformation([]) with TemporaryDirectory() as tmp_dir: - if use_torch_api: - if perf_loss_ths is not None: - q_types = [ - None, - QuantizationType.DYNAMIC, - QuantizationType.HALF, - ] - if dataloader is not None: - q_types.append(QuantizationType.STATIC) - else: - q_types = [None] - torch_res = [ - _torch_api_optimization( - model, model_params, perf_loss_ths, q_type - ) - for q_type in q_types + logger.info("Running Optimization using torch interface (1/3)") + if perf_loss_ths is not None: + q_types = [ + None, + QuantizationType.DYNAMIC, + QuantizationType.HALF, ] - (torch_api_model, torch_api_latency, used_compilers,) = sorted( - torch_res, key=lambda x: x[1] - )[0] - ignore_compilers.extend(used_compilers) + if dataloader is not None: + q_types.append(QuantizationType.STATIC) + else: + q_types = [None] + torch_res = [ + _torch_api_optimization( + model, + model_params, + perf_loss_ths, + q_type, + use_torch_api, + input_data, + ) + for q_type in tqdm(q_types) + ] + (torch_api_model, torch_api_latency, used_compilers,) = sorted( + torch_res, key=lambda x: x[1] + )[0] + ignore_compilers.extend(used_compilers) + logger.info("Running Optimization using ONNX interface (2/3)") model_converter = ONNXConverter() model_optimizer = MultiCompilerOptimizer( ignore_compilers=ignore_compilers, extra_optimizers=custom_optimizers, debug_mode=int(os.environ.get("DEBUG_MODE", "0")) > 0, + logger=logger, ) if model_optimizer.usable: onnx_path = model_converter.convert( model, model_params, Path(tmp_dir) ) model_optimized = model_optimizer.optimize( - onnx_model=str(onnx_path), + model=str(onnx_path), output_library=dl_library, model_params=model_params, input_tfms=input_tfms, @@ -288,12 +304,12 @@ def optimize_torch_model( ) else: model_optimized = None - if use_torch_api: - model_optimized = _compare_optimized_models( - model_optimized, - torch_api_model, - torch_api_latency, - ) + logger.info("Running comparison between optimized models (3/3).") + model_optimized = _compare_optimized_models( + model_optimized, + torch_api_model, + torch_api_latency, + ) if model_optimized is None: raise RuntimeError( "No valid compiled model has been produced. " @@ -303,32 +319,68 @@ def optimize_torch_model( return model_optimized.load(save_dir) +def _get_optimizers_supporting_torch_api( + use_extra_compilers: bool, +) -> List[Tuple[ModelCompiler, BaseOptimizer]]: + optimizers = [ + (ModelCompiler.TORCHVISION, PytorchBackendOptimizer(logger=logger)), + ] + if use_extra_compilers: + optimizers.append( + (ModelCompiler.APACHE_TVM, ApacheTVMOptimizer(logger=logger)) + ) + return optimizers + + def _torch_api_optimization( model: torch.nn.Module, model_params: ModelParams, quantization_ths: float, quantization_type: QuantizationType, + use_extra_compilers: bool, + input_data: DataManager, ) -> Tuple[Optional[PytorchBaseInferenceLearner], float, List]: - try: - best_torch_opt_model = ApacheTVMOptimizer().optimize_from_torch( - torch_model=model, - model_params=model_params, - perf_loss_ths=quantization_ths - if quantization_type is not None - else None, - quantization_type=quantization_type, - ) - best_latency = compute_optimized_running_time(best_torch_opt_model) - used_compilers = [ModelCompiler.APACHE_TVM] - except Exception as ex: - warnings.warn( - f"Compilation failed with torch interface of TVM. " - f"Got error {ex}. The compilation will be re-scheduled " - f"with the ONNX interface." - ) - best_torch_opt_model = None - best_latency = np.inf - used_compilers = [] + used_compilers = [] + best_torch_opt_model = None + best_latency = np.inf + for compiler, optimizer in tqdm( + _get_optimizers_supporting_torch_api(use_extra_compilers) + ): + try: + if hasattr(optimizer, "optimize_from_torch"): + candidate_model = optimizer.optimize_from_torch( + torch_model=model, + model_params=model_params, + perf_loss_ths=quantization_ths + if quantization_type is not None + else None, + quantization_type=quantization_type, + input_data=input_data, + ) + else: + candidate_model = optimizer.optimize( + model=model, + output_library=DeepLearningFramework.PYTORCH, + model_params=model_params, + perf_loss_ths=quantization_ths + if quantization_type is not None + else None, + quantization_type=quantization_type, + input_data=input_data, + ) + candidate_latency = compute_optimized_running_time(candidate_model) + if candidate_latency < best_latency: + best_latency = candidate_latency + best_torch_opt_model = candidate_model + used_compilers.append(compiler) + except Exception as ex: + warnings.warn( + f"Compilation failed with torch interface of {compiler}. " + f"Got error {ex}. If possible the compilation will be " + f"re-scheduled with the ONNX interface. Please consult the " + f"documentation for further info or open an issue on GitHub " + f"for receiving assistance." + ) return best_torch_opt_model, best_latency, used_compilers diff --git a/nebullvm/base.py b/nebullvm/base.py index 6d8746f0..3564a805 100644 --- a/nebullvm/base.py +++ b/nebullvm/base.py @@ -106,6 +106,8 @@ class ModelCompiler(Enum): OPENVINO = "openvino" APACHE_TVM = "tvm" ONNX_RUNTIME = "onnxruntime" + TORCHVISION = "torchvision" + TFLITE = "tflite" class QuantizationType(Enum): diff --git a/nebullvm/config.py b/nebullvm/config.py index 05135099..68c48ad2 100644 --- a/nebullvm/config.py +++ b/nebullvm/config.py @@ -37,3 +37,8 @@ "description_file": "description.xml", "weights": "weights.bin", } + +TENSORFLOW_BACKEND_FILENAMES = { + "tflite_model": "tf_model.tflite", + "tf_model": "tf_model.h5", +} diff --git a/nebullvm/inference_learners/onnx.py b/nebullvm/inference_learners/onnx.py index 9d254fea..2192c450 100644 --- a/nebullvm/inference_learners/onnx.py +++ b/nebullvm/inference_learners/onnx.py @@ -11,7 +11,11 @@ import torch from nebullvm.base import DeepLearningFramework, ModelParams -from nebullvm.config import ONNX_FILENAMES, CUDA_PROVIDERS +from nebullvm.config import ( + ONNX_FILENAMES, + CUDA_PROVIDERS, + NO_COMPILER_INSTALLATION, +) from nebullvm.inference_learners.base import ( BaseInferenceLearner, LearnerMetadata, @@ -24,13 +28,24 @@ try: import onnxruntime as ort except ImportError: - warnings.warn( - "No valid onnxruntime installation found. Trying to install it..." - ) - from nebullvm.installers.installers import install_onnxruntime + if NO_COMPILER_INSTALLATION: + warnings.warn( + "No valid onnxruntime installation found. The compiler will raise " + "an error if used." + ) - install_onnxruntime() - import onnxruntime as ort + class ort: + pass + + setattr(ort, "SessionOptions", None) + else: + warnings.warn( + "No valid onnxruntime installation found. Trying to install it..." + ) + from nebullvm.installers.installers import install_onnxruntime + + install_onnxruntime() + import onnxruntime as ort def _is_intel_cpu(): diff --git a/nebullvm/inference_learners/openvino.py b/nebullvm/inference_learners/openvino.py index 5c17d2f9..7d995ec8 100644 --- a/nebullvm/inference_learners/openvino.py +++ b/nebullvm/inference_learners/openvino.py @@ -10,7 +10,7 @@ import tensorflow as tf import torch -from nebullvm.config import OPENVINO_FILENAMES +from nebullvm.config import OPENVINO_FILENAMES, NO_COMPILER_INSTALLATION from nebullvm.inference_learners.base import ( BaseInferenceLearner, LearnerMetadata, @@ -24,7 +24,10 @@ try: from openvino.runtime import Core, Model, CompiledModel, InferRequest except ImportError: - if "intel" in cpuinfo.get_cpu_info()["brand_raw"].lower(): + if ( + "intel" in cpuinfo.get_cpu_info()["brand_raw"].lower() + and not NO_COMPILER_INSTALLATION + ): warnings.warn( "No valid OpenVino installation has been found. " "Trying to re-install it from source." diff --git a/nebullvm/inference_learners/pytorch.py b/nebullvm/inference_learners/pytorch.py new file mode 100644 index 00000000..02331e3e --- /dev/null +++ b/nebullvm/inference_learners/pytorch.py @@ -0,0 +1,63 @@ +from pathlib import Path +from typing import Tuple, Union, Optional + +import torch + +from nebullvm.base import ModelParams +from nebullvm.inference_learners import ( + PytorchBaseInferenceLearner, + LearnerMetadata, +) +from nebullvm.transformations.base import MultiStageTransformation + + +class PytorchBackendInferenceLearner(PytorchBaseInferenceLearner): + MODEL_NAME = "model_scripted.pt" + + def __init__(self, torch_model: torch.jit.ScriptModule, **kwargs): + super().__init__(**kwargs) + self.model = torch_model.eval() + if torch.cuda.is_available(): + self.model.cuda() + + def run(self, *input_tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: + device = input_tensors[0].device + if torch.cuda.is_available(): + input_tensors = (t.cuda() for t in input_tensors) + with torch.no_grad(): + res = self.model(*input_tensors) + if not isinstance(res, tuple): + res = res.to(device) + return (res,) + return tuple(out.to(device) for out in res) + + def save(self, path: Union[str, Path], **kwargs): + path = Path(path) + metadata = LearnerMetadata.from_model(self, **kwargs) + metadata.save(path) + self.model.save(path / self.MODEL_NAME) + + @classmethod + def load(cls, path: Union[Path, str], **kwargs): + path = Path(path) + model = torch.jit.load(path / cls.MODEL_NAME) + metadata = LearnerMetadata.read(path) + return cls( + torch_model=model, + network_parameters=ModelParams(**metadata.network_parameters), + input_tfms=metadata.input_tfms, + ) + + @classmethod + def from_torch_model( + cls, + model: torch.nn.Module, + network_parameters: ModelParams, + input_tfms: Optional[MultiStageTransformation] = None, + ): + model_scripted = torch.jit.script(model) + return cls( + torch_model=model_scripted, + network_parameters=network_parameters, + input_tfms=input_tfms, + ) diff --git a/nebullvm/inference_learners/tensor_rt.py b/nebullvm/inference_learners/tensor_rt.py index ab4e7426..a55309ff 100644 --- a/nebullvm/inference_learners/tensor_rt.py +++ b/nebullvm/inference_learners/tensor_rt.py @@ -23,7 +23,7 @@ if torch.cuda.is_available(): try: import tensorrt as trt - import polygraphy + import polygraphy.cuda except ImportError: if not NO_COMPILER_INSTALLATION: from nebullvm.installers.installers import install_tensor_rt @@ -34,7 +34,7 @@ ) install_tensor_rt() import tensorrt as trt - import polygraphy + import polygraphy.cuda else: warnings.warn( "No TensorRT valid installation has been found. " @@ -343,7 +343,7 @@ def _synchronize_stream(self): @staticmethod def _get_default_cuda_stream() -> Any: - return polygraphy.Stream() + return polygraphy.cuda.Stream() @property def stream_ptr(self): @@ -362,7 +362,7 @@ def _predict_array( ) -> Generator[np.ndarray, None, None]: if self.network_parameters.dynamic_info is None: cuda_output_arrays = [ - polygraphy.DeviceArray( + polygraphy.cuda.DeviceArray( shape=(self.network_parameters.batch_size, *output_size) ) for output_size in self.network_parameters.output_sizes @@ -375,10 +375,10 @@ def _predict_array( ) cuda_output_arrays = [ - polygraphy.DeviceArray( + polygraphy.cuda.DeviceArray( shape=tuple( x - if i in dyn_out_axis.keys() + if i not in dyn_out_axis.keys() else dynamic_info.retrieve_output_dim( input_shapes, j, i, x ) @@ -440,9 +440,10 @@ def run(self, *input_tensors: tf.Tensor) -> Tuple[tf.Tensor, ...]: multiple-output of the model given a (multi-) tensor input. """ cuda_input_arrays = [ - polygraphy.DeviceArray.copy_from( - input_tensor.numpy(), stream=self.cuda_stream - ) + polygraphy.cuda.DeviceArray( + shape=tuple(input_tensor.shape), + dtype=input_tensor.numpy().dtype, + ).copy_from(input_tensor.numpy(), stream=self.cuda_stream) for input_tensor in input_tensors ] input_shapes = ( @@ -495,9 +496,9 @@ def run(self, *input_tensors: np.ndarray) -> Tuple[np.ndarray, ...]: input. """ cuda_input_arrays = [ - polygraphy.DeviceArray.copy_from( - input_tensor, stream=self.cuda_stream - ) + polygraphy.cuda.DeviceArray( + shape=tuple(input_tensor.shape), dtype=input_tensor.dtype + ).copy_from(input_tensor, stream=self.cuda_stream) for input_tensor in input_tensors ] input_shapes = ( diff --git a/nebullvm/inference_learners/tensorflow.py b/nebullvm/inference_learners/tensorflow.py new file mode 100644 index 00000000..1a9d5259 --- /dev/null +++ b/nebullvm/inference_learners/tensorflow.py @@ -0,0 +1,104 @@ +import shutil +from pathlib import Path +from typing import Tuple, Union, Dict, Type + +import tensorflow as tf + +from nebullvm.base import ModelParams +from nebullvm.config import TENSORFLOW_BACKEND_FILENAMES +from nebullvm.inference_learners import ( + TensorflowBaseInferenceLearner, + LearnerMetadata, +) + + +class TensorflowBackendInferenceLearner(TensorflowBaseInferenceLearner): + def __init__(self, tf_model: tf.Module, **kwargs): + super(TensorflowBackendInferenceLearner, self).__init__(**kwargs) + self.model = tf_model + + @tf.function(jit_compile=True) + def run(self, *input_tensors: tf.Tensor) -> Tuple[tf.Tensor, ...]: + res = self.model.predict(*input_tensors) + if not isinstance(res, tuple): + return (res,) + return res + + def save(self, path: Union[str, Path], **kwargs): + path = Path(path) + metadata = LearnerMetadata.from_model(self, **kwargs) + metadata.save(path) + self.model.save(path / TENSORFLOW_BACKEND_FILENAMES["tf_model"]) + + @classmethod + def load(cls, path: Union[Path, str], **kwargs): + path = Path(path) + metadata = LearnerMetadata.read(path) + network_parameters = ModelParams(**metadata.network_parameters) + input_tfms = metadata.input_tfms + model = tf.keras.models.load_model( + path / TENSORFLOW_BACKEND_FILENAMES["tf_model"] + ) + return cls( + tf_model=model, + network_parameters=network_parameters, + input_tfms=input_tfms, + ) + + +class TFLiteBackendInferenceLearner(TensorflowBaseInferenceLearner): + def __init__(self, tflite_file: str, **kwargs): + super(TFLiteBackendInferenceLearner, self).__init__(**kwargs) + self._tflite_file = self._store_file(tflite_file) + self.interpreter = tf.lite.Interpreter(tflite_file) + + def run(self, *input_tensors: tf.Tensor): + input_details = self.interpreter.get_input_details() + output_details = self.interpreter.get_output_details() + if self.network_parameters.dynamic_info: + for i, (input_tensor, detail) in enumerate( + zip(input_tensors, input_details) + ): + if input_tensor.shape != detail["shape"]: + self.interpreter.resize_tensor_input(i, input_tensor.shape) + self.interpreter.allocate_tensors() + for i, input_tensor in enumerate(input_tensors): + self.interpreter.set_tensor(i, input_tensor) + self.interpreter.invoke() + return tuple( + self.interpreter.get_tensor(output_detail["index"]) + for output_detail in output_details + ) + + def save(self, path: Union[str, Path], **kwargs): + path = Path(path) + metadata = LearnerMetadata.from_model(self, **kwargs) + metadata.save(path) + shutil.copy2( + self._tflite_file, + path / TENSORFLOW_BACKEND_FILENAMES["tflite_model"], + ) + + @classmethod + def load(cls, path: Union[Path, str], **kwargs): + path = Path(path) + tflite_file = str(path / TENSORFLOW_BACKEND_FILENAMES["tflite_model"]) + metadata = LearnerMetadata.read(path) + network_parameters = ModelParams(**metadata.network_parameters) + input_tfms = metadata.input_tfms + return cls( + tflite_file=tflite_file, + network_parameters=network_parameters, + input_tfms=input_tfms, + ) + + +TF_BACKEND_LEARNERS_DICT: Dict[ + str, + Type[ + Union[TensorflowBackendInferenceLearner, TFLiteBackendInferenceLearner] + ], +] = { + "tf": TensorflowBackendInferenceLearner, + "tflite": TFLiteBackendInferenceLearner, +} diff --git a/nebullvm/installers/install_tvm.sh b/nebullvm/installers/install_tvm.sh index 0d0e4a3f..2b685e88 100644 --- a/nebullvm/installers/install_tvm.sh +++ b/nebullvm/installers/install_tvm.sh @@ -8,6 +8,10 @@ then git clone --recursive https://github.com/apache/tvm tvm fi +# Fix for tvm +mv /root/tvm/configs /root/tvm/configs_orig +touch /root/tvm/configs + cd tvm mkdir -p build cp $CONFIG_PATH build/ diff --git a/nebullvm/optimizers/base.py b/nebullvm/optimizers/base.py index 8772e009..ec69556d 100644 --- a/nebullvm/optimizers/base.py +++ b/nebullvm/optimizers/base.py @@ -1,6 +1,7 @@ +import logging from abc import abstractmethod, ABC from logging import Logger -from typing import Optional, Callable +from typing import Optional, Callable, Any from nebullvm.base import DeepLearningFramework, ModelParams, QuantizationType from nebullvm.inference_learners.base import BaseInferenceLearner @@ -17,7 +18,7 @@ def __init__(self, logger: Logger = None): @abstractmethod def optimize( self, - onnx_model: str, + model: Any, output_library: DeepLearningFramework, model_params: ModelParams, input_tfms: MultiStageTransformation = None, @@ -27,3 +28,9 @@ def optimize( input_data: DataManager = None, ) -> Optional[BaseInferenceLearner]: raise NotImplementedError + + def _log(self, message: str, level: int = logging.INFO): + if self.logger is None: + logging.log(level, message) + else: + self.logger.log(level, message) diff --git a/nebullvm/optimizers/extra.py b/nebullvm/optimizers/extra.py index 164ad512..e80ee5b5 100644 --- a/nebullvm/optimizers/extra.py +++ b/nebullvm/optimizers/extra.py @@ -49,7 +49,7 @@ def __init__( def optimize( self, - onnx_model: str, + model: str, output_library: DeepLearningFramework, model_params: ModelParams, input_tfms: MultiStageTransformation = None, @@ -58,16 +58,18 @@ def optimize( perf_metric: Callable = None, input_data: DataManager = None, ) -> Optional[ONNXInferenceLearner]: - optimized_model = optimizer.optimize_model( - onnx_model, **self.hf_params + self._log( + f"Optimizing with {self.__class__.__name__} and " + f"q_type: {quantization_type}." ) + optimized_model = optimizer.optimize_model(model, **self.hf_params) if perf_loss_ths is not None: if quantization_type is not QuantizationType.HALF: return None optimized_model.convert_float_to_float16() - new_onnx_model = onnx_model.replace(".onnx", "_fp16.onnx") + new_onnx_model = model.replace(".onnx", "_fp16.onnx") else: - new_onnx_model = onnx_model.replace(".onnx", "_opt.onnx") + new_onnx_model = model.replace(".onnx", "_opt.onnx") optimized_model.save_model_to_file(new_onnx_model) learner = ONNX_INFERENCE_LEARNERS[output_library]( input_tfms=input_tfms, @@ -87,7 +89,7 @@ def optimize( tuple(convert_to_numpy(x) for x in input_) for input_ in inputs ] base_outputs = [ - tuple(run_onnx_model(onnx_model, list(input_onnx))) + tuple(run_onnx_model(model, list(input_onnx))) for input_onnx in inputs_onnx ] is_valid = check_precision( diff --git a/nebullvm/optimizers/multi_compiler.py b/nebullvm/optimizers/multi_compiler.py index 13d2abc6..a3a540a0 100644 --- a/nebullvm/optimizers/multi_compiler.py +++ b/nebullvm/optimizers/multi_compiler.py @@ -8,7 +8,7 @@ import cpuinfo import numpy as np import torch - +from tqdm import tqdm from nebullvm.base import ( ModelCompiler, @@ -162,7 +162,7 @@ def __init__( def optimize( self, - onnx_model: str, + model: str, output_library: DeepLearningFramework, model_params: ModelParams, input_tfms: MultiStageTransformation = None, @@ -174,7 +174,7 @@ def optimize( """Optimize the ONNX model using the available compilers. Args: - onnx_model (str): Path to the ONNX model. + model (str): Path to the ONNX model. output_library (DeepLearningFramework): Framework of the optimized model (either torch on tensorflow). model_params (ModelParams): Model parameters. @@ -208,7 +208,7 @@ def optimize( _optimize_with_compiler( compiler, logger=self.logger, - onnx_model=onnx_model, + model=model, output_library=output_library, model_params=model_params, input_tfms=input_tfms.copy() @@ -221,14 +221,15 @@ def optimize( input_data=input_data, ) for compiler in self.compilers - for q_type in quantization_types + for q_type in tqdm(quantization_types) ] if self.extra_optimizers is not None: + self._log("Running extra-optimizers...") optimized_models += [ _optimize_with_optimizer( op, logger=self.logger, - onnx_model=onnx_model, + model=model, output_library=output_library, model_params=model_params, input_tfms=input_tfms.copy() @@ -243,7 +244,7 @@ def optimize( input_data=input_data, ) for op in self.extra_optimizers - for q_type in quantization_types + for q_type in tqdm(quantization_types) ] optimized_models.sort(key=lambda x: x[1], reverse=False) return optimized_models[0][0] @@ -251,7 +252,7 @@ def optimize( def optimize_on_custom_metric( self, metric_func: Callable, - onnx_model: str, + model: str, output_library: DeepLearningFramework, model_params: ModelParams, input_tfms: MultiStageTransformation = None, @@ -270,7 +271,7 @@ def optimize_on_custom_metric( InferenceLearner and return a numerical value. Note that the outputs will be sorted in an ascendant order, i.e. the compiled model with the smallest value will be selected. - onnx_model (str): Path to the ONNX model. + model (str): Path to the ONNX model. output_library (DeepLearningFramework): Framework of the optimized model (either torch on tensorflow). model_params (ModelParams): Model parameters. @@ -311,7 +312,7 @@ def optimize_on_custom_metric( compiler, metric_func=metric_func, logger=self.logger, - onnx_model=onnx_model, + model=model, output_library=output_library, model_params=model_params, input_tfms=input_tfms.copy() @@ -324,14 +325,14 @@ def optimize_on_custom_metric( input_data=input_data, ) for compiler in self.compilers - for q_type in quantization_types + for q_type in tqdm(quantization_types) ] if self.extra_optimizers is not None: optimized_models += [ _optimize_with_optimizer( op, logger=self.logger, - onnx_model=onnx_model, + model=model, output_library=output_library, model_params=model_params, input_tfms=input_tfms.copy() @@ -346,7 +347,7 @@ def optimize_on_custom_metric( input_data=input_data, ) for op in self.extra_optimizers - for q_type in quantization_types + for q_type in tqdm(quantization_types) ] if return_all: return optimized_models diff --git a/nebullvm/optimizers/onnx.py b/nebullvm/optimizers/onnx.py index 752be94c..4772cc1e 100644 --- a/nebullvm/optimizers/onnx.py +++ b/nebullvm/optimizers/onnx.py @@ -27,7 +27,7 @@ class ONNXOptimizer(BaseOptimizer): def optimize( self, - onnx_model: str, + model: str, output_library: DeepLearningFramework, model_params: ModelParams, input_tfms: MultiStageTransformation = None, @@ -39,7 +39,7 @@ def optimize( """Build the ONNX runtime learner from the onnx model. Args: - onnx_model (str): Path to the saved onnx model. + model (str): Path to the saved onnx model. output_library (str): DL Framework the optimized model will be compatible with. model_params (ModelParams): Model parameters. @@ -61,6 +61,10 @@ def optimize( will have an interface in the DL library specified in `output_library`. """ + self._log( + f"Optimizing with {self.__class__.__name__} and " + f"q_type: {quantization_type}." + ) input_data_onnx, output_data_onnx, ys = [], [], None check_quantization(quantization_type, perf_loss_ths) if perf_loss_ths is not None: @@ -77,18 +81,18 @@ def optimize( 300, with_ys=True ) output_data_onnx = [ - tuple(run_onnx_model(onnx_model, list(input_tensors))) + tuple(run_onnx_model(model, list(input_tensors))) for input_tensors in input_data_onnx ] - onnx_model, input_tfms = quantize_onnx( - onnx_model, quantization_type, input_tfms, input_data_onnx + model, input_tfms = quantize_onnx( + model, quantization_type, input_tfms, input_data_onnx ) learner = ONNX_INFERENCE_LEARNERS[output_library]( input_tfms=input_tfms, network_parameters=model_params, - onnx_path=onnx_model, - input_names=get_input_names(onnx_model), - output_names=get_output_names(onnx_model), + onnx_path=model, + input_names=get_input_names(model), + output_names=get_output_names(model), ) if perf_loss_ths is not None: inputs = [ diff --git a/nebullvm/optimizers/openvino.py b/nebullvm/optimizers/openvino.py index fe30a50c..08caae7f 100644 --- a/nebullvm/optimizers/openvino.py +++ b/nebullvm/optimizers/openvino.py @@ -25,7 +25,7 @@ class OpenVinoOptimizer(BaseOptimizer): def optimize( self, - onnx_model: str, + model: str, output_library: DeepLearningFramework, model_params: ModelParams, input_tfms: MultiStageTransformation = None, @@ -37,7 +37,7 @@ def optimize( """Optimize the onnx model with OpenVino. Args: - onnx_model (str): Path to the saved onnx model. + model (str): Path to the saved onnx model. output_library (str): DL Framework the optimized model will be compatible with. model_params (ModelParams): Model parameters. @@ -59,14 +59,18 @@ def optimize( will have an interface in the DL library specified in `output_library`. """ + self._log( + f"Optimizing with {self.__class__.__name__} and " + f"q_type: {quantization_type}." + ) cmd = [ "mo", "--input_model", - onnx_model, + model, "--output_dir", - str(Path(onnx_model).parent), + str(Path(model).parent), "--input", - ",".join(get_input_names(onnx_model)), + ",".join(get_input_names(model)), "--input_shape", ",".join( [ @@ -87,31 +91,32 @@ def optimize( return None process = subprocess.Popen(cmd) process.wait() - base_path = Path(onnx_model).parent - openvino_model_path = base_path / f"{Path(onnx_model).stem}.xml" - openvino_model_weights = base_path / f"{Path(onnx_model).stem}.bin" + base_path = Path(model).parent + openvino_model_path = base_path / f"{Path(model).stem}.xml" + openvino_model_weights = base_path / f"{Path(model).stem}.bin" if ( perf_loss_ths is not None and quantization_type is not QuantizationType.HALF ): - if input_data is not None: - input_data_onnx = input_data.get_numpy_list(300, with_ys=False) + if input_data is not None and quantization_type: + input_data_onnx = input_data.get_numpy_list(300, with_ys=True) else: input_data_onnx = [ - tuple( + ( create_model_inputs_onnx( model_params.batch_size, model_params.input_infos - ) + ), + 0, ) ] # Add post training optimization openvino_model_path, openvino_model_weights = quantize_openvino( model_topology=str(openvino_model_path), model_weights=str(openvino_model_weights), - input_names=get_input_names(onnx_model), + input_names=get_input_names(model), input_data=input_data_onnx, ) - model = OPENVINO_INFERENCE_LEARNERS[output_library].from_model_name( + learner = OPENVINO_INFERENCE_LEARNERS[output_library].from_model_name( model_name=str(openvino_model_path), model_weights=str(openvino_model_weights), network_parameters=model_params, @@ -119,7 +124,7 @@ def optimize( ) if perf_loss_ths is not None: if input_data is None: - inputs = [model.get_inputs_example()] + inputs = [learner.get_inputs_example()] ys = None else: inputs, ys = input_data.get_list( @@ -128,14 +133,14 @@ def optimize( output_data_onnx = [ tuple( run_onnx_model( - onnx_model, + model, [convert_to_numpy(x) for x in tuple_], ) ) for tuple_ in inputs ] is_valid = check_precision( - model, + learner, inputs, output_data_onnx, perf_loss_ths, @@ -144,4 +149,4 @@ def optimize( ) if not is_valid: return None - return model + return learner diff --git a/nebullvm/optimizers/pytorch.py b/nebullvm/optimizers/pytorch.py new file mode 100644 index 00000000..254719fd --- /dev/null +++ b/nebullvm/optimizers/pytorch.py @@ -0,0 +1,118 @@ +from collections import Callable +from typing import Optional + +import torch.nn + +from nebullvm.base import DeepLearningFramework, ModelParams, QuantizationType +from nebullvm.inference_learners.pytorch import PytorchBackendInferenceLearner +from nebullvm.optimizers import BaseOptimizer +from nebullvm.optimizers.quantization.pytorch import quantize_torch +from nebullvm.optimizers.quantization.utils import ( + check_quantization, + check_precision, +) +from nebullvm.transformations.base import MultiStageTransformation +from nebullvm.utils.data import DataManager +from nebullvm.utils.onnx import convert_to_target_framework +from nebullvm.utils.torch import create_model_inputs_torch, run_torch_model + + +class PytorchBackendOptimizer(BaseOptimizer): + """Optimizer working directly on the pytorch backend, with no need of a + conversion to ONNX. The model will be finally compiled using torchscript. + For avoiding un-wanted modification to the input model models are copied + before being optimized. + + Attributes: + logger (Logger, optional): Optional logger for logging optimization + information. + """ + + def optimize( + self, + model: torch.nn.Module, + output_library: DeepLearningFramework, + model_params: ModelParams, + input_tfms: MultiStageTransformation = None, + perf_loss_ths: float = None, + quantization_type: QuantizationType = None, + perf_metric: Callable = None, + input_data: DataManager = None, + ) -> Optional[PytorchBackendInferenceLearner]: + """Optimize the input model using pytorch built-in techniques. + + Args: + model (torch.nn.Module): The pytorch model. For avoiding un-wanted + modifications to the original model, it will be copied in the + method. + output_library (DeepLearningFramework): Output framework. At the + current stage just PYTORCH is supported. + model_params (ModelParams): Model parameters. + input_tfms (MultiStageTransformation, optional): Transformations + to be performed to the model's input tensors in order to + get the prediction. + perf_loss_ths (float, optional): Threshold for the accepted drop + in terms of precision. Any optimized model with an higher drop + will be ignored. + quantization_type (QuantizationType, optional): The desired + quantization algorithm to be used. + perf_metric (Callable, optional): If given it should + compute the difference between the quantized and the normal + prediction. + input_data (DataManager, optional): User defined data. + + Returns: + PytorchBackendInferenceLearner: Model optimized for inference. + """ + self._log( + f"Optimizing with {self.__class__.__name__} and " + f"q_type: {quantization_type}." + ) + assert output_library is DeepLearningFramework.PYTORCH, ( + "Other APIs than the Pytorch one are not supported " + "for the Pytorch Backend yet." + ) + check_quantization(quantization_type, perf_loss_ths) + if perf_loss_ths is not None: + if input_data is None: + input_data_torch = [ + tuple( + create_model_inputs_torch( + model_params.batch_size, model_params.input_infos + ) + ) + ] + else: + input_data_torch, ys = input_data.get_numpy_list( + 300, with_ys=True + ) + input_data_torch = [ + tuple( + convert_to_target_framework(t, output_library) + for t in data_tuple + ) + for data_tuple in input_data_torch + ] + output_data_torch = [ + tuple(run_torch_model(model, list(input_tensors))) + for input_tensors in input_data_torch + ] + model, input_tfms = quantize_torch( + model, quantization_type, input_tfms, input_data_torch + ) + + learner = PytorchBackendInferenceLearner.from_torch_model( + model, network_parameters=model_params, input_tfms=input_tfms + ) + if perf_loss_ths is not None: + is_valid = check_precision( + learner, + input_data_torch, + output_data_torch, + perf_loss_ths, + metric_func=perf_metric, + ys=ys, + ) + if not is_valid: + return None + return learner diff --git a/nebullvm/optimizers/quantization/onnx.py b/nebullvm/optimizers/quantization/onnx.py index 4377c008..5770184e 100644 --- a/nebullvm/optimizers/quantization/onnx.py +++ b/nebullvm/optimizers/quantization/onnx.py @@ -5,33 +5,38 @@ import numpy as np import onnx import torch -from onnxmltools.utils.float16_converter import ( - convert_float_to_float16_model_path, -) from torch.utils.data import DataLoader from nebullvm.base import QuantizationType +from nebullvm.config import NO_COMPILER_INSTALLATION from nebullvm.transformations.base import MultiStageTransformation from nebullvm.transformations.precision_tfms import HalfPrecisionTransformation from nebullvm.utils.onnx import get_input_names try: - from onnxruntime.quantization import ( - QuantType, - quantize_static, - quantize_dynamic, - CalibrationDataReader, + from onnxmltools.utils.float16_converter import ( + convert_float_to_float16_model_path, ) -except ImportError: - from nebullvm.installers.installers import install_onnxruntime - - install_onnxruntime() from onnxruntime.quantization import ( QuantType, quantize_static, quantize_dynamic, CalibrationDataReader, ) +except ImportError: + if NO_COMPILER_INSTALLATION: + QuantType = quantize_static = quantize_dynamic = None + CalibrationDataReader = object + else: + from nebullvm.installers.installers import install_onnxruntime + + install_onnxruntime() + from onnxruntime.quantization import ( + QuantType, + quantize_static, + quantize_dynamic, + CalibrationDataReader, + ) class _IterableCalibrationDataReader(CalibrationDataReader): diff --git a/nebullvm/optimizers/quantization/openvino.py b/nebullvm/optimizers/quantization/openvino.py index e85127d9..8ed5f3f3 100644 --- a/nebullvm/optimizers/quantization/openvino.py +++ b/nebullvm/optimizers/quantization/openvino.py @@ -2,6 +2,8 @@ import numpy as np +from nebullvm.config import NO_COMPILER_INSTALLATION + try: from openvino.tools.pot import DataLoader from openvino.tools.pot import IEEngine @@ -11,7 +13,10 @@ except ImportError: import cpuinfo - if "intel" in cpuinfo.get_cpu_info()["brand_raw"].lower(): + if ( + "intel" in cpuinfo.get_cpu_info()["brand_raw"].lower() + and not NO_COMPILER_INSTALLATION + ): from nebullvm.installers.installers import install_openvino install_openvino() @@ -32,10 +37,13 @@ def __init__( self._input_names = input_names def __len__(self): - return self._input_data + return len(self._input_data[0]) def __getitem__(self, item): - return dict(zip(self._input_names, self._input_data[item])) + return ( + dict(zip(self._input_names, self._input_data[0][item])), + self._input_data[1][item], + ) def quantize_openvino( @@ -56,7 +64,11 @@ def quantize_openvino( algorithms = [ { "name": "DefaultQuantization", - "params": {"target_device": "ANY", "stat_subset_size": 300}, + "params": { + "target_device": "ANY", + "preset": "performance", + "stat_subset_size": len(input_data), + }, } ] data_loader = _CalibrationDataLoader( diff --git a/nebullvm/optimizers/quantization/pytorch.py b/nebullvm/optimizers/quantization/pytorch.py new file mode 100644 index 00000000..dc1f093e --- /dev/null +++ b/nebullvm/optimizers/quantization/pytorch.py @@ -0,0 +1,70 @@ +from typing import List, Tuple + +import torch +from torch.ao.quantization.stubs import QuantStub, DeQuantStub + +from nebullvm.base import QuantizationType +from nebullvm.transformations.base import MultiStageTransformation +from nebullvm.transformations.precision_tfms import HalfPrecisionTransformation + + +class _QuantWrapper(torch.nn.Module): + def __init__(self, model: torch.nn.Module): + super(_QuantWrapper, self).__init__() + qconfig = model.qconfig if hasattr(model, "qconfig") else None + self.quant = QuantStub(qconfig) + self.model = model + self.dequant = DeQuantStub() + + def forward(self, *inputs: torch.Tensor): + inputs = (self.quant(x) for x in inputs) + outputs = self.model(*inputs) + return tuple(self.dequant(x) for x in outputs) + + +def _quantize_dynamic(model: torch.nn.Module): + layer_types = { + type(layer) + for layer in model.children() + if len(list(layer.parameters())) > 0 + } + quantized_model = torch.quantization.quantize_dynamic( + model=model, qconfig_spec=layer_types, dtype=torch.qint8 + ) + return quantized_model + + +def _quantize_static( + model: torch.nn.Module, input_data: List[Tuple[torch.Tensor, ...]] +): + model = _QuantWrapper(model) + model.qconfig = torch.quantization.get_default_qconfig("fbgemm") + model = torch.quantization.fuse_modules(model, [["conv", "relu"]]) + model = torch.quantization.prepare(model) + for tensors in input_data: + _ = model(*tensors) + return torch.quantization.convert(model) + + +def _half_precision(model: torch.nn.Module): + return model.half() + + +def quantize_torch( + model: torch.nn.Module, + quantization_type: QuantizationType, + input_tfms: MultiStageTransformation, + input_data_torch: List[Tuple[torch.Tensor, ...]], +): + if quantization_type is QuantizationType.HALF: + input_tfms.append(HalfPrecisionTransformation()) + return _half_precision(model), input_tfms + elif quantization_type is QuantizationType.STATIC: + return _quantize_static(model, input_data_torch), input_tfms + elif quantization_type is QuantizationType.DYNAMIC: + return _quantize_dynamic(model), input_tfms + else: + raise NotImplementedError( + f"No quantization implemented for quantization " + f"type {quantization_type}" + ) diff --git a/nebullvm/optimizers/quantization/tensorflow.py b/nebullvm/optimizers/quantization/tensorflow.py new file mode 100644 index 00000000..4823bf44 --- /dev/null +++ b/nebullvm/optimizers/quantization/tensorflow.py @@ -0,0 +1,61 @@ +import os.path +from typing import List, Tuple + +import tensorflow as tf + +from nebullvm.base import QuantizationType +from nebullvm.config import TENSORFLOW_BACKEND_FILENAMES +from nebullvm.transformations.base import MultiStageTransformation + + +def _quantize_dynamic(model: tf.Module): + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + tflite_quant_model = converter.convert() + return tflite_quant_model + + +def _quantize_static(model: tf.Module, dataset: List[Tuple[tf.Tensor, ...]]): + def representative_dataset(): + for data_tuple in dataset: + yield list(data_tuple) + + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + tflite_quant_model = converter.convert() + return tflite_quant_model + + +def _half_precision(model: tf.Module): + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.target_spec.supported_types = [tf.float16] + tflite_quant_model = converter.convert() + return tflite_quant_model + + +def quantize_tf( + model: tf.Module, + quantization_type: QuantizationType, + input_tfms: MultiStageTransformation, + input_data: List[Tuple[tf.Tensor, ...]], + tmp_dir: str, +): + if quantization_type is QuantizationType.DYNAMIC: + quantized_model = _quantize_dynamic(model) + elif quantization_type is QuantizationType.STATIC: + quantized_model = _quantize_static(model, input_data) + elif quantization_type is QuantizationType.HALF: + quantized_model = _half_precision(model) + else: + raise NotImplementedError( + f"Quantization not supported for type {quantization_type}" + ) + + filepath = os.path.join( + tmp_dir, TENSORFLOW_BACKEND_FILENAMES["tflite_model"] + ) + with open(filepath, "wb") as f: + f.write(quantized_model) + return filepath, input_tfms diff --git a/nebullvm/optimizers/tensor_rt.py b/nebullvm/optimizers/tensor_rt.py index da2ea86a..fbf164ab 100644 --- a/nebullvm/optimizers/tensor_rt.py +++ b/nebullvm/optimizers/tensor_rt.py @@ -13,7 +13,6 @@ from nebullvm.optimizers.base import ( BaseOptimizer, ) -from nebullvm.optimizers.quantization.onnx import quantize_onnx from nebullvm.optimizers.quantization.tensor_rt import TensorRTCalibrator from nebullvm.optimizers.quantization.utils import ( check_precision, @@ -87,10 +86,10 @@ def _build_and_save_the_engine( config.set_flag(trt.BuilderFlag.INT8) config.int8_calibrator = calibrator elif quantization_type is QuantizationType.DYNAMIC: - onnx_model_path, _ = quantize_onnx( - onnx_model_path, quantization_type, input_tfms, input_data - ) - config.set_flag(trt.BuilderFlag.kINT8) + # onnx_model_path, _ = quantize_onnx( + # onnx_model_path, quantization_type, input_tfms, input_data + # ) + config.set_flag(trt.BuilderFlag.INT8) # import the model parser = trt.OnnxParser(network, nvidia_logger) success = parser.parse_from_file(onnx_model_path) @@ -134,7 +133,7 @@ def _build_and_save_the_engine( def optimize( self, - onnx_model: str, + model: str, output_library: DeepLearningFramework, model_params: ModelParams, input_tfms: MultiStageTransformation = None, @@ -146,7 +145,7 @@ def optimize( """Optimize the input model with TensorRT. Args: - onnx_model (str): Path to the saved onnx model. + model (str): Path to the saved onnx model. output_library (str): DL Framework the optimized model will be compatible with. model_params (ModelParams): Model parameters. @@ -168,13 +167,17 @@ def optimize( will have an interface in the DL library specified in `output_library`. """ + self._log( + f"Optimizing with {self.__class__.__name__} and " + f"q_type: {quantization_type}." + ) if not torch.cuda.is_available(): raise SystemError( "You are trying to run an optimizer developed for NVidia gpus " "on a machine not connected to any GPU supporting CUDA." ) check_quantization(quantization_type, perf_loss_ths) - engine_path = Path(onnx_model).parent / NVIDIA_FILENAMES["engine"] + engine_path = Path(model).parent / NVIDIA_FILENAMES["engine"] if ( perf_loss_ths is not None and quantization_type is QuantizationType.STATIC @@ -189,26 +192,31 @@ def optimize( ] else: input_data_onnx = input_data.get_numpy_list(300, with_ys=False) + elif ( + perf_loss_ths is not None + and quantization_type is QuantizationType.DYNAMIC + ): + return None # Dynamic quantization is not supported on tensorRT else: input_data_onnx = None self._build_and_save_the_engine( engine_path=engine_path, - onnx_model_path=onnx_model, + onnx_model_path=model, model_params=model_params, input_tfms=input_tfms, quantization_type=quantization_type, input_data=input_data_onnx, ) - model = NVIDIA_INFERENCE_LEARNERS[output_library].from_engine_path( + learner = NVIDIA_INFERENCE_LEARNERS[output_library].from_engine_path( input_tfms=input_tfms, network_parameters=model_params, engine_path=engine_path, - input_names=get_input_names(onnx_model), - output_names=get_output_names(onnx_model), + input_names=get_input_names(model), + output_names=get_output_names(model), ) if quantization_type is not None: if input_data is None: - inputs = [model.get_inputs_example()] + inputs = [learner.get_inputs_example()] ys = None else: inputs, ys = input_data.get_numpy_list( @@ -217,14 +225,14 @@ def optimize( output_data = [ tuple( run_onnx_model( - onnx_model, + model, [convert_to_numpy(x) for x in tuple_], ) ) for tuple_ in inputs ] is_valid = check_precision( - model, + learner, inputs, output_data, perf_loss_ths, @@ -233,4 +241,4 @@ def optimize( ) if not is_valid: return None - return model + return learner diff --git a/nebullvm/optimizers/tensorflow.py b/nebullvm/optimizers/tensorflow.py new file mode 100644 index 00000000..f7a99cb4 --- /dev/null +++ b/nebullvm/optimizers/tensorflow.py @@ -0,0 +1,130 @@ +from tempfile import TemporaryDirectory +from typing import Callable, Optional + +import tensorflow as tf + +from nebullvm.base import DeepLearningFramework, ModelParams, QuantizationType +from nebullvm.inference_learners.tensorflow import ( + TensorflowBackendInferenceLearner, + TF_BACKEND_LEARNERS_DICT, +) +from nebullvm.optimizers import BaseOptimizer +from nebullvm.optimizers.quantization.tensorflow import quantize_tf +from nebullvm.optimizers.quantization.utils import ( + check_quantization, + check_precision, +) +from nebullvm.transformations.base import MultiStageTransformation +from nebullvm.utils.data import DataManager +from nebullvm.utils.onnx import convert_to_target_framework +from nebullvm.utils.tf import create_model_inputs_tf, run_tf_model + + +class TensorflowBackendOptimizer(BaseOptimizer): + """Optimizer working directly on the tensorflow backend, with no need of a + conversion to ONNX. The model will be finally compiled using tflite. + For avoiding un-wanted modification to the input model models are copied + before being optimized. + + Attributes: + logger (Logger, optional): Optional logger for logging optimization + information. + """ + + def optimize( + self, + model: tf.Module, + output_library: DeepLearningFramework, + model_params: ModelParams, + input_tfms: MultiStageTransformation = None, + perf_loss_ths: float = None, + quantization_type: QuantizationType = None, + perf_metric: Callable = None, + input_data: DataManager = None, + ) -> Optional[TensorflowBackendInferenceLearner]: + """Optimize the input model using pytorch built-in techniques. + + Args: + model (tf.Module): The tensorflow model. For avoiding un-wanted + modifications to the original model, it will be copied in the + method. + output_library (DeepLearningFramework): Output framework. At the + current stage just TENSORFLOW is supported. + model_params (ModelParams): Model parameters. + input_tfms (MultiStageTransformation, optional): Transformations + to be performed to the model's input tensors in order to + get the prediction. + perf_loss_ths (float, optional): Threshold for the accepted drop + in terms of precision. Any optimized model with an higher drop + will be ignored. + quantization_type (QuantizationType, optional): The desired + quantization algorithm to be used. + perf_metric (Callable, optional): If given it should + compute the difference between the quantized and the normal + prediction. + input_data (DataManager, optional): User defined data. + + Returns: + TensorflowBackendInferenceLearner or TFLiteBackendInferenceLearner: + Model optimized for inference. + """ + self._log( + f"Optimizing with {self.__class__.__name__} and " + f"q_type: {quantization_type}." + ) + assert output_library is DeepLearningFramework.TENSORFLOW, ( + "Other APIs than the Tensorflow one are not supported " + "for the Tensorflow Backend yet." + ) + + check_quantization(quantization_type, perf_loss_ths) + with TemporaryDirectory() as tmp_dir: + if perf_loss_ths is not None: + if input_data is None: + input_data_tf = [ + tuple( + create_model_inputs_tf( + model_params.batch_size, + model_params.input_infos, + ) + ) + ] + ys = None + else: + input_data_tf, ys = input_data.get_numpy_list( + 300, with_ys=True + ) + input_data_tf = [ + tuple( + convert_to_target_framework(t, output_library) + for t in data_tuple + ) + for data_tuple in input_data_tf + ] + output_data_tf = [ + tuple(run_tf_model(model, input_tensors)) + for input_tensors in input_data_tf + ] + model, input_tfms = quantize_tf( + model=model, + quantization_type=quantization_type, + input_tfms=input_tfms, + input_data=input_data_tf, + tmp_dir=tmp_dir, + ) + + learner = TF_BACKEND_LEARNERS_DICT[ + "tflite" if perf_loss_ths is not None else "tf" + ](model, network_parameters=model_params, input_tfms=input_tfms) + if perf_loss_ths is not None: + is_valid = check_precision( + learner, + input_data_tf, + output_data_tf, + perf_loss_ths, + metric_func=perf_metric, + ys=ys, + ) + if not is_valid: + return None + return learner diff --git a/nebullvm/optimizers/tvm.py b/nebullvm/optimizers/tvm.py index 2a8824db..a0981629 100644 --- a/nebullvm/optimizers/tvm.py +++ b/nebullvm/optimizers/tvm.py @@ -57,6 +57,10 @@ def optimize_from_torch( perf_metric: Callable = None, input_data: DataManager = None, ) -> Optional[ApacheTVMInferenceLearner]: + self._log( + f"Optimizing with {self.__class__.__name__} and " + f"q_type: {quantization_type}." + ) target = self._get_target() mod, params = self._build_tvm_model_from_torch( torch_model, model_params @@ -125,7 +129,7 @@ def optimize_from_torch( def optimize( self, - onnx_model: str, + model: str, output_library: DeepLearningFramework, model_params: ModelParams, input_tfms: MultiStageTransformation = None, @@ -137,7 +141,7 @@ def optimize( """Optimize the input model with Apache TVM. Args: - onnx_model (str): Path to the saved onnx model. + model (str): Path to the saved onnx model. output_library (str): DL Framework the optimized model will be compatible with. model_params (ModelParams): Model parameters. @@ -159,9 +163,13 @@ def optimize( will have an interface in the DL library specified in `output_library`. """ + self._log( + f"Optimizing with {self.__class__.__name__} and " + f"q_type: {quantization_type}." + ) check_quantization(quantization_type, perf_loss_ths) target = self._get_target() - mod, params = self._build_tvm_model_from_onnx(onnx_model, model_params) + mod, params = self._build_tvm_model_from_onnx(model, model_params) if perf_loss_ths is not None: if quantization_type is QuantizationType.HALF: mod = tvm.relay.transform.ToMixedPrecision( @@ -182,7 +190,7 @@ def optimize( ] else: inputs = input_data.get_numpy_list(300, with_ys=False) - inputs = TVMCalibrator(inputs, get_input_names(onnx_model)) + inputs = TVMCalibrator(inputs, get_input_names(model)) else: return mod = self._quantize(mod, params, input_data=inputs) @@ -195,7 +203,7 @@ def optimize( network_parameters=model_params, lib=lib, target_device=target, - input_names=get_input_names(onnx_model), + input_names=get_input_names(model), ) if quantization_type is not None: if input_data is None: @@ -208,7 +216,7 @@ def optimize( output_data = [ tuple( run_onnx_model( - onnx_model, + model, [convert_to_numpy(x) for x in tuple_], ) ) diff --git a/nebullvm/transformations/tensor_tfms.py b/nebullvm/transformations/tensor_tfms.py index 9d606820..8e137811 100644 --- a/nebullvm/transformations/tensor_tfms.py +++ b/nebullvm/transformations/tensor_tfms.py @@ -6,7 +6,9 @@ class VerifyContiguity(BaseTransformation): - def _transform(self, _input: torch.Tensor, **kwargs) -> Any: + def _transform(self, _input: Any, **kwargs) -> Any: + if not isinstance(_input, torch.Tensor): + return _input if not _input.is_contiguous(): _input = _input.contiguous() return _input diff --git a/nebullvm/utils/data.py b/nebullvm/utils/data.py index ec3de1dc..9339af97 100644 --- a/nebullvm/utils/data.py +++ b/nebullvm/utils/data.py @@ -58,7 +58,7 @@ def get_list( if shuffle: idx = np.random.choice(len(self), n, replace=n > len(self)) else: - idx = np.arange(0, max(n, len(self))) + idx = np.arange(0, min(n, len(self))) if n > len(self): idx = np.concatenate( [ @@ -72,7 +72,8 @@ def get_list( return [self[i][0] for i in idx] ys, xs = [], [] - for x, y in self: + for i in idx: + x, y = self[i] xs.append(x) ys.append(y) return xs, ys diff --git a/requirements.txt b/requirements.txt index c584086f..0b7df235 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,5 +5,6 @@ py-cpuinfo==8.0.0 tensorflow>=2.7.0, <2.8.0 tf2onnx>=1.8.4 torch>=1.10.0 +tqdm>=4.63.0 transformers pytest \ No newline at end of file diff --git a/setup.py b/setup.py index d874fee4..167a12d9 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ "tensorflow>=2.7.0, <2.8.0", "tf2onnx>=1.8.4", "torch>=1.10.0", + "tqdm>=4.63.0", ] this_directory = Path(__file__).parent @@ -16,7 +17,7 @@ setup( name="nebullvm", - version="0.3.0", + version="0.3.1", packages=find_packages(), install_requires=REQUIREMENTS, package_data={