Skip to content

Commit

Permalink
Refactor flower-client internals (#2689)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Taner Topal <[email protected]>
  • Loading branch information
danieljanes and tanertopal authored Dec 6, 2023
1 parent 30bc20c commit b7e7bde
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 15 deletions.
88 changes: 74 additions & 14 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from flwr.common.logger import log, warn_experimental_feature
from flwr.proto.task_pb2 import TaskIns, TaskRes

from .flower import load_callable
from .flower import load_flower_callable
from .grpc_client.connection import grpc_connection
from .grpc_rere_client.connection import grpc_request_response
from .message_handler.message_handler import handle_control_message
Expand All @@ -47,6 +47,8 @@

def run_client() -> None:
"""Run Flower client."""
event(EventType.RUN_CLIENT_ENTER)

log(INFO, "Long-running Flower client starting")

args = _parse_args_client().parse_args()
Expand Down Expand Up @@ -80,16 +82,17 @@ def run_client() -> None:
sys.path.insert(0, callable_dir)

def _load() -> Flower:
flower: Flower = load_callable(args.callable)
flower: Flower = load_flower_callable(args.callable)
return flower

return start_client(
_start_client_internal(
server_address=args.server,
load_callable_fn=_load,
load_flower_callable_fn=_load,
transport="grpc-rere", # Only
root_certificates=root_certificates,
insecure=args.insecure,
)
event(EventType.RUN_CLIENT_LEAVE)


def _parse_args_client() -> argparse.ArgumentParser:
Expand Down Expand Up @@ -149,7 +152,6 @@ def _check_actionable_client(
def start_client(
*,
server_address: str,
load_callable_fn: Optional[Callable[[], Flower]] = None,
client_fn: Optional[ClientFn] = None,
client: Optional[Client] = None,
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
Expand All @@ -165,8 +167,6 @@ def start_client(
The IPv4 or IPv6 address of the server. If the Flower
server runs on the same machine on port 8080, then `server_address`
would be `"[::]:8080"`.
load_callable_fn : Optional[Callable[[], Flower]] (default: None)
...
client_fn : Optional[ClientFn]
A callable that instantiates a Client. (default: None)
client : Optional[flwr.client.Client]
Expand Down Expand Up @@ -223,11 +223,73 @@ class `flwr.client.Client` (default: None)
>>> )
"""
event(EventType.START_CLIENT_ENTER)
_start_client_internal(
server_address=server_address,
load_flower_callable_fn=None,
client_fn=client_fn,
client=client,
grpc_max_message_length=grpc_max_message_length,
root_certificates=root_certificates,
insecure=insecure,
transport=transport,
)
event(EventType.START_CLIENT_LEAVE)


# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-branches
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
def _start_client_internal(
*,
server_address: str,
load_flower_callable_fn: Optional[Callable[[], Flower]] = None,
client_fn: Optional[ClientFn] = None,
client: Optional[Client] = None,
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
root_certificates: Optional[Union[bytes, str]] = None,
insecure: Optional[bool] = None,
transport: Optional[str] = None,
) -> None:
"""Start a Flower client node which connects to a Flower server.
Parameters
----------
server_address : str
The IPv4 or IPv6 address of the server. If the Flower
server runs on the same machine on port 8080, then `server_address`
would be `"[::]:8080"`.
load_flower_callable_fn : Optional[Callable[[], Flower]] (default: None)
A function that can be used to load a `Flower` callable instance.
client_fn : Optional[ClientFn]
A callable that instantiates a Client. (default: None)
client : Optional[flwr.client.Client]
An implementation of the abstract base
class `flwr.client.Client` (default: None)
grpc_max_message_length : int (default: 536_870_912, this equals 512MB)
The maximum length of gRPC messages that can be exchanged with the
Flower server. The default should be sufficient for most models.
Users who train very large models might need to increase this
value. Note that the Flower server needs to be started with the
same value (see `flwr.server.start_server`), otherwise it will not
know about the increased limit and block larger messages.
root_certificates : Optional[Union[bytes, str]] (default: None)
The PEM-encoded root certificates as a byte string or a path string.
If provided, a secure connection using the certificates will be
established to an SSL-enabled Flower server.
insecure : bool (default: True)
Starts an insecure gRPC connection when True. Enables HTTPS connection
when False, using system certificates if `root_certificates` is None.
transport : Optional[str] (default: None)
Configure the transport layer. Allowed values:
- 'grpc-bidi': gRPC, bidirectional streaming
- 'grpc-rere': gRPC, request-response (experimental)
- 'rest': HTTP (experimental)
"""
if insecure is None:
insecure = root_certificates is None

if load_callable_fn is None:
if load_flower_callable_fn is None:
_check_actionable_client(client, client_fn)

if client_fn is None:
Expand All @@ -246,11 +308,11 @@ def single_client_factory(
def _load_app() -> Flower:
return Flower(client_fn=client_fn)

load_callable_fn = _load_app
load_flower_callable_fn = _load_app
else:
warn_experimental_feature("`load_callable_fn`")
warn_experimental_feature("`load_flower_callable_fn`")

# At this point, only `load_callable_fn` should be used
# At this point, only `load_flower_callable_fn` should be used
# Both `client` and `client_fn` must not be used directly

# Initialize connection context manager
Expand Down Expand Up @@ -284,7 +346,7 @@ def _load_app() -> Flower:
break

# Load app
app: Flower = load_callable_fn()
app: Flower = load_flower_callable_fn()

# Handle task message
fwd_msg: Fwd = Fwd(
Expand All @@ -311,8 +373,6 @@ def _load_app() -> Flower:
)
time.sleep(sleep_duration)

event(EventType.START_CLIENT_LEAVE)


def start_numpy_client(
*,
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/flower.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class LoadCallableError(Exception):
"""."""


def load_callable(module_attribute_str: str) -> Flower:
def load_flower_callable(module_attribute_str: str) -> Flower:
"""Load the `Flower` object specified in a module attribute string.
The module/attribute string should have the form <module>:<attribute>. Valid
Expand Down
4 changes: 4 additions & 0 deletions src/py/flwr/common/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ def _generate_next_value_(name: str, start: int, count: int, last_values: List[A
START_DRIVER_ENTER = auto()
START_DRIVER_LEAVE = auto()

# SuperNode: flower-client
RUN_CLIENT_ENTER = auto()
RUN_CLIENT_LEAVE = auto()


# Use the ThreadPoolExecutor with max_workers=1 to have a queue
# and also ensure that telemetry calls are not blocking.
Expand Down

0 comments on commit b7e7bde

Please sign in to comment.