Skip to content

Commit

Permalink
added binding keys support
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxTeselkin committed Nov 1, 2023
1 parent fd2b857 commit 837cac7
Showing 1 changed file with 47 additions and 34 deletions.
81 changes: 47 additions & 34 deletions train/src/sly_to_yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,55 @@ def _transform_label(class_names, img_size, label: sly.Label, task_type, labels_
height = round(rect_geometry.height / img_size[0], 6)
result = "{} {} {} {} {}".format(class_number, x_center, y_center, width, height)
elif task_type == "pose estimation":
# find corresponding bbox for graph
graph_center = label.geometry.to_bbox().center
graph_center = [graph_center.col, graph_center.row]
boxes_list = [label.geometry for label in labels_list if isinstance(label.geometry, sly.Rectangle)]
center2box = {}
for box in boxes_list:
center2box[f"{box.center.col} {box.center.row}"] = box
distance2center = {}
for center in center2box.keys():
cx, cy = center.split()
distance = math.dist(graph_center, [int(cx), int(cy)])
distance2center[distance] = center
min_distance = min(distance2center.keys())
box_center = distance2center[min_distance]
# corresponding bbox for graph is the one with the smallest distance to graph center
matched_box = center2box[box_center]
box_x, box_y = box_center.split()
box_x, box_y = int(box_x), int(box_y)
box_center = [box_x, box_y]
if box_center not in g.center_matches.values():
g.center_matches[f"{graph_center[0]} {graph_center[1]}"] = box_center
if label.binding_key:
binding_key = label.binding_key
boxes_list = []
for lbl in labels_list:
if isinstance(lbl.geometry, sly.Rectangle) and lbl.binding_key:
boxes_list.append(lbl)
box = [element.geometry for element in boxes_list if element.binding_key == binding_key][0]
x_center = round(box.center.col / img_size[1], 6)
y_center = round(box.center.row / img_size[0], 6)
width = round(box.width / img_size[1], 6)
height = round(box.height / img_size[0], 6)
class_number = class_names.index(label.obj_class.name)
x_center = round(box_center[0] / img_size[1], 6)
y_center = round(box_center[1] / img_size[0], 6)
width = round(matched_box.width / img_size[1], 6)
height = round(matched_box.height / img_size[0], 6)
# if failed to match graphs and boxes, get box by transforming graph to box
else:
sly.logger.warn("Failed to match graphs and boxes, boxes will be created by transforming graphs to boxes")
class_number = class_names.index(label.obj_class.name)
rect_geometry = label.geometry.to_bbox()
center = rect_geometry.center
x_center = round(center.col / img_size[1], 6)
y_center = round(center.row / img_size[0], 6)
width = round(rect_geometry.width / img_size[1], 6)
height = round(rect_geometry.height / img_size[0], 6)
# find corresponding bbox for graph
graph_center = label.geometry.to_bbox().center
graph_center = [graph_center.col, graph_center.row]
boxes_list = [label.geometry for label in labels_list if isinstance(label.geometry, sly.Rectangle)]
center2box = {}
for box in boxes_list:
center2box[f"{box.center.col} {box.center.row}"] = box
distance2center = {}
for center in center2box.keys():
cx, cy = center.split()
distance = math.dist(graph_center, [int(cx), int(cy)])
distance2center[distance] = center
min_distance = min(distance2center.keys())
box_center = distance2center[min_distance]
# corresponding bbox for graph is the one with the smallest distance to graph center
matched_box = center2box[box_center]
box_x, box_y = box_center.split()
box_x, box_y = int(box_x), int(box_y)
box_center = [box_x, box_y]
if box_center not in g.center_matches.values():
g.center_matches[f"{graph_center[0]} {graph_center[1]}"] = box_center
class_number = class_names.index(label.obj_class.name)
x_center = round(box_center[0] / img_size[1], 6)
y_center = round(box_center[1] / img_size[0], 6)
width = round(matched_box.width / img_size[1], 6)
height = round(matched_box.height / img_size[0], 6)
# if failed to match graphs and boxes, get box by transforming graph to box
else:
sly.logger.warn("Failed to match graphs and boxes, boxes will be created by transforming graphs to boxes")
class_number = class_names.index(label.obj_class.name)
rect_geometry = label.geometry.to_bbox()
center = rect_geometry.center
x_center = round(center.col / img_size[1], 6)
y_center = round(center.row / img_size[0], 6)
width = round(rect_geometry.width / img_size[1], 6)
height = round(rect_geometry.height / img_size[0], 6)
graph_nodes = label.geometry.nodes
keypoints = []
for node_id in g.keypoints_template["nodes"].keys():
Expand Down

0 comments on commit 837cac7

Please sign in to comment.