-
Notifications
You must be signed in to change notification settings - Fork 0
/
ssd.py
203 lines (165 loc) · 7.53 KB
/
ssd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import torch
import torchvision
import torch.nn as n
from d2l import torch as d2l
from torch.nn import functional as F
import matplotlib.pyplot as plt
def cls_predictor(num_inputs, num_anchors, num_classes):
return n.Conv2d(num_inputs, num_anchors*(num_classes+1), kernel_size=3, padding=1)
def bbox_predictor(num_inputs, num_anchors):
return n.Conv2d(num_inputs, num_anchors*4, kernel_size=3, padding=1)
def forward(x, block):
return block(x)
def flatten_pred(pred:torch.Tensor):
return torch.flatten(pred.permute(0,2,3,1), start_dim=1)
def concat_preds(preds:torch.Tensor):
return torch.cat([flatten_pred(p) for p in preds], dim=1)
# 通过最大池化层高宽减半
def down_sample_blk(in_channels, out_channels):
blk=[]
for _ in range(2):
blk.append(n.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
blk.append(n.BatchNorm2d(out_channels))
blk.append(n.ReLU())
in_channels = out_channels
blk.append(n.MaxPool2d(2))
return n.Sequential(*blk)
#基本网络,每块高宽减半,但是channel数增加
def base_net():
blk=[]
num_filters = [3,16,32,64]
for i in range(len(num_filters)-1):
blk.append(down_sample_blk(num_filters[i], num_filters[i+1]))
return n.Sequential(*blk)
def get_blk(i):
if i==0:
blk = base_net()
elif i==1:
blk = down_sample_blk(64,128)
elif i == 4:
blk = n.AdaptiveMaxPool2d((1,1))
else:
blk = down_sample_blk(128,128)
return blk
def blk_forward(X, blk, size, ratio, cls_predictor, bbox_predictor):
Y = blk(X)
anchors = d2l.multibox_prior(Y, sizes=size, ratios=ratio)
# print(anchors.shape)
cls_preds = cls_predictor(Y)
bbox_preds = bbox_predictor(Y)
return (Y, anchors, cls_preds, bbox_preds)
sizes = [[0.2, 0.272], [0.37, 0.447], [0.54, 0.619], [0.71, 0.79],
[0.88, 0.961]]
ratios = [[1, 2, 0.5]] * 5
num_anchors = len(sizes[0]) + len(ratios[0]) - 1
class TinySSD(n.Module):
def __init__(self, num_classes, **kwargs):
super().__init__()
self.num_classes = num_classes
idx_to_in_channels = [64,128,128,128,128]
for i in range(5):
setattr(self, f'blk_{i}', get_blk(i))
setattr(self, f'cls_{i}', cls_predictor(idx_to_in_channels[i], num_anchors, num_classes))
setattr(self, f'bbox_{i}', bbox_predictor(idx_to_in_channels[i],num_anchors))
def forward(self, X):
anchors, cls_preds, bbox_preds = [None]*5, [None]*5, [None]*5
for i in range(5):
# print('i ',i)
X, anchors[i], cls_preds[i], bbox_preds[i] = blk_forward(X, getattr(self, f'blk_{i}'), sizes[i], ratios[i],getattr(self,f'cls_{i}'), getattr(self, f'bbox_{i}'))
anchors = torch.cat(anchors, dim=1)
cls_preds = concat_preds(cls_preds)
cls_preds = cls_preds.reshape(cls_preds.shape[0],-1,self.num_classes+1)
bbox_preds = concat_preds(bbox_preds)
# print('=======++++=====')
# print(anchors.shape, cls_preds.shape, bbox_preds.shape)
# print('==========++++==')
return anchors, cls_preds, bbox_preds
# net = TinySSD(num_classes=1)
# X = torch.zeros((32, 3, 256, 256))
# anchors, cls_preds, bbox_preds = net(X)
# print('output anchors', anchors.shape)
# print('output cls_preds', cls_preds.shape)
# print('output bbox_preds', bbox_preds.shape)
# batch_size = 32
train_iter,_ = d2l.load_data_bananas(batch_size)
device, net = d2l.try_gpu(), TinySSD(num_classes=1)
trainer = torch.optim.SGD(net.parameters(), lr=0.2, weight_decay=5e-4)
cls_loss = n.CrossEntropyLoss(reduction='none')
bbox_loss = n.L1Loss(reduction='none')
def calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks):
batch_size, num_classes = cls_preds.shape[0], cls_preds.shape[2]
cls = cls_loss(cls_preds.reshape(-1, num_classes),
cls_labels.reshape(-1)).reshape(batch_size, -1).mean(dim=1)
bbox = bbox_loss(bbox_preds * bbox_masks,
bbox_labels * bbox_masks).mean(dim=1)
return cls + bbox
def cls_eval(cls_preds, cls_labels):
# 由于类别预测结果放在最后一维,argmax需要指定最后一维。
return float((cls_preds.argmax(dim=-1).type(
cls_labels.dtype) == cls_labels).sum())
def bbox_eval(bbox_preds, bbox_labels, bbox_masks):
return float((torch.abs((bbox_labels - bbox_preds) * bbox_masks)).sum())
num_epochs, timer = 10, d2l.Timer()
net = net.to(device)
# for epoch in range(num_epochs):
# # 训练精确度的和,训练精确度的和中的示例数
# # 绝对误差的和,绝对误差的和中的示例数
# metric = d2l.Accumulator(4)
# net.train()
# print('oooooooooooooooooooooooooo')
# for features, target in train_iter:
# timer.start()
# trainer.zero_grad()
# X, Y = features.to(device), target.to(device)
# print('x shape y shape ===')
# print(X.shape,' ',Y.shape)
# print('x shape y shape ===')
# # 生成多尺度的锚框,为每个锚框预测类别和偏移量
# anchors, cls_preds, bbox_preds = net(X)
# print(anchors.shape,' ',cls_preds.shape,' ',bbox_preds.shape)
# # 为每个锚框标注类别和偏移量
# bbox_labels, bbox_masks, cls_labels = d2l.multibox_target(anchors, Y)
# print(bbox_labels.shape,' ',bbox_masks.shape,' ',cls_labels.shape)
# break
# # 根据类别和偏移量的预测和标注值计算损失函数
# l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,
# bbox_masks)
# l.mean().backward()
# trainer.step()
# metric.add(cls_eval(cls_preds, cls_labels), cls_labels.numel(),
# bbox_eval(bbox_preds, bbox_labels, bbox_masks),
# bbox_labels.numel())
# cls_err, bbox_mae = 1 - metric[0] / metric[1], metric[2] / metric[3]
# print('\n epoch {}, cls_err {}, bbox_mae {}'.format(epoch, cls_err, bbox_mae))
# # torch.save(net.state_dict(), r'C:\Users\Orange\Desktop\python\deepLearn\nn\ssd.params')
# print(f'class err {cls_err:.2e}, bbox mae {bbox_mae:.2e}')
# print(f'{len(train_iter.dataset) / timer.stop():.1f} examples/sec on '
# f'{str(device)}')
#预测
net.load_state_dict(torch.load(r'C:\Users\Orange\Desktop\python\deepLearn\nn\sssd.params'))
X = torchvision.io.read_image(r'C:\Users\Orange\Desktop\python\deepLearn\nn\result\banana.jpeg').unsqueeze(0).float()
img = X.squeeze(0).permute(1, 2, 0).long()
def predict(X):
net.eval()
anchors, cls_preds, bbox_preds = net(X.to(device))
print('===================+++++++++')
print(anchors.shape,' ',cls_preds.shape,' ',bbox_preds.shape)
print('+=+=+ ',cls_preds[:,:10,:])
cls_probs = F.softmax(cls_preds, dim=2).permute(0, 2, 1)
print('===--',cls_probs.shape)
output = d2l.multibox_detection(cls_probs, bbox_preds, anchors)
idx = [i for i, row in enumerate(output[0]) if row[0] != -1]
return output[0, idx]
output = predict(X)
def display(img, output, threshold):
d2l.set_figsize((5, 5))
fig = d2l.plt.imshow(img)
for row in output:
score = float(row[1])
if score < threshold:
continue
h, w = img.shape[0:2]
bbox = [row[2:6] * torch.tensor((w, h, w, h), device=row.device)]
d2l.show_bboxes(fig.axes, bbox, '%.2f' % score, 'w')
display(img, output.cpu(), threshold=0.9)
plt.show()