Skip to content

Commit

Permalink
hide class selection
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Aug 29, 2024
1 parent eb939fe commit 7a30cbb
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 53 deletions.
6 changes: 5 additions & 1 deletion src/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from supervisely.nn.inference import Session

geometry_to_task_type = {
"object detection": [sly.Rectangle, sly.Bitmap, sly.Polygon, sly.AlphaMask],
"object detection": [sly.Rectangle],
"instance segmentation": [sly.Bitmap, sly.Polygon, sly.AlphaMask],
# "semantic segmentation": [sly.Bitmap, sly.Polygon, sly.AlphaMask],
}
Expand Down Expand Up @@ -45,3 +45,7 @@ def get_classes():
not_matched_model_cls.append(obj_class)

return (matched_proj_cls, matched_model_cls), (not_matched_proj_cls, not_matched_model_cls)

def get_matched_class_names():
(cls_collection, _), _ = get_classes()
return [obj_cls.name for obj_cls in cls_collection]
2 changes: 2 additions & 0 deletions src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@
if session_id is not None:
session_id = int(session_id)
autostart = bool(strtobool(os.environ.get("modal.state.autoStart", "false")))

selected_classes = None
138 changes: 86 additions & 52 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@ def main_func():
pbar.show()
report_model_benchmark.hide()

selected_classes = select_classes.get_selected()
# selected_classes = select_classes.get_selected()

# matches = [x[1] for x in select_classes.get_stat()["match"]]
# selected_classes = [x[0] for x in selected_classes if x[0] in matches]
if g.selected_classes is None:
g.selected_classes = f.get_matched_class_names()

matches = [x[1] for x in select_classes.get_stat()["match"]]
selected_classes = [x[0] for x in selected_classes if x[0] in matches]
bm = ObjectDetectionBenchmark(
api,
project.id,
output_dir=g.STORAGE_DIR + "/benchmark",
progress=pbar,
classes_whitelist=selected_classes,
classes_whitelist=g.selected_classes,
)
sly.logger.info(f"{g.session_id = }")
bm.run_evaluation(model_session=g.session_id)
Expand All @@ -46,8 +49,6 @@ def main_func():
report = bm.upload_report_link(remote_dir)
api.task.set_output_report(g.task_id, report.id, report.name)

creating_report_f.hide()

template_vis_file = api.file.get_info_by_path(
sly.env.team_id(), eval_res_dir + "/visualizations/template.vue"
)
Expand All @@ -72,22 +73,24 @@ def main_func():
app.stop()


select_classes = widgets.MatchObjClasses(
selectable=True,
left_name="Model classes",
right_name="GT project classes",
)
select_classes.hide()
not_matched_classes = widgets.MatchObjClasses(
left_name="Model classes",
right_name="GT project classes",
)
not_matched_classes.hide()
# select_classes = widgets.MatchObjClasses(
# selectable=True,
# left_name="Model classes",
# right_name="GT project classes",
# )
# select_classes.hide()
# not_matched_classes = widgets.MatchObjClasses(
# left_name="Model classes",
# right_name="GT project classes",
# )
# not_matched_classes.hide()
no_classes_label = widgets.Text(
"Not found any classes in the project that are present in the model",
status="error",
"Not found any classes in the project that are present in the model", status="error"
)
no_classes_label.hide()
total_classes_text = widgets.Text(status="info")
selected_matched_text = widgets.Text(status="success")
not_matched_text = widgets.Text(status="warning")

sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True)
sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id)
Expand All @@ -99,46 +102,78 @@ def main_func():

report_model_benchmark = widgets.ReportThumbnail()
report_model_benchmark.hide()
creating_report_f = widgets.Field(widgets.Empty(), "", "Creating report on model...")
creating_report_f.hide()

controls_card = widgets.Card(
title="Settings",
description="Select Ground Truth project and deployed model session",
content=widgets.Container(
[sel_project, sel_app_session, button, report_model_benchmark, pbar]
),
)


# matched_card = widgets.Card(
# title="✅ Available classes",
# description="Select classes that are present in the model and in the project",
# content=select_classes,
# )
# matched_card.lock(message="Select project and model session to enable")

# not_matched_card = widgets.Card(
# title="❌ Not available classes",
# description="List of classes that are not matched between the model and the project",
# content=not_matched_classes,
# )
# not_matched_card.lock(message="Select project and model session to enable")


layout = widgets.Container(
widgets=[
widgets.Text("Select GT Project"),
sel_project,
sel_app_session,
button,
select_classes,
not_matched_classes,
no_classes_label,
creating_report_f,
report_model_benchmark,
pbar,
],
widgets=[controls_card, widgets.Empty(), widgets.Empty()], # , matched_card, not_matched_card],
direction="horizontal",
fractions=[1, 1, 1],
)

main_layout = widgets.Container(
widgets=[layout, total_classes_text, selected_matched_text, not_matched_text, no_classes_label]
)


def handle_selectors(active: bool):
no_classes_label.hide()
select_classes.hide()
button.loading = True
# select_classes.hide()
# not_matched_classes.hide()
if active:
metched, not_matched = f.get_classes()
project_classes, model_classes = metched
select_classes.set(left_collection=model_classes, right_collection=project_classes)

project_classes, model_classes = not_matched
not_matched_classes.set(left_collection=model_classes, right_collection=project_classes)
stats = select_classes.get_stat()
if len(stats["match"]) > 0:
select_classes.show()
not_matched_classes.show()
button.loading = False
matched, not_matched = f.get_classes()
_, matched_model = matched
_, not_matched_model = not_matched

g.selected_classes = [obj_cls.name for obj_cls in matched_model]
not_matched_classes = len(not_matched_model)
total_classes = len(matched_model) + len(not_matched_model)

total_classes_text.text = f"{total_classes} total classes available in the model."
selected_matched_text.text = f"{len(matched_model)} classes can be used for evaluation."
not_matched_text.text = f"{not_matched_classes} classes are not available for evaluation (not found in the GT project)."

if len(matched_model) > 0:
button.enable()
return
else:
no_classes_label.show()
button.loading = False
# select_classes.set(left_collection=model_classes, right_collection=matched_model)

# matched_model, model_classes = not_matched
# not_matched_classes.set(left_collection=model_classes, right_collection=matched_model)
# stats = select_classes.get_stat()
# if len(stats["match"]) > 0:
# select_classes.show()
# not_matched_classes.show()
# button.enable()
# # matched_card.unlock()
# # not_matched_card.unlock()
# return
# else:
# no_classes_label.show()
button.disable()


Expand All @@ -158,13 +193,12 @@ def handle_sel_app_session(session_id: Optional[int]):

@button.click
def start_evaluation():
creating_report_f.show()
select_classes.hide()
not_matched_classes.hide()
# select_classes.hide()
# not_matched_classes.hide()
main_func()


app = sly.Application(layout=layout, static_dir=g.STATIC_DIR)
app = sly.Application(layout=main_layout, static_dir=g.STATIC_DIR)

if g.project_id:
sel_project.set_project_id(g.project_id)
Expand All @@ -173,4 +207,4 @@ def start_evaluation():
sel_app_session.set_session_id(g.session_id)

if g.autostart:
start_evaluation()
start_evaluation()

0 comments on commit 7a30cbb

Please sign in to comment.