Skip to content

Commit

Permalink
Extend llamalike-converter (#2536)
Browse files Browse the repository at this point in the history
* add various cases to the converter
  • Loading branch information
vince62s authored Dec 14, 2023
1 parent f0bd36f commit cd49e76
Showing 1 changed file with 190 additions and 72 deletions.
262 changes: 190 additions & 72 deletions tools/convert_HF_llamalike.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,27 @@
import huggingface_hub
from safetensors.torch import save_file

key_maps = {}
key_maps["LlamaForCausalLM"] = {
"layer_prefix": "model.layers.",
"decoder.embeddings.make_embedding.emb_luts.0.weight": "model.embed_tokens.weight",
"decoder.layer_norm.weight": "model.norm.weight",
"generator.weight": "lm_head.weight",
".self_attn.linear_query.": ".self_attn.q_proj.",
".self_attn.linear_keys.": ".self_attn.k_proj.",
".self_attn.linear_values.": ".self_attn.v_proj.",
".self_attn.final_linear.": ".self_attn.o_proj.",
".feed_forward.w_1.": ".mlp.gate_proj.",
".feed_forward.w_2.": ".mlp.down_proj.",
".feed_forward.w_3.": ".mlp.up_proj.",
".layer_norm_1.weight": ".input_layernorm.weight",
".feed_forward.layer_norm.weight": ".post_attention_layernorm.weight",
}
key_maps["MistralForCausalLM"] = key_maps["LlamaForCausalLM"]
ln_table = {"LlamaForCausalLM": "rms", "MistralForCausalLM": "rms"}
act_table = {"LlamaForCausalLM": "silu", "MistralForCausalLM": "silu"}
decoder_start_table = {"LlamaForCausalLM": "<s>", "MistralForCausalLM": "<s>"}


class Tokenizer:
def __init__(self, model_path: str):
Expand Down Expand Up @@ -77,9 +98,14 @@ def __init__(self, model_path: str):
if os.path.exists(os.path.join(opt.model_dir, "tokenizer.model")):
tokenizer_model = os.path.join(opt.model_dir, "tokenizer.model")
else:
raise ValueError(
"You used a local directory but tokenizer.model is missing"
)
if os.path.exists(os.path.join(opt.model_dir, "tokenizer.json")):
tokenizer_json = os.path.join(opt.model_dir, "tokenizer.json")
tokenizer_model = None
else:
raise ValueError(
"You used a local directory but tokenizer.model",
" and/or tokenizer.json are missing",
)
else:
directory_path, _ = os.path.split(opt.output)
os.makedirs(directory_path, exist_ok=True)
Expand All @@ -91,9 +117,18 @@ def __init__(self, model_path: str):
token=opt.token,
)
except huggingface_hub.utils.EntryNotFoundError:
raise huggingface_hub.utils.EntryNotFoundError(
"Make sure the repo contains tokenizer.model - needed for all Llama-like models"
)
try:
tokenizer_json = huggingface_hub.hf_hub_download(
repo_id=opt.model_dir,
filename="tokenizer.json",
local_dir=directory_path,
token=opt.token,
)
tokenizer_model = None
except huggingface_hub.utils.EntryNotFoundError:
raise huggingface_hub.utils.EntryNotFoundError(
"Make sure the repo contains tokenizer.model or tokenizer.json"
)
try:
config_path = huggingface_hub.hf_hub_download(
repo_id=opt.model_dir,
Expand Down Expand Up @@ -146,23 +181,46 @@ def __init__(self, model_path: str):
with open(config_path, encoding="utf-8") as fconfig:
config = json.load(fconfig)

arch = config["architectures"][0]
decoder_layers = config["num_hidden_layers"]
src_word_vec_size = config["hidden_size"]
tgt_word_vec_size = config["hidden_size"]
hidden_size = config["hidden_size"]
heads = config["num_attention_heads"]
vocab_size = config["vocab_size"]
transformer_ff = config["intermediate_size"]

if (
if "intermediate_size" in config.keys():
transformer_ff = config["intermediate_size"]
else:
transformer_ff = hidden_size * 4
pos_ffn_activation_fn = act_table[arch]
layer_norm = ln_table[arch]

multiquery = False
if "multi_query" in config.keys():
multiquery = config["multi_query"]
num_kv = 1
elif (
"num_key_value_heads" in config.keys()
and config["num_key_value_heads"] != heads
):
num_kv = config["num_key_value_heads"]
elif "num_kv_heads" in config.keys() and config["num_kv_heads"] != heads:
num_kv = config["num_kv_heads"]
elif "n_head_kv" in config.keys() and config["n_head_kv"] != heads:
num_kv = config["n_head_kv"]
else:
num_kv = 0
shared_layer = num_kv == 1

if "parallel_attn" in config.keys():
parallel_residual = config["parallel_attn"]
else:
parallel_residual = False

if "rms_norm_eps" in config.keys():
norm_eps = config["rms_norm_eps"]
elif "layer_norm_epsilon" in config.keys():
norm_eps = config["layer_norm_epsilon"]
else:
norm_eps = 1e-6
if "sliding_window" in config.keys():
Expand Down Expand Up @@ -247,7 +305,7 @@ def get_load_ckpt(dir_path, file_path):
)
except PermissionError:
ckpt_path = os.path.join(dir_path, file_path)
if ckpt_path[-3:] == ".pt":
if ckpt_path[-4:] == ".bin":
checkpoint = torch.load(ckpt_path, map_location=torch.device("cpu"))
else:
checkpoint = ckpt_path
Expand All @@ -273,27 +331,24 @@ def get_weight(checkpoint, tensor_name):
onmt_safetensor = {}

if shard == 0:
sourcelist = [
"model.embed_tokens.weight",
"model.norm.weight",
"lm_head.weight",
]
targetlist = [
"decoder.embeddings.make_embedding.emb_luts.0.weight",
"decoder.layer_norm.weight",
"decoder.layer_norm.bias",
"generator.weight",
]

for source, target in zip(sourcelist, targetlist):
if wmap_path:
checkpoint = get_load_ckpt(
os.path.split(wmap_path)[0], wmap["weight_map"][source]
)
else:
checkpoint = get_load_ckpt(*os.path.split(model_path))
w = get_weight(checkpoint, source)
if w is not None:
onmt_safetensor[target] = w
for target in targetlist:
if target in key_maps[arch].keys():
source = key_maps[arch][target]
if wmap_path:
checkpoint = get_load_ckpt(
os.path.split(wmap_path)[0], wmap["weight_map"][source]
)
else:
checkpoint = get_load_ckpt(*os.path.split(model_path))
w = get_weight(checkpoint, source)
if w is not None:
onmt_safetensor[target] = w

onmt_safetensor["generator.bias"] = torch.zeros(
onmt_safetensor["generator.weight"].size(0), dtype=torch.float16
Expand All @@ -304,7 +359,7 @@ def get_weight(checkpoint, tensor_name):
ckpt_list = []
for key in weightmap.keys():
if (
key.startswith("model.layers.")
key.startswith(key_maps[arch]["layer_prefix"])
and int(key.split(".")[2])
in range(
-(decoder_layers // -opt.nshards) * shard,
Expand All @@ -317,6 +372,7 @@ def get_weight(checkpoint, tensor_name):
and weightmap[key] not in ckpt_list
):
ckpt_list.append(weightmap[key])
print(weightmap[key])
else:
ckpt_list = [model_path]

Expand All @@ -332,24 +388,7 @@ def get_weight(checkpoint, tensor_name):
1,
):

w = get_weight(
checkpoint, "model.layers." + str(i) + ".input_layernorm.weight"
)
if w is not None:
onmt_safetensor[
"decoder.transformer_layers." + str(i) + ".layer_norm_1.weight"
] = w

for param in params:
sourcelist = [
".self_attn.q_proj.",
".self_attn.k_proj.",
".self_attn.v_proj.",
".self_attn.o_proj.",
".mlp.gate_proj.",
".mlp.down_proj.",
".mlp.up_proj.",
]
targetlist = [
".self_attn.linear_query.",
".self_attn.linear_keys.",
Expand All @@ -359,25 +398,86 @@ def get_weight(checkpoint, tensor_name):
".feed_forward.w_2.",
".feed_forward.w_3.",
]
for source, target in zip(sourcelist, targetlist):
for target in targetlist:
if target in key_maps[arch].keys():
source = key_maps[arch][target]
if type(source) == tuple:
srckey = source[0]
srcmap = source[1]
else:
srckey = source
w = get_weight(
checkpoint,
key_maps[arch]["layer_prefix"]
+ str(i)
+ srckey
+ param,
)

if w is not None:
if type(source) == tuple:
w = eval("w" + srcmap)
onmt_safetensor[
"decoder.transformer_layers."
+ str(i)
+ target
+ param
] = w

if shared_layer:
idx = 0
else:
idx = 1
for p in ["weight", "bias"]:
if ".layer_norm_1." + p in key_maps[arch].keys():
if type(key_maps[arch][".layer_norm_1." + p]) == tuple:
w = get_weight(
checkpoint,
key_maps[arch]["layer_prefix"]
+ str(i)
+ key_maps[arch][".layer_norm_1." + p][idx],
)
else:
w = get_weight(
checkpoint,
key_maps[arch]["layer_prefix"]
+ str(i)
+ key_maps[arch][".layer_norm_1." + p],
)
if w is not None:
onmt_safetensor[
"decoder.transformer_layers."
+ str(i)
+ ".layer_norm_1."
+ p
] = w
if ".layer_norm_res." + p in key_maps[arch].keys():
w = get_weight(
checkpoint, "model.layers." + str(i) + source + param
checkpoint,
key_maps[arch]["layer_prefix"]
+ str(i)
+ key_maps[arch][".layer_norm_res." + p],
)
if w is not None:
onmt_safetensor[
"decoder.transformer_layers." + str(i) + target + param
"decoder.transformer_layers."
+ str(i)
+ ".layer_norm_res."
+ p
] = w
if ".feed_forward.layer_norm.weight" in key_maps[arch].keys():
w = get_weight(
checkpoint,
key_maps[arch]["layer_prefix"]
+ str(i)
+ key_maps[arch][".feed_forward.layer_norm.weight"],
)
if w is not None:
onmt_safetensor[
"decoder.transformer_layers."
+ str(i)
+ ".feed_forward.layer_norm.weight"
] = w

w = get_weight(
checkpoint,
"model.layers." + str(i) + ".post_attention_layernorm.weight",
)
if w is not None:
onmt_safetensor[
"decoder.transformer_layers."
+ str(i)
+ ".feed_forward.layer_norm.weight"
] = w

if shard == 0:
vocab_size = onmt_safetensor["generator.weight"].size(0)
Expand All @@ -395,23 +495,40 @@ def get_weight(checkpoint, tensor_name):
onmt_cp["model"] = {}
onmt_cp["model"] = onmt_safetensor

tokenizer = Tokenizer(model_path=tokenizer_model)
directory_path, _ = os.path.split(opt.output)
os.makedirs(directory_path, exist_ok=True)
vocabs = {}
vocab = tokenizer.vocab
vocab[3] = DefaultTokens.PAD
src_vocab = pyonmttok.build_vocab_from_tokens(
vocab, maximum_size=tokenizer.n_words, special_tokens=["<unk>", "<s>", "</s>"]
)
if tokenizer_model is not None:
tokenizer = Tokenizer(model_path=tokenizer_model)
vocab = tokenizer.vocab
vocab[3] = DefaultTokens.PAD
src_vocab = pyonmttok.build_vocab_from_tokens(
vocab,
maximum_size=tokenizer.n_words,
special_tokens=["<unk>", "<s>", "</s>"],
)
else: # this section is not used for llama for now
with open(tokenizer_json, encoding="utf-8") as f:
data = json.load(f)
vocab = [
tok if tok != "Ā" else DefaultTokens.PAD for tok in data["model"]["vocab"]
]
vocab[11] = "</s>" # Falcon only
src_vocab = pyonmttok.build_vocab_from_tokens(vocab)
with open(
os.path.join(directory_path, "bpe.model"), "w", encoding="utf-8"
) as bpemodel:
bpemodel.write("v3;false;false;false;Ġ;Ġ\n")
for merge in data["model"]["merges"]:
bpemodel.write(merge + "\n")

vocabs["src"] = src_vocab
vocabs["tgt"] = src_vocab
vocabs["data_task"] = "lm"
vocabs["decoder_start_token"] = "<s>"

vocabs["decoder_start_token"] = decoder_start_table[arch]
onmt_cp["vocab"] = {}
onmt_cp["vocab"] = vocabs_to_dict(vocabs)

directory_path, _ = os.path.split(opt.output)
os.makedirs(directory_path, exist_ok=True)
with open(
os.path.join(directory_path, "vocab.txt"), "w", encoding="utf-8"
) as vocabfile:
Expand Down Expand Up @@ -497,9 +614,9 @@ def get_weight(checkpoint, tensor_name):
enc_hid_size=hidden_size,
dec_hid_size=hidden_size,
cnn_kernel_width=3,
layer_norm="rms",
layer_norm=layer_norm,
norm_eps=norm_eps,
pos_ffn_activation_fn="silu",
pos_ffn_activation_fn=pos_ffn_activation_fn,
input_feed=1,
bridge=False,
rnn_type="LSTM",
Expand All @@ -522,9 +639,10 @@ def get_weight(checkpoint, tensor_name):
aan_useffn=False,
add_qkvbias=False,
add_ffnbias=False,
multiquery=False,
multiquery=multiquery,
num_kv=num_kv,
parallel_residual=False,
parallel_residual=parallel_residual,
shared_layer_norm=shared_layer,
lambda_align=0.0,
alignment_layer=-3,
alignment_heads=0,
Expand Down

0 comments on commit cd49e76

Please sign in to comment.