diff --git a/radam/radam.py b/radam/radam.py index 0f97c81..8c738f2 100644 --- a/radam/radam.py +++ b/radam/radam.py @@ -55,8 +55,8 @@ def step(self, closure=None): exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] - exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) - exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2) + exp_avg.mul_(beta1).add_(grad, alpha = 1 - beta1) state['step'] += 1 buffered = group['buffer'][int(state['step'] % 10)] @@ -81,14 +81,14 @@ def step(self, closure=None): # more conservative since it's an approximated value if N_sma >= 5: if group['weight_decay'] != 0: - p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + p_data_fp32.add_(p_data_fp32, alpha = -group['weight_decay'] * group['lr']) denom = exp_avg_sq.sqrt().add_(group['eps']) - p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) + p_data_fp32.addcdiv_(exp_avg, denom, value = -step_size * group['lr']) p.data.copy_(p_data_fp32) elif step_size > 0: if group['weight_decay'] != 0: - p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) - p_data_fp32.add_(-step_size * group['lr'], exp_avg) + p_data_fp32.add_(p_data_fp32, alpha = -group['weight_decay'] * group['lr']) + p_data_fp32.add_(exp_avg, alpha = -step_size * group['lr']) p.data.copy_(p_data_fp32) return loss