Skip to content

Commit

Permalink
Add more options to the unshard_checkpoint function to help scale (#…
Browse files Browse the repository at this point in the history
…145)

I was getting worried about unsharding really big checkpoints like for
the 32B, which we'll need to do soon. The main issue at the moment is
that in order to unshard we need to load the entire model (or optimizer
state) in memory, which clearly isn't scalable. So I've added an option
to unshard the checkpoint into chunks of a given size which helps scale
because only a single chunk (which could be as small as a single tensor)
needs to load into memory at a time. Each chunk gets written to a unique
file. I think HuggingFace does something similar.

Note: this is not supported for optimizer state yet. But, speaking of
optimizer state, this PR also adds a function called `load_keys()` for
loading (and unsharding) specific keys from a checkpoint. So if you want
to inspect part of the optimizer state, you could use that function
without having to unshard the whole optimizer state.
  • Loading branch information
epwalsh authored Jan 24, 2025
1 parent 16885ab commit b4a195b
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 61 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `SkipStepAdamW` optimizer.
- The trainer can load model-only checkpoints now.
- Added the option to throttle checkpoint uploads to one rank from each node at a time.
- Added `unshard_strategy` parameter to `unshard_checkpoint()` function in `olmo_coer.distributed.checkpoint`.
- Added function `load_keys()` to `olmo_core.distributed.checkpoint`.

### Changed

Expand Down
267 changes: 223 additions & 44 deletions src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,23 @@

import logging
from concurrent.futures import Future
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.state_dict as dist_cp_sd
import torch.nn as nn
from rich.progress import track
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.metadata import Metadata
from torch.distributed.checkpoint.metadata import Metadata, TensorStorageMetadata

from olmo_core.aliases import PathOrStr
from olmo_core.config import StrEnum
from olmo_core.io import clear_directory, dir_is_empty, is_url, normalize_path
from olmo_core.utils import gc_cuda, wait_for
from olmo_core.utils import gc_cuda, get_element_size, wait_for

