Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolaiPetukhov committed Mar 5, 2024
1 parent 5ca3852 commit e6979cc
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 5 deletions.
2 changes: 1 addition & 1 deletion train/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
supervisely==6.73.22
supervisely==6.73.41
73 changes: 73 additions & 0 deletions train/src/sly_project_cached.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os

import supervisely as sly
from supervisely.project.download import (
download_to_cache,
copy_from_cache,
is_cached,
get_cache_size,
)
from sly_train_progress import get_progress_cb
import sly_globals as g


def download_project(
api: sly.Api,
project_info: sly.ProjectInfo,
project_dir: str,
use_cache: bool,
):
if os.path.exists(project_dir):
sly.fs.clean_dir(project_dir)
if not use_cache:
total = project_info.items_count
download_progress = get_progress_cb("Downloading input data...", total * 2)
sly.download(
api=api,
project_id=project_info.id,
dest_dir=project_dir,
dataset_ids=None,
log_progress=True,
progress_cb=download_progress,
)
return

# get datasets to download and cached
dataset_infos = api.dataset.get_list(project_info.id)
to_download = [info for info in dataset_infos if not is_cached(project_info.id, info.name)]
cached = [info for info in dataset_infos if is_cached(project_info.id, info.name)]
if len(cached) == 0:
log_msg = "No cached datasets found"
else:
log_msg = "Using cached datasets: " + ", ".join(
f"{ds_info.name} ({ds_info.id})" for ds_info in cached
)
sly.logger.info(log_msg)
if len(to_download) == 0:
log_msg = "All datasets are cached. No datasets to download"
else:
log_msg = "Downloading datasets: " + ", ".join(
f"{ds_info.name} ({ds_info.id})" for ds_info in to_download
)
sly.logger.info(log_msg)
# get images count
total = sum([ds_info.images_count for ds_info in to_download])
# download
download_progress = get_progress_cb("Downloading input data...", total * 2)
download_to_cache(
api=api,
project_id=project_info.id,
dataset_infos=to_download,
log_progress=True,
progress_cb=download_progress,
)
# copy datasets from cache
total = sum([get_cache_size(project_info.id, ds.name) for ds in dataset_infos])
dataset_names = [ds_info.name for ds_info in dataset_infos]
download_progress = get_progress_cb("Retreiving data from cache...", total, is_size=True)
copy_from_cache(
project_id=project_info.id,
dest_dir=project_dir,
dataset_names=dataset_names,
progress_cb=download_progress,
)
4 changes: 4 additions & 0 deletions train/src/ui/input_project.html
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
:href="`/projects/${data.projectId}/datasets`">{{data.projectName}} ({{data.projectImagesCount}}
images)</a>
<sly-icon slot="icon" :options="{ imageUrl: `${data.projectPreviewUrl}` }"/>
<el-checkbox v-model="state.useCache">
<span v-if="data.isCached">Use cached data stored on the agent to optimize project download</span>
<span v-else>Cache data on the agent to optimize project download for future trainings</span>
</el-checkbox>
</sly-field>
<el-button
type="primary"
Expand Down
14 changes: 10 additions & 4 deletions train/src/ui/input_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import random
from collections import namedtuple
import supervisely as sly
from supervisely.project.download import is_cached
import sly_globals as g
from sly_train_progress import get_progress_cb, reset_progress, init_progress
from sly_project_cached import download_project

progress_index = 1
_images_infos = None # dataset_name -> image_name -> image_info
Expand All @@ -21,6 +23,8 @@ def init(data, state):
init_progress(progress_index, data)
data["done1"] = False
state["collapsed1"] = False
data["isCached"] = is_cached(g.project_info.id)
state["useCache"] = True


@g.my_app.callback("download_project")
Expand All @@ -32,10 +36,12 @@ def download(api: sly.Api, task_id, context, state, app_logger):
pass
else:
sly.fs.mkdir(g.project_dir)
download_progress = get_progress_cb(progress_index, "Download project", g.project_info.items_count * 2)
sly.download_project(g.api, g.project_id, g.project_dir,
cache=g.my_app.cache, progress_cb=download_progress,
only_image_tags=False, save_image_info=True)
download_project(
api=g.api,
project_info=g.project_info,
project_dir=g.project_dir,
use_cache=state["useCache"]
)
reset_progress(progress_index)

global project_fs
Expand Down

0 comments on commit e6979cc

Please sign in to comment.