Skip to content

Commit

Permalink
add inference script and usage examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrii-Sheba committed Nov 29, 2024
1 parent 4f6275a commit 9b263ac
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tools/inference.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
python tools/inference_custom.py \
/data/new_mmpose/mmpose/work_dirs/ls_1704_res18/td-hm_res50_8xb64-210e_coco-256x192.py \
/data/new_mmpose/mmpose/work_dirs/ls_1704_res18/epoch_300.pth \
--img-dir /data/new_mmpose/mmpose/data/1704_split_exported_data_project_id_422/val2017 \
--out-dir /data/new_mmpose/mmpose/work_dirs/ls_1704_res18/out \
--bbox-json /data/new_mmpose/mmpose/data/1704_split_exported_data_project_id_422/annotations/forklift_keypoints_val2017.json \
--output-file /data/new_mmpose/mmpose/work_dirs/ls_1704_res18/results.json

python tools/inference_custom.py \
/data/new_mmpose/mmpose/work_dirs/ls_1704_res18/td-hm_res50_8xb64-210e_coco-256x192.py \
/data/new_mmpose/mmpose/work_dirs/ls_1704_res18/epoch_300.pth \
--img-dir /data/new_mmpose/mmpose/data/1704_split_exported_data_project_id_422/val2017 \
--bbox-json /data/new_mmpose/mmpose/data/1704_split_exported_data_project_id_422/annotations/forklift_keypoints_val2017.json \
--output-file /data/new_mmpose/mmpose/work_dirs/ls_1704_res18/results.json
131 changes: 131 additions & 0 deletions tools/inference_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import argparse
import os
import json
import numpy as np
import cv2
from mmengine.config import Config, DictAction
from mmpose.apis import init_model, inference_topdown
from xtcocotools.coco import COCO
from mmengine.utils import ProgressBar


def parse_args():
parser = argparse.ArgumentParser(description="Run MMPose inference on images")
parser.add_argument('config', help='Path to model config file')
parser.add_argument('model', help='Path to checkpoint or ONNX file')
parser.add_argument('--img-dir', type=str, required=True, help='Directory with input images')
parser.add_argument('--bbox-json', type=str, required=True, help='Path to COCO format bounding box JSON')
parser.add_argument('--out-dir', type=str, help='Directory to save visualized results (optional)')
parser.add_argument('--output-file', type=str, help='File to save keypoint results in JSON')
parser.add_argument('--device', default='cuda:0', help='Device to run inference on (e.g., "cuda:0" or "cpu")')
parser.add_argument('--score-thr', type=float, default=0.3, help='Keypoint score threshold')
parser.add_argument(
'--cfg-options', nargs='+', action=DictAction,
help='Override some settings in the config file. The key-value pair in '
'xxx=yyy format will be merged into the config.')
return parser.parse_args()


def draw_keypoints(image, keypoints, scores, score_thr):
"""Draw keypoints on the image using OpenCV."""
for el1, el2 in zip(keypoints, scores):
for kp, score in zip(el1, el2):
if score > score_thr:
x, y = int(kp[0]), int(kp[1])
cv2.circle(image, (x, y), 5, (0, 255, 0), -1) # Draw keypoint
cv2.putText(image, f"{int(score*100)}%", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)


def main():
args = parse_args()

# Load the COCO bounding boxes
coco = COCO(args.bbox_json)
img_ids = list(coco.imgs.keys())

# Initialize the model
cfg = Config.fromfile(args.config)
if args.cfg_options:
cfg.merge_from_dict(args.cfg_options)
model = init_model(cfg, args.model, device=args.device)

# Ensure output directories exist if `out-dir` is provided
if args.out_dir:
os.makedirs(args.out_dir, exist_ok=True)

# Results to be saved
results = []

# Progress bar
progress_bar = ProgressBar(len(img_ids))

for img_id in img_ids:
img_info = coco.loadImgs([img_id])[0]
img_path = os.path.join(args.img_dir, img_info['file_name'])

if not os.path.exists(img_path):
print(f"Image not found: {img_path}")
progress_bar.update()
continue

# Load the image
image = cv2.imread(img_path)
if image is None:
print(f"Failed to read image: {img_path}")
progress_bar.update()
continue

# Convert image to RGB for inference
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Load bounding boxes for the image
ann_ids = coco.getAnnIds(imgIds=[img_id])
annotations = coco.loadAnns(ann_ids)
person_bboxes = np.array([ann['bbox'] for ann in annotations])

# Run pose inference
pose_results = inference_topdown(
model,
image_rgb,
person_bboxes,
bbox_format='xywh' # COCO annotations typically use 'xywh' format
)

# Extract keypoints and bounding boxes from PoseDataSample
keypoints_results = []
for pose in pose_results:
pred_instances = pose.pred_instances
if pred_instances is not None:
keypoints = pred_instances.keypoints
scores = pred_instances.keypoint_scores
keypoints_results.append({
'keypoints': keypoints.tolist(),
'scores': scores.tolist()
})

# Draw keypoints on the image
draw_keypoints(image, keypoints, scores, args.score_thr)

# Save the visualized image if `out-dir` is provided
if args.out_dir:
out_file = os.path.join(args.out_dir, img_info['file_name'])
cv2.imwrite(out_file, image)

# Save keypoints to results
results.append({
'image_id': img_id,
'file_name': img_info['file_name'],
'keypoints': keypoints_results
})

progress_bar.update()

# Save results to output file
if args.output_file:
with open(args.output_file, 'w') as f:
json.dump(results, f, indent=4)
print(f"Inference completed. Results {'saved to ' + args.out_dir if args.out_dir else ''}.")


if __name__ == '__main__':
main()

0 comments on commit 9b263ac

Please sign in to comment.