From 4cd74d95e07f14e35d6337d1a63239ffa04dcfc5 Mon Sep 17 00:00:00 2001 From: macournoyer Date: Wed, 16 Dec 2015 23:07:08 -0500 Subject: [PATCH] Make model output all sorted word IDs, not just top one (argmax). Also add --debug flag to eval script to output probabilities. --- eval.lua | 34 +++++++++++++++++++++++++++------- seq2seq.lua | 30 ++++++++++++++++-------------- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/eval.lua b/eval.lua index 42066e1..08e0016 100644 --- a/eval.lua +++ b/eval.lua @@ -1,11 +1,13 @@ require 'neuralconvo' local tokenizer = require "tokenizer" local list = require "pl.List" +local options = {} if dataset == nil then cmd = torch.CmdLine() cmd:text('Options:') cmd:option('--cuda', false, 'use CUDA. Training must be done on CUDA') + cmd:option('--debug', false, 'show debug info') cmd:text() options = cmd:parse(arg) @@ -24,15 +26,26 @@ if model == nil then model = torch.load("data/model.t7") end --- Word ID tensor to words -function t2w(t) +-- Word IDs to sentence +function pred2sent(wordIds, i) local words = {} + i = i or 1 - for i = 1, t:size(1) do - table.insert(words, dataset.id2word[t[i]]) + for _, wordId in ipairs(wordIds) do + local word = dataset.id2word[wordId[i]] + table.insert(words, word) end - return words + return tokenizer.join(words) +end + +function printProbabilities(wordIds, probabilities, i) + local words = {} + + for p, wordId in ipairs(wordIds) do + local word = dataset.id2word[wordId[i]] + print(string.format("%-23s(%4d%%)", word, probabilities[p][i] * 100)) + end end function say(text) @@ -44,7 +57,14 @@ function say(text) end local input = torch.Tensor(list.reverse(wordIds)) - local output = model:eval(input) + local wordIds, probabilities = model:eval(input) + + print(">> " .. pred2sent(wordIds)) - print(">> " .. tokenizer.join(t2w(torch.Tensor(output)))) + if options.debug then + for i = 1, 4 do + print(string.rep("-", 30)) + printProbabilities(wordIds, probabilities, i) + end + end end diff --git a/seq2seq.lua b/seq2seq.lua index eda7e26..1b1f422 100644 --- a/seq2seq.lua +++ b/seq2seq.lua @@ -91,15 +91,6 @@ function Seq2Seq:train(input, target) return Edecoder end -local function argmax(t) - local max = t:max() - for i = 1, t:size(1) do - if t[i] == max then - return i - end - end -end - local MAX_OUTPUT_SIZE = 20 function Seq2Seq:eval(input) @@ -109,20 +100,31 @@ function Seq2Seq:eval(input) self.encoder:forward(input) self:forwardConnect(input:size(1)) + local predictions = {} + local probabilities = {} + -- Forward and all of it's output recursively back to the decoder local output = self.goToken - local outputs = {} for i = 1, MAX_OUTPUT_SIZE do - local predictions = self.decoder:forward(torch.Tensor{output}) - output = argmax(predictions[1]) + local prediction = self.decoder:forward(torch.Tensor{output})[1] + -- prediction contains the probabilities for each word IDs. + -- The index of the probability is the word ID. + local prob, wordIds = prediction:sort(1, true) + + -- First one is the most likely. + output = wordIds[1] + + -- Terminate on EOS token if output == self.eosToken then break end - table.insert(outputs, output) + + table.insert(predictions, wordIds) + table.insert(probabilities, prob) end self.decoder:forget() self.encoder:forget() - return outputs + return predictions, probabilities end