From 7137f5de5a11bb3740ca07c6054c38eada3e2341 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 22 Nov 2024 10:06:05 -0800 Subject: [PATCH] the adopt authors have updated paper with clipping for stability --- adam_atan2_pytorch/adopt.py | 15 +++++++++++---- adam_atan2_pytorch/adopt_atan2.py | 6 +++--- pyproject.toml | 2 +- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/adam_atan2_pytorch/adopt.py b/adam_atan2_pytorch/adopt.py index 0e0ea13..86f6be5 100644 --- a/adam_atan2_pytorch/adopt.py +++ b/adam_atan2_pytorch/adopt.py @@ -16,7 +16,7 @@ class Adopt(Optimizer): """ the proposed Adam substitute from University of Tokyo - Algorithm 2 in https://arxiv.org/abs/2411.02853 + Algorithm 3 in https://arxiv.org/abs/2411.02853 """ def __init__( @@ -74,7 +74,7 @@ def step( if len(state) == 0: state['steps'] = 0 - state['m'] = torch.empty_like(grad) + state['m'] = torch.zeros_like(grad) state['v'] = grad * grad # get some of the states @@ -91,9 +91,16 @@ def step( grad_sq = grad * grad - next_m = grad.div(v.sqrt().clamp(min = eps)) # they claim that a max(value, eps) performs better than adding the epsilon + update = grad.div(v.sqrt().clamp(min = eps)) # they claim that a max(value, eps) performs better than adding the epsilon - m.lerp_(next_m, 1. - (beta1 * int(steps > 1))) + # clip with t ^ 0.25 as in Algorithm 3 + + clip_value = steps ** 0.25 + update.clamp_(min = -clip_value, max = clip_value) + + # update m + + m.lerp_(update, 1. - beta1) # then update parameters diff --git a/adam_atan2_pytorch/adopt_atan2.py b/adam_atan2_pytorch/adopt_atan2.py index 54ce3b7..3bbe626 100644 --- a/adam_atan2_pytorch/adopt_atan2.py +++ b/adam_atan2_pytorch/adopt_atan2.py @@ -17,7 +17,7 @@ class AdoptAtan2(Optimizer): the proposed Adam substitute from University of Tokyo combined with the proposed atan2 method for ridding of the eps from Google - Algorithm 2 in https://arxiv.org/abs/2411.02853 + Algorithm 3 in https://arxiv.org/abs/2411.02853 """ def __init__( @@ -77,7 +77,7 @@ def step( if len(state) == 0: state['steps'] = 0 - state['m'] = torch.empty_like(grad) + state['m'] = torch.zeros_like(grad) state['v'] = grad * grad # get some of the states @@ -96,7 +96,7 @@ def step( next_m = grad.atan2(b * v.sqrt()) - m.lerp_(next_m, 1. - (beta1 * int(steps > 1))) + m.lerp_(next_m, 1. - beta1) # then update parameters diff --git a/pyproject.toml b/pyproject.toml index 8d6eb0c..9abbc06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "adam-atan2-pytorch" -version = "0.1.9" +version = "0.1.10" description = "Adam-atan2 for Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }