From 017757951a3be63328326e7ff24a3829f97b3bff Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 22 Nov 2024 06:14:37 -0800 Subject: [PATCH] facepalm --- adam_atan2_pytorch/adopt.py | 3 +-- adam_atan2_pytorch/adopt_atan2.py | 3 +-- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/adam_atan2_pytorch/adopt.py b/adam_atan2_pytorch/adopt.py index 9137d11..1796af8 100644 --- a/adam_atan2_pytorch/adopt.py +++ b/adam_atan2_pytorch/adopt.py @@ -97,8 +97,7 @@ def step( next_m = grad.div(v.sqrt().clamp(min = eps)) # they claim that a max(value, eps) performs better than adding the epsilon - if steps > 1: - m.lerp_(next_m, 1. - beta1) + m.lerp_(next_m, 1. - (beta1 * int(steps > 1))) # then update parameters diff --git a/adam_atan2_pytorch/adopt_atan2.py b/adam_atan2_pytorch/adopt_atan2.py index e17ea38..350b827 100644 --- a/adam_atan2_pytorch/adopt_atan2.py +++ b/adam_atan2_pytorch/adopt_atan2.py @@ -100,8 +100,7 @@ def step( next_m = grad.atan2(b * v.sqrt()) - if steps > 1: - m.lerp_(next_m, 1. - beta1) + m.lerp_(next_m, 1. - (beta1 * int(steps > 1))) # then update parameters diff --git a/pyproject.toml b/pyproject.toml index 9726487..0ea5a11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "adam-atan2-pytorch" -version = "0.1.6" +version = "0.1.7" description = "Adam-atan2 for Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }