Skip to content

Commit

Permalink
Add custom GUI support (#48)
Browse files Browse the repository at this point in the history
* Add custom GUI support

* update and sort imports

* change column Name to Model

* update req.txt

* update req.txt

* update Supervisely version

* Update build_image.yml
  • Loading branch information
cxnt authored May 23, 2024
1 parent ce168ea commit 3dffb9e
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 91 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ train/src/.DS_Store
serve/src/__pycache__
train/src/__pycache__
train/src/ui/__pycache__
my_model
my_model

supervisely
2 changes: 1 addition & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
supervisely==6.73.45
supervisely==6.73.89

openmim
ffmpeg-python==0.2.0
Expand Down
60 changes: 29 additions & 31 deletions serve/config.json
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
{
"name": "Serve MMSegmentation",
"type": "app",
"version": "2.0.0",
"categories": [
"neural network",
"images",
"videos",
"semantic segmentation",
"segmentation & tracking",
"serve"
],
"description": "Deploy model as REST API service",
"docker_image": "supervisely/mmseg:1.3.4",
"instance_version": "6.8.88",
"entrypoint": "python -m uvicorn main:m.app --app-dir ./serve/src --host 0.0.0.0 --port 8000 --ws websockets",
"port": 8000,
"task_location": "application_sessions",
"need_gpu": false,
"gpu": "preferred",
"icon": "https://i.imgur.com/GfjrdKI.png",
"isolate": true,
"icon_cover": true,
"session_tags": [
"deployed_nn"
],
"poster": "https://user-images.githubusercontent.com/48245050/182851208-d8e50d77-686e-470d-a136-428856a60ef5.jpg",
"community_agent": false,
"license": {
"type": "Apache-2.0"
}
}
"name": "Serve MMSegmentation",
"type": "app",
"version": "2.0.0",
"categories": [
"neural network",
"images",
"videos",
"semantic segmentation",
"segmentation & tracking",
"serve"
],
"description": "Deploy model as REST API service",
"docker_image": "supervisely/mmseg:1.3.5",
"instance_version": "6.8.88",
"entrypoint": "python -m uvicorn main:m.app --app-dir ./serve/src --host 0.0.0.0 --port 8000 --ws websockets",
"port": 8000,
"task_location": "application_sessions",
"need_gpu": false,
"gpu": "preferred",
"icon": "https://i.imgur.com/GfjrdKI.png",
"isolate": true,
"icon_cover": true,
"session_tags": ["deployed_nn"],
"poster": "https://user-images.githubusercontent.com/48245050/182851208-d8e50d77-686e-470d-a136-428856a60ef5.jpg",
"community_agent": false,
"license": {
"type": "Apache-2.0"
}
}
1 change: 0 additions & 1 deletion serve/requirements.txt

This file was deleted.

241 changes: 184 additions & 57 deletions serve/src/main.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,43 @@
import os
import shutil
import sys

try:
from typing import Literal
except:
from typing_extensions import Literal
from typing import List, Any, Dict

from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, List

import numpy as np
import pkg_resources
import torch
import yaml
from dotenv import load_dotenv
import torch
import supervisely as sly
import supervisely.app.widgets as Widgets
import supervisely.nn.inference.gui as GUI
import pkg_resources
from collections import OrderedDict
from mmcv import Config
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.runner import load_checkpoint
from mmseg.models import build_segmentor
from mmseg.apis.inference import inference_segmentor
from mmseg.datasets import *
from mmseg.models import build_segmentor

import supervisely as sly
from serve.src import utils
from supervisely.app.widgets import (CustomModelsSelector,
PretrainedModelsSelector, RadioTabs,
Widget)
from supervisely.io.fs import silent_remove

root_source_path = str(Path(__file__).parents[2])
app_source_path = str(Path(__file__).parents[1])
load_dotenv(os.path.join(app_source_path, "local.env"))
load_dotenv(os.path.expanduser("~/supervisely.env"))

api = sly.Api.from_env()
team_id = sly.env.team_id()

use_gui_for_local_debug = bool(int(os.environ.get("USE_GUI", "1")))

models_meta_path = os.path.join(root_source_path, "models", "model_meta.json")
Expand All @@ -51,6 +61,158 @@ def str_to_class(classname):


class MMSegmentationModel(sly.nn.inference.SemanticSegmentation):
def initialize_custom_gui(self) -> Widget:
"""Create custom GUI layout for model selection. This method is called once when the application is started."""
models = self.get_models()
filtered_models = utils.filter_models_structure(models)
self.pretrained_models_table = PretrainedModelsSelector(filtered_models)
custom_models = sly.nn.checkpoints.mmsegmentation.get_list(api, team_id)
self.custom_models_table = CustomModelsSelector(
team_id,
custom_models,
show_custom_checkpoint_path=True,
custom_checkpoint_task_types=["semantic segmentation"],
)

