Skip to content

Commit

Permalink
add the proposed cautious optimizer from https://arxiv.org/abs/2411.1…
Browse files Browse the repository at this point in the history
…6085, but allow for attenuating unaligned updates with some factor instead of zeroing completely
  • Loading branch information
lucidrains committed Nov 27, 2024
1 parent 5a48ed9 commit 8f14cf5
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 5 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,12 @@ for _ in range(100):
url = {https://api.semanticscholar.org/CorpusID:273822148}
}
```

```bibtex
@inproceedings{Liang2024CautiousOI,
title = {Cautious Optimizers: Improving Training with One Line of Code},
author = {Kaizhao Liang and Lizhang Chen and Bo Liu and Qiang Liu},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:274234738}
}
```
12 changes: 11 additions & 1 deletion adam_atan2_pytorch/adam_atan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
weight_decay = 0.,
regen_reg_rate = 0.,
decoupled_wd = False,
cautious_factor = 1., # set to 0. for zeroing out any updates not in same direction as gradient as in https://arxiv.org/abs/2411.16085
a = 1.27,
b = 1.
):
Expand All @@ -29,6 +30,7 @@ def __init__(
assert weight_decay >= 0.
assert regen_reg_rate >= 0.
assert not (weight_decay > 0. and regen_reg_rate > 0.)
assert 0. <= cautious_factor <= 1.

self._init_lr = lr
self.decoupled_wd = decoupled_wd
Expand All @@ -40,6 +42,7 @@ def __init__(
b = b,
weight_decay = weight_decay,
regen_reg_rate = regen_reg_rate,
cautious_factor = cautious_factor
)

super().__init__(params, defaults)
Expand All @@ -58,7 +61,7 @@ def step(
for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
grad, lr, wd, regen_rate, cautious_factor, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['cautious_factor'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr

# maybe decoupled weight decay

Expand Down Expand Up @@ -109,6 +112,13 @@ def step(
den = exp_avg_sq.mul(b * b / bias_correct2).sqrt_()
update = exp_avg.mul(1. / bias_correct1).atan2_(den)

# maybe cautious update - algorithm 2 in https://arxiv.org/abs/2411.16085

if cautious_factor < 1.:
align_mask = (update * grad) > 0
scale = torch.where(align_mask, torch.ones_like(grad), cautious_factor)
update *= (scale / scale.mean().clamp(min = 1e-5))

# update parameters

p.add_(update, alpha = -lr * a)
Expand Down
17 changes: 14 additions & 3 deletions adam_atan2_pytorch/adopt_atan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
weight_decay = 0.,
regen_reg_rate = 0.,
decoupled_wd = True,
cautious_factor = 1., # set to 0. for zeroing out any updates not in same direction as gradient as in https://arxiv.org/abs/2411.16085
a = 1.27,
b = 1.
):
Expand All @@ -45,7 +46,8 @@ def __init__(
a = a,
b = b,
weight_decay = weight_decay,
regen_reg_rate = regen_reg_rate
regen_reg_rate = regen_reg_rate,
cautious_factor = cautious_factor
)

super().__init__(params, defaults)
Expand All @@ -64,7 +66,7 @@ def step(
for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
grad, lr, wd, regen_rate, cautious_factor, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['cautious_factor'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr

# maybe decoupled weight decay

Expand Down Expand Up @@ -110,9 +112,18 @@ def step(

m.lerp_(update, 1. - beta1)

# maybe cautious update - algorithm 2 in https://arxiv.org/abs/2411.16085

scale = 1.

if cautious_factor < 1.:
align_mask = (update * grad) > 0
scale = torch.where(align_mask, torch.ones_like(grad), cautious_factor)
scale /= scale.mean().clamp(min = 1e-5)

# then update parameters

p.add_(m, alpha = -lr * a)
p.add_(m * scale, alpha = -lr * a)

# update exp grad sq (v)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "adam-atan2-pytorch"
version = "0.1.16"
version = "0.1.18"
description = "Adam-atan2 for Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 8f14cf5

Please sign in to comment.