Skip to content

Commit

Permalink
refactor evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Dec 11, 2024
1 parent d8448aa commit 5dfe28e
Showing 1 changed file with 62 additions and 64 deletions.
126 changes: 62 additions & 64 deletions src/ui/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Union
from typing import Dict, Optional, Tuple, Union

import yaml

Expand All @@ -20,6 +20,14 @@
SlyTqdm,
Text,
)
from supervisely.nn.benchmark import (
InstanceSegmentationBenchmark,
InstanceSegmentationEvaluator,
ObjectDetectionBenchmark,
ObjectDetectionEvaluator,
SemanticSegmentationBenchmark,
SemanticSegmentationEvaluator,
)
from supervisely.nn.inference.session import SessionJSON

no_classes_label = Text(
Expand Down Expand Up @@ -78,6 +86,27 @@
]
)

benchmark_cls_type = Union[
ObjectDetectionBenchmark, InstanceSegmentationBenchmark, SemanticSegmentationBenchmark
]

evaluator_cls_type = Union[
ObjectDetectionEvaluator, InstanceSegmentationEvaluator, SemanticSegmentationEvaluator
]


def get_benchmark_and_evaluator_classes(
task_type: sly.nn.TaskType,
) -> Tuple[benchmark_cls_type, evaluator_cls_type]:
if task_type == sly.nn.TaskType.OBJECT_DETECTION:
return ObjectDetectionBenchmark, ObjectDetectionEvaluator
elif task_type == sly.nn.TaskType.INSTANCE_SEGMENTATION:
return (InstanceSegmentationBenchmark, InstanceSegmentationEvaluator)
elif task_type == sly.nn.TaskType.SEMANTIC_SEGMENTATION:
return (SemanticSegmentationBenchmark, SemanticSegmentationEvaluator)
else:
raise ValueError(f"Unknown task type: {task_type}")


@f.with_clean_up_progress(eval_pbar)
def run_evaluation(
Expand All @@ -87,10 +116,8 @@ def run_evaluation(
):
work_dir = g.STORAGE_DIR + "/benchmark_" + sly.rand_str(6)

if session_id is not None:
g.session_id = session_id
if project_id is not None:
g.project_id = project_id
g.session_id = session_id or g.session_id
g.project_id = project_id or g.project_id

project = g.api.project.get_info_by_id(g.project_id)
if g.session is None:
Expand Down Expand Up @@ -119,48 +146,23 @@ def run_evaluation(

params = eval_params.get_value() or params
if isinstance(params, str):
sly.Annotation.filter_labels_by_classes
params = yaml.safe_load(params)

if task_type == sly.nn.TaskType.OBJECT_DETECTION:
if params is None:
params = sly.nn.benchmark.ObjectDetectionEvaluator.load_yaml_evaluation_params()
params = yaml.safe_load(params)
bm = sly.nn.benchmark.ObjectDetectionBenchmark(
g.api,
project.id,
gt_dataset_ids=dataset_ids,
output_dir=work_dir,
progress=eval_pbar,
progress_secondary=sec_eval_pbar,
classes_whitelist=g.selected_classes,
evaluation_params=params,
)
elif task_type == sly.nn.TaskType.INSTANCE_SEGMENTATION:
if params is None:
params = sly.nn.benchmark.InstanceSegmentationEvaluator.load_yaml_evaluation_params()
params = yaml.safe_load(params)
bm = sly.nn.benchmark.InstanceSegmentationBenchmark(
g.api,
project.id,
gt_dataset_ids=dataset_ids,
output_dir=work_dir,
progress=eval_pbar,
progress_secondary=sec_eval_pbar,
classes_whitelist=g.selected_classes,
evaluation_params=params,
)
elif task_type == sly.nn.TaskType.SEMANTIC_SEGMENTATION:
params = sly.nn.benchmark.SemanticSegmentationEvaluator.load_yaml_evaluation_params()
bm = sly.nn.benchmark.SemanticSegmentationBenchmark(
g.api,
project.id,
gt_dataset_ids=dataset_ids,
output_dir=work_dir,
progress=eval_pbar,
progress_secondary=sec_eval_pbar,
classes_whitelist=g.selected_classes,
evaluation_params=params,
)
bm_cls, evaluator_cls = get_benchmark_and_evaluator_classes(task_type)
if params is None:
params = evaluator_cls.load_yaml_evaluation_params()
params = yaml.safe_load(params)
bm: benchmark_cls_type = bm_cls(
g.api,
project.id,
gt_dataset_ids=dataset_ids,
output_dir=work_dir,
progress=eval_pbar,
progress_secondary=sec_eval_pbar,
classes_whitelist=g.selected_classes,
evaluation_params=params,
)

bm.evaluator_app_info = g.api.task.get_info_by_id(g.task_id)
sly.logger.info(f"{g.session_id = }")
Expand Down Expand Up @@ -197,30 +199,21 @@ def run_evaluation(
bm.visualize()

bm.upload_eval_results(res_dir + "/evaluation/")
remote_dir = bm.upload_visualizations(res_dir + "/visualizations/")
bm.upload_visualizations(res_dir + "/visualizations/")

report = bm.upload_report_link(remote_dir)
g.api.task.set_output_report(g.task_id, report.id, report.name)

template_vis_file = g.api.file.get_info_by_path(
sly.env.team_id(), res_dir + "/visualizations/template.vue"
)
report_model_benchmark.set(template_vis_file)
g.api.task.set_output_report(g.task_id, bm.lnk.id, bm.lnk.name, "Click to open the report")
report_model_benchmark.set(bm.report)
report_model_benchmark.show()
eval_pbar.hide()

# ==================== Workflow output ====================
w.workflow_output(g.api, res_dir, template_vis_file)
w.workflow_output(g.api, res_dir, bm.report)
# =======================================================

sly.logger.info(
f"Predictions project: "
f" name {bm.dt_project_info.name}, "
f" workspace_id {bm.dt_project_info.workspace_id}. "
# f"Differences project: "
# f" name {bm.diff_project_info.name}, "
# f" workspace_id {bm.diff_project_info.workspace_id}"
f"Predictions project {bm.dt_project_info.name}, workspace ID: {bm.dt_project_info.workspace_id}."
)
sly.logger.info(f"Report link: {bm.get_report_link()}")

eval_button.loading = False

Expand Down Expand Up @@ -250,13 +243,18 @@ def update_eval_params():
g.session = SessionJSON(g.api, g.session_id)
task_type = g.session.get_deploy_info()["task_type"]
if task_type == sly.nn.TaskType.OBJECT_DETECTION:
params = sly.nn.benchmark.ObjectDetectionEvaluator.load_yaml_evaluation_params()
params = ObjectDetectionEvaluator.load_yaml_evaluation_params()
elif task_type == sly.nn.TaskType.INSTANCE_SEGMENTATION:
params = sly.nn.benchmark.InstanceSegmentationEvaluator.load_yaml_evaluation_params()
params = InstanceSegmentationEvaluator.load_yaml_evaluation_params()
elif task_type == sly.nn.TaskType.SEMANTIC_SEGMENTATION:
params = "# Semantic Segmentation evaluation parameters are not available yet."
params = ""
eval_params.set_text(params, language_mode="yaml")
eval_params_card.uncollapse()

if task_type == sly.nn.TaskType.SEMANTIC_SEGMENTATION:
eval_params_card.hide()
else:
eval_params_card.show()
eval_params_card.uncollapse()


def handle_selectors(active: bool):
Expand Down

0 comments on commit 5dfe28e

Please sign in to comment.