Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Aug 30, 2024
1 parent e2f9d56 commit d37b40e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 86 deletions.
8 changes: 2 additions & 6 deletions src/functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import src.globals as g
import supervisely as sly
from supervisely.nn import TaskType
from supervisely.nn.inference import Session

import src.globals as g

geometry_to_task_type = {
TaskType.OBJECT_DETECTION: [sly.Rectangle],
TaskType.INSTANCE_SEGMENTATION: [sly.Bitmap, sly.Polygon],
Expand Down Expand Up @@ -45,8 +46,3 @@ def get_classes():
not_matched_model_cls.append(obj_class)

return (matched_proj_cls, matched_model_cls), (not_matched_proj_cls, not_matched_model_cls)


def get_matched_class_names():
(cls_collection, _), _ = get_classes()
return [obj_cls.name for obj_cls in cls_collection]
108 changes: 28 additions & 80 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import Optional

import src.functions as f
import src.globals as g
import src.workflow as w
import supervisely as sly
import supervisely.app.widgets as widgets
from supervisely.nn.benchmark import (
Expand All @@ -11,11 +8,15 @@
)
from supervisely.nn.inference.session import SessionJSON

import src.functions as f
import src.globals as g
import src.workflow as w


def main_func():
api = g.api
project = api.project.get_info_by_id(g.project_id)
if g.session_id is None:
if g.session is None:
g.session = SessionJSON(api, g.session_id)
task_type = g.session.get_deploy_info()["task_type"]

Expand All @@ -26,13 +27,7 @@ def main_func():
pbar.show()
report_model_benchmark.hide()

# selected_classes = select_classes.get_selected()

# matches = [x[1] for x in select_classes.get_stat()["match"]]
# selected_classes = [x[0] for x in selected_classes if x[0] in matches]
if g.selected_classes is None:
g.selected_classes = f.get_matched_class_names()

set_selected_classes_and_show_info()
if task_type == "object detection":
bm = ObjectDetectionBenchmark(
api,
Expand Down Expand Up @@ -90,17 +85,6 @@ def main_func():
app.stop()


# select_classes = widgets.MatchObjClasses(
# selectable=True,
# left_name="Model classes",
# right_name="GT project classes",
# )
# select_classes.hide()
# not_matched_classes = widgets.MatchObjClasses(
# left_name="Model classes",
# right_name="GT project classes",
# )
# not_matched_classes.hide()
no_classes_label = widgets.Text(
"Not found any classes in the project that are present in the model", status="error"
)
Expand All @@ -123,27 +107,10 @@ def main_func():
controls_card = widgets.Card(
title="Settings",
description="Select Ground Truth project and deployed model session",
content=widgets.Container(
[sel_project, sel_app_session, button, report_model_benchmark, pbar]
),
content=widgets.Container([sel_project, sel_app_session, button, report_model_benchmark, pbar]),
)


# matched_card = widgets.Card(
# title="✅ Available classes",
# description="Select classes that are present in the model and in the project",
# content=select_classes,
# )
# matched_card.lock(message="Select project and model session to enable")

# not_matched_card = widgets.Card(
# title="❌ Not available classes",
# description="List of classes that are not matched between the model and the project",
# content=not_matched_classes,
# )
# not_matched_card.lock(message="Select project and model session to enable")


layout = widgets.Container(
widgets=[controls_card, widgets.Empty(), widgets.Empty()], # , matched_card, not_matched_card],
direction="horizontal",
Expand All @@ -155,51 +122,32 @@ def main_func():
)


def set_selected_classes_and_show_info():
matched, not_matched = f.get_classes()
_, matched_model_classes = matched
_, not_matched_model_classes = not_matched
total_classes_text.text = (
f"{len(matched_model_classes) + len(not_matched_model_classes)} classes found in the model."
)
selected_matched_text.text = f"{len(matched_model_classes)} classes can be used for evaluation."
not_matched_text.text = f"{len(not_matched_model_classes)} classes are not available for evaluation (not found in the GT project or have different geometry type)."
if len(matched_model_classes) > 0:
g.selected_classes = [obj_cls.name for obj_cls in matched_model_classes]
selected_matched_text.show()
if len(not_matched_model_classes) > 0:
not_matched_text.show()
else:
no_classes_label.show()


def handle_selectors(active: bool):
no_classes_label.hide()
selected_matched_text.hide()
not_matched_text.hide()
button.loading = True
# select_classes.hide()
# not_matched_classes.hide()
if active:
matched, not_matched = f.get_classes()
_, matched_model_classes = matched
_, not_matched_model_classes = not_matched

g.selected_classes = [obj_cls.name for obj_cls in matched_model_classes]
not_matched_classes_cnt = len(not_matched_model_classes)
total_classes = len(matched_model_classes) + len(not_matched_model_classes)

total_classes_text.text = f"{total_classes} classes found in the model."
selected_matched_text.text = f"{len(matched_model_classes)} classes can be used for evaluation."
not_matched_text.text = f"{not_matched_classes_cnt} classes are not available for evaluation (not found in the GT project or have different geometry type)."

if len(matched_model_classes) > 0:
selected_matched_text.show()
if not_matched_classes_cnt > 0:
not_matched_text.show()
button.enable()
button.loading = False
return
else:
no_classes_label.show()
# select_classes.set(left_collection=model_classes, right_collection=matched_model_classes)

# matched_model_classes, model_classes = not_matched
# not_matched_classes.set(left_collection=model_classes, right_collection=matched_model_classes)
# stats = select_classes.get_stat()
# if len(stats["match"]) > 0:
# select_classes.show()
# not_matched_classes.show()
# button.enable()
# # matched_card.unlock()
# # not_matched_card.unlock()
# return
# else:
# no_classes_label.show()
button.loading = False
button.disable()
button.enable()
else:
button.disable()


@sel_project.value_changed
Expand Down

0 comments on commit d37b40e

Please sign in to comment.