Skip to content

Commit

Permalink
handle no bbox on image case, clear the dict with matched centers
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Jan 12, 2024
1 parent 4707480 commit 8515b45
Showing 1 changed file with 39 additions and 25 deletions.
64 changes: 39 additions & 25 deletions train/src/sly_to_yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,38 +34,51 @@ def _transform_label(class_names, img_size, label: sly.Label, task_type, labels_
graph_center = label.geometry.to_bbox().center
graph_center = [graph_center.col, graph_center.row]
boxes_list = [label.geometry for label in labels_list if isinstance(label.geometry, sly.Rectangle)]
center2box = {}
for box in boxes_list:
center2box[f"{box.center.col} {box.center.row}"] = box
distance2center = {}
for center in center2box.keys():
cx, cy = center.split()
distance = math.dist(graph_center, [int(cx), int(cy)])
distance2center[distance] = center
min_distance = min(distance2center.keys())
box_center = distance2center[min_distance]
# corresponding bbox for graph is the one with the smallest distance to graph center
matched_box = center2box[box_center]
box_x, box_y = box_center.split()
box_x, box_y = int(box_x), int(box_y)
box_center = [box_x, box_y]
if box_center not in g.center_matches.values():
g.center_matches[f"{graph_center[0]} {graph_center[1]}"] = box_center
class_number = class_names.index(label.obj_class.name)
x_center = round(box_center[0] / img_size[1], 6)
y_center = round(box_center[1] / img_size[0], 6)
width = round(matched_box.width / img_size[1], 6)
height = round(matched_box.height / img_size[0], 6)
# if failed to match graphs and boxes, get box by transforming graph to box
else:
sly.logger.warn("Failed to match graphs and boxes, boxes will be created by transforming graphs to boxes")
if len(boxes_list) == 0:
sly.logger.warn(
"Failed to find bounding boxes for graphs, "
"boxes will be created by transforming graphs to boxes"
)
class_number = class_names.index(label.obj_class.name)
rect_geometry = label.geometry.to_bbox()
center = rect_geometry.center
x_center = round(center.col / img_size[1], 6)
y_center = round(center.row / img_size[0], 6)
width = round(rect_geometry.width / img_size[1], 6)
height = round(rect_geometry.height / img_size[0], 6)
else:
center2box = {}
for box in boxes_list:
center2box[f"{box.center.col} {box.center.row}"] = box
distance2center = {}
for center in center2box.keys():
cx, cy = center.split()
distance = math.dist(graph_center, [int(cx), int(cy)])
distance2center[distance] = center
min_distance = min(distance2center.keys())
box_center = distance2center[min_distance]
# corresponding bbox for graph is the one with the smallest distance to graph center
matched_box = center2box[box_center]
box_x, box_y = box_center.split()
box_x, box_y = int(box_x), int(box_y)
box_center = [box_x, box_y]
if box_center not in g.center_matches.values():
g.center_matches[f"{graph_center[0]} {graph_center[1]}"] = box_center
class_number = class_names.index(label.obj_class.name)
x_center = round(box_center[0] / img_size[1], 6)
y_center = round(box_center[1] / img_size[0], 6)
width = round(matched_box.width / img_size[1], 6)
height = round(matched_box.height / img_size[0], 6)
# if failed to match graphs and boxes, get box by transforming graph to box
else:
sly.logger.warn("Failed to match graphs and boxes, boxes will be created by transforming graphs to boxes")
class_number = class_names.index(label.obj_class.name)
rect_geometry = label.geometry.to_bbox()
center = rect_geometry.center
x_center = round(center.col / img_size[1], 6)
y_center = round(center.row / img_size[0], 6)
width = round(rect_geometry.width / img_size[1], 6)
height = round(rect_geometry.height / img_size[0], 6)
graph_nodes = label.geometry.nodes
keypoints = []
for node_id in g.keypoints_template["nodes"].keys():
Expand Down Expand Up @@ -151,6 +164,7 @@ def _create_data_config(output_dir, meta: sly.ProjectMeta, task_type):


def _transform_annotation(ann, class_names, save_path, task_type):
g.center_matches = {}
yolov8_ann = []
for label in ann.labels:
if label.obj_class.name in class_names:
Expand Down

0 comments on commit 8515b45

Please sign in to comment.