Skip to content

Commit

Permalink
address #33
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 21, 2024
1 parent dee87fb commit 7f6cd75
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
10 changes: 8 additions & 2 deletions ema_pytorch/ema_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from typing import Callable

from copy import deepcopy
from functools import partial
Expand Down Expand Up @@ -64,7 +65,7 @@ class EMA(Module):
def __init__(
self,
model: Module,
ema_model: Module | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
ema_model: Module | Callable[[], Module] | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
beta = 0.9999,
update_after_step = 100,
update_every = 10,
Expand All @@ -82,7 +83,7 @@ def __init__(
forward_method_names: tuple[str, ...] = (),
move_ema_to_online_device = False,
coerce_dtype = False,
lazy_init_ema = False
lazy_init_ema = False,
):
super().__init__()
self.beta = beta
Expand All @@ -98,6 +99,11 @@ def __init__(
else:
self.online_model = [model] # hack

# handle callable returning ema module

if callable(ema_model):
ema_model = ema_model()

# ema model

self.ema_model = None
Expand Down
15 changes: 13 additions & 2 deletions ema_pytorch/post_hoc_ema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from typing import Callable

from pathlib import Path
from copy import deepcopy
Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(
model: Module,
sigma_rel: float | None = None,
gamma: float | None = None,
ema_model: Module | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
ema_model: Module | Callable[[], Module] | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
update_every: int = 100,
frozen: bool = False,
param_or_buffer_names_no_ema: Set[str] = set(),
Expand All @@ -74,6 +75,11 @@ def __init__(

self.online_model = [model]

# handle callable returning ema module

if callable(ema_model):
ema_model = ema_model()

# ema model

self.ema_model = ema_model
Expand Down Expand Up @@ -274,6 +280,7 @@ class PostHocEMA(Module):
def __init__(
self,
model: Module,
ema_model: Callable[[], Module] | None = None,
sigma_rels: Tuple[float, ...] | None = None,
gammas: Tuple[float, ...] | None = None,
checkpoint_every_num_steps: int = 1000,
Expand All @@ -290,11 +297,13 @@ def __init__(
assert len(gammas) > 1, 'at least 2 ema models with different gammas in order to synthesize new ema models of a different gamma'
assert len(set(gammas)) == len(gammas), 'calculated gammas must be all unique'

self.maybe_ema_model = ema_model

self.gammas = gammas
self.num_ema_models = len(gammas)

self._model = [model]
self.ema_models = ModuleList([KarrasEMA(model, gamma = gamma, **kwargs) for gamma in gammas])
self.ema_models = ModuleList([KarrasEMA(model, ema_model = ema_model, gamma = gamma, **kwargs) for gamma in gammas])

self.checkpoint_folder = Path(checkpoint_folder)
self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
Expand Down Expand Up @@ -355,6 +364,7 @@ def synthesize_ema_model(

synthesized_ema_model = KarrasEMA(
model = self.model,
ema_model = self.maybe_ema_model,
gamma = gamma,
**self.ema_kwargs
)
Expand Down Expand Up @@ -392,6 +402,7 @@ def synthesize_ema_model(

tmp_ema_model = KarrasEMA(
model = self.model,
ema_model = self.maybe_ema_model,
gamma = gamma,
**self.ema_kwargs
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ema-pytorch',
packages = find_packages(exclude=[]),
version = '0.7.3',
version = '0.7.4',
license='MIT',
description = 'Easy way to keep track of exponential moving average version of your pytorch module',
author = 'Phil Wang',
Expand Down

0 comments on commit 7f6cd75

Please sign in to comment.