Skip to content

Commit

Permalink
add modal and classes selector
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Aug 27, 2024
1 parent 2f771e1 commit 455c588
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 31 deletions.
14 changes: 13 additions & 1 deletion config.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,17 @@
"entrypoint": "python -m uvicorn src.main:app --host 0.0.0.0 --port 8000",
"port": 8000,
"docker_image": "supervisely/model-benchmark:0.0.1",
"instance_version": "6.11.10"
"instance_version": "6.11.10",
"context_menu": {
"target": ["images_project"]
},
"modal_template": "src/modal.html",
"modal_template_state": {
"sessionId": null,
"sessionOptions": {
"sessionTags": ["deployed_nn"],
"showLabel": false,
"size": "small"
}
}
}
35 changes: 35 additions & 0 deletions src/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import src.globals as g
import supervisely as sly
from supervisely.nn.inference import Session

geometry_to_task_type = {
"object detection": [sly.Rectangle],
"instance segmentation": [sly.Bitmap, sly.Polygon, sly.AlphaMask],
# "semantic segmentation": [sly.Bitmap, sly.Polygon, sly.AlphaMask],
}


def get_project_classes():
meta = sly.ProjectMeta.from_json(g.api.project.get_meta(g.project_id))
return meta.obj_classes


def get_model_info():
session = Session(g.api, g.session_id)
model_meta = session.get_model_meta()
session_info = session.get_session_info()
return model_meta.obj_classes, session_info["task type"]


def get_classes():
project_classes = get_project_classes()
model_classes, task_type = get_model_info()
if task_type not in geometry_to_task_type:
raise ValueError(f"Task type {task_type} is not supported yet")
filtered_classes = []
for obj_class in project_classes:
if model_classes.has_key(obj_class.name):
if obj_class.geometry_type in geometry_to_task_type[task_type]:
filtered_classes.append(obj_class)

return filtered_classes
6 changes: 5 additions & 1 deletion src/globals.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os

import supervisely as sly
from dotenv import load_dotenv

import supervisely as sly

if sly.is_development():
load_dotenv("local.env")
load_dotenv(os.path.expanduser("~/supervisely.env"))
Expand All @@ -21,3 +22,6 @@
project_id = sly.env.project_id(raise_not_found=False)
team_id = sly.env.team_id()
task_id = sly.env.task_id(raise_not_found=False)
session_id = os.environ.get("modal.state.sessionId", None)
if session_id is not None:
session_id = int(session_id)
99 changes: 70 additions & 29 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,37 @@
from typing import Optional

import src.functions as f
import src.globals as g
import src.workflow as w
import supervisely as sly
import supervisely.app.widgets as widgets
from supervisely.nn.benchmark import ObjectDetectionBenchmark

import src.globals as g
import src.workflow as w


def main_func():
api = g.api
project = api.project.get_info_by_id(sel_project.get_selected_id())
session_id = sel_app_session.get_selected_id()
project = api.project.get_info_by_id(g.project_id)

# ==================== Workflow input ====================
w.workflow_input(api, project, session_id)
w.workflow_input(api, project, g.session_id)
# =======================================================


pbar.show()
report_model_benchmark.hide()

selected_classes = select_classes.get_selected_classes()
bm = ObjectDetectionBenchmark(
api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar
api,
project.id,
output_dir=g.STORAGE_DIR + "/benchmark",
progress=pbar,
classes=selected_classes,
)
sly.logger.info(f"{session_id = }")
bm.run_evaluation(model_session=session_id)
sly.logger.info(f"{g.session_id = }")
bm.run_evaluation(model_session=g.session_id)

session_info = api.task.get_info_by_id(session_id)
task_dir = f"{session_id}_{session_info['meta']['app']['name']}"
session_info = api.task.get_info_by_id(g.session_id)
task_dir = f"{g.session_id}_{session_info['meta']['app']['name']}"
eval_res_dir = f"/model-benchmark/evaluation/{project.id}_{project.name}/{task_dir}/"
eval_res_dir = api.storage.get_free_dir_name(g.team_id, eval_res_dir)

Expand All @@ -54,20 +57,34 @@ def main_func():
# =======================================================

sly.logger.info(
f"Predictions project name {bm.dt_project_info.name}, workspace_id {bm.dt_project_info.workspace_id}"
)
sly.logger.info(
f"Differences project name {bm.diff_project_info.name}, workspace_id {bm.diff_project_info.workspace_id}"
f"Predictions project: "
f" name {bm.dt_project_info.name}, "
f" workspace_id {bm.dt_project_info.workspace_id}. "
f"Differences project: "
f" name {bm.diff_project_info.name}, "
f" workspace_id {bm.diff_project_info.workspace_id}"
)

button.loading = False
app.stop()


select_classes = widgets.ClassesListSelector(multiple=True)
select_classes.hide()
no_classes_label = widgets.Text(
"Not found any classes in the project that are present in the model",
status="error",
)
no_classes_label.hide()

sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True)
sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id)

button = widgets.Button("Evaluate")
button.disable()

pbar = widgets.SlyTqdm()

report_model_benchmark = widgets.ReportThumbnail()
report_model_benchmark.hide()
creating_report_f = widgets.Field(widgets.Empty(), "", "Creating report on model...")
Expand All @@ -79,35 +96,59 @@ def main_func():
sel_project,
sel_app_session,
button,
select_classes,
no_classes_label,
creating_report_f,
report_model_benchmark,
pbar,
]
],
)


@sel_project.value_changed
def handle(project_id: Optional[int]):
active = project_id is not None and sel_app_session.get_selected_id() is not None
def handle_selectors(active: bool):
no_classes_label.hide()
select_classes.hide()
button.loading = True
if active:
button.enable()
else:
button.disable()
classes = f.get_classes()
if len(classes) > 0:
select_classes.set(classes)
select_classes.show()
select_classes.select_all()
button.loading = False
button.enable()
return
else:
no_classes_label.show()
button.loading = False
button.disable()


@sel_project.value_changed
def handle_sel_project(project_id: Optional[int]):
g.project_id = project_id
active = project_id is not None and g.session_id is not None
handle_selectors(active)


@sel_app_session.value_changed
def handle(session_id: Optional[int]):
active = session_id is not None and sel_project.get_selected_id() is not None
if active:
button.enable()
else:
button.disable()
def handle_sel_app_session(session_id: Optional[int]):
g.session_id = session_id
active = session_id is not None and g.project_id is not None
handle_selectors(active)


@button.click
def handle():
creating_report_f.show()
select_classes.disable()
main_func()


app = sly.Application(layout=layout, static_dir=g.STATIC_DIR)

if g.project_id:
sel_project.set_project_id(g.project_id)

if g.session_id:
sel_app_session.set_session_id(g.session_id)
8 changes: 8 additions & 0 deletions src/modal.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<div id="embeddings-calculator">
<sly-field title="Select Served session" description="select ">
<sly-select-app-session :group-id="context.teamId"
:app-session-id.sync="state.sessionId"
:options="state.sessionOptions">
</sly-select-app-session>
</sly-field>
</div>

0 comments on commit 455c588

Please sign in to comment.