From a8e03f64583352158f2137e3710d74533ceafde3 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Tue, 13 Dec 2022 14:35:05 +0000 Subject: [PATCH 1/2] initialize loc-spec params from gaussian --- .../score_sde_pytorch/models/location_params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/location_params.py b/src/ml_downscaling_emulator/score_sde_pytorch/models/location_params.py index f7ebd7f72..dd856bce5 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/models/location_params.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/models/location_params.py @@ -5,7 +5,7 @@ class LocationParams(torch.nn.Module): def __init__(self, n_channels, size) -> None: super().__init__() - self.params = torch.nn.Parameter(torch.zeros(n_channels, size, size)) + self.params = torch.nn.Parameter(torch.randn(n_channels, size, size)) def forward(self, cond): batch_size = cond.shape[0] From e1b866dba88c9a0eed0c0458d5991c9055550294 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Mon, 4 Mar 2024 12:58:40 +0000 Subject: [PATCH 2/2] use He initialization --- .../score_sde_pytorch/models/location_params.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/location_params.py b/src/ml_downscaling_emulator/score_sde_pytorch/models/location_params.py index dd856bce5..ff4585e7c 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/models/location_params.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/models/location_params.py @@ -5,7 +5,11 @@ class LocationParams(torch.nn.Module): def __init__(self, n_channels, size) -> None: super().__init__() - self.params = torch.nn.Parameter(torch.randn(n_channels, size, size)) + # He initialization of weights + tensor = torch.randn(n_channels, size, size) + torch.nn.init.kaiming_normal_(tensor, mode="fan_out") + self.params = torch.nn.Parameter(tensor) + def forward(self, cond): batch_size = cond.shape[0]