Skip to content

Commit

Permalink
use Driver in DriverClientProxy
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Dec 1, 2023
1 parent 5709a85 commit 769459a
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 121 deletions.
10 changes: 2 additions & 8 deletions src/py/flwr/driver/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time
from logging import INFO
from pathlib import Path
from typing import Dict, List, Optional, Union, cast
from typing import Dict, List, Optional, Union

from flwr.common import EventType, event
from flwr.common.address import parse_address
Expand All @@ -33,7 +33,6 @@

from .driver import Driver
from .driver_client_proxy import DriverClientProxy
from .grpc_driver import GrpcDriver

DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"

Expand Down Expand Up @@ -111,8 +110,6 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
# Create the Driver
if isinstance(root_certificates, str):
root_certificates = Path(root_certificates).read_bytes()
driver = GrpcDriver(driver_service_address=address, certificates=root_certificates)
driver.connect()

# Initialize the Driver API server and config
initialized_server, initialized_config = init_defaults(
Expand Down Expand Up @@ -186,9 +183,8 @@ def update_client_manager(
for node_id in new_nodes:
client_proxy = DriverClientProxy(
node_id=node_id,
driver=cast(GrpcDriver, driver.grpc_driver),
driver=driver,
anonymous=False,
workload_id=cast(int, driver.workload_id),
)
if client_manager.register(client_proxy):
registered_nodes[node_id] = client_proxy
Expand All @@ -197,5 +193,3 @@ def update_client_manager(

# Sleep for 3 seconds
time.sleep(3)
# Exit
del driver
31 changes: 8 additions & 23 deletions src/py/flwr/driver/driver_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,25 @@


import time
from typing import List, Optional, cast
from typing import Optional, cast

from flwr import common
from flwr.common import serde
from flwr.proto import driver_pb2, node_pb2, task_pb2, transport_pb2
from flwr.proto import node_pb2, task_pb2, transport_pb2
from flwr.server.client_proxy import ClientProxy

from .grpc_driver import GrpcDriver
from .driver import Driver

SLEEP_TIME = 1


class DriverClientProxy(ClientProxy):
"""Flower client proxy which delegates work using the Driver API."""

def __init__(
self, node_id: int, driver: GrpcDriver, anonymous: bool, workload_id: int
):
def __init__(self, node_id: int, driver: Driver, anonymous: bool):
super().__init__(str(node_id))
self.node_id = node_id
self.driver = driver
self.workload_id = workload_id
self.anonymous = anonymous

def get_properties(
Expand Down Expand Up @@ -104,9 +101,6 @@ def _send_receive_msg(
self, server_message: transport_pb2.ServerMessage, timeout: Optional[float]
) -> transport_pb2.ClientMessage:
task_ins = task_pb2.TaskIns(
task_id="",
group_id="",
workload_id=self.workload_id,
task=task_pb2.Task(
producer=node_pb2.Node(
node_id=0,
Expand All @@ -119,33 +113,24 @@ def _send_receive_msg(
legacy_server_message=server_message,
),
)
push_task_ins_req = driver_pb2.PushTaskInsRequest(task_ins_list=[task_ins])

# Send TaskIns to Driver API
push_task_ins_res = self.driver.push_task_ins(req=push_task_ins_req)
task_ids = self.driver.push_task_ins([task_ins])

if len(push_task_ins_res.task_ids) != 1:
if len(task_ids) != 1:
raise ValueError("Unexpected number of task_ids")

task_id = push_task_ins_res.task_ids[0]
task_id = task_ids[0]
if task_id == "":
raise ValueError(f"Failed to schedule task for node {self.node_id}")

if timeout:
start_time = time.time()

while True:
pull_task_res_req = driver_pb2.PullTaskResRequest(
node=node_pb2.Node(node_id=0, anonymous=True),
task_ids=[task_id],
)

# Ask Driver API for TaskRes
pull_task_res_res = self.driver.pull_task_res(req=pull_task_res_req)
task_res_list = self.driver.pull_task_res([task_id])

task_res_list: List[task_pb2.TaskRes] = list(
pull_task_res_res.task_res_list
)
if len(task_res_list) == 1:
task_res = task_res_list[0]
return serde.client_message_from_proto( # type: ignore
Expand Down
160 changes: 70 additions & 90 deletions src/py/flwr/driver/driver_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import flwr
from flwr.common.typing import Config, GetParametersIns
from flwr.driver.driver_client_proxy import DriverClientProxy
from flwr.proto import driver_pb2, node_pb2, task_pb2
from flwr.proto import node_pb2, task_pb2
from flwr.proto.transport_pb2 import ClientMessage, Parameters, Scalar

MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np")
Expand All @@ -37,35 +37,29 @@ class DriverClientProxyTestCase(unittest.TestCase):
def setUp(self) -> None:
"""Set up mocks for tests."""
self.driver = MagicMock()
self.driver.get_nodes.return_value = driver_pb2.GetNodesResponse(
nodes=[node_pb2.Node(node_id=1, anonymous=False)]
)
self.driver.get_nodes.return_value = [node_pb2.Node(node_id=1, anonymous=False)]

def test_get_properties(self) -> None:
"""Test positive case."""
# Prepare
self.driver.push_task_ins.return_value = driver_pb2.PushTaskInsResponse(
task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"]
)
self.driver.pull_task_res.return_value = driver_pb2.PullTaskResResponse(
task_res_list=[
task_pb2.TaskRes(
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
workload_id=0,
task=task_pb2.Task(
legacy_client_message=ClientMessage(
get_properties_res=ClientMessage.GetPropertiesRes(
properties=CLIENT_PROPERTIES
)
self.driver.push_task_ins.return_value = [
"19341fd7-62e1-4eb4-beb4-9876d3acda32"
]
self.driver.pull_task_res.return_value = [
task_pb2.TaskRes(
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
workload_id=0,
task=task_pb2.Task(
legacy_client_message=ClientMessage(
get_properties_res=ClientMessage.GetPropertiesRes(
properties=CLIENT_PROPERTIES
)
),
)
]
)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=0
)
)
),
)
]
client = DriverClientProxy(node_id=1, driver=self.driver, anonymous=True)
request_properties: Config = {"tensor_type": "str"}
ins: flwr.common.GetPropertiesIns = flwr.common.GetPropertiesIns(
config=request_properties
Expand All @@ -80,28 +74,24 @@ def test_get_properties(self) -> None:
def test_get_parameters(self) -> None:
"""Test positive case."""
# Prepare
self.driver.push_task_ins.return_value = driver_pb2.PushTaskInsResponse(
task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"]
)
self.driver.pull_task_res.return_value = driver_pb2.PullTaskResResponse(
task_res_list=[
task_pb2.TaskRes(
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
workload_id=0,
task=task_pb2.Task(
legacy_client_message=ClientMessage(
get_parameters_res=ClientMessage.GetParametersRes(
parameters=MESSAGE_PARAMETERS,
)
self.driver.push_task_ins.return_value = [
"19341fd7-62e1-4eb4-beb4-9876d3acda32"
]
self.driver.pull_task_res.return_value = [
task_pb2.TaskRes(
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
workload_id=0,
task=task_pb2.Task(
legacy_client_message=ClientMessage(
get_parameters_res=ClientMessage.GetParametersRes(
parameters=MESSAGE_PARAMETERS,
)
),
)
]
)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=0
)
)
),
)
]
client = DriverClientProxy(node_id=1, driver=self.driver, anonymous=True)
get_parameters_ins = GetParametersIns(config={})

# Execute
Expand All @@ -115,29 +105,25 @@ def test_get_parameters(self) -> None:
def test_fit(self) -> None:
"""Test positive case."""
# Prepare
self.driver.push_task_ins.return_value = driver_pb2.PushTaskInsResponse(
task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"]
)
self.driver.pull_task_res.return_value = driver_pb2.PullTaskResResponse(
task_res_list=[
task_pb2.TaskRes(
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
workload_id=0,
task=task_pb2.Task(
legacy_client_message=ClientMessage(
fit_res=ClientMessage.FitRes(
parameters=MESSAGE_PARAMETERS,
num_examples=10,
)
self.driver.push_task_ins.return_value = [
"19341fd7-62e1-4eb4-beb4-9876d3acda32"
]
self.driver.pull_task_res.return_value = [
task_pb2.TaskRes(
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
workload_id=0,
task=task_pb2.Task(
legacy_client_message=ClientMessage(
fit_res=ClientMessage.FitRes(
parameters=MESSAGE_PARAMETERS,
num_examples=10,
)
),
)
]
)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=0
)
)
),
)
]
client = DriverClientProxy(node_id=1, driver=self.driver, anonymous=True)
parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))])
ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {})

Expand All @@ -152,28 +138,22 @@ def test_fit(self) -> None:
def test_evaluate(self) -> None:
"""Test positive case."""
# Prepare
self.driver.push_task_ins.return_value = driver_pb2.PushTaskInsResponse(
task_ids=["19341fd7-62e1-4eb4-beb4-9876d3acda32"]
)
self.driver.pull_task_res.return_value = driver_pb2.PullTaskResResponse(
task_res_list=[
task_pb2.TaskRes(
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
workload_id=0,
task=task_pb2.Task(
legacy_client_message=ClientMessage(
evaluate_res=ClientMessage.EvaluateRes(
loss=0.0, num_examples=0
)
)
),
)
]
)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=0
)
self.driver.push_task_ins.return_value = [
"19341fd7-62e1-4eb4-beb4-9876d3acda32"
]
self.driver.pull_task_res.return_value = [
task_pb2.TaskRes(
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
workload_id=0,
task=task_pb2.Task(
legacy_client_message=ClientMessage(
evaluate_res=ClientMessage.EvaluateRes(loss=0.0, num_examples=0)
)
),
)
]
client = DriverClientProxy(node_id=1, driver=self.driver, anonymous=True)
parameters = flwr.common.Parameters(tensors=[], tensor_type="np")
evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {})

Expand Down

0 comments on commit 769459a

Please sign in to comment.