Skip to content

Commit

Permalink
temp fix: dummy progress to allow the app to receive requests
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Oct 25, 2024
1 parent d5f4668 commit 008ebc9
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 89 deletions.
15 changes: 15 additions & 0 deletions src/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,18 @@ def get_res_dir(eval_dirs: List[str]) -> str:
res_dir = g.api.file.get_free_dir_name(g.team_id, res_dir)

return res_dir


# ! temp fix (to allow the app to receive requests)
def with_clean_up_progress(pbar):
def decorator(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
finally:
with pbar(message="Application is started ...", total=1) as pb:
pb.update(1)

return wrapper

return decorator
61 changes: 30 additions & 31 deletions src/ui/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,38 @@
from supervisely._utils import rand_str
from supervisely.nn.benchmark.comparison.model_comparison import ModelComparison

compare_button = widgets.Button("Compare")
comp_pbar = widgets.SlyTqdm()
models_comparison_report = widgets.ReportThumbnail(
title="Models Comparison Report",
color="#ffc084",
bg_color="#fff2e6",
)
models_comparison_report.hide()
team_files_selector = widgets.TeamFilesSelector(
g.team_id,
multiple_selection=True,
selection_file_type="folder",
max_height=350,
initial_folder="/model-benchmark",
)

compare_contatiner = widgets.Container(
[
team_files_selector,
compare_button,
models_comparison_report,
comp_pbar,
]
)


@f.with_clean_up_progress(comp_pbar)
def run_compare(eval_dirs: List[str] = None):
workdir = g.STORAGE_DIR + "/model-comparison-" + rand_str(6)
team_files_selector.disable()
models_comparison_report.hide()
pbar.show()
comp_pbar.show()

g.eval_dirs = eval_dirs or team_files_selector.get_selected_paths()
f.validate_paths(g.eval_dirs)
Expand All @@ -21,10 +47,10 @@ def run_compare(eval_dirs: List[str] = None):
w.workflow_input(g.api, team_files_dirs=g.eval_dirs)
# =======================================================

comp = ModelComparison(g.api, g.eval_dirs, progress=pbar, workdir=workdir)
comp = ModelComparison(g.api, g.eval_dirs, progress=comp_pbar, workdir=workdir)
comp.visualize()
res_dir = f.get_res_dir(g.eval_dirs)
res_dir = comp.upload_results(g.team_id, remote_dir=res_dir, progress=pbar)
res_dir = comp.upload_results(g.team_id, remote_dir=res_dir, progress=comp_pbar)

report = g.api.file.get_info_by_path(g.team_id, comp.get_report_link())
g.api.task.set_output_report(g.task_id, report.id, report.name)
Expand All @@ -36,34 +62,7 @@ def run_compare(eval_dirs: List[str] = None):
w.workflow_output(g.api, model_comparison_report=report)
# =======================================================

pbar.hide()

comp_pbar.hide()
compare_button.loading = False

return res_dir


compare_button = widgets.Button("Compare")
pbar = widgets.SlyTqdm()
models_comparison_report = widgets.ReportThumbnail(
title="Models Comparison Report",
color="#ffc084",
bg_color="#fff2e6",
)
models_comparison_report.hide()
team_files_selector = widgets.TeamFilesSelector(
g.team_id,
multiple_selection=True,
selection_file_type="folder",
max_height=350,
initial_folder="/model-benchmark",
)

compare_contatiner = widgets.Container(
[
team_files_selector,
compare_button,
models_comparison_report,
pbar,
]
)
117 changes: 59 additions & 58 deletions src/ui/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,57 @@
from supervisely.nn.inference.session import SessionJSON


no_classes_label = widgets.Text(
"Not found any classes in the project that are present in the model", status="error"
)
no_classes_label.hide()
total_classes_text = widgets.Text(status="info")
selected_matched_text = widgets.Text(status="success")
not_matched_text = widgets.Text(status="warning")

sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True)
sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id)

