diff --git a/models/simsiam.py b/models/simsiam.py index 7db8269..c231f6a 100644 --- a/models/simsiam.py +++ b/models/simsiam.py @@ -39,7 +39,7 @@ def __init__(self, in_dim, hidden_dim=2048, out_dim=2048): ) self.layer3 = nn.Sequential( nn.Linear(hidden_dim, out_dim), - nn.BatchNorm1d(hidden_dim) + nn.BatchNorm1d(out_dim) ) self.num_layers = 3 def set_layers(self, num_layers):