-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
86 lines (65 loc) · 2.73 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from data2 import TrainTransforms, DinoTransforms
from data_voc import VOCSegmentation
import torch
import torch.nn as nn
import os
from torchvision import transforms
from params import SlotAttentionParams
from model2 import PatchModel
from cluster import UnsupervisedMaskIoUMetric, AverageBestOverlapMetric, BboxCorLocMetric
params = SlotAttentionParams()
class Model2(nn.Module):
def __init__(self):
super().__init__()
self.model = PatchModel(num_slots=params.num_slots,
num_iterations=3,
empty_cache=params.empty_cache,
slot_size=384)
def forward(self, x, y):
return self.model(x, y)
model_key = Model2()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
val=43047446
addr = f'/home/rishavp/projects/def-mpederso/rishavp/object_unsup/lightning_logs/version_{val}/checkpoints/'+os.listdir(
f'/home/rishavp/projects/def-mpederso/rishavp/object_unsup/lightning_logs/version_{val}/checkpoints/')[-1]
ckpt_key = torch.load(addr)
model_key.load_state_dict(ckpt_key['state_dict'])
model_key.eval()
model_key.to(device)
invTrans = transforms.Compose([transforms.Normalize(
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
std=[1/0.229, 1/0.224, 1/0.255]
)
])
val = VOCSegmentation(
root=params.data_root,
year='2012',
image_set='val',
transform=TrainTransforms().transforms,
dino_transform=DinoTransforms().transforms,
evo=True
)
val_dataloader = torch.utils.data.DataLoader(val, batch_size=1,
shuffle=True,num_workers=1)
corloc = BboxCorLocMetric().to(
'cuda' if torch.cuda.is_available() else 'cpu')
mbo = AverageBestOverlapMetric().to(
'cuda' if torch.cuda.is_available() else 'cpu')
iou = UnsupervisedMaskIoUMetric(ignore_background=False).to(
'cuda' if torch.cuda.is_available() else 'cpu')
with torch.no_grad():
for img, dino, lal in val_dataloader:
model_key.eval()
model_key.training = False
torch.cuda.empty_cache()
img, dino, lal = img.to('cuda' if params.gpus > 0 else 'cpu'), dino.to(
'cuda' if params.gpus > 0 else 'cpu'), lal.to('cuda' if params.gpus > 0 else 'cpu')
reconstruction, mask, mask_as_image, target, slots = model_key(
img, dino)
msk0 = torch.argmax(mask_as_image, dim=1).detach()
msk0 = torch.nn.functional.one_hot(
msk0, num_classes=8).permute(0, 3, 1, 2)
cor=corloc(msk0,lal)
bo=mbo(msk0,lal)
acc = iou(msk0, lal)
print( f" {round(float(iou.compute().detach().cpu()),4)} corloc: {round(float(corloc.compute().detach().cpu()),4)} mbo: {round(float(mbo.compute().detach().cpu()),4)}")