Skip to content

Commit

Permalink
Make DriverClientProxy work with RecordSet and task_type. (#2853)
Browse files Browse the repository at this point in the history
Co-authored-by: jafermarq <[email protected]>
Co-authored-by: Heng Pan <[email protected]>
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
4 people authored Jan 29, 2024
1 parent eeb9ea2 commit 599a974
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 190 deletions.
2 changes: 1 addition & 1 deletion src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]:
Returns
-------
task_res : Optional[TaskRes]
TaskRes to be returned to the server. If None, the client should
TaskRes to be sent back to the server. If None, the client should
continue to process messages from the server.
sleep_duration : int
Number of seconds that the client should disconnect from the server.
Expand Down
89 changes: 44 additions & 45 deletions src/py/flwr/driver/driver_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@


import time
from typing import List, Optional, cast
from typing import List, Optional

from flwr import common
from flwr.common import recordset_compat as compat
from flwr.common import serde
from flwr.proto import ( # pylint: disable=E0611
driver_pb2,
node_pb2,
task_pb2,
transport_pb2,
from flwr.common.constant import (
TASK_TYPE_EVALUATE,
TASK_TYPE_FIT,
TASK_TYPE_GET_PARAMETERS,
TASK_TYPE_GET_PROPERTIES,
)
from flwr.common.recordset import RecordSet
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
from flwr.server.client_proxy import ClientProxy

from .grpc_driver import GrpcDriver
Expand All @@ -47,67 +50,64 @@ def get_properties(
self, ins: common.GetPropertiesIns, timeout: Optional[float]
) -> common.GetPropertiesRes:
"""Return client's properties."""
server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101
serde.server_message_to_proto(
server_message=common.ServerMessage(get_properties_ins=ins)
)
)
return cast(
common.GetPropertiesRes,
self._send_receive_msg(server_message_proto, timeout).get_properties_res,
# Ins to RecordSet
out_recordset = compat.getpropertiesins_to_recordset(ins)
# Fetch response
in_recordset = self._send_receive_recordset(
out_recordset, TASK_TYPE_GET_PROPERTIES, timeout
)
# RecordSet to Res
return compat.recordset_to_getpropertiesres(in_recordset)

def get_parameters(
self, ins: common.GetParametersIns, timeout: Optional[float]
) -> common.GetParametersRes:
"""Return the current local model parameters."""
server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101
serde.server_message_to_proto(
server_message=common.ServerMessage(get_parameters_ins=ins)
)
)
return cast(
common.GetParametersRes,
self._send_receive_msg(server_message_proto, timeout).get_parameters_res,
# Ins to RecordSet
out_recordset = compat.getparametersins_to_recordset(ins)
# Fetch response
in_recordset = self._send_receive_recordset(
out_recordset, TASK_TYPE_GET_PARAMETERS, timeout
)
# RecordSet to Res
return compat.recordset_to_getparametersres(in_recordset, False)

def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes:
"""Train model parameters on the locally held dataset."""
server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101
serde.server_message_to_proto(
server_message=common.ServerMessage(fit_ins=ins)
)
)
return cast(
common.FitRes,
self._send_receive_msg(server_message_proto, timeout).fit_res,
# Ins to RecordSet
out_recordset = compat.fitins_to_recordset(ins, keep_input=True)
# Fetch response
in_recordset = self._send_receive_recordset(
out_recordset, TASK_TYPE_FIT, timeout
)
# RecordSet to Res
return compat.recordset_to_fitres(in_recordset, keep_input=False)

def evaluate(
self, ins: common.EvaluateIns, timeout: Optional[float]
) -> common.EvaluateRes:
"""Evaluate model parameters on the locally held dataset."""
server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101
serde.server_message_to_proto(
server_message=common.ServerMessage(evaluate_ins=ins)
)
)
return cast(
common.EvaluateRes,
self._send_receive_msg(server_message_proto, timeout).evaluate_res,
# Ins to RecordSet
out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True)
# Fetch response
in_recordset = self._send_receive_recordset(
out_recordset, TASK_TYPE_EVALUATE, timeout
)
# RecordSet to Res
return compat.recordset_to_evaluateres(in_recordset)

def reconnect(
self, ins: common.ReconnectIns, timeout: Optional[float]
) -> common.DisconnectRes:
"""Disconnect and (optionally) reconnect later."""
return common.DisconnectRes(reason="") # Nothing to do here (yet)

