-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathyolo.py
321 lines (274 loc) · 10.7 KB
/
yolo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
"""
LICENSE AGREEMENT
By downloading, copying, installing or using the software you agree to this license.
If you do not agree to this license, do not download, install, modify, copy or use the software in any part, partial or whole.
Do not remove this license notice.
"""
## IMPORTS
# Image processing
import cv2
import torch
# Data management
import numpy as np
# Get adjacent zones
def get_adj_zone_locations(pos, array_dim=(3,2), addself=True):
"""
Gets the adjacent locations of a given location in a 2d array as a 1d array
Parameters
----------
pos: Position of tile in 1d array
array_dim: Dimensions of 1d array in 2ds
addself: whether or not to include self position as adjacent to self position
Return
----------
Positions of element in 1d array of location 2d array adjacent to position of pos in 1d array.
"""
x_pos = (pos%array_dim[0])
y_pos = (pos%array_dim[1])
zoned_arr = []
counter = 0
for i in range(0, array_dim[1]):
y = []
for j in range(0, array_dim[0]):
y.append(counter)
counter+=1
zoned_arr.append(y)
return_arr=[]
for r in [-1, 0, 1]:
for c in [-1, 0, 1]:
if r == c == 0:
continue
if 0 <= x_pos+r < array_dim[0] and 0 <= y_pos+c < array_dim[1]:
return_arr.append(zoned_arr[y_pos+c][x_pos+r])
if addself:
return_arr.append(pos)
return return_arr
# Normalizes image for processing
def normalize(
img,
img_mean=(78.4263377603, 87.7689143744, 114.895847746),
img_scale=1/256
):
"""
Outputs a normalized version of the input image
Parameters
----------
img: image to be normalized
img_mean: mean of image to be normalized to
img_scale: scale by which to normalize by
Return
----------
Normalized image
"""
img = np.array(img, dtype=np.float32)
img = (img - img_mean) * img_scale
return img
def get_iou(bb1, bb2):
"""
Calculate the Intersection over Union (IoU) of two bounding boxes.
Author: Martin Thoma
Source: https://stackoverflow.com/a/42874377/13171500
Parameters
----------
bb1 : dict
Keys: {'x1', 'x2', 'y1', 'y2'}
The (x1, y1) position is at the top left corner,
the (x2, y2) position is at the bottom right corner
bb2 : dict
Keys: {'x1', 'x2', 'y1', 'y2'}
The (x, y) position is at the top left corner,
the (x2, y2) position is at the bottom right corner
Returns
-------
float
in [0, 1]
"""
assert bb1['x1'] < bb1['x2']
assert bb1['y1'] < bb1['y2']
assert bb2['x1'] < bb2['x2']
assert bb2['y1'] < bb2['y2']
# determine the coordinates of the intersection rectangle
x_left = max(bb1['x1'], bb2['x1'])
y_top = max(bb1['y1'], bb2['y1'])
x_right = min(bb1['x2'], bb2['x2'])
y_bottom = min(bb1['y2'], bb2['y2'])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The intersection of two axis-aligned bounding boxes is always an
# axis-aligned bounding box
intersection_area = (x_right - x_left) * (y_bottom - y_top)
# compute the area of both AABBs
bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1'])
bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1'])
# compute the intersection over union by taking the intersection
# area and dividing it by the sum of prediction + ground-truth
# areas - the interesection area
iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
assert iou >= 0.0
assert iou <= 1.0
return iou
def get_iou_arr(bb1, bb2):
b1 = {'x1':bb1[0], 'x2':bb1[2], 'y1':bb1[1], 'y2':bb1[3]}
b2 = {'x1':bb2[0], 'x2':bb2[2], 'y1':bb2[1], 'y2':bb2[3]}
return get_iou(b1,b2)
def runpt( image_data,
sensitivity=0.3,
overlap=0.5,
modelloc='weights/best.pt',
model_shape=(640,480),
MODEL_MEAN_VALUES=(78.4263377603, 87.7689143744, 114.895847746),
img_save=False,
img_show=False,
debug=True,
upscale_img_mult=1
):
"""
Runs YOLO detection PyTorch model on an image.
Parameters
----------
image_data: Either file location or image loaded with cv2.imread
sensitivity: How sensitive do we want accuracy of the ai model to output
overlap: How much overlap to remove
modelloc: Model location, can be used with other .pt models
model_shape: (640,480)
MODEL_MEAN_VALUES: straightforward title
img_save: Not implemented, saves image to device
img_show: Show image on completion
debug: Whether to show debug or not
upscale_img_mult: Can sometimes allow for higher accuracy in predictions, seems to depend on a per case basis however
Return
----------
{
"detections": Number of items found,
"data": Items,
"image": Image of items detections
}
"""
# Model Shape
model_x, model_y = model_shape
mx_half, my_half = (int(model_x/2), int(model_y/2))
# Load model
try:
model = torch.load(modelloc, map_location=torch.device("cpu")).get('model').float() # make sure model is loaded in correctly
except:
print("ERROR LOADING MODEL. Likely improper reference location.")
return
# Get image
if type(image_data)==str:
try:
frame = cv2.imread(image_data)
except:
print("ERROR LOADING IMAGE. Likely improper image location.")
else:
try:
frame = image_data
except:
print("ERROR LOADING IMAGE. Likely improper image format. Try OpenCV.")
frame_height, frame_width, ch = frame.shape
if upscale_img_mult != 1:
frame = cv2.resize(frame, (int(frame_height*upscale_img_mult), int(frame_width*upscale_img_mult)))
frame_height, frame_width, ch = frame.shape
if debug:
print("Frame size "+str(frame_width)+"x"+str(frame_height))
# Ensure we are scanning whole screen by checking if there are "leftovers"
weird_y = False
weird_x = False
if frame_height%my_half!=0:
if debug:
print("Irregular image height, adjusting parameters to better scan zones...")
weird_y = True
if frame_width%mx_half!=0:
if debug:
print("Irregular image width, adjusting parameters to better scan zones...")
weird_x = True
# Zones we are going to scan, we only want to do this math once
zones_to_scan = [] # [[y,yend,x,xend]]
for i in range(0,(int(frame_height/my_half)-1)):
y = (my_half*i)
y_end = (y+model_y)
for j in range(0, (int(frame_width/mx_half)-1)):
x = (mx_half*j)
x_end = (x+model_x)
zones_to_scan.append([y,y_end,x,x_end])
if weird_x:
zones_to_scan.append([y,y_end,frame_width-model_x,frame_width])
if weird_y:
for j in range(0, (int(frame_width/mx_half)-1)):
x = (mx_half*j)
x_end = (x+model_x)
zones_to_scan.append([frame_height-model_y,frame_height,x,x_end])
if weird_x:
zones_to_scan.append([frame_height-model_y,frame_height,frame_width-model_x,frame_width])
# Number of zones (total)
num_zones = len(zones_to_scan)
# Dimensions of array
zone_array_dimensions = (int(frame_width/mx_half)-(0 if weird_x else 1), int(frame_height/my_half)-(0 if weird_y else 1))
# Adjacent zones
adj_zone_arr = [get_adj_zone_locations(pos, zone_array_dimensions) for pos in range(0,len(zones_to_scan))]
if debug:
print("Scaning zones...")
temp_boxes = [[] for i in range(0,num_zones)]
zone_loc = 0
for zone in zones_to_scan:
y, x = zone[0], zone[2]
cimg = frame[y:zone[1], x:zone[3]] # Crop image to where we are scaning
if debug:
print("Scaning zone: ("+str(x)+":"+str(zone[3])+", "+str(y)+":"+str(zone[1])+")")
cimg = cv2.resize(cimg, model_shape, interpolation=cv2.INTER_CUBIC)
cimg = normalize(cimg, MODEL_MEAN_VALUES, 1/256) # Model Mean Values are just a guess
cimg = torch.from_numpy(cimg).permute(2,0,1).unsqueeze(0).float()
output = model(cimg) # Run model on image
arr = np.squeeze(output[0].detach().cpu().numpy())
zoned_boxes = []
for k in range(len(arr[0])):
if arr[4][k] > sensitivity:
add_to_zone_checked = True
for checked_boxes in temp_boxes[zone_loc]:
check = get_iou_arr(checked_boxes,[int(arr[0][k])+x, int(arr[1][k])+y, int(arr[0][k]+arr[2][k])+x, int(arr[1][k]+arr[3][k])+y, arr[4][k]])
if check > overlap:
add_to_zone_checked = False
break
if add_to_zone_checked:
temp_boxes[zone_loc].append([int(arr[0][k])+x, int(arr[1][k])+y, int(arr[0][k]+arr[2][k])+x, int(arr[1][k]+arr[3][k])+y, arr[4][k]])
zone_loc+=1
# Ensure that overlaping images aren't scanned twice
if debug:
print("Checking for overlap...")
# temp_accepted_boxes = [[] for i in range(0,num_zones)]
boxes = []
# Get zones that are going to be checked in adjacent and same
for zones_to_check in adj_zone_arr:
if debug:
print("Checking zone: "+str(zones_to_check[-1]))
# Check each zone
for zone_being_checked in zones_to_check:
items_in_zone = temp_boxes[zone_being_checked]
# Check items in zone if overlapping objects in other zones
for i in items_in_zone:
add_box = True
for j in boxes: # TODO Make check against found objects in specific zones as opposed to all zones - minor efficiency upgrade at scale
check = get_iou_arr(i,j)
if check > overlap: # If overlaping
add_box = False
break
if add_box:
boxes.append(i)
cv2.rectangle(frame, (i[0]-int(int(i[2]-i[0])/2), i[1]-int(int(i[3]-i[1])/2)), (i[2]-int(int(i[2]-i[0])/2), i[3]-int(int(i[3]-i[1])/2)), (255,0,0),2)
if debug:
print("Found "+str(len(boxes))+" item(s).")
frame = cv2.resize(frame, (640*2,480*2), interpolation=cv2.INTER_CUBIC)
if img_show:
while True:
# Displaying color frame with contour of motion of object
cv2.imshow("Color Frame", frame)
key = cv2.waitKey(1)
if key == ord('q'):
break
return {"detections":len(boxes),"data":boxes,"image":frame}
if __name__=='__main__':
runpt("images/test.png",
sensitivity=0.6,
overlap=0.3,
img_show=True,
upscale_img_mult=2
)