eval_params = widgets.Editor(
initial_text=None,
language_mode="yaml",
height_lines=16,
)
eval_params_card = widgets.Card(
title="Evaluation parameters",
content=eval_params,
collapsable=True,
)
eval_params_card.collapse()


eval_button = widgets.Button("Evaluate")
eval_button.disable()

eval_pbar = widgets.SlyTqdm()
sec_eval_pbar = widgets.Progress("")

report_model_benchmark = widgets.ReportThumbnail()
report_model_benchmark.hide()

evaluation_container = widgets.Container(
[
sel_project,
sel_app_session,
eval_params_card,
eval_button,
report_model_benchmark,
eval_pbar,
sec_eval_pbar,
total_classes_text,
selected_matched_text,
not_matched_text,
no_classes_label,
]
)


@f.with_clean_up_progress(eval_pbar)
def run_evaluation(
session_id: Optional[int] = None,
project_id: Optional[int] = None,
Expand Down Expand Up @@ -49,8 +100,8 @@ def run_evaluation(
if g.selected_classes is None or len(g.selected_classes) == 0:
return

pbar.show()
sec_pbar.show()
eval_pbar.show()
sec_eval_pbar.show()

evaluation_params = eval_params.get_value() or params
if isinstance(evaluation_params, str):
Expand All @@ -64,8 +115,8 @@ def run_evaluation(
g.api,
project.id,
output_dir=work_dir,
progress=pbar,
progress_secondary=sec_pbar,
progress=eval_pbar,
progress_secondary=sec_eval_pbar,
classes_whitelist=g.selected_classes,
evaluation_params=evaluation_params,
)
Expand All @@ -77,8 +128,8 @@ def run_evaluation(
g.api,
project.id,
output_dir=work_dir,
progress=pbar,
progress_secondary=sec_pbar,
progress=eval_pbar,
progress_secondary=sec_eval_pbar,
classes_whitelist=g.selected_classes,
evaluation_params=evaluation_params,
)
Expand Down Expand Up @@ -107,7 +158,7 @@ def run_evaluation(
elif max_batch_size is not None:
batch_sizes = tuple([bs for bs in batch_sizes if bs <= max_batch_size])
bm.run_speedtest(g.session_id, g.project_id, batch_sizes=batch_sizes)
sec_pbar.hide()
sec_eval_pbar.hide()
bm.upload_speedtest_results(res_dir + "/speedtest/")
except Exception as e:
sly.logger.warning(f"Speedtest failed. Skipping. {e}")
Expand All @@ -125,7 +176,7 @@ def run_evaluation(
)
report_model_benchmark.set(template_vis_file)
report_model_benchmark.show()
pbar.hide()
eval_pbar.hide()

# ==================== Workflow output ====================
w.workflow_output(g.api, res_dir, template_vis_file)
Expand All @@ -145,56 +196,6 @@ def run_evaluation(
return res_dir


no_classes_label = widgets.Text(
"Not found any classes in the project that are present in the model", status="error"
)
no_classes_label.hide()
total_classes_text = widgets.Text(status="info")
selected_matched_text = widgets.Text(status="success")
not_matched_text = widgets.Text(status="warning")

sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True)
sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id)

eval_params = widgets.Editor(
initial_text=None,
language_mode="yaml",
height_lines=16,
)
eval_params_card = widgets.Card(
title="Evaluation parameters",
content=eval_params,
collapsable=True,
)
eval_params_card.collapse()


eval_button = widgets.Button("Evaluate")
eval_button.disable()

pbar = widgets.SlyTqdm()
sec_pbar = widgets.Progress("")

report_model_benchmark = widgets.ReportThumbnail()
report_model_benchmark.hide()

evaluation_container = widgets.Container(
[
sel_project,
sel_app_session,
eval_params_card,
eval_button,
report_model_benchmark,
pbar,
sec_pbar,
total_classes_text,
selected_matched_text,
not_matched_text,
no_classes_label,
]
)


def set_selected_classes_and_show_info():
matched, not_matched = f.get_classes()
_, matched_model_classes = matched
Expand Down

0 comments on commit 008ebc9

Please sign in to comment.