Skip to content

Commit

Permalink
Merged with master
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKoff88 committed Nov 9, 2023
2 parents 6d22f96 + f248835 commit 975b277
Show file tree
Hide file tree
Showing 27 changed files with 1,013 additions and 439 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/test_inc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[neural-compressor,ipex,diffusers,tests]
pip install .[neural-compressor,diffusers,tests]
pip install intel-extension-for-pytorch
- name: Test with Pytest
run: |
pytest tests/neural_compressor/
43 changes: 34 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,51 @@ For more details on the supported compression techniques, please refer to the [d

Below are the examples of how to use OpenVINO and its [NNCF](https://docs.openvino.ai/latest/tmo_introduction.html) framework to accelerate inference.

#### Export:

It is possible to export your model to the [OpenVINO](https://docs.openvino.ai/2023.1/openvino_ir.html) IR format with the CLI :

```plain
optimum-cli export openvino --model gpt2 ov_model
```

If you add `--int8`, the weights will be quantized to INT8, the activations will be kept in floating point precision.

```plain
optimum-cli export openvino --model gpt2 --int8 ov_model
```


#### Inference:

To load a model and run inference with OpenVINO Runtime, you can just replace your `AutoModelForXxx` class with the corresponding `OVModelForXxx` class.
If you want to load a PyTorch checkpoint, set `export=True` to convert your model to the OpenVINO IR.


```diff
- from transformers import AutoModelForSequenceClassification
+ from optimum.intel import OVModelForSequenceClassification
- from transformers import AutoModelForSeq2SeqLM
+ from optimum.intel import OVModelForSeq2SeqLM
from transformers import AutoTokenizer, pipeline

model_id = "distilbert-base-uncased-finetuned-sst-2-english"
- model = AutoModelForSequenceClassification.from_pretrained(model_id)
+ model = OVModelForSequenceClassification.from_pretrained(model_id, export=True)
model_id = "echarlaix/t5-small-openvino"
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
+ model = OVModelForSeq2SeqLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model.save_pretrained("./distilbert")
pipe = pipeline("translation_en_to_fr", model=model, tokenizer=tokenizer)
results = pipe("He never went out without a book under his arm, and he often came back with two.")

classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
results = classifier("He's a dreadful magician.")
[{'translation_text': "Il n'est jamais sorti sans un livre sous son bras, et il est souvent revenu avec deux."}]
```

If you want to load a PyTorch checkpoint, set `export=True` to convert your model to the OpenVINO IR.

```python
from optimum.intel import OVModelForCausalLM

model = OVModelForCausalLM.from_pretrained("gpt2", export=True)
model.save_pretrained("./ov_model")
```


#### Post-training static quantization:

Post-training static quantization introduces an additional calibration step where data is fed through the network in order to compute the activations quantization parameters. Here is an example on how to apply static quantization on a fine-tuned DistilBERT.
Expand Down
196 changes: 127 additions & 69 deletions docs/source/inference.mdx

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions docs/source/optimization_ov.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ tokenizer.save_pretrained(save_dir)

The `quantize()` method applies post-training static quantization and export the resulting quantized model to the OpenVINO Intermediate Representation (IR). The resulting graph is represented with two files: an XML file describing the network topology and a binary file describing the weights. The resulting model can be run on any target Intel device.


## Training-time optimization

Apart from optimizing a model after training like post-training quantization above, `optimum.openvino` also provides optimization methods during training, namely Quantization-Aware Training (QAT) and Joint Pruning, Quantization and Distillation (JPQD).
Expand Down Expand Up @@ -221,4 +220,4 @@ text = "He's a dreadful magician."
outputs = cls_pipe(text)

[{'label': 'NEGATIVE', 'score': 0.9840195178985596}]
```
```
29 changes: 11 additions & 18 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors

from ...intel.utils.import_utils import is_nncf_available
from ...intel.utils.modeling_utils import patch_decoder_attention_mask
from .convert import export_models


Expand Down Expand Up @@ -260,24 +259,18 @@ class StoreAttr(object):
preprocessors = maybe_load_preprocessors(
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
)
if not task.startswith("text-generation"):
onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs(
model=model,
task=task,
monolith=False,
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
custom_architecture=custom_architecture,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
_variant="default",
)
else:
# TODO : ModelPatcher will be added in next optimum release
model = patch_decoder_attention_mask(model)

onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config)
models_and_onnx_configs = {"model": (model, onnx_config)}
onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs(
model=model,
task=task,
monolith=False,
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
custom_architecture=custom_architecture,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
_variant="default",
legacy=False,
)

if compression_option is None:
num_parameters = model.num_parameters() if not is_stable_diffusion else model.unet.num_parameters()
Expand Down
18 changes: 16 additions & 2 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed
from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx
from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher
from optimum.utils import is_diffusers_available

from ...intel.utils.import_utils import is_nncf_available
Expand Down Expand Up @@ -336,14 +337,21 @@ def export_pytorch(
dummy_inputs, dict_inputs = remove_none_from_dummy_inputs(dummy_inputs)
input_info = get_input_shapes(dummy_inputs, inputs)
custom_patcher = type(config).patch_model_for_export != OnnxConfig.patch_model_for_export
patch_model_forward = False
orig_forward = model.forward
try:
# TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
# while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
# To handle it, additional wrapper on patcher forward applied.
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to Flase
if custom_patcher or dict_inputs:
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
patched_forward = patcher.patched_forward
# DecoderModelPatcher does not override model forward
if isinstance(patcher, DecoderModelPatcher) or patcher.orig_forward_name != "forward":
patch_model_forward = True
patched_forward = model.forward
else:
patched_forward = patcher.patched_forward

@functools.wraps(patched_forward)
def ts_patched_forward(*args, **kwargs):
Expand All @@ -356,14 +364,20 @@ def ts_patched_forward(*args, **kwargs):
outputs = patched_forward(*args, **kwargs)
return tuple(outputs.values())

patcher.patched_forward = ts_patched_forward
if not patch_model_forward:
patcher.patched_forward = ts_patched_forward
else:
model.forward = ts_patched_forward
with patcher:
ov_model = convert_model(model, example_input=dummy_inputs, input=input_info)
else:
model.config.torchscript = True
model.config.retun_dict = False
ov_model = convert_model(model, example_input=dummy_inputs, input=input_info)
except Exception as ex:
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
if patch_model_forward:
model.forward = orig_forward
return export_pytorch_via_onnx(
model,
config,
Expand Down
4 changes: 4 additions & 0 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"OVStableDiffusionInpaintPipeline",
"OVStableDiffusionXLPipeline",
"OVStableDiffusionXLImg2ImgPipeline",
"OVLatentConsistencyModelPipeline",
]
else:
_import_structure["openvino"].extend(
Expand All @@ -71,6 +72,7 @@
"OVStableDiffusionInpaintPipeline",
"OVStableDiffusionXLPipeline",
"OVStableDiffusionXLImg2ImgPipeline",
"OVLatentConsistencyModelPipeline",
]
)

Expand Down Expand Up @@ -158,6 +160,7 @@
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_openvino_and_diffusers_objects import (
OVLatentConsistencyModelPipeline,
OVStableDiffusionImg2ImgPipeline,
OVStableDiffusionInpaintPipeline,
OVStableDiffusionPipeline,
Expand All @@ -166,6 +169,7 @@
)
else:
from .openvino import (
OVLatentConsistencyModelPipeline,
OVStableDiffusionImg2ImgPipeline,
OVStableDiffusionInpaintPipeline,
OVStableDiffusionPipeline,
Expand Down
Loading

0 comments on commit 975b277

Please sign in to comment.