From f74654687578268983d28ca977c9fe1e59753de0 Mon Sep 17 00:00:00 2001 From: thinkpad Date: Tue, 24 Apr 2018 00:29:28 +0800 Subject: [PATCH] fix some bugs reported at issur #8. --- README.md | 2 +- cnn_lstm_otc_ocr.py | 16 ++++++++-------- main.py | 15 ++++++--------- utils.py | 4 ++-- 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 06f002e..1c53885 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ CUDA_VISIBLE_DEVICES=0 python ./main.py --train_dir=../imgs/train/ \ --image_height=60 \ --image_width=180 \ --image_channel=1 \ - --max_stepsize=64 \ + --out_channels=64 \ --num_hidden=128 \ --batch_size=128 \ --log_dir=./log/train \ diff --git a/cnn_lstm_otc_ocr.py b/cnn_lstm_otc_ocr.py index 88e8bcf..1b27b74 100644 --- a/cnn_lstm_otc_ocr.py +++ b/cnn_lstm_otc_ocr.py @@ -18,7 +18,7 @@ def __init__(self, mode): # SparseTensor required by ctc_loss op self.labels = tf.sparse_placeholder(tf.int32) # 1d array of size [batch_size] - self.seq_len = tf.placeholder(tf.int32, [None]) + # self.seq_len = tf.placeholder(tf.int32, [None]) # l2 self._extra_train_ops = [] @@ -29,7 +29,7 @@ def build_graph(self): self.merged_summay = tf.summary.merge_all() def _build_model(self): - filters = [1, 64, 128, 128, FLAGS.max_stepsize] + filters = [1, 64, 128, 128, FLAGS.out_channels] strides = [1, 2] feature_h = FLAGS.image_height @@ -54,15 +54,15 @@ def _build_model(self): # print('----x.get_shape().as_list(): {}'.format(x.get_shape().as_list())) _, feature_h, feature_w, _ = x.get_shape().as_list() - print('feature_h: {}, feature_w: {}'.format(feature_h, feature_w)) + print('\nfeature_h: {}, feature_w: {}'.format(feature_h, feature_w)) # LSTM part with tf.variable_scope('lstm'): - x = tf.reshape(x, [FLAGS.batch_size, -1, filters[4]]) # [batch_size, num_features, max_stepsize] - x = tf.transpose(x, [0, 2, 1]) # [batch_size, max_stepsize, num_features] - # shp = x.get_shape().as_list() - # x.set_shape([FLAGS.batch_size, filters[3], shp[1]]) - x.set_shape([FLAGS.batch_size, filters[4], feature_h * feature_w]) + x = tf.transpose(x, [0, 2, 1, 3]) # [batch_size, feature_w, feature_h, FLAGS.out_channels] + x = tf.reshape(x, [FLAGS.batch_size, feature_w, feature_h * FLAGS.out_channels]) + print('lstm input shape: {}'.format(x.get_shape().as_list())) + self.seq_len = tf.fill([x.get_shape().as_list()[0]], feature_w) + # print('self.seq_len.shape: {}'.format(self.seq_len.shape.as_list())) # tf.nn.rnn_cell.RNNCell, tf.nn.rnn_cell.GRUCell cell = tf.nn.rnn_cell.LSTMCell(FLAGS.num_hidden, state_is_tuple=True) diff --git a/main.py b/main.py index d416b2f..196e39e 100644 --- a/main.py +++ b/main.py @@ -31,7 +31,7 @@ def train(train_dir=None, val_dir=None, mode='train'): print('loading validation data') val_feeder = utils.DataIterator(data_dir=val_dir) - print('size: ', val_feeder.size) + print('size: {}\n'.format(val_feeder.size)) num_train_samples = train_feeder.size # 100000 num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size) # example: 100000/100 @@ -69,12 +69,11 @@ def train(train_dir=None, val_dir=None, mode='train'): batch_time = time.time() indexs = [shuffle_idx[i % num_train_samples] for i in range(cur_batch * FLAGS.batch_size, (cur_batch + 1) * FLAGS.batch_size)] - batch_inputs, batch_seq_len, batch_labels = \ + batch_inputs, _, batch_labels = \ train_feeder.input_index_generate_batch(indexs) # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size) feed = {model.inputs: batch_inputs, - model.labels: batch_labels, - model.seq_len: batch_seq_len} + model.labels: batch_labels} # if summary is needed # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed) @@ -103,11 +102,10 @@ def train(train_dir=None, val_dir=None, mode='train'): for j in range(num_batches_per_epoch_val): indexs_val = [shuffle_idx_val[i % num_val_samples] for i in range(j * FLAGS.batch_size, (j + 1) * FLAGS.batch_size)] - val_inputs, val_seq_len, val_labels = \ + val_inputs, _, val_labels = \ val_feeder.input_index_generate_batch(indexs_val) val_feed = {model.inputs: val_inputs, - model.labels: val_labels, - model.seq_len: val_seq_len} + model.labels: val_labels} dense_decoded, lastbatch_err, lr = \ sess.run([model.dense_decoded, model.cost, model.lrn_rate], @@ -177,8 +175,7 @@ def get_input_lens(seqs): seq_len_input = np.asarray(seq_len_input) seq_len_input = np.reshape(seq_len_input, [-1]) - feed = {model.inputs: imgs_input, - model.seq_len: seq_len_input} + feed = {model.inputs: imgs_input} dense_decoded_code = sess.run(model.dense_decoded, feed) for item in dense_decoded_code: diff --git a/utils.py b/utils.py index 7fb2012..b622051 100644 --- a/utils.py +++ b/utils.py @@ -21,7 +21,7 @@ tf.app.flags.DEFINE_integer('image_channel', 1, 'image channels as input') tf.app.flags.DEFINE_integer('cnn_count', 4, 'count of cnn module to extract image features.') -tf.app.flags.DEFINE_integer('max_stepsize', 64, +tf.app.flags.DEFINE_integer('out_channels', 64, 'max stepsize in lstm, as well as the output channels of last layer in CNN') tf.app.flags.DEFINE_integer('num_hidden', 128, 'number of hidden units in lstm') tf.app.flags.DEFINE_float('output_keep_prob', 0.8, 'output_keep_prob in lstm') @@ -101,7 +101,7 @@ def input_index_generate_batch(self, index=None): def get_input_lens(sequences): # 64 is the output channels of the last layer of CNN - lengths = np.asarray([FLAGS.max_stepsize for _ in sequences], dtype=np.int64) + lengths = np.asarray([FLAGS.out_channels for _ in sequences], dtype=np.int64) return sequences, lengths