Skip to content

Commit

Permalink
add async
Browse files Browse the repository at this point in the history
  • Loading branch information
vorozhkog committed Nov 29, 2024
1 parent a621ffd commit 7619744
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 13 deletions.
2 changes: 1 addition & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
supervisely==6.73.162
git+https://github.com/supervisely/supervisely.git@project_download_async
53 changes: 41 additions & 12 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import src.globals as g
import src.workflow as w

import asyncio
from tinytimer import Timer

api = None


Expand Down Expand Up @@ -80,20 +83,46 @@ def process(self, context: sly.app.Export.Context):
total_images_count += len(train_img_paths)
total_images_count += len(val_image_paths)

if api.server_address.startswith("https://"):
semaphore = asyncio.Semaphore(100)
else:
semaphore = None

download_progress = sly.Progress("Downloading images ...", total_images_count)
for ds in datasets:
for train_batch in sly.batched(
list(zip(all_ids[ds.id]["train_ids"], all_paths[ds.id]["train_paths"]))
):
img_ids, img_paths = zip(*train_batch)
api.image.download_paths(ds.id, img_ids, img_paths)
download_progress.iters_done_report(len(train_batch))
for val_batch in sly.batched(
list(zip(all_ids[ds.id]["val_ids"], all_paths[ds.id]["val_paths"]))
):
img_ids, img_paths = zip(*val_batch)
api.image.download_paths(ds.id, img_ids, img_paths)
download_progress.iters_done_report(len(val_batch))
ids = all_ids[ds.id]["train_ids"] + all_ids[ds.id]["val_ids"]
paths = all_paths[ds.id]["train_paths"] + all_paths[ds.id]["val_paths"]
with Timer() as t:
coro = api.image.download_paths_async(
ids, paths, semaphore, progress_cb=download_progress.iters_done_report
)
loop = sly.utils.get_or_create_event_loop()
if loop.is_running():
future = asyncio.run_coroutine_threadsafe(coro, loop)
future.result()
else:
loop.run_until_complete(coro)
sly.logger.info(
f"Downloading time: {t.elapsed:.4f} seconds per {len(ids)} images ({t.elapsed/len(ids):.4f} seconds per image)"
)

# with Timer() as t:
# total_ids_cnt = len(all_ids[ds.id]["train_ids"] + all_ids[ds.id]["val_ids"])
# for train_batch in sly.batched(
# list(zip(all_ids[ds.id]["train_ids"], all_paths[ds.id]["train_paths"]))
# ):
# img_ids, img_paths = zip(*train_batch)
# api.image.download_paths(ds.id, img_ids, img_paths)
# download_progress.iters_done_report(len(train_batch))
# for val_batch in sly.batched(
# list(zip(all_ids[ds.id]["val_ids"], all_paths[ds.id]["val_paths"]))
# ):
# img_ids, img_paths = zip(*val_batch)
# api.image.download_paths(ds.id, img_ids, img_paths)
# download_progress.iters_done_report(len(val_batch))
# sly.logger.info(
# f"Downloading time: {t.elapsed:.4f} seconds per {total_ids_cnt} images ({t.elapsed/total_ids_cnt:.4f} seconds per image)"
# )

f.prepare_yaml(result_dir_name, result_dir, class_names, class_colors, max_kpts_count)

Expand Down

0 comments on commit 7619744

Please sign in to comment.