diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index b2d33e7647..8669126164 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -68,6 +68,8 @@ def parse_args_openvino(parser: "ArgumentParser"): "This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it." ), ) + optional_group.add_argument("--fp16", action="store_true", help="Compress weights to half precision"), + optional_group.add_argument("--int8", action="store_true", help="Compress weights to int8"), class OVExportCommand(BaseOptimumCLICommand): @@ -102,5 +104,7 @@ def run(self): cache_dir=self.args.cache_dir, trust_remote_code=self.args.trust_remote_code, pad_token_id=self.args.pad_token_id, + fp16=self.args.fp16, + int8=self.args.int8, # **input_shapes, ) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index a7d5874585..558a0970d5 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -19,25 +19,26 @@ from requests.exceptions import ConnectionError as RequestsConnectionError from transformers import AutoTokenizer -from transformers.utils import is_torch_available +import openvino +from openvino import Core from optimum.exporters import TasksManager from optimum.exporters.onnx import __main__ as optimum_main from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast from optimum.utils import DEFAULT_DUMMY_SHAPES from optimum.utils.save_utils import maybe_save_preprocessors +from ...intel.utils.import_utils import is_nncf_available from ...intel.utils.modeling_utils import patch_decoder_attention_mask from .convert import export_models +core = Core() + OV_XML_FILE_NAME = "openvino_model.xml" logger = logging.getLogger(__name__) -if is_torch_available(): - import torch - def main_export( model_name_or_path: str, @@ -57,6 +58,7 @@ def main_export( model_kwargs: Optional[Dict[str, Any]] = None, custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None, fn_get_submodels: Optional[Callable] = None, + int8: Optional[bool] = False, **kwargs_shapes, ): """ @@ -123,6 +125,19 @@ def main_export( >>> main_export("gpt2", output="gpt2_onnx/") ``` """ + if int8: + if fp16: + raise ValueError( + "Both `fp16` and `int8` were both set to `True`, please select only one of these options." + ) + + if not is_nncf_available(): + raise ImportError( + "Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`" + ) + + import nncf + output = Path(output) if not output.exists(): output.mkdir(parents=True) @@ -139,8 +154,6 @@ def main_export( kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] ) - torch_dtype = None if fp16 is False else torch.float16 - if task == "auto": try: task = TasksManager.infer_task_from_model(model_name_or_path) @@ -164,7 +177,6 @@ def main_export( force_download=force_download, trust_remote_code=trust_remote_code, framework=framework, - torch_dtype=torch_dtype, device=device, ) @@ -299,5 +311,18 @@ def main_export( output_names=files_subpaths, input_shapes=input_shapes, device=device, + fp16=fp16, model_kwargs=model_kwargs, ) + del models_and_onnx_configs + + if int8: + for model_path in files_subpaths: + model = core.read_model(output / model_path) + model = nncf.compress_weights(model) + + for filename in (model_path, model_path.replace("xml", "bin")): + os.remove(output / filename) + + openvino.save_model(model, output / model_path, compress_to_fp16=False) + del model diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 9a6cbec07b..998fe228b4 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -60,6 +60,7 @@ def export( device: str = "cpu", input_shapes: Optional[Dict] = None, model_kwargs: Optional[Dict[str, Any]] = None, + fp16: bool = False, ) -> Tuple[List[str], List[str]]: """ Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation. @@ -101,6 +102,7 @@ def export( device=device, input_shapes=input_shapes, model_kwargs=model_kwargs, + fp16=fp16, ) elif is_tf_available() and issubclass(type(model), TFPreTrainedModel): @@ -111,7 +113,7 @@ def export( raise RuntimeError("`tf2onnx` does not support export on CUDA device.") if input_shapes is not None: logger.info("`input_shapes` argument is not supported by the Tensorflow ONNX export and will be ignored.") - return export_tensorflow(model, config, opset, output) + return export_tensorflow(model, config, opset, output, fp16=fp16) else: raise RuntimeError( @@ -119,7 +121,13 @@ def export( ) -def export_tensorflow(model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, opset: int, output: Path): +def export_tensorflow( + model: Union["PreTrainedModel", "ModelMixin"], + config: OnnxConfig, + opset: int, + output: Path, + fp16: bool = False, +): """ Export the TensorFlow model to OpenVINO format. @@ -137,7 +145,7 @@ def export_tensorflow(model: Union["PreTrainedModel", "ModelMixin"], config: Onn onnx_path = Path(output).with_suffix(".onnx") input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path) ov_model = convert_model(str(onnx_path)) - save_model(ov_model, output.parent / output, compress_to_fp16=False) + save_model(ov_model, output.parent / output, compress_to_fp16=fp16) return input_names, output_names, True @@ -149,6 +157,7 @@ def export_pytorch_via_onnx( device: str = "cpu", input_shapes: Optional[Dict] = None, model_kwargs: Optional[Dict[str, Any]] = None, + fp16: bool = False, ): """ Exports a PyTorch model to an OpenVINO Intermediate Representation via ONNX export. @@ -190,7 +199,7 @@ def export_pytorch_via_onnx( save_model( ov_model, output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output, - compress_to_fp16=False, + compress_to_fp16=fp16, ) return input_names, output_names, True @@ -203,6 +212,7 @@ def export_pytorch( device: str = "cpu", input_shapes: Optional[Dict] = None, model_kwargs: Optional[Dict[str, Any]] = None, + fp16: bool = False, ) -> Tuple[List[str], List[str]]: """ Exports a PyTorch model to an OpenVINO Intermediate Representation. @@ -297,7 +307,9 @@ def ts_patched_forward(*args, **kwargs): ov_model = convert_model(model, example_input=dummy_inputs, input=input_info) except Exception as ex: logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX") - return export_pytorch_via_onnx(model, config, opset, output, device, input_shapes, model_kwargs) + return export_pytorch_via_onnx( + model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16 + ) ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs} ordered_input_names = list(inputs) flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values()) @@ -318,7 +330,7 @@ def ts_patched_forward(*args, **kwargs): inp_tensor.get_node().set_partial_shape(static_shape) inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) ov_model.validate_nodes_and_infer_types() - save_model(ov_model, output, compress_to_fp16=False) + save_model(ov_model, output, compress_to_fp16=fp16) clear_class_registry() del model gc.collect() @@ -335,6 +347,7 @@ def export_models( device: str = "cpu", input_shapes: Optional[Dict] = None, model_kwargs: Optional[Dict[str, Any]] = None, + fp16: bool = False, ) -> Tuple[List[List[str]], List[List[str]]]: """ Export the models to OpenVINO IR format @@ -379,6 +392,7 @@ def export_models( device=device, input_shapes=input_shapes, model_kwargs=model_kwargs, + fp16=fp16, ) )