From cabfa51cadfaf682e721154da174b7ed3767dc8a Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Sun, 24 Nov 2024 21:33:12 +0800 Subject: [PATCH 1/4] support minicpm-1b --- .../npu_pipeline_model/convert_pipeline.py | 33 +++ .../transformers/npu_pipeline_model/llama.py | 3 +- .../npu_pipeline_model/minicpm.py | 199 ++++++++++++++++-- 3 files changed, 211 insertions(+), 24 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 18ee5b1d4ad..a99bd07a39d 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -493,3 +493,36 @@ def convert_llm_for_deploy(model: torch.nn.Module, # save blob of lmhead and bin of embedding convert_lm_head_and_embedding(model, n_splits_linear, save_directory, weight_dir, True) + elif model.config.model_type == "minicpm": + layernorm_const = True + fused_layers = 2 + update_dict = {"kv_len": kv_len, + "num_head": model.model.layers[0].self_attn.num_heads, + "head_dim": model.model.layers[0].self_attn.head_dim, + "transpose_value_cache": transpose_value_cache, + "max_prompt_len": max_prompt_len, + "layernorm_const": layernorm_const, + "group_size": group_size, + "fused_layers": fused_layers, + "qkv_bias": False, + "use_prefill_sdp": False, + "weight_num": 7, + "weight_idx": 5, + "model_type": "minicpm", + "embedding_post": True} + model.config.update(update_dict) + model.config.save_pretrained(save_directory) + + from .minicpm import convert_minicpm_layer, convert_fused_minicpm_layer + from .minicpm import convert_lm_head_and_embedding + # save blob of lmhead and bin of embedding + convert_lm_head_and_embedding(model, n_splits_linear, + save_directory, weight_dir, True, max_prompt_len) + # save fused_layers blobs of fused decoder layers + convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_down_proj, + save_directory, weight_dir, transpose_value_cache, kv_len, + group_size, layernorm_const, "decode") + # save blob of single prefill layer + convert_minicpm_layer(model, 0, n_splits_linear, n_splits_down_proj, + save_directory, weight_dir, transpose_value_cache, max_prompt_len, + group_size, layernorm_const, "prefill") diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py index 3fb57381ef2..ecf1083a52b 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py @@ -147,10 +147,11 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, ) else: # llama-3.2-3B & llama-3.2-1B + embedding_layer = model.model.embed_tokens new_embedding = Llama32Embedding( vocab_size=model.config.vocab_size, embedding_dim=model.config.hidden_size, - embedding_weight=model.model.embed_tokens.weight.to(torch.float16).detach().numpy(), + embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), padding_idx=model.config.pad_token_id, inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16), attention_scaling=model.model.rotary_emb.attention_scaling, diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py index e5939efcb97..94b6ebd58be 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py @@ -58,6 +58,8 @@ def __init__( res = self.gather(node_masked_w, input, axis_node, 0) else: res = self.gather(weight, input, axis_node, 0) + print(res) + print(scale_emb) res = res * scale_emb # define outputs @@ -67,6 +69,31 @@ def __init__( self.compile() +class MiniCPMPostEmbedding(NNFactory): + def __init__( + self, + input_size, + embedding_dim, + dtype, # fp16 + scale_emb, + device: str = "NPU", + ): + super().__init__(False, device) + self.embedding_dim = embedding_dim + self.dtype = dtype + + input = self.parameter((1, input_size, embedding_dim), dtype=dtype) + print(input) + print(scale_emb) + res = input * scale_emb + + # define outputs + res = self.convert_to_fp16(res) + + print("start compiling") + self.compile() + + class MiniCPMLMHead(LLMBaseNNFactory): def __init__( self, @@ -134,7 +161,8 @@ def __init__( self.compile() -def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): +def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, + convert_model=False, max_prompt_len=1): num_heads = model.model.layers[0].self_attn.num_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads head_dim = model.model.layers[0].self_attn.head_dim @@ -180,7 +208,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): vocab_size=vocab_size, n_splits=n_splits_linear ) - last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir) + last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir, + True, True) # save weights bins files if n_splits_linear == 1: @@ -209,14 +238,31 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): dtype=np.float16, scale_emb=model.config.scale_emb, ) - first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", - temp_dir) + if convert_model: + bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") + embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) + first_blob_path = None + # save embedding post module + embedding_post = MiniCPMPostEmbedding(1, model.config.hidden_size, + dtype=np.float16, + scale_emb=model.config.scale_emb) + update_names_of_IR_and_export_blob(embedding_post, "embedding_post", + temp_dir, True, False) + embedding_post_prefill = MiniCPMPostEmbedding(max_prompt_len, model.config.hidden_size, + dtype=np.float16, + scale_emb=model.config.scale_emb) + update_names_of_IR_and_export_blob(embedding_post_prefill, + "embedding_post_prefill", + temp_dir, True, False) + else: + first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", + temp_dir, True, False) return first_blob_path, last_blob_path def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, temp_dir, weight_dir, transpose_value_cache, kv_len, group_size, - layernorm_const): + layernorm_const, mode="decode"): num_heads = model.model.layers[0].self_attn.num_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads head_dim = model.model.layers[0].self_attn.head_dim @@ -252,8 +298,16 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, else: # FP16 Linear np_dtype = np.float16 + if mode == "decode": + input_len = 1 + decoder_name = f"decoder_layer_{layer_idx}" + else: + input_len = kv_len + decoder_name = "decoder_layer_prefill" + layernorm_const = False + single_decoder = LowBitMinicpmMultiDecoderlayer( - [1, 1, num_heads * head_dim], + [1, input_len, num_heads * head_dim], input_layernorm_weights=[layer_norm_0] if layernorm_const else None, post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None, cached_cos=cached_cos, @@ -266,7 +320,7 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, intermediate_size=intermediate_size, scale_depth=scale_depth, num_hidden_layers=num_hidden_layers, - mode="decode", + mode=mode, transpose_value=transpose_value_cache, dtype=np_dtype, n_splits_linear=n_splits_linear, @@ -274,20 +328,119 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, group_size=group_size ) rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, - f"decoder_layer_{layer_idx}", - temp_dir) + decoder_name, + temp_dir, + True, True) - if layernorm_const: - st_idx = 5 - else: - input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") - post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") - layer_norm_0.data.numpy().tofile(input_lm_bin_file) - layer_norm_1.data.numpy().tofile(post_lm_bin_file) - st_idx = 7 - for idx, (weight, scale) in enumerate(weights): - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") - weight.numpy().tofile(bin_file) - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") - scale.numpy().tofile(bin_file) - del single_decoder + if mode == "decode": + if layernorm_const: + st_idx = 5 + else: + input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") + post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") + layer_norm_0.data.numpy().tofile(input_lm_bin_file) + layer_norm_1.data.numpy().tofile(post_lm_bin_file) + st_idx = 7 + for idx, (weight, scale) in enumerate(weights): + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") + scale.numpy().tofile(bin_file) + del single_decoder + + +def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_down_proj, + save_dir, weight_dir, transpose_value_cache, kv_len, group_size, + layernorm_const, mode="decode"): + num_heads = model.model.layers[0].self_attn.num_heads + num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads + head_dim = model.model.layers[0].self_attn.head_dim + intermediate_size = model.config.intermediate_size + rms_norm_eps = model.config.rms_norm_eps + num_hidden_layers = model.config.num_hidden_layers + scale_depth = model.model.config.scale_depth + layer_num = len(model.model.layers) + fused_layer_num = layer_num // fused_layers + + from ipex_llm.transformers.npu_models.minicpm_mp import LowBitMinicpmMultiDecoderlayer + for i in range(fused_layers): + layer_start = i * fused_layer_num + layer_end = min((i + 1) * fused_layer_num, layer_num) + layer_weights = [] + input_layer_norm_weights = [] + post_attn_layernorm_weights = [] + layer_indexs = range(layer_start, layer_end) + n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) + n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) + for layer_idx in layer_indexs: + curr_layer = model.model.layers[layer_idx] + attn_layer = curr_layer.self_attn + mlp_layer = curr_layer.mlp + + weights = [] + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + + cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) + cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) + layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) + + layer_weights.extend(weights) + input_layer_norm_weights.append(layer_norm_0) + post_attn_layernorm_weights.append(layer_norm_1) + + # save weight + input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") + post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") + layer_norm_0.data.numpy().tofile(input_lm_bin_file) + layer_norm_1.data.numpy().tofile(post_lm_bin_file) + st_idx = 5 + # 6, 7 are past k/v + for idx, (weight, scale) in enumerate(weights): + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") + scale.numpy().tofile(bin_file) + + if isinstance(weights[0], tuple): + np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 + else: # FP16 Linear + np_dtype = np.float16 + + fused_decoder = LowBitMinicpmMultiDecoderlayer( + [1, 1, num_heads * head_dim], + input_layernorm_weights=input_layer_norm_weights, + post_attn_layernorm_weights=post_attn_layernorm_weights, + cached_cos=cached_cos, + cached_sin=cached_sin, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + num_layers=fused_layer_num, + max_seq_len=kv_len, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + scale_depth=scale_depth, + num_hidden_layers=num_hidden_layers, + mode=mode, + transpose_value=transpose_value_cache, + dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size + ) + update_names_of_IR_and_export_blob(fused_decoder, + f"decoder_layer_{i}", + save_dir, + compile_blob=True, + keep_ir=False) + return 0 From d0b6d3377772354c1f628c9e73407386ba4f26b2 Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Sun, 24 Nov 2024 21:45:23 +0800 Subject: [PATCH 2/4] update --- .../transformers/npu_pipeline_model/convert_pipeline.py | 6 +++--- .../src/ipex_llm/transformers/npu_pipeline_model/minicpm.py | 4 ---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index a99bd07a39d..f418ea9e1fd 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -515,9 +515,6 @@ def convert_llm_for_deploy(model: torch.nn.Module, from .minicpm import convert_minicpm_layer, convert_fused_minicpm_layer from .minicpm import convert_lm_head_and_embedding - # save blob of lmhead and bin of embedding - convert_lm_head_and_embedding(model, n_splits_linear, - save_directory, weight_dir, True, max_prompt_len) # save fused_layers blobs of fused decoder layers convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_down_proj, save_directory, weight_dir, transpose_value_cache, kv_len, @@ -526,3 +523,6 @@ def convert_llm_for_deploy(model: torch.nn.Module, convert_minicpm_layer(model, 0, n_splits_linear, n_splits_down_proj, save_directory, weight_dir, transpose_value_cache, max_prompt_len, group_size, layernorm_const, "prefill") + # save blob of lmhead and bin of embedding + convert_lm_head_and_embedding(model, n_splits_linear, + save_directory, weight_dir, True, max_prompt_len) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py index 94b6ebd58be..1893db9c963 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py @@ -58,8 +58,6 @@ def __init__( res = self.gather(node_masked_w, input, axis_node, 0) else: res = self.gather(weight, input, axis_node, 0) - print(res) - print(scale_emb) res = res * scale_emb # define outputs @@ -83,8 +81,6 @@ def __init__( self.dtype = dtype input = self.parameter((1, input_size, embedding_dim), dtype=dtype) - print(input) - print(scale_emb) res = input * scale_emb # define outputs From 1baa8d77a04e700869eccb955e9ba6ca8b84e789 Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Sun, 24 Nov 2024 22:20:18 +0800 Subject: [PATCH 3/4] tune fused_layers --- .../transformers/npu_pipeline_model/convert_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index f418ea9e1fd..4537a756acc 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -495,7 +495,7 @@ def convert_llm_for_deploy(model: torch.nn.Module, save_directory, weight_dir, True) elif model.config.model_type == "minicpm": layernorm_const = True - fused_layers = 2 + fused_layers = 4 update_dict = {"kv_len": kv_len, "num_head": model.model.layers[0].self_attn.num_heads, "head_dim": model.model.layers[0].self_attn.head_dim, From e857c391f4f7b02e54e77673b66d9661bd3a096a Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Mon, 25 Nov 2024 09:40:47 +0800 Subject: [PATCH 4/4] update readme.md --- .../NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md index 22a12a8313b..aa5e076c20d 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md @@ -9,6 +9,7 @@ In this directory, you will find a C++ example on how to run LLM models on Intel | Qwen2.5 | [Qwen/Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) | | Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | | Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | +| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16), [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) | ## 0. Requirements To run this C++ example with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.