Skip to content

Commit

Permalink
Merge pull request #37 from henryaddison/randn-init-loc-params
Browse files Browse the repository at this point in the history
Random initization of location-specific parameters
  • Loading branch information
henryaddison authored Sep 18, 2024
2 parents 2bae4fe + e1b866d commit 5d1c28b
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ class LocationParams(torch.nn.Module):
def __init__(self, n_channels, size) -> None:
super().__init__()

self.params = torch.nn.Parameter(torch.zeros(n_channels, size, size))
# He initialization of weights
tensor = torch.randn(n_channels, size, size)
torch.nn.init.kaiming_normal_(tensor, mode="fan_out")
self.params = torch.nn.Parameter(tensor)


def forward(self, cond):
batch_size = cond.shape[0]
Expand Down

0 comments on commit 5d1c28b

Please sign in to comment.