Skip to content

Commit

Permalink
Added 4 bit compression into quantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKoff88 committed Nov 2, 2023
1 parent 1ef49b1 commit ceb73e4
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 12 deletions.
2 changes: 1 addition & 1 deletion optimum/intel/openvino/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _enable_standard_onnx_export_option(self):
# save_onnx_model is defaulted to false so that the final model output is
# in OpenVINO IR to realize performance benefit in OpenVINO runtime.
# True value of save_onnx_model will save a model in onnx format.
if isinstance(self.compression, dict) and self.compression["algorithm"] == "quantization":
if isinstance(self.compression, dict) and "algorithm" in self.compression and self.compression["algorithm"] == "quantization":
self.compression["export_to_onnx_standard_ops"] = self.save_onnx_model
elif isinstance(self.compression, list):
for i, algo_config in enumerate(self.compression):
Expand Down
32 changes: 31 additions & 1 deletion optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,25 @@
OV_XML_FILE_NAME,
)

COMPRESSION_OPTIONS = {
"i8": { "mode": nncf.CompressWeightsMode.INT8 },
"i4_sym_g128": {
"mode": nncf.CompressWeightsMode.INT4_SYM,
"group_size": 128
},
"i4_asym_g128": {
"mode": nncf.CompressWeightsMode.INT4_ASYM,
"group_size": 128
},
"i4_sym_g64": {
"mode": nncf.CompressWeightsMode.INT4_SYM,
"group_size": 64
},
"i4_asym_g64": {
"mode": nncf.CompressWeightsMode.INT4_ASYM,
"group_size": 64
},
}

register_module(ignored_algorithms=[])(Conv1D)

Expand Down Expand Up @@ -186,6 +205,7 @@ def quantize(
data_collator,
remove_unused_columns,
weights_only,
quantization_config,
**kwargs,
)
elif isinstance(self.model, OVBaseModel):
Expand All @@ -212,6 +232,14 @@ def quantize(
else:
raise TypeError(f"Unsupported model type: {type(self.model)}")

def _get_compression_options(self, config: OVConfig):
options = {}
if config is not None and "type" in config.compression:
options = COMPRESSION_OPTIONS[config.compression["type"]]
if "ratio" in config.compression:
options["ratio"] = config.compression["ratio"]
return options

def _quantize_ovbasemodel(
self,
calibration_dataset: Dataset,
Expand Down Expand Up @@ -256,13 +284,15 @@ def _quantize_ovcausallm(
data_collator: Optional[DataCollator] = None,
remove_unused_columns: bool = True,
weights_only: bool = False,
quantization_config: OVConfig = None,
**kwargs,
):
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)

if weights_only:
self.model.model = nncf.compress_weights(self.model.model)
options = self._get_compression_options(quantization_config)
self.model.model = nncf.compress_weights(self.model.model, **options)
self.model.save_pretrained(save_directory)
return

Expand Down
43 changes: 35 additions & 8 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def preprocess_function(examples, tokenizer):
model = model_cls.from_pretrained(tmp_dir, file_name=file_name)

# TODO: uncomment once move to a newer version of NNCF which has some fixes (addmm, baddmm)
# num_fake_quantize, num_int8 = get_num_quantized_nodes(model)
# num_fake_quantize, num_int8, _ = get_num_quantized_nodes(model)
# self.assertEqual(expected_fake_quantize, num_fake_quantize)
# self.assertEqual(expected_int8, num_int8)

Expand Down Expand Up @@ -143,9 +143,13 @@ def preprocess_function(examples, tokenizer):

class OVWeightCompressionTest(unittest.TestCase):
# TODO : add models
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS = (
(OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 70, 35),
(OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45, 22),
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS = (
(OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 70, 70),
(OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45, 44),
)

SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = (
(OVModelForCausalLM, "opt125m", 82, 323),
)

SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION = (
Expand All @@ -162,7 +166,7 @@ class OVWeightCompressionTest(unittest.TestCase):
(OVStableDiffusionXLPipeline, "stable-diffusion-xl"),
)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS)
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS)
def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
task = model_cls.export_feature

Expand All @@ -187,8 +191,8 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_i
loaded_config = OVConfig.from_pretrained(tmp_dir)
self.assertIsNotNone(loaded_config)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS)
def test_ovmodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS)
def test_ovmodel_8bit_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
task = model_cls.export_feature

with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -207,6 +211,29 @@ def test_ovmodel_weight_compression(self, model_cls, model_name, expected_pt_int
tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model(**tokens)
self.assertTrue("logits" in outputs)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS)
def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_int8, expected_int4):
task = model_cls.export_feature

with tempfile.TemporaryDirectory() as tmp_dir:
model_id = MODEL_NAMES[model_name]
transformers_model = model_cls.from_pretrained(model_id, export=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

quantizer = OVQuantizer.from_pretrained(transformers_model, task=task)
quantizer.quantize(save_directory=tmp_dir, weights_only=True, quantization_config=OVConfig(compression={"type": "i4_sym_g128", "ratio": 0.8}))
model = model_cls.from_pretrained(tmp_dir)

_, num_int8, num_int4 = get_num_quantized_nodes(model)
self.assertEqual(expected_int8, num_int8)
self.assertEqual(expected_int4, num_int4)

tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model(**tokens)
self.assertTrue("logits" in outputs)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION)
def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type):
Expand Down Expand Up @@ -349,7 +376,7 @@ def compute_metrics(p):
trainer.save_model()

model = OVModelForSequenceClassification.from_pretrained(tmp_dir)
num_fake_quantize, num_int8 = get_num_quantized_nodes(model)
num_fake_quantize, num_int8, _ = get_num_quantized_nodes(model)
self.assertEqual(expected_fake_quantize, num_fake_quantize)
self.assertEqual(expected_int8, num_int8)

Expand Down
4 changes: 2 additions & 2 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@
"t5": (64, 104, 84),
"stable-diffusion": (148, 8, 8, 64),
"stable-diffusion-xl": (296, 8, 8, 66),
"stable-diffusion-xl-refiner": (296, 4, 8, 66),
"stable-diffusion-xl-refiner": (296, 8, 8, 66),
}


_ARCHITECTURES_TO_EXPECTED_INT4_INT8 = {
"opt125m": (128, 64)
"opt125m": (82, 323)
}


Expand Down

0 comments on commit ceb73e4

Please sign in to comment.