-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathutils.py
103 lines (75 loc) · 2.92 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import shutil
import pickle
import numpy as np
from itertools import groupby
from matplotlib import pyplot as plt
from CONSTANTS import *
from mlp_generator import MLPSearchSpace
########################################################
# DATA PROCESSING #
########################################################
def unison_shuffled_copies(a, b):
assert len(a) == len(b)
p = np.random.permutation(len(a))
return a[p], b[p]
########################################################
# LOGGING #
########################################################
def clean_log():
filelist = os.listdir('LOGS')
for file in filelist:
if os.path.isfile('LOGS/{}'.format(file)):
os.remove('LOGS/{}'.format(file))
def log_event():
dest = 'LOGS'
while os.path.exists(dest):
dest = 'LOGS/event{}'.format(np.random.randint(10000))
os.mkdir(dest)
filelist = os.listdir('LOGS')
for file in filelist:
if os.path.isfile('LOGS/{}'.format(file)):
shutil.move('LOGS/{}'.format(file),dest)
def get_latest_event_id():
all_subdirs = ['LOGS/' + d for d in os.listdir('LOGS') if os.path.isdir('LOGS/' + d)]
latest_subdir = max(all_subdirs, key=os.path.getmtime)
return int(latest_subdir.replace('LOGS/event', ''))
########################################################
# RESULTS PROCESSING #
########################################################
def load_nas_data():
event = get_latest_event_id()
data_file = 'LOGS/event{}/nas_data.pkl'.format(event)
with open(data_file, 'rb') as f:
data = pickle.load(f)
return data
def sort_search_data(nas_data):
val_accs = [item[1] for item in nas_data]
sorted_idx = np.argsort(val_accs)[::-1]
nas_data = [nas_data[x] for x in sorted_idx]
return nas_data
########################################################
# EVALUATION AND PLOTS #
########################################################
def get_top_n_architectures(n):
data = load_nas_data()
data = sort_search_data(data)
search_space = MLPSearchSpace(TARGET_CLASSES)
print('Top {} Architectures:'.format(n))
for seq_data in data[:n]:
print('Architecture', search_space.decode_sequence(seq_data[0]))
print('Validation Accuracy:', seq_data[1])
def get_nas_accuracy_plot():
data = load_nas_data()
accuracies = [x[1] for x in data]
plt.plot(np.arange(len(data)), accuracies)
plt.show()
def get_accuracy_distribution():
event = get_latest_event_id()
data = load_nas_data()
accuracies = [x[1]*100. for x in data]
accuracies = [int(x) for x in accuracies]
sorted_accs = np.sort(accuracies)
count_dict = {k: len(list(v)) for k, v in groupby(sorted_accs)}
plt.bar(list(count_dict.keys()), list(count_dict.values()))
plt.show()