def _send_receive_msg(
def _send_receive_recordset(
self,
server_message: transport_pb2.ServerMessage, # pylint: disable=E1101
recordset: RecordSet,
task_type: str,
timeout: Optional[float],
) -> transport_pb2.ClientMessage: # pylint: disable=E1101
) -> RecordSet:
task_ins = task_pb2.TaskIns( # pylint: disable=E1101
task_id="",
group_id="",
Expand All @@ -121,7 +121,8 @@ def _send_receive_msg(
node_id=self.node_id,
anonymous=self.anonymous,
),
legacy_server_message=server_message,
task_type=task_type,
recordset=serde.recordset_to_proto(recordset),
),
)
push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101
Expand Down Expand Up @@ -155,9 +156,7 @@ def _send_receive_msg(
)
if len(task_res_list) == 1:
task_res = task_res_list[0]
return serde.client_message_from_proto( # type: ignore
task_res.task.legacy_client_message
)
return serde.recordset_from_proto(task_res.task.recordset)

if timeout is not None and time.time() > start_time + timeout:
raise RuntimeError("Timeout reached")
Expand Down
98 changes: 68 additions & 30 deletions src/py/flwr/driver/driver_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,63 @@


import unittest
from typing import Union, cast
from unittest.mock import MagicMock

import numpy as np

import flwr
from flwr.common.typing import Config, GetParametersIns
from flwr.driver.driver_client_proxy import DriverClientProxy
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
ClientMessage,
from flwr.common import recordset_compat as compat
from flwr.common import serde
from flwr.common.constant import (
TASK_TYPE_EVALUATE,
TASK_TYPE_FIT,
TASK_TYPE_GET_PARAMETERS,
TASK_TYPE_GET_PROPERTIES,
)
from flwr.common.typing import (
Code,
Config,
EvaluateIns,
EvaluateRes,
FitRes,
GetParametersIns,
GetParametersRes,
GetPropertiesRes,
Parameters,
Scalar,
Properties,
Status,
)
from flwr.driver.driver_client_proxy import DriverClientProxy
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611

MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np")

CLIENT_PROPERTIES = {"tensor_type": Scalar(string="numpy.ndarray")}
CLIENT_PROPERTIES = cast(Properties, {"tensor_type": "numpy.ndarray"})
CLIENT_STATUS = Status(code=Code.OK, message="OK")


def _make_task(
res: Union[GetParametersRes, GetPropertiesRes, FitRes, EvaluateRes]
) -> task_pb2.Task: # pylint: disable=E1101
if isinstance(res, GetParametersRes):
task_type = TASK_TYPE_GET_PARAMETERS
recordset = compat.getparametersres_to_recordset(res, True)
elif isinstance(res, GetPropertiesRes):
task_type = TASK_TYPE_GET_PROPERTIES
recordset = compat.getpropertiesres_to_recordset(res)
elif isinstance(res, FitRes):
task_type = TASK_TYPE_FIT
recordset = compat.fitres_to_recordset(res, True)
elif isinstance(res, EvaluateRes):
task_type = TASK_TYPE_EVALUATE
recordset = compat.evaluateres_to_recordset(res)
else:
raise ValueError(f"Unsupported type: {type(res)}")
return task_pb2.Task( # pylint: disable=E1101
task_type=task_type,
recordset=serde.recordset_to_proto(recordset),
)


