From 05b6ce66533e09cd2b6f7f9b759dd838c6aa8f86 Mon Sep 17 00:00:00 2001 From: macournoyer Date: Tue, 24 Nov 2015 15:31:42 -0500 Subject: [PATCH] Revert input and a few fixes. Nothing really seems to help to get better results... --- dataset.lua | 13 +++++++++++++ eval.lua | 35 +++++++++++++++++++++++++++-------- seq2seq.lua | 4 ++++ tokenizer.lua | 2 ++ train.lua | 5 +++-- 5 files changed, 49 insertions(+), 10 deletions(-) diff --git a/dataset.lua b/dataset.lua index ecb3857..bfbb5ec 100644 --- a/dataset.lua +++ b/dataset.lua @@ -13,11 +13,15 @@ Also build the vocabulary. local DataSet = torch.class("e.DataSet") local xlua = require "xlua" local tokenizer = require "tokenizer" +local list = require "pl.list" function DataSet:__init(filename, loader, loadFirst) -- Discard words with lower frequency then this self.minWordFreq = 1 + -- Make length of one text + self.maxTextLen = 10 + -- Load only first fews examples self.loadFirst = loadFirst @@ -125,6 +129,9 @@ function DataSet:visitConversation(lines, start) local targetIds = self:visitText(target.text) if inputIds and targetIds then + -- Revert inputs + inputIds = list.reverse(inputIds) + table.insert(targetIds, 1, self.goToken) table.insert(targetIds, self.eosToken) @@ -141,8 +148,14 @@ function DataSet:visitText(text) return end + local i = 0 + for t, word in tokenizer.tokenize(text) do table.insert(words, self:makeWordId(word)) + i = i + 1 + if i > self.maxTextLen then + break + end end if #words == 0 then diff --git a/eval.lua b/eval.lua index e56490e..fc7cbfe 100644 --- a/eval.lua +++ b/eval.lua @@ -1,5 +1,6 @@ require 'e' local tokenizer = require "tokenizer" +local list = require "pl.list" if dataset == nil then cmd = torch.CmdLine() @@ -17,20 +18,38 @@ if model == nil then model = torch.load("data/model.t7") end +-- Word IDs tensor to sentence +function t2s(t, reverse) + local words = {} + + for i = 1, t:size(1) do + table.insert(words, dataset.id2word[t[i]]) + end + + if reverse then + words = list.reverse(words) + end + + return table.concat(words, " ") +end + +-- for i,example in ipairs(dataset.examples) do +-- print("-- " .. t2s(example[1], true)) +-- print(">> " .. t2s(example[2])) +-- end + function say(text) - local inputs = {} + local wordIds = {} for t, word in tokenizer.tokenize(text) do local id = dataset.word2id[word:lower()] or dataset.unknownToken - table.insert(inputs, id) + table.insert(wordIds, id) end - local outputs = model:eval(torch.Tensor(inputs)) - local words = {} + local input = torch.Tensor(list.reverse(wordIds)) + print("-- " .. t2s(input, true)) - for i,id in ipairs(outputs) do - table.insert(words, dataset.id2word[id]) - end + local output = model:eval(input) - return table.concat(words, " ") + print(">> " .. t2s(torch.Tensor(output))) end diff --git a/seq2seq.lua b/seq2seq.lua index 37e30c4..c624162 100644 --- a/seq2seq.lua +++ b/seq2seq.lua @@ -61,8 +61,12 @@ function Seq2Seq:train(input, target) local zeroTensor = torch.Tensor(2):zero() self.encoder:backward(encoderInput, zeroTensor) + self.encoder:updateGradParameters(self.momentum) + self.decoder:updateGradParameters(self.momentum) self.decoder:updateParameters(self.learningRate) self.encoder:updateParameters(self.learningRate) + self.encoder:zeroGradParameters() + self.decoder:zeroGradParameters() self.decoder:forget() self.encoder:forget() diff --git a/tokenizer.lua b/tokenizer.lua index 04e6854..45e465a 100644 --- a/tokenizer.lua +++ b/tokenizer.lua @@ -31,6 +31,8 @@ function M.tokenize(text) { "^%s+", space }, { "^['\"]", quote }, { "^%w+", word }, + { "^%-+", space }, + { "^%.+", punct }, { "^[,:;%.%?!%-]", punct }, { "^", tag }, { "^.", unknown }, diff --git a/train.lua b/train.lua index 8ba2a4e..46b5ab2 100644 --- a/train.lua +++ b/train.lua @@ -16,7 +16,7 @@ dataset = e.DataSet("data/cornell_movie_dialogs_" .. (options.dataset or "full") e.CornellMovieDialogs("data/cornell_movie_dialogs"), options.dataset) -- Model -local hiddenSize = 100 +local hiddenSize = 300 model = e.Seq2Seq(dataset.wordsCount, hiddenSize) model.goToken = dataset.goToken model.eosToken = dataset.eosToken @@ -24,7 +24,8 @@ model.eosToken = dataset.eosToken -- Training model.criterion = nn.SequencerCriterion(nn.ClassNLLCriterion()) model.learningRate = 0.5 -local epochCount = 20 +model.momentum = 0.9 +local epochCount = 3 local minErr = 0.1 local err = 0