Skip to content

Commit

Permalink
Merge branch 'yoonkim-master'
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Aug 20, 2015
2 parents c0dac23 + 9427c73 commit e10f5f9
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 45 deletions.
73 changes: 69 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
## Neural Language Modeling with Characters
## Character-Aware Neural Language Models
A neural language model (NLM) built on character inputs only. Predictions
are still made at the word-level. The model employs a convolutional neural network (CNN) over characters
to use as inputs into an long short-term memory (LSTM)
recurrent neural network language model (RNN-LM). Also optionally
passes the output from the CNN through a [Highway Network](http://arxiv.org/abs/1507.06228),
which improves performance.

Note: Code is messy/experimental. Cleaner (and faster) code coming. Paper
will be posted on arXiv very soon.
Note: Paper will be posted on arXiv very soon.

Much of the base code is from Andrej Karpathy's excellent character RNN implementation,
available at https://github.com/karpathy/char-rnn

<<<<<<< HEAD
Also, the repo name 'word-char-rnn' is bit of a misnomer, as the primary motivation
is to use character-level inputs only. But as a baseline we implemented the
word-level models (and also experimented with models whereby the input
Expand All @@ -20,24 +20,38 @@ hence the name.

### Requirements
Code is written in Lua and requires Torch. It additionally requires
=======
### Requirements
Code is written in Lua and requires Torch. It also requires
>>>>>>> c4faf308ba29d3fcb8f52a48d88fe07d93399059
the `nngraph` and `optim` packages, which can be installed via:
```
luarocks install nngraph
luarocks install optim
```
GPU usage will additionally require:
GPU usage will additionally require `cutorch` and `cunn` packages:
```
luarocks install cutorch
luarocks install cunn
```

`cudnn` also will result in a good (10x) speed-up.

`cudnn` will result in a good (8x-10x) speed-up for convolutions, so it is
highly recommended. This will make the training time of a character-level model
be somewhat competitive against a word-level model (0.5 secs/batch vs 0.25 secs/batch for
the large character/word-level models described below).

```
git clone https://github.com/soumith/cudnn.torch.git
luarocks make cudnn-scm-1.rockspec
```
### Data
Data should be put into the `data/` directory, split into `train.txt`,
`valid.txt`, and `test.txt`

Each line of the .txt file should be a sentence. The English Penn
<<<<<<< HEAD
Treebank data (Tomas Mikolov's pre-processed version with vocab size equal to 10K,
widely used by the language modeling community) is given as the default.

Expand Down Expand Up @@ -72,6 +86,57 @@ th main.lua -savefile word-small -word_vec_size 200 -highway_layers 0
-use_chars 0 -use_words 1 -rnn_size 200
```

=======
Treebank (PTB) data (Tomas Mikolov's pre-processed version with vocab size equal to 10K,
widely used by the language modeling community) is given as the default.

The paper also runs the models on non-English data (Czech, French, German, Russian, and Spanish), from the ICML 2014
paper [Compositional Morphology for Word Representations and Language Modelling](http://arxiv.org/abs/1405.4273)
by Jan Botha and Phil Blunsom. This can be downloaded from [Jan's website](https://bothameister.github.io).

#### Note on PTB
The PTB data above does not have end-of-sentence tokens for each sentence, and hence these must be
manually appended. This can be done by adding `-EOS '+'` to the script (obviously you
can use other characters than `+` to represent an end-of-sentence token---we recommend a single
unused character).

Jan's datasets already have end-of-sentence tokens for each line so you do not need to
add the `-EOS` command (equivalent to adding `-EOS ''`, which is the default).

### Model
Here are some example scripts. Add `-gpuid 0` to each line to use a GPU (which is
required to get any reasonable speed with the CNN), and `-cudnn 1` to use the
cudnn package.

#### Character-level models
Large character-level model (LSTM-CharCNN-Large in the paper).
This is the default: should get ~82 on valid and ~79 on test.
```
th main.lua -savefile char-large -EOS '+'
```
Small character-level model (LSTM-CharCNN-Small in the paper).
This should get ~96 on valid and ~93 on test.
```
th main.lua -savefile char-small -rnn_size 300 -highway_layers 1
-kernels '{1,2,3,4,5,6}' -feature_maps '{25,50,75,100,125,150}' -EOS '+'
```

#### Word-level models
Large word-level model (LSTM-Word-Large in the paper).
This should get ~89 on valid and ~85 on test.
```
th main.lua -savefile word-large -word_vec_size 650 -highway_layers 0
-use_chars 0 -use_words 1 -EOS '+'
```
Small word-level model (LSTM-Word-Small in the paper).
This should get ~101 on valid and ~98 on test.
```
th main.lua -savefile word-small -word_vec_size 200 -highway_layers 0
-use_chars 0 -use_words 1 -rnn_size 200 -EOS '+'
```

#### Combining both
>>>>>>> c4faf308ba29d3fcb8f52a48d88fe07d93399059
Note that if `-use_chars` and `-use_words` are both set to 1, the model
will concatenate the output from the CNN with the word embedding. We've
found this model to underperform a purely character-level model, though.
Expand Down
21 changes: 14 additions & 7 deletions evaluate.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,19 @@ require 'util.misc'

BatchLoader = require 'util.BatchLoaderUnk'
model_utils = require 'util.model_utils'
HighwayMLP = require 'model.HighwayMLP'
TDNN = require 'model.TDNN'
LSTMTDNN = require 'model.LSTMTDNN'

local stringx = require('pl.stringx')

cmd = torch.CmdLine()
cmd:text('Options')
-- data
cmd:option('-data_dir','data/ptb','data directory. Should contain train.txt/valid.txt/test.txt with input data')
cmd:option('-savefile', 'cv-ptb/lm_results.t7', 'save results to')
cmd:option('-model', 'cv-ptb/lm_model.t7', 'model checkpoint file')
cmd:option('-savefile', 'final-results/lm_results.t7', 'save results to')
cmd:option('-model', 'final-results/en-large-word-model.t7', 'model checkpoint file')
-- GPU/CPU
cmd:option('-gpuid',-1,'which gpu to use. -1 = use CPU')
cmd:text()


-- parse input params
opt2 = cmd:parse(arg)
if opt2.gpuid >= 0 then
Expand All @@ -41,6 +37,17 @@ if opt2.gpuid >= 0 then
require 'cunn'
cutorch.setDevice(opt2.gpuid + 1)
end

if opt.cudnn == 1 then
assert(opt2.gpuid >= 0, 'GPU must be used if using cudnn')
print('using cudnn')
require 'cudnn'
end

HighwayMLP = require 'model.HighwayMLP'
TDNN = require 'model.TDNN'
LSTMTDNN = require 'model.LSTMTDNN'

checkpoint = torch.load(opt2.model)
opt = checkpoint.opt
protos = checkpoint.protos
Expand Down Expand Up @@ -131,4 +138,4 @@ test_results.vocab = {idx2word, word2idx, idx2char, char2idx}
test_results.opt = opt
test_results.val_losses = checkpoint.val_losses
torch.save(opt2.savefile, test_results)
collectgarbage()
collectgarbage()
33 changes: 25 additions & 8 deletions introspect.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ cmd:text('Perform model introspection')
cmd:text()
cmd:text('Options')
-- data
cmd:option('-model','final-results/en-large-model.t7', 'model file')
cmd:option('-model','final-results/en-large-word-model.t7', 'model file')
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
cmd:option('-savefile', 'chargrams.tsv', 'save max chargrams to')
cmd:text()
Expand Down Expand Up @@ -149,15 +149,23 @@ end
function get_all_chargrams(idx2word)
local idx2chargram = {}
local chargram2idx = {}
local count = {}
for i = 1, #idx2word do
local ngrams = get_chargrams(opt.tokens.START .. idx2word[i] .. opt.tokens.END, 2, 7)
for _, ngram in pairs(ngrams) do
if chargram2idx[ngram] == nil then
idx2chargram[#idx2chargram + 1] = ngram
chargram2idx[ngram] = #idx2chargram
if count[ngram] == nil then
count[ngram] = 1
else
count[ngram] = count[ngram] + 1
end
end
end
for ngram, c in pairs(count) do
if c > 3 then
idx2chargram[#idx2chargram + 1] = ngram
chargram2idx[ngram] = #idx2chargram
end
end
return idx2chargram, chargram2idx
end

Expand All @@ -178,15 +186,24 @@ function get_chargram(word, word_len, n, ngrams)
end
end

function get_chargram_vecs()
function get_chargram_vecs(savefile)
idx2chargram, chargram2idx = get_all_chargrams(idx2word)
print(#idx2chargram)
chargram_idx_all = torch.zeros(#idx2chargram, opt.max_word_l)
for i = 1, #idx2chargram do
chargram_idx_all[i] = word2char2idx(idx2chargram[i], opt.max_word_l)
end
chargram_vecs_all = cnn:forward(char_vecs:forward(chargram_idx_all))

return chargram_vecs_all
chargram_vecs_all = char_vecs:forward(chargram_idx_all)
chargram_vecs = torch.zeros(#idx2chargram, torch.sum(torch.Tensor(opt.feature_maps)))
for i = 1, #idx2chargram do
chargram_vecs[i] = cnn:forward(chargram_vecs_all[i]:view(1, opt.max_word_l, opt.char_vec_size)):float()
end
local f = io.open(savefile..'-dic.txt', 'w')
for _, ngram in ipairs(idx2chargram) do
f:write(ngram..'\n')
end
f:close()
torch.save(savefile..'.t7', chargram_vecs)
end

--get contribution of each character to the feature vector by counting
Expand Down
29 changes: 17 additions & 12 deletions main.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
--[[
Trains a word+character-level multi-layer rnn language model
Trains a word-level or character-level (for inputs) lstm language model
Predictions are still made at the word-level.
Much of the code is borrowed from the following implementations
https://github.com/karpathy/char-rnn
Expand All @@ -16,9 +17,6 @@ require 'util.misc'

BatchLoader = require 'util.BatchLoaderUnk'
model_utils = require 'util.model_utils'
TDNN = require 'model.TDNN'
LSTMTDNN = require 'model.LSTMTDNN'
HighwayMLP = require 'model.HighwayMLP'

local stringx = require('pl.stringx')

Expand Down Expand Up @@ -55,12 +53,13 @@ cmd:option('-threads', 16, 'number of threads')
-- bookkeeping
cmd:option('-seed',3435,'torch manual random number generator seed')
cmd:option('-print_every',100,'how many steps/minibatches between printing out the loss')
cmd:option('-checkpoint_dir', 'cv-ptb', 'output directory where checkpoints get written')
cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written')
cmd:option('-savefile','char','filename to autosave the checkpont to. Will be inside checkpoint_dir/')
cmd:option('-checkpoint', 'checkpoint.t7', 'start from a checkpoint if a valid checkpoint.t7 file is given')
cmd:option('-EOS', '+', '<EOS> symbol. should be a single unused character (like +) for PTB and blank for others')
cmd:option('-EOS', '', '<EOS> symbol. should be a single unused character (like +) for PTB and blank for others')
-- GPU/CPU
cmd:option('-gpuid',-1,'which gpu to use. -1 = use CPU')
cmd:option('-cudnn', 0,'use cudnn (1=yes). this should greatly speed up convolutions')
cmd:option('-time', 0, 'print batch times')
cmd:text()

Expand All @@ -83,6 +82,18 @@ if opt.gpuid >= 0 then
cutorch.setDevice(opt.gpuid + 1)
end

if opt.cudnn == 1 then
assert(opt.gpuid >= 0, 'GPU must be used if using cudnn')
print('using cudnn...')
require 'cudnn'
end

-- load models. we do this here instead of before
-- because of cudnn
TDNN = require 'model.TDNN'
LSTMTDNN = require 'model.LSTMTDNN'
HighwayMLP = require 'model.HighwayMLP'

-- some housekeeping
loadstring('opt.kernels = ' .. opt.kernels)() -- get kernel sizes
loadstring('opt.feature_maps = ' .. opt.feature_maps)() -- get feature map sizes
Expand Down Expand Up @@ -353,12 +364,6 @@ for i = 1, iterations do
end
end

some_function()
another_function()
coroutine.resume( some_coroutine )
ProFi:stop()
ProFi:writeReport( 'MyProfilingReport.txt' )

--evaluate on full test set. this just uses the model from the last epoch
--rather than best-performing model. it is also incredibly inefficient
--because of batch size issues. for faster evaluation, use evaluate.lua, i.e.
Expand Down
27 changes: 13 additions & 14 deletions model/TDNN.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
-- Time-delayed Neural Network (i.e. 1-d CNN) with multiple filter widths

local TDNN = {}
local cudnn_status, cudnn = pcall(require, 'cudnn')
--local cudnn_status, cudnn = pcall(require, 'cudnn')

function TDNN.tdnn(length, input_size, feature_maps, kernels)
-- length = length of sentences/words (zero padded to be of same length)
Expand All @@ -15,26 +15,25 @@ function TDNN.tdnn(length, input_size, feature_maps, kernels)
for i = 1, #kernels do
local reduced_l = length - kernels[i] + 1
local pool_layer
if false then
pool_layer = nn.Linear(length * input_size, feature_maps[i])(
nn.View(length * input_size)(input))
elseif not cudnn_status then
-- Temporal conv.
local conv = nn.TemporalConvolution(input_size, feature_maps[i], kernels[i])
local conv_layer = conv(input)
conv.name = 'conv_filter_' .. kernels[i] .. '_' .. feature_maps[i]
--pool_layer = nn.Max(2)(nn.Tanh()(conv_layer))
pool_layer = nn.TemporalMaxPooling(reduced_l)(nn.Tanh()(conv_layer))
pool_layer = nn.Squeeze()(pool_layer)
else
if opt.cudnn == 1 then
-- Use CuDNN for temporal convolution.
if not cudnn then require 'cudnn' end
-- Fake the spatial convolution.
local conv = cudnn.SpatialConvolution(1, feature_maps[i], input_size,
kernels[i], 1, 1, 0)
conv.name = 'conv_filter_' .. kernels[i] .. '_' .. feature_maps[i]
local conv_layer = conv(nn.View(1, -1, input_size):setNumInputDims(2)(input))
pool_layer = nn.Max(3)(nn.Max(3)(nn.Tanh()(conv_layer)))
--pool_layer = nn.Max(3)(nn.Max(3)(nn.Tanh()(conv_layer)))
pool_layer = nn.Squeeze()(cudnn.SpatialMaxPooling(1, reduced_l, 1, 1, 0,0)(conv_layer))
else
-- Temporal conv. much slower
local conv = nn.TemporalConvolution(input_size, feature_maps[i], kernels[i])
local conv_layer = conv(input)
conv.name = 'conv_filter_' .. kernels[i] .. '_' .. feature_maps[i]
--pool_layer = nn.Max(2)(nn.Tanh()(conv_layer))
pool_layer = nn.TemporalMaxPooling(reduced_l)(nn.Tanh()(conv_layer))
pool_layer = nn.Squeeze()(pool_layer)

end
table.insert(layer1, pool_layer)
end
Expand Down

0 comments on commit e10f5f9

Please sign in to comment.