-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathclassifier_test_accuracy.py
72 lines (48 loc) · 2.12 KB
/
classifier_test_accuracy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Get classification metrics for a trained classifier model
# Authors:
# Christian F. Baumgartner ([email protected])
import numpy as np
import os
import glob
from importlib.machinery import SourceFileLoader
import argparse
from sklearn.metrics import f1_score, classification_report, confusion_matrix
import config.system as sys_config
from classifier.model_classifier import classifier
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
def main(model_path, exp_config):
# Get Data
if exp_config.data_identifier == 'synthetic':
from data.synthetic_data import synthetic_data as data_loader
elif exp_config.data_identifier == 'adni':
from data.adni_data import adni_data as data_loader
else:
raise ValueError('Unknown data identifier: %s' % exp_config.data_identifier)
data = data_loader(exp_config)
# Make and restore vagan model
classifier_model = classifier(exp_config=exp_config, data=data)
classifier_model.load_weights(model_path, type='latest')
# Run predictions in an endless loop
pred_list = []
gt_list = []
for batch in data.test.iterate_batches(32):
x, y = batch
y_ = classifier_model.predict(x)[0]
pred_list += list(y_)
gt_list += list(y)
print(pred_list)
print(gt_list)
print(classification_report(np.asarray(gt_list), np.asarray(pred_list)))
print(confusion_matrix(np.asarray(gt_list), np.asarray(pred_list)))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Script for a simple test loop evaluating a network on the test dataset")
parser.add_argument("EXP_PATH", type=str, help="Path to experiment folder (assuming you are in the working directory)")
args = parser.parse_args()
base_path = sys_config.project_root
model_path = os.path.join(base_path, args.EXP_PATH)
config_file = glob.glob(model_path + '/*py')[0]
config_module = config_file.split('/')[-1].rstrip('.py')
exp_config = SourceFileLoader(config_module, os.path.join(config_file)).load_module()
main(model_path, exp_config=exp_config)