-
Notifications
You must be signed in to change notification settings - Fork 0
/
module.py
358 lines (288 loc) · 11.5 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
class SpectralNorm(object):
r"""
Spectral Normalization for GANs (Miyato 2018).
Inheritable class for performing spectral normalization of weights,
as approximated using power iteration.
Details: See Algorithm 1 of Appendix A (Miyato 2018).
Attributes:
n_dim (int): Number of dimensions.
num_iters (int): Number of iterations for power iter.
eps (float): Epsilon for zero division tolerance when normalizing.
"""
def __init__(self, n_dim, num_iters=1, eps=1e-12):
self.num_iters = num_iters
self.eps = eps
# Register a singular vector for each sigma
self.register_buffer("sn_u", torch.randn(1, n_dim))
self.register_buffer("sn_sigma", torch.ones(1))
@property
def u(self):
return getattr(self, "sn_u")
@property
def sigma(self):
return getattr(self, "sn_sigma")
def _power_iteration(self, W, u, num_iters, eps=1e-12):
with torch.no_grad():
for _ in range(num_iters):
v = F.normalize(torch.matmul(u, W), eps=eps)
u = F.normalize(torch.matmul(v, W.t()), eps=eps)
# Note: must have gradients, otherwise weights do not get updated!
sigma = torch.mm(u, torch.mm(W, v.t()))
return sigma, u, v
def sn_weights(self):
r"""
Spectrally normalize current weights of the layer.
"""
W = self.weight.view(self.weight.shape[0], -1)
# Power iteration
sigma, u, v = self._power_iteration(
W=W, u=self.u, num_iters=self.num_iters, eps=self.eps
)
# Update only during training
if self.training:
with torch.no_grad():
self.sigma[:] = sigma
self.u[:] = u
return self.weight / sigma
class SNConv2d(nn.Conv2d, SpectralNorm):
r"""
Spectrally normalized layer for Conv2d.
Attributes:
in_channels (int): Input channel dimension.
out_channels (int): Output channel dimensions.
"""
def __init__(self, in_channels, out_channels, *args, **kwargs):
nn.Conv2d.__init__(self, in_channels, out_channels, *args, **kwargs)
SpectralNorm.__init__(
self, n_dim=out_channels, num_iters=kwargs.get("num_iters", 1)
)
def forward(self, x):
return F.conv2d(
input=x,
weight=self.sn_weights(),
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
class SNLinear(nn.Linear, SpectralNorm):
r"""
Spectrally normalized layer for Linear.
Attributes:
in_features (int): Input feature dimensions.
out_features (int): Output feature dimensions.
"""
def __init__(self, in_features, out_features, *args, **kwargs):
nn.Linear.__init__(self, in_features, out_features, *args, **kwargs)
SpectralNorm.__init__(
self, n_dim=out_features, num_iters=kwargs.get("num_iters", 1)
)
def forward(self, x):
return F.linear(input=x, weight=self.sn_weights(), bias=self.bias)
class GBlock(nn.Module):
r"""
Residual block for generator.
Uses bilinear (rather than nearest) interpolation, and align_corners
set to False. This is as per how torchvision does upsampling, as seen in:
https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/_utils.py
Attributes:
in_channels (int): The channel size of input feature map.
out_channels (int): The channel size of output feature map.
upsample (bool): If True, upsamples the input feature map.
"""
def __init__(self, in_channels, out_channels, upsample=False):
super().__init__()
self.learnable_sc = in_channels != out_channels or upsample
self.upsample = upsample
self.c1 = SNConv2d(in_channels, out_channels, 3, 1, padding=1)
self.c2 = SNConv2d(out_channels, out_channels, 3, 1, padding=1)
if self.learnable_sc:
self.c_sc = SNConv2d(in_channels, out_channels, 1, 1, padding=0)
self.b1 = nn.BatchNorm2d(in_channels)
self.b2 = nn.BatchNorm2d(out_channels)
self.activation = nn.ReLU(True)
nn.init.xavier_uniform_(self.c1.weight.data, math.sqrt(2.0))
nn.init.xavier_uniform_(self.c2.weight.data, math.sqrt(2.0))
if self.learnable_sc:
nn.init.xavier_uniform_(self.c_sc.weight.data, 1.0)
def _upsample_conv(self, x, conv):
return conv(
F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
)
def _residual(self, x):
h = x
h = self.b1(h)
h = self.activation(h)
h = self._upsample_conv(h, self.c1) if self.upsample else self.c1(h)
h = self.b2(h)
h = self.activation(h)
h = self.c2(h)
return h
def _shortcut(self, x):
if self.learnable_sc:
x = self._upsample_conv(x, self.c_sc) if self.upsample else self.c_sc(x)
return x
def forward(self, x):
return self._residual(x) + self._shortcut(x)
class DBlock(nn.Module):
"""
Residual block for discriminator.
Attributes:
in_channels (int): The channel size of input feature map.
out_channels (int): The channel size of output feature map.
downsample (bool): If True, downsamples the input feature map.
"""
def __init__(self, in_channels, out_channels, downsample=False):
super().__init__()
self.downsample = downsample
self.learnable_sc = (in_channels != out_channels) or downsample
self.c1 = SNConv2d(in_channels, in_channels, 3, 1, 1)
self.c2 = SNConv2d(in_channels, out_channels, 3, 1, 1)
if self.learnable_sc:
self.c_sc = SNConv2d(in_channels, out_channels, 1, 1, 0)
self.activation = nn.ReLU(True)
nn.init.xavier_uniform_(self.c1.weight.data, math.sqrt(2.0))
nn.init.xavier_uniform_(self.c2.weight.data, math.sqrt(2.0))
if self.learnable_sc:
nn.init.xavier_uniform_(self.c_sc.weight.data, 1.0)
def _residual(self, x):
h = x
h = self.activation(h)
h = self.c1(h)
h = self.activation(h)
h = self.c2(h)
if self.downsample:
h = F.avg_pool2d(h, 2)
return h
def _shortcut(self, x):
if self.learnable_sc:
x = self.c_sc(x)
x = F.avg_pool2d(x, 2) if self.downsample else x
return x
def forward(self, x):
return self._residual(x) + self._shortcut(x)
class DBlockOptimized(nn.Module):
"""
Optimized residual block for discriminator. This is used as the first residual block,
where there is a definite downsampling involved. Follows the official SNGAN reference implementation
in chainer.
Attributes:
in_channels (int): The channel size of input feature map.
out_channels (int): The channel size of output feature map.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.c1 = SNConv2d(in_channels, out_channels, 3, 1, 1)
self.c2 = SNConv2d(out_channels, out_channels, 3, 1, 1)
self.c_sc = SNConv2d(in_channels, out_channels, 1, 1, 0)
self.activation = nn.ReLU(True)
nn.init.xavier_uniform_(self.c1.weight.data, math.sqrt(2.0))
nn.init.xavier_uniform_(self.c2.weight.data, math.sqrt(2.0))
nn.init.xavier_uniform_(self.c_sc.weight.data, 1.0)
def _residual(self, x):
h = x
h = self.c1(h)
h = self.activation(h)
h = self.c2(h)
h = F.avg_pool2d(h, 2)
return h
def _shortcut(self, x):
return self.c_sc(F.avg_pool2d(x, 2))
def forward(self, x):
return self._residual(x) + self._shortcut(x)
# FastGAN modules
class UpSample(nn.Module):
r"""
Upsample block for FastGAN generator.
Uses nearest sample interpolation, conv, batchnorm and GLU.
Attributes:
in_channels (int): The channel size of input feature map.
out_channels (int): The channel size of output feature map.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self._upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
spectral_norm(nn.Conv2d(in_channels, out_channels*2, 3, 1, 1, bias=False)),
nn.BatchNorm2d(out_channels*2), nn.GLU(dim=1))
def forward(self, x):
return self._upsample(x)
class SLEBlock(nn.Module):
r"""
Skip-layer Excitation block for FastGAN generator.
Attributes:
in_channels (int): The channel size of input feature map.
out_channels (int): The channel size of output feature map.
"""
def __init__(self, ch_in, ch_out):
super().__init__()
self._sle = nn.Sequential(
nn.AdaptiveAvgPool2d(4),
spectral_norm(nn.Conv2d(ch_in, ch_out, 4, 1, 0, bias=False)),
nn.SiLU(),
spectral_norm(nn.Conv2d(ch_out, ch_out, 1, 1, 0, bias=False)),
nn.Sigmoid())
def forward(self, f_low, f_hi):
return f_hi * self._sle(f_low)
class InitBlock(nn.Module):
def __init__(self, nz, channel):
super().__init__()
self._init = nn.Sequential(
spectral_norm(nn.ConvTranspose2d(nz, channel*2, 4, 1, 0, bias=False)),
nn.BatchNorm2d(channel*2),
nn.GLU(dim=1)
)
def forward(self, z):
z = z.view(z.shape[0], -1, 1, 1)
return self._init(z)
class DownSample(nn.Module):
r"""
Downsample block for FastGAN discriminator.
Attributes:
in_channels (int): The channel size of input feature map.
out_channels (int): The channel size of output feature map.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.downsample_1 = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, inplace=True))
self.downsample_2 = nn.Sequential(
nn.AvgPool2d(2, 2),
spectral_norm(nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, inplace=True))
def forward(self, x):
return 0.5*(self.downsample_1(x) + self.downsample_2(x))
class FastGANDecoder(nn.Module):
r"""
Simple decoder block for FastGAN discriminator.
Attributes:
in_channels (int): The channel size of input feature map.
out_channels (int): The channel size of output feature map.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
nfc_multi = {16:4, 32:2, 64:2, 128:1}
nfc = {}
for k, v in nfc_multi.items():
nfc[k] = int(v*32)
self._decoder = nn.Sequential(nn.AdaptiveAvgPool2d(8),
UpSample(in_channels, nfc[16]),
UpSample(nfc[16], nfc[32]),
UpSample(nfc[32], nfc[64]),
UpSample(nfc[64], nfc[128]),
spectral_norm(nn.Conv2d(nfc[128], out_channels, 3, 1, 1, bias=False)),
nn.Tanh())
def forward(self, x):
return self._decoder(x)