Skip to content

Commit

Permalink
Raise an error when OVQuantizer is invoked on an already compressed m…
Browse files Browse the repository at this point in the history
…odel
  • Loading branch information
l-bat committed Jan 20, 2025
1 parent 2590794 commit 0420051
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
44 changes: 42 additions & 2 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,31 @@ def quantize(
else:
raise TypeError(f"Unsupported model type: {type(self.model)}")

def _check_model_state(self, sub_model_names: List[str] = None):
message_template = (
"Couldn't apply optimization to the model because it was already compressed with config: {}. "
"To avoid this issue, set load_in_8bit=False in the from_pretrained method when using the optimum-intel API, "
"or explicitly specify the desired weight format using --weight_format fp16/fp32 for CLI."
)

def check_rt_info(ov_model):
rt_info = ov_model.get_rt_info()
if "nncf" in rt_info:
model_weight_compression_config = rt_info["nncf"].get("weight_compression", None)
model_quantization_config = rt_info["nncf"].get("quantization", None)
if model_weight_compression_config is not None:
raise RuntimeError(message_template.format(model_weight_compression_config))
elif model_quantization_config is not None:
raise RuntimeError(message_template.format(model_quantization_config))

if sub_model_names is None:
check_rt_info(self.model.model)
else:
for name in sub_model_names:
if hasattr(self.model, name):
ov_model = getattr(self.model, name).model
check_rt_info(ov_model)

def _quantize_ovbasemodel(
self,
ov_config: OVConfig,
Expand All @@ -325,7 +350,7 @@ def _quantize_ovbasemodel(
remove_unused_columns: bool = True,
**kwargs,
):
from optimum.intel.openvino.modeling_seq2seq import _OVModelForWhisper
from optimum.intel.openvino.modeling_seq2seq import _OVModelForWhisper, OVModelForSeq2SeqLM
from optimum.intel.openvino.modeling_visual_language import OVModelForVisualCausalLM

if is_diffusers_available():
Expand Down Expand Up @@ -404,6 +429,7 @@ def _quantize_ovbasemodel(
"text_encoder_2",
"text_encoder_3",
]
self._check_model_state(sub_model_names)
sub_models = filter(lambda x: x, (getattr(self.model, name) for name in sub_model_names))
for sub_model in sub_models:
_weight_only_quantization(sub_model.model, quantization_config_copy, **kwargs)
Expand All @@ -421,6 +447,7 @@ def _quantize_ovbasemodel(
self.model.clear_requests()
else:
# The model may be for example OVModelForImageClassification, OVModelForAudioClassification, etc.
self._check_model_state()
self.model.model = _hybrid_quantization(
self.model.model, quantization_config, calibration_dataset, **kwargs
)
Expand All @@ -436,19 +463,31 @@ def _quantize_ovbasemodel(
"transformer",
"text_encoder_3",
]
self._check_model_state(sub_model_names)
sub_models = filter(lambda x: x, (getattr(self.model, name) for name in sub_model_names))
for sub_model in sub_models:
_weight_only_quantization(sub_model.model, quantization_config, **kwargs)
self.model.clear_requests()
elif isinstance(self.model, OVModelForVisualCausalLM):
language_model = self.model.language_model
_weight_only_quantization(language_model.model, quantization_config, calibration_dataset, **kwargs)
sub_model_names = ["vision_embeddings", "text_embeddings"] + self.model.additional_parts
self._check_model_state(sub_model_names + ["language_model"])
_weight_only_quantization(language_model.model, quantization_config, calibration_dataset, **kwargs)
sub_models = [getattr(self.model, f"{name}_model") for name in sub_model_names]
for sub_model in sub_models:
_weight_only_quantization(sub_model, OVWeightQuantizationConfig(bits=8, sym=True), **kwargs)
self.model.clear_requests()
elif isinstance(self.model, OVModelForSeq2SeqLM):
sub_model_names = ["encoder", "decoder"]
if self.model.decoder_with_past is not None:
sub_model_names.append("decoder_with_past")
self._check_model_state(sub_model_names)
sub_models = [getattr(self.model, name) for name in sub_model_names]
for sub_model in sub_models:
_weight_only_quantization(sub_model, quantization_config, **kwargs)
self.model.clear_requests()
else:
self._check_model_state()
_weight_only_quantization(self.model.model, quantization_config, calibration_dataset, **kwargs)
self.model.request = None
else:
Expand All @@ -460,6 +499,7 @@ def _quantize_ovbasemodel(

# Quantize model(s)
if isinstance(self.model, _OVModelForWhisper):
self._check_model_state(["encoder_model", "decoder_model", "decoder_with_past_model"])
self._quantize_whisper_model(quantization_config, calibration_dataset, **kwargs)
else:
quantized_model = _full_quantization(
Expand Down
17 changes: 17 additions & 0 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,23 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type, trust
_, num_weight_nodes = get_num_quantized_nodes(model)
self.assertEqual(expected_ov_int8[i], num_weight_nodes["int8"])

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION)
def test_raise_error_WC_over_WC(self, model_cls, model_type, trust_remote_code):
model = model_cls.from_pretrained(
MODEL_NAMES[model_type],
export=True,
load_in_8bit=True,
trust_remote_code=trust_remote_code,
)
quantization_config = OVWeightQuantizationConfig(bits=4, sym=True)
quantizer = OVQuantizer(model)
if isinstance(model, OVModelOpenCLIPForZeroShotImageClassification):
with pytest.raises(TypeError):
quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config))
else:
with pytest.raises(RuntimeError):
quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config))

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION)
def test_ovmodel_hybrid_quantization(self, model_cls, model_type, expected_num_fake_quantize, expected_ov_int8):
model_id = MODEL_NAMES[model_type]
Expand Down

0 comments on commit 0420051

Please sign in to comment.