Skip to content

Commit

Permalink
add ability to update online model with ema model periodically, eithe…
Browse files Browse the repository at this point in the history
…r hard or interpolated factor, from a promising continual learning paper
  • Loading branch information
lucidrains committed Oct 5, 2024
1 parent 8bdba16 commit f6c0771
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
28 changes: 22 additions & 6 deletions ema_pytorch/ema_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
def exists(val):
return val is not None

def divisible_by(num, den):
return (num % den) == 0

def get_module_device(m: Module):
return next(m.parameters()).device

Expand Down Expand Up @@ -74,6 +77,8 @@ def __init__(
include_online_model = True, # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally)
allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor
use_foreach = False,
update_model_with_ema_every = None, # update the model with EMA model weights every number of steps, for better continual learning https://arxiv.org/abs/2406.02596
update_model_with_ema_beta = 0., # amount of model weight to keep when updating to EMA (hare to tortoise)
forward_method_names: tuple[str, ...] = (),
move_ema_to_online_device = False,
coerce_dtype = False,
Expand Down Expand Up @@ -123,6 +128,11 @@ def __init__(
self.ignore_names = ignore_names
self.ignore_startswith_names = ignore_startswith_names

# continual learning related

self.update_model_with_ema_every = update_model_with_ema_every
self.update_model_with_ema_beta = update_model_with_ema_beta

# whether to manage if EMA model is kept on a different device

self.allow_different_devices = allow_different_devices
Expand Down Expand Up @@ -216,7 +226,10 @@ def copy_params_from_ema_to_model(self):
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
copy(current_buffers.data, ma_buffers.data)

def update_model_from_ema(self, decay):
def update_model_with_ema(self, decay = None):
if not exists(decay):
decay = self.update_model_with_ema_beta

if decay == 0.:
return self.copy_params_from_ema_to_model()

Expand All @@ -235,9 +248,6 @@ def update(self):
step = self.step.item()
self.step += 1

if (step % self.update_every) != 0:
return

if not self.initted.item():
if not exists(self.ema_model):
self.init_ema()
Expand All @@ -246,11 +256,17 @@ def update(self):
self.initted.data.copy_(torch.tensor(True))
return

if step <= self.update_after_step:
should_update = divisible_by(step, self.update_every)

if should_update and step <= self.update_after_step:
self.copy_params_from_model_to_ema()
return

self.update_moving_average(self.ema_model, self.model)
if should_update:
self.update_moving_average(self.ema_model, self.model)

if exists(self.update_model_with_ema_every) and divisible_by(step, self.update_model_with_ema_every):
self.update_model_with_ema()

@torch.no_grad()
def update_moving_average(self, ma_model, current_model, current_decay = None):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ema-pytorch',
packages = find_packages(exclude=[]),
version = '0.6.5',
version = '0.7.0',
license='MIT',
description = 'Easy way to keep track of exponential moving average version of your pytorch module',
author = 'Phil Wang',
Expand Down

0 comments on commit f6c0771

Please sign in to comment.