forked from kjunelee/MetaOptNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
156 lines (129 loc) · 6.17 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# -*- coding: utf-8 -*-
import argparse
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
from tqdm import tqdm
from models.protonet_embedding import ProtoNetEmbedding
from models.R2D2_embedding import R2D2Embedding
from models.ResNet12_embedding import resnet12
from models.classification_heads import ClassificationHead
from utils import pprint, set_gpu, Timer, count_accuracy, log
import numpy as np
import os
def get_model(options):
# Choose the embedding network
if options.network == 'ProtoNet':
network = ProtoNetEmbedding().cuda()
elif options.network == 'R2D2':
network = R2D2Embedding().cuda()
elif options.network == 'ResNet':
if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=5).cuda()
network = torch.nn.DataParallel(network, device_ids=[0, 1, 2, 3])
else:
network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=2).cuda()
else:
print ("Cannot recognize the network type")
assert(False)
# Choose the classification head
if opt.head == 'ProtoNet':
cls_head = ClassificationHead(base_learner='ProtoNet').cuda()
elif opt.head == 'Ridge':
cls_head = ClassificationHead(base_learner='Ridge').cuda()
elif opt.head == 'R2D2':
cls_head = ClassificationHead(base_learner='R2D2').cuda()
elif opt.head == 'SVM':
cls_head = ClassificationHead(base_learner='SVM-CS').cuda()
else:
print ("Cannot recognize the classification head type")
assert(False)
return (network, cls_head)
def get_dataset(options):
# Choose the embedding network
if options.dataset == 'miniImageNet':
from data.mini_imagenet import MiniImageNet, FewShotDataloader
dataset_test = MiniImageNet(phase='test')
data_loader = FewShotDataloader
elif options.dataset == 'tieredImageNet':
from data.tiered_imagenet import tieredImageNet, FewShotDataloader
dataset_test = tieredImageNet(phase='test')
data_loader = FewShotDataloader
elif options.dataset == 'CIFAR_FS':
from data.CIFAR_FS import CIFAR_FS, FewShotDataloader
dataset_test = CIFAR_FS(phase='test')
data_loader = FewShotDataloader
elif options.dataset == 'FC100':
from data.FC100 import FC100, FewShotDataloader
dataset_test = FC100(phase='test')
data_loader = FewShotDataloader
else:
print ("Cannot recognize the dataset type")
assert(False)
return (dataset_test, data_loader)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', default='0')
parser.add_argument('--load', default='./experiments/exp_1/best_model.pth',
help='path of the checkpoint file')
parser.add_argument('--episode', type=int, default=1000,
help='number of episodes to test')
parser.add_argument('--way', type=int, default=5,
help='number of classes in one test episode')
parser.add_argument('--shot', type=int, default=1,
help='number of support examples per training class')
parser.add_argument('--query', type=int, default=15,
help='number of query examples per training class')
parser.add_argument('--network', type=str, default='ProtoNet',
help='choose which embedding network to use. ProtoNet, R2D2, ResNet')
parser.add_argument('--head', type=str, default='ProtoNet',
help='choose which embedding network to use. ProtoNet, Ridge, R2D2, SVM')
parser.add_argument('--dataset', type=str, default='miniImageNet',
help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100')
opt = parser.parse_args()
(dataset_test, data_loader) = get_dataset(opt)
dloader_test = data_loader(
dataset=dataset_test,
nKnovel=opt.way,
nKbase=0,
nExemplars=opt.shot, # num training examples per novel category
nTestNovel=opt.query * opt.way, # num test examples for all the novel categories
nTestBase=0, # num test examples for all the base categories
batch_size=1,
num_workers=1,
epoch_size=opt.episode, # num of batches per epoch
)
set_gpu(opt.gpu)
log_file_path = os.path.join(os.path.dirname(opt.load), "test_log.txt")
log(log_file_path, str(vars(opt)))
# Define the models
(embedding_net, cls_head) = get_model(opt)
# Load saved model checkpoints
saved_models = torch.load(opt.load)
embedding_net.load_state_dict(saved_models['embedding'])
embedding_net.eval()
cls_head.load_state_dict(saved_models['head'])
cls_head.eval()
# Evaluate on test set
test_accuracies = []
for i, batch in enumerate(tqdm(dloader_test()), 1):
data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch]
n_support = opt.way * opt.shot
n_query = opt.way * opt.query
emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:])))
emb_support = emb_support.reshape(1, n_support, -1)
emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:])))
emb_query = emb_query.reshape(1, n_query, -1)
if opt.head == 'SVM':
logits = cls_head(emb_query, emb_support, labels_support, opt.way, opt.shot, maxIter=3)
else:
logits = cls_head(emb_query, emb_support, labels_support, opt.way, opt.shot)
acc = count_accuracy(logits.reshape(-1, opt.way), labels_query.reshape(-1))
test_accuracies.append(acc.item())
avg = np.mean(np.array(test_accuracies))
std = np.std(np.array(test_accuracies))
ci95 = 1.96 * std / np.sqrt(i + 1)
if i % 50 == 0:
print('Episode [{}/{}]:\t\t\tAccuracy: {:.2f} ± {:.2f} % ({:.2f} %)'\
.format(i, opt.episode, avg, ci95, acc))