diff --git a/README.md b/README.md index fb57399..038c43b 100644 --- a/README.md +++ b/README.md @@ -50,3 +50,12 @@ for _ in range(100): url = {https://api.semanticscholar.org/CorpusID:271051056} } ``` + +```bibtex +@inproceedings{Kumar2023MaintainingPI, + title = {Maintaining Plasticity in Continual Learning via Regenerative Regularization}, + author = {Saurabh Kumar and Henrik Marklund and Benjamin Van Roy}, + year = {2023}, + url = {https://api.semanticscholar.org/CorpusID:261076021} +} +``` diff --git a/adam_atan2_pytorch/adam_atan2.py b/adam_atan2_pytorch/adam_atan2.py index 161a9d9..653001c 100644 --- a/adam_atan2_pytorch/adam_atan2.py +++ b/adam_atan2_pytorch/adam_atan2.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Tuple, Callable +from typing import Callable import torch from torch import atan2, sqrt @@ -17,8 +17,9 @@ def __init__( self, params, lr = 1e-4, - betas: Tuple[float, float] = (0.9, 0.99), + betas: tuple[float, float] = (0.9, 0.99), weight_decay = 0., + regen_reg_rate = 0., decoupled_wd = False, a = 1.27, b = 1. @@ -26,6 +27,8 @@ def __init__( assert lr > 0. assert all([0. <= beta <= 1. for beta in betas]) assert weight_decay >= 0. + assert regen_reg_rate >= 0. + assert not (weight_decay > 0. and regen_reg_rate > 0.) self._init_lr = lr self.decoupled_wd = decoupled_wd @@ -35,7 +38,8 @@ def __init__( betas = betas, a = a, b = b, - weight_decay = weight_decay + weight_decay = weight_decay, + regen_reg_rate = regen_reg_rate ) super().__init__(params, defaults) @@ -54,7 +58,7 @@ def step( for group in self.param_groups: for p in filter(lambda p: exists(p.grad), group['params']): - grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr + grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr # maybe decoupled weight decay @@ -64,7 +68,13 @@ def step( # weight decay if wd > 0.: - p.mul_(1. - lr / init_lr * wd) + p.mul_(1. - lr * wd) + + # regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958 + + if regen_rate > 0. and 'param_init' in state: + param_init = state['param_init'] + p.lerp_(param_init, lr / init_lr * regen_rate) # init state if needed @@ -73,6 +83,9 @@ def step( state['exp_avg'] = torch.zeros_like(grad) state['exp_avg_sq'] = torch.zeros_like(grad) + if regen_rate > 0.: + state['param_init'] = p.clone() + # get some of the states exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps'] diff --git a/adam_atan2_pytorch/foreach.py b/adam_atan2_pytorch/foreach.py index 91a726d..14dff48 100644 --- a/adam_atan2_pytorch/foreach.py +++ b/adam_atan2_pytorch/foreach.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Tuple, List, Callable +from typing import Callable import torch from torch import atan2, sqrt, Tensor @@ -18,7 +18,7 @@ def default(*args): # slow foreach atan2 -def slow_foreach_atan2_(nums: List[Tensor], dens: List[Tensor]): +def slow_foreach_atan2_(nums: list[Tensor], dens: list[Tensor]): for num, den, in zip(nums, dens): num.atan2_(den) @@ -29,8 +29,9 @@ def __init__( self, params, lr = 1e-4, - betas: Tuple[float, float] = (0.9, 0.99), + betas: tuple[float, float] = (0.9, 0.99), weight_decay = 0., + regen_reg_rate = 0., decoupled_wd = False, a = 1.27, b = 1., @@ -39,6 +40,8 @@ def __init__( assert lr > 0. assert all([0. <= beta <= 1. for beta in betas]) assert weight_decay >= 0. + assert regen_reg_rate >= 0. + assert not (weight_decay > 0. and regen_reg_rate > 0.) assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'lerp', 'sqrt')]), 'this version of torch does not have the prerequisite foreach functions' self._init_lr = lr @@ -55,7 +58,8 @@ def __init__( betas = betas, a = a, b = b, - weight_decay = weight_decay + weight_decay = weight_decay, + regen_reg_rate = regen_reg_rate ) super().__init__(params, defaults) @@ -74,13 +78,16 @@ def step( for group in self.param_groups: - wd, lr, beta1, beta2, a, b = group['weight_decay'], group['lr'], *group['betas'], group['a'], group['b'] + wd, regen_rate, lr, beta1, beta2, a, b = group['weight_decay'], group['regen_reg_rate'], group['lr'], *group['betas'], group['a'], group['b'] has_weight_decay = wd > 0 + has_regenerative_reg = regen_rate > 0 + # accumulate List[Tensor] for foreach inplace updates params = [] + params_init = [] grads = [] grad_squared = [] exp_avgs = [] @@ -101,10 +108,11 @@ def step( state['steps'] = 0 state['exp_avg'] = torch.zeros_like(grad) state['exp_avg_sq'] = torch.zeros_like(grad) + state['param_init'] = p.clone() # get some of the states - exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps'] + exp_avg, exp_avg_sq, param_init, steps = state['exp_avg'], state['exp_avg_sq'], state['param_init'], state['steps'] steps += 1 @@ -116,6 +124,7 @@ def step( # append to list params.append(p) + params_init.append(param_init) grads.append(grad) grad_squared.append(grad * grad) exp_avgs.append(exp_avg) @@ -130,6 +139,11 @@ def step( if has_weight_decay: torch._foreach_mul_(params, 1. - lr * wd) + # regenerative regularization + + if has_regenerative_reg: + torch._foreach_lerp_(params, params_init, lr / init_lr * regen_rate) + # decay running averages torch._foreach_lerp_(exp_avgs, grads, 1. - beta1) diff --git a/pyproject.toml b/pyproject.toml index 7bd34a0..af5f0d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "adam-atan2-pytorch" -version = "0.0.10" +version = "0.0.12" description = "Adam-atan2 for Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }