Skip to content

Commit

Permalink
add a way to disable the EMA updating
Browse files Browse the repository at this point in the history
so can basically disable EMA with a flag.
This is another difference between u-net trained on score_sde side deterministically
and the separate deterministic training approach.""

In theory decay rate of 1 should allow this but it's complicated by a num_updates params too
  • Loading branch information
henryaddison committed Aug 9, 2024
1 parent 34369c5 commit de2d0a9
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_default_configs():
model.loc_spec_channels = 0
model.num_scales = 1
model.ema_rate = 0.9999
model.ema_disabled = False
model.dropout = 0.1
model.embedding_type = 'fourier'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_config():
# model
model = config.model
model.name = 'det_cunet'
model.ema_disabled = True

# optimizer
optim = config.optim
Expand Down
25 changes: 15 additions & 10 deletions src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class ExponentialMovingAverage:
Maintains (exponential) moving average of a set of parameters.
"""

def __init__(self, parameters, decay, use_num_updates=True):
def __init__(self, parameters, decay, use_num_updates=True, disable_update=False):
"""
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the result of
Expand All @@ -28,6 +28,7 @@ def __init__(self, parameters, decay, use_num_updates=True):
self.shadow_params = [p.clone().detach()
for p in parameters if p.requires_grad]
self.collected_params = []
self.disable_update = disable_update

def update(self, parameters):
"""
Expand All @@ -40,15 +41,19 @@ def update(self, parameters):
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
parameters used to initialize this object.
"""
decay = self.decay
if self.num_updates is not None:
self.num_updates += 1
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
s_param.sub_(one_minus_decay * (s_param - param))
if not self.disable_update:
decay = self.decay
if self.num_updates is not None:
self.num_updates += 1
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
s_param.sub_(one_minus_decay * (s_param - param))
else:
# if disabled then just maintain a copy of the parameters
self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad]

def copy_to(self, parameters):
"""
Expand Down
3 changes: 2 additions & 1 deletion src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def train(config, workdir):
location_params = LocationParams(config.model.loc_spec_channels, config.data.image_size)
location_params = location_params.to(config.device)
location_params = torch.nn.DataParallel(location_params)
ema = ExponentialMovingAverage(itertools.chain(score_model.parameters(), location_params.parameters()), decay=config.model.ema_rate)
ema = ExponentialMovingAverage(itertools.chain(score_model.parameters(), location_params.parameters()), decay=config.model.ema_rate, disable_update=config.model.ema_disabled)

optimizer = losses.get_optimizer(config, itertools.chain(score_model.parameters(), location_params.parameters()))
state = dict(optimizer=optimizer, model=score_model, location_params=location_params, ema=ema, step=0, epoch=0)

Expand Down

0 comments on commit de2d0a9

Please sign in to comment.