@@ -33,7 +33,6 @@ limitations under the License.
33
33
#include " tensorflow/core/platform/logging.h"
34
34
35
35
namespace tensorflow {
36
- namespace {
37
36
38
37
typedef Eigen::ThreadPoolDevice CPUDevice;
39
38
@@ -90,59 +89,6 @@ static inline float ComputeIOU(typename TTypes<float, 2>::ConstTensor boxes,
90
89
return intersection_area / (area_i + area_j - intersection_area);
91
90
}
92
91
93
- void DoNonMaxSuppressionOp (OpKernelContext* context, const Tensor& boxes,
94
- const Tensor& scores, const Tensor& max_output_size,
95
- const float iou_threshold) {
96
- OP_REQUIRES (context, iou_threshold >= 0 && iou_threshold <= 1 ,
97
- errors::InvalidArgument (" iou_threshold must be in [0, 1]" ));
98
-
99
- int num_boxes = 0 ;
100
- ParseAndCheckBoxSizes (context, boxes, scores, &num_boxes);
101
- if (!context->status ().ok ()) {
102
- return ;
103
- }
104
-
105
- const int output_size = std::min (max_output_size.scalar <int >()(), num_boxes);
106
- typename TTypes<float , 2 >::ConstTensor boxes_data = boxes.tensor <float , 2 >();
107
-
108
- std::vector<float > scores_data (num_boxes);
109
- std::copy_n (scores.flat <float >().data (), num_boxes, scores_data.begin ());
110
- std::vector<int > sorted_indices;
111
- DecreasingArgSort (scores_data, &sorted_indices);
112
-
113
- std::vector<bool > active (num_boxes, true );
114
- std::vector<int > selected;
115
- int num_active = active.size ();
116
- for (int i = 0 ; i < num_boxes; ++i) {
117
- if (num_active == 0 || selected.size () >= output_size) break ;
118
- if (active[i]) {
119
- selected.push_back (sorted_indices[i]);
120
- } else {
121
- continue ;
122
- }
123
- for (int j = i + 1 ; j < num_boxes; ++j) {
124
- if (active[j]) {
125
- float iou =
126
- ComputeIOU (boxes_data, sorted_indices[i], sorted_indices[j]);
127
- if (iou > iou_threshold) {
128
- active[j] = false ;
129
- num_active--;
130
- }
131
- }
132
- }
133
- }
134
-
135
- // Allocate output tensor
136
- Tensor* output = nullptr ;
137
- TensorShape output_shape ({static_cast <int >(selected.size ())});
138
- OP_REQUIRES_OK (context, context->allocate_output (0 , output_shape, &output));
139
- typename TTypes<int , 1 >::Tensor selected_indices_data =
140
- output->tensor <int , 1 >();
141
- std::copy_n (selected.begin (), selected.size (), selected_indices_data.data ());
142
- }
143
-
144
- } // namespace
145
-
146
92
template <typename Device>
147
93
class NonMaxSuppressionOp : public OpKernel {
148
94
public:
@@ -152,6 +98,9 @@ class NonMaxSuppressionOp : public OpKernel {
152
98
}
153
99
154
100
void Compute (OpKernelContext* context) override {
101
+ OP_REQUIRES (context, iou_threshold_ >= 0 && iou_threshold_ <= 1 ,
102
+ errors::InvalidArgument (" iou_threshold must be in [0, 1]" ));
103
+
155
104
// boxes: [num_boxes, 4]
156
105
const Tensor& boxes = context->input (0 );
157
106
// scores: [num_boxes]
@@ -163,48 +112,59 @@ class NonMaxSuppressionOp : public OpKernel {
163
112
errors::InvalidArgument (" max_output_size must be 0-D, got shape " ,
164
113
max_output_size.shape ().DebugString ()));
165
114
166
- DoNonMaxSuppressionOp (context, boxes, scores, max_output_size,
167
- iou_threshold_);
115
+ int num_boxes = 0 ;
116
+ ParseAndCheckBoxSizes (context, boxes, scores, &num_boxes);
117
+ if (!context->status ().ok ()) {
118
+ return ;
119
+ }
120
+
121
+ const int output_size =
122
+ std::min (max_output_size.scalar <int >()(), num_boxes);
123
+ typename TTypes<float , 2 >::ConstTensor boxes_data =
124
+ boxes.tensor <float , 2 >();
125
+
126
+ std::vector<float > scores_data (num_boxes);
127
+ std::copy_n (scores.flat <float >().data (), num_boxes, scores_data.begin ());
128
+ std::vector<int > sorted_indices;
129
+ DecreasingArgSort (scores_data, &sorted_indices);
130
+
131
+ std::vector<bool > active (num_boxes, true );
132
+ std::vector<int > selected;
133
+ int num_active = active.size ();
134
+ for (int i = 0 ; i < num_boxes; ++i) {
135
+ if (num_active == 0 || selected.size () >= output_size) break ;
136
+ if (active[i]) {
137
+ selected.push_back (sorted_indices[i]);
138
+ } else {
139
+ continue ;
140
+ }
141
+ for (int j = i + 1 ; j < num_boxes; ++j) {
142
+ if (active[j]) {
143
+ float iou =
144
+ ComputeIOU (boxes_data, sorted_indices[i], sorted_indices[j]);
145
+ if (iou > iou_threshold_) {
146
+ active[j] = false ;
147
+ num_active--;
148
+ }
149
+ }
150
+ }
151
+ }
152
+
153
+ // Allocate output tensor
154
+ Tensor* output = nullptr ;
155
+ TensorShape output_shape ({static_cast <int >(selected.size ())});
156
+ OP_REQUIRES_OK (context, context->allocate_output (0 , output_shape, &output));
157
+ typename TTypes<int , 1 >::Tensor selected_indices_data =
158
+ output->tensor <int , 1 >();
159
+ std::copy_n (selected.begin (), selected.size (),
160
+ selected_indices_data.data ());
168
161
}
169
162
170
163
private:
171
164
float iou_threshold_;
172
165
};
173
166
174
- template <typename Device>
175
- class NonMaxSuppressionV2Op : public OpKernel {
176
- public:
177
- explicit NonMaxSuppressionV2Op (OpKernelConstruction* context)
178
- : OpKernel(context) {}
179
-
180
- void Compute (OpKernelContext* context) override {
181
- // boxes: [num_boxes, 4]
182
- const Tensor& boxes = context->input (0 );
183
- // scores: [num_boxes]
184
- const Tensor& scores = context->input (1 );
185
- // max_output_size: scalar
186
- const Tensor& max_output_size = context->input (2 );
187
- OP_REQUIRES (
188
- context, TensorShapeUtils::IsScalar (max_output_size.shape ()),
189
- errors::InvalidArgument (" max_output_size must be 0-D, got shape " ,
190
- max_output_size.shape ().DebugString ()));
191
- // iou_threshold: scalar
192
- const Tensor& iou_threshold = context->input (3 );
193
- OP_REQUIRES (context, TensorShapeUtils::IsScalar (iou_threshold.shape ()),
194
- errors::InvalidArgument (" iou_threshold must be 0-D, got shape " ,
195
- iou_threshold.shape ().DebugString ()));
196
-
197
- const float iou_threshold_val = iou_threshold.scalar <float >()();
198
-
199
- DoNonMaxSuppressionOp (context, boxes, scores, max_output_size,
200
- iou_threshold_val);
201
- }
202
- };
203
-
204
167
REGISTER_KERNEL_BUILDER (Name(" NonMaxSuppression" ).Device(DEVICE_CPU),
205
168
NonMaxSuppressionOp<CPUDevice>);
206
169
207
- REGISTER_KERNEL_BUILDER (Name(" NonMaxSuppressionV2" ).Device(DEVICE_CPU),
208
- NonMaxSuppressionV2Op<CPUDevice>);
209
-
210
170
} // namespace tensorflow
0 commit comments