diff --git a/src/model/common.py b/src/model/common.py index aeee3351..d5b159d4 100644 --- a/src/model/common.py +++ b/src/model/common.py @@ -26,7 +26,7 @@ def __init__( self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True)): - m = [conv(in_channels, out_channels, kernel_size, bias=bias)] + m = [conv(in_channels, out_channels, kernel_size, bias=(bias and not bn))] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: @@ -42,7 +42,7 @@ def __init__( super(ResBlock, self).__init__() m = [] for i in range(2): - m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + m.append(conv(n_feats, n_feats, kernel_size, bias=(bias and not bn))) if bn: m.append(nn.BatchNorm2d(n_feats)) if i == 0: @@ -63,7 +63,7 @@ def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): - m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(conv(n_feats, 4 * n_feats, 3, bias=(bias and not bn))) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feats))