From cd49e76a6af79a53a99216390396ab3c247c5c83 Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Thu, 14 Dec 2023 14:00:55 +0100 Subject: [PATCH] Extend llamalike-converter (#2536) * add various cases to the converter --- tools/convert_HF_llamalike.py | 262 ++++++++++++++++++++++++---------- 1 file changed, 190 insertions(+), 72 deletions(-) diff --git a/tools/convert_HF_llamalike.py b/tools/convert_HF_llamalike.py index 64053d72a1..8fe9d4efc6 100755 --- a/tools/convert_HF_llamalike.py +++ b/tools/convert_HF_llamalike.py @@ -12,6 +12,27 @@ import huggingface_hub from safetensors.torch import save_file +key_maps = {} +key_maps["LlamaForCausalLM"] = { + "layer_prefix": "model.layers.", + "decoder.embeddings.make_embedding.emb_luts.0.weight": "model.embed_tokens.weight", + "decoder.layer_norm.weight": "model.norm.weight", + "generator.weight": "lm_head.weight", + ".self_attn.linear_query.": ".self_attn.q_proj.", + ".self_attn.linear_keys.": ".self_attn.k_proj.", + ".self_attn.linear_values.": ".self_attn.v_proj.", + ".self_attn.final_linear.": ".self_attn.o_proj.", + ".feed_forward.w_1.": ".mlp.gate_proj.", + ".feed_forward.w_2.": ".mlp.down_proj.", + ".feed_forward.w_3.": ".mlp.up_proj.", + ".layer_norm_1.weight": ".input_layernorm.weight", + ".feed_forward.layer_norm.weight": ".post_attention_layernorm.weight", +} +key_maps["MistralForCausalLM"] = key_maps["LlamaForCausalLM"] +ln_table = {"LlamaForCausalLM": "rms", "MistralForCausalLM": "rms"} +act_table = {"LlamaForCausalLM": "silu", "MistralForCausalLM": "silu"} +decoder_start_table = {"LlamaForCausalLM": "", "MistralForCausalLM": ""} + class Tokenizer: def __init__(self, model_path: str): @@ -77,9 +98,14 @@ def __init__(self, model_path: str): if os.path.exists(os.path.join(opt.model_dir, "tokenizer.model")): tokenizer_model = os.path.join(opt.model_dir, "tokenizer.model") else: - raise ValueError( - "You used a local directory but tokenizer.model is missing" - ) + if os.path.exists(os.path.join(opt.model_dir, "tokenizer.json")): + tokenizer_json = os.path.join(opt.model_dir, "tokenizer.json") + tokenizer_model = None + else: + raise ValueError( + "You used a local directory but tokenizer.model", + " and/or tokenizer.json are missing", + ) else: directory_path, _ = os.path.split(opt.output) os.makedirs(directory_path, exist_ok=True) @@ -91,9 +117,18 @@ def __init__(self, model_path: str): token=opt.token, ) except huggingface_hub.utils.EntryNotFoundError: - raise huggingface_hub.utils.EntryNotFoundError( - "Make sure the repo contains tokenizer.model - needed for all Llama-like models" - ) + try: + tokenizer_json = huggingface_hub.hf_hub_download( + repo_id=opt.model_dir, + filename="tokenizer.json", + local_dir=directory_path, + token=opt.token, + ) + tokenizer_model = None + except huggingface_hub.utils.EntryNotFoundError: + raise huggingface_hub.utils.EntryNotFoundError( + "Make sure the repo contains tokenizer.model or tokenizer.json" + ) try: config_path = huggingface_hub.hf_hub_download( repo_id=opt.model_dir, @@ -146,23 +181,46 @@ def __init__(self, model_path: str): with open(config_path, encoding="utf-8") as fconfig: config = json.load(fconfig) + arch = config["architectures"][0] decoder_layers = config["num_hidden_layers"] src_word_vec_size = config["hidden_size"] tgt_word_vec_size = config["hidden_size"] hidden_size = config["hidden_size"] heads = config["num_attention_heads"] vocab_size = config["vocab_size"] - transformer_ff = config["intermediate_size"] - - if ( + if "intermediate_size" in config.keys(): + transformer_ff = config["intermediate_size"] + else: + transformer_ff = hidden_size * 4 + pos_ffn_activation_fn = act_table[arch] + layer_norm = ln_table[arch] + + multiquery = False + if "multi_query" in config.keys(): + multiquery = config["multi_query"] + num_kv = 1 + elif ( "num_key_value_heads" in config.keys() and config["num_key_value_heads"] != heads ): num_kv = config["num_key_value_heads"] + elif "num_kv_heads" in config.keys() and config["num_kv_heads"] != heads: + num_kv = config["num_kv_heads"] + elif "n_head_kv" in config.keys() and config["n_head_kv"] != heads: + num_kv = config["n_head_kv"] else: num_kv = 0 + shared_layer = num_kv == 1 + + if "parallel_attn" in config.keys(): + parallel_residual = config["parallel_attn"] + else: + parallel_residual = False + if "rms_norm_eps" in config.keys(): norm_eps = config["rms_norm_eps"] + elif "layer_norm_epsilon" in config.keys(): + norm_eps = config["layer_norm_epsilon"] else: norm_eps = 1e-6 if "sliding_window" in config.keys(): @@ -247,7 +305,7 @@ def get_load_ckpt(dir_path, file_path): ) except PermissionError: ckpt_path = os.path.join(dir_path, file_path) - if ckpt_path[-3:] == ".pt": + if ckpt_path[-4:] == ".bin": checkpoint = torch.load(ckpt_path, map_location=torch.device("cpu")) else: checkpoint = ckpt_path @@ -273,27 +331,24 @@ def get_weight(checkpoint, tensor_name): onmt_safetensor = {} if shard == 0: - sourcelist = [ - "model.embed_tokens.weight", - "model.norm.weight", - "lm_head.weight", - ] targetlist = [ "decoder.embeddings.make_embedding.emb_luts.0.weight", "decoder.layer_norm.weight", + "decoder.layer_norm.bias", "generator.weight", ] - - for source, target in zip(sourcelist, targetlist): - if wmap_path: - checkpoint = get_load_ckpt( - os.path.split(wmap_path)[0], wmap["weight_map"][source] - ) - else: - checkpoint = get_load_ckpt(*os.path.split(model_path)) - w = get_weight(checkpoint, source) - if w is not None: - onmt_safetensor[target] = w + for target in targetlist: + if target in key_maps[arch].keys(): + source = key_maps[arch][target] + if wmap_path: + checkpoint = get_load_ckpt( + os.path.split(wmap_path)[0], wmap["weight_map"][source] + ) + else: + checkpoint = get_load_ckpt(*os.path.split(model_path)) + w = get_weight(checkpoint, source) + if w is not None: + onmt_safetensor[target] = w onmt_safetensor["generator.bias"] = torch.zeros( onmt_safetensor["generator.weight"].size(0), dtype=torch.float16 @@ -304,7 +359,7 @@ def get_weight(checkpoint, tensor_name): ckpt_list = [] for key in weightmap.keys(): if ( - key.startswith("model.layers.") + key.startswith(key_maps[arch]["layer_prefix"]) and int(key.split(".")[2]) in range( -(decoder_layers // -opt.nshards) * shard, @@ -317,6 +372,7 @@ def get_weight(checkpoint, tensor_name): and weightmap[key] not in ckpt_list ): ckpt_list.append(weightmap[key]) + print(weightmap[key]) else: ckpt_list = [model_path] @@ -332,24 +388,7 @@ def get_weight(checkpoint, tensor_name): 1, ): - w = get_weight( - checkpoint, "model.layers." + str(i) + ".input_layernorm.weight" - ) - if w is not None: - onmt_safetensor[ - "decoder.transformer_layers." + str(i) + ".layer_norm_1.weight" - ] = w - for param in params: - sourcelist = [ - ".self_attn.q_proj.", - ".self_attn.k_proj.", - ".self_attn.v_proj.", - ".self_attn.o_proj.", - ".mlp.gate_proj.", - ".mlp.down_proj.", - ".mlp.up_proj.", - ] targetlist = [ ".self_attn.linear_query.", ".self_attn.linear_keys.", @@ -359,25 +398,86 @@ def get_weight(checkpoint, tensor_name): ".feed_forward.w_2.", ".feed_forward.w_3.", ] - for source, target in zip(sourcelist, targetlist): + for target in targetlist: + if target in key_maps[arch].keys(): + source = key_maps[arch][target] + if type(source) == tuple: + srckey = source[0] + srcmap = source[1] + else: + srckey = source + w = get_weight( + checkpoint, + key_maps[arch]["layer_prefix"] + + str(i) + + srckey + + param, + ) + + if w is not None: + if type(source) == tuple: + w = eval("w" + srcmap) + onmt_safetensor[ + "decoder.transformer_layers." + + str(i) + + target + + param + ] = w + + if shared_layer: + idx = 0 + else: + idx = 1 + for p in ["weight", "bias"]: + if ".layer_norm_1." + p in key_maps[arch].keys(): + if type(key_maps[arch][".layer_norm_1." + p]) == tuple: + w = get_weight( + checkpoint, + key_maps[arch]["layer_prefix"] + + str(i) + + key_maps[arch][".layer_norm_1." + p][idx], + ) + else: + w = get_weight( + checkpoint, + key_maps[arch]["layer_prefix"] + + str(i) + + key_maps[arch][".layer_norm_1." + p], + ) + if w is not None: + onmt_safetensor[ + "decoder.transformer_layers." + + str(i) + + ".layer_norm_1." + + p + ] = w + if ".layer_norm_res." + p in key_maps[arch].keys(): w = get_weight( - checkpoint, "model.layers." + str(i) + source + param + checkpoint, + key_maps[arch]["layer_prefix"] + + str(i) + + key_maps[arch][".layer_norm_res." + p], ) if w is not None: onmt_safetensor[ - "decoder.transformer_layers." + str(i) + target + param + "decoder.transformer_layers." + + str(i) + + ".layer_norm_res." + + p + ] = w + if ".feed_forward.layer_norm.weight" in key_maps[arch].keys(): + w = get_weight( + checkpoint, + key_maps[arch]["layer_prefix"] + + str(i) + + key_maps[arch][".feed_forward.layer_norm.weight"], + ) + if w is not None: + onmt_safetensor[ + "decoder.transformer_layers." + + str(i) + + ".feed_forward.layer_norm.weight" ] = w - - w = get_weight( - checkpoint, - "model.layers." + str(i) + ".post_attention_layernorm.weight", - ) - if w is not None: - onmt_safetensor[ - "decoder.transformer_layers." - + str(i) - + ".feed_forward.layer_norm.weight" - ] = w if shard == 0: vocab_size = onmt_safetensor["generator.weight"].size(0) @@ -395,23 +495,40 @@ def get_weight(checkpoint, tensor_name): onmt_cp["model"] = {} onmt_cp["model"] = onmt_safetensor - tokenizer = Tokenizer(model_path=tokenizer_model) + directory_path, _ = os.path.split(opt.output) + os.makedirs(directory_path, exist_ok=True) vocabs = {} - vocab = tokenizer.vocab - vocab[3] = DefaultTokens.PAD - src_vocab = pyonmttok.build_vocab_from_tokens( - vocab, maximum_size=tokenizer.n_words, special_tokens=["", "", ""] - ) + if tokenizer_model is not None: + tokenizer = Tokenizer(model_path=tokenizer_model) + vocab = tokenizer.vocab + vocab[3] = DefaultTokens.PAD + src_vocab = pyonmttok.build_vocab_from_tokens( + vocab, + maximum_size=tokenizer.n_words, + special_tokens=["", "", ""], + ) + else: # this section is not used for llama for now + with open(tokenizer_json, encoding="utf-8") as f: + data = json.load(f) + vocab = [ + tok if tok != "Ā" else DefaultTokens.PAD for tok in data["model"]["vocab"] + ] + vocab[11] = "" # Falcon only + src_vocab = pyonmttok.build_vocab_from_tokens(vocab) + with open( + os.path.join(directory_path, "bpe.model"), "w", encoding="utf-8" + ) as bpemodel: + bpemodel.write("v3;false;false;false;Ġ;Ġ\n") + for merge in data["model"]["merges"]: + bpemodel.write(merge + "\n") + vocabs["src"] = src_vocab vocabs["tgt"] = src_vocab vocabs["data_task"] = "lm" - vocabs["decoder_start_token"] = "" - + vocabs["decoder_start_token"] = decoder_start_table[arch] onmt_cp["vocab"] = {} onmt_cp["vocab"] = vocabs_to_dict(vocabs) - directory_path, _ = os.path.split(opt.output) - os.makedirs(directory_path, exist_ok=True) with open( os.path.join(directory_path, "vocab.txt"), "w", encoding="utf-8" ) as vocabfile: @@ -497,9 +614,9 @@ def get_weight(checkpoint, tensor_name): enc_hid_size=hidden_size, dec_hid_size=hidden_size, cnn_kernel_width=3, - layer_norm="rms", + layer_norm=layer_norm, norm_eps=norm_eps, - pos_ffn_activation_fn="silu", + pos_ffn_activation_fn=pos_ffn_activation_fn, input_feed=1, bridge=False, rnn_type="LSTM", @@ -522,9 +639,10 @@ def get_weight(checkpoint, tensor_name): aan_useffn=False, add_qkvbias=False, add_ffnbias=False, - multiquery=False, + multiquery=multiquery, num_kv=num_kv, - parallel_residual=False, + parallel_residual=parallel_residual, + shared_layer_norm=shared_layer, lambda_align=0.0, alignment_layer=-3, alignment_heads=0,