diff --git a/src/py/flwr/driver/driver.py b/src/py/flwr/driver/driver.py index f692b09567b8..496a49e5721b 100644 --- a/src/py/flwr/driver/driver.py +++ b/src/py/flwr/driver/driver.py @@ -16,7 +16,7 @@ from logging import ERROR, INFO, WARNING -from typing import Iterable, List, Optional, Set +from typing import Iterable, List, Optional, Tuple import grpc @@ -136,65 +136,47 @@ class Driver: def __init__(self) -> None: self.grpc_driver: Optional[GrpcDriver] = None self.workload_id: Optional[int] = None - self.task_id_pool: Set[str] = set() self.node = Node(node_id=0, anonymous=True) - def _check_and_init_grpc_driver(self) -> None: + def _get_grpc_driver_and_workload_id(self) -> Tuple[GrpcDriver, int]: # Check if the GrpcDriver is initialized - if self.grpc_driver is not None: - return + if self.grpc_driver is None or self.workload_id is None: + # Connect and create workload + self.grpc_driver = GrpcDriver() + self.grpc_driver.connect() + res = self.grpc_driver.create_workload(CreateWorkloadRequest()) + self.workload_id = res.workload_id - # Connect and create workload - self.grpc_driver = GrpcDriver() - self.grpc_driver.connect() - res = self.grpc_driver.create_workload(CreateWorkloadRequest()) - self.workload_id = res.workload_id + return self.grpc_driver, self.workload_id def get_nodes(self) -> List[Node]: """Get node IDs.""" - self._check_and_init_grpc_driver() + grpc_driver, workload_id = self._get_grpc_driver_and_workload_id() # Call GrpcDriver method - res = self.grpc_driver.get_nodes( # type: ignore - GetNodesRequest(workload_id=self.workload_id) # type: ignore - ) + res = grpc_driver.get_nodes(GetNodesRequest(workload_id=workload_id)) return list(res.nodes) - def push_task_ins(self, task_ins_list: Iterable[TaskIns]) -> List[str]: + def push_task_ins(self, task_ins_list: List[TaskIns]) -> List[str]: """Schedule tasks.""" - self._check_and_init_grpc_driver() + grpc_driver, workload_id = self._get_grpc_driver_and_workload_id() # Set workload_id for task_ins in task_ins_list: - task_ins.workload_id = self.workload_id # type: ignore + task_ins.workload_id = workload_id # Call GrpcDriver method - res = self.grpc_driver.push_task_ins( # type: ignore - PushTaskInsRequest(task_ins_list=task_ins_list) - ) - - # Cache received task_ids - self.task_id_pool.update(res.task_ids) + res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list)) return list(res.task_ids) - def pull_task_res(self, task_ids: Optional[Iterable[str]] = None) -> List[TaskRes]: - """Get task results. - - Retrieve all task results if `task_ids` is None. - """ - self._check_and_init_grpc_driver() - - # Check if task_ids is None - if task_ids is None: - task_ids = list(self.task_id_pool) + def pull_task_res(self, task_ids: Optional[Iterable[str]]) -> List[TaskRes]: + """Get task results.""" + grpc_driver, _ = self._get_grpc_driver_and_workload_id() # Call GrpcDriver method - res = self.grpc_driver.pull_task_res( # type: ignore + res = grpc_driver.pull_task_res( PullTaskResRequest(node=self.node, task_ids=task_ids) ) - self.task_id_pool.difference_update( - [task_res.task.ancestry[0] for task_res in res.task_res_list] - ) return list(res.task_res_list) def __del__(self) -> None: diff --git a/src/py/flwr/driver/driver_test.py b/src/py/flwr/driver/driver_test.py index cff068fdb546..820018788a8f 100644 --- a/src/py/flwr/driver/driver_test.py +++ b/src/py/flwr/driver/driver_test.py @@ -50,10 +50,11 @@ def test_check_and_init_grpc_driver_already_initialized(self) -> None: """Test that GrpcDriver doesn't initialize if workload is created.""" # Prepare self.driver.grpc_driver = self.mock_grpc_driver + self.driver.workload_id = 61016 # Execute # pylint: disable-next=protected-access - self.driver._check_and_init_grpc_driver() + self.driver._get_grpc_driver_and_workload_id() # Assert self.mock_grpc_driver.connect.assert_not_called() @@ -62,7 +63,7 @@ def test_check_and_init_grpc_driver_needs_initialization(self) -> None: """Test GrpcDriver initialization when workload is not created.""" # Execute # pylint: disable-next=protected-access - self.driver._check_and_init_grpc_driver() + self.driver._get_grpc_driver_and_workload_id() # Assert self.mock_grpc_driver.connect.assert_called_once() @@ -111,7 +112,6 @@ def test_push_task_ins(self) -> None: def test_pull_task_res_with_given_task_ids(self) -> None: """Test pulling task results with specific task IDs.""" # Prepare - self.driver.task_id_pool = {"id1", "id2", "id3", "id4"} mock_response = Mock() mock_response.task_res_list = [ TaskRes(task=Task(ancestry=["id2"])), @@ -131,38 +131,12 @@ def test_pull_task_res_with_given_task_ids(self) -> None: self.assertIsInstance(args[0], PullTaskResRequest) self.assertEqual(args[0].task_ids, task_ids) self.assertEqual(task_res_list, mock_response.task_res_list) - self.assertEqual(self.driver.task_id_pool, {"id1", "id4"}) - - def test_pull_task_res_without_given_task_ids(self) -> None: - """Test pulling all task results when no task IDs are provided.""" - # Prepare - mock_response = Mock() - mock_response.task_res_list = [ - TaskRes(task=Task(ancestry=["id1"])), - TaskRes(task=Task(ancestry=["id2"])), - ] - self.mock_grpc_driver.pull_task_res.return_value = mock_response - self.driver.task_id_pool = {"id1", "id2", "id3"} - task_ids = {"id1", "id2", "id3"} - - # Execute - task_res_list = self.driver.pull_task_res() - args, kwargs = self.mock_grpc_driver.pull_task_res.call_args - - # Assert - self.mock_grpc_driver.connect.assert_called_once() - self.assertEqual(len(args), 1) - self.assertEqual(len(kwargs), 0) - self.assertIsInstance(args[0], PullTaskResRequest) - self.assertEqual(set(args[0].task_ids), task_ids) - self.assertEqual(task_res_list, mock_response.task_res_list) - self.assertEqual(self.driver.task_id_pool, {"id3"}) def test_del_with_initialized_driver(self) -> None: """Test cleanup behavior when Driver is initialized.""" # Prepare # pylint: disable-next=protected-access - self.driver._check_and_init_grpc_driver() + self.driver._get_grpc_driver_and_workload_id() # Execute self.driver.__del__()