Skip to content

Commit

Permalink
safer not to inplace update grad
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 31, 2024
1 parent 0044795 commit 8910336
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
11 changes: 7 additions & 4 deletions adam_atan2_pytorch/foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def step(
self,
closure: Callable | None = None
):
init_lr = self._init_lr

loss = None
if exists(closure):
Expand All @@ -71,16 +72,19 @@ def step(

for group in self.param_groups:

wd, lr, beta1, beta2, a, b = group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b']

# accumulate List[Tensor] for foreach inplace updates

params = []
grads = []
grad_squared = []
exp_avgs = []
exp_avg_sqs = []

for p in filter(lambda p: exists(p.grad), group['params']):

grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
grad, state = p.grad, self.state[p]

# decoupled weight decay

Expand Down Expand Up @@ -109,6 +113,7 @@ def step(

params.append(p)
grads.append(grad)
grad_squared.append(grad * grad)
exp_avgs.append(exp_avg)
exp_avg_sqs.append(exp_avg_sq)

Expand All @@ -123,9 +128,7 @@ def step(
# decay running averages

torch._foreach_lerp_(exp_avgs, grads, 1. - beta1)

torch._foreach_mul_(grads, grads)
torch._foreach_lerp_(exp_avg_sqs, grads, 1. - beta2) # grads is grad squared now
torch._foreach_lerp_(exp_avg_sqs, grad_squared, 1. - beta2) # grads is grad squared now

# clone for update

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "adam-atan2-pytorch"
version = "0.0.6"
version = "0.0.7"
description = "Adam-atan2 for Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 8910336

Please sign in to comment.