diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py b/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py index 59d199282..3bb26ae33 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py @@ -51,6 +51,9 @@ def update(self, parameters): parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): s_param.sub_(one_minus_decay * (s_param - param)) + else: + # if disabled then just maintain a copy of the parameters + self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] def copy_to(self, parameters): """