From ded72a237d18aedc411abd37a8925d1add531d8f Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 29 Nov 2023 12:06:53 +0000 Subject: [PATCH] Inject `WorkloadState` into instantiated client (#2632) Co-authored-by: Daniel J. Beutel --- doc/source/ref-changelog.md | 2 ++ src/py/flwr/client/client.py | 11 +++++++ src/py/flwr/client/flower.py | 5 +-- .../client/message_handler/message_handler.py | 32 +++++++++++++------ .../message_handler/message_handler_test.py | 13 ++++++-- src/py/flwr/client/numpy_client.py | 23 +++++++++++++ 6 files changed, 72 insertions(+), 14 deletions(-) diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 0cf7aca63941..f73d3babda18 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -50,6 +50,8 @@ We would like to give our special thanks to all the contributors who made the ne - **Add new** `XGB Bagging` **strategy** ([#2611](https://github.com/adap/flower/pull/2611)) +- **Introduce `WorkloadState`** ([#2564](https://github.com/adap/flower/pull/2564), [#2632](https://github.com/adap/flower/pull/2632)) + - **Update Flower Baselines** - FedProx ([#2210](https://github.com/adap/flower/pull/2210), [#2286](https://github.com/adap/flower/pull/2286), [#2509](https://github.com/adap/flower/pull/2509)) diff --git a/src/py/flwr/client/client.py b/src/py/flwr/client/client.py index 03769c6b8bcf..280e0a8ca989 100644 --- a/src/py/flwr/client/client.py +++ b/src/py/flwr/client/client.py @@ -19,6 +19,7 @@ from abc import ABC +from flwr.client.workload_state import WorkloadState from flwr.common import ( Code, EvaluateIns, @@ -37,6 +38,8 @@ class Client(ABC): """Abstract base class for Flower clients.""" + state: WorkloadState + def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: """Return set of client's properties. @@ -138,6 +141,14 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: metrics={}, ) + def get_state(self) -> WorkloadState: + """Get the workload state from this client.""" + return self.state + + def set_state(self, state: WorkloadState) -> None: + """Apply a workload state to this client.""" + self.state = state + def to_client(self) -> Client: """Return client (itself).""" return self diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py index 9eeb41887e24..f0d4ce122524 100644 --- a/src/py/flwr/client/flower.py +++ b/src/py/flwr/client/flower.py @@ -79,13 +79,14 @@ def __init__( def __call__(self, fwd: Fwd) -> Bwd: """.""" # Execute the task - task_res = handle( + task_res, state_updated = handle( client_fn=self.client_fn, + state=fwd.state, task_ins=fwd.task_ins, ) return Bwd( task_res=task_res, - state=WorkloadState(state={}), + state=state_updated, ) diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 971f2fe56e31..0f3070cfb01a 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -30,6 +30,7 @@ ) from flwr.client.secure_aggregation import SecureAggregationHandler from flwr.client.typing import ClientFn +from flwr.client.workload_state import WorkloadState from flwr.common import serde from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage @@ -77,13 +78,17 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: return None, 0 -def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes: +def handle( + client_fn: ClientFn, state: WorkloadState, task_ins: TaskIns +) -> Tuple[TaskRes, WorkloadState]: """Handle incoming TaskIns from the server. Parameters ---------- client_fn : ClientFn A callable that instantiates a Client. + state : WorkloadState + A dataclass storing the state for the workload being executed by the client. task_ins: TaskIns The task instruction coming from the server, to be processed by the client. @@ -96,6 +101,7 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes: if server_msg is None: # Instantiate the client client = client_fn("-1") + client.set_state(state) # Secure Aggregation if task_ins.task.HasField("sa") and isinstance( client, SecureAggregationHandler @@ -112,22 +118,24 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes: sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), ), ) - return task_res + return task_res, client.get_state() raise NotImplementedError() - client_msg = handle_legacy_message(client_fn, server_msg) + client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg) task_res = wrap_client_message_in_task_res(client_msg) - return task_res + return task_res, updated_state def handle_legacy_message( - client_fn: ClientFn, server_msg: ServerMessage -) -> ClientMessage: + client_fn: ClientFn, state: WorkloadState, server_msg: ServerMessage +) -> Tuple[ClientMessage, WorkloadState]: """Handle incoming messages from the server. Parameters ---------- client_fn : ClientFn A callable that instantiates a Client. + state : WorkloadState + A dataclass storing the state for the workload being executed by the client. server_msg: ServerMessage The message coming from the server, to be processed by the client. @@ -144,15 +152,19 @@ def handle_legacy_message( # Instantiate the client client = client_fn("-1") + client.set_state(state) # Execute task + message = None if field == "get_properties_ins": - return _get_properties(client, server_msg.get_properties_ins) + message = _get_properties(client, server_msg.get_properties_ins) if field == "get_parameters_ins": - return _get_parameters(client, server_msg.get_parameters_ins) + message = _get_parameters(client, server_msg.get_parameters_ins) if field == "fit_ins": - return _fit(client, server_msg.fit_ins) + message = _fit(client, server_msg.fit_ins) if field == "evaluate_ins": - return _evaluate(client, server_msg.evaluate_ins) + message = _evaluate(client, server_msg.evaluate_ins) + if message: + return message, client.get_state() raise UnknownServerMessage() diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 0183f161f873..d7f410d81fc0 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -19,6 +19,7 @@ from flwr.client import Client from flwr.client.typing import ClientFn +from flwr.client.workload_state import WorkloadState from flwr.common import ( EvaluateIns, EvaluateRes, @@ -133,7 +134,11 @@ def test_client_without_get_properties() -> None: disconnect_task_res, actual_sleep_duration = handle_control_message( task_ins=task_ins ) - task_res = handle(client_fn=_get_client_fn(client), task_ins=task_ins) + task_res, _ = handle( + client_fn=_get_client_fn(client), + state=WorkloadState(state={}), + task_ins=task_ins, + ) if not task_res.HasField("task"): raise ValueError("Task value not found") @@ -197,7 +202,11 @@ def test_client_with_get_properties() -> None: disconnect_task_res, actual_sleep_duration = handle_control_message( task_ins=task_ins ) - task_res = handle(client_fn=_get_client_fn(client), task_ins=task_ins) + task_res, _ = handle( + client_fn=_get_client_fn(client), + state=WorkloadState(state={}), + task_ins=task_ins, + ) if not task_res.HasField("task"): raise ValueError("Task value not found") diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index 883be2f3d554..8b0893ea30aa 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -19,6 +19,7 @@ from typing import Callable, Dict, Tuple from flwr.client.client import Client +from flwr.client.workload_state import WorkloadState from flwr.common import ( Config, NDArrays, @@ -69,6 +70,8 @@ class NumPyClient(ABC): """Abstract base class for Flower clients using NumPy.""" + state: WorkloadState + def get_properties(self, config: Config) -> Dict[str, Scalar]: """Return a client's set of properties. @@ -171,6 +174,14 @@ def evaluate( _ = (self, parameters, config) return 0.0, 0, {} + def get_state(self) -> WorkloadState: + """Get the workload state from this client.""" + return self.state + + def set_state(self, state: WorkloadState) -> None: + """Apply a workload state to this client.""" + self.state = state + def to_client(self) -> Client: """Convert to object to Client type and return it.""" return _wrap_numpy_client(client=self) @@ -267,9 +278,21 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: ) +def _get_state(self: Client) -> WorkloadState: + """Return state of underlying NumPyClient.""" + return self.numpy_client.get_state() # type: ignore + + +def _set_state(self: Client, state: WorkloadState) -> None: + """Apply state to underlying NumPyClient.""" + self.numpy_client.set_state(state) # type: ignore + + def _wrap_numpy_client(client: NumPyClient) -> Client: member_dict: Dict[str, Callable] = { # type: ignore "__init__": _constructor, + "get_state": _get_state, + "set_state": _set_state, } # Add wrapper type methods (if overridden)