Skip to content

Commit

Permalink
Add CLI openvino export in documentation (#440)
Browse files Browse the repository at this point in the history
* add cli openvino export readme

* minor

* add int8

* add in documentation

* add int8 section

* format

* add comment
  • Loading branch information
echarlaix authored Nov 6, 2023
1 parent 99a3970 commit c5ed584
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 78 deletions.
43 changes: 34 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,51 @@ For more details on the supported compression techniques, please refer to the [d

Below are the examples of how to use OpenVINO and its [NNCF](https://docs.openvino.ai/latest/tmo_introduction.html) framework to accelerate inference.

#### Export:

It is possible to export your model to the [OpenVINO](https://docs.openvino.ai/2023.1/openvino_ir.html) IR format with the CLI :

```plain
optimum-cli export openvino --model gpt2 ov_model
```

If you add `--int8`, the weights will be quantized to INT8, the activations will be kept in floating point precision.

```plain
optimum-cli export openvino --model gpt2 --int8 ov_model
```


#### Inference:

To load a model and run inference with OpenVINO Runtime, you can just replace your `AutoModelForXxx` class with the corresponding `OVModelForXxx` class.
If you want to load a PyTorch checkpoint, set `export=True` to convert your model to the OpenVINO IR.


```diff
- from transformers import AutoModelForSequenceClassification
+ from optimum.intel import OVModelForSequenceClassification
- from transformers import AutoModelForSeq2SeqLM
+ from optimum.intel import OVModelForSeq2SeqLM
from transformers import AutoTokenizer, pipeline

model_id = "distilbert-base-uncased-finetuned-sst-2-english"
- model = AutoModelForSequenceClassification.from_pretrained(model_id)
+ model = OVModelForSequenceClassification.from_pretrained(model_id, export=True)
model_id = "echarlaix/t5-small-openvino"
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
+ model = OVModelForSeq2SeqLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model.save_pretrained("./distilbert")
pipe = pipeline("translation_en_to_fr", model=model, tokenizer=tokenizer)
results = pipe("He never went out without a book under his arm, and he often came back with two.")

classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
results = classifier("He's a dreadful magician.")
[{'translation_text': "Il n'est jamais sorti sans un livre sous son bras, et il est souvent revenu avec deux."}]
```

If you want to load a PyTorch checkpoint, set `export=True` to convert your model to the OpenVINO IR.

```python
from optimum.intel import OVModelForCausalLM

model = OVModelForCausalLM.from_pretrained("gpt2", export=True)
model.save_pretrained("./ov_model")
```


#### Post-training static quantization:

Post-training static quantization introduces an additional calibration step where data is fed through the network in order to compute the activations quantization parameters. Here is an example on how to apply static quantization on a fine-tuned DistilBERT.
Expand Down
196 changes: 127 additions & 69 deletions docs/source/inference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,111 @@ specific language governing permissions and limitations under the License.

Optimum Intel can be used to load optimized models from the [Hugging Face Hub](https://huggingface.co/models?library=openvino&sort=downloads) and create pipelines to run inference with OpenVINO Runtime without rewriting your APIs.

## Switching from Transformers to Optimum
## Transformers models

You can now easily perform inference with OpenVINO Runtime on a variety of Intel processors ([see](https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_Supported_Devices.html) the full list of supported devices).
For that, just replace the `AutoModelForXxx` class with the corresponding `OVModelForXxx` class.
To load a Transformers model and convert it to the OpenVINO format on-the-fly, you can set `export=True` when loading your model.

Here is an example on how to perform inference with OpenVINO Runtime for a text classification class:
As shown in the table below, each task is associated with a class enabling to automatically load your model.

| Task | Auto Class |
|--------------------------------------|--------------------------------------|
| `text-classification` | `OVModelForSequenceClassification` |
| `token-classification` | `OVModelForTokenClassification` |
| `question-answering` | `OVModelForQuestionAnswering` |
| `audio-classification` | `OVModelForAudioClassification` |
| `image-classification` | `OVModelForImageClassification` |
| `feature-extraction` | `OVModelForFeatureExtraction` |
| `fill-mask` | `OVModelForMaskedLM` |
| `text-generation` | `OVModelForCausalLM` |
| `text2text-generation` | `OVModelForSeq2SeqLM` |


### Export

It is possible to export your model to the [OpenVINO](https://docs.openvino.ai/2023.1/openvino_ir.html) IR format with the CLI :

```bash
optimum-cli export openvino --model gpt2 ov_model
```

The example above illustrates exporting a checkpoint from the 🤗 Hub. When exporting a local model, first make sure that you saved both the model’s weights and tokenizer files in the same directory (`local_path`).
When using CLI, pass the `local_path` to the model argument instead of the checkpoint name of the model hosted on the Hub and provide the `--task` argument. You can review the list of supported tasks in the 🤗 [Optimum documentation](https://huggingface.co/docs/optimum/exporters/task_manager). If task argument is not provided, it will default to the model architecture without any task specific head.
Here we set the `task` to `text-generation-with-past`, with the `-with-past` suffix enabling the re-use of the pre-computed key/values hidden-states `use_cache=True`.

```bash
optimum-cli export openvino --model local_path --task text-generation-with-past ov_model
```

Once the model is exported, you can load the OpenVINO model using :

```python
from optimum.intel import AutoModelForCausalLM

model_id = "helenai/gpt2-ov"
model = AutoModelForCausalLM.from_pretrained(model_id)
```

You can also load your PyTorch checkpoint and convert it to the OpenVINO format on-the-fly, by setting `export=True` when loading your model.

```python
from optimum.intel import AutoModelForCausalLM

model_id = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_id, export=True)
model.save_pretrained("ov_model")
```

### Inference

You can load an OpenVINO hosted on the hub and perform inference, no need to adapt your code to get it to work with `OVModelForXxx` classes:

```diff
- from transformers import AutoModelForSequenceClassification
+ from optimum.intel import OVModelForSequenceClassification
- from transformers import AutoModelForCausalLM
+ from optimum.intel import OVModelForCausalLM
from transformers import AutoTokenizer, pipeline

model_id = "distilbert-base-uncased-finetuned-sst-2-english"
- model = AutoModelForSequenceClassification.from_pretrained(model_id)
+ model = OVModelForSequenceClassification.from_pretrained(model_id, export=True)
model_id = "helenai/gpt2-ov"
- model = AutoModelForCausalLM.from_pretrained(model_id)
+ model = OVModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
cls_pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
outputs = cls_pipe("He's a dreadful magician.")

[{'label': 'NEGATIVE', 'score': 0.9919503927230835}]
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
results = pipe("He's a dreadful magician and")
```

See the [reference documentation](reference_ov) for more information about parameters, and examples for different tasks.

To easily save the resulting model, you can use the `save_pretrained()` method, which will save both the BIN and XML files describing the graph. It is useful to save the tokenizer to the same directory, to enable easy loading of the tokenizer for the model.


```python
# Save the exported model
save_directory = "openvino_distilbert"
model.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)
```

### Weight only quantization

You can also apply INT8 quantization on your models weights when exporting your model with the CLI:

```bash
optimum-cli export openvino --model gpt2 --int8 ov_model
```

This will results in the exported model linear and embedding layers to be quanrtized to INT8, the activations will be kept in floating point precision.

This can also be done when loading your model by setting the `load_in_8bit` argument when calling the `from_pretrained()` method.

```python
from optimum.intel import OVModelForCausalLM

model = OVModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)
```

To apply quantization on both weights and activations, you can use the `OVQuantizer`, more information in the [documentation](https://huggingface.co/docs/optimum/main/en/intel/optimization_ov#optimization).

### Static shape

By default, `OVModelForXxx` support dynamic shapes, enabling inputs of every shapes. To speed up inference, static shapes can be enabled by giving the desired inputs shapes.

```python
Expand All @@ -55,7 +125,6 @@ model.reshape(1, 9)
model.compile()
```


When fixing the shapes with the `reshape()` method, inference cannot be performed with an input of a different shape. When instantiating your pipeline, you can specify the maximum total input sequence length after tokenization in order for shorter sequences to be padded and for longer sequences to be truncated.

```python
Expand All @@ -81,16 +150,7 @@ qa_pipe = pipeline(
metric = task_evaluator.compute(model_or_pipeline=qa_pipe, data=eval_dataset, metric="squad")
```


To run inference on Intel integrated or discrete GPU, use `.to("gpu")`. On GPU, models run in FP16 precision by default. (See [OpenVINO documentation](https://docs.openvino.ai/nightly/openvino_docs_install_guides_configurations_for_intel_gpu.html) about installing drivers for GPU inference).

```python
# Static shapes speed up inference
model.reshape(1, 9)
model.to("gpu")
# Compile the model before the first inference
model.compile()
```
### Compilation

By default the model will be compiled when instantiating our `OVModel`. In the case where the model is reshaped or placed to another device, the model will need to be recompiled again, which will happen by default before the first inference (thus inflating the latency of the first inference). To avoid an unnecessary compilation, you can disable the first compilation by setting `compile=False`. The model can be compiled before the first inference with `model.compile()`.

Expand All @@ -106,6 +166,19 @@ model.reshape(1,128)
model.compile()
```

To run inference on Intel integrated or discrete GPU, use `.to("gpu")`. On GPU, models run in FP16 precision by default. (See [OpenVINO documentation](https://docs.openvino.ai/nightly/openvino_docs_install_guides_configurations_for_intel_gpu.html) about installing drivers for GPU inference).

```python
# Static shapes speed up inference
model.reshape(1, 9)
model.to("gpu")
# Compile the model before the first inference
model.compile()
```

### Configuration


It is possible to pass an `ov_config` parameter to `from_pretrained()` with custom OpenVINO configuration values. This can be used for example to enable full precision inference on devices where FP16 or BF16 inference precision is used by default.


Expand All @@ -120,7 +193,7 @@ Optimum Intel leverages OpenVINO's model caching to speed up model compiling. By
model = OVModelForSequenceClassification.from_pretrained(model_id, ov_config={"CACHE_DIR":""})
```

## Sequence-to-sequence models
### Sequence-to-sequence models

Sequence-to-sequence (Seq2Seq) models, that generate a new sequence from an input, can also be used when running inference with OpenVINO. When Seq2Seq models are exported to the OpenVINO IR, they are decomposed into two parts : the encoder and the "decoder" (which actually consists of the decoder with the language modeling head), that are later combined during inference.
To speed up sequential decoding, a cache with pre-computed key/values hidden-states will be used by default. An additional model component will be exported: the "decoder" with pre-computed key/values as one of its inputs. This specific export comes from the fact that during the first pass, the decoder has no pre-computed key/values hidden-states, while during the rest of the generation past key/values will be used to speed up sequential decoding. To disable this cache, set `use_cache=False` in the `from_pretrained()` method.
Expand All @@ -147,23 +220,33 @@ tokenizer.save_pretrained(save_directory)
[{'translation_text': "Il n'est jamais sorti sans un livre sous son bras, et il est souvent revenu avec deux."}]
```

## Stable Diffusion
## Diffusers models

Make sure you have 🤗 Diffusers installed.

To install `diffusers`:
```bash
pip install optimum[diffusers]
```


### Stable Diffusion

Stable Diffusion models can also be used when running inference with OpenVINO. When Stable Diffusion models
are exported to the OpenVINO format, they are decomposed into three components that are later combined during inference:
are exported to the OpenVINO format, they are decomposed into different components that are later combined during inference:
- The text encoder
- The U-NET
- The VAE encoder
- The VAE decoder

Make sure you have 🤗 Diffusers installed.
| Task | Auto Class |
|--------------------------------------|--------------------------------------|
| `text-to-image` | `OVStableDiffusionPipeline` |
| `image-to-image` | `OVStableDiffusionImg2ImgPipeline` |
| `inpaint` | `OVStableDiffusionInpaintPipeline` |

To install `diffusers`:
```bash
pip install optimum[diffusers]
```

### Text-to-Image
#### Text-to-Image
Here is an example of how you can load an OpenVINO Stable Diffusion model and run inference using OpenVINO Runtime:

```python
Expand Down Expand Up @@ -208,7 +291,7 @@ In case you want to change any parameters such as the outputs height or width, y
<img src="https://huggingface.co/datasets/optimum/documentation-images/resolve/main/intel/openvino/stable_diffusion_v1_5_sail_boat_rembrandt.png">
</div>

### Text-to-Image with Textual Inversion
#### Text-to-Image with Textual Inversion
Here is an example of how you can load an OpenVINO Stable Diffusion model with pre-trained textual inversion embeddings and run inference using OpenVINO Runtime:


Expand Down Expand Up @@ -248,7 +331,7 @@ The left image shows the generation result of original stable diffusion v1.5, th
| ![](https://huggingface.co/datasets/optimum/documentation-images/resolve/main/intel/openvino/textual_inversion/stable_diffusion_v1_5_without_textual_inversion.png) | ![](https://huggingface.co/datasets/optimum/documentation-images/resolve/main/intel/openvino/textual_inversion/stable_diffusion_v1_5_with_textual_inversion.png) |


### Image-to-Image
#### Image-to-Image

```python
import requests
Expand All @@ -269,16 +352,15 @@ image = pipeline(prompt=prompt, image=init_image, strength=0.75, guidance_scale=
image.save("fantasy_landscape.png")
```

## Stable Diffusion XL
### Stable Diffusion XL

Before using `OVtableDiffusionXLPipeline` make sure to have `diffusers` and `invisible_watermark` installed. You can install the libraries as follows:
| Task | Auto Class |
|--------------------------------------|--------------------------------------|
| `text-to-image` | `OVStableDiffusionXLPipeline` |
| `image-to-image` | `OVStableDiffusionXLImg2ImgPipeline` |

```bash
pip install diffusers
pip install invisible-watermark>=0.2.0
```

### Text-to-Image
#### Text-to-Image

Here is an example of how you can load a SDXL OpenVINO model from [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and run inference using OpenVINO Runtime:

Expand All @@ -296,7 +378,7 @@ image.save("train_station.png")
|---|---|
| ![](https://huggingface.co/datasets/optimum/documentation-images/resolve/main/intel/openvino/sd_xl/train_station_friedrich.png) | ![](https://huggingface.co/datasets/optimum/documentation-images/resolve/main/intel/openvino/sd_xl/train_station_friedrich_2.png) |

### Text-to-Image with Textual Inversion
#### Text-to-Image with Textual Inversion

Here is an example of how you can load an SDXL OpenVINO model from [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) with pre-trained textual inversion embeddings and run inference using OpenVINO Runtime:

Expand Down Expand Up @@ -338,7 +420,7 @@ The left image shows the generation result of the original SDXL base 1.0, the ri
| ![](https://huggingface.co/datasets/optimum/documentation-images/resolve/main/intel/openvino/textual_inversion/sdxl_without_textual_inversion.png) | ![](https://huggingface.co/datasets/optimum/documentation-images/resolve/main/intel/openvino/textual_inversion/sdxl_with_textual_inversion.png) |


### Image-to-Image
#### Image-to-Image

Here is an example of how you can load a PyTorch SDXL model, convert it to OpenVINO on-the-fly and run inference using OpenVINO Runtime for *image-to-image*:

Expand All @@ -358,7 +440,7 @@ pipeline.save_pretrained("openvino-sd-xl-refiner-1.0")
```


### Refining the image output
#### Refining the image output

The image can be refined by making use of a model like [stabilityai/stable-diffusion-xl-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0). In this case, you only have to output the latents from the base model.

Expand All @@ -372,27 +454,3 @@ refiner = OVStableDiffusionXLImg2ImgPipeline.from_pretrained(model_id, export=Tr
image = base(prompt=prompt, output_type="latent").images[0]
image = refiner(prompt=prompt, image=image[None, :]).images[0]
```



## Supported tasks

As shown in the table below, each task is associated with a class enabling to automatically load your model.


| Task | Auto Class |
|--------------------------------------|--------------------------------------|
| `text-classification` | `OVModelForSequenceClassification` |
| `token-classification` | `OVModelForTokenClassification` |
| `question-answering` | `OVModelForQuestionAnswering` |
| `audio-classification` | `OVModelForAudioClassification` |
| `image-classification` | `OVModelForImageClassification` |
| `feature-extraction` | `OVModelForFeatureExtraction` |
| `fill-mask` | `OVModelForMaskedLM` |
| `text-generation` | `OVModelForCausalLM` |
| `text2text-generation` | `OVModelForSeq2SeqLM` |
| `text-to-image` | `OVStableDiffusionPipeline` |
| `text-to-image` | `OVStableDiffusionXLPipeline` |
| `image-to-image` | `OVStableDiffusionImg2ImgPipeline` |
| `image-to-image` | `OVStableDiffusionXLImg2ImgPipeline` |
| `inpaint` | `OVStableDiffusionInpaintPipeline` |

0 comments on commit c5ed584

Please sign in to comment.