From dd4779f2919aa9ce64b9f60f0ce9c7d699821621 Mon Sep 17 00:00:00 2001 From: Heng Pan <134433891+panh99@users.noreply.github.com> Date: Tue, 24 Oct 2023 22:04:30 +0100 Subject: [PATCH] Add `Driver` class (#2531) Co-authored-by: Daniel J. Beutel --- src/py/flwr/driver/driver.py | 65 +++++++++++++- src/py/flwr/driver/driver_test.py | 138 ++++++++++++++++++++++++++++++ 2 files changed, 201 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/driver/driver.py b/src/py/flwr/driver/driver.py index 4b189b9ce290..3fb4ac346ccf 100644 --- a/src/py/flwr/driver/driver.py +++ b/src/py/flwr/driver/driver.py @@ -16,7 +16,7 @@ from logging import ERROR, INFO, WARNING -from typing import Optional +from typing import Iterable, List, Optional, Tuple import grpc @@ -34,6 +34,8 @@ PushTaskInsResponse, ) from flwr.proto.driver_pb2_grpc import DriverStub +from flwr.proto.node_pb2 import Node +from flwr.proto.task_pb2 import TaskIns, TaskRes DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" @@ -46,7 +48,7 @@ class GrpcDriver: - """`GrpcDriver` provides access to the Driver API/service.""" + """`GrpcDriver` provides access to the gRPC Driver API/service.""" def __init__( self, @@ -126,3 +128,62 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: # Call Driver API res: PullTaskResResponse = self.stub.PullTaskRes(request=req) return res + + +class Driver: + """`Driver` class provides an interface to the Driver API.""" + + def __init__(self) -> None: + self.grpc_driver: Optional[GrpcDriver] = None + self.workload_id: Optional[int] = None + self.node = Node(node_id=0, anonymous=True) + + def _get_grpc_driver_and_workload_id(self) -> Tuple[GrpcDriver, int]: + # Check if the GrpcDriver is initialized + if self.grpc_driver is None or self.workload_id is None: + # Connect and create workload + self.grpc_driver = GrpcDriver() + self.grpc_driver.connect() + res = self.grpc_driver.create_workload(CreateWorkloadRequest()) + self.workload_id = res.workload_id + + return self.grpc_driver, self.workload_id + + def get_nodes(self) -> List[Node]: + """Get node IDs.""" + grpc_driver, workload_id = self._get_grpc_driver_and_workload_id() + + # Call GrpcDriver method + res = grpc_driver.get_nodes(GetNodesRequest(workload_id=workload_id)) + return list(res.nodes) + + def push_task_ins(self, task_ins_list: List[TaskIns]) -> List[str]: + """Schedule tasks.""" + grpc_driver, workload_id = self._get_grpc_driver_and_workload_id() + + # Set workload_id + for task_ins in task_ins_list: + task_ins.workload_id = workload_id + + # Call GrpcDriver method + res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list)) + return list(res.task_ids) + + def pull_task_res(self, task_ids: Iterable[str]) -> List[TaskRes]: + """Get task results.""" + grpc_driver, _ = self._get_grpc_driver_and_workload_id() + + # Call GrpcDriver method + res = grpc_driver.pull_task_res( + PullTaskResRequest(node=self.node, task_ids=task_ids) + ) + return list(res.task_res_list) + + def __del__(self) -> None: + """Disconnect GrpcDriver if connected.""" + # Check if GrpcDriver is initialized + if self.grpc_driver is None: + return + + # Disconnect + self.grpc_driver.disconnect() diff --git a/src/py/flwr/driver/driver_test.py b/src/py/flwr/driver/driver_test.py index ef2a17e8538d..820018788a8f 100644 --- a/src/py/flwr/driver/driver_test.py +++ b/src/py/flwr/driver/driver_test.py @@ -13,3 +13,141 @@ # limitations under the License. # ============================================================================== """Tests for driver SDK.""" + + +import unittest +from unittest.mock import Mock, patch + +from flwr.driver.driver import Driver +from flwr.proto.driver_pb2 import ( + GetNodesRequest, + PullTaskResRequest, + PushTaskInsRequest, +) +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes + + +class TestDriver(unittest.TestCase): + """Tests for `Driver` class.""" + + def setUp(self) -> None: + """Initialize mock GrpcDriver and Driver instance before each test.""" + mock_response = Mock() + mock_response.workload_id = 61016 + self.mock_grpc_driver = Mock() + self.mock_grpc_driver.create_workload.return_value = mock_response + self.patcher = patch( + "flwr.driver.driver.GrpcDriver", return_value=self.mock_grpc_driver + ) + self.patcher.start() + self.driver = Driver() + + def tearDown(self) -> None: + """Cleanup after each test.""" + self.patcher.stop() + + def test_check_and_init_grpc_driver_already_initialized(self) -> None: + """Test that GrpcDriver doesn't initialize if workload is created.""" + # Prepare + self.driver.grpc_driver = self.mock_grpc_driver + self.driver.workload_id = 61016 + + # Execute + # pylint: disable-next=protected-access + self.driver._get_grpc_driver_and_workload_id() + + # Assert + self.mock_grpc_driver.connect.assert_not_called() + + def test_check_and_init_grpc_driver_needs_initialization(self) -> None: + """Test GrpcDriver initialization when workload is not created.""" + # Execute + # pylint: disable-next=protected-access + self.driver._get_grpc_driver_and_workload_id() + + # Assert + self.mock_grpc_driver.connect.assert_called_once() + self.assertEqual(self.driver.workload_id, 61016) + + def test_get_nodes(self) -> None: + """Test retrieval of nodes.""" + # Prepare + mock_response = Mock() + mock_response.nodes = [Mock(), Mock()] + self.mock_grpc_driver.get_nodes.return_value = mock_response + + # Execute + nodes = self.driver.get_nodes() + args, kwargs = self.mock_grpc_driver.get_nodes.call_args + + # Assert + self.mock_grpc_driver.connect.assert_called_once() + self.assertEqual(len(args), 1) + self.assertEqual(len(kwargs), 0) + self.assertIsInstance(args[0], GetNodesRequest) + self.assertEqual(args[0].workload_id, 61016) + self.assertEqual(nodes, mock_response.nodes) + + def test_push_task_ins(self) -> None: + """Test pushing task instructions.""" + # Prepare + mock_response = Mock() + mock_response.task_ids = ["id1", "id2"] + self.mock_grpc_driver.push_task_ins.return_value = mock_response + task_ins_list = [TaskIns(), TaskIns()] + + # Execute + task_ids = self.driver.push_task_ins(task_ins_list) + args, kwargs = self.mock_grpc_driver.push_task_ins.call_args + + # Assert + self.mock_grpc_driver.connect.assert_called_once() + self.assertEqual(len(args), 1) + self.assertEqual(len(kwargs), 0) + self.assertIsInstance(args[0], PushTaskInsRequest) + self.assertEqual(task_ids, mock_response.task_ids) + for task_ins in args[0].task_ins_list: + self.assertEqual(task_ins.workload_id, 61016) + + def test_pull_task_res_with_given_task_ids(self) -> None: + """Test pulling task results with specific task IDs.""" + # Prepare + mock_response = Mock() + mock_response.task_res_list = [ + TaskRes(task=Task(ancestry=["id2"])), + TaskRes(task=Task(ancestry=["id3"])), + ] + self.mock_grpc_driver.pull_task_res.return_value = mock_response + task_ids = ["id1", "id2", "id3"] + + # Execute + task_res_list = self.driver.pull_task_res(task_ids) + args, kwargs = self.mock_grpc_driver.pull_task_res.call_args + + # Assert + self.mock_grpc_driver.connect.assert_called_once() + self.assertEqual(len(args), 1) + self.assertEqual(len(kwargs), 0) + self.assertIsInstance(args[0], PullTaskResRequest) + self.assertEqual(args[0].task_ids, task_ids) + self.assertEqual(task_res_list, mock_response.task_res_list) + + def test_del_with_initialized_driver(self) -> None: + """Test cleanup behavior when Driver is initialized.""" + # Prepare + # pylint: disable-next=protected-access + self.driver._get_grpc_driver_and_workload_id() + + # Execute + self.driver.__del__() + + # Assert + self.mock_grpc_driver.disconnect.assert_called_once() + + def test_del_with_uninitialized_driver(self) -> None: + """Test cleanup behavior when Driver is not initialized.""" + # Execute + self.driver.__del__() + + # Assert + self.mock_grpc_driver.disconnect.assert_not_called()