Skip to content

Commit

Permalink
update eval params
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Oct 9, 2024
1 parent 5d9ccf0 commit d20a8ef
Showing 1 changed file with 32 additions and 6 deletions.
38 changes: 32 additions & 6 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
from typing import Optional

import yaml

import src.functions as f
import src.globals as g
import src.workflow as w
import supervisely as sly
import supervisely.app.widgets as widgets
from supervisely.nn import TaskType
from supervisely.nn.benchmark import (
InstanceSegmentationBenchmark,
ObjectDetectionBenchmark,
)
from supervisely.nn.benchmark.evaluation.base_evaluator import BaseEvaluator
from supervisely.nn.benchmark.evaluation.instance_segmentation_evaluator import (
InstanceSegmentationEvaluator,
)
from supervisely.nn.benchmark.evaluation.object_detection_evaluator import (
ObjectDetectionEvaluator,
)
from supervisely.nn.inference.session import SessionJSON


Expand Down Expand Up @@ -41,7 +50,7 @@ def main_func():
progress=pbar,
progress_secondary=sec_pbar,
classes_whitelist=g.selected_classes,
evaluation_parameters=evaluation_parameters,
evaluation_params=evaluation_parameters,
)
elif task_type == "instance segmentation":
bm = InstanceSegmentationBenchmark(
Expand All @@ -51,7 +60,7 @@ def main_func():
progress=pbar,
progress_secondary=sec_pbar,
classes_whitelist=g.selected_classes,
evaluation_parameters=evaluation_parameters,
evaluation_params=evaluation_parameters,
)
sly.logger.info(f"{g.session_id = }")

Expand All @@ -62,7 +71,6 @@ def main_func():
res_dir = api.storage.get_free_dir_name(g.team_id, res_dir)

bm.run_evaluation(model_session=g.session_id)
bm.upload_eval_results(res_dir + "/evaluation/")

try:
bm.run_speedtest(g.session_id, g.project_id)
Expand All @@ -72,6 +80,8 @@ def main_func():
sly.logger.warn(f"Speedtest failed. Skipping. {e}")

bm.visualize()

bm.upload_eval_results(res_dir + "/evaluation/")
remote_dir = bm.upload_visualizations(res_dir + "/visualizations/")

report = bm.upload_report_link(remote_dir)
Expand Down Expand Up @@ -113,7 +123,9 @@ def main_func():
sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id)

eval_params = widgets.Editor(
initial_text=BaseEvaluator.default_parameters(), language_mode="yaml", height_lines=16
initial_text=BaseEvaluator._get_default_evaluation_params(),
language_mode="yaml",
height_lines=16,
)
eval_params_card = widgets.Card(
title="Evaluation parameters",
Expand Down Expand Up @@ -178,6 +190,18 @@ def set_selected_classes_and_show_info():
no_classes_label.show()


def update_eval_params():
if g.session is None:
g.session = SessionJSON(g.api, g.session_id)
task_type = g.session.get_deploy_info()["task_type"]
if task_type == TaskType.OBJECT_DETECTION:
params = ObjectDetectionEvaluator._get_default_evaluation_params()
elif task_type == TaskType.INSTANCE_SEGMENTATION:
params = InstanceSegmentationEvaluator._get_default_evaluation_params()
eval_params.set_text(yaml.dump(params), language_mode="yaml")
eval_params_card.uncollapse()


def handle_selectors(active: bool):
no_classes_label.hide()
selected_matched_text.hide()
Expand All @@ -201,11 +225,12 @@ def handle_sel_app_session(session_id: Optional[int]):
active = session_id is not None and g.project_id is not None
handle_selectors(active)

if g.session_id:
update_eval_params()


@button.click
def start_evaluation():
# select_classes.hide()
# not_matched_classes.hide()
main_func()


Expand All @@ -216,6 +241,7 @@ def start_evaluation():

if g.session_id:
sel_app_session.set_session_id(g.session_id)
update_eval_params()

if g.autostart:
start_evaluation()
Expand Down

0 comments on commit d20a8ef

Please sign in to comment.