-
Notifications
You must be signed in to change notification settings - Fork 40
/
nms.lua
68 lines (54 loc) · 1.43 KB
/
nms.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
local function nms(boxes, overlap)
local pick = torch.LongTensor()
if boxes:numel() == 0 then
return pick
end
local x1 = boxes[{{},1}]
local y1 = boxes[{{},2}]
local x2 = boxes[{{},3}]
local y2 = boxes[{{},4}]
local s = boxes[{{},-1}]
local area = boxes.new():resizeAs(s):zero()
area:map2(x2,x1,function(xx,xx2,xx1) return xx2-xx1+1 end)
area:map2(y2,y1,function(xx,xx2,xx1) return xx*(xx2-xx1+1) end)
local vals, I = s:sort(1)
pick:resize(s:size()):zero()
local counter = 1
local xx1 = boxes.new()
local yy1 = boxes.new()
local xx2 = boxes.new()
local yy2 = boxes.new()
local w = boxes.new()
local h = boxes.new()
while I:numel()>0 do
local last = I:size(1)
local i = I[last]
pick[counter] = i
counter = counter + 1
if last == 1 then
break
end
I = I[{{1,last-1}}]
xx1:index(x1,1,I)
xx1:cmax(x1[i])
yy1:index(y1,1,I)
yy1:cmax(y1[i])
xx2:index(x2,1,I)
xx2:cmin(x2[i])
yy2:index(y2,1,I)
yy2:cmin(y2[i])
w:resizeAs(xx2):zero()
w:map2(xx2,xx1,function(xx,xxx2,xxx1) return math.max(xxx2-xxx1+1,0) end)
h:resizeAs(yy2):zero()
h:map2(yy2,yy1,function(xx,yyy2,yyy1) return math.max(yyy2-yyy1+1,0) end)
local inter = w
inter:cmul(h)
local o = h
xx1:index(area,1,I)
torch.cdiv(o,inter,xx1+area[i]-inter)
I = I[o:le(overlap)]
end
pick = pick[{{1,counter-1}}]
return pick
end
return nms