-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
executable file
·72 lines (51 loc) · 1.74 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python
"""Runs the testing protocol for a given dataset, model, and NMS methods.
Can either run as a single process or distributed using celery."""
import celery
from nyc3dcars import SESSION, Photo, Model
import scores
import logging
import argparse
from detect import detect
from geo_rescore import geo_rescore
from nms import nms
def test(model, remote, methods, dataset_id):
"""Executes the testing protocol."""
session = SESSION()
try:
test_set = session.query(Photo) \
.filter_by(test=True, dataset_id=dataset_id)
session.query(Model) \
.filter_by(filename=model) \
.one()
for photo in test_set:
logging.info(photo.id)
celery_list = [detect.s(photo.id, model)]
for method in methods:
celery_list += [geo_rescore.s(model, method)]
for method in methods:
celery_list += [nms.s(model, method)]
celery_task = celery.chain(celery_list)
if remote:
celery_task.apply_async()
else:
celery_task.apply()
except:
session.rollback()
raise
finally:
session.close()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
PARSER = argparse.ArgumentParser()
PARSER.add_argument('--model', required=True)
PARSER.add_argument('--methods', nargs='+', default=scores.METHODS.keys())
PARSER.add_argument('--dataset-id', required=True, type=int)
PARSER.add_argument('--remote', action='store_true')
ARGS = PARSER.parse_args()
test(
model=ARGS.model,
remote=ARGS.remote,
methods=ARGS.methods,
dataset_id=ARGS.dataset_id,
)