-
Notifications
You must be signed in to change notification settings - Fork 82
/
test_classification.py
81 lines (68 loc) · 3.47 KB
/
test_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
import torch
from utilities.utils import model_parameters, compute_flops
from utilities.train_eval_classification import validate
import os
from data_loader.classification.imagenet import val_loader as loader
from utilities.print_utils import *
#============================================
__author__ = "Sachin Mehta"
__license__ = "MIT"
__maintainer__ = "Sachin Mehta"
#============================================
def main(args):
# create model
if args.model == 'dicenet':
from model.classification import dicenet as net
model = net.CNNModel(args)
elif args.model == 'espnetv2':
from model.classification import espnetv2 as net
model = net.EESPNet(args)
elif args.model == 'shufflenetv2':
from model.classification import shufflenetv2 as net
model = net.CNNModel(args)
else:
NotImplementedError('Model {} not yet implemented'.format(args.model))
exit()
num_params = model_parameters(model)
flops = compute_flops(model)
print_info_message('FLOPs: {:.2f} million'.format(flops))
print_info_message('Network Parameters: {:.2f} million'.format(num_params))
if not args.weights:
print_info_message('Grabbing location of the ImageNet weights from the weight dictionary')
from model.weight_locations.classification import model_weight_map
weight_file_key = '{}_{}'.format(args.model, args.s)
assert weight_file_key in model_weight_map.keys(), '{} does not exist'.format(weight_file_key)
args.weights = model_weight_map[weight_file_key]
num_gpus = torch.cuda.device_count()
device = 'cuda' if num_gpus >=1 else 'cpu'
weight_dict = torch.load(args.weights, map_location=torch.device(device))
model.load_state_dict(weight_dict)
if num_gpus >= 1:
model = torch.nn.DataParallel(model)
model = model.cuda()
if torch.backends.cudnn.is_available():
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
cudnn.deterministic = True
# Data loading code
val_loader = loader(args)
validate(val_loader, model, criteria=None, device=device)
if __name__ == '__main__':
from commons.general_details import classification_models, classification_datasets
parser = argparse.ArgumentParser(description='Testing efficient networks')
parser.add_argument('--workers', default=4, type=int, help='number of data loading workers (default: 4)')
parser.add_argument('--data', default='', help='path to dataset')
parser.add_argument('--dataset', default='imagenet', help='Name of the dataset', choices=classification_datasets)
parser.add_argument('--batch-size', default=512, type=int, help='mini-batch size (default: 512)')
parser.add_argument('--num-classes', default=1000, type=int, help='# of classes in the dataset')
parser.add_argument('--s', default=1, type=float, help='Width scaling factor')
parser.add_argument('--weights', type=str, default='', help='weight file')
parser.add_argument('--inpSize', default=224, type=int, help='Input size')
##Select a model
parser.add_argument('--model', default='dicenet', choices=classification_models, help='Which model?')
parser.add_argument('--model-width', default=224, type=int, help='Model width')
parser.add_argument('--model-height', default=224, type=int, help='Model height')
parser.add_argument('--channels', default=3, type=int, help='Input channels')
args = parser.parse_args()
main(args)