-
Notifications
You must be signed in to change notification settings - Fork 353
/
main.py
69 lines (55 loc) · 2.38 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/python
import tensorflow as tf
from config import Config
from model import CaptionGenerator
from dataset import prepare_train_data, prepare_eval_data, prepare_test_data
FLAGS = tf.app.flags.FLAGS
tf.flags.DEFINE_string('phase', 'train',
'The phase can be train, eval or test')
tf.flags.DEFINE_boolean('load', False,
'Turn on to load a pretrained model from either \
the latest checkpoint or a specified file')
tf.flags.DEFINE_string('model_file', None,
'If sepcified, load a pretrained model from this file')
tf.flags.DEFINE_boolean('load_cnn', False,
'Turn on to load a pretrained CNN model')
tf.flags.DEFINE_string('cnn_model_file', './vgg16_no_fc.npy',
'The file containing a pretrained CNN model')
tf.flags.DEFINE_boolean('train_cnn', False,
'Turn on to train both CNN and RNN. \
Otherwise, only RNN is trained')
tf.flags.DEFINE_integer('beam_size', 3,
'The size of beam search for caption generation')
def main(argv):
config = Config()
config.phase = FLAGS.phase
config.train_cnn = FLAGS.train_cnn
config.beam_size = FLAGS.beam_size
with tf.Session() as sess:
if FLAGS.phase == 'train':
# training phase
data = prepare_train_data(config)
model = CaptionGenerator(config)
sess.run(tf.global_variables_initializer())
if FLAGS.load:
model.load(sess, FLAGS.model_file)
if FLAGS.load_cnn:
model.load_cnn(sess, FLAGS.cnn_model_file)
tf.get_default_graph().finalize()
model.train(sess, data)
elif FLAGS.phase == 'eval':
# evaluation phase
coco, data, vocabulary = prepare_eval_data(config)
model = CaptionGenerator(config)
model.load(sess, FLAGS.model_file)
tf.get_default_graph().finalize()
model.eval(sess, coco, data, vocabulary)
else:
# testing phase
data, vocabulary = prepare_test_data(config)
model = CaptionGenerator(config)
model.load(sess, FLAGS.model_file)
tf.get_default_graph().finalize()
model.test(sess, data, vocabulary)
if __name__ == '__main__':
tf.app.run()