From 8f41d49b3bff385dfd39a0a693d82e56f877142a Mon Sep 17 00:00:00 2001 From: "Wang, Chang" Date: Mon, 23 Oct 2023 22:50:09 +0800 Subject: [PATCH] [Optimization] Text-generation support qwen (#513) --- .../text-generation/quantization/run_generation.py | 8 +++++--- .../transformers/utils/utility.py | 4 ++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation.py index 7bf58967e16..4083bcec269 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation.py @@ -119,7 +119,7 @@ excluded_precisions=excluded_precisions, # default is [] ) elif args.woq: - quantization_config = WeightOnlyQuantConfig() #default is A32W4G32 + quantization_config = WeightOnlyQuantConfig(compute_type="fp32", weight_type="int4_fullrange", group_size=32) #default is A32W4G32 # bitsandbytes elif args.bitsandbytes: # GPU device is need for `load_in_4bit` and `load_in_8bit`. @@ -133,6 +133,8 @@ if quantization_config is not None: user_model = AutoModelForCausalLM.from_pretrained(args.model, quantization_config=quantization_config, + trust_remote_code=args.trust_remote_code, + torchscript=True if args.sq else False, use_llm_runtime=False ) if args.sq: @@ -145,8 +147,8 @@ load_in_8bit=args.load_in_8bit, use_llm_runtime=False ) -elif not args.int8 or not args.int8_bf16_mixed: - user_model = AutoModelForCausalLM.from_pretrained(args.model, config=config, use_llm_runtime=False) +elif not args.int8 and not args.int8_bf16_mixed: + user_model = AutoModelForCausalLM.from_pretrained(args.model, config=config, trust_remote_code=args.trust_remote_code, use_llm_runtime=False) # peft if args.peft_model_id is not None: from peft import PeftModel diff --git a/intel_extension_for_transformers/transformers/utils/utility.py b/intel_extension_for_transformers/transformers/utils/utility.py index 02a3b3d8f20..8f46e95c2e0 100644 --- a/intel_extension_for_transformers/transformers/utils/utility.py +++ b/intel_extension_for_transformers/transformers/utils/utility.py @@ -97,6 +97,10 @@ def generate_dummy_past_key_values(input_bs, model): else: new_shape = [input_bs * num_attention_heads, 1, d_k] pkv = pkv + (torch.ones(size=new_shape),) + elif model.config.model_type == "qwen": + new_shape = [input_bs, 1, num_attention_heads, d_k] + dummy_tensor = torch.ones(size=new_shape) + pkv = tuple(dummy_tensor for _ in range(nb_pkv)) else: new_shape = [input_bs, num_attention_heads, 1, d_k] dummy_tensor = torch.ones(size=new_shape)