Skip to content

Commit

Permalink
Support for Instance segmentation. Bug fixes and tiny improvements (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev authored Aug 30, 2024
1 parent 4b46be6 commit d7518ae
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 46 deletions.
2 changes: 1 addition & 1 deletion config.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"task_location": "workspace_tasks",
"entrypoint": "python -m uvicorn src.main:app --host 0.0.0.0 --port 8000",
"port": 8000,
"docker_image": "supervisely/model-benchmark:0.0.1",
"docker_image": "supervisely/model-benchmark:0.0.2",
"instance_version": "6.11.10",
"context_menu": {
"target": ["images_project"]
Expand Down
4 changes: 2 additions & 2 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# git+https://github.com/supervisely/supervisely.git@model-benchmark
supervisely==6.73.166
supervisely[model-benchmark]==6.73.166
supervisely==6.73.173
supervisely[model-benchmark]==6.73.173
52 changes: 52 additions & 0 deletions src/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import src.globals as g
import supervisely as sly
from supervisely.nn import TaskType
from supervisely.nn.inference import Session

geometry_to_task_type = {
TaskType.OBJECT_DETECTION: [sly.Rectangle],
TaskType.INSTANCE_SEGMENTATION: [sly.Bitmap, sly.Polygon],
}


def get_project_classes():
meta = sly.ProjectMeta.from_json(g.api.project.get_meta(g.project_id))
return meta.obj_classes


def get_model_info():
g.session = Session(g.api, g.session_id)
model_meta = g.session.get_model_meta()
session_info = g.session.get_session_info()
return model_meta.obj_classes, session_info["task type"]


def get_classes():
project_classes = get_project_classes()
model_classes, task_type = get_model_info()
if task_type not in geometry_to_task_type:
raise ValueError(f"Task type {task_type} is not supported yet")
matched_proj_cls = []
matched_model_cls = []
not_matched_proj_cls = []
not_matched_model_cls = []
for obj_class in project_classes:
if model_classes.has_key(obj_class.name):
if obj_class.geometry_type in geometry_to_task_type[task_type]:
matched_proj_cls.append(obj_class)
matched_model_cls.append(model_classes.get(obj_class.name))
else:
not_matched_proj_cls.append(obj_class)
else:
not_matched_proj_cls.append(obj_class)

for obj_class in model_classes:
if not project_classes.has_key(obj_class.name):
not_matched_model_cls.append(obj_class)

return (matched_proj_cls, matched_model_cls), (not_matched_proj_cls, not_matched_model_cls)


def get_matched_class_names():
(cls_collection, _), _ = get_classes()
return [obj_cls.name for obj_cls in cls_collection]
3 changes: 3 additions & 0 deletions src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,7 @@
session_id = os.environ.get("modal.state.sessionId", None)
if session_id is not None:
session_id = int(session_id)
session = None
autostart = bool(strtobool(os.environ.get("modal.state.autoStart", "false")))

selected_classes = None
201 changes: 158 additions & 43 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,60 @@
from typing import Optional

import src.functions as f
import src.globals as g
import src.workflow as w
import supervisely as sly
import supervisely.app.widgets as widgets
from supervisely.nn.benchmark import ObjectDetectionBenchmark
from supervisely.nn.benchmark import (
InstanceSegmentationBenchmark,
ObjectDetectionBenchmark,
)
from supervisely.nn.inference.session import SessionJSON


def main_func():
api = g.api
project = api.project.get_info_by_id(sel_project.get_selected_id())
session_id = sel_app_session.get_selected_id()
project = api.project.get_info_by_id(g.project_id)
if g.session_id is None:
g.session = SessionJSON(api, g.session_id)
task_type = g.session.get_deploy_info()["task_type"]

# ==================== Workflow input ====================
w.workflow_input(api, project, session_id)
w.workflow_input(api, project, g.session_id)
# =======================================================

pbar.show()
report_model_benchmark.hide()

bm = ObjectDetectionBenchmark(
api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar
)
sly.logger.info(f"{session_id = }")
bm.run_evaluation(model_session=session_id)
# selected_classes = select_classes.get_selected()

# matches = [x[1] for x in select_classes.get_stat()["match"]]
# selected_classes = [x[0] for x in selected_classes if x[0] in matches]
if g.selected_classes is None:
g.selected_classes = f.get_matched_class_names()

if task_type == "object detection":
bm = ObjectDetectionBenchmark(
api,
project.id,
output_dir=g.STORAGE_DIR + "/benchmark",
progress=pbar,
classes_whitelist=g.selected_classes,
)
elif task_type == "instance segmentation":
bm = InstanceSegmentationBenchmark(
api,
project.id,
output_dir=g.STORAGE_DIR + "/benchmark",
progress=pbar,
classes_whitelist=g.selected_classes,
)
sly.logger.info(f"{g.session_id = }")
bm.run_evaluation(model_session=g.session_id)

task_info = api.task.get_info_by_id(g.session_id)
task_dir = f"{g.session_id}_{task_info['meta']['app']['name']}"

session_info = api.task.get_info_by_id(session_id)
task_dir = f"{session_id}_{session_info['meta']['app']['name']}"
eval_res_dir = f"/model-benchmark/evaluation/{project.id}_{project.name}/{task_dir}/"
eval_res_dir = api.storage.get_free_dir_name(g.team_id, eval_res_dir)

Expand All @@ -38,8 +66,6 @@ def main_func():
report = bm.upload_report_link(remote_dir)
api.task.set_output_report(g.task_id, report.id, report.name)

creating_report_f.hide()

template_vis_file = api.file.get_info_by_path(
sly.env.team_id(), eval_res_dir + "/visualizations/template.vue"
)
Expand All @@ -52,63 +78,152 @@ def main_func():
# =======================================================

sly.logger.info(
f"Predictions project name {bm.dt_project_info.name}, workspace_id {bm.dt_project_info.workspace_id}"
)
sly.logger.info(
f"Differences project name {bm.diff_project_info.name}, workspace_id {bm.diff_project_info.workspace_id}"
f"Predictions project: "
f" name {bm.dt_project_info.name}, "
f" workspace_id {bm.dt_project_info.workspace_id}. "
f"Differences project: "
f" name {bm.diff_project_info.name}, "
f" workspace_id {bm.diff_project_info.workspace_id}"
)

button.loading = False
app.stop()


# select_classes = widgets.MatchObjClasses(
# selectable=True,
# left_name="Model classes",
# right_name="GT project classes",
# )
# select_classes.hide()
# not_matched_classes = widgets.MatchObjClasses(
# left_name="Model classes",
# right_name="GT project classes",
# )
# not_matched_classes.hide()
no_classes_label = widgets.Text(
"Not found any classes in the project that are present in the model", status="error"
)
no_classes_label.hide()
total_classes_text = widgets.Text(status="info")
selected_matched_text = widgets.Text(status="success")
not_matched_text = widgets.Text(status="warning")

sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True)
sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id)

