Skip to content

Commit

Permalink
add bbox info
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrii-Sheba committed Dec 2, 2024
1 parent 57c9776 commit 3826c29
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions tools/inference_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ def parse_args():
return parser.parse_args()


def draw_bboxes(image, bboxes):
"""Draw bounding boxes on the image using OpenCV."""
for bbox in bboxes:
x, y, x2, y2 = map(int, bbox)
cv2.rectangle(image, (x, y), (x2, y2), (255, 0, 0), 2) # Draw rectangle


def draw_keypoints(image, keypoints, scores, score_thr):
"""Draw keypoints on the image using OpenCV."""
for el1, el2 in zip(keypoints, scores):
Expand Down Expand Up @@ -78,7 +85,7 @@ def main():
model,
image_rgb,
person_bboxes,
bbox_format='xywh' # COCO annotations typically use 'xywh' format
bbox_format='xyxy' # COCO annotations typically use 'xywh' format
)

keypoints_results = []
Expand All @@ -87,19 +94,25 @@ def main():
if pred_instances is not None:
keypoints = pred_instances.keypoints
scores = pred_instances.keypoint_scores
bbox = person_bboxes[0] # Taking first bbox, adjust for multiple detections

# Save the bbox along with the keypoints data
keypoints_results.append({
'keypoints': keypoints.tolist(),
'scores': scores.tolist()
'scores': scores.tolist(),
'bbox': bbox.tolist() # Add bbox to the result
})

if args.out_dir:
draw_keypoints(image, keypoints, scores, args.score_thr)
draw_bboxes(image, person_bboxes)

# Save the visualized image if `out-dir` is provided
if args.out_dir:
out_file = os.path.join(args.out_dir, img_info['file_name'])
cv2.imwrite(out_file, image)

# Save individual prediction file
prediction_file = os.path.join(args.predictions_dir, f"{os.path.splitext(img_info['file_name'])[0]}.json")
with open(prediction_file, 'w') as f:
json.dump({
Expand Down

0 comments on commit 3826c29

Please sign in to comment.