Skip to content

Commit

Permalink
Make model output all sorted word IDs, not just top one (argmax).
Browse files Browse the repository at this point in the history
Also add --debug flag to eval script to output probabilities.
  • Loading branch information
macournoyer committed Dec 17, 2015
1 parent b177096 commit 4cd74d9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 21 deletions.
34 changes: 27 additions & 7 deletions eval.lua
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
require 'neuralconvo'
local tokenizer = require "tokenizer"
local list = require "pl.List"
local options = {}

if dataset == nil then
cmd = torch.CmdLine()
cmd:text('Options:')
cmd:option('--cuda', false, 'use CUDA. Training must be done on CUDA')
cmd:option('--debug', false, 'show debug info')
cmd:text()
options = cmd:parse(arg)

Expand All @@ -24,15 +26,26 @@ if model == nil then
model = torch.load("data/model.t7")
end

-- Word ID tensor to words
function t2w(t)
-- Word IDs to sentence
function pred2sent(wordIds, i)
local words = {}
i = i or 1

for i = 1, t:size(1) do
table.insert(words, dataset.id2word[t[i]])
for _, wordId in ipairs(wordIds) do
local word = dataset.id2word[wordId[i]]
table.insert(words, word)
end

return words
return tokenizer.join(words)
end

function printProbabilities(wordIds, probabilities, i)
local words = {}

for p, wordId in ipairs(wordIds) do
local word = dataset.id2word[wordId[i]]
print(string.format("%-23s(%4d%%)", word, probabilities[p][i] * 100))
end
end

function say(text)
Expand All @@ -44,7 +57,14 @@ function say(text)
end

local input = torch.Tensor(list.reverse(wordIds))
local output = model:eval(input)
local wordIds, probabilities = model:eval(input)

print(">> " .. pred2sent(wordIds))

print(">> " .. tokenizer.join(t2w(torch.Tensor(output))))
if options.debug then
for i = 1, 4 do
print(string.rep("-", 30))
printProbabilities(wordIds, probabilities, i)
end
end
end
30 changes: 16 additions & 14 deletions seq2seq.lua
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,6 @@ function Seq2Seq:train(input, target)
return Edecoder
end

local function argmax(t)
local max = t:max()
for i = 1, t:size(1) do
if t[i] == max then
return i
end
end
end

local MAX_OUTPUT_SIZE = 20

function Seq2Seq:eval(input)
Expand All @@ -109,20 +100,31 @@ function Seq2Seq:eval(input)
self.encoder:forward(input)
self:forwardConnect(input:size(1))

local predictions = {}
local probabilities = {}

-- Forward <go> and all of it's output recursively back to the decoder
local output = self.goToken
local outputs = {}
for i = 1, MAX_OUTPUT_SIZE do
local predictions = self.decoder:forward(torch.Tensor{output})
output = argmax(predictions[1])
local prediction = self.decoder:forward(torch.Tensor{output})[1]
-- prediction contains the probabilities for each word IDs.
-- The index of the probability is the word ID.
local prob, wordIds = prediction:sort(1, true)

-- First one is the most likely.
output = wordIds[1]

-- Terminate on EOS token
if output == self.eosToken then
break
end
table.insert(outputs, output)

table.insert(predictions, wordIds)
table.insert(probabilities, prob)
end

self.decoder:forget()
self.encoder:forget()

return outputs
return predictions, probabilities
end

0 comments on commit 4cd74d9

Please sign in to comment.