Skip to content

Commit

Permalink
Added 8bit compression for decoders larger than 1B
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKoff88 committed Oct 2, 2023
1 parent 72f369c commit 8801566
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
11 changes: 11 additions & 0 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

OV_XML_FILE_NAME = "openvino_model.xml"

_MAX_UNCOMPRESSED_DECODER_SIZE = 1e9

logger = logging.getLogger(__name__)

if is_torch_available():
Expand Down Expand Up @@ -232,6 +234,15 @@ def main_export(
onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config)
models_and_onnx_configs = {"model": (model, onnx_config)}
load_in_8bit = model_kwargs.get("load_in_8bit", None)
if load_in_8bit is None:
if model_kwargs is None:
model_kwargs = {}

if model.num_parameters() >= _MAX_UNCOMPRESSED_DECODER_SIZE:
model_kwargs["load_in_8bit"] = True
else:
model_kwargs["load_in_8bit"] = False

if not is_stable_diffusion:
needs_pad_token_id = (
Expand Down
20 changes: 16 additions & 4 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from transformers.utils import is_tf_available, is_torch_available

import nncf

from openvino.runtime import PartialShape, save_model
from openvino.runtime.utils.types import get_element_type
from openvino.tools.ovc import convert_model
Expand Down Expand Up @@ -52,6 +54,12 @@
from transformers.modeling_tf_utils import TFPreTrainedModel


def _save_model(model, path: str, compress_to_fp16=False, load_in_8bit=False):
if load_in_8bit:
model = nncf.compress_weights(model)
save_model(model, path, compress_to_fp16)


def export(
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
config: OnnxConfig,
Expand Down Expand Up @@ -137,7 +145,7 @@ def export_tensorflow(model: Union["PreTrainedModel", "ModelMixin"], config: Onn
onnx_path = Path(output).with_suffix(".onnx")
input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path)
ov_model = convert_model(str(onnx_path))
save_model(ov_model, output.parent / output, compress_to_fp16=False)
_save_model(ov_model, output.parent / output, compress_to_fp16=False, load_in_8bit=False)
return input_names, output_names, True


Expand Down Expand Up @@ -187,10 +195,11 @@ def export_pytorch_via_onnx(
)
torch.onnx.export = orig_torch_onnx_export
ov_model = convert_model(str(onnx_output))
save_model(
_save_model(
ov_model,
output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output,
compress_to_fp16=False,
load_in_8bit=model_kwargs.get("load_in_8bit", False)
)
return input_names, output_names, True

Expand Down Expand Up @@ -314,11 +323,14 @@ def ts_patched_forward(*args, **kwargs):
dims = inputs[input_name]

for dim in dims:
static_shape[dim] = -1
static_shape[dim] = -1
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
ov_model.validate_nodes_and_infer_types()
save_model(ov_model, output, compress_to_fp16=False)
_save_model(ov_model,
output,
compress_to_fp16=False,
load_in_8bit=model_kwargs.get("load_in_8bit", False))
clear_class_registry()
del model
gc.collect()
Expand Down
3 changes: 2 additions & 1 deletion optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@
"pegasus",
}


@add_start_docstrings(
"""
Base OVBaseDecoderModel class.
Expand Down Expand Up @@ -226,6 +225,7 @@ def _from_transformers(
if use_cache:
task = task + "-with-past"


main_export(
model_name_or_path=model_id,
output=save_dir_path,
Expand All @@ -237,6 +237,7 @@ def _from_transformers(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
model_kwargs=kwargs,
)

config.is_decoder = True
Expand Down

0 comments on commit 8801566

Please sign in to comment.