diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index 397739cb72a..b5691961ef6 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -98,6 +98,8 @@ def __init__( n_splits_linear: int = 1, n_splits_down_proj: int = 1, group_size: int = 0, + cos_len: int = 1, + keep_position_ids=True, asym: bool = False, ): super().__init__(max_seq_len=max_seq_len, @@ -114,18 +116,13 @@ def __init__( self.dtype = dtype self.cached_cos = cached_cos self.cached_sin = cached_sin + self.cos_len = cos_len self.batch_size, self.seq_len, self.hidden_size = hidden_shape self.mode = mode self.rms_norm_eps = rms_norm_eps self.transpose_value = transpose_value self.num_layers = num_layers - cos = self.constant(self.cached_cos) - self.cos = self.unsqueeze(cos, axis=0) - - sin = self.constant(self.cached_sin) - self.sin = self.unsqueeze(sin, axis=0) - if mode == "decode": self.kv_seq_len = self.max_seq_len + 1 else: @@ -148,7 +145,21 @@ def __init__( attention_mask = self.create_input_op( (self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.float16) - position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) + if self.cached_cos is None: + if mode == "prefill" and keep_position_ids: + position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) + cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim), + dtype=np.float32) + self.cos = self.convert_to_fp16(cos) + sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim), + dtype=np.float32) + self.sin = self.convert_to_fp16(sin) + else: + position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) + cos = self.constant(self.cached_cos) + self.cos = self.unsqueeze(cos, axis=0) + sin = self.constant(self.cached_sin) + self.sin = self.unsqueeze(sin, axis=0) if input_layernorm_weights is None: input_layernorm_weights = [] @@ -211,11 +222,12 @@ def __init__( hidden_states = input curr_key_values = [] + cos_condition = cached_cos is not None or (mode == "prefill" and keep_position_ids) for i in range(num_layers): hidden_states, new_key_states, new_value_states = self.build_decoder( hidden_states=hidden_states, attention_mask=attention_mask, - position_ids=position_ids, + position_ids=position_ids if cos_condition else None, input_layernorm_weight=input_layernorm_weights[i], post_attention_layernorm_weight=post_attn_layernorm_weights[i], q_bias=q_biases[i], diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py index fbccd683d70..13dbb013a43 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py @@ -173,6 +173,105 @@ def __init__( self.compile() +class Llama32Embedding(NNFactory): + def __init__( + self, + vocab_size, + embedding_dim, + embedding_weight, + padding_idx, + inv_freq, + attention_scaling, + dtype, # fp16 + device: str = "NPU", + ): + super().__init__(False, device) + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.attention_scaling = attention_scaling + self.dtype = dtype + + # define input + weight = self.constant(embedding_weight) + input = self.parameter((1, 1), dtype=np.int32) + position_ids = self.parameter((1, 1), dtype=np.int64) + inv_freq = self.constant(inv_freq) + + # embed_tokens module + if padding_idx == -1: + padding_idx += vocab_size + + axis_node = self.constant(np.array([0], dtype=np.int64)) + if padding_idx is not None: + masked_embeddings = np.ones(weight.shape, dtype=np.float16) + masked_embeddings[padding_idx, :] = 0.0 # mask + + node_mask = self.constant(masked_embeddings) + node_masked_w = self.eltwise_mul(weight, node_mask) + res = self.gather(node_masked_w, input, axis_node, 0) + else: + res = self.gather(weight, input, axis_node, 0) + + # rotary_emb module + inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1)) + position_ids = self.reshape(position_ids, (1, 1, 1)) + freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq), + self.convert_to_fp32(position_ids)) + freqs = self.transpose(freqs, [0, 2, 1]) + emb = self.concat(freqs, freqs, axis=2) + cos = self.cos(emb) + sin = self.sin(emb) + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + # define outputs + res = self.convert_to_fp16(res) + cos = self.convert_to_fp32(cos) + sin = self.convert_to_fp32(sin) + + print("start compiling") + self.compile() + + +class Llama32PostEmbedding(NNFactory): + def __init__( + self, + inv_freq, + attention_scaling, + input_len: int = 1, + device: str = "NPU", + ): + super().__init__(False, device) + self.attention_scaling = attention_scaling + + # define input + position_ids = self.parameter((1, input_len), dtype=np.int64) + inv_freq = self.constant(inv_freq) + + # rotary_emb module + inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1)) + position_ids = self.reshape(position_ids, (1, 1, input_len)) + freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq), + self.convert_to_fp32(position_ids)) + freqs = self.transpose(freqs, [0, 2, 1]) + emb = self.concat(freqs, freqs, axis=2) + cos = self.cos(emb) + sin = self.sin(emb) + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + if input_len > 1: + cos = self.unsqueeze(cos, [1]) + sin = self.unsqueeze(sin, [1]) + + # define outputs + cos = self.convert_to_fp32(cos) + sin = self.convert_to_fp32(sin) + + print("start compiling") + self.compile() + + def obtain_weight_from_single_layer(attn_layer, mlp_layer): weights = [] if hasattr(attn_layer, "q_proj_dq_list"): @@ -216,3 +315,65 @@ def obtain_qkv_bias_from_single_layer(attn_layer): k_bias = attn_layer.k_proj.bias.to(torch.float16) v_bias = attn_layer.v_proj.bias.to(torch.float16) return q_bias, k_bias, v_bias + + +def obtain_embedding_from_model(model, convert_model, temp_dir, weight_dir, + max_prompt_len, keep_ir, compile_blob): + if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"): + # llama-2-7B & llama-3-8B + embedding_layer = model.model.embed_tokens + new_embedding = LLMEmbedding( + vocab_size=model.config.vocab_size, + embedding_dim=model.config.hidden_size, + embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), + padding_idx=model.config.pad_token_id, + dtype=np.float16, + ) + if convert_model: + bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") + embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) + first_blob_path = None + else: + first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", + temp_dir, keep_ir=keep_ir, + compile_blob=compile_blob) + os.remove(os.path.join(temp_dir, "embedding.bin")) + else: + # llama-3.2-3B & llama-3.2-1B + # for transformers >= 4.45.0 + embedding_layer = model.model.embed_tokens + new_embedding = Llama32Embedding( + vocab_size=model.config.vocab_size, + embedding_dim=model.config.hidden_size, + embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), + padding_idx=model.config.pad_token_id, + inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16), + attention_scaling=model.model.rotary_emb.attention_scaling, + dtype=np.float16, + ) + if convert_model: + bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") + embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) + first_blob_path = None + # save embedding post module + inv_freq = model.model.rotary_emb.inv_freq.to(torch.float16) + attention_scaling = model.model.rotary_emb.attention_scaling + embedding_post = Llama32PostEmbedding(inv_freq=inv_freq, + attention_scaling=attention_scaling, + input_len=1) + update_names_of_IR_and_export_blob(embedding_post, "embedding_post", + temp_dir, keep_ir=keep_ir, compile_blob=compile_blob) + embedding_post_prefill = Llama32PostEmbedding(inv_freq=inv_freq, + attention_scaling=attention_scaling, + input_len=max_prompt_len) + update_names_of_IR_and_export_blob(embedding_post_prefill, + "embedding_post_prefill", + temp_dir, keep_ir=keep_ir, compile_blob=compile_blob) + os.remove(os.path.join(temp_dir, "embedding_post.bin")) + os.remove(os.path.join(temp_dir, "embedding_post_prefill.bin")) + else: + first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", + temp_dir, keep_ir=keep_ir, + compile_blob=compile_blob) + os.remove(os.path.join(temp_dir, "embedding.bin")) + return first_blob_path diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 24bbca21fb5..023b6b53460 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -31,6 +31,7 @@ import numpy as np from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead from multiprocessing import Pool +import transformers def generate( @@ -456,6 +457,8 @@ def convert_llm_for_deploy(model: torch.nn.Module, custom_object_save(model, save_directory, config=model.config) if model.config.model_type == "qwen2": + cos_sin_input = not hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached") + embedding_post = not hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached") if group_size == 0: if model.config.hidden_size == 1536: # Qwen2-1.5B-Instruct @@ -476,6 +479,8 @@ def convert_llm_for_deploy(model: torch.nn.Module, "use_prefill_sdp": False, "weight_num": 7, "weight_idx": 8, + "embedding_post": embedding_post, + "cos_sin_input": cos_sin_input, "n_splits_linear": n_splits_linear, "n_splits_down_proj": n_splits_down_proj, "lm_head_low_bit": lm_head_low_bit} @@ -493,8 +498,8 @@ def convert_llm_for_deploy(model: torch.nn.Module, group_size, layernorm_const, "prefill", keep_ir=keep_ir, compile_blob=compile_blob) # save blob of lmhead and bin of embedding - convert_lm_head_and_embedding(model, save_directory, weight_dir, - convert_model=True, group_size=group_size, + convert_lm_head_and_embedding(model, save_directory, weight_dir, convert_model=True, + group_size=group_size, max_prompt_len=max_prompt_len, keep_ir=keep_ir, compile_blob=compile_blob) elif model.config.model_type == "llama": embedding_post = False diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py index dea8c0f32a4..714213796dd 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py @@ -18,108 +18,8 @@ import torch import numpy as np import os -from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead, \ - obtain_weight_from_single_layer -from intel_npu_acceleration_library.backend.factory import NNFactory - - -class Llama32Embedding(NNFactory): - def __init__( - self, - vocab_size, - embedding_dim, - embedding_weight, - padding_idx, - inv_freq, - attention_scaling, - dtype, # fp16 - device: str = "NPU", - ): - super().__init__(False, device) - self.vocab_size = vocab_size - self.embedding_dim = embedding_dim - self.padding_idx = padding_idx - self.attention_scaling = attention_scaling - self.dtype = dtype - - # define input - weight = self.constant(embedding_weight) - input = self.parameter((1, 1), dtype=np.int32) - position_ids = self.parameter((1, 1), dtype=np.int64) - inv_freq = self.constant(inv_freq) - - # embed_tokens module - if padding_idx == -1: - padding_idx += vocab_size - - axis_node = self.constant(np.array([0], dtype=np.int64)) - if padding_idx is not None: - masked_embeddings = np.ones(weight.shape, dtype=np.float16) - masked_embeddings[padding_idx, :] = 0.0 # mask - - node_mask = self.constant(masked_embeddings) - node_masked_w = self.eltwise_mul(weight, node_mask) - res = self.gather(node_masked_w, input, axis_node, 0) - else: - res = self.gather(weight, input, axis_node, 0) - - # rotary_emb module - inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1)) - position_ids = self.reshape(position_ids, (1, 1, 1)) - freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq), - self.convert_to_fp32(position_ids)) - freqs = self.transpose(freqs, [0, 2, 1]) - emb = self.concat(freqs, freqs, axis=2) - cos = self.cos(emb) - sin = self.sin(emb) - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - # define outputs - res = self.convert_to_fp16(res) - cos = self.convert_to_fp32(cos) - sin = self.convert_to_fp32(sin) - - print("start compiling") - self.compile() - - -class Llama32PostEmbedding(NNFactory): - def __init__( - self, - inv_freq, - attention_scaling, - input_len: int = 1, - device: str = "NPU", - ): - super().__init__(False, device) - self.attention_scaling = attention_scaling - - # define input - position_ids = self.parameter((1, input_len), dtype=np.int64) - inv_freq = self.constant(inv_freq) - - # rotary_emb module - inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1)) - position_ids = self.reshape(position_ids, (1, 1, input_len)) - freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq), - self.convert_to_fp32(position_ids)) - freqs = self.transpose(freqs, [0, 2, 1]) - emb = self.concat(freqs, freqs, axis=2) - cos = self.cos(emb) - sin = self.sin(emb) - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - if input_len > 1: - cos = self.unsqueeze(cos, [1]) - sin = self.unsqueeze(sin, [1]) - - # define outputs - cos = self.convert_to_fp32(cos) - sin = self.convert_to_fp32(sin) - - print("start compiling") - self.compile() +from .common import update_names_of_IR_and_export_blob, LowBitLLMLMHead, \ + obtain_weight_from_single_layer, obtain_embedding_from_model def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, @@ -197,62 +97,10 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin") weight.tofile(bin_file) - if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"): - # llama-2-7B & llama-3-8B - embedding_layer = model.model.embed_tokens - new_embedding = LLMEmbedding( - vocab_size=model.config.vocab_size, - embedding_dim=model.config.hidden_size, - embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), - padding_idx=model.config.pad_token_id, - dtype=np.float16, - ) - if convert_model: - bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") - embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) - first_blob_path = None - else: - first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", - temp_dir, keep_ir=keep_ir, - compile_blob=compile_blob) - os.remove(os.path.join(temp_dir, "embedding.bin")) - else: - # llama-3.2-3B & llama-3.2-1B - embedding_layer = model.model.embed_tokens - new_embedding = Llama32Embedding( - vocab_size=model.config.vocab_size, - embedding_dim=model.config.hidden_size, - embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), - padding_idx=model.config.pad_token_id, - inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16), - attention_scaling=model.model.rotary_emb.attention_scaling, - dtype=np.float16, - ) - if convert_model: - bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") - embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) - first_blob_path = None - # save embedding post module - inv_freq = model.model.rotary_emb.inv_freq.to(torch.float16) - attention_scaling = model.model.rotary_emb.attention_scaling - embedding_post = Llama32PostEmbedding(inv_freq=inv_freq, - attention_scaling=attention_scaling, - input_len=1) - update_names_of_IR_and_export_blob(embedding_post, "embedding_post", - temp_dir, keep_ir=keep_ir, compile_blob=compile_blob) - embedding_post_prefill = Llama32PostEmbedding(inv_freq=inv_freq, - attention_scaling=attention_scaling, - input_len=max_prompt_len) - update_names_of_IR_and_export_blob(embedding_post_prefill, - "embedding_post_prefill", - temp_dir, keep_ir=keep_ir, compile_blob=compile_blob) - os.remove(os.path.join(temp_dir, "embedding_post.bin")) - os.remove(os.path.join(temp_dir, "embedding_post_prefill.bin")) - else: - first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", - temp_dir, keep_ir=keep_ir, - compile_blob=compile_blob) - os.remove(os.path.join(temp_dir, "embedding.bin")) + first_blob_path = obtain_embedding_from_model(model, convert_model, + temp_dir, weight_dir, + max_prompt_len, + keep_ir, compile_blob) return first_blob_path, last_blob_path diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py index 183b71b9f42..076fc70bbc9 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py @@ -18,13 +18,14 @@ import torch import numpy as np import os -from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead, \ - obtain_weight_from_single_layer, obtain_qkv_bias_from_single_layer +from .common import update_names_of_IR_and_export_blob, LowBitLLMLMHead, \ + obtain_weight_from_single_layer, obtain_qkv_bias_from_single_layer, \ + obtain_embedding_from_model from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead def convert_lm_head_and_embedding(model, temp_dir, weight_dir, - convert_model=False, group_size=0, + convert_model=False, group_size=0, max_prompt_len=1, keep_ir=False, compile_blob=True): num_heads = model.model.layers[0].self_attn.num_heads head_dim = model.model.layers[0].self_attn.head_dim @@ -107,24 +108,10 @@ def convert_lm_head_and_embedding(model, temp_dir, weight_dir, bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin") weight.tofile(bin_file) - embedding_layer = model.model.embed_tokens - new_embedding = LLMEmbedding( - vocab_size=model.config.vocab_size, - embedding_dim=model.config.hidden_size, - embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), - padding_idx=model.config.pad_token_id, - dtype=np.float16, - input_length=1, - ) - if convert_model: - bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") - embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) - first_blob_path = True - else: - first_blob_path = update_names_of_IR_and_export_blob(new_embedding, f"embedding", - temp_dir, keep_ir=keep_ir, - compile_blob=compile_blob) - os.remove(os.path.join(temp_dir, "embedding.bin")) + first_blob_path = obtain_embedding_from_model(model, convert_model, + temp_dir, weight_dir, + max_prompt_len, + keep_ir, compile_blob) return first_blob_path, last_blob_path @@ -145,8 +132,13 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, mlp_layer = curr_layer.mlp weights = obtain_weight_from_single_layer(attn_layer, mlp_layer) q_bias, k_bias, v_bias = obtain_qkv_bias_from_single_layer(attn_layer) - cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) - cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): + cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) + cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + else: + # transformers >= 4.45.0 + cached_cos = None + cached_sin = None layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) @@ -158,10 +150,12 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, if mode == "decode": input_len = 1 decoder_name = f"decoder_layer_{layer_idx}" + keep_position_ids = True npu_dpu_groups = None else: input_len = kv_len decoder_name = "decoder_layer_prefill" + keep_position_ids = False npu_dpu_groups = 6 single_decoder = LowBitQwenMultiDecoderlayer( @@ -185,6 +179,8 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, group_size=group_size, + cos_len=input_len, + keep_position_ids=keep_position_ids, asym=asym ) rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, @@ -196,14 +192,25 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, # 0, 1, 2 are input_embed/attention_mask/position_id if mode == "decode": - if layernorm_const: - st_idx = 3 + if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): + if layernorm_const: + st_idx = 3 + else: + input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") + post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") + layer_norm_0.data.numpy().tofile(input_lm_bin_file) + layer_norm_1.data.numpy().tofile(post_lm_bin_file) + st_idx = 5 else: - input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") - post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") - layer_norm_0.data.numpy().tofile(input_lm_bin_file) - layer_norm_1.data.numpy().tofile(post_lm_bin_file) - st_idx = 5 + # transformers >= 4.45.0 + if layernorm_const: + st_idx = 4 + else: + input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") + post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_5.bin") + layer_norm_0.data.numpy().tofile(input_lm_bin_file) + layer_norm_1.data.numpy().tofile(post_lm_bin_file) + st_idx = 6 q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx}.bin") k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+1}.bin") v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+2}.bin") @@ -261,8 +268,13 @@ def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down attn_layer = curr_layer.self_attn mlp_layer = curr_layer.mlp weights = obtain_weight_from_single_layer(attn_layer, mlp_layer) - cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) - cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): + cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) + cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + else: + # transformers >= 4.45.0 + cached_cos = None + cached_sin = None layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)