diff --git a/nebullvm/operations/conversions/converters.py b/nebullvm/operations/conversions/converters.py index 355749d8..ac709986 100644 --- a/nebullvm/operations/conversions/converters.py +++ b/nebullvm/operations/conversions/converters.py @@ -5,11 +5,13 @@ from nebullvm.operations.base import Operation from nebullvm.operations.conversions.pytorch import convert_torch_to_onnx from nebullvm.operations.conversions.tensorflow import convert_tf_to_onnx +from nebullvm.operations.conversions.onnx import convert_onnx_to_torch from nebullvm.optional_modules.onnx import onnx from nebullvm.optional_modules.tensorflow import tensorflow as tf from nebullvm.optional_modules.torch import torch from nebullvm.tools.base import DeepLearningFramework, ModelParams from nebullvm.tools.data import DataManager +from onnx import ModelProto class Converter(Operation, abc.ABC): @@ -27,7 +29,7 @@ def __init__(self, model_name: Optional[str] = None): self.model_name = model_name or "temp" def set_state( - self, model: Union[torch.nn.Module, tf.Module, str], data: DataManager + self, model: Union[torch.nn.Module, tf.Module, onnx.ModelProto, str], data: DataManager ): self.model = model self.data = data @@ -104,26 +106,32 @@ def pytorch_conversion(self): class ONNXConverter(Converter): - DEST_FRAMEWORKS = [] + DEST_FRAMEWORKS = [DeepLearningFramework.NUMPY] - def execute(self, save_path, model_params): - onnx_path = save_path / f"{self.model_name}{self.ONNX_EXTENSION}" - try: - model_onnx = onnx.load(str(self.model)) - onnx.save(model_onnx, str(onnx_path)) - except Exception: - self.logger.error( - "The provided onnx model path is invalid. Please provide" - " a valid path to a model in order to use Nebullvm." - ) - self.converted_models = [] - - self.converted_models = [str(onnx_path)] + def execute( + self, + save_path: Path, + model_params: ModelProto, + ): + self.converted_models = [self.model] + for framework in self.DEST_FRAMEWORKS: + if framework is DeepLearningFramework.NUMPY: + self.pytorch_conversion(save_path, model_params) + else: + raise NotImplementedError() def tensorflow_conversion(self): # TODO: Implement conversion from ONNX to Tensorflow raise NotImplementedError() - def pytorch_conversion(self): - # TODO: Implement conversion from ONNX to Pytorch - raise NotImplementedError() + def pytorch_conversion(self, save_path, model_params): + self.model_onnx = self.model + #torch_path = save_path / f"{self.model_name}{self.TORCH_EXTENSION}" + torch_model = convert_onnx_to_torch( + onnx_model=self.model_onnx + ) + if self.converted_models is None: + self.converted_models = [torch_model] + else: + self.converted_models.append(torch_model) + diff --git a/nebullvm/operations/conversions/onnx.py b/nebullvm/operations/conversions/onnx.py new file mode 100644 index 00000000..777da314 --- /dev/null +++ b/nebullvm/operations/conversions/onnx.py @@ -0,0 +1,27 @@ +import logging +from pathlib import Path + +from nebullvm.optional_modules.torch import torch +from nebullvm.tools.base import Device +from nebullvm.optional_modules.onnx import ModelProto +logger = logging.getLogger("nebullvm_logger") + +from nebullvm.optional_modules.onnx import convert + +def convert_onnx_to_torch( + onnx_model: ModelProto +): + """Function importing a custom ONNX model and converting it in Pytorch + + Args: + onnx_model: ONNX model (tested with model=onnx.load("model.onnx")). + """ + try: + torch_model = torch.fx.symbolic_trace(convert(onnx_model)) + return torch_model + except Exception as e: + logger.warning("Exception raised during conversion of ONNX to Pytorch." + "ONNX to Torch pipeline will be skipped") + logger.warning(e) + return None + diff --git a/nebullvm/operations/inference_learners/pytorch.py b/nebullvm/operations/inference_learners/pytorch.py index b5a3b5f4..32295aff 100644 --- a/nebullvm/operations/inference_learners/pytorch.py +++ b/nebullvm/operations/inference_learners/pytorch.py @@ -4,6 +4,8 @@ from tempfile import TemporaryDirectory from typing import Tuple, Union, Optional, List +import numpy as np + from nebullvm.operations.inference_learners.base import ( PytorchBaseInferenceLearner, LearnerMetadata, @@ -116,3 +118,19 @@ def from_torch_model( input_data=input_data, device=device, ) + + +class NumpyPytorchBackendInferenceLearner( + PytorchBackendInferenceLearner, PytorchBaseInferenceLearner +): + + """ + Wrapper around PytorchBackendInferenceLearner to allow numpy inputs and outputs + """ + + + def run(self, *input_tensors: np.ndarray) -> Tuple[np.ndarray, ...]: + input_tensors = [torch.from_numpy(t) for t in input_tensors] + # Call the PytorchBackendInferenceLearner run method + res = super().run(*input_tensors) + return tuple(out.numpy() for out in res) diff --git a/nebullvm/optional_modules/onnx.py b/nebullvm/optional_modules/onnx.py index 7ce9d2b4..4d70a3e7 100644 --- a/nebullvm/optional_modules/onnx.py +++ b/nebullvm/optional_modules/onnx.py @@ -1,5 +1,6 @@ try: import onnx # noqa F401 + from onnx import ModelProto except ImportError: onnx = None @@ -11,3 +12,8 @@ except ImportError: convert_float_to_float16_model_path = object + +try: + from onnx2torch import convert # noqa F401 +except ImportError: + convert = None