Skip to content

Commit

Permalink
added support for dynamic number of keypoints
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxTeselkin committed Oct 24, 2023
1 parent bba8fdf commit df73c9b
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 31 deletions.
2 changes: 2 additions & 0 deletions train/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@
else:
train_params_filepath = "training_params.yml" # for debug
train_counter, val_counter = 0, 0
center_matches = {}
keypoints_template = None
49 changes: 33 additions & 16 deletions train/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,13 @@ def select_task(task_type):
status="warning",
)
select_classes_button.disable()
elif "rectangle" not in project_shapes:
sly.app.show_dialog(
title="There are no classes of shape rectangle in selected project (bounding boxes are required for pose estimation)",
description="Please, change task type or select another project with classes of shape rectangle",
status="warning",
)
select_classes_button.disable()
else:
select_classes_button.enable()
models_table_columns = [key for key in g.pose_models_data[0].keys()]
Expand All @@ -630,21 +637,31 @@ def select_task(task_type):

@select_classes_button.click
def select_classes():
n_classes = len(classes_table.get_selected_classes())
if n_classes > 1:
classes_done.text = f"{n_classes} classes were selected successfully"
selected_classes = classes_table.get_selected_classes()
selected_shapes = [cls.geometry_type.geometry_name() for cls in project_meta.obj_classes if cls.name in selected_classes]
task_type = task_type_select.get_value()
if task_type == "pose estimation" and ("graph" not in selected_shapes or "rectangle" not in selected_shapes):
sly.app.show_dialog(
title="Pose estimation task requires input project to have at least one class of shape graph and one class of shape rectangle",
description="Please, select both classes of shape rectangle and graph or change task type",
status="warning",
)
else:
classes_done.text = f"{n_classes} class was selected successfully"
select_classes_button.hide()
classes_done.show()
select_other_classes_button.show()
classes_table.disable()
task_type_select.disable()
curr_step = stepper.get_active_step()
curr_step += 1
stepper.set_active_step(curr_step)
card_train_val_split.unlock()
card_train_val_split.uncollapse()
n_classes = len(classes_table.get_selected_classes())
if n_classes > 1:
classes_done.text = f"{n_classes} classes were selected successfully"
else:
classes_done.text = f"{n_classes} class was selected successfully"
select_classes_button.hide()
classes_done.show()
select_other_classes_button.show()
classes_table.disable()
task_type_select.disable()
curr_step = stepper.get_active_step()
curr_step += 1
stepper.set_active_step(curr_step)
card_train_val_split.unlock()
card_train_val_split.uncollapse()


@select_other_classes_button.click
Expand Down Expand Up @@ -831,7 +848,7 @@ def start_training():
necessary_geometries = ["rectangle"]
local_artifacts_dir = os.path.join(local_dir, "runs", "detect", "train")
elif task_type == "pose estimation":
necessary_geometries = ["graph"]
necessary_geometries = ["graph", "rectangle"]
local_artifacts_dir = os.path.join(local_dir, "runs", "pose", "train")
elif task_type == "instance segmentation":
necessary_geometries = ["bitmap", "polygon"]
Expand Down Expand Up @@ -1373,7 +1390,7 @@ def auto_train(request: Request):
necessary_geometries = ["rectangle"]
local_artifacts_dir = os.path.join(local_dir, "runs", "detect", "train")
elif task_type == "pose estimation":
necessary_geometries = ["graph"]
necessary_geometries = ["graph", "rectangle"]
local_artifacts_dir = os.path.join(local_dir, "runs", "pose", "train")
elif task_type == "instance segmentation":
necessary_geometries = ["bitmap", "polygon"]
Expand Down
72 changes: 57 additions & 15 deletions train/src/sly_to_yolov8.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import yaml
import supervisely as sly
import src.globals as g
import numpy as np
import math


