diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py index 7c2c38ebd..12db7ee2e 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py @@ -71,7 +71,6 @@ def get_default_configs(): model.loc_spec_channels = 0 model.num_scales = 1 model.ema_rate = 0.9999 - model.ema_disabled = False model.dropout = 0.1 model.embedding_type = 'fourier' diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py index 50432dd1f..fc5ce7c6d 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py @@ -43,7 +43,7 @@ def get_config(): # model model = config.model model.name = 'det_cunet' - model.ema_disabled = True + model.ema_rate = 1 # basically disables EMA # optimizer optim = config.optim diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py index d032d3175..2cde45932 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py @@ -43,7 +43,6 @@ def get_config(): # model model = config.model model.name = 'det_cunet' - model.ema_disabled = False # optimizer optim = config.optim diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py b/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py index 3bb26ae33..dd57fa3c1 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py @@ -12,7 +12,7 @@ class ExponentialMovingAverage: Maintains (exponential) moving average of a set of parameters. """ - def __init__(self, parameters, decay, use_num_updates=True, disable_update=False): + def __init__(self, parameters, decay, use_num_updates=True): """ Args: parameters: Iterable of `torch.nn.Parameter`; usually the result of @@ -28,7 +28,6 @@ def __init__(self, parameters, decay, use_num_updates=True, disable_update=False self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] self.collected_params = [] - self.disable_update = disable_update def update(self, parameters): """ @@ -41,7 +40,10 @@ def update(self, parameters): parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to initialize this object. """ - if not self.disable_update: + if self.decay == 1: + # complete decay so just maintain a copy of the parameters + self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] + else: decay = self.decay if self.num_updates is not None: self.num_updates += 1 @@ -51,9 +53,7 @@ def update(self, parameters): parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): s_param.sub_(one_minus_decay * (s_param - param)) - else: - # if disabled then just maintain a copy of the parameters - self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] + def copy_to(self, parameters): """ diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index f391fab95..e185a675d 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -124,7 +124,7 @@ def train(config, workdir): location_params = LocationParams(config.model.loc_spec_channels, config.data.image_size) location_params = location_params.to(config.device) location_params = torch.nn.DataParallel(location_params) - ema = ExponentialMovingAverage(itertools.chain(score_model.parameters(), location_params.parameters()), decay=config.model.ema_rate, disable_update=config.model.ema_disabled) + ema = ExponentialMovingAverage(itertools.chain(score_model.parameters(), location_params.parameters()), decay=config.model.ema_rate) optimizer = losses.get_optimizer(config, itertools.chain(score_model.parameters(), location_params.parameters())) state = dict(optimizer=optimizer, model=score_model, location_params=location_params, ema=ema, step=0, epoch=0)