Skip to content

Commit

Permalink
Merge pull request #151 from huggingface/nouamane/moe-5
Browse files Browse the repository at this point in the history
Script to fix duplicated ".safetensors" in checkpoints naming
  • Loading branch information
3outeille authored Apr 25, 2024
2 parents b792a22 + 9af94e7 commit 1c7f038
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ repos:
- id: codespell
args:
- -w
- --ignore-words-list=nd,reacher,thist,ths,magent,ba,fo
- --ignore-words-list=nd,reacher,thist,ths,magent,ba,fo,doesnt
1 change: 0 additions & 1 deletion examples/moe/llamoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ def forward(
# Double check that we use store only at inference time
assert key_states.requires_grad is False
assert value_states.requires_grad is False
print("Using store")
if "position_offsets" in store:
old_position_offsets = store["position_offsets"]
position_ids = old_position_offsets[:, None] + sequence_mask
Expand Down
2 changes: 1 addition & 1 deletion examples/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def __init__(self, module, expert_parallel_size: int):
self.expert_parallel_size = expert_parallel_size

def forward(self, *args, **kwargs):
# self.scale_gradients()
self.scale_gradients()
return self.module(*args, **kwargs)

def scale_gradients(self):
Expand Down
51 changes: 51 additions & 0 deletions scripts/fix_checkpoint_bad_naming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Fixes the problem where '{type.value}_{suffix_name}.safetensors' was duplicated in checkpoint files
For example this script will change the following:
```
checkpoints/10/model/model/decoder/0/pp_block/attn/o_proj/model_model_weight.safetensors_pp-rank-0-of-1_tp-rank-0-of-2.safetensors
to
checkpoints/10/model/model/decoder/0/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-2.safetensors
```
Example Usage:
python scripts/fix_checkpoint_bad_naming.py /fsx/nouamane/projects/nanotron/checkpoints/10
"""

import argparse
import os
import re
from pathlib import Path


def update_checkpoint(checkpoint_dir: str):
print(f"Updating checkpoint in {checkpoint_dir}")
for root, _, files in os.walk(checkpoint_dir):
for file in files:
if file.endswith(".safetensors"):
# r'(?<=model)_(model)' means match the string '_model' that is preceded by 'model'
if len(re.findall(r"(?<=model)_(model)", file)) == 0:
continue
# we remove second _model
new_file = re.sub(r"(?<=model)_(model)", "", file)
# we would have "model_weight.safetensors_pp-rank-0-of-1_tp-rank-0-of-2.safetensors"

# let's assert we have two matches of ".safetensors"
assert len(re.findall(r".safetensors", new_file)) == 2
# then we remove first match
new_file = re.sub(r".safetensors", "", new_file, count=1)
# so that we get "model_weight_pp-rank-0-of-1_tp-rank-0-of-2.safetensors"

print(f"Renaming {file} to {new_file}")
os.rename(os.path.join(root, file), os.path.join(root, new_file))


def main():
parser = argparse.ArgumentParser(description="Update checkpoint from 1.3 to 1.4")
parser.add_argument("checkpoint_dir", type=Path, help="Path to the checkpoint directory")
args = parser.parse_args()
update_checkpoint(args.checkpoint_dir)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion src/nanotron/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from packaging.version import Version, parse

CHECKPOINT_VERSION = Version("1.3")
CHECKPOINT_VERSION = Version("1.4")

PY_VERSION = parse(platform.python_version())

Expand Down
6 changes: 3 additions & 3 deletions src/nanotron/serialize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@ def get_path(
suffix = tensor_name.split(".")
suffix_path, suffix_name = suffix[:-1], suffix[-1]

suffix_name = f"{type.value}_{suffix_name}.safetensors"

if exp_tp_pp_rank_and_size:
# We always show pp_rank and tp_rank if `exp_tp_pp_rank_and_size` is provided
# We only show exp_rank if tensor is exp_sharded and exp_size > 1
(exp_rank, exp_size), (tp_rank, tp_size), (pp_rank, pp_size) = exp_tp_pp_rank_and_size
if not is_expert_sharded or exp_size == 1:
suffix_name = (
f"{type.value}_{suffix_name}_pp-rank-{pp_rank}-of-{pp_size}_tp-rank-{tp_rank}-of-{tp_size}.safetensors"
)
else:
# We only show exp_rank if tensor is exp_sharded and exp_size > 1
suffix_name = f"{type.value}_{suffix_name}_pp-rank-{pp_rank}-of-{pp_size}_tp-rank-{tp_rank}-of-{tp_size}_exp-rank-{exp_rank}-of-{exp_size}.safetensors"
else:
suffix_name = f"{type.value}_{suffix_name}.safetensors"

suffix_path.append(suffix_name)
if prefix is None:
Expand Down
6 changes: 5 additions & 1 deletion src/nanotron/serialize/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def load_weights(
exp_tp_pp_rank_and_size = get_exp_tp_pp_rank_and_size_from(
world_rank=get_global_rank(group=group, group_rank=group_rank), parallel_context=parallel_context
)
# TODO @nouamane: do we consider exp_size=1 expert_sharded?
is_expert_sharded = sharded_info.is_expert_sharded(parallel_context)
else:
exp_tp_pp_rank_and_size = None
Expand Down Expand Up @@ -280,7 +281,10 @@ def load_weights(
suffix = base_name.rsplit(".", 1)[-1]
shards_path = list(path.parent.glob(f"{ObjectType.MODEL.value}_{suffix}*.safetensors"))
if len(shards_path) <= 0:
raise ValueError(f"Could not find any shards in {path.parent}")
raise ValueError(
f"Could not find any shards {ObjectType.MODEL.value}_{suffix}*.safetensors in {path.parent}."
f"If you notice `.safetensors` in the middle of the name of some of the checkpoints files. You need to run `scripts/fix_checkpoint_bad_naming.py`."
)

if checkpoint_version is None:
checkpoint_version = get_checkpoint_version(
Expand Down

0 comments on commit 1c7f038

Please sign in to comment.