Skip to content

Commit

Permalink
fix ema params not changing at all if disabled
Browse files Browse the repository at this point in the history
shadow params should just be the same as the model params
by not updating at all then they just stay as the original init model params (ie random)
  • Loading branch information
henryaddison committed Aug 8, 2024
1 parent 9309247 commit f28b04c
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def update(self, parameters):
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
s_param.sub_(one_minus_decay * (s_param - param))
else:
# if disabled then just maintain a copy of the parameters
self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad]

def copy_to(self, parameters):
"""
Expand Down

0 comments on commit f28b04c

Please sign in to comment.