From 004479505e69dfc469d2b98ad681e1661dbe0d0f Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 31 Jul 2024 14:38:37 +0000 Subject: [PATCH] fix foreach --- adam_atan2_pytorch/foreach.py | 15 ++++++++------- pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/adam_atan2_pytorch/foreach.py b/adam_atan2_pytorch/foreach.py index 88df8d2..81e7cd8 100644 --- a/adam_atan2_pytorch/foreach.py +++ b/adam_atan2_pytorch/foreach.py @@ -70,6 +70,14 @@ def step( loss = closure() for group in self.param_groups: + + # accumulate List[Tensor] for foreach inplace updates + + params = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + for p in filter(lambda p: exists(p.grad), group['params']): grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr @@ -79,13 +87,6 @@ def step( if wd > 0.: wd /= init_lr - # accumulate List[Tensor] for foreach inplace updates - - params = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - # init state if needed if len(state) == 0: diff --git a/pyproject.toml b/pyproject.toml index e10c209..d0e1111 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "adam-atan2-pytorch" -version = "0.0.5" +version = "0.0.6" description = "Adam-atan2 for Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }