From eba6ca8dc4b3ac30393da5bca3e1865c3dd7bcb9 Mon Sep 17 00:00:00 2001 From: Nikolai Petukhov Date: Mon, 26 Aug 2024 17:17:45 -0300 Subject: [PATCH] add instance segmentation task support --- requirements.txt | 1 + src/main.py | 38 +++++++++++++++++++++++++++----------- 2 files changed, 28 insertions(+), 11 deletions(-) create mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..bf89648 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +git+https://github.com/supervisely/supervisely.git@model-benchmark-inst-seg#egg=supervisely[model-benchmark] \ No newline at end of file diff --git a/src/main.py b/src/main.py index 98892aa..14514ce 100644 --- a/src/main.py +++ b/src/main.py @@ -1,11 +1,14 @@ +from concurrent.futures import as_completed from typing import Optional -import supervisely as sly -import supervisely.app.widgets as widgets -from supervisely.nn.benchmark import ObjectDetectionBenchmark - import src.globals as g import src.workflow as w +import supervisely as sly +import supervisely.app.widgets as widgets +from supervisely.nn.benchmark import ( + InstanceSegmentationBenchmark, + ObjectDetectionBenchmark, +) def main_func(): @@ -17,13 +20,20 @@ def main_func(): w.workflow_input(api, project, session_id) # ======================================================= - pbar.show() report_model_benchmark.hide() - bm = ObjectDetectionBenchmark( - api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar - ) + task = sel_task.get_value() + if task == "detection": + bm = ObjectDetectionBenchmark( + api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar + ) + elif task == "segmentation": + bm = InstanceSegmentationBenchmark( + api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar + ) + else: + raise ValueError(f"Unknown task type: {task}") sly.logger.info(f"{session_id = }") bm.run_evaluation(model_session=session_id) @@ -66,6 +76,12 @@ def main_func(): sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True) sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id) +sel_task = widgets.Select( + items=[ + widgets.Select.Item("detection", "Detection"), + widgets.Select.Item("segmentation", "Segmentation"), + ] +) button = widgets.Button("Evaluate") pbar = widgets.SlyTqdm() report_model_benchmark = widgets.ReportThumbnail() @@ -75,9 +91,9 @@ def main_func(): layout = widgets.Container( widgets=[ - widgets.Text("Select GT Project"), - sel_project, - sel_app_session, + widgets.Field(sel_project, "Select GT Project"), + widgets.Field(sel_app_session, "Select Model Session"), + widgets.Field(sel_task, "Select Task"), button, creating_report_f, report_model_benchmark,