Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added support for dynamic number of keypoints #16

Merged
merged 1 commit into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions train/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@
else:
train_params_filepath = "training_params.yml" # for debug
train_counter, val_counter = 0, 0
center_matches = {}
keypoints_template = None
49 changes: 33 additions & 16 deletions train/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,13 @@ def select_task(task_type):
status="warning",
)
select_classes_button.disable()
elif "rectangle" not in project_shapes:
sly.app.show_dialog(
title="There are no classes of shape rectangle in selected project (bounding boxes are required for pose estimation)",
description="Please, change task type or select another project with classes of shape rectangle",
status="warning",
)
select_classes_button.disable()
else:
select_classes_button.enable()
models_table_columns = [key for key in g.pose_models_data[0].keys()]
Expand All @@ -630,21 +637,31 @@ def select_task(task_type):

@select_classes_button.click
def select_classes():
n_classes = len(classes_table.get_selected_classes())
if n_classes > 1:
classes_done.text = f"{n_classes} classes were selected successfully"
selected_classes = classes_table.get_selected_classes()
selected_shapes = [cls.geometry_type.geometry_name() for cls in project_meta.obj_classes if cls.name in selected_classes]
task_type = task_type_select.get_value()
if task_type == "pose estimation" and ("graph" not in selected_shapes or "rectangle" not in selected_shapes):
sly.app.show_dialog(
title="Pose estimation task requires input project to have at least one class of shape graph and one class of shape rectangle",
description="Please, select both classes of shape rectangle and graph or change task type",
status="warning",
)
else:
classes_done.text = f"{n_classes} class was selected successfully"
select_classes_button.hide()
classes_done.show()
select_other_classes_button.show()
classes_table.disable()
task_type_select.disable()
curr_step = stepper.get_active_step()
curr_step += 1
stepper.set_active_step(curr_step)
card_train_val_split.unlock()
card_train_val_split.uncollapse()
n_classes = len(classes_table.get_selected_classes())
if n_classes > 1:
classes_done.text = f"{n_classes} classes were selected successfully"
else:
classes_done.text = f"{n_classes} class was selected successfully"
select_classes_button.hide()
classes_done.show()
select_other_classes_button.show()
classes_table.disable()
task_type_select.disable()
curr_step = stepper.get_active_step()
curr_step += 1
stepper.set_active_step(curr_step)
card_train_val_split.unlock()
card_train_val_split.uncollapse()


@select_other_classes_button.click
Expand Down Expand Up @@ -831,7 +848,7 @@ def start_training():
necessary_geometries = ["rectangle"]
local_artifacts_dir = os.path.join(local_dir, "runs", "detect", "train")
elif task_type == "pose estimation":
necessary_geometries = ["graph"]
necessary_geometries = ["graph", "rectangle"]
local_artifacts_dir = os.path.join(local_dir, "runs", "pose", "train")
elif task_type == "instance segmentation":
necessary_geometries = ["bitmap", "polygon"]
Expand Down Expand Up @@ -1373,7 +1390,7 @@ def auto_train(request: Request):
necessary_geometries = ["rectangle"]
local_artifacts_dir = os.path.join(local_dir, "runs", "detect", "train")
elif task_type == "pose estimation":
necessary_geometries = ["graph"]
necessary_geometries = ["graph", "rectangle"]
local_artifacts_dir = os.path.join(local_dir, "runs", "pose", "train")
elif task_type == "instance segmentation":
necessary_geometries = ["bitmap", "polygon"]
Expand Down
72 changes: 57 additions & 15 deletions train/src/sly_to_yolov8.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import yaml
import supervisely as sly
import src.globals as g
import numpy as np
import math


