-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into llama3_converter
- Loading branch information
Showing
27 changed files
with
1,315 additions
and
243 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
on: | ||
push: | ||
|
||
name: Secret Leaks | ||
|
||
jobs: | ||
trufflehog: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v4 | ||
with: | ||
fetch-depth: 0 | ||
- name: Secret Scanning | ||
uses: trufflesecurity/trufflehog@main |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
## Debugging the tests with vscode | ||
|
||
To debug the tests with vscode, add the following json to your `launch.json` file. | ||
|
||
``` | ||
{ | ||
"name": "Test conversion", | ||
"type": "python", | ||
"request": "launch", | ||
"module": "pytest", | ||
"console": "integratedTerminal", | ||
"args": [ | ||
"examples/llama/tests" | ||
], | ||
"justMyCode": false | ||
} | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
""" | ||
Converts a HF model to nanotron format | ||
Command: | ||
torchrun --nproc_per_node=1 convert_hf_to_nanotron.py --checkpoint_path=hf_weights --save_path=nanotron_weights | ||
""" | ||
|
||
import dataclasses | ||
import json | ||
from argparse import ArgumentParser | ||
from pathlib import Path | ||
|
||
import nanotron | ||
import torch | ||
from convert_weights import get_config_mapping, get_weight_mapping, load_nanotron_model | ||
from nanotron.config import LlamaConfig as NanotronLlamaConfig | ||
from nanotron.models.llama import LlamaForTraining | ||
from transformers import LlamaConfig as HFLlamaConfig | ||
from transformers import LlamaForCausalLM | ||
|
||
|
||
def _handle_attention_block( | ||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, n_q_heads: int, n_kv_heads: int, d_qk: int | ||
) -> torch.Tensor: | ||
# Huggingface Llama separates the q, k, v weights (as opposed to nanotron). | ||
# Furthermore, in the rotary embeddings in nanotron expects interleaved pairs of even | ||
# and odd dimensions GPT-J style, while the huggingface implementation expects | ||
# the whole 1st half and then the whole 2nd half GPT-NeoX style (for more information | ||
# see flash_attn.layers.rotary.RotaryEmbedding). | ||
# This function handles the concatenation of the q, k, v weights and proper permutation | ||
# to ensure correct transformation. | ||
|
||
def interleave(w: torch.Tensor): | ||
w_new = [] | ||
for head_w in w.split(d_qk): | ||
head_w = head_w.view(2, d_qk // 2, -1).transpose(0, 1).reshape(d_qk, -1) | ||
w_new.append(head_w) | ||
return torch.cat(w_new) | ||
|
||
q = interleave(q) | ||
k = interleave(k) | ||
return torch.cat([q, k, v]) | ||
|
||
|
||
def convert_hf_to_nt(model_hf: LlamaForCausalLM, model_nt: LlamaForTraining, config: NanotronLlamaConfig): | ||
"""Converts the weights from the model_hf to model_nt, making modifications | ||
in-place.""" | ||
|
||
hf_sd = model_hf.state_dict() | ||
nt_to_hf = get_weight_mapping(config, nt_to_hf=True) | ||
|
||
for module_name_nt, module_nt in model_nt.named_modules(): | ||
for param_name_nt, param_nt in module_nt.named_parameters(recurse=False): | ||
# In the case of qkv_proj, the nt_to_hf has exactly three keys, ccorresponding | ||
# to q, k, v. | ||
if "qkv_proj" in module_name_nt: | ||
key_k, key_q, key_v = sorted(nt_to_hf[f"{module_name_nt}.{param_name_nt}"]) | ||
q = hf_sd[key_q] | ||
k = hf_sd[key_k] | ||
v = hf_sd[key_v] | ||
param = _handle_attention_block( | ||
q, | ||
k, | ||
v, | ||
config.num_attention_heads, | ||
config.num_key_value_heads, | ||
config.hidden_size // config.num_attention_heads, | ||
) | ||
# The case of gate_up_proj, nt_to_hf_map has two keys. | ||
elif "gate_up_proj" in module_name_nt: | ||
key_gate, key_up = sorted(nt_to_hf[f"{module_name_nt}.{param_name_nt}"]) | ||
gate = hf_sd[key_gate] | ||
up = hf_sd[key_up] | ||
param = torch.cat([gate, up]) | ||
# All other cases are simple 1-to-1 correspondence. | ||
else: | ||
hf_key = nt_to_hf[f"{module_name_nt}.{param_name_nt}"] | ||
param = hf_sd[hf_key] | ||
|
||
with torch.no_grad(): | ||
param_nt.copy_(param) | ||
|
||
|
||
def get_nanotron_config(config: HFLlamaConfig) -> NanotronLlamaConfig: | ||
"""Converts a huggingface configuration to nanotron configuration.""" | ||
attrs = {key: getattr(config, value) for key, value in get_config_mapping(nt_to_hf=True).items()} | ||
return NanotronLlamaConfig(**attrs) | ||
|
||
|
||
def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): | ||
"""Loads the huggingface checkpoint in `checkpoint_path`, creates | ||
a new nanotron instance, copies the weights from the huggingface checkpoint | ||
and saves the transformed nanotron to `save_path`.""" | ||
|
||
# Load huggingface. | ||
hf_model = LlamaForCausalLM.from_pretrained(checkpoint_path) | ||
|
||
# Init nanotron model. | ||
model_config = get_nanotron_config(hf_model.config) | ||
nanotron_model = load_nanotron_model(model_config=model_config) | ||
|
||
# Copy weights and save model. | ||
parallel_context = nanotron.parallel.ParallelContext( | ||
data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=1 | ||
) | ||
convert_hf_to_nt(hf_model, nanotron_model, model_config) | ||
nanotron.serialize.save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=save_path) | ||
with open(save_path / "model_config.json", "w+") as f: | ||
json.dump(dataclasses.asdict(model_config), f) | ||
print(f"Model saved to {save_path}") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser(description="Convert HF weights to nanotron format") | ||
parser.add_argument("--checkpoint_path", type=Path, default="llama-7b", help="Path to the checkpoint") | ||
parser.add_argument("--save_path", type=Path, default="llama-7b-hf", help="Path to save the nanotron model") | ||
args = parser.parse_args() | ||
|
||
# Convert HF model to nanotron format. | ||
convert_checkpoint_and_save(checkpoint_path=args.checkpoint_path, save_path=args.save_path) |
Oops, something went wrong.