diff --git a/yoeo/utils/utils.py b/yoeo/utils/utils.py index afa0790..a11bbe8 100644 --- a/yoeo/utils/utils.py +++ b/yoeo/utils/utils.py @@ -437,6 +437,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non t = time.time() output = [torch.zeros((0, 6), device="cpu")] * prediction.shape[0] + if robot_class_ids: + robot_class_ids = torch.tensor(robot_class_ids, device=prediction.device, dtype=prediction.dtype) for xi, x in enumerate(prediction): # image index, image inference # Apply constraints @@ -474,7 +476,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non x = x[x[:, 4].argsort(descending=True)[:max_nms]] # Batched NMS - if not robot_class_ids: + if robot_class_ids is None: c = x[:, 5:6] * max_wh # classes else: # If multiple robot classes are present, all robot classes are treated as one class in order to perform