Skip to content

Commit

Permalink
SDK 6.73.208: Custom evaluation_params, key metrics in the report and…
Browse files Browse the repository at this point in the history
… JSON file (#10)


---------

Co-authored-by: Nikolai Petukhov <[email protected]>
  • Loading branch information
almazgimaev and NikolaiPetukhov authored Oct 10, 2024
1 parent bdcf333 commit 2b6bef6
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 9 deletions.
2 changes: 1 addition & 1 deletion config.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"task_location": "workspace_tasks",
"entrypoint": "python -m uvicorn src.main:app --host 0.0.0.0 --port 8000",
"port": 8000,
"docker_image": "supervisely/model-benchmark:0.0.10",
"docker_image": "supervisely/model-benchmark:1.0.13",
"instance_version": "6.11.19",
"context_menu": {
"target": ["images_project"]
Expand Down
2 changes: 1 addition & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# git+https://github.com/supervisely/supervisely.git@model-benchmark
supervisely[model-benchmark]==6.73.201
supervisely[model-benchmark]==6.73.208
6 changes: 3 additions & 3 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
FROM supervisely/base-py-sdk:6.73.201
FROM supervisely/base-py-sdk:6.73.208

RUN python3 -m pip install supervisely[model-benchmark]==6.73.201
RUN python3 -m pip install supervisely[model-benchmark]==6.73.208

LABEL python_sdk_version=6.73.201
LABEL python_sdk_version=6.73.208
57 changes: 53 additions & 4 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from typing import Optional

import yaml

import src.functions as f
import src.globals as g
import src.workflow as w
import supervisely as sly
import supervisely.app.widgets as widgets
from supervisely.nn import TaskType
from supervisely.nn.benchmark import (
InstanceSegmentationBenchmark,
ObjectDetectionBenchmark,
)
from supervisely.nn.benchmark.evaluation.base_evaluator import BaseEvaluator
from supervisely.nn.benchmark.evaluation.instance_segmentation_evaluator import (
InstanceSegmentationEvaluator,
)
from supervisely.nn.benchmark.evaluation.object_detection_evaluator import (
ObjectDetectionEvaluator,
)
from supervisely.nn.inference.session import SessionJSON


Expand All @@ -31,6 +41,7 @@ def main_func():

pbar.show()
sec_pbar.show()
evaluation_parameters = yaml.safe_load(eval_params.get_value())
if task_type == "object detection":
bm = ObjectDetectionBenchmark(
api,
Expand All @@ -39,6 +50,7 @@ def main_func():
progress=pbar,
progress_secondary=sec_pbar,
classes_whitelist=g.selected_classes,
evaluation_params=evaluation_parameters,
)
elif task_type == "instance segmentation":
bm = InstanceSegmentationBenchmark(
Expand All @@ -48,6 +60,7 @@ def main_func():
progress=pbar,
progress_secondary=sec_pbar,
classes_whitelist=g.selected_classes,
evaluation_params=evaluation_parameters,
)
sly.logger.info(f"{g.session_id = }")

Expand All @@ -58,7 +71,6 @@ def main_func():
res_dir = api.storage.get_free_dir_name(g.team_id, res_dir)

bm.run_evaluation(model_session=g.session_id)
bm.upload_eval_results(res_dir + "/evaluation/")

try:
batch_sizes = (1, 8, 16)
Expand All @@ -73,6 +85,8 @@ def main_func():
sly.logger.warn(f"Speedtest failed. Skipping. {e}")

bm.visualize()

bm.upload_eval_results(res_dir + "/evaluation/")
remote_dir = bm.upload_visualizations(res_dir + "/visualizations/")

report = bm.upload_report_link(remote_dir)
Expand Down Expand Up @@ -113,6 +127,19 @@ def main_func():
sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True)
sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id)

eval_params = widgets.Editor(
initial_text=None,
language_mode="yaml",
height_lines=16,
)
eval_params_card = widgets.Card(
title="Evaluation parameters",
content=eval_params,
collapsable=True,
)
eval_params_card.collapse()


button = widgets.Button("Evaluate")
button.disable()

Expand All @@ -126,7 +153,15 @@ def main_func():
title="Settings",
description="Select Ground Truth project and deployed model session",
content=widgets.Container(
[sel_project, sel_app_session, button, report_model_benchmark, pbar, sec_pbar]
[
sel_project,
sel_app_session,
eval_params_card,
button,
report_model_benchmark,
pbar,
sec_pbar,
]
),
)

Expand Down Expand Up @@ -160,6 +195,18 @@ def set_selected_classes_and_show_info():
no_classes_label.show()


def update_eval_params():
if g.session is None:
g.session = SessionJSON(g.api, g.session_id)
task_type = g.session.get_deploy_info()["task_type"]
if task_type == TaskType.OBJECT_DETECTION:
params = ObjectDetectionEvaluator.load_yaml_evaluation_params()
elif task_type == TaskType.INSTANCE_SEGMENTATION:
params = InstanceSegmentationEvaluator.load_yaml_evaluation_params()
eval_params.set_text(params, language_mode="yaml")
eval_params_card.uncollapse()


def handle_selectors(active: bool):
no_classes_label.hide()
selected_matched_text.hide()
Expand All @@ -183,11 +230,12 @@ def handle_sel_app_session(session_id: Optional[int]):
active = session_id is not None and g.project_id is not None
handle_selectors(active)

if g.session_id:
update_eval_params()


@button.click
def start_evaluation():
# select_classes.hide()
# not_matched_classes.hide()
main_func()


Expand All @@ -198,6 +246,7 @@ def start_evaluation():

if g.session_id:
sel_app_session.set_session_id(g.session_id)
update_eval_params()

if g.autostart:
start_evaluation()
Expand Down

0 comments on commit 2b6bef6

Please sign in to comment.