Skip to content

Commit

Permalink
Add Driver class (#2531)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
panh99 and danieljanes authored Oct 24, 2023
1 parent f61942f commit dd4779f
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 2 deletions.
65 changes: 63 additions & 2 deletions src/py/flwr/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


from logging import ERROR, INFO, WARNING
from typing import Optional
from typing import Iterable, List, Optional, Tuple

import grpc

Expand All @@ -34,6 +34,8 @@
PushTaskInsResponse,
)
from flwr.proto.driver_pb2_grpc import DriverStub
from flwr.proto.node_pb2 import Node
from flwr.proto.task_pb2 import TaskIns, TaskRes

DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"

Expand All @@ -46,7 +48,7 @@


class GrpcDriver:
"""`GrpcDriver` provides access to the Driver API/service."""
"""`GrpcDriver` provides access to the gRPC Driver API/service."""

def __init__(
self,
Expand Down Expand Up @@ -126,3 +128,62 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
# Call Driver API
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
return res


class Driver:
"""`Driver` class provides an interface to the Driver API."""

def __init__(self) -> None:
self.grpc_driver: Optional[GrpcDriver] = None
self.workload_id: Optional[int] = None
self.node = Node(node_id=0, anonymous=True)

def _get_grpc_driver_and_workload_id(self) -> Tuple[GrpcDriver, int]:
# Check if the GrpcDriver is initialized
if self.grpc_driver is None or self.workload_id is None:
# Connect and create workload
self.grpc_driver = GrpcDriver()
self.grpc_driver.connect()
res = self.grpc_driver.create_workload(CreateWorkloadRequest())
self.workload_id = res.workload_id

return self.grpc_driver, self.workload_id

def get_nodes(self) -> List[Node]:
"""Get node IDs."""
grpc_driver, workload_id = self._get_grpc_driver_and_workload_id()

# Call GrpcDriver method
res = grpc_driver.get_nodes(GetNodesRequest(workload_id=workload_id))
return list(res.nodes)

def push_task_ins(self, task_ins_list: List[TaskIns]) -> List[str]:
"""Schedule tasks."""
grpc_driver, workload_id = self._get_grpc_driver_and_workload_id()

# Set workload_id
for task_ins in task_ins_list:
task_ins.workload_id = workload_id

# Call GrpcDriver method
res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
return list(res.task_ids)

def pull_task_res(self, task_ids: Iterable[str]) -> List[TaskRes]:
"""Get task results."""
grpc_driver, _ = self._get_grpc_driver_and_workload_id()

# Call GrpcDriver method
res = grpc_driver.pull_task_res(
PullTaskResRequest(node=self.node, task_ids=task_ids)
)
return list(res.task_res_list)

def __del__(self) -> None:
"""Disconnect GrpcDriver if connected."""
# Check if GrpcDriver is initialized
if self.grpc_driver is None:
return

# Disconnect
self.grpc_driver.disconnect()
138 changes: 138 additions & 0 deletions src/py/flwr/driver/driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,141 @@
# limitations under the License.
# ==============================================================================
"""Tests for driver SDK."""


import unittest
from unittest.mock import Mock, patch

from flwr.driver.driver import Driver
from flwr.proto.driver_pb2 import (
GetNodesRequest,
PullTaskResRequest,
PushTaskInsRequest,
)
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes


class TestDriver(unittest.TestCase):
"""Tests for `Driver` class."""

def setUp(self) -> None:
"""Initialize mock GrpcDriver and Driver instance before each test."""
mock_response = Mock()
mock_response.workload_id = 61016
self.mock_grpc_driver = Mock()
self.mock_grpc_driver.create_workload.return_value = mock_response
self.patcher = patch(
"flwr.driver.driver.GrpcDriver", return_value=self.mock_grpc_driver
)
self.patcher.start()
self.driver = Driver()

def tearDown(self) -> None:
"""Cleanup after each test."""
self.patcher.stop()

def test_check_and_init_grpc_driver_already_initialized(self) -> None:
"""Test that GrpcDriver doesn't initialize if workload is created."""
# Prepare
self.driver.grpc_driver = self.mock_grpc_driver
self.driver.workload_id = 61016

# Execute
# pylint: disable-next=protected-access
self.driver._get_grpc_driver_and_workload_id()

# Assert
self.mock_grpc_driver.connect.assert_not_called()

def test_check_and_init_grpc_driver_needs_initialization(self) -> None:
"""Test GrpcDriver initialization when workload is not created."""
# Execute
# pylint: disable-next=protected-access
self.driver._get_grpc_driver_and_workload_id()

# Assert
self.mock_grpc_driver.connect.assert_called_once()
self.assertEqual(self.driver.workload_id, 61016)

def test_get_nodes(self) -> None:
"""Test retrieval of nodes."""
# Prepare
mock_response = Mock()
mock_response.nodes = [Mock(), Mock()]
self.mock_grpc_driver.get_nodes.return_value = mock_response

# Execute
nodes = self.driver.get_nodes()
args, kwargs = self.mock_grpc_driver.get_nodes.call_args

# Assert
self.mock_grpc_driver.connect.assert_called_once()
self.assertEqual(len(args), 1)
self.assertEqual(len(kwargs), 0)
self.assertIsInstance(args[0], GetNodesRequest)
self.assertEqual(args[0].workload_id, 61016)
self.assertEqual(nodes, mock_response.nodes)

def test_push_task_ins(self) -> None:
"""Test pushing task instructions."""
# Prepare
mock_response = Mock()
mock_response.task_ids = ["id1", "id2"]
self.mock_grpc_driver.push_task_ins.return_value = mock_response
task_ins_list = [TaskIns(), TaskIns()]

# Execute
task_ids = self.driver.push_task_ins(task_ins_list)
args, kwargs = self.mock_grpc_driver.push_task_ins.call_args

# Assert
self.mock_grpc_driver.connect.assert_called_once()
self.assertEqual(len(args), 1)
self.assertEqual(len(kwargs), 0)
self.assertIsInstance(args[0], PushTaskInsRequest)
self.assertEqual(task_ids, mock_response.task_ids)
for task_ins in args[0].task_ins_list:
self.assertEqual(task_ins.workload_id, 61016)

def test_pull_task_res_with_given_task_ids(self) -> None:
"""Test pulling task results with specific task IDs."""
# Prepare
mock_response = Mock()
mock_response.task_res_list = [
TaskRes(task=Task(ancestry=["id2"])),
TaskRes(task=Task(ancestry=["id3"])),
]
self.mock_grpc_driver.pull_task_res.return_value = mock_response
task_ids = ["id1", "id2", "id3"]

# Execute
task_res_list = self.driver.pull_task_res(task_ids)
args, kwargs = self.mock_grpc_driver.pull_task_res.call_args

# Assert
self.mock_grpc_driver.connect.assert_called_once()
self.assertEqual(len(args), 1)
self.assertEqual(len(kwargs), 0)
self.assertIsInstance(args[0], PullTaskResRequest)
self.assertEqual(args[0].task_ids, task_ids)
self.assertEqual(task_res_list, mock_response.task_res_list)

def test_del_with_initialized_driver(self) -> None:
"""Test cleanup behavior when Driver is initialized."""
# Prepare
# pylint: disable-next=protected-access
self.driver._get_grpc_driver_and_workload_id()

# Execute
self.driver.__del__()

# Assert
self.mock_grpc_driver.disconnect.assert_called_once()

def test_del_with_uninitialized_driver(self) -> None:
"""Test cleanup behavior when Driver is not initialized."""
# Execute
self.driver.__del__()

# Assert
self.mock_grpc_driver.disconnect.assert_not_called()

0 comments on commit dd4779f

Please sign in to comment.