diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index d9189002e5ad..bd534565049a 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -142,7 +142,11 @@ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: # Pull TaskRes task_res_list = self.state.get_task_res(task_ids=msg_ids) # Delete tasks in state - self.state.delete_tasks(msg_ids) + # Delete the TaskIns/TaskRes pairs if TaskRes is found + task_ins_ids_to_delete = { + UUID(task_res.task.ancestry[0]) for task_res in task_res_list + } + self.state.delete_tasks(task_ins_ids=task_ins_ids_to_delete) # Convert TaskRes to Message msgs = [message_from_taskres(taskres) for taskres in task_res_list] return msgs diff --git a/src/py/flwr/server/superlink/driver/serverappio_servicer.py b/src/py/flwr/server/superlink/driver/serverappio_servicer.py index a3d06ac222de..f52129a2ba11 100644 --- a/src/py/flwr/server/superlink/driver/serverappio_servicer.py +++ b/src/py/flwr/server/superlink/driver/serverappio_servicer.py @@ -190,27 +190,15 @@ def PullTaskRes( # Convert each task_id str to UUID task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids} - # Register callback - def on_rpc_done() -> None: - log( - DEBUG, - "ServerAppIoServicer.PullTaskRes callback: delete TaskIns/TaskRes", - ) - - if context.is_active(): - return - if context.code() != grpc.StatusCode.OK: - return - - # Delete delivered TaskIns and TaskRes - state.delete_tasks(task_ids=task_ids) - - context.add_callback(on_rpc_done) - # Read from state task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids) - context.set_code(grpc.StatusCode.OK) + # Delete the TaskIns/TaskRes pairs if TaskRes is found + task_ins_ids_to_delete = { + UUID(task_res.task.ancestry[0]) for task_res in task_res_list + } + state.delete_tasks(task_ins_ids=task_ins_ids_to_delete) + return PullTaskResResponse(task_res_list=task_res_list) def GetRun( diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index 6d719a7dd377..f26bb11a4bdb 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -265,33 +265,22 @@ def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]: for task_res in task_res_found: task_res.task.delivered_at = delivered_at - # Cleanup - self._force_delete_tasks_by_ids(set(ret.keys())) - return list(ret.values()) - def delete_tasks(self, task_ids: set[UUID]) -> None: - """Delete all delivered TaskIns/TaskRes pairs.""" - task_ins_to_be_deleted: set[UUID] = set() - task_res_to_be_deleted: set[UUID] = set() + def delete_tasks(self, task_ins_ids: set[UUID]) -> None: + """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.""" + if not task_ins_ids: + return with self.lock: - for task_ins_id in task_ids: - # Find the task_id of the matching task_res - for task_res_id, task_res in self.task_res_store.items(): - if UUID(task_res.task.ancestry[0]) != task_ins_id: - continue - if task_res.task.delivered_at == "": - continue - - task_ins_to_be_deleted.add(task_ins_id) - task_res_to_be_deleted.add(task_res_id) - - for task_id in task_ins_to_be_deleted: - del self.task_ins_store[task_id] - del self.task_ins_id_to_task_res_id[task_id] - for task_id in task_res_to_be_deleted: - del self.task_res_store[task_id] + for task_id in task_ins_ids: + # Delete TaskIns + if task_id in self.task_ins_store: + del self.task_ins_store[task_id] + # Delete TaskRes + if task_id in self.task_ins_id_to_task_res_id: + task_res_id = self.task_ins_id_to_task_res_id.pop(task_id) + del self.task_res_store[task_res_id] def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]: """Get all TaskIns IDs for the given run_id.""" @@ -303,21 +292,6 @@ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]: return task_id_list - def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None: - """Delete tasks based on a set of TaskIns IDs.""" - if not task_ids: - return - - with self.lock: - for task_id in task_ids: - # Delete TaskIns - if task_id in self.task_ins_store: - del self.task_ins_store[task_id] - # Delete TaskRes - if task_id in self.task_ins_id_to_task_res_id: - task_res_id = self.task_ins_id_to_task_res_id.pop(task_id) - del self.task_res_store[task_res_id] - def num_task_ins(self) -> int: """Calculate the number of task_ins in store. diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index c5ab7efa8cf2..ae9d1710f069 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -139,8 +139,15 @@ def num_task_res(self) -> int: """ @abc.abstractmethod - def delete_tasks(self, task_ids: set[UUID]) -> None: - """Delete all delivered TaskIns/TaskRes pairs.""" + def delete_tasks(self, task_ins_ids: set[UUID]) -> None: + """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs. + + Parameters + ---------- + task_ins_ids : set[UUID] + A set of TaskIns IDs. For each ID in the set, the corresponding + TaskIns and its associated TaskRes will be deleted. + """ @abc.abstractmethod def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]: diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 15b97ee1a0a1..93f5d94daef7 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -349,10 +349,21 @@ def test_store_and_delete_tasks(self) -> None: # Situation now: # - State has three TaskIns, all of them delivered # - State has two TaskRes, one of the delivered, the other not + assert state.num_task_ins() == 3 + assert state.num_task_res() == 2 + state.delete_tasks({task_id_0}) assert state.num_task_ins() == 2 assert state.num_task_res() == 1 + state.delete_tasks({task_id_1}) + assert state.num_task_ins() == 1 + assert state.num_task_res() == 0 + + state.delete_tasks({task_id_2}) + assert state.num_task_ins() == 0 + assert state.num_task_res() == 0 + def test_get_task_ids_from_run_id(self) -> None: """Test get_task_ids_from_run_id.""" # Prepare @@ -993,7 +1004,8 @@ def test_get_task_res_expired_task_ins(self) -> None: # Assert assert len(task_res_list) == 1 assert task_res_list[0].task.HasField("error") - assert state.num_task_ins() == state.num_task_res() == 0 + assert state.num_task_ins() == 1 + assert state.num_task_res() == 0 def test_get_task_res_returns_empty_for_missing_taskins(self) -> None: """Test that get_task_res returns an empty result when the corresponding TaskIns diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index 8e4043582d14..bdb3bcf6c4db 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -566,9 +566,6 @@ def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]: data: list[Any] = [delivered_at] + task_res_ids self.query(query, data) - # Cleanup - self._force_delete_tasks_by_ids(set(ret.keys())) - return list(ret.values()) def num_task_ins(self) -> int: @@ -592,43 +589,32 @@ def num_task_res(self) -> int: result: dict[str, int] = rows[0] return result["num"] - def delete_tasks(self, task_ids: set[UUID]) -> None: - """Delete all delivered TaskIns/TaskRes pairs.""" - ids = list(task_ids) - if len(ids) == 0: - return None + def delete_tasks(self, task_ins_ids: set[UUID]) -> None: + """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.""" + if not task_ins_ids: + return + if self.conn is None: + raise AttributeError("LinkState not initialized") - placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))]) - data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)} + placeholders = ",".join(["?"] * len(task_ins_ids)) + data = tuple(str(task_id) for task_id in task_ins_ids) - # 1. Query: Delete task_ins which have a delivered task_res + # Delete task_ins query_1 = f""" DELETE FROM task_ins - WHERE delivered_at != '' - AND task_id IN ( - SELECT ancestry - FROM task_res - WHERE ancestry IN ({placeholders}) - AND delivered_at != '' - ); + WHERE task_id IN ({placeholders}); """ - # 2. Query: Delete delivered task_res to be run after 1. Query + # Delete task_res query_2 = f""" DELETE FROM task_res - WHERE ancestry IN ({placeholders}) - AND delivered_at != ''; + WHERE ancestry IN ({placeholders}); """ - if self.conn is None: - raise AttributeError("LinkState not intitialized") - with self.conn: self.conn.execute(query_1, data) self.conn.execute(query_2, data) - return None - def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]: """Get all TaskIns IDs for the given run_id.""" if self.conn is None: @@ -648,32 +634,6 @@ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]: return {UUID(row["task_id"]) for row in rows} - def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None: - """Delete tasks based on a set of TaskIns IDs.""" - if not task_ids: - return - if self.conn is None: - raise AttributeError("LinkState not initialized") - - placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))]) - data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)} - - # Delete task_ins - query_1 = f""" - DELETE FROM task_ins - WHERE task_id IN ({placeholders}); - """ - - # Delete task_res - query_2 = f""" - DELETE FROM task_res - WHERE ancestry IN ({placeholders}); - """ - - with self.conn: - self.conn.execute(query_1, data) - self.conn.execute(query_2, data) - def create_node( self, ping_interval: float, public_key: Optional[bytes] = None ) -> int: