Skip to content

Commit

Permalink
add inst segm
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Aug 30, 2024
1 parent 7bc024f commit 9306184
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 23 deletions.
7 changes: 4 additions & 3 deletions src/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ def get_project_classes():


def get_model_info():
session = Session(g.api, g.session_id)
model_meta = session.get_model_meta()
session_info = session.get_session_info()
if g.session is None:
g.session = Session(g.api, g.session_id)
model_meta = g.session.get_model_meta()
session_info = g.session.get_session_info()
return model_meta.obj_classes, session_info["task type"]


Expand Down
1 change: 1 addition & 0 deletions src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
session_id = os.environ.get("modal.state.sessionId", None)
if session_id is not None:
session_id = int(session_id)
session = None
autostart = bool(strtobool(os.environ.get("modal.state.autoStart", "false")))

selected_classes = None
56 changes: 36 additions & 20 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,19 @@
import src.workflow as w
import supervisely as sly
import supervisely.app.widgets as widgets
from supervisely.nn.benchmark import ObjectDetectionBenchmark
from supervisely.nn.benchmark import (
InstanceSegmentationBenchmark,
ObjectDetectionBenchmark,
)
from supervisely.nn.inference.session import SessionJSON


def main_func():
api = g.api
project = api.project.get_info_by_id(g.project_id)
if g.session_id is None:
g.session = SessionJSON(api, g.session_id)
task_type = g.session.get_deploy_info()["task_type"]

# ==================== Workflow input ====================
w.workflow_input(api, project, g.session_id)
Expand All @@ -26,18 +33,27 @@ def main_func():
if g.selected_classes is None:
g.selected_classes = f.get_matched_class_names()

bm = ObjectDetectionBenchmark(
api,
project.id,
output_dir=g.STORAGE_DIR + "/benchmark",
progress=pbar,
classes_whitelist=g.selected_classes,
)
if task_type == "object detection":
bm = ObjectDetectionBenchmark(
api,
project.id,
output_dir=g.STORAGE_DIR + "/benchmark",
progress=pbar,
classes_whitelist=g.selected_classes,
)
elif task_type == "instance segmentation":
bm = InstanceSegmentationBenchmark(
api,
project.id,
output_dir=g.STORAGE_DIR + "/benchmark",
progress=pbar,
classes_whitelist=g.selected_classes,
)
sly.logger.info(f"{g.session_id = }")
bm.run_evaluation(model_session=g.session_id)

session_info = api.task.get_info_by_id(g.session_id)
task_dir = f"{g.session_id}_{session_info['meta']['app']['name']}"
task_info = api.task.get_info_by_id(g.session_id)
task_dir = f"{g.session_id}_{task_info['meta']['app']['name']}"
eval_res_dir = f"/model-benchmark/evaluation/{project.id}_{project.name}/{task_dir}/"
eval_res_dir = api.storage.get_free_dir_name(g.team_id, eval_res_dir)

Expand Down Expand Up @@ -146,29 +162,29 @@ def handle_selectors(active: bool):
# not_matched_classes.hide()
if active:
matched, not_matched = f.get_classes()
_, matched_model = matched
_, not_matched_model = not_matched
_, matched_model_classes = matched
_, not_matched_model_classes = not_matched

g.selected_classes = [obj_cls.name for obj_cls in matched_model]
not_matched_classes_cnt = len(not_matched_model)
total_classes = len(matched_model) + len(not_matched_model)
g.selected_classes = [obj_cls.name for obj_cls in matched_model_classes]
not_matched_classes_cnt = len(not_matched_model_classes)
total_classes = len(matched_model_classes) + len(not_matched_model_classes)

total_classes_text.text = f"{total_classes} classes found in the model."
selected_matched_text.text = f"{len(matched_model)} classes can be used for evaluation."
selected_matched_text.text = f"{len(matched_model_classes)} classes can be used for evaluation."
not_matched_text.text = f"{not_matched_classes_cnt} classes are not available for evaluation (not found in the GT project or have different geometry type)."

if len(matched_model) > 0:
if len(matched_model_classes) > 0:
selected_matched_text.show()
if not_matched_classes_cnt > 0:
not_matched_text.show()
button.enable()
return
else:
no_classes_label.show()
# select_classes.set(left_collection=model_classes, right_collection=matched_model)
# select_classes.set(left_collection=model_classes, right_collection=matched_model_classes)

# matched_model, model_classes = not_matched
# not_matched_classes.set(left_collection=model_classes, right_collection=matched_model)
# matched_model_classes, model_classes = not_matched
# not_matched_classes.set(left_collection=model_classes, right_collection=matched_model_classes)
# stats = select_classes.get_stat()
# if len(stats["match"]) > 0:
# select_classes.show()
Expand Down

0 comments on commit 9306184

Please sign in to comment.