diff --git a/README.md b/README.md index 038c43b..b931401 100644 --- a/README.md +++ b/README.md @@ -59,3 +59,14 @@ for _ in range(100): url = {https://api.semanticscholar.org/CorpusID:261076021} } ``` + +```bibtex +@article{Lewandowski2024LearningCB, + title = {Learning Continually by Spectral Regularization}, + author = {Alex Lewandowski and Saurabh Kumar and Dale Schuurmans and Andr'as Gyorgy and Marlos C. Machado}, + journal = {ArXiv}, + year = {2024}, + volume = {abs/2406.06811}, + url = {https://api.semanticscholar.org/CorpusID:270380086} +} +``` diff --git a/adam_atan2_pytorch/adam_atan2.py b/adam_atan2_pytorch/adam_atan2.py index 653001c..554a91f 100644 --- a/adam_atan2_pytorch/adam_atan2.py +++ b/adam_atan2_pytorch/adam_atan2.py @@ -20,6 +20,7 @@ def __init__( betas: tuple[float, float] = (0.9, 0.99), weight_decay = 0., regen_reg_rate = 0., + wasserstein_reg = False, decoupled_wd = False, a = 1.27, b = 1. @@ -39,7 +40,8 @@ def __init__( a = a, b = b, weight_decay = weight_decay, - regen_reg_rate = regen_reg_rate + regen_reg_rate = regen_reg_rate, + wasserstein_reg = wasserstein_reg, ) super().__init__(params, defaults) @@ -58,7 +60,7 @@ def step( for group in self.param_groups: for p in filter(lambda p: exists(p.grad), group['params']): - grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr + grad, lr, wd, regen_rate, wasserstein_reg, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['wasserstein_reg'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr # maybe decoupled weight decay diff --git a/adam_atan2_pytorch/adam_atan2_with_wasserstein_reg.py b/adam_atan2_pytorch/adam_atan2_with_wasserstein_reg.py new file mode 100644 index 0000000..9a5e136 --- /dev/null +++ b/adam_atan2_pytorch/adam_atan2_with_wasserstein_reg.py @@ -0,0 +1,135 @@ +from __future__ import annotations +from typing import Callable + +import torch +from torch import atan2, sqrt +from torch.optim.optimizer import Optimizer + +# functions + +def exists(val): + return val is not None + +# class + +class AdamAtan2(Optimizer): + def __init__( + self, + params, + lr = 1e-4, + betas: tuple[float, float] = (0.9, 0.99), + weight_decay = 0., + regen_reg_rate = 0., + decoupled_wd = False, + a = 1.27, + b = 1. + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + assert weight_decay >= 0. + assert regen_reg_rate >= 0. + assert not (weight_decay > 0. and regen_reg_rate > 0.) + + self._init_lr = lr + self.decoupled_wd = decoupled_wd + + defaults = dict( + lr = lr, + betas = betas, + a = a, + b = b, + weight_decay = weight_decay, + regen_reg_rate = regen_reg_rate, + ) + + super().__init__(params, defaults) + + @torch.no_grad() + def step( + self, + closure: Callable | None = None + ): + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in filter(lambda p: exists(p.grad), group['params']): + + grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr + + # maybe decoupled weight decay + + if self.decoupled_wd: + wd /= init_lr + + # weight decay + + if wd > 0.: + p.mul_(1. - lr * wd) + + # regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958 + + if regen_rate > 0. and 'param_init' in state: + param_init = state['param_init'] + + shape = param_init.shape + + # wasserstein compares using ordered statistics, iiuc + + indices = p.flatten().sort(dim = -1).indices + indices = indices.argsort(dim = -1) + + target = param_init.flatten()[indices] + target.reshape(shape) + + p.lerp_(target, lr / init_lr * regen_rate) + + # init state if needed + + if len(state) == 0: + state['steps'] = 0 + state['exp_avg'] = torch.zeros_like(grad) + state['exp_avg_sq'] = torch.zeros_like(grad) + + if regen_rate > 0.: + shape = p.shape + p = p.flatten().sort(dim = -1).values + p = p.reshape(shape) + + state['param_init'] = p.clone() + + # get some of the states + + exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps'] + + steps += 1 + + # bias corrections + + bias_correct1 = 1. - beta1 ** steps + bias_correct2 = 1. - beta2 ** steps + + # decay running averages + + exp_avg.lerp_(grad, 1. - beta1) + exp_avg_sq.lerp_(grad * grad, 1. - beta2) + + # the following line is the proposed change to the update rule + # using atan2 instead of a division with epsilon in denominator + # a * atan2(exp_avg / bias_correct1, b * sqrt(exp_avg_sq / bias_correct2)) + + den = exp_avg_sq.mul(b * b / bias_correct2).sqrt_() + update = exp_avg.mul(1. / bias_correct1).atan2_(den) + + # update parameters + + p.add_(update, alpha = -lr * a) + + # increment steps + + state['steps'] = steps + + return loss diff --git a/pyproject.toml b/pyproject.toml index af5f0d3..c706649 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "adam-atan2-pytorch" -version = "0.0.12" +version = "0.0.14" description = "Adam-atan2 for Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }