-
Notifications
You must be signed in to change notification settings - Fork 346
/
seq2seq.lua
109 lines (86 loc) · 3.2 KB
/
seq2seq.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
-- Based on https://github.com/Element-Research/rnn/blob/master/examples/encoder-decoder-coupling.lua
local Seq2Seq = torch.class("neuralconvo.Seq2Seq")
function Seq2Seq:__init(vocabSize, hiddenSize)
self.vocabSize = assert(vocabSize, "vocabSize required at arg #1")
self.hiddenSize = assert(hiddenSize, "hiddenSize required at arg #2")
self:buildModel()
end
function Seq2Seq:buildModel()
self.encoder = nn.Sequential()
self.encoder:add(nn.LookupTableMaskZero(self.vocabSize, self.hiddenSize))
self.encoderLSTM = nn.FastLSTM(self.hiddenSize, self.hiddenSize):maskZero(1)
self.encoder:add(nn.Sequencer(self.encoderLSTM))
self.encoder:add(nn.Select(1,-1))
self.decoder = nn.Sequential()
self.decoder:add(nn.LookupTableMaskZero(self.vocabSize, self.hiddenSize))
self.decoderLSTM = nn.FastLSTM(self.hiddenSize, self.hiddenSize):maskZero(1)
self.decoder:add(nn.Sequencer(self.decoderLSTM))
self.decoder:add(nn.Sequencer(nn.MaskZero(nn.Linear(self.hiddenSize, self.vocabSize),1)))
self.decoder:add(nn.Sequencer(nn.MaskZero(nn.LogSoftMax(),1)))
self.encoder:zeroGradParameters()
self.decoder:zeroGradParameters()
end
function Seq2Seq:cuda()
self.encoder:cuda()
self.decoder:cuda()
if self.criterion then
self.criterion:cuda()
end
end
function Seq2Seq:float()
self.encoder:float()
self.decoder:float()
if self.criterion then
self.criterion:float()
end
end
function Seq2Seq:cl()
self.encoder:cl()
self.decoder:cl()
if self.criterion then
self.criterion:cl()
end
end
function Seq2Seq:getParameters()
return nn.Container():add(self.encoder):add(self.decoder):getParameters()
end
--[[ Forward coupling: Copy encoder cell and output to decoder LSTM ]]--
function Seq2Seq:forwardConnect(inputSeqLen)
self.decoderLSTM.userPrevOutput =
nn.rnn.recursiveCopy(self.decoderLSTM.userPrevOutput, self.encoderLSTM.outputs[inputSeqLen])
self.decoderLSTM.userPrevCell =
nn.rnn.recursiveCopy(self.decoderLSTM.userPrevCell, self.encoderLSTM.cells[inputSeqLen])
end
--[[ Backward coupling: Copy decoder gradients to encoder LSTM ]]--
function Seq2Seq:backwardConnect(inputSeqLen)
self.encoderLSTM:setGradHiddenState(inputSeqLen, self.decoderLSTM:getGradHiddenState(0))
end
local MAX_OUTPUT_SIZE = 20
function Seq2Seq:eval(input)
assert(self.goToken, "No goToken specified")
assert(self.eosToken, "No eosToken specified")
self.encoder:forward(input)
self:forwardConnect(input:size(1))
local predictions = {}
local probabilities = {}
-- Forward <go> and all of it's output recursively back to the decoder
local output = {self.goToken}
for i = 1, MAX_OUTPUT_SIZE do
local prediction = self.decoder:forward(torch.Tensor(output))[#output]
-- prediction contains the probabilities for each word IDs.
-- The index of the probability is the word ID.
local prob, wordIds = prediction:topk(5, 1, true, true)
-- First one is the most likely.
next_output = wordIds[1]
table.insert(output, next_output)
-- Terminate on EOS token
if next_output == self.eosToken then
break
end
table.insert(predictions, wordIds)
table.insert(probabilities, prob)
end
self.decoder:forget()
self.encoder:forget()
return predictions, probabilities
end