diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/location_params.py b/src/ml_downscaling_emulator/score_sde_pytorch/models/location_params.py index dd856bce5..ff4585e7c 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/models/location_params.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/models/location_params.py @@ -5,7 +5,11 @@ class LocationParams(torch.nn.Module): def __init__(self, n_channels, size) -> None: super().__init__() - self.params = torch.nn.Parameter(torch.randn(n_channels, size, size)) + # He initialization of weights + tensor = torch.randn(n_channels, size, size) + torch.nn.init.kaiming_normal_(tensor, mode="fan_out") + self.params = torch.nn.Parameter(tensor) + def forward(self, cond): batch_size = cond.shape[0]