class DriverClientProxyTestCase(unittest.TestCase):
Expand Down Expand Up @@ -64,11 +104,9 @@ def test_get_properties(self) -> None:
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
task=task_pb2.Task( # pylint: disable=E1101
legacy_client_message=ClientMessage(
get_properties_res=ClientMessage.GetPropertiesRes(
properties=CLIENT_PROPERTIES
)
task=_make_task(
GetPropertiesRes(
status=CLIENT_STATUS, properties=CLIENT_PROPERTIES
)
),
)
Expand Down Expand Up @@ -104,11 +142,10 @@ def test_get_parameters(self) -> None:
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
task=task_pb2.Task( # pylint: disable=E1101
legacy_client_message=ClientMessage(
get_parameters_res=ClientMessage.GetParametersRes(
parameters=MESSAGE_PARAMETERS,
)
task=_make_task(
GetParametersRes(
status=CLIENT_STATUS,
parameters=MESSAGE_PARAMETERS,
)
),
)
Expand Down Expand Up @@ -143,12 +180,12 @@ def test_fit(self) -> None:
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
task=task_pb2.Task( # pylint: disable=E1101
legacy_client_message=ClientMessage(
fit_res=ClientMessage.FitRes(
parameters=MESSAGE_PARAMETERS,
num_examples=10,
)
task=_make_task(
FitRes(
status=CLIENT_STATUS,
parameters=MESSAGE_PARAMETERS,
num_examples=10,
metrics={},
)
),
)
Expand Down Expand Up @@ -184,11 +221,12 @@ def test_evaluate(self) -> None:
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
task=task_pb2.Task( # pylint: disable=E1101
legacy_client_message=ClientMessage(
evaluate_res=ClientMessage.EvaluateRes(
loss=0.0, num_examples=0
)
task=_make_task(
EvaluateRes(
status=CLIENT_STATUS,
loss=0.0,
num_examples=0,
metrics={},
)
),
)
Expand All @@ -198,8 +236,8 @@ def test_evaluate(self) -> None:
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, run_id=0
)
parameters = flwr.common.Parameters(tensors=[], tensor_type="np")
evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {})
parameters = Parameters(tensors=[], tensor_type="np")
evaluate_ins = EvaluateIns(parameters, {})

# Execute
evaluate_res = client.evaluate(evaluate_ins, timeout=None)
Expand Down
15 changes: 5 additions & 10 deletions src/py/flwr/server/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,8 @@
from uuid import uuid4

from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
ClientMessage,
ServerMessage,
)
from flwr.server.state import InMemoryState, SqliteState, State


Expand Down Expand Up @@ -421,9 +418,8 @@ def create_task_ins(
delivered_at=delivered_at,
producer=Node(node_id=0, anonymous=True),
consumer=consumer,
legacy_server_message=ServerMessage(
reconnect_ins=ServerMessage.ReconnectIns()
),
task_type="mock",
recordset=RecordSet(parameters={}, metrics={}, configs={}),
),
)
return task
Expand All @@ -444,9 +440,8 @@ def create_task_res(
producer=Node(node_id=producer_node_id, anonymous=anonymous),
consumer=Node(node_id=0, anonymous=True),
ancestry=ancestry,
legacy_client_message=ClientMessage(
disconnect_res=ClientMessage.DisconnectRes()
),
task_type="mock",
recordset=RecordSet(parameters={}, metrics={}, configs={}),
),
)
return task_res
Expand Down
38 changes: 8 additions & 30 deletions src/py/flwr/server/utils/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,10 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str
validation_errors.append("non-anonymous consumer MUST provide a `node_id`")

# Content check
has_fields = {
"sa": tasks_ins_res.task.HasField("sa"),
"legacy_server_message": tasks_ins_res.task.HasField(
"legacy_server_message"
),
}
if not (has_fields["sa"] or has_fields["legacy_server_message"]):
err_msg = ", ".join([f"`{field}`" for field in has_fields])
validation_errors.append(
f"`task` in `TaskIns` must set at least one of fields {{{err_msg}}}"
)
if has_fields[
"legacy_server_message"
] and not tasks_ins_res.task.legacy_server_message.HasField("msg"):
validation_errors.append("`legacy_server_message` does not set field `msg`")
if tasks_ins_res.task.task_type == "":
validation_errors.append("`task_type` MUST be set")
if not tasks_ins_res.task.HasField("recordset"):
validation_errors.append("`recordset` MUST be set")

# Ancestors
if len(tasks_ins_res.task.ancestry) != 0:
Expand Down Expand Up @@ -115,21 +104,10 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str
validation_errors.append("non-anonymous consumer MUST provide a `node_id`")

# Content check
has_fields = {
"sa": tasks_ins_res.task.HasField("sa"),
"legacy_client_message": tasks_ins_res.task.HasField(
"legacy_client_message"
),
}
if not (has_fields["sa"] or has_fields["legacy_client_message"]):
err_msg = ", ".join([f"`{field}`" for field in has_fields])
validation_errors.append(
f"`task` in `TaskRes` must set at least one of fields {{{err_msg}}}"
)
if has_fields[
"legacy_client_message"
] and not tasks_ins_res.task.legacy_client_message.HasField("msg"):
validation_errors.append("`legacy_client_message` does not set field `msg`")
if tasks_ins_res.task.task_type == "":
validation_errors.append("`task_type` MUST be set")
if not tasks_ins_res.task.HasField("recordset"):
validation_errors.append("`recordset` MUST be set")

# Ancestors
if len(tasks_ins_res.task.ancestry) == 0:
Expand Down
Loading

0 comments on commit 599a974

Please sign in to comment.