Skip to content

Commit

Permalink
faster BatchLoaderUnk
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoon Kim committed Aug 17, 2015
1 parent 3ac4467 commit 5381f4a
Showing 1 changed file with 24 additions and 25 deletions.
49 changes: 24 additions & 25 deletions util/BatchLoaderUnk.lua
Original file line number Diff line number Diff line change
Expand Up @@ -106,54 +106,53 @@ end
function BatchLoaderUnk.text_to_tensor(input_files, out_vocabfile, out_tensorfile, out_charfile, max_word_l)
print('Processing text into tensors...')
local tokens = opt.tokens -- inherit global constants for tokens
local f, rawdata, output, output_char
local f, rawdata
local output_tensors = {} -- output tensors for train/val/test
local output_chars = {} -- output character for train/val/test sets (not tensors yet)
local output_chars = {} -- output character tensors for train/val/test sets
local vocab_count = {} -- vocab count
local max_word_l_tmp = 0
local max_word_l_tmp = 0 -- max word length of the corpus
local idx2word = {tokens.UNK} -- unknown word token
local word2idx = {}; word2idx[tokens.UNK] = 1
local idx2char = {tokens.ZEROPAD, tokens.START, tokens.END} -- zero-pad, start-of-word, end-of-word tokens
local char2idx = {}; char2idx[tokens.ZEROPAD] = 1; char2idx[tokens.START] = 2; char2idx[tokens.END] = 3
local split_counts = {}

-- first go through train/valid/test to get max word length
-- if actual max word length (e.g. 19 for PTB) is smaller than specified
-- we use that instead. this is inefficient, but only a one-off thing
-- we use that instead. this is inefficient, but only a one-off thing so should be fine
-- also counts the number of tokens
for split = 1,3 do -- split = 1 (train), 2 (val), or 3 (test)
f = io.open(input_files[split], 'r')
local counts = 0
for line in f:lines() do
line = stringx.replace(line, '<unk>', tokens.UNK) -- replace unk with a single character
line = stringx.replace(line, tokens.START, '') --start-of-word token is reserved
line = stringx.replace(line, tokens.END, '') --end-of-word token is reserved
for word in line:gmatch'([^%s]+)' do
max_word_l_tmp = math.max(max_word_l_tmp, word:len())
counts = counts + 1
end
if tokens.EOS ~= '' then
counts = counts + 1 --PTB uses \n for <eos>, so need to add one more token at the end
end
end
f:close()
split_counts[split] = counts
end

print('After first pass of data, max word length is: ' .. max_word_l_tmp)
print(string.format('Token count: train %d, val %d, test %d',
split_counts[1], split_counts[2], split_counts[3]))

-- if actual max word length is less than the limit, use that
max_word_l = math.min(max_word_l_tmp, max_word_l)

for split = 1,3 do -- split = 1 (train), 2 (val), or 3 (test)
output = {}
output_char = {}
f = io.open(input_files[split], 'r')
-- First count all the words in the string.
local counts = 0
for line in f:lines() do
line = stringx.replace(line, '<unk>', tokens.UNK) -- replace unk with a single character
line = stringx.replace(line, tokens.START, '') --start-of-word token is reserved
line = stringx.replace(line, tokens.END, '') --end-of-word token is reserved
for word in line:gmatch'([^%s]+)' do
counts = counts + 1
end
counts = counts + 1
end
f:close()

-- Next preallocate the tensors we will need.
for split = 1,3 do -- split = 1 (train), 2 (val), or 3 (test)
-- Preallocate the tensors we will need.
-- Watch out the second one needs a lot of RAM.
output_tensors[split] = torch.LongTensor(counts)
output_chars[split] = torch.ones(counts, max_word_l):long()
output_tensors[split] = torch.LongTensor(split_counts[split])
output_chars[split] = torch.ones(split_counts[split], max_word_l):long()

f = io.open(input_files[split], 'r')
local word_num = 0
for line in f:lines() do
Expand Down Expand Up @@ -193,7 +192,7 @@ function BatchLoaderUnk.text_to_tensor(input_files, out_vocabfile, out_tensorfil
append(rword)
end
if tokens.EOS ~= '' then --PTB does not have <eos> so we add a character for <eos> tokens
append(tokens.EOS) --other datasets with periods or <eos> already present do not need this
append(tokens.EOS) --other datasets don't need this
end
end
end
Expand Down

0 comments on commit 5381f4a

Please sign in to comment.