Skip to content

Commit

Permalink
Validate loading model and optim state (#18)
Browse files Browse the repository at this point in the history
* validate loading model and optim state

* Add another test for checkpointing with dtensors

* update docstring
  • Loading branch information
epwalsh authored May 14, 2024
1 parent c726658 commit fb20119
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
20 changes: 17 additions & 3 deletions src/olmo_core/distributed/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def load_model_and_optim_state(
dir: PathOrStr,
model: nn.Module,
optim: Optional[torch.optim.Optimizer] = None,
validate: bool = True,
):
"""
Load model and optimizer state in-place from a checkpoint saved via :func:`save_model_and_optim_state()`.
Expand All @@ -150,13 +151,18 @@ def load_model_and_optim_state(
:param dir: Path/URL to the checkpoint saved via :func:`save_model_and_optim_state()`.
:param model: The model to load the state into.
:param optim: The optimizer to load the state into.
:param validate: Validate that all tensors have been loaded completely from the checkpoint by
pre-filling each tensor with NaNs prior to loading in-place, then checking afterwards
that there are no NaNs remaining.
"""
dir = str(dir).rstrip("/")
checkpointer = Checkpointer()

# Get model state in-place.
# Get model state and load in-place.
model_state = _get_model_state_dict_for_checkpoint(model)
checkpointer.load(f"{dir}/model", model_state)
if validate:
_fill_state_dict_with_nan(model_state)
checkpointer.load(f"{dir}/model", model_state, _check_for_nans=validate)
_load_model_state_dict(model, model_state)

if optim is not None:
Expand All @@ -173,9 +179,11 @@ def load_model_and_optim_state(
for i in range(len(optim.param_groups)):
flat_optim_state[f"param_group{i}"] = metadata.tensors[f"param_group{i}"].materialize_empty()
flat_optim_state["state_keys"] = metadata.tensors["state_keys"].materialize_empty()
if validate:
_fill_state_dict_with_nan(flat_optim_state)

# Load flattened optimizer state in place.
checkpointer.load(f"{dir}/optim", flat_optim_state, metadata=metadata)
checkpointer.load(f"{dir}/optim", flat_optim_state, metadata=metadata, _check_for_nans=validate)

# Unflatten optimizer state and pass to optimizer.
optim_state_to_load = _unflatten_optimizer_state(flat_optim_state)
Expand Down Expand Up @@ -1337,3 +1345,9 @@ def _wrap_tensor_for_sharded_parameter(tensor: torch.Tensor, param: Optional[tor
return _wrap_tensor_for_sharded_parameter(tensor, param.data)
else:
return tensor


def _fill_state_dict_with_nan(state_dict: Dict[str, torch.Tensor]):
for tensor in state_dict.values():
if tensor.dtype.is_floating_point:
_get_local_tensor_data(tensor).fill_(torch.nan)
17 changes: 17 additions & 0 deletions src/test/distributed/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
SafeTensorsLoader,
TensorShardSpec,
_flatten_optimizer_state,
_get_local_tensor_data,
_get_model_state_dict_for_checkpoint,
_offsets_overlap,
_unflatten_optimizer_state,
Expand Down Expand Up @@ -250,6 +251,22 @@ def save_and_load_checkpoint_with_regular_and_sharded_tensors(dir):
assert full_state_dict["y"].shape == (2, 3)


def run_get_local_tensor_data_with_dtensor():
mesh = init_device_mesh("cuda", (dist.get_world_size(),))
dtensor = distribute_tensor(torch.randn(16, device=get_default_device()), mesh, [Shard(dim=0)])

# Make sure modifying the data returned from `_get_local_tensor_data` will modify the data
# in the actual tensor.
_get_local_tensor_data(dtensor).fill_(torch.nan)
assert _get_local_tensor_data(dtensor).isnan().all()
assert dtensor.full_tensor().isnan().all()


@requires_multi_gpu
def test_get_local_tensor_data_with_dtensor():
run_distributed_test(run_get_local_tensor_data_with_dtensor, backend="nccl")


@pytest.mark.parametrize("backend", BACKENDS)
def test_save_and_load_checkpoint_with_regular_and_sharded_tensors(backend, tmp_path):
run_distributed_test(
Expand Down

0 comments on commit fb20119

Please sign in to comment.