From 8801566dff6fd22df376a0c9f9ffe388bf1753d7 Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 2 Oct 2023 18:20:48 +0400 Subject: [PATCH 1/9] Added 8bit compression for decoders larger than 1B --- optimum/exporters/openvino/__main__.py | 11 +++++++++++ optimum/exporters/openvino/convert.py | 20 ++++++++++++++++---- optimum/intel/openvino/modeling_decoder.py | 3 ++- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index a7d5874585..736b5234a9 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -33,6 +33,8 @@ OV_XML_FILE_NAME = "openvino_model.xml" +_MAX_UNCOMPRESSED_DECODER_SIZE = 1e9 + logger = logging.getLogger(__name__) if is_torch_available(): @@ -232,6 +234,15 @@ def main_export( onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) onnx_config = onnx_config_constructor(model.config) models_and_onnx_configs = {"model": (model, onnx_config)} + load_in_8bit = model_kwargs.get("load_in_8bit", None) + if load_in_8bit is None: + if model_kwargs is None: + model_kwargs = {} + + if model.num_parameters() >= _MAX_UNCOMPRESSED_DECODER_SIZE: + model_kwargs["load_in_8bit"] = True + else: + model_kwargs["load_in_8bit"] = False if not is_stable_diffusion: needs_pad_token_id = ( diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 9a6cbec07b..24af408f78 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -21,6 +21,8 @@ from transformers.utils import is_tf_available, is_torch_available +import nncf + from openvino.runtime import PartialShape, save_model from openvino.runtime.utils.types import get_element_type from openvino.tools.ovc import convert_model @@ -52,6 +54,12 @@ from transformers.modeling_tf_utils import TFPreTrainedModel +def _save_model(model, path: str, compress_to_fp16=False, load_in_8bit=False): + if load_in_8bit: + model = nncf.compress_weights(model) + save_model(model, path, compress_to_fp16) + + def export( model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], config: OnnxConfig, @@ -137,7 +145,7 @@ def export_tensorflow(model: Union["PreTrainedModel", "ModelMixin"], config: Onn onnx_path = Path(output).with_suffix(".onnx") input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path) ov_model = convert_model(str(onnx_path)) - save_model(ov_model, output.parent / output, compress_to_fp16=False) + _save_model(ov_model, output.parent / output, compress_to_fp16=False, load_in_8bit=False) return input_names, output_names, True @@ -187,10 +195,11 @@ def export_pytorch_via_onnx( ) torch.onnx.export = orig_torch_onnx_export ov_model = convert_model(str(onnx_output)) - save_model( + _save_model( ov_model, output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output, compress_to_fp16=False, + load_in_8bit=model_kwargs.get("load_in_8bit", False) ) return input_names, output_names, True @@ -314,11 +323,14 @@ def ts_patched_forward(*args, **kwargs): dims = inputs[input_name] for dim in dims: - static_shape[dim] = -1 + static_shape[dim] = -1 inp_tensor.get_node().set_partial_shape(static_shape) inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) ov_model.validate_nodes_and_infer_types() - save_model(ov_model, output, compress_to_fp16=False) + _save_model(ov_model, + output, + compress_to_fp16=False, + load_in_8bit=model_kwargs.get("load_in_8bit", False)) clear_class_registry() del model gc.collect() diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 6c45172652..db3bc21108 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -86,7 +86,6 @@ "pegasus", } - @add_start_docstrings( """ Base OVBaseDecoderModel class. @@ -226,6 +225,7 @@ def _from_transformers( if use_cache: task = task + "-with-past" + main_export( model_name_or_path=model_id, output=save_dir_path, @@ -237,6 +237,7 @@ def _from_transformers( local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, + model_kwargs=kwargs, ) config.is_decoder = True From 821d2a9e732edf10f57625b2ae58af6c99914d6b Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 2 Oct 2023 18:21:44 +0400 Subject: [PATCH 2/9] Style --- optimum/exporters/openvino/__main__.py | 4 ++-- optimum/exporters/openvino/convert.py | 12 ++++-------- optimum/intel/openvino/modeling_decoder.py | 2 +- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 736b5234a9..1d3ab14907 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -238,9 +238,9 @@ def main_export( if load_in_8bit is None: if model_kwargs is None: model_kwargs = {} - + if model.num_parameters() >= _MAX_UNCOMPRESSED_DECODER_SIZE: - model_kwargs["load_in_8bit"] = True + model_kwargs["load_in_8bit"] = True else: model_kwargs["load_in_8bit"] = False diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 24af408f78..b1597e60d2 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -19,9 +19,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union -from transformers.utils import is_tf_available, is_torch_available - import nncf +from transformers.utils import is_tf_available, is_torch_available from openvino.runtime import PartialShape, save_model from openvino.runtime.utils.types import get_element_type @@ -199,7 +198,7 @@ def export_pytorch_via_onnx( ov_model, output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output, compress_to_fp16=False, - load_in_8bit=model_kwargs.get("load_in_8bit", False) + load_in_8bit=model_kwargs.get("load_in_8bit", False), ) return input_names, output_names, True @@ -323,14 +322,11 @@ def ts_patched_forward(*args, **kwargs): dims = inputs[input_name] for dim in dims: - static_shape[dim] = -1 + static_shape[dim] = -1 inp_tensor.get_node().set_partial_shape(static_shape) inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) ov_model.validate_nodes_and_infer_types() - _save_model(ov_model, - output, - compress_to_fp16=False, - load_in_8bit=model_kwargs.get("load_in_8bit", False)) + _save_model(ov_model, output, compress_to_fp16=False, load_in_8bit=model_kwargs.get("load_in_8bit", False)) clear_class_registry() del model gc.collect() diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index db3bc21108..91a2c7ddc2 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -86,6 +86,7 @@ "pegasus", } + @add_start_docstrings( """ Base OVBaseDecoderModel class. @@ -225,7 +226,6 @@ def _from_transformers( if use_cache: task = task + "-with-past" - main_export( model_name_or_path=model_id, output=save_dir_path, From aa0c6adaf96ac7944ecf4dfb1dd5c6243c2f195f Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 2 Oct 2023 19:24:42 +0400 Subject: [PATCH 3/9] Fixed issue --- optimum/exporters/openvino/convert.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index b1597e60d2..22cbeb68b9 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -194,11 +194,12 @@ def export_pytorch_via_onnx( ) torch.onnx.export = orig_torch_onnx_export ov_model = convert_model(str(onnx_output)) + load_in_8bit = False if model_kwargs is None else model_kwargs.get("load_in_8bit", False) _save_model( ov_model, output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output, compress_to_fp16=False, - load_in_8bit=model_kwargs.get("load_in_8bit", False), + load_in_8bit=load_in_8bit, ) return input_names, output_names, True @@ -326,7 +327,8 @@ def ts_patched_forward(*args, **kwargs): inp_tensor.get_node().set_partial_shape(static_shape) inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) ov_model.validate_nodes_and_infer_types() - _save_model(ov_model, output, compress_to_fp16=False, load_in_8bit=model_kwargs.get("load_in_8bit", False)) + load_in_8bit = False if model_kwargs is None else model_kwargs.get("load_in_8bit", False) + _save_model(ov_model, output, compress_to_fp16=False, load_in_8bit=load_in_8bit) clear_class_registry() del model gc.collect() From 9adc2984ebc3f0870b763906823d3eeab23c38c0 Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 3 Oct 2023 10:19:44 +0400 Subject: [PATCH 4/9] Fixed one more issue --- optimum/exporters/openvino/__main__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 1d3ab14907..7d31060e6a 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -234,11 +234,10 @@ def main_export( onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) onnx_config = onnx_config_constructor(model.config) models_and_onnx_configs = {"model": (model, onnx_config)} + if model_kwargs is None: + model_kwargs = {} load_in_8bit = model_kwargs.get("load_in_8bit", None) if load_in_8bit is None: - if model_kwargs is None: - model_kwargs = {} - if model.num_parameters() >= _MAX_UNCOMPRESSED_DECODER_SIZE: model_kwargs["load_in_8bit"] = True else: From bf519490249054ba47b67df042ad801668dba5dd Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 3 Oct 2023 12:06:15 +0400 Subject: [PATCH 5/9] Added warning for nncf absense in case of default compression to 8 bits --- optimum/exporters/openvino/convert.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 22cbeb68b9..c5bd4267cc 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -19,7 +19,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union -import nncf from transformers.utils import is_tf_available, is_torch_available from openvino.runtime import PartialShape, save_model @@ -29,6 +28,7 @@ from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx +from optimum.intel.openvino import is_nncf_available from optimum.utils import is_diffusers_available from .utils import ( @@ -55,6 +55,13 @@ def _save_model(model, path: str, compress_to_fp16=False, load_in_8bit=False): if load_in_8bit: + if not is_nncf_available(): + logger.warning( + "The model will be converted with no weights quantization. Quantization of the weights to int8 requires nncf." + "please install it with `pip install nncf`" + ) + import nncf + model = nncf.compress_weights(model) save_model(model, path, compress_to_fp16) From 2d16741a446d057849979ac1456aeb3de0dcce27 Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 3 Oct 2023 15:36:15 +0400 Subject: [PATCH 6/9] Fixed an issue. Added warning message when NNCF is not available --- optimum/exporters/openvino/__main__.py | 6 ++++++ optimum/exporters/openvino/convert.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 7d31060e6a..8152b92d29 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -27,6 +27,7 @@ from optimum.utils import DEFAULT_DUMMY_SHAPES from optimum.utils.save_utils import maybe_save_preprocessors +from ...intel.utils.import_utils import is_nncf_available from ...intel.utils.modeling_utils import patch_decoder_attention_mask from .convert import export_models @@ -242,6 +243,11 @@ def main_export( model_kwargs["load_in_8bit"] = True else: model_kwargs["load_in_8bit"] = False + else: + if not is_nncf_available(): + raise ImportError( + "Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`" + ) if not is_stable_diffusion: needs_pad_token_id = ( diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index c5bd4267cc..b29efe253e 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -28,9 +28,9 @@ from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx -from optimum.intel.openvino import is_nncf_available from optimum.utils import is_diffusers_available +from ...intel.utils.import_utils import is_nncf_available from .utils import ( OV_XML_FILE_NAME, clear_class_registry, From d5ff9d1088fa4fb2ae3192ec7a90a0947c64baaf Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 3 Oct 2023 17:31:29 +0400 Subject: [PATCH 7/9] Revised logic of the default INT8 export --- optimum/exporters/openvino/__main__.py | 18 ++++++++---------- optimum/exporters/openvino/convert.py | 6 +++--- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 8152b92d29..3baa9119a1 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -235,19 +235,17 @@ def main_export( onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) onnx_config = onnx_config_constructor(model.config) models_and_onnx_configs = {"model": (model, onnx_config)} - if model_kwargs is None: - model_kwargs = {} + model_kwargs = model_kwargs or {} load_in_8bit = model_kwargs.get("load_in_8bit", None) if load_in_8bit is None: if model.num_parameters() >= _MAX_UNCOMPRESSED_DECODER_SIZE: - model_kwargs["load_in_8bit"] = True - else: - model_kwargs["load_in_8bit"] = False - else: - if not is_nncf_available(): - raise ImportError( - "Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`" - ) + if not is_nncf_available(): + logger.warning( + "The model will be converted with no weights quantization. Quantization of the weights to int8 requires nncf." + "please install it with `pip install nncf`" + ) + else: + model_kwargs["load_in_8bit"] = True if not is_stable_diffusion: needs_pad_token_id = ( diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index b29efe253e..ab4a41e873 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -56,10 +56,10 @@ def _save_model(model, path: str, compress_to_fp16=False, load_in_8bit=False): if load_in_8bit: if not is_nncf_available(): - logger.warning( - "The model will be converted with no weights quantization. Quantization of the weights to int8 requires nncf." - "please install it with `pip install nncf`" + raise ImportError( + "Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`" ) + import nncf model = nncf.compress_weights(model) From 6b22da2cb0b0912227fa771c8689e259af91c8c4 Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 3 Oct 2023 18:20:37 +0400 Subject: [PATCH 8/9] Added tests for auto weights compression --- tests/openvino/test_quantization.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 55758b6683..6563eed7d8 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -150,6 +150,8 @@ class OVWeightCompressionTest(unittest.TestCase): (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45, 22), ) + UPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION = ((OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 22),) + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS) def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8): task = model_cls.export_feature @@ -197,6 +199,18 @@ def test_ovmodel_weight_compression(self, model_cls, model_name, expected_pt_int outputs = model(**tokens) self.assertTrue("logits" in outputs) + @parameterized.expand(UPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) + def test_ovmodel_load_with_compressed_weights(self, model_cls, model_name, expected_ov_int8): + model = model_cls.from_pretrained(model_name, export=True, load_in_8bit=True) + _, num_int8 = get_num_quantized_nodes(model) + self.assertEqual(expected_ov_int8, num_int8) + + @parameterized.expand(UPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) + def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_name, expected_ov_int8): + model = model_cls.from_pretrained(model_name, export=True, load_in_8bit=False) + _, num_int8 = get_num_quantized_nodes(model) + self.assertEqual(0, num_int8) + class OVQuantizerQATest(unittest.TestCase): SUPPORTED_ARCHITECTURES = (("hf-internal-testing/tiny-random-BertForQuestionAnswering",),) From dd9b40ef41a15d2bc8a2ecc7fb25a75677eb66ce Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 4 Oct 2023 10:55:09 +0400 Subject: [PATCH 9/9] Updated references --- tests/openvino/test_training.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/openvino/test_training.py b/tests/openvino/test_training.py index 6699687c69..91defbefbb 100644 --- a/tests/openvino/test_training.py +++ b/tests/openvino/test_training.py @@ -715,7 +715,7 @@ def check_ovmodel_reshaping(self, ovmodel: OVModel): model_id="hf-internal-testing/tiny-random-Wav2Vec2Model", nncf_compression_config=[QUANTIZATION_CONFIG_FOR_WAV2VEC2], expected_fake_quantize=48, - expected_int8=31, + expected_int8=30, compression_metrics=["compression_loss"], ), "structured_movement_sparsity": OVTrainerTestDescriptor( @@ -734,7 +734,7 @@ def check_ovmodel_reshaping(self, ovmodel: OVModel): model_id="hf-internal-testing/tiny-random-Wav2Vec2Model", nncf_compression_config=[QUANTIZATION_CONFIG_FOR_WAV2VEC2, STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_WAV2VEC2], expected_fake_quantize=48, - expected_int8=31, + expected_int8=30, expected_binary_masks=48, compression_metrics=["compression_loss"], ), @@ -742,7 +742,7 @@ def check_ovmodel_reshaping(self, ovmodel: OVModel): model_id="hf-internal-testing/tiny-random-Wav2Vec2Model", nncf_compression_config=[QUANTIZATION_CONFIG_FOR_WAV2VEC2, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_WAV2VEC2], expected_fake_quantize=48, - expected_int8=31, + expected_int8=30, expected_binary_masks=48, compression_metrics=["compression_loss"], ), @@ -751,7 +751,7 @@ def check_ovmodel_reshaping(self, ovmodel: OVModel): teacher_model_id="hf-internal-testing/tiny-random-Wav2Vec2Model", nncf_compression_config=[QUANTIZATION_CONFIG_FOR_WAV2VEC2, STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_WAV2VEC2], expected_fake_quantize=48, - expected_int8=31, + expected_int8=30, expected_binary_masks=48, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], ), @@ -760,7 +760,7 @@ def check_ovmodel_reshaping(self, ovmodel: OVModel): teacher_model_id="hf-internal-testing/tiny-random-Wav2Vec2Model", nncf_compression_config=[QUANTIZATION_CONFIG_FOR_WAV2VEC2, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_WAV2VEC2], expected_fake_quantize=48, - expected_int8=31, + expected_int8=30, expected_binary_masks=48, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], ),