-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy patheval.py
108 lines (92 loc) · 4.94 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
With this script you can evaluate checkpoints or test models from two popular
landmark retrieval github repos.
The first is https://github.com/naver/deep-image-retrieval from Naver labs,
provides ResNet-50 and ResNet-101 trained with AP on Google Landmarks 18 clean.
$ python eval.py --off_the_shelf=naver --l2=none --backbone=resnet101conv5 --aggregation=gem --fc_output_dim=2048
The second is https://github.com/filipradenovic/cnnimageretrieval-pytorch from
Radenovic, provides ResNet-50 and ResNet-101 trained with a triplet loss
on Google Landmarks 18 and sfm120k.
$ python eval.py --off_the_shelf=radenovic_gldv1 --l2=after_pool --backbone=resnet101conv5 --aggregation=gem --fc_output_dim=2048
$ python eval.py --off_the_shelf=radenovic_sfm --l2=after_pool --backbone=resnet101conv5 --aggregation=gem --fc_output_dim=2048
Note that although the architectures are almost the same, Naver's
implementation does not use a l2 normalization before/after the GeM aggregation,
while Radenovic's uses it after (and we use it before, which shows better
results in VG)
"""
import os
import sys
import torch
import parser
import logging
import sklearn
from os.path import join
from datetime import datetime
from torch.utils.model_zoo import load_url
from google_drive_downloader import GoogleDriveDownloader as gdd
import test
import util
import commons
import datasets_ws
from model import network
OFF_THE_SHELF_RADENOVIC = {
'resnet50conv5_sfm' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth',
'resnet101conv5_sfm' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth',
'resnet50conv5_gldv1' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth',
'resnet101conv5_gldv1' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth',
}
OFF_THE_SHELF_NAVER = {
"resnet50conv5" : "1oPtE_go9tnsiDLkWjN4NMpKjh-_md1G5",
'resnet101conv5' : "1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy"
}
######################################### SETUP #########################################
args = parser.parse_arguments()
start_time = datetime.now()
args.save_dir = join("test", args.save_dir, start_time.strftime('%Y-%m-%d_%H-%M-%S'))
commons.setup_logging(args.save_dir)
commons.make_deterministic(args.seed)
logging.info(f"Arguments: {args}")
logging.info(f"The outputs are being saved in {args.save_dir}")
######################################### MODEL #########################################
model = network.GeoLocalizationNet(args)
model = model.to(args.device)
model = torch.nn.DataParallel(model)
if args.aggregation in ["netvlad", "crn"]:
args.features_dim *= args.netvlad_clusters
if args.off_the_shelf.startswith("radenovic") or args.off_the_shelf.startswith("naver"):
if args.off_the_shelf.startswith("radenovic"):
pretrain_dataset_name = args.off_the_shelf.split("_")[1] # sfm or gldv1 datasets
url = OFF_THE_SHELF_RADENOVIC[f"{args.backbone}_{pretrain_dataset_name}"]
state_dict = load_url(url, model_dir=join("data", "off_the_shelf_nets"))
else:
# This is a hacky workaround to maintain compatibility
sys.modules['sklearn.decomposition.pca'] = sklearn.decomposition._pca
zip_file_path = join("data", "off_the_shelf_nets", args.backbone + "_naver.zip")
if not os.path.exists(zip_file_path):
gdd.download_file_from_google_drive(file_id=OFF_THE_SHELF_NAVER[args.backbone],
dest_path=zip_file_path, unzip=True)
if args.backbone == "resnet50conv5":
state_dict_filename = "Resnet50-AP-GeM.pt"
elif args.backbone == "resnet101conv5":
state_dict_filename = "Resnet-101-AP-GeM.pt"
state_dict = torch.load(join("data", "off_the_shelf_nets", state_dict_filename))
state_dict = state_dict["state_dict"]
model_keys = model.state_dict().keys()
renamed_state_dict = {k: v for k, v in zip(model_keys, state_dict.values())}
model.load_state_dict(renamed_state_dict)
elif args.resume != None:
state_dict = torch.load(args.resume)["model_state_dict"]
model.load_state_dict(state_dict, False)
if args.pca_dim == None:
pca = None
else:
full_features_dim = args.features_dim
args.features_dim = args.pca_dim
pca = util.compute_pca(args, model, args.pca_dataset_folder, full_features_dim)
######################################### DATASETS #########################################
test_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "test")
logging.info(f"Test set: {test_ds}")
######################################### TEST on TEST SET #########################################
recalls, recalls_str = test.test(args, test_ds, model, args.test_method, pca)
logging.info(f"Recalls on {test_ds}: {recalls_str}")
logging.info(f"Finished in {str(datetime.now() - start_time)[:-7]}")