Skip to content

Commit

Permalink
use decay/ema_rate to effectively disable EMA
Browse files Browse the repository at this point in the history
a rate of 1 means no EMA
this is backwards compatible unlike adding a new config attribute
  • Loading branch information
henryaddison committed Aug 9, 2024
1 parent 75fd83c commit 753b1a0
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def get_default_configs():
model.loc_spec_channels = 0
model.num_scales = 1
model.ema_rate = 0.9999
model.ema_disabled = False
model.dropout = 0.1
model.embedding_type = 'fourier'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_config():
# model
model = config.model
model.name = 'det_cunet'
model.ema_disabled = True
model.ema_rate = 1 # basically disables EMA

# optimizer
optim = config.optim
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def get_config():
# model
model = config.model
model.name = 'det_cunet'
model.ema_disabled = False

# optimizer
optim = config.optim
Expand Down
12 changes: 6 additions & 6 deletions src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class ExponentialMovingAverage:
Maintains (exponential) moving average of a set of parameters.
"""

def __init__(self, parameters, decay, use_num_updates=True, disable_update=False):
def __init__(self, parameters, decay, use_num_updates=True):
"""
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the result of
Expand All @@ -28,7 +28,6 @@ def __init__(self, parameters, decay, use_num_updates=True, disable_update=False
self.shadow_params = [p.clone().detach()
for p in parameters if p.requires_grad]
self.collected_params = []
self.disable_update = disable_update

def update(self, parameters):
"""
Expand All @@ -41,7 +40,10 @@ def update(self, parameters):
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
parameters used to initialize this object.
"""
if not self.disable_update:
if self.decay == 1:
# complete decay so just maintain a copy of the parameters
self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad]
else:
decay = self.decay
if self.num_updates is not None:
self.num_updates += 1
Expand All @@ -51,9 +53,7 @@ def update(self, parameters):
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
s_param.sub_(one_minus_decay * (s_param - param))
else:
# if disabled then just maintain a copy of the parameters
self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad]


def copy_to(self, parameters):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def train(config, workdir):
location_params = LocationParams(config.model.loc_spec_channels, config.data.image_size)
location_params = location_params.to(config.device)
location_params = torch.nn.DataParallel(location_params)
ema = ExponentialMovingAverage(itertools.chain(score_model.parameters(), location_params.parameters()), decay=config.model.ema_rate, disable_update=config.model.ema_disabled)
ema = ExponentialMovingAverage(itertools.chain(score_model.parameters(), location_params.parameters()), decay=config.model.ema_rate)

optimizer = losses.get_optimizer(config, itertools.chain(score_model.parameters(), location_params.parameters()))
state = dict(optimizer=optimizer, model=score_model, location_params=location_params, ema=ema, step=0, epoch=0)
Expand Down

0 comments on commit 753b1a0

Please sign in to comment.