From e73f05bd2e4acbe23d2a791a59d1935d87ad9029 Mon Sep 17 00:00:00 2001 From: almaz Date: Mon, 16 Dec 2024 12:29:55 +0100 Subject: [PATCH] rollback prepare_segmentation_data --- train/src/ui/monitoring.py | 160 +++++++++++++++++++++++-------------- 1 file changed, 98 insertions(+), 62 deletions(-) diff --git a/train/src/ui/monitoring.py b/train/src/ui/monitoring.py index 0afbb98..2c50fbe 100644 --- a/train/src/ui/monitoring.py +++ b/train/src/ui/monitoring.py @@ -316,74 +316,110 @@ def init_class_charts_series(state): g.api.app.set_fields(g.task_id, fields) -def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes=None): +# def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes=None): +# target_classes = target_classes or state["selectedClasses"] +# temp_project_seg_dir = g.project_seg_dir + "_temp" +# bg_name = get_bg_class_name(target_classes) or "__bg__" +# bg_color = (0, 0, 0) +# if bg_name in target_classes: +# try: +# bg_color = palette[target_classes.index(bg_name)] +# except: +# pass + +# project = sly.Project(g.project_dir, sly.OpenMode.READ) +# with TqdmProgress( +# message="Converting project to segmentation task", +# total=project.total_items, +# ) as p: +# sly.Project.to_segmentation_task( +# g.project_dir, +# temp_project_seg_dir, +# target_classes=target_classes, +# progress_cb=p.update, +# bg_color=bg_color, +# bg_name=bg_name, +# ) + +# palette_lookup = np.zeros(256**3, dtype=np.int32) +# for idx, color in enumerate(palette, 1): +# key = (color[0] << 16) | (color[1] << 8) | color[2] +# palette_lookup[key] = idx + +# datasets = os.listdir(temp_project_seg_dir) +# os.makedirs(os.path.join(g.project_seg_dir, img_dir), exist_ok=True) +# os.makedirs(os.path.join(g.project_seg_dir, ann_dir), exist_ok=True) + +# with TqdmProgress( +# message="Converting masks to required format", +# total=project.total_items, +# ) as p: +# for dataset in datasets: +# if not os.path.isdir(os.path.join(temp_project_seg_dir, dataset)): +# if dataset == "meta.json": +# shutil.move(os.path.join(temp_project_seg_dir, "meta.json"), g.project_seg_dir) +# continue +# # convert masks to required format and save to general ann_dir +# mask_files = os.listdir(os.path.join(temp_project_seg_dir, dataset, ann_dir)) +# for mask_file in mask_files: +# path = os.path.join(temp_project_seg_dir, dataset, ann_dir, mask_file) +# mask = cv2.imread(path)[:, :, ::-1] + +# mask_keys = ( +# (mask[:, :, 0].astype(np.int32) << 16) +# | (mask[:, :, 1].astype(np.int32) << 8) +# | mask[:, :, 2].astype(np.int32) +# ) +# result = palette_lookup[mask_keys] +# cv2.imwrite(os.path.join(g.project_seg_dir, ann_dir, mask_file), result) + +# p.update(1) + +# imgfiles_to_move = os.listdir(os.path.join(temp_project_seg_dir, dataset, img_dir)) +# for filename in imgfiles_to_move: +# shutil.move( +# os.path.join(temp_project_seg_dir, dataset, img_dir, filename), +# os.path.join(g.project_seg_dir, img_dir), +# ) + +# shutil.rmtree(temp_project_seg_dir) +# g.api.app.set_field(g.task_id, "state.preparingData", False) + + +def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes): target_classes = target_classes or state["selectedClasses"] temp_project_seg_dir = g.project_seg_dir + "_temp" - bg_name = get_bg_class_name(target_classes) or "__bg__" - bg_color = (0, 0, 0) - if bg_name in target_classes: - try: - bg_color = palette[target_classes.index(bg_name)] - except: - pass - - project = sly.Project(g.project_dir, sly.OpenMode.READ) - with TqdmProgress( - message="Converting project to segmentation task", - total=project.total_items, - ) as p: - sly.Project.to_segmentation_task( - g.project_dir, - temp_project_seg_dir, - target_classes=target_classes, - progress_cb=p.update, - bg_color=bg_color, - bg_name=bg_name, - ) - - palette_lookup = np.zeros(256**3, dtype=np.int32) - for idx, color in enumerate(palette, 1): - key = (color[0] << 16) | (color[1] << 8) | color[2] - palette_lookup[key] = idx + sly.Project.to_segmentation_task( + g.project_dir, temp_project_seg_dir, target_classes=target_classes + ) datasets = os.listdir(temp_project_seg_dir) os.makedirs(os.path.join(g.project_seg_dir, img_dir), exist_ok=True) os.makedirs(os.path.join(g.project_seg_dir, ann_dir), exist_ok=True) - - with TqdmProgress( - message="Converting masks to required format", - total=project.total_items, - ) as p: - for dataset in datasets: - if not os.path.isdir(os.path.join(temp_project_seg_dir, dataset)): - if dataset == "meta.json": - shutil.move(os.path.join(temp_project_seg_dir, "meta.json"), g.project_seg_dir) - continue - # convert masks to required format and save to general ann_dir - mask_files = os.listdir(os.path.join(temp_project_seg_dir, dataset, ann_dir)) - for mask_file in mask_files: - path = os.path.join(temp_project_seg_dir, dataset, ann_dir, mask_file) - mask = cv2.imread(path)[:, :, ::-1] - - mask_keys = ( - (mask[:, :, 0].astype(np.int32) << 16) - | (mask[:, :, 1].astype(np.int32) << 8) - | mask[:, :, 2].astype(np.int32) - ) - result = palette_lookup[mask_keys] - cv2.imwrite(os.path.join(g.project_seg_dir, ann_dir, mask_file), result) - - p.update(1) - - imgfiles_to_move = os.listdir(os.path.join(temp_project_seg_dir, dataset, img_dir)) - for filename in imgfiles_to_move: - shutil.move( - os.path.join(temp_project_seg_dir, dataset, img_dir, filename), - os.path.join(g.project_seg_dir, img_dir), - ) - - shutil.rmtree(temp_project_seg_dir) - g.api.app.set_field(g.task_id, "state.preparingData", False) + for dataset in datasets: + if not os.path.isdir(os.path.join(temp_project_seg_dir, dataset)): + if dataset == "meta.json": + shutil.move(os.path.join(temp_project_seg_dir, "meta.json"), g.project_seg_dir) + continue + # convert masks to required format and save to general ann_dir + mask_files = os.listdir(os.path.join(temp_project_seg_dir, dataset, ann_dir)) + for mask_file in mask_files: + mask = cv2.imread(os.path.join(temp_project_seg_dir, dataset, ann_dir, mask_file))[ + :, :, ::-1 + ] + result = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int32) + # human masks to machine masks + for color_idx, color in enumerate(palette): + colormap = np.where(np.all(mask == color, axis=-1)) + result[colormap] = color_idx + cv2.imwrite(os.path.join(g.project_seg_dir, ann_dir, mask_file), result) + + imgfiles_to_move = os.listdir(os.path.join(temp_project_seg_dir, dataset, img_dir)) + for filename in imgfiles_to_move: + shutil.move( + os.path.join(temp_project_seg_dir, dataset, img_dir, filename), + os.path.join(g.project_seg_dir, img_dir), + ) def run_benchmark(api: sly.Api, task_id, classes, cfg, state, remote_dir):