diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index 823d73b5..e1af96dc 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -822,7 +822,7 @@ def test_save_and_load_torch_fsdp_model( ) -def run_save_and_load_tensor_parallel_model(dir): +def run_save_and_load_tensor_parallel_model(dir, take_step_before_checkpoint): tp_mesh = init_device_mesh("cuda", (dist.get_world_size(),)) class FeedForward(nn.Module): @@ -853,17 +853,22 @@ def forward(self, x): # Take a forward and backward pass. feed_forward(torch.rand((2, feed_forward.dim), device="cuda")).sum().backward() - # Take an optimizer step. + if take_step_before_checkpoint: + # Take an optimizer step. + optim.step() # Save checkpoint. save_model_and_optim_state(dir, feed_forward, optim) @requires_multi_gpu -def test_save_and_load_tensor_parallel_model(tmp_path): +@pytest.mark.parametrize( + "take_step_before_checkpoint", [pytest.param(True, id="after-step"), pytest.param(False, id="pre-step")] +) +def test_save_and_load_tensor_parallel_model(tmp_path, take_step_before_checkpoint): run_distributed_test( run_save_and_load_tensor_parallel_model, backend="nccl", start_method="spawn", - func_args=(tmp_path,), + func_args=(tmp_path, take_step_before_checkpoint), )