Skip to content

Commit

Permalink
break(framework) Remove setting Context as Client and `NumPyClien…
Browse files Browse the repository at this point in the history
…t` attribute (#4652)
  • Loading branch information
jafermarq authored Dec 17, 2024
1 parent 21e5577 commit ea7c194
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 96 deletions.
22 changes: 13 additions & 9 deletions e2e/e2e-bare/e2e_bare/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import ConfigsRecord, Context
from flwr.common import ConfigsRecord, Context, RecordSet

SUBSET_SIZE = 1000
STATE_VAR = "timestamp"
Expand All @@ -15,23 +15,24 @@

# Define Flower client
class FlowerClient(NumPyClient):
def __init__(self, state: RecordSet):
self.state = state

def get_parameters(self, config):
return model_params

def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
value = str(t_stamp)
if STATE_VAR in self.context.state.configs_records.keys():
value = self.context.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore
if STATE_VAR in self.state.configs_records.keys():
value = self.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore
value += f",{t_stamp}"

self.context.state.configs_records[STATE_VAR] = ConfigsRecord(
{STATE_VAR: value}
)
self.state.configs_records[STATE_VAR] = ConfigsRecord({STATE_VAR: value})

def _retrieve_timestamp_from_state(self):
return self.context.state.configs_records[STATE_VAR][STATE_VAR]
return self.state.configs_records[STATE_VAR][STATE_VAR]

def fit(self, parameters, config):
model_params = parameters
Expand All @@ -52,7 +53,7 @@ def evaluate(self, parameters, config):


def client_fn(context: Context):
return FlowerClient().to_client()
return FlowerClient(context.state).to_client()


app = ClientApp(
Expand All @@ -61,4 +62,7 @@ def client_fn(context: Context):

if __name__ == "__main__":
# Start Flower client
start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient(state=RecordSet()).to_client(),
)
20 changes: 11 additions & 9 deletions e2e/e2e-pytorch/e2e_pytorch/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tqdm import tqdm

from flwr.client import ClientApp, NumPyClient, start_client
from flwr.common import ConfigsRecord, Context
from flwr.common import ConfigsRecord, Context, RecordSet

# #############################################################################
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
Expand Down Expand Up @@ -90,23 +90,25 @@ def load_data():

# Define Flower client
class FlowerClient(NumPyClient):

def __init__(self, state: RecordSet):
self.state = state

def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
value = str(t_stamp)
if STATE_VAR in self.context.state.configs_records.keys():
value = self.context.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore
if STATE_VAR in self.state.configs_records.keys():
value = self.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore
value += f",{t_stamp}"

self.context.state.configs_records[STATE_VAR] = ConfigsRecord(
{STATE_VAR: value}
)
self.state.configs_records[STATE_VAR] = ConfigsRecord({STATE_VAR: value})

def _retrieve_timestamp_from_state(self):
return self.context.state.configs_records[STATE_VAR][STATE_VAR]
return self.state.configs_records[STATE_VAR][STATE_VAR]

def fit(self, parameters, config):
set_parameters(net, parameters)
Expand Down Expand Up @@ -137,7 +139,7 @@ def set_parameters(model, parameters):


def client_fn(context: Context):
return FlowerClient().to_client()
return FlowerClient(context.state).to_client()


app = ClientApp(
Expand All @@ -149,5 +151,5 @@ def client_fn(context: Context):
# Start Flower client
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
client=FlowerClient(state=RecordSet()).to_client(),
)
32 changes: 0 additions & 32 deletions src/py/flwr/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from flwr.common import (
Code,
Context,
EvaluateIns,
EvaluateRes,
FitIns,
Expand All @@ -34,14 +33,11 @@
Parameters,
Status,
)
from flwr.common.logger import warn_deprecated_feature_with_example


class Client(ABC):
"""Abstract base class for Flower clients."""

_context: Context

def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return set of client's properties.
Expand Down Expand Up @@ -143,34 +139,6 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
metrics={},
)

@property
def context(self) -> Context:
"""Getter for `Context` client attribute."""
warn_deprecated_feature_with_example(
"Accessing the context via the client's attribute is deprecated.",
example_message="Instead, pass it to the client's "
"constructor in your `client_fn()` which already "
"receives a context object.",
code_example="def client_fn(context: Context) -> Client:\n\n"
"\t\t# Your existing client_fn\n\n"
"\t\t# Pass `context` to the constructor\n"
"\t\treturn FlowerClient(context).to_client()",
)
return self._context

@context.setter
def context(self, context: Context) -> None:
"""Setter for `Context` client attribute."""
self._context = context

def get_context(self) -> Context:
"""Get the run context from this client."""
return self.context

def set_context(self, context: Context) -> None:
"""Apply a run context to this client."""
self.context = context

def to_client(self) -> Client:
"""Return client (itself)."""
return self
Expand Down
2 changes: 0 additions & 2 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ def handle_legacy_message_from_msgtype(
"Please use `NumPyClient.to_client()` method to convert it to `Client`.",
)

client.set_context(context)

message_type = message.metadata.message_type

# Handle GetPropertiesIns
Expand Down
44 changes: 0 additions & 44 deletions src/py/flwr/client/numpy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@
from flwr.client.client import Client
from flwr.common import (
Config,
Context,
NDArrays,
Scalar,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common.logger import warn_deprecated_feature_with_example
from flwr.common.typing import (
Code,
EvaluateIns,
Expand Down Expand Up @@ -71,8 +69,6 @@
class NumPyClient(ABC):
"""Abstract base class for Flower clients using NumPy."""

_context: Context

def get_properties(self, config: Config) -> dict[str, Scalar]:
"""Return a client's set of properties.
Expand Down Expand Up @@ -175,34 +171,6 @@ def evaluate(
_ = (self, parameters, config)
return 0.0, 0, {}

@property
def context(self) -> Context:
"""Getter for `Context` client attribute."""
warn_deprecated_feature_with_example(
"Accessing the context via the client's attribute is deprecated.",
example_message="Instead, pass it to the client's "
"constructor in your `client_fn()` which already "
"receives a context object.",
code_example="def client_fn(context: Context) -> Client:\n\n"
"\t\t# Your existing client_fn\n\n"
"\t\t# Pass `context` to the constructor\n"
"\t\treturn FlowerClient(context).to_client()",
)
return self._context

@context.setter
def context(self, context: Context) -> None:
"""Setter for `Context` client attribute."""
self._context = context

def get_context(self) -> Context:
"""Get the run context from this client."""
return self.context

def set_context(self, context: Context) -> None:
"""Apply a run context to this client."""
self.context = context

def to_client(self) -> Client:
"""Convert to object to Client type and return it."""
return _wrap_numpy_client(client=self)
Expand Down Expand Up @@ -299,21 +267,9 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes:
)


def _get_context(self: Client) -> Context:
"""Return context of underlying NumPyClient."""
return self.numpy_client.get_context() # type: ignore


def _set_context(self: Client, context: Context) -> None:
"""Apply context to underlying NumPyClient."""
self.numpy_client.set_context(context) # type: ignore


def _wrap_numpy_client(client: NumPyClient) -> Client:
member_dict: dict[str, Callable] = { # type: ignore
"__init__": _constructor,
"get_context": _get_context,
"set_context": _set_context,
}

# Add wrapper type methods (if overridden)
Expand Down

0 comments on commit ea7c194

Please sign in to comment.