diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml index cb58f412a6..d43cabe323 100644 --- a/.github/workflows/test_openvino.yml +++ b/.github/workflows/test_openvino.yml @@ -36,3 +36,9 @@ jobs: - name: Test with Pytest run: | pytest tests/openvino/ --ignore test_modeling_basic + - name: Test openvino-nightly import + run: | + pip uninstall -y openvino + pip install openvino-nightly + python -c "from optimum.intel import OVModelForCausalLM; OVModelForCausalLM.from_pretrained('hf-internal-testing/tiny-random-gpt2', export=True, compile=False)" + diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 076f0e3896..56e8f050be 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, Optional, Union from requests.exceptions import ConnectionError as RequestsConnectionError -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from optimum.exporters import TasksManager from optimum.exporters.onnx import __main__ as optimum_main @@ -140,6 +140,41 @@ def main_export( original_task = task task = TasksManager.map_from_synonym(task) + # Patch the modules to export of GPTQ models w/o GPU + do_gptq_patching = False + try: + config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code) + config_dict = config.to_dict() + quantization_config = config_dict.get("quantization_config", None) + do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq" + except Exception: + pass + + if do_gptq_patching: + import torch + + torch.set_default_dtype(torch.float32) + orig_cuda_check = torch.cuda.is_available + torch.cuda.is_available = lambda: True + + from optimum.gptq import GPTQQuantizer + + orig_post_init_model = GPTQQuantizer.post_init_model + + def post_init_model(self, model): + from auto_gptq import exllama_set_max_input_length + + class StoreAttr(object): + pass + + model.quantize_config = StoreAttr() + model.quantize_config.desc_act = self.desc_act + if self.desc_act and not self.disable_exllama and self.max_input_length is not None: + model = exllama_set_max_input_length(model, self.max_input_length) + return model + + GPTQQuantizer.post_init_model = post_init_model + framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework) # get the shapes to be used to generate dummy inputs @@ -326,3 +361,8 @@ def main_export( compression_ratio=compression_ratio, model_kwargs=model_kwargs, ) + + # Unpatch modules after GPTQ export + if do_gptq_patching: + torch.cuda.is_available = orig_cuda_check + GPTQQuantizer.post_init_model = orig_post_init_model diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index f8acd0b5b8..217fe9a213 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -229,34 +229,6 @@ def _from_transformers( if use_cache: task = task + "-with-past" - # Patch the modules to export of GPTQ models w/o GPU - do_gptq_patching = False - config_dict = config.to_dict() - quantization_config = config_dict.get("quantization_config", None) - do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq" - if do_gptq_patching: - torch.set_default_dtype(torch.float32) - orig_cuda_check = torch.cuda.is_available - torch.cuda.is_available = lambda: True - - from optimum.gptq import GPTQQuantizer - - orig_post_init_model = GPTQQuantizer.post_init_model - - def post_init_model(self, model): - from auto_gptq import exllama_set_max_input_length - - class StoreAttr(object): - pass - - model.quantize_config = StoreAttr() - model.quantize_config.desc_act = self.desc_act - if self.desc_act and not self.disable_exllama and self.max_input_length is not None: - model = exllama_set_max_input_length(model, self.max_input_length) - return model - - GPTQQuantizer.post_init_model = post_init_model - main_export( model_name_or_path=model_id, output=save_dir_path, @@ -271,11 +243,6 @@ class StoreAttr(object): compression_option="i8" if load_in_8bit else None, ) - # Unpatch modules after GPTQ export - if do_gptq_patching: - torch.cuda.is_available = orig_cuda_check - GPTQQuantizer.post_init_model = orig_post_init_model - config.is_decoder = True config.is_encoder_decoder = False config.save_pretrained(save_dir_path) @@ -504,7 +471,7 @@ def _from_pretrained( elif model_type == "gpt-bigcode": init_cls = OVGPTBigCodeForCausalLM else: - init_cls = OVModelForCausalLM + init_cls = cls return init_cls(model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs) diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py index 87ddcc1315..d15780384f 100644 --- a/optimum/intel/utils/import_utils.py +++ b/optimum/intel/utils/import_utils.py @@ -71,7 +71,10 @@ try: _openvino_version = importlib_metadata.version("openvino") except importlib_metadata.PackageNotFoundError: - _openvino_available = False + try: + _openvino_version = importlib_metadata.version("openvino-nightly") + except importlib_metadata.PackageNotFoundError: + _openvino_available = False _nncf_available = importlib.util.find_spec("nncf") is not None