diff --git a/notebooks/openvino/README.md b/notebooks/openvino/README.md
index f63c13c55b..31c2580996 100644
--- a/notebooks/openvino/README.md
+++ b/notebooks/openvino/README.md
@@ -12,4 +12,5 @@ The notebooks have been tested with Python 3.8 and 3.10 on Ubuntu Linux.
|:----------|:-------------|:-------------|------:|
| [How to run inference with the OpenVINO](https://github.com/huggingface/optimum-intel/blob/main/notebooks/openvino/optimum_openvino_inference.ipynb) | Explains how to export your model to OpenVINO and to run inference with OpenVINO Runtime on various tasks| [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Open in Colab"](https://colab.research.google.com/github/huggingface/optimum-intel/blob/main/notebooks/openvino/optimum_openvino_inference.ipynb)| [data:image/s3,"s3://crabby-images/93f0e/93f0eb76f7b1999493dd777417858b495378833c" alt="Open in AWS Studio"](https://studiolab.sagemaker.aws/import/github/huggingface/optimum-intel/blob/main/notebooks/openvino/optimum_openvino_inference.ipynb)|
| [How to quantize a question answering model with OpenVINO NNCF](https://github.com/huggingface/optimum-intel/blob/main/notebooks/openvino/question_answering_quantization.ipynb) | Show how to apply post-training quantization on a question answering model using [NNCF](https://github.com/openvinotoolkit/nncf) and to accelerate inference with OpenVINO| [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Open in Colab"](https://colab.research.google.com/github/huggingface/optimum-intel/blob/main/notebooks/openvino/question_answering_quantization.ipynb)| [data:image/s3,"s3://crabby-images/93f0e/93f0eb76f7b1999493dd777417858b495378833c" alt="Open in AWS Studio"](https://studiolab.sagemaker.aws/import/github/huggingface/optimum-intel/blob/main/notebooks/openvino/question_answering_quantization.ipynb)|
-| [How to quantize Stable Diffusion model with OpenVINO NNCF](https://github.com/huggingface/optimum-intel/blob/main/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb)| Show how to apply post-training hybrid quantization on a Stable Diffusion model using [NNCF](https://github.com/openvinotoolkit/nncf) and to accelerate inference with OpenVINO| [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Open in Colab"](https://colab.research.google.com/github/huggingface/optimum-intel/blob/main/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb)| [data:image/s3,"s3://crabby-images/93f0e/93f0eb76f7b1999493dd777417858b495378833c" alt="Open in AWS Studio"](https://studiolab.sagemaker.aws/import/github/huggingface/optimum-intel/blob/main/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb)|
\ No newline at end of file
+| [How to quantize Stable Diffusion model with OpenVINO NNCF](https://github.com/huggingface/optimum-intel/blob/main/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb)| Show how to apply post-training hybrid quantization on a Stable Diffusion model using [NNCF](https://github.com/openvinotoolkit/nncf) and to accelerate inference with OpenVINO| [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Open in Colab"](https://colab.research.google.com/github/huggingface/optimum-intel/blob/main/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb)| [data:image/s3,"s3://crabby-images/93f0e/93f0eb76f7b1999493dd777417858b495378833c" alt="Open in AWS Studio"](https://studiolab.sagemaker.aws/import/github/huggingface/optimum-intel/blob/main/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb)|
+| [How to quantize Sentence Transformer model with OpenVINO NNCF](https://github.com/huggingface/optimum-intel/blob/main/notebooks/openvino/sentence_transformer_quantization.ipynb)| Show how to apply post-training 8-bit quantization on a Sentence Transformer model using [NNCF](https://github.com/openvinotoolkit/nncf) and to accelerate inference with OpenVINO| [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Open in Colab"](https://colab.research.google.com/github/huggingface/optimum-intel/blob/main/notebooks/openvino/sentence_transformer_quantization.ipynb)| [data:image/s3,"s3://crabby-images/93f0e/93f0eb76f7b1999493dd777417858b495378833c" alt="Open in AWS Studio"](https://studiolab.sagemaker.aws/import/github/huggingface/optimum-intel/blob/main/notebooks/openvino/sentence_transformer_quantization.ipynb)|
diff --git a/notebooks/openvino/requirements.txt b/notebooks/openvino/requirements.txt
index bb7a517cff..64ccd6d8cc 100644
--- a/notebooks/openvino/requirements.txt
+++ b/notebooks/openvino/requirements.txt
@@ -4,4 +4,3 @@ evaluate[evaluator]
ipywidgets
pillow
torchaudio
-
diff --git a/notebooks/openvino/sentence_transformer_quantization.ipynb b/notebooks/openvino/sentence_transformer_quantization.ipynb
new file mode 100644
index 0000000000..714544aa9a
--- /dev/null
+++ b/notebooks/openvino/sentence_transformer_quantization.ipynb
@@ -0,0 +1,625 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Quantization of Text Embedding model from Sentence Transformers library"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install optimum[openvino]\n",
+ "%pip install evaluate"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Quantize staticly model to 8-bit with NNCF via Optimum-Intel API"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The code snippet below shows how to use Optimum-Intel [Model Optimization API](https://huggingface.co/docs/optimum/en/intel/openvino/optimization#static-quantization) to quantize the model staticly. It leaverages [NNCF](https://github.com/openvinotoolkit/nncf) capabilites for static quantization of Transformer models where a combination of the special quantization scheme + SmoothQuant method + Bias Correction method are used to provide state-of-the-art accuracy.\n",
+ "\n",
+ "The static quantization requires some data to estimate quantization parameters of activations. It means that some calibration dataset should be provided. `OVQuantizer` class used for quantization provides an API to build such a dataset with `.get_calibration_dataset()` method."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "No OpenVINO files were found for sentence-transformers/all-MiniLM-L6-v2, setting `export=True` to convert the model to the OpenVINO IR. Don't forget to save the resulting model with `.save_pretrained()`\n",
+ "Framework not specified. Using pt to export the model.\n",
+ "Using framework PyTorch: 2.4.1+cpu\n",
+ "Overriding 1 configuration item(s)\n",
+ "\t- use_cache -> False\n",
+ "Compiling the model to CPU ...\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a9bd847756fd467e905a7ad7a243640c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "9d8ad91623d642f48e85b60ac823aca4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a2a7d09a573c4092a830bbaadc39f756",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b67c493aab36426090f8fafd25a17a00",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Output()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Configuration saved in all-MiniLM-L6-v2_int8/openvino_config.json\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "('all-MiniLM-L6-v2_int8/tokenizer_config.json',\n",
+ " 'all-MiniLM-L6-v2_int8/special_tokens_map.json',\n",
+ " 'all-MiniLM-L6-v2_int8/vocab.txt',\n",
+ " 'all-MiniLM-L6-v2_int8/added_tokens.json',\n",
+ " 'all-MiniLM-L6-v2_int8/tokenizer.json')"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from functools import partial\n",
+ "import datasets\n",
+ "from transformers import AutoTokenizer\n",
+ "from optimum.intel import OVModelForFeatureExtraction, OVQuantizer, OVQuantizationConfig, OVConfig\n",
+ "\n",
+ "MODEL_ID = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
+ "base_model_path = \"all-MiniLM-L6-v2\"\n",
+ "int8_ptq_model_path = \"all-MiniLM-L6-v2_int8\"\n",
+ "\n",
+ "model = OVModelForFeatureExtraction.from_pretrained(MODEL_ID)\n",
+ "model.save_pretrained(base_model_path)\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n",
+ "tokenizer.save_pretrained(base_model_path)\n",
+ "\n",
+ "\n",
+ "quantizer = OVQuantizer.from_pretrained(model)\n",
+ "\n",
+ "def preprocess_function(examples, tokenizer):\n",
+ " return tokenizer(examples[\"sentence\"], padding=\"max_length\", max_length=384, truncation=True)\n",
+ "\n",
+ "\n",
+ "calibration_dataset = quantizer.get_calibration_dataset(\n",
+ " \"glue\",\n",
+ " dataset_config_name=\"sst2\",\n",
+ " preprocess_function=partial(preprocess_function, tokenizer=tokenizer),\n",
+ " num_samples=300,\n",
+ " dataset_split=\"train\",\n",
+ ")\n",
+ "\n",
+ "ov_config = OVConfig(quantization_config=OVQuantizationConfig())\n",
+ "\n",
+ "quantizer.quantize(ov_config=ov_config, calibration_dataset=calibration_dataset, save_directory=int8_ptq_model_path)\n",
+ "tokenizer.save_pretrained(int8_ptq_model_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Benchmark model accuracy on GLUE STSB task"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here we estimate accuracy impact from model quantization. We evaluate accuracy of both the baseline and quantized model on a different task from the GLUE benchmark."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from transformers import Pipeline\n",
+ "import torch.nn.functional as F\n",
+ "import torch\n",
+ "\n",
+ "\n",
+ "# copied from the model card \"sentence-transformers/all-MiniLM-L6-v2\"\n",
+ "def mean_pooling(model_output, attention_mask):\n",
+ " token_embeddings = model_output[0] # First element of model_output contains all token embeddings\n",
+ " input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n",
+ " return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n",
+ "\n",
+ "\n",
+ "class SentenceEmbeddingPipeline(Pipeline):\n",
+ " def _sanitize_parameters(self, **kwargs):\n",
+ " # we don\"t have any hyperameters to sanitize\n",
+ " preprocess_kwargs = {}\n",
+ " return preprocess_kwargs, {}, {}\n",
+ "\n",
+ " def preprocess(self, inputs):\n",
+ " encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors=\"pt\")\n",
+ " return encoded_inputs\n",
+ "\n",
+ " def _forward(self, model_inputs):\n",
+ " outputs = self.model(**model_inputs)\n",
+ " return {\"outputs\": outputs, \"attention_mask\": model_inputs[\"attention_mask\"]}\n",
+ "\n",
+ " def postprocess(self, model_outputs):\n",
+ " # Perform pooling\n",
+ " sentence_embeddings = mean_pooling(model_outputs[\"outputs\"], model_outputs[\"attention_mask\"])\n",
+ " # Normalize embeddings\n",
+ " sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)\n",
+ " return sentence_embeddings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Compiling the model to CPU ...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Compiling the model to CPU ...\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = OVModelForFeatureExtraction.from_pretrained(base_model_path)\n",
+ "vanilla_emb = SentenceEmbeddingPipeline(model=model, tokenizer=tokenizer)\n",
+ "\n",
+ "q_model = OVModelForFeatureExtraction.from_pretrained(int8_ptq_model_path)\n",
+ "q8_emb = SentenceEmbeddingPipeline(model=q_model, tokenizer=tokenizer)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from datasets import load_dataset\n",
+ "from evaluate import load\n",
+ "\n",
+ "eval_dataset = load_dataset(\"glue\", \"stsb\", split=\"validation\")\n",
+ "metric = load(\"glue\", \"stsb\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Parameter 'function'= of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "5cab9e8fc58245a4b395a9575017633b",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Map: 0%| | 0/1500 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def compute_sentence_similarity(sentence_1, sentence_2, pipeline):\n",
+ " embedding_1 = pipeline(sentence_1)\n",
+ " embedding_2 = pipeline(sentence_2)\n",
+ " # compute cosine similarity between two sentences\n",
+ " return torch.nn.functional.cosine_similarity(embedding_1, embedding_2, dim=1)\n",
+ "\n",
+ "\n",
+ "def evaluate_stsb(example):\n",
+ " default = compute_sentence_similarity(example[\"sentence1\"], example[\"sentence2\"], vanilla_emb)\n",
+ " quantized = compute_sentence_similarity(example[\"sentence1\"], example[\"sentence2\"], q8_emb)\n",
+ " return {\n",
+ " \"reference\": (example[\"label\"] - 1) / (5 - 1), # rescale to [0,1]\n",
+ " \"default\": float(default),\n",
+ " \"quantized\": float(quantized),\n",
+ " }\n",
+ "\n",
+ "\n",
+ "result = eval_dataset.map(evaluate_stsb)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "vanilla model: pearson= 0.869619439095004\n",
+ "quantized model: pearson= 0.869415534480936\n",
+ "The quantized model achieves 100.0 % accuracy of the fp32 model\n"
+ ]
+ }
+ ],
+ "source": [
+ "default_acc = metric.compute(predictions=result[\"default\"], references=result[\"reference\"])\n",
+ "quantized = metric.compute(predictions=result[\"quantized\"], references=result[\"reference\"])\n",
+ "\n",
+ "print(\"vanilla model: pearson=\", default_acc[\"pearson\"])\n",
+ "print(\"quantized model: pearson=\", quantized[\"pearson\"])\n",
+ "print(\n",
+ " \"The quantized model achieves \",\n",
+ " round(quantized[\"pearson\"] / default_acc[\"pearson\"], 2) * 100,\n",
+ " \"% accuracy of the fp32 model\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Compare performance of the baseline and INT8 models"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We use OpenVINO `benchmark_app` with static input shape `[1,384]` for performance benchmarking. It should reflect the application performance as the tokenizer pads or trancates the input sequence to `max_length=384`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
+ "To disable this warning, you can either:\n",
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[Step 1/11] Parsing and validating input arguments\n",
+ "[ INFO ] Parsing input parameters\n",
+ "[Step 2/11] Loading OpenVINO Runtime\n",
+ "[ INFO ] OpenVINO:\n",
+ "[ INFO ] Build ................................. 2024.4.1-16618-643f23d1318-releases/2024/4\n",
+ "[ INFO ] \n",
+ "[ INFO ] Device info:\n",
+ "[ INFO ] CPU\n",
+ "[ INFO ] Build ................................. 2024.4.1-16618-643f23d1318-releases/2024/4\n",
+ "[ INFO ] \n",
+ "[ INFO ] \n",
+ "[Step 3/11] Setting device configuration\n",
+ "[ WARNING ] Performance hint was not explicitly specified in command line. Device(CPU) performance hint will be set to PerformanceMode.LATENCY.\n",
+ "[Step 4/11] Reading model files\n",
+ "[ INFO ] Loading model files\n",
+ "[ INFO ] Read model took 10.17 ms\n",
+ "[ INFO ] Original model I/O parameters:\n",
+ "[ INFO ] Model inputs:\n",
+ "[ INFO ] input_ids (node: input_ids) : i64 / [...] / [?,?]\n",
+ "[ INFO ] attention_mask (node: attention_mask) : i64 / [...] / [?,?]\n",
+ "[ INFO ] token_type_ids (node: token_type_ids) : i64 / [...] / [?,?]\n",
+ "[ INFO ] Model outputs:\n",
+ "[ INFO ] last_hidden_state (node: __module.encoder.layer.5.output.LayerNorm/aten::layer_norm/Add) : f32 / [...] / [?,?,384]\n",
+ "[Step 5/11] Resizing model to match image sizes and given batch\n",
+ "[ INFO ] Model batch size: 1\n",
+ "[ INFO ] Reshaping model: 'input_ids': [1,384], 'attention_mask': [1,384], 'token_type_ids': [1,384]\n",
+ "[ INFO ] Reshape model took 2.23 ms\n",
+ "[Step 6/11] Configuring input of the model\n",
+ "[ INFO ] Model inputs:\n",
+ "[ INFO ] input_ids (node: input_ids) : i64 / [...] / [1,384]\n",
+ "[ INFO ] attention_mask (node: attention_mask) : i64 / [...] / [1,384]\n",
+ "[ INFO ] token_type_ids (node: token_type_ids) : i64 / [...] / [1,384]\n",
+ "[ INFO ] Model outputs:\n",
+ "[ INFO ] last_hidden_state (node: __module.encoder.layer.5.output.LayerNorm/aten::layer_norm/Add) : f32 / [...] / [1,384,384]\n",
+ "[Step 7/11] Loading the model to the device\n",
+ "[ INFO ] Compile model took 134.63 ms\n",
+ "[Step 8/11] Querying optimal runtime parameters\n",
+ "[ INFO ] Model:\n",
+ "[ INFO ] NETWORK_NAME: Model0\n",
+ "[ INFO ] OPTIMAL_NUMBER_OF_INFER_REQUESTS: 1\n",
+ "[ INFO ] NUM_STREAMS: 1\n",
+ "[ INFO ] INFERENCE_NUM_THREADS: 18\n",
+ "[ INFO ] PERF_COUNT: NO\n",
+ "[ INFO ] INFERENCE_PRECISION_HINT: \n",
+ "[ INFO ] PERFORMANCE_HINT: LATENCY\n",
+ "[ INFO ] EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE\n",
+ "[ INFO ] PERFORMANCE_HINT_NUM_REQUESTS: 0\n",
+ "[ INFO ] ENABLE_CPU_PINNING: True\n",
+ "[ INFO ] SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE\n",
+ "[ INFO ] MODEL_DISTRIBUTION_POLICY: set()\n",
+ "[ INFO ] ENABLE_HYPER_THREADING: False\n",
+ "[ INFO ] EXECUTION_DEVICES: ['CPU']\n",
+ "[ INFO ] CPU_DENORMALS_OPTIMIZATION: False\n",
+ "[ INFO ] LOG_LEVEL: Level.NO\n",
+ "[ INFO ] CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1.0\n",
+ "[ INFO ] DYNAMIC_QUANTIZATION_GROUP_SIZE: 32\n",
+ "[ INFO ] KV_CACHE_PRECISION: \n",
+ "[ INFO ] AFFINITY: Affinity.CORE\n",
+ "[Step 9/11] Creating infer requests and preparing input tensors\n",
+ "[ WARNING ] No input files were given for input 'input_ids'!. This input will be filled with random values!\n",
+ "[ WARNING ] No input files were given for input 'attention_mask'!. This input will be filled with random values!\n",
+ "[ WARNING ] No input files were given for input 'token_type_ids'!. This input will be filled with random values!\n",
+ "[ INFO ] Fill input 'input_ids' with random values \n",
+ "[ INFO ] Fill input 'attention_mask' with random values \n",
+ "[ INFO ] Fill input 'token_type_ids' with random values \n",
+ "[Step 10/11] Measuring performance (Start inference synchronously, limits: 200 iterations)\n",
+ "[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).\n",
+ "[ INFO ] First inference took 12.27 ms\n",
+ "[Step 11/11] Dumping statistics report\n",
+ "[ INFO ] Execution Devices:['CPU']\n",
+ "[ INFO ] Count: 200 iterations\n",
+ "[ INFO ] Duration: 1988.84 ms\n",
+ "[ INFO ] Latency:\n",
+ "[ INFO ] Median: 9.74 ms\n",
+ "[ INFO ] Average: 9.77 ms\n",
+ "[ INFO ] Min: 9.59 ms\n",
+ "[ INFO ] Max: 11.12 ms\n",
+ "[ INFO ] Throughput: 100.56 FPS\n"
+ ]
+ }
+ ],
+ "source": [
+ "# FP32 baseline model\n",
+ "!benchmark_app -m all-MiniLM-L6-v2/openvino_model.xml -shape \"input_ids[1,384],attention_mask[1,384],token_type_ids[1,384]\" -api sync -niter 200"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
+ "To disable this warning, you can either:\n",
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[Step 1/11] Parsing and validating input arguments\n",
+ "[ INFO ] Parsing input parameters\n",
+ "[Step 2/11] Loading OpenVINO Runtime\n",
+ "[ INFO ] OpenVINO:\n",
+ "[ INFO ] Build ................................. 2024.4.1-16618-643f23d1318-releases/2024/4\n",
+ "[ INFO ] \n",
+ "[ INFO ] Device info:\n",
+ "[ INFO ] CPU\n",
+ "[ INFO ] Build ................................. 2024.4.1-16618-643f23d1318-releases/2024/4\n",
+ "[ INFO ] \n",
+ "[ INFO ] \n",
+ "[Step 3/11] Setting device configuration\n",
+ "[ WARNING ] Performance hint was not explicitly specified in command line. Device(CPU) performance hint will be set to PerformanceMode.LATENCY.\n",
+ "[Step 4/11] Reading model files\n",
+ "[ INFO ] Loading model files\n",
+ "[ INFO ] Read model took 20.87 ms\n",
+ "[ INFO ] Original model I/O parameters:\n",
+ "[ INFO ] Model inputs:\n",
+ "[ INFO ] input_ids (node: input_ids) : i64 / [...] / [?,?]\n",
+ "[ INFO ] attention_mask (node: attention_mask) : i64 / [...] / [?,?]\n",
+ "[ INFO ] token_type_ids (node: token_type_ids) : i64 / [...] / [?,?]\n",
+ "[ INFO ] Model outputs:\n",
+ "[ INFO ] last_hidden_state (node: __module.encoder.layer.5.output.LayerNorm/aten::layer_norm/Add) : f32 / [...] / [?,?,384]\n",
+ "[Step 5/11] Resizing model to match image sizes and given batch\n",
+ "[ INFO ] Model batch size: 1\n",
+ "[ INFO ] Reshaping model: 'input_ids': [1,384], 'attention_mask': [1,384], 'token_type_ids': [1,384]\n",
+ "[ INFO ] Reshape model took 3.42 ms\n",
+ "[Step 6/11] Configuring input of the model\n",
+ "[ INFO ] Model inputs:\n",
+ "[ INFO ] input_ids (node: input_ids) : i64 / [...] / [1,384]\n",
+ "[ INFO ] attention_mask (node: attention_mask) : i64 / [...] / [1,384]\n",
+ "[ INFO ] token_type_ids (node: token_type_ids) : i64 / [...] / [1,384]\n",
+ "[ INFO ] Model outputs:\n",
+ "[ INFO ] last_hidden_state (node: __module.encoder.layer.5.output.LayerNorm/aten::layer_norm/Add) : f32 / [...] / [1,384,384]\n",
+ "[Step 7/11] Loading the model to the device\n",
+ "[ INFO ] Compile model took 323.91 ms\n",
+ "[Step 8/11] Querying optimal runtime parameters\n",
+ "[ INFO ] Model:\n",
+ "[ INFO ] NETWORK_NAME: Model0\n",
+ "[ INFO ] OPTIMAL_NUMBER_OF_INFER_REQUESTS: 1\n",
+ "[ INFO ] NUM_STREAMS: 1\n",
+ "[ INFO ] INFERENCE_NUM_THREADS: 18\n",
+ "[ INFO ] PERF_COUNT: NO\n",
+ "[ INFO ] INFERENCE_PRECISION_HINT: \n",
+ "[ INFO ] PERFORMANCE_HINT: LATENCY\n",
+ "[ INFO ] EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE\n",
+ "[ INFO ] PERFORMANCE_HINT_NUM_REQUESTS: 0\n",
+ "[ INFO ] ENABLE_CPU_PINNING: True\n",
+ "[ INFO ] SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE\n",
+ "[ INFO ] MODEL_DISTRIBUTION_POLICY: set()\n",
+ "[ INFO ] ENABLE_HYPER_THREADING: False\n",
+ "[ INFO ] EXECUTION_DEVICES: ['CPU']\n",
+ "[ INFO ] CPU_DENORMALS_OPTIMIZATION: False\n",
+ "[ INFO ] LOG_LEVEL: Level.NO\n",
+ "[ INFO ] CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1.0\n",
+ "[ INFO ] DYNAMIC_QUANTIZATION_GROUP_SIZE: 32\n",
+ "[ INFO ] KV_CACHE_PRECISION: \n",
+ "[ INFO ] AFFINITY: Affinity.CORE\n",
+ "[Step 9/11] Creating infer requests and preparing input tensors\n",
+ "[ WARNING ] No input files were given for input 'input_ids'!. This input will be filled with random values!\n",
+ "[ WARNING ] No input files were given for input 'attention_mask'!. This input will be filled with random values!\n",
+ "[ WARNING ] No input files were given for input 'token_type_ids'!. This input will be filled with random values!\n",
+ "[ INFO ] Fill input 'input_ids' with random values \n",
+ "[ INFO ] Fill input 'attention_mask' with random values \n",
+ "[ INFO ] Fill input 'token_type_ids' with random values \n",
+ "[Step 10/11] Measuring performance (Start inference synchronously, limits: 200 iterations)\n",
+ "[ INFO ] Benchmarking in inference only mode (inputs filling are not included in measurement loop).\n",
+ "[ INFO ] First inference took 6.72 ms\n",
+ "[Step 11/11] Dumping statistics report\n",
+ "[ INFO ] Execution Devices:['CPU']\n",
+ "[ INFO ] Count: 200 iterations\n",
+ "[ INFO ] Duration: 853.85 ms\n",
+ "[ INFO ] Latency:\n",
+ "[ INFO ] Median: 4.13 ms\n",
+ "[ INFO ] Average: 4.15 ms\n",
+ "[ INFO ] Min: 4.05 ms\n",
+ "[ INFO ] Max: 5.13 ms\n",
+ "[ INFO ] Throughput: 234.23 FPS\n"
+ ]
+ }
+ ],
+ "source": [
+ "# INT8 counterpart\n",
+ "!benchmark_app -m all-MiniLM-L6-v2_int8/openvino_model.xml -shape \"input_ids[1,384],attention_mask[1,384],token_type_ids[1,384]\" -api sync -niter 200"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "test3.11_cpu",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py
index 93528e0085..70d2e4885c 100644
--- a/optimum/commands/export/openvino.py
+++ b/optimum/commands/export/openvino.py
@@ -318,6 +318,10 @@ def run(self):
from optimum.intel import OVStableDiffusionPipeline
model_cls = OVStableDiffusionPipeline
+ elif class_name == "StableDiffusion3Pipeline":
+ from optimum.intel import OVStableDiffusion3Pipeline
+
+ model_cls = OVStableDiffusion3Pipeline
else:
raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.")
diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py
index 412ed21f6b..ee61563c98 100644
--- a/optimum/exporters/openvino/__main__.py
+++ b/optimum/exporters/openvino/__main__.py
@@ -493,7 +493,7 @@ def maybe_convert_tokenizers(library_name: str, output: Path, model=None, prepro
f"models won't be generated. Exception: {exception}"
)
elif model:
- for tokenizer_name in ("tokenizer", "tokenizer_2"):
+ for tokenizer_name in ("tokenizer", "tokenizer_2", "tokenizer_3"):
tokenizer = getattr(model, tokenizer_name, None)
if tokenizer:
export_tokenizer(tokenizer, output / tokenizer_name, task=task)
diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py
index e731cd1801..2c076827d2 100644
--- a/optimum/exporters/openvino/convert.py
+++ b/optimum/exporters/openvino/convert.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import functools
import gc
import logging
@@ -31,7 +32,12 @@
from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed
from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx
from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx
-from optimum.exporters.utils import _get_submodels_and_export_configs as _default_get_submodels_and_export_configs
+from optimum.exporters.utils import (
+ _get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
+)
+from optimum.exporters.utils import (
+ get_diffusion_models_for_export,
+)
from optimum.intel.utils.import_utils import (
_nncf_version,
_open_clip_version,
@@ -619,23 +625,27 @@ def export_from_model(
model, library_name, task, preprocessors, custom_export_configs, fn_get_submodels
)
- logging.disable(logging.INFO)
- export_config, models_and_export_configs, stateful_submodels = _get_submodels_and_export_configs(
- model=model,
- task=task,
- monolith=False,
- custom_export_configs=custom_export_configs if custom_export_configs is not None else {},
- custom_architecture=custom_architecture,
- fn_get_submodels=fn_get_submodels,
- preprocessors=preprocessors,
- library_name=library_name,
- model_kwargs=model_kwargs,
- _variant="default",
- legacy=False,
- exporter="openvino",
- stateful=stateful,
- )
- logging.disable(logging.NOTSET)
+ if library_name == "diffusers":
+ export_config, models_and_export_configs = get_diffusion_models_for_export_ext(model, exporter="openvino")
+ stateful_submodels = False
+ else:
+ logging.disable(logging.INFO)
+ export_config, models_and_export_configs, stateful_submodels = _get_submodels_and_export_configs(
+ model=model,
+ task=task,
+ monolith=False,
+ custom_export_configs=custom_export_configs if custom_export_configs is not None else {},
+ custom_architecture=custom_architecture,
+ fn_get_submodels=fn_get_submodels,
+ preprocessors=preprocessors,
+ library_name=library_name,
+ model_kwargs=model_kwargs,
+ _variant="default",
+ legacy=False,
+ exporter="openvino",
+ stateful=stateful,
+ )
+ logging.disable(logging.NOTSET)
if library_name == "open_clip":
if hasattr(model.config, "save_pretrained"):
@@ -701,6 +711,10 @@ def export_from_model(
if tokenizer_2 is not None:
tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))
+ tokenizer_3 = getattr(model, "tokenizer_3", None)
+ if tokenizer_3 is not None:
+ tokenizer_3.save_pretrained(output.joinpath("tokenizer_3"))
+
model.save_config(output)
export_models(
@@ -889,3 +903,218 @@ def _get_submodels_and_export_configs(
)
stateful_per_model = [stateful] * len(models_for_export)
return export_config, models_for_export, stateful_per_model
+
+
+def get_diffusion_models_for_export_ext(
+ pipeline: "DiffusionPipeline", int_dtype: str = "int64", float_dtype: str = "fp32", exporter: str = "openvino"
+):
+ try:
+ from diffusers import (
+ StableDiffusion3Img2ImgPipeline,
+ StableDiffusion3InpaintPipeline,
+ StableDiffusion3Pipeline,
+ )
+
+ is_sd3 = isinstance(
+ pipeline, (StableDiffusion3Pipeline, StableDiffusion3InpaintPipeline, StableDiffusion3Img2ImgPipeline)
+ )
+ except ImportError:
+ is_sd3 = False
+
+ try:
+ from diffusers import FluxPipeline
+
+ is_flux = isinstance(pipeline, FluxPipeline)
+ except ImportError:
+ is_flux = False
+
+ if not is_sd3 and not is_flux:
+ return None, get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter)
+ if is_sd3:
+ models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
+ else:
+ models_for_export = get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype)
+
+ return None, models_for_export
+
+
+def get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype):
+ models_for_export = {}
+
+ # Text encoder
+ text_encoder = getattr(pipeline, "text_encoder", None)
+ if text_encoder is not None:
+ text_encoder.config.output_hidden_states = True
+ text_encoder.text_model.config.output_hidden_states = True
+ text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
+ model=text_encoder,
+ exporter=exporter,
+ library_name="diffusers",
+ task="feature-extraction",
+ model_type="clip-text-with-projection",
+ )
+ text_encoder_export_config = text_encoder_config_constructor(
+ pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
+ )
+ models_for_export["text_encoder"] = (text_encoder, text_encoder_export_config)
+
+ transformer = pipeline.transformer
+ transformer.config.text_encoder_projection_dim = transformer.config.joint_attention_dim
+ transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
+ transformer.config.time_cond_proj_dim = None
+ export_config_constructor = TasksManager.get_exporter_config_constructor(
+ model=transformer,
+ exporter=exporter,
+ library_name="diffusers",
+ task="semantic-segmentation",
+ model_type="sd3-transformer",
+ )
+ transformer_export_config = export_config_constructor(
+ pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype
+ )
+ models_for_export["transformer"] = (transformer, transformer_export_config)
+
+ # VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
+ vae_encoder = copy.deepcopy(pipeline.vae)
+ vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters}
+ vae_config_constructor = TasksManager.get_exporter_config_constructor(
+ model=vae_encoder,
+ exporter=exporter,
+ library_name="diffusers",
+ task="semantic-segmentation",
+ model_type="vae-encoder",
+ )
+ vae_encoder_export_config = vae_config_constructor(
+ vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
+ )
+ models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config)
+
+ # VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
+ vae_decoder = copy.deepcopy(pipeline.vae)
+ vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
+ vae_config_constructor = TasksManager.get_exporter_config_constructor(
+ model=vae_decoder,
+ exporter=exporter,
+ library_name="diffusers",
+ task="semantic-segmentation",
+ model_type="vae-decoder",
+ )
+ vae_decoder_export_config = vae_config_constructor(
+ vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype
+ )
+ models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config)
+
+ text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
+ if text_encoder_2 is not None:
+ text_encoder_2.config.output_hidden_states = True
+ text_encoder_2.text_model.config.output_hidden_states = True
+ export_config_constructor = TasksManager.get_exporter_config_constructor(
+ model=text_encoder_2,
+ exporter=exporter,
+ library_name="diffusers",
+ task="feature-extraction",
+ model_type="clip-text-with-projection",
+ )
+ export_config = export_config_constructor(text_encoder_2.config, int_dtype=int_dtype, float_dtype=float_dtype)
+ models_for_export["text_encoder_2"] = (text_encoder_2, export_config)
+
+ text_encoder_3 = getattr(pipeline, "text_encoder_3", None)
+ if text_encoder_3 is not None:
+ export_config_constructor = TasksManager.get_exporter_config_constructor(
+ model=text_encoder_3,
+ exporter=exporter,
+ library_name="diffusers",
+ task="feature-extraction",
+ model_type="t5-encoder-model",
+ )
+ export_config = export_config_constructor(
+ text_encoder_3.config,
+ int_dtype=int_dtype,
+ float_dtype=float_dtype,
+ )
+ models_for_export["text_encoder_3"] = (text_encoder_3, export_config)
+
+ return models_for_export
+
+
+def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
+ models_for_export = {}
+
+ # Text encoder
+ text_encoder = getattr(pipeline, "text_encoder", None)
+ if text_encoder is not None:
+ text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
+ model=text_encoder,
+ exporter=exporter,
+ library_name="diffusers",
+ task="feature-extraction",
+ model_type="clip-text-model",
+ )
+ text_encoder_export_config = text_encoder_config_constructor(
+ pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
+ )
+ models_for_export["text_encoder"] = (text_encoder, text_encoder_export_config)
+
+ transformer = pipeline.transformer
+ transformer.config.text_encoder_projection_dim = transformer.config.joint_attention_dim
+ transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
+ transformer.config.time_cond_proj_dim = None
+ export_config_constructor = TasksManager.get_exporter_config_constructor(
+ model=transformer,
+ exporter=exporter,
+ library_name="diffusers",
+ task="semantic-segmentation",
+ model_type="flux-transformer",
+ )
+ transformer_export_config = export_config_constructor(
+ pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype
+ )
+ models_for_export["transformer"] = (transformer, transformer_export_config)
+
+ # VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
+ vae_encoder = copy.deepcopy(pipeline.vae)
+ vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters}
+ vae_config_constructor = TasksManager.get_exporter_config_constructor(
+ model=vae_encoder,
+ exporter=exporter,
+ library_name="diffusers",
+ task="semantic-segmentation",
+ model_type="vae-encoder",
+ )
+ vae_encoder_export_config = vae_config_constructor(
+ vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
+ )
+ models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config)
+
+ # VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
+ vae_decoder = copy.deepcopy(pipeline.vae)
+ vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
+ vae_config_constructor = TasksManager.get_exporter_config_constructor(
+ model=vae_decoder,
+ exporter=exporter,
+ library_name="diffusers",
+ task="semantic-segmentation",
+ model_type="vae-decoder",
+ )
+ vae_decoder_export_config = vae_config_constructor(
+ vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype
+ )
+ models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config)
+
+ text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
+ if text_encoder_2 is not None:
+ export_config_constructor = TasksManager.get_exporter_config_constructor(
+ model=text_encoder_2,
+ exporter=exporter,
+ library_name="diffusers",
+ task="feature-extraction",
+ model_type="t5-encoder-model",
+ )
+ export_config = export_config_constructor(
+ text_encoder_2.config,
+ int_dtype=int_dtype,
+ float_dtype=float_dtype,
+ )
+ models_for_export["text_encoder_2"] = (text_encoder_2, export_config)
+
+ return models_for_export
diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py
index 33190e6f1c..ace5c150df 100644
--- a/optimum/exporters/openvino/model_configs.py
+++ b/optimum/exporters/openvino/model_configs.py
@@ -35,22 +35,26 @@
MistralOnnxConfig,
MPTOnnxConfig,
PhiOnnxConfig,
+ UNetOnnxConfig,
VisionOnnxConfig,
)
from optimum.exporters.onnx.model_patcher import ModelPatcher
from optimum.exporters.tasks import TasksManager
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.input_generators import (
+ DTYPE_MAPPER,
DummyInputGenerator,
DummyPastKeyValuesGenerator,
+ DummySeq2SeqDecoderTextInputGenerator,
DummyTextInputGenerator,
+ DummyTimestepInputGenerator,
DummyVisionInputGenerator,
FalconDummyPastKeyValuesGenerator,
MistralDummyPastKeyValuesGenerator,
)
-from optimum.utils.normalized_config import NormalizedTextConfig, NormalizedVisionConfig
+from optimum.utils.normalized_config import NormalizedConfig, NormalizedTextConfig, NormalizedVisionConfig
-from ...intel.utils.import_utils import _transformers_version, is_transformers_version
+from ...intel.utils.import_utils import _transformers_version, is_diffusers_version, is_transformers_version
from .model_patcher import (
AquilaModelPatcher,
ArcticModelPatcher,
@@ -60,6 +64,7 @@
DBRXModelPatcher,
DeciLMModelPatcher,
FalconModelPatcher,
+ FluxTransfromerModelPatcher,
Gemma2ModelPatcher,
GptNeoxJapaneseModelPatcher,
GptNeoxModelPatcher,
@@ -1570,3 +1575,166 @@ def patch_model_for_export(
if self._behavior != InternVLChatConfigBehavior.VISION_EMBEDDINGS:
return super().patch_model_for_export(model, model_kwargs)
return InternVLChatImageEmbeddingModelPatcher(self, model, model_kwargs)
+
+
+class PooledProjectionsDummyInputGenerator(DummyInputGenerator):
+ SUPPORTED_INPUT_NAMES = ["pooled_projections"]
+
+ def __init__(
+ self,
+ task: str,
+ normalized_config: NormalizedConfig,
+ batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
+ random_batch_size_range: Optional[Tuple[int, int]] = None,
+ **kwargs,
+ ):
+ self.task = task
+ self.batch_size = batch_size
+ self.pooled_projection_dim = normalized_config.config.pooled_projection_dim
+
+ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
+ shape = [self.batch_size, self.pooled_projection_dim]
+ return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
+
+
+class DummyTransformerTimestpsInputGenerator(DummyTimestepInputGenerator):
+ SUPPORTED_INPUT_NAMES = ("timestep", "text_embeds", "time_ids", "timestep_cond", "guidance")
+
+ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
+ if input_name in ["timestep", "guidance"]:
+ shape = [self.batch_size]
+ return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype)
+ return super().generate(input_name, framework, int_dtype, float_dtype)
+
+
+@register_in_tasks_manager("sd3-transformer", *["semantic-segmentation"], library_name="diffusers")
+class SD3TransformerOpenVINOConfig(UNetOnnxConfig):
+ DUMMY_INPUT_GENERATOR_CLASSES = (
+ (DummyTransformerTimestpsInputGenerator,)
+ + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
+ + (PooledProjectionsDummyInputGenerator,)
+ )
+ NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
+ image_size="sample_size",
+ num_channels="in_channels",
+ hidden_size="joint_attention_dim",
+ vocab_size="attention_head_dim",
+ allow_new=True,
+ )
+
+ @property
+ def inputs(self):
+ common_inputs = super().inputs
+ common_inputs["pooled_projections"] = {0: "batch_size"}
+ return common_inputs
+
+ def rename_ambiguous_inputs(self, inputs):
+ # The input name in the model signature is `x, hence the export input name is updated.
+ hidden_states = inputs.pop("sample", None)
+ if hidden_states is not None:
+ inputs["hidden_states"] = hidden_states
+ return inputs
+
+
+@register_in_tasks_manager("t5-encoder-model", *["feature-extraction"], library_name="diffusers")
+class T5EncoderOpenVINOConfig(CLIPTextOpenVINOConfig):
+ pass
+
+
+class DummyFluxTransformerInputGenerator(DummyVisionInputGenerator):
+ SUPPORTED_INPUT_NAMES = (
+ "pixel_values",
+ "pixel_mask",
+ "sample",
+ "latent_sample",
+ "hidden_states",
+ "img_ids",
+ )
+
+ def __init__(
+ self,
+ task: str,
+ normalized_config: NormalizedVisionConfig,
+ batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
+ num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
+ width: int = DEFAULT_DUMMY_SHAPES["width"],
+ height: int = DEFAULT_DUMMY_SHAPES["height"],
+ **kwargs,
+ ):
+ super().__init__(task, normalized_config, batch_size, num_channels, width, height, **kwargs)
+ if getattr(normalized_config, "in_channels", None):
+ self.num_channels = normalized_config.in_channels // 4
+
+ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
+ if input_name in ["hidden_states", "sample"]:
+ shape = [self.batch_size, (self.height // 2) * (self.width // 2), self.num_channels * 4]
+ return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
+ if input_name == "img_ids":
+ img_ids_height = self.height // 2
+ img_ids_width = self.width // 2
+ return self.random_int_tensor(
+ [self.batch_size, img_ids_height * img_ids_width, 3]
+ if is_diffusers_version("<", "0.31.0")
+ else [img_ids_height * img_ids_width, 3],
+ min_value=0,
+ max_value=min(img_ids_height, img_ids_width),
+ framework=framework,
+ dtype=float_dtype,
+ )
+
+ return super().generate(input_name, framework, int_dtype, float_dtype)
+
+
+class DummyFluxTextInputGenerator(DummySeq2SeqDecoderTextInputGenerator):
+ SUPPORTED_INPUT_NAMES = (
+ "decoder_input_ids",
+ "decoder_attention_mask",
+ "encoder_outputs",
+ "encoder_hidden_states",
+ "txt_ids",
+ )
+
+ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
+ if input_name == "txt_ids":
+ import torch
+
+ shape = (
+ [self.batch_size, self.sequence_length, 3]
+ if is_diffusers_version("<", "0.31.0")
+ else [self.sequence_length, 3]
+ )
+ dtype = DTYPE_MAPPER.pt(float_dtype)
+ return torch.full(shape, 0, dtype=dtype)
+ return super().generate(input_name, framework, int_dtype, float_dtype)
+
+
+@register_in_tasks_manager("flux-transformer", *["semantic-segmentation"], library_name="diffusers")
+class FluxTransformerOpenVINOConfig(SD3TransformerOpenVINOConfig):
+ DUMMY_INPUT_GENERATOR_CLASSES = (
+ DummyTransformerTimestpsInputGenerator,
+ DummyFluxTransformerInputGenerator,
+ DummyFluxTextInputGenerator,
+ PooledProjectionsDummyInputGenerator,
+ )
+
+ @property
+ def inputs(self):
+ common_inputs = super().inputs
+ common_inputs.pop("sample", None)
+ common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"}
+ common_inputs["txt_ids"] = (
+ {0: "batch_size", 1: "sequence_length"} if is_diffusers_version("<", "0.31.0") else {0: "sequence_length"}
+ )
+ common_inputs["img_ids"] = (
+ {0: "batch_size", 1: "packed_height_width"}
+ if is_diffusers_version("<", "0.31.0")
+ else {0: "packed_height_width"}
+ )
+ if getattr(self._normalized_config, "guidance_embeds", False):
+ common_inputs["guidance"] = {0: "batch_size"}
+ return common_inputs
+
+ def patch_model_for_export(
+ self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
+ ) -> ModelPatcher:
+ return FluxTransfromerModelPatcher(self, model, model_kwargs=model_kwargs)
diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py
index eadce6d382..7e5cd76a76 100644
--- a/optimum/exporters/openvino/model_patcher.py
+++ b/optimum/exporters/openvino/model_patcher.py
@@ -29,6 +29,7 @@
_openvino_version,
_torch_version,
_transformers_version,
+ is_diffusers_version,
is_openvino_version,
is_torch_version,
is_transformers_version,
@@ -2504,6 +2505,26 @@ def patched_forward(*args, **kwargs):
self.patched_forward = patched_forward
+ def __enter__(self):
+ super().__enter__()
+ if is_transformers_version(">=", "4.45.0"):
+ from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES
+
+ sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"]
+ eager_attn = GEMMA2_ATTENTION_CLASSES["eager"]
+
+ for layer in self._model.model.layers:
+ if isinstance(layer.self_attn, eager_attn):
+ layer.self_attn._orig_forward = layer.self_attn.forward
+ layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ super().__exit__(exc_type, exc_value, traceback)
+ if is_transformers_version(">=", "4.45.0"):
+ for layer in self._model.model.layers:
+ if hasattr(layer.self_attn, "_orig_forward"):
+ layer.self_attn.forward = layer.self_attn._orig_forward
+
def _decilm_attn_forward(
self,
@@ -2705,3 +2726,40 @@ def __init__(
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
+
+
+def _embednb_forward(self, ids: torch.Tensor) -> torch.Tensor:
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
+ assert dim % 2 == 0, "The dimension must be even."
+
+ scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
+ omega = 1.0 / (theta**scale)
+
+ batch_size, seq_length = pos.shape
+ out = pos.unsqueeze(-1) * omega.unsqueeze(0).unsqueeze(0)
+ cos_out = torch.cos(out)
+ sin_out = torch.sin(out)
+
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
+ return out.float()
+
+ n_axes = ids.shape[-1]
+ emb = torch.cat(
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+ dim=-3,
+ )
+ return emb.unsqueeze(1)
+
+
+class FluxTransfromerModelPatcher(ModelPatcher):
+ def __enter__(self):
+ super().__enter__()
+ if is_diffusers_version("<", "0.31.0"):
+ self._model.pos_embed._orig_forward = self._model.pos_embed.forward
+ self._model.pos_embed.forward = types.MethodType(_embednb_forward, self._model.pos_embed)
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ super().__exit__(exc_type, exc_value, traceback)
+ if hasattr(self._model.pos_embed, "_orig_forward"):
+ self._model.pos_embed.forward = self._model.pos_embed._orig_forward
diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py
index 5926f1869c..67a01011a2 100644
--- a/optimum/intel/__init__.py
+++ b/optimum/intel/__init__.py
@@ -100,8 +100,12 @@
"OVStableDiffusionXLPipeline",
"OVStableDiffusionXLImg2ImgPipeline",
"OVStableDiffusionXLInpaintPipeline",
+ "OVStableDiffusion3Pipeline",
+ "OVStableDiffusion3Image2ImagePipeline",
+ "OVStableDiffusion3InpaintPipeline",
"OVLatentConsistencyModelPipeline",
"OVLatentConsistencyModelImg2ImgPipeline",
+ "OVFluxPipeline",
"OVPipelineForImage2Image",
"OVPipelineForText2Image",
"OVPipelineForInpainting",
@@ -116,8 +120,12 @@
"OVStableDiffusionXLPipeline",
"OVStableDiffusionXLImg2ImgPipeline",
"OVStableDiffusionXLInpaintPipeline",
+ "OVStableDiffusion3Pipeline",
+ "OVStableDiffusion3Image2ImagePipeline",
+ "OVStableDiffusion3InpaintPipeline",
"OVLatentConsistencyModelPipeline",
"OVLatentConsistencyModelImg2ImgPipeline",
+ "OVFluxPipeline",
"OVPipelineForImage2Image",
"OVPipelineForText2Image",
"OVPipelineForInpainting",
@@ -263,10 +271,14 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_openvino_and_diffusers_objects import (
OVDiffusionPipeline,
+ OVFluxPipeline,
OVLatentConsistencyModelPipeline,
OVPipelineForImage2Image,
OVPipelineForInpainting,
OVPipelineForText2Image,
+ OVStableDiffusion3Img2ImgPipeline,
+ OVStableDiffusion3InpaintPipeline,
+ OVStableDiffusion3Pipeline,
OVStableDiffusionImg2ImgPipeline,
OVStableDiffusionInpaintPipeline,
OVStableDiffusionPipeline,
@@ -276,11 +288,15 @@
else:
from .openvino import (
OVDiffusionPipeline,
+ OVFluxPipeline,
OVLatentConsistencyModelImg2ImgPipeline,
OVLatentConsistencyModelPipeline,
OVPipelineForImage2Image,
OVPipelineForInpainting,
OVPipelineForText2Image,
+ OVStableDiffusion3Img2ImgPipeline,
+ OVStableDiffusion3InpaintPipeline,
+ OVStableDiffusion3Pipeline,
OVStableDiffusionImg2ImgPipeline,
OVStableDiffusionInpaintPipeline,
OVStableDiffusionPipeline,
diff --git a/optimum/intel/openvino/__init__.py b/optimum/intel/openvino/__init__.py
index 549bf8170d..589a0938e3 100644
--- a/optimum/intel/openvino/__init__.py
+++ b/optimum/intel/openvino/__init__.py
@@ -82,11 +82,15 @@
if is_diffusers_available():
from .modeling_diffusion import (
OVDiffusionPipeline,
+ OVFluxPipeline,
OVLatentConsistencyModelImg2ImgPipeline,
OVLatentConsistencyModelPipeline,
OVPipelineForImage2Image,
OVPipelineForInpainting,
OVPipelineForText2Image,
+ OVStableDiffusion3Img2ImgPipeline,
+ OVStableDiffusion3InpaintPipeline,
+ OVStableDiffusion3Pipeline,
OVStableDiffusionImg2ImgPipeline,
OVStableDiffusionInpaintPipeline,
OVStableDiffusionPipeline,
diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py
index 22e8bf314f..8bca8cc9a8 100644
--- a/optimum/intel/openvino/modeling_diffusion.py
+++ b/optimum/intel/openvino/modeling_diffusion.py
@@ -22,7 +22,7 @@
from copy import deepcopy
from pathlib import Path
from tempfile import gettempdir
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Union
import numpy as np
import openvino
@@ -82,6 +82,20 @@
else:
from diffusers.models.vae import DiagonalGaussianDistribution
+if is_diffusers_version(">=", "0.29.0"):
+ from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline
+else:
+ StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline = StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
+
+if is_diffusers_version(">=", "0.30.0"):
+ from diffusers import FluxPipeline, StableDiffusion3InpaintPipeline
+else:
+ StableDiffusion3InpaintPipeline = StableDiffusionInpaintPipeline
+ FluxPipeline = StableDiffusionPipeline
+
+
+DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER = "transformer"
+DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER = "text_encoder_3"
core = Core()
@@ -99,15 +113,18 @@ class OVDiffusionPipeline(OVBaseModel, DiffusionPipeline):
def __init__(
self,
scheduler: SchedulerMixin,
- unet: openvino.runtime.Model,
- vae_decoder: openvino.runtime.Model,
+ unet: Optional[openvino.runtime.Model] = None,
+ vae_decoder: Optional[openvino.runtime.Model] = None,
# optional pipeline models
vae_encoder: Optional[openvino.runtime.Model] = None,
text_encoder: Optional[openvino.runtime.Model] = None,
text_encoder_2: Optional[openvino.runtime.Model] = None,
+ text_encoder_3: Optional[openvino.runtime.Model] = None,
+ transformer: Optional[openvino.runtime.Model] = None,
# optional pipeline submodels
tokenizer: Optional[CLIPTokenizer] = None,
tokenizer_2: Optional[CLIPTokenizer] = None,
+ tokenizer_3: Optional[CLIPTokenizer] = None,
feature_extractor: Optional[CLIPFeatureExtractor] = None,
# stable diffusion xl specific arguments
force_zeros_for_empty_prompt: bool = True,
@@ -149,7 +166,15 @@ def __init__(
f"Please set `compile_only=False` or `dynamic_shapes={model_is_dynamic}`"
)
- self.unet = OVModelUnet(unet, self, DIFFUSION_MODEL_UNET_SUBFOLDER)
+ self.unet = OVModelUnet(unet, self, DIFFUSION_MODEL_UNET_SUBFOLDER) if unet is not None else None
+ self.transformer = (
+ OVModelTransformer(transformer, self, DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER)
+ if transformer is not None
+ else None
+ )
+
+ if unet is None and transformer is None:
+ raise ValueError("`unet` or `transformer` model should be provided for pipeline work")
self.vae_decoder = OVModelVaeDecoder(vae_decoder, self, DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER)
self.vae_encoder = (
OVModelVaeEncoder(vae_encoder, self, DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER)
@@ -166,12 +191,18 @@ def __init__(
if text_encoder_2 is not None
else None
)
+ self.text_encoder_3 = (
+ OVModelTextEncoder(text_encoder_3, self, DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER)
+ if text_encoder_3 is not None
+ else None
+ )
# We wrap the VAE Decoder & Encoder in a single object to simulate diffusers API
self.vae = OVModelVae(decoder=self.vae_decoder, encoder=self.vae_encoder)
self.scheduler = scheduler
self.tokenizer = tokenizer
self.tokenizer_2 = tokenizer_2
+ self.tokenizer_3 = tokenizer_3
self.feature_extractor = feature_extractor
# we allow passing these as torch models for now
@@ -181,13 +212,16 @@ def __init__(
all_pipeline_init_args = {
"vae": self.vae,
"unet": self.unet,
+ "transformer": self.transformer,
"text_encoder": self.text_encoder,
"text_encoder_2": self.text_encoder_2,
+ "text_encoder_3": self.text_encoder_3,
"safety_checker": self.safety_checker,
"image_encoder": self.image_encoder,
"scheduler": self.scheduler,
"tokenizer": self.tokenizer,
"tokenizer_2": self.tokenizer_2,
+ "tokenizer_3": self.tokenizer_3,
"feature_extractor": self.feature_extractor,
"requires_aesthetics_score": requires_aesthetics_score,
"force_zeros_for_empty_prompt": force_zeros_for_empty_prompt,
@@ -236,6 +270,8 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
(self.vae_encoder, save_directory / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER),
(self.text_encoder, save_directory / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER),
(self.text_encoder_2, save_directory / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER),
+ (self.text_encoder_3, save_directory / DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER),
+ (self.transformer, save_directory / DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER),
}
for model, save_path in models_to_save_paths:
if model is not None:
@@ -254,6 +290,8 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
self.tokenizer.save_pretrained(save_directory / "tokenizer")
if self.tokenizer_2 is not None:
self.tokenizer_2.save_pretrained(save_directory / "tokenizer_2")
+ if self.tokenizer_3 is not None:
+ self.tokenizer_3.save_pretrained(save_directory / "tokenizer_3")
if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(save_directory / "feature_extractor")
@@ -294,6 +332,8 @@ def _from_pretrained(
vae_encoder_file_name: Optional[str] = None,
text_encoder_file_name: Optional[str] = None,
text_encoder_2_file_name: Optional[str] = None,
+ text_encoder_3_file_name: Optional[str] = None,
+ transformer_file_name: Optional[str] = None,
from_onnx: bool = False,
load_in_8bit: bool = False,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
@@ -314,6 +354,8 @@ def _from_pretrained(
vae_decoder_file_name = vae_decoder_file_name or default_file_name
text_encoder_file_name = text_encoder_file_name or default_file_name
text_encoder_2_file_name = text_encoder_2_file_name or default_file_name
+ text_encoder_3_file_name = text_encoder_3_file_name or default_file_name
+ transformer_file_name = transformer_file_name or default_file_name
if not os.path.isdir(str(model_id)):
all_components = {key for key in config.keys() if not key.startswith("_")} | {"vae_encoder", "vae_decoder"}
@@ -321,15 +363,19 @@ def _from_pretrained(
allow_patterns.update(
{
unet_file_name,
+ transformer_file_name,
vae_encoder_file_name,
vae_decoder_file_name,
text_encoder_file_name,
text_encoder_2_file_name,
+ text_encoder_3_file_name,
unet_file_name.replace(".xml", ".bin"),
+ transformer_file_name.replace(".xml", ".bin"),
vae_encoder_file_name.replace(".xml", ".bin"),
vae_decoder_file_name.replace(".xml", ".bin"),
text_encoder_file_name.replace(".xml", ".bin"),
text_encoder_2_file_name.replace(".xml", ".bin"),
+ text_encoder_3_file_name.replace(".xml", ".bin"),
SCHEDULER_CONFIG_NAME,
cls.config_name,
CONFIG_NAME,
@@ -357,9 +403,15 @@ def _from_pretrained(
if model_save_dir is None:
model_save_dir = model_save_path
- submodels = {"scheduler": None, "tokenizer": None, "tokenizer_2": None, "feature_extractor": None}
+ submodels = {
+ "scheduler": None,
+ "tokenizer": None,
+ "tokenizer_2": None,
+ "tokenizer_3": None,
+ "feature_extractor": None,
+ }
for name in submodels.keys():
- if kwargs.get(name, None) is not None:
+ if kwargs.get(name) is not None:
submodels[name] = kwargs.pop(name)
elif config.get(name, (None, None))[0] is not None:
library_name, library_classes = config.get(name)
@@ -374,17 +426,19 @@ def _from_pretrained(
models = {
"unet": model_save_path / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name,
+ "transformer": model_save_path / DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER / transformer_file_name,
"vae_decoder": model_save_path / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name,
"vae_encoder": model_save_path / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name,
"text_encoder": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name,
"text_encoder_2": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name,
+ "text_encoder_3": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER / text_encoder_3_file_name,
}
compile_only = kwargs.get("compile_only", False)
quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
if (quantization_config is None or quantization_config.dataset is None) and not compile_only:
for name, path in models.items():
- if kwargs.get(name, None) is not None:
+ if name in kwargs:
models[name] = kwargs.pop(name)
else:
models[name] = cls.load_model(path, quantization_config) if path.is_file() else None
@@ -395,7 +449,7 @@ def _from_pretrained(
if "GPU" in device.upper() and "INFERENCE_PRECISION_HINT" not in vae_ov_conifg:
vae_ov_conifg["INFERENCE_PRECISION_HINT"] = "f32"
for name, path in models.items():
- if kwargs.get(name, None) is not None:
+ if name in kwargs:
models[name] = kwargs.pop(name)
else:
models[name] = (
@@ -416,7 +470,7 @@ def _from_pretrained(
from optimum.intel import OVQuantizer
for name, path in models.items():
- if kwargs.get(name, None) is not None:
+ if name in kwargs:
models[name] = kwargs.pop(name)
else:
models[name] = cls.load_model(path) if path.is_file() else None
@@ -431,7 +485,6 @@ def _from_pretrained(
quantizer.quantize(ov_config=OVConfig(quantization_config=hybrid_quantization_config))
return ov_pipeline
-
ov_pipeline = ov_pipeline_class(
**models,
**submodels,
@@ -483,6 +536,7 @@ def _from_transformers(
no_post_process=True,
revision=revision,
cache_dir=cache_dir,
+ task=cls.export_feature,
token=token,
local_files_only=local_files_only,
force_download=force_download,
@@ -515,7 +569,7 @@ def to(self, *args, device: Optional[str] = None, dtype: Optional[torch.dtype] =
if isinstance(device, str):
self._device = device.upper()
- self.request = None
+ self.clear_requests()
elif device is not None:
raise ValueError(
"The `device` argument should be a string representing the device on which the model should be loaded."
@@ -531,21 +585,24 @@ def to(self, *args, device: Optional[str] = None, dtype: Optional[torch.dtype] =
@property
def height(self) -> int:
- height = self.unet.model.inputs[0].get_partial_shape()[2]
+ model = self.vae.decoder.model
+ height = model.inputs[0].get_partial_shape()[2]
if height.is_dynamic:
return -1
return height.get_length() * self.vae_scale_factor
@property
def width(self) -> int:
- width = self.unet.model.inputs[0].get_partial_shape()[3]
+ model = self.vae.decoder.model
+ width = model.inputs[0].get_partial_shape()[3]
if width.is_dynamic:
return -1
return width.get_length() * self.vae_scale_factor
@property
def batch_size(self) -> int:
- batch_size = self.unet.model.inputs[0].get_partial_shape()[0]
+ model = self.unet.model if self.unet is not None else self.transformer.model
+ batch_size = model.inputs[0].get_partial_shape()[0]
if batch_size.is_dynamic:
return -1
return batch_size.get_length()
@@ -597,6 +654,65 @@ def _reshape_unet(
model.reshape(shapes)
return model
+ def _reshape_transformer(
+ self,
+ model: openvino.runtime.Model,
+ batch_size: int = -1,
+ height: int = -1,
+ width: int = -1,
+ num_images_per_prompt: int = -1,
+ tokenizer_max_length: int = -1,
+ ):
+ if batch_size == -1 or num_images_per_prompt == -1:
+ batch_size = -1
+ else:
+ # The factor of 2 comes from the guidance scale > 1
+ batch_size *= num_images_per_prompt
+ if "img_ids" not in {inputs.get_any_name() for inputs in model.inputs}:
+ batch_size *= 2
+
+ height = height // self.vae_scale_factor if height > 0 else height
+ width = width // self.vae_scale_factor if width > 0 else width
+ packed_height = height // 2 if height > 0 else height
+ packed_width = width // 2 if width > 0 else width
+ packed_height_width = packed_width * packed_height if height > 0 and width > 0 else -1
+ shapes = {}
+ for inputs in model.inputs:
+ shapes[inputs] = inputs.get_partial_shape()
+ if inputs.get_any_name() in ["timestep", "guidance"]:
+ shapes[inputs][0] = batch_size
+ elif inputs.get_any_name() == "hidden_states":
+ in_channels = self.transformer.config.get("in_channels", None)
+ if in_channels is None:
+ in_channels = (
+ shapes[inputs][1] if inputs.get_partial_shape().rank.get_length() == 4 else shapes[inputs][2]
+ )
+ if in_channels.is_dynamic:
+ logger.warning(
+ "Could not identify `in_channels` from the unet configuration, to statically reshape the unet please provide a configuration."
+ )
+ self.is_dynamic = True
+ if inputs.get_partial_shape().rank.get_length() == 4:
+ shapes[inputs] = [batch_size, in_channels, height, width]
+ else:
+ shapes[inputs] = [batch_size, packed_height_width, in_channels]
+
+ elif inputs.get_any_name() == "pooled_projections":
+ shapes[inputs] = [batch_size, self.transformer.config["pooled_projection_dim"]]
+ elif inputs.get_any_name() == "img_ids":
+ shapes[inputs] = (
+ [batch_size, packed_height_width, 3]
+ if is_diffusers_version("<", "0.31.0")
+ else [packed_height_width, 3]
+ )
+ elif inputs.get_any_name() == "txt_ids":
+ shapes[inputs] = [batch_size, -1, 3] if is_diffusers_version("<", "0.31.0") else [-1, 3]
+ else:
+ shapes[inputs][0] = batch_size
+ shapes[inputs][1] = -1 # text_encoder_3 may have vary input length
+ model.reshape(shapes)
+ return model
+
def _reshape_text_encoder(
self, model: openvino.runtime.Model, batch_size: int = -1, tokenizer_max_length: int = -1
):
@@ -658,9 +774,14 @@ def reshape(
self.tokenizer.model_max_length if self.tokenizer is not None else self.tokenizer_2.model_max_length
)
- self.unet.model = self._reshape_unet(
- self.unet.model, batch_size, height, width, num_images_per_prompt, tokenizer_max_len
- )
+ if self.unet is not None:
+ self.unet.model = self._reshape_unet(
+ self.unet.model, batch_size, height, width, num_images_per_prompt, tokenizer_max_len
+ )
+ if self.transformer is not None:
+ self.transformer.model = self._reshape_transformer(
+ self.transformer.model, batch_size, height, width, num_images_per_prompt, tokenizer_max_len
+ )
self.vae_decoder.model = self._reshape_vae_decoder(
self.vae_decoder.model, height, width, num_images_per_prompt
)
@@ -678,6 +799,11 @@ def reshape(
self.text_encoder_2.model, batch_size, self.tokenizer_2.model_max_length
)
+ if self.text_encoder_3 is not None:
+ self.text_encoder_3.model = self._reshape_text_encoder(
+ self.text_encoder_3.model, batch_size, self.tokenizer_3.model_max_length
+ )
+
self.clear_requests()
return self
@@ -690,7 +816,15 @@ def half(self):
"`half()` is not supported with `compile_only` mode, please intialize model without this option"
)
- for component in {self.unet, self.vae_encoder, self.vae_decoder, self.text_encoder, self.text_encoder_2}:
+ for component in {
+ self.unet,
+ self.transformer,
+ self.vae_encoder,
+ self.vae_decoder,
+ self.text_encoder,
+ self.text_encoder_2,
+ self.text_encoder_3,
+ }:
if component is not None:
compress_model_transformation(component.model)
@@ -704,12 +838,28 @@ def clear_requests(self):
"`clear_requests()` is not supported with `compile_only` mode, please intialize model without this option"
)
- for component in {self.unet, self.vae_encoder, self.vae_decoder, self.text_encoder, self.text_encoder_2}:
+ for component in [
+ self.unet,
+ self.transformer,
+ self.vae_encoder,
+ self.vae_decoder,
+ self.text_encoder,
+ self.text_encoder_2,
+ self.text_encoder_3,
+ ]:
if component is not None:
component.request = None
def compile(self):
- for component in {self.unet, self.vae_encoder, self.vae_decoder, self.text_encoder, self.text_encoder_2}:
+ for component in [
+ self.unet,
+ self.transformer,
+ self.vae_encoder,
+ self.vae_decoder,
+ self.text_encoder,
+ self.text_encoder_2,
+ self.text_encoder_3,
+ ]:
if component is not None:
component._compile()
@@ -725,8 +875,10 @@ def components(self) -> Dict[str, Any]:
components = {
"vae": self.vae,
"unet": self.unet,
+ "transformer": self.transformer,
"text_encoder": self.text_encoder,
"text_encoder_2": self.text_encoder_2,
+ "text_encoder_3": self.text_encoder_2,
"safety_checker": self.safety_checker,
"image_encoder": self.image_encoder,
}
@@ -841,6 +993,12 @@ def modules(self):
class OVModelTextEncoder(OVPipelinePart):
+ def __init__(self, model: openvino.runtime.Model, parent_pipeline: OVDiffusionPipeline, model_name: str = ""):
+ super().__init__(model, parent_pipeline, model_name)
+ self.hidden_states_output_names = sorted(
+ {name for out in self.model.outputs for name in out.names if name.startswith("hidden_states")}
+ )
+
def forward(
self,
input_ids: Union[np.ndarray, torch.Tensor],
@@ -849,24 +1007,26 @@ def forward(
return_dict: bool = False,
):
self._compile()
-
model_inputs = {"input_ids": input_ids}
- ov_outputs = self.request(model_inputs, share_inputs=True).to_dict()
-
+ ov_outputs = self.request(model_inputs, share_inputs=True)
+ main_out = ov_outputs[0]
model_outputs = {}
- for key, value in ov_outputs.items():
- model_outputs[next(iter(key.names))] = torch.from_numpy(value)
-
- if output_hidden_states:
- model_outputs["hidden_states"] = []
- for i in range(self.config.num_hidden_layers):
- model_outputs["hidden_states"].append(model_outputs.pop(f"hidden_states.{i}"))
- model_outputs["hidden_states"].append(model_outputs.get("last_hidden_state"))
+ model_outputs[self.model.outputs[0].get_any_name()] = torch.from_numpy(main_out)
+ if len(self.model.outputs) > 1 and "pooler_output" in self.model.outputs[1].get_any_name():
+ model_outputs["pooler_output"] = torch.from_numpy(ov_outputs[1])
+ if self.hidden_states_output_names and "last_hidden_state" not in model_outputs:
+ model_outputs["last_hidden_state"] = torch.from_numpy(ov_outputs[self.hidden_states_output_names[-1]])
+ if (
+ self.hidden_states_output_names
+ and output_hidden_states
+ or getattr(self.config, "output_hidden_states", False)
+ ):
+ hidden_states = [torch.from_numpy(ov_outputs[out_name]) for out_name in self.hidden_states_output_names]
+ model_outputs["hidden_states"] = hidden_states
if return_dict:
return model_outputs
-
return ModelOutput(**model_outputs)
@@ -924,6 +1084,48 @@ def forward(
return ModelOutput(**model_outputs)
+class OVModelTransformer(OVPipelinePart):
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ pooled_projections: torch.FloatTensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ block_controlnet_hidden_states: List = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ):
+ self._compile()
+
+ model_inputs = {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "pooled_projections": pooled_projections,
+ }
+
+ if img_ids is not None:
+ model_inputs["img_ids"] = img_ids
+ if txt_ids is not None:
+ model_inputs["txt_ids"] = txt_ids
+ if guidance is not None:
+ model_inputs["guidance"] = guidance
+
+ ov_outputs = self.request(model_inputs, share_inputs=True).to_dict()
+
+ model_outputs = {}
+ for key, value in ov_outputs.items():
+ model_outputs[next(iter(key.names))] = torch.from_numpy(value)
+
+ if return_dict:
+ return model_outputs
+
+ return ModelOutput(**model_outputs)
+
+
class OVModelVaeEncoder(OVPipelinePart):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -1197,6 +1399,34 @@ class OVLatentConsistencyModelImg2ImgPipeline(
auto_model_class = LatentConsistencyModelImg2ImgPipeline
+class OVStableDiffusion3Pipeline(OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusion3Pipeline):
+ main_input_name = "prompt"
+ export_feature = "text-to-image"
+ auto_model_class = StableDiffusion3Pipeline
+
+
+class OVStableDiffusion3Img2ImgPipeline(
+ OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusion3Img2ImgPipeline
+):
+ main_input_name = "image"
+ export_feature = "image-to-image"
+ auto_model_class = StableDiffusion3Img2ImgPipeline
+
+
+class OVStableDiffusion3InpaintPipeline(
+ OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusion3InpaintPipeline
+):
+ main_input_name = "image"
+ export_feature = "inpainting"
+ auto_model_class = StableDiffusion3InpaintPipeline
+
+
+class OVFluxPipeline(OVDiffusionPipeline, OVTextualInversionLoaderMixin, FluxPipeline):
+ main_input_name = "prompt"
+ export_feature = "text-to-image"
+ auto_model_class = FluxPipeline
+
+
SUPPORTED_OV_PIPELINES = [
OVStableDiffusionPipeline,
OVStableDiffusionImg2ImgPipeline,
@@ -1244,6 +1474,23 @@ def _get_ov_class(pipeline_class_name: str, throw_error_if_not_exist: bool = Tru
]
)
+if is_diffusers_version(">=", "0.29.0"):
+ SUPPORTED_OV_PIPELINES.extend(
+ [
+ OVStableDiffusion3Pipeline,
+ OVStableDiffusion3Img2ImgPipeline,
+ ]
+ )
+
+ OV_TEXT2IMAGE_PIPELINES_MAPPING["stable-diffusion-3"] = OVStableDiffusion3Pipeline
+ OV_IMAGE2IMAGE_PIPELINES_MAPPING["stable-diffusion-3"] = OVStableDiffusion3Img2ImgPipeline
+
+if is_diffusers_version(">=", "0.30.0"):
+ SUPPORTED_OV_PIPELINES.extend([OVStableDiffusion3InpaintPipeline, OVFluxPipeline])
+ OV_INPAINT_PIPELINES_MAPPING["stable-diffusion-3"] = OVStableDiffusion3InpaintPipeline
+ OV_TEXT2IMAGE_PIPELINES_MAPPING["flux"] = OVFluxPipeline
+
+
SUPPORTED_OV_PIPELINES_MAPPINGS = [
OV_TEXT2IMAGE_PIPELINES_MAPPING,
OV_IMAGE2IMAGE_PIPELINES_MAPPING,
@@ -1299,13 +1546,16 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
class OVPipelineForText2Image(OVPipelineForTask):
auto_model_class = AutoPipelineForText2Image
ov_pipelines_mapping = OV_TEXT2IMAGE_PIPELINES_MAPPING
+ export_feature = "text-to-image"
class OVPipelineForImage2Image(OVPipelineForTask):
auto_model_class = AutoPipelineForImage2Image
ov_pipelines_mapping = OV_IMAGE2IMAGE_PIPELINES_MAPPING
+ export_feature = "image-to-image"
class OVPipelineForInpainting(OVPipelineForTask):
auto_model_class = AutoPipelineForInpainting
ov_pipelines_mapping = OV_INPAINT_PIPELINES_MAPPING
+ export_feature = "inpainting"
diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py
index 1ad75477cc..c2e880e62a 100644
--- a/optimum/intel/openvino/quantization.py
+++ b/optimum/intel/openvino/quantization.py
@@ -380,15 +380,27 @@ def _quantize_ovbasemodel(
quantization_config_copy = copy.deepcopy(quantization_config)
quantization_config_copy.dataset = None
quantization_config_copy.quant_method = OVQuantizationMethod.DEFAULT
- sub_model_names = ["vae_encoder", "vae_decoder", "text_encoder", "text_encoder_2"]
+ sub_model_names = [
+ "vae_encoder",
+ "vae_decoder",
+ "text_encoder",
+ "text_encoder_2",
+ "text_encoder_3",
+ ]
sub_models = filter(lambda x: x, (getattr(self.model, name) for name in sub_model_names))
for sub_model in sub_models:
_weight_only_quantization(sub_model.model, quantization_config_copy)
- # Apply hybrid quantization to UNet
- self.model.unet.model = _hybrid_quantization(
- self.model.unet.model, quantization_config, calibration_dataset
- )
+ if self.model.unet is not None:
+ # Apply hybrid quantization to UNet
+ self.model.unet.model = _hybrid_quantization(
+ self.model.unet.model, quantization_config, calibration_dataset
+ )
+ else:
+ self.model.transformer.model = _hybrid_quantization(
+ self.model.transformer.model, quantization_config, calibration_dataset
+ )
+
self.model.clear_requests()
else:
# The model may be for example OVModelForImageClassification, OVModelForAudioClassification, etc.
@@ -396,7 +408,15 @@ def _quantize_ovbasemodel(
self.model.request = None
else:
if is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline):
- sub_model_names = ["vae_encoder", "vae_decoder", "text_encoder", "text_encoder_2", "unet"]
+ sub_model_names = [
+ "vae_encoder",
+ "vae_decoder",
+ "text_encoder",
+ "text_encoder_2",
+ "unet",
+ "transformer",
+ "text_encoder_3",
+ ]
sub_models = filter(lambda x: x, (getattr(self.model, name) for name in sub_model_names))
for sub_model in sub_models:
_weight_only_quantization(sub_model.model, quantization_config)
@@ -743,7 +763,9 @@ def _prepare_unet_dataset(
) -> nncf.Dataset:
self.model.compile()
- size = self.model.unet.config.get("sample_size", 64) * self.model.vae_scale_factor
+ diffuser = self.model.unet if self.model.unet is not None else self.model.transformer
+
+ size = diffuser.config.get("sample_size", 64) * self.model.vae_scale_factor
height, width = 2 * (min(size, 512),)
num_samples = num_samples or 200
@@ -784,7 +806,7 @@ def transform_fn(data_item):
calibration_data = []
try:
- self.model.unet.request = InferRequestWrapper(self.model.unet.request, calibration_data)
+ diffuser.request = InferRequestWrapper(diffuser.request, calibration_data)
for inputs in dataset:
inputs = transform_fn(inputs)
@@ -795,7 +817,7 @@ def transform_fn(data_item):
if len(calibration_data) >= num_samples:
break
finally:
- self.model.unet.request = self.model.unet.request.request
+ diffuser.request = diffuser.request.request
calibration_dataset = nncf.Dataset(calibration_data[:num_samples])
return calibration_dataset
diff --git a/optimum/intel/openvino/utils.py b/optimum/intel/openvino/utils.py
index fcc6944e9f..ca7d177201 100644
--- a/optimum/intel/openvino/utils.py
+++ b/optimum/intel/openvino/utils.py
@@ -119,6 +119,8 @@
"audio-classification": "OVModelForAudioClassification",
"stable-diffusion": "OVStableDiffusionPipeline",
"stable-diffusion-xl": "OVStableDiffusionXLPipeline",
+ "stable-diffusion-3": "OVStableDiffusion3Pipeline",
+ "flux": "OVFluxPipeline",
"pix2struct": "OVModelForPix2Struct",
"latent-consistency": "OVLatentConsistencyModelPipeline",
"open_clip_text": "OVModelOpenCLIPText",
diff --git a/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py b/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py
index 6ded4fd5df..38aea6c1f1 100644
--- a/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py
+++ b/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py
@@ -145,3 +145,47 @@ def __init__(self, *args, **kwargs):
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["openvino", "diffusers"])
+
+
+class OVStableDiffusion3Img2ImgPipeline(metaclass=DummyObject):
+ _backends = ["openvino", "diffusers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["openvino", "diffusers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["openvino", "diffusers"])
+
+
+class OVStableDiffusion3Pipeline(metaclass=DummyObject):
+ _backends = ["openvino", "diffusers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["openvino", "diffusers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["openvino", "diffusers"])
+
+
+class OVStableDiffusion3InpaintPipeline(metaclass=DummyObject):
+ _backends = ["openvino", "diffusers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["openvino", "diffusers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["openvino", "diffusers"])
+
+
+class OVFluxPipeline(metaclass=DummyObject):
+ _backends = ["openvino", "diffusers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["openvino", "diffusers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["openvino", "diffusers"])
diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py
index a05efc46c7..a39957bbf7 100644
--- a/optimum/intel/utils/modeling_utils.py
+++ b/optimum/intel/utils/modeling_utils.py
@@ -123,17 +123,20 @@ def _find_files_matching_pattern(
str(model_name_or_path), subfolder=subfolder, revision=revision, token=token
)
if library_name == "diffusers":
- subfolder = os.path.join(subfolder, "unet")
+ subfolders = [os.path.join(subfolder, "unet"), os.path.join(subfolder, "transformer")]
else:
- subfolder = subfolder or "."
+ subfolders = [subfolder or "."]
if model_path.is_dir():
- glob_pattern = subfolder + "/*"
- files = model_path.glob(glob_pattern)
- files = [p for p in files if re.search(pattern, str(p))]
+ files = []
+ for subfolder in subfolders:
+ glob_pattern = subfolder + "/*"
+ files_ = model_path.glob(glob_pattern)
+ files_ = [p for p in files_ if re.search(pattern, str(p))]
+ files.extend(files_)
else:
repo_files = map(Path, HfApi().list_repo_files(model_name_or_path, revision=revision, token=token))
- files = [Path(p) for p in repo_files if re.match(pattern, str(p)) and str(p.parent) == subfolder]
+ files = [Path(p) for p in repo_files if re.match(pattern, str(p)) and str(p.parent) in subfolders]
return files
diff --git a/optimum/intel/version.py b/optimum/intel/version.py
index e118ea7131..16bf124e0e 100644
--- a/optimum/intel/version.py
+++ b/optimum/intel/version.py
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "1.20.0.dev0"
+__version__ = "1.21.0.dev0"
diff --git a/tests/openvino/test_diffusion.py b/tests/openvino/test_diffusion.py
index 687c1f5c02..1467e5ed1f 100644
--- a/tests/openvino/test_diffusion.py
+++ b/tests/openvino/test_diffusion.py
@@ -35,6 +35,7 @@
OVPipelineForInpainting,
OVPipelineForText2Image,
)
+from optimum.intel.utils.import_utils import is_transformers_version
from optimum.utils.testing_utils import require_diffusers
@@ -73,6 +74,11 @@ def _generate_images(height=128, width=128, batch_size=1, channel=3, input_type=
class OVPipelineForText2ImageTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"]
+ NEGATIVE_PROMPT_SUPPORT_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"]
+ if is_transformers_version(">=", "4.40.0"):
+ SUPPORTED_ARCHITECTURES.extend(["stable-diffusion-3", "flux"])
+ NEGATIVE_PROMPT_SUPPORT_ARCHITECTURES.append("stable-diffusion-3")
+ CALLBACK_SUPPORT_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"]
OVMODEL_CLASS = OVPipelineForText2Image
AUTOMODEL_CLASS = AutoPipelineForText2Image
@@ -126,8 +132,8 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
height, width, batch_size = 128, 128, 1
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
- ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
- diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
+ ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], text_encoder_3=None)
+ diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], text_encoder_3=None)
for output_type in ["latent", "np", "pt"]:
inputs["output_type"] = output_type
@@ -135,9 +141,9 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
- np.testing.assert_allclose(ov_output, diffusers_output, atol=1e-4, rtol=1e-2)
+ np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)
- @parameterized.expand(SUPPORTED_ARCHITECTURES)
+ @parameterized.expand(CALLBACK_SUPPORT_ARCHITECTURES)
@require_diffusers
def test_callback(self, model_arch: str):
height, width, batch_size = 64, 128, 1
@@ -184,10 +190,26 @@ def test_shape(self, model_arch: str):
elif output_type == "pt":
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
else:
- self.assertEqual(
- outputs.shape,
- (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor),
- )
+ if model_arch != "flux":
+ out_channels = (
+ pipeline.unet.config.out_channels
+ if pipeline.unet is not None
+ else pipeline.transformer.config.out_channels
+ )
+ self.assertEqual(
+ outputs.shape,
+ (
+ batch_size,
+ out_channels,
+ height // pipeline.vae_scale_factor,
+ width // pipeline.vae_scale_factor,
+ ),
+ )
+ else:
+ packed_height = height // pipeline.vae_scale_factor
+ packed_width = width // pipeline.vae_scale_factor
+ channels = pipeline.transformer.config.in_channels
+ self.assertEqual(outputs.shape, (batch_size, packed_height * packed_width, channels))
@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_diffusers
@@ -205,7 +227,7 @@ def test_image_reproducibility(self, model_arch: str):
self.assertFalse(np.array_equal(ov_outputs_1.images[0], ov_outputs_3.images[0]))
np.testing.assert_allclose(ov_outputs_1.images[0], ov_outputs_2.images[0], atol=1e-4, rtol=1e-2)
- @parameterized.expand(SUPPORTED_ARCHITECTURES)
+ @parameterized.expand(NEGATIVE_PROMPT_SUPPORT_ARCHITECTURES)
def test_negative_prompt(self, model_arch: str):
height, width, batch_size = 64, 64, 1
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
@@ -229,6 +251,22 @@ def test_negative_prompt(self, model_arch: str):
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
+ elif model_arch == "stable-diffusion-3":
+ (
+ inputs["prompt_embeds"],
+ inputs["negative_prompt_embeds"],
+ inputs["pooled_prompt_embeds"],
+ inputs["negative_pooled_prompt_embeds"],
+ ) = pipeline.encode_prompt(
+ prompt=prompt,
+ prompt_2=None,
+ prompt_3=None,
+ num_images_per_prompt=1,
+ device=torch.device("cpu"),
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+
else:
inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = pipeline.encode_prompt(
prompt=prompt,
@@ -288,11 +326,18 @@ def test_height_width_properties(self, model_arch: str):
)
self.assertFalse(ov_pipeline.is_dynamic)
+ expected_batch = batch_size * num_images_per_prompt
+ if (
+ ov_pipeline.unet is not None
+ and "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs}
+ ) or (
+ ov_pipeline.transformer is not None
+ and "txt_ids" not in {inputs.get_any_name() for inputs in ov_pipeline.transformer.model.inputs}
+ ):
+ expected_batch *= 2
self.assertEqual(
ov_pipeline.batch_size,
- batch_size
- * num_images_per_prompt
- * (2 if "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs} else 1),
+ expected_batch,
)
self.assertEqual(ov_pipeline.height, height)
self.assertEqual(ov_pipeline.width, width)
@@ -324,6 +369,8 @@ def test_textual_inversion(self):
class OVPipelineForImage2ImageTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"]
+ if is_transformers_version(">=", "4.40.0"):
+ SUPPORTED_ARCHITECTURES.append("stable-diffusion-3")
AUTOMODEL_CLASS = AutoPipelineForImage2Image
OVMODEL_CLASS = OVPipelineForImage2Image
@@ -369,7 +416,7 @@ def test_num_images_per_prompt(self, model_arch: str):
outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3))
- @parameterized.expand(SUPPORTED_ARCHITECTURES)
+ @parameterized.expand(["stable-diffusion", "stable-diffusion-xl", "latent-consistency"])
@require_diffusers
def test_callback(self, model_arch: str):
height, width, batch_size = 32, 64, 1
@@ -416,9 +463,19 @@ def test_shape(self, model_arch: str):
elif output_type == "pt":
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
else:
+ out_channels = (
+ pipeline.unet.config.out_channels
+ if pipeline.unet is not None
+ else pipeline.transformer.config.out_channels
+ )
self.assertEqual(
outputs.shape,
- (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor),
+ (
+ batch_size,
+ out_channels,
+ height // pipeline.vae_scale_factor,
+ width // pipeline.vae_scale_factor,
+ ),
)
@parameterized.expand(SUPPORTED_ARCHITECTURES)
@@ -427,16 +484,17 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
height, width, batch_size = 128, 128, 1
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
- diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
- ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
+ diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], text_encoder_3=None)
+ ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], text_encoder_3=None)
for output_type in ["latent", "np", "pt"]:
+ print(output_type)
inputs["output_type"] = output_type
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
- np.testing.assert_allclose(ov_output, diffusers_output, atol=1e-4, rtol=1e-2)
+ np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)
@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_diffusers
@@ -500,12 +558,12 @@ def test_height_width_properties(self, model_arch: str):
)
self.assertFalse(ov_pipeline.is_dynamic)
- self.assertEqual(
- ov_pipeline.batch_size,
- batch_size
- * num_images_per_prompt
- * (2 if "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs} else 1),
- )
+ expected_batch = batch_size * num_images_per_prompt
+ if ov_pipeline.unet is None or "timestep_cond" not in {
+ inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs
+ }:
+ expected_batch *= 2
+ self.assertEqual(ov_pipeline.batch_size, expected_batch)
self.assertEqual(ov_pipeline.height, height)
self.assertEqual(ov_pipeline.width, width)
@@ -537,6 +595,9 @@ def test_textual_inversion(self):
class OVPipelineForInpaintingTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl"]
+ if is_transformers_version(">=", "4.40.0"):
+ SUPPORTED_ARCHITECTURES.append("stable-diffusion-3")
+
AUTOMODEL_CLASS = AutoPipelineForInpainting
OVMODEL_CLASS = OVPipelineForInpainting
@@ -586,7 +647,7 @@ def test_num_images_per_prompt(self, model_arch: str):
outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3))
- @parameterized.expand(SUPPORTED_ARCHITECTURES)
+ @parameterized.expand(["stable-diffusion", "stable-diffusion-xl"])
@require_diffusers
def test_callback(self, model_arch: str):
height, width, batch_size = 32, 64, 1
@@ -633,9 +694,19 @@ def test_shape(self, model_arch: str):
elif output_type == "pt":
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
else:
+ out_channels = (
+ pipeline.unet.config.out_channels
+ if pipeline.unet is not None
+ else pipeline.transformer.config.out_channels
+ )
self.assertEqual(
outputs.shape,
- (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor),
+ (
+ batch_size,
+ out_channels,
+ height // pipeline.vae_scale_factor,
+ width // pipeline.vae_scale_factor,
+ ),
)
@parameterized.expand(SUPPORTED_ARCHITECTURES)
@@ -653,7 +724,7 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
- np.testing.assert_allclose(ov_output, diffusers_output, atol=1e-4, rtol=1e-2)
+ np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)
@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_diffusers
@@ -717,11 +788,14 @@ def test_height_width_properties(self, model_arch: str):
)
self.assertFalse(ov_pipeline.is_dynamic)
+ expected_batch = batch_size * num_images_per_prompt
+ if ov_pipeline.unet is None or "timestep_cond" not in {
+ inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs
+ }:
+ expected_batch *= 2
self.assertEqual(
ov_pipeline.batch_size,
- batch_size
- * num_images_per_prompt
- * (2 if "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs} else 1),
+ expected_batch,
)
self.assertEqual(ov_pipeline.height, height)
self.assertEqual(ov_pipeline.width, width)
diff --git a/tests/openvino/test_export.py b/tests/openvino/test_export.py
index 43c535e673..6a42c4a09f 100644
--- a/tests/openvino/test_export.py
+++ b/tests/openvino/test_export.py
@@ -27,6 +27,7 @@
from optimum.exporters.openvino import export_from_model, main_export
from optimum.exporters.tasks import TasksManager
from optimum.intel import (
+ OVFluxPipeline,
OVLatentConsistencyModelPipeline,
OVModelForAudioClassification,
OVModelForCausalLM,
@@ -40,13 +41,14 @@
OVModelForSequenceClassification,
OVModelForSpeechSeq2Seq,
OVModelForTokenClassification,
+ OVStableDiffusion3Pipeline,
OVStableDiffusionPipeline,
OVStableDiffusionXLImg2ImgPipeline,
OVStableDiffusionXLPipeline,
)
from optimum.intel.openvino.modeling_base import OVBaseModel
from optimum.intel.openvino.utils import TemporaryDirectory
-from optimum.intel.utils.import_utils import _transformers_version
+from optimum.intel.utils.import_utils import _transformers_version, is_transformers_version
from optimum.utils.save_utils import maybe_load_preprocessors
@@ -70,6 +72,9 @@ class ExportModelTest(unittest.TestCase):
"latent-consistency": OVLatentConsistencyModelPipeline,
}
+ if is_transformers_version(">=", "4.45"):
+ SUPPORTED_ARCHITECTURES.update({"stable-diffusion-3": OVStableDiffusion3Pipeline, "flux": OVFluxPipeline})
+
GENERATIVE_MODELS = ("pix2struct", "t5", "bart", "gpt2", "whisper")
def _openvino_export(
diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py
index cea6c94fcd..7542a347da 100644
--- a/tests/openvino/test_exporters_cli.py
+++ b/tests/openvino/test_exporters_cli.py
@@ -25,6 +25,7 @@
from optimum.exporters.openvino.__main__ import main_export
from optimum.intel import ( # noqa
+ OVFluxPipeline,
OVLatentConsistencyModelPipeline,
OVModelForAudioClassification,
OVModelForCausalLM,
@@ -39,6 +40,7 @@
OVModelOpenCLIPText,
OVModelOpenCLIPVisual,
OVSentenceTransformer,
+ OVStableDiffusion3Pipeline,
OVStableDiffusionPipeline,
OVStableDiffusionXLPipeline,
)
@@ -48,6 +50,7 @@
compare_versions,
is_openvino_tokenizers_available,
is_tokenizers_version,
+ is_transformers_version,
)
@@ -56,7 +59,7 @@ class OVCLIExportTestCase(unittest.TestCase):
Integration tests ensuring supported models are correctly exported.
"""
- SUPPORTED_ARCHITECTURES = (
+ SUPPORTED_ARCHITECTURES = [
("text-generation", "gpt2"),
("text-generation-with-past", "gpt2"),
("text2text-generation", "t5"),
@@ -71,7 +74,10 @@ class OVCLIExportTestCase(unittest.TestCase):
("text-to-image", "stable-diffusion"),
("text-to-image", "stable-diffusion-xl"),
("image-to-image", "stable-diffusion-xl-refiner"),
- )
+ ]
+
+ if is_transformers_version(">=", "4.45"):
+ SUPPORTED_ARCHITECTURES.extend([("text-to-image", "stable-diffusion-3"), ("text-to-image", "flux")])
EXPECTED_NUMBER_OF_TOKENIZER_MODELS = {
"gpt2": 2 if is_tokenizers_version("<", "0.20") else 0,
"t5": 0, # no .model file in the repository
@@ -84,13 +90,18 @@ class OVCLIExportTestCase(unittest.TestCase):
"blenderbot": 2 if is_tokenizers_version("<", "0.20") else 0,
"stable-diffusion": 2 if is_tokenizers_version("<", "0.20") else 0,
"stable-diffusion-xl": 4 if is_tokenizers_version("<", "0.20") else 0,
+ "stable-diffusion-3": 6 if is_tokenizers_version("<", "0.20") else 2,
+ "flux": 4 if is_tokenizers_version("<", "0.20") else 0,
}
- SUPPORTED_SD_HYBRID_ARCHITECTURES = (
+ SUPPORTED_SD_HYBRID_ARCHITECTURES = [
("stable-diffusion", 72, 195),
("stable-diffusion-xl", 84, 331),
("latent-consistency", 50, 135),
- )
+ ]
+
+ if is_transformers_version(">=", "4.45"):
+ SUPPORTED_SD_HYBRID_ARCHITECTURES.append(("stable-diffusion-3", 9, 65))
TEST_4BIT_CONFIGURATONS = [
("text-generation-with-past", "opt125m", "int4 --sym --group-size 128", {"int8": 4, "int4": 72}),
@@ -208,8 +219,8 @@ def test_exporters_cli_int8(self, task: str, model_type: str):
models = [model.encoder, model.decoder]
if task.endswith("with-past"):
models.append(model.decoder_with_past)
- elif model_type.startswith("stable-diffusion"):
- models = [model.unet, model.vae_encoder, model.vae_decoder]
+ elif model_type.startswith("stable-diffusion") or model_type.startswith("flux"):
+ models = [model.unet or model.transformer, model.vae_encoder, model.vae_decoder]
models.append(model.text_encoder if model_type == "stable-diffusion" else model.text_encoder_2)
else:
models = [model]
@@ -228,7 +239,9 @@ def test_exporters_cli_hybrid_quantization(self, model_type: str, exp_num_fq: in
check=True,
)
model = eval(_HEAD_TO_AUTOMODELS[model_type.replace("-refiner", "")]).from_pretrained(tmpdir)
- num_fq, num_weight_nodes = get_num_quantized_nodes(model.unet)
+ num_fq, num_weight_nodes = get_num_quantized_nodes(
+ model.unet if model.unet is not None else model.transformer
+ )
self.assertEqual(exp_num_int8, num_weight_nodes["int8"])
self.assertEqual(exp_num_fq, num_fq)
diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py
index 606f0a1048..68261dbd48 100644
--- a/tests/openvino/test_modeling.py
+++ b/tests/openvino/test_modeling.py
@@ -863,6 +863,10 @@ def test_compare_to_transformers(self, model_arch):
if model_arch in self.REMOTE_CODE_MODELS:
model_kwargs = {"trust_remote_code": True}
+ # starting from transformers 4.45.0 gemma2 uses eager attention by default, while ov - sdpa
+ if model_arch == "gemma2" and is_transformers_version(">=", "4.45.0"):
+ model_kwargs["attn_implementation"] = "sdpa"
+
ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG, **model_kwargs)
self.assertIsInstance(ov_model.config, PretrainedConfig)
self.assertTrue(ov_model.use_cache)
@@ -1096,6 +1100,11 @@ def test_beam_search(self, model_arch):
"trust_remote_code": True,
}
+
+ # starting from transformers 4.45.0 gemma2 uses eager attention by default, while ov - sdpa
+ if model_arch == "gemma2" and is_transformers_version(">=", "4.45.0"):
+ model_kwargs["attn_implementation"] = "sdpa"
+
# Qwen tokenizer does not support padding, chatglm, glm4 testing models produce nan that incompatible with beam search
if model_arch in ["qwen", "chatglm", "glm4"]:
return
diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py
index b294e3e221..f2a4dc723f 100644
--- a/tests/openvino/test_quantization.py
+++ b/tests/openvino/test_quantization.py
@@ -56,6 +56,8 @@
OVModelForSpeechSeq2Seq,
OVStableDiffusionPipeline,
OVStableDiffusionXLPipeline,
+ OVStableDiffusion3Pipeline,
+ OVFluxPipeline,
OVQuantizer,
OVTrainer,
OVQuantizationConfig,
@@ -300,11 +302,18 @@ class OVWeightCompressionTest(unittest.TestCase):
(OVModelOpenCLIPForZeroShotImageClassification, "open-clip"),
)
- SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION = (
+ SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION = [
(OVStableDiffusionPipeline, "stable-diffusion", 72, 195),
(OVStableDiffusionXLPipeline, "stable-diffusion-xl", 84, 331),
(OVLatentConsistencyModelPipeline, "latent-consistency", 50, 135),
- )
+ ]
+
+ if is_transformers_version(">=", "4.45.0"):
+ SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION.extend(
+ [
+ (OVStableDiffusion3Pipeline, "stable-diffusion-3", 9, 65),
+ ]
+ )
IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3")
@@ -454,7 +463,9 @@ def test_ovmodel_hybrid_quantization(self, model_cls, model_type, expected_num_f
with TemporaryDirectory() as tmp_dir:
model = model_cls.from_pretrained(model_id, export=True, quantization_config=quantization_config)
- num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(model.unet)
+ num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(
+ model.unet if model.unet is not None else model.transformer
+ )
self.assertEqual(expected_num_fake_quantize, num_fake_quantize)
self.assertEqual(expected_ov_int8, num_weight_nodes["int8"])
self.assertEqual(0, num_weight_nodes["int4"])
@@ -468,7 +479,9 @@ def test_stable_diffusion_with_weight_compression(self):
quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config))
- num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(int8_pipe.unet)
+ num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(
+ int8_pipe.unet if int8_pipe.unet is not None else int8_pipe.transformer
+ )
self.assertEqual(0, num_fake_quantize)
self.assertEqual(242, num_weight_nodes["int8"])
self.assertEqual(0, num_weight_nodes["int4"])
@@ -487,7 +500,9 @@ def test_ovmodel_hybrid_quantization_with_custom_dataset(
self.assertEqual(quantization_config.quant_method, OVQuantizationMethod.HYBRID)
quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config), calibration_dataset=dataset)
- num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(model.unet)
+ num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(
+ model.unet if model.unet is not None else model.transformer
+ )
self.assertEqual(expected_num_fake_quantize, num_fake_quantize)
self.assertEqual(expected_ov_int8, num_weight_nodes["int8"])
self.assertEqual(0, num_weight_nodes["int4"])
diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py
index d7eea01dba..e5a9f73a64 100644
--- a/tests/openvino/utils_tests.py
+++ b/tests/openvino/utils_tests.py
@@ -59,6 +59,7 @@
"falcon": "fxmarty/really-tiny-falcon-testing",
"falcon-40b": "katuni4ka/tiny-random-falcon-40b",
"flaubert": "hf-internal-testing/tiny-random-flaubert",
+ "flux": "katuni4ka/tiny-random-flux",
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
@@ -118,6 +119,7 @@
"stable-diffusion-openvino": "hf-internal-testing/tiny-stable-diffusion-openvino",
"stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl",
"stable-diffusion-xl-refiner": "echarlaix/tiny-random-stable-diffusion-xl-refiner",
+ "stable-diffusion-3": "yujiepan/stable-diffusion-3-tiny-random",
"stablelm": "hf-internal-testing/tiny-random-StableLmForCausalLM",
"starcoder2": "hf-internal-testing/tiny-random-Starcoder2ForCausalLM",
"latent-consistency": "echarlaix/tiny-random-latent-consistency",
@@ -170,6 +172,8 @@
"stable-diffusion-xl": (366, 34, 42, 66),
"stable-diffusion-xl-refiner": (366, 34, 42, 66),
"open-clip": (20, 28),
+ "stable-diffusion-3": (66, 42, 58, 30),
+ "flux": (56, 24, 28, 64),
}