Skip to content

Commit

Permalink
the adopt authors have updated paper with clipping for stability
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 22, 2024
1 parent 4ba86bc commit 7137f5d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
15 changes: 11 additions & 4 deletions adam_atan2_pytorch/adopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Adopt(Optimizer):
"""
the proposed Adam substitute from University of Tokyo
Algorithm 2 in https://arxiv.org/abs/2411.02853
Algorithm 3 in https://arxiv.org/abs/2411.02853
"""

def __init__(
Expand Down Expand Up @@ -74,7 +74,7 @@ def step(

if len(state) == 0:
state['steps'] = 0
state['m'] = torch.empty_like(grad)
state['m'] = torch.zeros_like(grad)
state['v'] = grad * grad

# get some of the states
Expand All @@ -91,9 +91,16 @@ def step(

grad_sq = grad * grad

next_m = grad.div(v.sqrt().clamp(min = eps)) # they claim that a max(value, eps) performs better than adding the epsilon
update = grad.div(v.sqrt().clamp(min = eps)) # they claim that a max(value, eps) performs better than adding the epsilon

m.lerp_(next_m, 1. - (beta1 * int(steps > 1)))
# clip with t ^ 0.25 as in Algorithm 3

clip_value = steps ** 0.25
update.clamp_(min = -clip_value, max = clip_value)

# update m

m.lerp_(update, 1. - beta1)

# then update parameters

Expand Down
6 changes: 3 additions & 3 deletions adam_atan2_pytorch/adopt_atan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class AdoptAtan2(Optimizer):
the proposed Adam substitute from University of Tokyo
combined with the proposed atan2 method for ridding of the eps from Google
Algorithm 2 in https://arxiv.org/abs/2411.02853
Algorithm 3 in https://arxiv.org/abs/2411.02853
"""

def __init__(
Expand Down Expand Up @@ -77,7 +77,7 @@ def step(

if len(state) == 0:
state['steps'] = 0
state['m'] = torch.empty_like(grad)
state['m'] = torch.zeros_like(grad)
state['v'] = grad * grad

# get some of the states
Expand All @@ -96,7 +96,7 @@ def step(

next_m = grad.atan2(b * v.sqrt())

m.lerp_(next_m, 1. - (beta1 * int(steps > 1)))
m.lerp_(next_m, 1. - beta1)

# then update parameters

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "adam-atan2-pytorch"
version = "0.1.9"
version = "0.1.10"
description = "Adam-atan2 for Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 7137f5d

Please sign in to comment.