22
22
import sys
23
23
24
24
import tensorflow as tf
25
+ from tensorflow .examples .tutorials .mnist import input_data
25
26
26
27
parser = argparse .ArgumentParser ()
27
28
28
29
# Basic model parameters.
29
- parser .add_argument ('--batch_size' , type = int , default = 100 ,
30
- help = 'Number of images to process in a batch' )
30
+ parser .add_argument (
31
+ '--batch_size' ,
32
+ type = int ,
33
+ default = 100 ,
34
+ help = 'Number of images to process in a batch' )
31
35
32
- parser .add_argument ('--data_dir' , type = str , default = '/tmp/mnist_data' ,
33
- help = 'Path to the MNIST data directory.' )
36
+ parser .add_argument (
37
+ '--data_dir' ,
38
+ type = str ,
39
+ default = '/tmp/mnist_data' ,
40
+ help = 'Path to directory containing the MNIST dataset' )
34
41
35
- parser .add_argument ('--model_dir' , type = str , default = '/tmp/mnist_model' ,
36
- help = 'The directory where the model will be stored.' )
42
+ parser .add_argument (
43
+ '--model_dir' ,
44
+ type = str ,
45
+ default = '/tmp/mnist_model' ,
46
+ help = 'The directory where the model will be stored.' )
37
47
38
- parser .add_argument ('--train_epochs' , type = int , default = 40 ,
39
- help = 'Number of epochs to train.' )
48
+ parser .add_argument (
49
+ '--train_epochs' , type = int , default = 40 , help = 'Number of epochs to train.' )
40
50
41
51
parser .add_argument (
42
- '--data_format' , type = str , default = None ,
52
+ '--data_format' ,
53
+ type = str ,
54
+ default = None ,
43
55
choices = ['channels_first' , 'channels_last' ],
44
56
help = 'A flag to override the data format used in the model. channels_first '
45
- 'provides a performance boost on GPU but is not always compatible '
46
- 'with CPU. If left unspecified, the data format will be chosen '
47
- 'automatically based on whether TensorFlow was built for CPU or GPU.' )
48
-
49
- _NUM_IMAGES = {
50
- 'train' : 50000 ,
51
- 'validation' : 10000 ,
52
- }
53
-
54
-
55
- def input_fn (is_training , filename , batch_size = 1 , num_epochs = 1 ):
56
- """A simple input_fn using the tf.data input pipeline."""
57
-
58
- def example_parser (serialized_example ):
59
- """Parses a single tf.Example into image and label tensors."""
60
- features = tf .parse_single_example (
61
- serialized_example ,
62
- features = {
63
- 'image_raw' : tf .FixedLenFeature ([], tf .string ),
64
- 'label' : tf .FixedLenFeature ([], tf .int64 ),
65
- })
66
- image = tf .decode_raw (features ['image_raw' ], tf .uint8 )
67
- image .set_shape ([28 * 28 ])
68
-
69
- # Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
70
- image = tf .cast (image , tf .float32 ) / 255 - 0.5
71
- label = tf .cast (features ['label' ], tf .int32 )
72
- return image , tf .one_hot (label , 10 )
73
-
74
- dataset = tf .data .TFRecordDataset ([filename ])
75
-
76
- # Apply dataset transformations
77
- if is_training :
78
- # When choosing shuffle buffer sizes, larger sizes result in better
79
- # randomness, while smaller sizes have better performance. Because MNIST is
80
- # a small dataset, we can easily shuffle the full epoch.
81
- dataset = dataset .shuffle (buffer_size = _NUM_IMAGES ['train' ])
57
+ 'provides a performance boost on GPU but is not always compatible '
58
+ 'with CPU. If left unspecified, the data format will be chosen '
59
+ 'automatically based on whether TensorFlow was built for CPU or GPU.' )
60
+
82
61
83
- # We call repeat after shuffling, rather than before, to prevent separate
84
- # epochs from blending together.
85
- dataset = dataset .repeat (num_epochs )
62
+ def train_dataset (data_dir ):
63
+ """Returns a tf.data.Dataset yielding (image, label) pairs for training."""
64
+ data = input_data .read_data_sets (data_dir , one_hot = True ).train
65
+ return tf .data .Dataset .from_tensor_slices ((data .images , data .labels ))
86
66
87
- # Map example_parser over dataset, and batch results by up to batch_size
88
- dataset = dataset .map (example_parser ).prefetch (batch_size )
89
- dataset = dataset .batch (batch_size )
90
- iterator = dataset .make_one_shot_iterator ()
91
- images , labels = iterator .get_next ()
92
67
93
- return images , labels
68
+ def eval_dataset (data_dir ):
69
+ """Returns a tf.data.Dataset yielding (image, label) pairs for evaluation."""
70
+ data = input_data .read_data_sets (data_dir , one_hot = True ).test
71
+ return tf .data .Dataset .from_tensors ((data .images , data .labels ))
94
72
95
73
96
74
def mnist_model (inputs , mode , data_format ):
@@ -104,8 +82,8 @@ def mnist_model(inputs, mode, data_format):
104
82
# When running on GPU, transpose the data from channels_last (NHWC) to
105
83
# channels_first (NCHW) to improve performance.
106
84
# See https://www.tensorflow.org/performance/performance_guide#data_formats
107
- data_format = ('channels_first' if tf . test . is_built_with_cuda () else
108
- 'channels_last' )
85
+ data_format = ('channels_first'
86
+ if tf . test . is_built_with_cuda () else 'channels_last' )
109
87
110
88
if data_format == 'channels_first' :
111
89
inputs = tf .transpose (inputs , [0 , 3 , 1 , 2 ])
@@ -127,8 +105,8 @@ def mnist_model(inputs, mode, data_format):
127
105
# First max pooling layer with a 2x2 filter and stride of 2
128
106
# Input Tensor Shape: [batch_size, 28, 28, 32]
129
107
# Output Tensor Shape: [batch_size, 14, 14, 32]
130
- pool1 = tf .layers .max_pooling2d (inputs = conv1 , pool_size = [ 2 , 2 ], strides = 2 ,
131
- data_format = data_format )
108
+ pool1 = tf .layers .max_pooling2d (
109
+ inputs = conv1 , pool_size = [ 2 , 2 ], strides = 2 , data_format = data_format )
132
110
133
111
# Convolutional Layer #2
134
112
# Computes 64 features using a 5x5 filter.
@@ -147,8 +125,8 @@ def mnist_model(inputs, mode, data_format):
147
125
# Second max pooling layer with a 2x2 filter and stride of 2
148
126
# Input Tensor Shape: [batch_size, 14, 14, 64]
149
127
# Output Tensor Shape: [batch_size, 7, 7, 64]
150
- pool2 = tf .layers .max_pooling2d (inputs = conv2 , pool_size = [ 2 , 2 ], strides = 2 ,
151
- data_format = data_format )
128
+ pool2 = tf .layers .max_pooling2d (
129
+ inputs = conv2 , pool_size = [ 2 , 2 ], strides = 2 , data_format = data_format )
152
130
153
131
# Flatten tensor into a batch of vectors
154
132
# Input Tensor Shape: [batch_size, 7, 7, 64]
@@ -159,8 +137,7 @@ def mnist_model(inputs, mode, data_format):
159
137
# Densely connected layer with 1024 neurons
160
138
# Input Tensor Shape: [batch_size, 7 * 7 * 64]
161
139
# Output Tensor Shape: [batch_size, 1024]
162
- dense = tf .layers .dense (inputs = pool2_flat , units = 1024 ,
163
- activation = tf .nn .relu )
140
+ dense = tf .layers .dense (inputs = pool2_flat , units = 1024 , activation = tf .nn .relu )
164
141
165
142
# Add dropout operation; 0.6 probability that element will be kept
166
143
dropout = tf .layers .dropout (
@@ -211,34 +188,37 @@ def mnist_model_fn(features, labels, mode, params):
211
188
212
189
213
190
def main (unused_argv ):
214
- # Make sure that training and testing data have been converted.
215
- train_file = os .path .join (FLAGS .data_dir , 'train.tfrecords' )
216
- test_file = os .path .join (FLAGS .data_dir , 'test.tfrecords' )
217
- assert (tf .gfile .Exists (train_file ) and tf .gfile .Exists (test_file )), (
218
- 'Run convert_to_records.py first to convert the MNIST data to TFRecord '
219
- 'file format.' )
220
-
221
191
# Create the Estimator
222
192
mnist_classifier = tf .estimator .Estimator (
223
- model_fn = mnist_model_fn , model_dir = FLAGS .model_dir ,
224
- params = {'data_format' : FLAGS .data_format })
193
+ model_fn = mnist_model_fn ,
194
+ model_dir = FLAGS .model_dir ,
195
+ params = {
196
+ 'data_format' : FLAGS .data_format
197
+ })
225
198
226
199
# Set up training hook that logs the training accuracy every 100 steps.
227
- tensors_to_log = {
228
- 'train_accuracy' : 'train_accuracy'
229
- }
200
+ tensors_to_log = {'train_accuracy' : 'train_accuracy' }
230
201
logging_hook = tf .train .LoggingTensorHook (
231
202
tensors = tensors_to_log , every_n_iter = 100 )
232
203
233
204
# Train the model
234
- mnist_classifier .train (
235
- input_fn = lambda : input_fn (
236
- True , train_file , FLAGS .batch_size , FLAGS .train_epochs ),
237
- hooks = [logging_hook ])
205
+ def train_input_fn ():
206
+ # When choosing shuffle buffer sizes, larger sizes result in better
207
+ # randomness, while smaller sizes use less memory. MNIST is a small
208
+ # enough dataset that we can easily shuffle the full epoch.
209
+ dataset = train_dataset (FLAGS .data_dir )
210
+ dataset = dataset .shuffle (buffer_size = 50000 ).batch (FLAGS .batch_size ).repeat (
211
+ FLAGS .train_epochs )
212
+ (images , labels ) = dataset .make_one_shot_iterator ().get_next ()
213
+ return (images , labels )
214
+
215
+ mnist_classifier .train (input_fn = train_input_fn , hooks = [logging_hook ])
238
216
239
217
# Evaluate the model and print results
240
- eval_results = mnist_classifier .evaluate (
241
- input_fn = lambda : input_fn (False , test_file , FLAGS .batch_size ))
218
+ def eval_input_fn ():
219
+ return eval_dataset (FLAGS .data_dir ).make_one_shot_iterator ().get_next ()
220
+
221
+ eval_results = mnist_classifier .evaluate (input_fn = eval_input_fn )
242
222
print ()
243
223
print ('Evaluation results:\n \t %s' % eval_results )
244
224
0 commit comments