From 771341a37312caa46ba6abb3d63f5e7b650bec6e Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 21 Nov 2024 12:25:55 -0800 Subject: [PATCH] first add vanilla adopt for running experiments against adam --- README.md | 9 +++ adam_atan2_pytorch/__init__.py | 1 + adam_atan2_pytorch/adopt.py | 115 +++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 4 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 adam_atan2_pytorch/adopt.py diff --git a/README.md b/README.md index 8573c50..e1235b5 100644 --- a/README.md +++ b/README.md @@ -72,3 +72,12 @@ for _ in range(100): url = {https://api.semanticscholar.org/CorpusID:270380086} } ``` + +```bibtex +@inproceedings{Taniguchi2024ADOPTMA, + title = {ADOPT: Modified Adam Can Converge with Any \$\beta\_2\$ with the Optimal Rate}, + author = {Shohei Taniguchi and Keno Harada and Gouki Minegishi and Yuta Oshima and Seong Cheol Jeong and Go Nagahara and Tomoshi Iiyama and Masahiro Suzuki and Yusuke Iwasawa and Yutaka Matsuo}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:273822148} +} +``` diff --git a/adam_atan2_pytorch/__init__.py b/adam_atan2_pytorch/__init__.py index 51abe3b..6fdc8d1 100644 --- a/adam_atan2_pytorch/__init__.py +++ b/adam_atan2_pytorch/__init__.py @@ -1,3 +1,4 @@ from adam_atan2_pytorch.adam_atan2 import AdamAtan2 +from adam_atan2_pytorch.adopt import Adopt Adam = AdamAtan2 diff --git a/adam_atan2_pytorch/adopt.py b/adam_atan2_pytorch/adopt.py new file mode 100644 index 0000000..67dea64 --- /dev/null +++ b/adam_atan2_pytorch/adopt.py @@ -0,0 +1,115 @@ +from __future__ import annotations +from typing import Callable + +import torch +from torch import atan2, sqrt +from torch.optim.optimizer import Optimizer + +# functions + +def exists(val): + return val is not None + +# class + +class Adopt(Optimizer): + """ + the proposed Adam substitute from University of Tokyo + + Algorithm 2 in https://arxiv.org/abs/2411.02853 + """ + + def __init__( + self, + params, + lr = 1e-4, + betas: tuple[float, float] = (0.9, 0.9999), + eps = 1e-6, + weight_decay = 0., + decoupled_wd = True + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + assert weight_decay >= 0. + + self._init_lr = lr + self.decoupled_wd = decoupled_wd + + defaults = dict( + lr = lr, + betas = betas, + eps = eps, + weight_decay = weight_decay, + ) + + super().__init__(params, defaults) + + @torch.no_grad() + def step( + self, + closure: Callable | None = None + ): + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in filter(lambda p: exists(p.grad), group['params']): + + grad, lr, wd, beta1, beta2, eps, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['eps'], self.state[p], self._init_lr + + # maybe decoupled weight decay + + if self.decoupled_wd: + wd /= init_lr + + # weight decay + + if wd > 0.: + p.mul_(1. - lr * wd) + + # init state if needed + + if len(state) == 0: + state['steps'] = 0 + state['m'] = torch.empty_like(grad) + state['v'] = grad * grad + + # get some of the states + + m, v, steps = state['m'], state['v'], state['steps'] + + # for the first step do nothing + + if steps == 0: + state['steps'] += 1 + continue + + # logic + + steps += 1 + + # calculate m + + grad_sq = grad * grad + + next_m = grad.div(v.sqrt().clamp(min = eps)) # they claim that a max(value, eps) performs better than adding the epsilon + + if steps > 1: + m.lerp_(next_m, 1. - beta2) + + # then update parameters + + p.add_(m, alpha = -lr) + + # update exp grad sq (v) + + v.lerp_(grad_sq, 1. - beta1) + + # increment steps + + state['steps'] = steps + + return loss diff --git a/pyproject.toml b/pyproject.toml index ee5e654..f677c60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "adam-atan2-pytorch" -version = "0.1.1" +version = "0.1.2" description = "Adam-atan2 for Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }