Skip to content

Commit

Permalink
f
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Feb 13, 2024
1 parent 6bb69ff commit b39c831
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions tests/test_clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _test_clip_grads_with_pp(parallel_context: ParallelContext, norm_type: float
to_rank=reference_rank,
)

parallel_context.destroy()
# parallel_context.destroy()


@pytest.mark.skipif(available_gpus() < 2, reason="test_clip_grads_with_tp requires at least 2 gpus")
Expand Down Expand Up @@ -340,11 +340,7 @@ def _test_clip_grads_with_tp(
)
torch.testing.assert_close(total_norm, ref_total_norm)

try:
parallel_context.destroy()
except Exception:
print("Failed to destroy parallel context")
print(f"parallel_contex.type: {type(parallel_context)}")
# parallel_context.destroy()


@pytest.mark.skipif(available_gpus() < 2, reason="test_clip_grads_tied_weights requires at least 2 gpus")
Expand Down Expand Up @@ -439,7 +435,7 @@ def _test_clip_grads_tied_weights(parallel_context: ParallelContext, norm_type:
assert torch.allclose(bias.grad, ref_bias.grad, rtol=1e-7, atol=1e-6)
assert torch.allclose(total_norm, ref_total_norm, rtol=0, atol=0), f"Got {total_norm} and {ref_total_norm}"

parallel_context.destroy()
# parallel_context.destroy()


@pytest.mark.parametrize("half_precision", [torch.float16, torch.bfloat16])
Expand Down Expand Up @@ -629,4 +625,4 @@ def _test_clip_grads_fp32_accumulator(
to_rank=reference_rank,
)

parallel_context.destroy()
# parallel_context.destroy()

0 comments on commit b39c831

Please sign in to comment.