Skip to content

Commit

Permalink
Merge branch 'main' into backoff
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Oct 25, 2023
2 parents 8899ea8 + dd4779f commit c92992b
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 4 deletions.
2 changes: 1 addition & 1 deletion doc/source/ref-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

- FedMeta [#2438](https://github.com/adap/flower/pull/2438)

- **Update Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384)), ([#2425](https://github.com/adap/flower/pull/2425))
- **Update Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384),[#2425](https://github.com/adap/flower/pull/2425), [#2526](https://github.com/adap/flower/pull/2526))

- **General updates to baselines** ([#2301](https://github.com/adap/flower/pull/2301), [#2305](https://github.com/adap/flower/pull/2305), [#2307](https://github.com/adap/flower/pull/2307), [#2327](https://github.com/adap/flower/pull/2327), [#2435](https://github.com/adap/flower/pull/2435))

Expand Down
2 changes: 1 addition & 1 deletion examples/embedded-devices/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ This tutorial allows for a variety of settings (some shown in the diagrams above

- For Flower server: A machine running Linux/macOS/Windows (e.g. your laptop). You can run the server on an embedded device too!
- For Flower clients (one or more): Raspberry Pi 4 (or Zero 2), or an NVIDIA Jetson Xavier-NX (or Nano), or anything similar to these.
- A uSD card with 32GB or more.
- A uSD card with 32GB or more. While 32GB is enough for the RPi, a larger 64GB uSD card works best for the NVIDIA Jetson.
- Software to flash the images to a uSD card:
- For Raspberry Pi we recommend the [Raspberry Pi Imager](https://www.raspberrypi.com/software/)
- For other devices [balenaEtcher](https://www.balena.io/etcher/) it's a great option.
Expand Down
65 changes: 63 additions & 2 deletions src/py/flwr/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


from logging import ERROR, INFO, WARNING
from typing import Optional
from typing import Iterable, List, Optional, Tuple

import grpc

Expand All @@ -34,6 +34,8 @@
PushTaskInsResponse,
)
from flwr.proto.driver_pb2_grpc import DriverStub
from flwr.proto.node_pb2 import Node
from flwr.proto.task_pb2 import TaskIns, TaskRes

DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"

Expand All @@ -46,7 +48,7 @@


class GrpcDriver:
"""`GrpcDriver` provides access to the Driver API/service."""
"""`GrpcDriver` provides access to the gRPC Driver API/service."""

def __init__(
self,
Expand Down Expand Up @@ -126,3 +128,62 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
# Call Driver API
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
return res


class Driver:
"""`Driver` class provides an interface to the Driver API."""

def __init__(self) -> None:
self.grpc_driver: Optional[GrpcDriver] = None
self.workload_id: Optional[int] = None
self.node = Node(node_id=0, anonymous=True)

def _get_grpc_driver_and_workload_id(self) -> Tuple[GrpcDriver, int]:
# Check if the GrpcDriver is initialized
if self.grpc_driver is None or self.workload_id is None:
# Connect and create workload
self.grpc_driver = GrpcDriver()
self.grpc_driver.connect()
res = self.grpc_driver.create_workload(CreateWorkloadRequest())
self.workload_id = res.workload_id

return self.grpc_driver, self.workload_id

def get_nodes(self) -> List[Node]:
"""Get node IDs."""
grpc_driver, workload_id = self._get_grpc_driver_and_workload_id()

# Call GrpcDriver method
res = grpc_driver.get_nodes(GetNodesRequest(workload_id=workload_id))
return list(res.nodes)

def push_task_ins(self, task_ins_list: List[TaskIns]) -> List[str]:
"""Schedule tasks."""
grpc_driver, workload_id = self._get_grpc_driver_and_workload_id()

# Set workload_id
for task_ins in task_ins_list:
task_ins.workload_id = workload_id

# Call GrpcDriver method
res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
return list(res.task_ids)

def pull_task_res(self, task_ids: Iterable[str]) -> List[TaskRes]:
"""Get task results."""
grpc_driver, _ = self._get_grpc_driver_and_workload_id()

# Call GrpcDriver method
res = grpc_driver.pull_task_res(
PullTaskResRequest(node=self.node, task_ids=task_ids)
)
return list(res.task_res_list)

def __del__(self) -> None:
"""Disconnect GrpcDriver if connected."""
# Check if GrpcDriver is initialized
if self.grpc_driver is None:
return

# Disconnect
self.grpc_driver.disconnect()
138 changes: 138 additions & 0 deletions src/py/flwr/driver/driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,141 @@
# limitations under the License.
# ==============================================================================
"""Tests for driver SDK."""


import unittest
from unittest.mock import Mock, patch

from flwr.driver.driver import Driver
from flwr.proto.driver_pb2 import (
GetNodesRequest,
PullTaskResRequest,
PushTaskInsRequest,
)
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes


class TestDriver(unittest.TestCase):
"""Tests for `Driver` class."""

def setUp(self) -> None:
"""Initialize mock GrpcDriver and Driver instance before each test."""
mock_response = Mock()
mock_response.workload_id = 61016
self.mock_grpc_driver = Mock()
self.mock_grpc_driver.create_workload.return_value = mock_response
self.patcher = patch(
"flwr.driver.driver.GrpcDriver", return_value=self.mock_grpc_driver
)
self.patcher.start()
self.driver = Driver()

def tearDown(self) -> None:
"""Cleanup after each test."""
self.patcher.stop()

def test_check_and_init_grpc_driver_already_initialized(self) -> None:
"""Test that GrpcDriver doesn't initialize if workload is created."""
# Prepare
self.driver.grpc_driver = self.mock_grpc_driver
self.driver.workload_id = 61016

# Execute
# pylint: disable-next=protected-access
self.driver._get_grpc_driver_and_workload_id()

# Assert
self.mock_grpc_driver.connect.assert_not_called()

def test_check_and_init_grpc_driver_needs_initialization(self) -> None:
"""Test GrpcDriver initialization when workload is not created."""
# Execute
# pylint: disable-next=protected-access
self.driver._get_grpc_driver_and_workload_id()

# Assert
self.mock_grpc_driver.connect.assert_called_once()
self.assertEqual(self.driver.workload_id, 61016)

def test_get_nodes(self) -> None:
"""Test retrieval of nodes."""
# Prepare
mock_response = Mock()
mock_response.nodes = [Mock(), Mock()]
self.mock_grpc_driver.get_nodes.return_value = mock_response

# Execute
nodes = self.driver.get_nodes()
args, kwargs = self.mock_grpc_driver.get_nodes.call_args

# Assert
self.mock_grpc_driver.connect.assert_called_once()
self.assertEqual(len(args), 1)
self.assertEqual(len(kwargs), 0)
self.assertIsInstance(args[0], GetNodesRequest)
self.assertEqual(args[0].workload_id, 61016)
self.assertEqual(nodes, mock_response.nodes)

def test_push_task_ins(self) -> None:
"""Test pushing task instructions."""
# Prepare
mock_response = Mock()
mock_response.task_ids = ["id1", "id2"]
self.mock_grpc_driver.push_task_ins.return_value = mock_response
task_ins_list = [TaskIns(), TaskIns()]

# Execute
task_ids = self.driver.push_task_ins(task_ins_list)
args, kwargs = self.mock_grpc_driver.push_task_ins.call_args

# Assert
self.mock_grpc_driver.connect.assert_called_once()
self.assertEqual(len(args), 1)
self.assertEqual(len(kwargs), 0)
self.assertIsInstance(args[0], PushTaskInsRequest)
self.assertEqual(task_ids, mock_response.task_ids)
for task_ins in args[0].task_ins_list:
self.assertEqual(task_ins.workload_id, 61016)

def test_pull_task_res_with_given_task_ids(self) -> None:
"""Test pulling task results with specific task IDs."""
# Prepare
mock_response = Mock()
mock_response.task_res_list = [
TaskRes(task=Task(ancestry=["id2"])),
TaskRes(task=Task(ancestry=["id3"])),
]
self.mock_grpc_driver.pull_task_res.return_value = mock_response
task_ids = ["id1", "id2", "id3"]

# Execute
task_res_list = self.driver.pull_task_res(task_ids)
args, kwargs = self.mock_grpc_driver.pull_task_res.call_args

# Assert
self.mock_grpc_driver.connect.assert_called_once()
self.assertEqual(len(args), 1)
self.assertEqual(len(kwargs), 0)
self.assertIsInstance(args[0], PullTaskResRequest)
self.assertEqual(args[0].task_ids, task_ids)
self.assertEqual(task_res_list, mock_response.task_res_list)

def test_del_with_initialized_driver(self) -> None:
"""Test cleanup behavior when Driver is initialized."""
# Prepare
# pylint: disable-next=protected-access
self.driver._get_grpc_driver_and_workload_id()

# Execute
self.driver.__del__()

# Assert
self.mock_grpc_driver.disconnect.assert_called_once()

def test_del_with_uninitialized_driver(self) -> None:
"""Test cleanup behavior when Driver is not initialized."""
# Execute
self.driver.__del__()

# Assert
self.mock_grpc_driver.disconnect.assert_not_called()

0 comments on commit c92992b

Please sign in to comment.