From 3e98be279fc07be55d738ef278015283a260c3d9 Mon Sep 17 00:00:00 2001 From: pan463194277 Date: Thu, 10 Aug 2017 15:29:26 +0800 Subject: [PATCH] fix the bug for eval function while variable_update=parameter_server|distributed_replicated --- scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py | 17 +++++++++++------ scripts/tf_cnn_benchmarks/variable_mgr.py | 9 +++++++-- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py b/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py index a75ac5bf..4425732b 100644 --- a/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py +++ b/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py @@ -698,7 +698,7 @@ def get_perf_timing_str(batch_size, step_train_times, scale=1): def load_checkpoint(saver, sess, ckpt_dir): ckpt = tf.train.get_checkpoint_state(ckpt_dir) if ckpt and ckpt.model_checkpoint_path: - if os.path.isabs(ckpt.model_checkpoint_path): + if os.path.isabs(ckpt.model_checkpoint_path) or ckpt.model_checkpoint_path.startswith('hdfs'): # Restores from checkpoint with absolute path. model_checkpoint_path = ckpt.model_checkpoint_path else: @@ -858,7 +858,8 @@ def print_info(self): log_fn('Model: %s' % self.model) log_fn('Mode: %s' % get_mode_from_flags()) log_fn('Batch size: %s global' % self.batch_size) - log_fn(' %s per device' % (self.batch_size / len(self.devices))) + if self.devices: + log_fn(' %s per device' % (self.batch_size / len(self.devices))) log_fn('Devices: %s' % self.raw_devices) log_fn('Data format: %s' % self.data_format) log_fn('Optimizer: %s' % FLAGS.optimizer) @@ -876,7 +877,7 @@ def run(self): log_fn('Running parameter server %s' % self.task_index) self.server.join() return - + with tf.Graph().as_default(): if FLAGS.eval: self._eval_cnn() @@ -886,11 +887,14 @@ def run(self): def _eval_cnn(self): """Evaluate the model from a checkpoint using validation dataset.""" (enqueue_ops, fetches) = self._build_model() - saver = tf.train.Saver(tf.global_variables()) + variables_to_save =self.variable_mgr.get_variables_to_save() + saver = tf.train.Saver(variables_to_save) summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, tf.get_default_graph()) - target = '' + + target = self.server.target if self.server else '' with tf.Session(target=target, config=create_config_proto()) as sess: + sess.run(tf.local_variables_initializer()) for i in xrange(len(enqueue_ops)): sess.run(enqueue_ops[:(i+1)]) if FLAGS.train_dir is None: @@ -959,10 +963,11 @@ def _benchmark_cnn(self): # passing in None for summary_op to avoid a summary_thread being started. # Running summaries and training operations in parallel could run out of # GPU memory. + saver=tf.train.Saver(self.variable_mgr.get_variables_to_save()) sv = tf.train.Supervisor( is_chief=is_chief, logdir=FLAGS.train_dir, - saver=tf.train.Saver(tf.global_variables()), + saver=saver, global_step=global_step, summary_op=None, save_model_secs=FLAGS.save_model_secs, diff --git a/scripts/tf_cnn_benchmarks/variable_mgr.py b/scripts/tf_cnn_benchmarks/variable_mgr.py index fd134842..44f88414 100644 --- a/scripts/tf_cnn_benchmarks/variable_mgr.py +++ b/scripts/tf_cnn_benchmarks/variable_mgr.py @@ -202,7 +202,11 @@ def trainable_variables_on_device(self, device_num, writable=False): else: params = tf.trainable_variables() return params - + def get_variables_to_save(self): + """ it decides what variables collection will be used to save to checkpoint + tf.global_variables() as default + """ + return tf.global_variables() class VariableMgrIndependent(VariableMgr): """VariableMgr that implements the --independent mode for local jobs. @@ -639,7 +643,8 @@ def strip_port(s): def get_devices(self): return self.benchmark_cnn.raw_devices - + def get_variables_to_save(self): + return tf.local_variables() def sum_grad_and_var_all_reduce(grad_and_vars, devices): # Note that each grad_and_vars looks like the following: