-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
115 lines (108 loc) · 4.71 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from __future__ import print_function
from __future__ import division
import os
import sys
import time
import datetime
import argparse
import os.path as osp
import numpy as np
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torchFewShot.models.net import Model
from torchFewShot.data_manager import DataManager
from torchFewShot.losses import CrossEntropyLoss
from torchFewShot.optimizers import init_optimizer
from torchFewShot.utils.iotools import save_checkpoint, check_isfile
from torchFewShot.utils.avgmeter import AverageMeter
from torchFewShot.utils.logger import Logger
from torchFewShot.utils.torchtools import one_hot, adjust_learning_rate
from args_mini import add_arguments as add_arguments_mini
from args_tiered import add_arguments as add_arguments_tiered
from args_CBM_1_shot import add_arguments as add_arguments_CBM_1_shot
from args_CBM_5_shot import add_arguments as add_arguments_CBM_5_shot
from args_CBM_LLE_1_shot import add_arguments as add_arguments_CBM_LLE_1_shot
from args_CBM_LLE_5_shot import add_arguments as add_arguments_CBM_LLE_5_shot
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
parser_mini = subparsers.add_parser('mini')
add_arguments_mini(parser_mini)
parser_tiered = subparsers.add_parser('tiered')
add_arguments_tiered(parser_tiered)
parser_CBM_1_shot = subparsers.add_parser('CBM_1_shot')
add_arguments_CBM_1_shot(parser_CBM_1_shot)
parser_CBM_5_shot = subparsers.add_parser('CBM_5_shot')
add_arguments_CBM_5_shot(parser_CBM_5_shot)
parser_CBM_LLE_1_shot = subparsers.add_parser('CBM_LLE_1_shot')
add_arguments_CBM_LLE_1_shot(parser_CBM_LLE_1_shot)
parser_CBM_LLE_5_shot = subparsers.add_parser('CBM_LLE_5_shot')
add_arguments_CBM_LLE_5_shot(parser_CBM_LLE_5_shot)
args = parser.parse_args()
args.phase = 'test'
def main():
torch.manual_seed(args.seed)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
use_gpu = torch.cuda.is_available()
args.save_dir = osp.join(args.save_dir, str(args.nExemplars)+'-shot')
args.resume = osp.join(args.resume, str(
args.nExemplars)+'-shot', 'best_model.pth.tar')
sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
print("==========\nArgs:{}\n==========".format(args))
if use_gpu:
print("Currently using GPU {}".format(args.gpu_devices))
cudnn.benchmark = True
torch.cuda.manual_seed_all(args.seed)
else:
print("Currently using CPU (GPU is highly recommended)")
print('Initializing image data manager')
dm = DataManager(args, use_gpu)
trainloader, testloader = dm.return_dataloaders()
model = Model(args)
# load the model
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
print("Loaded checkpoint from '{}'".format(args.resume))
if use_gpu:
model = model.cuda()
test(model, testloader, use_gpu)
def test(model, testloader, use_gpu):
accs = AverageMeter()
test_accuracies = []
model.eval()
with torch.no_grad():
for batch_idx, (images_train, labels_train, images_test, labels_test) in enumerate(testloader):
if use_gpu:
images_train = images_train.cuda()
images_test = images_test.cuda()
end = time.time()
batch_size, num_train_examples, channels, height, width = images_train.size()
num_test_examples = images_test.size(1)
labels_train_1hot = one_hot(labels_train).cuda()
labels_test_1hot = one_hot(labels_test).cuda()
cls_scores = model(images_train, images_test,
labels_train_1hot, labels_test_1hot)
cls_scores = cls_scores.view(batch_size * num_test_examples, -1)
labels_test = labels_test.view(batch_size * num_test_examples)
_, preds = torch.max(cls_scores.detach().cpu(), 1)
acc = (torch.sum(preds == labels_test.detach().cpu()
).float()) / labels_test.size(0)
accs.update(acc.item(), labels_test.size(0))
gt = (preds == labels_test.detach().cpu()).float()
gt = gt.view(batch_size, num_test_examples).numpy() # [b, n]
acc = np.sum(gt, 1) / num_test_examples
acc = np.reshape(acc, (batch_size))
test_accuracies.append(acc)
accuracy = accs.avg
test_accuracies = np.array(test_accuracies)
test_accuracies = np.reshape(test_accuracies, -1)
stds = np.std(test_accuracies, 0)
ci95 = 1.96 * stds / np.sqrt(args.epoch_size)
print('Accuracy: {:.2%}, std: :{:.2%}'.format(accuracy, ci95))
return accuracy
if __name__ == '__main__':
main()