diff --git a/config.json b/config.json index 91f79c6..29b3ff6 100644 --- a/config.json +++ b/config.json @@ -11,7 +11,7 @@ "task_location": "workspace_tasks", "entrypoint": "python -m uvicorn src.main:app --host 0.0.0.0 --port 8000", "port": 8000, - "docker_image": "supervisely/model-benchmark:0.0.10", + "docker_image": "supervisely/model-benchmark:1.0.13", "instance_version": "6.11.19", "context_menu": { "target": ["images_project"] diff --git a/dev_requirements.txt b/dev_requirements.txt index a0390e8..19db2b7 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,2 +1,2 @@ # git+https://github.com/supervisely/supervisely.git@model-benchmark -supervisely[model-benchmark]==6.73.201 \ No newline at end of file +supervisely[model-benchmark]==6.73.208 \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index 3632a4a..b6cba3d 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,5 +1,5 @@ -FROM supervisely/base-py-sdk:6.73.201 +FROM supervisely/base-py-sdk:6.73.208 -RUN python3 -m pip install supervisely[model-benchmark]==6.73.201 +RUN python3 -m pip install supervisely[model-benchmark]==6.73.208 -LABEL python_sdk_version=6.73.201 \ No newline at end of file +LABEL python_sdk_version=6.73.208 \ No newline at end of file diff --git a/src/main.py b/src/main.py index f37c40f..1467f52 100644 --- a/src/main.py +++ b/src/main.py @@ -1,14 +1,24 @@ from typing import Optional +import yaml + import src.functions as f import src.globals as g import src.workflow as w import supervisely as sly import supervisely.app.widgets as widgets +from supervisely.nn import TaskType from supervisely.nn.benchmark import ( InstanceSegmentationBenchmark, ObjectDetectionBenchmark, ) +from supervisely.nn.benchmark.evaluation.base_evaluator import BaseEvaluator +from supervisely.nn.benchmark.evaluation.instance_segmentation_evaluator import ( + InstanceSegmentationEvaluator, +) +from supervisely.nn.benchmark.evaluation.object_detection_evaluator import ( + ObjectDetectionEvaluator, +) from supervisely.nn.inference.session import SessionJSON @@ -31,6 +41,7 @@ def main_func(): pbar.show() sec_pbar.show() + evaluation_parameters = yaml.safe_load(eval_params.get_value()) if task_type == "object detection": bm = ObjectDetectionBenchmark( api, @@ -39,6 +50,7 @@ def main_func(): progress=pbar, progress_secondary=sec_pbar, classes_whitelist=g.selected_classes, + evaluation_params=evaluation_parameters, ) elif task_type == "instance segmentation": bm = InstanceSegmentationBenchmark( @@ -48,6 +60,7 @@ def main_func(): progress=pbar, progress_secondary=sec_pbar, classes_whitelist=g.selected_classes, + evaluation_params=evaluation_parameters, ) sly.logger.info(f"{g.session_id = }") @@ -58,7 +71,6 @@ def main_func(): res_dir = api.storage.get_free_dir_name(g.team_id, res_dir) bm.run_evaluation(model_session=g.session_id) - bm.upload_eval_results(res_dir + "/evaluation/") try: batch_sizes = (1, 8, 16) @@ -73,6 +85,8 @@ def main_func(): sly.logger.warn(f"Speedtest failed. Skipping. {e}") bm.visualize() + + bm.upload_eval_results(res_dir + "/evaluation/") remote_dir = bm.upload_visualizations(res_dir + "/visualizations/") report = bm.upload_report_link(remote_dir) @@ -113,6 +127,19 @@ def main_func(): sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True) sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id) +eval_params = widgets.Editor( + initial_text=None, + language_mode="yaml", + height_lines=16, +) +eval_params_card = widgets.Card( + title="Evaluation parameters", + content=eval_params, + collapsable=True, +) +eval_params_card.collapse() + + button = widgets.Button("Evaluate") button.disable() @@ -126,7 +153,15 @@ def main_func(): title="Settings", description="Select Ground Truth project and deployed model session", content=widgets.Container( - [sel_project, sel_app_session, button, report_model_benchmark, pbar, sec_pbar] + [ + sel_project, + sel_app_session, + eval_params_card, + button, + report_model_benchmark, + pbar, + sec_pbar, + ] ), ) @@ -160,6 +195,18 @@ def set_selected_classes_and_show_info(): no_classes_label.show() +def update_eval_params(): + if g.session is None: + g.session = SessionJSON(g.api, g.session_id) + task_type = g.session.get_deploy_info()["task_type"] + if task_type == TaskType.OBJECT_DETECTION: + params = ObjectDetectionEvaluator.load_yaml_evaluation_params() + elif task_type == TaskType.INSTANCE_SEGMENTATION: + params = InstanceSegmentationEvaluator.load_yaml_evaluation_params() + eval_params.set_text(params, language_mode="yaml") + eval_params_card.uncollapse() + + def handle_selectors(active: bool): no_classes_label.hide() selected_matched_text.hide() @@ -183,11 +230,12 @@ def handle_sel_app_session(session_id: Optional[int]): active = session_id is not None and g.project_id is not None handle_selectors(active) + if g.session_id: + update_eval_params() + @button.click def start_evaluation(): - # select_classes.hide() - # not_matched_classes.hide() main_func() @@ -198,6 +246,7 @@ def start_evaluation(): if g.session_id: sel_app_session.set_session_id(g.session_id) + update_eval_params() if g.autostart: start_evaluation()