Skip to content

Commit

Permalink
Missed extra nadam algo step for capturable path
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jun 14, 2023
1 parent 4790c0f commit 2d597b1
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions timm/optim/nadamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,11 @@ def _multi_tensor_nadamw(

bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2)

# Only difference between NAdamW and AdamW in this implementation.
# The official PyTorch implementation of NAdam uses a different algorithm.
exp_avgs = torch._foreach_mul(exp_avgs, beta1)
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)

exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
torch._foreach_div_(
exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)
Expand Down

0 comments on commit 2d597b1

Please sign in to comment.