Skip to content

Commit

Permalink
Merge branch 'select-classes-and-run-from-script' into inst-seg
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev authored Aug 30, 2024
2 parents 6da925c + 9306184 commit 64c517d
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 42 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
git+https://github.com/supervisely/supervisely.git@model-benchmark-stage-2
52 changes: 52 additions & 0 deletions src/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import src.globals as g
import supervisely as sly
from supervisely.nn.inference import Session

geometry_to_task_type = {
"object detection": [sly.Rectangle],
"instance segmentation": [sly.Bitmap, sly.Polygon, sly.AlphaMask],
# "semantic segmentation": [sly.Bitmap, sly.Polygon, sly.AlphaMask],
}


def get_project_classes():
meta = sly.ProjectMeta.from_json(g.api.project.get_meta(g.project_id))
return meta.obj_classes


def get_model_info():
if g.session is None:
g.session = Session(g.api, g.session_id)
model_meta = g.session.get_model_meta()
session_info = g.session.get_session_info()
return model_meta.obj_classes, session_info["task type"]


def get_classes():
project_classes = get_project_classes()
model_classes, task_type = get_model_info()
if task_type not in geometry_to_task_type:
raise ValueError(f"Task type {task_type} is not supported yet")
matched_proj_cls = []
matched_model_cls = []
not_matched_proj_cls = []
not_matched_model_cls = []
for obj_class in project_classes:
if model_classes.has_key(obj_class.name):
if obj_class.geometry_type in geometry_to_task_type[task_type]:
matched_proj_cls.append(obj_class)
matched_model_cls.append(model_classes.get(obj_class.name))
else:
not_matched_proj_cls.append(obj_class)
else:
not_matched_proj_cls.append(obj_class)

for obj_class in model_classes:
if not project_classes.has_key(obj_class.name):
not_matched_model_cls.append(obj_class)

return (matched_proj_cls, matched_model_cls), (not_matched_proj_cls, not_matched_model_cls)

def get_matched_class_names():
(cls_collection, _), _ = get_classes()
return [obj_cls.name for obj_cls in cls_collection]
3 changes: 3 additions & 0 deletions src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,7 @@
session_id = os.environ.get("modal.state.sessionId", None)
if session_id is not None:
session_id = int(session_id)
session = None
autostart = bool(strtobool(os.environ.get("modal.state.autoStart", "false")))

selected_classes = None
185 changes: 143 additions & 42 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

import src.functions as f
import src.globals as g
import src.workflow as w
import supervisely as sly
Expand All @@ -13,31 +14,47 @@

def main_func():
api = g.api
project = api.project.get_info_by_id(sel_project.get_selected_id())
session_id = sel_app_session.get_selected_id()
session = SessionJSON(api, session_id)
task_type = session.get_deploy_info()["task_type"]
project = api.project.get_info_by_id(g.project_id)
if g.session_id is None:
g.session = SessionJSON(api, g.session_id)
task_type = g.session.get_deploy_info()["task_type"]

# ==================== Workflow input ====================
w.workflow_input(api, project, session_id)
w.workflow_input(api, project, g.session_id)
# =======================================================

pbar.show()
report_model_benchmark.hide()

# selected_classes = select_classes.get_selected()

# matches = [x[1] for x in select_classes.get_stat()["match"]]
# selected_classes = [x[0] for x in selected_classes if x[0] in matches]
if g.selected_classes is None:
g.selected_classes = f.get_matched_class_names()

if task_type == "object detection":
bm = ObjectDetectionBenchmark(
api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar
api,
project.id,
output_dir=g.STORAGE_DIR + "/benchmark",
progress=pbar,
classes_whitelist=g.selected_classes,
)
elif task_type == "instance segmentation":
bm = InstanceSegmentationBenchmark(
api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar
api,
project.id,
output_dir=g.STORAGE_DIR + "/benchmark",
progress=pbar,
classes_whitelist=g.selected_classes,
)
sly.logger.info(f"{session_id = }")
bm.run_evaluation(model_session=session_id)
sly.logger.info(f"{g.session_id = }")
bm.run_evaluation(model_session=g.session_id)

task_info = api.task.get_info_by_id(g.session_id)
task_dir = f"{g.session_id}_{task_info['meta']['app']['name']}"

task_info = api.task.get_info_by_id(session_id)
task_dir = f"{session_id}_{task_info['meta']['app']['name']}"
eval_res_dir = f"/model-benchmark/evaluation/{project.id}_{project.name}/{task_dir}/"
eval_res_dir = api.storage.get_free_dir_name(g.team_id, eval_res_dir)

Expand All @@ -49,8 +66,6 @@ def main_func():
report = bm.upload_report_link(remote_dir)
api.task.set_output_report(g.task_id, report.id, report.name)

creating_report_f.hide()

template_vis_file = api.file.get_info_by_path(
sly.env.team_id(), eval_res_dir + "/visualizations/template.vue"
)
Expand All @@ -63,63 +78,149 @@ def main_func():
# =======================================================

