Skip to content

Commit

Permalink
Save checkpoint to disk for API with new save layout (#3399)
Browse files Browse the repository at this point in the history
  • Loading branch information
eracah authored and mvpatel2000 committed Jul 21, 2024
1 parent 7a4644a commit 94f1ec1
Show file tree
Hide file tree
Showing 7 changed files with 507 additions and 61 deletions.
3 changes: 1 addition & 2 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@
is_model_deepspeed,
partial_format,
)
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME
from composer.utils.compression import get_compressor, is_compressed_pt
from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY

log = logging.getLogger(__name__)

__all__ = ['CheckpointSaver']

_TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME = '.metadata'


class CheckpointSaver(Callback): # noqa: D101
__doc__ = f"""Callback to save checkpoints.
Expand Down
284 changes: 281 additions & 3 deletions composer/checkpoint/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,291 @@

"""Useful functions for saving state dicts to disk."""

import json
import logging
import os
import pickle
import textwrap
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Sequence, Union

import torch
import torch.distributed.checkpoint as DCP
from packaging import version
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor

from composer.checkpoint.state_dict import (
get_metadata_state_dict,
get_model_state_dict,
get_optim_state_dict,
get_resumption_state_dict,
)
from composer.core import State, Time
from composer.devices import Device
from composer.models import ComposerModel
from composer.utils import dist
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME, _write_checkpoint_file
from composer.utils.file_helpers import format_name_with_dist_and_time

log = logging.getLogger(__name__)

MODEL_CHECKPOINT_DIRECTORY_NAME = 'model'
MONOLITHIC_MODEL_CHECKPOINT_FILENAME = 'model.pt'
OPTIM_CHECKPOINT_DIRECTORY_NAME = 'optim'
OPTIM_MONO_CHECKPOINT_FILENAME = 'optim.pt'
METADATA_CHECKPOINT_FILENAME = 'composer_metadata.json'
RESUMPTION_CHECKPOINT_FILENAME = 'resumption.pkl'


@dataclass
class CheckpointSaveOptions:
"""Options for saving a checkpoint to disk.
Args:
destination_dir (str): The directory to save the checkpoint to.
save_frequency (Union[str, int, Time]): The frequency to save the checkpoint.
If '1ep', the checkpoint will be saved after each epoch.
If '1ba', the checkpoint will be saved after each batch.
If an int, the checkpoint will be saved after that many epochs.
dir_prefix (str): The prefix to use for the directory name. Can include {epoch} and {batch}.
overwrite (bool): Whether to overwrite the checkpoint if it already exists.
save_model (bool): Whether to save the model.
save_optimizer (bool): Whether to save the optimizer.
save_resumption_state (bool): Whether to save the resumption state.
num_checkpoints_to_keep (int): The number of checkpoints to keep.
If -1, all checkpoints will be kept.
save_format (str): The format to save the model in. 'pt', which is the standard pytorch serializarion, is the only option for now.
sharded_checkpoint (bool): Whether to save the model as a sharded checkpoint.
precision (str): The precision to save the model in. One of 'bf16', 'fp32', 'fp16', 'fp64'.
include_keys (Optional[Union[str, Sequence[str]]]): Keys to include in the saved model.
ignore_keys (Optional[Union[str, Sequence[str]]]): Keys to ignore in the saved model.
"""
destination_dir: str
save_frequency: Union[str, int, Time] = '1ep'
dir_prefix: str = 'ep{epoch}-ba{batch}'
overwrite: bool = False
save_model: bool = True
save_optimizer: bool = True
save_resumption_state: bool = True
num_checkpoints_to_keep: int = -1
save_format: str = 'pt'
sharded_checkpoint: bool = False
precision: str = 'bf16'
include_keys: Optional[Union[str, Sequence[str]]] = None
ignore_keys: Optional[Union[str, Sequence[str]]] = None


def save_checkpoint_to_disk(
state: State,
options: Optional[Union[CheckpointSaveOptions, Dict]] = None,
destination_dir: Optional[str] = None,
):
"""Saves a checkpoint to disk.
Args:
state (State): The state to save.
options (Optional[Union[CheckpointSaveOptions, Dict]]): The options for saving the checkpoint.
If None, destination_dir must be provided.
destination_dir (Optional[str]): The directory to save the checkpoint to.
If options is provided, this will overwrite options.destination_dir.
"""
if options is None:
if destination_dir is None:
raise ValueError('destination_dir must be provided if options is None')
options = CheckpointSaveOptions(destination_dir=destination_dir)
else:
if isinstance(options, Dict):
options = CheckpointSaveOptions(**options)
if destination_dir is not None:
options.destination_dir = destination_dir
save_path = os.path.join(options.destination_dir, options.dir_prefix)
save_path = format_name_with_dist_and_time(save_path, state.run_name, state.timestamp)
os.makedirs(save_path, exist_ok=True)
if options.save_model:
save_model_to_disk(
state.model,
save_path,
options.sharded_checkpoint,
options.precision,
options.include_keys,
options.ignore_keys,
options.overwrite,
options.save_format,
)
if options.save_optimizer:
optimizer = state.optimizers[0]
save_optim_to_disk(
state.model,
optimizer,
save_path,
options.sharded_checkpoint,
options.precision,
options.overwrite,
options.save_format,
)
if options.save_resumption_state:
save_resumption_state_to_disk(state, save_path)

save_composer_metadata_to_disk(
save_path,
state.model,
options.sharded_checkpoint,
options.precision,
state.device,
state.device_train_microbatch_size,
)


def save_model_to_disk(
model: Union[ComposerModel, torch.nn.Module],
destination_dir: str,
sharded_checkpoint: bool = False,
precision: str = 'fp32',
include_keys: Optional[Union[str, Sequence[str]]] = None,
ignore_keys: Optional[Union[str, Sequence[str]]] = None,
overwrite: bool = False,
save_format: str = 'pt', # or hf, safetensor
) -> Optional[str]:
"""Saves a model to disk.
Args:
model (Union[ComposerModel, torch.nn.Module]): The model to save.
destination_dir (str): The directory to save the model to.
Model will be saved as distination_dir/models/model.pt if sharded_checkpoint is False,
otherwise all shards will be saved as destination_dir/models/__<rank>_0.distcp.
sharded_checkpoint (bool): Whether to save the model as a sharded checkpoint.
precision (str): The precision to save the model in. One of 'bf16', 'fp32', 'fp16', 'fp64'.
include_keys (Optional[Union[str, Sequence[str]]]): Keys to include in the saved model.
ignore_keys (Optional[Union[str, Sequence[str]]]): Keys to ignore in the saved model.
overwrite (bool): If True, the file will be overwritten if it exists.
save_format (str): The format to save the model in. One of 'pt', 'hf', or 'safetensor'.
Returns:
str: The full path to the saved model.
"""
if save_format != 'pt':
raise NotImplementedError(
f"Saving checkpoint in format {save_format} is not supported. Please choose from ['pt'].",
)
model_state_dict = get_model_state_dict(
model,
sharded_checkpoint,
precision,
include_keys,
ignore_keys,
)

destination_file_path = (
os.path.join(destination_dir, MODEL_CHECKPOINT_DIRECTORY_NAME) if sharded_checkpoint else
os.path.join(destination_dir, MODEL_CHECKPOINT_DIRECTORY_NAME, MONOLITHIC_MODEL_CHECKPOINT_FILENAME)
)
saved_path = save_state_dict_to_disk(
state_dict=model_state_dict,
destination_file_path=destination_file_path,
overwrite=overwrite,
save_format=save_format,
)
return saved_path


def save_optim_to_disk(
model: Union[ComposerModel, torch.nn.Module],
optimizer: torch.optim.Optimizer,
destination_dir: str,
sharded_checkpoint: bool = False,
precision: str = 'fp32',
overwrite: bool = False,
save_format: str = 'pt',
) -> Optional[str]:
"""Saves an optimizer to disk.
Args:
model (Union[ComposerModel, torch.nn.Module]): The model to save.
optimizer (torch.optim.Optimizer): The optimizer to save.
destination_dir (str): The directory to save the optimizer to.
Optimizer will be saved as destination_dir/optim/optim.pt if sharded_checkpoint is False,
otherwise all shards will be saved as destination_dir/optim/__<rank>_0.distcp.
sharded_checkpoint (bool): Whether to save the optimizer as a sharded checkpoint.
precision (str): The precision to save the optimizer in. One of 'bf16', 'fp32', 'fp16', 'fp64'.
overwrite (bool): If True, the file will be overwritten if it exists.
save_format (str): The format to save the optimizer in. One of 'pt'.
"""
optim_state_dict = get_optim_state_dict(
model,
optimizer,
sharded_state_dict=sharded_checkpoint,
precision=precision,
)
destination_file_path = os.path.join(destination_dir,
OPTIM_CHECKPOINT_DIRECTORY_NAME) if sharded_checkpoint else os.path.join(
destination_dir,
OPTIM_CHECKPOINT_DIRECTORY_NAME,
OPTIM_MONO_CHECKPOINT_FILENAME,
)
saved_path = save_state_dict_to_disk(
state_dict=optim_state_dict,
destination_file_path=destination_file_path,
overwrite=overwrite,
save_format=save_format,
)

return saved_path


def save_composer_metadata_to_disk(
destination_dir: str,
model: Optional[Union[ComposerModel, torch.nn.Module]] = None,
sharded_state_dict: Optional[bool] = None,
precision: Optional[Union[str, torch.dtype]] = None,
device: Optional[Device] = None,
device_train_microbatch_size: Optional[Union[int, float]] = None,
):
"""Saves metadata about the model to disk.
Args:
destination_dir (str): The directory to save the metadata to.
model (Optional[Union[ComposerModel, torch.nn.Module]]): The model to save metadata about.
sharded_state_dict (Optional[bool]): Whether the model is sharded.
precision (Optional[Union[str, torch.dtype]]): The precision of the model.
device (Optional[Device]): The device the model is on.
device_train_microbatch_size (Optional[Union[int, float]]): The device train microbatch size.
"""
md_dict = get_metadata_state_dict(
model,
sharded_state_dict,
precision,
device,
device_train_microbatch_size,
)
os.makedirs(destination_dir, exist_ok=True)
destination_file_path = os.path.join(destination_dir, METADATA_CHECKPOINT_FILENAME)

if dist.get_global_rank() == 0:
with open(destination_file_path, 'w') as f:
json.dump(md_dict, f, indent=4)
return destination_file_path


def save_resumption_state_to_disk(
state: State,
destination_dir: str,
):
"""Saves the resumption state to disk.
Args:
state (State): The state to save.
destination_dir (str): The directory to save the resumption state to.
"""
resumption_state_dict = get_resumption_state_dict(state)
destination_file_path = os.path.join(destination_dir, RESUMPTION_CHECKPOINT_FILENAME)
with open(destination_file_path, 'wb') as f:
pickle.dump(resumption_state_dict, f)
return destination_file_path


from composer.utils import dist
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME, _write_checkpoint_file

Expand Down Expand Up @@ -80,6 +352,8 @@ def _save_sharded_state_dict_to_disk(
)
destination_file_path = stripped_path

# Wait for all ranks to get here before checking if the directory exists.
dist.barrier()
if dist.get_global_rank() == 0 and not overwrite and os.path.exists(destination_file_path):
raise ValueError(f'Directory {destination_file_path} already exists. Set overwrite=True to overwrite it.')

Expand All @@ -94,6 +368,9 @@ def _save_sharded_state_dict_to_disk(
else:
DCP.save(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path))

log.debug(
f'Finished saving of sharded state dict to {destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}',
)
return destination_file_path + '/' + _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME


Expand All @@ -106,13 +383,14 @@ def _save_full_state_dict_to_disk(

if save_format != 'pt':
raise NotImplementedError(
f"Saving sharded state dict to disk in format {save_format} is not supported. Please choose from ['pt'].",
f"Saving full state dict to disk in format {save_format} is not supported. Please choose from ['pt'].",
)

if not overwrite and os.path.exists(destination_file_path):
raise ValueError(f'File {destination_file_path} already exists. Set overwrite=True to overwrite it.')

if dist.get_global_rank() == 0:
os.makedirs(os.path.dirname(destination_file_path), exist_ok=True)
_write_checkpoint_file(state_dict=state_dict, filename=destination_file_path)
return destination_file_path
return None
Expand All @@ -130,7 +408,7 @@ def is_state_dict_sharded(state_dict: Dict[str, Any]) -> bool:
for value in state_dict.values():
if isinstance(value, ShardedTensor) or isinstance(value, DTensor):
return True
if isinstance(value, Dict):
elif isinstance(value, Dict):
is_sharded = is_state_dict_sharded(value)
if is_sharded:
return True
Expand Down
2 changes: 1 addition & 1 deletion composer/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def get_metadata_state_dict(
sharded_state_dict: Optional[bool] = None,
precision: Optional[Union[str, torch.dtype]] = None,
device: Optional[Device] = None,
device_train_microbatch_size: Optional[int] = None,
device_train_microbatch_size: Optional[Union[int, float]] = None,
) -> dict[str, Any]:
"""Generate the metadata and integrations for a training run.
Expand Down
1 change: 1 addition & 0 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
_COMPOSER_STATES_FILENAME = 'composer_states.pt'
_DEEPSPEED_TAG = 'deepspeed' # always tag with the same, deterministic name. We'll rename the tarball to the appropriate name.
_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME = f'__{dist.get_global_rank()}_0.distcp'
_TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME = '.metadata'


def _get_checkpoint_validation_function(
Expand Down
Loading

0 comments on commit 94f1ec1

Please sign in to comment.