From 02835ce8833b5e2b67ba1a87bf85b0739335ac4d Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Tue, 10 Dec 2024 09:57:51 +0100 Subject: [PATCH 1/7] [OV] Quantization of Whisper pipeline (#1040) * Add whisper quantization * Extend quantization to decoder model too * Add documentation * Fix tests * Add quantization on from_pretrained; change test whisper model * Update docs * Update test reference for older transformers version * Address comments * Tweak reference * Tweak test * Tweak reference * Style * Change references * Trigger Tests * Create 'tests-openvino' extra dependency * Style --- .github/workflows/test_openvino.yml | 2 +- .github/workflows/test_openvino_full.yml | 2 +- .github/workflows/test_openvino_slow.yml | 2 +- docs/source/openvino/optimization.mdx | 19 ++ optimum/intel/openvino/configuration.py | 79 +++++- optimum/intel/openvino/modeling_seq2seq.py | 22 +- optimum/intel/openvino/quantization.py | 265 +++++++++++++++------ optimum/intel/openvino/utils.py | 11 +- setup.py | 1 + tests/openvino/test_quantization.py | 61 ++++- tests/openvino/utils_tests.py | 2 +- 11 files changed, 378 insertions(+), 88 deletions(-) diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml index 7583c51078..67f6d680a4 100644 --- a/.github/workflows/test_openvino.yml +++ b/.github/workflows/test_openvino.yml @@ -43,7 +43,7 @@ jobs: run: | pip install --upgrade pip pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - pip install .[openvino,openvino-tokenizers,diffusers,tests] transformers[testing] + pip install .[openvino,openvino-tokenizers,diffusers,tests,tests-openvino] transformers[testing] - if: ${{ matrix.transformers-version != 'latest' }} name: Downgrade Transformers and Accelerate diff --git a/.github/workflows/test_openvino_full.yml b/.github/workflows/test_openvino_full.yml index 914035b750..3455f8ca54 100644 --- a/.github/workflows/test_openvino_full.yml +++ b/.github/workflows/test_openvino_full.yml @@ -56,7 +56,7 @@ jobs: python -m pip install --upgrade pip # Install PyTorch CPU to prevent unnecessary downloading/installing of CUDA packages pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - pip install .[tests] + pip install .[tests,tests-openvino] - name: Install openvino-nightly if: ${{ matrix.openvino == 'ov-nightly' }} diff --git a/.github/workflows/test_openvino_slow.yml b/.github/workflows/test_openvino_slow.yml index 9ad5ef2691..f7555c64bc 100644 --- a/.github/workflows/test_openvino_slow.yml +++ b/.github/workflows/test_openvino_slow.yml @@ -42,7 +42,7 @@ jobs: run: | pip install --upgrade pip pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - pip install .[openvino,tests] transformers[testing] + pip install .[openvino,tests,tests-openvino] transformers[testing] pip uninstall -y nncf - if: ${{ matrix.transformers-version != 'latest' }} diff --git a/docs/source/openvino/optimization.mdx b/docs/source/openvino/optimization.mdx index 28de5ffa4b..147421dd4a 100644 --- a/docs/source/openvino/optimization.mdx +++ b/docs/source/openvino/optimization.mdx @@ -166,6 +166,25 @@ calibration_dataset = quantizer.get_calibration_dataset( The `quantize()` method applies post-training static quantization and export the resulting quantized model to the OpenVINO Intermediate Representation (IR). The resulting graph is represented with two files: an XML file describing the network topology and a binary file describing the weights. The resulting model can be run on any target Intel device. +#### Speech-to-text Models Quantization + +The speech-to-text Whisper model can be quantized without the need for preparing a custom calibration dataset. Please see example below. + +```python +model_id = "openai/whisper-tiny" +ov_model = OVModelForSpeechSeq2Seq.from_pretrained( + model_id, + quantization_config=OVQuantizationConfig( + num_samples=10, + dataset="librispeech", + processor=model_id, + matmul_sq_alpha=0.95, + ) +) +``` + +With this, encoder, decoder and decoder-with-past models of the Whisper pipeline will be fully quantized, including activations. + ### Hybrid quantization Traditional optimization methods like post-training 8-bit quantization do not work well for Stable Diffusion (SD) models and can lead to poor generation results. On the other hand, weight compression does not improve performance significantly when applied to Stable Diffusion models, as the size of activations is comparable to weights. diff --git a/optimum/intel/openvino/configuration.py b/optimum/intel/openvino/configuration.py index b34cd84cd0..61ba98119e 100644 --- a/optimum/intel/openvino/configuration.py +++ b/optimum/intel/openvino/configuration.py @@ -26,7 +26,7 @@ from optimum.configuration_utils import BaseConfig from ..utils.import_utils import is_nncf_available -from .utils import PREDEFINED_SD_DATASETS, PREDEFINED_VISUAL_LM_DATASETS +from .utils import PREDEFINED_SD_DATASETS, PREDEFINED_SPEECH_TO_TEXT_DATASETS, PREDEFINED_VISUAL_LM_DATASETS if is_nncf_available(): @@ -255,6 +255,10 @@ def __init__( sym: bool = False, ignored_scope: Optional[dict] = None, num_samples: Optional[int] = None, + dataset: Optional[Optional[Union[str, List[str]]]] = None, + tokenizer: Optional[str] = None, + processor: Optional[str] = None, + trust_remote_code: bool = False, **kwargs, ): """ @@ -272,6 +276,10 @@ def __init__( self.bits = bits self.sym = sym self.num_samples = num_samples + self.dataset = dataset + self.tokenizer = tokenizer + self.processor = processor + self.trust_remote_code = trust_remote_code if isinstance(ignored_scope, nncf.IgnoredScope): ignored_scope = ignored_scope.__dict__ @@ -313,6 +321,10 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase): user or organization name, like `dbmdz/bert-base-german-cased`. - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set + for repositories you trust and in which you have read the code, as it will execute on your local machine + arbitrary code present in the model repository. dataset (`str or List[str]`, *optional*): The dataset used for data-aware compression with NNCF. - For language models you can provide your own dataset in a list of strings or just use one from the list @@ -395,10 +407,16 @@ def __init__( backup_precision: Optional[str] = None, **kwargs, ): - super().__init__(bits=bits, sym=sym, ignored_scope=ignored_scope, num_samples=num_samples) - self.tokenizer = tokenizer - self.trust_remote_code = trust_remote_code - self.dataset = dataset + super().__init__( + bits=bits, + sym=sym, + ignored_scope=ignored_scope, + num_samples=num_samples, + dataset=dataset, + tokenizer=tokenizer, + processor=processor, + trust_remote_code=trust_remote_code, + ) self.group_size = group_size or (-1 if bits == 8 else 128) self.ratio = ratio self.all_layers = all_layers @@ -407,7 +425,6 @@ def __init__( self.scale_estimation = scale_estimation self.weight_format = weight_format self.gptq = gptq - self.processor = processor self.lora_correction = lora_correction self.backup_precision = backup_precision self.post_init() @@ -535,6 +552,11 @@ def __init__( model_type: str = "transformer", fast_bias_correction: bool = True, overflow_fix: str = "disable", + dataset: Optional[str] = None, + tokenizer: Optional[str] = None, + processor: Optional[str] = None, + trust_remote_code: bool = False, + smooth_quant_alpha: Optional[float] = None, **kwargs, ): """ @@ -557,11 +579,42 @@ def __init__( Whether to apply fast or full bias correction algorithm. overflow_fix (`str`, default to "disable"): Parameter for controlling overflow fix setting. + dataset (`str`, *optional*): + The dataset used for quantization. For text-to-speech model quantization the allowed value is 'librispeech'. + tokenizer (`str`, *optional*): + The tokenizer used to process the dataset. You can pass either: + - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved + using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + processor (`str`, *optional*): + A transformers processor used to process inputs for multi-modal models. You can pass either: + - A string, the *model id* of a predefined processor hosted inside a model repo on huggingface.co. + - A path to a *directory* containing files required by the processor, for instance saved + using the [`~AutoProcessor.save_pretrained`] method, e.g., `./my_model_directory/`. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set + for repositories you trust and in which you have read the code, as it will execute on your local machine + arbitrary code present in the model repository. + smooth_quant_alpha (`float`, *optional*): + SmoothQuant alpha parameter that improves the distribution of activations before MatMul layers and + reduces quantization error. """ - super().__init__(bits=bits, sym=sym, ignored_scope=ignored_scope, num_samples=num_samples) + super().__init__( + bits=bits, + sym=sym, + ignored_scope=ignored_scope, + num_samples=num_samples, + dataset=dataset, + tokenizer=tokenizer, + processor=processor, + trust_remote_code=trust_remote_code, + ) self.model_type = model_type self.fast_bias_correction = fast_bias_correction self.overflow_fix = overflow_fix + self.smooth_quant_alpha = smooth_quant_alpha self.post_init() def post_init(self): @@ -573,6 +626,18 @@ def post_init(self): if self.bits != 8: raise ValueError(f"Only support 8-bit for static quantization but found {self.bits}") + if self.dataset is not None: + if self.dataset not in PREDEFINED_SPEECH_TO_TEXT_DATASETS: + raise ValueError( + f"You have entered the following string value for dataset: {self.dataset}. But it is not supported." + f" Currently you can only choose {list(PREDEFINED_SPEECH_TO_TEXT_DATASETS.keys())}." + ) + + if self.smooth_quant_alpha is not None and not (0 <= self.smooth_quant_alpha <= 1): + raise ValueError( + f"SmoothQuant alpha parameter must be in range [0, 1], but found {self.smooth_quant_alpha}" + ) + class OVConfig(BaseConfig): CONFIG_NAME = "openvino_config.json" diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index 3e70bca7f3..fa48430a77 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy import logging import os from pathlib import Path @@ -35,7 +35,9 @@ from transformers.generation import GenerationMixin from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput +from .. import OVConfig, OVQuantizer from ..utils import is_transformers_version +from .configuration import OVQuantizationConfig, OVQuantizationConfigBase from .modeling_base_seq2seq import OVBaseModelForSeq2SeqLM from .utils import OV_TO_PT_TYPE, _print_compiled_model_properties @@ -973,9 +975,25 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: "PretrainedConfig", + load_in_8bit: bool = False, + quantization_config: Union[dict, OVQuantizationConfigBase] = None, **kwargs, ): - return super(OVModelForSpeechSeq2Seq, cls)._from_pretrained(model_id, config, **kwargs) + compile_only = kwargs.get("compile_only", False) + + if not compile_only and isinstance(quantization_config, OVQuantizationConfig): + model = super(OVModelForSpeechSeq2Seq, cls)._from_pretrained( + model_id, config, load_in_8bit=False, **kwargs + ) + quantization_config_copy = copy.deepcopy(quantization_config) + quantization_config_copy.processor = quantization_config.processor or model_id + OVQuantizer(model).quantize(ov_config=OVConfig(quantization_config=quantization_config_copy)) + else: + model = super(OVModelForSpeechSeq2Seq, cls)._from_pretrained( + model_id, config, load_in_8bit=load_in_8bit, quantization_config=quantization_config, **kwargs + ) + + return model class DummyWhisperModel: def __init__(self): diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index c6a625bd7b..7d227444ec 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -59,7 +59,13 @@ is_diffusers_available, ) from ..utils.modeling_utils import get_model_device -from .configuration import OVConfig, OVQuantizationConfig, OVQuantizationMethod, OVWeightQuantizationConfig +from .configuration import ( + OVConfig, + OVQuantizationConfig, + OVQuantizationConfigBase, + OVQuantizationMethod, + OVWeightQuantizationConfig, +) from .modeling_base import OVBaseModel from .utils import ( MAX_ONNX_OPSET, @@ -67,6 +73,7 @@ ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, PREDEFINED_SD_DATASETS, + PREDEFINED_SPEECH_TO_TEXT_DATASETS, PREDEFINED_VISUAL_LM_DATASETS, ) @@ -319,6 +326,7 @@ def _quantize_ovbasemodel( remove_unused_columns: bool = True, **kwargs, ): + from optimum.intel.openvino.modeling_seq2seq import _OVModelForWhisper from optimum.intel.openvino.modeling_visual_language import OVModelForVisualCausalLM if is_diffusers_available(): @@ -344,7 +352,7 @@ def _quantize_ovbasemodel( data_collator=data_collator, ) if self.model.export_feature == "text-generation" and self.model.use_cache: - calibration_dataset = self._prepare_text_generation_dataset( + calibration_dataset = self._prepare_text_generation_calibration_data( quantization_config, calibration_dataloader ) else: @@ -357,31 +365,31 @@ def _quantize_ovbasemodel( f"`nncf.Dataset` or `datasets.Dataset`. Found: {type(calibration_dataset)}." ) - if isinstance(quantization_config, OVWeightQuantizationConfig): - if quantization_config.dataset is not None and calibration_dataset is not None: - logger.info( - "Both `quantization_config.dataset` and `calibration_dataset` were provided for weight only " - "quantization. Will rely on `calibration_dataset`." - ) - - if calibration_dataset is None and quantization_config.dataset is not None: - from optimum.intel import OVModelForCausalLM + if quantization_config.dataset is not None and calibration_dataset is not None: + logger.info( + "Both `quantization_config.dataset` and `calibration_dataset` were provided for weight only " + "quantization. Will rely on `calibration_dataset`." + ) - if isinstance(self.model, OVModelForCausalLM): - calibration_dataset = self._prepare_causal_lm_dataset(quantization_config) - elif isinstance(self.model, OVModelForVisualCausalLM): - calibration_dataset = self._prepare_visual_causal_lm_dataset(quantization_config) - elif is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline): - if not isinstance(quantization_config.dataset, str): - raise ValueError("Please provide dataset as one of the accepted dataset labels.") - calibration_dataset = self._prepare_unet_dataset( - quantization_config.num_samples, dataset_name=quantization_config.dataset - ) - else: - raise ValueError( - f"Can't create weight compression calibration dataset from string for {type(self.model)}" - ) + if calibration_dataset is None and quantization_config.dataset is not None: + from optimum.intel import OVModelForCausalLM + + if isinstance(self.model, OVModelForCausalLM): + calibration_dataset = self._prepare_causal_lm_calibration_data(quantization_config) + elif isinstance(self.model, OVModelForVisualCausalLM): + calibration_dataset = self._prepare_visual_causal_lm_calibration_data(quantization_config) + elif isinstance(self.model, _OVModelForWhisper): + calibration_dataset = self._prepare_speech_to_text_calibration_data(quantization_config) + elif is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline): + if not isinstance(quantization_config.dataset, str): + raise ValueError("Please provide dataset as one of the accepted dataset labels.") + calibration_dataset = self._prepare_unet_dataset( + quantization_config.num_samples, dataset_name=quantization_config.dataset + ) + else: + raise ValueError(f"Can't create quantization calibration dataset from string for {type(self.model)}") + if isinstance(quantization_config, OVWeightQuantizationConfig): if quantization_config.quant_method == OVQuantizationMethod.HYBRID: if calibration_dataset is None: raise ValueError("Calibration dataset is required to run hybrid quantization.") @@ -399,22 +407,24 @@ def _quantize_ovbasemodel( ] sub_models = filter(lambda x: x, (getattr(self.model, name) for name in sub_model_names)) for sub_model in sub_models: - _weight_only_quantization(sub_model.model, quantization_config_copy) + _weight_only_quantization(sub_model.model, quantization_config_copy, **kwargs) if self.model.unet is not None: # Apply hybrid quantization to UNet self.model.unet.model = _hybrid_quantization( - self.model.unet.model, quantization_config, calibration_dataset + self.model.unet.model, quantization_config, calibration_dataset, **kwargs ) else: self.model.transformer.model = _hybrid_quantization( - self.model.transformer.model, quantization_config, calibration_dataset + self.model.transformer.model, quantization_config, calibration_dataset, **kwargs ) self.model.clear_requests() else: # The model may be for example OVModelForImageClassification, OVModelForAudioClassification, etc. - self.model.model = _hybrid_quantization(self.model.model, quantization_config, calibration_dataset) + self.model.model = _hybrid_quantization( + self.model.model, quantization_config, calibration_dataset, **kwargs + ) self.model.request = None else: if is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline): @@ -429,47 +439,36 @@ def _quantize_ovbasemodel( ] sub_models = filter(lambda x: x, (getattr(self.model, name) for name in sub_model_names)) for sub_model in sub_models: - _weight_only_quantization(sub_model.model, quantization_config) + _weight_only_quantization(sub_model.model, quantization_config, **kwargs) self.model.clear_requests() elif isinstance(self.model, OVModelForVisualCausalLM): language_model = self.model.language_model - _weight_only_quantization(language_model.model, quantization_config, calibration_dataset) + _weight_only_quantization(language_model.model, quantization_config, calibration_dataset, **kwargs) sub_model_names = ["vision_embeddings", "text_embeddings"] + self.model.additional_parts sub_models = [getattr(self.model, f"{name}_model") for name in sub_model_names] for sub_model in sub_models: - _weight_only_quantization(sub_model, OVWeightQuantizationConfig(bits=8, sym=True)) + _weight_only_quantization(sub_model, OVWeightQuantizationConfig(bits=8, sym=True), **kwargs) self.model.clear_requests() else: - _weight_only_quantization(self.model.model, quantization_config, calibration_dataset) + _weight_only_quantization(self.model.model, quantization_config, calibration_dataset, **kwargs) self.model.request = None - if save_directory is not None: - self.model.save_pretrained(save_directory) - ov_config.save_pretrained(save_directory) - return + else: + if not isinstance(quantization_config, OVQuantizationConfig): + raise ValueError(f"Unsupported type of quantization config: {type(quantization_config)}") - if not isinstance(quantization_config, OVQuantizationConfig): - raise ValueError(f"Unsupported type of quantization config: {type(quantization_config)}") - - if calibration_dataset is None: - raise ValueError("Calibration dataset is required to run quantization.") - - # Actual model quantization - quantized_model = nncf.quantize( - self.model.model, - calibration_dataset, - subset_size=quantization_config.num_samples, - ignored_scope=quantization_config.get_ignored_scope_instance(), - model_type=nncf.ModelType(quantization_config.model_type), - preset=nncf.QuantizationPreset.PERFORMANCE if quantization_config.sym else nncf.QuantizationPreset.MIXED, - fast_bias_correction=quantization_config.fast_bias_correction, - advanced_parameters=nncf.AdvancedQuantizationParameters( - overflow_fix=OverflowFix(quantization_config.overflow_fix) - ), - **kwargs, - ) + if calibration_dataset is None: + raise ValueError("Calibration dataset is required to run quantization.") + + # Quantize model(s) + if isinstance(self.model, _OVModelForWhisper): + self._quantize_whisper_model(quantization_config, calibration_dataset, **kwargs) + else: + quantized_model = _full_quantization( + self.model.model, quantization_config, calibration_dataset, **kwargs + ) + self.model.model = quantized_model + self.model.request = None - self.model.model = quantized_model - self.model.request = None if save_directory is not None: self.model.save_pretrained(save_directory) ov_config.save_pretrained(save_directory) @@ -725,7 +724,7 @@ def _remove_unused_columns(self, dataset: "Dataset"): ignored_columns = list(set(dataset.column_names) - set(self._signature_columns)) return dataset.remove_columns(ignored_columns) - def _prepare_causal_lm_dataset(self, quantization_config: OVWeightQuantizationConfig): + def _prepare_causal_lm_calibration_data(self, quantization_config: OVQuantizationConfigBase): from optimum.gptq.data import get_dataset, prepare_dataset tokenizer = AutoTokenizer.from_pretrained( @@ -748,7 +747,7 @@ def _prepare_causal_lm_dataset(self, quantization_config: OVWeightQuantizationCo return calibration_dataset - def _prepare_visual_causal_lm_dataset(self, config: OVWeightQuantizationConfig): + def _prepare_visual_causal_lm_calibration_data(self, config: OVQuantizationConfigBase): dataset_name = config.dataset if dataset_name not in PREDEFINED_VISUAL_LM_DATASETS: raise ValueError( @@ -770,8 +769,8 @@ def _prepare_visual_causal_lm_dataset(self, config: OVWeightQuantizationConfig): tokenizer = None dataset_metadata = PREDEFINED_VISUAL_LM_DATASETS[dataset_name] - dataset = datasets.load_dataset(dataset_metadata["name"], split=dataset_metadata["split"]).shuffle(seed=0) - num_samples = min(config.num_samples or 128, len(dataset)) + dataset = datasets.load_dataset(dataset_metadata["id"], split=dataset_metadata["split"]).shuffle(seed=0) + num_samples = min(config.num_samples or 32, len(dataset)) dataset = islice(dataset, num_samples) calibration_dataset = [] @@ -809,8 +808,75 @@ def _prepare_visual_causal_lm_dataset(self, config: OVWeightQuantizationConfig): calibration_dataset = nncf.Dataset(calibration_dataset) return calibration_dataset - def _prepare_text_generation_dataset( - self, quantization_config: OVQuantizationConfig, calibration_dataloader: OVDataLoader + def _prepare_speech_to_text_calibration_data(self, config: OVQuantizationConfigBase): + if not is_datasets_available(): + raise ValueError(DATASETS_IMPORT_ERROR.format("OVQuantizer._prepare_whisper_calibration_data")) + + from datasets import load_dataset + + encoder_calibration_data = [] + encoder_model = self.model.encoder + encoder_model._compile() + encoder_model.request = InferRequestWrapper( + encoder_model.request, encoder_calibration_data, apply_caching=True + ) + + decoder_calibration_data = [] + decoder_model = self.model.decoder + decoder_model._compile() + decoder_model.request = InferRequestWrapper( + decoder_model.request, decoder_calibration_data, apply_caching=True + ) + + decoder_w_p_calibration_data = [] + decoder_w_p_model = self.model.decoder_with_past + decoder_w_p_model._compile() + decoder_w_p_model.request = InferRequestWrapper( + decoder_w_p_model.request, decoder_w_p_calibration_data, apply_caching=True + ) + + dataset_metadata = PREDEFINED_SPEECH_TO_TEXT_DATASETS[config.dataset] + + processor = AutoProcessor.from_pretrained(config.processor) + + try: + dataset = load_dataset( + dataset_metadata["id"], + dataset_metadata["name"], + split=dataset_metadata["split"], + streaming=True, + trust_remote_code=config.trust_remote_code, + ) + num_samples = config.num_samples or 128 + + audio_inputs = [] + # Download audio inputs beforehand to avoid possible connection issues + for item in tqdm(islice(dataset, num_samples), desc="Downloading audio inputs", total=num_samples): + audio = item + for key_name in dataset_metadata["inputs"]["audio"]: + audio = audio[key_name] + + sampling_rate = item + for key_name in dataset_metadata["inputs"]["sampling_rate"]: + sampling_rate = sampling_rate[key_name] + audio_inputs.append((audio, sampling_rate)) + + for audio, sampling_rate in tqdm(audio_inputs, desc="Collecting calibration data"): + input_features = processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_features + self.model.generate(input_features) + finally: + encoder_model.request = encoder_model.request.request + decoder_model.request = decoder_model.request.request + decoder_w_p_model.request = decoder_w_p_model.request.request + + return ( + nncf.Dataset(encoder_calibration_data), + nncf.Dataset(decoder_calibration_data), + nncf.Dataset(decoder_w_p_calibration_data), + ) + + def _prepare_text_generation_calibration_data( + self, quantization_config: OVQuantizationConfigBase, calibration_dataloader: OVDataLoader ) -> nncf.Dataset: # Prefetch past_key_values self.model.update_pkv_precision(True) @@ -898,11 +964,44 @@ def transform_fn(data_item): calibration_dataset = nncf.Dataset(calibration_data[:num_samples]) return calibration_dataset + def _quantize_whisper_model(self, quantization_config, calibration_dataset, **kwargs): + # Quantize encoder model + # quantization_config.num_samples of audio samples result in more actual model inputs + config = copy.deepcopy(quantization_config) + config.num_samples = calibration_dataset[0].get_length() + quantized_encoder_model = _full_quantization( + self.model.encoder_model, config, calibration_dataset[0], **kwargs + ) + self.model.encoder_model = quantized_encoder_model + self.model.encoder.model = quantized_encoder_model + self.model.encoder.request = None + + # Quantize decoder model + config = copy.deepcopy(quantization_config) + config.num_samples = calibration_dataset[1].get_length() + quantized_decoder_model = _full_quantization( + self.model.decoder_model, config, calibration_dataset[1], **kwargs + ) + self.model.decoder_model = quantized_decoder_model + self.model.decoder.model = quantized_decoder_model + self.model.decoder.request = None + + # Quantize decoder with past model + config = copy.deepcopy(quantization_config) + config.num_samples = calibration_dataset[2].get_length() + quantized_decoder_w_p_model = _full_quantization( + self.model.decoder_with_past_model, config, calibration_dataset[2], **kwargs + ) + self.model.decoder_with_past_model = quantized_decoder_w_p_model + self.model.decoder_with_past.model = quantized_decoder_w_p_model + self.model.decoder_with_past.request = None + def _weight_only_quantization( model: openvino.runtime.Model, quantization_config: Union[OVWeightQuantizationConfig, Dict], calibration_dataset: Optional[Union[nncf.Dataset, Iterable]] = None, + **kwargs, ) -> openvino.runtime.Model: config = quantization_config if isinstance(config, dict): @@ -950,9 +1049,40 @@ def _weight_only_quantization( gptq=config.gptq, lora_correction=config.lora_correction, backup_mode=None if config.backup_precision is None else nncf.BackupMode(config.backup_precision), + **kwargs, ) +def _full_quantization( + model: openvino.runtime.Model, + quantization_config: OVQuantizationConfig, + calibration_dataset: nncf.Dataset, + **kwargs, +): + advanced_parameters_kwargs = {} + if quantization_config.smooth_quant_alpha is not None: + advanced_parameters_kwargs["smooth_quant_alphas"] = AdvancedSmoothQuantParameters( + matmul=quantization_config.smooth_quant_alpha + ) + + quantized_model = nncf.quantize( + model, + calibration_dataset, + subset_size=quantization_config.num_samples, + ignored_scope=quantization_config.get_ignored_scope_instance(), + model_type=nncf.ModelType(quantization_config.model_type), + preset=nncf.QuantizationPreset.PERFORMANCE if quantization_config.sym else nncf.QuantizationPreset.MIXED, + fast_bias_correction=quantization_config.fast_bias_correction, + advanced_parameters=nncf.AdvancedQuantizationParameters( + overflow_fix=OverflowFix(quantization_config.overflow_fix), + **advanced_parameters_kwargs, + ), + **kwargs, + ) + + return quantized_model + + def _get_operation_const_op(operation, const_port_id: int): node = operation.input_value(const_port_id).get_node() queue = deque([node]) @@ -999,7 +1129,7 @@ def _collect_ops_with_weights(model): def _hybrid_quantization( - model: openvino.runtime.Model, quantization_config: OVWeightQuantizationConfig, dataset: nncf.Dataset + model: openvino.runtime.Model, quantization_config: OVWeightQuantizationConfig, dataset: nncf.Dataset, **kwargs ) -> openvino.runtime.Model: """ Quantize a model in hybrid mode with NNCF which means that we quantize: @@ -1021,7 +1151,7 @@ def _hybrid_quantization( wc_config = copy.deepcopy(quantization_config) wc_config.ignored_scope = wc_config.ignored_scope or {} wc_config.ignored_scope["types"] = wc_config.ignored_scope.get("types", []) + ["Convolution"] - compressed_model = _weight_only_quantization(model, wc_config) + compressed_model = _weight_only_quantization(model, wc_config, **kwargs) ptq_ignored_scope = quantization_config.get_ignored_scope_instance() ptq_ignored_scope.names += ops_to_compress @@ -1037,5 +1167,6 @@ def _hybrid_quantization( smooth_quant_alphas=AdvancedSmoothQuantParameters(matmul=-1) ), subset_size=subset_size, + **kwargs, ) return quantized_model diff --git a/optimum/intel/openvino/utils.py b/optimum/intel/openvino/utils.py index 36755cd64d..e54503e83d 100644 --- a/optimum/intel/openvino/utils.py +++ b/optimum/intel/openvino/utils.py @@ -142,12 +142,21 @@ PREDEFINED_VISUAL_LM_DATASETS = { "contextual": { - "name": "ucla-contextual/contextual_test", + "id": "ucla-contextual/contextual_test", "split": "test", "inputs": {"image_url": "image_url", "instruction": "instruction"}, } } +PREDEFINED_SPEECH_TO_TEXT_DATASETS = { + "librispeech": { + "id": "openslr/librispeech_asr", + "name": "clean", + "split": "validation", + "inputs": {"audio": ("audio", "array"), "sampling_rate": ("audio", "sampling_rate")}, + } +} + NEED_CONVERT_TO_FAST_TOKENIZER: Tuple[Type[PreTrainedTokenizer]] = (CLIPTokenizer,) diff --git a/setup.py b/setup.py index 7f7e91df33..8c8c19f119 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, + "tests-openvino": ["datasets[audio]>=1.4.0"], } setup( diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index e97eed5aed..5751bfdaae 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -93,6 +93,22 @@ class OVQuantizerTest(unittest.TestCase): (OVModelForSequenceClassification, "bert", 32, 35), (OVModelForCausalLM, "gpt2", 31, 22), ) + SUPPORTED_ARCHITECTURES_OV_MODEL_WITH_AUTO_DATASET = [ + ( + OVModelForSpeechSeq2Seq, + "whisper", + OVQuantizationConfig( + dataset="librispeech", + num_samples=1, + processor=MODEL_NAMES["whisper"], + trust_remote_code=True, + weight_only=False, + smooth_quant_alpha=0.95, + ), + (14, 22, 21) if is_transformers_version("<=", "4.36.0") else (14, 22, 25), + (14, 21, 17) if is_transformers_version("<=", "4.36.0") else (14, 22, 18), + ), + ] @parameterized.expand(SUPPORTED_ARCHITECTURES_TORCH_MODEL) def test_automodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8): @@ -180,6 +196,31 @@ def preprocess_function(examples, tokenizer): loaded_config = OVConfig.from_pretrained(tmp_dir) self.assertEqual(ov_config.quantization_config.to_dict(), loaded_config.quantization_config.to_dict()) + @parameterized.expand(SUPPORTED_ARCHITECTURES_OV_MODEL_WITH_AUTO_DATASET) + def test_ov_model_static_quantization_with_auto_dataset( + self, model_cls, model_name, quantization_config, expected_fake_quantize, expected_int8 + ): + model_id = MODEL_NAMES[model_name] + + with TemporaryDirectory() as tmp_dir: + ov_model = model_cls.from_pretrained(model_id, quantization_config=quantization_config) + ov_model.save_pretrained(tmp_dir) + + if model_cls == OVModelForSpeechSeq2Seq: + for model, expected_fq, expected_i8 in zip( + (ov_model.encoder.model, ov_model.decoder.model, ov_model.decoder_with_past.model), + expected_fake_quantize, + expected_int8, + ): + num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(model) + self.assertEqual(expected_fq, num_fake_quantize) + self.assertEqual(expected_i8, num_weight_nodes["int8"]) + + input_features = torch.randn((1, 128, 3000), dtype=torch.float32) + ov_model.generate(input_features) + else: + raise Exception("Unexpected model class.") + class OVWeightCompressionTest(unittest.TestCase): SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS = ( @@ -1054,9 +1095,9 @@ class OVQuantizationConfigTest(unittest.TestCase): (dict(num_samples=100), OVWeightQuantizationConfig, "Can't determine type of OV quantization config"), (dict(abc="def"), OVWeightQuantizationConfig, "Can't determine type of OV quantization config"), ( - dict(bits=4, fast_bias_correction=True, dataset="wikitext2"), - OVWeightQuantizationConfig, - "Can't determine type of OV quantization config", + dict(bits=8, fast_bias_correction=True, dataset="librispeech"), + OVQuantizationConfig, + None, ), (dict(model_type="transformer"), OVQuantizationConfig, None), ( @@ -1076,7 +1117,12 @@ class OVQuantizationConfigTest(unittest.TestCase): (dict(abc="def", weight_only=False), OVQuantizationConfig, None), (dict(abc="def", weight_only=True), OVWeightQuantizationConfig, None), ( - dict(bits=4, fast_bias_correction=True, dataset="wikitext2", weight_only=True), + dict(bits=8, fast_bias_correction=True, dataset="librispeech", weight_only=True), + OVQuantizationConfig, + None, + ), + ( + dict(bits=4, dataset="wikitext2", weight_only=True), OVWeightQuantizationConfig, None, ), @@ -1136,7 +1182,7 @@ def test_for_no_short_id_duplicates(self): class InferRequestWrapperTest(unittest.TestCase): - MODEL_ID = ("openai/whisper-tiny.en",) + MODEL_NAME = ("whisper",) APPLY_CACHING = (False, True) @staticmethod @@ -1150,8 +1196,9 @@ def _generate_random_audio_data(processor): ).input_features return input_features - @parameterized.expand(itertools.product(MODEL_ID, APPLY_CACHING)) - def test_calibration_data_uniqueness(self, model_id, apply_caching): + @parameterized.expand(itertools.product(MODEL_NAME, APPLY_CACHING)) + def test_calibration_data_uniqueness(self, model_name, apply_caching): + model_id = MODEL_NAMES[model_name] ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, compile=True) processor = AutoProcessor.from_pretrained(model_id) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 17d9dd1fbe..bf509a044f 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -147,7 +147,7 @@ "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier", "wav2vec2-hf": "hf-internal-testing/tiny-random-Wav2Vec2Model", "wav2vec2-conformer": "hf-internal-testing/tiny-random-wav2vec2-conformer", - "whisper": "openai/whisper-tiny.en", + "whisper": "yujiepan/whisper-v3-tiny-random", "xlm": "hf-internal-testing/tiny-random-xlm", "xlm_roberta": "hf-internal-testing/tiny-xlm-roberta", "xglm": "hf-internal-testing/tiny-random-XGLMForCausalLM", From f6b73d0b9dfb6d33687bc70cfb88edd60654ff53 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Tue, 10 Dec 2024 16:35:53 +0100 Subject: [PATCH 2/7] Remove tests-openvino extra requirement (#1059) * Remove openvino-tests extra requirement * export LD_LIBRARY_PATH * Add tbb to tests requirements --- .github/workflows/test_openvino.yml | 2 +- .github/workflows/test_openvino_full.yml | 2 +- .github/workflows/test_openvino_slow.yml | 2 +- setup.py | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml index 67f6d680a4..7583c51078 100644 --- a/.github/workflows/test_openvino.yml +++ b/.github/workflows/test_openvino.yml @@ -43,7 +43,7 @@ jobs: run: | pip install --upgrade pip pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - pip install .[openvino,openvino-tokenizers,diffusers,tests,tests-openvino] transformers[testing] + pip install .[openvino,openvino-tokenizers,diffusers,tests] transformers[testing] - if: ${{ matrix.transformers-version != 'latest' }} name: Downgrade Transformers and Accelerate diff --git a/.github/workflows/test_openvino_full.yml b/.github/workflows/test_openvino_full.yml index 3455f8ca54..914035b750 100644 --- a/.github/workflows/test_openvino_full.yml +++ b/.github/workflows/test_openvino_full.yml @@ -56,7 +56,7 @@ jobs: python -m pip install --upgrade pip # Install PyTorch CPU to prevent unnecessary downloading/installing of CUDA packages pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - pip install .[tests,tests-openvino] + pip install .[tests] - name: Install openvino-nightly if: ${{ matrix.openvino == 'ov-nightly' }} diff --git a/.github/workflows/test_openvino_slow.yml b/.github/workflows/test_openvino_slow.yml index f7555c64bc..9ad5ef2691 100644 --- a/.github/workflows/test_openvino_slow.yml +++ b/.github/workflows/test_openvino_slow.yml @@ -42,7 +42,7 @@ jobs: run: | pip install --upgrade pip pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - pip install .[openvino,tests,tests-openvino] transformers[testing] + pip install .[openvino,tests] transformers[testing] pip uninstall -y nncf - if: ${{ matrix.transformers-version != 'latest' }} diff --git a/setup.py b/setup.py index 8c8c19f119..ca415fca35 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,8 @@ "sentence-transformers", "open_clip_torch>=2.26.1", "peft", + "datasets[audio]>=1.4.0", + "tbb", ] QUALITY_REQUIRE = ["black~=23.1", "ruff==0.4.4"] @@ -68,7 +70,6 @@ "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, - "tests-openvino": ["datasets[audio]>=1.4.0"], } setup( From d7b1e1d1c0c9b9e22ec217f4990ebe9f45fd9c8b Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Wed, 11 Dec 2024 05:44:34 +0000 Subject: [PATCH 3/7] Add hybrid quantization for Flux model (#1060) * Add hybrid quantization for Flux model * Update optimum/intel/openvino/quantization.py Co-authored-by: Nikita Savelyev --------- Co-authored-by: Nikita Savelyev --- optimum/commands/export/openvino.py | 4 ++++ optimum/intel/openvino/quantization.py | 4 +++- tests/openvino/test_exporters_cli.py | 1 + tests/openvino/test_quantization.py | 2 ++ 4 files changed, 10 insertions(+), 1 deletion(-) diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index 5e951aa438..b8cc035393 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -354,6 +354,10 @@ def run(self): from optimum.intel import OVStableDiffusion3Pipeline model_cls = OVStableDiffusion3Pipeline + elif class_name == "FluxPipeline": + from optimum.intel import OVFluxPipeline + + model_cls = OVFluxPipeline else: raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.") diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 7d227444ec..6f739e2543 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -1150,7 +1150,9 @@ def _hybrid_quantization( wc_config = copy.deepcopy(quantization_config) wc_config.ignored_scope = wc_config.ignored_scope or {} - wc_config.ignored_scope["types"] = wc_config.ignored_scope.get("types", []) + ["Convolution"] + + wc_ignored_types = ["Convolution"] if any(op.get_type_name() == "Convolution" for op in model.get_ops()) else [] + wc_config.ignored_scope["types"] = wc_config.ignored_scope.get("types", []) + wc_ignored_types compressed_model = _weight_only_quantization(model, wc_config, **kwargs) ptq_ignored_scope = quantization_config.get_ignored_scope_instance() diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index f94d0f4b5d..97cbe8ef22 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -105,6 +105,7 @@ class OVCLIExportTestCase(unittest.TestCase): if is_transformers_version(">=", "4.45"): SUPPORTED_SD_HYBRID_ARCHITECTURES.append(("stable-diffusion-3", 9, 65)) + SUPPORTED_SD_HYBRID_ARCHITECTURES.append(("flux", 7, 56)) TEST_4BIT_CONFIGURATIONS = [ ("text-generation-with-past", "opt125m", "int4 --sym --group-size 128", {"int8": 4, "int4": 72}), diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 5751bfdaae..1fd58646e7 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -43,6 +43,7 @@ from optimum.intel import ( OVConfig, + OVFluxPipeline, OVLatentConsistencyModelPipeline, OVModelForAudioClassification, OVModelForCausalLM, @@ -491,6 +492,7 @@ class OVWeightCompressionTest(unittest.TestCase): SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION.extend( [ (OVStableDiffusion3Pipeline, "stable-diffusion-3", 9, 65), + (OVFluxPipeline, "flux", 7, 56), ] ) From 34bf5ae82a70fd56742376afa41da06b44a821b2 Mon Sep 17 00:00:00 2001 From: Nikita Malinin Date: Wed, 11 Dec 2024 06:46:00 +0100 Subject: [PATCH 4/7] Update default OV configuration (#1057) * Update Baichuan2 models default config * Update configuration.py --- optimum/intel/openvino/configuration.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/optimum/intel/openvino/configuration.py b/optimum/intel/openvino/configuration.py index 61ba98119e..a0fc68361c 100644 --- a/optimum/intel/openvino/configuration.py +++ b/optimum/intel/openvino/configuration.py @@ -123,11 +123,18 @@ class OVQuantizationMethod(str, Enum): "mistralai/Mistral-7B-v0.1": {"bits": 4, "sym": True, "group_size": 128, "ratio": 0.9}, "baichuan-inc/Baichuan2-7B-Chat": { "bits": 4, - "sym": True, + "sym": False, "group_size": 128, "ratio": 0.8, + }, + "baichuan-inc/Baichuan2-13B-Chat": { + "bits": 4, + "sym": False, + "group_size": 128, + "ratio": 1.0, "dataset": "wikitext2", "quant_method": OVQuantizationMethod.AWQ, + "scale_estimation": True, }, "lmsys/longchat-7b-16k": { "bits": 4, From 41d9a377c75a693c04113bbc1b408a1bbe88776b Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Wed, 11 Dec 2024 06:47:57 +0100 Subject: [PATCH 5/7] Update backup-precision option description (#1055) * Update backup-precision option description * Trigger Tests * Trigger Tests --- docs/source/openvino/export.mdx | 4 ++-- optimum/commands/export/openvino.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/openvino/export.mdx b/docs/source/openvino/export.mdx index dd542be735..4876885219 100644 --- a/docs/source/openvino/export.mdx +++ b/docs/source/openvino/export.mdx @@ -84,8 +84,8 @@ Optional arguments: The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization. --backup-precision {none,int8_sym,int8_asym} - Defines a backup precision for mixed-precision weight compression. Only valid for int4 weight - format. If not provided, backup precision is int8_asym. 'none' stands for original floating- + Defines a backup precision for mixed-precision weight compression. Only valid for 4-bit weight + formats. If not provided, backup precision is int8_asym. 'none' stands for original floating- point precision of the model weights, in this case weights are retained in their original precision without any quantization. 'int8_sym' stands for 8-bit integer symmetric quantization without zero point. 'int8_asym' stands for 8-bit integer asymmetric quantization with zero diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index b8cc035393..61c21c5c72 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -123,7 +123,7 @@ def parse_args_openvino(parser: "ArgumentParser"): choices=["none", "int8_sym", "int8_asym"], default=None, help=( - "Defines a backup precision for mixed-precision weight compression. Only valid for int4 weight format. " + "Defines a backup precision for mixed-precision weight compression. Only valid for 4-bit weight formats. " "If not provided, backup precision is int8_asym. 'none' stands for original floating-point precision of " "the model weights, in this case weights are retained in their original precision without any " "quantization. 'int8_sym' stands for 8-bit integer symmetric quantization without zero point. 'int8_asym' " From bb51139f7d01b32fe82e053a8c225e3ea8c31221 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 11 Dec 2024 16:50:57 +0800 Subject: [PATCH 6/7] fix autotp linear check (#1062) Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index e741575edd..8d5f8afa1a 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -664,9 +664,9 @@ def __init__(self, module, config) -> None: if use_bias: concat_bias = torch.concat(bias_list, 0).contiguous() self.concat_linear.bias = nn.Parameter(concat_bias) - self.q_slice = self.q_proj.out_features - self.k_slice = self.q_slice + self.k_proj.out_features - self.v_slice = self.k_slice + self.v_proj.out_features + self.q_slice = self.q_proj.weight.shape[0] + self.k_slice = self.q_slice + self.k_proj.weight.shape[0] + self.v_slice = self.k_slice + self.v_proj.weight.shape[0] if self.module_device.type == "cpu": if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: self.mha_linear_add = LinearAdd(module.o_proj) From 35cf1d289a0f15e056c1dae5d9967ba027b3acec Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Dec 2024 16:59:07 +0800 Subject: [PATCH 7/7] Update readme (#1066) Signed-off-by: jiqing-feng --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b3879ef380..0cd317c78d 100644 --- a/README.md +++ b/README.md @@ -168,7 +168,7 @@ To load your IPEX model, you can just replace your `AutoModelForXxx` class with model_id = "gpt2" - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) -+ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, export=True) ++ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) results = pipe("He's a dreadful magician and")