Skip to content

Commit

Permalink
make sure total_norm in clip grad is a scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Jan 31, 2024
1 parent 0a754a1 commit 5d822bb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/nanotron/optim/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def clip_grad_norm(
dtype=torch.float,
).pow(norm_type)
else:
total_norm = torch.zeros(1, dtype=torch.float, device=torch.device("cuda"))
total_norm = torch.zeros([], dtype=torch.float, device=torch.device("cuda"))
dist.all_reduce(total_norm, group=mp_pg, op=dist.ReduceOp.SUM)
total_norm.pow_(1.0 / norm_type)

Expand Down
17 changes: 6 additions & 11 deletions tests/test_clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,17 +345,9 @@ def test_clip_grads_tied_weights(norm_type: float):

def _test_clip_grads_tied_weights(parallel_context: ParallelContext, norm_type: float):
if dist.get_rank(parallel_context.pp_pg) == 0:
model = nn.ModuleDict(
{
"dense0": nn.Linear(10, 10, device="cuda"),
}
)
model = nn.ModuleDict({"dense0": nn.Linear(10, 10, device="cuda")})
else:
model = nn.ModuleDict(
{
"dense1": nn.Linear(10, 10, device="cuda"),
}
)
model = nn.ModuleDict({"dense1": nn.Linear(10, 10, device="cuda")})

# Tie weights/bias
tie_parameters(
Expand Down Expand Up @@ -427,14 +419,17 @@ def _test_clip_grads_tied_weights(parallel_context: ParallelContext, norm_type:
norm_type=norm_type,
)
ref_total_norm = torch.nn.utils.clip_grad_norm_([ref_weight, ref_bias], max_norm=1.0, norm_type=norm_type)
assert len(total_norm.shape) == 0, f"total_norm should be a scalar. Got {total_norm}"

# Check that the gradients have changed
assert not torch.allclose(old_grad, weight.grad), "Gradients should have changed after clipping"

# Test that we get the same gradient after clipping
torch.testing.assert_close(weight.grad, ref_weight.grad, rtol=1e-7, atol=1e-6)
torch.testing.assert_close(bias.grad, ref_bias.grad, rtol=1e-7, atol=1e-6)
assert total_norm == ref_total_norm, "Total norm should be the same"
torch.testing.assert_close(
total_norm, ref_total_norm, rtol=0, atol=0, msg=lambda msg: f"{msg}\n" f"Got {total_norm} and {ref_total_norm}"
)


@pytest.mark.parametrize("half_precision", [torch.float16, torch.bfloat16])
Expand Down

0 comments on commit 5d822bb

Please sign in to comment.