Skip to content

Commit

Permalink
[fbsync] [references/classification] Adding gradient clipping (#4824)
Browse files Browse the repository at this point in the history
Summary:
* [references] Adding gradient clipping

* ufmt formatting

* remove apex code

* resolve naming issue

Reviewed By: kazhang

Differential Revision: D32216659

fbshipit-source-id: 9c5ffb102fa5fd9861ae5ba0c44052920c34ebaf
  • Loading branch information
prabhat00155 authored and facebook-github-bot committed Nov 8, 2021
1 parent 5357fc9 commit d0d2d63
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
5 changes: 5 additions & 0 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
else:
loss = criterion(output, target)
loss.backward()

if args.clip_grad_norm is not None:
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm)

optimizer.step()

if model_ema and i % args.model_ema_steps == 0:
Expand Down Expand Up @@ -472,6 +476,7 @@ def get_args_parser(add_help=True):
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")

# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
Expand Down
8 changes: 8 additions & 0 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,11 @@ def reduce_across_processes(val):
dist.barrier()
dist.all_reduce(t)
return t


def get_optimizer_params(optimizer):
"""Generator to iterate over all parameters in the optimizer param_groups."""

for group in optimizer.param_groups:
for p in group["params"]:
yield p

0 comments on commit d0d2d63

Please sign in to comment.