diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index b2d33e7647..75d8db8f00 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -68,6 +68,8 @@ def parse_args_openvino(parser: "ArgumentParser"): "This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it." ), ) + optional_group.add_argument("--fp16", action="store_true", help="Compress weights to fp16"), + optional_group.add_argument("--int8", action="store_true", help="Compress weights to int8"), class OVExportCommand(BaseOptimumCLICommand): @@ -102,5 +104,7 @@ def run(self): cache_dir=self.args.cache_dir, trust_remote_code=self.args.trust_remote_code, pad_token_id=self.args.pad_token_id, + fp16=self.args.fp16, + int8=self.args.int8, # **input_shapes, ) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 3baa9119a1..bc6d942c24 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -19,7 +19,6 @@ from requests.exceptions import ConnectionError as RequestsConnectionError from transformers import AutoTokenizer -from transformers.utils import is_torch_available from optimum.exporters import TasksManager from optimum.exporters.onnx import __main__ as optimum_main @@ -34,13 +33,10 @@ OV_XML_FILE_NAME = "openvino_model.xml" -_MAX_UNCOMPRESSED_DECODER_SIZE = 1e9 +_MAX_UNCOMPRESSED_SIZE = 1e9 logger = logging.getLogger(__name__) -if is_torch_available(): - import torch - def main_export( model_name_or_path: str, @@ -60,6 +56,7 @@ def main_export( model_kwargs: Optional[Dict[str, Any]] = None, custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None, fn_get_submodels: Optional[Callable] = None, + int8: Optional[bool] = None, **kwargs_shapes, ): """ @@ -126,6 +123,13 @@ def main_export( >>> main_export("gpt2", output="gpt2_onnx/") ``` """ + if int8 and not is_nncf_available(): + raise ImportError( + "Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`" + ) + + model_kwargs = model_kwargs or {} + output = Path(output) if not output.exists(): output.mkdir(parents=True) @@ -142,8 +146,6 @@ def main_export( kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] ) - torch_dtype = None if fp16 is False else torch.float16 - if task == "auto": try: task = TasksManager.infer_task_from_model(model_name_or_path) @@ -167,7 +169,6 @@ def main_export( force_download=force_download, trust_remote_code=trust_remote_code, framework=framework, - torch_dtype=torch_dtype, device=device, ) @@ -235,17 +236,19 @@ def main_export( onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) onnx_config = onnx_config_constructor(model.config) models_and_onnx_configs = {"model": (model, onnx_config)} - model_kwargs = model_kwargs or {} - load_in_8bit = model_kwargs.get("load_in_8bit", None) - if load_in_8bit is None: - if model.num_parameters() >= _MAX_UNCOMPRESSED_DECODER_SIZE: - if not is_nncf_available(): - logger.warning( - "The model will be converted with no weights quantization. Quantization of the weights to int8 requires nncf." - "please install it with `pip install nncf`" - ) - else: - model_kwargs["load_in_8bit"] = True + + if int8 is None: + int8 = False + num_parameters = model.num_parameters() if not is_stable_diffusion else model.unet.num_parameters() + if num_parameters >= _MAX_UNCOMPRESSED_SIZE: + if is_nncf_available(): + int8 = True + logger.info("The model weights will be quantized to int8.") + else: + logger.warning( + "The model will be converted with no weights quantization. Quantization of the weights to int8 requires nncf." + "please install it with `pip install nncf`" + ) if not is_stable_diffusion: needs_pad_token_id = ( @@ -313,5 +316,7 @@ def main_export( output_names=files_subpaths, input_shapes=input_shapes, device=device, + fp16=fp16, + int8=int8, model_kwargs=model_kwargs, ) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index ab4a41e873..14636f1f77 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -74,6 +74,8 @@ def export( device: str = "cpu", input_shapes: Optional[Dict] = None, model_kwargs: Optional[Dict[str, Any]] = None, + fp16: bool = False, + int8: bool = False, ) -> Tuple[List[str], List[str]]: """ Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation. @@ -115,6 +117,8 @@ def export( device=device, input_shapes=input_shapes, model_kwargs=model_kwargs, + fp16=fp16, + int8=int8, ) elif is_tf_available() and issubclass(type(model), TFPreTrainedModel): @@ -133,7 +137,12 @@ def export( ) -def export_tensorflow(model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, opset: int, output: Path): +def export_tensorflow( + model: Union["PreTrainedModel", "ModelMixin"], + config: OnnxConfig, + opset: int, + output: Path, +): """ Export the TensorFlow model to OpenVINO format. @@ -163,6 +172,8 @@ def export_pytorch_via_onnx( device: str = "cpu", input_shapes: Optional[Dict] = None, model_kwargs: Optional[Dict[str, Any]] = None, + fp16: bool = False, + int8: bool = False, ): """ Exports a PyTorch model to an OpenVINO Intermediate Representation via ONNX export. @@ -201,12 +212,11 @@ def export_pytorch_via_onnx( ) torch.onnx.export = orig_torch_onnx_export ov_model = convert_model(str(onnx_output)) - load_in_8bit = False if model_kwargs is None else model_kwargs.get("load_in_8bit", False) _save_model( ov_model, output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output, - compress_to_fp16=False, - load_in_8bit=load_in_8bit, + compress_to_fp16=fp16, + load_in_8bit=int8, ) return input_names, output_names, True @@ -219,6 +229,8 @@ def export_pytorch( device: str = "cpu", input_shapes: Optional[Dict] = None, model_kwargs: Optional[Dict[str, Any]] = None, + fp16: bool = False, + int8: bool = False, ) -> Tuple[List[str], List[str]]: """ Exports a PyTorch model to an OpenVINO Intermediate Representation. @@ -313,7 +325,9 @@ def ts_patched_forward(*args, **kwargs): ov_model = convert_model(model, example_input=dummy_inputs, input=input_info) except Exception as ex: logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX") - return export_pytorch_via_onnx(model, config, opset, output, device, input_shapes, model_kwargs) + return export_pytorch_via_onnx( + model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16, int8=int8 + ) ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs} ordered_input_names = list(inputs) flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values()) @@ -334,8 +348,7 @@ def ts_patched_forward(*args, **kwargs): inp_tensor.get_node().set_partial_shape(static_shape) inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) ov_model.validate_nodes_and_infer_types() - load_in_8bit = False if model_kwargs is None else model_kwargs.get("load_in_8bit", False) - _save_model(ov_model, output, compress_to_fp16=False, load_in_8bit=load_in_8bit) + _save_model(ov_model, output, compress_to_fp16=fp16, load_in_8bit=int8) clear_class_registry() del model gc.collect() @@ -352,6 +365,8 @@ def export_models( device: str = "cpu", input_shapes: Optional[Dict] = None, model_kwargs: Optional[Dict[str, Any]] = None, + fp16: bool = False, + int8: bool = False, ) -> Tuple[List[List[str]], List[List[str]]]: """ Export the models to OpenVINO IR format @@ -396,6 +411,8 @@ def export_models( device=device, input_shapes=input_shapes, model_kwargs=model_kwargs, + fp16=fp16, + int8=int8, ) ) diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index c477d487a2..58eb2163d0 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -29,7 +29,7 @@ from optimum.modeling_base import OptimizedModel from ...exporters.openvino import export, main_export -from ..utils.import_utils import is_transformers_version +from ..utils.import_utils import is_nncf_available, is_transformers_version from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME @@ -93,7 +93,7 @@ def __init__( self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None @staticmethod - def load_model(file_name: Union[str, Path]): + def load_model(file_name: Union[str, Path], load_in_8bit: bool = False): """ Loads the model. @@ -120,6 +120,15 @@ def fix_op_names_duplicates(model: openvino.runtime.Model): if file_name.suffix == ".onnx": model = fix_op_names_duplicates(model) # should be called during model conversion to IR + if load_in_8bit: + if not is_nncf_available(): + raise ImportError( + "Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`" + ) + import nncf + + model = nncf.compress_weights(model) + return model def _save_pretrained(self, save_directory: Union[str, Path]): @@ -146,6 +155,7 @@ def _from_pretrained( file_name: Optional[str] = None, from_onnx: bool = False, local_files_only: bool = False, + load_in_8bit: bool = False, **kwargs, ): """ @@ -203,7 +213,8 @@ def _from_pretrained( model_save_dir = Path(model_cache_path).parent file_name = file_names[0] - model = cls.load_model(file_name) + model = cls.load_model(file_name, load_in_8bit=load_in_8bit) + return cls(model, config=config, model_save_dir=model_save_dir, **kwargs) @classmethod @@ -219,6 +230,7 @@ def _from_transformers( local_files_only: bool = False, task: Optional[str] = None, trust_remote_code: bool = False, + load_in_8bit: bool = False, **kwargs, ): """ @@ -253,10 +265,11 @@ def _from_transformers( local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, + int8=load_in_8bit, ) config.save_pretrained(save_dir_path) - return cls._from_pretrained(model_id=save_dir_path, config=config, **kwargs) + return cls._from_pretrained(model_id=save_dir_path, config=config, load_in_8bit=load_in_8bit, **kwargs) @classmethod def _to_load( diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index bedd63af6d..527adc4347 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -119,6 +119,7 @@ def _from_pretrained( local_files_only: bool = False, use_cache: bool = True, from_onnx: bool = False, + load_in_8bit: bool = False, **kwargs, ): """ @@ -159,14 +160,14 @@ def _from_pretrained( encoder_file_name = encoder_file_name or default_encoder_file_name decoder_file_name = decoder_file_name or default_decoder_file_name decoder_with_past_file_name = decoder_with_past_file_name or default_decoder_with_past_file_name - + decoder_with_past = None # Load model from a local directory if os.path.isdir(model_id): - encoder = cls.load_model(os.path.join(model_id, encoder_file_name)) - decoder = cls.load_model(os.path.join(model_id, decoder_file_name)) - decoder_with_past = ( - cls.load_model(os.path.join(model_id, decoder_with_past_file_name)) if use_cache else None - ) + encoder = cls.load_model(os.path.join(model_id, encoder_file_name), load_in_8bit) + decoder = cls.load_model(os.path.join(model_id, decoder_file_name), load_in_8bit) + if use_cache: + decoder_with_past = cls.load_model(os.path.join(model_id, decoder_with_past_file_name), load_in_8bit) + model_save_dir = Path(model_id) # Load model from hub @@ -193,9 +194,10 @@ def _from_pretrained( file_names[name] = model_cache_path model_save_dir = Path(model_cache_path).parent - encoder = cls.load_model(file_names["encoder"]) - decoder = cls.load_model(file_names["decoder"]) - decoder_with_past = cls.load_model(file_names["decoder_with_past"]) if use_cache else None + encoder = cls.load_model(file_names["encoder"], load_in_8bit) + decoder = cls.load_model(file_names["decoder"], load_in_8bit) + if use_cache: + decoder_with_past = cls.load_model(file_names["decoder_with_past"], load_in_8bit) return cls( encoder=encoder, @@ -220,6 +222,7 @@ def _from_transformers( task: Optional[str] = None, use_cache: bool = True, trust_remote_code: bool = False, + load_in_8bit: bool = False, **kwargs, ): """ @@ -261,10 +264,13 @@ def _from_transformers( local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, + int8=load_in_8bit, ) config.save_pretrained(save_dir_path) - return cls._from_pretrained(model_id=save_dir_path, config=config, use_cache=use_cache, **kwargs) + return cls._from_pretrained( + model_id=save_dir_path, config=config, use_cache=use_cache, load_in_8bit=load_in_8bit, **kwargs + ) def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_length: int, is_decoder=True): shapes = {} diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 91a2c7ddc2..68d737fe74 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -210,6 +210,7 @@ def _from_transformers( task: Optional[str] = None, use_cache: bool = True, trust_remote_code: bool = False, + load_in_8bit: bool = False, **kwargs, ): if config.model_type not in _SUPPORTED_ARCHITECTURES: @@ -238,12 +239,15 @@ def _from_transformers( force_download=force_download, trust_remote_code=trust_remote_code, model_kwargs=kwargs, + int8=load_in_8bit, ) config.is_decoder = True config.is_encoder_decoder = False config.save_pretrained(save_dir_path) - return cls._from_pretrained(model_id=save_dir_path, config=config, use_cache=use_cache, **kwargs) + return cls._from_pretrained( + model_id=save_dir_path, config=config, use_cache=use_cache, load_in_8bit=load_in_8bit, **kwargs + ) def _reshape( self, diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index 1ca517397d..1ca0b93643 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -190,6 +190,7 @@ def _from_pretrained( local_files_only: bool = False, from_onnx: bool = False, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + load_in_8bit: bool = False, **kwargs, ): default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME @@ -252,7 +253,9 @@ def _from_pretrained( else: kwargs[name] = load_method(new_model_save_dir) - unet = cls.load_model(new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name) + unet = cls.load_model( + new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, load_in_8bit=load_in_8bit + ) components = { "vae_encoder": new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name, @@ -262,7 +265,7 @@ def _from_pretrained( } for key, value in components.items(): - components[key] = cls.load_model(value) if value.is_file() else None + components[key] = cls.load_model(value, load_in_8bit=load_in_8bit) if value.is_file() else None if model_save_dir is None: model_save_dir = new_model_save_dir @@ -295,6 +298,7 @@ def _from_transformers( tokenizer: "CLIPTokenizer" = None, scheduler: Union["DDIMScheduler", "PNDMScheduler", "LMSDiscreteScheduler"] = None, feature_extractor: Optional["CLIPFeatureExtractor"] = None, + load_in_8bit: bool = False, **kwargs, ): save_dir = TemporaryDirectory() @@ -311,6 +315,7 @@ def _from_transformers( use_auth_token=use_auth_token, local_files_only=local_files_only, force_download=force_download, + int8=load_in_8bit, ) return cls._from_pretrained( @@ -326,6 +331,7 @@ def _from_transformers( tokenizer=tokenizer, scheduler=scheduler, feature_extractor=feature_extractor, + load_in_8bit=load_in_8bit, **kwargs, ) diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index d143c4c3cc..d2b9960258 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -16,7 +16,7 @@ from tempfile import TemporaryDirectory from parameterized import parameterized -from utils_tests import MODEL_NAMES +from utils_tests import _ARCHITECTURES_TO_EXPECTED_INT8, MODEL_NAMES, get_num_quantized_nodes from optimum.exporters.openvino.__main__ import main_export from optimum.intel import ( # noqa @@ -41,25 +41,25 @@ class OVCLIExportTestCase(unittest.TestCase): """ SUPPORTED_ARCHITECTURES = ( - ["text-generation", "gpt2"], - ["text-generation-with-past", "gpt2"], - ["text2text-generation", "t5"], - ["text2text-generation-with-past", "t5"], - ["text-classification", "bert"], - ["question-answering", "distilbert"], - ["token-classification", "roberta"], - ["image-classification", "vit"], - ["audio-classification", "wav2vec2"], - ["fill-mask", "bert"], - ["feature-extraction", "blenderbot"], - ["stable-diffusion", "stable-diffusion"], - ["stable-diffusion-xl", "stable-diffusion-xl"], - ["stable-diffusion-xl", "stable-diffusion-xl-refiner"], + ("text-generation", "gpt2"), + ("text-generation-with-past", "gpt2"), + ("text2text-generation", "t5"), + ("text2text-generation-with-past", "t5"), + ("text-classification", "albert"), + ("question-answering", "distilbert"), + ("token-classification", "roberta"), + ("image-classification", "vit"), + ("audio-classification", "wav2vec2"), + ("fill-mask", "bert"), + ("feature-extraction", "blenderbot"), + ("stable-diffusion", "stable-diffusion"), + ("stable-diffusion-xl", "stable-diffusion-xl"), + ("stable-diffusion-xl", "stable-diffusion-xl-refiner"), ) - def _openvino_export(self, model_name: str, task: str): + def _openvino_export(self, model_name: str, task: str, fp16: bool = False, int8: bool = False): with TemporaryDirectory() as tmpdir: - main_export(model_name_or_path=model_name, output=tmpdir, task=task) + main_export(model_name_or_path=model_name, output=tmpdir, task=task, fp16=fp16, int8=int8) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_export(self, task: str, model_type: str): @@ -75,3 +75,40 @@ def test_exporters_cli(self, task: str, model_type: str): ) model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {} eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_exporters_cli_fp16(self, task: str, model_type: str): + with TemporaryDirectory() as tmpdir: + subprocess.run( + f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --fp16 {tmpdir}", + shell=True, + check=True, + ) + model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {} + eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_exporters_cli_int8(self, task: str, model_type: str): + with TemporaryDirectory() as tmpdir: + subprocess.run( + f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --int8 {tmpdir}", + shell=True, + check=True, + ) + model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {} + model = eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs) + + if task.startswith("text2text-generation"): + models = [model.encoder, model.decoder] + if task.endswith("with-past"): + models.append(model.decoder_with_past) + elif task.startswith("stable-diffusion"): + models = [model.unet, model.vae_encoder, model.vae_decoder] + models.append(model.text_encoder if task == "stable-diffusion" else model.text_encoder_2) + else: + models = [model] + + expected_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type] + for i, model in enumerate(models): + _, num_int8 = get_num_quantized_nodes(model) + self.assertEqual(expected_int8[i], num_int8) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index a2fb38e153..f09bd35acd 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -249,8 +249,7 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - # TODO : Replace from_transformers with export for optimum-intel v1.8 - model = OVModelForSequenceClassification.from_pretrained(model_id, from_transformers=True, compile=False) + model = OVModelForSequenceClassification.from_pretrained(model_id, export=True, compile=False) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("text-classification", model=model, tokenizer=tokenizer) text = "This restaurant is awesome" @@ -319,7 +318,7 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = OVModelForQuestionAnswering.from_pretrained(model_id, from_transformers=True) + model = OVModelForQuestionAnswering.from_pretrained(model_id, export=True) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("question-answering", model=model, tokenizer=tokenizer) question = "What's my name?" @@ -334,7 +333,7 @@ def test_pipeline(self, model_arch): def test_metric(self): model_id = "distilbert-base-cased-distilled-squad" set_seed(SEED) - ov_model = OVModelForQuestionAnswering.from_pretrained(model_id, from_transformers=True) + ov_model = OVModelForQuestionAnswering.from_pretrained(model_id, export=True) transformers_model = AutoModelForQuestionAnswering.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) data = load_dataset("squad", split="validation").select(range(50)) @@ -385,7 +384,7 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = OVModelForTokenClassification.from_pretrained(model_id, from_transformers=True) + model = OVModelForTokenClassification.from_pretrained(model_id, export=True) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("token-classification", model=model, tokenizer=tokenizer) outputs = pipe("My Name is Arthur and I live in Lyon.") @@ -433,7 +432,7 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = OVModelForFeatureExtraction.from_pretrained(model_id, from_transformers=True) + model = OVModelForFeatureExtraction.from_pretrained(model_id, export=True) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("feature-extraction", model=model, tokenizer=tokenizer) outputs = pipe("My Name is Arthur and I live in Lyon.") @@ -490,7 +489,7 @@ def test_compare_to_transformers(self, model_arch): def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) - model = OVModelForCausalLM.from_pretrained(model_id, from_transformers=True, use_cache=False, compile=False) + model = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=False, compile=False) model.config.encoder_no_repeat_ngram_size = 0 model.to("cpu") model.half() @@ -619,7 +618,7 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = OVModelForMaskedLM.from_pretrained(model_id, from_transformers=True) + model = OVModelForMaskedLM.from_pretrained(model_id, export=True) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer) outputs = pipe(f"This is a {tokenizer.mask_token}.") @@ -676,7 +675,7 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = OVModelForImageClassification.from_pretrained(model_id, from_transformers=True) + model = OVModelForImageClassification.from_pretrained(model_id, export=True) preprocessor = AutoFeatureExtractor.from_pretrained(model_id) pipe = pipeline("image-classification", model=model, feature_extractor=preprocessor) outputs = pipe("http://images.cocodataset.org/val2017/000000039769.jpg") @@ -771,7 +770,7 @@ def test_compare_to_transformers(self, model_arch): def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) - model = OVModelForSeq2SeqLM.from_pretrained(model_id, from_transformers=True, compile=False) + model = OVModelForSeq2SeqLM.from_pretrained(model_id, export=True, compile=False) model.half() model.to("cpu") model.compile() @@ -803,7 +802,7 @@ def test_pipeline(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_generate_utils(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = OVModelForSeq2SeqLM.from_pretrained(model_id, from_transformers=True) + model = OVModelForSeq2SeqLM.from_pretrained(model_id, export=True) tokenizer = AutoTokenizer.from_pretrained(model_id) text = "This is a sample input" tokens = tokenizer(text, return_tensors="pt") @@ -827,14 +826,14 @@ def test_compare_with_and_without_past_key_values(self): text = "This is a sample input" tokens = tokenizer(text, return_tensors="pt") - model_with_pkv = OVModelForSeq2SeqLM.from_pretrained(model_id, from_transformers=True, use_cache=True) + model_with_pkv = OVModelForSeq2SeqLM.from_pretrained(model_id, export=True, use_cache=True) _ = model_with_pkv.generate(**tokens) # warmup with Timer() as with_pkv_timer: outputs_model_with_pkv = model_with_pkv.generate( **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 ) - model_without_pkv = OVModelForSeq2SeqLM.from_pretrained(model_id, from_transformers=True, use_cache=False) + model_without_pkv = OVModelForSeq2SeqLM.from_pretrained(model_id, export=True, use_cache=False) _ = model_without_pkv.generate(**tokens) # warmup with Timer() as without_pkv_timer: outputs_model_without_pkv = model_without_pkv.generate( @@ -904,7 +903,7 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = OVModelForAudioClassification.from_pretrained(model_id, from_transformers=True) + model = OVModelForAudioClassification.from_pretrained(model_id, export=True) preprocessor = AutoFeatureExtractor.from_pretrained(model_id) pipe = pipeline("audio-classification", model=model, feature_extractor=preprocessor) outputs = pipe([np.random.random(16000)]) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 6563eed7d8..c1ec95ea9b 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -34,14 +34,24 @@ from optimum.intel import ( OVConfig, + OVModelForAudioClassification, + OVModelForCausalLM, + OVModelForFeatureExtraction, + OVModelForImageClassification, + OVModelForMaskedLM, OVModelForQuestionAnswering, + OVModelForSeq2SeqLM, OVModelForSequenceClassification, OVModelForTokenClassification, - OVModelForCausalLM, + OVStableDiffusionPipeline, + OVStableDiffusionXLPipeline, OVQuantizer, OVTrainer, ) + + from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG +from utils_tests import MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT8 _TASK_TO_DATASET = { "text-generation": ("wikitext", "wikitext-2-raw-v1", "text"), @@ -49,18 +59,6 @@ } -def get_num_quantized_nodes(ov_model): - num_fake_quantize = 0 - num_int8 = 0 - for elem in ov_model.model.get_ops(): - if "FakeQuantize" in elem.name: - num_fake_quantize += 1 - for i in range(elem.get_output_size()): - if "8" in elem.get_output_element_type(i).get_type_name(): - num_int8 += 1 - return num_fake_quantize, num_int8 - - class OVQuantizerTest(unittest.TestCase): # TODO : add models SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = ( @@ -150,7 +148,19 @@ class OVWeightCompressionTest(unittest.TestCase): (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45, 22), ) - UPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION = ((OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 22),) + SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION = ( + (OVModelForCausalLM, "gpt2"), + (OVModelForMaskedLM, "bert"), + (OVModelForTokenClassification, "roberta"), + (OVModelForImageClassification, "vit"), + (OVModelForSeq2SeqLM, "t5"), + (OVModelForSequenceClassification, "albert"), + (OVModelForQuestionAnswering, "distilbert"), + (OVModelForAudioClassification, "wav2vec2"), + (OVModelForFeatureExtraction, "blenderbot"), + (OVStableDiffusionPipeline, "stable-diffusion"), + (OVStableDiffusionXLPipeline, "stable-diffusion-xl"), + ) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS) def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8): @@ -166,7 +176,6 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_i quantizer.quantize(save_directory=tmp_dir, weights_only=True) model = model_cls.from_pretrained(tmp_dir) - # TODO: uncomment once move to a newer version of NNCF which has some fixes _, num_int8 = get_num_quantized_nodes(model) self.assertEqual(expected_pt_int8, num_int8) @@ -199,17 +208,38 @@ def test_ovmodel_weight_compression(self, model_cls, model_name, expected_pt_int outputs = model(**tokens) self.assertTrue("logits" in outputs) - @parameterized.expand(UPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) - def test_ovmodel_load_with_compressed_weights(self, model_cls, model_name, expected_ov_int8): - model = model_cls.from_pretrained(model_name, export=True, load_in_8bit=True) - _, num_int8 = get_num_quantized_nodes(model) - self.assertEqual(expected_ov_int8, num_int8) - - @parameterized.expand(UPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) - def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_name, expected_ov_int8): - model = model_cls.from_pretrained(model_name, export=True, load_in_8bit=False) - _, num_int8 = get_num_quantized_nodes(model) - self.assertEqual(0, num_int8) + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) + def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type): + model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=True) + + if model.export_feature.startswith("text2text-generation"): + models = [model.encoder, model.decoder, model.decoder_with_past] + elif model.export_feature.startswith("stable-diffusion"): + models = [model.unet, model.vae_encoder, model.vae_decoder] + models.append(model.text_encoder if model.export_feature == "stable-diffusion" else model.text_encoder_2) + else: + models = [model] + + expected_ov_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type] + for i, model in enumerate(models): + _, num_int8 = get_num_quantized_nodes(model) + self.assertEqual(expected_ov_int8[i], num_int8) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) + def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type): + model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=False) + + if model.export_feature.startswith("text2text-generation"): + models = [model.encoder, model.decoder, model.decoder_with_past] + elif model.export_feature.startswith("stable-diffusion"): + models = [model.unet, model.vae_encoder, model.vae_decoder] + models.append(model.text_encoder if model.export_feature == "stable-diffusion" else model.text_encoder_2) + else: + models = [model] + + for i, model in enumerate(models): + _, num_int8 = get_num_quantized_nodes(model) + self.assertEqual(0, num_int8) class OVQuantizerQATest(unittest.TestCase): diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 94643b02f4..091548c4b1 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -92,3 +92,31 @@ } SEED = 42 + + +_ARCHITECTURES_TO_EXPECTED_INT8 = { + "bert": (34,), + "roberta": (34,), + "albert": (42,), + "vit": (31,), + "blenderbot": (35,), + "gpt2": (22,), + "wav2vec2": (15,), + "distilbert": (33,), + "t5": (32, 52, 42), + "stable-diffusion": (74, 4, 4, 32), + "stable-diffusion-xl": (148, 4, 4, 33), + "stable-diffusion-xl-refiner": (148, 4, 4, 33), +} + + +def get_num_quantized_nodes(ov_model): + num_fake_quantize = 0 + num_int8 = 0 + for elem in ov_model.model.get_ops(): + if "FakeQuantize" in elem.name: + num_fake_quantize += 1 + for i in range(elem.get_output_size()): + if "8" in elem.get_output_element_type(i).get_type_name(): + num_int8 += 1 + return num_fake_quantize, num_int8