Skip to content

Commit

Permalink
update train serve
Browse files Browse the repository at this point in the history
  • Loading branch information
cxnt committed Jun 7, 2024
1 parent 1f052ab commit c8e2bcb
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 18 deletions.
6 changes: 3 additions & 3 deletions serve/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import supervisely as sly
from serve.src import utils
from supervisely.nn.checkpoints.mmsegmentation import MMSegmentationCheckpoint
from supervisely.nn.artifacts.mmsegmentation import MMSegmentation
from supervisely.app.widgets import (
CustomModelsSelector,
PretrainedModelsSelector,
Expand Down Expand Up @@ -72,8 +72,8 @@ def initialize_custom_gui(self) -> Widget:
models = self.get_models()
filtered_models = utils.filter_models_structure(models)
self.pretrained_models_table = PretrainedModelsSelector(filtered_models)
checkpoint = MMSegmentationCheckpoint(team_id)
custom_models = checkpoint.get_list()
sly_mmseg = MMSegmentation(team_id)
custom_models = sly_mmseg.get_list()
self.custom_models_table = CustomModelsSelector(
team_id,
custom_models,
Expand Down
10 changes: 7 additions & 3 deletions train/src/sly_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import sys
import supervisely as sly
from supervisely.app.v1.app_service import AppService
from supervisely.nn.artifacts.mmsegmentation import MMSegmentation

import shutil
import pkg_resources

Expand All @@ -29,9 +31,9 @@
api = my_app.public_api
task_id = my_app.task_id

team_id = int(os.environ['context.teamId'])
workspace_id = int(os.environ['context.workspaceId'])
project_id = int(os.environ['modal.state.slyProjectId'])
team_id = int(os.environ["context.teamId"])
workspace_id = int(os.environ["context.workspaceId"])
project_id = int(os.environ["modal.state.slyProjectId"])

project_info = api.project.get_info_by_id(project_id)

Expand All @@ -49,6 +51,8 @@
checkpoints_dir = os.path.join(artifacts_dir, "checkpoints")
sly.fs.mkdir(checkpoints_dir)

sly_mmseg = MMSegmentation(team_id)

configs_dir = os.path.join(root_source_dir, "configs")
mmseg_ver = pkg_resources.get_distribution("mmsegmentation").version
if os.path.isdir(f"/tmp/mmseg/mmsegmentation-{mmseg_ver}"):
Expand Down
23 changes: 11 additions & 12 deletions train/src/ui/monitoring.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import supervisely as sly
from supervisely.nn.checkpoints.mmsegmentation import MMSegmentationCheckpoint
from sly_train_progress import init_progress, _update_progress_ui
import sly_globals as g
import os
Expand Down Expand Up @@ -152,24 +151,24 @@ def upload_monitor(monitor, api: sly.Api, task_id, progress: sly.Progress):
)
progress_cb = partial(upload_monitor, api=g.api, task_id=g.task_id, progress=progress)

checkpoint = MMSegmentationCheckpoint(g.team_id)
model_dir = checkpoint.get_model_dir()
model_dir = g.sly_mmseg.framework_folder
remote_artifacts_dir = f"{model_dir}/{g.task_id}_{g.project_info.name}"
remote_weights_dir = os.path.join(remote_artifacts_dir, checkpoint.weights_dir)
remote_config_path = os.path.join(remote_weights_dir, checkpoint.config_file)
remote_weights_dir = os.path.join(remote_artifacts_dir, g.sly_mmseg.weights_folder)
remote_config_path = os.path.join(remote_weights_dir, g.sly_mmseg.config_file)

res_dir = g.api.file.upload_directory(
g.team_id, g.artifacts_dir, remote_artifacts_dir, progress_size_cb=progress_cb
)

# generate metadata file
checkpoint.generate_sly_metadata(
app_name=checkpoint.app_name,
session_id=g.task_id,
session_path=remote_artifacts_dir,
weights_dir=remote_weights_dir,
training_project_name=g.project_info.name,
task_type=checkpoint.task_type,
g.sly_mmseg.generate_metadata(
app_name=g.sly_mmseg.app_name,
task_id=g.task_id,
artifacts_folder=remote_artifacts_dir,
weights_folder=remote_weights_dir,
weights_ext=g.sly_mmseg.weights_ext,
project_name=g.project_info.name,
task_type=g.sly_mmseg.task_type,
config_path=remote_config_path,
)

Expand Down

0 comments on commit c8e2bcb

Please sign in to comment.