Skip to content

Commit afd3e67

Browse files
committed
[mnist]: Tweaks
- Remove `convert_to_records.py` and instead create `tf.data.Dataset` objects directly from the numpy arrays. - Format the Google Python Style (https://github.com/google/yapf/)
1 parent 5a5d330 commit afd3e67

File tree

3 files changed

+61
-185
lines changed

3 files changed

+61
-185
lines changed

official/mnist/README.md

-7
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@ APIs.
1212
## Setup
1313

1414
To begin, you'll simply need the latest version of TensorFlow installed.
15-
16-
First convert the MNIST data to TFRecord file format by running the following:
17-
18-
```
19-
python convert_to_records.py
20-
```
21-
2215
Then to train the model, run the following:
2316

2417
```

official/mnist/convert_to_records.py

-97
This file was deleted.

official/mnist/mnist.py

+61-81
Original file line numberDiff line numberDiff line change
@@ -22,75 +22,53 @@
2222
import sys
2323

2424
import tensorflow as tf
25+
from tensorflow.examples.tutorials.mnist import input_data
2526

2627
parser = argparse.ArgumentParser()
2728

2829
# Basic model parameters.
29-
parser.add_argument('--batch_size', type=int, default=100,
30-
help='Number of images to process in a batch')
30+
parser.add_argument(
31+
'--batch_size',
32+
type=int,
33+
default=100,
34+
help='Number of images to process in a batch')
3135

32-
parser.add_argument('--data_dir', type=str, default='/tmp/mnist_data',
33-
help='Path to the MNIST data directory.')
36+
parser.add_argument(
37+
'--data_dir',
38+
type=str,
39+
default='/tmp/mnist_data',
40+
help='Path to directory containing the MNIST dataset')
3441

35-
parser.add_argument('--model_dir', type=str, default='/tmp/mnist_model',
36-
help='The directory where the model will be stored.')
42+
parser.add_argument(
43+
'--model_dir',
44+
type=str,
45+
default='/tmp/mnist_model',
46+
help='The directory where the model will be stored.')
3747

38-
parser.add_argument('--train_epochs', type=int, default=40,
39-
help='Number of epochs to train.')
48+
parser.add_argument(
49+
'--train_epochs', type=int, default=40, help='Number of epochs to train.')
4050

4151
parser.add_argument(
42-
'--data_format', type=str, default=None,
52+
'--data_format',
53+
type=str,
54+
default=None,
4355
choices=['channels_first', 'channels_last'],
4456
help='A flag to override the data format used in the model. channels_first '
45-
'provides a performance boost on GPU but is not always compatible '
46-
'with CPU. If left unspecified, the data format will be chosen '
47-
'automatically based on whether TensorFlow was built for CPU or GPU.')
48-
49-
_NUM_IMAGES = {
50-
'train': 50000,
51-
'validation': 10000,
52-
}
53-
54-
55-
def input_fn(is_training, filename, batch_size=1, num_epochs=1):
56-
"""A simple input_fn using the tf.data input pipeline."""
57-
58-
def example_parser(serialized_example):
59-
"""Parses a single tf.Example into image and label tensors."""
60-
features = tf.parse_single_example(
61-
serialized_example,
62-
features={
63-
'image_raw': tf.FixedLenFeature([], tf.string),
64-
'label': tf.FixedLenFeature([], tf.int64),
65-
})
66-
image = tf.decode_raw(features['image_raw'], tf.uint8)
67-
image.set_shape([28 * 28])
68-
69-
# Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
70-
image = tf.cast(image, tf.float32) / 255 - 0.5
71-
label = tf.cast(features['label'], tf.int32)
72-
return image, tf.one_hot(label, 10)
73-
74-
dataset = tf.data.TFRecordDataset([filename])
75-
76-
# Apply dataset transformations
77-
if is_training:
78-
# When choosing shuffle buffer sizes, larger sizes result in better
79-
# randomness, while smaller sizes have better performance. Because MNIST is
80-
# a small dataset, we can easily shuffle the full epoch.
81-
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])
57+
'provides a performance boost on GPU but is not always compatible '
58+
'with CPU. If left unspecified, the data format will be chosen '
59+
'automatically based on whether TensorFlow was built for CPU or GPU.')
60+
8261

83-
# We call repeat after shuffling, rather than before, to prevent separate
84-
# epochs from blending together.
85-
dataset = dataset.repeat(num_epochs)
62+
def train_dataset(data_dir):
63+
"""Returns a tf.data.Dataset yielding (image, label) pairs for training."""
64+
data = input_data.read_data_sets(data_dir, one_hot=True).train
65+
return tf.data.Dataset.from_tensor_slices((data.images, data.labels))
8666

87-
# Map example_parser over dataset, and batch results by up to batch_size
88-
dataset = dataset.map(example_parser).prefetch(batch_size)
89-
dataset = dataset.batch(batch_size)
90-
iterator = dataset.make_one_shot_iterator()
91-
images, labels = iterator.get_next()
9267

93-
return images, labels
68+
def eval_dataset(data_dir):
69+
"""Returns a tf.data.Dataset yielding (image, label) pairs for evaluation."""
70+
data = input_data.read_data_sets(data_dir, one_hot=True).test
71+
return tf.data.Dataset.from_tensors((data.images, data.labels))
9472

9573

9674
def mnist_model(inputs, mode, data_format):
@@ -104,8 +82,8 @@ def mnist_model(inputs, mode, data_format):
10482
# When running on GPU, transpose the data from channels_last (NHWC) to
10583
# channels_first (NCHW) to improve performance.
10684
# See https://www.tensorflow.org/performance/performance_guide#data_formats
107-
data_format = ('channels_first' if tf.test.is_built_with_cuda() else
108-
'channels_last')
85+
data_format = ('channels_first'
86+
if tf.test.is_built_with_cuda() else 'channels_last')
10987

11088
if data_format == 'channels_first':
11189
inputs = tf.transpose(inputs, [0, 3, 1, 2])
@@ -127,8 +105,8 @@ def mnist_model(inputs, mode, data_format):
127105
# First max pooling layer with a 2x2 filter and stride of 2
128106
# Input Tensor Shape: [batch_size, 28, 28, 32]
129107
# Output Tensor Shape: [batch_size, 14, 14, 32]
130-
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2,
131-
data_format=data_format)
108+
pool1 = tf.layers.max_pooling2d(
109+
inputs=conv1, pool_size=[2, 2], strides=2, data_format=data_format)
132110

133111
# Convolutional Layer #2
134112
# Computes 64 features using a 5x5 filter.
@@ -147,8 +125,8 @@ def mnist_model(inputs, mode, data_format):
147125
# Second max pooling layer with a 2x2 filter and stride of 2
148126
# Input Tensor Shape: [batch_size, 14, 14, 64]
149127
# Output Tensor Shape: [batch_size, 7, 7, 64]
150-
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2,
151-
data_format=data_format)
128+
pool2 = tf.layers.max_pooling2d(
129+
inputs=conv2, pool_size=[2, 2], strides=2, data_format=data_format)
152130

153131
# Flatten tensor into a batch of vectors
154132
# Input Tensor Shape: [batch_size, 7, 7, 64]
@@ -159,8 +137,7 @@ def mnist_model(inputs, mode, data_format):
159137
# Densely connected layer with 1024 neurons
160138
# Input Tensor Shape: [batch_size, 7 * 7 * 64]
161139
# Output Tensor Shape: [batch_size, 1024]
162-
dense = tf.layers.dense(inputs=pool2_flat, units=1024,
163-
activation=tf.nn.relu)
140+
dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
164141

165142
# Add dropout operation; 0.6 probability that element will be kept
166143
dropout = tf.layers.dropout(
@@ -211,34 +188,37 @@ def mnist_model_fn(features, labels, mode, params):
211188

212189

213190
def main(unused_argv):
214-
# Make sure that training and testing data have been converted.
215-
train_file = os.path.join(FLAGS.data_dir, 'train.tfrecords')
216-
test_file = os.path.join(FLAGS.data_dir, 'test.tfrecords')
217-
assert (tf.gfile.Exists(train_file) and tf.gfile.Exists(test_file)), (
218-
'Run convert_to_records.py first to convert the MNIST data to TFRecord '
219-
'file format.')
220-
221191
# Create the Estimator
222192
mnist_classifier = tf.estimator.Estimator(
223-
model_fn=mnist_model_fn, model_dir=FLAGS.model_dir,
224-
params={'data_format': FLAGS.data_format})
193+
model_fn=mnist_model_fn,
194+
model_dir=FLAGS.model_dir,
195+
params={
196+
'data_format': FLAGS.data_format
197+
})
225198

226199
# Set up training hook that logs the training accuracy every 100 steps.
227-
tensors_to_log = {
228-
'train_accuracy': 'train_accuracy'
229-
}
200+
tensors_to_log = {'train_accuracy': 'train_accuracy'}
230201
logging_hook = tf.train.LoggingTensorHook(
231202
tensors=tensors_to_log, every_n_iter=100)
232203

233204
# Train the model
234-
mnist_classifier.train(
235-
input_fn=lambda: input_fn(
236-
True, train_file, FLAGS.batch_size, FLAGS.train_epochs),
237-
hooks=[logging_hook])
205+
def train_input_fn():
206+
# When choosing shuffle buffer sizes, larger sizes result in better
207+
# randomness, while smaller sizes use less memory. MNIST is a small
208+
# enough dataset that we can easily shuffle the full epoch.
209+
dataset = train_dataset(FLAGS.data_dir)
210+
dataset = dataset.shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
211+
FLAGS.train_epochs)
212+
(images, labels) = dataset.make_one_shot_iterator().get_next()
213+
return (images, labels)
214+
215+
mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])
238216

239217
# Evaluate the model and print results
240-
eval_results = mnist_classifier.evaluate(
241-
input_fn=lambda: input_fn(False, test_file, FLAGS.batch_size))
218+
def eval_input_fn():
219+
return eval_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next()
220+
221+
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
242222
print()
243223
print('Evaluation results:\n\t%s' % eval_results)
244224

0 commit comments

Comments
 (0)