-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add more options to the unshard_checkpoint
function to help scale
#145
Conversation
chunks = [] | ||
for chunk_num, key in enumerate(metadata.state_dict_metadata.keys()): | ||
if key.startswith(f"{prefix}."): | ||
chunks.append((path / f"chunk-{chunk_num:05d}.{ext}", [key])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thought: it could be nice to name it as chunk-{key}
, maybe with some cleanup of the key. Then one could easily examine the weights of specific parameters. Would possibly be slightly cleaner for hypothetical future converter logic too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done: bc79ab5
I was getting worried about unsharding really big checkpoints like for the 32B, which we'll need to do soon. The main issue at the moment is that in order to unshard we need to load the entire model (or optimizer state) in memory, which clearly isn't scalable. So I've added an option to unshard the checkpoint into chunks of a given size which helps scale because only a single chunk (which could be as small as a single tensor) needs to load into memory at a time. Each chunk gets written to a unique file. I think HuggingFace does something similar.
Note: this is not supported for optimizer state yet. But, speaking of optimizer state, this PR also adds a function called
load_keys()
for loading (and unsharding) specific keys from a checkpoint. So if you want to inspect part of the optimizer state, you could use that function without having to unshard the whole optimizer state.