Skip to content

Commit

Permalink
add instance segmentation task support
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolaiPetukhov committed Aug 26, 2024
1 parent 2f771e1 commit eba6ca8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
git+https://github.com/supervisely/supervisely.git@model-benchmark-inst-seg#egg=supervisely[model-benchmark]
38 changes: 27 additions & 11 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from concurrent.futures import as_completed
from typing import Optional

import supervisely as sly
import supervisely.app.widgets as widgets
from supervisely.nn.benchmark import ObjectDetectionBenchmark

import src.globals as g
import src.workflow as w
import supervisely as sly
import supervisely.app.widgets as widgets
from supervisely.nn.benchmark import (
InstanceSegmentationBenchmark,
ObjectDetectionBenchmark,
)


def main_func():
Expand All @@ -17,13 +20,20 @@ def main_func():
w.workflow_input(api, project, session_id)
# =======================================================


pbar.show()
report_model_benchmark.hide()

bm = ObjectDetectionBenchmark(
api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar
)
task = sel_task.get_value()
if task == "detection":
bm = ObjectDetectionBenchmark(
api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar
)
elif task == "segmentation":
bm = InstanceSegmentationBenchmark(
api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar
)
else:
raise ValueError(f"Unknown task type: {task}")
sly.logger.info(f"{session_id = }")
bm.run_evaluation(model_session=session_id)

Expand Down Expand Up @@ -66,6 +76,12 @@ def main_func():

sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True)
sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id)
sel_task = widgets.Select(
items=[
widgets.Select.Item("detection", "Detection"),
widgets.Select.Item("segmentation", "Segmentation"),
]
)
button = widgets.Button("Evaluate")
pbar = widgets.SlyTqdm()
report_model_benchmark = widgets.ReportThumbnail()
Expand All @@ -75,9 +91,9 @@ def main_func():

layout = widgets.Container(
widgets=[
widgets.Text("Select GT Project"),
sel_project,
sel_app_session,
widgets.Field(sel_project, "Select GT Project"),
widgets.Field(sel_app_session, "Select Model Session"),
widgets.Field(sel_task, "Select Task"),
button,
creating_report_f,
report_model_benchmark,
Expand Down

0 comments on commit eba6ca8

Please sign in to comment.