-
Notifications
You must be signed in to change notification settings - Fork 11
/
__init__.py
124 lines (101 loc) · 4 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import torch.nn as nn
class AbstractMethodInterface(object):
def __init__(self):
self.name = self.__class__.__name__
def propose_H(self, dataset):
raise NotImplementedError("%s does not have implementations for this"%(self.name))
def train_H(self, dataset):
raise NotImplementedError("%s does not have implementations for this"%(self.name))
def test_H(self, dataset):
raise NotImplementedError("%s does not have implementations for this"%(self.name))
def method_identifier(self):
raise NotImplementedError("Please implement the identifier method for %s"%(self.name))
class AbstractModelWrapper(nn.Module):
def __init__(self, base_model):
super(AbstractModelWrapper, self).__init__()
self.base_model = base_model
if hasattr(self.base_model, 'eval'):
self.base_model.eval()
if hasattr(self.base_model, 'parameters'):
for parameter in self.base_model.parameters():
parameter.requires_grad = False
self.eval_direct = False
self.cache = {} #Be careful what you cache! You wouldn't have infinite memory.
def set_eval_direct(self, eval_direct):
self.eval_direct = eval_direct
def train(self, mode=True):
""" Must override the train mode
because the base_model is always in eval mode.
"""
self.training = mode
for module in self.children():
module.train(mode)
# Now revert back the base_model to eval.
if hasattr(self.base_model, 'eval'):
self.base_model.eval()
return self
def subnetwork_eval(self, x):
raise NotImplementedError
def wrapper_eval(self, x):
raise NotImplementedError
def subnetwork_cached_eval(self, x, indices, group):
output = None
cache = None
if self.cache.has_key(group):
cache = self.cache[group]
else:
cache = {}
all_indices = [cache.has_key(ind) for ind in indices]
if torch.ByteTensor(all_indices).all():
# Then fetch from the cache.
all_outputs = [cache[ind] for ind in indices]
output = torch.cat(all_outputs)
else:
output = self.subnetwork_eval(x)
for i, entry in enumerate(output):
cache[indices[i]] = entry.unsqueeze_(0)
self.cache[group] = cache
return output
def forward(self, x, indices=None, group=None):
input = None
if not self.eval_direct:
if indices is None:
input = self.subnetwork_eval(x)
else:
input = self.subnetwork_cached_eval(x, indices=indices, group=group)
input = input.detach()
input.requires_grad = False
else:
input = x
output = self.wrapper_eval(input)
return output
class SVMLoss(nn.Module):
def __init__(self, margin=1.0):
super(SVMLoss, self).__init__()
self.margin = margin
self.size_average = True
def forward(self, x, target):
target = target.clone()
# 0 labels should be set to -1 for this loss.
target.data[target.data<0.1]=-1
error = self.margin-x*target
loss = torch.clamp(error, min=0)
if self.size_average:
loss = loss.mean()
else:
loss = loss.sum()
return loss
def get_cached(model, dataset_loader, device):
from tqdm import tqdm
outputX, outputY = [], []
with torch.set_grad_enabled(False):
with tqdm(total=len(dataset_loader)) as pbar:
pbar.set_description('Caching data')
for i, (image, label) in enumerate(dataset_loader):
pbar.update()
input, target = image.to(device), label.to(device)
new_input = model.subnetwork_eval(input)
outputX.append(new_input)
outputY.append(target)
return torch.cat(outputX, 0), torch.cat(outputY, 0)