Skip to content

Commit

Permalink
refactor convert
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Nov 9, 2024
1 parent 75b268f commit 5a1f6a3
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 99 deletions.
18 changes: 2 additions & 16 deletions docs/source/en/model_doc/vitpose.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ The original code can be found [here](https://github.com/ViTAE-Transformer/ViTPo
>>> outputs = model(pixel_values, dataset_index)
```

- ViTPose is a so-called top-down keypoint detection model. This means that one first uses an object detector, like [RT-DETR](rt-detr), to detect people (or other instances) in an image. Next, ViTPose takes the cropped images as input and predicts the keypoints.
- ViTPose is a so-called top-down keypoint detection model. This means that one first uses an object detector, like [RT-DETR](rt_detr.md), to detect people (or other instances) in an image. Next, ViTPose takes the cropped images as input and predicts the keypoints.

```py
import math
Expand Down Expand Up @@ -117,20 +117,6 @@ for pose_result in pose_results:
x, y, score = keypoint
print(f"coordinate : [{x}, {y}], score : {score}")

def draw_points(image, keypoints, keypoint_colors, keypoint_score_threshold, radius, show_keypoint_weight):
if keypoint_colors is not None:
assert len(keypoint_colors) == len(keypoints)
for id, keypoint in enumerate(keypoints):
x_coord, y_coord, keypoint_score = int(keypoint[0]), int(keypoint[1]), keypoint[2]
if keypoint_score > keypoint_score_threshold:
color = tuple(int(c) for c in keypoint_colors[id])
if show_keypoint_weight:
cv2.circle(image, (x_coord, y_coord), radius, color, -1)
transparency = max(0, min(1, keypoint_score))
cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
else:
cv2.circle(image, (x_coord, y_coord), radius, color, -1)

def draw_links(image, keypoints, keypoint_edges, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight, stick_width = 2):
height, width, _ = image.shape
if keypoint_edges is not None and link_colors is not None:
Expand Down Expand Up @@ -216,7 +202,7 @@ def visualize_keypoints(
return image

# Note: keypoint_edges and color palette are dataset-specific
keypoint_edges = config.keypoint_edges
keypoint_edges = config.edges

palette = np.array(
[
Expand Down
168 changes: 85 additions & 83 deletions src/transformers/models/vitpose/convert_vitpose_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from transformers import VitPoseBackboneConfig, VitPoseConfig, VitPoseForPoseEstimation, VitPoseImageProcessor


KEYS_TO_MODIFY_MAPPING = {
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
r"patch_embed.proj": "embeddings.patch_embeddings.projection",
r"pos_embed": "embeddings.position_embeddings",
r"blocks": "encoder.layer",
Expand All @@ -38,6 +38,8 @@
r"norm1": "layernorm_before",
r"norm2": "layernorm_after",
r"last_norm": "layernorm",
r"keypoint_head": "head",
r"final_layer": "conv",
}

MODEL_TO_FILE_NAME_MAPPING = {
Expand Down Expand Up @@ -72,7 +74,7 @@ def get_config(model_name):

use_simple_decoder = "simple" in model_name

keypoint_edges = (
edges = (
[
[15, 13],
[13, 11],
Expand All @@ -95,98 +97,56 @@ def get_config(model_name):
[4, 6],
],
)
keypoint_labels = (
[
"Nose",
"L_Eye",
"R_Eye",
"L_Ear",
"R_Ear",
"L_Shoulder",
"R_Shoulder",
"L_Elbow",
"R_Elbow",
"L_Wrist",
"R_Wrist",
"L_Hip",
"R_Hip",
"L_Knee",
"R_Knee",
"L_Ankle",
"R_Ankle",
],
)
id2label = {
0: "Nose",
1: "L_Eye",
2: "R_Eye",
3: "L_Ear",
4: "R_Ear",
5: "L_Shoulder",
6: "R_Shoulder",
7: "L_Elbow",
8: "R_Elbow",
9: "L_Wrist",
10: "R_Wrist",
11: "L_Hip",
12: "R_Hip",
13: "L_Knee",
14: "R_Knee",
15: "L_Ankle",
16: "R_Ankle",
}

label2id = {v: k for k, v in id2label.items()}

config = VitPoseConfig(
backbone_config=backbone_config,
num_labels=17,
use_simple_decoder=use_simple_decoder,
keypoint_edges=keypoint_edges,
keypoint_labels=keypoint_labels,
edges=edges,
id2label=id2label,
label2id=label2id,
)

return config


def convert_old_keys_to_new_keys(state_dict, config):
def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
"""
This function should be applied only once, on the concatenated keys to efficiently rename using
the key mappings.
"""
model_state_dict = {}

output_hypernetworks_qkv_pattern = r".*.qkv.*"
output_hypernetworks_head_pattern = r"keypoint_head.*"

dim = config.backbone_config.hidden_size

for key in state_dict.copy().keys():
value = state_dict.pop(key)
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)

if re.match(output_hypernetworks_qkv_pattern, key):
layer_num = int(key.split(".")[3])
if "weight" in key:
model_state_dict[f"backbone.encoder.layer.{layer_num}.attention.attention.query.weight"] = value[
:dim, :
]
model_state_dict[f"backbone.encoder.layer.{layer_num}.attention.attention.key.weight"] = value[
dim : dim * 2, :
]
model_state_dict[f"backbone.encoder.layer.{layer_num}.attention.attention.value.weight"] = value[
-dim:, :
]
else:
model_state_dict[f"backbone.encoder.layer.{layer_num}.attention.attention.query.bias"] = value[:dim]
model_state_dict[f"backbone.encoder.layer.{layer_num}.attention.attention.key.bias"] = value[
dim : dim * 2
]
model_state_dict[f"backbone.encoder.layer.{layer_num}.attention.attention.value.bias"] = value[-dim:]

if re.match(output_hypernetworks_head_pattern, key):
if config.use_simple_decoder:
key = key.replace("keypoint_head.final_layer", "head.conv")
else:
key = key.replace("keypoint_head", "head")
key = key.replace("deconv_layers.0.weight", "deconv1.weight")
key = key.replace("deconv_layers.1.weight", "batchnorm1.weight")
key = key.replace("deconv_layers.1.bias", "batchnorm1.bias")
key = key.replace("deconv_layers.1.running_mean", "batchnorm1.running_mean")
key = key.replace("deconv_layers.1.running_var", "batchnorm1.running_var")
key = key.replace("deconv_layers.1.num_batches_tracked", "batchnorm1.num_batches_tracked")
key = key.replace("deconv_layers.3.weight", "deconv2.weight")
key = key.replace("deconv_layers.4.weight", "batchnorm2.weight")
key = key.replace("deconv_layers.4.bias", "batchnorm2.bias")
key = key.replace("deconv_layers.4.running_mean", "batchnorm2.running_mean")
key = key.replace("deconv_layers.4.running_var", "batchnorm2.running_var")
key = key.replace("deconv_layers.4.num_batches_tracked", "batchnorm2.num_batches_tracked")
key = key.replace("final_layer.weight", "conv.weight")
key = key.replace("final_layer.bias", "conv.bias")
model_state_dict[key] = value

return model_state_dict
output_dict = {}
if state_dict_keys is not None:
old_text = "\n".join(state_dict_keys)
new_text = old_text
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
if replacement is None:
new_text = re.sub(pattern, "", new_text) # an empty line
continue
new_text = re.sub(pattern, replacement, new_text)
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
return output_dict


# We will verify our results on a COCO image
Expand Down Expand Up @@ -220,13 +180,55 @@ def write_model(model_path, model_name, push_to_hub):
)

print("Converting model...")
state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
new_state_dict = convert_old_keys_to_new_keys(state_dict, config)
original_state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
all_keys = list(original_state_dict.keys())
new_keys = convert_old_keys_to_new_keys(all_keys)

dim = config.backbone_config.hidden_size

state_dict = {}
for key in all_keys:
new_key = new_keys[key]
value = original_state_dict[key]

if re.search("qkv", new_key):
if "weight" in new_key:
state_dict[new_key.replace("self.qkv", "attention.query")] = value[:dim, :]
state_dict[new_key.replace("self.qkv", "attention.key")] = value[dim : dim * 2, :]
state_dict[new_key.replace("self.qkv", "attention.value")] = value[-dim:, :]
else:
state_dict[new_key.replace("self.qkv", "attention.query")] = value[:dim]
state_dict[new_key.replace("self.qkv", "attention.key")] = value[dim : dim * 2]
state_dict[new_key.replace("self.qkv", "attention.value")] = value[-dim:]

elif re.search("head", new_key) and not config.use_simple_decoder:
# Pattern for deconvolution layers
print(new_key)
deconv_pattern = r"deconv_layers\.(0|3)\.weight"
new_key = re.sub(deconv_pattern, lambda m: f"deconv{int(m.group(1))//3 + 1}.weight", new_key)
# Pattern for batch normalization layers
bn_patterns = [
(r"deconv_layers\.(\d+)\.weight", r"batchnorm\1.weight"),
(r"deconv_layers\.(\d+)\.bias", r"batchnorm\1.bias"),
(r"deconv_layers\.(\d+)\.running_mean", r"batchnorm\1.running_mean"),
(r"deconv_layers\.(\d+)\.running_var", r"batchnorm\1.running_var"),
(r"deconv_layers\.(\d+)\.num_batches_tracked", r"batchnorm\1.num_batches_tracked"),
]

for pattern, replacement in bn_patterns:
if re.search(pattern, new_key):
# Convert the layer number to the correct batch norm index
layer_num = int(re.search(pattern, key).group(1))
bn_num = layer_num // 3 + 1
new_key = re.sub(pattern, replacement.replace(r"\1", str(bn_num)), new_key)
state_dict[new_key] = value
else:
state_dict[new_key] = value

print("Loading the checkpoint in a Vitpose model.")
model = VitPoseForPoseEstimation(config)
model.eval()
model.load_state_dict(new_state_dict, strict=False)
model.load_state_dict(state_dict, strict=False)
print("Checkpoint loaded successfully.")

# create image processor
Expand Down

0 comments on commit 5a1f6a3

Please sign in to comment.