From d7518aedcb9c6ca8afd13897e210d95928d74a8c Mon Sep 17 00:00:00 2001 From: Almaz <79905215+almazgimaev@users.noreply.github.com> Date: Fri, 30 Aug 2024 18:36:21 +0200 Subject: [PATCH] Support for Instance segmentation. Bug fixes and tiny improvements (#3) --- config.json | 2 +- dev_requirements.txt | 4 +- src/functions.py | 52 +++++++++++ src/globals.py | 3 + src/main.py | 201 ++++++++++++++++++++++++++++++++++--------- 5 files changed, 216 insertions(+), 46 deletions(-) create mode 100644 src/functions.py diff --git a/config.json b/config.json index 426fdb7..aac4b03 100644 --- a/config.json +++ b/config.json @@ -11,7 +11,7 @@ "task_location": "workspace_tasks", "entrypoint": "python -m uvicorn src.main:app --host 0.0.0.0 --port 8000", "port": 8000, - "docker_image": "supervisely/model-benchmark:0.0.1", + "docker_image": "supervisely/model-benchmark:0.0.2", "instance_version": "6.11.10", "context_menu": { "target": ["images_project"] diff --git a/dev_requirements.txt b/dev_requirements.txt index bb66a55..16598c4 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,3 +1,3 @@ # git+https://github.com/supervisely/supervisely.git@model-benchmark -supervisely==6.73.166 -supervisely[model-benchmark]==6.73.166 \ No newline at end of file +supervisely==6.73.173 +supervisely[model-benchmark]==6.73.173 \ No newline at end of file diff --git a/src/functions.py b/src/functions.py new file mode 100644 index 0000000..e038260 --- /dev/null +++ b/src/functions.py @@ -0,0 +1,52 @@ +import src.globals as g +import supervisely as sly +from supervisely.nn import TaskType +from supervisely.nn.inference import Session + +geometry_to_task_type = { + TaskType.OBJECT_DETECTION: [sly.Rectangle], + TaskType.INSTANCE_SEGMENTATION: [sly.Bitmap, sly.Polygon], +} + + +def get_project_classes(): + meta = sly.ProjectMeta.from_json(g.api.project.get_meta(g.project_id)) + return meta.obj_classes + + +def get_model_info(): + g.session = Session(g.api, g.session_id) + model_meta = g.session.get_model_meta() + session_info = g.session.get_session_info() + return model_meta.obj_classes, session_info["task type"] + + +def get_classes(): + project_classes = get_project_classes() + model_classes, task_type = get_model_info() + if task_type not in geometry_to_task_type: + raise ValueError(f"Task type {task_type} is not supported yet") + matched_proj_cls = [] + matched_model_cls = [] + not_matched_proj_cls = [] + not_matched_model_cls = [] + for obj_class in project_classes: + if model_classes.has_key(obj_class.name): + if obj_class.geometry_type in geometry_to_task_type[task_type]: + matched_proj_cls.append(obj_class) + matched_model_cls.append(model_classes.get(obj_class.name)) + else: + not_matched_proj_cls.append(obj_class) + else: + not_matched_proj_cls.append(obj_class) + + for obj_class in model_classes: + if not project_classes.has_key(obj_class.name): + not_matched_model_cls.append(obj_class) + + return (matched_proj_cls, matched_model_cls), (not_matched_proj_cls, not_matched_model_cls) + + +def get_matched_class_names(): + (cls_collection, _), _ = get_classes() + return [obj_cls.name for obj_cls in cls_collection] diff --git a/src/globals.py b/src/globals.py index 3b116fd..34d190d 100644 --- a/src/globals.py +++ b/src/globals.py @@ -26,4 +26,7 @@ session_id = os.environ.get("modal.state.sessionId", None) if session_id is not None: session_id = int(session_id) +session = None autostart = bool(strtobool(os.environ.get("modal.state.autoStart", "false"))) + +selected_classes = None \ No newline at end of file diff --git a/src/main.py b/src/main.py index c1a5532..d9e1472 100644 --- a/src/main.py +++ b/src/main.py @@ -1,32 +1,60 @@ from typing import Optional +import src.functions as f import src.globals as g import src.workflow as w import supervisely as sly import supervisely.app.widgets as widgets -from supervisely.nn.benchmark import ObjectDetectionBenchmark +from supervisely.nn.benchmark import ( + InstanceSegmentationBenchmark, + ObjectDetectionBenchmark, +) +from supervisely.nn.inference.session import SessionJSON def main_func(): api = g.api - project = api.project.get_info_by_id(sel_project.get_selected_id()) - session_id = sel_app_session.get_selected_id() + project = api.project.get_info_by_id(g.project_id) + if g.session_id is None: + g.session = SessionJSON(api, g.session_id) + task_type = g.session.get_deploy_info()["task_type"] # ==================== Workflow input ==================== - w.workflow_input(api, project, session_id) + w.workflow_input(api, project, g.session_id) # ======================================================= pbar.show() report_model_benchmark.hide() - bm = ObjectDetectionBenchmark( - api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar - ) - sly.logger.info(f"{session_id = }") - bm.run_evaluation(model_session=session_id) + # selected_classes = select_classes.get_selected() + + # matches = [x[1] for x in select_classes.get_stat()["match"]] + # selected_classes = [x[0] for x in selected_classes if x[0] in matches] + if g.selected_classes is None: + g.selected_classes = f.get_matched_class_names() + + if task_type == "object detection": + bm = ObjectDetectionBenchmark( + api, + project.id, + output_dir=g.STORAGE_DIR + "/benchmark", + progress=pbar, + classes_whitelist=g.selected_classes, + ) + elif task_type == "instance segmentation": + bm = InstanceSegmentationBenchmark( + api, + project.id, + output_dir=g.STORAGE_DIR + "/benchmark", + progress=pbar, + classes_whitelist=g.selected_classes, + ) + sly.logger.info(f"{g.session_id = }") + bm.run_evaluation(model_session=g.session_id) + + task_info = api.task.get_info_by_id(g.session_id) + task_dir = f"{g.session_id}_{task_info['meta']['app']['name']}" - session_info = api.task.get_info_by_id(session_id) - task_dir = f"{session_id}_{session_info['meta']['app']['name']}" eval_res_dir = f"/model-benchmark/evaluation/{project.id}_{project.name}/{task_dir}/" eval_res_dir = api.storage.get_free_dir_name(g.team_id, eval_res_dir) @@ -38,8 +66,6 @@ def main_func(): report = bm.upload_report_link(remote_dir) api.task.set_output_report(g.task_id, report.id, report.name) - creating_report_f.hide() - template_vis_file = api.file.get_info_by_path( sly.env.team_id(), eval_res_dir + "/visualizations/template.vue" ) @@ -52,63 +78,152 @@ def main_func(): # ======================================================= sly.logger.info( - f"Predictions project name {bm.dt_project_info.name}, workspace_id {bm.dt_project_info.workspace_id}" - ) - sly.logger.info( - f"Differences project name {bm.diff_project_info.name}, workspace_id {bm.diff_project_info.workspace_id}" + f"Predictions project: " + f" name {bm.dt_project_info.name}, " + f" workspace_id {bm.dt_project_info.workspace_id}. " + f"Differences project: " + f" name {bm.diff_project_info.name}, " + f" workspace_id {bm.diff_project_info.workspace_id}" ) button.loading = False app.stop() +# select_classes = widgets.MatchObjClasses( +# selectable=True, +# left_name="Model classes", +# right_name="GT project classes", +# ) +# select_classes.hide() +# not_matched_classes = widgets.MatchObjClasses( +# left_name="Model classes", +# right_name="GT project classes", +# ) +# not_matched_classes.hide() +no_classes_label = widgets.Text( + "Not found any classes in the project that are present in the model", status="error" +) +no_classes_label.hide() +total_classes_text = widgets.Text(status="info") +selected_matched_text = widgets.Text(status="success") +not_matched_text = widgets.Text(status="warning") + sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True) sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id) + button = widgets.Button("Evaluate") +button.disable() + pbar = widgets.SlyTqdm() + report_model_benchmark = widgets.ReportThumbnail() report_model_benchmark.hide() -creating_report_f = widgets.Field(widgets.Empty(), "", "Creating report on model...") -creating_report_f.hide() + +controls_card = widgets.Card( + title="Settings", + description="Select Ground Truth project and deployed model session", + content=widgets.Container( + [sel_project, sel_app_session, button, report_model_benchmark, pbar] + ), +) + + +# matched_card = widgets.Card( +# title="✅ Available classes", +# description="Select classes that are present in the model and in the project", +# content=select_classes, +# ) +# matched_card.lock(message="Select project and model session to enable") + +# not_matched_card = widgets.Card( +# title="❌ Not available classes", +# description="List of classes that are not matched between the model and the project", +# content=not_matched_classes, +# ) +# not_matched_card.lock(message="Select project and model session to enable") + layout = widgets.Container( - widgets=[ - widgets.Text("Select GT Project"), - sel_project, - sel_app_session, - button, - creating_report_f, - report_model_benchmark, - pbar, - ] + widgets=[controls_card, widgets.Empty(), widgets.Empty()], # , matched_card, not_matched_card], + direction="horizontal", + fractions=[1, 1, 1], ) +main_layout = widgets.Container( + widgets=[layout, total_classes_text, selected_matched_text, not_matched_text, no_classes_label] +) -@sel_project.value_changed -def handle(project_id: Optional[int]): - active = project_id is not None and sel_app_session.get_selected_id() is not None + +def handle_selectors(active: bool): + no_classes_label.hide() + selected_matched_text.hide() + not_matched_text.hide() + button.loading = True + # select_classes.hide() + # not_matched_classes.hide() if active: - button.enable() - else: - button.disable() + matched, not_matched = f.get_classes() + _, matched_model_classes = matched + _, not_matched_model_classes = not_matched + + g.selected_classes = [obj_cls.name for obj_cls in matched_model_classes] + not_matched_classes_cnt = len(not_matched_model_classes) + total_classes = len(matched_model_classes) + len(not_matched_model_classes) + + total_classes_text.text = f"{total_classes} classes found in the model." + selected_matched_text.text = f"{len(matched_model_classes)} classes can be used for evaluation." + not_matched_text.text = f"{not_matched_classes_cnt} classes are not available for evaluation (not found in the GT project or have different geometry type)." + + if len(matched_model_classes) > 0: + selected_matched_text.show() + if not_matched_classes_cnt > 0: + not_matched_text.show() + button.enable() + button.loading = False + return + else: + no_classes_label.show() + # select_classes.set(left_collection=model_classes, right_collection=matched_model_classes) + + # matched_model_classes, model_classes = not_matched + # not_matched_classes.set(left_collection=model_classes, right_collection=matched_model_classes) + # stats = select_classes.get_stat() + # if len(stats["match"]) > 0: + # select_classes.show() + # not_matched_classes.show() + # button.enable() + # # matched_card.unlock() + # # not_matched_card.unlock() + # return + # else: + # no_classes_label.show() + button.loading = False + button.disable() + + +@sel_project.value_changed +def handle_sel_project(project_id: Optional[int]): + g.project_id = project_id + active = project_id is not None and g.session_id is not None + handle_selectors(active) @sel_app_session.value_changed -def handle(session_id: Optional[int]): - active = session_id is not None and sel_project.get_selected_id() is not None - if active: - button.enable() - else: - button.disable() +def handle_sel_app_session(session_id: Optional[int]): + g.session_id = session_id + active = session_id is not None and g.project_id is not None + handle_selectors(active) @button.click def start_evaluation(): - creating_report_f.show() + # select_classes.hide() + # not_matched_classes.hide() main_func() -app = sly.Application(layout=layout, static_dir=g.STATIC_DIR) +app = sly.Application(layout=main_layout, static_dir=g.STATIC_DIR) if g.project_id: sel_project.set_project_id(g.project_id) @@ -117,4 +232,4 @@ def start_evaluation(): sel_app_session.set_session_id(g.session_id) if g.autostart: - start_evaluation() \ No newline at end of file + start_evaluation()