Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Feb 21, 2025
1 parent 2c2a101 commit a59744d
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions tests/distributed/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,12 @@ def move_grads_to_cpu(parameters):
xm.mark_step()

if max_grad_norm is not None:
accelerator.clip_grad_norm_(model.local_parameters(), max_norm=max_grad_norm, norm_type=2)
accelerator.clip_grad_norm_(
model.local_parameters(),
max_norm=max_grad_norm,
norm_type=2,
postpone_clipping_to_optimizer_step=True,
)

# Checking that at least some of the parameters have a gradient.
grads_on_cpu = move_grads_to_cpu(model.local_parameters())
Expand Down Expand Up @@ -259,7 +264,12 @@ def move_grads_to_cpu(parameters):
loss.backward()

if max_grad_norm is not None:
accelerator.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm, norm_type=2)
accelerator.clip_grad_norm_(
model.parameters(),
max_norm=max_grad_norm,
norm_type=2,
postpone_clipping_to_optimizer_step=True,
)

# Checking that at least some of the parameters have a gradient.
grads_on_cpu = move_grads_to_cpu(model.parameters())
Expand Down

0 comments on commit a59744d

Please sign in to comment.