Skip to content

Commit

Permalink
Add function to load specific keys
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 23, 2025
1 parent 94d756b commit 2050369
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
44 changes: 43 additions & 1 deletion src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from concurrent.futures import Future
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple

import torch
import torch.distributed as dist
Expand All @@ -53,6 +53,7 @@
"async_save_model_and_optim_state",
"load_model_and_optim_state",
"unshard_checkpoint",
"load_keys",
"get_checkpoint_metadata",
"UnshardStrategy",
"UnshardStrategyType",
Expand Down Expand Up @@ -527,6 +528,47 @@ def unshard_chunk(prefix: str, path: Path, keys: List[str]):
return model_path, optim_path


def load_keys(
dir: PathOrStr,
keys: Iterable[str],
*,
pre_download: bool = False,
work_dir: Optional[PathOrStr] = None,
) -> Generator[Any, None, None]:
"""
Load specific keys from a checkpoint.
.. warning::
This should only be called in a non-distributed context. Otherwise a :class:`RuntimeError` is raised.
:param dir: The path/URL to the original checkpoint created via :func:`save_model_and_optim_state()`,
:func:`save_state_dict`, or one of the other functions in this module.
:param keys: The keys to load.
:param pre_download: Download and cache relevant remote checkpoint files before trying to read from them.
:param work_dir: A working directory for caching files/directories.
:returns: The (unsharded) objects from the checkpoint corresponding to the given keys.
"""
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict

if is_distributed():
raise RuntimeError("'load_keys' cannot be called in a distributed context")

dir = normalize_path(dir)

keys = list(keys)
state_dict: Dict[str, Any] = {}
_load_state_dict(
state_dict,
storage_reader=RemoteFileSystemReader(dir, pre_download=pre_download, work_dir=work_dir),
planner=_EmptyStateDictLoadPlanner(keys=keys),
no_dist=True,
)
for key in keys:
yield _get_key(state_dict, key, pop=True)


def get_checkpoint_metadata(dir: PathOrStr) -> Metadata:
"""
Load the metadata from a checkpoint.
Expand Down
7 changes: 7 additions & 0 deletions src/test/distributed/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from olmo_core.distributed.checkpoint import (
UnshardStrategy,
async_save_model_and_optim_state,
load_keys,
load_model_and_optim_state,
save_model_and_optim_state,
save_state_dict,
Expand Down Expand Up @@ -312,6 +313,12 @@ def test_unshard_checkpoint(backend, tmp_path):
combined_model_state.update(safetensors.torch.load_file(path))
torch.testing.assert_close(combined_model_state, model_state_st)

# Try loading specific keys.
tensors = list(load_keys(sharded_checkpoint_dir, ["model.w1.weight", "model.w2.bias"]))
assert len(tensors) == 2
assert tensors[0].shape == (16, 16)
assert tensors[1].shape == (16,)


def run_load_checkpoint_with_missing_keys(dir):
class FF1(nn.Module):
Expand Down

0 comments on commit 2050369

Please sign in to comment.