Skip to content

Commit

Permalink
add instance segmentation task support
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolaiPetukhov committed Aug 30, 2024
1 parent 4b46be6 commit 6da925c
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@
import src.workflow as w
import supervisely as sly
import supervisely.app.widgets as widgets
from supervisely.nn.benchmark import ObjectDetectionBenchmark
from supervisely.nn.benchmark import (
InstanceSegmentationBenchmark,
ObjectDetectionBenchmark,
)
from supervisely.nn.inference.session import SessionJSON


def main_func():
api = g.api
project = api.project.get_info_by_id(sel_project.get_selected_id())
session_id = sel_app_session.get_selected_id()
session = SessionJSON(api, session_id)
task_type = session.get_deploy_info()["task_type"]

# ==================== Workflow input ====================
w.workflow_input(api, project, session_id)
Expand All @@ -19,14 +25,19 @@ def main_func():
pbar.show()
report_model_benchmark.hide()

bm = ObjectDetectionBenchmark(
api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar
)
if task_type == "object detection":
bm = ObjectDetectionBenchmark(
api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar
)
elif task_type == "instance segmentation":
bm = InstanceSegmentationBenchmark(
api, project.id, output_dir=g.STORAGE_DIR + "/benchmark", progress=pbar
)
sly.logger.info(f"{session_id = }")
bm.run_evaluation(model_session=session_id)

session_info = api.task.get_info_by_id(session_id)
task_dir = f"{session_id}_{session_info['meta']['app']['name']}"
task_info = api.task.get_info_by_id(session_id)
task_dir = f"{session_id}_{task_info['meta']['app']['name']}"
eval_res_dir = f"/model-benchmark/evaluation/{project.id}_{project.name}/{task_dir}/"
eval_res_dir = api.storage.get_free_dir_name(g.team_id, eval_res_dir)

Expand Down Expand Up @@ -117,4 +128,4 @@ def start_evaluation():
sel_app_session.set_session_id(g.session_id)

if g.autostart:
start_evaluation()
start_evaluation()

0 comments on commit 6da925c

Please sign in to comment.