From 673484b85c31bd2939923c1ec7a79b7a8bcda023 Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Wed, 20 Sep 2023 17:31:25 +0100 Subject: [PATCH] Added 8-bit weight compression for OVModel (#415) * Added 8-bit weight compression for OVModel * fix test --- optimum/intel/openvino/quantization.py | 18 ++++++++++++---- tests/openvino/test_quantization.py | 29 ++++++++++++++++++++++---- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 3349ce142f..da33eca733 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -167,10 +167,6 @@ def quantize( raise ValueError("`save_directory` needs to be specified") if weights_only: - if isinstance(self.model, OVBaseModel): - raise ValueError( - "`weights_only` currently not supported for `OVModels`, only available for torch.nn.Module." - ) if calibration_dataset is not None: logger.warning( "`calibration_dataset` was provided but will not be used as `weights_only` is set to `True`." @@ -189,6 +185,7 @@ def quantize( batch_size, data_collator, remove_unused_columns, + weights_only, **kwargs, ) elif isinstance(self.model, OVBaseModel): @@ -198,6 +195,7 @@ def quantize( batch_size, data_collator, remove_unused_columns, + weights_only, **kwargs, ) elif isinstance(self.model, torch.nn.Module): @@ -221,11 +219,17 @@ def _quantize_ovbasemodel( batch_size: int = 1, data_collator: Optional[DataCollator] = None, remove_unused_columns: bool = True, + weights_only: bool = False, **kwargs, ): save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) + if weights_only: + self.model.model = nncf.compress_weights(self.model.model) + self.model.save_pretrained(save_directory) + return + calibration_dataloader = self._get_calibration_dataloader( calibration_dataset=calibration_dataset, batch_size=batch_size, @@ -251,11 +255,17 @@ def _quantize_ovcausallm( batch_size: int = 1, data_collator: Optional[DataCollator] = None, remove_unused_columns: bool = True, + weights_only: bool = False, **kwargs, ): save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) + if weights_only: + self.model.model = nncf.compress_weights(self.model.model) + self.model.save_pretrained(save_directory) + return + calibration_dataloader = self._get_calibration_dataloader( calibration_dataset=calibration_dataset, batch_size=batch_size, diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 369ad0f836..55758b6683 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -146,12 +146,12 @@ def preprocess_function(examples, tokenizer): class OVWeightCompressionTest(unittest.TestCase): # TODO : add models SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS = ( - (OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 70), - (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45), + (OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 70, 35), + (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45, 22), ) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS) - def test_automodel_weight_compression(self, model_cls, model_name, expected_int8): + def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8): task = model_cls.export_feature with tempfile.TemporaryDirectory() as tmp_dir: @@ -166,7 +166,7 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_int8 # TODO: uncomment once move to a newer version of NNCF which has some fixes _, num_int8 = get_num_quantized_nodes(model) - self.assertEqual(expected_int8, num_int8) + self.assertEqual(expected_pt_int8, num_int8) tokens = tokenizer("This is a sample input", return_tensors="pt") outputs = model(**tokens) @@ -176,6 +176,27 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_int8 loaded_config = OVConfig.from_pretrained(tmp_dir) self.assertIsNotNone(loaded_config) + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS) + def test_ovmodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8): + task = model_cls.export_feature + + with tempfile.TemporaryDirectory() as tmp_dir: + transformers_model = model_cls.from_pretrained(model_name, export=True) + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + quantizer = OVQuantizer.from_pretrained(transformers_model, task=task) + quantizer.quantize(save_directory=tmp_dir, weights_only=True) + model = model_cls.from_pretrained(tmp_dir) + + _, num_int8 = get_num_quantized_nodes(model) + self.assertEqual(expected_ov_int8, num_int8) + + tokens = tokenizer("This is a sample input", return_tensors="pt") + outputs = model(**tokens) + self.assertTrue("logits" in outputs) + class OVQuantizerQATest(unittest.TestCase): SUPPORTED_ARCHITECTURES = (("hf-internal-testing/tiny-random-BertForQuestionAnswering",),)