diff --git a/src/convert_sly_to_yolov5.py b/src/convert_sly_to_yolov5.py index 350af90..c57e2be 100644 --- a/src/convert_sly_to_yolov5.py +++ b/src/convert_sly_to_yolov5.py @@ -23,6 +23,8 @@ TEAM_ID = int(os.environ["context.teamId"]) WORKSPACE_ID = int(os.environ["context.workspaceId"]) PROJECT_ID = int(os.environ["modal.state.slyProjectId"]) +PROCCESS_SHAPES = os.environ.get("modal.state.processShapes", "transform") +PROCCESS_SHAPES_MSG = "skipped" if PROCCESS_SHAPES == "skip" else "transformed to rectangles" TRAIN_TAG_NAME = "train" VAL_TAG_NAME = "val" @@ -78,19 +80,16 @@ def transform(api: sly.Api, task_id, context, state, app_logger): ) error_classes = [] - ok_classes = [] for obj_class in meta.obj_classes: if obj_class.geometry_type != sly.Rectangle: error_classes.append(obj_class) - else: - ok_classes.append(obj_class) if len(error_classes) > 0: sly.logger.warn( - f"Project has unsupported classes. All unsupported classes will be transformed to rectangles: " + f"Project has unsupported classes. " + f"Objects with unsupported geometry types will be {PROCCESS_SHAPES_MSG}" f"{[obj_class.name for obj_class in error_classes]}" ) - def _write_new_ann(path, content): with open(path, "a") as f1: f1.write("\n".join(content)) @@ -109,6 +108,7 @@ def _add_to_split(image_id, img_name, split_ids, split_image_paths, labels_dir, for dataset in api.dataset.get_list(PROJECT_ID): images = api.image.get_list(dataset.id) + unsupported_shapes = 0 train_ids = [] train_image_paths = [] val_ids = [] @@ -125,6 +125,10 @@ def _add_to_split(image_id, img_name, split_ids, split_image_paths, labels_dir, yolov5_ann = [] for label in ann.labels: + if label.obj_class.geometry_type != sly.Rectangle: + unsupported_shapes += 1 + if PROCCESS_SHAPES == "skip": + continue yolov5_ann.append(transform_label(class_names, ann.img_size, label)) image_processed = False @@ -167,6 +171,11 @@ def _add_to_split(image_id, img_name, split_ids, split_image_paths, labels_dir, api.image.download_paths(dataset.id, val_ids, val_image_paths) progress.iters_done_report(len(batch)) + if unsupported_shapes > 0: + app_logger.warn( + f"Dataset {dataset.name} has {unsupported_shapes} objects with unsupported geometry types. " + f"These objects will be {PROCCESS_SHAPES_MSG}" + ) data_yaml = { "train": "../{}/images/train".format(result_dir_name),