@@ -54,12 +54,14 @@ def __init__(self, in_channels=3, out_channels=3):
54
54
self .down4 = UNetDown (256 , 512 , dropout = 0.5 )
55
55
self .down5 = UNetDown (512 , 512 , dropout = 0.5 )
56
56
self .down6 = UNetDown (512 , 512 , dropout = 0.5 )
57
+ self .down7 = UNetDown (512 , 512 , dropout = 0.5 )
57
58
58
- self .up1 = UNetUp (512 , 512 , dropout = 0.5 )
59
+ self .up1 = UNetUp (512 , 512 , dropout = 0.8 )
59
60
self .up2 = UNetUp (1024 , 512 , dropout = 0.5 )
60
- self .up3 = UNetUp (1024 , 256 , dropout = 0.5 )
61
- self .up4 = UNetUp (512 , 128 )
62
- self .up5 = UNetUp (256 , 64 )
61
+ self .up3 = UNetUp (1024 , 512 , dropout = 0.5 )
62
+ self .up4 = UNetUp (1024 , 256 , dropout = 0.5 )
63
+ self .up5 = UNetUp (512 , 128 )
64
+ self .up6 = UNetUp (256 , 64 )
63
65
64
66
65
67
final = [ nn .Upsample (scale_factor = 2 ),
@@ -75,14 +77,19 @@ def forward(self, x):
75
77
d4 = self .down4 (d3 )
76
78
d5 = self .down5 (d4 )
77
79
d6 = self .down6 (d5 )
78
- u1 = self .up1 (d6 , d5 )
79
- u2 = self .up2 (u1 , d4 )
80
- u3 = self .up3 (u2 , d3 )
81
- u4 = self .up4 (u3 , d2 )
82
- u5 = self .up5 (u4 , d1 )
80
+ d7 = self .down7 (d6 )
81
+ u1 = self .up1 (d7 , d6 )
82
+ u2 = self .up2 (u1 , d5 )
83
+ u3 = self .up3 (u2 , d4 )
84
+ u4 = self .up4 (u3 , d3 )
85
+ u5 = self .up5 (u4 , d2 )
86
+ u6 = self .up6 (u5 , d1 )
83
87
84
- return self .final (u5 )
88
+ return self .final (u6 )
85
89
90
+ ##############################
91
+ # Discriminator
92
+ ##############################
86
93
87
94
class Discriminator (nn .Module ):
88
95
def __init__ (self , img_shape ):
@@ -104,12 +111,5 @@ def block(in_features, out_features, normalization=True):
104
111
nn .Conv2d (512 , 1 , 3 , 1 , 1 )
105
112
)
106
113
107
- #in_size = img_shape[1] // 2**4
108
- #self.output_layer = nn.Sequential(nn.Linear(512*in_size**2, 1))
109
-
110
114
def forward (self , img ):
111
- # feature_repr = self.model(img)
112
- # feature_repr = feature_repr.view(feature_repr.size(0), -1)
113
- # validity = self.output_layer(feature_repr)
114
-
115
115
return self .model (img )
0 commit comments