Skip to content

Commit

Permalink
Add int8 and fp16 to OV export CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 29, 2023
1 parent 72f369c commit 42d04b2
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 13 deletions.
4 changes: 4 additions & 0 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def parse_args_openvino(parser: "ArgumentParser"):
"This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it."
),
)
optional_group.add_argument("--fp16", action="store_true", help="Compress weights to half precision"),
optional_group.add_argument("--int8", action="store_true", help="Compress weights to int8"),


class OVExportCommand(BaseOptimumCLICommand):
Expand Down Expand Up @@ -102,5 +104,7 @@ def run(self):
cache_dir=self.args.cache_dir,
trust_remote_code=self.args.trust_remote_code,
pad_token_id=self.args.pad_token_id,
fp16=self.args.fp16,
int8=self.args.int8,
# **input_shapes,
)
39 changes: 32 additions & 7 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,26 @@

from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoTokenizer
from transformers.utils import is_torch_available

import openvino
from openvino import Core
from optimum.exporters import TasksManager
from optimum.exporters.onnx import __main__ as optimum_main
from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.save_utils import maybe_save_preprocessors

from ...intel.utils.import_utils import is_nncf_available
from ...intel.utils.modeling_utils import patch_decoder_attention_mask
from .convert import export_models


core = Core()

OV_XML_FILE_NAME = "openvino_model.xml"

logger = logging.getLogger(__name__)

if is_torch_available():
import torch


def main_export(
model_name_or_path: str,
Expand All @@ -57,6 +58,7 @@ def main_export(
model_kwargs: Optional[Dict[str, Any]] = None,
custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None,
fn_get_submodels: Optional[Callable] = None,
int8: Optional[bool] = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -123,6 +125,19 @@ def main_export(
>>> main_export("gpt2", output="gpt2_onnx/")
```
"""
if int8:
if fp16:
raise ValueError(
"Both `fp16` and `int8` were both set to `True`, please select only one of these options."
)

if not is_nncf_available():
raise ImportError(
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
)

import nncf

output = Path(output)
if not output.exists():
output.mkdir(parents=True)
Expand All @@ -139,8 +154,6 @@ def main_export(
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
)

torch_dtype = None if fp16 is False else torch.float16

if task == "auto":
try:
task = TasksManager.infer_task_from_model(model_name_or_path)
Expand All @@ -164,7 +177,6 @@ def main_export(
force_download=force_download,
trust_remote_code=trust_remote_code,
framework=framework,
torch_dtype=torch_dtype,
device=device,
)

Expand Down Expand Up @@ -299,5 +311,18 @@ def main_export(
output_names=files_subpaths,
input_shapes=input_shapes,
device=device,
fp16=fp16,
model_kwargs=model_kwargs,
)
del models_and_onnx_configs

if int8:
for model_path in files_subpaths:
model = core.read_model(output / model_path)
model = nncf.compress_weights(model)

for filename in (model_path, model_path.replace("xml", "bin")):
os.remove(output / filename)

openvino.save_model(model, output / model_path, compress_to_fp16=False)
del model
26 changes: 20 additions & 6 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def export(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -101,6 +102,7 @@ def export(
device=device,
input_shapes=input_shapes,
model_kwargs=model_kwargs,
fp16=fp16,
)

elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
Expand All @@ -111,15 +113,21 @@ def export(
raise RuntimeError("`tf2onnx` does not support export on CUDA device.")
if input_shapes is not None:
logger.info("`input_shapes` argument is not supported by the Tensorflow ONNX export and will be ignored.")
return export_tensorflow(model, config, opset, output)
return export_tensorflow(model, config, opset, output, fp16=fp16)

else:
raise RuntimeError(
"You either provided a PyTorch model with only TensorFlow installed, or a TensorFlow model with only PyTorch installed."
)


def export_tensorflow(model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, opset: int, output: Path):
def export_tensorflow(
model: Union["PreTrainedModel", "ModelMixin"],
config: OnnxConfig,
opset: int,
output: Path,
fp16: bool = False,
):
"""
Export the TensorFlow model to OpenVINO format.
Expand All @@ -137,7 +145,7 @@ def export_tensorflow(model: Union["PreTrainedModel", "ModelMixin"], config: Onn
onnx_path = Path(output).with_suffix(".onnx")
input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path)
ov_model = convert_model(str(onnx_path))
save_model(ov_model, output.parent / output, compress_to_fp16=False)
save_model(ov_model, output.parent / output, compress_to_fp16=fp16)
return input_names, output_names, True


Expand All @@ -149,6 +157,7 @@ def export_pytorch_via_onnx(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
):
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation via ONNX export.
Expand Down Expand Up @@ -190,7 +199,7 @@ def export_pytorch_via_onnx(
save_model(
ov_model,
output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output,
compress_to_fp16=False,
compress_to_fp16=fp16,
)
return input_names, output_names, True

Expand All @@ -203,6 +212,7 @@ def export_pytorch(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -297,7 +307,9 @@ def ts_patched_forward(*args, **kwargs):
ov_model = convert_model(model, example_input=dummy_inputs, input=input_info)
except Exception as ex:
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
return export_pytorch_via_onnx(model, config, opset, output, device, input_shapes, model_kwargs)
return export_pytorch_via_onnx(
model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16
)
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
ordered_input_names = list(inputs)
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
Expand All @@ -318,7 +330,7 @@ def ts_patched_forward(*args, **kwargs):
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
ov_model.validate_nodes_and_infer_types()
save_model(ov_model, output, compress_to_fp16=False)
save_model(ov_model, output, compress_to_fp16=fp16)
clear_class_registry()
del model
gc.collect()
Expand All @@ -335,6 +347,7 @@ def export_models(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Export the models to OpenVINO IR format
Expand Down Expand Up @@ -379,6 +392,7 @@ def export_models(
device=device,
input_shapes=input_shapes,
model_kwargs=model_kwargs,
fp16=fp16,
)
)

Expand Down

0 comments on commit 42d04b2

Please sign in to comment.