diff --git a/ema_pytorch/ema_pytorch.py b/ema_pytorch/ema_pytorch.py index 7e5b430..e0d6826 100644 --- a/ema_pytorch/ema_pytorch.py +++ b/ema_pytorch/ema_pytorch.py @@ -10,6 +10,9 @@ def exists(val): return val is not None +def divisible_by(num, den): + return (num % den) == 0 + def get_module_device(m: Module): return next(m.parameters()).device @@ -74,6 +77,8 @@ def __init__( include_online_model = True, # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally) allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor use_foreach = False, + update_model_with_ema_every = None, # update the model with EMA model weights every number of steps, for better continual learning https://arxiv.org/abs/2406.02596 + update_model_with_ema_beta = 0., # amount of model weight to keep when updating to EMA (hare to tortoise) forward_method_names: tuple[str, ...] = (), move_ema_to_online_device = False, coerce_dtype = False, @@ -123,6 +128,11 @@ def __init__( self.ignore_names = ignore_names self.ignore_startswith_names = ignore_startswith_names + # continual learning related + + self.update_model_with_ema_every = update_model_with_ema_every + self.update_model_with_ema_beta = update_model_with_ema_beta + # whether to manage if EMA model is kept on a different device self.allow_different_devices = allow_different_devices @@ -216,7 +226,10 @@ def copy_params_from_ema_to_model(self): for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)): copy(current_buffers.data, ma_buffers.data) - def update_model_from_ema(self, decay): + def update_model_with_ema(self, decay = None): + if not exists(decay): + decay = self.update_model_with_ema_beta + if decay == 0.: return self.copy_params_from_ema_to_model() @@ -235,9 +248,6 @@ def update(self): step = self.step.item() self.step += 1 - if (step % self.update_every) != 0: - return - if not self.initted.item(): if not exists(self.ema_model): self.init_ema() @@ -246,11 +256,17 @@ def update(self): self.initted.data.copy_(torch.tensor(True)) return - if step <= self.update_after_step: + should_update = divisible_by(step, self.update_every) + + if should_update and step <= self.update_after_step: self.copy_params_from_model_to_ema() return - self.update_moving_average(self.ema_model, self.model) + if should_update: + self.update_moving_average(self.ema_model, self.model) + + if exists(self.update_model_with_ema_every) and divisible_by(step, self.update_model_with_ema_every): + self.update_model_with_ema() @torch.no_grad() def update_moving_average(self, ma_model, current_model, current_decay = None): diff --git a/setup.py b/setup.py index b2c71d5..9bb79bc 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ema-pytorch', packages = find_packages(exclude=[]), - version = '0.6.5', + version = '0.7.0', license='MIT', description = 'Easy way to keep track of exponential moving average version of your pytorch module', author = 'Phil Wang',