Skip to content

Commit

Permalink
take care of a paper in the continual learning literature that improv…
Browse files Browse the repository at this point in the history
…es on regenerative reg by using wasserstein distance
  • Loading branch information
lucidrains committed Oct 1, 2024
1 parent 068150a commit fcc9b1d
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 3 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,14 @@ for _ in range(100):
url = {https://api.semanticscholar.org/CorpusID:261076021}
}
```

```bibtex
@article{Lewandowski2024LearningCB,
title = {Learning Continually by Spectral Regularization},
author = {Alex Lewandowski and Saurabh Kumar and Dale Schuurmans and Andr'as Gyorgy and Marlos C. Machado},
journal = {ArXiv},
year = {2024},
volume = {abs/2406.06811},
url = {https://api.semanticscholar.org/CorpusID:270380086}
}
```
6 changes: 4 additions & 2 deletions adam_atan2_pytorch/adam_atan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
betas: tuple[float, float] = (0.9, 0.99),
weight_decay = 0.,
regen_reg_rate = 0.,
wasserstein_reg = False,
decoupled_wd = False,
a = 1.27,
b = 1.
Expand All @@ -39,7 +40,8 @@ def __init__(
a = a,
b = b,
weight_decay = weight_decay,
regen_reg_rate = regen_reg_rate
regen_reg_rate = regen_reg_rate,
wasserstein_reg = wasserstein_reg,
)

super().__init__(params, defaults)
Expand All @@ -58,7 +60,7 @@ def step(
for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
grad, lr, wd, regen_rate, wasserstein_reg, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['wasserstein_reg'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr

# maybe decoupled weight decay

Expand Down
135 changes: 135 additions & 0 deletions adam_atan2_pytorch/adam_atan2_with_wasserstein_reg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from __future__ import annotations
from typing import Callable

import torch
from torch import atan2, sqrt
from torch.optim.optimizer import Optimizer

# functions

def exists(val):
return val is not None

# class

class AdamAtan2(Optimizer):
def __init__(
self,
params,
lr = 1e-4,
betas: tuple[float, float] = (0.9, 0.99),
weight_decay = 0.,
regen_reg_rate = 0.,
decoupled_wd = False,
a = 1.27,
b = 1.
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])
assert weight_decay >= 0.
assert regen_reg_rate >= 0.
assert not (weight_decay > 0. and regen_reg_rate > 0.)

self._init_lr = lr
self.decoupled_wd = decoupled_wd

defaults = dict(
lr = lr,
betas = betas,
a = a,
b = b,
weight_decay = weight_decay,
regen_reg_rate = regen_reg_rate,
)

super().__init__(params, defaults)

@torch.no_grad()
def step(
self,
closure: Callable | None = None
):

loss = None
if exists(closure):
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr

# maybe decoupled weight decay

if self.decoupled_wd:
wd /= init_lr

# weight decay

if wd > 0.:
p.mul_(1. - lr * wd)

# regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958

if regen_rate > 0. and 'param_init' in state:
param_init = state['param_init']

shape = param_init.shape

# wasserstein compares using ordered statistics, iiuc

indices = p.flatten().sort(dim = -1).indices
indices = indices.argsort(dim = -1)

target = param_init.flatten()[indices]
target.reshape(shape)

p.lerp_(target, lr / init_lr * regen_rate)

# init state if needed

if len(state) == 0:
state['steps'] = 0
state['exp_avg'] = torch.zeros_like(grad)
state['exp_avg_sq'] = torch.zeros_like(grad)

if regen_rate > 0.:
shape = p.shape
p = p.flatten().sort(dim = -1).values
p = p.reshape(shape)

state['param_init'] = p.clone()

# get some of the states

exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps']

steps += 1

# bias corrections

bias_correct1 = 1. - beta1 ** steps
bias_correct2 = 1. - beta2 ** steps

# decay running averages

exp_avg.lerp_(grad, 1. - beta1)
exp_avg_sq.lerp_(grad * grad, 1. - beta2)

# the following line is the proposed change to the update rule
# using atan2 instead of a division with epsilon in denominator
# a * atan2(exp_avg / bias_correct1, b * sqrt(exp_avg_sq / bias_correct2))

den = exp_avg_sq.mul(b * b / bias_correct2).sqrt_()
update = exp_avg.mul(1. / bias_correct1).atan2_(den)

# update parameters

p.add_(update, alpha = -lr * a)

# increment steps

state['steps'] = steps

return loss
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "adam-atan2-pytorch"
version = "0.0.12"
version = "0.0.14"
description = "Adam-atan2 for Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit fcc9b1d

Please sign in to comment.