From 14734de73bd471de162fc7c84fc54f46b50ed65a Mon Sep 17 00:00:00 2001 From: "Wang, Chang" Date: Tue, 11 Jun 2024 11:13:44 +0800 Subject: [PATCH] Fix SQ baichuan without position_ids for torch and ipex 2.3.0 (#1597) Signed-off-by: Wang, Chang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- examples/.config/pytorch_optimize.json | 26 ---------- .../quantization/requirements_sq.txt | 2 +- .../quantization/run_benchmark.sh | 25 ++------- .../quantization/run_tuning.sh | 32 ++---------- .../transformers/llm/evaluation/models.py | 11 ++-- .../transformers/modeling/modeling_auto.py | 7 ++- .../transformers/utils/utility.py | 51 +++++++++++++++++-- 7 files changed, 63 insertions(+), 91 deletions(-) diff --git a/examples/.config/pytorch_optimize.json b/examples/.config/pytorch_optimize.json index 64fc3afd53c..1c1573ceb41 100644 --- a/examples/.config/pytorch_optimize.json +++ b/examples/.config/pytorch_optimize.json @@ -2268,32 +2268,6 @@ } } }, - "baichuan_7b_gen_ipex_static": { - "working_dir": "huggingface/pytorch/text-generation/quantization", - "tune": { - "cmd": "bash run_tuning.sh", - "params": { - "topology": "baichuan_7b", - "task": "generation", - "approach": "static", - "output_model": "saved_results" - } - }, - "benchmark": { - "cmd": "bash run_benchmark.sh", - "params": { - "topology": "baichuan_7b", - "task": "generation", - "approach": "static", - "backend": "ipex", - "mode": "benchmark", - "batch_size": "112", - "iters": "100", - "int8": "false", - "config": "saved_results" - } - } - }, "baichuan2_7b_gen_ipex_static": { "working_dir": "huggingface/pytorch/text-generation/quantization", "tune": { diff --git a/examples/huggingface/pytorch/text-generation/quantization/requirements_sq.txt b/examples/huggingface/pytorch/text-generation/quantization/requirements_sq.txt index 7da5200339d..047f65d091a 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/requirements_sq.txt +++ b/examples/huggingface/pytorch/text-generation/quantization/requirements_sq.txt @@ -5,7 +5,7 @@ protobuf sentencepiece != 0.1.92 --extra-index-url https://download.pytorch.org/whl/cpu torch==2.3.0+cpu -transformers +transformers==4.38.1 intel_extension_for_pytorch==2.3.0 optimum-intel==1.16.1 bitsandbytes #baichuan diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh b/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh index 085e3da3574..c92e733e9fe 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh +++ b/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh @@ -119,14 +119,12 @@ function run_benchmark { elif [ "${topology}" = "llama_7b" ]; then model_name_or_path="meta-llama/Llama-2-7b-chat-hf" script="run_generation_sq.py" - pip install transformers==4.35.2 elif [ "${topology}" = "llama2_7b_gptq" ]; then model_name_or_path="meta-llama/Llama-2-7b-hf" script="run_generation_cpu_woq.py" elif [ "${topology}" = "llama_13b" ]; then model_name_or_path="meta-llama/Llama-2-13b-chat-hf" script="run_generation_sq.py" - pip install transformers==4.35.2 elif [ "${topology}" = "dolly_v2_3b" ]; then model_name_or_path="/tf_dataset2/models/pytorch/dolly_v2_3b" script="run_generation_sq.py" @@ -137,47 +135,32 @@ function run_benchmark { model_name_or_path="THUDM/chatglm3-6b" script="run_generation_sq.py" extra_cmd=$extra_cmd" --trust_remote_code" - pip install transformers==4.35.2 elif [ "${topology}" = "chatglm2_6b" ]; then model_name_or_path="THUDM/chatglm2-6b" script="run_generation_sq.py" extra_cmd=$extra_cmd" --trust_remote_code" - pip install transformers==4.35.2 elif [ "${topology}" = "chatglm_6b" ]; then model_name_or_path="THUDM/chatglm-6b" script="run_generation_sq.py" extra_cmd=$extra_cmd" --trust_remote_code" - pip install transformers==4.33 elif [ "${topology}" = "falcon_7b" ]; then model_name_or_path="tiiuae/falcon-7b-instruct" script="run_generation_sq.py" - pip install transformers==4.33 - elif [ "${topology}" = "baichuan_7b" ]; then - model_name_or_path="baichuan-inc/Baichuan-7B" - extra_cmd=$extra_cmd" --trust_remote_code" - pip install transformers==4.33 - script="run_generation_sq.py" elif [ "${topology}" = "baichuan_13b" ]; then - model_name_or_path="baichuan-inc/Baichuan-13B-Base" + model_name_or_path="baichuan-inc/Baichuan-13B-Chat" extra_cmd=$extra_cmd" --trust_remote_code" - extra_cmd=$extra_cmd" --_commit_hash 14d5b0e204542744900f6fb52422c6d633bdcb00" - pip install transformers==4.33 script="run_generation_sq.py" elif [ "${topology}" = "baichuan2_7b" ]; then - model_name_or_path="baichuan-inc/Baichuan2-7B-Base" + model_name_or_path="baichuan-inc/Baichuan2-7B-Chat" extra_cmd=$extra_cmd" --trust_remote_code" - pip install transformers==4.33 script="run_generation_sq.py" elif [ "${topology}" = "baichuan2_13b" ]; then - model_name_or_path="baichuan-inc/Baichuan2-13B-Base" + model_name_or_path="baichuan-inc/Baichuan2-13B-Chat" extra_cmd=$extra_cmd" --trust_remote_code" - pip install transformers==4.35.2 script="run_generation_sq.py" elif [ "${topology}" = "qwen_7b" ]; then - model_name_or_path="Qwen/Qwen-7B" + model_name_or_path="Qwen/Qwen-7B-Chat" extra_cmd=$extra_cmd" --trust_remote_code" - extra_cmd=$extra_cmd" --_commit_hash f7bc352f27bb1c02ee371a4576942a7d96c8bb97" - pip install transformers==4.35.2 script="run_generation_sq.py" elif [ "${topology}" = "mistral_7b" ]; then model_name_or_path="Intel/neural-chat-7b-v3" diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh b/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh index 2d3aae508e4..16eaaa3182e 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh +++ b/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh @@ -133,7 +133,6 @@ function run_tuning { model_name_or_path="/tf_dataset2/models/nlp_toolkit/llama-2-7b-chat/Llama-2-7b-chat-hf" extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" - pip install transformers==4.35.2 script="run_generation_sq.py" elif [ "${topology}" = "llama_13b" ]; then alpha=0.8 @@ -141,7 +140,6 @@ function run_tuning { extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" script="run_generation_sq.py" - pip install transformers==4.35.2 elif [ "${topology}" = "dolly_v2_3b" ]; then alpha=0.6 model_name_or_path="/tf_dataset2/models/pytorch/dolly_v2_3b" @@ -161,7 +159,6 @@ function run_tuning { extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code" script="run_generation_sq.py" - pip install transformers==4.35.2 elif [ "${topology}" = "chatglm2_6b" ]; then alpha=0.75 model_name_or_path="THUDM/chatglm2-6b" @@ -169,64 +166,47 @@ function run_tuning { extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code" script="run_generation_sq.py" - pip install transformers==4.35.2 elif [ "${topology}" = "chatglm_6b" ]; then alpha=0.75 model_name_or_path="THUDM/chatglm-6b" extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code" - pip install transformers==4.33 script="run_generation_sq.py" elif [ "${topology}" = "falcon_7b" ]; then alpha=0.7 model_name_or_path="tiiuae/falcon-7b-instruct" extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" - pip install transformers==4.33.3 script="run_generation_sq.py" - elif [ "${topology}" = "baichuan_7b" ]; then - alpha=0.85 - model_name_or_path="baichuan-inc/Baichuan-7B" - extra_cmd=$extra_cmd" --sq --alpha ${alpha}" - extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" - extra_cmd=$extra_cmd" --trust_remote_code" - script="run_generation_sq.py" - pip install transformers==4.33 elif [ "${topology}" = "baichuan_13b" ]; then alpha=0.85 - model_name_or_path="baichuan-inc/Baichuan-13B-Base" + model_name_or_path="baichuan-inc/Baichuan-13B-Chat" extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code" - extra_cmd=$extra_cmd" --_commit_hash 14d5b0e204542744900f6fb52422c6d633bdcb00" - pip install transformers==4.33 script="run_generation_sq.py" elif [ "${topology}" = "baichuan2_7b" ]; then alpha=0.85 - model_name_or_path="baichuan-inc/Baichuan2-7B-Base" + model_name_or_path="baichuan-inc/Baichuan2-7B-Chat" extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code" - pip install transformers==4.33 script="run_generation_sq.py" elif [ "${topology}" = "baichuan2_13b" ]; then alpha=0.55 - model_name_or_path="baichuan-inc/Baichuan2-13B-Base" + model_name_or_path="baichuan-inc/Baichuan2-13B-Chat" extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code" - pip install transformers==4.35.2 script="run_generation_sq.py" elif [ "${topology}" = "qwen_7b" ]; then alpha=0.9 - model_name_or_path="Qwen/Qwen-7B" + model_name_or_path="Qwen/Qwen-7B-Chat" extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code" - extra_cmd=$extra_cmd" --_commit_hash f7bc352f27bb1c02ee371a4576942a7d96c8bb97" - pip install transformers==4.35.2 - script="run_generation_sq.py" + script="run_generation_sq.py" elif [ "${topology}" = "mistral_7b" ]; then alpha=0.8 model_name_or_path="Intel/neural-chat-7b-v3" @@ -240,7 +220,6 @@ function run_tuning { extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code" - pip install transformers==4.36.1 script="run_generation_sq.py" elif [ "${topology}" = "phi_1_5b" ]; then alpha=0.5 @@ -248,7 +227,6 @@ function run_tuning { extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code" - pip install transformers==4.36.1 script="run_generation_sq.py" elif [ "${topology}" = "llama2_7b_gptq" ]; then model_name_or_path="meta-llama/Llama-2-7b-hf" diff --git a/intel_extension_for_transformers/transformers/llm/evaluation/models.py b/intel_extension_for_transformers/transformers/llm/evaluation/models.py index c95bd1f462f..98dc24e3673 100644 --- a/intel_extension_for_transformers/transformers/llm/evaluation/models.py +++ b/intel_extension_for_transformers/transformers/llm/evaluation/models.py @@ -166,14 +166,9 @@ def forward( input_bs, input_len = input_ids.shape if self.use_cache and past_key_values is None: if model_type in IPEX_OPT_LLM_SUPPORTED: - if model_type == "llama" and transformers.__version__ >= "4.36": - past_key_values = generate_dummy_past_key_values_for_inference( - config=self.config, input_bs=input_bs - ) - else: - past_key_values = generate_dummy_past_key_values_for_opt_llm( - config=self.config, input_bs=input_bs, num_beams=1 - ) + past_key_values = generate_dummy_past_key_values_for_opt_llm( + config=self.config, input_bs=input_bs, num_beams=1 + ) else: past_key_values = generate_dummy_past_key_values_for_inference( config=self.config, input_bs=input_bs diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index c90f9c7fb16..e5aa98070cd 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -841,8 +841,7 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: model = model.float() model.eval() model_type = model.config.model_type.replace("_", "-") - if "llama" in model_type and transformers.__version__ >= "4.36.0": - quantization_config.ipex_opt_llm = False + logger.info("Applying SmoothQuant.") # ipex.optimize_transformers if quantization_config.ipex_opt_llm is None: @@ -851,7 +850,7 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: logger.info( "quantization_config.ipex_opt_llm set to True and ipex.optimize_transformers is used." ) - logger.warning("The suggested transformers version is 4.35.2.") + logger.warning("The suggested transformers version is 4.38.1.") else: quantization_config.ipex_opt_llm = False if quantization_config.ipex_opt_llm: @@ -946,7 +945,7 @@ def collate_batch(batch): ) last_ind.append(input_ids.shape[0] - 1) - if model_type in ["bloom", "qwen"]: + if model_type in ["bloom"]: attention_mask = torch.ones(len(input_ids) + 1) attention_mask[0] = 0 else: diff --git a/intel_extension_for_transformers/transformers/utils/utility.py b/intel_extension_for_transformers/transformers/utils/utility.py index 21e207f91eb..90a03600835 100644 --- a/intel_extension_for_transformers/transformers/utils/utility.py +++ b/intel_extension_for_transformers/transformers/utils/utility.py @@ -21,6 +21,7 @@ from typing import Optional, Tuple from neural_compressor.utils import logger from neural_compressor.utils.utility import LazyImport, CpuInfo +from intel_extension_for_transformers.tools.utils import is_ipex_available CONFIG_NAME = "best_configure.yaml" @@ -36,6 +37,8 @@ SAFE_WEIGHTS_NAME = "model.safetensors" SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" +if is_ipex_available(): + import intel_extension_for_pytorch as ipex torch = LazyImport("torch") def str2bool(v): @@ -300,8 +303,24 @@ def generate_dummy_past_key_values_for_opt_llm(config, input_bs, num_beams=1): ] return tuple(past_key_values) - -IPEX_OPT_LLM_SUPPORTED = {"gptj", "opt", "llama", "falcon", "chatglm", "baichuan"} +IPEX_OPT_LLM_SUPPORTED_DICT = { + "2.2": ["gptj", "opt", "llama", "falcon", "chatglm", "baichuan", "gpt-neox"], + "2.3": [ + "gptj", + "opt", + "llama", + "falcon", + "chatglm", + "baichuan", + "qwen", + "bloom", + "codegen", + "gptbigcode", + "t5", + "mixtral", + "mpt", + ], +} MODEL_TYPES_REQUIRING_POSITION_IDS = { "codegen", @@ -314,9 +333,32 @@ def generate_dummy_past_key_values_for_opt_llm(config, input_bs, num_beams=1): "llama", "mistral", "chatglm", - "baichuan" } +if is_ipex_available() and ipex.__version__ == "2.2.0+cpu": + logger.info( + "ipex.llm.optimize by 2.2.0 version supported model family: {}".format( + ",".join(IPEX_OPT_LLM_SUPPORTED_DICT["2.2"]) + ) + ) + logger.info( + "The recommended transformers version is 4.35.2 if you used IPEX 2.2.0 version." + ) + IPEX_OPT_LLM_SUPPORTED = IPEX_OPT_LLM_SUPPORTED_DICT["2.2"] +elif is_ipex_available() and ipex.__version__ == "2.3.0+cpu": + logger.info( + "ipex.llm.optimize by 2.3.0 version supported model family: {}".format( + ", ".join(IPEX_OPT_LLM_SUPPORTED_DICT["2.3"]) + ) + ) + logger.info( + "The recommended transformers version is 4.38.1 if you used IPEX 2.3.0 version." + ) + IPEX_OPT_LLM_SUPPORTED = IPEX_OPT_LLM_SUPPORTED_DICT["2.3"] +else: + logger.warning("Please check the intel_extension_for_pytorch version is 2.3.0+cpu.") + IPEX_OPT_LLM_SUPPORTED = IPEX_OPT_LLM_SUPPORTED_DICT["2.3"] + def get_example_inputs(model_config, batch_size=1, tokenizer=None, num_beams=4): """Generate the dummy example inputs.""" prompt = "Welcome to use Intel Extension for Transformers." @@ -420,7 +462,8 @@ def recover_model_from_json(fp32_model_name_or_path, json_file_path, trust_remot (object): quantized model """ from transformers import AutoModelForCausalLM - user_model = AutoModelForCausalLM.from_pretrained(fp32_model_name_or_path, trust_remote_code=trust_remote_code) + user_model = AutoModelForCausalLM.from_pretrained(fp32_model_name_or_path, + trust_remote_code=trust_remote_code).float() if user_model.config.model_type in IPEX_OPT_LLM_SUPPORTED: import intel_extension_for_pytorch as ipex qconfig = ipex.quantization.default_static_qconfig_mapping