diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index b9abf9bf142b..ea87c35c83f7 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -72,7 +72,7 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: Returns ------- task_res : Optional[TaskRes] - TaskRes to be returned to the server. If None, the client should + TaskRes to be sent back to the server. If None, the client should continue to process messages from the server. sleep_duration : int Number of seconds that the client should disconnect from the server. diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/driver/driver_client_proxy.py index 8b2e51c17ea0..e0ff26c035f7 100644 --- a/src/py/flwr/driver/driver_client_proxy.py +++ b/src/py/flwr/driver/driver_client_proxy.py @@ -16,16 +16,19 @@ import time -from typing import List, Optional, cast +from typing import List, Optional from flwr import common +from flwr.common import recordset_compat as compat from flwr.common import serde -from flwr.proto import ( # pylint: disable=E0611 - driver_pb2, - node_pb2, - task_pb2, - transport_pb2, +from flwr.common.constant import ( + TASK_TYPE_EVALUATE, + TASK_TYPE_FIT, + TASK_TYPE_GET_PARAMETERS, + TASK_TYPE_GET_PROPERTIES, ) +from flwr.common.recordset import RecordSet +from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 from flwr.server.client_proxy import ClientProxy from .grpc_driver import GrpcDriver @@ -47,55 +50,51 @@ def get_properties( self, ins: common.GetPropertiesIns, timeout: Optional[float] ) -> common.GetPropertiesRes: """Return client's properties.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(get_properties_ins=ins) - ) - ) - return cast( - common.GetPropertiesRes, - self._send_receive_msg(server_message_proto, timeout).get_properties_res, + # Ins to RecordSet + out_recordset = compat.getpropertiesins_to_recordset(ins) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, TASK_TYPE_GET_PROPERTIES, timeout ) + # RecordSet to Res + return compat.recordset_to_getpropertiesres(in_recordset) def get_parameters( self, ins: common.GetParametersIns, timeout: Optional[float] ) -> common.GetParametersRes: """Return the current local model parameters.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(get_parameters_ins=ins) - ) - ) - return cast( - common.GetParametersRes, - self._send_receive_msg(server_message_proto, timeout).get_parameters_res, + # Ins to RecordSet + out_recordset = compat.getparametersins_to_recordset(ins) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, TASK_TYPE_GET_PARAMETERS, timeout ) + # RecordSet to Res + return compat.recordset_to_getparametersres(in_recordset, False) def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: """Train model parameters on the locally held dataset.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(fit_ins=ins) - ) - ) - return cast( - common.FitRes, - self._send_receive_msg(server_message_proto, timeout).fit_res, + # Ins to RecordSet + out_recordset = compat.fitins_to_recordset(ins, keep_input=True) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, TASK_TYPE_FIT, timeout ) + # RecordSet to Res + return compat.recordset_to_fitres(in_recordset, keep_input=False) def evaluate( self, ins: common.EvaluateIns, timeout: Optional[float] ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(evaluate_ins=ins) - ) - ) - return cast( - common.EvaluateRes, - self._send_receive_msg(server_message_proto, timeout).evaluate_res, + # Ins to RecordSet + out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True) + # Fetch response + in_recordset = self._send_receive_recordset( + out_recordset, TASK_TYPE_EVALUATE, timeout ) + # RecordSet to Res + return compat.recordset_to_evaluateres(in_recordset) def reconnect( self, ins: common.ReconnectIns, timeout: Optional[float] @@ -103,11 +102,12 @@ def reconnect( """Disconnect and (optionally) reconnect later.""" return common.DisconnectRes(reason="") # Nothing to do here (yet) - def _send_receive_msg( + def _send_receive_recordset( self, - server_message: transport_pb2.ServerMessage, # pylint: disable=E1101 + recordset: RecordSet, + task_type: str, timeout: Optional[float], - ) -> transport_pb2.ClientMessage: # pylint: disable=E1101 + ) -> RecordSet: task_ins = task_pb2.TaskIns( # pylint: disable=E1101 task_id="", group_id="", @@ -121,7 +121,8 @@ def _send_receive_msg( node_id=self.node_id, anonymous=self.anonymous, ), - legacy_server_message=server_message, + task_type=task_type, + recordset=serde.recordset_to_proto(recordset), ), ) push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101 @@ -155,9 +156,7 @@ def _send_receive_msg( ) if len(task_res_list) == 1: task_res = task_res_list[0] - return serde.client_message_from_proto( # type: ignore - task_res.task.legacy_client_message - ) + return serde.recordset_from_proto(task_res.task.recordset) if timeout is not None and time.time() > start_time + timeout: raise RuntimeError("Timeout reached") diff --git a/src/py/flwr/driver/driver_client_proxy_test.py b/src/py/flwr/driver/driver_client_proxy_test.py index d3cab152e4db..4e9a02a6cbf9 100644 --- a/src/py/flwr/driver/driver_client_proxy_test.py +++ b/src/py/flwr/driver/driver_client_proxy_test.py @@ -16,23 +16,63 @@ import unittest +from typing import Union, cast from unittest.mock import MagicMock import numpy as np import flwr -from flwr.common.typing import Config, GetParametersIns -from flwr.driver.driver_client_proxy import DriverClientProxy -from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, +from flwr.common import recordset_compat as compat +from flwr.common import serde +from flwr.common.constant import ( + TASK_TYPE_EVALUATE, + TASK_TYPE_FIT, + TASK_TYPE_GET_PARAMETERS, + TASK_TYPE_GET_PROPERTIES, +) +from flwr.common.typing import ( + Code, + Config, + EvaluateIns, + EvaluateRes, + FitRes, + GetParametersIns, + GetParametersRes, + GetPropertiesRes, Parameters, - Scalar, + Properties, + Status, ) +from flwr.driver.driver_client_proxy import DriverClientProxy +from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np") -CLIENT_PROPERTIES = {"tensor_type": Scalar(string="numpy.ndarray")} +CLIENT_PROPERTIES = cast(Properties, {"tensor_type": "numpy.ndarray"}) +CLIENT_STATUS = Status(code=Code.OK, message="OK") + + +def _make_task( + res: Union[GetParametersRes, GetPropertiesRes, FitRes, EvaluateRes] +) -> task_pb2.Task: # pylint: disable=E1101 + if isinstance(res, GetParametersRes): + task_type = TASK_TYPE_GET_PARAMETERS + recordset = compat.getparametersres_to_recordset(res, True) + elif isinstance(res, GetPropertiesRes): + task_type = TASK_TYPE_GET_PROPERTIES + recordset = compat.getpropertiesres_to_recordset(res) + elif isinstance(res, FitRes): + task_type = TASK_TYPE_FIT + recordset = compat.fitres_to_recordset(res, True) + elif isinstance(res, EvaluateRes): + task_type = TASK_TYPE_EVALUATE + recordset = compat.evaluateres_to_recordset(res) + else: + raise ValueError(f"Unsupported type: {type(res)}") + return task_pb2.Task( # pylint: disable=E1101 + task_type=task_type, + recordset=serde.recordset_to_proto(recordset), + ) class DriverClientProxyTestCase(unittest.TestCase): @@ -64,11 +104,9 @@ def test_get_properties(self) -> None: task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - get_properties_res=ClientMessage.GetPropertiesRes( - properties=CLIENT_PROPERTIES - ) + task=_make_task( + GetPropertiesRes( + status=CLIENT_STATUS, properties=CLIENT_PROPERTIES ) ), ) @@ -104,11 +142,10 @@ def test_get_parameters(self) -> None: task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - get_parameters_res=ClientMessage.GetParametersRes( - parameters=MESSAGE_PARAMETERS, - ) + task=_make_task( + GetParametersRes( + status=CLIENT_STATUS, + parameters=MESSAGE_PARAMETERS, ) ), ) @@ -143,12 +180,12 @@ def test_fit(self) -> None: task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - fit_res=ClientMessage.FitRes( - parameters=MESSAGE_PARAMETERS, - num_examples=10, - ) + task=_make_task( + FitRes( + status=CLIENT_STATUS, + parameters=MESSAGE_PARAMETERS, + num_examples=10, + metrics={}, ) ), ) @@ -184,11 +221,12 @@ def test_evaluate(self) -> None: task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - evaluate_res=ClientMessage.EvaluateRes( - loss=0.0, num_examples=0 - ) + task=_make_task( + EvaluateRes( + status=CLIENT_STATUS, + loss=0.0, + num_examples=0, + metrics={}, ) ), ) @@ -198,8 +236,8 @@ def test_evaluate(self) -> None: client = DriverClientProxy( node_id=1, driver=self.driver, anonymous=True, run_id=0 ) - parameters = flwr.common.Parameters(tensors=[], tensor_type="np") - evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {}) + parameters = Parameters(tensors=[], tensor_type="np") + evaluate_ins = EvaluateIns(parameters, {}) # Execute evaluate_res = client.evaluate(evaluate_ins, timeout=None) diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py index 7f9094625765..95d764792ff3 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/state/state_test.py @@ -23,11 +23,8 @@ from uuid import uuid4 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - ServerMessage, -) from flwr.server.state import InMemoryState, SqliteState, State @@ -421,9 +418,8 @@ def create_task_ins( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), consumer=consumer, - legacy_server_message=ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns() - ), + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task @@ -444,9 +440,8 @@ def create_task_res( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), ancestry=ancestry, - legacy_client_message=ClientMessage( - disconnect_res=ClientMessage.DisconnectRes() - ), + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task_res diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index 01dbcf982cce..f9b271beafdc 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -64,21 +64,10 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("non-anonymous consumer MUST provide a `node_id`") # Content check - has_fields = { - "sa": tasks_ins_res.task.HasField("sa"), - "legacy_server_message": tasks_ins_res.task.HasField( - "legacy_server_message" - ), - } - if not (has_fields["sa"] or has_fields["legacy_server_message"]): - err_msg = ", ".join([f"`{field}`" for field in has_fields]) - validation_errors.append( - f"`task` in `TaskIns` must set at least one of fields {{{err_msg}}}" - ) - if has_fields[ - "legacy_server_message" - ] and not tasks_ins_res.task.legacy_server_message.HasField("msg"): - validation_errors.append("`legacy_server_message` does not set field `msg`") + if tasks_ins_res.task.task_type == "": + validation_errors.append("`task_type` MUST be set") + if not tasks_ins_res.task.HasField("recordset"): + validation_errors.append("`recordset` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) != 0: @@ -115,21 +104,10 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("non-anonymous consumer MUST provide a `node_id`") # Content check - has_fields = { - "sa": tasks_ins_res.task.HasField("sa"), - "legacy_client_message": tasks_ins_res.task.HasField( - "legacy_client_message" - ), - } - if not (has_fields["sa"] or has_fields["legacy_client_message"]): - err_msg = ", ".join([f"`{field}`" for field in has_fields]) - validation_errors.append( - f"`task` in `TaskRes` must set at least one of fields {{{err_msg}}}" - ) - if has_fields[ - "legacy_client_message" - ] and not tasks_ins_res.task.legacy_client_message.HasField("msg"): - validation_errors.append("`legacy_client_message` does not set field `msg`") + if tasks_ins_res.task.task_type == "": + validation_errors.append("`task_type` MUST be set") + if not tasks_ins_res.task.HasField("recordset"): + validation_errors.append("`recordset` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) == 0: diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index a93e4fb4d457..8e0849508020 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -19,16 +19,8 @@ from typing import List, Tuple from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import ( # pylint: disable=E0611 - SecureAggregation, - Task, - TaskIns, - TaskRes, -) -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - ServerMessage, -) +from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from .validator import validate_task_ins_or_res @@ -45,16 +37,12 @@ def test_task_ins(self) -> None: # Execute & Assert for consumer_node_id, anonymous in valid_ins: - msg = create_task_ins( - consumer_node_id, anonymous, has_legacy_server_message=True - ) + msg = create_task_ins(consumer_node_id, anonymous) val_errors = validate_task_ins_or_res(msg) self.assertFalse(val_errors) for consumer_node_id, anonymous in invalid_ins: - msg = create_task_ins( - consumer_node_id, anonymous, has_legacy_server_message=True - ) + msg = create_task_ins(consumer_node_id, anonymous) val_errors = validate_task_ins_or_res(msg) self.assertTrue(val_errors) @@ -78,61 +66,19 @@ def test_is_valid_task_res(self) -> None: # Execute & Assert for producer_node_id, anonymous, ancestry in valid_res: - msg = create_task_res( - producer_node_id, anonymous, ancestry, has_legacy_client_message=True - ) + msg = create_task_res(producer_node_id, anonymous, ancestry) val_errors = validate_task_ins_or_res(msg) self.assertFalse(val_errors) for producer_node_id, anonymous, ancestry in invalid_res: - msg = create_task_res( - producer_node_id, anonymous, ancestry, has_legacy_client_message=True - ) + msg = create_task_res(producer_node_id, anonymous, ancestry) val_errors = validate_task_ins_or_res(msg) self.assertTrue(val_errors, (producer_node_id, anonymous, ancestry)) - def test_task_ins_secure_aggregation(self) -> None: - """Test is_valid task_ins for Secure Aggregation.""" - # Prepare - # (has_legacy_server_message, has_sa) - valid_ins = [(True, True), (False, True)] - invalid_ins = [(False, False)] - - # Execute & Assert - for has_legacy_server_message, has_sa in valid_ins: - msg = create_task_ins(1, False, has_legacy_server_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertFalse(val_errors) - - for has_legacy_server_message, has_sa in invalid_ins: - msg = create_task_ins(1, False, has_legacy_server_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertTrue(val_errors) - - def test_task_res_secure_aggregation(self) -> None: - """Test is_valid task_res for Secure Aggregation.""" - # Prepare - # (has_legacy_server_message, has_sa) - valid_res = [(True, True), (False, True)] - invalid_res = [(False, False)] - - # Execute & Assert - for has_legacy_client_message, has_sa in valid_res: - msg = create_task_res(0, True, ["1"], has_legacy_client_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertFalse(val_errors) - - for has_legacy_client_message, has_sa in invalid_res: - msg = create_task_res(0, True, ["1"], has_legacy_client_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertTrue(val_errors) - def create_task_ins( consumer_node_id: int, anonymous: bool, - has_legacy_server_message: bool = False, - has_sa: bool = False, delivered_at: str = "", ) -> TaskIns: """Create a TaskIns for testing.""" @@ -148,12 +94,8 @@ def create_task_ins( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), consumer=consumer, - legacy_server_message=ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns() - ) - if has_legacy_server_message - else None, - sa=SecureAggregation(named_values={}) if has_sa else None, + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task @@ -163,8 +105,6 @@ def create_task_res( producer_node_id: int, anonymous: bool, ancestry: List[str], - has_legacy_client_message: bool = False, - has_sa: bool = False, ) -> TaskRes: """Create a TaskRes for testing.""" task_res = TaskRes( @@ -175,12 +115,8 @@ def create_task_res( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), ancestry=ancestry, - legacy_client_message=ClientMessage( - disconnect_res=ClientMessage.DisconnectRes() - ) - if has_legacy_client_message - else None, - sa=SecureAggregation(named_values={}) if has_sa else None, + task_type="mock", + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task_res