diff --git a/train_flags.py b/train_flags.py index 794ca93..c23f182 100644 --- a/train_flags.py +++ b/train_flags.py @@ -61,6 +61,10 @@ 'dataset', default=None, help=('The file to read examples from')) +flags.DEFINE_string( + 'val_dataset', default=None, + help=('The file to read examples from for validation')) + flags.DEFINE_string( 'export_dataset', default=None, help=('Export the dataset as a .tfrecord file')) @@ -128,6 +132,10 @@ ' --iterations_per_loop. The larger this value is, the higher the' ' utilization on the TPU.')) +flags.DEFINE_integer( + 'loops_per_val_run', default=5, + help=('')) + flags.DEFINE_integer( 'num_parallel_calls', default=64, help=('Cycle length of the parallel interleave in tf.data.dataset.')) diff --git a/train_runner.py b/train_runner.py index 1b4ec97..8dcb3db 100644 --- a/train_runner.py +++ b/train_runner.py @@ -26,12 +26,14 @@ from absl import flags import tensorflow as tf import tflex +import input_fns from tensorflow.contrib import tpu from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.data.util import nest as data_nest from tensorflow.python.framework import graph_io +from tensorflow.python.ops import control_flow_ops FLAGS = flags.FLAGS @@ -205,7 +207,6 @@ def infeed_thread_fn(): tflex.run(self.input_sess, [self.enqueue_ops]) self.build_enqueue_ops(input_fn, params, 0) - def get_tpu_step(mparams): """Get the TPU graph generation function.""" @@ -261,9 +262,49 @@ def tpu_loop(): graph_io.write_graph(tf.Graph().as_graph_def(add_shapes=True), FLAGS.model_dir, "graph.pbtxt") + val_tokens = input_fns.load_source_tokens(FLAGS.val_dataset) + self.val_tokens_var = tf.get_local_variable( + 'val_tokens', + dtype=tf.int32, + shape=[len(val_tokens)], + use_resource=True, + ) + + def get_validation_tensor(mparams): + # TODO: make deterministic so that same samples are used for each validation run + s = tf.size(self.val_tokens_var, out_type=tf.dtypes.int64) + r = tf.random.uniform([], maxval=s-(1024+1), dtype=tf.dtypes.int64) + r1 = tf.range(0, 1024, dtype=tf.int64) + r + r2 = tf.range(0, 1024, dtype=tf.int64) + r + 1 + features = tf.stack([tf.gather(self.val_tokens_var, r1)]) + labels = tf.stack([tf.gather(self.val_tokens_var, r2)]) + estimator_spec = model_fn( + features, + labels, + tf.estimator.ModeKeys.EVAL, + mparams, + ) + return estimator_spec.loss + # Build tpu train model session and initialize graph self.sess = tf.Session(self.cluster_resolver.get_master(), config=self.config) tflex.run(self.sess, initializer) + self.sess.run(self.val_tokens_var.initializer, feed_dict={ + self.val_tokens_var.initializer.inputs[1]: val_tokens, + }) + + (self.validation_loss,) = tpu.shard( + lambda: control_flow_ops.while_loop( + cond=lambda x: tf.constant(True), + body=lambda x: x + get_validation_tensor(params), + loop_vars=[tf.constant(0.0, dtype=tf.float32)], + back_prop=False, + maximum_iterations=100, + ), + inputs=[], + num_shards=FLAGS.num_cores, + outputs_from_all_shards=False, + ) if FLAGS.restore_dir is not None: ckpt = tf.train.latest_checkpoint(FLAGS.restore_dir) @@ -320,11 +361,21 @@ def save(): for i in range(num_threads): checkpoint_threads.append(None) end_step = self.cur_step + self.train_steps + loops_until_next_validation_run = 0 + while self.cur_step < end_step: tflex.check_commands() if tflex.should_quit(): tf.logging.info("TrainRunner: quitting") break + + validation_loss = None + if loops_until_next_validation_run == 0: + validation_loss = self.sess.run(self.validation_loss) / 100.0 + tf.logging.info("validation loss: %.3f", validation_loss) + loops_until_next_validation_run = FLAGS.loops_per_val_run + loops_until_next_validation_run -= 1 + start = time.time() tf.logging.info("TrainRunner: start next %d steps", self.iterations) self.cur_step += self.iterations @@ -363,6 +414,8 @@ def save(): 'train_batch_size_per_core': FLAGS.train_batch_size // FLAGS.num_cores, 'num_cores': FLAGS.num_cores, } + if validation_loss is not None: + eval_results['validation_loss'] = validation_loss for metric in eval_results: values = eval_results[metric] if not isinstance(values, list):