diff --git a/config.json b/config.json index 7898dc2..68659a7 100644 --- a/config.json +++ b/config.json @@ -12,5 +12,17 @@ "entrypoint": "python -m uvicorn src.main:app --host 0.0.0.0 --port 8000", "port": 8000, "docker_image": "supervisely/model-benchmark:0.0.1", - "instance_version": "6.11.10" + "instance_version": "6.11.10", + "context_menu": { + "target": ["images_project"] + }, + "modal_template": "src/modal.html", + "modal_template_state": { + "sessionId": null, + "sessionOptions": { + "sessionTags": ["deployed_nn"], + "showLabel": false, + "size": "small" + } + } } diff --git a/src/functions.py b/src/functions.py new file mode 100644 index 0000000..74c65a4 --- /dev/null +++ b/src/functions.py @@ -0,0 +1,35 @@ +import src.globals as g +import supervisely as sly +from supervisely.nn.inference import Session + +geometry_to_task_type = { + "object detection": [sly.Rectangle], + "instance segmentation": [sly.Bitmap, sly.Polygon, sly.AlphaMask], + # "semantic segmentation": [sly.Bitmap, sly.Polygon, sly.AlphaMask], +} + + +def get_project_classes(): + meta = sly.ProjectMeta.from_json(g.api.project.get_meta(g.project_id)) + return meta.obj_classes + + +def get_model_info(): + session = Session(g.api, g.session_id) + model_meta = session.get_model_meta() + session_info = session.get_session_info() + return model_meta.obj_classes, session_info["task type"] + + +def get_classes(): + project_classes = get_project_classes() + model_classes, task_type = get_model_info() + if task_type not in geometry_to_task_type: + raise ValueError(f"Task type {task_type} is not supported yet") + filtered_classes = [] + for obj_class in project_classes: + if model_classes.has_key(obj_class.name): + if obj_class.geometry_type in geometry_to_task_type[task_type]: + filtered_classes.append(obj_class) + + return filtered_classes diff --git a/src/globals.py b/src/globals.py index bf1f62b..543c631 100644 --- a/src/globals.py +++ b/src/globals.py @@ -1,8 +1,9 @@ import os -import supervisely as sly from dotenv import load_dotenv +import supervisely as sly + if sly.is_development(): load_dotenv("local.env") load_dotenv(os.path.expanduser("~/supervisely.env")) @@ -21,3 +22,6 @@ project_id = sly.env.project_id(raise_not_found=False) team_id = sly.env.team_id() task_id = sly.env.task_id(raise_not_found=False) +session_id = os.environ.get("modal.state.sessionId", None) +if session_id is not None: + session_id = int(session_id) diff --git a/src/main.py b/src/main.py index 98892aa..19924c1 100644 --- a/src/main.py +++ b/src/main.py @@ -1,34 +1,37 @@ from typing import Optional +import src.functions as f +import src.globals as g +import src.workflow as w import supervisely as sly import supervisely.app.widgets as widgets from supervisely.nn.benchmark import ObjectDetectionBenchmark -import src.globals as g -import src.workflow as w - def main_func(): api = g.api - project = api.project.get_info_by_id(sel_project.get_selected_id()) - session_id = sel_app_session.get_selected_id() + project = api.project.get_info_by_id(g.project_id) # ==================== Workflow input ==================== - w.workflow_input(api, project, session_id) + w.workflow_input(api, project, g.session_id) # ======================================================= - pbar.show() report_model_benchmark.hide() + selected_classes = select_classes.get_selected_classes() bm = ObjectDetectionBenchmark( - api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar + api, + project.id, + output_dir=g.STORAGE_DIR + "/benchmark", + progress=pbar, + classes=selected_classes, ) - sly.logger.info(f"{session_id = }") - bm.run_evaluation(model_session=session_id) + sly.logger.info(f"{g.session_id = }") + bm.run_evaluation(model_session=g.session_id) - session_info = api.task.get_info_by_id(session_id) - task_dir = f"{session_id}_{session_info['meta']['app']['name']}" + session_info = api.task.get_info_by_id(g.session_id) + task_dir = f"{g.session_id}_{session_info['meta']['app']['name']}" eval_res_dir = f"/model-benchmark/evaluation/{project.id}_{project.name}/{task_dir}/" eval_res_dir = api.storage.get_free_dir_name(g.team_id, eval_res_dir) @@ -54,20 +57,34 @@ def main_func(): # ======================================================= sly.logger.info( - f"Predictions project name {bm.dt_project_info.name}, workspace_id {bm.dt_project_info.workspace_id}" - ) - sly.logger.info( - f"Differences project name {bm.diff_project_info.name}, workspace_id {bm.diff_project_info.workspace_id}" + f"Predictions project: " + f" name {bm.dt_project_info.name}, " + f" workspace_id {bm.dt_project_info.workspace_id}. " + f"Differences project: " + f" name {bm.diff_project_info.name}, " + f" workspace_id {bm.diff_project_info.workspace_id}" ) button.loading = False app.stop() +select_classes = widgets.ClassesListSelector(multiple=True) +select_classes.hide() +no_classes_label = widgets.Text( + "Not found any classes in the project that are present in the model", + status="error", +) +no_classes_label.hide() + sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True) sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id) + button = widgets.Button("Evaluate") +button.disable() + pbar = widgets.SlyTqdm() + report_model_benchmark = widgets.ReportThumbnail() report_model_benchmark.hide() creating_report_f = widgets.Field(widgets.Empty(), "", "Creating report on model...") @@ -79,35 +96,59 @@ def main_func(): sel_project, sel_app_session, button, + select_classes, + no_classes_label, creating_report_f, report_model_benchmark, pbar, - ] + ], ) -@sel_project.value_changed -def handle(project_id: Optional[int]): - active = project_id is not None and sel_app_session.get_selected_id() is not None +def handle_selectors(active: bool): + no_classes_label.hide() + select_classes.hide() + button.loading = True if active: - button.enable() - else: - button.disable() + classes = f.get_classes() + if len(classes) > 0: + select_classes.set(classes) + select_classes.show() + select_classes.select_all() + button.loading = False + button.enable() + return + else: + no_classes_label.show() + button.loading = False + button.disable() + + +@sel_project.value_changed +def handle_sel_project(project_id: Optional[int]): + g.project_id = project_id + active = project_id is not None and g.session_id is not None + handle_selectors(active) @sel_app_session.value_changed -def handle(session_id: Optional[int]): - active = session_id is not None and sel_project.get_selected_id() is not None - if active: - button.enable() - else: - button.disable() +def handle_sel_app_session(session_id: Optional[int]): + g.session_id = session_id + active = session_id is not None and g.project_id is not None + handle_selectors(active) @button.click def handle(): creating_report_f.show() + select_classes.disable() main_func() app = sly.Application(layout=layout, static_dir=g.STATIC_DIR) + +if g.project_id: + sel_project.set_project_id(g.project_id) + +if g.session_id: + sel_app_session.set_session_id(g.session_id) diff --git a/src/modal.html b/src/modal.html new file mode 100644 index 0000000..782571d --- /dev/null +++ b/src/modal.html @@ -0,0 +1,8 @@ +