You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
#2#33
Regarding the issues above, I think the following lines are the proper way to replace the original implementation with the group normalization.
The second function is the slightly modified official implementation of SGE block to align with the GN in pytorch
self.gn = nn.GroupNorm(1, 1)
def forward(self, x):
b, c, h, w = x.size()
x = x.view(b * self.groups, -1, h, w)
xn = x * self.avg_pool(x)
xn = xn.sum(dim=1, keepdim=True)
xn = xn.view(b * self.groups, -1, h, w)
t = self.gn.forward(xn)
x = x * self.sig(t.view(b * self.groups, 1, h, w))
x = x.view(b, c, h, w)
return x
def oforward(self, x):
b, c, h, w = x.size()
x = x.view(b * self.groups, -1, h, w)
xn = x * self.avg_pool(x)
# Reduce the weighted channels in each groups to obtain the attention maps for each groups
# (This operation is not performed in GN)
xn = xn.sum(dim=1, keepdim=True)
# Flatten the spatial in each groups
t = xn.view(b * self.groups, -1)
# I think we should use the std of the original t instead of the one updated by subtracting a mean from it.
var = t.var(dim=1, keepdim=True, unbiased=False)
t = (t - t.mean(dim=1, keepdim=True)) / torch.sqrt(var + self.eps)
t = t.view(b, self.groups, h, w)
t = t * self.weight + self.bias
t = t.view(b * self.groups, 1, h, w)
x = x * self.sig(t)
x = x.view(b, c, h, w)
return x
Following is the testing code with the result:4.3839216232299807e-07
running_sum = 0
for _ in range(100):
t = torch.rand(32, 512, 21, 21)
m = SGE(64, 512) # number of groups and input channels
running_sum += (m.forward(t) - m.oforward(t)).max().item()
print("The average maximum difference between the tensor is : ", running_sum / 100)
The text was updated successfully, but these errors were encountered:
#2 #33
Regarding the issues above, I think the following lines are the proper way to replace the original implementation with the group normalization.
The second function is the slightly modified official implementation of SGE block to align with the GN in pytorch
Following is the testing code with the result:4.3839216232299807e-07
The text was updated successfully, but these errors were encountered: