-
Notifications
You must be signed in to change notification settings - Fork 0
/
prune.py
135 lines (105 loc) · 5.1 KB
/
prune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import time
import numpy as np
import torch
from torchvision import models
def replace_layers(model, i, indexes, layers):
if i in indexes:
return layers[indexes.index(i)]
return model[i]
def prune_vgg16_conv_layer(model, layer_index, filter_index):
_, conv = list(model.features._modules.items())[layer_index]
next_conv = None
offset = 1
while layer_index + offset < len(list(model.features._modules.items())):
res = list(model.features._modules.items())[layer_index + offset]
if isinstance(res[1], torch.nn.modules.conv.Conv2d):
next_name, next_conv = res
break
offset = offset + 1
new_conv = \
torch.nn.Conv2d(in_channels=conv.in_channels, \
out_channels=conv.out_channels - 1,
kernel_size=conv.kernel_size, \
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
# bias=conv.bias
bias = True
# https://github.com/jacobgil/pytorch-pruning/issues/6
)
old_weights = conv.weight.data.cpu().numpy()
new_weights = new_conv.weight.data.cpu().numpy()
new_weights[: filter_index, :, :, :] = old_weights[: filter_index, :, :, :]
new_weights[filter_index:, :, :, :] = old_weights[filter_index + 1:, :, :, :]
new_conv.weight.data = torch.from_numpy(new_weights).cuda()
bias_numpy = conv.bias.data.cpu().numpy()
bias = np.zeros(shape=(bias_numpy.shape[0] - 1), dtype=np.float32)
bias[:filter_index] = bias_numpy[:filter_index]
bias[filter_index:] = bias_numpy[filter_index + 1:]
new_conv.bias.data = torch.from_numpy(bias).cuda()
if not next_conv is None:
next_new_conv = \
torch.nn.Conv2d(in_channels=next_conv.in_channels - 1, \
out_channels=next_conv.out_channels, \
kernel_size=next_conv.kernel_size, \
stride=next_conv.stride,
padding=next_conv.padding,
dilation=next_conv.dilation,
groups=next_conv.groups,
# bias=next_conv.bias
bias = True
)
old_weights = next_conv.weight.data.cpu().numpy()
new_weights = next_new_conv.weight.data.cpu().numpy()
new_weights[:, : filter_index, :, :] = old_weights[:, : filter_index, :, :]
new_weights[:, filter_index:, :, :] = old_weights[:, filter_index + 1:, :, :]
next_new_conv.weight.data = torch.from_numpy(new_weights).cuda()
next_new_conv.bias.data = next_conv.bias.data
if not next_conv is None:
features = torch.nn.Sequential(
*(replace_layers(model.features, i, [layer_index, layer_index + offset], \
[new_conv, next_new_conv]) for i, _ in enumerate(model.features)))
del model.features
del conv
model.features = features
else:
# Prunning the last conv layer. This affects the first linear layer of the classifier.
model.features = torch.nn.Sequential(
*(replace_layers(model.features, i, [layer_index], \
[new_conv]) for i, _ in enumerate(model.features)))
layer_index = 0
old_linear_layer = None
for _, module in list(model.classifier._modules.items()):
if isinstance(module, torch.nn.Linear):
old_linear_layer = module
break
layer_index = layer_index + 1
if old_linear_layer is None:
raise BaseException("No linear laye found in classifier")
params_per_input_channel = int(old_linear_layer.in_features / conv.out_channels)
new_linear_layer = \
torch.nn.Linear(old_linear_layer.in_features - params_per_input_channel,
old_linear_layer.out_features)
old_weights = old_linear_layer.weight.data.cpu().numpy()
new_weights = new_linear_layer.weight.data.cpu().numpy()
new_weights[:, : filter_index * params_per_input_channel] = \
old_weights[:, : filter_index * params_per_input_channel]
new_weights[:, filter_index * params_per_input_channel:] = \
old_weights[:, (filter_index + 1) * params_per_input_channel:]
new_linear_layer.bias.data = old_linear_layer.bias.data
new_linear_layer.weight.data = torch.from_numpy(new_weights).cuda()
classifier = torch.nn.Sequential(
*(replace_layers(model.classifier, i, [layer_index], \
[new_linear_layer]) for i, _ in enumerate(model.classifier)))
del model.classifier
del next_conv
del conv
model.classifier = classifier
return model
if __name__ == '__main__':
model = models.vgg16(pretrained=True)
model.train()
t0 = time.time()
model = prune_vgg16_conv_layer(model, 28, 10)
print("The prunning took", time.time() - t0)