-
Notifications
You must be signed in to change notification settings - Fork 7
/
helper.py
83 lines (73 loc) · 2.67 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import zipfile
import shutil
import gdown
import cv2
import glob
import numpy as np
import matplotlib.pyplot as plt
DATA_DIR = 'data/'
DOCS_DIR = 'docs/'
LABELS_FILE = 'All_labels.txt'
DATA_URL = 'https://drive.google.com/uc?id=1w0TorBfTIqbquQVd6k3h_77ypnrvfGwf'
ZFILE = 'SCUT-FBP5500_v2.1.zip'
def extract_zipfile():
with zipfile.ZipFile(ZFILE) as zip_file:
for member in zip_file.namelist():
filename = os.path.basename(member)
# skip directories, all non-jpgs, except labels
if filename.endswith(".jpg") or filename == LABELS_FILE:
# copy file (taken from zipfile's extract)
source = zip_file.open(member)
target = open(os.path.join(DATA_DIR, filename), "wb")
with source, target:
shutil.copyfileobj(source, target)
def download_data():
# Download Dataset
if os.path.isfile(ZFILE) or os.path.isfile(DATA_DIR+LABELS_FILE):
print('data already downloaded')
else:
print ("data does not exist. downloading it.")
gdown.download(DATA_URL, ZFILE, quiet=False)
# Extract ZipFile
if os.path.isfile(DATA_DIR+LABELS_FILE):
print("data already extracted.")
else:
print("extracting data.")
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
extract_zipfile()
os.remove(ZFILE)
def preprocess_image(image,target_size):
return cv2.resize(cv2.cvtColor(image, cv2.COLOR_BGR2RGB),target_size) / .255
def create_dataset(target_size):
X = []
y = []
labels_dict = get_labels_dict()
img_files = glob.glob(DATA_DIR+'*.jpg')
print(f'reading {len(img_files)} images into dataset')
for f in img_files:
img = preprocess_image(cv2.imread(f), target_size)
X.append(img)
y.append(labels_dict[os.path.split(f)[-1]])
return np.array(X), np.array(y)
def get_labels_dict():
labels_dict = {}
with open(DATA_DIR + LABELS_FILE) as fp:
for line in fp:
img,label = line.split(' ', 1)
labels_dict[img] = float(label)
return labels_dict
def plot_metrics(history, model_name, stage):
f,(ax1) = plt.subplots(1, 1, figsize=(15,7))
f.suptitle(f'Stage {stage} Model "{model_name}" training Metrics')
ax1.plot(history.history["loss"], color='darkblue', label="Train")
ax1.plot(history.history["val_loss"], color='darkred', label="Test")
ax1.set_title('Loss (Mean Squared Error) over epoch')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss (MSE)')
ax1.legend()
plt.savefig(f'{DOCS_DIR}metrics_stage_{stage}_{model_name}')
plt.show()
if __name__ == "__main__":
download_data()