diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 782aa0bc0d..b17d93aa5e 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, Optional, Union from requests.exceptions import ConnectionError as RequestsConnectionError -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from optimum.exporters import TasksManager from optimum.exporters.onnx import __main__ as optimum_main @@ -137,6 +137,41 @@ def main_export( original_task = task task = TasksManager.map_from_synonym(task) + # Patch the modules to export of GPTQ models w/o GPU + do_gptq_patching = False + try: + config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code) + config_dict = config.to_dict() + quantization_config = config_dict.get("quantization_config", None) + do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq" + except Exception: + pass + + if do_gptq_patching: + import torch + + torch.set_default_dtype(torch.float32) + orig_cuda_check = torch.cuda.is_available + torch.cuda.is_available = lambda: True + + from optimum.gptq import GPTQQuantizer + + orig_post_init_model = GPTQQuantizer.post_init_model + + def post_init_model(self, model): + from auto_gptq import exllama_set_max_input_length + + class StoreAttr(object): + pass + + model.quantize_config = StoreAttr() + model.quantize_config.desc_act = self.desc_act + if self.desc_act and not self.disable_exllama and self.max_input_length is not None: + model = exllama_set_max_input_length(model, self.max_input_length) + return model + + GPTQQuantizer.post_init_model = post_init_model + framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework) # get the shapes to be used to generate dummy inputs @@ -324,3 +359,8 @@ def main_export( int8=int8, model_kwargs=model_kwargs, ) + + # Unpatch modules after GPTQ export + if do_gptq_patching: + torch.cuda.is_available = orig_cuda_check + GPTQQuantizer.post_init_model = orig_post_init_model diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 0e018f9f62..4d87b7eec2 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -229,34 +229,6 @@ def _from_transformers( if use_cache: task = task + "-with-past" - # Patch the modules to export of GPTQ models w/o GPU - do_gptq_patching = False - config_dict = config.to_dict() - quantization_config = config_dict.get("quantization_config", None) - do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq" - if do_gptq_patching: - torch.set_default_dtype(torch.float32) - orig_cuda_check = torch.cuda.is_available - torch.cuda.is_available = lambda: True - - from optimum.gptq import GPTQQuantizer - - orig_post_init_model = GPTQQuantizer.post_init_model - - def post_init_model(self, model): - from auto_gptq import exllama_set_max_input_length - - class StoreAttr(object): - pass - - model.quantize_config = StoreAttr() - model.quantize_config.desc_act = self.desc_act - if self.desc_act and not self.disable_exllama and self.max_input_length is not None: - model = exllama_set_max_input_length(model, self.max_input_length) - return model - - GPTQQuantizer.post_init_model = post_init_model - main_export( model_name_or_path=model_id, output=save_dir_path, @@ -271,11 +243,6 @@ class StoreAttr(object): int8=load_in_8bit, ) - # Unpatch modules after GPTQ export - if do_gptq_patching: - torch.cuda.is_available = orig_cuda_check - GPTQQuantizer.post_init_model = orig_post_init_model - config.is_decoder = True config.is_encoder_decoder = False config.save_pretrained(save_dir_path) @@ -504,7 +471,7 @@ def _from_pretrained( elif model_type == "gpt-bigcode": init_cls = OVGPTBigCodeForCausalLM else: - init_cls = OVModelForCausalLM + init_cls = cls return init_cls(model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs)