Skip to content

Commit

Permalink
prepare_segmentation_data: add progress
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Dec 16, 2024
1 parent e73f05b commit 75858e5
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions train/src/ui/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,30 +396,36 @@ def prepare_segmentation_data(state, img_dir, ann_dir, palette, target_classes):
datasets = os.listdir(temp_project_seg_dir)
os.makedirs(os.path.join(g.project_seg_dir, img_dir), exist_ok=True)
os.makedirs(os.path.join(g.project_seg_dir, ann_dir), exist_ok=True)
for dataset in datasets:
if not os.path.isdir(os.path.join(temp_project_seg_dir, dataset)):
if dataset == "meta.json":
shutil.move(os.path.join(temp_project_seg_dir, "meta.json"), g.project_seg_dir)
continue
# convert masks to required format and save to general ann_dir
mask_files = os.listdir(os.path.join(temp_project_seg_dir, dataset, ann_dir))
for mask_file in mask_files:
mask = cv2.imread(os.path.join(temp_project_seg_dir, dataset, ann_dir, mask_file))[
:, :, ::-1
]
result = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int32)
# human masks to machine masks
for color_idx, color in enumerate(palette):
colormap = np.where(np.all(mask == color, axis=-1))
result[colormap] = color_idx
cv2.imwrite(os.path.join(g.project_seg_dir, ann_dir, mask_file), result)

imgfiles_to_move = os.listdir(os.path.join(temp_project_seg_dir, dataset, img_dir))
for filename in imgfiles_to_move:
shutil.move(
os.path.join(temp_project_seg_dir, dataset, img_dir, filename),
os.path.join(g.project_seg_dir, img_dir),
)
total_items = sly.Project(g.project_dir, sly.OpenMode.READ).total_items
with TqdmProgress(
message="Converting masks to required format",
total=total_items,
) as p:
for dataset in datasets:
if not os.path.isdir(os.path.join(temp_project_seg_dir, dataset)):
if dataset == "meta.json":
shutil.move(os.path.join(temp_project_seg_dir, "meta.json"), g.project_seg_dir)
continue
# convert masks to required format and save to general ann_dir
mask_files = os.listdir(os.path.join(temp_project_seg_dir, dataset, ann_dir))
for mask_file in mask_files:
mask = cv2.imread(os.path.join(temp_project_seg_dir, dataset, ann_dir, mask_file))[
:, :, ::-1
]
result = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int32)
# human masks to machine masks
for color_idx, color in enumerate(palette):
colormap = np.where(np.all(mask == color, axis=-1))
result[colormap] = color_idx
cv2.imwrite(os.path.join(g.project_seg_dir, ann_dir, mask_file), result)
p.update(1)

imgfiles_to_move = os.listdir(os.path.join(temp_project_seg_dir, dataset, img_dir))
for filename in imgfiles_to_move:
shutil.move(
os.path.join(temp_project_seg_dir, dataset, img_dir, filename),
os.path.join(g.project_seg_dir, img_dir),
)


def run_benchmark(api: sly.Api, task_id, classes, cfg, state, remote_dir):
Expand Down

0 comments on commit 75858e5

Please sign in to comment.