Skip to content

Commit

Permalink
refactor detection and add sem segm
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Nov 13, 2024
1 parent 1188b87 commit 2d00efa
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 38 deletions.
1 change: 1 addition & 0 deletions src/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
geometry_to_task_type = {
TaskType.OBJECT_DETECTION: [sly.Rectangle, sly.AnyGeometry],
TaskType.INSTANCE_SEGMENTATION: [sly.Bitmap, sly.Polygon, sly.AnyGeometry],
TaskType.SEMANTIC_SEGMENTATION: [sly.Bitmap, sly.Polygon, sly.AnyGeometry],
}


Expand Down
79 changes: 41 additions & 38 deletions src/ui/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,8 @@
import src.workflow as w
import supervisely as sly
import supervisely.app.widgets as widgets
from supervisely._utils import rand_str
from supervisely.nn import TaskType
from supervisely.nn.benchmark import (
InstanceSegmentationBenchmark,
ObjectDetectionBenchmark,
)
from supervisely.nn.benchmark.evaluation.instance_segmentation_evaluator import (
InstanceSegmentationEvaluator,
)
from supervisely.nn.benchmark.evaluation.object_detection_evaluator import (
ObjectDetectionEvaluator,
)
from supervisely.nn.inference.session import SessionJSON


no_classes_label = widgets.Text(
"Not found any classes in the project that are present in the model", status="error"
)
Expand Down Expand Up @@ -83,7 +70,7 @@ def run_evaluation(
project_id: Optional[int] = None,
params: Optional[Union[str, Dict]] = None,
):
work_dir = g.STORAGE_DIR + "/benchmark_" + rand_str(6)
work_dir = g.STORAGE_DIR + "/benchmark_" + sly.rand_str(6)

if session_id is not None:
g.session_id = session_id
Expand All @@ -102,7 +89,6 @@ def run_evaluation(
if len(dataset_ids) == 0:
raise ValueError("No datasets selected")


# ==================== Workflow input ====================
w.workflow_input(g.api, project, g.session_id)
# =======================================================
Expand All @@ -116,38 +102,52 @@ def run_evaluation(
eval_pbar.show()
sec_eval_pbar.show()

evaluation_params = eval_params.get_value() or params
if isinstance(evaluation_params, str):
evaluation_params = yaml.safe_load(evaluation_params)
params = eval_params.get_value() or params
if isinstance(params, str):
params = yaml.safe_load(params)

if task_type == TaskType.OBJECT_DETECTION:
if evaluation_params is None:
evaluation_params = ObjectDetectionEvaluator.load_yaml_evaluation_params()
evaluation_params = yaml.safe_load(evaluation_params)
bm = ObjectDetectionBenchmark(
if task_type == sly.nn.TaskType.OBJECT_DETECTION:
if params is None:
params = sly.nn.benchmark.ObjectDetectionEvaluator.load_yaml_evaluation_params()
params = yaml.safe_load(params)
bm = sly.nn.benchmark.ObjectDetectionBenchmark(
g.api,
project.id,
gt_dataset_ids=dataset_ids,
output_dir=work_dir,
progress=eval_pbar,
progress_secondary=sec_eval_pbar,
classes_whitelist=g.selected_classes,
evaluation_params=params,
)
elif task_type == sly.nn.TaskType.INSTANCE_SEGMENTATION:
if params is None:
params = sly.nn.benchmark.InstanceSegmentationEvaluator.load_yaml_evaluation_params()
params = yaml.safe_load(params)
bm = sly.nn.benchmark.InstanceSegmentationBenchmark(
g.api,
project.id,
gt_dataset_ids=dataset_ids,
output_dir=work_dir,
progress=eval_pbar,
progress_secondary=sec_eval_pbar,
classes_whitelist=g.selected_classes,
evaluation_params=evaluation_params,
evaluation_params=params,
)
elif task_type == TaskType.INSTANCE_SEGMENTATION:
if evaluation_params is None:
evaluation_params = InstanceSegmentationEvaluator.load_yaml_evaluation_params()
evaluation_params = yaml.safe_load(evaluation_params)
bm = InstanceSegmentationBenchmark(
elif task_type == sly.nn.TaskType.SEMANTIC_SEGMENTATION:
params = sly.nn.benchmark.SemanticSegmentationEvaluator.load_yaml_evaluation_params()
bm = sly.nn.benchmark.SemanticSegmentationBenchmark(
g.api,
project.id,
gt_dataset_ids=dataset_ids,
output_dir=work_dir,
progress=eval_pbar,
progress_secondary=sec_eval_pbar,
classes_whitelist=g.selected_classes,
evaluation_params=evaluation_params,
evaluation_params=params,
)

bm.evaluator_app_info = g.api.task.get_info_by_id(g.task_id)
sly.logger.info(f"{g.session_id = }")

task_info = g.api.task.get_info_by_id(g.session_id)
Expand Down Expand Up @@ -201,9 +201,9 @@ def run_evaluation(
f"Predictions project: "
f" name {bm.dt_project_info.name}, "
f" workspace_id {bm.dt_project_info.workspace_id}. "
f"Differences project: "
f" name {bm.diff_project_info.name}, "
f" workspace_id {bm.diff_project_info.workspace_id}"
# f"Differences project: "
# f" name {bm.diff_project_info.name}, "
# f" workspace_id {bm.diff_project_info.workspace_id}"
)

eval_button.loading = False
Expand Down Expand Up @@ -233,10 +233,12 @@ def update_eval_params():
if g.session is None:
g.session = SessionJSON(g.api, g.session_id)
task_type = g.session.get_deploy_info()["task_type"]
if task_type == TaskType.OBJECT_DETECTION:
params = ObjectDetectionEvaluator.load_yaml_evaluation_params()
elif task_type == TaskType.INSTANCE_SEGMENTATION:
params = InstanceSegmentationEvaluator.load_yaml_evaluation_params()
if task_type == sly.nn.TaskType.OBJECT_DETECTION:
params = sly.nn.benchmark.ObjectDetectionEvaluator.load_yaml_evaluation_params()
elif task_type == sly.nn.TaskType.INSTANCE_SEGMENTATION:
params = sly.nn.benchmark.InstanceSegmentationEvaluator.load_yaml_evaluation_params()
elif task_type == sly.nn.TaskType.SEMANTIC_SEGMENTATION:
params = "# Semantic Segmentation evaluation parameters are not available yet."
eval_params.set_text(params, language_mode="yaml")
eval_params_card.uncollapse()

Expand Down Expand Up @@ -269,9 +271,10 @@ def handle_sel_app_session(session_id: Optional[int]):
if g.session_id:
update_eval_params()


@all_datasets_checkbox.value_changed
def handle_all_datasets_checkbox(checked: bool):
if checked:
sel_dataset.hide()
else:
sel_dataset.show()
sel_dataset.show()

0 comments on commit 2d00efa

Please sign in to comment.