Skip to content

Commit

Permalink
Split eval into another file.
Browse files Browse the repository at this point in the history
  • Loading branch information
macournoyer committed Nov 3, 2015
1 parent 7c1b903 commit c213f48
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 45 deletions.
4 changes: 4 additions & 0 deletions e.lua
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
require 'torch'
require 'nn'
require 'rnn'

e = {}

torch.include('e', 'cornell_movie_dialogs.lua')
Expand Down
43 changes: 43 additions & 0 deletions eval.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
require 'e'
local tokenizer = require "tokenizer"

dataset = e.DataSet("data/cornell_movie_dialogs.t7",
e.CornellMovieDialogs("data/cornell_movie_dialogs"))

EOS = dataset.word2id["</s>"]

print("-- Loading model")
model = torch.load("data/model.t7")

function output2wordId(t)
local max = t:max()
for i = 1, t:size(1) do
if t[i] == max then
return i
end
end
end

function say(text)
local inputs = {}
for t, word in tokenizer.tokenize(text) do
local t = dataset.word2id[word:lower()]
table.insert(inputs, t)
end

model:forget()

for i = #inputs, 1, -1 do
local input = inputs[i]
model:forward(torch.Tensor{input})
end

local input = EOS
repeat
local output = model:forward(torch.Tensor{input})
io.write(dataset.id2word[output2wordId(output)] .. " ")
input = output2wordId(output)
until input == EOS

print("")
end
57 changes: 12 additions & 45 deletions train.lua
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
require 'nn'
require 'rnn'
require 'xlua'
require 'e'
require 'xlua'

-- Data
-- local dataset = e.DataSet("data/cornell_movie_dialogs.t7",
-- e.CornellMovieDialogs("data/cornell_movie_dialogs"))
dataset = e.DataSet("data/cornell_movie_dialogs_tiny.t7",
e.CornellMovieDialogs("data/cornell_movie_dialogs"), 1000)
dataset = e.DataSet("data/cornell_movie_dialogs.t7",
e.CornellMovieDialogs("data/cornell_movie_dialogs"))
-- dataset = e.DataSet("data/cornell_movie_dialogs_tiny.t7",
-- e.CornellMovieDialogs("data/cornell_movie_dialogs"), 1000)

EOS = dataset.word2id["</s>"]

Expand All @@ -22,8 +20,8 @@ model:add(nn.LookupTable(dataset.wordsCount, inputSize))
model:add(nn.SplitTable(1,2))
model:add(nn.Sequencer(nn.FastLSTM(inputSize, hiddenSize)))
model:add(nn.Sequencer(nn.Dropout(dropout)))
-- model:add(nn.Sequencer(nn.FastLSTM(hiddenSize, hiddenSize)))
-- model:add(nn.Sequencer(nn.Dropout(dropout)))
model:add(nn.Sequencer(nn.FastLSTM(hiddenSize, hiddenSize)))
model:add(nn.Sequencer(nn.Dropout(dropout)))
model:add(nn.Sequencer(nn.Linear(hiddenSize, dataset.wordsCount)))
model:add(nn.JoinTable(1,2))
model:add(nn.LogSoftMax())
Expand Down Expand Up @@ -70,44 +68,13 @@ for epoch = 1, epochCount do

model:forget()
xlua.progress(i, #dataset.examples)
end

print("-- Saving model")
torch.save("data/model.t7", model)
end


-- Testing
function output2wordId(t)
local max = t:max()
for i = 1, t:size(1) do
if t[i] == max then
return i
-- TODO remove this when training is faster
if i % 1000 == 0 then
torch.save("data/model.t7", model)
end
end
end

local tokenizer = require "tokenizer"
function say(text)
local inputs = {}
for t, word in tokenizer.tokenize(text) do
local t = dataset.word2id[word:lower()]
table.insert(inputs, t)
end

model:forget()

for i = #inputs, 1, -1 do
local input = inputs[i]
model:forward(torch.Tensor{input})
end

local input = EOS
repeat
local output = model:forward(torch.Tensor{input})
io.write(dataset.id2word[output2wordId(output)] .. " ")
input = output2wordId(output)
until input == EOS

print("")
print("-- Saving model")
torch.save("data/model.t7", model)
end

0 comments on commit c213f48

Please sign in to comment.