-
Notifications
You must be signed in to change notification settings - Fork 28
/
models.py
152 lines (133 loc) · 6.76 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import tensorflow as tf
import numpy as np
import math
import utils
from torchvision import models as torch_models
from torch.nn import DataParallel
from madry_mnist.model import Model as madry_model_mnist
from madry_cifar10.model import Model as madry_model_cifar10
from logit_pairing.models import LeNet as lp_model_mnist, ResNet20_v2 as lp_model_cifar10
from post_avg.postAveragedModels import pa_resnet110_config1 as post_avg_cifar10_resnet
from post_avg.postAveragedModels import pa_resnet152_config1 as post_avg_imagenet_resnet
class Model:
def __init__(self, batch_size, gpu_memory):
self.batch_size = batch_size
self.gpu_memory = gpu_memory
def predict(self, x):
raise NotImplementedError('use ModelTF or ModelPT')
def loss(self, y, logits, targeted=False, loss_type='margin_loss'):
""" Implements the margin loss (difference between the correct and 2nd best class). """
if loss_type == 'margin_loss':
preds_correct_class = (logits * y).sum(1, keepdims=True)
diff = preds_correct_class - logits # difference between the correct class and all other classes
diff[y] = np.inf # to exclude zeros coming from f_correct - f_correct
margin = diff.min(1, keepdims=True)
loss = margin * -1 if targeted else margin
elif loss_type == 'cross_entropy':
probs = utils.softmax(logits)
loss = -np.log(probs[y])
loss = loss * -1 if not targeted else loss
else:
raise ValueError('Wrong loss.')
return loss.flatten()
class ModelTF(Model):
"""
Wrapper class around TensorFlow models.
In order to incorporate a new model, one has to ensure that self.model has a TF variable `logits`,
and that the preprocessing of the inputs is done correctly (e.g. subtracting the mean and dividing over the
standard deviation).
"""
def __init__(self, model_name, batch_size, gpu_memory):
super().__init__(batch_size, gpu_memory)
model_folder = model_path_dict[model_name]
model_file = tf.train.latest_checkpoint(model_folder)
self.model = model_class_dict[model_name]()
self.batch_size = batch_size
self.model_name = model_name
self.model_file = model_file
if 'logits' not in self.model.__dict__:
self.model.logits = self.model.pre_softmax
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory)
config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)
self.sess = tf.Session(config=config)
tf.train.Saver().restore(self.sess, model_file)
def predict(self, x):
if 'mnist' in self.model_name:
shape = self.model.x_input.shape[1:].as_list()
x = np.reshape(x, [-1, *shape])
elif 'cifar10' in self.model_name:
x = np.transpose(x, axes=[0, 2, 3, 1])
n_batches = math.ceil(x.shape[0] / self.batch_size)
logits_list = []
for i in range(n_batches):
x_batch = x[i*self.batch_size:(i+1)*self.batch_size]
logits = self.sess.run(self.model.logits, feed_dict={self.model.x_input: x_batch})
logits_list.append(logits)
logits = np.vstack(logits_list)
return logits
class ModelPT(Model):
"""
Wrapper class around PyTorch models.
In order to incorporate a new model, one has to ensure that self.model is a callable object that returns logits,
and that the preprocessing of the inputs is done correctly (e.g. subtracting the mean and dividing over the
standard deviation).
"""
def __init__(self, model_name, batch_size, gpu_memory):
super().__init__(batch_size, gpu_memory)
if model_name in ['pt_vgg', 'pt_resnet', 'pt_inception', 'pt_densenet']:
model = model_class_dict[model_name](pretrained=True)
self.mean = np.reshape([0.485, 0.456, 0.406], [1, 3, 1, 1])
self.std = np.reshape([0.229, 0.224, 0.225], [1, 3, 1, 1])
model = DataParallel(model.cuda())
else:
model = model_class_dict[model_name]()
if model_name in ['pt_post_avg_cifar10', 'pt_post_avg_imagenet']:
# checkpoint = torch.load(model_path_dict[model_name])
self.mean = np.reshape([0.485, 0.456, 0.406], [1, 3, 1, 1])
self.std = np.reshape([0.229, 0.224, 0.225], [1, 3, 1, 1])
else:
model = DataParallel(model).cuda()
checkpoint = torch.load(model_path_dict[model_name] + '.pth')
self.mean = np.reshape([0.485, 0.456, 0.406], [1, 3, 1, 1])
self.std = np.reshape([0.225, 0.225, 0.225], [1, 3, 1, 1])
model.load_state_dict(checkpoint)
model.float()
self.mean, self.std = self.mean.astype(np.float32), self.std.astype(np.float32)
model.eval()
self.model = model
def predict(self, x):
x = (x - self.mean) / self.std
x = x.astype(np.float32)
n_batches = math.ceil(x.shape[0] / self.batch_size)
logits_list = []
with torch.no_grad(): # otherwise consumes too much memory and leads to a slowdown
for i in range(n_batches):
x_batch = x[i*self.batch_size:(i+1)*self.batch_size]
x_batch_torch = torch.as_tensor(x_batch, device=torch.device('cuda'))
logits = self.model(x_batch_torch).cpu().numpy()
logits_list.append(logits)
logits = np.vstack(logits_list)
return logits
model_path_dict = {'madry_mnist_robust': 'madry_mnist/models/robust',
'madry_cifar10_robust': 'madry_cifar10/models/robust',
'clp_mnist': 'logit_pairing/models/clp_mnist',
'lsq_mnist': 'logit_pairing/models/lsq_mnist',
'clp_cifar10': 'logit_pairing/models/clp_cifar10',
'lsq_cifar10': 'logit_pairing/models/lsq_cifar10',
'pt_post_avg_cifar10': 'post_avg/trainedModel/resnet110.th'
}
model_class_dict = {'pt_vgg': torch_models.vgg16_bn,
'pt_resnet': torch_models.resnet50,
'pt_inception': torch_models.inception_v3,
'pt_densenet': torch_models.densenet121,
'madry_mnist_robust': madry_model_mnist,
'madry_cifar10_robust': madry_model_cifar10,
'clp_mnist': lp_model_mnist,
'lsq_mnist': lp_model_mnist,
'clp_cifar10': lp_model_cifar10,
'lsq_cifar10': lp_model_cifar10,
'pt_post_avg_cifar10': post_avg_cifar10_resnet,
'pt_post_avg_imagenet': post_avg_imagenet_resnet,
}
all_model_names = list(model_class_dict.keys())