diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index 263558fc2b..c17b874c21 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -30,6 +30,7 @@ is_model_deepspeed, partial_format, ) +from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME from composer.utils.compression import get_compressor, is_compressed_pt from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY @@ -37,8 +38,6 @@ __all__ = ['CheckpointSaver'] -_TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME = '.metadata' - class CheckpointSaver(Callback): # noqa: D101 __doc__ = f"""Callback to save checkpoints. diff --git a/composer/checkpoint/save.py b/composer/checkpoint/save.py index 72e5311d0f..03166d8802 100644 --- a/composer/checkpoint/save.py +++ b/composer/checkpoint/save.py @@ -3,12 +3,15 @@ """Useful functions for saving state dicts to disk.""" +import json import logging import os +import pickle import textwrap import warnings +from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Sequence, Union import torch import torch.distributed.checkpoint as DCP @@ -16,6 +19,275 @@ from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._tensor import DTensor +from composer.checkpoint.state_dict import ( + get_metadata_state_dict, + get_model_state_dict, + get_optim_state_dict, + get_resumption_state_dict, +) +from composer.core import State, Time +from composer.devices import Device +from composer.models import ComposerModel +from composer.utils import dist +from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME, _write_checkpoint_file +from composer.utils.file_helpers import format_name_with_dist_and_time + +log = logging.getLogger(__name__) + +MODEL_CHECKPOINT_DIRECTORY_NAME = 'model' +MONOLITHIC_MODEL_CHECKPOINT_FILENAME = 'model.pt' +OPTIM_CHECKPOINT_DIRECTORY_NAME = 'optim' +OPTIM_MONO_CHECKPOINT_FILENAME = 'optim.pt' +METADATA_CHECKPOINT_FILENAME = 'composer_metadata.json' +RESUMPTION_CHECKPOINT_FILENAME = 'resumption.pkl' + + +@dataclass +class CheckpointSaveOptions: + """Options for saving a checkpoint to disk. + + Args: + destination_dir (str): The directory to save the checkpoint to. + save_frequency (Union[str, int, Time]): The frequency to save the checkpoint. + If '1ep', the checkpoint will be saved after each epoch. + If '1ba', the checkpoint will be saved after each batch. + If an int, the checkpoint will be saved after that many epochs. + dir_prefix (str): The prefix to use for the directory name. Can include {epoch} and {batch}. + overwrite (bool): Whether to overwrite the checkpoint if it already exists. + save_model (bool): Whether to save the model. + save_optimizer (bool): Whether to save the optimizer. + save_resumption_state (bool): Whether to save the resumption state. + num_checkpoints_to_keep (int): The number of checkpoints to keep. + If -1, all checkpoints will be kept. + save_format (str): The format to save the model in. 'pt', which is the standard pytorch serializarion, is the only option for now. + sharded_checkpoint (bool): Whether to save the model as a sharded checkpoint. + precision (str): The precision to save the model in. One of 'bf16', 'fp32', 'fp16', 'fp64'. + include_keys (Optional[Union[str, Sequence[str]]]): Keys to include in the saved model. + ignore_keys (Optional[Union[str, Sequence[str]]]): Keys to ignore in the saved model. + """ + destination_dir: str + save_frequency: Union[str, int, Time] = '1ep' + dir_prefix: str = 'ep{epoch}-ba{batch}' + overwrite: bool = False + save_model: bool = True + save_optimizer: bool = True + save_resumption_state: bool = True + num_checkpoints_to_keep: int = -1 + save_format: str = 'pt' + sharded_checkpoint: bool = False + precision: str = 'bf16' + include_keys: Optional[Union[str, Sequence[str]]] = None + ignore_keys: Optional[Union[str, Sequence[str]]] = None + + +def save_checkpoint_to_disk( + state: State, + options: Optional[Union[CheckpointSaveOptions, Dict]] = None, + destination_dir: Optional[str] = None, +): + """Saves a checkpoint to disk. + + Args: + state (State): The state to save. + options (Optional[Union[CheckpointSaveOptions, Dict]]): The options for saving the checkpoint. + If None, destination_dir must be provided. + destination_dir (Optional[str]): The directory to save the checkpoint to. + If options is provided, this will overwrite options.destination_dir. + """ + if options is None: + if destination_dir is None: + raise ValueError('destination_dir must be provided if options is None') + options = CheckpointSaveOptions(destination_dir=destination_dir) + else: + if isinstance(options, Dict): + options = CheckpointSaveOptions(**options) + if destination_dir is not None: + options.destination_dir = destination_dir + save_path = os.path.join(options.destination_dir, options.dir_prefix) + save_path = format_name_with_dist_and_time(save_path, state.run_name, state.timestamp) + os.makedirs(save_path, exist_ok=True) + if options.save_model: + save_model_to_disk( + state.model, + save_path, + options.sharded_checkpoint, + options.precision, + options.include_keys, + options.ignore_keys, + options.overwrite, + options.save_format, + ) + if options.save_optimizer: + optimizer = state.optimizers[0] + save_optim_to_disk( + state.model, + optimizer, + save_path, + options.sharded_checkpoint, + options.precision, + options.overwrite, + options.save_format, + ) + if options.save_resumption_state: + save_resumption_state_to_disk(state, save_path) + + save_composer_metadata_to_disk( + save_path, + state.model, + options.sharded_checkpoint, + options.precision, + state.device, + state.device_train_microbatch_size, + ) + + +def save_model_to_disk( + model: Union[ComposerModel, torch.nn.Module], + destination_dir: str, + sharded_checkpoint: bool = False, + precision: str = 'fp32', + include_keys: Optional[Union[str, Sequence[str]]] = None, + ignore_keys: Optional[Union[str, Sequence[str]]] = None, + overwrite: bool = False, + save_format: str = 'pt', # or hf, safetensor +) -> Optional[str]: + """Saves a model to disk. + + Args: + model (Union[ComposerModel, torch.nn.Module]): The model to save. + destination_dir (str): The directory to save the model to. + Model will be saved as distination_dir/models/model.pt if sharded_checkpoint is False, + otherwise all shards will be saved as destination_dir/models/___0.distcp. + sharded_checkpoint (bool): Whether to save the model as a sharded checkpoint. + precision (str): The precision to save the model in. One of 'bf16', 'fp32', 'fp16', 'fp64'. + include_keys (Optional[Union[str, Sequence[str]]]): Keys to include in the saved model. + ignore_keys (Optional[Union[str, Sequence[str]]]): Keys to ignore in the saved model. + overwrite (bool): If True, the file will be overwritten if it exists. + save_format (str): The format to save the model in. One of 'pt', 'hf', or 'safetensor'. + + Returns: + str: The full path to the saved model. + """ + if save_format != 'pt': + raise NotImplementedError( + f"Saving checkpoint in format {save_format} is not supported. Please choose from ['pt'].", + ) + model_state_dict = get_model_state_dict( + model, + sharded_checkpoint, + precision, + include_keys, + ignore_keys, + ) + + destination_file_path = ( + os.path.join(destination_dir, MODEL_CHECKPOINT_DIRECTORY_NAME) if sharded_checkpoint else + os.path.join(destination_dir, MODEL_CHECKPOINT_DIRECTORY_NAME, MONOLITHIC_MODEL_CHECKPOINT_FILENAME) + ) + saved_path = save_state_dict_to_disk( + state_dict=model_state_dict, + destination_file_path=destination_file_path, + overwrite=overwrite, + save_format=save_format, + ) + return saved_path + + +def save_optim_to_disk( + model: Union[ComposerModel, torch.nn.Module], + optimizer: torch.optim.Optimizer, + destination_dir: str, + sharded_checkpoint: bool = False, + precision: str = 'fp32', + overwrite: bool = False, + save_format: str = 'pt', +) -> Optional[str]: + """Saves an optimizer to disk. + + Args: + model (Union[ComposerModel, torch.nn.Module]): The model to save. + optimizer (torch.optim.Optimizer): The optimizer to save. + destination_dir (str): The directory to save the optimizer to. + Optimizer will be saved as destination_dir/optim/optim.pt if sharded_checkpoint is False, + otherwise all shards will be saved as destination_dir/optim/___0.distcp. + sharded_checkpoint (bool): Whether to save the optimizer as a sharded checkpoint. + precision (str): The precision to save the optimizer in. One of 'bf16', 'fp32', 'fp16', 'fp64'. + overwrite (bool): If True, the file will be overwritten if it exists. + save_format (str): The format to save the optimizer in. One of 'pt'. + """ + optim_state_dict = get_optim_state_dict( + model, + optimizer, + sharded_state_dict=sharded_checkpoint, + precision=precision, + ) + destination_file_path = os.path.join(destination_dir, + OPTIM_CHECKPOINT_DIRECTORY_NAME) if sharded_checkpoint else os.path.join( + destination_dir, + OPTIM_CHECKPOINT_DIRECTORY_NAME, + OPTIM_MONO_CHECKPOINT_FILENAME, + ) + saved_path = save_state_dict_to_disk( + state_dict=optim_state_dict, + destination_file_path=destination_file_path, + overwrite=overwrite, + save_format=save_format, + ) + + return saved_path + + +def save_composer_metadata_to_disk( + destination_dir: str, + model: Optional[Union[ComposerModel, torch.nn.Module]] = None, + sharded_state_dict: Optional[bool] = None, + precision: Optional[Union[str, torch.dtype]] = None, + device: Optional[Device] = None, + device_train_microbatch_size: Optional[Union[int, float]] = None, +): + """Saves metadata about the model to disk. + + Args: + destination_dir (str): The directory to save the metadata to. + model (Optional[Union[ComposerModel, torch.nn.Module]]): The model to save metadata about. + sharded_state_dict (Optional[bool]): Whether the model is sharded. + precision (Optional[Union[str, torch.dtype]]): The precision of the model. + device (Optional[Device]): The device the model is on. + device_train_microbatch_size (Optional[Union[int, float]]): The device train microbatch size. + """ + md_dict = get_metadata_state_dict( + model, + sharded_state_dict, + precision, + device, + device_train_microbatch_size, + ) + os.makedirs(destination_dir, exist_ok=True) + destination_file_path = os.path.join(destination_dir, METADATA_CHECKPOINT_FILENAME) + + if dist.get_global_rank() == 0: + with open(destination_file_path, 'w') as f: + json.dump(md_dict, f, indent=4) + return destination_file_path + + +def save_resumption_state_to_disk( + state: State, + destination_dir: str, +): + """Saves the resumption state to disk. + + Args: + state (State): The state to save. + destination_dir (str): The directory to save the resumption state to. + """ + resumption_state_dict = get_resumption_state_dict(state) + destination_file_path = os.path.join(destination_dir, RESUMPTION_CHECKPOINT_FILENAME) + with open(destination_file_path, 'wb') as f: + pickle.dump(resumption_state_dict, f) + return destination_file_path + + from composer.utils import dist from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME, _write_checkpoint_file @@ -80,6 +352,8 @@ def _save_sharded_state_dict_to_disk( ) destination_file_path = stripped_path + # Wait for all ranks to get here before checking if the directory exists. + dist.barrier() if dist.get_global_rank() == 0 and not overwrite and os.path.exists(destination_file_path): raise ValueError(f'Directory {destination_file_path} already exists. Set overwrite=True to overwrite it.') @@ -94,6 +368,9 @@ def _save_sharded_state_dict_to_disk( else: DCP.save(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path)) + log.debug( + f'Finished saving of sharded state dict to {destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}', + ) return destination_file_path + '/' + _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME @@ -106,13 +383,14 @@ def _save_full_state_dict_to_disk( if save_format != 'pt': raise NotImplementedError( - f"Saving sharded state dict to disk in format {save_format} is not supported. Please choose from ['pt'].", + f"Saving full state dict to disk in format {save_format} is not supported. Please choose from ['pt'].", ) if not overwrite and os.path.exists(destination_file_path): raise ValueError(f'File {destination_file_path} already exists. Set overwrite=True to overwrite it.') if dist.get_global_rank() == 0: + os.makedirs(os.path.dirname(destination_file_path), exist_ok=True) _write_checkpoint_file(state_dict=state_dict, filename=destination_file_path) return destination_file_path return None @@ -130,7 +408,7 @@ def is_state_dict_sharded(state_dict: Dict[str, Any]) -> bool: for value in state_dict.values(): if isinstance(value, ShardedTensor) or isinstance(value, DTensor): return True - if isinstance(value, Dict): + elif isinstance(value, Dict): is_sharded = is_state_dict_sharded(value) if is_sharded: return True diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index a20baaf165..5f82836d7b 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -380,7 +380,7 @@ def get_metadata_state_dict( sharded_state_dict: Optional[bool] = None, precision: Optional[Union[str, torch.dtype]] = None, device: Optional[Device] = None, - device_train_microbatch_size: Optional[int] = None, + device_train_microbatch_size: Optional[Union[int, float]] = None, ) -> dict[str, Any]: """Generate the metadata and integrations for a training run. diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index f2342eeb4c..f9ad516724 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -53,6 +53,7 @@ _COMPOSER_STATES_FILENAME = 'composer_states.pt' _DEEPSPEED_TAG = 'deepspeed' # always tag with the same, deterministic name. We'll rename the tarball to the appropriate name. _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME = f'__{dist.get_global_rank()}_0.distcp' +_TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME = '.metadata' def _get_checkpoint_validation_function( diff --git a/tests/checkpoint/helpers.py b/tests/checkpoint/helpers.py index 047d30e813..4915c3a150 100644 --- a/tests/checkpoint/helpers.py +++ b/tests/checkpoint/helpers.py @@ -1,24 +1,85 @@ # Copyright 2024 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict +from typing import Any, Dict, Tuple, Union +from unittest.mock import MagicMock import torch +from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.api import CPUOffload from torch.optim import adam - +from torch.optim.lr_scheduler import StepLR +from torch.utils.data import DataLoader + +from composer.algorithms import SWA +from composer.callbacks import SpeedMonitor +from composer.core import State +from composer.devices import Device, DeviceCPU, DeviceGPU +from composer.models import ComposerModel from tests.common.models import EvenSimplerMLP, SimpleComposerMLP __all__ = [ 'init_model_and_optimizer', 'init_model', 'init_optimizer', + 'init_state', ] +def init_state( + use_fsdp: bool = False, + device: str = 'cpu', + include_schedulers=False, + include_callbacks=False, + include_algorithms=False, + use_grad_scaler=False, + rank_zero_seed=10, + run_name='test_run', + take_step=False, +) -> State: + model, optimizer = init_model_and_optimizer( + use_fsdp=use_fsdp, + use_composer_model=True, + take_step=take_step, + device=device, + ) + + test_dataset_sd = {'test': 0} + device_obj: Device = DeviceCPU() if device == 'cpu' else DeviceGPU() + + dataloader = MagicMock(spec=DataLoader) + dataloader.dataset = MagicMock() + dataloader.dataset.state_dict = MagicMock(return_value=test_dataset_sd) + kwargs = {} + + if include_callbacks: + kwargs['callbacks'] = [SpeedMonitor(), SpeedMonitor()] + if include_algorithms: + kwargs['algorithms'] = [SWA()] + if use_grad_scaler: + if version.parse(torch.__version__) >= version.parse('2.3.0'): + from torch.amp.grad_scaler import GradScaler + else: + from torch.cuda.amp.grad_scaler import GradScaler + kwargs['scaler'] = GradScaler() + + state = State( + model=model, + rank_zero_seed=rank_zero_seed, + run_name=run_name, + device=device_obj, + train_dataloader=dataloader, + optimizers=[optimizer], + **kwargs, + ) + if include_schedulers: + state.schedulers = StepLR(optimizer=optimizer, step_size=2) + return state + + def init_model_and_optimizer( - use_composer_model: bool, + use_composer_model: bool = True, num_classes=3, batch_size=5, num_features=8, @@ -26,7 +87,7 @@ def init_model_and_optimizer( use_fsdp=False, tensor_type='sharded_tensor', device='cuda', -): +) -> Tuple[Union[ComposerModel, torch.nn.Module], torch.optim.Optimizer]: model, loss_fn = init_model( use_composer_model, num_classes=num_classes, @@ -59,7 +120,7 @@ def init_model( tensor_type='sharded_tensor', sync_module_states=True, cpu_offload=False, -): +) -> Tuple[Union[ComposerModel, torch.nn.Module], Any]: if use_composer_model: model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device=device) loss_fn = model._loss_fn diff --git a/tests/checkpoint/test_save.py b/tests/checkpoint/test_save.py index 03b12bbcbc..f4d41cc09d 100644 --- a/tests/checkpoint/test_save.py +++ b/tests/checkpoint/test_save.py @@ -1,6 +1,7 @@ # Copyright 2024 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import json import os import time import uuid @@ -12,15 +13,157 @@ import torch.distributed.checkpoint as DCP from packaging import version -from composer.checkpoint.save import save_state_dict_to_disk -from composer.checkpoint.state_dict import get_model_state_dict +from composer.checkpoint.save import ( + save_checkpoint_to_disk, + save_composer_metadata_to_disk, + save_model_to_disk, + save_optim_to_disk, + save_state_dict_to_disk, +) +from composer.checkpoint.state_dict import get_model_state_dict, get_optim_state_dict +from composer.core import Timestamp from composer.utils import dist -from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME -from tests.checkpoint.helpers import init_model +from composer.utils.checkpoint import ( + _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME, + _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME, +) +from tests.checkpoint.helpers import init_model, init_model_and_optimizer, init_state from tests.common.compare import deep_compare from tests.common.markers import world_size +@pytest.mark.gpu +@pytest.mark.parametrize( + 'world_size,sharded_model,sharded_checkpoint', + [ + pytest.param(1, False, False, marks=pytest.mark.world_size(1)), + pytest.param(2, True, True, marks=pytest.mark.world_size(2)), + pytest.param(2, True, False, marks=pytest.mark.world_size(2)), + ], +) +@pytest.mark.filterwarnings('ignore::UserWarning') +def test_save_checkpoint_to_disk(world_size: int, tmp_path: str, sharded_model: bool, sharded_checkpoint: bool): + destination_dir = os.path.join(tmp_path, str(uuid.uuid4())[:8]) + destination_dir = dist.all_gather_object(destination_dir)[0] + save_options = { + 'destination_dir': destination_dir, + 'save_model': True, + 'save_optimizer': True, + 'save_resumption_state': True, + 'sharded_checkpoint': sharded_checkpoint, + 'dir_prefix': 'ep{epoch}-ba{batch}', + } + state = init_state(use_fsdp=sharded_model, device='cuda', take_step=True) + state.run_name = 'foo' + state.timestamp = Timestamp() + expected_destination_dir = os.path.join(destination_dir, 'ep0-ba0') + save_checkpoint_to_disk(state, save_options) + expected_model_dir = os.path.join(expected_destination_dir, 'model') + expected_optim_dir = os.path.join(expected_destination_dir, 'optim') + expected_metadata_filepath = os.path.join(expected_destination_dir, 'composer_metadata.json') + expected_resumption_filepath = os.path.join(expected_destination_dir, 'resumption.pkl') + if sharded_checkpoint: + checkpoint_filenames = dist.all_gather_object(_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME) + for checkpoint_filename in checkpoint_filenames: + assert os.path.exists(os.path.join(expected_model_dir, checkpoint_filename)) + assert os.path.exists(os.path.join(expected_optim_dir, checkpoint_filename)) + assert os.path.exists(os.path.join(expected_model_dir, _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME)) + assert os.path.exists(os.path.join(expected_optim_dir, _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME)) + else: + assert os.path.exists(os.path.join(expected_model_dir, 'model.pt')) + assert os.path.exists(os.path.join(expected_optim_dir, 'optim.pt')) + + import time + + # Need to wait for the file to be written to avoid flaky test. + time.sleep(0.2) + assert os.path.exists(expected_metadata_filepath) + assert os.path.exists(expected_resumption_filepath) + + +def test_save_composer_metadata_to_disk(tmp_path: str): + destination_dir = os.path.join(tmp_path, str(uuid.uuid4())[:8]) + destination_dir = dist.all_gather_object(destination_dir)[0] + save_composer_metadata_to_disk(destination_dir) + expected_file_path = os.path.join(destination_dir, 'composer_metadata.json') + assert os.path.exists(expected_file_path) + json.load(open(expected_file_path, 'r')) + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'world_size,sharded_optimizer,sharded_checkpoint', + [ + pytest.param(1, False, False, marks=pytest.mark.world_size(1)), + pytest.param(2, True, True, marks=pytest.mark.world_size(2)), + pytest.param(2, True, False, marks=pytest.mark.world_size(2)), + ], +) +def test_save_optim_to_disk(world_size: int, tmp_path: str, sharded_optimizer: bool, sharded_checkpoint: bool): + destination_dir = os.path.join(tmp_path, str(uuid.uuid4())[:8]) + # Sync the path across all ranks + destination_dir = dist.all_gather_object(destination_dir)[0] + use_fsdp = sharded_optimizer + model, optim = init_model_and_optimizer(use_fsdp=use_fsdp, device='cuda') + optim_state_dict = get_optim_state_dict(model, optimizer=optim, sharded_state_dict=sharded_checkpoint) + optim_state_dict_saved = deepcopy(optim_state_dict) + save_optim_to_disk(model, optim, destination_dir=destination_dir, sharded_checkpoint=sharded_checkpoint) + + # Load new optim from disk + model, optim = init_model_and_optimizer(use_fsdp=use_fsdp, device='cuda') + cur_state_dict = get_optim_state_dict(model, optimizer=optim, sharded_state_dict=sharded_checkpoint) + + if sharded_checkpoint: + expected_file_path = os.path.join(destination_dir, 'optim') + if version.parse(torch.__version__) < version.parse('2.2.0'): + DCP.load_state_dict(state_dict=cur_state_dict, storage_reader=DCP.FileSystemReader(expected_file_path)) + else: + DCP.load(state_dict=cur_state_dict, storage_reader=DCP.FileSystemReader(expected_file_path)) + else: + if dist.get_global_rank() == 0: + expected_file_path = os.path.join(destination_dir, 'optim', 'optim.pt') + cur_state_dict = torch.load(expected_file_path, map_location='cuda') + + deep_compare(optim_state_dict_saved, cur_state_dict) + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'world_size,sharded_model,sharded_checkpoint', + [ + pytest.param(1, False, False, marks=pytest.mark.world_size(1)), + pytest.param(2, True, True, marks=pytest.mark.world_size(2)), + pytest.param(2, True, False, marks=pytest.mark.world_size(2)), + ], +) +def test_save_model_to_disk(world_size: int, tmp_path: str, sharded_model: bool, sharded_checkpoint: bool): + destination_dir = os.path.join(tmp_path, str(uuid.uuid4())[:8]) + # Sync the path across all ranks + destination_dir = dist.all_gather_object(destination_dir)[0] + use_fsdp = sharded_model + model, _ = init_model(use_fsdp=use_fsdp, device='cuda', sync_module_states=True) + state_dict = get_model_state_dict(model, sharded_state_dict=sharded_checkpoint) + state_dict_saved = deepcopy(state_dict) + save_model_to_disk(model, destination_dir=destination_dir, sharded_checkpoint=sharded_checkpoint) + + # Load new model from disk + new_model, _ = init_model(use_fsdp=use_fsdp, device='cuda', sync_module_states=True) + cur_state_dict = get_model_state_dict(new_model, sharded_state_dict=sharded_checkpoint) + + if sharded_checkpoint: + expected_file_path = os.path.join(destination_dir, 'model') + if version.parse(torch.__version__) < version.parse('2.2.0'): + DCP.load_state_dict(state_dict=cur_state_dict, storage_reader=DCP.FileSystemReader(expected_file_path)) + else: + DCP.load(state_dict=cur_state_dict, storage_reader=DCP.FileSystemReader(expected_file_path)) + else: + if dist.get_global_rank() == 0: + expected_file_path = os.path.join(destination_dir, 'model', 'model.pt') + cur_state_dict = torch.load(expected_file_path, map_location='cuda') + + deep_compare(state_dict_saved, cur_state_dict) + + @world_size(1, 2) @pytest.mark.gpu @pytest.mark.parametrize('sharded_model', [False, True]) diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index 4f719254a7..12fde27249 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -3,27 +3,21 @@ import datetime from typing import Any -from unittest.mock import MagicMock import pytest import torch from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.optim.lr_scheduler import StepLR -from torch.utils.data import DataLoader -from composer.algorithms import SWA -from composer.callbacks import SpeedMonitor from composer.checkpoint import ( get_metadata_state_dict, get_model_state_dict, get_optim_state_dict, get_resumption_state_dict, ) -from composer.core import State -from composer.devices import DeviceCPU, DeviceGPU +from composer.devices import DeviceGPU from composer.utils import dist, reproducibility -from tests.checkpoint.helpers import init_model_and_optimizer +from tests.checkpoint.helpers import init_model_and_optimizer, init_state from tests.common.compare import deep_compare from tests.common.markers import world_size from tests.common.models import EvenSimplerMLP, SimpleComposerMLP, configure_tiny_gpt2_hf_model @@ -444,27 +438,17 @@ def test_get_metadata_sharded_model(model_type: str, tensor_type: str, world_siz @pytest.mark.filterwarnings('ignore:SWA has') def test_get_resumption_state_dict(): - - model, optimizer = init_model_and_optimizer(use_composer_model=True, take_step=True, device='cpu') - - rank_zero_seed = 10 run_name = 'test_run' - device = DeviceCPU() - test_dataset_sd = {'foo': 0} - dataloader = MagicMock(spec=DataLoader) - dataloader.dataset = MagicMock() - dataloader.dataset.state_dict = MagicMock(return_value=test_dataset_sd) - swa = SWA() - state = State( - model=model, + rank_zero_seed = 10 + state = init_state( + device='cpu', + include_algorithms=True, + include_callbacks=True, + include_schedulers=True, rank_zero_seed=rank_zero_seed, run_name=run_name, - device=device, - train_dataloader=dataloader, - algorithms=[swa], - callbacks=[SpeedMonitor(), SpeedMonitor()], ) - state.schedulers = StepLR(optimizer=optimizer, step_size=2) + test_dataset_sd = {'test': 0} rsd = get_resumption_state_dict(state) assert rsd['rank_zero_seed'] == rank_zero_seed @@ -505,27 +489,7 @@ def test_get_resumption_state_dict(): @pytest.mark.gpu def test_get_resumption_state_dict_gpu(): - if version.parse(torch.__version__) >= version.parse('2.3.0'): - from torch.amp.grad_scaler import GradScaler - else: - from torch.cuda.amp.grad_scaler import GradScaler - - model, _ = init_model_and_optimizer(use_composer_model=True, take_step=False, device='cuda') - - rank_zero_seed = 10 - run_name = 'test_run' - device = DeviceCPU() - test_dataset_sd = {'test': 0} - dataloader = MagicMock() - dataloader.dataset = MagicMock() - dataloader.dataset.state_dict = MagicMock(return_value=test_dataset_sd) - state = State( - model=model, - rank_zero_seed=rank_zero_seed, - run_name=run_name, - device=device, - scaler=GradScaler(), - ) + state = init_state(device='cuda', use_grad_scaler=True) rsd = get_resumption_state_dict(state) assert 'scaler' in rsd assert set(