Skip to content

Commit 9b51944

Browse files
authored
Merge pull request tensorflow#2932 from joel-shor/master
Add image to image translation example
2 parents 0b2bc49 + a585fc1 commit 9b51944

File tree

5 files changed

+433
-0
lines changed

5 files changed

+433
-0
lines changed

research/gan/pix2pix/launch_jobs.sh

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
#!/bin/bash
16+
#
17+
# This script performs the following operations:
18+
# 1. Downloads the Imagenet dataset.
19+
# 2. Trains image compression model on patches from Imagenet.
20+
# 3. Evaluates the models and writes sample images to disk.
21+
#
22+
# Usage:
23+
# cd models/research/gan/image_compression
24+
# ./launch_jobs.sh ${weight_factor} ${git_repo}
25+
set -e
26+
27+
# Weight of the adversarial loss.
28+
weight_factor=$1
29+
if [[ "$weight_factor" == "" ]]; then
30+
echo "'weight_factor' must not be empty."
31+
exit
32+
fi
33+
34+
# Location of the git repository.
35+
git_repo=$2
36+
if [[ "$git_repo" == "" ]]; then
37+
echo "'git_repo' must not be empty."
38+
exit
39+
fi
40+
41+
# Base name for where the checkpoint and logs will be saved to.
42+
TRAIN_DIR=/tmp/compression-model
43+
44+
# Base name for where the evaluation images will be saved to.
45+
EVAL_DIR=/tmp/compression-model/eval
46+
47+
# Where the dataset is saved to.
48+
DATASET_DIR=/tmp/imagenet-data
49+
50+
export PYTHONPATH=$PYTHONPATH:$git_repo:$git_repo/research:$git_repo/research/slim:$git_repo/research/slim/nets
51+
52+
# A helper function for printing pretty output.
53+
Banner () {
54+
local text=$1
55+
local green='\033[0;32m'
56+
local nc='\033[0m' # No color.
57+
echo -e "${green}${text}${nc}"
58+
}
59+
60+
# Download the dataset.
61+
bazel build "${git_repo}/research/slim:download_and_convert_imagenet"
62+
"./bazel-bin/download_and_convert_imagenet" ${DATASET_DIR}
63+
64+
# Run the pix2pix model.
65+
NUM_STEPS=10000
66+
MODEL_TRAIN_DIR="${TRAIN_DIR}/wt${weight_factor}"
67+
Banner "Starting training an image compression model for ${NUM_STEPS} steps..."
68+
python "${git_repo}/research/gan/image_compression/train.py" \
69+
--train_log_dir=${MODEL_TRAIN_DIR} \
70+
--dataset_dir=${DATASET_DIR} \
71+
--max_number_of_steps=${NUM_STEPS} \
72+
--weight_factor=${weight_factor} \
73+
--alsologtostderr
74+
Banner "Finished training pix2pix model ${NUM_STEPS} steps."

research/gan/pix2pix/networks.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Networks for GAN Pix2Pix example using TFGAN."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf
22+
23+
from slim.nets import cyclegan
24+
from slim.nets import pix2pix
25+
26+
27+
def generator(input_images):
28+
"""Thin wrapper around CycleGAN generator to conform to the TFGAN API.
29+
30+
Args:
31+
input_images: A batch of images to translate. Images should be normalized
32+
already. Shape is [batch, height, width, channels].
33+
34+
Returns:
35+
Returns generated image batch.
36+
"""
37+
input_images.shape.assert_has_rank(4)
38+
with tf.contrib.framework.arg_scope(cyclegan.cyclegan_arg_scope()):
39+
output_images, _ = cyclegan.cyclegan_generator_resnet(input_images)
40+
return output_images
41+
42+
43+
def discriminator(image_batch, unused_conditioning=None):
44+
"""A thin wrapper around the Pix2Pix discriminator to conform to TFGAN API."""
45+
with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
46+
logits_4d, _ = pix2pix.pix2pix_discriminator(
47+
image_batch, num_filters=[64, 128, 256, 512])
48+
logits_4d.shape.assert_has_rank(4)
49+
# Output of logits is 4D. Reshape to 2D, for TFGAN.
50+
logits_2d = tf.contrib.layers.flatten(logits_4d)
51+
52+
return logits_2d

