Skip to content

Commit

Permalink
Refactor progress reporting for annotation and image downloads using …
Browse files Browse the repository at this point in the history
…tqdm_sly for better user feedback
  • Loading branch information
GoldenAnpu committed Dec 19, 2024
1 parent b566b58 commit e01bb2e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
5 changes: 3 additions & 2 deletions src/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def _write_new_ann(path, content):
img_names = [f"{ds.name}_{image_info.name}" for image_info in images_infos]
ann_infos = []

coro = api.annotation.download_bulk_async(ds.id, img_ids)
ann_progress = sly.tqdm_sly(desc="Downloading annotations...", total=len(img_ids))
coro = api.annotation.download_bulk_async(ds.id, img_ids, progress_cb=ann_progress)
loop = sly.utils.get_or_create_event_loop()
if loop.is_running():
future = asyncio.run_coroutine_threadsafe(coro, loop)
Expand Down Expand Up @@ -187,7 +188,7 @@ def _write_new_ann(path, content):
image_processed = True
train_count += 1

progress.iter_done_report()
progress(1)

sly.logger.info(
f"DATASET '{ds.name}': {train_count} images for train, {val_count} images for validation"
Expand Down
6 changes: 3 additions & 3 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def process(self, context: sly.app.Export.Context):
total_images_count = 0
skipped = []

progress = sly.Progress("Transforming annotations...", images_count)
progress = sly.tqdm_sly(desc="Transforming annotations...", total=images_count)
for ds in datasets:
train_info, val_info = f.process_images(
api,
Expand All @@ -81,7 +81,7 @@ def process(self, context: sly.app.Export.Context):
total_images_count += len(train_img_paths)
total_images_count += len(val_image_paths)

download_progress = sly.Progress("Downloading images...", total_images_count)
download_progress = sly.tqdm_sly(desc="Downloading images...", total=total_images_count)
img_ids = []
img_paths = []
for ds in datasets:
Expand All @@ -91,7 +91,7 @@ def process(self, context: sly.app.Export.Context):
img_paths.extend(all_paths[ds.id]["val_paths"])

coro = api.image.download_paths_async(
img_ids, img_paths, progress_cb=download_progress.iters_done_report
img_ids, img_paths, progress_cb=download_progress
)
loop = sly.utils.get_or_create_event_loop()
if loop.is_running():
Expand Down

0 comments on commit e01bb2e

Please sign in to comment.