Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inject WorkloadState into instantiated client #2632

Merged
merged 12 commits into from
Nov 29, 2023
2 changes: 2 additions & 0 deletions doc/source/ref-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

Using the `client_fn`, Flower clients can interchangeably run as standalone processes (i.e. via `start_client`) or in simulation (i.e. via `start_simulation`) without requiring changes to how the client class is defined and instantiated. Calling `start_numpy_client` is now deprecated.

- **Introduced `WorkloadState`** ([#2564](https://github.com/adap/flower/pull/2564), [#2632](https://github.com/adap/flower/pull/2632))
jafermarq marked this conversation as resolved.
Show resolved Hide resolved

- **Update Flower Baselines**

- FedProx ([#2210](https://github.com/adap/flower/pull/2210), [#2286](https://github.com/adap/flower/pull/2286), [#2509](https://github.com/adap/flower/pull/2509))
Expand Down
11 changes: 11 additions & 0 deletions src/py/flwr/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from abc import ABC

from flwr.client.workload_state import WorkloadState
from flwr.common import (
Code,
EvaluateIns,
Expand All @@ -37,6 +38,8 @@
class Client(ABC):
"""Abstract base class for Flower clients."""

state: WorkloadState
danieljanes marked this conversation as resolved.
Show resolved Hide resolved

def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return set of client's properties.

Expand Down Expand Up @@ -138,6 +141,14 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
metrics={},
)

def get_state(self) -> WorkloadState:
"""Get the workload state from this client."""
return self.state

def set_state(self, state: WorkloadState) -> None:
"""Apply a workload state to this client."""
self.state = state

def to_client(self) -> Client:
"""Return client (itself)."""
return self
Expand Down
5 changes: 3 additions & 2 deletions src/py/flwr/client/flower.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,14 @@ def __init__(
def __call__(self, fwd: Fwd) -> Bwd:
"""."""
# Execute the task
task_res = handle(
task_res, state_updated = handle(
client_fn=self.client_fn,
state=fwd.state,
task_ins=fwd.task_ins,
)
return Bwd(
task_res=task_res,
state=WorkloadState(state={}),
state=state_updated,
)


Expand Down
32 changes: 22 additions & 10 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from flwr.client.secure_aggregation import SecureAggregationHandler
from flwr.client.typing import ClientFn
from flwr.client.workload_state import WorkloadState
from flwr.common import serde
from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes
from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage
Expand Down Expand Up @@ -77,13 +78,17 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]:
return None, 0


def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes:
def handle(
client_fn: ClientFn, state: WorkloadState, task_ins: TaskIns
) -> Tuple[TaskRes, WorkloadState]:
"""Handle incoming TaskIns from the server.

Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : WorkloadState
A dataclass storing the state for the workload being executed by the client.
task_ins: TaskIns
The task instruction coming from the server, to be processed by the client.

Expand All @@ -96,6 +101,7 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes:
if server_msg is None:
# Instantiate the client
client = client_fn("-1")
client.set_state(state)
# Secure Aggregation
if task_ins.task.HasField("sa") and isinstance(
client, SecureAggregationHandler
Expand All @@ -112,22 +118,24 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes:
sa=SecureAggregation(named_values=serde.named_values_to_proto(res)),
),
)
return task_res
return task_res, client.get_state()
raise NotImplementedError()
client_msg = handle_legacy_message(client_fn, server_msg)
client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg)
task_res = wrap_client_message_in_task_res(client_msg)
return task_res
return task_res, updated_state


def handle_legacy_message(
client_fn: ClientFn, server_msg: ServerMessage
) -> ClientMessage:
client_fn: ClientFn, state: WorkloadState, server_msg: ServerMessage
) -> Tuple[ClientMessage, WorkloadState]:
"""Handle incoming messages from the server.

Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : WorkloadState
A dataclass storing the state for the workload being executed by the client.
server_msg: ServerMessage
The message coming from the server, to be processed by the client.

Expand All @@ -144,15 +152,19 @@ def handle_legacy_message(

# Instantiate the client
client = client_fn("-1")
client.set_state(state)
# Execute task
message = None
if field == "get_properties_ins":
return _get_properties(client, server_msg.get_properties_ins)
message = _get_properties(client, server_msg.get_properties_ins)
if field == "get_parameters_ins":
return _get_parameters(client, server_msg.get_parameters_ins)
message = _get_parameters(client, server_msg.get_parameters_ins)
if field == "fit_ins":
return _fit(client, server_msg.fit_ins)
message = _fit(client, server_msg.fit_ins)
if field == "evaluate_ins":
return _evaluate(client, server_msg.evaluate_ins)
message = _evaluate(client, server_msg.evaluate_ins)
if message:
return message, client.get_state()
raise UnknownServerMessage()


Expand Down
13 changes: 11 additions & 2 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from flwr.client import Client
from flwr.client.typing import ClientFn
from flwr.client.workload_state import WorkloadState
from flwr.common import (
EvaluateIns,
EvaluateRes,
Expand Down Expand Up @@ -133,7 +134,11 @@ def test_client_without_get_properties() -> None:
disconnect_task_res, actual_sleep_duration = handle_control_message(
task_ins=task_ins
)
task_res = handle(client_fn=_get_client_fn(client), task_ins=task_ins)
task_res, state_updated = handle( # pylint: disable=unused-variable
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
client_fn=_get_client_fn(client),
state=WorkloadState(state={}),
task_ins=task_ins,
)

if not task_res.HasField("task"):
raise ValueError("Task value not found")
Expand Down Expand Up @@ -197,7 +202,11 @@ def test_client_with_get_properties() -> None:
disconnect_task_res, actual_sleep_duration = handle_control_message(
task_ins=task_ins
)
task_res = handle(client_fn=_get_client_fn(client), task_ins=task_ins)
task_res, updated_state = handle( # pylint: disable=unused-variable
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
client_fn=_get_client_fn(client),
state=WorkloadState(state={}),
task_ins=task_ins,
)

if not task_res.HasField("task"):
raise ValueError("Task value not found")
Expand Down