from ..utils import barrier, get_fs_local_rank, is_distributed
from .filesystem import RemoteFileSystemReader, RemoteFileSystemWriter
Expand All @@ -51,7 +54,10 @@
"async_save_model_and_optim_state",
"load_model_and_optim_state",
"unshard_checkpoint",
"load_keys",
"get_checkpoint_metadata",
"UnshardStrategy",
"UnshardStrategyType",
]

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -237,9 +243,9 @@ def load_model_and_optim_state(
reader = RemoteFileSystemReader(
dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir
)
metadata = reader.read_metadata()

if key_mapping is not None:
metadata = reader.read_metadata()
for current_key, original_key in key_mapping.items():
if f"model.{original_key}" not in metadata.state_dict_metadata:
continue
Expand Down Expand Up @@ -267,7 +273,6 @@ def load_model_and_optim_state(
)

if key_mapping is not None:
metadata = reader.read_metadata()
for current_key, original_key in key_mapping.items():
if f"model.{original_key}" not in metadata.state_dict_metadata:
continue
Expand Down Expand Up @@ -298,15 +303,86 @@ def load_model_and_optim_state(
gc_cuda()


class UnshardStrategyType(StrEnum):
"""
An enumeration of the unsharding strategies that can be used with :func:`unshard_checkpoint`.
"""

one_file = "one_file"
"""
Save the unsharded model state into a one file, and optionally the optimizer state into
another file. The bigger the model, the more memory this requires. For very big models,
:data:`one_file_per_tensor` will scale better.
"""

one_file_per_tensor = "one_file_per_tensor"
"""
Save each unsharded tensor to its own file. Currently this is not compatible with optimizer
state.
"""

chunks = "chunks"
"""
Like :data:`one_file_per_tensor` but multiple tensors and objects may be grouped into the same file
up to the limit defined by :data:`UnshardStrategy.chunk_size_bytes`.
"""


@dataclass
class UnshardStrategy:
"""
Unsharding strategy config for :func:`unshard_checkpoint`.
"""

name: UnshardStrategyType = UnshardStrategyType.one_file
"""
The strategy type.
"""

chunk_size_bytes: Optional[int] = None
"""
The approximate max chunk size (per file size), in bytes, for the :data:`UnshardStrategyType.chunks` strategy.
"""

def __post_init__(self):
if self.name == UnshardStrategyType.chunks and self.chunk_size_bytes is None:
raise ValueError("'chunk_size_bytes' is required for the 'chunks' strategy")
if self.chunk_size_bytes is not None and self.name != UnshardStrategyType.chunks:
raise ValueError("'chunk_size_bytes' is only valid for the 'chunks' strategy")

@classmethod
def one_file(cls) -> "UnshardStrategy":
"""
Use the :data:`UnshardStrategy.one_file` strategy.
"""
return cls(name=UnshardStrategyType.one_file)

@classmethod
def one_file_per_tensor(cls) -> "UnshardStrategy":
"""
Use the :data:`UnshardStrategy.one_file_per_tensor` strategy.
"""
return cls(name=UnshardStrategyType.one_file_per_tensor)

@classmethod
def chunks(cls, chunk_size_in_bytes: int) -> "UnshardStrategy":
"""
Use the :data:`UnshardStrategy.chunks` strategy.
"""
return cls(name=UnshardStrategyType.chunks, chunk_size_bytes=chunk_size_in_bytes)


def unshard_checkpoint(
dir: PathOrStr,
target_dir: PathOrStr,
*,
optim: Optional[bool] = None,
save_overwrite: bool = False,
use_safetensors: bool = False,
unshard_strategy: Optional[UnshardStrategy] = None,
pre_download: bool = False,
work_dir: Optional[PathOrStr] = None,
quiet: bool = False,
) -> Tuple[Path, Optional[Path]]:
"""
Convert a checkpoint saved via :func:`save_model_and_optim_state()` into unsharded
Expand All @@ -321,6 +397,9 @@ def unshard_checkpoint(
.. warning::
This should only be called in a non-distributed context. Otherwise a :class:`RuntimeError` is raised.
.. seealso::
:func:`load_keys()` if you only need to load and unshard certain keys in the checkpoint.
:param dir: The path/URL to the original checkpoint created via :func:`save_model_and_optim_state()`.
:param target_dir: The directory to save the unsharded model/optimizer checkpoint files to.
This must be a local directory. URLs are not supported.
Expand All @@ -329,82 +408,162 @@ def unshard_checkpoint(
:param save_overwrite: Overwrite any existing files in ``target_dir``.
:param use_safetensors: Save the unsharded files with :func:`safetensors.torch.save_file()` instead
of :func:`torch.save()`.
:param unshard_strategy: The strategy to use. Defaults to :meth:`UnshardStrategy.one_file`.
:param pre_download: Download and cache relevant remote checkpoint files before trying to read from them.
:param work_dir: A working directory for caching files/directories.
:param quiet: Do not show progress messages.
:return: The path to the unsharded model checkpoint and the path to the unsharded
optimizer checkpoint if ``optim=True``.
optimizer checkpoint if ``optim=True``. These paths may represent files or directories
depending on the ``unshard_strategy``.
:raises FileExistsError: If the ``target_dir`` is non-empty and ``save_overwrite=False``.
"""
# Adapted from `torch.distributed.checkpoint.format_utils.dcp_to_torch_save()`.

from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
if unshard_strategy is None:
unshard_strategy = UnshardStrategy.one_file()

if optim is None:
optim = not use_safetensors
optim = (not use_safetensors) and (unshard_strategy.name == UnshardStrategyType.one_file)
elif optim and use_safetensors:
raise NotImplementedError("`optim=True` is incompatible with `use_safetensors=True`")
elif optim and unshard_strategy.name != UnshardStrategyType.one_file:
raise NotImplementedError(
f"`optim=True` is incompatible with `unshard_strategy={unshard_strategy}`"
)

if is_distributed():
raise RuntimeError("'unshard_checkpoint' cannot be called in a distributed context")

dir = normalize_path(dir)

if is_url(target_dir):
raise ValueError("'target_dir' must be a local directory")
target_dir = Path(normalize_path(target_dir))
target_dir.mkdir(exist_ok=True, parents=True)

ext = "pt" if not use_safetensors else "safetensors"
metadata = get_checkpoint_metadata(dir)

def save(state_dict: Dict[str, Any], path: Path):
if path.is_file() and not save_overwrite:
raise FileExistsError(
f"'{path}' already exists, use `save_overwrite=True` to overwrite it"
)

path.parent.mkdir(parents=True, exist_ok=True)

if use_safetensors:
from safetensors.torch import save_file

save_file(state_dict, path)
else:
torch.save(state_dict, path)

dir = normalize_path(dir)
if is_url(target_dir):
raise ValueError("'target_dir' must be a local directory")
target_dir = Path(normalize_path(target_dir))
target_dir.mkdir(exist_ok=True, parents=True)

ext = "pt" if not use_safetensors else "safetensors"
model_path = target_dir / f"model.{ext}"
optim_path = target_dir / f"optim.{ext}" if optim else None

model_sd: Dict[str, Any] = {}
_load_state_dict(
model_sd,
storage_reader=RemoteFileSystemReader(dir, pre_download=pre_download, work_dir=work_dir),
planner=_EmptyStateDictLoadPlanner(keys=["model"]),
no_dist=True,
)
if not model_sd:
raise RuntimeError("no model state found in checkpoint")
save(model_sd["model"], model_path)
del model_sd
gc_cuda()
def get_chunks(prefix: str) -> Tuple[Path, List[Tuple[Path, List[str]]]]:
assert unshard_strategy is not None
assert isinstance(target_dir, Path)

if unshard_strategy.name == UnshardStrategyType.one_file:
path = target_dir / f"{prefix}.{ext}"
return path, [(path, [prefix])]
elif unshard_strategy.name == UnshardStrategyType.one_file_per_tensor:
path = target_dir / prefix
chunks = []
for key in metadata.state_dict_metadata.keys():
if key.startswith(f"{prefix}."):
chunks.append((path / f"{key.replace('.', '-')}.{ext}", [key]))
return path, chunks
elif unshard_strategy.name == UnshardStrategyType.chunks:
assert unshard_strategy.chunk_size_bytes is not None
path = target_dir / prefix
chunks = []
current_size = 0
current_keys: List[str] = []
for key, meta in metadata.state_dict_metadata.items():
if key.startswith(f"{prefix}."):
if isinstance(meta, TensorStorageMetadata):
size = meta.size.numel() * get_element_size(meta.properties.dtype)
if current_keys and current_size + size > unshard_strategy.chunk_size_bytes:
chunks.append((path / f"chunk-{len(chunks):05d}.{ext}", current_keys))
current_size = 0
current_keys = []
current_size += size
current_keys.append(key)
else:
# This is a pickled Python object, which is probably pretty small,
# so we don't worry about recording the size.
current_keys.append(key)
if current_keys:
chunks.append((path / f"chunk-{len(chunks):05d}.{ext}", current_keys))
return path, chunks
else:
raise NotImplementedError(unshard_strategy.name)

if optim_path is not None:
optim_sd: Dict[str, Any] = {}
_load_state_dict(
optim_sd,
storage_reader=RemoteFileSystemReader(
dir, pre_download=pre_download, work_dir=work_dir
),
planner=_EmptyStateDictLoadPlanner(keys=["optim"]),
no_dist=True,
def unshard_chunk(prefix: str, path: Path, keys: List[str]):
state_dict: Dict[str, Any] = _load_unsharded_keys(
dir, keys, pre_download=pre_download, work_dir=work_dir
)
if not optim_sd:
raise RuntimeError("no optimizer state found in checkpoint")
save(optim_sd["optim"], optim_path)
del optim_sd
if not state_dict:
raise RuntimeError(f"missing keys '{keys}' in checkpoint")

save(state_dict[prefix], path)
del state_dict
gc_cuda()

model_path, model_chunks = get_chunks("model")
for chunk_path, chunk_keys in track(
model_chunks, description="Unsharding model chunks...", disable=quiet
):
unshard_chunk("model", chunk_path, chunk_keys)

optim_path: Optional[Path] = None
if optim:
optim_path, optim_chunks = get_chunks("optim")
for chunk_path, chunk_keys in track(
optim_chunks, description="Unsharding optim chunks...", disable=quiet
):
unshard_chunk("optim", chunk_path, chunk_keys)

return model_path, optim_path


def load_keys(
dir: PathOrStr,
keys: Iterable[str],
*,
pre_download: bool = False,
work_dir: Optional[PathOrStr] = None,
) -> Generator[Any, None, None]:
"""
Load specific keys from a checkpoint.
.. warning::
This should only be called in a non-distributed context. Otherwise a :class:`RuntimeError` is raised.
:param dir: The path/URL to the original checkpoint created via :func:`save_model_and_optim_state()`,
:func:`save_state_dict`, or one of the other functions in this module.
:param keys: The keys to load.
:param pre_download: Download and cache relevant remote checkpoint files before trying to read from them.
:param work_dir: A working directory for caching files/directories.
:returns: The (unsharded) objects from the checkpoint corresponding to the given keys, in the
same order as the keys.
"""
if is_distributed():
raise RuntimeError("'load_keys' cannot be called in a distributed context")

dir = normalize_path(dir)
# validate checkpoint.
get_checkpoint_metadata(dir)

keys = list(keys)
state_dict = _load_unsharded_keys(dir, keys, pre_download=pre_download, work_dir=work_dir)
for key in keys:
yield _get_key(state_dict, key, pop=True)


def get_checkpoint_metadata(dir: PathOrStr) -> Metadata:
"""
Load the metadata from a checkpoint.
Expand Down Expand Up @@ -460,6 +619,26 @@ def _prepare_state_dict(
return state_dict


def _load_unsharded_keys(
dir: PathOrStr,
keys: List[str],
*,
pre_download: bool = False,
work_dir: Optional[PathOrStr] = None,
) -> Dict[str, Any]:
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict

state_dict: Dict[str, Any] = {}
_load_state_dict(
state_dict,
storage_reader=RemoteFileSystemReader(dir, pre_download=pre_download, work_dir=work_dir),
planner=_EmptyStateDictLoadPlanner(keys=keys),
no_dist=True,
)
return state_dict


def _get_key(state_dict: Dict[str, Any], key: str, pop: bool = False) -> Any:
if key in state_dict:
if pop:
Expand Down
Loading

0 comments on commit b4a195b

Please sign in to comment.