-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtrain.py
121 lines (100 loc) · 5.26 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#coding=utf-8
import tensorflow as tf
import model
import time
import os
from load_data import read_dataset, batch_iter
# Data loading params
tf.flags.DEFINE_string("data_dir", "data/data.dat", "data directory")
tf.flags.DEFINE_integer("vocab_size", 46960, "vocabulary size")
tf.flags.DEFINE_integer("num_classes", 5, "number of classes")
tf.flags.DEFINE_integer("embedding_size", 200, "Dimensionality of character embedding (default: 200)")
tf.flags.DEFINE_integer("hidden_size", 50, "Dimensionality of GRU hidden layer (default: 50)")
tf.flags.DEFINE_integer("batch_size", 32, "Batch Size (default: 64)")
tf.flags.DEFINE_integer("num_epochs", 10, "Number of training epochs (default: 50)")
tf.flags.DEFINE_integer("checkpoint_every", 100, "Save model after this many steps (default: 100)")
tf.flags.DEFINE_integer("num_checkpoints", 5, "Number of checkpoints to store (default: 5)")
tf.flags.DEFINE_integer("evaluate_every", 100, "evaluate every this many batches")
tf.flags.DEFINE_float("learning_rate", 0.01, "learning rate")
tf.flags.DEFINE_float("grad_clip", 5, "grad clip to prevent gradient explode")
FLAGS = tf.flags.FLAGS
train_x, train_y, dev_x, dev_y = read_dataset()
print "data load finished"
with tf.Session() as sess:
han = model.HAN(vocab_size=FLAGS.vocab_size,
num_classes=FLAGS.num_classes,
embedding_size=FLAGS.embedding_size,
hidden_size=FLAGS.hidden_size)
with tf.name_scope('loss'):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=han.input_y,
logits=han.out,
name='loss'))
with tf.name_scope('accuracy'):
predict = tf.argmax(han.out, axis=1, name='predict')
label = tf.argmax(han.input_y, axis=1, name='label')
acc = tf.reduce_mean(tf.cast(tf.equal(predict, label), tf.float32))
timestamp = str(int(time.time()))
out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
print("Writing to {}\n".format(out_dir))
global_step = tf.Variable(0, trainable=False)
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
# RNN中常用的梯度截断,防止出现梯度过大难以求导的现象
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), FLAGS.grad_clip)
grads_and_vars = tuple(zip(grads, tvars))
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
# Keep track of gradient values and sparsity (optional)
grad_summaries = []
for g, v in grads_and_vars:
if g is not None:
grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g)
grad_summaries.append(grad_hist_summary)
grad_summaries_merged = tf.summary.merge(grad_summaries)
loss_summary = tf.summary.scalar('loss', loss)
acc_summary = tf.summary.scalar('accuracy', acc)
train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged])
train_summary_dir = os.path.join(out_dir, "summaries", "train")
train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)
dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)
checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
checkpoint_prefix = os.path.join(checkpoint_dir, "model")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)
sess.run(tf.global_variables_initializer())
def train_step(x_batch, y_batch):
feed_dict = {
han.input_x: x_batch,
han.input_y: y_batch,
han.max_sentence_num: 30,
han.max_sentence_length: 30,
han.batch_size: 64
}
_, step, summaries, cost, accuracy = sess.run([train_op, global_step, train_summary_op, loss, acc], feed_dict)
time_str = str(int(time.time()))
print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, cost, accuracy))
train_summary_writer.add_summary(summaries, step)
return step
def dev_step(x_batch, y_batch, writer=None):
feed_dict = {
han.input_x: x_batch,
han.input_y: y_batch,
han.max_sentence_num: 30,
han.max_sentence_length: 30,
han.batch_size: 64
}
step, summaries, cost, accuracy = sess.run([global_step, dev_summary_op, loss, acc], feed_dict)
time_str = str(int(time.time()))
print("++++++++++++++++++dev++++++++++++++{}: step {}, loss {:g}, acc {:g}".format(time_str, step, cost, accuracy))
if writer:
writer.add_summary(summaries, step)
for epoch in range(FLAGS.num_epochs):
print('current epoch %s' % (epoch + 1))
for i in range(0, 200000, FLAGS.batch_size):
x = train_x[i:i + FLAGS.batch_size]
y = train_y[i:i + FLAGS.batch_size]
step = train_step(x, y)
if step % FLAGS.evaluate_every == 0:
dev_step(dev_x, dev_y, dev_summary_writer)