forked from threelittlemonkeys/lstm-crf-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
248 lines (215 loc) · 10.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import torch
import torch.nn as nn
import torch.nn.functional as F
UNIT = "word" # unit of tokenization (char, word)
RNN_TYPE = "LSTM" # LSTM or GRU
NUM_DIRS = 2 # unidirectional: 1, bidirectional: 2
NUM_LAYERS = 2
BATCH_SIZE = 256
EMBED_UNIT = "char+word" # char, word, char+word
EMBED_SIZE = 300
HIDDEN_SIZE = 1000
DROPOUT = 0.5
LEARNING_RATE = 1e-4
SAVE_EVERY = 10
PAD = "<PAD>" # padding
SOS = "<SOS>" # start of sequence
EOS = "<EOS>" # end of sequence
UNK = "<UNK>" # unknown token
PAD_IDX = 0 # tag dành cho những token được pad vào câu cho đủ độ dài
SOS_IDX = 1 # t
EOS_IDX = 2
UNK_IDX = 3
torch.manual_seed(1)
CUDA = torch.cuda.is_available()
class rnn_crf(nn.Module):
def __init__(self, char_vocab_size, word_vocab_size, num_tags):
super().__init__()
self.rnn = rnn(char_vocab_size, word_vocab_size, num_tags)
self.crf = crf(num_tags)
self = self.cuda() if CUDA else self
def forward(self, cx, wx, y): # for training
# .gt(0) thực hiện phép so sánh wx > 0 để kiểm tra xem những chỗ nào không phải padding
# => cách s
mask = wx.data.gt(0).float()
h = self.rnn(cx, wx, mask)
Z = self.crf.forward(h, mask)
score = self.crf.score(h, y, mask)
return Z - score # NLL loss
def decode(self, cx, wx): # for prediction
mask = wx.data.gt(0).float()
h = self.rnn(cx, wx, mask)
return self.crf.decode(h, mask)
class embed(nn.Module):
def __init__(self, char_vocab_size, word_vocab_size, embed_size):
super().__init__()
num_embeds = EMBED_UNIT.count("+") + 1 # number of embeddings (1, 2)
dim = embed_size // num_embeds # dimension of each embedding vector
# architecture
if EMBED_UNIT[:4] == "char":
self.char_embed = self.cnn(char_vocab_size, dim)
if EMBED_UNIT[-4:] == "word":
self.word_embed = nn.Embedding(word_vocab_size, dim, padding_idx = PAD_IDX)
class cnn(nn.Module):
# với chuỗi từ, mỗi từ là 1 chuỗi ký tự, mỗi ký tự là 1 vector -> biến thành 1 chuỗi từ, mà mỗi từ giờ sẽ có
# biểu diễn là 1 vector thu được thông qua convolution nên tất cả các biểu diễn ký tự của t
def __init__(self, dim_in, dim_out):
super().__init__()
self.embed_size = 50
self.num_featmaps = 30 # feature maps generated by each kernel
self.kernel_sizes = [3]
# architecture
self.embed = nn.Embedding(dim_in, self.embed_size, padding_idx = PAD_IDX)
self.conv = nn.ModuleList([nn.Conv2d(
in_channels = 1, # Ci
out_channels = self.num_featmaps, # Co
kernel_size = (i, self.embed_size) # (height, width)
) for i in self.kernel_sizes]) # num_kernels (K)
# =>> có tất cả 'Ci x Co' filters, tương ứng với sự ghép cặp giữa 'Ci' in-channels và
# 'Co' out-channels. Output của phép toán convolution đối với mỗi out-channel sẽ được tính như sau:
# thực hiện phép toán convolution đối với từng in-channel -> xong cộng tất cả lại.
# xem thêm tại: https://pytorch.org/docs/stable/nn.html#torch.nn.Conv2d
self.dropout = nn.Dropout(DROPOUT)
self.fc = nn.Linear(len(self.kernel_sizes) * self.num_featmaps, dim_out)
def forward(self, x):
x = x.view(-1, x.size(2)) # [batch_size (B) * word_seq_len (L), char_seq_len (H)]
x = self.embed(x) # [B * L, H, embed_size (W)]
x = x.unsqueeze(1) # [B * L, Ci, H, W]
h = [conv(x) for conv in self.conv] # [B * L, Co, H-2, 1] * K
# h là 1 danh sách các feature map thu được khi thực hiện các phép toán convolution khác nhau
# với các kích cỡ kernel size khác nhau, mỗi out-channel sẽ có 1 feature map có kích cỡ là [H-2, 1]
# sau đây sẽ thực hiện max pooling trên mỗi feature map.
# conv2d là thưc hiện convolution trên ảnh, tức là theo 2 chiều -cao và chiều -rộng
# conv1d là thưc hiện convolution trên mảng, tức là theo 1 chiều -dài
# conv3d là thực hiện convolution trên hình khối 3 chiều, tức filter sẽ trượt trên cả chiều rộng, chiều cao
# và chiều sâu của đối tượng.
h = [F.relu(k).squeeze(3) for k in h] # [B * L, Co, H-2] * K
# max_pool1d(input, window_size) chỉ trượt trên chiều dài của input
# max_pool2d trượt trên 2 chiều: chiều cao + chiều rộng
# max_pool3d trượt trên 3 chiều: chiều cao + chiều rộng + chiều sâu của đối tượng
h = [F.max_pool1d(k, k.size(2)).squeeze(2) for k in h] # [B * L, Co] * K
h = torch.cat(h, 1) # [B * L, Co * K]
h = self.dropout(h)
h = self.fc(h) # [B * L, dim_out]
h = h.view(BATCH_SIZE, -1, h.size(1)) # [B, L, dim_out]
return h
class rnn(nn.Module): # TODO
def __init__(self, dim_in, dim_out):
pass
def forward(self, x):
pass
def forward(self, cx, wx):
# cx: charactor x
#
ch = self.char_embed(cx) if EMBED_UNIT[:4] == "char" else [] # thực hiện convolution trên các chuỗi ký tự để lấy biểu diễn của mỗi từ
wh = self.word_embed(wx) if EMBED_UNIT[-4:] == "word" else [] # thực hiện embedding để lấy biểu diễn của mỗi từ
h = torch.cat([ch, wh], 2)
return h
class rnn(nn.Module):
def __init__(self, char_vocab_size, word_vocab_size, num_tags):
super().__init__()
# architecture
self.embed = embed(char_vocab_size, word_vocab_size, EMBED_SIZE)
self.rnn = getattr(nn, RNN_TYPE)(
input_size = EMBED_SIZE,
hidden_size = HIDDEN_SIZE // NUM_DIRS,
num_layers = NUM_LAYERS,
bias = True,
batch_first = True,
dropout = DROPOUT,
bidirectional = NUM_DIRS == 2
)
self.out = nn.Linear(HIDDEN_SIZE, num_tags) # RNN output to tag
def init_hidden(self): # initialize hidden states
h = zeros(NUM_LAYERS * NUM_DIRS, BATCH_SIZE, HIDDEN_SIZE // NUM_DIRS) # hidden state
if RNN_TYPE == "LSTM":
c = zeros(NUM_LAYERS * NUM_DIRS, BATCH_SIZE, HIDDEN_SIZE // NUM_DIRS) # cell state
return (h, c)
return h
def forward(self, cx, wx, mask):
self.hidden = self.init_hidden()
x = self.embed(cx, wx)
x = nn.utils.rnn.pack_padded_sequence(x, mask.sum(1).int(), batch_first = True)
h, _ = self.rnn(x, self.hidden)
h, _ = nn.utils.rnn.pad_packed_sequence(h, batch_first = True)
h = self.out(h)
h *= mask.unsqueeze(2)
return h
class crf(nn.Module):
def __init__(self, num_tags):
super().__init__()
self.num_tags = num_tags
# matrix of transition scores from j to i
self.trans = nn.Parameter(randn(num_tags, num_tags))
self.trans.data[SOS_IDX, :] = -10000. # no transition to SOS
self.trans.data[:, EOS_IDX] = -10000. # no transition from EOS except to PAD
self.trans.data[:, PAD_IDX] = -10000. # no transition from PAD except to PAD
self.trans.data[PAD_IDX, :] = -10000. # no transition to PAD except from EOS
self.trans.data[PAD_IDX, EOS_IDX] = 0.
self.trans.data[PAD_IDX, PAD_IDX] = 0.
def forward(self, h, mask): # forward algorithm
# initialize forward variables in log space
score = Tensor(BATCH_SIZE, self.num_tags).fill_(-10000.) # [B, C]
score[:, SOS_IDX] = 0.
trans = self.trans.unsqueeze(0) # [1, C, C]
for t in range(h.size(1)): # recursion through the sequence
mask_t = mask[:, t].unsqueeze(1)
emit_t = h[:, t].unsqueeze(2) # [B, C, 1]
score_t = score.unsqueeze(1) + emit_t + trans # [B, 1, C] -> [B, C, C]
score_t = log_sum_exp(score_t) # [B, C, C] -> [B, C]
score = score_t * mask_t + score * (1 - mask_t)
score = log_sum_exp(score + self.trans[EOS_IDX])
return score # partition function
def score(self, h, y, mask): # calculate the score of a given sequence
score = Tensor(BATCH_SIZE).fill_(0.)
h = h.unsqueeze(3)
trans = self.trans.unsqueeze(2)
for t in range(h.size(1)): # recursion through the sequence
mask_t = mask[:, t]
emit_t = torch.cat([h[t, y[t + 1]] for h, y in zip(h, y)])
trans_t = torch.cat([trans[y[t + 1], y[t]] for y in y])
score += (emit_t + trans_t) * mask_t
last_tag = y.gather(1, mask.sum(1).long().unsqueeze(1)).squeeze(1)
score += self.trans[EOS_IDX, last_tag]
return score
def decode(self, h, mask): # Viterbi decoding
# initialize backpointers and viterbi variables in log space
bptr = LongTensor()
score = Tensor(BATCH_SIZE, self.num_tags).fill_(-10000.)
score[:, SOS_IDX] = 0.
for t in range(h.size(1)): # recursion through the sequence
mask_t = mask[:, t].unsqueeze(1)
score_t = score.unsqueeze(1) + self.trans # [B, 1, C] -> [B, C, C]
score_t, bptr_t = score_t.max(2) # best previous scores and tags
score_t += h[:, t] # plus emission scores
bptr = torch.cat((bptr, bptr_t.unsqueeze(1)), 1)
score = score_t * mask_t + score * (1 - mask_t)
score += self.trans[EOS_IDX]
best_score, best_tag = torch.max(score, 1)
# back-tracking
bptr = bptr.tolist()
best_path = [[i] for i in best_tag.tolist()]
for b in range(BATCH_SIZE):
x = best_tag[b] # best tag
y = int(mask[b].sum().item())
for bptr_t in reversed(bptr[b][:y]):
x = bptr_t[x]
best_path[b].append(x)
best_path[b].pop()
best_path[b].reverse()
return best_path
def Tensor(*args):
x = torch.Tensor(*args)
return x.cuda() if CUDA else x
def LongTensor(*args):
x = torch.LongTensor(*args)
return x.cuda() if CUDA else x
def randn(*args):
x = torch.randn(*args)
return x.cuda() if CUDA else x
def zeros(*args):
x = torch.zeros(*args)
return x.cuda() if CUDA else x
def log_sum_exp(x):
m = torch.max(x, -1)[0]
return m + torch.log(torch.sum(torch.exp(x - m.unsqueeze(-1)), -1))