forked from open-mmlab/mmpose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference_custom.py
131 lines (106 loc) · 4.67 KB
/
inference_custom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
import os
import json
import numpy as np
import cv2
from mmengine.config import Config, DictAction
from mmpose.apis import init_model, inference_topdown
from xtcocotools.coco import COCO
from mmengine.utils import ProgressBar
def parse_args():
parser = argparse.ArgumentParser(description="Run MMPose inference on images")
parser.add_argument('config', help='Path to model config file')
parser.add_argument('model', help='Path to checkpoint or ONNX file')
parser.add_argument('--img-dir', type=str, required=True, help='Directory with input images')
parser.add_argument('--bbox-json', type=str, required=True, help='Path to COCO format bounding box JSON')
parser.add_argument('--out-dir', type=str, help='Directory to save visualized results (optional)')
parser.add_argument('--output-file', type=str, help='File to save keypoint results in JSON')
parser.add_argument('--device', default='cuda:0', help='Device to run inference on (e.g., "cuda:0" or "cpu")')
parser.add_argument('--score-thr', type=float, default=0.3, help='Keypoint score threshold')
parser.add_argument(
'--cfg-options', nargs='+', action=DictAction,
help='Override some settings in the config file. The key-value pair in '
'xxx=yyy format will be merged into the config.')
return parser.parse_args()
def draw_keypoints(image, keypoints, scores, score_thr):
"""Draw keypoints on the image using OpenCV."""
for el1, el2 in zip(keypoints, scores):
for kp, score in zip(el1, el2):
if score > score_thr:
x, y = int(kp[0]), int(kp[1])
cv2.circle(image, (x, y), 5, (0, 255, 0), -1) # Draw keypoint
cv2.putText(image, f"{int(score*100)}%", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
def main():
args = parse_args()
# Load the COCO bounding boxes
coco = COCO(args.bbox_json)
img_ids = list(coco.imgs.keys())
# Initialize the model
cfg = Config.fromfile(args.config)
if args.cfg_options:
cfg.merge_from_dict(args.cfg_options)
model = init_model(cfg, args.model, device=args.device)
# Ensure output directories exist if `out-dir` is provided
if args.out_dir:
os.makedirs(args.out_dir, exist_ok=True)
# Results to be saved
results = []
# Progress bar
progress_bar = ProgressBar(len(img_ids))
for img_id in img_ids:
img_info = coco.loadImgs([img_id])[0]
img_path = os.path.join(args.img_dir, img_info['file_name'])
if not os.path.exists(img_path):
print(f"Image not found: {img_path}")
progress_bar.update()
continue
# Load the image
image = cv2.imread(img_path)
if image is None:
print(f"Failed to read image: {img_path}")
progress_bar.update()
continue
# Convert image to RGB for inference
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Load bounding boxes for the image
ann_ids = coco.getAnnIds(imgIds=[img_id])
annotations = coco.loadAnns(ann_ids)
person_bboxes = np.array([ann['bbox'] for ann in annotations])
# Run pose inference
pose_results = inference_topdown(
model,
image_rgb,
person_bboxes,
bbox_format='xywh' # COCO annotations typically use 'xywh' format
)
# Extract keypoints and bounding boxes from PoseDataSample
keypoints_results = []
for pose in pose_results:
pred_instances = pose.pred_instances
if pred_instances is not None:
keypoints = pred_instances.keypoints
scores = pred_instances.keypoint_scores
keypoints_results.append({
'keypoints': keypoints.tolist(),
'scores': scores.tolist()
})
# Draw keypoints on the image
draw_keypoints(image, keypoints, scores, args.score_thr)
# Save the visualized image if `out-dir` is provided
if args.out_dir:
out_file = os.path.join(args.out_dir, img_info['file_name'])
cv2.imwrite(out_file, image)
# Save keypoints to results
results.append({
'image_id': img_id,
'file_name': img_info['file_name'],
'keypoints': keypoints_results
})
progress_bar.update()
# Save results to output file
if args.output_file:
with open(args.output_file, 'w') as f:
json.dump(results, f, indent=4)
print(f"Inference completed. Results {'saved to ' + args.out_dir if args.out_dir else ''}.")
if __name__ == '__main__':
main()