Skip to content

Commit

Permalink
prepare_segmentation_data: handle bg class name and color
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Dec 16, 2024
1 parent 75858e5 commit 2dff200
Showing 1 changed file with 23 additions and 25 deletions.
48 changes: 23 additions & 25 deletions train/src/ui/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,6 @@ def init_class_charts_series(state):


# def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes=None):
# target_classes = target_classes or state["selectedClasses"]
# temp_project_seg_dir = g.project_seg_dir + "_temp"
# bg_name = get_bg_class_name(target_classes) or "__bg__"
# bg_color = (0, 0, 0)
# if bg_name in target_classes:
Expand Down Expand Up @@ -346,14 +344,6 @@ def init_class_charts_series(state):
# key = (color[0] << 16) | (color[1] << 8) | color[2]
# palette_lookup[key] = idx

# datasets = os.listdir(temp_project_seg_dir)
# os.makedirs(os.path.join(g.project_seg_dir, img_dir), exist_ok=True)
# os.makedirs(os.path.join(g.project_seg_dir, ann_dir), exist_ok=True)

# with TqdmProgress(
# message="Converting masks to required format",
# total=project.total_items,
# ) as p:
# for dataset in datasets:
# if not os.path.isdir(os.path.join(temp_project_seg_dir, dataset)):
# if dataset == "meta.json":
Expand All @@ -373,25 +363,31 @@ def init_class_charts_series(state):
# result = palette_lookup[mask_keys]
# cv2.imwrite(os.path.join(g.project_seg_dir, ann_dir, mask_file), result)

# p.update(1)

# imgfiles_to_move = os.listdir(os.path.join(temp_project_seg_dir, dataset, img_dir))
# for filename in imgfiles_to_move:
# shutil.move(
# os.path.join(temp_project_seg_dir, dataset, img_dir, filename),
# os.path.join(g.project_seg_dir, img_dir),
# )

# shutil.rmtree(temp_project_seg_dir)
# g.api.app.set_field(g.task_id, "state.preparingData", False)


def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes):
target_classes = target_classes or state["selectedClasses"]
temp_project_seg_dir = g.project_seg_dir + "_temp"
sly.Project.to_segmentation_task(
g.project_dir, temp_project_seg_dir, target_classes=target_classes
)
bg_name = get_bg_class_name(target_classes) or "__bg__"
bg_color = (0, 0, 0)
if bg_name in target_classes:
try:
bg_color = palette[target_classes.index(bg_name)]
except:
pass

project = sly.Project(g.project_dir, sly.OpenMode.READ)
with TqdmProgress(
message="Converting project to segmentation task",
total=project.total_items,
) as p:
sly.Project.to_segmentation_task(
g.project_dir,
temp_project_seg_dir,
target_classes=target_classes,
progress_cb=p.update,
bg_color=bg_color,
bg_name=bg_name,
)

datasets = os.listdir(temp_project_seg_dir)
os.makedirs(os.path.join(g.project_seg_dir, img_dir), exist_ok=True)
Expand Down Expand Up @@ -427,6 +423,8 @@ def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes):
os.path.join(g.project_seg_dir, img_dir),
)

shutil.rmtree(temp_project_seg_dir)
g.api.app.set_field(g.task_id, "state.preparingData", False)

def run_benchmark(api: sly.Api, task_id, classes, cfg, state, remote_dir):
global m
Expand Down

0 comments on commit 2dff200

Please sign in to comment.