Skip to content

Commit

Permalink
Unify client types (#2390)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
jafermarq and danieljanes authored Oct 14, 2023
1 parent 4df413f commit bfd5656
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 97 deletions.
6 changes: 3 additions & 3 deletions doc/source/ref-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

The types of the return values in the docstrings in two methods (`aggregate_fit` and `aggregate_evaluate`) now match the hint types in the code.

- **Unify client API** ([#2303](https://github.com/adap/flower/pull/2303))
- **Unify client API** ([#2303](https://github.com/adap/flower/pull/2303), [#2390](https://github.com/adap/flower/pull/2390), [#2493](https://github.com/adap/flower/pull/2493))

Using the `client_fn`, Flower clients can interchangeably run as standalone processes (i.e. via `start_client`) or in simulation (i.e. via `start_simulation`) without requiring changes to how the client class is defined and instantiated.
Using the `client_fn`, Flower clients can interchangeably run as standalone processes (i.e. via `start_client`) or in simulation (i.e. via `start_simulation`) without requiring changes to how the client class is defined and instantiated. Calling `start_numpy_client` is now deprecated.

- **Update Flower Baselines**

Expand All @@ -28,7 +28,7 @@

- **General updates to the simulation engine** ([#2331](https://github.com/adap/flower/pull/2331), [#2447](https://github.com/adap/flower/pull/2447), [#2448](https://github.com/adap/flower/pull/2448))

- **General improvements** ([#2309](https://github.com/adap/flower/pull/2309), [#2310](https://github.com/adap/flower/pull/2310), [2313](https://github.com/adap/flower/pull/2313), [#2316](https://github.com/adap/flower/pull/2316), [2317](https://github.com/adap/flower/pull/2317),[#2349](https://github.com/adap/flower/pull/2349), [#2360](https://github.com/adap/flower/pull/2360), [#2402](https://github.com/adap/flower/pull/2402), [#2446](https://github.com/adap/flower/pull/2446), [#2493](https://github.com/adap/flower/pull/2493))
- **General improvements** ([#2309](https://github.com/adap/flower/pull/2309), [#2310](https://github.com/adap/flower/pull/2310), [2313](https://github.com/adap/flower/pull/2313), [#2316](https://github.com/adap/flower/pull/2316), [2317](https://github.com/adap/flower/pull/2317),[#2349](https://github.com/adap/flower/pull/2349), [#2360](https://github.com/adap/flower/pull/2360), [#2402](https://github.com/adap/flower/pull/2402), [#2446](https://github.com/adap/flower/pull/2446))

Flower received many improvements under the hood, too many to list here.

Expand Down
4 changes: 0 additions & 4 deletions src/py/flwr/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,14 @@
from .app import start_numpy_client as start_numpy_client
from .client import Client as Client
from .numpy_client import NumPyClient as NumPyClient
from .numpy_client_wrapper import to_client as to_client
from .run import run_client as run_client
from .typing import ClientFn as ClientFn
from .typing import ClientLike as ClientLike

__all__ = [
"Client",
"ClientFn",
"ClientLike",
"NumPyClient",
"run_client",
"start_client",
"start_numpy_client",
"to_client",
]
54 changes: 25 additions & 29 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

import sys
import time
import warnings
from logging import INFO
from typing import Callable, Optional, Union
from typing import Optional, Union

from flwr.client.client import Client
from flwr.client.typing import ClientFn, ClientLike
from flwr.client.typing import ClientFn
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
from flwr.common.address import parse_address
from flwr.common.constant import (
Expand All @@ -40,7 +41,7 @@


def _check_actionable_client(
client: Optional[ClientLike], client_fn: Optional[ClientFn]
client: Optional[Client], client_fn: Optional[ClientFn]
) -> None:
if client_fn is None and client is None:
raise Exception("Both `client_fn` and `client` are `None`, but one is required")
Expand All @@ -57,7 +58,7 @@ def start_client(
*,
server_address: str,
client_fn: Optional[ClientFn] = None,
client: Optional[ClientLike] = None,
client: Optional[Client] = None,
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
root_certificates: Optional[Union[bytes, str]] = None,
transport: Optional[str] = None,
Expand Down Expand Up @@ -124,7 +125,7 @@ class `flwr.client.Client` (default: None)
# Wrap `Client` instance in `client_fn`
def single_client_factory(
cid: str, # pylint: disable=unused-argument
) -> ClientLike:
) -> Client:
if client is None: # Added this to keep mypy happy
raise Exception(
"Both `client_fn` and `client` are `None`, but one is required"
Expand Down Expand Up @@ -209,8 +210,7 @@ def single_client_factory(
def start_numpy_client(
*,
server_address: str,
client_fn: Optional[Callable[[str], NumPyClient]] = None,
client: Optional[NumPyClient] = None,
client: NumPyClient,
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
root_certificates: Optional[bytes] = None,
transport: Optional[str] = None,
Expand All @@ -223,9 +223,7 @@ def start_numpy_client(
The IPv4 or IPv6 address of the server. If the Flower server runs on
the same machine on port 8080, then `server_address` would be
`"[::]:8080"`.
client_fn : Optional[Callable[[str], NumPyClient]]
A callable that instantiates a NumPyClient. (default: None)
client : Optional[flwr.client.NumPyClient]
client : flwr.client.NumPyClient
An implementation of the abstract base class `flwr.client.NumPyClient`.
grpc_max_message_length : int (default: 536_870_912, this equals 512MB)
The maximum length of gRPC messages that can be exchanged with the
Expand All @@ -248,42 +246,40 @@ def start_numpy_client(
--------
Starting a client with an insecure server connection:
>>> def client_fn(cid: str):
>>> return FlowerClient()
>>>
>>> start_numpy_client(
>>> server_address=localhost:8080,
>>> client_fn=client_fn,
>>> client=FlowerClient(),
>>> )
Starting an SSL-enabled gRPC client:
>>> from pathlib import Path
>>> def client_fn(cid: str):
>>> return FlowerClient()
>>>
>>> start_numpy_client(
>>> server_address=localhost:8080,
>>> client_fn=client_fn,
>>> client=FlowerClient(),
>>> root_certificates=Path("/crts/root.pem").read_bytes(),
>>> )
"""
# Start
_check_actionable_client(client, client_fn)

wrp_client = client.to_client() if client else None
wrp_clientfn = None
if client_fn:
warnings.warn(
"flwr.client.start_numpy_client() is deprecated and will "
"be removed in a future version of Flower. Instead, pass "
"your client to `flwr.client.start_client()` by calling "
"first the `.to_client()` method as shown below: \n"
"\tflwr.client.start_client(\n"
"\t\tserver_address='<IP>:<PORT>',\n"
"\t\tclient=FlowerClient().to_client()\n"
"\t)",
DeprecationWarning,
stacklevel=2,
)

def convert(cid: str) -> Client:
"""Convert `NumPyClient` to `Client` upon instantiation."""
return client_fn(cid).to_client()
# Calling this function is deprecated. A warning is thrown.
# We first need to convert either the supplied client to `Client.`

wrp_clientfn = convert
wrp_client = client.to_client()

start_client(
server_address=server_address,
client_fn=wrp_clientfn,
client=wrp_client,
grpc_max_message_length=grpc_max_message_length,
root_certificates=root_certificates,
Expand Down
17 changes: 4 additions & 13 deletions src/py/flwr/client/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from typing import Dict, Tuple

from flwr.client import ClientLike, to_client
from flwr.common import (
Config,
EvaluateIns,
Expand Down Expand Up @@ -83,26 +82,18 @@ def evaluate(

def test_to_client_with_client() -> None:
"""Test to_client."""
# Prepare
client_like: ClientLike = PlainClient()

# Execute
actual = to_client(client_like=client_like)
client = PlainClient().to_client()

# Assert
assert isinstance(actual, Client)
assert isinstance(client, Client)


def test_to_client_with_numpyclient() -> None:
"""Test fit_clients."""
# Prepare
client_like: ClientLike = NeedsWrappingClient()

# Execute
actual = to_client(client_like=client_like)
client = NeedsWrappingClient().to_client()

# Assert
assert isinstance(actual, Client)
assert isinstance(client, Client)


def test_start_client_transport_invalid() -> None:
Expand Down
9 changes: 3 additions & 6 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@
get_server_message_from_task_ins,
wrap_client_message_in_task_res,
)
from flwr.client.numpy_client_wrapper import to_client
from flwr.client.secure_aggregation import SecureAggregationHandler
from flwr.client.typing import ClientFn, ClientLike
from flwr.client.typing import ClientFn
from flwr.common import serde
from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes
from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage
Expand Down Expand Up @@ -64,8 +63,7 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> Tuple[TaskRes, int, bool]:
server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False)
if server_msg is None:
# Instantiate the client
client_like: ClientLike = client_fn("-1")
client = to_client(client_like)
client = client_fn("-1")
# Secure Aggregation
if task_ins.task.HasField("sa") and isinstance(
client, SecureAggregationHandler
Expand Down Expand Up @@ -120,8 +118,7 @@ def handle_legacy_message(
return disconnect_msg, sleep_duration, False

# Instantiate the client
client_like: ClientLike = client_fn("-1")
client = to_client(client_like)
client = client_fn("-1")
# Execute task
if field == "get_properties_ins":
return _get_properties(client, server_msg.get_properties_ins), 0, True
Expand Down
27 changes: 0 additions & 27 deletions src/py/flwr/client/numpy_client_wrapper.py

This file was deleted.

6 changes: 2 additions & 4 deletions src/py/flwr/client/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
# ==============================================================================
"""Custom types for Flower clients."""

from typing import Callable, Union
from typing import Callable

from .client import Client as Client
from .numpy_client import NumPyClient as NumPyClient

ClientLike = Union[Client, NumPyClient]
ClientFn = Callable[[str], ClientLike]
ClientFn = Callable[[str], Client]
2 changes: 1 addition & 1 deletion src/py/flwr/simulation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def start_simulation(
client_fn : ClientFn
A function creating client instances. The function must take a single
`str` argument called `cid`. It should return a single client instance
of type ClientLike. Note that the created client instances are ephemeral
of type Client. Note that the created client instances are ephemeral
and will often be destroyed after a single method invocation. Since client
instances are not long-lived, they should not attempt to carry state over
method invocations. Any state required by the instance (model, dataset,
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/simulation/ray_transport/ray_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from ray.util.actor_pool import ActorPool

from flwr import common
from flwr.client import Client, ClientFn, to_client
from flwr.client import Client, ClientFn
from flwr.common.logger import log
from flwr.simulation.ray_transport.utils import check_clientfn_returns_client

# All possible returns by a client
ClientRes = Union[
Expand Down Expand Up @@ -65,9 +66,8 @@ def run(
# return also cid which is needed to ensure results
# from the pool are correctly assigned to each ClientProxy
try:
# Instantiate client
client_like = client_fn(cid)
client = to_client(client_like=client_like)
# Instantiate client (check 'Client' type is returned)
client = check_clientfn_returns_client(client_fn(cid))
# Run client job
job_results = job_fn(client)
except Exception as ex:
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import ray

from flwr import common
from flwr.client import Client, ClientFn, ClientLike, to_client
from flwr.client import Client, ClientFn
from flwr.client.client import (
maybe_call_evaluate,
maybe_call_fit,
Expand Down Expand Up @@ -275,5 +275,5 @@ def launch_and_evaluate(

def _create_client(client_fn: ClientFn, cid: str) -> Client:
"""Create a client instance."""
client_like: ClientLike = client_fn(cid)
return to_client(client_like=client_like)
# Materialize client
return client_fn(cid)
6 changes: 3 additions & 3 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def __init__(self, cid: str) -> None:
self.cid = int(cid)


def get_dummy_client(cid: str) -> DummyClient:
"""Return a DummyClient."""
return DummyClient(cid)
def get_dummy_client(cid: str) -> Client:
"""Return a DummyClient converted to Client type."""
return DummyClient(cid).to_client()


# A dummy workload
Expand Down
27 changes: 27 additions & 0 deletions src/py/flwr/simulation/ray_transport/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,20 @@
"""Utilities for Actors in the Virtual Client Engine."""

import traceback
import warnings
from logging import ERROR

from flwr.client import Client
from flwr.common.logger import log

try:
import tensorflow as TF
except ModuleNotFoundError:
TF = None

# Display Deprecation warning once
warnings.filterwarnings("once", category=DeprecationWarning)


def enable_tf_gpu_growth() -> None:
"""Enable GPU memory growth to prevent premature OOM."""
Expand Down Expand Up @@ -55,3 +60,25 @@ def enable_tf_gpu_growth() -> None:
log(ERROR, traceback.format_exc())
log(ERROR, ex)
raise ex


def check_clientfn_returns_client(client: Client) -> Client:
"""Warn once that clients returned in `clinet_fn` should be of type Client.
This is here for backwards compatibility. If a ClientFn is provided returning
a different type of client (e.g. NumPyClient) we'll warn the user but convert
the client internally to `Client` by calling `.to_client()`.
"""
if not isinstance(client, Client):
mssg = (
" Ensure your client is of type `Client`. Please convert it"
" using the `.to_client()` method before returning it"
" in the `client_fn` you pass to `start_simulation`."
" We have applied this conversion on your behalf."
" Not returning a `Client` might trigger an error in future"
" versions of Flower."
)

warnings.warn(mssg, DeprecationWarning, stacklevel=2)
client = client.to_client()
return client

0 comments on commit bfd5656

Please sign in to comment.