Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added openvino compiler with tensorflow interface #137

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions nebullvm/operations/inference_learners/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,25 @@ def execute(
)


class TensorFlowOpenVINOBuildInferenceLearner(BuildInferenceLearner):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this class is actually useless. You can simply reuse the OpenVINOBuildInferenceLearner passing TensorFlow as source_dl_framework.

def execute(
self,
model: str,
model_params: ModelParams,
input_tfms: MultiStageTransformation,
source_dl_framework: DeepLearningFramework,
**kwargs,
):
self.inference_learner = OPENVINO_INFERENCE_LEARNERS[
source_dl_framework
].from_model_name(
model_name=model + ".xml",
model_weights=model + ".bin",
input_tfms=input_tfms,
network_parameters=model_params,
)


class PyTorchTensorRTBuildInferenceLearner(BuildInferenceLearner):
def execute(
self,
Expand Down
2 changes: 1 addition & 1 deletion nebullvm/operations/inference_learners/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,4 +400,4 @@ def run(self, *input_tensors: np.ndarray) -> Tuple[np.ndarray, ...]:
DeepLearningFramework.PYTORCH: PytorchOpenVinoInferenceLearner,
DeepLearningFramework.TENSORFLOW: TensorflowOpenVinoInferenceLearner,
DeepLearningFramework.NUMPY: NumpyOpenVinoInferenceLearner,
}
}
12 changes: 10 additions & 2 deletions nebullvm/operations/optimizations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
TFLiteBuildInferenceLearner,
TensorflowBuildInferenceLearner,
)
from nebullvm.operations.inference_learners.openvino import (
TensorflowOpenVinoInferenceLearner,
)
from nebullvm.operations.measures.measures import MetricDropMeasure
from nebullvm.operations.measures.utils import (
compute_relative_difference,
Expand All @@ -35,6 +38,7 @@
)
from nebullvm.operations.optimizations.compilers.openvino import (
OpenVINOCompiler,
TensorFlowOpenVINOCompiler,
)
from nebullvm.operations.optimizations.compilers.pytorch import (
PytorchBackendCompiler,
Expand Down Expand Up @@ -319,7 +323,10 @@ def get_result(self) -> List:
DeepLearningFramework.NUMPY: ONNXApacheTVMCompiler,
},
ModelCompiler.ONNX_RUNTIME: {DeepLearningFramework.NUMPY: ONNXCompiler},
ModelCompiler.OPENVINO: {DeepLearningFramework.NUMPY: OpenVINOCompiler},
ModelCompiler.OPENVINO: {
DeepLearningFramework.NUMPY: OpenVINOCompiler,
DeepLearningFramework.TENSORFLOW: TensorFlowOpenVINOCompiler,
},
ModelCompiler.TFLITE: {
DeepLearningFramework.TENSORFLOW: TFLiteBackendCompiler
},
Expand Down Expand Up @@ -352,7 +359,8 @@ def get_result(self) -> List:
DeepLearningFramework.NUMPY: ONNXBuildInferenceLearner
},
ModelCompiler.OPENVINO: {
DeepLearningFramework.NUMPY: OpenVINOBuildInferenceLearner
DeepLearningFramework.NUMPY: OpenVINOBuildInferenceLearner,
DeepLearningFramework.TENSORFLOW: TensorflowOpenVinoInferenceLearner,
},
ModelCompiler.TFLITE: {
DeepLearningFramework.TENSORFLOW: TFLiteBuildInferenceLearner
Expand Down
131 changes: 131 additions & 0 deletions nebullvm/operations/optimizations/compilers/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,137 @@
from nebullvm.tools.transformations import MultiStageTransformation


# classe di base. Contains generic methods.
diegofiori marked this conversation as resolved.
Show resolved Hide resolved
class TensorFlowOpenVINOCompiler(Compiler):
supported_ops = {
"cpu": [
None,
# QuantizationType.STATIC,
# QuantizationType.HALF,
],
"gpu": [],
}

def __init__(self):
super().__init__()

def execute(
self,
model: Union[str, Path],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the input for tensorflow compilers should be a tensorflow mode not a file. You should then save it to file and just at the end give it as input to the OpenVino CLI.

Copy link
Author

@marcoschouten marcoschouten Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, this is how I would do it:

  1. Save the model to a file, (would we need to define a specific folder?):
    model.save('saved_model/my_model')

  2. Load the model:
    my_model= tf.keras.models.load_model('saved_model/my_model')

  3. Pass it as input to the OpenVino CLI
    mo --input_model <my_model>.pb

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can skip step 2, the execute method receives a tensorflow model, you can then save it to file (remember to save it to a temp directory as we do in the other sections) and pass it to the mo command

model_params: ModelParams,
input_tfms: MultiStageTransformation = None,
metric_drop_ths: float = None,
quantization_type: QuantizationType = None,
input_data: DataManager = None,
**kwargs,
):
"""Compile the input model using TF-OPENVINO library.

Args:
model (str): The onnx model path.
model_params (ModelParams): The model parameters.
input_tfms (MultiStageTransformation, optional): Transformations
to be performed to the model's input tensors in order to
get the prediction. Default: None.
metric_drop_ths (float, optional): Threshold for the accepted drop
in terms of precision. Any optimized model with a higher drop
will be ignored. Default: None.
quantization_type (QuantizationType, optional): The desired
quantization algorithm to be used. Default: None.
input_data (DataManager): User defined data. Default: None
"""

if quantization_type not in self.supported_ops[self.device.value]:
self.compiled_model = None
return

if quantization_type is QuantizationType.STATIC and input_data is None:
raise ValueError("Input data is required for static quantization.")

self.logger.info(
f"Optimizing with {self.__class__.__name__} and "
f"q_type: {quantization_type}."
)

check_quantization(quantization_type, metric_drop_ths)
train_input_data = input_data.get_split("train").get_numpy_list(
QUANTIZATION_DATA_NUM
)

# SYNTAX for MO command
# f"""mo
# --saved_model_dir "{model_path}"
# --input_shape "[1,224,224,3]"
# --mean_values="[127.5,127.5,127.5]"
# --scale_values="[127.5]"
# --model_name "{model_path.name}"
# --compress_to_fp16
# --output_dir "{model_path.parent}"
# """
diegofiori marked this conversation as resolved.
Show resolved Hide resolved

cmd = [
"mo",
"--saved_model_dir",
str(Path(model)),
"--output_dir",
str(Path(model)),
"--input",
",".join(get_input_names(model)),
"--input_shape",
",".join(
[
f"{list((model_params.batch_size,) + shape)}"
for shape in model_params.input_sizes
]
),
]

if quantization_type is QuantizationType.DYNAMIC:
return None

if quantization_type is QuantizationType.HALF:
cmd = cmd + ["--data_type", "FP16"]

process = subprocess.Popen(cmd)
process.wait()
base_path = Path(model).parent
openvino_model_path = base_path / f"{Path(model).stem}.xml"
openvino_model_weights = base_path / f"{Path(model).stem}.bin"

if quantization_type not in [QuantizationType.HALF, None]:
openvino_model_path, openvino_model_weights = self._quantize_model(
model_topology=str(openvino_model_path),
model_weights=str(openvino_model_weights),
input_names=get_input_names(model),
input_data=train_input_data,
)

self.compiled_model = str(
Path(openvino_model_path).parent / Path(openvino_model_path).stem
)

def _compile_model(
self,
model_name: str,
model_weights: str,
network_parameters: ModelParams,
) -> CompiledModel:
core = Core()
model = core.read_model(model=model_name, weights=model_weights)

dynamic_shape = self._get_dynamic_shape(model, network_parameters)

if dynamic_shape is not None:
model.reshape(dynamic_shape)

return core.compile_model(model=model, device_name="CPU")


# ____________________________________________________________________________________________________________________________________________________________________________
# ____________________________________________________________________________________________________________________________________________________________________________
# ____________________________________________________________________________________________________________________________________________________________________________


class OpenVINOCompiler(Compiler):
supported_ops = {
"cpu": [
Expand Down
2 changes: 2 additions & 0 deletions nebullvm/operations/optimizations/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def _select_compilers_from_hardware(self):
if tensorflow_is_available():
compilers.append(ModelCompiler.XLA)
compilers.append(ModelCompiler.TFLITE)
if self.device is Device.CPU and openvino_is_available():
compilers.append(ModelCompiler.OPENVINO)
return compilers


Expand Down
174 changes: 174 additions & 0 deletions nebullvm/tools/openvino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from pathlib import Path
diegofiori marked this conversation as resolved.
Show resolved Hide resolved
from typing import List, Tuple, Optional, Dict, Union, Sequence

from nebullvm.optional_modules.torch import torch, Module, DataLoader
from nebullvm.tools.base import DataType, InputInfo, Device
from nebullvm.tools.data import DataManager

FX_MODULE_NAME = "NebullvmFxModule"


def save_with_torch_fx(model: Module, path: Path):
traced_model = torch.fx.symbolic_trace(model)
traced_model.to_folder(path, FX_MODULE_NAME)


def load_with_torch_fx(
path: Path, state_dict_name: str = "pruned_state_dict.pt"
):
module_file = path / "module.py"
with open(module_file, "r") as f:
module_str = f.read()
exec(module_str, globals())
model = eval(FX_MODULE_NAME)()
model.load_state_dict(torch.load(path / state_dict_name))
return model


def get_outputs_sizes_torch(
torch_model: Module,
input_tensors: List[torch.Tensor],
device: Device,
) -> List[Tuple[int, ...]]:
if device is Device.GPU:
input_tensors = [x.cuda() for x in input_tensors]
torch_model.cuda()
with torch.no_grad():
outputs = torch_model(*input_tensors)
if isinstance(outputs, torch.Tensor):
return [tuple(outputs.size())[1:]]
else:
return [tuple(output.size())[1:] for output in outputs]


def create_model_inputs_torch(
batch_size: int, input_infos: List[InputInfo]
) -> List[torch.Tensor]:
input_tensors = (
torch.randn((batch_size, *input_info.size))
if input_info.dtype is DataType.FLOAT32
else torch.randint(
size=(batch_size, *input_info.size),
low=input_info.min_value or 0,
high=input_info.max_value or 100,
)
for input_info in input_infos
)
return list(input_tensors)


def run_torch_model(
torch_model: Module,
input_tensors: List[torch.Tensor],
device: Device,
dtype: torch.dtype = torch.float,
) -> List[torch.Tensor]:
torch_model.eval()
if device is Device.GPU:
torch_model.cuda()
if dtype != torch.half:
input_tensors = (t.cuda() for t in input_tensors)
else:
input_tensors = (
t.cuda().half() if t.dtype == torch.float else t.cuda()
for t in input_tensors
)
with torch.no_grad():
pred = torch_model(*input_tensors)
if isinstance(pred, torch.Tensor):
pred = [pred.cpu()]
else:
pred = [p.cpu() for p in pred]
return pred


def _extract_dynamic_axis(
torch_model: Module,
dataloader: DataManager,
input_sizes: List[Tuple[int, ...]],
batch_size: int,
device: Device,
max_data: int = 100,
) -> Optional[Dict]:
from nebullvm.tools.utils import inspect_dynamic_size

dynamic_axis = {"inputs": [{}] * len(input_sizes), "outputs": []}
output_sizes = []
for i, input_data in enumerate(dataloader):
input_tensors = input_data[0]
if i >= max_data:
break
inspect_dynamic_size(
input_tensors, input_sizes, batch_size, dynamic_axis["inputs"]
)
outputs = tuple(run_torch_model(torch_model, input_tensors, device))
if i == 0:
dynamic_axis["outputs"] = [{}] * len(outputs)
output_sizes = [tuple(output.shape[1:]) for output in outputs]
inspect_dynamic_size(
outputs, output_sizes, batch_size, dynamic_axis["outputs"]
)
if any(
len(x) > 0 for x in (dynamic_axis["inputs"] + dynamic_axis["outputs"])
):
return dynamic_axis
return None


def extract_info_from_torch_data(
model: Module,
dataloader: Union[DataLoader, Sequence],
batch_size: int,
input_sizes: List[Tuple[int, ...]],
input_types: List[str],
dynamic_axis: Dict,
device: Device,
):
from nebullvm.tools.utils import ifnone

input_data = (
dataloader[0]
if isinstance(dataloader, Sequence)
else next(iter(dataloader))
)
input_row = input_data[0]

batch_size = ifnone(batch_size, int(input_row[0].shape[0]))
input_sizes = ifnone(input_sizes, [tuple(x.shape[1:]) for x in input_row])
input_types = ifnone(
input_types,
[
"int64"
if isinstance(x.cpu(), torch.LongTensor)
else "int32"
if isinstance(x.cpu(), torch.IntTensor)
else "float32"
for x in input_row
],
)

if dynamic_axis is not None:
dynamic_axis["inputs"] = [
{int(k): v for (k, v) in val.items()}
for val in dynamic_axis["inputs"]
]
dynamic_axis["outputs"] = [
{int(k): v for (k, v) in val.items()}
for val in dynamic_axis["outputs"]
]

dynamic_axis = ifnone(
dynamic_axis,
_extract_dynamic_axis(
model, dataloader, input_sizes, batch_size, device
),
)
return batch_size, input_sizes, input_types, dynamic_axis


def torch_is_gpu_available():
return torch.cuda.is_available()


def torch_get_device_name():
return torch.cuda.get_device_name(0)
Loading