From be690280b02627ed1d159efcd69444e997bdbfed Mon Sep 17 00:00:00 2001 From: Pierre Merriaux Date: Tue, 1 Oct 2024 13:59:07 -0400 Subject: [PATCH] fix: affine gradiant propagation --- models/modules.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/models/modules.py b/models/modules.py index 73c8494..8c2283b 100644 --- a/models/modules.py +++ b/models/modules.py @@ -236,10 +236,8 @@ def __init__( def zero_init(self): torch.nn.init.zeros_(self.embedding.weight) - for layer in self.decoder: - if isinstance(layer, nn.Linear): - torch.nn.init.zeros_(layer.weight) - torch.nn.init.zeros_(layer.bias) + torch.nn.init.zeros_(self.decoder[2].weight) + torch.nn.init.zeros_(self.decoder[2].bias) def forward(self, image_infos): if "img_idx" in image_infos and not self.in_test_set: