Skip to content

Commit

Permalink
fix eval params
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Oct 14, 2024
1 parent ff325dd commit 6eecdc9
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions src/ui/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Dict, Optional, Union

import yaml

Expand All @@ -13,7 +13,6 @@
InstanceSegmentationBenchmark,
ObjectDetectionBenchmark,
)
from supervisely.nn.benchmark.comparison.model_comparison import ModelComparison
from supervisely.nn.benchmark.evaluation.instance_segmentation_evaluator import (
InstanceSegmentationEvaluator,
)
Expand All @@ -23,7 +22,11 @@
from supervisely.nn.inference.session import SessionJSON


def run_evaluation(session_id: Optional[int] = None, project_id: Optional[int] = None):
def run_evaluation(
session_id: Optional[int] = None,
project_id: Optional[int] = None,
params: Optional[Union[str, Dict]] = None,
):
work_dir = g.STORAGE_DIR + "/benchmark_" + rand_str(6)

if session_id is not None:
Expand All @@ -48,26 +51,36 @@ def run_evaluation(session_id: Optional[int] = None, project_id: Optional[int] =

pbar.show()
sec_pbar.show()
evaluation_parameters = yaml.safe_load(eval_params.get_value())
if task_type == "object detection":

evaluation_params = eval_params.get_value() or params
if isinstance(evaluation_params, str):
evaluation_params = yaml.safe_load(evaluation_params)

if task_type == TaskType.OBJECT_DETECTION:
if evaluation_params is None:
evaluation_params = ObjectDetectionEvaluator.load_yaml_evaluation_params()
evaluation_params = yaml.safe_load(evaluation_params)
bm = ObjectDetectionBenchmark(
g.api,
project.id,
output_dir=work_dir,
progress=pbar,
progress_secondary=sec_pbar,
classes_whitelist=g.selected_classes,
evaluation_params=evaluation_parameters,
evaluation_params=evaluation_params,
)
elif task_type == "instance segmentation":
elif task_type == TaskType.INSTANCE_SEGMENTATION:
if evaluation_params is None:
evaluation_params = InstanceSegmentationEvaluator.load_yaml_evaluation_params()
evaluation_params = yaml.safe_load(evaluation_params)
bm = InstanceSegmentationBenchmark(
g.api,
project.id,
output_dir=work_dir,
progress=pbar,
progress_secondary=sec_pbar,
classes_whitelist=g.selected_classes,
evaluation_params=evaluation_parameters,
evaluation_params=evaluation_params,
)
sly.logger.info(f"{g.session_id = }")

Expand Down

0 comments on commit 6eecdc9

Please sign in to comment.