From 4df413f8a1696f643fdde27bf7bac8c33b623f56 Mon Sep 17 00:00:00 2001 From: Heng Pan <134433891+panh99@users.noreply.github.com> Date: Fri, 13 Oct 2023 17:22:53 +0100 Subject: [PATCH 1/2] Rename Driver to GrpcDriver (#2511) Co-authored-by: Daniel J. Beutel --- examples/mt-pytorch/driver.py | 4 ++-- examples/secaggplus-mt/driver.py | 4 ++-- src/py/flwr/driver/__init__.py | 4 ++-- src/py/flwr/driver/app.py | 6 +++--- src/py/flwr/driver/driver.py | 20 ++++++++++---------- src/py/flwr/driver/driver_client_proxy.py | 6 ++++-- 6 files changed, 23 insertions(+), 21 deletions(-) diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py index 683a4b2975a6..fed760f021af 100644 --- a/examples/mt-pytorch/driver.py +++ b/examples/mt-pytorch/driver.py @@ -2,7 +2,7 @@ import random import time -from flwr.driver import Driver +from flwr.driver import GrpcDriver from flwr.common import ( ServerMessage, FitIns, @@ -43,7 +43,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK -driver = Driver(driver_service_address="0.0.0.0:9091", certificates=None) +driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None) # -------------------------------------------------------------------------- Driver SDK anonymous_client_nodes = False diff --git a/examples/secaggplus-mt/driver.py b/examples/secaggplus-mt/driver.py index 4e0a53ed1c91..d9f795766f6d 100644 --- a/examples/secaggplus-mt/driver.py +++ b/examples/secaggplus-mt/driver.py @@ -6,7 +6,7 @@ from workflows import get_workflow_factory from flwr.common import Metrics, ndarrays_to_parameters -from flwr.driver import Driver +from flwr.driver import GrpcDriver from flwr.proto import driver_pb2, node_pb2, task_pb2 from flwr.server import History @@ -71,7 +71,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK -driver = Driver(driver_service_address="0.0.0.0:9091", certificates=None) +driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None) # -------------------------------------------------------------------------- Driver SDK anonymous_client_nodes = False diff --git a/src/py/flwr/driver/__init__.py b/src/py/flwr/driver/__init__.py index 8c100e935f70..17b3a1ea236f 100644 --- a/src/py/flwr/driver/__init__.py +++ b/src/py/flwr/driver/__init__.py @@ -16,9 +16,9 @@ from .app import start_driver -from .driver import Driver +from .driver import GrpcDriver __all__ = [ "start_driver", - "Driver", + "GrpcDriver", ] diff --git a/src/py/flwr/driver/app.py b/src/py/flwr/driver/app.py index eeacfc3d9ede..ca9e7b13084b 100644 --- a/src/py/flwr/driver/app.py +++ b/src/py/flwr/driver/app.py @@ -31,7 +31,7 @@ from flwr.server.server import Server from flwr.server.strategy import Strategy -from .driver import Driver +from .driver import GrpcDriver from .driver_client_proxy import DriverClientProxy DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" @@ -112,7 +112,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" # Create the Driver - driver = Driver(driver_service_address=address, certificates=certificates) + driver = GrpcDriver(driver_service_address=address, certificates=certificates) driver.connect() lock = threading.Lock() @@ -157,7 +157,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals def update_client_manager( - driver: Driver, + driver: GrpcDriver, client_manager: ClientManager, lock: threading.Lock, ) -> None: diff --git a/src/py/flwr/driver/driver.py b/src/py/flwr/driver/driver.py index 8e029c1e1be1..4b189b9ce290 100644 --- a/src/py/flwr/driver/driver.py +++ b/src/py/flwr/driver/driver.py @@ -40,13 +40,13 @@ ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """ [Driver] Error: Not connected. -Call `connect()` on the `Driver` instance before calling any of the other `Driver` -methods. +Call `connect()` on the `GrpcDriver` instance before calling any of the other +`GrpcDriver` methods. """ -class Driver: - """`Driver` provides access to the Driver API.""" +class GrpcDriver: + """`GrpcDriver` provides access to the Driver API/service.""" def __init__( self, @@ -88,7 +88,7 @@ def create_workload(self, req: CreateWorkloadRequest) -> CreateWorkloadResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`Driver` instance not connected") + raise Exception("`GrpcDriver` instance not connected") # Call Driver API res: CreateWorkloadResponse = self.stub.CreateWorkload(request=req) @@ -99,9 +99,9 @@ def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`Driver` instance not connected") + raise Exception("`GrpcDriver` instance not connected") - # Call Driver API + # Call gRPC Driver API res: GetNodesResponse = self.stub.GetNodes(request=req) return res @@ -110,9 +110,9 @@ def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`Driver` instance not connected") + raise Exception("`GrpcDriver` instance not connected") - # Call Driver API + # Call gRPC Driver API res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) return res @@ -121,7 +121,7 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`Driver` instance not connected") + raise Exception("`GrpcDriver` instance not connected") # Call Driver API res: PullTaskResResponse = self.stub.PullTaskRes(request=req) diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/driver/driver_client_proxy.py index b732cf66c220..ae04ec5c1512 100644 --- a/src/py/flwr/driver/driver_client_proxy.py +++ b/src/py/flwr/driver/driver_client_proxy.py @@ -23,7 +23,7 @@ from flwr.proto import driver_pb2, node_pb2, task_pb2, transport_pb2 from flwr.server.client_proxy import ClientProxy -from .driver import Driver +from .driver import GrpcDriver SLEEP_TIME = 1 @@ -31,7 +31,9 @@ class DriverClientProxy(ClientProxy): """Flower client proxy which delegates work using the Driver API.""" - def __init__(self, node_id: int, driver: Driver, anonymous: bool, workload_id: int): + def __init__( + self, node_id: int, driver: GrpcDriver, anonymous: bool, workload_id: int + ): super().__init__(str(node_id)) self.node_id = node_id self.driver = driver From bfd56560c8ec508004391b72838622dbf7480190 Mon Sep 17 00:00:00 2001 From: Javier Date: Sat, 14 Oct 2023 11:04:34 +0100 Subject: [PATCH 2/2] Unify client types (#2390) Co-authored-by: Daniel J. Beutel --- doc/source/ref-changelog.md | 6 +-- src/py/flwr/client/__init__.py | 4 -- src/py/flwr/client/app.py | 54 +++++++++---------- src/py/flwr/client/app_test.py | 17 ++---- .../client/message_handler/message_handler.py | 9 ++-- src/py/flwr/client/numpy_client_wrapper.py | 27 ---------- src/py/flwr/client/typing.py | 6 +-- src/py/flwr/simulation/app.py | 2 +- .../simulation/ray_transport/ray_actor.py | 8 +-- .../ray_transport/ray_client_proxy.py | 6 +-- .../ray_transport/ray_client_proxy_test.py | 6 +-- src/py/flwr/simulation/ray_transport/utils.py | 27 ++++++++++ 12 files changed, 75 insertions(+), 97 deletions(-) delete mode 100644 src/py/flwr/client/numpy_client_wrapper.py diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 05ad1a64f1a2..d06f889fe4fb 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -12,9 +12,9 @@ The types of the return values in the docstrings in two methods (`aggregate_fit` and `aggregate_evaluate`) now match the hint types in the code. -- **Unify client API** ([#2303](https://github.com/adap/flower/pull/2303)) +- **Unify client API** ([#2303](https://github.com/adap/flower/pull/2303), [#2390](https://github.com/adap/flower/pull/2390), [#2493](https://github.com/adap/flower/pull/2493)) - Using the `client_fn`, Flower clients can interchangeably run as standalone processes (i.e. via `start_client`) or in simulation (i.e. via `start_simulation`) without requiring changes to how the client class is defined and instantiated. + Using the `client_fn`, Flower clients can interchangeably run as standalone processes (i.e. via `start_client`) or in simulation (i.e. via `start_simulation`) without requiring changes to how the client class is defined and instantiated. Calling `start_numpy_client` is now deprecated. - **Update Flower Baselines** @@ -28,7 +28,7 @@ - **General updates to the simulation engine** ([#2331](https://github.com/adap/flower/pull/2331), [#2447](https://github.com/adap/flower/pull/2447), [#2448](https://github.com/adap/flower/pull/2448)) -- **General improvements** ([#2309](https://github.com/adap/flower/pull/2309), [#2310](https://github.com/adap/flower/pull/2310), [2313](https://github.com/adap/flower/pull/2313), [#2316](https://github.com/adap/flower/pull/2316), [2317](https://github.com/adap/flower/pull/2317),[#2349](https://github.com/adap/flower/pull/2349), [#2360](https://github.com/adap/flower/pull/2360), [#2402](https://github.com/adap/flower/pull/2402), [#2446](https://github.com/adap/flower/pull/2446), [#2493](https://github.com/adap/flower/pull/2493)) +- **General improvements** ([#2309](https://github.com/adap/flower/pull/2309), [#2310](https://github.com/adap/flower/pull/2310), [2313](https://github.com/adap/flower/pull/2313), [#2316](https://github.com/adap/flower/pull/2316), [2317](https://github.com/adap/flower/pull/2317),[#2349](https://github.com/adap/flower/pull/2349), [#2360](https://github.com/adap/flower/pull/2360), [#2402](https://github.com/adap/flower/pull/2402), [#2446](https://github.com/adap/flower/pull/2446)) Flower received many improvements under the hood, too many to list here. diff --git a/src/py/flwr/client/__init__.py b/src/py/flwr/client/__init__.py index 4ec1082190f2..56bfadc558c3 100644 --- a/src/py/flwr/client/__init__.py +++ b/src/py/flwr/client/__init__.py @@ -19,18 +19,14 @@ from .app import start_numpy_client as start_numpy_client from .client import Client as Client from .numpy_client import NumPyClient as NumPyClient -from .numpy_client_wrapper import to_client as to_client from .run import run_client as run_client from .typing import ClientFn as ClientFn -from .typing import ClientLike as ClientLike __all__ = [ "Client", "ClientFn", - "ClientLike", "NumPyClient", "run_client", "start_client", "start_numpy_client", - "to_client", ] diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 880e2e80dbef..a74568b8e418 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -17,11 +17,12 @@ import sys import time +import warnings from logging import INFO -from typing import Callable, Optional, Union +from typing import Optional, Union from flwr.client.client import Client -from flwr.client.typing import ClientFn, ClientLike +from flwr.client.typing import ClientFn from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event from flwr.common.address import parse_address from flwr.common.constant import ( @@ -40,7 +41,7 @@ def _check_actionable_client( - client: Optional[ClientLike], client_fn: Optional[ClientFn] + client: Optional[Client], client_fn: Optional[ClientFn] ) -> None: if client_fn is None and client is None: raise Exception("Both `client_fn` and `client` are `None`, but one is required") @@ -57,7 +58,7 @@ def start_client( *, server_address: str, client_fn: Optional[ClientFn] = None, - client: Optional[ClientLike] = None, + client: Optional[Client] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, transport: Optional[str] = None, @@ -124,7 +125,7 @@ class `flwr.client.Client` (default: None) # Wrap `Client` instance in `client_fn` def single_client_factory( cid: str, # pylint: disable=unused-argument - ) -> ClientLike: + ) -> Client: if client is None: # Added this to keep mypy happy raise Exception( "Both `client_fn` and `client` are `None`, but one is required" @@ -209,8 +210,7 @@ def single_client_factory( def start_numpy_client( *, server_address: str, - client_fn: Optional[Callable[[str], NumPyClient]] = None, - client: Optional[NumPyClient] = None, + client: NumPyClient, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[bytes] = None, transport: Optional[str] = None, @@ -223,9 +223,7 @@ def start_numpy_client( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. - client_fn : Optional[Callable[[str], NumPyClient]] - A callable that instantiates a NumPyClient. (default: None) - client : Optional[flwr.client.NumPyClient] + client : flwr.client.NumPyClient An implementation of the abstract base class `flwr.client.NumPyClient`. grpc_max_message_length : int (default: 536_870_912, this equals 512MB) The maximum length of gRPC messages that can be exchanged with the @@ -248,42 +246,40 @@ def start_numpy_client( -------- Starting a client with an insecure server connection: - >>> def client_fn(cid: str): - >>> return FlowerClient() - >>> >>> start_numpy_client( >>> server_address=localhost:8080, - >>> client_fn=client_fn, + >>> client=FlowerClient(), >>> ) Starting an SSL-enabled gRPC client: >>> from pathlib import Path - >>> def client_fn(cid: str): - >>> return FlowerClient() - >>> >>> start_numpy_client( >>> server_address=localhost:8080, - >>> client_fn=client_fn, + >>> client=FlowerClient(), >>> root_certificates=Path("/crts/root.pem").read_bytes(), >>> ) """ - # Start - _check_actionable_client(client, client_fn) - - wrp_client = client.to_client() if client else None - wrp_clientfn = None - if client_fn: + warnings.warn( + "flwr.client.start_numpy_client() is deprecated and will " + "be removed in a future version of Flower. Instead, pass " + "your client to `flwr.client.start_client()` by calling " + "first the `.to_client()` method as shown below: \n" + "\tflwr.client.start_client(\n" + "\t\tserver_address=':',\n" + "\t\tclient=FlowerClient().to_client()\n" + "\t)", + DeprecationWarning, + stacklevel=2, + ) - def convert(cid: str) -> Client: - """Convert `NumPyClient` to `Client` upon instantiation.""" - return client_fn(cid).to_client() + # Calling this function is deprecated. A warning is thrown. + # We first need to convert either the supplied client to `Client.` - wrp_clientfn = convert + wrp_client = client.to_client() start_client( server_address=server_address, - client_fn=wrp_clientfn, client=wrp_client, grpc_max_message_length=grpc_max_message_length, root_certificates=root_certificates, diff --git a/src/py/flwr/client/app_test.py b/src/py/flwr/client/app_test.py index 6d1df4697a61..7ef6410debad 100644 --- a/src/py/flwr/client/app_test.py +++ b/src/py/flwr/client/app_test.py @@ -17,7 +17,6 @@ from typing import Dict, Tuple -from flwr.client import ClientLike, to_client from flwr.common import ( Config, EvaluateIns, @@ -83,26 +82,18 @@ def evaluate( def test_to_client_with_client() -> None: """Test to_client.""" - # Prepare - client_like: ClientLike = PlainClient() - - # Execute - actual = to_client(client_like=client_like) + client = PlainClient().to_client() # Assert - assert isinstance(actual, Client) + assert isinstance(client, Client) def test_to_client_with_numpyclient() -> None: """Test fit_clients.""" - # Prepare - client_like: ClientLike = NeedsWrappingClient() - - # Execute - actual = to_client(client_like=client_like) + client = NeedsWrappingClient().to_client() # Assert - assert isinstance(actual, Client) + assert isinstance(client, Client) def test_start_client_transport_invalid() -> None: diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index b64158ea3a6c..d2eecb83d71a 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -28,9 +28,8 @@ get_server_message_from_task_ins, wrap_client_message_in_task_res, ) -from flwr.client.numpy_client_wrapper import to_client from flwr.client.secure_aggregation import SecureAggregationHandler -from flwr.client.typing import ClientFn, ClientLike +from flwr.client.typing import ClientFn from flwr.common import serde from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage @@ -64,8 +63,7 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> Tuple[TaskRes, int, bool]: server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) if server_msg is None: # Instantiate the client - client_like: ClientLike = client_fn("-1") - client = to_client(client_like) + client = client_fn("-1") # Secure Aggregation if task_ins.task.HasField("sa") and isinstance( client, SecureAggregationHandler @@ -120,8 +118,7 @@ def handle_legacy_message( return disconnect_msg, sleep_duration, False # Instantiate the client - client_like: ClientLike = client_fn("-1") - client = to_client(client_like) + client = client_fn("-1") # Execute task if field == "get_properties_ins": return _get_properties(client, server_msg.get_properties_ins), 0, True diff --git a/src/py/flwr/client/numpy_client_wrapper.py b/src/py/flwr/client/numpy_client_wrapper.py deleted file mode 100644 index cfdfb6cf607c..000000000000 --- a/src/py/flwr/client/numpy_client_wrapper.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Wrapper for NumPyClient objects.""" - -from flwr.client.typing import ClientLike - -from .client import Client -from .numpy_client import NumPyClient, _wrap_numpy_client - - -def to_client(client_like: ClientLike) -> Client: - """Take any Client-like object and return it as a Client.""" - if isinstance(client_like, NumPyClient): - return _wrap_numpy_client(client=client_like) - return client_like diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index 07d1da074f1b..7ee6f069768c 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -14,10 +14,8 @@ # ============================================================================== """Custom types for Flower clients.""" -from typing import Callable, Union +from typing import Callable from .client import Client as Client -from .numpy_client import NumPyClient as NumPyClient -ClientLike = Union[Client, NumPyClient] -ClientFn = Callable[[str], ClientLike] +ClientFn = Callable[[str], Client] diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index 6e169493f5c8..0bb9290b6911 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -93,7 +93,7 @@ def start_simulation( client_fn : ClientFn A function creating client instances. The function must take a single `str` argument called `cid`. It should return a single client instance - of type ClientLike. Note that the created client instances are ephemeral + of type Client. Note that the created client instances are ephemeral and will often be destroyed after a single method invocation. Since client instances are not long-lived, they should not attempt to carry state over method invocations. Any state required by the instance (model, dataset, diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 63323f51368a..e6ddfd001ff6 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -26,8 +26,9 @@ from ray.util.actor_pool import ActorPool from flwr import common -from flwr.client import Client, ClientFn, to_client +from flwr.client import Client, ClientFn from flwr.common.logger import log +from flwr.simulation.ray_transport.utils import check_clientfn_returns_client # All possible returns by a client ClientRes = Union[ @@ -65,9 +66,8 @@ def run( # return also cid which is needed to ensure results # from the pool are correctly assigned to each ClientProxy try: - # Instantiate client - client_like = client_fn(cid) - client = to_client(client_like=client_like) + # Instantiate client (check 'Client' type is returned) + client = check_clientfn_returns_client(client_fn(cid)) # Run client job job_results = job_fn(client) except Exception as ex: diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 0365cce073b4..c4fc311b48f4 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -22,7 +22,7 @@ import ray from flwr import common -from flwr.client import Client, ClientFn, ClientLike, to_client +from flwr.client import Client, ClientFn from flwr.client.client import ( maybe_call_evaluate, maybe_call_fit, @@ -275,5 +275,5 @@ def launch_and_evaluate( def _create_client(client_fn: ClientFn, cid: str) -> Client: """Create a client instance.""" - client_like: ClientLike = client_fn(cid) - return to_client(client_like=client_like) + # Materialize client + return client_fn(cid) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 8c0aaee48af0..35a082678058 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -40,9 +40,9 @@ def __init__(self, cid: str) -> None: self.cid = int(cid) -def get_dummy_client(cid: str) -> DummyClient: - """Return a DummyClient.""" - return DummyClient(cid) +def get_dummy_client(cid: str) -> Client: + """Return a DummyClient converted to Client type.""" + return DummyClient(cid).to_client() # A dummy workload diff --git a/src/py/flwr/simulation/ray_transport/utils.py b/src/py/flwr/simulation/ray_transport/utils.py index ec3130bb1044..01e2257b429e 100644 --- a/src/py/flwr/simulation/ray_transport/utils.py +++ b/src/py/flwr/simulation/ray_transport/utils.py @@ -15,8 +15,10 @@ """Utilities for Actors in the Virtual Client Engine.""" import traceback +import warnings from logging import ERROR +from flwr.client import Client from flwr.common.logger import log try: @@ -24,6 +26,9 @@ except ModuleNotFoundError: TF = None +# Display Deprecation warning once +warnings.filterwarnings("once", category=DeprecationWarning) + def enable_tf_gpu_growth() -> None: """Enable GPU memory growth to prevent premature OOM.""" @@ -55,3 +60,25 @@ def enable_tf_gpu_growth() -> None: log(ERROR, traceback.format_exc()) log(ERROR, ex) raise ex + + +def check_clientfn_returns_client(client: Client) -> Client: + """Warn once that clients returned in `clinet_fn` should be of type Client. + + This is here for backwards compatibility. If a ClientFn is provided returning + a different type of client (e.g. NumPyClient) we'll warn the user but convert + the client internally to `Client` by calling `.to_client()`. + """ + if not isinstance(client, Client): + mssg = ( + " Ensure your client is of type `Client`. Please convert it" + " using the `.to_client()` method before returning it" + " in the `client_fn` you pass to `start_simulation`." + " We have applied this conversion on your behalf." + " Not returning a `Client` might trigger an error in future" + " versions of Flower." + ) + + warnings.warn(mssg, DeprecationWarning, stacklevel=2) + client = client.to_client() + return client