Skip to content

Commit

Permalink
add dataset_ids argument (when running from script)
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Dec 16, 2024
1 parent 5dfe28e commit d23d119
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
2 changes: 1 addition & 1 deletion local.env
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ WORKSPACE_ID = 680
# PROJECT_ID = 41021
SLY_APP_DATA_DIR = "APP_DATA"

TASK_ID = 68257
TASK_ID = 68088
# modal.state.sessionId=66693
5 changes: 4 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ async def evaluate(request: Request):
req = await request.json()
try:
state = req["state"]
return {"data": run_evaluation(state["session_id"], state["project_id"])}
session_id = state["session_id"]
project_id = state["project_id"]
dataset_ids = state.get("dataset_ids", None)
return {"data": run_evaluation(session_id, project_id, dataset_ids=dataset_ids)}
except Exception as e:
sly.logger.error(f"Error during model evaluation: {e}")
return {"error": str(e)}
Expand Down
17 changes: 9 additions & 8 deletions src/ui/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def run_evaluation(
session_id: Optional[int] = None,
project_id: Optional[int] = None,
params: Optional[Union[str, Dict]] = None,
dataset_ids: Optional[Tuple[int]] = None,
):
work_dir = g.STORAGE_DIR + "/benchmark_" + sly.rand_str(6)

Expand All @@ -124,12 +125,13 @@ def run_evaluation(
g.session = SessionJSON(g.api, g.session_id)
task_type = g.session.get_deploy_info()["task_type"]

if all_datasets_checkbox.is_checked():
dataset_ids = None
else:
dataset_ids = sel_dataset.get_selected_ids()
if len(dataset_ids) == 0:
raise ValueError("No datasets selected")
if dataset_ids is None:
if all_datasets_checkbox.is_checked():
dataset_ids = None
else:
dataset_ids = sel_dataset.get_selected_ids()
if len(dataset_ids) == 0:
raise ValueError("No datasets selected")

# ==================== Workflow input ====================
w.workflow_input(g.api, project, g.session_id)
Expand Down Expand Up @@ -239,8 +241,7 @@ def set_selected_classes_and_show_info():


def update_eval_params():
if g.session is None:
g.session = SessionJSON(g.api, g.session_id)
g.session = SessionJSON(g.api, g.session_id)
task_type = g.session.get_deploy_info()["task_type"]
if task_type == sly.nn.TaskType.OBJECT_DETECTION:
params = ObjectDetectionEvaluator.load_yaml_evaluation_params()
Expand Down

0 comments on commit d23d119

Please sign in to comment.