-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathevaluate.py
executable file
·97 lines (72 loc) · 2.94 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import sys
import os
import cPickle as pickle
from simi_ite.logger import Logger as Log
Log.VERBOSE = True
import simi_ite.evaluation as evaluation
from simi_ite.plotting import *
def sort_by_config(results, configs, key):
vals = np.array([cfg[key] for cfg in configs])
I_vals = np.argsort(vals)
for k in results['train'].keys():
results['train'][k] = results['train'][k][I_vals,]
results['valid'][k] = results['valid'][k][I_vals,]
if k in results['test']:
results['test'][k] = results['test'][k][I_vals,]
configs_sorted = []
for i in I_vals:
configs_sorted.append(configs[i])
return results, configs_sorted
def load_config(config_file):
with open(config_file, 'r') as f:
cfg = [l.split('=') for l in f.read().split('\n') if '=' in l]
cfg = dict([(kv[0], eval(kv[1])) for kv in cfg])
return cfg
def evaluate(config_file, overwrite=False, filters=None):
if not os.path.isfile(config_file):
raise Exception('Could not find config file at path: %s' % config_file)
cfg = load_config(config_file)
output_dir = cfg['outdir']
if not os.path.isdir(output_dir):
raise Exception('Could not find output at path: %s' % output_dir)
data_train = cfg['datadir']+'/'+cfg['dataform']
data_test = cfg['datadir']+'/'+cfg['data_test']
binary = False
if cfg['loss'] == 'log':
binary = True
# Evaluate results_try1
eval_path = '%s/evaluation.npz' % output_dir
if overwrite or (not os.path.isfile(eval_path)):
eval_results, configs = evaluation.evaluate(output_dir,
data_path_train=data_train,
data_path_test=data_test,
binary=binary)
# Save evaluation
pickle.dump((eval_results, configs), open(eval_path, "wb"))
else:
if Log.VERBOSE:
print 'Loading evaluation results_try1 from %s...' % eval_path
# Load evaluation
eval_results, configs = pickle.load(open(eval_path, "rb"))
# print eval_results
# print configs
data_train_find = load_data(data_train)
if binary:
if data_train_find['HAVE_TRUTH']:
plot_evaluation_cont(eval_results, configs, output_dir, data_train, data_test, filters)
else:
plot_evaluation_bin(eval_results, configs, output_dir, data_train, data_test, filters)
else:
plot_evaluation_cont(eval_results, configs, output_dir, data_train, data_test, filters)
if __name__ == "__main__":
if len(sys.argv) < 2:
print 'Usage: python evaluate.py <config_file> <overwrite (default 0)> <filters (optional)>'
else:
config_file = sys.argv[1]
overwrite = False
if len(sys.argv)>2 and sys.argv[2] == '1':
overwrite = True
filters = None
if len(sys.argv)>3:
filters = eval(sys.argv[3])
evaluate(config_file, overwrite, filters=filters)