-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
59 lines (49 loc) · 2.18 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import torch
class TwoCropsTransform:
"""Take two random crops of one image as the query and key."""
def __init__(self, base_transform):
self.base_transform = base_transform
def __call__(self, x):
q = self.base_transform(x)
k = self.base_transform(x)
return [q, k]
def load_checkpoints(resume,model,optimizer, gpu):
if os.path.isfile(resume):
print("=> loading checkpoint '{}'".format(resume))
if gpu is None:
checkpoint = torch.load(resume)
else:
# Map model to be loaded to specified single gpu.
loc = 'cuda:{}'.format(gpu)
checkpoint = torch.load(resume, map_location=loc)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(resume, checkpoint['epoch']))
return model, optimizer, start_epoch
else:
print("=> no checkpoint found at '{}'".format(resume))
def load_pretrained_checkpoints(pretrained,model,optimizer, gpu):
# load from pre-trained
if pretrained:
if os.path.isfile(pretrained):
print("=> loading checkpoint '{}'".format(pretrained))
checkpoint = torch.load(pretrained, map_location="cpu")
# rename moco pre-trained keys
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
# retain only encoder up to before the embedding layer
if k.startswith('encoder') and not k.startswith('encoder.fc'):
# remove prefix
state_dict[k[len("encoder."):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
start_epoch = 0
msg = model.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
print("=> loaded pre-trained model '{}'".format(pretrained))
return model, optimizer, start_epoch
else:
print("=> no checkpoint found at '{}'".format(pretrained))