research/gan/pix2pix/networks_test.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for tfgan.examples.networks.networks."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf
22+
from google3.third_party.tensorflow_models.gan.pix2pix import networks
23+
24+
25+
class Pix2PixTest(tf.test.TestCase):
26+
27+
def test_generator_run(self):
28+
img_batch = tf.zeros([3, 128, 128, 3])
29+
model_output = networks.generator(img_batch)
30+
with self.test_session() as sess:
31+
sess.run(tf.global_variables_initializer())
32+
sess.run(model_output)
33+
34+
def test_generator_graph(self):
35+
for shape in ([4, 32, 32], [3, 128, 128], [2, 80, 400]):
36+
tf.reset_default_graph()
37+
img = tf.ones(shape + [3])
38+
output_imgs = networks.generator(img)
39+
40+
self.assertAllEqual(shape + [3], output_imgs.shape.as_list())
41+
42+
def test_generator_graph_unknown_batch_dim(self):
43+
img = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
44+
output_imgs = networks.generator(img)
45+
46+
self.assertAllEqual([None, 32, 32, 3], output_imgs.shape.as_list())
47+
48+
def test_generator_invalid_input(self):
49+
with self.assertRaisesRegexp(ValueError, 'must have rank 4'):
50+
networks.generator(tf.zeros([28, 28, 3]))
51+
52+
def test_discriminator_run(self):
53+
img_batch = tf.zeros([3, 70, 70, 3])
54+
disc_output = networks.discriminator(img_batch)
55+
with self.test_session() as sess:
56+
sess.run(tf.global_variables_initializer())
57+
sess.run(disc_output)
58+
59+
def test_discriminator_graph(self):
60+
# Check graph construction for a number of image size/depths and batch
61+
# sizes.
62+
for batch_size, patch_size in zip([3, 6], [70, 128]):
63+
tf.reset_default_graph()
64+
img = tf.ones([batch_size, patch_size, patch_size, 3])
65+
disc_output = networks.discriminator(img)
66+
67+
self.assertEqual(2, disc_output.shape.ndims)
68+
self.assertEqual(batch_size, disc_output.shape.as_list()[0])
69+
70+
def test_discriminator_invalid_input(self):
71+
with self.assertRaisesRegexp(ValueError, 'Shape must be rank 4'):
72+
networks.discriminator(tf.zeros([28, 28, 3]))
73+
74+
75+
if __name__ == '__main__':
76+
tf.test.main()

