Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed the logic of default 8-bit weights compression #445

Merged
merged 10 commits into from
Oct 4, 2023
18 changes: 8 additions & 10 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,19 +235,17 @@ def main_export(
onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config)
models_and_onnx_configs = {"model": (model, onnx_config)}
if model_kwargs is None:
model_kwargs = {}
model_kwargs = model_kwargs or {}
load_in_8bit = model_kwargs.get("load_in_8bit", None)
if load_in_8bit is None:
if model.num_parameters() >= _MAX_UNCOMPRESSED_DECODER_SIZE:
model_kwargs["load_in_8bit"] = True
else:
model_kwargs["load_in_8bit"] = False
else:
if not is_nncf_available():
raise ImportError(
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
)
if not is_nncf_available():
logger.warning(
"The model will be converted with no weights quantization. Quantization of the weights to int8 requires nncf."
"please install it with `pip install nncf`"
)
else:
model_kwargs["load_in_8bit"] = True

if not is_stable_diffusion:
needs_pad_token_id = (
Expand Down
6 changes: 3 additions & 3 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@
def _save_model(model, path: str, compress_to_fp16=False, load_in_8bit=False):
if load_in_8bit:
if not is_nncf_available():
logger.warning(
"The model will be converted with no weights quantization. Quantization of the weights to int8 requires nncf."
"please install it with `pip install nncf`"
raise ImportError(
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
)

import nncf

model = nncf.compress_weights(model)
Expand Down
Loading