Skip to content

Commit

Permalink
parametrize test
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed May 10, 2024
1 parent 9e278f3 commit 0ab1db5
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/test/distributed/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def test_save_and_load_torch_fsdp_model(
)


def run_save_and_load_tensor_parallel_model(dir):
def run_save_and_load_tensor_parallel_model(dir, take_step_before_checkpoint):
tp_mesh = init_device_mesh("cuda", (dist.get_world_size(),))

class FeedForward(nn.Module):
Expand Down Expand Up @@ -853,17 +853,22 @@ def forward(self, x):
# Take a forward and backward pass.
feed_forward(torch.rand((2, feed_forward.dim), device="cuda")).sum().backward()

# Take an optimizer step.
if take_step_before_checkpoint:
# Take an optimizer step.
optim.step()

# Save checkpoint.
save_model_and_optim_state(dir, feed_forward, optim)


@requires_multi_gpu
def test_save_and_load_tensor_parallel_model(tmp_path):
@pytest.mark.parametrize(
"take_step_before_checkpoint", [pytest.param(True, id="after-step"), pytest.param(False, id="pre-step")]
)
def test_save_and_load_tensor_parallel_model(tmp_path, take_step_before_checkpoint):
run_distributed_test(
run_save_and_load_tensor_parallel_model,
backend="nccl",
start_method="spawn",
func_args=(tmp_path,),
func_args=(tmp_path, take_step_before_checkpoint),
)

0 comments on commit 0ab1db5

Please sign in to comment.