From 008ebc97d35ba9acdb723066d2125da985e7c1ad Mon Sep 17 00:00:00 2001 From: almaz Date: Fri, 25 Oct 2024 11:38:27 +0200 Subject: [PATCH] temp fix: dummy progress to allow the app to receive requests --- src/functions.py | 15 ++++++ src/ui/compare.py | 61 +++++++++++----------- src/ui/evaluation.py | 117 ++++++++++++++++++++++--------------------- 3 files changed, 104 insertions(+), 89 deletions(-) diff --git a/src/functions.py b/src/functions.py index 6c02a53..9580ad0 100644 --- a/src/functions.py +++ b/src/functions.py @@ -87,3 +87,18 @@ def get_res_dir(eval_dirs: List[str]) -> str: res_dir = g.api.file.get_free_dir_name(g.team_id, res_dir) return res_dir + + +# ! temp fix (to allow the app to receive requests) +def with_clean_up_progress(pbar): + def decorator(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + finally: + with pbar(message="Application is started ...", total=1) as pb: + pb.update(1) + + return wrapper + + return decorator diff --git a/src/ui/compare.py b/src/ui/compare.py index 8b19862..fd6cd81 100644 --- a/src/ui/compare.py +++ b/src/ui/compare.py @@ -7,12 +7,38 @@ from supervisely._utils import rand_str from supervisely.nn.benchmark.comparison.model_comparison import ModelComparison +compare_button = widgets.Button("Compare") +comp_pbar = widgets.SlyTqdm() +models_comparison_report = widgets.ReportThumbnail( + title="Models Comparison Report", + color="#ffc084", + bg_color="#fff2e6", +) +models_comparison_report.hide() +team_files_selector = widgets.TeamFilesSelector( + g.team_id, + multiple_selection=True, + selection_file_type="folder", + max_height=350, + initial_folder="/model-benchmark", +) + +compare_contatiner = widgets.Container( + [ + team_files_selector, + compare_button, + models_comparison_report, + comp_pbar, + ] +) + +@f.with_clean_up_progress(comp_pbar) def run_compare(eval_dirs: List[str] = None): workdir = g.STORAGE_DIR + "/model-comparison-" + rand_str(6) team_files_selector.disable() models_comparison_report.hide() - pbar.show() + comp_pbar.show() g.eval_dirs = eval_dirs or team_files_selector.get_selected_paths() f.validate_paths(g.eval_dirs) @@ -21,10 +47,10 @@ def run_compare(eval_dirs: List[str] = None): w.workflow_input(g.api, team_files_dirs=g.eval_dirs) # ======================================================= - comp = ModelComparison(g.api, g.eval_dirs, progress=pbar, workdir=workdir) + comp = ModelComparison(g.api, g.eval_dirs, progress=comp_pbar, workdir=workdir) comp.visualize() res_dir = f.get_res_dir(g.eval_dirs) - res_dir = comp.upload_results(g.team_id, remote_dir=res_dir, progress=pbar) + res_dir = comp.upload_results(g.team_id, remote_dir=res_dir, progress=comp_pbar) report = g.api.file.get_info_by_path(g.team_id, comp.get_report_link()) g.api.task.set_output_report(g.task_id, report.id, report.name) @@ -36,34 +62,7 @@ def run_compare(eval_dirs: List[str] = None): w.workflow_output(g.api, model_comparison_report=report) # ======================================================= - pbar.hide() - + comp_pbar.hide() compare_button.loading = False return res_dir - - -compare_button = widgets.Button("Compare") -pbar = widgets.SlyTqdm() -models_comparison_report = widgets.ReportThumbnail( - title="Models Comparison Report", - color="#ffc084", - bg_color="#fff2e6", -) -models_comparison_report.hide() -team_files_selector = widgets.TeamFilesSelector( - g.team_id, - multiple_selection=True, - selection_file_type="folder", - max_height=350, - initial_folder="/model-benchmark", -) - -compare_contatiner = widgets.Container( - [ - team_files_selector, - compare_button, - models_comparison_report, - pbar, - ] -) diff --git a/src/ui/evaluation.py b/src/ui/evaluation.py index 274e98a..8ac69ee 100644 --- a/src/ui/evaluation.py +++ b/src/ui/evaluation.py @@ -22,6 +22,57 @@ from supervisely.nn.inference.session import SessionJSON +no_classes_label = widgets.Text( + "Not found any classes in the project that are present in the model", status="error" +) +no_classes_label.hide() +total_classes_text = widgets.Text(status="info") +selected_matched_text = widgets.Text(status="success") +not_matched_text = widgets.Text(status="warning") + +sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True) +sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id) + +eval_params = widgets.Editor( + initial_text=None, + language_mode="yaml", + height_lines=16, +) +eval_params_card = widgets.Card( + title="Evaluation parameters", + content=eval_params, + collapsable=True, +) +eval_params_card.collapse() + + +eval_button = widgets.Button("Evaluate") +eval_button.disable() + +eval_pbar = widgets.SlyTqdm() +sec_eval_pbar = widgets.Progress("") + +report_model_benchmark = widgets.ReportThumbnail() +report_model_benchmark.hide() + +evaluation_container = widgets.Container( + [ + sel_project, + sel_app_session, + eval_params_card, + eval_button, + report_model_benchmark, + eval_pbar, + sec_eval_pbar, + total_classes_text, + selected_matched_text, + not_matched_text, + no_classes_label, + ] +) + + +@f.with_clean_up_progress(eval_pbar) def run_evaluation( session_id: Optional[int] = None, project_id: Optional[int] = None, @@ -49,8 +100,8 @@ def run_evaluation( if g.selected_classes is None or len(g.selected_classes) == 0: return - pbar.show() - sec_pbar.show() + eval_pbar.show() + sec_eval_pbar.show() evaluation_params = eval_params.get_value() or params if isinstance(evaluation_params, str): @@ -64,8 +115,8 @@ def run_evaluation( g.api, project.id, output_dir=work_dir, - progress=pbar, - progress_secondary=sec_pbar, + progress=eval_pbar, + progress_secondary=sec_eval_pbar, classes_whitelist=g.selected_classes, evaluation_params=evaluation_params, ) @@ -77,8 +128,8 @@ def run_evaluation( g.api, project.id, output_dir=work_dir, - progress=pbar, - progress_secondary=sec_pbar, + progress=eval_pbar, + progress_secondary=sec_eval_pbar, classes_whitelist=g.selected_classes, evaluation_params=evaluation_params, ) @@ -107,7 +158,7 @@ def run_evaluation( elif max_batch_size is not None: batch_sizes = tuple([bs for bs in batch_sizes if bs <= max_batch_size]) bm.run_speedtest(g.session_id, g.project_id, batch_sizes=batch_sizes) - sec_pbar.hide() + sec_eval_pbar.hide() bm.upload_speedtest_results(res_dir + "/speedtest/") except Exception as e: sly.logger.warning(f"Speedtest failed. Skipping. {e}") @@ -125,7 +176,7 @@ def run_evaluation( ) report_model_benchmark.set(template_vis_file) report_model_benchmark.show() - pbar.hide() + eval_pbar.hide() # ==================== Workflow output ==================== w.workflow_output(g.api, res_dir, template_vis_file) @@ -145,56 +196,6 @@ def run_evaluation( return res_dir -no_classes_label = widgets.Text( - "Not found any classes in the project that are present in the model", status="error" -) -no_classes_label.hide() -total_classes_text = widgets.Text(status="info") -selected_matched_text = widgets.Text(status="success") -not_matched_text = widgets.Text(status="warning") - -sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True) -sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id) - -eval_params = widgets.Editor( - initial_text=None, - language_mode="yaml", - height_lines=16, -) -eval_params_card = widgets.Card( - title="Evaluation parameters", - content=eval_params, - collapsable=True, -) -eval_params_card.collapse() - - -eval_button = widgets.Button("Evaluate") -eval_button.disable() - -pbar = widgets.SlyTqdm() -sec_pbar = widgets.Progress("") - -report_model_benchmark = widgets.ReportThumbnail() -report_model_benchmark.hide() - -evaluation_container = widgets.Container( - [ - sel_project, - sel_app_session, - eval_params_card, - eval_button, - report_model_benchmark, - pbar, - sec_pbar, - total_classes_text, - selected_matched_text, - not_matched_text, - no_classes_label, - ] -) - - def set_selected_classes_and_show_info(): matched, not_matched = f.get_classes() _, matched_model_classes = matched