Skip to content

Commit

Permalink
first add vanilla adopt for running experiments against adam
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 21, 2024
1 parent 8600496 commit 771341a
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 1 deletion.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,12 @@ for _ in range(100):
url = {https://api.semanticscholar.org/CorpusID:270380086}
}
```

```bibtex
@inproceedings{Taniguchi2024ADOPTMA,
title = {ADOPT: Modified Adam Can Converge with Any \$\beta\_2\$ with the Optimal Rate},
author = {Shohei Taniguchi and Keno Harada and Gouki Minegishi and Yuta Oshima and Seong Cheol Jeong and Go Nagahara and Tomoshi Iiyama and Masahiro Suzuki and Yusuke Iwasawa and Yutaka Matsuo},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273822148}
}
```
1 change: 1 addition & 0 deletions adam_atan2_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from adam_atan2_pytorch.adam_atan2 import AdamAtan2
from adam_atan2_pytorch.adopt import Adopt

Adam = AdamAtan2
115 changes: 115 additions & 0 deletions adam_atan2_pytorch/adopt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations
from typing import Callable

import torch
from torch import atan2, sqrt
from torch.optim.optimizer import Optimizer

# functions

def exists(val):
return val is not None

# class

class Adopt(Optimizer):
"""
the proposed Adam substitute from University of Tokyo
Algorithm 2 in https://arxiv.org/abs/2411.02853
"""

def __init__(
self,
params,
lr = 1e-4,
betas: tuple[float, float] = (0.9, 0.9999),
eps = 1e-6,
weight_decay = 0.,
decoupled_wd = True
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])
assert weight_decay >= 0.

self._init_lr = lr
self.decoupled_wd = decoupled_wd

defaults = dict(
lr = lr,
betas = betas,
eps = eps,
weight_decay = weight_decay,
)

super().__init__(params, defaults)

@torch.no_grad()
def step(
self,
closure: Callable | None = None
):

loss = None
if exists(closure):
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

grad, lr, wd, beta1, beta2, eps, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['eps'], self.state[p], self._init_lr

# maybe decoupled weight decay

if self.decoupled_wd:
wd /= init_lr

# weight decay

if wd > 0.:
p.mul_(1. - lr * wd)

# init state if needed

if len(state) == 0:
state['steps'] = 0
state['m'] = torch.empty_like(grad)
state['v'] = grad * grad

# get some of the states

m, v, steps = state['m'], state['v'], state['steps']

# for the first step do nothing

if steps == 0:
state['steps'] += 1
continue

# logic

steps += 1

# calculate m

grad_sq = grad * grad

next_m = grad.div(v.sqrt().clamp(min = eps)) # they claim that a max(value, eps) performs better than adding the epsilon

if steps > 1:
m.lerp_(next_m, 1. - beta2)

# then update parameters

p.add_(m, alpha = -lr)

# update exp grad sq (v)

v.lerp_(grad_sq, 1. - beta1)

# increment steps

state['steps'] = steps

return loss
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "adam-atan2-pytorch"
version = "0.1.1"
version = "0.1.2"
description = "Adam-atan2 for Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 771341a

Please sign in to comment.