sly.logger.info(
f"Predictions project name {bm.dt_project_info.name}, workspace_id {bm.dt_project_info.workspace_id}"
)
sly.logger.info(
f"Differences project name {bm.diff_project_info.name}, workspace_id {bm.diff_project_info.workspace_id}"
f"Predictions project: "
f" name {bm.dt_project_info.name}, "
f" workspace_id {bm.dt_project_info.workspace_id}. "
f"Differences project: "
f" name {bm.diff_project_info.name}, "
f" workspace_id {bm.diff_project_info.workspace_id}"
)

button.loading = False
app.stop()


# select_classes = widgets.MatchObjClasses(
# selectable=True,
# left_name="Model classes",
# right_name="GT project classes",
# )
# select_classes.hide()
# not_matched_classes = widgets.MatchObjClasses(
# left_name="Model classes",
# right_name="GT project classes",
# )
# not_matched_classes.hide()
no_classes_label = widgets.Text(
"Not found any classes in the project that are present in the model", status="error"
)
no_classes_label.hide()
total_classes_text = widgets.Text(status="info")
selected_matched_text = widgets.Text(status="success")
not_matched_text = widgets.Text(status="warning")

sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True)
sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id)

button = widgets.Button("Evaluate")
button.disable()

pbar = widgets.SlyTqdm()

report_model_benchmark = widgets.ReportThumbnail()
report_model_benchmark.hide()
creating_report_f = widgets.Field(widgets.Empty(), "", "Creating report on model...")
creating_report_f.hide()

controls_card = widgets.Card(
title="Settings",
description="Select Ground Truth project and deployed model session",
content=widgets.Container(
[sel_project, sel_app_session, button, report_model_benchmark, pbar]
),
)


# matched_card = widgets.Card(
# title="✅ Available classes",
# description="Select classes that are present in the model and in the project",
# content=select_classes,
# )
# matched_card.lock(message="Select project and model session to enable")

# not_matched_card = widgets.Card(
# title="❌ Not available classes",
# description="List of classes that are not matched between the model and the project",
# content=not_matched_classes,
# )
# not_matched_card.lock(message="Select project and model session to enable")


layout = widgets.Container(
widgets=[
widgets.Text("Select GT Project"),
sel_project,
sel_app_session,
button,
creating_report_f,
report_model_benchmark,
pbar,
]
widgets=[controls_card, widgets.Empty(), widgets.Empty()], # , matched_card, not_matched_card],
direction="horizontal",
fractions=[1, 1, 1],
)

main_layout = widgets.Container(
widgets=[layout, total_classes_text, selected_matched_text, not_matched_text, no_classes_label]
)


@sel_project.value_changed
def handle(project_id: Optional[int]):
active = project_id is not None and sel_app_session.get_selected_id() is not None
def handle_selectors(active: bool):
no_classes_label.hide()
selected_matched_text.hide()
not_matched_text.hide()
# select_classes.hide()
# not_matched_classes.hide()
if active:
button.enable()
else:
button.disable()
matched, not_matched = f.get_classes()
_, matched_model_classes = matched
_, not_matched_model_classes = not_matched

g.selected_classes = [obj_cls.name for obj_cls in matched_model_classes]
not_matched_classes_cnt = len(not_matched_model_classes)
total_classes = len(matched_model_classes) + len(not_matched_model_classes)

total_classes_text.text = f"{total_classes} classes found in the model."
selected_matched_text.text = f"{len(matched_model_classes)} classes can be used for evaluation."
not_matched_text.text = f"{not_matched_classes_cnt} classes are not available for evaluation (not found in the GT project or have different geometry type)."

if len(matched_model_classes) > 0:
selected_matched_text.show()
if not_matched_classes_cnt > 0:
not_matched_text.show()
button.enable()
return
else:
no_classes_label.show()
# select_classes.set(left_collection=model_classes, right_collection=matched_model_classes)

# matched_model_classes, model_classes = not_matched
# not_matched_classes.set(left_collection=model_classes, right_collection=matched_model_classes)
# stats = select_classes.get_stat()
# if len(stats["match"]) > 0:
# select_classes.show()
# not_matched_classes.show()
# button.enable()
# # matched_card.unlock()
# # not_matched_card.unlock()
# return
# else:
# no_classes_label.show()
button.disable()


@sel_project.value_changed
def handle_sel_project(project_id: Optional[int]):
g.project_id = project_id
active = project_id is not None and g.session_id is not None
handle_selectors(active)


@sel_app_session.value_changed
def handle(session_id: Optional[int]):
active = session_id is not None and sel_project.get_selected_id() is not None
if active:
button.enable()
else:
button.disable()
def handle_sel_app_session(session_id: Optional[int]):
g.session_id = session_id
active = session_id is not None and g.project_id is not None
handle_selectors(active)


@button.click
def start_evaluation():
creating_report_f.show()
# select_classes.hide()
# not_matched_classes.hide()
main_func()


app = sly.Application(layout=layout, static_dir=g.STATIC_DIR)
app = sly.Application(layout=main_layout, static_dir=g.STATIC_DIR)

if g.project_id:
sel_project.set_project_id(g.project_id)
Expand Down

0 comments on commit 64c517d

Please sign in to comment.