Skip to content

Commit

Permalink
fix torch.isin(..) call
Browse files Browse the repository at this point in the history
  • Loading branch information
phinik committed Apr 13, 2023
1 parent 9f0589d commit eb6bd85
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion yoeo/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non

t = time.time()
output = [torch.zeros((0, 6), device="cpu")] * prediction.shape[0]
if robot_class_ids:
robot_class_ids = torch.tensor(robot_class_ids, device=prediction.device, dtype=prediction.dtype)

for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
Expand Down Expand Up @@ -474,7 +476,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
x = x[x[:, 4].argsort(descending=True)[:max_nms]]

# Batched NMS
if not robot_class_ids:
if robot_class_ids is None:
c = x[:, 5:6] * max_wh # classes
else:
# If multiple robot classes are present, all robot classes are treated as one class in order to perform
Expand Down

0 comments on commit eb6bd85

Please sign in to comment.