diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml index cb58f412a6..d43cabe323 100644 --- a/.github/workflows/test_openvino.yml +++ b/.github/workflows/test_openvino.yml @@ -36,3 +36,9 @@ jobs: - name: Test with Pytest run: | pytest tests/openvino/ --ignore test_modeling_basic + - name: Test openvino-nightly import + run: | + pip uninstall -y openvino + pip install openvino-nightly + python -c "from optimum.intel import OVModelForCausalLM; OVModelForCausalLM.from_pretrained('hf-internal-testing/tiny-random-gpt2', export=True, compile=False)" + diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index d25c2c5f3a..cb011706c8 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, Optional, Union from requests.exceptions import ConnectionError as RequestsConnectionError -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from optimum.exporters import TasksManager from optimum.exporters.onnx import __main__ as optimum_main @@ -136,6 +136,41 @@ def main_export( original_task = task task = TasksManager.map_from_synonym(task) + # Patch the modules to export of GPTQ models w/o GPU + do_gptq_patching = False + try: + config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code) + config_dict = config.to_dict() + quantization_config = config_dict.get("quantization_config", None) + do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq" + except Exception: + pass + + if do_gptq_patching: + import torch + + torch.set_default_dtype(torch.float32) + orig_cuda_check = torch.cuda.is_available + torch.cuda.is_available = lambda: True + + from optimum.gptq import GPTQQuantizer + + orig_post_init_model = GPTQQuantizer.post_init_model + + def post_init_model(self, model): + from auto_gptq import exllama_set_max_input_length + + class StoreAttr(object): + pass + + model.quantize_config = StoreAttr() + model.quantize_config.desc_act = self.desc_act + if self.desc_act and not self.disable_exllama and self.max_input_length is not None: + model = exllama_set_max_input_length(model, self.max_input_length) + return model + + GPTQQuantizer.post_init_model = post_init_model + framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework) # get the shapes to be used to generate dummy inputs @@ -317,3 +352,8 @@ def main_export( int8=int8, model_kwargs=model_kwargs, ) + + # Unpatch modules after GPTQ export + if do_gptq_patching: + torch.cuda.is_available = orig_cuda_check + GPTQQuantizer.post_init_model = orig_post_init_model diff --git a/optimum/intel/neural_compressor/trainer.py b/optimum/intel/neural_compressor/trainer.py index 5ae4a1f72a..56a97dbdcc 100644 --- a/optimum/intel/neural_compressor/trainer.py +++ b/optimum/intel/neural_compressor/trainer.py @@ -22,6 +22,14 @@ from itertools import chain from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +# Integrations must be imported before ML frameworks: +# isort: off +from transformers.integrations import hp_params +from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available + +# isort: on + import datasets import torch import torch.distributed as dist @@ -36,6 +44,7 @@ from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow + # Integrations must be imported before ML frameworks: from transformers.integrations import deepspeed_init, deepspeed_load_checkpoint, hp_params, is_deepspeed_available from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype, unwrap_model @@ -129,6 +138,8 @@ def __init__( task: Optional[str] = None, save_onnx_model: bool = False, ): + self.neftune_noise_alpha = None + super().__init__( model, args, diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 5bd9437391..2ebf04979d 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -229,34 +229,6 @@ def _from_transformers( if use_cache: task = task + "-with-past" - # Patch the modules to export of GPTQ models w/o GPU - do_gptq_patching = False - config_dict = config.to_dict() - quantization_config = config_dict.get("quantization_config", None) - do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq" - if do_gptq_patching: - torch.set_default_dtype(torch.float32) - orig_cuda_check = torch.cuda.is_available - torch.cuda.is_available = lambda: True - - from optimum.gptq import GPTQQuantizer - - orig_post_init_model = GPTQQuantizer.post_init_model - - def post_init_model(self, model): - from auto_gptq import exllama_set_max_input_length - - class StoreAttr(object): - pass - - model.quantize_config = StoreAttr() - model.quantize_config.desc_act = self.desc_act - if self.desc_act and not self.disable_exllama and self.max_input_length is not None: - model = exllama_set_max_input_length(model, self.max_input_length) - return model - - GPTQQuantizer.post_init_model = post_init_model - main_export( model_name_or_path=model_id, output=save_dir_path, @@ -271,11 +243,6 @@ class StoreAttr(object): int8=load_in_8bit, ) - # Unpatch modules after GPTQ export - if do_gptq_patching: - torch.cuda.is_available = orig_cuda_check - GPTQQuantizer.post_init_model = orig_post_init_model - config.is_decoder = True config.is_encoder_decoder = False config.save_pretrained(save_dir_path) @@ -519,7 +486,7 @@ def _from_pretrained( elif model_type == "gpt-bigcode": init_cls = OVGPTBigCodeForCausalLM else: - init_cls = OVModelForCausalLM + init_cls = cls return init_cls(model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs) diff --git a/optimum/intel/openvino/trainer.py b/optimum/intel/openvino/trainer.py index 7cf7b017a1..7bbfc4599e 100644 --- a/optimum/intel/openvino/trainer.py +++ b/optimum/intel/openvino/trainer.py @@ -24,8 +24,15 @@ from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +# Integrations must be imported before ML frameworks: +# isort: off +from transformers.integrations import hp_params +from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available + +# isort: on + import openvino -import openvino.runtime import torch import torch.distributed as dist import torch.nn.functional as F @@ -190,6 +197,8 @@ def __init__( task: Optional[str] = None, feature: Optional[str] = None, ): + self.neftune_noise_alpha = None + super().__init__( model, args, @@ -821,12 +830,12 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): if state_dict is None: state_dict = self.model.state_dict() if is_pretrained_model: - unwrapped_model.save_pretrained(output_dir, state_dict=state_dict) + unwrapped_model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=False) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - self.model.save_pretrained(output_dir, state_dict=state_dict) + self.model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=False) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py index 87ddcc1315..d15780384f 100644 --- a/optimum/intel/utils/import_utils.py +++ b/optimum/intel/utils/import_utils.py @@ -71,7 +71,10 @@ try: _openvino_version = importlib_metadata.version("openvino") except importlib_metadata.PackageNotFoundError: - _openvino_available = False + try: + _openvino_version = importlib_metadata.version("openvino-nightly") + except importlib_metadata.PackageNotFoundError: + _openvino_available = False _nncf_available = importlib.util.find_spec("nncf") is not None diff --git a/setup.py b/setup.py index e2da72bf3e..49fcef2408 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ assert False, "Error: Could not open '%s' due %s\n" % (filepath, error) INSTALL_REQUIRE = [ +<<<<<<< HEAD "optimum @ git+https://github.com/huggingface/optimum.git", "transformers>=4.20.0", "datasets>=1.4.0", @@ -41,8 +42,9 @@ "neural-compressor>=2.2.0", "onnx", "onnxruntime<1.15.0", + "transformers>=4.33.0", ], - "openvino": ["openvino>=2023.1.0", "onnx", "onnxruntime"], + "openvino": ["openvino>=2023.1.0", "onnx", "onnxruntime", "transformers>=4.33.0"], "nncf": ["nncf>=2.6.0"], "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"], "diffusers": ["diffusers"], diff --git a/tests/neural_compressor/test_modeling.py b/tests/neural_compressor/test_modeling.py index fc2a310595..8098f011c5 100644 --- a/tests/neural_compressor/test_modeling.py +++ b/tests/neural_compressor/test_modeling.py @@ -19,6 +19,7 @@ import unittest import torch +from packaging.version import Version, parse from parameterized import parameterized from transformers import AutoTokenizer, pipeline, set_seed @@ -39,6 +40,7 @@ INCTrainer, ) from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS, WEIGHTS_NAME +from optimum.version import __version__ as _optimum_version os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -133,6 +135,7 @@ def test_pipeline(self, model_id, task): pipe(*inputs) + @unittest.skipIf(parse(_optimum_version) < Version("1.14.0"), "not supported, needs optimum>=v1.14.0") def test_compare_with_and_without_past_key_values(self): model_id = "echarlaix/tiny-random-gpt2-torchscript" tokenizer = AutoTokenizer.from_pretrained(model_id)