Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inject WorkloadState into instantiated client #2632

Merged
merged 12 commits into from
Nov 29, 2023
2 changes: 2 additions & 0 deletions doc/source/ref-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ We would like to give our special thanks to all the contributors who made the ne

- **Add new** `XGB Bagging` **strategy** ([#2611](https://github.com/adap/flower/pull/2611))

- **Introduce `WorkloadState`** ([#2564](https://github.com/adap/flower/pull/2564), [#2632](https://github.com/adap/flower/pull/2632))

- **Update Flower Baselines**

- FedProx ([#2210](https://github.com/adap/flower/pull/2210), [#2286](https://github.com/adap/flower/pull/2286), [#2509](https://github.com/adap/flower/pull/2509))
Expand Down
11 changes: 11 additions & 0 deletions src/py/flwr/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from abc import ABC

from flwr.client.workload_state import WorkloadState
from flwr.common import (
Code,
EvaluateIns,
Expand All @@ -37,6 +38,8 @@
class Client(ABC):
"""Abstract base class for Flower clients."""

state: WorkloadState
danieljanes marked this conversation as resolved.
Show resolved Hide resolved

def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return set of client's properties.

Expand Down Expand Up @@ -138,6 +141,14 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
metrics={},
)

def get_state(self) -> WorkloadState:
"""Get the workload state from this client."""
return self.state

def set_state(self, state: WorkloadState) -> None:
"""Apply a workload state to this client."""
self.state = state

def to_client(self) -> Client:
"""Return client (itself)."""
return self
Expand Down
5 changes: 3 additions & 2 deletions src/py/flwr/client/flower.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,14 @@ def __init__(
def __call__(self, fwd: Fwd) -> Bwd:
"""."""
# Execute the task
task_res = handle(
task_res, state_updated = handle(
client_fn=self.client_fn,
state=fwd.state,
task_ins=fwd.task_ins,
)
return Bwd(
task_res=task_res,
state=WorkloadState(state={}),
state=state_updated,
)


Expand Down
32 changes: 22 additions & 10 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from flwr.client.secure_aggregation import SecureAggregationHandler
from flwr.client.typing import ClientFn
from flwr.client.workload_state import WorkloadState
from flwr.common import serde
from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes
from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage
Expand Down Expand Up @@ -77,13 +78,17 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]:
return None, 0


def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes:
def handle(
client_fn: ClientFn, state: WorkloadState, task_ins: TaskIns
) -> Tuple[TaskRes, WorkloadState]:
"""Handle incoming TaskIns from the server.

Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : WorkloadState
A dataclass storing the state for the workload being executed by the client.
task_ins: TaskIns
The task instruction coming from the server, to be processed by the client.

Expand All @@ -96,6 +101,7 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes:
if server_msg is None:
# Instantiate the client
client = client_fn("-1")
client.set_state(state)
# Secure Aggregation
if task_ins.task.HasField("sa") and isinstance(
client, SecureAggregationHandler
Expand All @@ -112,22 +118,24 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes:
sa=SecureAggregation(named_values=serde.named_values_to_proto(res)),
),
)
return task_res
return task_res, client.get_state()
raise NotImplementedError()
client_msg = handle_legacy_message(client_fn, server_msg)
client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg)
task_res = wrap_client_message_in_task_res(client_msg)
return task_res
return task_res, updated_state


def handle_legacy_message(
client_fn: ClientFn, server_msg: ServerMessage
) -> ClientMessage:
client_fn: ClientFn, state: WorkloadState, server_msg: ServerMessage
) -> Tuple[ClientMessage, WorkloadState]:
"""Handle incoming messages from the server.

Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : WorkloadState
A dataclass storing the state for the workload being executed by the client.
server_msg: ServerMessage
The message coming from the server, to be processed by the client.

Expand All @@ -144,15 +152,19 @@ def handle_legacy_message(

# Instantiate the client
client = client_fn("-1")
client.set_state(state)
# Execute task
message = None
if field == "get_properties_ins":
return _get_properties(client, server_msg.get_properties_ins)
message = _get_properties(client, server_msg.get_properties_ins)
if field == "get_parameters_ins":
return _get_parameters(client, server_msg.get_parameters_ins)
message = _get_parameters(client, server_msg.get_parameters_ins)
if field == "fit_ins":
return _fit(client, server_msg.fit_ins)
message = _fit(client, server_msg.fit_ins)
if field == "evaluate_ins":
return _evaluate(client, server_msg.evaluate_ins)
message = _evaluate(client, server_msg.evaluate_ins)
if message:
return message, client.get_state()
raise UnknownServerMessage()


Expand Down
13 changes: 11 additions & 2 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from flwr.client import Client
from flwr.client.typing import ClientFn
from flwr.client.workload_state import WorkloadState
from flwr.common import (
EvaluateIns,
EvaluateRes,
Expand Down Expand Up @@ -133,7 +134,11 @@ def test_client_without_get_properties() -> None:
disconnect_task_res, actual_sleep_duration = handle_control_message(
task_ins=task_ins
)
task_res = handle(client_fn=_get_client_fn(client), task_ins=task_ins)
task_res, _ = handle(
client_fn=_get_client_fn(client),
state=WorkloadState(state={}),
task_ins=task_ins,
)

if not task_res.HasField("task"):
raise ValueError("Task value not found")
Expand Down Expand Up @@ -197,7 +202,11 @@ def test_client_with_get_properties() -> None:
disconnect_task_res, actual_sleep_duration = handle_control_message(
task_ins=task_ins
)
task_res = handle(client_fn=_get_client_fn(client), task_ins=task_ins)
task_res, _ = handle(
client_fn=_get_client_fn(client),
state=WorkloadState(state={}),
task_ins=task_ins,
)

if not task_res.HasField("task"):
raise ValueError("Task value not found")
Expand Down
23 changes: 23 additions & 0 deletions src/py/flwr/client/numpy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Callable, Dict, Tuple

from flwr.client.client import Client
from flwr.client.workload_state import WorkloadState
from flwr.common import (
Config,
NDArrays,
Expand Down Expand Up @@ -69,6 +70,8 @@
class NumPyClient(ABC):
"""Abstract base class for Flower clients using NumPy."""

state: WorkloadState

def get_properties(self, config: Config) -> Dict[str, Scalar]:
"""Return a client's set of properties.

Expand Down Expand Up @@ -171,6 +174,14 @@ def evaluate(
_ = (self, parameters, config)
return 0.0, 0, {}

def get_state(self) -> WorkloadState:
"""Get the workload state from this client."""
return self.state

def set_state(self, state: WorkloadState) -> None:
"""Apply a workload state to this client."""
self.state = state

def to_client(self) -> Client:
"""Convert to object to Client type and return it."""
return _wrap_numpy_client(client=self)
Expand Down Expand Up @@ -267,9 +278,21 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes:
)


def _get_state(self: Client) -> WorkloadState:
"""Return state of underlying NumPyClient."""
return self.numpy_client.get_state() # type: ignore


def _set_state(self: Client, state: WorkloadState) -> None:
"""Apply state to underlying NumPyClient."""
self.numpy_client.set_state(state) # type: ignore


def _wrap_numpy_client(client: NumPyClient) -> Client:
member_dict: Dict[str, Callable] = { # type: ignore
"__init__": _constructor,
"get_state": _get_state,
"set_state": _set_state,
}

# Add wrapper type methods (if overridden)
Expand Down