diff --git a/adam_atan2_pytorch/adopt_atan2.py b/adam_atan2_pytorch/adopt_atan2.py index 3bbe626..ba664bc 100644 --- a/adam_atan2_pytorch/adopt_atan2.py +++ b/adam_atan2_pytorch/adopt_atan2.py @@ -27,6 +27,7 @@ def __init__( betas: tuple[float, float] = (0.9, 0.99), weight_decay = 0., decoupled_wd = True, + clip_update = True, a = 1.27, b = 1. ): @@ -43,6 +44,7 @@ def __init__( a = a, b = b, weight_decay = weight_decay, + clip_update = clip_update ) super().__init__(params, defaults) @@ -61,7 +63,7 @@ def step( for group in self.param_groups: for p in filter(lambda p: exists(p.grad), group['params']): - grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr + grad, lr, wd, beta1, beta2, clip, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['clip_update'], group['a'], group['b'], self.state[p], self._init_lr # maybe decoupled weight decay @@ -94,9 +96,13 @@ def step( grad_sq = grad * grad - next_m = grad.atan2(b * v.sqrt()) + update = grad.atan2(b * v.sqrt()) - m.lerp_(next_m, 1. - beta1) + if clip: + clip_value = steps ** 0.25 + update.clamp_(-clip_value, clip_value) + + m.lerp_(update, 1. - beta1) # then update parameters diff --git a/pyproject.toml b/pyproject.toml index 9abbc06..df8f1dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "adam-atan2-pytorch" -version = "0.1.10" +version = "0.1.11" description = "Adam-atan2 for Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }