From 4941807a15af9ed7342126d72bb8bb43b9ea4a21 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 21 Nov 2024 14:36:40 -0800 Subject: [PATCH] manual mode for PostHocEMA , contributed by @kalekundert --- ema_pytorch/post_hoc_ema.py | 11 ++++++++--- setup.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/ema_pytorch/post_hoc_ema.py b/ema_pytorch/post_hoc_ema.py index 9c0f7ef..b8cb871 100644 --- a/ema_pytorch/post_hoc_ema.py +++ b/ema_pytorch/post_hoc_ema.py @@ -22,6 +22,9 @@ def default(val, d): def first(arr): return arr[0] +def divisible_by(num, den): + return (num % den) == 0 + def get_module_device(m: Module): return next(m.parameters()).device @@ -337,9 +340,11 @@ def update(self): for ema_model in self.ema_models: ema_model.update() - if not (self.checkpoint_every_num_steps == 'manual'): - if not (self.step.item() % self.checkpoint_every_num_steps): - self.checkpoint() + if self.checkpoint_every_num_steps == 'manual': + return + + if divisible_by(self.step.item(), self.checkpoint_every_num_steps): + self.checkpoint() def checkpoint(self): step = self.step.item() diff --git a/setup.py b/setup.py index ddfed7f..05e1251 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ema-pytorch', packages = find_packages(exclude=[]), - version = '0.7.5', + version = '0.7.6', license='MIT', description = 'Easy way to keep track of exponential moving average version of your pytorch module', author = 'Phil Wang',