forked from open-mmlab/mmpose
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add inference script and usage examples
- Loading branch information
1 parent
4f6275a
commit 9b263ac
Showing
2 changed files
with
145 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
python tools/inference_custom.py \ | ||
/data/new_mmpose/mmpose/work_dirs/ls_1704_res18/td-hm_res50_8xb64-210e_coco-256x192.py \ | ||
/data/new_mmpose/mmpose/work_dirs/ls_1704_res18/epoch_300.pth \ | ||
--img-dir /data/new_mmpose/mmpose/data/1704_split_exported_data_project_id_422/val2017 \ | ||
--out-dir /data/new_mmpose/mmpose/work_dirs/ls_1704_res18/out \ | ||
--bbox-json /data/new_mmpose/mmpose/data/1704_split_exported_data_project_id_422/annotations/forklift_keypoints_val2017.json \ | ||
--output-file /data/new_mmpose/mmpose/work_dirs/ls_1704_res18/results.json | ||
|
||
python tools/inference_custom.py \ | ||
/data/new_mmpose/mmpose/work_dirs/ls_1704_res18/td-hm_res50_8xb64-210e_coco-256x192.py \ | ||
/data/new_mmpose/mmpose/work_dirs/ls_1704_res18/epoch_300.pth \ | ||
--img-dir /data/new_mmpose/mmpose/data/1704_split_exported_data_project_id_422/val2017 \ | ||
--bbox-json /data/new_mmpose/mmpose/data/1704_split_exported_data_project_id_422/annotations/forklift_keypoints_val2017.json \ | ||
--output-file /data/new_mmpose/mmpose/work_dirs/ls_1704_res18/results.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import argparse | ||
import os | ||
import json | ||
import numpy as np | ||
import cv2 | ||
from mmengine.config import Config, DictAction | ||
from mmpose.apis import init_model, inference_topdown | ||
from xtcocotools.coco import COCO | ||
from mmengine.utils import ProgressBar | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Run MMPose inference on images") | ||
parser.add_argument('config', help='Path to model config file') | ||
parser.add_argument('model', help='Path to checkpoint or ONNX file') | ||
parser.add_argument('--img-dir', type=str, required=True, help='Directory with input images') | ||
parser.add_argument('--bbox-json', type=str, required=True, help='Path to COCO format bounding box JSON') | ||
parser.add_argument('--out-dir', type=str, help='Directory to save visualized results (optional)') | ||
parser.add_argument('--output-file', type=str, help='File to save keypoint results in JSON') | ||
parser.add_argument('--device', default='cuda:0', help='Device to run inference on (e.g., "cuda:0" or "cpu")') | ||
parser.add_argument('--score-thr', type=float, default=0.3, help='Keypoint score threshold') | ||
parser.add_argument( | ||
'--cfg-options', nargs='+', action=DictAction, | ||
help='Override some settings in the config file. The key-value pair in ' | ||
'xxx=yyy format will be merged into the config.') | ||
return parser.parse_args() | ||
|
||
|
||
def draw_keypoints(image, keypoints, scores, score_thr): | ||
"""Draw keypoints on the image using OpenCV.""" | ||
for el1, el2 in zip(keypoints, scores): | ||
for kp, score in zip(el1, el2): | ||
if score > score_thr: | ||
x, y = int(kp[0]), int(kp[1]) | ||
cv2.circle(image, (x, y), 5, (0, 255, 0), -1) # Draw keypoint | ||
cv2.putText(image, f"{int(score*100)}%", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
# Load the COCO bounding boxes | ||
coco = COCO(args.bbox_json) | ||
img_ids = list(coco.imgs.keys()) | ||
|
||
# Initialize the model | ||
cfg = Config.fromfile(args.config) | ||
if args.cfg_options: | ||
cfg.merge_from_dict(args.cfg_options) | ||
model = init_model(cfg, args.model, device=args.device) | ||
|
||
# Ensure output directories exist if `out-dir` is provided | ||
if args.out_dir: | ||
os.makedirs(args.out_dir, exist_ok=True) | ||
|
||
# Results to be saved | ||
results = [] | ||
|
||
# Progress bar | ||
progress_bar = ProgressBar(len(img_ids)) | ||
|
||
for img_id in img_ids: | ||
img_info = coco.loadImgs([img_id])[0] | ||
img_path = os.path.join(args.img_dir, img_info['file_name']) | ||
|
||
if not os.path.exists(img_path): | ||
print(f"Image not found: {img_path}") | ||
progress_bar.update() | ||
continue | ||
|
||
# Load the image | ||
image = cv2.imread(img_path) | ||
if image is None: | ||
print(f"Failed to read image: {img_path}") | ||
progress_bar.update() | ||
continue | ||
|
||
# Convert image to RGB for inference | ||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
|
||
# Load bounding boxes for the image | ||
ann_ids = coco.getAnnIds(imgIds=[img_id]) | ||
annotations = coco.loadAnns(ann_ids) | ||
person_bboxes = np.array([ann['bbox'] for ann in annotations]) | ||
|
||
# Run pose inference | ||
pose_results = inference_topdown( | ||
model, | ||
image_rgb, | ||
person_bboxes, | ||
bbox_format='xywh' # COCO annotations typically use 'xywh' format | ||
) | ||
|
||
# Extract keypoints and bounding boxes from PoseDataSample | ||
keypoints_results = [] | ||
for pose in pose_results: | ||
pred_instances = pose.pred_instances | ||
if pred_instances is not None: | ||
keypoints = pred_instances.keypoints | ||
scores = pred_instances.keypoint_scores | ||
keypoints_results.append({ | ||
'keypoints': keypoints.tolist(), | ||
'scores': scores.tolist() | ||
}) | ||
|
||
# Draw keypoints on the image | ||
draw_keypoints(image, keypoints, scores, args.score_thr) | ||
|
||
# Save the visualized image if `out-dir` is provided | ||
if args.out_dir: | ||
out_file = os.path.join(args.out_dir, img_info['file_name']) | ||
cv2.imwrite(out_file, image) | ||
|
||
# Save keypoints to results | ||
results.append({ | ||
'image_id': img_id, | ||
'file_name': img_info['file_name'], | ||
'keypoints': keypoints_results | ||
}) | ||
|
||
progress_bar.update() | ||
|
||
# Save results to output file | ||
if args.output_file: | ||
with open(args.output_file, 'w') as f: | ||
json.dump(results, f, indent=4) | ||
print(f"Inference completed. Results {'saved to ' + args.out_dir if args.out_dir else ''}.") | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |