Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validation loss #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions train_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@
'dataset', default=None,
help=('The file to read examples from'))

flags.DEFINE_string(
'val_dataset', default=None,
help=('The file to read examples from for validation'))

flags.DEFINE_string(
'export_dataset', default=None,
help=('Export the dataset as a .tfrecord file'))
Expand Down Expand Up @@ -128,6 +132,10 @@
' --iterations_per_loop. The larger this value is, the higher the'
' utilization on the TPU.'))

flags.DEFINE_integer(
'loops_per_val_run', default=5,
help=(''))

flags.DEFINE_integer(
'num_parallel_calls', default=64,
help=('Cycle length of the parallel interleave in tf.data.dataset.'))
Expand Down
55 changes: 54 additions & 1 deletion train_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
from absl import flags
import tensorflow as tf
import tflex
import input_fns

from tensorflow.contrib import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.data.util import nest as data_nest
from tensorflow.python.framework import graph_io
from tensorflow.python.ops import control_flow_ops

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -205,7 +207,6 @@ def infeed_thread_fn():
tflex.run(self.input_sess, [self.enqueue_ops])

self.build_enqueue_ops(input_fn, params, 0)

def get_tpu_step(mparams):
"""Get the TPU graph generation function."""

Expand Down Expand Up @@ -261,9 +262,49 @@ def tpu_loop():
graph_io.write_graph(tf.Graph().as_graph_def(add_shapes=True),
FLAGS.model_dir, "graph.pbtxt")

val_tokens = input_fns.load_source_tokens(FLAGS.val_dataset)
self.val_tokens_var = tf.get_local_variable(
'val_tokens',
dtype=tf.int32,
shape=[len(val_tokens)],
use_resource=True,
)

def get_validation_tensor(mparams):
# TODO: make deterministic so that same samples are used for each validation run
s = tf.size(self.val_tokens_var, out_type=tf.dtypes.int64)
r = tf.random.uniform([], maxval=s-(1024+1), dtype=tf.dtypes.int64)
r1 = tf.range(0, 1024, dtype=tf.int64) + r
r2 = tf.range(0, 1024, dtype=tf.int64) + r + 1
features = tf.stack([tf.gather(self.val_tokens_var, r1)])
labels = tf.stack([tf.gather(self.val_tokens_var, r2)])
estimator_spec = model_fn(
features,
labels,
tf.estimator.ModeKeys.EVAL,
mparams,
)
return estimator_spec.loss

# Build tpu train model session and initialize graph
self.sess = tf.Session(self.cluster_resolver.get_master(), config=self.config)
tflex.run(self.sess, initializer)
self.sess.run(self.val_tokens_var.initializer, feed_dict={
self.val_tokens_var.initializer.inputs[1]: val_tokens,
})

(self.validation_loss,) = tpu.shard(
lambda: control_flow_ops.while_loop(
cond=lambda x: tf.constant(True),
body=lambda x: x + get_validation_tensor(params),
loop_vars=[tf.constant(0.0, dtype=tf.float32)],
back_prop=False,
maximum_iterations=100,
),
inputs=[],
num_shards=FLAGS.num_cores,
outputs_from_all_shards=False,
)

if FLAGS.restore_dir is not None:
ckpt = tf.train.latest_checkpoint(FLAGS.restore_dir)
Expand Down Expand Up @@ -320,11 +361,21 @@ def save():
for i in range(num_threads):
checkpoint_threads.append(None)
end_step = self.cur_step + self.train_steps
loops_until_next_validation_run = 0

while self.cur_step < end_step:
tflex.check_commands()
if tflex.should_quit():
tf.logging.info("TrainRunner: quitting")
break

validation_loss = None
if loops_until_next_validation_run == 0:
validation_loss = self.sess.run(self.validation_loss) / 100.0
tf.logging.info("validation loss: %.3f", validation_loss)
loops_until_next_validation_run = FLAGS.loops_per_val_run
loops_until_next_validation_run -= 1

start = time.time()
tf.logging.info("TrainRunner: start next %d steps", self.iterations)
self.cur_step += self.iterations
Expand Down Expand Up @@ -363,6 +414,8 @@ def save():
'train_batch_size_per_core': FLAGS.train_batch_size // FLAGS.num_cores,
'num_cores': FLAGS.num_cores,
}
if validation_loss is not None:
eval_results['validation_loss'] = validation_loss
for metric in eval_results:
values = eval_results[metric]
if not isinstance(values, list):
Expand Down