Skip to content

Commit

Permalink
use He initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Mar 4, 2024
1 parent cec39d3 commit 291f44e
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ class LocationParams(torch.nn.Module):
def __init__(self, n_channels, size) -> None:
super().__init__()

self.params = torch.nn.Parameter(torch.randn(n_channels, size, size))
# He initialization of weights
tensor = torch.empty(n_channels, size, size)
torch.nn.init.kaiming_normal_(tensor.weight)
self.params = torch.nn.Parameter(tensor)

def forward(self, cond):
batch_size = cond.shape[0]
Expand Down

0 comments on commit 291f44e

Please sign in to comment.