From 9306184b0bbc82661b8412d31c591f2ac6b94912 Mon Sep 17 00:00:00 2001 From: almaz Date: Fri, 30 Aug 2024 12:01:11 +0200 Subject: [PATCH] add inst segm --- src/functions.py | 7 +++--- src/globals.py | 1 + src/main.py | 56 +++++++++++++++++++++++++++++++----------------- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/src/functions.py b/src/functions.py index 910d074..98af164 100644 --- a/src/functions.py +++ b/src/functions.py @@ -15,9 +15,10 @@ def get_project_classes(): def get_model_info(): - session = Session(g.api, g.session_id) - model_meta = session.get_model_meta() - session_info = session.get_session_info() + if g.session is None: + g.session = Session(g.api, g.session_id) + model_meta = g.session.get_model_meta() + session_info = g.session.get_session_info() return model_meta.obj_classes, session_info["task type"] diff --git a/src/globals.py b/src/globals.py index 2c27d06..34d190d 100644 --- a/src/globals.py +++ b/src/globals.py @@ -26,6 +26,7 @@ session_id = os.environ.get("modal.state.sessionId", None) if session_id is not None: session_id = int(session_id) +session = None autostart = bool(strtobool(os.environ.get("modal.state.autoStart", "false"))) selected_classes = None \ No newline at end of file diff --git a/src/main.py b/src/main.py index d8a55d2..1e53c7c 100644 --- a/src/main.py +++ b/src/main.py @@ -5,12 +5,19 @@ import src.workflow as w import supervisely as sly import supervisely.app.widgets as widgets -from supervisely.nn.benchmark import ObjectDetectionBenchmark +from supervisely.nn.benchmark import ( + InstanceSegmentationBenchmark, + ObjectDetectionBenchmark, +) +from supervisely.nn.inference.session import SessionJSON def main_func(): api = g.api project = api.project.get_info_by_id(g.project_id) + if g.session_id is None: + g.session = SessionJSON(api, g.session_id) + task_type = g.session.get_deploy_info()["task_type"] # ==================== Workflow input ==================== w.workflow_input(api, project, g.session_id) @@ -26,18 +33,27 @@ def main_func(): if g.selected_classes is None: g.selected_classes = f.get_matched_class_names() - bm = ObjectDetectionBenchmark( - api, - project.id, - output_dir=g.STORAGE_DIR + "/benchmark", - progress=pbar, - classes_whitelist=g.selected_classes, - ) + if task_type == "object detection": + bm = ObjectDetectionBenchmark( + api, + project.id, + output_dir=g.STORAGE_DIR + "/benchmark", + progress=pbar, + classes_whitelist=g.selected_classes, + ) + elif task_type == "instance segmentation": + bm = InstanceSegmentationBenchmark( + api, + project.id, + output_dir=g.STORAGE_DIR + "/benchmark", + progress=pbar, + classes_whitelist=g.selected_classes, + ) sly.logger.info(f"{g.session_id = }") bm.run_evaluation(model_session=g.session_id) - session_info = api.task.get_info_by_id(g.session_id) - task_dir = f"{g.session_id}_{session_info['meta']['app']['name']}" + task_info = api.task.get_info_by_id(g.session_id) + task_dir = f"{g.session_id}_{task_info['meta']['app']['name']}" eval_res_dir = f"/model-benchmark/evaluation/{project.id}_{project.name}/{task_dir}/" eval_res_dir = api.storage.get_free_dir_name(g.team_id, eval_res_dir) @@ -146,18 +162,18 @@ def handle_selectors(active: bool): # not_matched_classes.hide() if active: matched, not_matched = f.get_classes() - _, matched_model = matched - _, not_matched_model = not_matched + _, matched_model_classes = matched + _, not_matched_model_classes = not_matched - g.selected_classes = [obj_cls.name for obj_cls in matched_model] - not_matched_classes_cnt = len(not_matched_model) - total_classes = len(matched_model) + len(not_matched_model) + g.selected_classes = [obj_cls.name for obj_cls in matched_model_classes] + not_matched_classes_cnt = len(not_matched_model_classes) + total_classes = len(matched_model_classes) + len(not_matched_model_classes) total_classes_text.text = f"{total_classes} classes found in the model." - selected_matched_text.text = f"{len(matched_model)} classes can be used for evaluation." + selected_matched_text.text = f"{len(matched_model_classes)} classes can be used for evaluation." not_matched_text.text = f"{not_matched_classes_cnt} classes are not available for evaluation (not found in the GT project or have different geometry type)." - if len(matched_model) > 0: + if len(matched_model_classes) > 0: selected_matched_text.show() if not_matched_classes_cnt > 0: not_matched_text.show() @@ -165,10 +181,10 @@ def handle_selectors(active: bool): return else: no_classes_label.show() - # select_classes.set(left_collection=model_classes, right_collection=matched_model) + # select_classes.set(left_collection=model_classes, right_collection=matched_model_classes) - # matched_model, model_classes = not_matched - # not_matched_classes.set(left_collection=model_classes, right_collection=matched_model) + # matched_model_classes, model_classes = not_matched + # not_matched_classes.set(left_collection=model_classes, right_collection=matched_model_classes) # stats = select_classes.get_stat() # if len(stats["match"]) > 0: # select_classes.show()