From e53af72a06789f18f1d317d5e5f2e2002b567f94 Mon Sep 17 00:00:00 2001 From: Yann HALLOUARD Date: Sun, 27 Oct 2024 20:31:57 +0100 Subject: [PATCH 1/4] feat: Add soft Non-Max suppression --- supervision/detection/core.py | 62 +++++ supervision/detection/overlap_filter.py | 196 ++++++++++++-- test/detection/test_overlap_filter.py | 335 ++++++++++++++++++++++++ 3 files changed, 570 insertions(+), 23 deletions(-) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 113948fc9..4731e6f85 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -19,7 +19,9 @@ from supervision.detection.overlap_filter import ( box_non_max_merge, box_non_max_suppression, + box_soft_non_max_suppression, mask_non_max_suppression, + mask_soft_non_max_suppression, ) from supervision.detection.tools.transformers import ( process_transformers_detection_result, @@ -1320,6 +1322,66 @@ def with_nms( return self[indices] + def with_soft_nms( + self, threshold: float = 0.5, class_agnostic: bool = False, sigma: float = 0.5 + ) -> Detections: + """ + Perform soft non-maximum suppression on the current set of object detections. + + Args: + threshold (float): The intersection-over-union threshold + to use for non-maximum suppression. Defaults to 0.5. + class_agnostic (bool): Whether to perform class-agnostic + non-maximum suppression. If True, the class_id of each detection + will be ignored. Defaults to False. + sigma (float): The sigma value to use for the soft non-maximum suppression + algorithm. Defaults to 0.5. + + Returns: + Detections: A new Detections object containing the subset of detections + after non-maximum suppression. + + Raises: + AssertionError: If `confidence` is None and class_agnostic is False. + """ + if len(self) == 0: + return self + + assert ( + self.confidence is not None + ), "Detections confidence must be given for NMS to be executed." + + if class_agnostic: + predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1))) + else: + assert self.class_id is not None, ( + "Detections class_id must be given for NMS to be executed. If you" + " intended to perform class agnostic NMS set class_agnostic=True." + ) + predictions = np.hstack( + ( + self.xyxy, + self.confidence.reshape(-1, 1), + self.class_id.reshape(-1, 1), + ) + ) + + if self.mask is not None: + soft_confidences = mask_soft_non_max_suppression( + predictions=predictions, + masks=self.mask, + iou_threshold=threshold, + sigma=sigma, + ) + self.confidence = soft_confidences + else: + indices, soft_confidences = box_soft_non_max_suppression( + predictions=predictions, iou_threshold=threshold, sigma=sigma + ) + self.confidence = soft_confidences + + return self + def with_nmm( self, threshold: float = 0.5, class_agnostic: bool = False ) -> Detections: diff --git a/supervision/detection/overlap_filter.py b/supervision/detection/overlap_filter.py index 4c59295f6..9739a709b 100644 --- a/supervision/detection/overlap_filter.py +++ b/supervision/detection/overlap_filter.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import List, Union +from typing import List, Tuple, Union import numpy as np import numpy.typing as npt @@ -38,6 +38,55 @@ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray: return resized_masks +def __prepare_data_for_mask_nms( + iou_threshold: float, + mask_dimension: int, + masks: np.ndarray, + predictions: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]: + """ + Get IOUs from mask. Prepare the data for non-max suppression. + + Args: + iou_threshold (float): The intersection-over-union threshold + to use for non-maximum suppression. + mask_dimension (int): The dimension to which the masks should be + resized before computing IOU values. + masks (np.ndarray): A 3D array of binary masks corresponding to the predictions. + Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the + dimensions of each + predictions (np.ndarray): An array of object detection predictions in the format + of `(x_min, y_min, x_max, y_max, score)` or + `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or `(N, 6)`, + where N is the number of predictions. + + Returns: + Tuple[np.ndarray, np.ndarray, int, np.ndarray]: A tuple containing the + predictions, categories, IOUs, number of rows, and the sorted indices. + + Raises: + AssertionError: If `iou_threshold` is not within the closed range from + `0` to `1`. + """ + assert 0 <= iou_threshold <= 1, ( + "Value of `iou_threshold` must be in the closed range from 0 to 1, " + f"{iou_threshold} given." + ) + rows, columns = predictions.shape + + if columns == 5: + predictions = np.c_[predictions, np.zeros(rows)] + + sort_index = predictions[:, 4].argsort()[::-1] + predictions = predictions[sort_index] + masks = masks[sort_index] + masks_resized = resize_masks(masks, mask_dimension) + ious = mask_iou_batch(masks_resized, masks_resized) + categories = predictions[:, 5] + + return predictions, categories, ious, rows, sort_index + + def mask_non_max_suppression( predictions: np.ndarray, masks: np.ndarray, @@ -68,21 +117,9 @@ def mask_non_max_suppression( AssertionError: If `iou_threshold` is not within the closed range from `0` to `1`. """ - assert 0 <= iou_threshold <= 1, ( - "Value of `iou_threshold` must be in the closed range from 0 to 1, " - f"{iou_threshold} given." + _, categories, ious, rows, sort_index = __prepare_data_for_mask_nms( + iou_threshold, mask_dimension, masks, predictions ) - rows, columns = predictions.shape - - if columns == 5: - predictions = np.c_[predictions, np.zeros(rows)] - - sort_index = predictions[:, 4].argsort()[::-1] - predictions = predictions[sort_index] - masks = masks[sort_index] - masks_resized = resize_masks(masks, mask_dimension) - ious = mask_iou_batch(masks_resized, masks_resized) - categories = predictions[:, 5] keep = np.ones(rows, dtype=bool) for i in range(rows): @@ -93,26 +130,73 @@ def mask_non_max_suppression( return keep[sort_index.argsort()] -def box_non_max_suppression( - predictions: np.ndarray, iou_threshold: float = 0.5 +def mask_soft_non_max_suppression( + predictions: np.ndarray, + masks: np.ndarray, + iou_threshold: float = 0.5, + mask_dimension: int = 640, + sigma: float = 0.5, ) -> np.ndarray: """ - Perform Non-Maximum Suppression (NMS) on object detection predictions. + Perform Soft Non-Maximum Suppression (Soft-NMS) on segmentation predictions. - Args: + Args: predictions (np.ndarray): An array of object detection predictions in the format of `(x_min, y_min, x_max, y_max, score)` or `(x_min, y_min, x_max, y_max, score, class)`. iou_threshold (float): The intersection-over-union threshold to use for non-maximum suppression. + sigma (float): The sigma value to use for soft non-maximum suppression. Returns: - np.ndarray: A boolean array indicating which predictions to keep after n - on-maximum suppression. + np.ndarray: An array containing the updated confidence scores. Raises: AssertionError: If `iou_threshold` is not within the closed range from `0` to `1`. + AssertionError: If `sigma` is not within the open range from `0` to `1`. + """ + assert ( + 0 < sigma < 1 + ), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given." + predictions, categories, ious, rows, sort_index = __prepare_data_for_mask_nms( + iou_threshold, mask_dimension, masks, predictions + ) + + not_this_row = np.ones(rows) + for i in range(rows): + not_this_row[i] = 0 + condition = (categories[i] == categories) * not_this_row + predictions[:, 4] = predictions[:, 4] * np.exp( + -(ious[i] ** 2) / sigma * condition + ) + + return predictions[sort_index.argsort(), 4] + + +def __prepare_data_for_box_nsm( + iou_threshold: float, predictions: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]: + """ + Prepare the data for non-max suppression. + + Args: + iou_threshold (float): The intersection-over-union threshold + to use for non-maximum suppression. + predictions (np.ndarray): An array of object detection predictions in the + format of `(x_min, y_min, x_max, y_max, score)` or + `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or `(N, 6)`, + where N is the number of predictions. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]: A tuple containing + the predictions, categories, IOUs, number of rows, and the sorted indices + + Raises: + AssertionError: If `iou_threshold` is not within the closed range from `0` + to `1`. + + """ assert 0 <= iou_threshold <= 1, ( "Value of `iou_threshold` must be in the closed range from 0 to 1, " @@ -127,14 +211,40 @@ def box_non_max_suppression( # sort predictions column #4 - score sort_index = np.flip(predictions[:, 4].argsort()) predictions = predictions[sort_index] - boxes = predictions[:, :4] categories = predictions[:, 5] ious = box_iou_batch(boxes, boxes) ious = ious - np.eye(rows) - keep = np.ones(rows, dtype=bool) + return predictions, categories, ious, rows, sort_index + + +def box_non_max_suppression( + predictions: np.ndarray, iou_threshold: float = 0.5 +) -> np.ndarray: + """ + Perform Non-Maximum Suppression (NMS) on object detection predictions. + + Args: + predictions (np.ndarray): An array of object detection predictions in + the format of `(x_min, y_min, x_max, y_max, score)` + or `(x_min, y_min, x_max, y_max, score, class)`. + iou_threshold (float): The intersection-over-union threshold + to use for non-maximum suppression. + + Returns: + np.ndarray: A boolean array indicating which predictions to keep after n + on-maximum suppression. + + Raises: + AssertionError: If `iou_threshold` is not within the + closed range from `0` to `1`. + """ + _, categories, ious, rows, sort_index = __prepare_data_for_box_nsm( + iou_threshold, predictions + ) + keep = np.ones(rows, dtype=bool) for index, (iou, category) in enumerate(zip(ious, categories)): if not keep[index]: continue @@ -147,6 +257,46 @@ def box_non_max_suppression( return keep[sort_index.argsort()] +def box_soft_non_max_suppression( + predictions: np.ndarray, iou_threshold: float = 0.5, sigma: float = 0.5 +) -> np.ndarray: + """ + Perform Soft Non-Maximum Suppression (Soft-NMS) on object detection predictions. + + Args: + predictions (np.ndarray): An array of object detection predictions in + the format of `(x_min, y_min, x_max, y_max, score)` + or `(x_min, y_min, x_max, y_max, score, class)`. + iou_threshold (float): The intersection-over-union threshold + to use for soft non-maximum suppression. + sigma (float): The sigma value to use for soft non-maximum suppression. + + Returns: + np.ndarray: An array containing the updated confidence scores. + Raises: + AssertionError: If `iou_threshold` is not within the + closed range from `0` to `1`. + AssertionError: If `sigma` is not within the opened range from `0` to `1`. + """ + + assert ( + 0 < sigma < 1 + ), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given." + predictions, categories, ious, rows, sort_index = __prepare_data_for_box_nsm( + iou_threshold, predictions + ) + + not_this_row = np.ones(rows) + for i in range(rows): + not_this_row[i] = 0 + condition = (categories[i] == categories) * not_this_row + predictions[:, 4] = predictions[:, 4] * np.exp( + -(ious[i] ** 2) / sigma * condition + ) + + return predictions[sort_index.argsort(), 4] + + def group_overlapping_boxes( predictions: npt.NDArray[np.float64], iou_threshold: float = 0.5 ) -> List[List[int]]: diff --git a/test/detection/test_overlap_filter.py b/test/detection/test_overlap_filter.py index f628c30f9..df97e4f14 100644 --- a/test/detection/test_overlap_filter.py +++ b/test/detection/test_overlap_filter.py @@ -6,8 +6,10 @@ from supervision.detection.overlap_filter import ( box_non_max_suppression, + box_soft_non_max_suppression, group_overlapping_boxes, mask_non_max_suppression, + mask_soft_non_max_suppression, ) @@ -243,6 +245,121 @@ def test_box_non_max_suppression( assert np.array_equal(result, expected_result) +@pytest.mark.parametrize( + "predictions, iou_threshold, sigma, expected_result, exception", + [ + ( + np.empty(shape=(0, 5)), + 0.5, + 0.1, + np.array([]), + DoesNotRaise(), + ), # single box with no category + ( + np.array([[10.0, 10.0, 40.0, 40.0, 0.8]]), + 0.5, + 0.8, + np.array([0.8]), + DoesNotRaise(), + ), # single box with no category + ( + np.array([[10.0, 10.0, 40.0, 40.0, 0.8, 0]]), + 0.5, + 0.9, + np.array([0.8]), + DoesNotRaise(), + ), # single box with category + ( + np.array( + [ + [10.0, 10.0, 40.0, 40.0, 0.8], + [15.0, 15.0, 40.0, 40.0, 0.9], + ] + ), + 0.5, + 0.2, + np.array([0.07176137, 0.9]), + DoesNotRaise(), + ), # two boxes with no category + ( + np.array( + [ + [10.0, 10.0, 40.0, 40.0, 0.8, 0], + [15.0, 15.0, 40.0, 40.0, 0.9, 1], + ] + ), + 0.5, + 0.3, + np.array([0.8, 0.9]), + DoesNotRaise(), + ), # two boxes with different category + ( + np.array( + [ + [10.0, 10.0, 40.0, 40.0, 0.8, 0], + [15.0, 15.0, 40.0, 40.0, 0.9, 0], + ] + ), + 0.5, + 0.9, + np.array([0.46814354, 0.9]), + DoesNotRaise(), + ), # two boxes with same category + ( + np.array( + [ + [0.0, 0.0, 30.0, 40.0, 0.8], + [5.0, 5.0, 35.0, 45.0, 0.9], + [10.0, 10.0, 40.0, 50.0, 0.85], + ] + ), + 0.5, + 0.7, + np.array([0.42648529, 0.9, 0.53109062]), + DoesNotRaise(), + ), # three boxes with no category + ( + np.array( + [ + [0.0, 0.0, 30.0, 40.0, 0.8, 0], + [5.0, 5.0, 35.0, 45.0, 0.9, 1], + [10.0, 10.0, 40.0, 50.0, 0.85, 2], + ] + ), + 0.5, + 0.5, + np.array([0.8, 0.9, 0.85]), + DoesNotRaise(), + ), # three boxes with same category + ( + np.array( + [ + [0.0, 0.0, 30.0, 40.0, 0.8, 0], + [5.0, 5.0, 35.0, 45.0, 0.9, 0], + [10.0, 10.0, 40.0, 50.0, 0.85, 1], + ] + ), + 0.5, + 0.9, + np.array([0.55491779, 0.9, 0.85]), + DoesNotRaise(), + ), # three boxes with different category + ], +) +def test_box_soft_non_max_suppression( + predictions: np.ndarray, + iou_threshold: float, + sigma: float, + expected_result: Optional[np.ndarray], + exception: Exception, +) -> None: + with exception: + result = box_soft_non_max_suppression( + predictions=predictions, iou_threshold=iou_threshold, sigma=sigma + ) + np.testing.assert_almost_equal(result, expected_result, decimal=5) + + @pytest.mark.parametrize( "predictions, masks, iou_threshold, expected_result, exception", [ @@ -447,3 +564,221 @@ def test_mask_non_max_suppression( predictions=predictions, masks=masks, iou_threshold=iou_threshold ) assert np.array_equal(result, expected_result) + + +@pytest.mark.parametrize( + "predictions, masks, iou_threshold, sigma, expected_result, exception", + [ + ( + np.empty((0, 6)), + np.empty((0, 5, 5)), + 0.5, + 0.1, + np.array([]), + DoesNotRaise(), + ), # empty predictions and masks + ( + np.array([[0, 0, 0, 0, 0.8]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, False, False, False, False], + ] + ] + ), + 0.5, + 0.2, + np.array([0.8]), + DoesNotRaise(), + ), # single mask with no category + ( + np.array([[0, 0, 0, 0, 0.8, 0]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, False, False, False, False], + ] + ] + ), + 0.5, + 1, + np.array([0.8]), + DoesNotRaise(), + ), # single mask with category + ( + np.array([[0, 0, 0, 0, 0.8], [0, 0, 0, 0, 0.9]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, True, True], + [False, False, False, True, True], + [False, False, False, False, False], + ], + ] + ), + 0.5, + 0.8, + np.array([0.8, 0.9]), + DoesNotRaise(), + ), # two masks non-overlapping with no category + ( + np.array([[0, 0, 0, 0, 0.8], [0, 0, 0, 0, 0.9]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, True, True, True], + [False, False, True, True, True], + [False, False, True, True, True], + [False, False, False, False, False], + ], + ] + ), + 0.4, + 0.6, + np.array([0.3831756, 0.9]), + DoesNotRaise(), + ), # two masks partially overlapping with no category + ( + np.array([[0, 0, 0, 0, 0.8, 0], [0, 0, 0, 0, 0.9, 1]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, True, True, True], + [False, False, True, True, True], + [False, False, True, True, True], + [False, False, False, False, False], + ], + ] + ), + 0.5, + 0.9, + np.array([0.8, 0.9]), + DoesNotRaise(), + ), # two masks partially overlapping with different category + ( + np.array( + [ + [0, 0, 0, 0, 0.8], + [0, 0, 0, 0, 0.85], + [0, 0, 0, 0, 0.9], + ] + ), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, True, True], + [False, False, False, True, True], + [False, False, False, False, False], + [False, False, False, False, False], + ], + ] + ), + 0.5, + 0.3, + np.array([0.02853919, 0.85, 0.9]), + DoesNotRaise(), + ), # three masks with no category + ( + np.array( + [ + [0, 0, 0, 0, 0.8, 0], + [0, 0, 0, 0, 0.85, 1], + [0, 0, 0, 0, 0.9, 2], + ] + ), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + ] + ), + 0.5, + 0.1, + np.array([0.8, 0.85, 0.9]), + DoesNotRaise(), + ), # three masks with different category + ], +) +def test_mask_soft_non_max_suppression( + predictions: np.ndarray, + masks: np.ndarray, + iou_threshold: float, + sigma: float, + expected_result: Optional[np.ndarray], + exception: Exception, +) -> None: + with exception: + result = mask_soft_non_max_suppression( + predictions=predictions, + masks=masks, + iou_threshold=iou_threshold, + sigma=sigma, + ) + np.testing.assert_almost_equal(result, expected_result, decimal=6) From 920c5c12724895134710e234d499d65c7f4527e7 Mon Sep 17 00:00:00 2001 From: Yann HALLOUARD Date: Sun, 27 Oct 2024 20:47:22 +0100 Subject: [PATCH 2/4] add doc --- docs/detection/double_detection_filter.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/detection/double_detection_filter.md b/docs/detection/double_detection_filter.md index b02663715..6cb2393fd 100644 --- a/docs/detection/double_detection_filter.md +++ b/docs/detection/double_detection_filter.md @@ -16,12 +16,25 @@ comments: true :::supervision.detection.overlap_filter.box_non_max_suppression + + +:::supervision.detection.overlap_filter.box_soft_non_max_suppression + + :::supervision.detection.overlap_filter.mask_non_max_suppression + + +:::supervision.detection.overlap_filter.mask_soft_non_max_suppression + From 821968841e8d73b965b12b0862af37b185f4da5a Mon Sep 17 00:00:00 2001 From: Yann HALLOUARD Date: Sun, 27 Oct 2024 20:52:49 +0100 Subject: [PATCH 3/4] fix tests --- test/detection/test_overlap_filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/detection/test_overlap_filter.py b/test/detection/test_overlap_filter.py index df97e4f14..1a5fb05b4 100644 --- a/test/detection/test_overlap_filter.py +++ b/test/detection/test_overlap_filter.py @@ -609,7 +609,7 @@ def test_mask_non_max_suppression( ] ), 0.5, - 1, + 0.99, np.array([0.8]), DoesNotRaise(), ), # single mask with category From f2a97f703760924cc615ff6262a465a83cb2cddb Mon Sep 17 00:00:00 2001 From: Yann HALLOUARD Date: Sun, 3 Nov 2024 01:59:16 +0100 Subject: [PATCH 4/4] fix bugs --- supervision/detection/core.py | 13 ++++------ supervision/detection/overlap_filter.py | 34 ++++++++++--------------- test/detection/test_overlap_filter.py | 28 +++----------------- 3 files changed, 21 insertions(+), 54 deletions(-) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 4731e6f85..ab3ab348d 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -1323,19 +1323,17 @@ def with_nms( return self[indices] def with_soft_nms( - self, threshold: float = 0.5, class_agnostic: bool = False, sigma: float = 0.5 + self, sigma: float = 0.5, class_agnostic: bool = False ) -> Detections: """ Perform soft non-maximum suppression on the current set of object detections. Args: - threshold (float): The intersection-over-union threshold - to use for non-maximum suppression. Defaults to 0.5. + sigma (float): The sigma value to use for the soft non-maximum suppression + algorithm. Defaults to 0.5. class_agnostic (bool): Whether to perform class-agnostic non-maximum suppression. If True, the class_id of each detection will be ignored. Defaults to False. - sigma (float): The sigma value to use for the soft non-maximum suppression - algorithm. Defaults to 0.5. Returns: Detections: A new Detections object containing the subset of detections @@ -1370,13 +1368,12 @@ def with_soft_nms( soft_confidences = mask_soft_non_max_suppression( predictions=predictions, masks=self.mask, - iou_threshold=threshold, sigma=sigma, ) self.confidence = soft_confidences else: - indices, soft_confidences = box_soft_non_max_suppression( - predictions=predictions, iou_threshold=threshold, sigma=sigma + soft_confidences = box_soft_non_max_suppression( + predictions=predictions, sigma=sigma ) self.confidence = soft_confidences diff --git a/supervision/detection/overlap_filter.py b/supervision/detection/overlap_filter.py index 9739a709b..a7ef40c19 100644 --- a/supervision/detection/overlap_filter.py +++ b/supervision/detection/overlap_filter.py @@ -39,7 +39,6 @@ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray: def __prepare_data_for_mask_nms( - iou_threshold: float, mask_dimension: int, masks: np.ndarray, predictions: np.ndarray, @@ -48,8 +47,6 @@ def __prepare_data_for_mask_nms( Get IOUs from mask. Prepare the data for non-max suppression. Args: - iou_threshold (float): The intersection-over-union threshold - to use for non-maximum suppression. mask_dimension (int): The dimension to which the masks should be resized before computing IOU values. masks (np.ndarray): A 3D array of binary masks corresponding to the predictions. @@ -68,10 +65,6 @@ def __prepare_data_for_mask_nms( AssertionError: If `iou_threshold` is not within the closed range from `0` to `1`. """ - assert 0 <= iou_threshold <= 1, ( - "Value of `iou_threshold` must be in the closed range from 0 to 1, " - f"{iou_threshold} given." - ) rows, columns = predictions.shape if columns == 5: @@ -117,8 +110,12 @@ def mask_non_max_suppression( AssertionError: If `iou_threshold` is not within the closed range from `0` to `1`. """ + assert 0 <= iou_threshold <= 1, ( + "Value of `iou_threshold` must be in the closed range from 0 to 1, " + f"{iou_threshold} given." + ) _, categories, ious, rows, sort_index = __prepare_data_for_mask_nms( - iou_threshold, mask_dimension, masks, predictions + mask_dimension, masks, predictions ) keep = np.ones(rows, dtype=bool) @@ -133,7 +130,6 @@ def mask_non_max_suppression( def mask_soft_non_max_suppression( predictions: np.ndarray, masks: np.ndarray, - iou_threshold: float = 0.5, mask_dimension: int = 640, sigma: float = 0.5, ) -> np.ndarray: @@ -160,7 +156,7 @@ def mask_soft_non_max_suppression( 0 < sigma < 1 ), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given." predictions, categories, ious, rows, sort_index = __prepare_data_for_mask_nms( - iou_threshold, mask_dimension, masks, predictions + mask_dimension, masks, predictions ) not_this_row = np.ones(rows) @@ -175,14 +171,12 @@ def mask_soft_non_max_suppression( def __prepare_data_for_box_nsm( - iou_threshold: float, predictions: np.ndarray + predictions: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]: """ Prepare the data for non-max suppression. Args: - iou_threshold (float): The intersection-over-union threshold - to use for non-maximum suppression. predictions (np.ndarray): An array of object detection predictions in the format of `(x_min, y_min, x_max, y_max, score)` or `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or `(N, 6)`, @@ -198,10 +192,6 @@ def __prepare_data_for_box_nsm( """ - assert 0 <= iou_threshold <= 1, ( - "Value of `iou_threshold` must be in the closed range from 0 to 1, " - f"{iou_threshold} given." - ) rows, columns = predictions.shape # add column #5 - category filled with zeros for agnostic nms @@ -240,9 +230,11 @@ def box_non_max_suppression( AssertionError: If `iou_threshold` is not within the closed range from `0` to `1`. """ - _, categories, ious, rows, sort_index = __prepare_data_for_box_nsm( - iou_threshold, predictions + assert 0 <= iou_threshold <= 1, ( + "Value of `iou_threshold` must be in the closed range from 0 to 1, " + f"{iou_threshold} given." ) + _, categories, ious, rows, sort_index = __prepare_data_for_box_nsm(predictions) keep = np.ones(rows, dtype=bool) for index, (iou, category) in enumerate(zip(ious, categories)): @@ -258,7 +250,7 @@ def box_non_max_suppression( def box_soft_non_max_suppression( - predictions: np.ndarray, iou_threshold: float = 0.5, sigma: float = 0.5 + predictions: np.ndarray, sigma: float = 0.5 ) -> np.ndarray: """ Perform Soft Non-Maximum Suppression (Soft-NMS) on object detection predictions. @@ -283,7 +275,7 @@ def box_soft_non_max_suppression( 0 < sigma < 1 ), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given." predictions, categories, ious, rows, sort_index = __prepare_data_for_box_nsm( - iou_threshold, predictions + predictions ) not_this_row = np.ones(rows) diff --git a/test/detection/test_overlap_filter.py b/test/detection/test_overlap_filter.py index 1a5fb05b4..6b0df77a4 100644 --- a/test/detection/test_overlap_filter.py +++ b/test/detection/test_overlap_filter.py @@ -246,25 +246,22 @@ def test_box_non_max_suppression( @pytest.mark.parametrize( - "predictions, iou_threshold, sigma, expected_result, exception", + "predictions, sigma, expected_result, exception", [ ( np.empty(shape=(0, 5)), - 0.5, 0.1, np.array([]), DoesNotRaise(), ), # single box with no category ( np.array([[10.0, 10.0, 40.0, 40.0, 0.8]]), - 0.5, 0.8, np.array([0.8]), DoesNotRaise(), ), # single box with no category ( np.array([[10.0, 10.0, 40.0, 40.0, 0.8, 0]]), - 0.5, 0.9, np.array([0.8]), DoesNotRaise(), @@ -276,7 +273,6 @@ def test_box_non_max_suppression( [15.0, 15.0, 40.0, 40.0, 0.9], ] ), - 0.5, 0.2, np.array([0.07176137, 0.9]), DoesNotRaise(), @@ -288,7 +284,6 @@ def test_box_non_max_suppression( [15.0, 15.0, 40.0, 40.0, 0.9, 1], ] ), - 0.5, 0.3, np.array([0.8, 0.9]), DoesNotRaise(), @@ -300,7 +295,6 @@ def test_box_non_max_suppression( [15.0, 15.0, 40.0, 40.0, 0.9, 0], ] ), - 0.5, 0.9, np.array([0.46814354, 0.9]), DoesNotRaise(), @@ -313,7 +307,6 @@ def test_box_non_max_suppression( [10.0, 10.0, 40.0, 50.0, 0.85], ] ), - 0.5, 0.7, np.array([0.42648529, 0.9, 0.53109062]), DoesNotRaise(), @@ -327,7 +320,6 @@ def test_box_non_max_suppression( ] ), 0.5, - 0.5, np.array([0.8, 0.9, 0.85]), DoesNotRaise(), ), # three boxes with same category @@ -339,7 +331,6 @@ def test_box_non_max_suppression( [10.0, 10.0, 40.0, 50.0, 0.85, 1], ] ), - 0.5, 0.9, np.array([0.55491779, 0.9, 0.85]), DoesNotRaise(), @@ -348,15 +339,12 @@ def test_box_non_max_suppression( ) def test_box_soft_non_max_suppression( predictions: np.ndarray, - iou_threshold: float, sigma: float, expected_result: Optional[np.ndarray], exception: Exception, ) -> None: with exception: - result = box_soft_non_max_suppression( - predictions=predictions, iou_threshold=iou_threshold, sigma=sigma - ) + result = box_soft_non_max_suppression(predictions=predictions, sigma=sigma) np.testing.assert_almost_equal(result, expected_result, decimal=5) @@ -567,12 +555,11 @@ def test_mask_non_max_suppression( @pytest.mark.parametrize( - "predictions, masks, iou_threshold, sigma, expected_result, exception", + "predictions, masks, sigma, expected_result, exception", [ ( np.empty((0, 6)), np.empty((0, 5, 5)), - 0.5, 0.1, np.array([]), DoesNotRaise(), @@ -590,7 +577,6 @@ def test_mask_non_max_suppression( ] ] ), - 0.5, 0.2, np.array([0.8]), DoesNotRaise(), @@ -608,7 +594,6 @@ def test_mask_non_max_suppression( ] ] ), - 0.5, 0.99, np.array([0.8]), DoesNotRaise(), @@ -633,7 +618,6 @@ def test_mask_non_max_suppression( ], ] ), - 0.5, 0.8, np.array([0.8, 0.9]), DoesNotRaise(), @@ -658,7 +642,6 @@ def test_mask_non_max_suppression( ], ] ), - 0.4, 0.6, np.array([0.3831756, 0.9]), DoesNotRaise(), @@ -683,7 +666,6 @@ def test_mask_non_max_suppression( ], ] ), - 0.5, 0.9, np.array([0.8, 0.9]), DoesNotRaise(), @@ -721,7 +703,6 @@ def test_mask_non_max_suppression( ], ] ), - 0.5, 0.3, np.array([0.02853919, 0.85, 0.9]), DoesNotRaise(), @@ -759,7 +740,6 @@ def test_mask_non_max_suppression( ], ] ), - 0.5, 0.1, np.array([0.8, 0.85, 0.9]), DoesNotRaise(), @@ -769,7 +749,6 @@ def test_mask_non_max_suppression( def test_mask_soft_non_max_suppression( predictions: np.ndarray, masks: np.ndarray, - iou_threshold: float, sigma: float, expected_result: Optional[np.ndarray], exception: Exception, @@ -778,7 +757,6 @@ def test_mask_soft_non_max_suppression( result = mask_soft_non_max_suppression( predictions=predictions, masks=masks, - iou_threshold=iou_threshold, sigma=sigma, ) np.testing.assert_almost_equal(result, expected_result, decimal=6)