From 8910336e5ecfb320a08d5a35cdfba5cb2f8a3f2d Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 31 Jul 2024 14:44:31 +0000 Subject: [PATCH] safer not to inplace update grad --- adam_atan2_pytorch/foreach.py | 11 +++++++---- pyproject.toml | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/adam_atan2_pytorch/foreach.py b/adam_atan2_pytorch/foreach.py index 81e7cd8..fc8aab7 100644 --- a/adam_atan2_pytorch/foreach.py +++ b/adam_atan2_pytorch/foreach.py @@ -63,6 +63,7 @@ def step( self, closure: Callable | None = None ): + init_lr = self._init_lr loss = None if exists(closure): @@ -71,16 +72,19 @@ def step( for group in self.param_groups: + wd, lr, beta1, beta2, a, b = group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'] + # accumulate List[Tensor] for foreach inplace updates params = [] grads = [] + grad_squared = [] exp_avgs = [] exp_avg_sqs = [] for p in filter(lambda p: exists(p.grad), group['params']): - grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr + grad, state = p.grad, self.state[p] # decoupled weight decay @@ -109,6 +113,7 @@ def step( params.append(p) grads.append(grad) + grad_squared.append(grad * grad) exp_avgs.append(exp_avg) exp_avg_sqs.append(exp_avg_sq) @@ -123,9 +128,7 @@ def step( # decay running averages torch._foreach_lerp_(exp_avgs, grads, 1. - beta1) - - torch._foreach_mul_(grads, grads) - torch._foreach_lerp_(exp_avg_sqs, grads, 1. - beta2) # grads is grad squared now + torch._foreach_lerp_(exp_avg_sqs, grad_squared, 1. - beta2) # grads is grad squared now # clone for update diff --git a/pyproject.toml b/pyproject.toml index d0e1111..832caa6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "adam-atan2-pytorch" -version = "0.0.6" +version = "0.0.7" description = "Adam-atan2 for Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }