Skip to content

Commit

Permalink
add log messages
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolaiPetukhov committed Feb 28, 2024
1 parent 3e2a27d commit a9bf1ff
Showing 1 changed file with 39 additions and 10 deletions.
49 changes: 39 additions & 10 deletions train/src/dataset_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ def split_by_cache(project_id, dataset_ids) -> Tuple[set, set]:
if sly.fs.dir_exists(cache_dataset_dir):
cached.add(dataset_id)
to_download.remove(dataset_id)

return to_download, cached


def download_project(api: sly.Api, project_info: sly.ProjectInfo, dataset_infos: List[sly.DatasetInfo], use_cache: bool, progress: Progress):
def download_project(
api: sly.Api,
project_info: sly.ProjectInfo,
dataset_infos: List[sly.DatasetInfo],
use_cache: bool,
progress: Progress,
):
dataset_ids = [dataset_info.id for dataset_info in dataset_infos]
if not use_cache:
total = sum([dataset_info.images_count for dataset_info in dataset_infos])
Expand All @@ -40,17 +46,34 @@ def download_project(api: sly.Api, project_info: sly.ProjectInfo, dataset_infos:
)
return

dataset_infos_dict = {dataset_info.id:dataset_info for dataset_info in dataset_infos}
dataset_infos_dict = {
dataset_info.id: dataset_info for dataset_info in dataset_infos
}
# get datasets to download and cached
to_download, cached = split_by_cache(project_info.id, dataset_ids)
if len(cached) == 0:
log_msg = "No cached datasets found"
else:
log_msg = "Using cached datasets: " + ", ".join(
f"{dataset_infos_dict[dataset_id].name} ({dataset_id})"
for dataset_id in cached
)
sly.logger.info(log_msg)
if len(to_download) == 0:
log_msg = "All datasets are cached. No datasets to download"
else:
log_msg = "Downloading datasets: " + ", ".join(
f"{dataset_infos_dict[dataset_id].name} ({dataset_id})"
for dataset_id in to_download
)
sly.logger.info(log_msg)
# get images count
total = sum([dataset_infos_dict[dataset_id].items_count for dataset_id in to_download])
total = sum(
[dataset_infos_dict[dataset_id].items_count for dataset_id in to_download]
)
# clean project dir
if os.path.exists(g.project_dir):
sly.fs.clean_dir(g.project_dir)

# TODO Check if to_download is empty

# download
with progress(message="Downloading input data...", total=total) as pbar:
sly.download(
Expand All @@ -69,8 +92,12 @@ def download_project(api: sly.Api, project_info: sly.ProjectInfo, dataset_infos:
total = sum([sly.fs.get_directory_size(dp) for dp in downloaded_dirs])
with progress(message="Saving data to cache...", total=total) as pbar:
for dataset_id, dataset_dir in zip(to_download, downloaded_dirs):
cache_dataset_dir = os.path.join(g.cache_dir, str(project_info.id), str(dataset_id))
sly.fs.copy_dir_recursively(dataset_dir, cache_dataset_dir, progress_cb=pbar.update)
cache_dataset_dir = os.path.join(
g.cache_dir, str(project_info.id), str(dataset_id)
)
sly.fs.copy_dir_recursively(
dataset_dir, cache_dataset_dir, progress_cb=pbar.update
)
# copy cached datasets
cached_dirs = [
os.path.join(g.cache_dir, str(project_info.id), str(dataset_id))
Expand All @@ -81,4 +108,6 @@ def download_project(api: sly.Api, project_info: sly.ProjectInfo, dataset_infos:
for dataset_id, cache_dataset_dir in zip(cached, cached_dirs):
dataset_name = dataset_infos_dict[dataset_id].name
dataset_dir = os.path.join(g.project_dir, dataset_name)
sly.fs.copy_dir_recursively(cache_dataset_dir, dataset_dir, progress_cb=pbar.update)
sly.fs.copy_dir_recursively(
cache_dataset_dir, dataset_dir, progress_cb=pbar.update
)

0 comments on commit a9bf1ff

Please sign in to comment.