-
Notifications
You must be signed in to change notification settings - Fork 4
/
load_data.py
141 lines (104 loc) · 4.24 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import numpy as np
from PIL import Image
from matplotlib import pylab as plt
from scipy import misc
from utilities import split_train_test_set
import pickle
class_keys = {'BAS': 0, 'EBO': 1, 'EOS': 2, 'KSC': 3, 'LYA': 4, 'LYT': 5, 'MMZ': 6, 'MOB': 7, 'MON': 8, 'MYB': 9 , 'MYO': 10,
'NGB': 11, 'NGS': 12, 'PMB': 13, 'PMO': 14}
inv_class_keys = {v: k for k, v in class_keys.items()}
main_folder = "AML-Cytomorphology_LMU"
def load_data(roof=100,downsample=False):
first = True
image_data = []
class_data = []
#class_index = 0
for filename in os.listdir(main_folder):
class_index = class_keys[filename]
print("class index: ", class_index)
nr = 0
for image in os.listdir(main_folder + "/" + filename):
nr += 1
im = Image.open(main_folder + "/" + filename + "/" + image)
np_image = np.array(im)
np_image = misc.imresize(np_image, 0.25)
image_data += [np_image]
class_data += [class_index]
if roof:
if nr >= roof:
break
if first:
plt.imshow(np_image)
plt.show()
print("np_image shape: ", np_image.shape)
first = False
print(filename, " got ", nr, " sets.")
#class_index += 1
return np.array(image_data), np.array(class_data)
#creating an index list with names of the images divided into 11 buckets (10 members and 1 test set)
def create_split_index(buckets=11):
name_dict = {}
for filename in os.listdir(main_folder):
class_index = class_keys[filename]
print("class index: ", class_index)
image_data = []
class_data = []
image_names = []
nr = 0
for image in os.listdir(main_folder + "/" + filename):
nr += 1
im = Image.open(main_folder + "/" + filename + "/" + image)
np_image = np.array(im)
np_image = misc.imresize(np_image, 0.25)
image_data += [np_image]
class_data += [class_index]
image_names += [image]
print(filename, " got ", nr, " sets.")
splits = np.int32(np.round(np.linspace(0, nr, buckets+1)))
image_names_splits = np.split(image_names,splits[1:-1])
inner_dict = {}
for i in range(11):
for im in image_names_splits[i]:
inner_dict[im] = i
name_dict[filename] = inner_dict
pickle.dump(name_dict, open('name_dict.p', "wb"))
def load_split_set():
image_data = [[]]*11
class_data = [[]]*11
main_dict = pickle.load(open('name_dict.p', "rb"))
for filename in os.listdir(main_folder):
inner_dict = main_dict[filename]
#print("inner_dict: ", inner_dict)
for image in os.listdir(main_folder + "/" + filename):
#print("image: ", image)
ind = inner_dict[image]
if image_data[ind] == []:
image_data[ind] = [misc.imresize(np.array(Image.open(main_folder + "/" + filename + "/" + image)), 0.25)]
class_data[ind] = [class_keys[filename]]
else:
image_data[ind] += [misc.imresize(np.array(Image.open(main_folder + "/" + filename + "/" + image)), 0.25)]
class_data[ind] += [class_keys[filename]]
return image_data, class_data
# image_data, class_data = load_data(roof=20)
# print("image data shape: ", image_data.shape)
# print("class data shape: ", class_data.shape)
#
# train_x, train_y, test_x, test_y = split_train_test_set(image_data, class_data, split=0.1)
def check_class_nr(class_data):
classes, return_counts = np.unique(class_data, return_counts = True)
print("nr of classes: ", len(classes), ", nr of data points: ", len(class_data))
for i in range(len(classes)):
#print("classes[i]: ", classes[i])
print("class ", inv_class_keys[classes[i]], ": ", return_counts[i])
# print("train check: ")
# check_class_nr(train_y)
# print("test check: ")
# check_class_nr(test_y)
#create_split_index()
# main_dict = pickle.load(open('name_dict.p', "rb"))
# print("main_dict: ", main_dict)
image_data, class_data = load_split_set()
for i in range(len(class_data)):
print("bucket: ", i)
check_class_nr(class_data[i])