button = widgets.Button("Evaluate")
button.disable()

pbar = widgets.SlyTqdm()

report_model_benchmark = widgets.ReportThumbnail()
report_model_benchmark.hide()
creating_report_f = widgets.Field(widgets.Empty(), "", "Creating report on model...")
creating_report_f.hide()

controls_card = widgets.Card(
title="Settings",
description="Select Ground Truth project and deployed model session",
content=widgets.Container(
[sel_project, sel_app_session, button, report_model_benchmark, pbar]
),
)


# matched_card = widgets.Card(
# title="✅ Available classes",
# description="Select classes that are present in the model and in the project",
# content=select_classes,
# )
# matched_card.lock(message="Select project and model session to enable")

# not_matched_card = widgets.Card(
# title="❌ Not available classes",
# description="List of classes that are not matched between the model and the project",
# content=not_matched_classes,
# )
# not_matched_card.lock(message="Select project and model session to enable")


layout = widgets.Container(
widgets=[
widgets.Text("Select GT Project"),
sel_project,
sel_app_session,
button,
creating_report_f,
report_model_benchmark,
pbar,
]
widgets=[controls_card, widgets.Empty(), widgets.Empty()], # , matched_card, not_matched_card],
direction="horizontal",
fractions=[1, 1, 1],
)

main_layout = widgets.Container(
widgets=[layout, total_classes_text, selected_matched_text, not_matched_text, no_classes_label]
)

@sel_project.value_changed
def handle(project_id: Optional[int]):
active = project_id is not None and sel_app_session.get_selected_id() is not None

def handle_selectors(active: bool):
no_classes_label.hide()
selected_matched_text.hide()
not_matched_text.hide()
button.loading = True
# select_classes.hide()
# not_matched_classes.hide()
if active:
button.enable()
else:
button.disable()
matched, not_matched = f.get_classes()
_, matched_model_classes = matched
_, not_matched_model_classes = not_matched

g.selected_classes = [obj_cls.name for obj_cls in matched_model_classes]
not_matched_classes_cnt = len(not_matched_model_classes)
total_classes = len(matched_model_classes) + len(not_matched_model_classes)

total_classes_text.text = f"{total_classes} classes found in the model."
selected_matched_text.text = f"{len(matched_model_classes)} classes can be used for evaluation."
not_matched_text.text = f"{not_matched_classes_cnt} classes are not available for evaluation (not found in the GT project or have different geometry type)."

if len(matched_model_classes) > 0:
selected_matched_text.show()
if not_matched_classes_cnt > 0:
not_matched_text.show()
button.enable()
button.loading = False
return
else:
no_classes_label.show()
# select_classes.set(left_collection=model_classes, right_collection=matched_model_classes)

# matched_model_classes, model_classes = not_matched
# not_matched_classes.set(left_collection=model_classes, right_collection=matched_model_classes)
# stats = select_classes.get_stat()
# if len(stats["match"]) > 0:
# select_classes.show()
# not_matched_classes.show()
# button.enable()
# # matched_card.unlock()
# # not_matched_card.unlock()
# return
# else:
# no_classes_label.show()
button.loading = False
button.disable()


@sel_project.value_changed
def handle_sel_project(project_id: Optional[int]):
g.project_id = project_id
active = project_id is not None and g.session_id is not None
handle_selectors(active)


@sel_app_session.value_changed
def handle(session_id: Optional[int]):
active = session_id is not None and sel_project.get_selected_id() is not None
if active:
button.enable()
else:
button.disable()
def handle_sel_app_session(session_id: Optional[int]):
g.session_id = session_id
active = session_id is not None and g.project_id is not None
handle_selectors(active)


@button.click
def start_evaluation():
creating_report_f.show()
# select_classes.hide()
# not_matched_classes.hide()
main_func()


app = sly.Application(layout=layout, static_dir=g.STATIC_DIR)
app = sly.Application(layout=main_layout, static_dir=g.STATIC_DIR)

if g.project_id:
sel_project.set_project_id(g.project_id)
Expand All @@ -117,4 +232,4 @@ def start_evaluation():
sel_app_session.set_session_id(g.session_id)

if g.autostart:
start_evaluation()
start_evaluation()

0 comments on commit d7518ae

Please sign in to comment.