Skip to content

Commit

Permalink
Added 8-bit weight compression for OVModel
Browse files Browse the repository at this point in the history
  • Loading branch information
l-bat committed Sep 20, 2023
1 parent 4b8ed24 commit 6cd4a85
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
18 changes: 14 additions & 4 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,6 @@ def quantize(
raise ValueError("`save_directory` needs to be specified")

if weights_only:
if isinstance(self.model, OVBaseModel):
raise ValueError(
"`weights_only` currently not supported for `OVModels`, only available for torch.nn.Module."
)
if calibration_dataset is not None:
logger.warning(
"`calibration_dataset` was provided but will not be used as `weights_only` is set to `True`."
Expand All @@ -189,6 +185,7 @@ def quantize(
batch_size,
data_collator,
remove_unused_columns,
weights_only,
**kwargs,
)
elif isinstance(self.model, OVBaseModel):
Expand All @@ -198,6 +195,7 @@ def quantize(
batch_size,
data_collator,
remove_unused_columns,
weights_only,
**kwargs,
)
elif isinstance(self.model, torch.nn.Module):
Expand All @@ -221,11 +219,17 @@ def _quantize_ovbasemodel(
batch_size: int = 1,
data_collator: Optional[DataCollator] = None,
remove_unused_columns: bool = True,
weights_only: bool = False,
**kwargs,
):
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)

if weights_only:
self.model.model = nncf.compress_weights(self.model.model)
self.model.save_pretrained(save_directory)
return

calibration_dataloader = self._get_calibration_dataloader(
calibration_dataset=calibration_dataset,
batch_size=batch_size,
Expand All @@ -251,11 +255,17 @@ def _quantize_ovcausallm(
batch_size: int = 1,
data_collator: Optional[DataCollator] = None,
remove_unused_columns: bool = True,
weights_only: bool = False,
**kwargs,
):
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)

if weights_only:
self.model.model = nncf.compress_weights(self.model.model)
self.model.save_pretrained(save_directory)
return

calibration_dataloader = self._get_calibration_dataloader(
calibration_dataset=calibration_dataset,
batch_size=batch_size,
Expand Down
29 changes: 25 additions & 4 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ def preprocess_function(examples, tokenizer):
class OVWeightCompressionTest(unittest.TestCase):
# TODO : add models
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS = (
(OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 70),
(OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45),
(OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 70, 35),
(OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45, 23),
)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS)
def test_automodel_weight_compression(self, model_cls, model_name, expected_int8):
def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
task = model_cls.export_feature

with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -166,7 +166,7 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_int8

# TODO: uncomment once move to a newer version of NNCF which has some fixes
_, num_int8 = get_num_quantized_nodes(model)
self.assertEqual(expected_int8, num_int8)
self.assertEqual(expected_pt_int8, num_int8)

tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model(**tokens)
Expand All @@ -176,6 +176,27 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_int8
loaded_config = OVConfig.from_pretrained(tmp_dir)
self.assertIsNotNone(loaded_config)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS)
def test_ovmodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
task = model_cls.export_feature

with tempfile.TemporaryDirectory() as tmp_dir:
transformers_model = model_cls.from_pretrained(model_name, export=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

quantizer = OVQuantizer.from_pretrained(transformers_model, task=task)
quantizer.quantize(save_directory=tmp_dir, weights_only=True)
model = model_cls.from_pretrained(tmp_dir)

_, num_int8 = get_num_quantized_nodes(model)
self.assertEqual(expected_ov_int8, num_int8)

tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model(**tokens)
self.assertTrue("logits" in outputs)


class OVQuantizerQATest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (("hf-internal-testing/tiny-random-BertForQuestionAnswering",),)
Expand Down

0 comments on commit 6cd4a85

Please sign in to comment.