From 99f27a460b68b9a9d82ae6c88f7b9ff9ca72281d Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 23 Nov 2023 22:31:28 +0000 Subject: [PATCH 1/8] prep; #TODO for next actions --- src/py/flwr/client/flower.py | 5 ++-- .../client/message_handler/message_handler.py | 28 +++++++++++++------ 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py index 9eeb41887e24..f0d4ce122524 100644 --- a/src/py/flwr/client/flower.py +++ b/src/py/flwr/client/flower.py @@ -79,13 +79,14 @@ def __init__( def __call__(self, fwd: Fwd) -> Bwd: """.""" # Execute the task - task_res = handle( + task_res, state_updated = handle( client_fn=self.client_fn, + state=fwd.state, task_ins=fwd.task_ins, ) return Bwd( task_res=task_res, - state=WorkloadState(state={}), + state=state_updated, ) diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 971f2fe56e31..c18340c14b4a 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -33,6 +33,7 @@ from flwr.common import serde from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage +from flwr.client.workload_state import WorkloadState class UnexpectedServerMessage(Exception): @@ -77,13 +78,15 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: return None, 0 -def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes: +def handle(client_fn: ClientFn, state: WorkloadState, task_ins: TaskIns) -> TaskRes: """Handle incoming TaskIns from the server. Parameters ---------- client_fn : ClientFn A callable that instantiates a Client. + state : WorkloadState + A dataclass storing the state for the workload being executed by the client. task_ins: TaskIns The task instruction coming from the server, to be processed by the client. @@ -96,6 +99,7 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes: if server_msg is None: # Instantiate the client client = client_fn("-1") + # TODO: inject state into client object # Secure Aggregation if task_ins.task.HasField("sa") and isinstance( client, SecureAggregationHandler @@ -112,15 +116,15 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes: sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), ), ) - return task_res + return task_res # TODO: return updated state raise NotImplementedError() - client_msg = handle_legacy_message(client_fn, server_msg) + client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg) task_res = wrap_client_message_in_task_res(client_msg) - return task_res + return task_res, updated_state def handle_legacy_message( - client_fn: ClientFn, server_msg: ServerMessage + client_fn: ClientFn, state: WorkloadState, server_msg: ServerMessage ) -> ClientMessage: """Handle incoming messages from the server. @@ -128,6 +132,8 @@ def handle_legacy_message( ---------- client_fn : ClientFn A callable that instantiates a Client. + state : WorkloadState + A dataclass storing the state for the workload being executed by the client. server_msg: ServerMessage The message coming from the server, to be processed by the client. @@ -144,15 +150,19 @@ def handle_legacy_message( # Instantiate the client client = client_fn("-1") + # TODO: inject state into client object # Execute task + message = None if field == "get_properties_ins": - return _get_properties(client, server_msg.get_properties_ins) + message = _get_properties(client, server_msg.get_properties_ins) if field == "get_parameters_ins": - return _get_parameters(client, server_msg.get_parameters_ins) + message = _get_parameters(client, server_msg.get_parameters_ins) if field == "fit_ins": - return _fit(client, server_msg.fit_ins) + message = _fit(client, server_msg.fit_ins) if field == "evaluate_ins": - return _evaluate(client, server_msg.evaluate_ins) + message = _evaluate(client, server_msg.evaluate_ins) + if message: + return message # TODO: return updated state raise UnknownServerMessage() From a7babb74e804fee942f3a6bed4d66383b168c8f6 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 23 Nov 2023 23:02:34 +0000 Subject: [PATCH 2/8] propagating --- src/py/flwr/client/client.py | 3 +++ .../client/message_handler/message_handler.py | 16 +++++++++------- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/py/flwr/client/client.py b/src/py/flwr/client/client.py index 03769c6b8bcf..d221791b655e 100644 --- a/src/py/flwr/client/client.py +++ b/src/py/flwr/client/client.py @@ -19,6 +19,7 @@ from abc import ABC +from flwr.client.workload_state import WorkloadState from flwr.common import ( Code, EvaluateIns, @@ -37,6 +38,8 @@ class Client(ABC): """Abstract base class for Flower clients.""" + state: WorkloadState + def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: """Return set of client's properties. diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index c18340c14b4a..d240cdcda5d7 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -30,10 +30,10 @@ ) from flwr.client.secure_aggregation import SecureAggregationHandler from flwr.client.typing import ClientFn +from flwr.client.workload_state import WorkloadState from flwr.common import serde from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage -from flwr.client.workload_state import WorkloadState class UnexpectedServerMessage(Exception): @@ -78,7 +78,9 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: return None, 0 -def handle(client_fn: ClientFn, state: WorkloadState, task_ins: TaskIns) -> TaskRes: +def handle( + client_fn: ClientFn, state: WorkloadState, task_ins: TaskIns +) -> Tuple[TaskRes, WorkloadState]: """Handle incoming TaskIns from the server. Parameters @@ -99,7 +101,7 @@ def handle(client_fn: ClientFn, state: WorkloadState, task_ins: TaskIns) -> Task if server_msg is None: # Instantiate the client client = client_fn("-1") - # TODO: inject state into client object + client.state = state # TODO: inject state into client object # Secure Aggregation if task_ins.task.HasField("sa") and isinstance( client, SecureAggregationHandler @@ -116,7 +118,7 @@ def handle(client_fn: ClientFn, state: WorkloadState, task_ins: TaskIns) -> Task sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), ), ) - return task_res # TODO: return updated state + return task_res, client.state # TODO: return updated state raise NotImplementedError() client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg) task_res = wrap_client_message_in_task_res(client_msg) @@ -125,7 +127,7 @@ def handle(client_fn: ClientFn, state: WorkloadState, task_ins: TaskIns) -> Task def handle_legacy_message( client_fn: ClientFn, state: WorkloadState, server_msg: ServerMessage -) -> ClientMessage: +) -> Tuple[ClientMessage, WorkloadState]: """Handle incoming messages from the server. Parameters @@ -150,7 +152,7 @@ def handle_legacy_message( # Instantiate the client client = client_fn("-1") - # TODO: inject state into client object + client.state = state # TODO: inject state into client object # Execute task message = None if field == "get_properties_ins": @@ -162,7 +164,7 @@ def handle_legacy_message( if field == "evaluate_ins": message = _evaluate(client, server_msg.evaluate_ins) if message: - return message # TODO: return updated state + return message, client.state # TODO: return updated state raise UnknownServerMessage() From bb5644f26e0ececd76d3e2baa9912e8518cc284b Mon Sep 17 00:00:00 2001 From: jafermarq Date: Fri, 24 Nov 2023 09:56:00 +0000 Subject: [PATCH 3/8] tests; removed #TODO; v0 ready? --- .../flwr/client/message_handler/message_handler.py | 8 ++++---- .../client/message_handler/message_handler_test.py | 13 +++++++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index d240cdcda5d7..422d96ec9188 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -101,7 +101,7 @@ def handle( if server_msg is None: # Instantiate the client client = client_fn("-1") - client.state = state # TODO: inject state into client object + client.state = state # Secure Aggregation if task_ins.task.HasField("sa") and isinstance( client, SecureAggregationHandler @@ -118,7 +118,7 @@ def handle( sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), ), ) - return task_res, client.state # TODO: return updated state + return task_res, client.state raise NotImplementedError() client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg) task_res = wrap_client_message_in_task_res(client_msg) @@ -152,7 +152,7 @@ def handle_legacy_message( # Instantiate the client client = client_fn("-1") - client.state = state # TODO: inject state into client object + client.state = state # Execute task message = None if field == "get_properties_ins": @@ -164,7 +164,7 @@ def handle_legacy_message( if field == "evaluate_ins": message = _evaluate(client, server_msg.evaluate_ins) if message: - return message, client.state # TODO: return updated state + return message, client.state raise UnknownServerMessage() diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 0183f161f873..c2d4def0a811 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -19,6 +19,7 @@ from flwr.client import Client from flwr.client.typing import ClientFn +from flwr.client.workload_state import WorkloadState from flwr.common import ( EvaluateIns, EvaluateRes, @@ -133,7 +134,11 @@ def test_client_without_get_properties() -> None: disconnect_task_res, actual_sleep_duration = handle_control_message( task_ins=task_ins ) - task_res = handle(client_fn=_get_client_fn(client), task_ins=task_ins) + task_res, state_updated = handle( # pylint: disable=unused-variable + client_fn=_get_client_fn(client), + state=WorkloadState(state={}), + task_ins=task_ins, + ) if not task_res.HasField("task"): raise ValueError("Task value not found") @@ -197,7 +202,11 @@ def test_client_with_get_properties() -> None: disconnect_task_res, actual_sleep_duration = handle_control_message( task_ins=task_ins ) - task_res = handle(client_fn=_get_client_fn(client), task_ins=task_ins) + task_res, updated_state = handle( # pylint: disable=unused-variable + client_fn=_get_client_fn(client), + state=WorkloadState(state={}), + task_ins=task_ins, + ) if not task_res.HasField("task"): raise ValueError("Task value not found") From 62b071b655ea4b608b8bd722f6dc2eb14d6281fd Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 27 Nov 2023 12:17:55 +0000 Subject: [PATCH 4/8] w/ set_stat()/get_state() --- src/py/flwr/client/client.py | 8 ++++++++ src/py/flwr/client/message_handler/message_handler.py | 8 ++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/client/client.py b/src/py/flwr/client/client.py index d221791b655e..280e0a8ca989 100644 --- a/src/py/flwr/client/client.py +++ b/src/py/flwr/client/client.py @@ -141,6 +141,14 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: metrics={}, ) + def get_state(self) -> WorkloadState: + """Get the workload state from this client.""" + return self.state + + def set_state(self, state: WorkloadState) -> None: + """Apply a workload state to this client.""" + self.state = state + def to_client(self) -> Client: """Return client (itself).""" return self diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 422d96ec9188..0f3070cfb01a 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -101,7 +101,7 @@ def handle( if server_msg is None: # Instantiate the client client = client_fn("-1") - client.state = state + client.set_state(state) # Secure Aggregation if task_ins.task.HasField("sa") and isinstance( client, SecureAggregationHandler @@ -118,7 +118,7 @@ def handle( sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), ), ) - return task_res, client.state + return task_res, client.get_state() raise NotImplementedError() client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg) task_res = wrap_client_message_in_task_res(client_msg) @@ -152,7 +152,7 @@ def handle_legacy_message( # Instantiate the client client = client_fn("-1") - client.state = state + client.set_state(state) # Execute task message = None if field == "get_properties_ins": @@ -164,7 +164,7 @@ def handle_legacy_message( if field == "evaluate_ins": message = _evaluate(client, server_msg.evaluate_ins) if message: - return message, client.state + return message, client.get_state() raise UnknownServerMessage() From 423d70889d7a5b2f5129c4859c3eb1967fde178c Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 27 Nov 2023 12:29:32 +0000 Subject: [PATCH 5/8] in change-log --- doc/source/ref-changelog.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index dd8a6e225401..53e80d155e85 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -22,6 +22,8 @@ Using the `client_fn`, Flower clients can interchangeably run as standalone processes (i.e. via `start_client`) or in simulation (i.e. via `start_simulation`) without requiring changes to how the client class is defined and instantiated. Calling `start_numpy_client` is now deprecated. +- **Introduced `WorkloadState`** ([#2564](https://github.com/adap/flower/pull/2564), [#2632](https://github.com/adap/flower/pull/2632)) + - **Update Flower Baselines** - FedProx ([#2210](https://github.com/adap/flower/pull/2210), [#2286](https://github.com/adap/flower/pull/2286), [#2509](https://github.com/adap/flower/pull/2509)) From 2804d9be592798f6c9af21a2130db30304da83f1 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 27 Nov 2023 13:42:12 +0000 Subject: [PATCH 6/8] Apply suggestions from code review Co-authored-by: Daniel J. Beutel --- doc/source/ref-changelog.md | 2 +- src/py/flwr/client/message_handler/message_handler_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 53e80d155e85..9b9c229901af 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -22,7 +22,7 @@ Using the `client_fn`, Flower clients can interchangeably run as standalone processes (i.e. via `start_client`) or in simulation (i.e. via `start_simulation`) without requiring changes to how the client class is defined and instantiated. Calling `start_numpy_client` is now deprecated. -- **Introduced `WorkloadState`** ([#2564](https://github.com/adap/flower/pull/2564), [#2632](https://github.com/adap/flower/pull/2632)) +- **Introduce `WorkloadState`** ([#2564](https://github.com/adap/flower/pull/2564), [#2632](https://github.com/adap/flower/pull/2632)) - **Update Flower Baselines** diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index c2d4def0a811..d7f410d81fc0 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -134,7 +134,7 @@ def test_client_without_get_properties() -> None: disconnect_task_res, actual_sleep_duration = handle_control_message( task_ins=task_ins ) - task_res, state_updated = handle( # pylint: disable=unused-variable + task_res, _ = handle( client_fn=_get_client_fn(client), state=WorkloadState(state={}), task_ins=task_ins, @@ -202,7 +202,7 @@ def test_client_with_get_properties() -> None: disconnect_task_res, actual_sleep_duration = handle_control_message( task_ins=task_ins ) - task_res, updated_state = handle( # pylint: disable=unused-variable + task_res, _ = handle( client_fn=_get_client_fn(client), state=WorkloadState(state={}), task_ins=task_ins, From 0d710d21d49d395d3b5577b3d9069d30dc65122f Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 27 Nov 2023 14:38:53 +0000 Subject: [PATCH 7/8] works for numpyclient --- src/py/flwr/client/client.py | 2 ++ src/py/flwr/client/numpy_client.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/py/flwr/client/client.py b/src/py/flwr/client/client.py index 280e0a8ca989..d1efd6786190 100644 --- a/src/py/flwr/client/client.py +++ b/src/py/flwr/client/client.py @@ -143,10 +143,12 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: def get_state(self) -> WorkloadState: """Get the workload state from this client.""" + print("get_state()") return self.state def set_state(self, state: WorkloadState) -> None: """Apply a workload state to this client.""" + print("set_state()") self.state = state def to_client(self) -> Client: diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index 883be2f3d554..8b0893ea30aa 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -19,6 +19,7 @@ from typing import Callable, Dict, Tuple from flwr.client.client import Client +from flwr.client.workload_state import WorkloadState from flwr.common import ( Config, NDArrays, @@ -69,6 +70,8 @@ class NumPyClient(ABC): """Abstract base class for Flower clients using NumPy.""" + state: WorkloadState + def get_properties(self, config: Config) -> Dict[str, Scalar]: """Return a client's set of properties. @@ -171,6 +174,14 @@ def evaluate( _ = (self, parameters, config) return 0.0, 0, {} + def get_state(self) -> WorkloadState: + """Get the workload state from this client.""" + return self.state + + def set_state(self, state: WorkloadState) -> None: + """Apply a workload state to this client.""" + self.state = state + def to_client(self) -> Client: """Convert to object to Client type and return it.""" return _wrap_numpy_client(client=self) @@ -267,9 +278,21 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: ) +def _get_state(self: Client) -> WorkloadState: + """Return state of underlying NumPyClient.""" + return self.numpy_client.get_state() # type: ignore + + +def _set_state(self: Client, state: WorkloadState) -> None: + """Apply state to underlying NumPyClient.""" + self.numpy_client.set_state(state) # type: ignore + + def _wrap_numpy_client(client: NumPyClient) -> Client: member_dict: Dict[str, Callable] = { # type: ignore "__init__": _constructor, + "get_state": _get_state, + "set_state": _set_state, } # Add wrapper type methods (if overridden) From fe6f2855b216cdf76c1e1697d1f8c4d309e7e200 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 27 Nov 2023 14:49:04 +0000 Subject: [PATCH 8/8] remove logs --- src/py/flwr/client/client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/py/flwr/client/client.py b/src/py/flwr/client/client.py index d1efd6786190..280e0a8ca989 100644 --- a/src/py/flwr/client/client.py +++ b/src/py/flwr/client/client.py @@ -143,12 +143,10 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: def get_state(self) -> WorkloadState: """Get the workload state from this client.""" - print("get_state()") return self.state def set_state(self, state: WorkloadState) -> None: """Apply a workload state to this client.""" - print("set_state()") self.state = state def to_client(self) -> Client: