diff --git a/examples/neural_compressor/language-modeling/README.md b/examples/neural_compressor/language-modeling/README.md index 1c8e98b9ee..b005bb78a4 100644 --- a/examples/neural_compressor/language-modeling/README.md +++ b/examples/neural_compressor/language-modeling/README.md @@ -18,7 +18,7 @@ limitations under the License. The scripts [`run_clm.py`](https://github.com/huggingface/optimum-intel/blob/main/examples/neural_compressor/language-modeling/run_clm.py) and [`run_mlm.py`](https://github.com/huggingface/optimum-intel/blob/main/examples/neural_compressor/language-modeling/run_mlm.py) -allow us to apply different quantization approaches (such as dynamic, static and aware-training quantization) as well as pruning +allow us to apply different quantization approaches (such as dynamic, static, weight-only and aware-training quantization) as well as pruning using the [Intel Neural Compressor ](https://github.com/intel/neural-compressor) library for language modeling tasks. The SmoothQuant methodology is also available for post-training quantization. @@ -67,6 +67,7 @@ python run_clm.py \ --do_eval \ --verify_loading \ --output_dir /tmp/clm_output +``` ### RoBERTa/BERT/DistilBERT and masked language modeling @@ -91,7 +92,9 @@ python run_mlm.py \ --output_dir /tmp/mlm_output ``` -In order to apply dynamic, static or aware-training quantization, `quantization_approach` must be set to -respectively `dynamic`, `static` or `aware_training`. +In order to apply dynamic, static, weight-only or aware-training quantization, `quantization_approach` must be set to +respectively `dynamic`, `static`, `weight_only` or `aware_training`. The flag `--verify_loading` can be passed along to verify that the resulting quantized model can be loaded correctly. + +> **_Note:_** `weight_only` quantization_approach requires neural-compressor >= 2.3 diff --git a/examples/neural_compressor/language-modeling/run_clm.py b/examples/neural_compressor/language-modeling/run_clm.py index 54f1e7b617..cbc523b663 100644 --- a/examples/neural_compressor/language-modeling/run_clm.py +++ b/examples/neural_compressor/language-modeling/run_clm.py @@ -196,6 +196,28 @@ class OptimizationArguments: default=False, metadata={"help": "Whether or not to verify the loading of the quantized model."}, ) + bits: int = field( + default=8, + metadata={"help": "Bits for weight only quantization, 1-8 bits."}, + ) + group_size: int = field( + default=-1, + metadata={ + "help": "Group size for weight only quantization. Group_size=[1-N] indicates " + "splitting the input channel elements per group_size. -1 indicates " + "the per-channel quantization per output channel." + }, + ) + weight_only_scheme: str = field( + default="sym", + metadata={"help": "Scheme for weight only quantization. Choose from 'sym' and 'asym'."}, + ) + quantization_methodology: str = field( + default="RTN", + metadata={ + "help": "Quantization methodology for weight only quantization. Choose from 'RTN', 'AWQ' and 'GPTQ'." + }, + ) @dataclass @@ -539,7 +561,9 @@ def group_texts(examples): desc=f"Grouping texts in chunks of {block_size}", ) - if training_args.do_train or (optim_args.apply_quantization and optim_args.quantization_approach == "static"): + if training_args.do_train or ( + optim_args.apply_quantization and optim_args.quantization_approach in ["static", "weight_only"] + ): if "train" not in tokenized_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = lm_datasets["train"] @@ -587,7 +611,7 @@ def compute_metrics(eval_preds): raise ValueError("`do_train` must be set to True.") if optim_args.apply_quantization: - supported_approach = {"static", "dynamic", "aware_training"} + supported_approach = {"static", "dynamic", "aware_training", "weight_only"} if optim_args.quantization_approach not in supported_approach: raise ValueError( f"Unknown quantization approach. Supported approach are {supported_approach}." @@ -600,7 +624,27 @@ def compute_metrics(eval_preds): recipes = {"smooth_quant": True, "smooth_quant_args": {"alpha": optim_args.smooth_quant_alpha}} else: recipes = {} - quantization_config = PostTrainingQuantConfig(approach=optim_args.quantization_approach, recipes=recipes) + if optim_args.quantization_approach == "weight_only": + op_type_dict = { + ".*": { + "weight": { + "bits": optim_args.bits, + "group_size": optim_args.group_size, + "scheme": optim_args.weight_only_scheme, + "algorithm": optim_args.quantization_methodology, + }, + }, + } + if optim_args.quantization_methodology == "GPTQ": + gptq_args = { + "pad_max_length": block_size, + } + recipes.update({"gptq_args": gptq_args}) + else: + op_type_dict = {} + quantization_config = PostTrainingQuantConfig( + approach=optim_args.quantization_approach, op_type_dict=op_type_dict, recipes=recipes + ) if optim_args.apply_pruning: if optim_args.end_step is None: @@ -677,10 +721,10 @@ def compute_metrics(eval_preds): trainer.save_metrics("train", metrics) trainer.save_state() - if optim_args.apply_quantization and optim_args.quantization_approach in {"static", "dynamic"}: + if optim_args.apply_quantization and optim_args.quantization_approach in {"static", "dynamic", "weight_only"}: model = trainer.model if isinstance(trainer.model, PreTrainedModel) else trainer.model._model quantizer = INCQuantizer.from_pretrained(model) - if optim_args.quantization_approach == "static": + if optim_args.quantization_approach in ["static", "weight_only"]: num_calibration_samples = min(len(train_dataset), optim_args.num_calibration_samples) train_dataset = train_dataset.select(range(num_calibration_samples)) quantization_config.calibration_sampling_size = num_calibration_samples @@ -688,8 +732,13 @@ def compute_metrics(eval_preds): quantizer.quantize( quantization_config=quantization_config, save_directory=training_args.output_dir, - calibration_dataset=train_dataset if optim_args.quantization_approach == "static" else None, - batch_size=training_args.per_device_train_batch_size, + calibration_dataset=train_dataset + if optim_args.quantization_approach in ["static", "weight_only"] + else None, + batch_size=1 # batch_size > 1 for GPTQ is WIP + if optim_args.quantization_approach == "weight_only" and optim_args.quantization_methodology == "GPTQ" + else training_args.per_device_train_batch_size, + weight_only=True if optim_args.quantization_approach == "weight_only" else False, ) trainer.model = quantizer._quantized_model if optim_args.apply_quantization and optim_args.verify_loading: diff --git a/optimum/exporters/openvino/__init__.py b/optimum/exporters/openvino/__init__.py new file mode 100644 index 0000000000..d87d8dda9e --- /dev/null +++ b/optimum/exporters/openvino/__init__.py @@ -0,0 +1,5 @@ +from .__main__ import main_export +from .convert import export, export_models, export_pytorch_via_onnx + + +__all__ = ["main_export", "export", "export_models"] diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py new file mode 100644 index 0000000000..5cf0adb176 --- /dev/null +++ b/optimum/exporters/openvino/__main__.py @@ -0,0 +1,293 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Union + +from requests.exceptions import ConnectionError as RequestsConnectionError +from transformers import AutoTokenizer +from transformers.utils import is_torch_available + +from optimum.exporters import TasksManager +from optimum.exporters.onnx import __main__ as optimum_main +from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast +from optimum.utils import DEFAULT_DUMMY_SHAPES +from optimum.utils.save_utils import maybe_save_preprocessors + +from .convert import export_models + + +OV_XML_FILE_NAME = "openvino_model.xml" + +logger = logging.getLogger(__name__) + +if is_torch_available(): + import torch + + +def main_export( + model_name_or_path: str, + output: Union[str, Path], + task: str = "auto", + device: str = "cpu", + fp16: Optional[bool] = False, + framework: Optional[str] = None, + cache_dir: Optional[str] = None, + trust_remote_code: bool = False, + pad_token_id: Optional[int] = None, + subfolder: str = "", + revision: str = "main", + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None, + fn_get_submodels: Optional[Callable] = None, + **kwargs_shapes, +): + """ + Full-suite OpenVINO export. + + Args: + > Required parameters + + model_name_or_path (`str`): + Model ID on huggingface.co or path on disk to the model repository to export. + output (`Union[str, Path]`): + Path indicating the directory where to store the generated ONNX model. + + > Optional parameters + + task (`Optional[str]`, defaults to `None`): + The task to export the model for. If not specified, the task will be auto-inferred based on the model. For decoder models, + use `xxx-with-past` to export the model using past key values in the decoder. + device (`str`, defaults to `"cpu"`): + The device to use to do the export. Defaults to "cpu". + fp16 (`Optional[bool]`, defaults to `"False"`): + Use half precision during the export. PyTorch-only, requires `device="cuda"`. + framework (`Optional[str]`, defaults to `None`): + The framework to use for the ONNX export (`"pt"` or `"tf"`). If not provided, will attempt to automatically detect + the framework for the checkpoint. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories + you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the + model repository. + pad_token_id (`Optional[int]`, defaults to `None`): + This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[str]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): + Experimental usage: keyword arguments to pass to the model during + the export. This argument should be used along the `custom_onnx_configs` argument + in case, for example, the model inputs/outputs are changed (for example, if + `model_kwargs={"output_attentions": True}` is passed). + custom_onnx_configs (`Optional[Dict[str, OnnxConfig]]`, defaults to `None`): + Experimental usage: override the default ONNX config used for the given model. This argument may be useful for advanced users that desire a finer-grained control on the export. An example is available [here](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model). + fn_get_submodels (`Optional[Callable]`, defaults to `None`): + Experimental usage: Override the default submodels that are used at the export. This is + especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success. + **kwargs_shapes (`Dict`): + Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. + + Example usage: + ```python + >>> from optimum.exporters.openvino import main_export + + >>> main_export("gpt2", output="gpt2_onnx/") + ``` + """ + output = Path(output) + if not output.exists(): + output.mkdir(parents=True) + + original_task = task + task = TasksManager.map_from_synonym(task) + + framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework) + + # get the shapes to be used to generate dummy inputs + input_shapes = {} + for input_name in DEFAULT_DUMMY_SHAPES.keys(): + input_shapes[input_name] = ( + kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] + ) + + torch_dtype = None if fp16 is False else torch.float16 + + if task == "auto": + try: + task = TasksManager.infer_task_from_model(model_name_or_path) + except KeyError as e: + raise KeyError( + f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" + ) + except RequestsConnectionError as e: + raise RequestsConnectionError( + f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" + ) + + model = TasksManager.get_model_from_task( + task, + model_name_or_path, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + framework=framework, + torch_dtype=torch_dtype, + device=device, + ) + + custom_architecture = False + is_stable_diffusion = "stable-diffusion" in task + model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-") + + if not is_stable_diffusion: + if model_type in TasksManager._UNSUPPORTED_CLI_MODEL_TYPE: + raise ValueError( + f"{model_type} is not supported yet. Only {TasksManager._SUPPORTED_CLI_MODEL_TYPE} are supported. " + f"If you want to support {model_type} please propose a PR or open up an issue." + ) + if model.config.model_type.replace("-", "_") not in TasksManager.get_supported_model_type_for_task( + task, exporter="onnx" + ): + custom_architecture = True + + if custom_architecture and custom_onnx_configs is None: + raise ValueError( + "Trying to export a model with a custom architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models." + ) + + if custom_architecture and original_task == "auto": + raise ValueError( + f'Automatic task detection is not supported with custom architectures. Please specify the `task` argument. Suggestion: task="{task}" (or task="{task}-with-past" if the model is decoder-based and supports KV cache)' + ) + + if ( + not custom_architecture + and not is_stable_diffusion + and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx") + ): + if original_task == "auto": # Make -with-past the default if --task was not explicitely specified + task = task + "-with-past" + else: + logger.info( + f"The task `{task}` was manually specified, and past key values will not be reused in the decoding." + f" if needed, please pass `--task {task}-with-past` to export using the past key values." + ) + + if original_task == "auto": + synonyms_for_task = sorted(TasksManager.synonyms_for_task(task)) + if synonyms_for_task: + synonyms_for_task = ", ".join(synonyms_for_task) + possible_synonyms = f" (possible synonyms are: {synonyms_for_task})" + else: + possible_synonyms = "" + logger.info(f"Automatic task detection to {task}{possible_synonyms}.") + onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs( + model=model, + task=task, + monolith=False, + custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {}, + custom_architecture=custom_architecture, + fn_get_submodels=fn_get_submodels, + _variant="default", + ) + + if not is_stable_diffusion: + needs_pad_token_id = ( + isinstance(onnx_config, OnnxConfigWithPast) + and getattr(model.config, "pad_token_id", None) is None + and task in ["text-classification"] + ) + if needs_pad_token_id: + if pad_token_id is not None: + model.config.pad_token_id = pad_token_id + else: + try: + tok = AutoTokenizer.from_pretrained(model_name_or_path) + model.config.pad_token_id = tok.pad_token_id + except Exception: + raise ValueError( + "Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument" + ) + # Saving the model config and preprocessor as this is needed sometimes. + model.config.save_pretrained(output) + generation_config = getattr(model, "generation_config", None) + if generation_config is not None: + generation_config.save_pretrained(output) + maybe_save_preprocessors(model_name_or_path, output) + + if model.config.is_encoder_decoder and task.startswith("text-generation"): + raise ValueError( + f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report" + f"at https://github.com/huggingface/optimum, if --task was explicitely passed, make sure you selected the right task for the model," + f" referring to `optimum.exporters.tasks.TaskManager`'s `_TASKS_TO_AUTOMODELS`." + ) + + files_subpaths = None + else: + # save the subcomponent configuration + for model_name in models_and_onnx_configs: + subcomponent = models_and_onnx_configs[model_name][0] + if hasattr(subcomponent, "save_config"): + subcomponent.save_config(output / model_name) + elif hasattr(subcomponent, "config") and hasattr(subcomponent.config, "save_pretrained"): + subcomponent.config.save_pretrained(output / model_name) + + files_subpaths = [os.path.join(name_dir, OV_XML_FILE_NAME) for name_dir in models_and_onnx_configs] + + # Saving the additional components needed to perform inference. + model.scheduler.save_pretrained(output.joinpath("scheduler")) + + feature_extractor = getattr(model, "feature_extractor", None) + if feature_extractor is not None: + feature_extractor.save_pretrained(output.joinpath("feature_extractor")) + + tokenizer = getattr(model, "tokenizer", None) + if tokenizer is not None: + tokenizer.save_pretrained(output.joinpath("tokenizer")) + + tokenizer_2 = getattr(model, "tokenizer_2", None) + if tokenizer_2 is not None: + tokenizer_2.save_pretrained(output.joinpath("tokenizer_2")) + + model.save_config(output) + + export_models( + models_and_onnx_configs=models_and_onnx_configs, + output_dir=output, + output_names=files_subpaths, + input_shapes=input_shapes, + device=device, + model_kwargs=model_kwargs, + ) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py new file mode 100644 index 0000000000..ab688f92fa --- /dev/null +++ b/optimum/exporters/openvino/convert.py @@ -0,0 +1,390 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import gc +import inspect +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from transformers.utils import is_tf_available, is_torch_available + +from openvino.runtime import PartialShape, save_model +from openvino.runtime.utils.types import get_element_type +from openvino.tools.ovc import convert_model +from optimum.exporters.onnx.base import OnnxConfig +from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed +from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx +from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx +from optimum.utils import is_diffusers_available + +from .utils import ( + OV_XML_FILE_NAME, + clear_class_registry, + flattenize_inputs, + get_input_shapes, + remove_none_from_dummy_inputs, +) + + +logger = logging.getLogger(__name__) + +if is_torch_available(): + import torch.nn as nn + from transformers.modeling_utils import PreTrainedModel + +if is_diffusers_available(): + from diffusers import ModelMixin + +if is_tf_available(): + from transformers.modeling_tf_utils import TFPreTrainedModel + + +def export( + model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], + config: OnnxConfig, + output: Path, + opset: Optional[int] = None, + device: str = "cpu", + input_shapes: Optional[Dict] = None, + model_kwargs: Optional[Dict[str, Any]] = None, +) -> Tuple[List[str], List[str]]: + """ + Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation. + + Args: + model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): + The model to export. + config ([`~exporters.onnx.config.OnnxConfig`]): + The ONNX configuration associated with the exported model. + output (`Path`): + Directory to store the exported model. + opset (`Optional[int]`, defaults to `None`): + The version of the ONNX operator set to use. + device (`str`, *optional*, defaults to `cpu`): + The device on which the model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for + export on CUDA devices. + input_shapes (`Optional[Dict]`, defaults to `None`): + If specified, allows to use specific shapes for the example input provided to the exporter. + + Returns: + `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from + the ONNX configuration. + """ + if not (is_torch_available() or is_tf_available()): + raise ImportError( + "Cannot convert because neither PyTorch nor TensorFlow are installed. " + "Please install torch or tensorflow first." + ) + + if "diffusers" in str(model.__class__) and not is_diffusers_available(): + raise ImportError("The pip package `diffusers` is required to export stable diffusion models to ONNX.") + + if is_torch_available() and isinstance(model, nn.Module): + return export_pytorch( + model, + config, + opset, + output, + device=device, + input_shapes=input_shapes, + model_kwargs=model_kwargs, + ) + + elif is_tf_available() and issubclass(type(model), TFPreTrainedModel): + output.parent.mkdir(parents=True, exist_ok=True) + if opset is None: + opset = config.DEFAULT_ONNX_OPSET + if device == "cuda": + raise RuntimeError("`tf2onnx` does not support export on CUDA device.") + if input_shapes is not None: + logger.info("`input_shapes` argument is not supported by the Tensorflow ONNX export and will be ignored.") + return export_tensorflow(model, config, opset, output) + + else: + raise RuntimeError( + "You either provided a PyTorch model with only TensorFlow installed, or a TensorFlow model with only PyTorch installed." + ) + + +def export_tensorflow(model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, opset: int, output: Path): + """ + Export the TensorFlow model to OpenVINO format. + + Args: + model (Union[): The model to export. + config (OnnxConfig): The configuration of the model. + opset (int): The ONNX opset version to use. + output (Path): The path to save the model. + + Returns: + input_names: list of input names from ONNX configuration + output_names: list of output names from ONNX configuration + bool: True if the model was exported successfully. + """ + onnx_path = Path(output).with_suffix(".onnx") + input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path) + ov_model = convert_model(str(onnx_path)) + save_model( + ov_model, + output.parent / output, + compress_to_fp16=False, + ) + return input_names, output_names, True + + +def export_pytorch_via_onnx( + model: Union["PreTrainedModel", "ModelMixin"], + config: OnnxConfig, + opset: int, + output: Path, + device: str = "cpu", + input_shapes: Optional[Dict] = None, + model_kwargs: Optional[Dict[str, Any]] = None, +): + """ + Exports a PyTorch model to an OpenVINO Intermediate Representation via ONNX export. + + Args: + model ([`PreTrainedModel`]): + The model to export. + config ([`~exporters.onnx.config.OnnxConfig`]): + The configuration associated with the exported model. + opset (`int`): + The version of the ONNX operator set to use. + output (`Path`): + Directory to store the exported model. + device (`str`, defaults to `"cpu"`): + The device on which the model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for + export on CUDA devices. + input_shapes (`optional[Dict]`, defaults to `None`): + If specified, allows to use specific shapes for the example input provided to the exporter. + model_kwargs (optional[Dict[str, Any]], defaults to `None`): + Additional kwargs for model export + + Returns: + `Tuple[List[str], List[str], bool]`: A tuple with an ordered list of the model's inputs, and the named inputs from + the ONNX configuration and boolean flag - was legacy ONNX path were applied to model or not. + """ + import torch + + output = Path(output) + orig_torch_onnx_export = torch.onnx.export + torch.onnx.export = functools.partial(orig_torch_onnx_export, do_constant_folding=False) + model.config.torchscript = False + model.config.return_dict = True + onnx_output = output.with_suffix(".onnx") + input_names, output_names = export_pytorch_to_onnx( + model, config, opset, onnx_output, device, input_shapes, model_kwargs + ) + torch.onnx.export = orig_torch_onnx_export + ov_model = convert_model(str(onnx_output)) + save_model( + ov_model, + output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output, + compress_to_fp16=False, + ) + return input_names, output_names, True + + +def export_pytorch( + model: Union["PreTrainedModel", "ModelMixin"], + config: OnnxConfig, + opset: int, + output: Path, + device: str = "cpu", + input_shapes: Optional[Dict] = None, + model_kwargs: Optional[Dict[str, Any]] = None, +) -> Tuple[List[str], List[str]]: + """ + Exports a PyTorch model to an OpenVINO Intermediate Representation. + + Args: + model ([`PreTrainedModel`]): + The model to export. + config ([`~exporters.onnx.config.OnnxConfig`]): + The configuration associated with the exported model. + opset (`int`): + The version of the ONNX operator set to use. + output (`Path`): + Directory to store the exported model. + device (`str`, defaults to `"cpu"`): + The device on which the model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for + export on CUDA devices. + input_shapes (`optional[Dict]`, defaults to `None`): + If specified, allows to use specific shapes for the example input provided to the exporter. + model_kwargs (optional[Dict[str, Any]], defaults to `None`): + Additional kwargs for model export + + Returns: + `Tuple[List[str], List[str], bool]`: A tuple with an ordered list of the model's inputs, and the named inputs from + the ONNX configuration and boolean flag - was legacy ONNX path were applied to model or not. + """ + import torch + from torch.utils._pytree import tree_map + + logger.info(f"Using framework PyTorch: {torch.__version__}") + output = Path(output) + + with torch.no_grad(): + model.config.torchscript = False + model.config.return_dict = True + model.eval() + + # Check if we need to override certain configuration item + if config.values_override is not None: + logger.info(f"Overriding {len(config.values_override)} configuration item(s)") + for override_config_key, override_config_value in config.values_override.items(): + logger.info(f"\t- {override_config_key} -> {override_config_value}") + setattr(model.config, override_config_key, override_config_value) + + if input_shapes is None: + input_shapes = {} # will use the defaults from DEFAULT_DUMMY_SHAPES + + # Check that inputs match, and order them properly + dummy_inputs = config.generate_dummy_inputs(framework="pt", **input_shapes) + device = torch.device(device) + if device.type == "cuda" and torch.cuda.is_available(): + model.to(device) + dummy_inputs = tree_map( + lambda value: value.to(device) if isinstance(value, torch.Tensor) else value, dummy_inputs + ) + check_dummy_inputs_are_allowed(model, dummy_inputs) + inputs = config.ordered_inputs(model) + input_names = list(inputs.keys()) + output_names = list(config.outputs.keys()) + if hasattr(model, "forward"): + sig = inspect.signature(model.forward) + else: + sig = inspect.signature(model.call) + + dummy_inputs, dict_inputs = remove_none_from_dummy_inputs(dummy_inputs) + input_info = get_input_shapes(dummy_inputs, inputs) + custom_patcher = type(config).patch_model_for_export != OnnxConfig.patch_model_for_export + try: + # TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching, + # while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output + # To handle it, additional wrapper on patcher forward applied. + # model.config.torchscript = True can not be used for patching, because it overrides return_dict to Flase + if custom_patcher or dict_inputs: + patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs) + patched_forward = patcher.patched_forward + + @functools.wraps(patched_forward) + def ts_patched_forward(*args, **kwargs): + for i in range(len(dict_inputs)): + input_name = dict_inputs[i][0] + keys = dict_inputs[i][1] + tuple_input = kwargs[input_name] + input_dict = dict(zip(keys, tuple_input)) + kwargs[input_name] = input_dict + outputs = patched_forward(*args, **kwargs) + return tuple(outputs.values()) + + patcher.patched_forward = ts_patched_forward + with patcher: + ov_model = convert_model(model, example_input=dummy_inputs, input=input_info) + else: + model.config.torchscript = True + ov_model = convert_model(model, example_input=dummy_inputs, input=input_info) + except Exception as ex: + logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX") + return export_pytorch_via_onnx(model, config, opset, output, device, input_shapes, model_kwargs) + ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs} + ordered_input_names = list(inputs) + flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values()) + ov_model.validate_nodes_and_infer_types() + for idx, out_tensor in enumerate(ov_model.outputs): + if idx < len(output_names): + out_tensor.get_tensor().set_names({output_names[idx]}) + + for idx, inp_tensor in enumerate(ov_model.inputs): + input_name = ordered_input_names[idx] + inp_tensor.get_tensor().set_names({input_name}) + inp_data = flatten_inputs[idx] + static_shape = PartialShape(inp_data.shape) + dims = inputs[input_name] + + for dim in dims: + static_shape[dim] = -1 + inp_tensor.get_node().set_partial_shape(static_shape) + inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) + ov_model.validate_nodes_and_infer_types() + save_model(ov_model, output, compress_to_fp16=False) + clear_class_registry() + del model + gc.collect() + return input_names, output_names, False + + +def export_models( + models_and_onnx_configs: Dict[ + str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"] + ], + output_dir: Path, + opset: Optional[int] = None, + output_names: Optional[List[str]] = None, + device: str = "cpu", + input_shapes: Optional[Dict] = None, + model_kwargs: Optional[Dict[str, Any]] = None, +) -> Tuple[List[List[str]], List[List[str]]]: + """ + Export the models to OpenVINO IR format + + Args: + models_and_onnx_configs (Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]): + output_dir (Path): output directory for saving models + opset (Optional[int], optional, Default to None): ONNX export opset + output_names (Optional[List[str]], optional, Defaults to None): model output names + device (str, optional, Defaults to "cpu"): + The device on which the model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for + export on CUDA devices. + input_shapes (Optional[Dict], optional, Defaults to None): + If specified, allows to use specific shapes for the example input provided to the exporter. + model_kwargs (Optional[Dict[str, Any]], optional): + Additional kwargs for model export + + Raises: + ValueError: if custom names set not equal of number of models + + Returns: + list of input_names and output_names from ONNX configuration + """ + outputs = [] + + if output_names is not None and len(output_names) != len(models_and_onnx_configs): + raise ValueError( + f"Provided custom names {output_names} for the export of {len(models_and_onnx_configs)} models. Please provide the same number of names as models to export." + ) + + for i, model_name in enumerate(models_and_onnx_configs.keys()): + submodel, sub_onnx_config = models_and_onnx_configs[model_name] + output_name = output_names[i] if output_names is not None else Path(model_name + ".xml") + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + outputs.append( + export( + model=submodel, + config=sub_onnx_config, + output=output_path, + opset=opset, + device=device, + input_shapes=input_shapes, + model_kwargs=model_kwargs, + ) + ) + + outputs = list(map(list, zip(*outputs))) + return outputs diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py new file mode 100644 index 0000000000..f0d5366526 --- /dev/null +++ b/optimum/exporters/openvino/utils.py @@ -0,0 +1,142 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Tuple, Union + +from transformers.utils import is_torch_available + +from openvino.runtime import PartialShape +from optimum.utils import is_diffusers_available + + +if is_torch_available(): + import torch + import torch.nn as nn + from transformers.modeling_utils import PreTrainedModel + +if is_diffusers_available(): + from diffusers import ModelMixin + + +OV_XML_FILE_NAME = "openvino_model.xml" + + +def is_torch_model(model: Union["PreTrainedModel", "ModelMixin"]): + """ + Checks whether the model is a torch model. + + Args: + model (Union[PretrainedModel, ModelMixin]): The model to check. + + Returns: + bool: True if the model is a torch model. + """ + if not is_torch_available(): + return False + return isinstance(model, nn.Module) + + +def flattenize_inputs(inputs: List[Any]): + """ + Flatten the inputs into a list. + + Args: + inputs (List[Any]): The inputs to flatten. + + Returns: + List[Any]: The flattened inputs. + """ + flatten_inputs = [] + for input_data in inputs: + if input_data is None: + continue + if isinstance(input_data, (list, tuple)): + flatten_inputs.extend(flattenize_inputs(input_data)) + else: + flatten_inputs.append(input_data) + return flatten_inputs + + +def remove_none_from_dummy_inputs(dummy_inputs: Dict[str, Any]): + """ + Removes None values from the dictionary. + + Args: + dummy_inputs (Dict[str, Any]): Dictionary with None values. + Returns: + upd_dummy (Dict[str, Any]): updated dictionary with removed None values + dict_dummy (List[Tuple[str, List[str]]]): list of inputs represented as dictionary provided as pair name and list of nested keys + """ + + def remove_none_from_list_tuple(item: Union[List[Any], Tuple[Any]]): + """ + Removes None values from a list or tuple. + + Args: + item (list or tuple): The list or tuple to remove None values from. + + Returns: + list or tuple: The list or tuple with None values removed. + """ + new_item = [i for i in item if i is not None] + return type(item)(new_item) + + upd_dummy = {} + dict_dummy = [] + for k, v in dummy_inputs.items(): + if v is None: + continue + if isinstance(v, dict): + dict_dummy.append((k, list(v.keys()))) + upd_dummy[k] = remove_none_from_list_tuple(tuple(v.values())) + continue + if isinstance(v, (tuple, list)): + upd_dummy[k] = remove_none_from_list_tuple(v) + continue + upd_dummy[k] = v + return upd_dummy, dict_dummy + + +def get_input_shapes(dummy_inputs: Dict[str, Any], inputs: Dict[str, Any]): + """ + Resolves input shapes based on dynamic axes from input config and dummy input shapes + + Args: + dummy_inputs (Dict[str, Any]): A dictionary of dummy inputs. + inputs (Dict[str, Any]): A dictionary of input tensors. + + Returns: + input_info: List of input info for conversion + + """ + input_info = [] + for input_name, data in dummy_inputs.items(): + if isinstance(data, (tuple, list, dict)): + return None + static_shape = PartialShape(data.shape) + if input_name in inputs: + dynamic_dims = inputs[input_name] + for dim in dynamic_dims: + static_shape[dim] = -1 + input_info.append((input_name, static_shape)) + return input_info + + +def clear_class_registry(): + """ + Removes Torchscript cached modules + """ + torch._C._jit_clear_class_registry() + torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() + torch.jit._state._clear_class_state() diff --git a/optimum/intel/neural_compressor/configuration.py b/optimum/intel/neural_compressor/configuration.py index 32d5e95375..7f5370e5ee 100644 --- a/optimum/intel/neural_compressor/configuration.py +++ b/optimum/intel/neural_compressor/configuration.py @@ -25,6 +25,7 @@ "post_training_dynamic_quant": "dynamic", "post_training_static_quant": "static", "quant_aware_training": "aware_training", + "post_training_weight_only": "weight_only", } diff --git a/optimum/intel/neural_compressor/quantization.py b/optimum/intel/neural_compressor/quantization.py index de5a5d5727..273d610e9d 100644 --- a/optimum/intel/neural_compressor/quantization.py +++ b/optimum/intel/neural_compressor/quantization.py @@ -74,6 +74,7 @@ logger = logging.getLogger(__name__) NEURAL_COMPRESSOR_MINIMUM_VERSION = "2.1.0" +NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION = "2.3.0" IPEX_MINIMUM_VERSION = "2.1.0" if is_neural_compressor_version("<", NEURAL_COMPRESSOR_MINIMUM_VERSION): @@ -87,6 +88,7 @@ class INCQuantizationMode(Enum): DYNAMIC = "post_training_dynamic_quant" STATIC = "post_training_static_quant" AWARE_TRAINING = "quant_aware_training" + WEIGHT_ONLY = "post_training_weight_only" SUPPORTED_QUANT_MODE = {approach.value for approach in INCQuantizationMode} @@ -142,6 +144,7 @@ def quantize( data_collator: Optional[DataCollator] = None, remove_unused_columns: bool = True, file_name: str = None, + weight_only: bool = False, **kwargs, ): """ @@ -160,6 +163,9 @@ def quantize( The function to use to form a batch from a list of elements of the calibration dataset. remove_unused_columns (`bool`, defaults to `True`): Whether or not to remove the columns unused by the model forward method. + weight_only (`bool`, defaults to `False`): + Whether compress weights to integer precision (4-bit by default) while keeping activations + floating-point. Fits best for LLM footprint reduction and performance acceleration. """ save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) @@ -168,7 +174,40 @@ def quantize( calibration_dataloader = None self._set_task() - if INCQuantizationMode(quantization_config.approach) == INCQuantizationMode.STATIC: + if weight_only: + # check neural-compressor version + if is_neural_compressor_version("<", NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION): + raise ImportError( + f"Found an incompatible version of neural-compressor. Found version {_neural_compressor_version}, " + f"but only version {NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION} or higher supports weight-only quantization." + ) + + # If op_type_dict of quantization_config is not defined, it will use default values for weight-only quantization: + # {"bits": 4, "group_size": 32, "scheme": "sym", "algorithm": "RTN"} + if isinstance(quantization_config.op_type_dict, dict) and len(quantization_config.op_type_dict) > 0: + algo = [] + for _, val in quantization_config.op_type_dict.items(): + algo += val.get("weight", {}).get("algorithm", ["RTN"]) + else: + algo = ["RTN"] + + if calibration_dataset is None and ("GPTQ" in algo or "AWQ" in algo): + raise ValueError( + "Weight-only quantization needs a calibration dataset for both GPTQ and AWQ methodologies." + ) + + if calibration_dataset is None: + calibration_dataloader = None + else: + calibration_dataloader = self._get_calibration_dataloader( + calibration_dataset=calibration_dataset, + batch_size=batch_size, + remove_unused_columns=remove_unused_columns, + data_collator=data_collator, + use_label=False if "GPTQ" in algo else True, + ) + + elif INCQuantizationMode(quantization_config.approach) == INCQuantizationMode.STATIC: # Since PyTorch fx trace does not really require an example_inputs, only need calibration_dataset or calibration_fn here. if calibration_dataset is None and self.calibration_fn is None: raise ValueError( @@ -378,6 +417,7 @@ def _get_calibration_dataloader( batch_size: int, remove_unused_columns: bool, data_collator: Optional[DataCollator] = None, + use_label: Optional[bool] = True, ) -> INCDataLoader: data_collator = data_collator if data_collator is not None else default_data_collator if remove_unused_columns: @@ -394,7 +434,7 @@ def _get_calibration_dataloader( drop_last=False, ) - return INCDataLoader.from_pytorch_dataloader(calibration_dataloader) + return INCDataLoader.from_pytorch_dataloader(calibration_dataloader, use_label) def _remove_unused_columns(self, dataset: Dataset): ignored_columns = list(set(dataset.column_names) - set(self._signature_columns)) diff --git a/optimum/intel/neural_compressor/utils.py b/optimum/intel/neural_compressor/utils.py index dd77011c04..fa21122595 100644 --- a/optimum/intel/neural_compressor/utils.py +++ b/optimum/intel/neural_compressor/utils.py @@ -49,11 +49,14 @@ class INCDataLoader(DataLoader): + use_label = True + @classmethod - def from_pytorch_dataloader(cls, dataloader: DataLoader): + def from_pytorch_dataloader(cls, dataloader: DataLoader, use_label: bool = True): if not isinstance(dataloader, DataLoader): raise TypeError(f"Expected a PyTorch DataLoader, got: {type(dataloader)}.") inc_dataloader = cls(dataloader.dataset) + cls.use_label = use_label for key, value in dataloader.__dict__.items(): inc_dataloader.__dict__[key] = value return inc_dataloader @@ -63,7 +66,10 @@ def __iter__(self): if not isinstance(input, (dict, tuple, list, UserDict)): raise TypeError(f"Model calibration cannot use input of type {type(input)}.") label = input.get("labels") if isinstance(input, dict) else None - yield input, label + if self.use_label: + yield input, label + else: + yield input def _cfgs_to_fx_cfgs(op_cfgs: Dict, observer_type: str = "post_training_static_quant") -> Dict: diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index 1cea230429..95fb0aca8b 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -549,7 +549,7 @@ def from_pretrained( model = TimmForImageClassification.from_pretrained(model_id, **kwargs) onnx_config = TimmOnnxConfig(model.config) - return cls._to_onnx_to_load( + return cls._to_load( model=model, config=config, onnx_config=onnx_config, diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 59fc89649a..42bdb8edba 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -20,15 +20,16 @@ import openvino from huggingface_hub import hf_hub_download +from openvino import Core, convert_model from openvino._offline_transformations import apply_moc_transformations, compress_model_transformation -from openvino.runtime import Core from transformers import PretrainedConfig from transformers.file_utils import add_start_docstrings -from optimum.exporters.onnx import OnnxConfig, export +from optimum.exporters.onnx import OnnxConfig from optimum.exporters.tasks import TasksManager from optimum.modeling_base import OptimizedModel +from ...exporters.openvino import export from ..utils.import_utils import is_transformers_version from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME @@ -127,9 +128,7 @@ def fix_op_names_duplicates(model: openvino.runtime.Model): if isinstance(file_name, str): file_name = Path(file_name) - bin_file_name = file_name.with_suffix(".bin") if file_name.suffix == ".xml" else None - - model = core.read_model(file_name, bin_file_name) + model = core.read_model(file_name) if not file_name.suffix == ".onnx" else convert_model(file_name) if file_name.suffix == ".onnx": model = fix_op_names_duplicates(model) # should be called during model conversion to IR @@ -145,7 +144,7 @@ def _save_pretrained(self, save_directory: Union[str, Path]): The directory where to save the model files. """ dst_path = os.path.join(save_directory, OV_XML_FILE_NAME) - openvino.runtime.serialize(self.model, dst_path) + openvino.save_model(self.model, dst_path, compress_to_fp16=False) @classmethod def _from_pretrained( @@ -198,6 +197,7 @@ def _from_pretrained( else: model_file_names = [file_name] # If not ONNX then OpenVINO IR + if not from_onnx: model_file_names.append(file_name.replace(".xml", ".bin")) file_names = [] @@ -276,7 +276,7 @@ def _from_transformers( onnx_config = onnx_config_class(model.config) - return cls._to_onnx_to_load( + return cls._to_load( model=model, config=config, onnx_config=onnx_config, @@ -288,7 +288,7 @@ def _from_transformers( ) @classmethod - def _to_onnx_to_load( + def _to_load( cls, model: PreTrainedModel, config: PretrainedConfig, @@ -308,13 +308,13 @@ def _to_onnx_to_load( model=model, config=onnx_config, opset=onnx_config.DEFAULT_ONNX_OPSET, - output=save_dir_path / ONNX_WEIGHTS_NAME, + output=save_dir_path / OV_XML_FILE_NAME, ) return cls._from_pretrained( model_id=save_dir_path, config=config, - from_onnx=True, + from_onnx=False, use_auth_token=use_auth_token, revision=revision, force_download=force_download, diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index a8ce3d0bf5..f8e09b2c91 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -24,9 +24,10 @@ from transformers import PretrainedConfig from transformers.file_utils import add_start_docstrings -from optimum.exporters.onnx import export_models, get_encoder_decoder_models_for_export -from optimum.exporters.tasks import TasksManager +from optimum.exporters import TasksManager +from optimum.exporters.onnx import get_encoder_decoder_models_for_export +from ...exporters.openvino import export_models from ..utils.import_utils import is_transformers_version from .modeling_base import OVBaseModel from .utils import ( @@ -104,7 +105,7 @@ def _save_pretrained(self, save_directory: Union[str, Path]): for src_file, dst_file_name in zip(src_files, dst_file_names): dst_path = os.path.join(save_directory, dst_file_name) - openvino.runtime.serialize(src_file, dst_path) + openvino.save_model(src_file, dst_path, compress_to_fp16=False) @classmethod def _from_pretrained( @@ -243,9 +244,6 @@ def _from_transformers( kwargs (`Dict`, *optional*): kwargs will be passed to the model during initialization """ - encoder_file_name = os.path.join("encoder", ONNX_ENCODER_NAME) - decoder_file_name = os.path.join("decoder", ONNX_DECODER_NAME) - decoder_with_past_file_name = os.path.join("decoder_with_past", ONNX_DECODER_WITH_PAST_NAME) task = task or cls.export_feature save_dir = TemporaryDirectory() @@ -265,6 +263,9 @@ def _from_transformers( onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) onnx_config = onnx_config_constructor(model.config, use_past=use_cache) models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config) + encoder_file_name = os.path.join("encoder", OV_ENCODER_NAME) + decoder_file_name = os.path.join("decoder", OV_DECODER_NAME) + decoder_with_past_file_name = os.path.join("decoder_with_past", OV_DECODER_WITH_PAST_NAME) output_names = [encoder_file_name, decoder_file_name] if use_cache is True: @@ -281,7 +282,7 @@ def _from_transformers( model_id=save_dir_path, config=config, use_cache=use_cache, - from_onnx=True, + from_onnx=False, use_auth_token=use_auth_token, revision=revision, force_download=force_download, diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 28c839f231..387e3daf51 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -27,14 +27,14 @@ from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.modeling_outputs import CausalLMOutputWithPast -from optimum.exporters.onnx import export -from optimum.exporters.tasks import TasksManager +from optimum.exporters import TasksManager from optimum.utils import NormalizedConfigManager +from ...exporters.openvino import export from ..utils.import_utils import is_transformers_version -from ..utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask +from ..utils.modeling_utils import patch_decoder_attention_mask from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel -from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE +from .utils import OV_XML_FILE_NAME, STR_TO_OV_TYPE if is_transformers_version("<", "4.25.0"): @@ -190,7 +190,7 @@ def _save_pretrained(self, save_directory: Union[str, Path]): """ model_to_save = self.model if self._pkv_precision == Type.f32 else self._original_model dst_path = os.path.join(save_directory, OV_XML_FILE_NAME) - openvino.runtime.serialize(model_to_save, dst_path) + openvino.save_model(model_to_save, dst_path, compress_to_fp16=False) @classmethod def _from_transformers( @@ -232,25 +232,20 @@ def _from_transformers( onnx_config = onnx_config_constructor(model.config, use_past=use_cache) # TODO : create ModelPatcher to patch each architecture - if config.model_type in {"bloom", "mpt"}: - model.transformer._prepare_attn_mask = _prepare_attn_mask - elif config.model_type == "llama": - model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - elif config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: - model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + model = patch_decoder_attention_mask(model) - # Export the model to the ONNX format - export(model=model, config=onnx_config, output=save_dir_path / ONNX_WEIGHTS_NAME) + # Export the model to the OpenVINO IR format + export(model=model, config=onnx_config, output=save_dir_path / OV_XML_FILE_NAME) return cls._from_pretrained( model_id=save_dir_path, config=config, - from_onnx=True, + from_onnx=False, use_auth_token=use_auth_token, revision=revision, force_download=force_download, cache_dir=cache_dir, - file_name=ONNX_WEIGHTS_NAME, + file_name=OV_XML_FILE_NAME, local_files_only=local_files_only, use_cache=use_cache, **kwargs, diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index b1679595d3..383171e144 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -36,7 +36,6 @@ from openvino.runtime import Core from transformers import CLIPFeatureExtractor, CLIPTokenizer -from optimum.exporters.onnx import main_export from optimum.pipelines.diffusers.pipeline_stable_diffusion import StableDiffusionPipelineMixin from optimum.pipelines.diffusers.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipelineMixin from optimum.pipelines.diffusers.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipelineMixin @@ -51,6 +50,7 @@ DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, ) +from ...exporters.openvino import main_export from .loaders import OVTextualInversionLoaderMixin from .modeling_base import OVBaseModel from .utils import ONNX_WEIGHTS_NAME, OV_TO_NP_TYPE, OV_XML_FILE_NAME @@ -159,7 +159,7 @@ def _save_pretrained(self, save_directory: Union[str, Path]): if ov_model is not None: dst_path = save_directory / dst_path / OV_XML_FILE_NAME dst_path.parent.mkdir(parents=True, exist_ok=True) - openvino.runtime.serialize(ov_model.model, dst_path) + openvino.save_model(ov_model.model, dst_path, compress_to_fp16=False) model_dir = ov_model.config.get("_name_or_path", None) or ov_model._model_dir / ov_model._model_name config_path = Path(model_dir) / ov_model.CONFIG_NAME if config_path.is_file(): @@ -315,7 +315,7 @@ def _from_transformers( return cls._from_pretrained( model_id=save_dir_path, config=config, - from_onnx=True, + from_onnx=False, use_auth_token=use_auth_token, revision=revision, force_download=force_download, @@ -606,6 +606,11 @@ def __call__(self, latent_sample: np.ndarray): outputs = self.request(inputs, share_inputs=True, share_outputs=True) return list(outputs.values()) + def _compile(self): + if "GPU" in self.device: + self.ov_config.update({"INFERENCE_PRECISION_HINT": "f32"}) + super()._compile() + class OVModelVaeEncoder(OVModelPart): def __init__( @@ -622,6 +627,11 @@ def __call__(self, sample: np.ndarray): outputs = self.request(inputs, share_inputs=True, share_outputs=True) return list(outputs.values()) + def _compile(self): + if "GPU" in self.device: + self.ov_config.update({"INFERENCE_PRECISION_HINT": "f32"}) + super()._compile() + class OVStableDiffusionPipeline(OVStableDiffusionPipelineBase, StableDiffusionPipelineMixin): def __call__( diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index c56ef632e6..9edeaae30c 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -412,7 +412,6 @@ def forward( # Add the encoder_hidden_states inputs when needed if "encoder_hidden_states" in self.input_names and encoder_hidden_states is not None: inputs["encoder_hidden_states"] = encoder_hidden_states - # Run inference results = self.request.infer(inputs, share_inputs=True, share_outputs=True) logits = torch.from_numpy(results["logits"]).to(self.device) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 99e22e72f5..3349ce142f 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -24,21 +24,23 @@ import transformers from accelerate.data_loader import DataLoaderStateMixin from datasets import Dataset, load_dataset -from nncf import NNCFConfig -from nncf.torch import create_compressed_model, register_default_init_args +from nncf import NNCFConfig, compress_weights +from nncf.torch import create_compressed_model, register_default_init_args, register_module from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk from nncf.torch.initialization import PTInitializingDataLoader from openvino._offline_transformations import compress_quantize_weights_transformation from openvino.runtime import Core, Tensor -from torch.utils.data import DataLoader, RandomSampler, TensorDataset +from torch.utils.data import DataLoader, RandomSampler from transformers import DataCollator, PreTrainedModel, default_data_collator +from transformers.pytorch_utils import Conv1D -from optimum.exporters.onnx import export from optimum.exporters.tasks import TasksManager from optimum.quantization_base import OptimumQuantizer +from ...exporters.openvino import export, export_pytorch_via_onnx from ..utils.constant import _TASK_ALIASES -from .configuration import INT8_WEIGHT_COMPRESSION_CONFIG, OVConfig +from ..utils.modeling_utils import patch_decoder_attention_mask +from .configuration import OVConfig from .modeling_base import OVBaseModel from .modeling_decoder import OVBaseDecoderModel from .utils import ( @@ -49,6 +51,8 @@ ) +register_module(ignored_algorithms=[])(Conv1D) + core = Core() logger = logging.getLogger(__name__) @@ -332,8 +336,8 @@ def _quantize_torchmodel( self._set_task() save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) - file_name = file_name if file_name is not None else OV_XML_FILE_NAME - output_path = save_directory.joinpath(file_name) + ov_file_name = file_name if file_name is not None else OV_XML_FILE_NAME + output_path = save_directory.joinpath(ov_file_name) output_path = output_path.with_suffix(".xml").as_posix() model_type = self.model.config.model_type.replace("_", "-") @@ -344,73 +348,73 @@ def _quantize_torchmodel( model_type=model_type, ) - if weights_only: - calibration_dataset = TensorDataset(torch.tensor([0.0, 1.0])) - calibration_dataset.column_names = [] - remove_unused_columns = False - onnx_config = onnx_config_class(self.model.config) - - def data_collator(batch): - return onnx_config.generate_dummy_inputs(framework="pt") - - calibration_dataloader = self._get_calibration_dataloader( - calibration_dataset=calibration_dataset, - batch_size=batch_size, - remove_unused_columns=remove_unused_columns, - data_collator=data_collator, - ) - if quantization_config is None: logger.info( "No configuration describing the quantization process was provided, a default OVConfig will be generated." ) - quantization_config = OVConfig(compression=INT8_WEIGHT_COMPRESSION_CONFIG) if weights_only else OVConfig() - - model_inputs = next(iter(calibration_dataloader)) - quantization_config.add_input_info(model_inputs) - nncf_config = NNCFConfig.from_dict(quantization_config.__dict__) - nncf_config = register_default_init_args(nncf_config, calibration_dataloader) - controller, compressed_model = create_compressed_model( - self.model, nncf_config, wrap_inputs_fn=wrap_nncf_model_inputs_with_objwalk + quantization_config = OVConfig() + onnx_file_name = ( + ONNX_WEIGHTS_NAME + if file_name is None and quantization_config.save_onnx_model + else Path(ov_file_name).with_suffix(".onnx") ) - compressed_model = controller.strip(do_copy=False) + if weights_only: + if getattr(self.model.config, "tie_word_embeddings", True): + # to fix problem with shared embedding weights in nncf compress_weights() + self.model.tie_weights() + compressed_model = compress_weights(self.model) + self.model = compressed_model + else: + calibration_dataloader = self._get_calibration_dataloader( + calibration_dataset=calibration_dataset, + batch_size=batch_size, + remove_unused_columns=remove_unused_columns, + data_collator=data_collator, + ) + + model_inputs = next(iter(calibration_dataloader)) + quantization_config.add_input_info(model_inputs) + nncf_config = NNCFConfig.from_dict(quantization_config.__dict__) + nncf_config = register_default_init_args(nncf_config, calibration_dataloader) + controller, compressed_model = create_compressed_model( + self.model, nncf_config, wrap_inputs_fn=wrap_nncf_model_inputs_with_objwalk + ) + compressed_model = controller.strip(do_copy=False) task = self.task model = self.model self.model.config.save_pretrained(save_directory) - + model = patch_decoder_attention_mask(model) if task == "text-generation": onnx_config = onnx_config_class(model.config, use_past=model.config.use_cache) else: onnx_config = onnx_config_class(model.config) - onnx_path = save_directory / ONNX_WEIGHTS_NAME - - # Export the model to the ONNX format + model_path = save_directory / (onnx_file_name if quantization_config.save_onnx_model else ov_file_name) + onnx_path = save_directory / onnx_file_name + export_fn = export if not quantization_config.save_onnx_model else export_pytorch_via_onnx opset = min(onnx_config.DEFAULT_ONNX_OPSET, MAX_ONNX_OPSET) opset = max(opset, MIN_ONNX_QDQ_OPSET) - export( - model=compressed_model, - config=onnx_config, - opset=opset, - output=onnx_path, - ) + _, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset) + if is_onnx: + # Load and save the compressed model + model = core.read_model(onnx_path) + # Model required second saving for appling weights compression transformations + self._save_pretrained(model, output_path) + # if onnx conversion happens as fallback for pytorch conversion, remove onnx model + if not quantization_config.save_onnx_model: + os.remove(onnx_path) + try: + os.remove(f"{onnx_path}_data") + except FileNotFoundError: + pass - # Load and save the compressed model - model = core.read_model(onnx_path) - self._save_pretrained(model, output_path) quantization_config.save_pretrained(save_directory) - if not quantization_config.save_onnx_model: - os.remove(onnx_path) - try: - os.remove(f"{onnx_path}_data") - except FileNotFoundError: - pass @staticmethod def _save_pretrained(model: openvino.runtime.Model, output_path: str): compress_quantize_weights_transformation(model) - openvino.runtime.serialize(model, output_path) + openvino.save_model(model, output_path, compress_to_fp16=False) def _set_task(self): if self.task is None: diff --git a/optimum/intel/openvino/trainer.py b/optimum/intel/openvino/trainer.py index 811309806a..0bba054ad3 100644 --- a/optimum/intel/openvino/trainer.py +++ b/optimum/intel/openvino/trainer.py @@ -39,13 +39,13 @@ from nncf.torch.compression_method_api import PTCompressionAlgorithmController from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.algo import QuantizationController -from openvino._offline_transformations import compress_quantize_weights_transformation -from openvino.runtime import Core, PartialShape, serialize -from openvino.tools.mo.back.offline_transformations import ( +from openvino._offline_transformations import ( apply_fused_names_cleanup, apply_moc_transformations, - apply_user_transformations, + apply_pruning_transformation, + compress_quantize_weights_transformation, ) +from openvino.runtime import Core, PartialShape, save_model from torch.onnx import export as onnx_export from torch.utils._pytree import tree_map from torch.utils.data import DataLoader, Dataset, RandomSampler @@ -134,7 +134,7 @@ def remap(value): with torch.no_grad(): model.eval() # Disable node additions to be exported in the graph - model.disable_dynamic_graph_building() + model.nncf.disable_dynamic_graph_building() onnx_export( model, model_inputs, @@ -145,7 +145,7 @@ def remap(value): do_constant_folding=True, opset_version=opset, ) - model.enable_dynamic_graph_building() + model.nncf.enable_dynamic_graph_building() class OVTrainer(Trainer): @@ -752,10 +752,10 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): try: # OpenVINO IR pruning requires static-shaped input ov_model = self._reshape_ir(ov_model, static_shape=True) - apply_moc_transformations(ov_model) + apply_moc_transformations(ov_model, cf=False) if self._get_compression_controller_by_cls(QuantizationController) is not None: compress_quantize_weights_transformation(ov_model) - apply_user_transformations(ov_model, [("Pruning", {})]) + apply_pruning_transformation(ov_model) apply_fused_names_cleanup(ov_model) # Reshape back to dynamic shape IR ov_model = self._reshape_ir(ov_model, static_shape=False) @@ -772,7 +772,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): compress_quantize_weights_transformation(ov_model) # Serialize IR xml and bin - serialize(ov_model, output_path) + save_model(ov_model, output_path, compress_to_fp16=False) def _get_compression_controller_by_cls( self, controller_cls: Type[PTCompressionAlgorithmController] diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index c7be049990..17abf1059e 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -15,6 +15,7 @@ from typing import Tuple import torch +from transformers.modeling_utils import PreTrainedModel # Modified from transformers.models.bloom.modeling_bloom._make_causal_mask @@ -89,3 +90,22 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, ) return combined_attention_mask + + +def patch_decoder_attention_mask(model: "PreTrainedModel"): + """ + Apply patch on decoder with past model forward to resolve first inference based on model architecture + + Args: + model (PretrainedModel): The model to patch. + + Returns: + model with applied patch + """ + if model.config.model_type in {"bloom", "mpt"}: + model.transformer._prepare_attn_mask = _prepare_attn_mask + elif model.config.model_type == "llama": + model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: + model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + return model diff --git a/setup.py b/setup.py index 0034ec2560..6d81b98b2a 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ "onnxruntime<1.15.0", ], "openvino": ["openvino>=2023.1.0", "onnx", "onnxruntime"], - "nncf": ["nncf>=2.5.0", "openvino-dev>=2023.1.0"], + "nncf": ["nncf>=2.6.0"], "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, diff --git a/tests/neural_compressor/test_optimization.py b/tests/neural_compressor/test_optimization.py index 578f556153..e31739b943 100644 --- a/tests/neural_compressor/test_optimization.py +++ b/tests/neural_compressor/test_optimization.py @@ -168,6 +168,91 @@ def test_ipex_static_quantization_with_smoothquant(self, task, model_name, expec num_samples=num_samples, ) + def test_weight_only_quantization(self): + model_name = "hf-internal-testing/tiny-random-GPTNeoForCausalLM" + op_type_dict = { + ".*": { + "weight": { + "bits": 8, + "group_size": -1, + "scheme": "sym", + "algorithm": "RTN", + }, + }, + } + quantization_config = PostTrainingQuantConfig(approach="weight_only", op_type_dict=op_type_dict) + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + quantizer = INCQuantizer.from_pretrained(model, task="text-generation") + calibration_dataset = _generate_dataset(quantizer, tokenizer, num_samples=2) + + with tempfile.TemporaryDirectory() as tmp_dir: + quantizer.quantize( + quantization_config=quantization_config, + calibration_dataset=calibration_dataset, + save_directory=tmp_dir, + weight_only=True, + ) + q_model = AutoModelForCausalLM.from_pretrained(tmp_dir) + inp = torch.tensor([calibration_dataset[0]["input_ids"]]) + out = model(inp)[0] + q_out = q_model(inp)[0] + self.assertTrue(torch.all(torch.isclose(out, q_out, atol=5e-1))) + + op_type_dict = { + ".*": { + "weight": { + "bits": 8, + "group_size": -1, + "scheme": "sym", + "algorithm": "AWQ", + }, + }, + } + quantization_config = PostTrainingQuantConfig(approach="weight_only", op_type_dict=op_type_dict) + + with tempfile.TemporaryDirectory() as tmp_dir: + quantizer.quantize( + quantization_config=quantization_config, + calibration_dataset=calibration_dataset, + save_directory=tmp_dir, + weight_only=True, + ) + q_model = AutoModelForCausalLM.from_pretrained(tmp_dir) + inp = torch.tensor([calibration_dataset[0]["input_ids"]]) + out = model(inp)[0] + q_out = q_model(inp)[0] + self.assertTrue(torch.all(torch.isclose(out, q_out, atol=6e-1))) + + op_type_dict = { + ".*": { + "weight": { + "bits": 8, + "group_size": -1, + "scheme": "sym", + "algorithm": "GPTQ", + }, + }, + } + recipes = {"gptq_args": {"pad_max_length": len(calibration_dataset[0]["input_ids"])}} + quantization_config = PostTrainingQuantConfig( + approach="weight_only", op_type_dict=op_type_dict, recipes=recipes + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + quantizer.quantize( + quantization_config=quantization_config, + calibration_dataset=calibration_dataset, + save_directory=tmp_dir, + weight_only=True, + ) + q_model = AutoModelForCausalLM.from_pretrained(tmp_dir) + inp = torch.tensor([calibration_dataset[0]["input_ids"]]) + out = model(inp)[0] + q_out = q_model(inp)[0] + self.assertTrue(torch.all(torch.isclose(out, q_out, atol=5e-1))) + def test_dynamic_accuracy_strategy_quantization(self): model_name = "distilbert-base-cased-distilled-squad" model = AutoModelForQuestionAnswering.from_pretrained(model_name) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 4b11435e0e..a4bf9b38e0 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -117,6 +117,9 @@ def test_load_from_hub_and_save_model(self): outputs = model(**tokens) self.assertTrue(torch.equal(loaded_model_outputs.logits, outputs.logits)) + del loaded_model + del model + gc.collect() def test_load_from_hub_and_save_decoder_model(self): tokenizer = AutoTokenizer.from_pretrained(self.OV_DECODER_MODEL_ID) @@ -134,6 +137,9 @@ def test_load_from_hub_and_save_decoder_model(self): outputs = model(**tokens) self.assertTrue(torch.equal(loaded_model_outputs.logits, outputs.logits)) + del loaded_model + del model + gc.collect() def test_load_from_hub_and_save_seq2seq_model(self): tokenizer = AutoTokenizer.from_pretrained(self.OV_SEQ2SEQ_MODEL_ID) @@ -153,6 +159,9 @@ def test_load_from_hub_and_save_seq2seq_model(self): outputs = model.generate(**tokens) self.assertTrue(torch.equal(loaded_model_outputs, outputs)) + del loaded_model + del model + gc.collect() @require_diffusers def test_load_from_hub_and_save_stable_diffusion_model(self): @@ -186,6 +195,8 @@ def test_load_from_hub_and_save_stable_diffusion_model(self): np.random.seed(0) outputs = pipeline(**inputs).images self.assertTrue(np.array_equal(pipeline_outputs, outputs)) + del pipeline + gc.collect() class OVModelForSequenceClassificationIntegrationTest(unittest.TestCase): @@ -228,6 +239,9 @@ def test_compare_to_transformers(self, model_arch): self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type]) # Compare tensor outputs self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) + del transformers_model + del ov_model + gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): @@ -257,6 +271,8 @@ def test_pipeline(self, model_arch): self.assertTrue(not model.is_dynamic) self.assertGreaterEqual(outputs[0]["score"], 0.0) self.assertIsInstance(outputs[0]["label"], str) + del model + del pipe gc.collect() @@ -293,6 +309,8 @@ def test_compare_to_transformers(self, model_arch): self.assertTrue( torch.allclose(torch.Tensor(ov_outputs.end_logits), transformers_outputs.end_logits, atol=1e-4) ) + del ov_model + del transformers_model gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -307,6 +325,7 @@ def test_pipeline(self, model_arch): self.assertEqual(pipe.device, model.device) self.assertGreaterEqual(outputs["score"], 0.0) self.assertIsInstance(outputs["answer"], str) + del model gc.collect() def test_metric(self): @@ -323,6 +342,10 @@ def test_metric(self): ov_metric = task_evaluator.compute(model_or_pipeline=ov_pipe, data=data, metric="squad") self.assertEqual(ov_metric["exact_match"], transformers_metric["exact_match"]) self.assertEqual(ov_metric["f1"], transformers_metric["f1"]) + del transformers_pipe + del transformers_model + del ov_pipe + del ov_model gc.collect() @@ -352,6 +375,8 @@ def test_compare_to_transformers(self, model_arch): self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type]) # Compare tensor outputs self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) + del transformers_model + del ov_model gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -363,6 +388,8 @@ def test_pipeline(self, model_arch): outputs = pipe("My Name is Arthur and I live in Lyon.") self.assertEqual(pipe.device, model.device) self.assertTrue(all(item["score"] > 0.0 for item in outputs)) + del model + del pipe gc.collect() @@ -396,6 +423,8 @@ def test_compare_to_transformers(self, model_arch): torch.Tensor(ov_outputs.last_hidden_state), transformers_outputs.last_hidden_state, atol=1e-4 ) ) + del transformers_model + del ov_model gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -407,6 +436,8 @@ def test_pipeline(self, model_arch): outputs = pipe("My Name is Arthur and I live in Lyon.") self.assertEqual(pipe.device, model.device) self.assertTrue(all(all(isinstance(item, float) for item in row) for row in outputs[0])) + del pipe + del model gc.collect() @@ -448,6 +479,8 @@ def test_compare_to_transformers(self, model_arch): transformers_outputs = transformers_model(**tokens) # Compare tensor outputs self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4)) + del transformers_model + del ov_model gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -463,6 +496,8 @@ def test_pipeline(self, model_arch): outputs = pipe("This is a sample", max_length=10) self.assertEqual(pipe.device, model.device) self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) + del pipe + del model gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -478,6 +513,8 @@ def test_multiple_inputs(self, model_arch): outputs = model.generate(**tokens, generation_config=generation_config) self.assertIsInstance(outputs, torch.Tensor) self.assertEqual(outputs.shape[0], 3) + del model + gc.collect() def test_model_and_decoder_same_device(self): model_id = MODEL_NAMES["gpt2"] @@ -486,6 +523,8 @@ def test_model_and_decoder_same_device(self): self.assertEqual(model._device, "TEST") # Verify that request is being reset self.assertEqual(model.request, None) + del model + gc.collect() def test_compare_with_and_without_past_key_values(self): model_id = MODEL_NAMES["gpt2"] @@ -515,6 +554,9 @@ def test_compare_with_and_without_past_key_values(self): f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms," f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", ) + del model_with_pkv + del model_without_pkv + gc.collect() class OVModelForMaskedLMIntegrationTest(unittest.TestCase): @@ -535,7 +577,7 @@ class OVModelForMaskedLMIntegrationTest(unittest.TestCase): "roformer", "squeezebert", "xlm", - # "xlm_roberta", + "xlm_roberta", ) @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -557,6 +599,8 @@ def test_compare_to_transformers(self, model_arch): self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type]) # Compare tensor outputs self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) + del transformers_model + del ov_model gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -568,6 +612,8 @@ def test_pipeline(self, model_arch): outputs = pipe(f"This is a {tokenizer.mask_token}.") self.assertEqual(pipe.device, model.device) self.assertTrue(all(item["score"] > 0.0 for item in outputs)) + del pipe + del model gc.collect() @@ -610,6 +656,8 @@ def test_compare_to_transformers(self, model_arch): self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type]) # Compare tensor outputs self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) + del transformers_model + del ov_model gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -622,6 +670,8 @@ def test_pipeline(self, model_arch): self.assertEqual(pipe.device, model.device) self.assertGreaterEqual(outputs[0]["score"], 0.0) self.assertTrue(isinstance(outputs[0]["label"], str)) + del model + del pipe gc.collect() @parameterized.expand(TIMM_MODELS) @@ -703,6 +753,8 @@ def test_compare_to_transformers(self, model_arch): transformers_outputs = transformers_model(**tokens, **decoder_inputs) # Compare tensor outputs self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4)) + del transformers_model + del ov_model gc.collect() @@ -735,7 +787,8 @@ def test_pipeline(self, model_arch): outputs = pipe(text) self.assertEqual(pipe.device, model.device) self.assertIsInstance(outputs[0]["translation_text"], str) - + del pipe + del model gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -755,6 +808,7 @@ def test_generate_utils(self, model_arch): outputs = model.generate(input_ids=tokens["input_ids"]) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) self.assertIsInstance(outputs[0], str) + del model gc.collect() @@ -786,6 +840,9 @@ def test_compare_with_and_without_past_key_values(self): f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms," f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", ) + del model_with_pkv + del model_without_pkv + gc.collect() class OVModelForAudioClassificationIntegrationTest(unittest.TestCase): @@ -831,6 +888,10 @@ def test_compare_to_transformers(self, model_arch): # Compare tensor outputs self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-3)) + del transformers_model + del ov_model + gc.collect() + @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] @@ -840,6 +901,9 @@ def test_pipeline(self, model_arch): outputs = pipe([np.random.random(16000)]) self.assertEqual(pipe.device, model.device) self.assertTrue(all(item["score"] > 0.0 for item in outputs[0])) + del pipe + del model + gc.collect() class OVModelForCTCIntegrationTest(unittest.TestCase): @@ -893,6 +957,8 @@ def test_compare_to_transformers(self, model_arch): # compare tensor outputs self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) + del transformers_model + del ov_model gc.collect() @@ -945,6 +1011,8 @@ def test_compare_to_transformers(self, model_arch): torch.allclose(torch.Tensor(ov_outputs.embeddings), transformers_outputs.embeddings, atol=1e-4) ) + del transformers_model + del ov_model gc.collect() @@ -994,4 +1062,6 @@ def test_compare_to_transformers(self, model_arch): # compare tensor outputs self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) + del transformers_model + del ov_model gc.collect() diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index da9ba3b25a..369ad0f836 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -64,8 +64,8 @@ def get_num_quantized_nodes(ov_model): class OVQuantizerTest(unittest.TestCase): # TODO : add models SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = ( - (OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 42, 32), - (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 41, 21), + (OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 32, 35), + (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 41, 22), ) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) @@ -146,8 +146,8 @@ def preprocess_function(examples, tokenizer): class OVWeightCompressionTest(unittest.TestCase): # TODO : add models SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS = ( - (OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 39), - (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 5), + (OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 70), + (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45), ) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS) @@ -173,9 +173,8 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_int8 self.assertTrue("logits" in outputs) # Verify that that the configuration is correctly saved and loaded - expected_config = OVConfig(compression=INT8_WEIGHT_COMPRESSION_CONFIG) loaded_config = OVConfig.from_pretrained(tmp_dir) - self.assertEqual(expected_config.to_dict()["compression"], loaded_config.to_dict()["compression"]) + self.assertIsNotNone(loaded_config) class OVQuantizerQATest(unittest.TestCase): diff --git a/tests/openvino/test_stable_diffusion.py b/tests/openvino/test_stable_diffusion.py index e04e2d6fd3..781fbe0ec6 100644 --- a/tests/openvino/test_stable_diffusion.py +++ b/tests/openvino/test_stable_diffusion.py @@ -25,7 +25,8 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline, ) -from diffusers.utils import floats_tensor, load_image +from diffusers.utils import load_image +from diffusers.utils.testing_utils import floats_tensor from openvino.runtime.ie_api import CompiledModel from parameterized import parameterized from utils_tests import MODEL_NAMES, SEED