-
Notifications
You must be signed in to change notification settings - Fork 637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added openvino compiler with tensorflow interface #137
base: main
Are you sure you want to change the base?
Changes from 1 commit
21ba2df
8df7a41
260a908
4e5f201
172a784
6e15480
2999748
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,137 @@ | |
from nebullvm.tools.transformations import MultiStageTransformation | ||
|
||
|
||
# classe di base. Contains generic methods. | ||
diegofiori marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class TensorFlowOpenVINOCompiler(Compiler): | ||
supported_ops = { | ||
"cpu": [ | ||
None, | ||
# QuantizationType.STATIC, | ||
# QuantizationType.HALF, | ||
], | ||
"gpu": [], | ||
} | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def execute( | ||
self, | ||
model: Union[str, Path], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the input for tensorflow compilers should be a tensorflow mode not a file. You should then save it to file and just at the end give it as input to the OpenVino CLI. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, this is how I would do it:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can skip step 2, the execute method receives a tensorflow model, you can then save it to file (remember to save it to a temp directory as we do in the other sections) and pass it to the |
||
model_params: ModelParams, | ||
input_tfms: MultiStageTransformation = None, | ||
metric_drop_ths: float = None, | ||
quantization_type: QuantizationType = None, | ||
input_data: DataManager = None, | ||
**kwargs, | ||
): | ||
"""Compile the input model using TF-OPENVINO library. | ||
|
||
Args: | ||
model (str): The onnx model path. | ||
model_params (ModelParams): The model parameters. | ||
input_tfms (MultiStageTransformation, optional): Transformations | ||
to be performed to the model's input tensors in order to | ||
get the prediction. Default: None. | ||
metric_drop_ths (float, optional): Threshold for the accepted drop | ||
in terms of precision. Any optimized model with a higher drop | ||
will be ignored. Default: None. | ||
quantization_type (QuantizationType, optional): The desired | ||
quantization algorithm to be used. Default: None. | ||
input_data (DataManager): User defined data. Default: None | ||
""" | ||
|
||
if quantization_type not in self.supported_ops[self.device.value]: | ||
self.compiled_model = None | ||
return | ||
|
||
if quantization_type is QuantizationType.STATIC and input_data is None: | ||
raise ValueError("Input data is required for static quantization.") | ||
|
||
self.logger.info( | ||
f"Optimizing with {self.__class__.__name__} and " | ||
f"q_type: {quantization_type}." | ||
) | ||
|
||
check_quantization(quantization_type, metric_drop_ths) | ||
train_input_data = input_data.get_split("train").get_numpy_list( | ||
QUANTIZATION_DATA_NUM | ||
) | ||
|
||
# SYNTAX for MO command | ||
# f"""mo | ||
# --saved_model_dir "{model_path}" | ||
# --input_shape "[1,224,224,3]" | ||
# --mean_values="[127.5,127.5,127.5]" | ||
# --scale_values="[127.5]" | ||
# --model_name "{model_path.name}" | ||
# --compress_to_fp16 | ||
# --output_dir "{model_path.parent}" | ||
# """ | ||
diegofiori marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
cmd = [ | ||
"mo", | ||
"--saved_model_dir", | ||
str(Path(model)), | ||
"--output_dir", | ||
str(Path(model)), | ||
"--input", | ||
",".join(get_input_names(model)), | ||
"--input_shape", | ||
",".join( | ||
[ | ||
f"{list((model_params.batch_size,) + shape)}" | ||
for shape in model_params.input_sizes | ||
] | ||
), | ||
] | ||
|
||
if quantization_type is QuantizationType.DYNAMIC: | ||
return None | ||
|
||
if quantization_type is QuantizationType.HALF: | ||
cmd = cmd + ["--data_type", "FP16"] | ||
|
||
process = subprocess.Popen(cmd) | ||
process.wait() | ||
base_path = Path(model).parent | ||
openvino_model_path = base_path / f"{Path(model).stem}.xml" | ||
openvino_model_weights = base_path / f"{Path(model).stem}.bin" | ||
|
||
if quantization_type not in [QuantizationType.HALF, None]: | ||
openvino_model_path, openvino_model_weights = self._quantize_model( | ||
model_topology=str(openvino_model_path), | ||
model_weights=str(openvino_model_weights), | ||
input_names=get_input_names(model), | ||
input_data=train_input_data, | ||
) | ||
|
||
self.compiled_model = str( | ||
Path(openvino_model_path).parent / Path(openvino_model_path).stem | ||
) | ||
|
||
def _compile_model( | ||
self, | ||
model_name: str, | ||
model_weights: str, | ||
network_parameters: ModelParams, | ||
) -> CompiledModel: | ||
core = Core() | ||
model = core.read_model(model=model_name, weights=model_weights) | ||
|
||
dynamic_shape = self._get_dynamic_shape(model, network_parameters) | ||
|
||
if dynamic_shape is not None: | ||
model.reshape(dynamic_shape) | ||
|
||
return core.compile_model(model=model, device_name="CPU") | ||
|
||
|
||
# ____________________________________________________________________________________________________________________________________________________________________________ | ||
# ____________________________________________________________________________________________________________________________________________________________________________ | ||
# ____________________________________________________________________________________________________________________________________________________________________________ | ||
|
||
|
||
class OpenVINOCompiler(Compiler): | ||
supported_ops = { | ||
"cpu": [ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from pathlib import Path | ||
diegofiori marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from typing import List, Tuple, Optional, Dict, Union, Sequence | ||
|
||
from nebullvm.optional_modules.torch import torch, Module, DataLoader | ||
from nebullvm.tools.base import DataType, InputInfo, Device | ||
from nebullvm.tools.data import DataManager | ||
|
||
FX_MODULE_NAME = "NebullvmFxModule" | ||
|
||
|
||
def save_with_torch_fx(model: Module, path: Path): | ||
traced_model = torch.fx.symbolic_trace(model) | ||
traced_model.to_folder(path, FX_MODULE_NAME) | ||
|
||
|
||
def load_with_torch_fx( | ||
path: Path, state_dict_name: str = "pruned_state_dict.pt" | ||
): | ||
module_file = path / "module.py" | ||
with open(module_file, "r") as f: | ||
module_str = f.read() | ||
exec(module_str, globals()) | ||
model = eval(FX_MODULE_NAME)() | ||
model.load_state_dict(torch.load(path / state_dict_name)) | ||
return model | ||
|
||
|
||
def get_outputs_sizes_torch( | ||
torch_model: Module, | ||
input_tensors: List[torch.Tensor], | ||
device: Device, | ||
) -> List[Tuple[int, ...]]: | ||
if device is Device.GPU: | ||
input_tensors = [x.cuda() for x in input_tensors] | ||
torch_model.cuda() | ||
with torch.no_grad(): | ||
outputs = torch_model(*input_tensors) | ||
if isinstance(outputs, torch.Tensor): | ||
return [tuple(outputs.size())[1:]] | ||
else: | ||
return [tuple(output.size())[1:] for output in outputs] | ||
|
||
|
||
def create_model_inputs_torch( | ||
batch_size: int, input_infos: List[InputInfo] | ||
) -> List[torch.Tensor]: | ||
input_tensors = ( | ||
torch.randn((batch_size, *input_info.size)) | ||
if input_info.dtype is DataType.FLOAT32 | ||
else torch.randint( | ||
size=(batch_size, *input_info.size), | ||
low=input_info.min_value or 0, | ||
high=input_info.max_value or 100, | ||
) | ||
for input_info in input_infos | ||
) | ||
return list(input_tensors) | ||
|
||
|
||
def run_torch_model( | ||
torch_model: Module, | ||
input_tensors: List[torch.Tensor], | ||
device: Device, | ||
dtype: torch.dtype = torch.float, | ||
) -> List[torch.Tensor]: | ||
torch_model.eval() | ||
if device is Device.GPU: | ||
torch_model.cuda() | ||
if dtype != torch.half: | ||
input_tensors = (t.cuda() for t in input_tensors) | ||
else: | ||
input_tensors = ( | ||
t.cuda().half() if t.dtype == torch.float else t.cuda() | ||
for t in input_tensors | ||
) | ||
with torch.no_grad(): | ||
pred = torch_model(*input_tensors) | ||
if isinstance(pred, torch.Tensor): | ||
pred = [pred.cpu()] | ||
else: | ||
pred = [p.cpu() for p in pred] | ||
return pred | ||
|
||
|
||
def _extract_dynamic_axis( | ||
torch_model: Module, | ||
dataloader: DataManager, | ||
input_sizes: List[Tuple[int, ...]], | ||
batch_size: int, | ||
device: Device, | ||
max_data: int = 100, | ||
) -> Optional[Dict]: | ||
from nebullvm.tools.utils import inspect_dynamic_size | ||
|
||
dynamic_axis = {"inputs": [{}] * len(input_sizes), "outputs": []} | ||
output_sizes = [] | ||
for i, input_data in enumerate(dataloader): | ||
input_tensors = input_data[0] | ||
if i >= max_data: | ||
break | ||
inspect_dynamic_size( | ||
input_tensors, input_sizes, batch_size, dynamic_axis["inputs"] | ||
) | ||
outputs = tuple(run_torch_model(torch_model, input_tensors, device)) | ||
if i == 0: | ||
dynamic_axis["outputs"] = [{}] * len(outputs) | ||
output_sizes = [tuple(output.shape[1:]) for output in outputs] | ||
inspect_dynamic_size( | ||
outputs, output_sizes, batch_size, dynamic_axis["outputs"] | ||
) | ||
if any( | ||
len(x) > 0 for x in (dynamic_axis["inputs"] + dynamic_axis["outputs"]) | ||
): | ||
return dynamic_axis | ||
return None | ||
|
||
|
||
def extract_info_from_torch_data( | ||
model: Module, | ||
dataloader: Union[DataLoader, Sequence], | ||
batch_size: int, | ||
input_sizes: List[Tuple[int, ...]], | ||
input_types: List[str], | ||
dynamic_axis: Dict, | ||
device: Device, | ||
): | ||
from nebullvm.tools.utils import ifnone | ||
|
||
input_data = ( | ||
dataloader[0] | ||
if isinstance(dataloader, Sequence) | ||
else next(iter(dataloader)) | ||
) | ||
input_row = input_data[0] | ||
|
||
batch_size = ifnone(batch_size, int(input_row[0].shape[0])) | ||
input_sizes = ifnone(input_sizes, [tuple(x.shape[1:]) for x in input_row]) | ||
input_types = ifnone( | ||
input_types, | ||
[ | ||
"int64" | ||
if isinstance(x.cpu(), torch.LongTensor) | ||
else "int32" | ||
if isinstance(x.cpu(), torch.IntTensor) | ||
else "float32" | ||
for x in input_row | ||
], | ||
) | ||
|
||
if dynamic_axis is not None: | ||
dynamic_axis["inputs"] = [ | ||
{int(k): v for (k, v) in val.items()} | ||
for val in dynamic_axis["inputs"] | ||
] | ||
dynamic_axis["outputs"] = [ | ||
{int(k): v for (k, v) in val.items()} | ||
for val in dynamic_axis["outputs"] | ||
] | ||
|
||
dynamic_axis = ifnone( | ||
dynamic_axis, | ||
_extract_dynamic_axis( | ||
model, dataloader, input_sizes, batch_size, device | ||
), | ||
) | ||
return batch_size, input_sizes, input_types, dynamic_axis | ||
|
||
|
||
def torch_is_gpu_available(): | ||
return torch.cuda.is_available() | ||
|
||
|
||
def torch_get_device_name(): | ||
return torch.cuda.get_device_name(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this class is actually useless. You can simply reuse the
OpenVINOBuildInferenceLearner
passing TensorFlow assource_dl_framework
.