From dc3f8f0a233ef3d1f66324a8867592fe65eed7e8 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 25 Jan 2024 15:21:22 +0000 Subject: [PATCH 1/7] update unittests --- src/py/flwr/driver/driver_client_proxy.py | 83 ++++++++----------- .../flwr/driver/driver_client_proxy_test.py | 76 ++++++++++++----- 2 files changed, 90 insertions(+), 69 deletions(-) diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/driver/driver_client_proxy.py index 8b2e51c17ea0..483d17fbc9a9 100644 --- a/src/py/flwr/driver/driver_client_proxy.py +++ b/src/py/flwr/driver/driver_client_proxy.py @@ -16,16 +16,13 @@ import time -from typing import List, Optional, cast +from typing import List, Optional from flwr import common +from flwr.common import recordset_compat as compat from flwr.common import serde -from flwr.proto import ( # pylint: disable=E0611 - driver_pb2, - node_pb2, - task_pb2, - transport_pb2, -) +from flwr.common.recordset import RecordSet +from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 from flwr.server.client_proxy import ClientProxy from .grpc_driver import GrpcDriver @@ -47,55 +44,43 @@ def get_properties( self, ins: common.GetPropertiesIns, timeout: Optional[float] ) -> common.GetPropertiesRes: """Return client's properties.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(get_properties_ins=ins) - ) - ) - return cast( - common.GetPropertiesRes, - self._send_receive_msg(server_message_proto, timeout).get_properties_res, - ) + # Ins to RecordSet + out_recordset = compat.getpropertiesins_to_recordset(ins) + # Fetch response + in_recordset = self._send_receive_recordset(out_recordset, timeout) + # RecordSet to Res + return compat.recordset_to_getpropertiesres(in_recordset) def get_parameters( self, ins: common.GetParametersIns, timeout: Optional[float] ) -> common.GetParametersRes: """Return the current local model parameters.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(get_parameters_ins=ins) - ) - ) - return cast( - common.GetParametersRes, - self._send_receive_msg(server_message_proto, timeout).get_parameters_res, - ) + # Ins to RecordSet + out_recordset = compat.getparametersins_to_recordset(ins) + # Fetch response + in_recordset = self._send_receive_recordset(out_recordset, timeout) + # RecordSet to Res + return compat.recordset_to_getparametersres(in_recordset) def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: """Train model parameters on the locally held dataset.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(fit_ins=ins) - ) - ) - return cast( - common.FitRes, - self._send_receive_msg(server_message_proto, timeout).fit_res, - ) + # Ins to RecordSet + out_recordset = compat.fitins_to_recordset(ins, keep_input=False) + # Fetch response + in_recordset = self._send_receive_recordset(out_recordset, timeout) + # RecordSet to Res + return compat.recordset_to_fitres(in_recordset, keep_input=False) def evaluate( self, ins: common.EvaluateIns, timeout: Optional[float] ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" - server_message_proto: transport_pb2.ServerMessage = ( # pylint: disable=E1101 - serde.server_message_to_proto( - server_message=common.ServerMessage(evaluate_ins=ins) - ) - ) - return cast( - common.EvaluateRes, - self._send_receive_msg(server_message_proto, timeout).evaluate_res, - ) + # Ins to RecordSet + out_recordset = compat.evaluateins_to_recordset(ins, keep_input=False) + # Fetch response + in_recordset = self._send_receive_recordset(out_recordset, timeout) + # RecordSet to Res + return compat.recordset_to_evaluateres(in_recordset) def reconnect( self, ins: common.ReconnectIns, timeout: Optional[float] @@ -103,11 +88,11 @@ def reconnect( """Disconnect and (optionally) reconnect later.""" return common.DisconnectRes(reason="") # Nothing to do here (yet) - def _send_receive_msg( + def _send_receive_recordset( self, - server_message: transport_pb2.ServerMessage, # pylint: disable=E1101 + recordset: RecordSet, timeout: Optional[float], - ) -> transport_pb2.ClientMessage: # pylint: disable=E1101 + ) -> RecordSet: task_ins = task_pb2.TaskIns( # pylint: disable=E1101 task_id="", group_id="", @@ -121,7 +106,7 @@ def _send_receive_msg( node_id=self.node_id, anonymous=self.anonymous, ), - legacy_server_message=server_message, + recordset=serde.recordset_to_proto(recordset), ), ) push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101 @@ -155,9 +140,7 @@ def _send_receive_msg( ) if len(task_res_list) == 1: task_res = task_res_list[0] - return serde.client_message_from_proto( # type: ignore - task_res.task.legacy_client_message - ) + return serde.recordset_from_proto(task_res.task.recordset) if timeout is not None and time.time() > start_time + timeout: raise RuntimeError("Timeout reached") diff --git a/src/py/flwr/driver/driver_client_proxy_test.py b/src/py/flwr/driver/driver_client_proxy_test.py index d3cab152e4db..132bec494710 100644 --- a/src/py/flwr/driver/driver_client_proxy_test.py +++ b/src/py/flwr/driver/driver_client_proxy_test.py @@ -16,23 +16,55 @@ import unittest +from typing import Union, cast from unittest.mock import MagicMock import numpy as np import flwr -from flwr.common.typing import Config, GetParametersIns -from flwr.driver.driver_client_proxy import DriverClientProxy -from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, +from flwr.common import recordset_compat as compat +from flwr.common import serde +from flwr.common.typing import ( + Code, + Config, + EvaluateIns, + EvaluateRes, + FitRes, + GetParametersIns, + GetParametersRes, + GetPropertiesRes, Parameters, - Scalar, + Properties, + Status, +) +from flwr.driver.driver_client_proxy import DriverClientProxy +from flwr.proto import ( # pylint: disable=E0611 + driver_pb2, + node_pb2, + recordset_pb2, + task_pb2, ) MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np") -CLIENT_PROPERTIES = {"tensor_type": Scalar(string="numpy.ndarray")} +CLIENT_PROPERTIES = cast(Properties, {"tensor_type": "numpy.ndarray"}) +CLIENT_STATUS = Status(code=Code.OK, message="OK") + + +def _make_recordset_proto( + res: Union[GetParametersRes, GetPropertiesRes, FitRes, EvaluateRes] +) -> recordset_pb2.RecordSet: # pylint: disable=E1101 + if isinstance(res, GetParametersRes): + recordset = compat.getparametersres_to_recordset(res) + elif isinstance(res, GetPropertiesRes): + recordset = compat.getpropertiesres_to_recordset(res) + elif isinstance(res, FitRes): + recordset = compat.fitres_to_recordset(res, keep_input=True) + elif isinstance(res, EvaluateRes): + recordset = compat.evaluateres_to_recordset(res) + else: + raise ValueError(f"Unsupported type: {type(res)}") + return serde.recordset_to_proto(recordset) class DriverClientProxyTestCase(unittest.TestCase): @@ -65,9 +97,9 @@ def test_get_properties(self) -> None: group_id="", run_id=0, task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - get_properties_res=ClientMessage.GetPropertiesRes( - properties=CLIENT_PROPERTIES + recordset=_make_recordset_proto( + GetPropertiesRes( + status=CLIENT_STATUS, properties=CLIENT_PROPERTIES ) ) ), @@ -105,8 +137,9 @@ def test_get_parameters(self) -> None: group_id="", run_id=0, task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - get_parameters_res=ClientMessage.GetParametersRes( + recordset=_make_recordset_proto( + GetParametersRes( + status=CLIENT_STATUS, parameters=MESSAGE_PARAMETERS, ) ) @@ -144,10 +177,12 @@ def test_fit(self) -> None: group_id="", run_id=0, task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - fit_res=ClientMessage.FitRes( + recordset=_make_recordset_proto( + FitRes( + status=CLIENT_STATUS, parameters=MESSAGE_PARAMETERS, num_examples=10, + metrics={}, ) ) ), @@ -185,9 +220,12 @@ def test_evaluate(self) -> None: group_id="", run_id=0, task=task_pb2.Task( # pylint: disable=E1101 - legacy_client_message=ClientMessage( - evaluate_res=ClientMessage.EvaluateRes( - loss=0.0, num_examples=0 + recordset=_make_recordset_proto( + EvaluateRes( + status=CLIENT_STATUS, + loss=0.0, + num_examples=0, + metrics={}, ) ) ), @@ -198,8 +236,8 @@ def test_evaluate(self) -> None: client = DriverClientProxy( node_id=1, driver=self.driver, anonymous=True, run_id=0 ) - parameters = flwr.common.Parameters(tensors=[], tensor_type="np") - evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {}) + parameters = Parameters(tensors=[], tensor_type="np") + evaluate_ins = EvaluateIns(parameters, {}) # Execute evaluate_res = client.evaluate(evaluate_ins, timeout=None) From c4d5671891dfcb025f3be8eccb17f87d540ef8e5 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 25 Jan 2024 15:29:56 +0000 Subject: [PATCH 2/7] add task_type --- src/py/flwr/driver/driver_client_proxy.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/driver/driver_client_proxy.py index 483d17fbc9a9..dbba0046eb96 100644 --- a/src/py/flwr/driver/driver_client_proxy.py +++ b/src/py/flwr/driver/driver_client_proxy.py @@ -47,7 +47,9 @@ def get_properties( # Ins to RecordSet out_recordset = compat.getpropertiesins_to_recordset(ins) # Fetch response - in_recordset = self._send_receive_recordset(out_recordset, timeout) + in_recordset = self._send_receive_recordset( + out_recordset, "get_properties_ins", timeout + ) # RecordSet to Res return compat.recordset_to_getpropertiesres(in_recordset) @@ -58,7 +60,9 @@ def get_parameters( # Ins to RecordSet out_recordset = compat.getparametersins_to_recordset(ins) # Fetch response - in_recordset = self._send_receive_recordset(out_recordset, timeout) + in_recordset = self._send_receive_recordset( + out_recordset, "get_parameters_ins", timeout + ) # RecordSet to Res return compat.recordset_to_getparametersres(in_recordset) @@ -67,7 +71,7 @@ def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: # Ins to RecordSet out_recordset = compat.fitins_to_recordset(ins, keep_input=False) # Fetch response - in_recordset = self._send_receive_recordset(out_recordset, timeout) + in_recordset = self._send_receive_recordset(out_recordset, "fit_ins", timeout) # RecordSet to Res return compat.recordset_to_fitres(in_recordset, keep_input=False) @@ -78,7 +82,9 @@ def evaluate( # Ins to RecordSet out_recordset = compat.evaluateins_to_recordset(ins, keep_input=False) # Fetch response - in_recordset = self._send_receive_recordset(out_recordset, timeout) + in_recordset = self._send_receive_recordset( + out_recordset, "evaluate_ins", timeout + ) # RecordSet to Res return compat.recordset_to_evaluateres(in_recordset) @@ -91,6 +97,7 @@ def reconnect( def _send_receive_recordset( self, recordset: RecordSet, + task_type: str, timeout: Optional[float], ) -> RecordSet: task_ins = task_pb2.TaskIns( # pylint: disable=E1101 @@ -106,6 +113,7 @@ def _send_receive_recordset( node_id=self.node_id, anonymous=self.anonymous, ), + task_type=task_type, recordset=serde.recordset_to_proto(recordset), ), ) From 2e8d911ece60f38eee1178cf3a9ac83fcf9c4bd7 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Fri, 26 Jan 2024 21:33:37 +0000 Subject: [PATCH 3/7] update names and validator.py --- src/py/flwr/driver/driver_client_proxy.py | 18 +++++++--- .../flwr/driver/driver_client_proxy_test.py | 4 +-- src/py/flwr/server/utils/validator.py | 34 +++---------------- 3 files changed, 19 insertions(+), 37 deletions(-) diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/driver/driver_client_proxy.py index dbba0046eb96..b691c6416728 100644 --- a/src/py/flwr/driver/driver_client_proxy.py +++ b/src/py/flwr/driver/driver_client_proxy.py @@ -21,6 +21,12 @@ from flwr import common from flwr.common import recordset_compat as compat from flwr.common import serde +from flwr.common.constant import ( + TASK_TYPE_EVALUATE, + TASK_TYPE_FIT, + TASK_TYPE_GET_PARAMETERS, + TASK_TYPE_GET_PROPERTIES, +) from flwr.common.recordset import RecordSet from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 from flwr.server.client_proxy import ClientProxy @@ -48,7 +54,7 @@ def get_properties( out_recordset = compat.getpropertiesins_to_recordset(ins) # Fetch response in_recordset = self._send_receive_recordset( - out_recordset, "get_properties_ins", timeout + out_recordset, TASK_TYPE_GET_PROPERTIES, timeout ) # RecordSet to Res return compat.recordset_to_getpropertiesres(in_recordset) @@ -61,17 +67,19 @@ def get_parameters( out_recordset = compat.getparametersins_to_recordset(ins) # Fetch response in_recordset = self._send_receive_recordset( - out_recordset, "get_parameters_ins", timeout + out_recordset, TASK_TYPE_GET_PARAMETERS, timeout ) # RecordSet to Res - return compat.recordset_to_getparametersres(in_recordset) + return compat.recordset_to_getparametersres(in_recordset, False) def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: """Train model parameters on the locally held dataset.""" # Ins to RecordSet out_recordset = compat.fitins_to_recordset(ins, keep_input=False) # Fetch response - in_recordset = self._send_receive_recordset(out_recordset, "fit_ins", timeout) + in_recordset = self._send_receive_recordset( + out_recordset, TASK_TYPE_FIT, timeout + ) # RecordSet to Res return compat.recordset_to_fitres(in_recordset, keep_input=False) @@ -83,7 +91,7 @@ def evaluate( out_recordset = compat.evaluateins_to_recordset(ins, keep_input=False) # Fetch response in_recordset = self._send_receive_recordset( - out_recordset, "evaluate_ins", timeout + out_recordset, TASK_TYPE_EVALUATE, timeout ) # RecordSet to Res return compat.recordset_to_evaluateres(in_recordset) diff --git a/src/py/flwr/driver/driver_client_proxy_test.py b/src/py/flwr/driver/driver_client_proxy_test.py index 132bec494710..a76ef1c97cf5 100644 --- a/src/py/flwr/driver/driver_client_proxy_test.py +++ b/src/py/flwr/driver/driver_client_proxy_test.py @@ -55,11 +55,11 @@ def _make_recordset_proto( res: Union[GetParametersRes, GetPropertiesRes, FitRes, EvaluateRes] ) -> recordset_pb2.RecordSet: # pylint: disable=E1101 if isinstance(res, GetParametersRes): - recordset = compat.getparametersres_to_recordset(res) + recordset = compat.getparametersres_to_recordset(res, True) elif isinstance(res, GetPropertiesRes): recordset = compat.getpropertiesres_to_recordset(res) elif isinstance(res, FitRes): - recordset = compat.fitres_to_recordset(res, keep_input=True) + recordset = compat.fitres_to_recordset(res, True) elif isinstance(res, EvaluateRes): recordset = compat.evaluateres_to_recordset(res) else: diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index 01dbcf982cce..1b39a8388a4c 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -64,21 +64,8 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("non-anonymous consumer MUST provide a `node_id`") # Content check - has_fields = { - "sa": tasks_ins_res.task.HasField("sa"), - "legacy_server_message": tasks_ins_res.task.HasField( - "legacy_server_message" - ), - } - if not (has_fields["sa"] or has_fields["legacy_server_message"]): - err_msg = ", ".join([f"`{field}`" for field in has_fields]) - validation_errors.append( - f"`task` in `TaskIns` must set at least one of fields {{{err_msg}}}" - ) - if has_fields[ - "legacy_server_message" - ] and not tasks_ins_res.task.legacy_server_message.HasField("msg"): - validation_errors.append("`legacy_server_message` does not set field `msg`") + if not tasks_ins_res.task.HasField("recordset"): + validation_errors.append("`recordset` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) != 0: @@ -115,21 +102,8 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("non-anonymous consumer MUST provide a `node_id`") # Content check - has_fields = { - "sa": tasks_ins_res.task.HasField("sa"), - "legacy_client_message": tasks_ins_res.task.HasField( - "legacy_client_message" - ), - } - if not (has_fields["sa"] or has_fields["legacy_client_message"]): - err_msg = ", ".join([f"`{field}`" for field in has_fields]) - validation_errors.append( - f"`task` in `TaskRes` must set at least one of fields {{{err_msg}}}" - ) - if has_fields[ - "legacy_client_message" - ] and not tasks_ins_res.task.legacy_client_message.HasField("msg"): - validation_errors.append("`legacy_client_message` does not set field `msg`") + if not tasks_ins_res.task.HasField("recordset"): + validation_errors.append("`recordset` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) == 0: From 175337ab01fe90b077bb975aae3eff16ab6b18a3 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Sat, 27 Jan 2024 11:03:00 +0000 Subject: [PATCH 4/7] keep input when *ins -> recordset --- src/py/flwr/driver/driver_client_proxy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/driver/driver_client_proxy.py index b691c6416728..e0ff26c035f7 100644 --- a/src/py/flwr/driver/driver_client_proxy.py +++ b/src/py/flwr/driver/driver_client_proxy.py @@ -75,7 +75,7 @@ def get_parameters( def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes: """Train model parameters on the locally held dataset.""" # Ins to RecordSet - out_recordset = compat.fitins_to_recordset(ins, keep_input=False) + out_recordset = compat.fitins_to_recordset(ins, keep_input=True) # Fetch response in_recordset = self._send_receive_recordset( out_recordset, TASK_TYPE_FIT, timeout @@ -88,7 +88,7 @@ def evaluate( ) -> common.EvaluateRes: """Evaluate model parameters on the locally held dataset.""" # Ins to RecordSet - out_recordset = compat.evaluateins_to_recordset(ins, keep_input=False) + out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True) # Fetch response in_recordset = self._send_receive_recordset( out_recordset, TASK_TYPE_EVALUATE, timeout From d5aef366f9102bd089f3847c60d88f3fab3bb099 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Sat, 27 Jan 2024 11:15:53 +0000 Subject: [PATCH 5/7] fix unittests --- src/py/flwr/server/state/state_test.py | 13 +--- src/py/flwr/server/utils/validator_test.py | 82 +++------------------- 2 files changed, 11 insertions(+), 84 deletions(-) diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py index 7f9094625765..fa654ff26ab0 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/state/state_test.py @@ -23,11 +23,8 @@ from uuid import uuid4 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - ServerMessage, -) from flwr.server.state import InMemoryState, SqliteState, State @@ -421,9 +418,7 @@ def create_task_ins( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), consumer=consumer, - legacy_server_message=ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns() - ), + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task @@ -444,9 +439,7 @@ def create_task_res( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), ancestry=ancestry, - legacy_client_message=ClientMessage( - disconnect_res=ClientMessage.DisconnectRes() - ), + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task_res diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index a93e4fb4d457..e6762c1e5ac0 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -19,16 +19,8 @@ from typing import List, Tuple from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import ( # pylint: disable=E0611 - SecureAggregation, - Task, - TaskIns, - TaskRes, -) -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - ServerMessage, -) +from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from .validator import validate_task_ins_or_res @@ -45,16 +37,12 @@ def test_task_ins(self) -> None: # Execute & Assert for consumer_node_id, anonymous in valid_ins: - msg = create_task_ins( - consumer_node_id, anonymous, has_legacy_server_message=True - ) + msg = create_task_ins(consumer_node_id, anonymous) val_errors = validate_task_ins_or_res(msg) self.assertFalse(val_errors) for consumer_node_id, anonymous in invalid_ins: - msg = create_task_ins( - consumer_node_id, anonymous, has_legacy_server_message=True - ) + msg = create_task_ins(consumer_node_id, anonymous) val_errors = validate_task_ins_or_res(msg) self.assertTrue(val_errors) @@ -78,61 +66,19 @@ def test_is_valid_task_res(self) -> None: # Execute & Assert for producer_node_id, anonymous, ancestry in valid_res: - msg = create_task_res( - producer_node_id, anonymous, ancestry, has_legacy_client_message=True - ) + msg = create_task_res(producer_node_id, anonymous, ancestry) val_errors = validate_task_ins_or_res(msg) self.assertFalse(val_errors) for producer_node_id, anonymous, ancestry in invalid_res: - msg = create_task_res( - producer_node_id, anonymous, ancestry, has_legacy_client_message=True - ) + msg = create_task_res(producer_node_id, anonymous, ancestry) val_errors = validate_task_ins_or_res(msg) self.assertTrue(val_errors, (producer_node_id, anonymous, ancestry)) - def test_task_ins_secure_aggregation(self) -> None: - """Test is_valid task_ins for Secure Aggregation.""" - # Prepare - # (has_legacy_server_message, has_sa) - valid_ins = [(True, True), (False, True)] - invalid_ins = [(False, False)] - - # Execute & Assert - for has_legacy_server_message, has_sa in valid_ins: - msg = create_task_ins(1, False, has_legacy_server_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertFalse(val_errors) - - for has_legacy_server_message, has_sa in invalid_ins: - msg = create_task_ins(1, False, has_legacy_server_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertTrue(val_errors) - - def test_task_res_secure_aggregation(self) -> None: - """Test is_valid task_res for Secure Aggregation.""" - # Prepare - # (has_legacy_server_message, has_sa) - valid_res = [(True, True), (False, True)] - invalid_res = [(False, False)] - - # Execute & Assert - for has_legacy_client_message, has_sa in valid_res: - msg = create_task_res(0, True, ["1"], has_legacy_client_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertFalse(val_errors) - - for has_legacy_client_message, has_sa in invalid_res: - msg = create_task_res(0, True, ["1"], has_legacy_client_message, has_sa) - val_errors = validate_task_ins_or_res(msg) - self.assertTrue(val_errors) - def create_task_ins( consumer_node_id: int, anonymous: bool, - has_legacy_server_message: bool = False, - has_sa: bool = False, delivered_at: str = "", ) -> TaskIns: """Create a TaskIns for testing.""" @@ -148,12 +94,7 @@ def create_task_ins( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), consumer=consumer, - legacy_server_message=ServerMessage( - reconnect_ins=ServerMessage.ReconnectIns() - ) - if has_legacy_server_message - else None, - sa=SecureAggregation(named_values={}) if has_sa else None, + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task @@ -163,8 +104,6 @@ def create_task_res( producer_node_id: int, anonymous: bool, ancestry: List[str], - has_legacy_client_message: bool = False, - has_sa: bool = False, ) -> TaskRes: """Create a TaskRes for testing.""" task_res = TaskRes( @@ -175,12 +114,7 @@ def create_task_res( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), ancestry=ancestry, - legacy_client_message=ClientMessage( - disconnect_res=ClientMessage.DisconnectRes() - ) - if has_legacy_client_message - else None, - sa=SecureAggregation(named_values={}) if has_sa else None, + recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) return task_res From 931e419747139e138ba08584226405425fdf54d9 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 29 Jan 2024 09:48:16 +0000 Subject: [PATCH 6/7] add task_type checks to validator.py --- .../flwr/driver/driver_client_proxy_test.py | 72 +++++++++---------- src/py/flwr/server/state/state_test.py | 2 + src/py/flwr/server/utils/validator.py | 4 ++ src/py/flwr/server/utils/validator_test.py | 2 + 4 files changed, 44 insertions(+), 36 deletions(-) diff --git a/src/py/flwr/driver/driver_client_proxy_test.py b/src/py/flwr/driver/driver_client_proxy_test.py index a76ef1c97cf5..4e9a02a6cbf9 100644 --- a/src/py/flwr/driver/driver_client_proxy_test.py +++ b/src/py/flwr/driver/driver_client_proxy_test.py @@ -24,6 +24,12 @@ import flwr from flwr.common import recordset_compat as compat from flwr.common import serde +from flwr.common.constant import ( + TASK_TYPE_EVALUATE, + TASK_TYPE_FIT, + TASK_TYPE_GET_PARAMETERS, + TASK_TYPE_GET_PROPERTIES, +) from flwr.common.typing import ( Code, Config, @@ -38,12 +44,7 @@ Status, ) from flwr.driver.driver_client_proxy import DriverClientProxy -from flwr.proto import ( # pylint: disable=E0611 - driver_pb2, - node_pb2, - recordset_pb2, - task_pb2, -) +from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np") @@ -51,20 +52,27 @@ CLIENT_STATUS = Status(code=Code.OK, message="OK") -def _make_recordset_proto( +def _make_task( res: Union[GetParametersRes, GetPropertiesRes, FitRes, EvaluateRes] -) -> recordset_pb2.RecordSet: # pylint: disable=E1101 +) -> task_pb2.Task: # pylint: disable=E1101 if isinstance(res, GetParametersRes): + task_type = TASK_TYPE_GET_PARAMETERS recordset = compat.getparametersres_to_recordset(res, True) elif isinstance(res, GetPropertiesRes): + task_type = TASK_TYPE_GET_PROPERTIES recordset = compat.getpropertiesres_to_recordset(res) elif isinstance(res, FitRes): + task_type = TASK_TYPE_FIT recordset = compat.fitres_to_recordset(res, True) elif isinstance(res, EvaluateRes): + task_type = TASK_TYPE_EVALUATE recordset = compat.evaluateres_to_recordset(res) else: raise ValueError(f"Unsupported type: {type(res)}") - return serde.recordset_to_proto(recordset) + return task_pb2.Task( # pylint: disable=E1101 + task_type=task_type, + recordset=serde.recordset_to_proto(recordset), + ) class DriverClientProxyTestCase(unittest.TestCase): @@ -96,11 +104,9 @@ def test_get_properties(self) -> None: task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - recordset=_make_recordset_proto( - GetPropertiesRes( - status=CLIENT_STATUS, properties=CLIENT_PROPERTIES - ) + task=_make_task( + GetPropertiesRes( + status=CLIENT_STATUS, properties=CLIENT_PROPERTIES ) ), ) @@ -136,12 +142,10 @@ def test_get_parameters(self) -> None: task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - recordset=_make_recordset_proto( - GetParametersRes( - status=CLIENT_STATUS, - parameters=MESSAGE_PARAMETERS, - ) + task=_make_task( + GetParametersRes( + status=CLIENT_STATUS, + parameters=MESSAGE_PARAMETERS, ) ), ) @@ -176,14 +180,12 @@ def test_fit(self) -> None: task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - recordset=_make_recordset_proto( - FitRes( - status=CLIENT_STATUS, - parameters=MESSAGE_PARAMETERS, - num_examples=10, - metrics={}, - ) + task=_make_task( + FitRes( + status=CLIENT_STATUS, + parameters=MESSAGE_PARAMETERS, + num_examples=10, + metrics={}, ) ), ) @@ -219,14 +221,12 @@ def test_evaluate(self) -> None: task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", run_id=0, - task=task_pb2.Task( # pylint: disable=E1101 - recordset=_make_recordset_proto( - EvaluateRes( - status=CLIENT_STATUS, - loss=0.0, - num_examples=0, - metrics={}, - ) + task=_make_task( + EvaluateRes( + status=CLIENT_STATUS, + loss=0.0, + num_examples=0, + metrics={}, ) ), ) diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py index fa654ff26ab0..95d764792ff3 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/state/state_test.py @@ -418,6 +418,7 @@ def create_task_ins( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), consumer=consumer, + task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) @@ -439,6 +440,7 @@ def create_task_res( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), ancestry=ancestry, + task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index 1b39a8388a4c..f9b271beafdc 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -64,6 +64,8 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("non-anonymous consumer MUST provide a `node_id`") # Content check + if tasks_ins_res.task.task_type == "": + validation_errors.append("`task_type` MUST be set") if not tasks_ins_res.task.HasField("recordset"): validation_errors.append("`recordset` MUST be set") @@ -102,6 +104,8 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("non-anonymous consumer MUST provide a `node_id`") # Content check + if tasks_ins_res.task.task_type == "": + validation_errors.append("`task_type` MUST be set") if not tasks_ins_res.task.HasField("recordset"): validation_errors.append("`recordset` MUST be set") diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index e6762c1e5ac0..8e0849508020 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -94,6 +94,7 @@ def create_task_ins( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), consumer=consumer, + task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) @@ -114,6 +115,7 @@ def create_task_res( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), ancestry=ancestry, + task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), ), ) From 8e8e7d73fa21e059112055e4351d4eab8f93076f Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 29 Jan 2024 10:26:50 +0000 Subject: [PATCH 7/7] misc --- src/py/flwr/client/message_handler/message_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index b9abf9bf142b..ea87c35c83f7 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -72,7 +72,7 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: Returns ------- task_res : Optional[TaskRes] - TaskRes to be returned to the server. If None, the client should + TaskRes to be sent back to the server. If None, the client should continue to process messages from the server. sleep_duration : int Number of seconds that the client should disconnect from the server.