Skip to content

Commit

Permalink
rollback prepare_segmentation_data
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Dec 16, 2024
1 parent 146d484 commit e73f05b
Showing 1 changed file with 98 additions and 62 deletions.
160 changes: 98 additions & 62 deletions train/src/ui/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,74 +316,110 @@ def init_class_charts_series(state):
g.api.app.set_fields(g.task_id, fields)


def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes=None):
# def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes=None):
# target_classes = target_classes or state["selectedClasses"]
# temp_project_seg_dir = g.project_seg_dir + "_temp"
# bg_name = get_bg_class_name(target_classes) or "__bg__"
# bg_color = (0, 0, 0)
# if bg_name in target_classes:
# try:
# bg_color = palette[target_classes.index(bg_name)]
# except:
# pass

# project = sly.Project(g.project_dir, sly.OpenMode.READ)
# with TqdmProgress(
# message="Converting project to segmentation task",
# total=project.total_items,
# ) as p:
# sly.Project.to_segmentation_task(
# g.project_dir,
# temp_project_seg_dir,
# target_classes=target_classes,
# progress_cb=p.update,
# bg_color=bg_color,
# bg_name=bg_name,
# )

# palette_lookup = np.zeros(256**3, dtype=np.int32)
# for idx, color in enumerate(palette, 1):
# key = (color[0] << 16) | (color[1] << 8) | color[2]
# palette_lookup[key] = idx

# datasets = os.listdir(temp_project_seg_dir)
# os.makedirs(os.path.join(g.project_seg_dir, img_dir), exist_ok=True)
# os.makedirs(os.path.join(g.project_seg_dir, ann_dir), exist_ok=True)

# with TqdmProgress(
# message="Converting masks to required format",
# total=project.total_items,
# ) as p:
# for dataset in datasets:
# if not os.path.isdir(os.path.join(temp_project_seg_dir, dataset)):
# if dataset == "meta.json":
# shutil.move(os.path.join(temp_project_seg_dir, "meta.json"), g.project_seg_dir)
# continue
# # convert masks to required format and save to general ann_dir
# mask_files = os.listdir(os.path.join(temp_project_seg_dir, dataset, ann_dir))
# for mask_file in mask_files:
# path = os.path.join(temp_project_seg_dir, dataset, ann_dir, mask_file)
# mask = cv2.imread(path)[:, :, ::-1]

# mask_keys = (
# (mask[:, :, 0].astype(np.int32) << 16)
# | (mask[:, :, 1].astype(np.int32) << 8)
# | mask[:, :, 2].astype(np.int32)
# )
# result = palette_lookup[mask_keys]
# cv2.imwrite(os.path.join(g.project_seg_dir, ann_dir, mask_file), result)

# p.update(1)

# imgfiles_to_move = os.listdir(os.path.join(temp_project_seg_dir, dataset, img_dir))
# for filename in imgfiles_to_move:
# shutil.move(
# os.path.join(temp_project_seg_dir, dataset, img_dir, filename),
# os.path.join(g.project_seg_dir, img_dir),
# )

# shutil.rmtree(temp_project_seg_dir)
# g.api.app.set_field(g.task_id, "state.preparingData", False)


def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes):
target_classes = target_classes or state["selectedClasses"]
temp_project_seg_dir = g.project_seg_dir + "_temp"
bg_name = get_bg_class_name(target_classes) or "__bg__"
bg_color = (0, 0, 0)
if bg_name in target_classes:
try:
bg_color = palette[target_classes.index(bg_name)]
except:
pass

project = sly.Project(g.project_dir, sly.OpenMode.READ)
with TqdmProgress(
message="Converting project to segmentation task",
total=project.total_items,
) as p:
sly.Project.to_segmentation_task(
g.project_dir,
temp_project_seg_dir,
target_classes=target_classes,
progress_cb=p.update,
bg_color=bg_color,
bg_name=bg_name,
)

palette_lookup = np.zeros(256**3, dtype=np.int32)
for idx, color in enumerate(palette, 1):
key = (color[0] << 16) | (color[1] << 8) | color[2]
palette_lookup[key] = idx
sly.Project.to_segmentation_task(
g.project_dir, temp_project_seg_dir, target_classes=target_classes
)

datasets = os.listdir(temp_project_seg_dir)
os.makedirs(os.path.join(g.project_seg_dir, img_dir), exist_ok=True)
os.makedirs(os.path.join(g.project_seg_dir, ann_dir), exist_ok=True)

with TqdmProgress(
message="Converting masks to required format",
total=project.total_items,
) as p:
for dataset in datasets:
if not os.path.isdir(os.path.join(temp_project_seg_dir, dataset)):
if dataset == "meta.json":
shutil.move(os.path.join(temp_project_seg_dir, "meta.json"), g.project_seg_dir)
continue
# convert masks to required format and save to general ann_dir
mask_files = os.listdir(os.path.join(temp_project_seg_dir, dataset, ann_dir))
for mask_file in mask_files:
path = os.path.join(temp_project_seg_dir, dataset, ann_dir, mask_file)
mask = cv2.imread(path)[:, :, ::-1]

mask_keys = (
(mask[:, :, 0].astype(np.int32) << 16)
| (mask[:, :, 1].astype(np.int32) << 8)
| mask[:, :, 2].astype(np.int32)
)
result = palette_lookup[mask_keys]
cv2.imwrite(os.path.join(g.project_seg_dir, ann_dir, mask_file), result)

p.update(1)

imgfiles_to_move = os.listdir(os.path.join(temp_project_seg_dir, dataset, img_dir))
for filename in imgfiles_to_move:
shutil.move(
os.path.join(temp_project_seg_dir, dataset, img_dir, filename),
os.path.join(g.project_seg_dir, img_dir),
)

shutil.rmtree(temp_project_seg_dir)
g.api.app.set_field(g.task_id, "state.preparingData", False)
for dataset in datasets:
if not os.path.isdir(os.path.join(temp_project_seg_dir, dataset)):
if dataset == "meta.json":
shutil.move(os.path.join(temp_project_seg_dir, "meta.json"), g.project_seg_dir)
continue
# convert masks to required format and save to general ann_dir
mask_files = os.listdir(os.path.join(temp_project_seg_dir, dataset, ann_dir))
for mask_file in mask_files:
mask = cv2.imread(os.path.join(temp_project_seg_dir, dataset, ann_dir, mask_file))[
:, :, ::-1
]
result = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int32)
# human masks to machine masks
for color_idx, color in enumerate(palette):
colormap = np.where(np.all(mask == color, axis=-1))
result[colormap] = color_idx
cv2.imwrite(os.path.join(g.project_seg_dir, ann_dir, mask_file), result)

imgfiles_to_move = os.listdir(os.path.join(temp_project_seg_dir, dataset, img_dir))
for filename in imgfiles_to_move:
shutil.move(
os.path.join(temp_project_seg_dir, dataset, img_dir, filename),
os.path.join(g.project_seg_dir, img_dir),
)


def run_benchmark(api: sly.Api, task_id, classes, cfg, state, remote_dir):
Expand Down

0 comments on commit e73f05b

Please sign in to comment.