research/gan/pix2pix/train.py

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Trains an image-to-image translation network with an adversarial loss."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
22+
23+
import tensorflow as tf
24+
25+
import data_provider
26+
from google3.third_party.tensorflow_models.gan.pix2pix import networks
27+
28+
flags = tf.flags
29+
tfgan = tf.contrib.gan
30+
31+
32+
flags.DEFINE_integer('batch_size', 10, 'The number of images in each batch.')
33+
34+
flags.DEFINE_integer('patch_size', 32, 'The size of the patches to train on.')
35+
36+
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
37+
38+
flags.DEFINE_string('train_log_dir', '/tmp/pix2pix/',
39+
'Directory where to write event logs.')
40+
41+
flags.DEFINE_float('generator_lr', 0.00001,
42+
'The compression model learning rate.')
43+
44+
flags.DEFINE_float('discriminator_lr', 0.00001,
45+
'The discriminator learning rate.')
46+
47+
flags.DEFINE_integer('max_number_of_steps', 2000000,
48+
'The maximum number of gradient steps.')
49+
50+
flags.DEFINE_integer(
51+
'ps_tasks', 0,
52+
'The number of parameter servers. If the value is 0, then the parameters '
53+
'are handled locally by the worker.')
54+
55+
flags.DEFINE_integer(
56+
'task', 0,
57+
'The Task ID. This value is used when training with multiple workers to '
58+
'identify each worker.')
59+
60+
flags.DEFINE_float(
61+
'weight_factor', 0.0,
62+
'How much to weight the adversarial loss relative to pixel loss.')
63+
64+
flags.DEFINE_string('dataset_dir', None, 'Location of data.')
65+
66+
67+
FLAGS = flags.FLAGS
68+
69+
70+
def main(_):
71+
if not tf.gfile.Exists(FLAGS.train_log_dir):
72+
tf.gfile.MakeDirs(FLAGS.train_log_dir)
73+
74+
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
75+
# Get real and distorted images.
76+
with tf.device('/cpu:0'), tf.name_scope('inputs'):
77+
real_images = data_provider.provide_data(
78+
'train', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir,
79+
patch_size=FLAGS.patch_size)
80+
distorted_images = _distort_images(
81+
real_images, downscale_size=int(FLAGS.patch_size / 2),
82+
upscale_size=FLAGS.patch_size)
83+
84+
# Create a GANModel tuple.
85+
gan_model = tfgan.gan_model(
86+
generator_fn=networks.generator,
87+
discriminator_fn=networks.discriminator,
88+
real_data=real_images,
89+
generator_inputs=distorted_images)
90+
tfgan.eval.add_image_comparison_summaries(
91+
gan_model, num_comparisons=3, display_diffs=True)
92+
tfgan.eval.add_gan_model_image_summaries(gan_model, grid_size=3)
93+
94+
# Define the GANLoss tuple using standard library functions.
95+
with tf.name_scope('losses'):
96+
gan_loss = tfgan.gan_loss(
97+
gan_model,
98+
generator_loss_fn=tfgan.losses.least_squares_generator_loss,
99+
discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss)
100+
101+
# Define the standard L1 pixel loss.
102+
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data,
103+
ord=1) / FLAGS.patch_size ** 2
104+
105+
# Modify the loss tuple to include the pixel loss. Add summaries as well.
106+
gan_loss = tfgan.losses.combine_adversarial_loss(
107+
gan_loss, gan_model, l1_pixel_loss,
108+
weight_factor=FLAGS.weight_factor)
109+
110+
with tf.name_scope('train_ops'):
111+
# Get the GANTrain ops using the custom optimizers and optional
112+
# discriminator weight clipping.
113+
gen_lr, dis_lr = _lr(FLAGS.generator_lr, FLAGS.discriminator_lr)
114+
gen_opt, dis_opt = _optimizer(gen_lr, dis_lr)
115+
train_ops = tfgan.gan_train_ops(
116+
gan_model,
117+
gan_loss,
118+
generator_optimizer=gen_opt,
119+
discriminator_optimizer=dis_opt,
120+
summarize_gradients=True,
121+
colocate_gradients_with_ops=True,
122+
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N,
123+
transform_grads_fn=tf.contrib.training.clip_gradient_norms_fn(1e3))
124+
tf.summary.scalar('generator_lr', gen_lr)
125+
tf.summary.scalar('discriminator_lr', dis_lr)
126+
127+
# Use GAN train step function if using adversarial loss, otherwise
128+
# only train the generator.
129+
train_steps = tfgan.GANTrainSteps(
130+
generator_train_steps=1,
131+
discriminator_train_steps=int(FLAGS.weight_factor > 0))
132+
133+
# Run the alternating training loop. Skip it if no steps should be taken
134+
# (used for graph construction tests).
135+
status_message = tf.string_join(
136+
['Starting train step: ',
137+
tf.as_string(tf.train.get_or_create_global_step())],
138+
name='status_message')
139+
if FLAGS.max_number_of_steps == 0: return
140+
tfgan.gan_train(
141+
train_ops,
142+
FLAGS.train_log_dir,
143+
get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
144+
hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
145+
tf.train.LoggingTensorHook([status_message], every_n_iter=10)],
146+
master=FLAGS.master,
147+
is_chief=FLAGS.task == 0)
148+
149+
150+
def _optimizer(gen_lr, dis_lr):
151+
kwargs = {'beta1': 0.5, 'beta2': 0.999}
152+
generator_opt = tf.train.AdamOptimizer(gen_lr, **kwargs)
153+
discriminator_opt = tf.train.AdamOptimizer(dis_lr, **kwargs)
154+
return generator_opt, discriminator_opt
155+
156+
157+
def _lr(gen_lr_base, dis_lr_base):
158+
"""Return the generator and discriminator learning rates."""
159+
gen_lr = tf.train.exponential_decay(
160+
learning_rate=gen_lr_base,
161+
global_step=tf.train.get_or_create_global_step(),
162+
decay_steps=100000,
163+
decay_rate=0.8,
164+
staircase=True,)
165+
dis_lr = dis_lr_base
166+
167+
return gen_lr, dis_lr
168+
169+
170+
def _distort_images(images, downscale_size, upscale_size):
171+
downscaled = tf.image.resize_area(images, [downscale_size] * 2)
172+
upscaled = tf.image.resize_area(downscaled, [upscale_size] * 2)
173+
return upscaled
174+
175+
176+
if __name__ == '__main__':
177+
tf.app.run()
178+

0 commit comments

Comments
 (0)