def _transform_label(class_names, img_size, label: sly.Label, task_type):
def _transform_label(class_names, img_size, label: sly.Label, task_type, labels_list):
if task_type == "object detection":
class_number = class_names.index(label.obj_class.name)
rect_geometry = label.geometry.to_bbox()
Expand All @@ -15,18 +17,54 @@ def _transform_label(class_names, img_size, label: sly.Label, task_type):
height = round(rect_geometry.height / img_size[0], 6)
result = "{} {} {} {} {}".format(class_number, x_center, y_center, width, height)
elif task_type == "pose estimation":
class_number = class_names.index(label.obj_class.name)
rect_geometry = label.geometry.to_bbox()
center = rect_geometry.center
x_center = round(center.col / img_size[1], 6)
y_center = round(center.row / img_size[0], 6)
width = round(rect_geometry.width / img_size[1], 6)
height = round(rect_geometry.height / img_size[0], 6)
nodes = label.geometry.nodes
# find corresponding bbox for graph
graph_center = label.geometry.to_bbox().center
graph_center = [graph_center.col, graph_center.row]
boxes_list = [label.geometry for label in labels_list if isinstance(label.geometry, sly.Rectangle)]
center2box = {}
for box in boxes_list:
center2box[f"{box.center.col} {box.center.row}"] = box
distance2center = {}
for center in center2box.keys():
cx, cy = center.split()
distance = math.dist(graph_center, [int(cx), int(cy)])
distance2center[distance] = center
min_distance = min(distance2center.keys())
box_center = distance2center[min_distance]
# corresponding bbox for graph is the one with the smallest distance to graph center
matched_box = center2box[box_center]
box_x, box_y = box_center.split()
box_x, box_y = int(box_x), int(box_y)
box_center = [box_x, box_y]
if box_center not in g.center_matches.values():
g.center_matches[f"{graph_center[0]} {graph_center[1]}"] = box_center
class_number = class_names.index(label.obj_class.name)
x_center = round(box_center[0] / img_size[1], 6)
y_center = round(box_center[1] / img_size[0], 6)
width = round(matched_box.width / img_size[1], 6)
height = round(matched_box.height / img_size[0], 6)
# if failed to match graphs and boxes, get box by transforming graph to box
else:
sly.logger.warn("Failed to match graphs and boxes, boxes will be created by transforming graphs to boxes")
class_number = class_names.index(label.obj_class.name)
rect_geometry = label.geometry.to_bbox()
center = rect_geometry.center
x_center = round(center.col / img_size[1], 6)
y_center = round(center.row / img_size[0], 6)
width = round(rect_geometry.width / img_size[1], 6)
height = round(rect_geometry.height / img_size[0], 6)
graph_nodes = label.geometry.nodes
keypoints = []
for node in nodes.values():
keypoints.append(round(node.location.col / img_size[1], 6))
keypoints.append(round(node.location.row / img_size[0], 6))
for node_id in g.keypoints_template["nodes"].keys():
if node_id in graph_nodes.keys():
visibility = 2
graph_node = graph_nodes[node_id]
point_x = round(graph_node.location.col / img_size[1], 6)
point_y = round(graph_node.location.row / img_size[0], 6)
else:
visibility = 0
point_x, point_y = 0, 0
keypoints.extend([point_x, point_y, visibility])
keypoints_str = " ".join(str(point) for point in keypoints)
result = f"{class_number} {x_center} {y_center} {width} {height} {keypoints_str}"
elif task_type == "instance segmentation":
Expand All @@ -52,11 +90,12 @@ def _transform_label(class_names, img_size, label: sly.Label, task_type):
result = f"{class_number} {scaled_points_str}"
return result


def _create_data_config(output_dir, meta: sly.ProjectMeta, task_type):
class_names = []
class_colors = []
for obj_class in meta.obj_classes:
if task_type == "pose estimation" and obj_class.geometry_type.geometry_name() != "graph":
continue
class_names.append(obj_class.name)
class_colors.append(obj_class.color)
if task_type in ["object detection", "instance segmentation"]:
Expand All @@ -73,6 +112,7 @@ def _create_data_config(output_dir, meta: sly.ProjectMeta, task_type):
for obj_class in meta.obj_classes:
if obj_class.geometry_type.geometry_name() == "graph":
geometry_config = obj_class.geometry_config
g.keypoints_template = geometry_config
n_keypoints = len(geometry_config["nodes"])
flip_idx = [i for i in range(n_keypoints)]
break
Expand All @@ -81,7 +121,7 @@ def _create_data_config(output_dir, meta: sly.ProjectMeta, task_type):
"val": os.path.join(output_dir, "images/val"),
"labels_train": os.path.join(output_dir, "labels/train"),
"labels_val": os.path.join(output_dir, "labels/val"),
"kpt_shape": [n_keypoints, 2],
"kpt_shape": [n_keypoints, 3],
"flip_idx": flip_idx,
"names": class_names,
}
Expand All @@ -101,7 +141,9 @@ def _transform_annotation(ann, class_names, save_path, task_type):
yolov8_ann = []
for label in ann.labels:
if label.obj_class.name in class_names:
transformed_label = _transform_label(class_names, ann.img_size, label, task_type)
if task_type == "pose estimation" and isinstance(label.geometry, sly.Rectangle):
continue
transformed_label = _transform_label(class_names, ann.img_size, label, task_type, ann.labels)
if transformed_label:
yolov8_ann.append(transformed_label)

Expand Down