From d23d1199e83504ab83cf186505748830959a03a9 Mon Sep 17 00:00:00 2001 From: almaz Date: Mon, 16 Dec 2024 15:25:41 +0100 Subject: [PATCH] add dataset_ids argument (when running from script) --- local.env | 2 +- src/main.py | 5 ++++- src/ui/evaluation.py | 17 +++++++++-------- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/local.env b/local.env index abda3fa..66cb857 100644 --- a/local.env +++ b/local.env @@ -3,5 +3,5 @@ WORKSPACE_ID = 680 # PROJECT_ID = 41021 SLY_APP_DATA_DIR = "APP_DATA" -TASK_ID = 68257 +TASK_ID = 68088 # modal.state.sessionId=66693 \ No newline at end of file diff --git a/src/main.py b/src/main.py index a5b45fc..0d3f5cb 100644 --- a/src/main.py +++ b/src/main.py @@ -41,7 +41,10 @@ async def evaluate(request: Request): req = await request.json() try: state = req["state"] - return {"data": run_evaluation(state["session_id"], state["project_id"])} + session_id = state["session_id"] + project_id = state["project_id"] + dataset_ids = state.get("dataset_ids", None) + return {"data": run_evaluation(session_id, project_id, dataset_ids=dataset_ids)} except Exception as e: sly.logger.error(f"Error during model evaluation: {e}") return {"error": str(e)} diff --git a/src/ui/evaluation.py b/src/ui/evaluation.py index 0390eb8..986ae0e 100644 --- a/src/ui/evaluation.py +++ b/src/ui/evaluation.py @@ -113,6 +113,7 @@ def run_evaluation( session_id: Optional[int] = None, project_id: Optional[int] = None, params: Optional[Union[str, Dict]] = None, + dataset_ids: Optional[Tuple[int]] = None, ): work_dir = g.STORAGE_DIR + "/benchmark_" + sly.rand_str(6) @@ -124,12 +125,13 @@ def run_evaluation( g.session = SessionJSON(g.api, g.session_id) task_type = g.session.get_deploy_info()["task_type"] - if all_datasets_checkbox.is_checked(): - dataset_ids = None - else: - dataset_ids = sel_dataset.get_selected_ids() - if len(dataset_ids) == 0: - raise ValueError("No datasets selected") + if dataset_ids is None: + if all_datasets_checkbox.is_checked(): + dataset_ids = None + else: + dataset_ids = sel_dataset.get_selected_ids() + if len(dataset_ids) == 0: + raise ValueError("No datasets selected") # ==================== Workflow input ==================== w.workflow_input(g.api, project, g.session_id) @@ -239,8 +241,7 @@ def set_selected_classes_and_show_info(): def update_eval_params(): - if g.session is None: - g.session = SessionJSON(g.api, g.session_id) + g.session = SessionJSON(g.api, g.session_id) task_type = g.session.get_deploy_info()["task_type"] if task_type == sly.nn.TaskType.OBJECT_DETECTION: params = ObjectDetectionEvaluator.load_yaml_evaluation_params()