From 8f14cf50a89030edbdea5214d22ac268511b400e Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 26 Nov 2024 17:26:38 -0800 Subject: [PATCH] add the proposed cautious optimizer from https://arxiv.org/abs/2411.16085, but allow for attenuating unaligned updates with some factor instead of zeroing completely --- README.md | 9 +++++++++ adam_atan2_pytorch/adam_atan2.py | 12 +++++++++++- adam_atan2_pytorch/adopt_atan2.py | 17 ++++++++++++++--- pyproject.toml | 2 +- 4 files changed, 35 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index e1235b5..6544294 100644 --- a/README.md +++ b/README.md @@ -81,3 +81,12 @@ for _ in range(100): url = {https://api.semanticscholar.org/CorpusID:273822148} } ``` + +```bibtex +@inproceedings{Liang2024CautiousOI, + title = {Cautious Optimizers: Improving Training with One Line of Code}, + author = {Kaizhao Liang and Lizhang Chen and Bo Liu and Qiang Liu}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:274234738} +} +``` diff --git a/adam_atan2_pytorch/adam_atan2.py b/adam_atan2_pytorch/adam_atan2.py index ce14367..d6fd352 100644 --- a/adam_atan2_pytorch/adam_atan2.py +++ b/adam_atan2_pytorch/adam_atan2.py @@ -21,6 +21,7 @@ def __init__( weight_decay = 0., regen_reg_rate = 0., decoupled_wd = False, + cautious_factor = 1., # set to 0. for zeroing out any updates not in same direction as gradient as in https://arxiv.org/abs/2411.16085 a = 1.27, b = 1. ): @@ -29,6 +30,7 @@ def __init__( assert weight_decay >= 0. assert regen_reg_rate >= 0. assert not (weight_decay > 0. and regen_reg_rate > 0.) + assert 0. <= cautious_factor <= 1. self._init_lr = lr self.decoupled_wd = decoupled_wd @@ -40,6 +42,7 @@ def __init__( b = b, weight_decay = weight_decay, regen_reg_rate = regen_reg_rate, + cautious_factor = cautious_factor ) super().__init__(params, defaults) @@ -58,7 +61,7 @@ def step( for group in self.param_groups: for p in filter(lambda p: exists(p.grad), group['params']): - grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr + grad, lr, wd, regen_rate, cautious_factor, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['cautious_factor'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr # maybe decoupled weight decay @@ -109,6 +112,13 @@ def step( den = exp_avg_sq.mul(b * b / bias_correct2).sqrt_() update = exp_avg.mul(1. / bias_correct1).atan2_(den) + # maybe cautious update - algorithm 2 in https://arxiv.org/abs/2411.16085 + + if cautious_factor < 1.: + align_mask = (update * grad) > 0 + scale = torch.where(align_mask, torch.ones_like(grad), cautious_factor) + update *= (scale / scale.mean().clamp(min = 1e-5)) + # update parameters p.add_(update, alpha = -lr * a) diff --git a/adam_atan2_pytorch/adopt_atan2.py b/adam_atan2_pytorch/adopt_atan2.py index 9206f4e..716a227 100644 --- a/adam_atan2_pytorch/adopt_atan2.py +++ b/adam_atan2_pytorch/adopt_atan2.py @@ -28,6 +28,7 @@ def __init__( weight_decay = 0., regen_reg_rate = 0., decoupled_wd = True, + cautious_factor = 1., # set to 0. for zeroing out any updates not in same direction as gradient as in https://arxiv.org/abs/2411.16085 a = 1.27, b = 1. ): @@ -45,7 +46,8 @@ def __init__( a = a, b = b, weight_decay = weight_decay, - regen_reg_rate = regen_reg_rate + regen_reg_rate = regen_reg_rate, + cautious_factor = cautious_factor ) super().__init__(params, defaults) @@ -64,7 +66,7 @@ def step( for group in self.param_groups: for p in filter(lambda p: exists(p.grad), group['params']): - grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr + grad, lr, wd, regen_rate, cautious_factor, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['cautious_factor'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr # maybe decoupled weight decay @@ -110,9 +112,18 @@ def step( m.lerp_(update, 1. - beta1) + # maybe cautious update - algorithm 2 in https://arxiv.org/abs/2411.16085 + + scale = 1. + + if cautious_factor < 1.: + align_mask = (update * grad) > 0 + scale = torch.where(align_mask, torch.ones_like(grad), cautious_factor) + scale /= scale.mean().clamp(min = 1e-5) + # then update parameters - p.add_(m, alpha = -lr * a) + p.add_(m * scale, alpha = -lr * a) # update exp grad sq (v) diff --git a/pyproject.toml b/pyproject.toml index 1cd3fa4..5ac4322 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "adam-atan2-pytorch" -version = "0.1.16" +version = "0.1.18" description = "Adam-atan2 for Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }