-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
85 lines (73 loc) · 2.52 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from utils.dataset import *
from config import *
from tqdm import tqdm
from utils.vision import show_box_masks, parse_rgb_allImage
from loss import *
from utils.vision import save_loss_rate
def test_dataset():
save_dir = os.path.join(DATA_ROOT, 'box.cache', 'val')
if not os.path.exists(save_dir):
save_dir = BoxDetect.prepare_voc_data(DATA_ROOT,image_set='val')
print(save_dir)
save_dir = os.path.join(DATA_ROOT, 'box.cache', 'trainval')
if not os.path.exists(save_dir):
save_dir = BoxDetect.prepare_voc_data(DATA_ROOT,image_set='trainval')
print(save_dir)
val_data_set = BoxDetect('./data/box.cache/val')
val_data_loader = DataLoader(
val_data_set,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
persistent_workers=True,
drop_last=True
)
loss_func = BoxLoss()
for image, target in tqdm(val_data_loader, desc='Validate', leave=False):
for i in range(20):
show_box_masks(image[i,:,:,:], target[i,:,:,:])
# print(loss_func(target, target))
break
def test_loss():
loss = BoxLoss()
input = [[0, 0, 0,4,5,6,6,1,5,5,8,8,1]]
target = [[0, 0, 0,4,5,6,6,1,5,5,8,8,1]]
pred = torch.tensor(input).reshape(1,1,1,13)
target =torch.tensor(target).reshape(1,1,1,13)
print(loss(pred, target))
def test_mask():
# MaskDetect.prepare_voc_data(DATA_ROOT,image_set='val')
# MaskDetect.prepare_voc_data(DATA_ROOT,image_set='trainval')
val_data_set = MaskDetect('./data/box-mask.cache/val')
val_data_loader = DataLoader(
val_data_set,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
persistent_workers=True,
drop_last=True
)
for image, target, mask in tqdm(val_data_loader, desc='Validate', leave=False):
for i in range(12):
show_box_masks(image[i,:,:,:], target[i,:,:,:], mask[i,:,:,:],color=(1,0,0))
break
def exp_rgbs():
rgbs = torch.tensor([11,12,3])
# temp = []
# for i in rgbs:
# temp.append(i.repeat(10,10))
# rgbs = torch.stack(temp)
# return rgbs
print(rgbs.shape[-1])
return rgbs.unsqueeze(1).unsqueeze(1).repeat((1,10,10)).expand(3, 10, 10)
if __name__ == "__main__":
test_mask()
a = [1,2,3,4,5]
b = [2,3,4,9,8]
c = [4,4,4,4,4]
d = torch.Tensor([a,b,c])
print(d)
print(d.shape)
e = d.unsqueeze(0)
f = e.repeat((8,1,1))
print(f[:, 0, 0])
print(f[:, 0, 4])
print(f[:, 2, 0])