Skip to content

Commit

Permalink
refactor(framework) Make LinkState.delete_tasks method delete all p…
Browse files Browse the repository at this point in the history
…rovided tasks without conditions (#4691)

Co-authored-by: Chong Shen Ng <[email protected]>
  • Loading branch information
panh99 and chongshenng authored Dec 13, 2024
1 parent 22c1ca3 commit 5c27a00
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 112 deletions.
6 changes: 5 additions & 1 deletion src/py/flwr/server/driver/inmemory_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,11 @@ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
# Pull TaskRes
task_res_list = self.state.get_task_res(task_ids=msg_ids)
# Delete tasks in state
self.state.delete_tasks(msg_ids)
# Delete the TaskIns/TaskRes pairs if TaskRes is found
task_ins_ids_to_delete = {
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
}
self.state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
# Convert TaskRes to Message
msgs = [message_from_taskres(taskres) for taskres in task_res_list]
return msgs
Expand Down
24 changes: 6 additions & 18 deletions src/py/flwr/server/superlink/driver/serverappio_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,27 +190,15 @@ def PullTaskRes(
# Convert each task_id str to UUID
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}

# Register callback
def on_rpc_done() -> None:
log(
DEBUG,
"ServerAppIoServicer.PullTaskRes callback: delete TaskIns/TaskRes",
)

if context.is_active():
return
if context.code() != grpc.StatusCode.OK:
return

# Delete delivered TaskIns and TaskRes
state.delete_tasks(task_ids=task_ids)

context.add_callback(on_rpc_done)

# Read from state
task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)

context.set_code(grpc.StatusCode.OK)
# Delete the TaskIns/TaskRes pairs if TaskRes is found
task_ins_ids_to_delete = {
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
}
state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)

return PullTaskResResponse(task_res_list=task_res_list)

