diff --git a/models/modules.py b/models/modules.py index 73c8494..8c2283b 100644 --- a/models/modules.py +++ b/models/modules.py @@ -236,10 +236,8 @@ def __init__( def zero_init(self): torch.nn.init.zeros_(self.embedding.weight) - for layer in self.decoder: - if isinstance(layer, nn.Linear): - torch.nn.init.zeros_(layer.weight) - torch.nn.init.zeros_(layer.bias) + torch.nn.init.zeros_(self.decoder[2].weight) + torch.nn.init.zeros_(self.decoder[2].bias) def forward(self, image_infos): if "img_idx" in image_infos and not self.in_test_set: