-
Notifications
You must be signed in to change notification settings - Fork 1
/
create_rotated_MNIST_dataset_two_class.py
96 lines (76 loc) · 3.56 KB
/
create_rotated_MNIST_dataset_two_class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import cv2
import os
def load_data(trainingData, trainingLabel, testingData, testingLabel, dataset = "MNIST"):
trainingData = os.environ[dataset] + trainingData
trainingLabel = os.environ[dataset] + trainingLabel
testingData = os.environ[dataset] + testingData
testingLabel = os.environ[dataset] + testingLabel
X_train = np.array(np.load(trainingData), dtype = np.float32).reshape(-1, 1, 28, 28)
Y_train = np.array(np.load(trainingLabel), dtype = np.uint8)
X_test = np.array(np.load(testingData), dtype = np.float32).reshape(-1, 1, 28, 28)
Y_test = np.array(np.load(testingLabel), dtype = np.uint8)
return X_train, Y_train, X_test, Y_test
def rotateImage(image, angle):
if len(image.shape) == 3:
image = image[0]
image_center = tuple(np.array(image.shape)/2)
rot_mat = cv2.getRotationMatrix2D(image_center,angle,1.0)
result = cv2.warpAffine(image, rot_mat, image.shape,flags=cv2.INTER_LINEAR)
return np.array(result[np.newaxis, :, :], dtype = np.float32)
def extend_image(inputs, size = 40):
if len(inputs.shape) == 3:
inputs = inputs.reshape(inputs.shape[0], 1, inputs.shape[1], inputs.shape[2])
extended_images = np.zeros((inputs.shape[0], 1, size, size), dtype = np.float32)
margin_size = (size - inputs.shape[2]) / 2
extended_images[:, :, margin_size:margin_size + inputs.shape[2], margin_size:margin_size + inputs
.shape[3]] = inputs
return extended_images
X_train, y_train, X_test, y_test = load_data("/X_train.npy", "/Y_train.npy", "/X_test.npy", "/Y_test.npy")
X_test = X_test[(y_test == 0) | (y_test == 1) ]
y_test = y_test[(y_test == 0) | (y_test == 1) ]
X_train = X_train[(y_train == 0) | (y_train == 1)]
y_train = y_train[(y_train == 0) | (y_train == 1)]
X_test = extend_image(X_test, 40)
X_train = extend_image(X_train, 40)
train_size = y_train.shape[0]
all_images = []
all_labels = []
for j in range(5):
angles_1 = list(np.random.randint(low = -50, high = 0, size = (train_size+1) // 2))
angles_2 = list(np.random.randint(low = 0, high = 50, size = (train_size+1) // 2))
angles = np.array(angles_1 + angles_2)
np.random.shuffle(angles)
rotated_image = np.array([rotateImage(X_train[i], angles[i]) for i in range(train_size)], dtype = np.float32)
all_images.append(rotated_image)
all_labels.append(y_train)
all_images = np.vstack(all_images)
all_labels = np.hstack(all_labels)
print(all_images.shape, all_labels.shape)
index = np.arange(5 * train_size)
np.random.shuffle(index)
all_images = all_images[index, 0, 6: 34, 6:34]
all_labels = all_labels[index]
x_train = extend_image(all_images, 60)
y_train = all_labels
test_size = y_test.shape[0]
all_images = []
all_labels = []
for j in range(5):
angles_1 = list(np.random.randint(low = -50, high = 0, size = (test_size+1) // 2))
angles_2 = list(np.random.randint(low = 0, high = 50, size = (test_size+1) // 2))
angles = np.array(angles_1 + angles_2)
np.random.shuffle(angles)
rotated_image = np.array([rotateImage(X_test[i], angles[i]) for i in range(test_size)], dtype = np.float32)
all_images.append(rotated_image)
all_labels.append(y_test)
all_images = np.vstack(all_images)
all_labels = np.hstack(all_labels)
print(all_images.shape, all_labels.shape)
index = np.arange(5 * test_size)
np.random.shuffle(index)
all_images = all_images[index, 0, 6: 34, 6:34]
all_labels = all_labels[index]
x_test = extend_image(all_images, 60)
y_test = all_labels
np.savez("/phddata/jiajun/Research/mnist/rotated_mnist_two_class.npz", x_train = x_train, y_train = y_train, x_test = x_test, y_test=y_test)