Skip to content

Commit

Permalink
Added 8 bit weights compression by default for decoders larger than 1B (
Browse files Browse the repository at this point in the history
#444)

* Added 8bit compression for decoders larger than 1B

* Style

* Fixed issue

* Fixed one more issue

* Added warning for nncf absense in case of default compression to 8 bits

* Fixed an issue. Added warning message when NNCF is not available
  • Loading branch information
AlexKoff88 authored Oct 3, 2023
1 parent 72f369c commit d207110
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 3 deletions.
16 changes: 16 additions & 0 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.save_utils import maybe_save_preprocessors

from ...intel.utils.import_utils import is_nncf_available
from ...intel.utils.modeling_utils import patch_decoder_attention_mask
from .convert import export_models


OV_XML_FILE_NAME = "openvino_model.xml"

_MAX_UNCOMPRESSED_DECODER_SIZE = 1e9

logger = logging.getLogger(__name__)

if is_torch_available():
Expand Down Expand Up @@ -232,6 +235,19 @@ def main_export(
onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config)
models_and_onnx_configs = {"model": (model, onnx_config)}
if model_kwargs is None:
model_kwargs = {}
load_in_8bit = model_kwargs.get("load_in_8bit", None)
if load_in_8bit is None:
if model.num_parameters() >= _MAX_UNCOMPRESSED_DECODER_SIZE:
model_kwargs["load_in_8bit"] = True
else:
model_kwargs["load_in_8bit"] = False
else:
if not is_nncf_available():
raise ImportError(
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
)

if not is_stable_diffusion:
needs_pad_token_id = (
Expand Down
23 changes: 20 additions & 3 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx
from optimum.utils import is_diffusers_available

from ...intel.utils.import_utils import is_nncf_available
from .utils import (
OV_XML_FILE_NAME,
clear_class_registry,
Expand All @@ -52,6 +53,19 @@
from transformers.modeling_tf_utils import TFPreTrainedModel


def _save_model(model, path: str, compress_to_fp16=False, load_in_8bit=False):
if load_in_8bit:
if not is_nncf_available():
logger.warning(
"The model will be converted with no weights quantization. Quantization of the weights to int8 requires nncf."
"please install it with `pip install nncf`"
)
import nncf

model = nncf.compress_weights(model)
save_model(model, path, compress_to_fp16)


def export(
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
config: OnnxConfig,
Expand Down Expand Up @@ -137,7 +151,7 @@ def export_tensorflow(model: Union["PreTrainedModel", "ModelMixin"], config: Onn
onnx_path = Path(output).with_suffix(".onnx")
input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path)
ov_model = convert_model(str(onnx_path))
save_model(ov_model, output.parent / output, compress_to_fp16=False)
_save_model(ov_model, output.parent / output, compress_to_fp16=False, load_in_8bit=False)
return input_names, output_names, True


Expand Down Expand Up @@ -187,10 +201,12 @@ def export_pytorch_via_onnx(
)
torch.onnx.export = orig_torch_onnx_export
ov_model = convert_model(str(onnx_output))
save_model(
load_in_8bit = False if model_kwargs is None else model_kwargs.get("load_in_8bit", False)
_save_model(
ov_model,
output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output,
compress_to_fp16=False,
load_in_8bit=load_in_8bit,
)
return input_names, output_names, True

Expand Down Expand Up @@ -318,7 +334,8 @@ def ts_patched_forward(*args, **kwargs):
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
ov_model.validate_nodes_and_infer_types()
save_model(ov_model, output, compress_to_fp16=False)
load_in_8bit = False if model_kwargs is None else model_kwargs.get("load_in_8bit", False)
_save_model(ov_model, output, compress_to_fp16=False, load_in_8bit=load_in_8bit)
clear_class_registry()
del model
gc.collect()
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def _from_transformers(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
model_kwargs=kwargs,
)

config.is_decoder = True
Expand Down

0 comments on commit d207110

Please sign in to comment.