forked from cxy1997/MNIST-baselines
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
54 lines (44 loc) · 2.29 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# data loader for mnist dataset
# the path is based on root dictory of this repo
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
class MnistLoader(object):
DATA_SIZE = (60000, 10000) # number of figures
FIG_W = 45 # width of each figure
CLASSES = 10 # number of classes
def __init__(self, flatten=False, data_path='data'):
'''
:param data_path: the path to mnist dataset
'''
self.flatten = flatten
self._load(data_path)
# normalize the data
self.data_train = (self.data_train - self.data_train.mean()) / self.data_train.std()
self.data_test = (self.data_test - self.data_test.mean()) / self.data_test.std()
# load data according to different configurations
def _load(self, data_path='data'):
self.data_train = np.concatenate([np.load(os.path.join(data_path, 'mnist_train', 'mnist_train_data_part1.npy')), np.load(os.path.join(data_path, 'mnist_train', 'mnist_train_data_part2.npy'))], axis=0).astype(np.float32)
self.label_train = np.load(os.path.join(data_path, 'mnist_train', 'mnist_train_label.npy')).astype(np.int64)
self.data_test = np.load(os.path.join(data_path, 'mnist_test', 'mnist_test_data.npy')).astype(np.float32)
self.label_test = np.load(os.path.join(data_path, 'mnist_test', 'mnist_test_label.npy')).astype(np.int64)
# flatten data
if self.flatten:
self.data_train = self.data_train.reshape(self.data_train.shape[0], -1)
self.data_test = self.data_test.reshape(self.data_test.shape[0], -1)
def demo(self):
# show the structure of data & label
print('Train data:', self.data_train.shape)
print('Train labels:', self.label_train.shape)
print('Test data:', self.data_test.shape)
print('Test labels:', self.label_test.shape)
# choose a random index
ind = np.random.randint(0, self.DATA_SIZE[0])
# print the index and label
print("index: ", ind)
print("label: ", self.label_train[ind])
# save the figure
plt.imshow(self.data_train[ind].reshape(self.FIG_W, self.FIG_W))
plt.show()
im.save("demo.png")