-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathgated_cnn.py
71 lines (59 loc) · 2.68 KB
/
gated_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
import torch.nn.functional as F
class GatedCNN(nn.Module):
'''
In : (N, sentence_len)
Out: (N, sentence_len, embd_size)
'''
def __init__(self,
seq_len,
vocab_size,
embd_size,
n_layers,
kernel,
out_chs,
res_block_count,
ans_size):
super(GatedCNN, self).__init__()
self.res_block_count = res_block_count
# self.embd_size = embd_size
self.embedding = nn.Embedding(vocab_size, embd_size)
# nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, ...
self.conv_0 = nn.Conv2d(1, out_chs, kernel, padding=(2, 0))
self.b_0 = nn.Parameter(torch.randn(1, out_chs, 1, 1))
self.conv_gate_0 = nn.Conv2d(1, out_chs, kernel, padding=(2, 0))
self.c_0 = nn.Parameter(torch.randn(1, out_chs, 1, 1))
self.conv = nn.ModuleList([nn.Conv2d(out_chs, out_chs, (kernel[0], 1), padding=(2, 0)) for _ in range(n_layers)])
self.conv_gate = nn.ModuleList([nn.Conv2d(out_chs, out_chs, (kernel[0], 1), padding=(2, 0)) for _ in range(n_layers)])
self.b = nn.ParameterList([nn.Parameter(torch.randn(1, out_chs, 1, 1)) for _ in range(n_layers)])
self.c = nn.ParameterList([nn.Parameter(torch.randn(1, out_chs, 1, 1)) for _ in range(n_layers)])
self.fc = nn.Linear(out_chs*seq_len, ans_size)
def forward(self, x):
# x: (N, seq_len)
# Embedding
bs = x.size(0) # batch size
seq_len = x.size(1)
x = self.embedding(x) # (bs, seq_len, embd_size)
# CNN
x = x.unsqueeze(1) # (bs, Cin, seq_len, embd_size), insert Channnel-In dim
# Conv2d
# Input : (bs, Cin, Hin, Win )
# Output: (bs, Cout, Hout, Wout)
A = self.conv_0(x) # (bs, Cout, seq_len, 1)
A += self.b_0.repeat(1, 1, seq_len, 1)
B = self.conv_gate_0(x) # (bs, Cout, seq_len, 1)
B += self.c_0.repeat(1, 1, seq_len, 1)
h = A * F.sigmoid(B) # (bs, Cout, seq_len, 1)
res_input = h # TODO this is h1 not h0
for i, (conv, conv_gate) in enumerate(zip(self.conv, self.conv_gate)):
A = conv(h) + self.b[i].repeat(1, 1, seq_len, 1)
B = conv_gate(h) + self.c[i].repeat(1, 1, seq_len, 1)
h = A * F.sigmoid(B) # (bs, Cout, seq_len, 1)
if i % self.res_block_count == 0: # size of each residual block
h += res_input
res_input = h
h = h.view(bs, -1) # (bs, Cout*seq_len)
out = self.fc(h) # (bs, ans_size)
out = F.log_softmax(out)
return out