-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
323 lines (292 loc) · 13.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
#!/usr/bin/python
# -*- coding: utf-8 -*-
# Author: Gozde Sahin
# Code is based on Clara Vania'a subword-lstm-lm project
import numpy as np
import argparse
import time
import os
import pickle
import sys
from utils import TextLoader,get_last_model_path
from model import *
from torch.autograd import Variable
from optimizer import *
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--train_file', type=str, default='data/train.txt',
help="training data")
parser.add_argument('--dev_file', type=str, default='data/dev.txt',
help="development data")
parser.add_argument('--output_vocab_file', type=str, default='',
help="vocab file, only use this if you want to specify special output vocabulary!")
parser.add_argument('--output', '-o', type=str, default='train.log',
help='output file')
parser.add_argument('--save_dir', type=str, default='model',
help='directory to store checkpointed models')
parser.add_argument('--rnn_size', type=int, default=200,
help='size of RNN hidden state')
parser.add_argument('--num_layers', type=int, default=2,
help='number of layers in the RNN')
parser.add_argument('--model', type=str, default='lstm',
help='rnn, gru, or lstm')
parser.add_argument('--unit', type=str, default='char-ngram',
help='char, char-ngram, morpheme, word, oracle or oracle-db')
parser.add_argument('--composition', type=str, default='addition',
help='none(word), addition, or bi-lstm')
parser.add_argument('--lowercase', dest='lowercase', action='store_true',
help='lowercase data', default=False)
parser.add_argument('--batch_size', type=int, default=32,
help='minibatch size')
parser.add_argument('--num_steps', type=int, default=20,
help='RNN sequence length')
parser.add_argument('--out_vocab_size', type=int, default=5000,
help='size of output vocabulary')
parser.add_argument('--num_epochs', type=int, default=100,
help='number of epochs')
parser.add_argument('--patience', type=int, default=3,
help='the number of iterations allowed before decaying the '
'learning rate if there is no improvement on dev set')
parser.add_argument('--validation_interval', type=int, default=1,
help='validation interval')
parser.add_argument('--init_scale', type=float, default=0.1,
help='initial weight scale')
parser.add_argument('--param_init_type', type=str, default="uniform",
help="""Options are [orthogonal|uniform|xavier_n|xavier_u]""")
parser.add_argument('--grad_clip', type=float, default=2.0,
help='maximum permissible norm of the gradient')
parser.add_argument('--learning_rate', type=float, default=1.0,
help='initial learning rate')
parser.add_argument('--decay_rate', type=float, default=0.5,
help='the decay of the learning rate')
parser.add_argument('--keep_prob', type=float, default=0.5,
help='the probability of keeping weights in the dropout layer')
parser.add_argument('--gpu', type=int, default=0,
help='the gpu id, if have more than one gpu')
parser.add_argument('--optimization', type=str, default='sgd',
help='sgd, momentum, or adagrad')
parser.add_argument('--train', type=str, default='softmax',
help='sgd, momentum, or adagrad')
parser.add_argument('--SOS', type=str, default='false',
help='start of sentence symbol')
parser.add_argument('--EOS', type=str, default='true',
help='end of sentence symbol')
parser.add_argument('--ngram', type=int, default=3,
help='length of character ngram (for char-ngram model only)')
parser.add_argument('--char_dim', type=int, default=200,
help='dimension of char embedding (for C2W model only)')
parser.add_argument('--morph_dim', type=int, default=200,
help='dimension of morpheme embedding (for M2W model only)')
parser.add_argument('--word_dim', type=int, default=200,
help='dimension of word embedding (for C2W model only)')
parser.add_argument('--cont', type=str, default='false',
help='continue training')
parser.add_argument('--seed', type=int, default=0,
help='seed for random initialization')
parser.add_argument('--lang', type=str, default='en',
help='Language (en|tr)')
args = parser.parse_args()
# check cuda
use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
otype = torch.cuda.LongTensor if use_cuda else torch.LongTensor
args.dtype = dtype
args.otype = otype
args.use_cuda = use_cuda
train(args)
def lossCriterion(vocabSize):
weight = torch.ones(vocabSize)
crit = nn.NLLLoss(weight, size_average=False)
return crit
def run_epoch(m, data, data_loader, optimizer, eval=False):
if eval:
m.eval()
else:
m.train()
epoch_size = ((len(data) // m.batch_size) - 1) // m.num_steps
start_time = time.time()
costs = 0.0
iters = 0
crit = lossCriterion(data_loader.out_vocab_size)
m.lm_hidden = m.init_hidden(m.num_layers,numdirec=1, batchsize=m.batch_size)
if data_loader.composition == "bi-lstm" or data_loader.composition == "add-bi-lstm":
m.comp_hidden = m.init_hidden(1, numdirec=2, batchsize=(m.batch_size*m.num_steps))
for step, (x, y) in enumerate(data_loader.data_iterator(data, m.batch_size, m.num_steps)):
# make them matrix
# x can be values, x can be indices
if (data_loader.composition=="bi-lstm") or \
(data_loader.composition == "none"): # indices are returned
x = torch.LongTensor(x).type(m.otype)
elif data_loader.composition=="addition" or \
(data_loader.composition == "add-bi-lstm"): # values are returned
x = torch.FloatTensor(x).type(m.dtype)
# y is always indices
y = torch.LongTensor(y).type(m.otype)
# move input tensors to gpu if possible
if m.use_cuda:
x = x.cuda()
y = y.cuda()
crit = crit.cuda()
# require_grad by default false
x_var = Variable(x, volatile=eval)
y_var = Variable(y, volatile=eval)
# zero the gradients
m.zero_grad()
m.lm_hidden = repackage_hidden(m.lm_hidden)
if data_loader.composition=="bi-lstm" or data_loader.composition=="add-bi-lstm":
m.comp_hidden = repackage_hidden(m.comp_hidden)
log_probs = m(x_var)
training_labels = y_var.view(log_probs.size(0))
loss = crit(log_probs, training_labels).div(m.batch_size)
costs += loss.data[0]
iters += m.num_steps
if not eval:
# go backwards and update weights
loss.backward()
optimizer.step()
# report
if not eval and step % (epoch_size // 10) == 10:
print("perplexity: %.3f speed: %.0f wps" %
( np.exp(costs / iters),
iters * m.batch_size / (time.time() - start_time)))
# calculate perplexity
# this is cost per word
cost_norm = (costs/iters)
ppl = math.exp(min(cost_norm, 100.0))
return ppl
def train(args):
start = time.time()
save_dir = args.save_dir
try:
os.stat(save_dir)
except:
os.mkdir(save_dir)
args.eos = ''
args.sos = ''
if args.EOS == "true":
args.eos = '</s>'
args.out_vocab_size += 1
if args.SOS == "true":
args.sos = '<s>'
args.out_vocab_size += 1
local_test = False
if local_test:
# Gozde
# char, char-ngram, morpheme, word, or oracle
args.unit = "oracle-db"
args.composition = "add-bi-lstm"
args.train_file = "data/train.morph"
args.dev_file = "data/dev.morph"
args.batch_size = 12
# End of test
data_loader = TextLoader(args)
train_data = data_loader.train_data
dev_data = data_loader.dev_data
fout = open(os.path.join(args.save_dir, args.output), "a")
args.word_vocab_size = data_loader.word_vocab_size
if args.unit != "word":
args.subword_vocab_size = data_loader.subword_vocab_size
fout.write(str(args) + "\n")
# Statistics of words
fout.write("Word vocab size: " + str(data_loader.word_vocab_size) + "\n")
# Statistics of sub units
if args.unit != "word":
fout.write("Subword vocab size: " + str(data_loader.subword_vocab_size) + "\n")
if args.composition == "bi-lstm":
if args.unit == "char":
fout.write("Maximum word length: " + str(data_loader.max_word_len) + "\n")
args.bilstm_num_steps = data_loader.max_word_len
elif args.unit == "char-ngram":
fout.write("Maximum ngram per word: " + str(data_loader.max_ngram_per_word) + "\n")
args.bilstm_num_steps = data_loader.max_ngram_per_word
elif args.unit == "morpheme" or args.unit == "oracle":
fout.write("Maximum morpheme per word: " + str(data_loader.max_morph_per_word) + "\n")
args.bilstm_num_steps = data_loader.max_morph_per_word
else:
sys.exit("Wrong unit.")
elif args.composition == "add-bi-lstm":
fout.write("Maximum db per word: " + str(data_loader.max_db_per_word) + "\n")
fout.write("Maximum morph per db: " + str(data_loader.max_morph_per_db) + "\n")
args.bilstm_num_steps = data_loader.max_db_per_word
elif args.composition == "addition":
if args.unit not in ["char-ngram", "morpheme", "oracle"]:
sys.exit("Wrong composition.")
else:
sys.exit("Wrong unit/composition.")
else:
if args.composition != "none":
sys.exit("Wrong composition.")
with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
pickle.dump(args, f)
print(args)
if args.unit == "word":
lm_model = WordModel
elif args.composition == "addition":
lm_model = AdditiveModel
elif args.composition == "bi-lstm":
lm_model = BiLSTMModel
elif args.composition == "add-bi-lstm":
lm_model = AddBiLSTMModel
else:
sys.exit("Unknown unit or composition.")
print("Begin training...")
mtrain = lm_model(args)
if args.use_cuda:
mtrain = mtrain.cuda()
nParams = sum([p.nelement() for p in mtrain.parameters()])
print('* number of parameters: %d' % nParams)
optim = Optim(
args.optimization, args.learning_rate, args.grad_clip,
lr_decay=args.decay_rate,
patience=args.patience
)
# update all parameters
optim.set_parameters(mtrain.parameters())
dev_pp = 10000000.0
if args.cont == 'true': # continue training from a saved model
# get model parameters
model_path, e = get_last_model_path(args.save_dir)
saved_model = torch.load(model_path)
mtrain.load_state_dict(saved_model['state_dict'])
# get optimizer states
# not saving learning rate (probably too small so it won't continue training)
optim.last_ppl = saved_model['last_ppl']
else:
# process each epoch
e = 1
while e <= args.num_epochs:
print("Epoch: %d" % e)
print("Learning rate: %.3f" % optim.lr)
# (1) train for one epoch on the training set
train_perplexity = run_epoch(mtrain, train_data, data_loader,optim, eval=False)
print("Train Perplexity: %.3f" % train_perplexity)
# (2) evaluate on the validation set
dev_perplexity = run_epoch(mtrain, dev_data, data_loader, optim, eval=True)
print("Valid Perplexity: %.3f" % dev_perplexity)
# (3) update the learning rate
optim.updateLearningRate(dev_perplexity, e)
# (4) save results and report
diff = dev_pp - dev_perplexity
if diff >= 0.1:
print("Achieve highest perplexity on dev set, save model.")
checkpoint = {
'state_dict': mtrain.state_dict(),
'last_ppl':optim.last_ppl
}
torch.save(checkpoint,
'%s/%s-%d.pt' % (save_dir, "model", e))
dev_pp = dev_perplexity
# write results to file
fout.write("Epoch: %d\n" % e)
fout.write("Learning rate: %.3f\n" % optim.lr)
fout.write("Train Perplexity: %.3f\n" % train_perplexity)
fout.write("Valid Perplexity: %.3f\n" % dev_perplexity)
fout.flush()
if optim.lr < 0.0001:
print('Learning rate too small, stop training.')
break
e += 1
print("Training time: %.0f" % (time.time() - start))
fout.write("Training time: %.0f\n" % (time.time() - start))
if __name__ == '__main__':
main()