diff --git a/src/functions.py b/src/functions.py index 57d305f..6cda154 100644 --- a/src/functions.py +++ b/src/functions.py @@ -1,5 +1,5 @@ import os - +from typing import List import numpy as np import supervisely as sly from supervisely.geometry import graph @@ -55,7 +55,7 @@ def get_categories_map_from_meta(meta): def get_keypoints_and_skeleton(obj_class): - nodes = list(obj_class.geometry_config["nodes"].keys()) + nodes = get_nodes_labels(obj_class) edges = obj_class.geometry_config["edges"] skeleton = [] for edge in edges: @@ -63,6 +63,14 @@ def get_keypoints_and_skeleton(obj_class): return nodes, skeleton +def get_nodes_labels(obj_class: sly.ObjClass) -> List[str]: + nodes_dict = obj_class.geometry_config["nodes"] + nodes = [] + for node_dict in nodes_dict.values(): + nodes.append(node_dict["label"]) + return nodes + + def get_categories_from_meta(meta: sly.ProjectMeta): obj_classes = meta.obj_classes categories = [] @@ -142,8 +150,12 @@ def create_coco_annotation( groups = ann.get_bindings() for binding_key, labels in groups.items(): bbox = None - if binding_key is not None and any(label.obj_class.geometry_type == sly.Rectangle for label in labels): - bbox_label = list(filter(lambda label: label.obj_class.geometry_type == sly.Rectangle, labels))[0] + if binding_key is not None and any( + label.obj_class.geometry_type == sly.Rectangle for label in labels + ): + bbox_label = list( + filter(lambda label: label.obj_class.geometry_type == sly.Rectangle, labels) + )[0] bbox = coco_bbox(bbox_label) for label in labels: label: sly.Label