Skip to content

Commit

Permalink
https://github.com/LiyuanLucasLiu/RAdam/pull/51
Browse files Browse the repository at this point in the history
  • Loading branch information
khlam committed Oct 11, 2021
1 parent d9fd30a commit cb883b6
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions radam/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def step(self, closure=None):
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']

exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2)
exp_avg.mul_(beta1).add_(grad, alpha = 1 - beta1)

state['step'] += 1
buffered = group['buffer'][int(state['step'] % 10)]
Expand All @@ -81,14 +81,14 @@ def step(self, closure=None):
# more conservative since it's an approximated value
if N_sma >= 5:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(p_data_fp32, alpha = -group['weight_decay'] * group['lr'])
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
p_data_fp32.addcdiv_(exp_avg, denom, value = -step_size * group['lr'])
p.data.copy_(p_data_fp32)
elif step_size > 0:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
p_data_fp32.add_(p_data_fp32, alpha = -group['weight_decay'] * group['lr'])
p_data_fp32.add_(exp_avg, alpha = -step_size * group['lr'])
p.data.copy_(p_data_fp32)

return loss
Expand Down

0 comments on commit cb883b6

Please sign in to comment.