Skip to content

Commit

Permalink
Inject WorkloadState into instantiated client (#2632)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
jafermarq and danieljanes authored Nov 29, 2023
1 parent 02bb3ba commit ded72a2
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 14 deletions.
2 changes: 2 additions & 0 deletions doc/source/ref-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ We would like to give our special thanks to all the contributors who made the ne

- **Add new** `XGB Bagging` **strategy** ([#2611](https://github.com/adap/flower/pull/2611))

- **Introduce `WorkloadState`** ([#2564](https://github.com/adap/flower/pull/2564), [#2632](https://github.com/adap/flower/pull/2632))

- **Update Flower Baselines**

- FedProx ([#2210](https://github.com/adap/flower/pull/2210), [#2286](https://github.com/adap/flower/pull/2286), [#2509](https://github.com/adap/flower/pull/2509))
Expand Down
11 changes: 11 additions & 0 deletions src/py/flwr/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from abc import ABC

from flwr.client.workload_state import WorkloadState
from flwr.common import (
Code,
EvaluateIns,
Expand All @@ -37,6 +38,8 @@
class Client(ABC):
"""Abstract base class for Flower clients."""

state: WorkloadState

def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return set of client's properties.
Expand Down Expand Up @@ -138,6 +141,14 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
metrics={},
)

def get_state(self) -> WorkloadState:
"""Get the workload state from this client."""
return self.state

def set_state(self, state: WorkloadState) -> None:
"""Apply a workload state to this client."""
self.state = state

def to_client(self) -> Client:
"""Return client (itself)."""
return self
Expand Down
5 changes: 3 additions & 2 deletions src/py/flwr/client/flower.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,14 @@ def __init__(
def __call__(self, fwd: Fwd) -> Bwd:
"""."""
# Execute the task
task_res = handle(
task_res, state_updated = handle(
client_fn=self.client_fn,
state=fwd.state,
task_ins=fwd.task_ins,
)
return Bwd(
task_res=task_res,
state=WorkloadState(state={}),
state=state_updated,
)


Expand Down
32 changes: 22 additions & 10 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from flwr.client.secure_aggregation import SecureAggregationHandler
from flwr.client.typing import ClientFn
from flwr.client.workload_state import WorkloadState
from flwr.common import serde
from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes
from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage
Expand Down Expand Up @@ -77,13 +78,17 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]:
return None, 0


def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes:
def handle(
client_fn: ClientFn, state: WorkloadState, task_ins: TaskIns
) -> Tuple[TaskRes, WorkloadState]:
"""Handle incoming TaskIns from the server.
Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : WorkloadState
A dataclass storing the state for the workload being executed by the client.
task_ins: TaskIns
The task instruction coming from the server, to be processed by the client.
Expand All @@ -96,6 +101,7 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes:
if server_msg is None:
# Instantiate the client
client = client_fn("-1")
client.set_state(state)
# Secure Aggregation
if task_ins.task.HasField("sa") and isinstance(
client, SecureAggregationHandler
Expand All @@ -112,22 +118,24 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes:
sa=SecureAggregation(named_values=serde.named_values_to_proto(res)),
),
)
return task_res
return task_res, client.get_state()
raise NotImplementedError()
client_msg = handle_legacy_message(client_fn, server_msg)
client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg)
task_res = wrap_client_message_in_task_res(client_msg)
return task_res
return task_res, updated_state


def handle_legacy_message(
client_fn: ClientFn, server_msg: ServerMessage
) -> ClientMessage:
client_fn: ClientFn, state: WorkloadState, server_msg: ServerMessage
) -> Tuple[ClientMessage, WorkloadState]:
"""Handle incoming messages from the server.
Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : WorkloadState
A dataclass storing the state for the workload being executed by the client.
server_msg: ServerMessage
The message coming from the server, to be processed by the client.
Expand All @@ -144,15 +152,19 @@ def handle_legacy_message(

# Instantiate the client
client = client_fn("-1")
client.set_state(state)
# Execute task
message = None
if field == "get_properties_ins":
return _get_properties(client, server_msg.get_properties_ins)
message = _get_properties(client, server_msg.get_properties_ins)
if field == "get_parameters_ins":
return _get_parameters(client, server_msg.get_parameters_ins)
message = _get_parameters(client, server_msg.get_parameters_ins)
if field == "fit_ins":
return _fit(client, server_msg.fit_ins)
message = _fit(client, server_msg.fit_ins)
if field == "evaluate_ins":
return _evaluate(client, server_msg.evaluate_ins)
message = _evaluate(client, server_msg.evaluate_ins)
if message:
return message, client.get_state()
raise UnknownServerMessage()


Expand Down
13 changes: 11 additions & 2 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from flwr.client import Client
from flwr.client.typing import ClientFn
from flwr.client.workload_state import WorkloadState
from flwr.common import (
EvaluateIns,
EvaluateRes,
Expand Down Expand Up @@ -133,7 +134,11 @@ def test_client_without_get_properties() -> None:
disconnect_task_res, actual_sleep_duration = handle_control_message(
task_ins=task_ins
)
task_res = handle(client_fn=_get_client_fn(client), task_ins=task_ins)
task_res, _ = handle(
client_fn=_get_client_fn(client),
state=WorkloadState(state={}),
task_ins=task_ins,
)

if not task_res.HasField("task"):
raise ValueError("Task value not found")
Expand Down Expand Up @@ -197,7 +202,11 @@ def test_client_with_get_properties() -> None:
disconnect_task_res, actual_sleep_duration = handle_control_message(
task_ins=task_ins
)
task_res = handle(client_fn=_get_client_fn(client), task_ins=task_ins)
task_res, _ = handle(
client_fn=_get_client_fn(client),
state=WorkloadState(state={}),
task_ins=task_ins,
)

if not task_res.HasField("task"):
raise ValueError("Task value not found")
Expand Down
23 changes: 23 additions & 0 deletions src/py/flwr/client/numpy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Callable, Dict, Tuple

from flwr.client.client import Client
from flwr.client.workload_state import WorkloadState
from flwr.common import (
Config,
NDArrays,
Expand Down Expand Up @@ -69,6 +70,8 @@
class NumPyClient(ABC):
"""Abstract base class for Flower clients using NumPy."""

state: WorkloadState

def get_properties(self, config: Config) -> Dict[str, Scalar]:
"""Return a client's set of properties.
Expand Down Expand Up @@ -171,6 +174,14 @@ def evaluate(
_ = (self, parameters, config)
return 0.0, 0, {}

def get_state(self) -> WorkloadState:
"""Get the workload state from this client."""
return self.state

def set_state(self, state: WorkloadState) -> None:
"""Apply a workload state to this client."""
self.state = state

def to_client(self) -> Client:
"""Convert to object to Client type and return it."""
return _wrap_numpy_client(client=self)
Expand Down Expand Up @@ -267,9 +278,21 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes:
)


def _get_state(self: Client) -> WorkloadState:
"""Return state of underlying NumPyClient."""
return self.numpy_client.get_state() # type: ignore


def _set_state(self: Client, state: WorkloadState) -> None:
"""Apply state to underlying NumPyClient."""
self.numpy_client.set_state(state) # type: ignore


def _wrap_numpy_client(client: NumPyClient) -> Client:
member_dict: Dict[str, Callable] = { # type: ignore
"__init__": _constructor,
"get_state": _get_state,
"set_state": _set_state,
}

# Add wrapper type methods (if overridden)
Expand Down

0 comments on commit ded72a2

Please sign in to comment.