Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Oct 24, 2023
1 parent 1cb1c50 commit 428eef6
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 67 deletions.
56 changes: 19 additions & 37 deletions src/py/flwr/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


from logging import ERROR, INFO, WARNING
from typing import Iterable, List, Optional, Set
from typing import Iterable, List, Optional, Tuple

import grpc

Expand Down Expand Up @@ -136,65 +136,47 @@ class Driver:
def __init__(self) -> None:
self.grpc_driver: Optional[GrpcDriver] = None
self.workload_id: Optional[int] = None
self.task_id_pool: Set[str] = set()
self.node = Node(node_id=0, anonymous=True)

def _check_and_init_grpc_driver(self) -> None:
def _get_grpc_driver_and_workload_id(self) -> Tuple[GrpcDriver, int]:
# Check if the GrpcDriver is initialized
if self.grpc_driver is not None:
return
if self.grpc_driver is None or self.workload_id is None:
# Connect and create workload
self.grpc_driver = GrpcDriver()
self.grpc_driver.connect()
res = self.grpc_driver.create_workload(CreateWorkloadRequest())
self.workload_id = res.workload_id

# Connect and create workload
self.grpc_driver = GrpcDriver()
self.grpc_driver.connect()
res = self.grpc_driver.create_workload(CreateWorkloadRequest())
self.workload_id = res.workload_id
return self.grpc_driver, self.workload_id

def get_nodes(self) -> List[Node]:
"""Get node IDs."""
self._check_and_init_grpc_driver()
grpc_driver, workload_id = self._get_grpc_driver_and_workload_id()

# Call GrpcDriver method
res = self.grpc_driver.get_nodes( # type: ignore
GetNodesRequest(workload_id=self.workload_id) # type: ignore
)
res = grpc_driver.get_nodes(GetNodesRequest(workload_id=workload_id))
return list(res.nodes)

def push_task_ins(self, task_ins_list: Iterable[TaskIns]) -> List[str]:
def push_task_ins(self, task_ins_list: List[TaskIns]) -> List[str]:
"""Schedule tasks."""
self._check_and_init_grpc_driver()
grpc_driver, workload_id = self._get_grpc_driver_and_workload_id()

# Set workload_id
for task_ins in task_ins_list:
task_ins.workload_id = self.workload_id # type: ignore
task_ins.workload_id = workload_id

# Call GrpcDriver method
res = self.grpc_driver.push_task_ins( # type: ignore
PushTaskInsRequest(task_ins_list=task_ins_list)
)

# Cache received task_ids
self.task_id_pool.update(res.task_ids)
res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
return list(res.task_ids)

def pull_task_res(self, task_ids: Optional[Iterable[str]] = None) -> List[TaskRes]:
"""Get task results.
Retrieve all task results if `task_ids` is None.
"""
self._check_and_init_grpc_driver()

# Check if task_ids is None
if task_ids is None:
task_ids = list(self.task_id_pool)
def pull_task_res(self, task_ids: Optional[Iterable[str]]) -> List[TaskRes]:
"""Get task results."""
grpc_driver, _ = self._get_grpc_driver_and_workload_id()

# Call GrpcDriver method
res = self.grpc_driver.pull_task_res( # type: ignore
res = grpc_driver.pull_task_res(
PullTaskResRequest(node=self.node, task_ids=task_ids)
)
self.task_id_pool.difference_update(
[task_res.task.ancestry[0] for task_res in res.task_res_list]
)
return list(res.task_res_list)

def __del__(self) -> None:
Expand Down
34 changes: 4 additions & 30 deletions src/py/flwr/driver/driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ def test_check_and_init_grpc_driver_already_initialized(self) -> None:
"""Test that GrpcDriver doesn't initialize if workload is created."""
# Prepare
self.driver.grpc_driver = self.mock_grpc_driver
self.driver.workload_id = 61016

# Execute
# pylint: disable-next=protected-access
self.driver._check_and_init_grpc_driver()
self.driver._get_grpc_driver_and_workload_id()

# Assert
self.mock_grpc_driver.connect.assert_not_called()
Expand All @@ -62,7 +63,7 @@ def test_check_and_init_grpc_driver_needs_initialization(self) -> None:
"""Test GrpcDriver initialization when workload is not created."""
# Execute
# pylint: disable-next=protected-access
self.driver._check_and_init_grpc_driver()
self.driver._get_grpc_driver_and_workload_id()

# Assert
self.mock_grpc_driver.connect.assert_called_once()
Expand Down Expand Up @@ -111,7 +112,6 @@ def test_push_task_ins(self) -> None:
def test_pull_task_res_with_given_task_ids(self) -> None:
"""Test pulling task results with specific task IDs."""
# Prepare
self.driver.task_id_pool = {"id1", "id2", "id3", "id4"}
mock_response = Mock()
mock_response.task_res_list = [
TaskRes(task=Task(ancestry=["id2"])),
Expand All @@ -131,38 +131,12 @@ def test_pull_task_res_with_given_task_ids(self) -> None:
self.assertIsInstance(args[0], PullTaskResRequest)
self.assertEqual(args[0].task_ids, task_ids)
self.assertEqual(task_res_list, mock_response.task_res_list)
self.assertEqual(self.driver.task_id_pool, {"id1", "id4"})

def test_pull_task_res_without_given_task_ids(self) -> None:
"""Test pulling all task results when no task IDs are provided."""
# Prepare
mock_response = Mock()
mock_response.task_res_list = [
TaskRes(task=Task(ancestry=["id1"])),
TaskRes(task=Task(ancestry=["id2"])),
]
self.mock_grpc_driver.pull_task_res.return_value = mock_response
self.driver.task_id_pool = {"id1", "id2", "id3"}
task_ids = {"id1", "id2", "id3"}

# Execute
task_res_list = self.driver.pull_task_res()
args, kwargs = self.mock_grpc_driver.pull_task_res.call_args

# Assert
self.mock_grpc_driver.connect.assert_called_once()
self.assertEqual(len(args), 1)
self.assertEqual(len(kwargs), 0)
self.assertIsInstance(args[0], PullTaskResRequest)
self.assertEqual(set(args[0].task_ids), task_ids)
self.assertEqual(task_res_list, mock_response.task_res_list)
self.assertEqual(self.driver.task_id_pool, {"id3"})

def test_del_with_initialized_driver(self) -> None:
"""Test cleanup behavior when Driver is initialized."""
# Prepare
# pylint: disable-next=protected-access
self.driver._check_and_init_grpc_driver()
self.driver._get_grpc_driver_and_workload_id()

# Execute
self.driver.__del__()
Expand Down

0 comments on commit 428eef6

Please sign in to comment.