Skip to content

Commit

Permalink
fix clip grads
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Feb 10, 2024
1 parent 44c0e05 commit b8eeb1e
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
5 changes: 3 additions & 2 deletions tests/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def get_all_3d_configurations(gpus: int) -> List[Tuple[int, int, int]]:
return result


def rerun_if_address_is_in_use():
def rerun_if_address_is_in_use(max_try: int = 100):
"""
This function reruns a wrapped function if "address already in use" occurs
in testing spawned with torch.multiprocessing
Expand All @@ -214,7 +214,7 @@ def test_something():
else:
exception = Exception

func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*", max_try=100)
func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*", max_try=max_try)
return func_wrapper


Expand Down Expand Up @@ -287,6 +287,7 @@ def _run_until_success(*args, **kwargs):
except exception_type as e:
error_lines = str(e).split("\n")
if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)):

print("Exception is caught, retrying...")
# when pattern is not specified, we always skip the exception
# when pattern is specified, we only skip when pattern is matched
Expand Down
4 changes: 1 addition & 3 deletions tests/test_clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,7 @@ def _test_clip_grads_tied_weights(parallel_context: ParallelContext, norm_type:
)
ref_total_norm = torch.nn.utils.clip_grad_norm_([ref_weight, ref_bias], max_norm=1.0, norm_type=norm_type)

assert (
total_norm.dim() == 0
), f"total_norm should be a scalar. Got {total_norm}, Debug: total_norm.dim()={total_norm.dim()}, type: {type(total_norm.dim())}"
assert total_norm.dim() == 1, f"total_norm should be a scalar. Got {total_norm}"
# Check that the gradients have changed
assert not torch.allclose(old_grad, weight.grad), "Gradients should have changed after clipping"

Expand Down
2 changes: 1 addition & 1 deletion tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def _test_init_parallel_context(parallel_context: ParallelContext):
for all_3d_configs in get_all_3d_configurations(gpus)
],
)
@rerun_if_address_is_in_use()
@rerun_if_address_is_in_use(max_try=150)
def test_init_parallel_context(tp: int, dp: int, pp: int):
init_distributed(tp=tp, dp=dp, pp=pp)(_test_init_parallel_context)()

0 comments on commit b8eeb1e

Please sign in to comment.