-
Notifications
You must be signed in to change notification settings - Fork 7
/
mnist_data.py
153 lines (119 loc) · 5.5 KB
/
mnist_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Some code was borrowed from https://github.com/petewarden/tensorflow_makefile/blob/master/tensorflow/models/image/mnist/convolutional.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import numpy
from scipy import ndimage
from six.moves import urllib
import tensorflow as tf
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
DATA_DIRECTORY = "data"
# Params for MNIST
IMAGE_SIZE = 28
NUM_CHANNELS = 1
PIXEL_DEPTH = 255
NUM_LABELS = 10
VALIDATION_SIZE = 5000 # Size of the validation set.
# Download MNIST data
def maybe_download(filename):
"""Download the data from Yann's website, unless it's already here."""
if not tf.gfile.Exists(DATA_DIRECTORY):
tf.gfile.MakeDirs(DATA_DIRECTORY)
filepath = os.path.join(DATA_DIRECTORY, filename)
if not tf.gfile.Exists(filepath):
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
with tf.gfile.GFile(filepath) as f:
size = f.size()
print('Successfully downloaded', filename, size, 'bytes.')
return filepath
# Extract the images
def extract_data(filename, num_images):
"""Extract the images into a 4D tensor [image index, y, x, channels].
Values are rescaled from [0, 255] down to [-0.5, 0.5].
"""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
bytestream.read(16)
buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images * NUM_CHANNELS)
data = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.float32)
data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH
data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)
data = numpy.reshape(data, [num_images, -1])
return data
# Extract the labels
def extract_labels(filename, num_images):
"""Extract the labels into a vector of int64 label IDs."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
bytestream.read(8)
buf = bytestream.read(1 * num_images)
labels = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.int64)
num_labels_data = len(labels)
one_hot_encoding = numpy.zeros((num_labels_data,NUM_LABELS))
one_hot_encoding[numpy.arange(num_labels_data),labels] = 1
one_hot_encoding = numpy.reshape(one_hot_encoding, [-1, NUM_LABELS])
return one_hot_encoding
# Augment training data
def expend_training_data(images, labels):
expanded_images = []
expanded_labels = []
j = 0 # counter
for x, y in zip(images, labels):
j = j+1
if j%100==0:
print ('expanding data : %03d / %03d' % (j,numpy.size(images,0)))
# register original data
expanded_images.append(x)
expanded_labels.append(y)
# get a value for the background
# zero is the expected value, but median() is used to estimate background's value
bg_value = numpy.median(x) # this is regarded as background's value
image = numpy.reshape(x, (-1, 28))
for i in range(4):
# rotate the image with random degree
angle = numpy.random.randint(-15,15,1)
new_img = ndimage.rotate(image,angle,reshape=False, cval=bg_value)
# shift the image with random distance
shift = numpy.random.randint(-2, 2, 2)
new_img_ = ndimage.shift(new_img,shift, cval=bg_value)
# register new training data
expanded_images.append(numpy.reshape(new_img_, 784))
expanded_labels.append(y)
# images and labels are concatenated for random-shuffle at each epoch
# notice that pair of image and label should not be broken
expanded_train_total_data = numpy.concatenate((expanded_images, expanded_labels), axis=1)
numpy.random.shuffle(expanded_train_total_data)
return expanded_train_total_data
# Prepare MNISt data
def prepare_MNIST_data(use_data_augmentation=True):
# Get the data.
train_data_filename = maybe_download('train-images-idx3-ubyte.gz')
train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz')
test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz')
test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz')
# Extract it into numpy arrays.
train_data = extract_data(train_data_filename, 60000)
train_labels = extract_labels(train_labels_filename, 60000)
test_data = extract_data(test_data_filename, 10000)
test_labels = extract_labels(test_labels_filename, 10000)
# # Generate a validation set.
# validation_data = train_data[:VALIDATION_SIZE, :]
# validation_labels = train_labels[:VALIDATION_SIZE,:]
# train_data = train_data[VALIDATION_SIZE:, :]
# train_labels = train_labels[VALIDATION_SIZE:,:]
# print("val data: ", validation_data.shape)
# print("val labels data: ", validation_labels.shape)
# print("train data: ", numpy.concatenate((train_data, train_labels), axis=1).shape)
# print("train labels data: ", train_data.shape)
# Concatenate train_data & train_labels for random shuffle
if use_data_augmentation:
train_total_data = expend_training_data(train_data, train_labels)
else:
train_total_data = numpy.concatenate((train_data, train_labels), axis=1)
# print("val data: ", validation_data.shape)
# print("val labels data: ", validation_labels.shape)
# print("train data: ", train_total_data.shape)
train_size = train_total_data.shape[0]
return train_total_data, train_size, test_data, test_labels