def _transform_label(class_names, img_size, label: sly.Label, task_type):
def _transform_label(class_names, img_size, label: sly.Label, task_type, labels_list):
if task_type == "object detection":
class_number = class_names.index(label.obj_class.name)
rect_geometry = label.geometry.to_bbox()
Expand All @@ -15,18 +17,54 @@ def _transform_label(class_names, img_size, label: sly.Label, task_type):
height = round(rect_geometry.height / img_size[0], 6)
result = "{} {} {} {} {}".format(class_number, x_center, y_center, width, height)
elif task_type == "pose estimation":
class_number = class_names.index(label.obj_class.name)
rect_geometry = label.geometry.to_bbox()
center = rect_geometry.center
x_center = round(center.col / img_size[1], 6)
y_center = round(center.row / img_size[0], 6)
width = round(rect_geometry.width / img_size[1], 6)
height = round(rect_geometry.height / img_size[0], 6)
nodes = label.geometry.nodes
# find corresponding bbox for graph
graph_center = label.geometry.to_bbox().center
graph_center = [graph_center.col, graph_center.row]
boxes_list = [label.geometry for label in labels_list if isinstance(label.geometry, sly.Rectangle)]
center2box = {}
for box in boxes_list:
center2box[f"{box.center.col} {box.center.row}"] = box
distance2center = {}
for center in center2box.keys():
cx, cy = center.split()
distance = math.dist(graph_center, [int(cx), int(cy)])
distance2center[distance] = center
min_distance = min(distance2center.keys())
box_center = distance2center[min_distance]
# corresponding bbox for graph is the one with the smallest distance to graph center
matched_box = center2box[box_center]
box_x, box_y = box_center.split()
box_x, box_y = int(box_x), int(box_y)
box_center = [box_x, box_y]
if box_center not in g.center_matches.values():
g.center_matches[f"{graph_center[0]} {graph_center[1]}"] = box_center
class_number = class_names.index(label.obj_class.name)
x_center = round(box_center[0] / img_size[1], 6)
y_center = round(box_center[1] / img_size[0], 6)
width = round(matched_box.width / img_size[1], 6)
height = round(matched_box.height / img_size[0], 6)
# if failed to match graphs and boxes, get box by transforming graph to box
else:
sly.logger.warn("Failed to match graphs and boxes, boxes will be created by transforming graphs to boxes")
class_number = class_names.index(label.obj_class.name)
rect_geometry = label.geometry.to_bbox()
center = rect_geometry.center
x_center = round(center.col / img_size[1], 6)
y_center = round(center.row / img_size[0], 6)
width = round(rect_geometry.width / img_size[1], 6)
height = round(rect_geometry.height / img_size[0], 6)
graph_nodes = label.geometry.nodes
keypoints = []
for node in nodes.values():
keypoints.append(round(node.location.col / img_size[1], 6))
keypoints.append(round(node.location.row / img_size[0], 6))
for node_id in g.keypoints_template["nodes"].keys():
if node_id in graph_nodes.keys():
visibility = 2
graph_node = graph_nodes[node_id]
point_x = round(graph_node.location.col / img_size[1], 6)
point_y = round(graph_node.location.row / img_size[0], 6)
else:
visibility = 0
point_x, point_y = 0, 0
keypoints.extend([point_x, point_y, visibility])
keypoints_str = " ".join(str(point) for point in keypoints)
result = f"{class_number} {x_center} {y_center} {width} {height} {keypoints_str}"
elif task_type == "instance segmentation":
Expand All @@ -52,11 +90,12 @@ def _transform_label(class_names, img_size, label: sly.Label, task_type):
result = f"{class_number} {scaled_points_str}"
return result


def _create_data_config(output_dir, meta: sly.ProjectMeta, task_type):
class_names = []
class_colors = []
for obj_class in meta.obj_classes:
if task_type == "pose estimation" and obj_class.geometry_type.geometry_name() != "graph":
continue
class_names.append(obj_class.name)
class_colors.append(obj_class.color)
if task_type in ["object detection", "instance segmentation"]:
Expand All @@ -73,6 +112,7 @@ def _create_data_config(output_dir, meta: sly.ProjectMeta, task_type):
for obj_class in meta.obj_classes:
if obj_class.geometry_type.geometry_name() == "graph":
geometry_config = obj_class.geometry_config
g.keypoints_template = geometry_config
n_keypoints = len(geometry_config["nodes"])
flip_idx = [i for i in range(n_keypoints)]
break
Expand All @@ -81,7 +121,7 @@ def _create_data_config(output_dir, meta: sly.ProjectMeta, task_type):
"val": os.path.join(output_dir, "images/val"),
"labels_train": os.path.join(output_dir, "labels/train"),
"labels_val": os.path.join(output_dir, "labels/val"),
"kpt_shape": [n_keypoints, 2],
"kpt_shape": [n_keypoints, 3],
"flip_idx": flip_idx,
"names": class_names,
}
Expand All @@ -101,7 +141,9 @@ def _transform_annotation(ann, class_names, save_path, task_type):
yolov8_ann = []
for label in ann.labels:
if label.obj_class.name in class_names:
transformed_label = _transform_label(class_names, ann.img_size, label, task_type)
if task_type == "pose estimation" and isinstance(label.geometry, sly.Rectangle):
continue
transformed_label = _transform_label(class_names, ann.img_size, label, task_type, ann.labels)
if transformed_label:
yolov8_ann.append(transformed_label)

Expand Down

0 comments on commit df73c9b

Please sign in to comment.