diff --git a/src/ui/evaluation.py b/src/ui/evaluation.py index 4e265d9..0390eb8 100644 --- a/src/ui/evaluation.py +++ b/src/ui/evaluation.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Union +from typing import Dict, Optional, Tuple, Union import yaml @@ -20,6 +20,14 @@ SlyTqdm, Text, ) +from supervisely.nn.benchmark import ( + InstanceSegmentationBenchmark, + InstanceSegmentationEvaluator, + ObjectDetectionBenchmark, + ObjectDetectionEvaluator, + SemanticSegmentationBenchmark, + SemanticSegmentationEvaluator, +) from supervisely.nn.inference.session import SessionJSON no_classes_label = Text( @@ -78,6 +86,27 @@ ] ) +benchmark_cls_type = Union[ + ObjectDetectionBenchmark, InstanceSegmentationBenchmark, SemanticSegmentationBenchmark +] + +evaluator_cls_type = Union[ + ObjectDetectionEvaluator, InstanceSegmentationEvaluator, SemanticSegmentationEvaluator +] + + +def get_benchmark_and_evaluator_classes( + task_type: sly.nn.TaskType, +) -> Tuple[benchmark_cls_type, evaluator_cls_type]: + if task_type == sly.nn.TaskType.OBJECT_DETECTION: + return ObjectDetectionBenchmark, ObjectDetectionEvaluator + elif task_type == sly.nn.TaskType.INSTANCE_SEGMENTATION: + return (InstanceSegmentationBenchmark, InstanceSegmentationEvaluator) + elif task_type == sly.nn.TaskType.SEMANTIC_SEGMENTATION: + return (SemanticSegmentationBenchmark, SemanticSegmentationEvaluator) + else: + raise ValueError(f"Unknown task type: {task_type}") + @f.with_clean_up_progress(eval_pbar) def run_evaluation( @@ -87,10 +116,8 @@ def run_evaluation( ): work_dir = g.STORAGE_DIR + "/benchmark_" + sly.rand_str(6) - if session_id is not None: - g.session_id = session_id - if project_id is not None: - g.project_id = project_id + g.session_id = session_id or g.session_id + g.project_id = project_id or g.project_id project = g.api.project.get_info_by_id(g.project_id) if g.session is None: @@ -119,48 +146,23 @@ def run_evaluation( params = eval_params.get_value() or params if isinstance(params, str): + sly.Annotation.filter_labels_by_classes params = yaml.safe_load(params) - if task_type == sly.nn.TaskType.OBJECT_DETECTION: - if params is None: - params = sly.nn.benchmark.ObjectDetectionEvaluator.load_yaml_evaluation_params() - params = yaml.safe_load(params) - bm = sly.nn.benchmark.ObjectDetectionBenchmark( - g.api, - project.id, - gt_dataset_ids=dataset_ids, - output_dir=work_dir, - progress=eval_pbar, - progress_secondary=sec_eval_pbar, - classes_whitelist=g.selected_classes, - evaluation_params=params, - ) - elif task_type == sly.nn.TaskType.INSTANCE_SEGMENTATION: - if params is None: - params = sly.nn.benchmark.InstanceSegmentationEvaluator.load_yaml_evaluation_params() - params = yaml.safe_load(params) - bm = sly.nn.benchmark.InstanceSegmentationBenchmark( - g.api, - project.id, - gt_dataset_ids=dataset_ids, - output_dir=work_dir, - progress=eval_pbar, - progress_secondary=sec_eval_pbar, - classes_whitelist=g.selected_classes, - evaluation_params=params, - ) - elif task_type == sly.nn.TaskType.SEMANTIC_SEGMENTATION: - params = sly.nn.benchmark.SemanticSegmentationEvaluator.load_yaml_evaluation_params() - bm = sly.nn.benchmark.SemanticSegmentationBenchmark( - g.api, - project.id, - gt_dataset_ids=dataset_ids, - output_dir=work_dir, - progress=eval_pbar, - progress_secondary=sec_eval_pbar, - classes_whitelist=g.selected_classes, - evaluation_params=params, - ) + bm_cls, evaluator_cls = get_benchmark_and_evaluator_classes(task_type) + if params is None: + params = evaluator_cls.load_yaml_evaluation_params() + params = yaml.safe_load(params) + bm: benchmark_cls_type = bm_cls( + g.api, + project.id, + gt_dataset_ids=dataset_ids, + output_dir=work_dir, + progress=eval_pbar, + progress_secondary=sec_eval_pbar, + classes_whitelist=g.selected_classes, + evaluation_params=params, + ) bm.evaluator_app_info = g.api.task.get_info_by_id(g.task_id) sly.logger.info(f"{g.session_id = }") @@ -197,30 +199,21 @@ def run_evaluation( bm.visualize() bm.upload_eval_results(res_dir + "/evaluation/") - remote_dir = bm.upload_visualizations(res_dir + "/visualizations/") + bm.upload_visualizations(res_dir + "/visualizations/") - report = bm.upload_report_link(remote_dir) - g.api.task.set_output_report(g.task_id, report.id, report.name) - - template_vis_file = g.api.file.get_info_by_path( - sly.env.team_id(), res_dir + "/visualizations/template.vue" - ) - report_model_benchmark.set(template_vis_file) + g.api.task.set_output_report(g.task_id, bm.lnk.id, bm.lnk.name, "Click to open the report") + report_model_benchmark.set(bm.report) report_model_benchmark.show() eval_pbar.hide() # ==================== Workflow output ==================== - w.workflow_output(g.api, res_dir, template_vis_file) + w.workflow_output(g.api, res_dir, bm.report) # ======================================================= sly.logger.info( - f"Predictions project: " - f" name {bm.dt_project_info.name}, " - f" workspace_id {bm.dt_project_info.workspace_id}. " - # f"Differences project: " - # f" name {bm.diff_project_info.name}, " - # f" workspace_id {bm.diff_project_info.workspace_id}" + f"Predictions project {bm.dt_project_info.name}, workspace ID: {bm.dt_project_info.workspace_id}." ) + sly.logger.info(f"Report link: {bm.get_report_link()}") eval_button.loading = False @@ -250,13 +243,18 @@ def update_eval_params(): g.session = SessionJSON(g.api, g.session_id) task_type = g.session.get_deploy_info()["task_type"] if task_type == sly.nn.TaskType.OBJECT_DETECTION: - params = sly.nn.benchmark.ObjectDetectionEvaluator.load_yaml_evaluation_params() + params = ObjectDetectionEvaluator.load_yaml_evaluation_params() elif task_type == sly.nn.TaskType.INSTANCE_SEGMENTATION: - params = sly.nn.benchmark.InstanceSegmentationEvaluator.load_yaml_evaluation_params() + params = InstanceSegmentationEvaluator.load_yaml_evaluation_params() elif task_type == sly.nn.TaskType.SEMANTIC_SEGMENTATION: - params = "# Semantic Segmentation evaluation parameters are not available yet." + params = "" eval_params.set_text(params, language_mode="yaml") - eval_params_card.uncollapse() + + if task_type == sly.nn.TaskType.SEMANTIC_SEGMENTATION: + eval_params_card.hide() + else: + eval_params_card.show() + eval_params_card.uncollapse() def handle_selectors(active: bool):