Skip to content

Commit

Permalink
add an atan2 flavor for adopt
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 21, 2024
1 parent 16bdbcd commit f6ab117
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 1 deletion.
117 changes: 117 additions & 0 deletions adam_atan2_pytorch/adopt_atan2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations
from typing import Callable

import torch
from torch import atan2, sqrt
from torch.optim.optimizer import Optimizer

# functions

def exists(val):
return val is not None

# class

class AdoptAtan2(Optimizer):
"""
the proposed Adam substitute from University of Tokyo
Algorithm 2 in https://arxiv.org/abs/2411.02853
"""

def __init__(
self,
params,
lr = 1e-4,
betas: tuple[float, float] = (0.9, 0.9999),
weight_decay = 0.,
decoupled_wd = True,
a = 1.27,
b = 1.
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])
assert weight_decay >= 0.

self._init_lr = lr
self.decoupled_wd = decoupled_wd

defaults = dict(
lr = lr,
betas = betas,
a = a,
b = b,
weight_decay = weight_decay,
)

super().__init__(params, defaults)

@torch.no_grad()
def step(
self,
closure: Callable | None = None
):

loss = None
if exists(closure):
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr

# maybe decoupled weight decay

if self.decoupled_wd:
wd /= init_lr

# weight decay

if wd > 0.:
p.mul_(1. - lr * wd)

# init state if needed

if len(state) == 0:
state['steps'] = 0
state['m'] = torch.empty_like(grad)
state['v'] = grad * grad

# get some of the states

m, v, steps = state['m'], state['v'], state['steps']

# for the first step do nothing

if steps == 0:
state['steps'] += 1
continue

# logic

steps += 1

# calculate m

grad_sq = grad * grad

next_m = grad.atan2(b * v.sqrt())

if steps > 1:
m.lerp_(next_m, 1. - beta1)

# then update parameters

p.add_(m, alpha = -lr * a)

# update exp grad sq (v)

v.lerp_(grad_sq, 1. - beta2)

# increment steps

state['steps'] = steps

return loss
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "adam-atan2-pytorch"
version = "0.1.4"
version = "0.1.5"
description = "Adam-atan2 for Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit f6ab117

Please sign in to comment.