def GetRun(
Expand Down
50 changes: 12 additions & 38 deletions src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,33 +265,22 @@ def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
for task_res in task_res_found:
task_res.task.delivered_at = delivered_at

# Cleanup
self._force_delete_tasks_by_ids(set(ret.keys()))

return list(ret.values())

def delete_tasks(self, task_ids: set[UUID]) -> None:
"""Delete all delivered TaskIns/TaskRes pairs."""
task_ins_to_be_deleted: set[UUID] = set()
task_res_to_be_deleted: set[UUID] = set()
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
if not task_ins_ids:
return

with self.lock:
for task_ins_id in task_ids:
# Find the task_id of the matching task_res
for task_res_id, task_res in self.task_res_store.items():
if UUID(task_res.task.ancestry[0]) != task_ins_id:
continue
if task_res.task.delivered_at == "":
continue

task_ins_to_be_deleted.add(task_ins_id)
task_res_to_be_deleted.add(task_res_id)

for task_id in task_ins_to_be_deleted:
del self.task_ins_store[task_id]
del self.task_ins_id_to_task_res_id[task_id]
for task_id in task_res_to_be_deleted:
del self.task_res_store[task_id]
for task_id in task_ins_ids:
# Delete TaskIns
if task_id in self.task_ins_store:
del self.task_ins_store[task_id]
# Delete TaskRes
if task_id in self.task_ins_id_to_task_res_id:
task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
del self.task_res_store[task_res_id]

def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
"""Get all TaskIns IDs for the given run_id."""
Expand All @@ -303,21 +292,6 @@ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:

return task_id_list

def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
"""Delete tasks based on a set of TaskIns IDs."""
if not task_ids:
return

with self.lock:
for task_id in task_ids:
# Delete TaskIns
if task_id in self.task_ins_store:
del self.task_ins_store[task_id]
# Delete TaskRes
if task_id in self.task_ins_id_to_task_res_id:
task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
del self.task_res_store[task_res_id]

def num_task_ins(self) -> int:
"""Calculate the number of task_ins in store.
Expand Down
11 changes: 9 additions & 2 deletions src/py/flwr/server/superlink/linkstate/linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,15 @@ def num_task_res(self) -> int:
"""

@abc.abstractmethod
def delete_tasks(self, task_ids: set[UUID]) -> None:
"""Delete all delivered TaskIns/TaskRes pairs."""
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
Parameters
----------
task_ins_ids : set[UUID]
A set of TaskIns IDs. For each ID in the set, the corresponding
TaskIns and its associated TaskRes will be deleted.
"""

@abc.abstractmethod
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
Expand Down
14 changes: 13 additions & 1 deletion src/py/flwr/server/superlink/linkstate/linkstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,21 @@ def test_store_and_delete_tasks(self) -> None:
# Situation now:
# - State has three TaskIns, all of them delivered
# - State has two TaskRes, one of the delivered, the other not
assert state.num_task_ins() == 3
assert state.num_task_res() == 2

state.delete_tasks({task_id_0})
assert state.num_task_ins() == 2
assert state.num_task_res() == 1

state.delete_tasks({task_id_1})
assert state.num_task_ins() == 1
assert state.num_task_res() == 0

state.delete_tasks({task_id_2})
assert state.num_task_ins() == 0
assert state.num_task_res() == 0

def test_get_task_ids_from_run_id(self) -> None:
"""Test get_task_ids_from_run_id."""
# Prepare
Expand Down Expand Up @@ -993,7 +1004,8 @@ def test_get_task_res_expired_task_ins(self) -> None:
# Assert
assert len(task_res_list) == 1
assert task_res_list[0].task.HasField("error")
assert state.num_task_ins() == state.num_task_res() == 0
assert state.num_task_ins() == 1
assert state.num_task_res() == 0

def test_get_task_res_returns_empty_for_missing_taskins(self) -> None:
"""Test that get_task_res returns an empty result when the corresponding TaskIns
Expand Down
64 changes: 12 additions & 52 deletions src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,9 +566,6 @@ def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
data: list[Any] = [delivered_at] + task_res_ids
self.query(query, data)

# Cleanup
self._force_delete_tasks_by_ids(set(ret.keys()))

return list(ret.values())

def num_task_ins(self) -> int:
Expand All @@ -592,43 +589,32 @@ def num_task_res(self) -> int:
result: dict[str, int] = rows[0]
return result["num"]

def delete_tasks(self, task_ids: set[UUID]) -> None:
"""Delete all delivered TaskIns/TaskRes pairs."""
ids = list(task_ids)
if len(ids) == 0:
return None
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
if not task_ins_ids:
return
if self.conn is None:
raise AttributeError("LinkState not initialized")

placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
placeholders = ",".join(["?"] * len(task_ins_ids))
data = tuple(str(task_id) for task_id in task_ins_ids)

# 1. Query: Delete task_ins which have a delivered task_res
# Delete task_ins
query_1 = f"""
DELETE FROM task_ins
WHERE delivered_at != ''
AND task_id IN (
SELECT ancestry
FROM task_res
WHERE ancestry IN ({placeholders})
AND delivered_at != ''
);
WHERE task_id IN ({placeholders});
"""

# 2. Query: Delete delivered task_res to be run after 1. Query
# Delete task_res
query_2 = f"""
DELETE FROM task_res
WHERE ancestry IN ({placeholders})
AND delivered_at != '';
WHERE ancestry IN ({placeholders});
"""

if self.conn is None:
raise AttributeError("LinkState not intitialized")

with self.conn:
self.conn.execute(query_1, data)
self.conn.execute(query_2, data)

return None

def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
"""Get all TaskIns IDs for the given run_id."""
if self.conn is None:
Expand All @@ -648,32 +634,6 @@ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:

return {UUID(row["task_id"]) for row in rows}

def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
"""Delete tasks based on a set of TaskIns IDs."""
if not task_ids:
return
if self.conn is None:
raise AttributeError("LinkState not initialized")

placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}

# Delete task_ins
query_1 = f"""
DELETE FROM task_ins
WHERE task_id IN ({placeholders});
"""

# Delete task_res
query_2 = f"""
DELETE FROM task_res
WHERE ancestry IN ({placeholders});
"""

with self.conn:
self.conn.execute(query_1, data)
self.conn.execute(query_2, data)

def create_node(
self, ping_interval: float, public_key: Optional[bytes] = None
) -> int:
Expand Down

0 comments on commit 5c27a00

Please sign in to comment.