|
| 1 | +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Trains an image-to-image translation network with an adversarial loss.""" |
| 16 | + |
| 17 | +from __future__ import absolute_import |
| 18 | +from __future__ import division |
| 19 | +from __future__ import print_function |
| 20 | + |
| 21 | + |
| 22 | + |
| 23 | +import tensorflow as tf |
| 24 | + |
| 25 | +import data_provider |
| 26 | +from google3.third_party.tensorflow_models.gan.pix2pix import networks |
| 27 | + |
| 28 | +flags = tf.flags |
| 29 | +tfgan = tf.contrib.gan |
| 30 | + |
| 31 | + |
| 32 | +flags.DEFINE_integer('batch_size', 10, 'The number of images in each batch.') |
| 33 | + |
| 34 | +flags.DEFINE_integer('patch_size', 32, 'The size of the patches to train on.') |
| 35 | + |
| 36 | +flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.') |
| 37 | + |
| 38 | +flags.DEFINE_string('train_log_dir', '/tmp/pix2pix/', |
| 39 | + 'Directory where to write event logs.') |
| 40 | + |
| 41 | +flags.DEFINE_float('generator_lr', 0.00001, |
| 42 | + 'The compression model learning rate.') |
| 43 | + |
| 44 | +flags.DEFINE_float('discriminator_lr', 0.00001, |
| 45 | + 'The discriminator learning rate.') |
| 46 | + |
| 47 | +flags.DEFINE_integer('max_number_of_steps', 2000000, |
| 48 | + 'The maximum number of gradient steps.') |
| 49 | + |
| 50 | +flags.DEFINE_integer( |
| 51 | + 'ps_tasks', 0, |
| 52 | + 'The number of parameter servers. If the value is 0, then the parameters ' |
| 53 | + 'are handled locally by the worker.') |
| 54 | + |
| 55 | +flags.DEFINE_integer( |
| 56 | + 'task', 0, |
| 57 | + 'The Task ID. This value is used when training with multiple workers to ' |
| 58 | + 'identify each worker.') |
| 59 | + |
| 60 | +flags.DEFINE_float( |
| 61 | + 'weight_factor', 0.0, |
| 62 | + 'How much to weight the adversarial loss relative to pixel loss.') |
| 63 | + |
| 64 | +flags.DEFINE_string('dataset_dir', None, 'Location of data.') |
| 65 | + |
| 66 | + |
| 67 | +FLAGS = flags.FLAGS |
| 68 | + |
| 69 | + |
| 70 | +def main(_): |
| 71 | + if not tf.gfile.Exists(FLAGS.train_log_dir): |
| 72 | + tf.gfile.MakeDirs(FLAGS.train_log_dir) |
| 73 | + |
| 74 | + with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): |
| 75 | + # Get real and distorted images. |
| 76 | + with tf.device('/cpu:0'), tf.name_scope('inputs'): |
| 77 | + real_images = data_provider.provide_data( |
| 78 | + 'train', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir, |
| 79 | + patch_size=FLAGS.patch_size) |
| 80 | + distorted_images = _distort_images( |
| 81 | + real_images, downscale_size=int(FLAGS.patch_size / 2), |
| 82 | + upscale_size=FLAGS.patch_size) |
| 83 | + |
| 84 | + # Create a GANModel tuple. |
| 85 | + gan_model = tfgan.gan_model( |
| 86 | + generator_fn=networks.generator, |
| 87 | + discriminator_fn=networks.discriminator, |
| 88 | + real_data=real_images, |
| 89 | + generator_inputs=distorted_images) |
| 90 | + tfgan.eval.add_image_comparison_summaries( |
| 91 | + gan_model, num_comparisons=3, display_diffs=True) |
| 92 | + tfgan.eval.add_gan_model_image_summaries(gan_model, grid_size=3) |
| 93 | + |
| 94 | + # Define the GANLoss tuple using standard library functions. |
| 95 | + with tf.name_scope('losses'): |
| 96 | + gan_loss = tfgan.gan_loss( |
| 97 | + gan_model, |
| 98 | + generator_loss_fn=tfgan.losses.least_squares_generator_loss, |
| 99 | + discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss) |
| 100 | + |
| 101 | + # Define the standard L1 pixel loss. |
| 102 | + l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, |
| 103 | + ord=1) / FLAGS.patch_size ** 2 |
| 104 | + |
| 105 | + # Modify the loss tuple to include the pixel loss. Add summaries as well. |
| 106 | + gan_loss = tfgan.losses.combine_adversarial_loss( |
| 107 | + gan_loss, gan_model, l1_pixel_loss, |
| 108 | + weight_factor=FLAGS.weight_factor) |
| 109 | + |
| 110 | + with tf.name_scope('train_ops'): |
| 111 | + # Get the GANTrain ops using the custom optimizers and optional |
| 112 | + # discriminator weight clipping. |
| 113 | + gen_lr, dis_lr = _lr(FLAGS.generator_lr, FLAGS.discriminator_lr) |
| 114 | + gen_opt, dis_opt = _optimizer(gen_lr, dis_lr) |
| 115 | + train_ops = tfgan.gan_train_ops( |
| 116 | + gan_model, |
| 117 | + gan_loss, |
| 118 | + generator_optimizer=gen_opt, |
| 119 | + discriminator_optimizer=dis_opt, |
| 120 | + summarize_gradients=True, |
| 121 | + colocate_gradients_with_ops=True, |
| 122 | + aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N, |
| 123 | + transform_grads_fn=tf.contrib.training.clip_gradient_norms_fn(1e3)) |
| 124 | + tf.summary.scalar('generator_lr', gen_lr) |
| 125 | + tf.summary.scalar('discriminator_lr', dis_lr) |
| 126 | + |
| 127 | + # Use GAN train step function if using adversarial loss, otherwise |
| 128 | + # only train the generator. |
| 129 | + train_steps = tfgan.GANTrainSteps( |
| 130 | + generator_train_steps=1, |
| 131 | + discriminator_train_steps=int(FLAGS.weight_factor > 0)) |
| 132 | + |
| 133 | + # Run the alternating training loop. Skip it if no steps should be taken |
| 134 | + # (used for graph construction tests). |
| 135 | + status_message = tf.string_join( |
| 136 | + ['Starting train step: ', |
| 137 | + tf.as_string(tf.train.get_or_create_global_step())], |
| 138 | + name='status_message') |
| 139 | + if FLAGS.max_number_of_steps == 0: return |
| 140 | + tfgan.gan_train( |
| 141 | + train_ops, |
| 142 | + FLAGS.train_log_dir, |
| 143 | + get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), |
| 144 | + hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps), |
| 145 | + tf.train.LoggingTensorHook([status_message], every_n_iter=10)], |
| 146 | + master=FLAGS.master, |
| 147 | + is_chief=FLAGS.task == 0) |
| 148 | + |
| 149 | + |
| 150 | +def _optimizer(gen_lr, dis_lr): |
| 151 | + kwargs = {'beta1': 0.5, 'beta2': 0.999} |
| 152 | + generator_opt = tf.train.AdamOptimizer(gen_lr, **kwargs) |
| 153 | + discriminator_opt = tf.train.AdamOptimizer(dis_lr, **kwargs) |
| 154 | + return generator_opt, discriminator_opt |
| 155 | + |
| 156 | + |
| 157 | +def _lr(gen_lr_base, dis_lr_base): |
| 158 | + """Return the generator and discriminator learning rates.""" |
| 159 | + gen_lr = tf.train.exponential_decay( |
| 160 | + learning_rate=gen_lr_base, |
| 161 | + global_step=tf.train.get_or_create_global_step(), |
| 162 | + decay_steps=100000, |
| 163 | + decay_rate=0.8, |
| 164 | + staircase=True,) |
| 165 | + dis_lr = dis_lr_base |
| 166 | + |
| 167 | + return gen_lr, dis_lr |
| 168 | + |
| 169 | + |
| 170 | +def _distort_images(images, downscale_size, upscale_size): |
| 171 | + downscaled = tf.image.resize_area(images, [downscale_size] * 2) |
| 172 | + upscaled = tf.image.resize_area(downscaled, [upscale_size] * 2) |
| 173 | + return upscaled |
| 174 | + |
| 175 | + |
| 176 | +if __name__ == '__main__': |
| 177 | + tf.app.run() |
| 178 | + |
0 commit comments