From 91c973b89e01a1d4893b1e0b0b143f8a695cacb8 Mon Sep 17 00:00:00 2001 From: changwangss Date: Thu, 13 Jun 2024 19:37:18 -0700 Subject: [PATCH] fix pylint Signed-off-by: changwangss --- .../neural_chat/models/model_utils.py | 2 +- .../transformers/llm/quantization/sq_utils.py | 5 +- .../transformers/llm/quantization/utils.py | 12 +- .../transformers/modeling/modeling_auto.py | 439 ++++++++++++------ .../transformers/utils/utility.py | 412 +--------------- 5 files changed, 306 insertions(+), 564 deletions(-) diff --git a/intel_extension_for_transformers/neural_chat/models/model_utils.py b/intel_extension_for_transformers/neural_chat/models/model_utils.py index dd0c2c99102..9c3e837c7d9 100644 --- a/intel_extension_for_transformers/neural_chat/models/model_utils.py +++ b/intel_extension_for_transformers/neural_chat/models/model_utils.py @@ -699,7 +699,7 @@ def load_model( assert ipex.__version__ >= "2.1.0+cpu", "Please use Intel Extension for PyTorch >=2.1.0+cpu." if re.search("falcon", model_name, re.IGNORECASE): assert transformers.__version__ <= "4.33.3", "Please pip install transformers==4.33.3" - from intel_extension_for_transformers.transformers.llm.evaluation.models import TSModelCausalLMForITREX + from intel_extension_for_transformers.transformers.llm.quantization.sq_utils import TSModelCausalLMForITREX model = TSModelCausalLMForITREX.from_pretrained( model_name, file_name="best_model.pt" diff --git a/intel_extension_for_transformers/transformers/llm/quantization/sq_utils.py b/intel_extension_for_transformers/transformers/llm/quantization/sq_utils.py index 1ffb7b47001..634ea7499c6 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/sq_utils.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/sq_utils.py @@ -14,12 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re + from typing import Optional, Tuple import transformers from datasets import load_dataset -from optimum.intel.generation.modeling import TSModelForCausalLM from torch.nn.functional import pad from torch.utils.data import DataLoader from transformers.modeling_outputs import CausalLMOutputWithPast @@ -315,7 +314,7 @@ def collate_batch(batch): ) return calib_dataloader - +from optimum.intel.generation.modeling import TSModelForCausalLM class TSModelCausalLMForITREX(TSModelForCausalLM): def _reorder_cache( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor diff --git a/intel_extension_for_transformers/transformers/llm/quantization/utils.py b/intel_extension_for_transformers/transformers/llm/quantization/utils.py index afad1d516c2..4a24dc7121d 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/utils.py @@ -57,9 +57,7 @@ ) if is_autoround_available(): - from auto_round.export.export_to_itrex.model_wrapper import ( - WeightOnlyLinear as auto_round_woqlinear, - ) # pylint: disable=E0401 + from auto_round.export.export_to_itrex.model_wrapper import WeightOnlyLinear as auto_round_woqlinear # pylint: disable=E0401 from neural_compressor.torch.algorithms.weight_only.autoround import get_dataloader as get_autoround_dataloader torch = LazyImport("torch") @@ -299,10 +297,8 @@ def _replace_linear( use_optimum_format=use_optimum_format, ) elif device == "xpu" or device == torch.device("xpu"): - from intel_extension_for_pytorch.nn.utils._quantize_convert import ( - WeightOnlyQuantizedLinear as ipex_linear, - ) # pylint: disable=E0401 - + from intel_extension_for_pytorch.nn.utils._quantize_convert import \ + WeightOnlyQuantizedLinear as ipex_linear # pylint: disable=E0401 model._modules[name] = ipex_linear( in_features, out_features, @@ -569,6 +565,8 @@ def convert_to_quantized_model(model, config, device="cpu"): ) model = prepare(model, quant_config) model = convert(model) + # qits module doesn't match with HQQ algorithm. + return model elif config.quant_method.value == "awq": quant_config = AWQConfig( dtype=dtype, diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index 89ab4f758ea..28dc9715782 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -164,7 +164,11 @@ def build_woq_model(model, quantization_config): if "lm_head" in n or "output_layer" in n or "embed_out" in n: continue if isinstance(m, torch.nn.Linear): - zp = getattr(quantization_config, "zero_point", not getattr(quantization_config, "sym", False)) + zp = getattr( + quantization_config, + "zero_point", + not getattr(quantization_config, "sym", False), + ) with init_empty_weights(): new_module = WeightOnlyLinear( m.in_features, @@ -201,6 +205,7 @@ def convert_model_to_public(model): ]: model = recover_export_model(model) + def make_contiguous(model): for param in model.parameters(): if param.data.ndimension() > 1: @@ -225,7 +230,8 @@ def save_low_bit( self.model.config.quantization_config = self.quantization_config self.model.config.save_pretrained(save_directory) weights_file = os.path.join( - os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME) + os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME + ) torch.save(self.quantized_state_dict(), weights_file) return @@ -239,25 +245,42 @@ def save_low_bit( ) if self.quantization_config.use_ipex: + def save_linear_parameters(model, save_directory): # only can save to pytorch model.bin due to ipex. weights_file = os.path.join( - os.path.abspath(os.path.expanduser(save_directory)), SAFE_WEIGHTS_NAME) + os.path.abspath(os.path.expanduser(save_directory)), SAFE_WEIGHTS_NAME + ) os.remove(weights_file) weights_file = os.path.join( - os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME) + os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME + ) linear_parameters = {} - from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_cpu_linear + from intel_extension_for_pytorch.nn.modules import ( + WeightOnlyQuantizedLinear as ipex_cpu_linear, + ) + for name, module in model.named_modules(): if isinstance(module, ipex_cpu_linear): - linear_parameters[name + ".ipex_scales"] = module._op_context.get_scales().contiguous() - linear_parameters[name + ".ipex_weight"] = \ - module._op_context.to_public(module._op_context.get_weight()).contiguous() - linear_parameters[name + ".ipex_zeros"] = module._op_context.get_zero_points().contiguous() + linear_parameters[name + ".ipex_scales"] = ( + module._op_context.get_scales().contiguous() + ) + linear_parameters[name + ".ipex_weight"] = ( + module._op_context.to_public( + module._op_context.get_weight() + ).contiguous() + ) + linear_parameters[name + ".ipex_zeros"] = ( + module._op_context.get_zero_points().contiguous() + ) if module._op_context.get_bias() is not None: - linear_parameters[name + ".ipex_bias"] = module._op_context.get_bias().contiguous() + linear_parameters[name + ".ipex_bias"] = ( + module._op_context.get_bias().contiguous() + ) if module._op_context.get_g_idx() is not None: - linear_parameters[name + ".ipex_g_idx"] = module._op_context.get_g_idx().contiguous() + linear_parameters[name + ".ipex_g_idx"] = ( + module._op_context.get_g_idx().contiguous() + ) others_parameters = model.state_dict() linear_parameters.update(others_parameters) @@ -346,17 +369,27 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): use_vllm = kwargs.pop("use_vllm", None) if use_vllm is not None: logger.info("The backend is vLLM.") - from vllm import LLM # pylint: disable=E1101 - from vllm.model_executor.model_loader import get_model_loader # pylint: disable=E0611 - from vllm.model_executor.model_loader.weight_utils import default_weight_loader # pylint: disable=E0401 disable=E0611 - from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ColumnParallelLinear, - RowParallelLinear) # pylint: disable=E1101 + from vllm import LLM # pylint: disable=E1101 + from vllm.model_executor.model_loader import ( + get_model_loader, + ) # pylint: disable=E0611 + from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + ) # pylint: disable=E0401 disable=E0611 + from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ColumnParallelLinear, + RowParallelLinear, + ) # pylint: disable=E1101 os.environ["backend"] = "use_vllm" - llm = LLM(model=pretrained_model_name_or_path, trust_remote_code=True) # Create an vllm instance. - model = llm.llm_engine.model_executor.driver_worker.model_runner.model # pylint: disable=E1101 + llm = LLM( + model=pretrained_model_name_or_path, trust_remote_code=True + ) # Create an vllm instance. + model = ( + llm.llm_engine.model_executor.driver_worker.model_runner.model + ) # pylint: disable=E1101 print("Original model =", model) original_parameter_memo = dict() @@ -366,12 +399,22 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): if "qkv_proj" in name or "gate_up_proj" in name: input_dim = getattr(params, "input_dim", None) output_dim = getattr(params, "output_dim", None) - original_parameter_memo[name] = (input_dim, output_dim, params.weight_loader) + original_parameter_memo[name] = ( + input_dim, + output_dim, + params.weight_loader, + ) class linear_adaptor(torch.nn.Linear): - def __init__(self, in_features: int, out_features: int, bias: bool = True, \ - device=None, dtype=None) -> None: + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: super().__init__(in_features, out_features, bias, device, dtype) def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: @@ -379,34 +422,49 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: for name, module in model.named_modules(): bias_flag = False - if isinstance(module, QKVParallelLinear) or isinstance(module, MergedColumnParallelLinear) or \ - isinstance(module, RowParallelLinear) or isinstance(module, ColumnParallelLinear): + if ( + isinstance(module, QKVParallelLinear) + or isinstance(module, MergedColumnParallelLinear) + or isinstance(module, RowParallelLinear) + or isinstance(module, ColumnParallelLinear) + ): out_feature = module.weight.shape[0] in_feature = module.weight.shape[1] if getattr(module, "bias", False) != None: bias_flag = True weight_dtype = module.weight.dtype - torch_linear = linear_adaptor(in_features=in_feature, - out_features=out_feature, - bias=bias_flag, - dtype=weight_dtype) + torch_linear = linear_adaptor( + in_features=in_feature, + out_features=out_feature, + bias=bias_flag, + dtype=weight_dtype, + ) module_traversal = model - all_module_names = name.split('.') + all_module_names = name.split(".") all_module_names_except_last = all_module_names[:-1] for sub_module_name in all_module_names_except_last: module_traversal = module_traversal._modules[sub_module_name] - module_traversal._modules[all_module_names[-1]] = copy.deepcopy(torch_linear) + module_traversal._modules[all_module_names[-1]] = copy.deepcopy( + torch_linear + ) print("Optimized model =", model) - loader = get_model_loader(llm.llm_engine.load_config) # pylint: disable=E1101 + loader = get_model_loader( + llm.llm_engine.load_config + ) # pylint: disable=E1101 + + weights_iterator = loader._get_weights_iterator( + llm.llm_engine.model_config.model, + llm.llm_engine.model_config.revision, + fall_back_to_pt=True, + ) - weights_iterator = loader._get_weights_iterator(llm.llm_engine.model_config.model, - llm.llm_engine.model_config.revision, - fall_back_to_pt=True) + from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + ) # pylint: disable=E0401 disable=E0611 - from vllm.model_executor.model_loader.weight_utils import default_weight_loader # pylint: disable=E0401 disable=E0611 params_dict = dict(model.named_parameters(remove_duplicate=False)) for name in params_dict.keys(): params = params_dict[name] @@ -424,11 +482,13 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: print("INC quantizing...") config = kwargs.pop("config", None) if config is None: - config = RtnConfig(compute_dtype="int8", - group_size=128, - scale_dtype="bf16", - weight_dtype="int4_clip", - bits=4) + config = RtnConfig( + compute_dtype="int8", + group_size=128, + scale_dtype="bf16", + weight_dtype="int4_clip", + bits=4, + ) print("using default RTNConfig = ", config) print("Using customized config = ", config) model = convert_to_quantized_model(model, config) @@ -489,8 +549,12 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: return model device_map = kwargs.get("device_map", "cpu") - use_cpu = True if device_map == torch.device("cpu") or device_map == "cpu" else False - use_xpu = True if device_map == torch.device("xpu") or device_map == "xpu" else False + use_cpu = ( + True if device_map == torch.device("cpu") or device_map == "cpu" else False + ) + use_xpu = ( + True if device_map == torch.device("xpu") or device_map == "xpu" else False + ) config = kwargs.pop("config", None) model_hub = kwargs.pop("model_hub", "huggingface") @@ -498,20 +562,28 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: quantization_config = kwargs.pop("quantization_config", None) if not isinstance(config, PretrainedConfig): if model_hub == "modelscope": - import modelscope # pylint: disable=E0401 - config = modelscope.AutoConfig.from_pretrained(pretrained_model_name_or_path, - trust_remote_code=True) + import modelscope # pylint: disable=E0401 + + config = modelscope.AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True + ) else: config, _ = AutoConfig.from_pretrained( pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs, - ) - if quantization_config is not None and quantization_config.quant_method in ["sq"]: + if quantization_config is not None and quantization_config.quant_method in [ + "sq" + ]: use_neural_speed = False - elif hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and "quant_method" in config.quantization_config and config.quantization_config["quant_method"] in ["sq"]: + elif ( + hasattr(config, "quantization_config") + and isinstance(config.quantization_config, dict) + and "quant_method" in config.quantization_config + and config.quantization_config["quant_method"] in ["sq"] + ): use_neural_speed = False elif kwargs.get("use_llm_runtime", None) is not None: use_neural_speed = kwargs.pop("use_llm_runtime", True) and not use_xpu @@ -544,30 +616,38 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: "Quantization_config loading failed. If you want to load saved " "low bit model, please check your quantizate_config.json." ) - elif use_neural_speed and not config.quantization_config["quant_method"] in ["dynamic", "static", "qat"]: + elif use_neural_speed and not config.quantization_config[ + "quant_method" + ] in ["dynamic", "static", "qat"]: if not os.path.exists(pretrained_model_name_or_path): from huggingface_hub import snapshot_download - pretrained_model_name_or_path = snapshot_download(repo_id=pretrained_model_name_or_path, - allow_patterns=["*.pt", "*.safetensors", "*.json", ".model"], - ) + + pretrained_model_name_or_path = snapshot_download( + repo_id=pretrained_model_name_or_path, + allow_patterns=["*.pt", "*.safetensors", "*.json", ".model"], + ) if quantization_config is None: - ConfigInit = {"rtn": RtnConfig, - "awq": AwqConfig, - "teq": TeqConfig, - "gptq": GPTQConfig, - "autoround": AutoRoundConfig, - } + ConfigInit = { + "rtn": RtnConfig, + "awq": AwqConfig, + "teq": TeqConfig, + "gptq": GPTQConfig, + "autoround": AutoRoundConfig, + } quantization_config = config.quantization_config - assert quantization_config.get("quant_method", None) in ConfigInit, \ - "Detect this model is not a low-bit model." - quantization_config = ConfigInit[quantization_config["quant_method"]].from_dict(quantization_config) + assert ( + quantization_config.get("quant_method", None) in ConfigInit + ), "Detect this model is not a low-bit model." + quantization_config = ConfigInit[ + quantization_config["quant_method"] + ].from_dict(quantization_config) logger.info("Loading Low Bits model by Neural Speed.") quantization_config.post_init_runtime() from neural_speed import Model model = Model() - model.init( # pylint: disable=E1123 + model.init( # pylint: disable=E1123 pretrained_model_name_or_path, weight_dtype=quantization_config.weight_dtype, alg=quantization_config.scheme, @@ -658,9 +738,15 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: else: quantization_config = RtnConfig( bits=4, - compute_dtype=torch.float32 if - (use_cpu and not CpuInfo().bf16 - and torch_dtype == torch.bfloat16) else convert_dtype_torch2str(torch_dtype), + compute_dtype=( + torch.float32 + if ( + use_cpu + and not CpuInfo().bf16 + and torch_dtype == torch.bfloat16 + ) + else convert_dtype_torch2str(torch_dtype) + ), weight_dtype="nf4" if use_cpu else "int4_fullrange", ) else: @@ -674,14 +760,21 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: if quantization_config is None: if use_neural_speed: quantization_config = RtnConfig( - compute_dtype="bf16" if CpuInfo().bf16 else "fp32", weight_dtype="int8" + compute_dtype="bf16" if CpuInfo().bf16 else "fp32", + weight_dtype="int8", ) else: quantization_config = RtnConfig( bits=8, - compute_dtype=torch.float32 if - (use_cpu and not CpuInfo().bf16 - and torch_dtype == torch.bfloat16) else convert_dtype_torch2str(torch_dtype), + compute_dtype=( + torch.float32 + if ( + use_cpu + and not CpuInfo().bf16 + and torch_dtype == torch.bfloat16 + ) + else convert_dtype_torch2str(torch_dtype) + ), weight_dtype="int8", ) else: @@ -731,7 +824,7 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: from neural_speed import Model model = Model() - model.init( # pylint: disable=E1123 + model.init( # pylint: disable=E1123 pretrained_model_name_or_path, weight_dtype=quantization_config.weight_dtype, alg=quantization_config.scheme, @@ -990,7 +1083,6 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: # torch.tensor(last_ind), # ) - # tokenized_dataset = calib_dataset.map(tokenize_function, batched=True) # tokenized_dataset.set_format(type="torch", columns=["input_ids"]) # calib_dataloader = DataLoader( @@ -1014,7 +1106,6 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: # ) # calib_func = calib_func - # # call inc static quant # from neural_compressor.torch.quantization import StaticQuantConfig, convert, prepare # quant_config = StaticQuantConfig( @@ -1130,7 +1221,6 @@ def collate_batch(batch): torch.tensor(last_ind), ) - tokenized_dataset = train_dataset.map(tokenize_function, batched=True) tokenized_dataset.set_format(type="torch", columns=["input_ids"]) train_dataloader = DataLoader( @@ -1157,7 +1247,7 @@ def train_func(model): optimizer.zero_grad() loss.backward() optimizer.step() - print('Iteration [{}], Loss: {:.4f}'.format(i+1, loss)) + print("Iteration [{}], Loss: {:.4f}".format(i + 1, loss)) return model logger.info( @@ -1170,6 +1260,7 @@ def train_func(model): # call inc static quant from neural_compressor import QuantizationAwareTrainingConfig, quantization from neural_compressor.training import prepare_compression + conf = QuantizationAwareTrainingConfig( backend=quantization_config.backend, excluded_precisions=quantization_config.excluded_precisions, @@ -1181,7 +1272,9 @@ def train_func(model): model = compression_manager.model train_func(model) compression_manager.callbacks.on_train_end() - compression_manager.model.save_pretrained = types.MethodType(save_low_bit, model) + compression_manager.model.save_pretrained = types.MethodType( + save_low_bit, model + ) quantization_config.remove_redundant_parameters() compression_manager.model.quantization_config = quantization_config logger.info("Quant Aware Training done.") @@ -1192,7 +1285,7 @@ def train_func(model): from neural_speed import Model model = Model() - model.init( # pylint: disable=E1123 + model.init( # pylint: disable=E1123 pretrained_model_name_or_path, weight_dtype="fp32", use_quant=False, @@ -1273,7 +1366,11 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): kwarg_attn_imp = kwargs.pop("attn_implementation", None) # lm-eval device map is dictionary - device_map = device_map[""] if isinstance(device_map, dict) and "" in device_map else device_map + device_map = ( + device_map[""] + if isinstance(device_map, dict) and "" in device_map + else device_map + ) if use_safetensors is None and not is_safetensors_available(): use_safetensors = False @@ -1289,8 +1386,12 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): ) token = use_auth_token - use_cpu = True if device_map == torch.device("cpu") or device_map == "cpu" else False - use_xpu = True if device_map == torch.device("xpu") or device_map == "xpu" else False + use_cpu = ( + True if device_map == torch.device("cpu") or device_map == "cpu" else False + ) + use_xpu = ( + True if device_map == torch.device("xpu") or device_map == "xpu" else False + ) user_agent = { "file_type": "model", @@ -1321,7 +1422,9 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): elif quantization_config["quant_method"] == "dynamic": quantization_config = DynamicQuantConfig.from_dict(quantization_config) elif quantization_config["quant_method"] == "qat": - quantization_config = QuantAwareTrainingConfig.from_dict(quantization_config) + quantization_config = QuantAwareTrainingConfig.from_dict( + quantization_config + ) elif quantization_config["quant_method"] == "sq": quantization_config = SmoothQuantConfig.from_dict(quantization_config) assert ( @@ -1462,11 +1565,15 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): "_raise_exceptions_for_missing_entries": False, "_commit_hash": commit_hash, } - resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + resolved_archive_file = cached_file( + pretrained_model_name_or_path, filename, **cached_file_kwargs + ) # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # result when internet is up, the repo and revision exist, but the file does not. - if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + if resolved_archive_file is None and filename == _add_variant( + SAFE_WEIGHTS_NAME, variant + ): # Maybe the checkpoint is sharded, we try to grab the index name in this case. resolved_archive_file = cached_file( pretrained_model_name_or_path, @@ -1487,9 +1594,13 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): # This repo has no safetensors file of any kind, we switch to PyTorch. filename = _add_variant(WEIGHTS_NAME, variant) resolved_archive_file = cached_file( - pretrained_model_name_or_path, filename, **cached_file_kwargs + pretrained_model_name_or_path, + filename, + **cached_file_kwargs, ) - if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + if resolved_archive_file is None and filename == _add_variant( + WEIGHTS_NAME, variant + ): # Maybe the checkpoint is sharded, we try to grab the index name in this case. resolved_archive_file = cached_file( pretrained_model_name_or_path, @@ -1508,7 +1619,9 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): "token": token, } if variant is not None and has_file( - pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs + pretrained_model_name_or_path, + WEIGHTS_NAME, + **has_file_kwargs, ): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" @@ -1571,8 +1684,11 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): if quantization_config.quant_method in ["static", "dynamic", "qat"]: model = model_class(config, *model_args, **kwargs) from neural_compressor.utils.pytorch import load + weights_file = os.path.join( - os.path.abspath(os.path.expanduser(pretrained_model_name_or_path)), WEIGHTS_NAME) + os.path.abspath(os.path.expanduser(pretrained_model_name_or_path)), + WEIGHTS_NAME, + ) q_model = load(weights_file, model, dataloader=None) del model return q_model @@ -1581,7 +1697,10 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): from intel_extension_for_transformers.transformers.llm.quantization.sq_utils import ( TSModelCausalLMForITREX, ) - q_model = torch.jit.load(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) + + q_model = torch.jit.load( + os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + ) origin_model_type = config.model_type if origin_model_type in ["chatglm", "qwen", "baichuan"]: config.model_type = "qwen2" @@ -1611,19 +1730,25 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): dtype_orig = model_class._set_default_torch_dtype(torch_dtype) if quantization_config.compute_dtype is None: if use_xpu: - quantization_config.compute_dtype = \ - "fp16" if (torch_dtype is None or - torch_dtype == torch.bfloat16) \ + quantization_config.compute_dtype = ( + "fp16" + if (torch_dtype is None or torch_dtype == torch.bfloat16) else convert_dtype_torch2str(torch_dtype) + ) else: - quantization_config.compute_dtype = \ - "fp32" if (torch_dtype is None or - (not CpuInfo().bf16 and torch_dtype == torch.bfloat16) or - (torch_dtype == torch.float16)) \ + quantization_config.compute_dtype = ( + "fp32" + if ( + torch_dtype is None + or (not CpuInfo().bf16 and torch_dtype == torch.bfloat16) + or (torch_dtype == torch.float16) + ) else convert_dtype_torch2str(torch_dtype) + ) else: - if ((not CpuInfo().bf16 and quantization_config.compute_dtype == "bf16") - or (use_cpu and quantization_config.compute_dtype == "fp16")): + if (not CpuInfo().bf16 and quantization_config.compute_dtype == "bf16") or ( + use_cpu and quantization_config.compute_dtype == "fp16" + ): quantization_config.compute_dtype = "fp32" if quantization_config.scale_dtype is None: @@ -1631,7 +1756,9 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): if quantization_config.scale_dtype not in ["fp32", "fp16", "bf16"]: logger.warning("scale_dtype only supports fp32, bf16, fp16.") quantization_config.scale_dtype = "fp32" - logger.warning("fp32 scale_dtype is used, please change the config.json if you don't want to use it.") + logger.warning( + "fp32 scale_dtype is used, please change the config.json if you don't want to use it." + ) # weight dtype is higher priority than bits in config.json when both existed. if quantization_config.weight_dtype is None: @@ -1639,36 +1766,47 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): quantization_config.weight_dtype = "int4_clip" logger.info( "{} quantization weight_dtype is used due to bits is 4 in config.json.".format( - quantization_config.weight_dtype) + quantization_config.weight_dtype ) + ) elif quantization_config.bits == 8: quantization_config.weight_dtype = "int8" logger.info( "{} quantization weight_dtype is used due to bits is 8 in config.json.".format( - quantization_config.weight_dtype) + quantization_config.weight_dtype ) + ) else: logger.warning("bits number only supports 4, 8.") quantization_config.weight_dtype = "int4_clip" logger.warning( - "int4_clip weight_dtype is used, please change the config.json if you don't want to use it.") + "int4_clip weight_dtype is used, please change the config.json if you don't want to use it." + ) else: - if quantization_config.weight_dtype not in ["int4_fullrange", - "int4_clip", - "int8", - "fp8_e5m2", - "fp8_e4m3", - "nf4", - "fp4_e2m1_bnb", - "fp4_e2m1"]: - logger.warning("Please provide the correct bits number or weight_dtype in config.json.") + if quantization_config.weight_dtype not in [ + "int4_fullrange", + "int4_clip", + "int8", + "fp8_e5m2", + "fp8_e4m3", + "nf4", + "fp4_e2m1_bnb", + "fp4_e2m1", + ]: + logger.warning( + "Please provide the correct bits number or weight_dtype in config.json." + ) raise ValueError( f"weight_dtype must be a string in " f"'int8', 'int4', 'int4_fullrange', 'int4_clip', 'nf4', " f"'fp4', 'fp4_e2m1_bnb', 'fp4_e2m1', 'fp8', 'fp8_e5m2, fp8_e4m3'" ) else: - logger.info("{} quantization weight_dtype is used.".format(quantization_config.weight_dtype)) + logger.info( + "{} quantization weight_dtype is used.".format( + quantization_config.weight_dtype + ) + ) init_contexts = [no_init_weights(_enable=_fast_init)] init_contexts.append(init_empty_weights()) @@ -1706,7 +1844,10 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): if is_ipex_available() and quantization_config.use_ipex: import intel_extension_for_pytorch as ipex - from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_linear + from intel_extension_for_pytorch.nn.modules import ( + WeightOnlyQuantizedLinear as ipex_linear, + ) + def replace_ipex_cpu_woq_linear(model, current_name=[]): for name, module in model.named_children(): current_name.append(name) @@ -1716,37 +1857,46 @@ def replace_ipex_cpu_woq_linear(model, current_name=[]): 8: ipex.quantization.WoqWeightDtype.INT8, } compute_dtype = { - "fp32": ipex.quantization.WoqLowpMode.NONE, # follow the activation datatype. + "fp32": ipex.quantization.WoqLowpMode.NONE, # follow the activation datatype. "bf16": ipex.quantization.WoqLowpMode.BF16, "fp16": ipex.quantization.WoqLowpMode.FP16, "int8": ipex.quantization.WoqLowpMode.INT8, - } - ipex_qconfig_mapping = ( - ipex.quantization.get_weight_only_quant_qconfig_mapping( - weight_dtype=weight_dtype[quantization_config.bits], - lowp_mode=compute_dtype[quantization_config.compute_dtype], - act_quant_mode=ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, - group_size=quantization_config.group_size, - ) + ipex_qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype[quantization_config.bits], + lowp_mode=compute_dtype[quantization_config.compute_dtype], + act_quant_mode=ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, + group_size=quantization_config.group_size, ) tmp_linear = torch.nn.Linear( module.in_features, module.out_features, - True if hasattr(module, "bias") else False - ) + True if hasattr(module, "bias") else False, + ) tmp_linear.qconfig = ipex_qconfig_mapping.global_qconfig target_linear = ipex_linear.from_float_and_int4_weight( - mod = tmp_linear, - qweight = state_dict.pop('.'.join(current_name) + ".ipex_weight"), - scales = state_dict.pop('.'.join(current_name) + ".ipex_scales"), - zero_points = state_dict.pop('.'.join(current_name) + ".ipex_zeros"), - bias = state_dict.pop('.'.join(current_name) + ".ipex_bias") \ - if '.'.join(current_name) + ".ipex_bias" in state_dict else None, - group_size = quantization_config.group_size, - g_idx = state_dict.pop('.'.join(current_name) + ".ipex_g_idx") \ - if '.'.join(current_name) + ".ipex_g_idx" in state_dict else None, + mod=tmp_linear, + qweight=state_dict.pop( + ".".join(current_name) + ".ipex_weight" + ), + scales=state_dict.pop( + ".".join(current_name) + ".ipex_scales" + ), + zero_points=state_dict.pop( + ".".join(current_name) + ".ipex_zeros" + ), + bias=( + state_dict.pop(".".join(current_name) + ".ipex_bias") + if ".".join(current_name) + ".ipex_bias" in state_dict + else None + ), + group_size=quantization_config.group_size, + g_idx=( + state_dict.pop(".".join(current_name) + ".ipex_g_idx") + if ".".join(current_name) + ".ipex_g_idx" in state_dict + else None + ), ) setattr(model, name, target_linear) else: @@ -1783,14 +1933,18 @@ def replace_ipex_cpu_woq_linear(model, current_name=[]): # Set model in evaluation mode to deactivate DropOut modules by default model.eval() - if quantization_config.weight_dtype not in [ - "fp8_e5m2", - "fp8_e4m3", - "nf4", - "fp4_e2m1", - "fp4_e2m1_bnb", - "int4_fullrange", - ] and not quantization_config.use_ipex: + if ( + quantization_config.weight_dtype + not in [ + "fp8_e5m2", + "fp8_e4m3", + "nf4", + "fp4_e2m1", + "fp4_e2m1_bnb", + "int4_fullrange", + ] + and not quantization_config.use_ipex + ): model = replace_linear( model, quantization_config=quantization_config, @@ -1798,8 +1952,9 @@ def replace_ipex_cpu_woq_linear(model, current_name=[]): empty_weights=True, ) - if (not use_xpu and torch_dtype == torch.float16) or (not use_xpu and not CpuInfo().bf16 - and torch_dtype == torch.bfloat16): + if (not use_xpu and torch_dtype == torch.float16) or ( + not use_xpu and not CpuInfo().bf16 and torch_dtype == torch.bfloat16 + ): model.to(dtype=torch.float32) # If it is a model with generation capabilities, attempt to load the generation config diff --git a/intel_extension_for_transformers/transformers/utils/utility.py b/intel_extension_for_transformers/transformers/utils/utility.py index 2467531fab2..092a3a33a58 100644 --- a/intel_extension_for_transformers/transformers/utils/utility.py +++ b/intel_extension_for_transformers/transformers/utils/utility.py @@ -18,9 +18,7 @@ import argparse import os -from typing import Optional, Tuple -from neural_compressor.utils import logger -from neural_compressor.utils.utility import LazyImport, CpuInfo +from neural_compressor.utils.utility import LazyImport from intel_extension_for_transformers.tools.utils import is_ipex_available @@ -96,411 +94,3 @@ def __init__(self) -> None: self.dataset = dataloader.dataset return INCDataLoader() - - -def generate_dummy_past_key_values(config, input_bs): - """Generate the dummy past_key_values.""" - from optimum.utils import NormalizedConfigManager - if config.model_type == "qwen": - new_shape = [ - input_bs, - 0, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_hidden_layers - elif config.model_type == "baichuan": - new_shape = [ - input_bs, - config.num_attention_heads, - 0, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_hidden_layers - elif config.model_type == "chatglm": - new_shape = [ - 0, - input_bs, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_layers - else: - normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.model_type - )(config) - nb_pkv = 2 - num_layers = normalized_config.num_layers - num_attention_heads = normalized_config.num_attention_heads - hidden_size = normalized_config.hidden_size - d_k = hidden_size // num_attention_heads - num_key_value_heads = num_attention_heads - if hasattr(normalized_config, "num_key_value_heads"): - num_key_value_heads = normalized_config.num_key_value_heads - if hasattr(normalized_config, "multi_query_group_num"): - num_key_value_heads = normalized_config.multi_query_group_num - - if config.model_type == "bloom": - shape_key = (input_bs * num_attention_heads, d_k, 1) - shape_value = (input_bs * num_attention_heads, 1, d_k) - key = torch.ones(size=shape_key) - value = torch.ones(size=shape_value) - past_key_values = tuple( - tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) - for _ in range(num_layers) - ) - return past_key_values - elif config.model_type == "gpt_bigcode": - new_shape = [input_bs, 0, d_k * 2] - dummy_tensor = torch.zeros(size=new_shape) - past_key_values = tuple([dummy_tensor] * num_layers) - return past_key_values - elif config.model_type == "falcon": - new_shape = [input_bs, 1, 0, d_k] - else: - new_shape = [input_bs, num_key_value_heads, 0, d_k] - past_key_values = [ - ( - torch.zeros(size=new_shape).contiguous(), - torch.zeros(size=new_shape).contiguous(), - ) - for _ in range(num_layers) - ] - return tuple(past_key_values) - -def generate_dummy_past_key_values_for_inference(config, input_bs): - """Generate the dummy past_key_values.""" - from optimum.utils import NormalizedConfigManager - if config.model_type == "qwen": - new_shape = [ - input_bs, - 0, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_hidden_layers - elif config.model_type == "baichuan": - new_shape = [ - input_bs, - config.num_attention_heads, - 0, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_hidden_layers - elif config.model_type == "chatglm": - new_shape = [ - 0, - input_bs, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_layers - else: - normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.model_type - )(config) - nb_pkv = 2 - num_layers = normalized_config.num_layers - num_attention_heads = normalized_config.num_attention_heads - hidden_size = normalized_config.hidden_size - d_k = hidden_size // num_attention_heads - num_key_value_heads = num_attention_heads - if hasattr(normalized_config, "num_key_value_heads"): - num_key_value_heads = normalized_config.num_key_value_heads - if hasattr(normalized_config, "multi_query_group_num"): - num_key_value_heads = normalized_config.multi_query_group_num - - if config.model_type == "bloom": - shape_key = (input_bs * num_attention_heads, d_k, 0) - shape_value = (input_bs * num_attention_heads, 0, d_k) - key = torch.empty(size=shape_key) - value = torch.empty(size=shape_value) - past_key_values = tuple( - tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) - for _ in range(num_layers) - ) - return past_key_values - elif config.model_type == "gpt_bigcode": - new_shape = [input_bs, 0, d_k * 2] - dummy_tensor = torch.zeros(size=new_shape) - past_key_values = tuple([dummy_tensor] * num_layers) - return past_key_values - elif config.model_type == "falcon": - new_shape = [input_bs, 1, 0, d_k] - else: - new_shape = [input_bs, num_key_value_heads, 0, d_k] - past_key_values = [ - ( - torch.zeros(size=new_shape).contiguous(), - torch.zeros(size=new_shape).contiguous(), - ) - for _ in range(num_layers) - ] - return tuple(past_key_values) - -def generate_dummy_past_key_values_for_opt_llm(config, input_bs, num_beams=1): - """Generate the dummy past_key_values.""" - from optimum.utils import NormalizedConfigManager - if config.model_type == "qwen": - new_shape = [ - input_bs, - 1, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_hidden_layers - elif config.model_type == "baichuan": - new_shape = [ - input_bs, - config.num_attention_heads, - 1, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_hidden_layers - elif config.model_type == "chatglm": - new_shape = [ - 1, - input_bs, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_layers - else: - normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.model_type - )(config) - num_layers = normalized_config.num_layers - num_attention_heads = normalized_config.num_attention_heads - hidden_size = normalized_config.hidden_size - d_k = hidden_size // num_attention_heads - num_key_value_heads = num_attention_heads - nb_pkv = 2 - if hasattr(normalized_config, "num_key_value_heads"): - num_key_value_heads = normalized_config.num_key_value_heads - if hasattr(normalized_config, "multi_query_group_num"): - num_key_value_heads = normalized_config.multi_query_group_num - if config.model_type == "bloom": - for nb_pkv in range(nb_pkv): - if nb_pkv % 2 == 0: - new_shape = [input_bs * num_key_value_heads, d_k, 1] - else: - new_shape = [input_bs * num_key_value_heads, 1, d_k] - - else: - new_shape = [input_bs, num_key_value_heads, 1, d_k] - - beam_idx_tmp = torch.zeros( - (2048, int(input_bs * num_beams)), dtype=torch.long - ).contiguous() - past_key_values = [ - ( - torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), - torch.zeros(size=new_shape).contiguous(), - torch.zeros(size=new_shape).contiguous(), - beam_idx_tmp, - ) - for _ in range(num_layers) - ] - return tuple(past_key_values) - -IPEX_OPT_LLM_SUPPORTED_DICT = { - "2.2": ["gptj", "opt", "llama", "falcon", "chatglm", "baichuan", "gpt-neox"], - "2.3": [ - "gptj", - "opt", - "llama", - "falcon", - "chatglm", - "baichuan", - "qwen", - "bloom", - "codegen", - "gptbigcode", - "t5", - "mixtral", - "mpt", - ], -} - -MODEL_TYPES_REQUIRING_POSITION_IDS = { - "codegen", - "gpt2", - "gpt-bigcode", - "gpt-neo", - "gpt-neox", - "gptj", - "imagegpt", - "llama", - "mistral", - "chatglm", -} - -if is_ipex_available() and ipex.__version__ == "2.2.0+cpu": - logger.info( - "ipex.llm.optimize by 2.2.0 version supported model family: {}".format( - ",".join(IPEX_OPT_LLM_SUPPORTED_DICT["2.2"]) - ) - ) - logger.info( - "The recommended transformers version is 4.35.2 if you used IPEX 2.2.0 version." - ) - IPEX_OPT_LLM_SUPPORTED = IPEX_OPT_LLM_SUPPORTED_DICT["2.2"] -elif is_ipex_available() and ipex.__version__ == "2.3.0+cpu": - logger.info( - "ipex.llm.optimize by 2.3.0 version supported model family: {}".format( - ", ".join(IPEX_OPT_LLM_SUPPORTED_DICT["2.3"]) - ) - ) - logger.info( - "The recommended transformers version is 4.38.1 if you used IPEX 2.3.0 version." - ) - IPEX_OPT_LLM_SUPPORTED = IPEX_OPT_LLM_SUPPORTED_DICT["2.3"] -else: - logger.warning("Please check the intel_extension_for_pytorch version is 2.3.0+cpu.") - IPEX_OPT_LLM_SUPPORTED = IPEX_OPT_LLM_SUPPORTED_DICT["2.3"] - -def get_example_inputs(model_config, batch_size=1, tokenizer=None, num_beams=4): - """Generate the dummy example inputs.""" - prompt = "Welcome to use Intel Extension for Transformers." - prompt = [prompt] * batch_size - input_ids = tokenizer(prompt, return_tensors="pt").input_ids - model_type = model_config.model_type.replace("_", "-") - if model_type in IPEX_OPT_LLM_SUPPORTED: - past_key_values = generate_dummy_past_key_values_for_opt_llm( - config=model_config, - input_bs=batch_size, - num_beams=num_beams - ) - else: - past_key_values = generate_dummy_past_key_values(config=model_config, input_bs=batch_size) - - input_ids = input_ids[:, :512] - if model_type in ["bloom", "qwen"]: - attention_mask = torch.ones(input_ids.shape[0], input_ids.shape[1] + 1) - attention_mask[:,0] = 0 - else: - attention_mask = torch.ones(input_ids.shape) - position_ids = torch.arange(input_ids.shape[1]).repeat(batch_size, 1) - - if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: - example_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - "past_key_values": past_key_values - } - else: - example_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values - } - return example_inputs - - -def make_torchscript_model(model, json_file_path, example_inputs): - """Recover ipex model from JSON file. - - Args: - model (object): fp32 model need to do quantization. - json_file_path (json): configuration JSON file for ipex. - example_inputs (tuple or torch.Tensor or dict): example inputs that will be passed to the ipex function. - - Returns: - (object): quantized model - """ - - ipex = LazyImport("intel_extension_for_pytorch") - from torch.ao.quantization.observer import MinMaxObserver - - if ipex.__version__ >= "2.1.100": - qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver) - else: - qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver()) - if isinstance(example_inputs, dict): - model = ipex.quantization.prepare(model, qconfig, example_kwarg_inputs=example_inputs, inplace=True) - else: - model = ipex.quantization.prepare(model, qconfig, example_inputs=example_inputs, inplace=True) - model.load_qconf_summary(qconf_summary=json_file_path) - model = ipex.quantization.convert(model, inplace=True) - model.eval() - with torch.no_grad(): - try: - if isinstance(example_inputs, dict): - # pylint: disable=E1120,E1123 - model = torch.jit.trace(model, example_kwarg_inputs=example_inputs) - else: - model = torch.jit.trace(model, example_inputs) - model = torch.jit.freeze(model.eval()) - except: - if isinstance(example_inputs, dict): - # pylint: disable=E1120,E1123 - model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False) - else: - model = torch.jit.trace(model, example_inputs, strict=False) - model = torch.jit.freeze(model.eval()) - if isinstance(example_inputs, dict): - model(**example_inputs) - model(**example_inputs) - elif isinstance(example_inputs, tuple) or isinstance(example_inputs, list): - model(*example_inputs) - model(*example_inputs) - else: - model(example_inputs) - model(example_inputs) - return model - -def recover_model_from_json(fp32_model_name_or_path, json_file_path, trust_remote_code=False): - """Recover ipex model from JSON file. - - Args: - model (object): fp32 model need to do quantization. - json_file_path (json): configuration JSON file for ipex. - trust_remote_code (bool): trust remote code. - - Returns: - (object): quantized model - """ - from transformers import AutoModelForCausalLM - - # ipex recovered int8 model from configure.json requests float32 model input and on cpu device. - user_model = AutoModelForCausalLM.from_pretrained(fp32_model_name_or_path, - trust_remote_code=trust_remote_code).float() - if user_model.config.model_type in IPEX_OPT_LLM_SUPPORTED: - import intel_extension_for_pytorch as ipex - qconfig = ipex.quantization.default_static_qconfig_mapping - user_model = ipex.optimize_transformers( - user_model.eval(), - dtype=torch.float, - inplace=True, - quantization_config=qconfig, - deployment_mode=False, - ) - - # tokenizer - if user_model.config.model_type == "llama": - from transformers import LlamaTokenizer - tokenizer = LlamaTokenizer.from_pretrained(user_model.config.name_or_path) - else: - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained( - user_model.config.name_or_path, trust_remote_code=trust_remote_code - ) - - # example_inputs - example_inputs = get_example_inputs(user_model.config, tokenizer=tokenizer) - - # pylint: disable=E0611 - user_model.config.torchscript = True - config = user_model.config - user_model = make_torchscript_model(user_model, json_file_path, example_inputs) - import intel_extension_for_pytorch as ipex - from intel_extension_for_transformers.transformers.llm.evaluation.models import ( - TSModelCausalLMForITREX, - ) - origin_model_type = config.model_type - if origin_model_type in ["chatglm", "qwen", "baichuan"]: - config.model_type = "qwen2" - user_model = TSModelCausalLMForITREX(user_model, config=config) - user_model.config.model_type = origin_model_type - return user_model