From f28b04c35ed9b817d9d3c9a49262498f6f6fa686 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Thu, 8 Aug 2024 11:20:43 +0100 Subject: [PATCH] fix ema params not changing at all if disabled shadow params should just be the same as the model params by not updating at all then they just stay as the original init model params (ie random) --- src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py b/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py index 59d199282..3bb26ae33 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py @@ -51,6 +51,9 @@ def update(self, parameters): parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): s_param.sub_(one_minus_decay * (s_param - param)) + else: + # if disabled then just maintain a copy of the parameters + self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] def copy_to(self, parameters): """