Skip to content

Commit

Permalink
add datasets selector
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolaiPetukhov committed Sep 28, 2024
1 parent a68dc5b commit c84e9be
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
3 changes: 2 additions & 1 deletion src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

workspace_id = sly.env.workspace_id()
project_id = sly.env.project_id(raise_not_found=False)
dataset_ids = None
team_id = sly.env.team_id()
task_id = sly.env.task_id(raise_not_found=False)
session_id = os.environ.get("modal.state.sessionId", None)
Expand All @@ -29,4 +30,4 @@
session = None
autostart = bool(strtobool(os.environ.get("modal.state.autoStart", "false")))

selected_classes = None
selected_classes = None
29 changes: 21 additions & 8 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional

import src.functions as f
import src.globals as g
Expand Down Expand Up @@ -35,6 +35,7 @@ def main_func():
bm = ObjectDetectionBenchmark(
api,
project.id,
gt_dataset_ids=g.dataset_ids,
output_dir=g.STORAGE_DIR + "/benchmark",
progress=pbar,
progress_secondary=sec_pbar,
Expand All @@ -44,6 +45,7 @@ def main_func():
bm = InstanceSegmentationBenchmark(
api,
project.id,
gt_dataset_ids=g.dataset_ids,
output_dir=g.STORAGE_DIR + "/benchmark",
progress=pbar,
progress_secondary=sec_pbar,
Expand Down Expand Up @@ -106,7 +108,14 @@ def main_func():
not_matched_text = widgets.Text(status="warning")

sel_app_session = widgets.SelectAppSession(g.team_id, tags=g.deployed_nn_tags, show_label=True)
sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id)
sel_dataset = widgets.SelectDataset(
default_id=None,
project_id=None,
multiselect=True,
select_all_datasets=True,
allowed_project_types=[sly.ProjectType.IMAGES],
)
# sel_project = widgets.SelectProject(default_id=None, workspace_id=g.workspace_id)

button = widgets.Button("Evaluate")
button.disable()
Expand All @@ -121,7 +130,7 @@ def main_func():
title="Settings",
description="Select Ground Truth project and deployed model session",
content=widgets.Container(
[sel_project, sel_app_session, button, report_model_benchmark, pbar, sec_pbar]
[sel_dataset, sel_app_session, button, report_model_benchmark, pbar, sec_pbar]
),
)

Expand Down Expand Up @@ -165,10 +174,14 @@ def handle_selectors(active: bool):
button.disable()


@sel_project.value_changed
def handle_sel_project(project_id: Optional[int]):
g.project_id = project_id
active = project_id is not None and g.session_id is not None
@sel_dataset.value_changed
def handle_sel_dataset(dataset_ids: List[int]):
g.project_id = sel_dataset.get_selected_project_id()
if sel_dataset._all_datasets_checkbox.is_checked():
g.dataset_ids = None
else:
g.dataset_ids = dataset_ids
active = g.project_id is not None and g.session_id is not None
handle_selectors(active)


Expand All @@ -189,7 +202,7 @@ def start_evaluation():
app = sly.Application(layout=main_layout, static_dir=g.STATIC_DIR)

if g.project_id:
sel_project.set_project_id(g.project_id)
sel_dataset.set_project_id(g.project_id)

if g.session_id:
sel_app_session.set_session_id(g.session_id)
Expand Down

0 comments on commit c84e9be

Please sign in to comment.