Skip to content

Commit

Permalink
Revert input and a few fixes.
Browse files Browse the repository at this point in the history
Nothing really seems to help to get better results...
  • Loading branch information
macournoyer committed Nov 24, 2015
1 parent 3c24f40 commit 05b6ce6
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 10 deletions.
13 changes: 13 additions & 0 deletions dataset.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@ Also build the vocabulary.
local DataSet = torch.class("e.DataSet")
local xlua = require "xlua"
local tokenizer = require "tokenizer"
local list = require "pl.list"

function DataSet:__init(filename, loader, loadFirst)
-- Discard words with lower frequency then this
self.minWordFreq = 1

-- Make length of one text
self.maxTextLen = 10

-- Load only first fews examples
self.loadFirst = loadFirst

Expand Down Expand Up @@ -125,6 +129,9 @@ function DataSet:visitConversation(lines, start)
local targetIds = self:visitText(target.text)

if inputIds and targetIds then
-- Revert inputs
inputIds = list.reverse(inputIds)

table.insert(targetIds, 1, self.goToken)
table.insert(targetIds, self.eosToken)

Expand All @@ -141,8 +148,14 @@ function DataSet:visitText(text)
return
end

local i = 0

for t, word in tokenizer.tokenize(text) do
table.insert(words, self:makeWordId(word))
i = i + 1
if i > self.maxTextLen then
break
end
end

if #words == 0 then
Expand Down
35 changes: 27 additions & 8 deletions eval.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
require 'e'
local tokenizer = require "tokenizer"
local list = require "pl.list"

if dataset == nil then
cmd = torch.CmdLine()
Expand All @@ -17,20 +18,38 @@ if model == nil then
model = torch.load("data/model.t7")
end

-- Word IDs tensor to sentence
function t2s(t, reverse)
local words = {}

for i = 1, t:size(1) do
table.insert(words, dataset.id2word[t[i]])
end

if reverse then
words = list.reverse(words)
end

return table.concat(words, " ")
end

-- for i,example in ipairs(dataset.examples) do
-- print("-- " .. t2s(example[1], true))
-- print(">> " .. t2s(example[2]))
-- end

function say(text)
local inputs = {}
local wordIds = {}

for t, word in tokenizer.tokenize(text) do
local id = dataset.word2id[word:lower()] or dataset.unknownToken
table.insert(inputs, id)
table.insert(wordIds, id)
end

local outputs = model:eval(torch.Tensor(inputs))
local words = {}
local input = torch.Tensor(list.reverse(wordIds))
print("-- " .. t2s(input, true))

for i,id in ipairs(outputs) do
table.insert(words, dataset.id2word[id])
end
local output = model:eval(input)

return table.concat(words, " ")
print(">> " .. t2s(torch.Tensor(output)))
end
4 changes: 4 additions & 0 deletions seq2seq.lua
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ function Seq2Seq:train(input, target)
local zeroTensor = torch.Tensor(2):zero()
self.encoder:backward(encoderInput, zeroTensor)

self.encoder:updateGradParameters(self.momentum)
self.decoder:updateGradParameters(self.momentum)
self.decoder:updateParameters(self.learningRate)
self.encoder:updateParameters(self.learningRate)
self.encoder:zeroGradParameters()
self.decoder:zeroGradParameters()

self.decoder:forget()
self.encoder:forget()
Expand Down
2 changes: 2 additions & 0 deletions tokenizer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ function M.tokenize(text)
{ "^%s+", space },
{ "^['\"]", quote },
{ "^%w+", word },
{ "^%-+", space },
{ "^%.+", punct },
{ "^[,:;%.%?!%-]", punct },
{ "^</?.->", tag },
{ "^.", unknown },
Expand Down
5 changes: 3 additions & 2 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ dataset = e.DataSet("data/cornell_movie_dialogs_" .. (options.dataset or "full")
e.CornellMovieDialogs("data/cornell_movie_dialogs"), options.dataset)

-- Model
local hiddenSize = 100
local hiddenSize = 300
model = e.Seq2Seq(dataset.wordsCount, hiddenSize)
model.goToken = dataset.goToken
model.eosToken = dataset.eosToken

-- Training
model.criterion = nn.SequencerCriterion(nn.ClassNLLCriterion())
model.learningRate = 0.5
local epochCount = 20
model.momentum = 0.9
local epochCount = 3
local minErr = 0.1
local err = 0

Expand Down

0 comments on commit 05b6ce6

Please sign in to comment.