From eb6bd857df945465641662cf7edd07dbe1763623 Mon Sep 17 00:00:00 2001 From: Philipp Donn <30521025+phinik@users.noreply.github.com> Date: Thu, 13 Apr 2023 20:50:26 +0200 Subject: [PATCH] fix torch.isin(..) call --- yoeo/utils/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/yoeo/utils/utils.py b/yoeo/utils/utils.py index afa0790..a11bbe8 100644 --- a/yoeo/utils/utils.py +++ b/yoeo/utils/utils.py @@ -437,6 +437,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non t = time.time() output = [torch.zeros((0, 6), device="cpu")] * prediction.shape[0] + if robot_class_ids: + robot_class_ids = torch.tensor(robot_class_ids, device=prediction.device, dtype=prediction.dtype) for xi, x in enumerate(prediction): # image index, image inference # Apply constraints @@ -474,7 +476,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non x = x[x[:, 4].argsort(descending=True)[:max_nms]] # Batched NMS - if not robot_class_ids: + if robot_class_ids is None: c = x[:, 5:6] * max_wh # classes else: # If multiple robot classes are present, all robot classes are treated as one class in order to perform