forked from sh8/RePOSE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
92 lines (75 loc) · 2.74 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from lib.config import cfg, args
import numpy as np
import os
def run_evaluate():
import time
from lib.datasets import make_data_loader
from lib.evaluators import make_evaluator
import tqdm
import torch
from lib.networks import make_network
from lib.utils.net_utils import load_network
network = make_network(cfg).cuda()
load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
evaluator = make_evaluator(cfg)
outputs = []
tot_elapsed_time = 0.0
tot_valid_cnt = 0
print('Start inference...')
with torch.inference_mode():
for i, batch in enumerate(tqdm.tqdm(data_loader)):
inp = batch['inp'].cuda()
K = batch['K'].cuda()
x_ini = batch['x_ini'].cuda()
bbox = batch['bbox'].cuda()
x2s = batch['x2s'].cuda()
x4s = batch['x4s'].cuda()
x8s = batch['x8s'].cuda()
xfc = batch['xfc'].cuda()
output, elapsed_time, is_valid = network(inp, K, x_ini, bbox, x2s,
x4s, x8s, xfc)
if is_valid:
tot_elapsed_time += elapsed_time
tot_valid_cnt += 1
outputs.append(output)
print('Start computing ADD(-S) metrics...')
for i, batch in enumerate(tqdm.tqdm(data_loader)):
output = outputs[i]
evaluator.evaluate(output, batch)
print('Average FPS:', 1000 / (tot_elapsed_time / tot_valid_cnt))
evaluator.summarize()
def run_visualize():
from lib.networks import make_network
from lib.datasets import make_data_loader
from lib.utils.net_utils import load_network
import tqdm
import torch
from lib.visualizers import make_visualizer
network = make_network(cfg).cuda()
load_network(network,
cfg.model_dir,
resume=cfg.resume,
epoch=cfg.test.epoch)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
visualizer = make_visualizer(cfg)
for i, batch in enumerate(tqdm.tqdm(data_loader)):
inp = batch['inp'].cuda()
K = batch['K'].cuda()
x_ini = batch['x_ini'].cuda()
bbox = batch['bbox'].cuda()
x2s = batch['x2s'].cuda()
x4s = batch['x4s'].cuda()
x8s = batch['x8s'].cuda()
xfc = batch['xfc'].cuda()
with torch.inference_mode():
output, _, _ = network(inp, K, x_ini, bbox, x2s, x4s, x8s, xfc)
visualizer.visualize(output, batch, i)
def run_linemod():
from lib.datasets.linemod import linemod_to_coco
linemod_to_coco.linemod_to_coco(cfg)
if __name__ == '__main__':
globals()['run_' + args.type]()