-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathpy_nms.py
More file actions
130 lines (102 loc) · 4.11 KB
/
py_nms.py
File metadata and controls
130 lines (102 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
logger = logging.getLogger(__name__)
def py_greedy_nms(dets, iou_thr):
"""Pure python implementation of traditional greedy NMS.
Args:
dets (numpy.array): Detection results with shape `(num, 5)`,
data in second dimension are [x1, y1, x2, y2, score] respectively.
iou_thr (float): Drop the boxes that overlap with current
maximum > thresh.
Returns:
numpy.array: Retained boxes.
"""
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
sorted_idx = scores.argsort()[::-1]
keep = []
while sorted_idx.size > 0:
i = sorted_idx[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[sorted_idx[1:]])
yy1 = np.maximum(y1[i], y1[sorted_idx[1:]])
xx2 = np.minimum(x2[i], x2[sorted_idx[1:]])
yy2 = np.minimum(y2[i], y2[sorted_idx[1:]])
w = np.maximum(xx2 - xx1 + 1, 0.0)
h = np.maximum(yy2 - yy1 + 1, 0.0)
inter = w * h
iou = inter / (areas[i] + areas[sorted_idx[1:]] - inter)
retained_idx = np.where(iou <= iou_thr)[0]
sorted_idx = sorted_idx[retained_idx + 1]
return dets[keep, :]
def py_soft_nms(dets, method='linear', iou_thr=0.3, sigma=0.5, score_thr=0.001):
"""Pure python implementation of soft NMS as described in the paper
`Improving Object Detection With One Line of Code`_.
Args:
dets (numpy.array): Detection results with shape `(num, 5)`,
data in second dimension are [x1, y1, x2, y2, score] respectively.
method (str): Rescore method. Only can be `linear`, `gaussian`
or 'greedy'.
iou_thr (float): IOU threshold. Only work when method is `linear`
or 'greedy'.
sigma (float): Gaussian function parameter. Only work when method
is `gaussian`.
score_thr (float): Boxes that score less than the.
Returns:
numpy.array: Retained boxes.
.. _`Improving Object Detection With One Line of Code`:
https://arxiv.org/abs/1704.04503
"""
if method not in ('linear', 'gaussian', 'greedy'):
raise ValueError('method must be linear, gaussian or greedy')
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
# expand dets with areas, and the second dimension is
# x1, y1, x2, y2, score, area
dets = np.concatenate((dets, areas[:, None]), axis=1)
retained_box = []
while dets.size > 0:
max_idx = np.argmax(dets[:, 4], axis=0)
dets[[0, max_idx], :] = dets[[max_idx, 0], :]
retained_box.append(dets[0, :-1])
xx1 = np.maximum(dets[0, 0], dets[1:, 0])
yy1 = np.maximum(dets[0, 1], dets[1:, 1])
xx2 = np.minimum(dets[0, 2], dets[1:, 2])
yy2 = np.minimum(dets[0, 3], dets[1:, 3])
w = np.maximum(xx2 - xx1 + 1, 0.0)
h = np.maximum(yy2 - yy1 + 1, 0.0)
inter = w * h
iou = inter / (dets[0, 5] + dets[1:, 5] - inter)
if method == 'linear':
weight = np.ones_like(iou)
weight[iou > iou_thr] -= iou[iou > iou_thr]
elif method == 'gaussian':
weight = np.exp(-(iou * iou) / sigma)
else: # traditional nms
weight = np.ones_like(iou)
weight[iou > iou_thr] = 0
dets[1:, 4] *= weight
retained_idx = np.where(dets[1:, 4] >= score_thr)[0]
dets = dets[retained_idx + 1, :]
return np.vstack(retained_box)
if __name__ == '__main__':
boxes = np.array([[100, 100, 210, 210, 0.72],
[250, 250, 420, 420, 0.8],
[220, 220, 320, 330, 0.92],
[100, 100, 210, 210, 0.72],
[230, 240, 325, 330, 0.81],
[220, 230, 315, 340, 0.9]], dtype=np.float32)
print('greedy result:')
print(py_greedy_nms(boxes, 0.7))
print('soft nms result:')
print(py_soft_nms(boxes, method='gaussian'))