-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathproject.lua
131 lines (97 loc) · 5.88 KB
/
project.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
require("hdf5")
require("nn")
require("optim")
require("seq2seq_att")
-- require("hmm")
-- require("memm")
-- require("structure")
function run_fold(opt)
local model
-- Train.
if opt.classifier == 'seq2seq_att' then
model = seq2seq_att(opt)
end
-- Test. Output test file to evaluate BLUE score
if opt.test==1 then
model.test(model,opt)
end
end
function run(opt)
if opt.gpuid >= 0 then
print('using CUDA on GPU ' .. opt.gpuid .. '...')
if opt.gpuid2 >= 0 then
print('using CUDA on second GPU ' .. opt.gpuid2 .. '...')
end
require 'cutorch'
require 'cunn'
if opt.cudnn == 1 then
print('loading cudnn...')
require 'cudnn'
end
cutorch.setDevice(opt.gpuid)
cutorch.manualSeed(opt.seed)
end
-- Reporting training speed, training set loss, training set predictive accuracy, and validation predictive accuracy
run_fold(opt)
end
cmd = torch.CmdLine()
-- Cmd Args
cmd:option('-trainfile', 'mydata/flickr8k_train.hdf5', 'train data')
cmd:option('-validfile', 'mydata/flickr8k_valid.hdf5', 'valid data')
cmd:option('-testfile', 'mydata/flickr8k_test.hdf5', 'test data')
cmd:option('-outfile', 'out.txt', 'output file')
cmd:option('-lossfile', 'loss.png', 'file to write loss plot png')
cmd:option('-classifier', 'seq2seq_att', 'classifier to use')
cmd:option('-optimizer', 'adam', 'classifier to use [adam|sgd]')
cmd:option('-train', 1, 'run training code')
cmd:option('-test', 1, 'run test')
cmd:option('-word_vec_size', 50, 'words embedded dimension')
cmd:option('-hop_attn', 0, [[If > 0, then use a *hop attention* on this layer of the decoder. For example, if num_layers = 3 and `hop_attn = 2`, then the model will do an attention over the source sequence on the second layer (and use that as input to the third layer) and the penultimate layer]])
cmd:option('-res_net', 0, [[Use residual connections between LSTM stacks whereby the input to the l-th LSTM layer if the hidden state of the l-1-th LSTM layer added with the l-2th LSTM layer. We didn't find this to help in our experiments]])
cmd:option('-curriculum', 0, [[For this many epochs, order the minibatches based on source sequence length. Sometimes setting this to 1 will increase convergence speed.]])
cmd:option('-source_size', 512, [[Source vocab size]])
cmd:option('-target_size', 8388, [[Target vocab size]])
cmd:option('-rnn_size', 500, [[Size of LSTM hidden states]])
cmd:option('-num_layers', 2, [[Number of layers in the LSTM encoder/decoder]])
cmd:option('-savefile', 'seq2seq_att.t7', [[Savefile name (model will be saved as
savefile.t7 where X is the X-th epoch and PPL is
the validation perplexity]])
cmd:option('-train_from', 'seq2seq_att.t7', [[If training from a checkpoint then this is the path to the pretrained model.]])
-- GPU
cmd:option('-gpuid', 1, [[Which gpu to use >0, -1 = use CPU]])
cmd:option('-gpuid2', -1, [[If this is >= 0, then the model will use two GPUs whereby the encoder is on the first GPU and the decoder is on the second GPU. This will allow you to train with bigger batches/models.]])
cmd:option('-cudnn', 1, [[Whether to use cudnn or not for convolutions (for the character model). cudnn has much faster convolutions so this is highly recommended if using the character model]])
-- optimization
cmd:option('-epochs', 20, [[Number of training epochs]])
cmd:option('-param_init', 0.1, [[Parameters are initialized over uniform distribution with support (-param_init, param_init)]])
cmd:option('-learning_rate', 1, [[Starting learning rate]])
cmd:option('-max_grad_norm', 5, [[If the norm of the gradient vector exceeds this, renormalize it to have the norm equal to max_grad_norm]])
cmd:option('-dropout', 0.3, [[Dropout probability. Dropout is applied between vertical LSTM stacks.]])
cmd:option('-lr_decay', 0.5, [[Decay learning rate by this much if (i) perplexity does not decrease on the validation set or (ii) epoch has gone past the start_decay_at_limit]])
cmd:option('-start_decay_at', 9, [[Start decay after this epoch]])
cmd:option('-pre_word_vecs_dec', '', [[If a valid path is specified, then this will load pretrained word embeddings (hdf5 file) on the decoder side. See README for specific formatting instructions.]])
cmd:option('-fix_word_vecs_dec', 0, [[If = 1, fix word embeddings on the decoder side]])
-- bookkeeping
cmd:option('-save_every', 1, [[Save every this many epochs]])
cmd:option('-print_every', 50, [[Print stats after this many batches]])
cmd:option('-seed', 3435, [[Seed for random initialization]])
-- beam
cmd:option('-beam', 5, [[Beam size]])
cmd:option('-targ_dict', 'mydata/idx_to_word.txt', [[Path to target vocabulary, "id word" per line]])
cmd:option('-max_sent_l', 196, [[Maximum sentence length. If any sequences in srcfile are longer than this then it will error out]])
cmd:option('-simple', 0, [[If = 1, output prediction is simply the first time the top of the beam ends with an end-of-sentence token. If = 0, the model considers all hypotheses that have been generated so far that ends with end-of-sentence token and takes the highest scoring of all of them.]])
cmd:option('-init_dec', 1, [[Initialize the hidden/cell state of the decoder at time
0 to be the last hidden/cell state of the encoder. If 0,
the initial states of the decoder are set to zero vectors]])
cmd:option('-reverse_src', 0, [[If = 1, reverse the source sequence. The original
sequence-to-sequence paper found that this was crucial to
achieving good performance, but with attention models this
does not seem necessary. Recommend leaving it to 0]])
cmd:option('-num_sentences', 5, [[Number of sentences per data]])
function main()
-- Parse input params
opt = cmd:parse(arg)
torch.manualSeed(opt.seed)
run(opt)
end
main()