-
Notifications
You must be signed in to change notification settings - Fork 346
/
Copy patheval.lua
43 lines (34 loc) · 902 Bytes
/
eval.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
require 'e'
local tokenizer = require "tokenizer"
dataset = e.DataSet("data/cornell_movie_dialogs.t7",
e.CornellMovieDialogs("data/cornell_movie_dialogs"))
EOS = dataset.word2id["</s>"]
print("-- Loading model")
model = torch.load("data/model.t7")
function output2wordId(t)
local max = t:max()
for i = 1, t:size(1) do
if t[i] == max then
return i
end
end
end
function say(text)
local inputs = {}
for t, word in tokenizer.tokenize(text) do
local t = dataset.word2id[word:lower()]
table.insert(inputs, t)
end
model:forget()
for i = #inputs, 1, -1 do
local input = inputs[i]
model:forward(torch.Tensor{input})
end
local input = EOS
repeat
local output = model:forward(torch.Tensor{input})
io.write(dataset.id2word[output2wordId(output)] .. " ")
input = output2wordId(output)
until input == EOS
print("")
end