Skip to content

Commit

Permalink
update driver app
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Oct 25, 2023
1 parent 05454ae commit 162237a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 41 deletions.
41 changes: 14 additions & 27 deletions src/py/flwr/driver/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@
import threading
import time
from logging import INFO
from typing import Dict, Optional
from typing import Dict, List, Optional, cast

from flwr.common import EventType, event
from flwr.common.address import parse_address
from flwr.common.logger import log
from flwr.proto import driver_pb2
from flwr.server.app import ServerConfig, init_defaults, run_fl
from flwr.server.client_manager import ClientManager
from flwr.server.history import History
from flwr.server.server import Server
from flwr.server.strategy import Strategy

from .driver import Driver
from .driver_client_proxy import DriverClientProxy
from .grpc_driver import GrpcDriver

Expand Down Expand Up @@ -111,11 +111,6 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
host, port, is_v6 = parsed_address
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"

# Create the Driver
driver = GrpcDriver(driver_service_address=address, certificates=certificates)
driver.connect()
lock = threading.Lock()

# Initialize the Driver API server and config
initialized_server, initialized_config = init_defaults(
server=server,
Expand All @@ -130,12 +125,13 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
)

# Start the thread updating nodes
bool_ref = [True]
thread = threading.Thread(
target=update_client_manager,
args=(
driver,
Driver(driver_service_address=address, certificates=certificates),
initialized_server.client_manager(),
lock,
bool_ref,
),
)
thread.start()
Expand All @@ -147,8 +143,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
)

# Stop the Driver API server and the thread
with lock:
driver.disconnect()
bool_ref[0] = False
thread.join()

event(EventType.START_SERVER_LEAVE)
Expand All @@ -157,9 +152,9 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals


def update_client_manager(
driver: GrpcDriver,
driver: Driver,
client_manager: ClientManager,
lock: threading.Lock,
bool_ref: List[bool],
) -> None:
"""Update the nodes list in the client manager.
Expand All @@ -171,20 +166,11 @@ def update_client_manager(
and dead nodes will be removed from the ClientManager via
`client_manager.unregister()`.
"""
# Request for workload_id
workload_id = driver.create_workload(driver_pb2.CreateWorkloadRequest()).workload_id

# Loop until the driver is disconnected
registered_nodes: Dict[int, DriverClientProxy] = {}
while True:
with lock:
# End the while loop if the driver is disconnected
if driver.stub is None:
break
get_nodes_res = driver.get_nodes(
req=driver_pb2.GetNodesRequest(workload_id=workload_id)
)
all_node_ids = {node.node_id for node in get_nodes_res.nodes}
while bool_ref[0]:
nodes = driver.get_nodes()
all_node_ids = {node.node_id for node in nodes}
dead_nodes = set(registered_nodes).difference(all_node_ids)
new_nodes = all_node_ids.difference(registered_nodes)

Expand All @@ -198,9 +184,9 @@ def update_client_manager(
for node_id in new_nodes:
client_proxy = DriverClientProxy(
node_id=node_id,
driver=driver,
driver=cast(GrpcDriver, driver.grpc_driver),
anonymous=False,
workload_id=workload_id,
workload_id=cast(int, driver.workload_id),
)
if client_manager.register(client_proxy):
registered_nodes[node_id] = client_proxy
Expand All @@ -209,3 +195,4 @@ def update_client_manager(

# Sleep for 3 seconds
time.sleep(3)
del driver
36 changes: 22 additions & 14 deletions src/py/flwr/driver/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import threading
import time
import unittest
from unittest.mock import MagicMock
from unittest.mock import Mock, patch

from flwr.driver.app import update_client_manager
from flwr.proto.driver_pb2 import CreateWorkloadResponse, GetNodesResponse
from flwr.proto.node_pb2 import Node
from flwr.server.client_manager import SimpleClientManager

from .driver import Driver


class TestClientManagerWithDriver(unittest.TestCase):
"""Tests for ClientManager.
Expand All @@ -37,24 +39,28 @@ class TestClientManagerWithDriver(unittest.TestCase):
def test_simple_client_manager_update(self) -> None:
"""Tests if the node update works correctly."""
# Prepare
mock_grpc_driver = Mock()
mock_grpc_driver.create_workload.return_value = CreateWorkloadResponse(
workload_id=61016
)
expected_nodes = [Node(node_id=i, anonymous=False) for i in range(100)]
expected_updated_nodes = [
Node(node_id=i, anonymous=False) for i in range(80, 120)
]
driver = MagicMock()
driver.stub = "driver stub"
driver.create_workload.return_value = CreateWorkloadResponse(workload_id=1)
driver.get_nodes.return_value = GetNodesResponse(nodes=expected_nodes)
mock_grpc_driver.get_nodes.return_value = GetNodesResponse(nodes=expected_nodes)
client_manager = SimpleClientManager()
lock = threading.Lock()
patcher = patch("flwr.driver.driver.GrpcDriver", return_value=mock_grpc_driver)
patcher.start()
driver = Driver()
bool_ref = [True]

# Execute
thread = threading.Thread(
target=update_client_manager,
args=(
driver,
client_manager,
lock,
bool_ref,
),
daemon=True,
)
Expand All @@ -64,21 +70,23 @@ def test_simple_client_manager_update(self) -> None:
# Retrieve all nodes in `client_manager`
node_ids = {proxy.node_id for proxy in client_manager.all().values()}
# Update the GetNodesResponse and wait until the `client_manager` is updated
driver.get_nodes.return_value = GetNodesResponse(nodes=expected_updated_nodes)
mock_grpc_driver.get_nodes.return_value = GetNodesResponse(
nodes=expected_updated_nodes
)
while True:
with lock:
if len(client_manager.all()) == len(expected_updated_nodes):
break
if len(client_manager) == len(expected_updated_nodes):
break
time.sleep(1.3)
# Retrieve all nodes in `client_manager`
updated_node_ids = {proxy.node_id for proxy in client_manager.all().values()}
# Simulate `driver.disconnect()`
driver.stub = None
# Stop client manager update
bool_ref[0] = False

# Assert
driver.create_workload.assert_called_once()
mock_grpc_driver.create_workload.assert_called_once()
assert node_ids == {node.node_id for node in expected_nodes}
assert updated_node_ids == {node.node_id for node in expected_updated_nodes}

# Exit
patcher.stop()
thread.join()

0 comments on commit 162237a

Please sign in to comment.