Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into stateful
Browse files Browse the repository at this point in the history
  • Loading branch information
slyalin committed Dec 12, 2023
2 parents 403adb5 + 5dac93d commit 6097bfd
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 7 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,13 @@ It is possible to export your model to the [OpenVINO](https://docs.openvino.ai/2
optimum-cli export openvino --model gpt2 ov_model
```

If you add `--int8`, the weights will be quantized to INT8, the activations will be kept in floating point precision.
If you add `--int8`, the model linear and embedding weights will be quantized to INT8, the activations will be kept in floating point precision.

```plain
optimum-cli export openvino --model gpt2 --int8 ov_model
```

To apply quantization on both weights and activations, you can find more information in the [documentation](https://huggingface.co/docs/optimum/main/en/intel/optimization_ov).

#### Inference:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/inference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ You can also apply INT8 quantization on your models weights when exporting your
optimum-cli export openvino --model gpt2 --int8 ov_model
```

This will results in the exported model linear and embedding layers to be quanrtized to INT8, the activations will be kept in floating point precision.
This will results in the exported model linear and embedding layers to be quantized to INT8, the activations will be kept in floating point precision.

This can also be done when loading your model by setting the `load_in_8bit` argument when calling the `from_pretrained()` method.

Expand Down
8 changes: 5 additions & 3 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from optimum.utils import is_diffusers_available
from .stateful import patch_stateful

from ...intel.utils.import_utils import is_nncf_available
from ...intel.utils.import_utils import is_nncf_available, is_optimum_version
from .utils import (
OV_XML_FILE_NAME,
clear_class_registry,
Expand Down Expand Up @@ -311,8 +311,10 @@ def export_pytorch(
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to Flase
if custom_patcher or dict_inputs:
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
# DecoderModelPatcher does not override model forward
if isinstance(patcher, DecoderModelPatcher) or patcher.orig_forward_name != "forward":
# DecoderModelPatcher does not override model forward in optimum < 1.15
if (
isinstance(patcher, DecoderModelPatcher) and is_optimum_version("<", "1.15.0")
) or patcher.orig_forward_name != "forward":
patch_model_forward = True
patched_forward = model.forward
else:
Expand Down
14 changes: 13 additions & 1 deletion optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,24 @@
logger = logging.getLogger(__name__)


def get_float_type(model_dtype: torch.dtype):
if model_dtype == torch.bfloat16:
return "bf16"
elif model_dtype == torch.float16:
return "fp16"
else:
return "fp32"


def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = False):
task = _TASK_ALIASES.get(task, task)
signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__)
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
float_dtype = get_float_type(model.dtype)
if "text-generation" in task:
onnx_config = onnx_config_class(model.config, use_past=use_cache, use_past_in_inputs=use_cache)
onnx_config = onnx_config_class(
model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype
)
else:
onnx_config = onnx_config_class(model.config)

Expand Down
3 changes: 2 additions & 1 deletion optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ def _quantize_ovcausallm(
save_directory.mkdir(parents=True, exist_ok=True)

if weights_only:
self.model.model = nncf.compress_weights(self.model.model)
model = nncf.compress_weights(self.model._original_model)
self.model.model = model
self.model.save_pretrained(save_directory)
return

Expand Down
5 changes: 5 additions & 0 deletions optimum/intel/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}

_optimum_version = importlib_metadata.version("optimum")

_transformers_available = importlib.util.find_spec("transformers") is not None
_transformers_version = "N/A"
Expand Down Expand Up @@ -175,6 +176,10 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)


def is_optimum_version(operation: str, version: str):
return compare_versions(parse(_optimum_version), operation, version)


def is_neural_compressor_version(operation: str, version: str):
"""
Compare the current Neural Compressor version to a given reference with an operation.
Expand Down

0 comments on commit 6097bfd

Please sign in to comment.