self.model_source_tabs = RadioTabs(
titles=["Pretrained models", "Custom models"],
descriptions=["Publicly available models", "Models trained by you in Supervisely"],
contents=[self.pretrained_models_table, self.custom_models_table],
)
return self.model_source_tabs

def get_params_from_gui(self) -> dict:
model_source = self.model_source_tabs.get_active_tab()
self.device = self.gui.get_device()
if model_source == "Pretrained models":
model_params = self.pretrained_models_table.get_selected_model_params()
elif model_source == "Custom models":
model_params = self.custom_models_table.get_selected_model_params()
if self.custom_models_table.use_custom_checkpoint_path():
checkpoint_path = self.custom_models_table.get_custom_checkpoint_path()
model_params["config_url"] = (
f"{os.path.dirname(checkpoint_path).rstrip('/')}/config.py"
)
file_info = api.file.exists(team_id, model_params["config_url"])
if file_info is None:
raise FileNotFoundError(
f"Config file not found: {model_params['config_url']}. "
"Config should be placed in the same directory as the checkpoint file."
)

self.selected_model_name = model_params.get("arch_type")
self.checkpoint_name = model_params.get("checkpoint_name")
self.task_type = model_params.get("task_type")

deploy_params = {
"device": self.device,
**model_params,
}
return deploy_params

def load_model_meta(
self, model_source: str, cfg: Config, checkpoint_name: str = None, arch_type: str = None
):
def set_common_meta(classes, palette):
obj_classes = [sly.ObjClass(name, sly.Bitmap, color) for name, color in zip(classes, palette)]
self.checkpoint_name = checkpoint_name
self.dataset_name = cfg.dataset_type
self.class_names = classes
self._model_meta = sly.ProjectMeta(obj_classes=sly.ObjClassCollection(obj_classes))
self._get_confidence_tag_meta()

if model_source == "Custom models":
self.selected_model_name = cfg.pretrained_model
classes = cfg.checkpoint_config.meta.CLASSES
palette = cfg.checkpoint_config.meta.PALETTE
set_common_meta(classes, palette)

elif model_source == "Pretrained models":
self.selected_model_name = arch_type
dataset_class_name = cfg.dataset_type
classes = str_to_class(dataset_class_name).CLASSES
palette = str_to_class(dataset_class_name).PALETTE
set_common_meta(classes, palette)

self.model.CLASSES = classes
self.model.PALETTE = palette


def load_model(
self,
device: Literal["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3"],
model_source: Literal["Pretrained models", "Custom models"],
task_type: Literal["semantic segmentation"],
checkpoint_name: str,
checkpoint_url: str,
config_url: str,
arch_type: str = None,
):
"""
Load model method is used to deploy model.
:param model_source: Specifies whether the model is pretrained or custom.
:type model_source: Literal["Pretrained models", "Custom models"]
:param device: The device on which the model will be deployed.
:type device: Literal["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3"]
:param task_type: The type of task the model is designed for.
:type task_type: Literal["semantic segmentation"]
:param checkpoint_name: The name of the checkpoint from which the model is loaded.
:type checkpoint_name: str
:param checkpoint_url: The URL where the model checkpoint can be downloaded.
:type checkpoint_url: str
:param config_url: The URL where the model config can be downloaded.
:type config_url: str
:param arch_type: The architecture type of the model.
:type arch_type: str
"""
self.device = device
self.task_type = task_type

local_weights_path = os.path.join(self.model_dir, checkpoint_name)
if model_source == "Pretrained models":
if not sly.fs.file_exists(local_weights_path):
self.download(
src_path=checkpoint_url,
dst_path=local_weights_path,
)
local_config_path = os.path.join(root_source_path, config_url)
else:
self.download(
src_path=checkpoint_url,
dst_path=local_weights_path,
)
local_config_path = os.path.join(configs_dir, "custom", "config.py")
if sly.fs.file_exists(local_config_path):
silent_remove(local_config_path)
self.download(
src_path=config_url,
dst_path=local_config_path,
)
if not sly.fs.file_exists(local_config_path):
raise FileNotFoundError(
f"Config file not found: {config_url}. "
"Config should be placed in the same directory as the checkpoint file."
)

try:
cfg = Config.fromfile(local_config_path)
cfg.model.pretrained = None
cfg.model.train_cfg = None

self.model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
checkpoint = load_checkpoint(self.model, local_weights_path, map_location='cpu')

self.load_model_meta(model_source, cfg, checkpoint_name, arch_type)

