-
Notifications
You must be signed in to change notification settings - Fork 2
/
complexnn.py
424 lines (360 loc) · 15.7 KB
/
complexnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def get_casual_padding1d():
pass
def get_casual_padding2d():
pass
class cPReLU(nn.Module):
def __init__(self, complex_axis=1):
super(cPReLU,self).__init__()
self.r_prelu = nn.PReLU()
self.i_prelu = nn.PReLU()
self.complex_axis = complex_axis
def forward(self, inputs):
real, imag = torch.chunk(inputs, 2,self.complex_axis)
real = self.r_prelu(real)
imag = self.i_prelu(imag)
return torch.cat([real,imag],self.complex_axis)
class NavieComplexLSTM(nn.Module):
def __init__(self, input_size, hidden_size, projection_dim=None, bidirectional=False, batch_first=False):
super(NavieComplexLSTM, self).__init__()
self.input_dim = input_size//2
self.rnn_units = hidden_size//2
self.real_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional, batch_first=False)
self.imag_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional, batch_first=False)
if bidirectional:
bidirectional=2
else:
bidirectional=1
if projection_dim is not None:
self.projection_dim = projection_dim//2
self.r_trans = nn.Linear(self.rnn_units*bidirectional, self.projection_dim)
self.i_trans = nn.Linear(self.rnn_units*bidirectional, self.projection_dim)
else:
self.projection_dim = None
def forward(self, inputs):
if isinstance(inputs,list):
real, imag = inputs
elif isinstance(inputs, torch.Tensor):
real, imag = torch.chunk(inputs,-1)
r2r_out = self.real_lstm(real)[0]
r2i_out = self.imag_lstm(real)[0]
i2r_out = self.real_lstm(imag)[0]
i2i_out = self.imag_lstm(imag)[0]
real_out = r2r_out - i2i_out
imag_out = i2r_out + r2i_out
if self.projection_dim is not None:
real_out = self.r_trans(real_out)
imag_out = self.i_trans(imag_out)
#print(real_out.shape,imag_out.shape)
return [real_out, imag_out]
def flatten_parameters(self):
self.imag_lstm.flatten_parameters()
self.real_lstm.flatten_parameters()
def complex_cat(inputs, axis):
real, imag = [],[]
for idx, data in enumerate(inputs):
r, i = torch.chunk(data,2,axis)
real.append(r)
imag.append(i)
real = torch.cat(real,axis)
imag = torch.cat(imag,axis)
outputs = torch.cat([real, imag],axis)
return outputs
class ComplexConv2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=(1,1),
stride=(1,1),
padding=(0,0),
dilation=1,
groups = 1,
causal=True,
complex_axis=1,
):
'''
in_channels: real+imag
out_channels: real+imag
kernel_size : input [B,C,D,T] kernel size in [D,T]
padding : input [B,C,D,T] padding in [D,T]
causal: if causal, will padding time dimension's left side,
otherwise both
'''
super(ComplexConv2d, self).__init__()
self.in_channels = in_channels//2
self.out_channels = out_channels//2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.causal = causal
self.groups = groups
self.dilation = dilation
self.complex_axis=complex_axis
self.real_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,padding=[self.padding[0],0],dilation=self.dilation, groups=self.groups)
self.imag_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,padding=[self.padding[0],0],dilation=self.dilation, groups=self.groups)
nn.init.normal_(self.real_conv.weight.data,std=0.05)
nn.init.normal_(self.imag_conv.weight.data,std=0.05)
nn.init.constant_(self.real_conv.bias,0.)
nn.init.constant_(self.imag_conv.bias,0.)
def forward(self,inputs):
if self.padding[1] != 0 and self.causal:
inputs = F.pad(inputs,[self.padding[1], 0,0,0])
else:
inputs = F.pad(inputs,[self.padding[1], self.padding[1],0,0])
if self.complex_axis == 0:
real = self.real_conv(inputs)
imag = self.imag_conv(inputs)
real2real,imag2real = torch.chunk(real,2, self.complex_axis)
real2imag,imag2imag = torch.chunk(imag,2, self.complex_axis)
else:
if isinstance(inputs, torch.Tensor):
real,imag = torch.chunk(inputs, 2, self.complex_axis)
real2real = self.real_conv(real,)
imag2imag = self.imag_conv(imag,)
real2imag = self.imag_conv(real)
imag2real = self.real_conv(imag)
real = real2real - imag2imag
imag = real2imag + imag2real
out = torch.cat([real, imag], self.complex_axis)
return out
class ComplexConvTranspose2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=(1,1),
stride=(1,1),
padding=(0,0),
output_padding=(0,0),
causal=False,
complex_axis=1,
groups=1,
dilation=1
):
'''
in_channels: real+imag
out_channels: real+imag
'''
super(ComplexConvTranspose2d, self).__init__()
self.in_channels = in_channels//2
self.out_channels = out_channels//2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.output_padding=output_padding
self.groups = groups
self.real_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels,kernel_size, self.stride,padding=self.padding,output_padding=output_padding, groups=self.groups, dilation=dilation)
self.imag_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels,kernel_size, self.stride,padding=self.padding,output_padding=output_padding, groups=self.groups, dilation=dilation)
self.complex_axis=complex_axis
nn.init.normal_(self.real_conv.weight,std=0.05)
nn.init.normal_(self.imag_conv.weight,std=0.05)
nn.init.constant_(self.real_conv.bias,0.)
nn.init.constant_(self.imag_conv.bias,0.)
def forward(self,inputs):
if isinstance(inputs, torch.Tensor):
real,imag = torch.chunk(inputs, 2, self.complex_axis)
elif isinstance(inputs, tuple) or isinstance(inputs, list):
real = inputs[0]
imag = inputs[1]
if self.complex_axis == 0:
real = self.real_conv(inputs)
imag = self.imag_conv(inputs)
real2real,imag2real = torch.chunk(real,2, self.complex_axis)
real2imag,imag2imag = torch.chunk(imag,2, self.complex_axis)
else:
if isinstance(inputs, torch.Tensor):
real,imag = torch.chunk(inputs, 2, self.complex_axis)
real2real = self.real_conv(real,)
imag2imag = self.imag_conv(imag,)
real2imag = self.imag_conv(real)
imag2real = self.real_conv(imag)
real = real2real - imag2imag
imag = real2imag + imag2real
out = torch.cat([real, imag], self.complex_axis)
return out
# Source: https://github.com/ChihebTrabelsi/deep_complex_networks/tree/pytorch
# from https://github.com/IMLHF/SE_DCUNet/blob/f28bf1661121c8901ad38149ea827693f1830715/models/layers/complexnn.py#L55
class ComplexBatchNorm(torch.nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True, complex_axis=1):
super(ComplexBatchNorm, self).__init__()
self.num_features = num_features//2
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self.complex_axis = complex_axis
if self.affine:
self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features))
self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features))
self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features))
self.Br = torch.nn.Parameter(torch.Tensor(self.num_features))
self.Bi = torch.nn.Parameter(torch.Tensor(self.num_features))
else:
self.register_parameter('Wrr', None)
self.register_parameter('Wri', None)
self.register_parameter('Wii', None)
self.register_parameter('Br', None)
self.register_parameter('Bi', None)
if self.track_running_stats:
self.register_buffer('RMr', torch.zeros(self.num_features))
self.register_buffer('RMi', torch.zeros(self.num_features))
self.register_buffer('RVrr', torch.ones (self.num_features))
self.register_buffer('RVri', torch.zeros(self.num_features))
self.register_buffer('RVii', torch.ones (self.num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
else:
self.register_parameter('RMr', None)
self.register_parameter('RMi', None)
self.register_parameter('RVrr', None)
self.register_parameter('RVri', None)
self.register_parameter('RVii', None)
self.register_parameter('num_batches_tracked', None)
self.reset_parameters()
def reset_running_stats(self):
if self.track_running_stats:
self.RMr.zero_()
self.RMi.zero_()
self.RVrr.fill_(1)
self.RVri.zero_()
self.RVii.fill_(1)
self.num_batches_tracked.zero_()
def reset_parameters(self):
self.reset_running_stats()
if self.affine:
self.Br.data.zero_()
self.Bi.data.zero_()
self.Wrr.data.fill_(1)
self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
self.Wii.data.fill_(1)
def _check_input_dim(self, xr, xi):
assert(xr.shape == xi.shape)
assert(xr.size(1) == self.num_features)
def forward(self, inputs):
#self._check_input_dim(xr, xi)
xr, xi = torch.chunk(inputs,2, axis=self.complex_axis)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
#
# NOTE: The precise meaning of the "training flag" is:
# True: Normalize using batch statistics, update running statistics
# if they are being collected.
# False: Normalize using running statistics, ignore batch statistics.
#
training = self.training or not self.track_running_stats
redux = [i for i in reversed(range(xr.dim())) if i!=1]
vdim = [1] * xr.dim()
vdim[1] = xr.size(1)
#
# Mean M Computation and Centering
#
# Includes running mean update if training and running.
#
if training:
Mr, Mi = xr, xi
for d in redux:
Mr = Mr.mean(d, keepdim=True)
Mi = Mi.mean(d, keepdim=True)
if self.track_running_stats:
self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
else:
Mr = self.RMr.view(vdim)
Mi = self.RMi.view(vdim)
xr, xi = xr-Mr, xi-Mi
#
# Variance Matrix V Computation
#
# Includes epsilon numerical stabilizer/Tikhonov regularizer.
# Includes running variance update if training and running.
#
if training:
Vrr = xr * xr
Vri = xr * xi
Vii = xi * xi
for d in redux:
Vrr = Vrr.mean(d, keepdim=True)
Vri = Vri.mean(d, keepdim=True)
Vii = Vii.mean(d, keepdim=True)
if self.track_running_stats:
self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
else:
Vrr = self.RVrr.view(vdim)
Vri = self.RVri.view(vdim)
Vii = self.RVii.view(vdim)
Vrr = Vrr + self.eps
Vri = Vri
Vii = Vii + self.eps
#
# Matrix Inverse Square Root U = V^-0.5
#
# sqrt of a 2x2 matrix,
# - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
tau = Vrr + Vii
delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri)
s = delta.sqrt()
t = (tau + 2*s).sqrt()
# matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
rst = (s * t).reciprocal()
Urr = (s + Vii) * rst
Uii = (s + Vrr) * rst
Uri = ( - Vri) * rst
#
# Optionally left-multiply U by affine weights W to produce combined
# weights Z, left-multiply the inputs by Z, then optionally bias them.
#
# y = Zx + B
# y = WUx + B
# y = [Wrr Wri][Urr Uri] [xr] + [Br]
# [Wir Wii][Uir Uii] [xi] [Bi]
#
if self.affine:
Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
Zrr = (Wrr * Urr) + (Wri * Uri)
Zri = (Wrr * Uri) + (Wri * Uii)
Zir = (Wri * Urr) + (Wii * Uri)
Zii = (Wri * Uri) + (Wii * Uii)
else:
Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
yr = (Zrr * xr) + (Zri * xi)
yi = (Zir * xr) + (Zii * xi)
if self.affine:
yr = yr + self.Br.view(vdim)
yi = yi + self.Bi.view(vdim)
outputs = torch.cat([yr, yi], self.complex_axis)
return outputs
def extra_repr(self):
return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
'track_running_stats={track_running_stats}'.format(**self.__dict__)
def complex_cat(inputs, axis):
real, imag = [],[]
for idx, data in enumerate(inputs):
r, i = torch.chunk(data,2,axis)
real.append(r)
imag.append(i)
real = torch.cat(real,axis)
imag = torch.cat(imag,axis)
outputs = torch.cat([real, imag],axis)
return outputs
if __name__ == '__main__':
import dc_crn7
torch.manual_seed(20)
onet1 = dc_crn7.ComplexConv2d(12,12,kernel_size=(3,2),padding=(2,1))
onet2 = dc_crn7.ComplexConvTranspose2d(12,12,kernel_size=(3,2),padding=(2,1))
inputs = torch.randn([1,12,12,10])
# print(onet1.real_kernel[0,0,0,0])
nnet1 = ComplexConv2d(12,12,kernel_size=(3,2),padding=(2,1),causal=True)
# print(nnet1.real_conv.weight[0,0,0,0])
nnet2 = ComplexConvTranspose2d(12,12,kernel_size=(3,2),padding=(2,1))
print(torch.mean(nnet1(inputs)-onet1(inputs)))