-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLSTMClique.py
69 lines (63 loc) · 3.14 KB
/
LSTMClique.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
USE_CUDA = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
class LSTMClique(nn.Module):
def __init__(self, params, data_obj):
super(LSTMClique, self).__init__()
self.embedding_dim = params['embedding_dim']
self.hidden_dim = params['hidden_dim']
self.lstm_dim = params['lstm_dim']
self.dropout = params['dropout']
self.clique_size = params['clique_size']
self.embeddings = data_obj.word_embeds
self.lstm = nn.LSTM(self.embedding_dim, self.lstm_dim)
self.hidden = None
self.clique_layer = nn.Linear(params['clique_size'] * self.lstm_dim, self.hidden_dim)
nn.init.xavier_uniform(self.clique_layer.weight, gain=nn.init.calculate_gain('tanh'))
self.task = params['task']
if params['task'] == 'perm':
num_labels = 2
elif params['task'] == 'minority':
num_labels = 2
elif params['task'] == 'class':
num_labels = 3
elif params['task'] == 'score_pred':
num_labels = 1
self.predict_layer = nn.Linear(self.hidden_dim, num_labels)
nn.init.xavier_uniform(self.predict_layer.weight, gain=nn.init.calculate_gain('sigmoid'))
if USE_CUDA:
self.clique_layer = self.clique_layer.cuda()
self.predict_layer = self.predict_layer.cuda()
def init_hidden(self, batch_size):
if USE_CUDA:
return (Variable(torch.zeros(1, batch_size, self.lstm_dim).cuda()),
Variable(torch.zeros(1, batch_size, self.lstm_dim)).cuda())
else:
return (Variable(torch.zeros(1, batch_size, self.lstm_dim)),
Variable(torch.zeros(1, batch_size, self.lstm_dim)))
def forward(self, inputs, input_lengths, original_index): # now with cliques
lstm_out = None
for i in range(self.clique_size): # send each sentence x batch through LSTM
self.hidden = self.init_hidden(len(input_lengths[i]))
seq_tensor = self.embeddings(inputs[i])
packed_input = pack_padded_sequence(seq_tensor, input_lengths[i], batch_first=True)
packed_output, (ht, ct) = self.lstm(packed_input, self.hidden)
# reorder
final_output = ht[-1]
odx = original_index[i].view(-1, 1).expand(len(input_lengths[i]), final_output.size(-1))
output_unsorted = torch.gather(final_output, 0, Variable(odx))
if lstm_out is None:
lstm_out = output_unsorted
else:
lstm_out = torch.cat([lstm_out, output_unsorted], dim=1)
clique_vector = F.tanh(self.clique_layer(lstm_out))
clique_vector = F.dropout(clique_vector, p=self.dropout, training=self.training)
coherence_pred = self.predict_layer(clique_vector)
if self.task != 'score_pred':
coherence_pred = F.softmax(coherence_pred, dim=0)
return coherence_pred