self.model.cfg = cfg # save the config in the model for convenience
self.model.to(device)
self.model.eval()
self.model = revert_sync_batchnorm(self.model)

except KeyError as e:
raise KeyError(f"Error loading config file: {local_config_path}. Error: {e}")

def load_on_device(
self,
model_dir: str,
Expand All @@ -69,6 +231,8 @@ def load_on_device(
# for local debug only
model_source = "Pretrained models"
weights_path, config_path = self.download_pretrained_files(selected_checkpoint, model_dir)


cfg = Config.fromfile(config_path)
cfg.model.pretrained = None
cfg.model.train_cfg = None
Expand Down Expand Up @@ -107,9 +271,6 @@ def load_on_device(
self._model_meta = sly.ProjectMeta(obj_classes=sly.ObjClassCollection(obj_classes))
print(f"✅ Model has been successfully loaded on {device.upper()} device")

def get_classes(self) -> List[str]:
return self.class_names # e.g. ["cat", "dog", ...]

def get_info(self) -> dict:
info = super().get_info()
info["model_name"] = self.selected_model_name
Expand All @@ -118,7 +279,7 @@ def get_info(self) -> dict:
info["device"] = self.device
return info

def get_models(self, add_links=False):
def get_models(self):
model_yamls = sly.json.load_json_file(models_meta_path)
model_config = {}
for model_meta in model_yamls:
Expand All @@ -134,7 +295,7 @@ def get_models(self, add_links=False):
model_config[model_meta["model_name"]]["config_url"] = os.path.dirname(model_yml_url)
for model in model_info["Models"]:
checkpoint_info = OrderedDict()
checkpoint_info["Name"] = model["Name"]
checkpoint_info["Model"] = model["Name"]
checkpoint_info["Backbone"] = model["Metadata"]["backbone"]
checkpoint_info["Method"] = model["In Collection"]
checkpoint_info["Dataset"] = model["Results"][0]["Dataset"]
Expand All @@ -150,53 +311,19 @@ def get_models(self, add_links=False):
checkpoint_info["Memory (Training, GB)"] = "-"
for metric_name, metric_val in model["Results"][0]["Metrics"].items():
checkpoint_info[metric_name] = metric_val
#checkpoint_info["config_file"] = os.path.join(f"https://github.com/open-mmlab/mmsegmentation/tree/v{mmseg_ver}", model["Config"])
if add_links:
checkpoint_info["config_file"] = os.path.join(root_source_path, model["Config"])
checkpoint_info["weights_file"] = model["Weights"]
#checkpoint_info["config_file"] = os.path.join(f"https://github.com/open-mmlab/mmsegmentation/tree/v{mmseg_ver}", model["Config"])
checkpoint_info["meta"] = {
"task_type": None,
"arch_type": None,
"arch_link": None,
"weights_url": model["Weights"],
"config_url": os.path.join(root_source_path, model["Config"]),
}
model_config[model_meta["model_name"]]["checkpoints"].append(checkpoint_info)
return model_config

def download_pretrained_files(self, selected_model: Dict[str, str], model_dir: str):
models = self.get_models(add_links=True)
if self.gui is not None:
model_name = list(self.gui.get_model_info().keys())[0]
else:
# for local debug only
model_name = selected_model_name
full_model_info = selected_model
for model_info in models[model_name]["checkpoints"]:
if model_info["Name"] == selected_model["Name"]:
full_model_info = model_info
weights_ext = sly.fs.get_file_ext(full_model_info["weights_file"])
config_ext = sly.fs.get_file_ext(full_model_info["config_file"])
weights_dst_path = os.path.join(model_dir, f"{selected_model['Name']}{weights_ext}")
if not sly.fs.file_exists(weights_dst_path):
self.download(
src_path=full_model_info["weights_file"],
dst_path=weights_dst_path
)
config_path = self.download(
src_path=full_model_info["config_file"],
dst_path=os.path.join(model_dir, f"config{config_ext}")
)

return weights_dst_path, config_path

def download_custom_files(self, custom_link: str, model_dir: str):
weight_filename = os.path.basename(custom_link)
weights_dst_path = os.path.join(model_dir, weight_filename)
if not sly.fs.file_exists(weights_dst_path):
self.download(
src_path=custom_link,
dst_path=weights_dst_path,
)
config_path = self.download(
src_path=os.path.join(os.path.dirname(custom_link), 'config.py'),
dst_path=os.path.join(model_dir, 'config.py'),
)

return weights_dst_path, config_path
def get_classes(self) -> List[str]:
return self.class_names # e.g. ["cat", "dog", ...]

def predict(
self, image_path: str, settings: Dict[str, Any]
Expand Down
Loading

0 comments on commit 3dffb9e

Please sign in to comment.