From e3a0affd8b501fac5ff15407b9c6754ec3309198 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 22 Nov 2024 13:19:52 -0800 Subject: [PATCH] remove clipping from adopt atan2 --- adam_atan2_pytorch/adopt_atan2.py | 8 +------- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/adam_atan2_pytorch/adopt_atan2.py b/adam_atan2_pytorch/adopt_atan2.py index ba664bc..b44e6ad 100644 --- a/adam_atan2_pytorch/adopt_atan2.py +++ b/adam_atan2_pytorch/adopt_atan2.py @@ -27,7 +27,6 @@ def __init__( betas: tuple[float, float] = (0.9, 0.99), weight_decay = 0., decoupled_wd = True, - clip_update = True, a = 1.27, b = 1. ): @@ -44,7 +43,6 @@ def __init__( a = a, b = b, weight_decay = weight_decay, - clip_update = clip_update ) super().__init__(params, defaults) @@ -63,7 +61,7 @@ def step( for group in self.param_groups: for p in filter(lambda p: exists(p.grad), group['params']): - grad, lr, wd, beta1, beta2, clip, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['clip_update'], group['a'], group['b'], self.state[p], self._init_lr + grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr # maybe decoupled weight decay @@ -98,10 +96,6 @@ def step( update = grad.atan2(b * v.sqrt()) - if clip: - clip_value = steps ** 0.25 - update.clamp_(-clip_value, clip_value) - m.lerp_(update, 1. - beta1) # then update parameters diff --git a/pyproject.toml b/pyproject.toml index df8f1dd..7422024 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "adam-atan2-pytorch" -version = "0.1.11" +version = "0.1.12" description = "Adam-atan2 for Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }