Skip to content

Commit

Permalink
prepare_segmentation_data: using palette_lookup to speed up conversat…
Browse files Browse the repository at this point in the history
…ions
  • Loading branch information
almazgimaev committed Dec 16, 2024
1 parent 2dff200 commit 12f05bc
Showing 1 changed file with 13 additions and 36 deletions.
49 changes: 13 additions & 36 deletions train/src/ui/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,43 +317,13 @@ def init_class_charts_series(state):


# def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes=None):
# bg_name = get_bg_class_name(target_classes) or "__bg__"
# bg_color = (0, 0, 0)
# if bg_name in target_classes:
# try:
# bg_color = palette[target_classes.index(bg_name)]
# except:
# pass

# project = sly.Project(g.project_dir, sly.OpenMode.READ)
# with TqdmProgress(
# message="Converting project to segmentation task",
# total=project.total_items,
# ) as p:
# sly.Project.to_segmentation_task(
# g.project_dir,
# temp_project_seg_dir,
# target_classes=target_classes,
# progress_cb=p.update,
# bg_color=bg_color,
# bg_name=bg_name,
# )

# palette_lookup = np.zeros(256**3, dtype=np.int32)
# for idx, color in enumerate(palette, 1):
# key = (color[0] << 16) | (color[1] << 8) | color[2]
# palette_lookup[key] = idx

# for dataset in datasets:
# if not os.path.isdir(os.path.join(temp_project_seg_dir, dataset)):
# if dataset == "meta.json":
# shutil.move(os.path.join(temp_project_seg_dir, "meta.json"), g.project_seg_dir)
# continue
# # convert masks to required format and save to general ann_dir
# mask_files = os.listdir(os.path.join(temp_project_seg_dir, dataset, ann_dir))
# for mask_file in mask_files:
# path = os.path.join(temp_project_seg_dir, dataset, ann_dir, mask_file)
# mask = cv2.imread(path)[:, :, ::-1]

# mask_keys = (
# (mask[:, :, 0].astype(np.int32) << 16)
Expand Down Expand Up @@ -389,6 +359,11 @@ def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes):
bg_name=bg_name,
)

palette_lookup = np.zeros(256**3, dtype=np.int32)
for idx, color in enumerate(palette):
key = (color[0] << 16) | (color[1] << 8) | color[2]
palette_lookup[key] = idx

datasets = os.listdir(temp_project_seg_dir)
os.makedirs(os.path.join(g.project_seg_dir, img_dir), exist_ok=True)
os.makedirs(os.path.join(g.project_seg_dir, ann_dir), exist_ok=True)
Expand All @@ -405,14 +380,16 @@ def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes):
# convert masks to required format and save to general ann_dir
mask_files = os.listdir(os.path.join(temp_project_seg_dir, dataset, ann_dir))
for mask_file in mask_files:
mask = cv2.imread(os.path.join(temp_project_seg_dir, dataset, ann_dir, mask_file))[
:, :, ::-1
]
path = os.path.join(temp_project_seg_dir, dataset, ann_dir, mask_file)
mask = cv2.imread(path)[:, :, ::-1]
result = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int32)
# human masks to machine masks
for color_idx, color in enumerate(palette):
colormap = np.where(np.all(mask == color, axis=-1))
result[colormap] = color_idx
mask_keys = (
(mask[:, :, 0].astype(np.int32) << 16)
| (mask[:, :, 1].astype(np.int32) << 8)
| mask[:, :, 2].astype(np.int32)
)
result = palette_lookup[mask_keys]
cv2.imwrite(os.path.join(g.project_seg_dir, ann_dir, mask_file), result)
p.update(1)

Expand Down

0 comments on commit 12f05bc

Please sign in to comment.