Skip to content

Commit

Permalink
add regenrative regularization for better continual plasticity
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 26, 2024
1 parent 665ec96 commit 6dd472d
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 12 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,12 @@ for _ in range(100):
url = {https://api.semanticscholar.org/CorpusID:271051056}
}
```

```bibtex
@inproceedings{Kumar2023MaintainingPI,
title = {Maintaining Plasticity in Continual Learning via Regenerative Regularization},
author = {Saurabh Kumar and Henrik Marklund and Benjamin Van Roy},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:261076021}
}
```
23 changes: 18 additions & 5 deletions adam_atan2_pytorch/adam_atan2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Tuple, Callable
from typing import Callable

import torch
from torch import atan2, sqrt
Expand All @@ -17,15 +17,18 @@ def __init__(
self,
params,
lr = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
betas: tuple[float, float] = (0.9, 0.99),
weight_decay = 0.,
regen_reg_rate = 0.,
decoupled_wd = False,
a = 1.27,
b = 1.
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])
assert weight_decay >= 0.
assert regen_reg_rate >= 0.
assert not (weight_decay > 0. and regen_reg_rate > 0.)

self._init_lr = lr
self.decoupled_wd = decoupled_wd
Expand All @@ -35,7 +38,8 @@ def __init__(
betas = betas,
a = a,
b = b,
weight_decay = weight_decay
weight_decay = weight_decay,
regen_reg_rate = regen_reg_rate
)

super().__init__(params, defaults)
Expand All @@ -54,7 +58,7 @@ def step(
for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr

# maybe decoupled weight decay

Expand All @@ -64,7 +68,13 @@ def step(
# weight decay

if wd > 0.:
p.mul_(1. - lr / init_lr * wd)
p.mul_(1. - lr * wd)

# regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958

if regen_rate > 0. and 'param_init' in state:
param_init = state['param_init']
p.lerp_(param_init, lr / init_lr * regen_rate)

# init state if needed

Expand All @@ -73,6 +83,9 @@ def step(
state['exp_avg'] = torch.zeros_like(grad)
state['exp_avg_sq'] = torch.zeros_like(grad)

if regen_rate > 0.:
state['param_init'] = p.clone()

# get some of the states

exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps']
Expand Down
26 changes: 20 additions & 6 deletions adam_atan2_pytorch/foreach.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Tuple, List, Callable
from typing import Callable

import torch
from torch import atan2, sqrt, Tensor
Expand All @@ -18,7 +18,7 @@ def default(*args):

# slow foreach atan2

def slow_foreach_atan2_(nums: List[Tensor], dens: List[Tensor]):
def slow_foreach_atan2_(nums: list[Tensor], dens: list[Tensor]):
for num, den, in zip(nums, dens):
num.atan2_(den)

Expand All @@ -29,8 +29,9 @@ def __init__(
self,
params,
lr = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
betas: tuple[float, float] = (0.9, 0.99),
weight_decay = 0.,
regen_reg_rate = 0.,
decoupled_wd = False,
a = 1.27,
b = 1.,
Expand All @@ -39,6 +40,8 @@ def __init__(
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])
assert weight_decay >= 0.
assert regen_reg_rate >= 0.
assert not (weight_decay > 0. and regen_reg_rate > 0.)
assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'lerp', 'sqrt')]), 'this version of torch does not have the prerequisite foreach functions'

self._init_lr = lr
Expand All @@ -55,7 +58,8 @@ def __init__(
betas = betas,
a = a,
b = b,
weight_decay = weight_decay
weight_decay = weight_decay,
regen_reg_rate = regen_reg_rate
)

super().__init__(params, defaults)
Expand All @@ -74,13 +78,16 @@ def step(

for group in self.param_groups:

wd, lr, beta1, beta2, a, b = group['weight_decay'], group['lr'], *group['betas'], group['a'], group['b']
wd, regen_rate, lr, beta1, beta2, a, b = group['weight_decay'], group['regen_reg_rate'], group['lr'], *group['betas'], group['a'], group['b']

has_weight_decay = wd > 0

has_regenerative_reg = regen_rate > 0

# accumulate List[Tensor] for foreach inplace updates

params = []
params_init = []
grads = []
grad_squared = []
exp_avgs = []
Expand All @@ -101,10 +108,11 @@ def step(
state['steps'] = 0
state['exp_avg'] = torch.zeros_like(grad)
state['exp_avg_sq'] = torch.zeros_like(grad)
state['param_init'] = p.clone()

# get some of the states

exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps']
exp_avg, exp_avg_sq, param_init, steps = state['exp_avg'], state['exp_avg_sq'], state['param_init'], state['steps']

steps += 1

Expand All @@ -116,6 +124,7 @@ def step(
# append to list

params.append(p)
params_init.append(param_init)
grads.append(grad)
grad_squared.append(grad * grad)
exp_avgs.append(exp_avg)
Expand All @@ -130,6 +139,11 @@ def step(
if has_weight_decay:
torch._foreach_mul_(params, 1. - lr * wd)

# regenerative regularization

if has_regenerative_reg:
torch._foreach_lerp_(params, params_init, lr / init_lr * regen_rate)

# decay running averages

torch._foreach_lerp_(exp_avgs, grads, 1. - beta1)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "adam-atan2-pytorch"
version = "0.0.10"
version = "0.0.12"
description = "Adam-atan2 for Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 6dd472d

Please sign in to comment.