diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index a74568b8e418..0b7bc19588d5 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -19,7 +19,7 @@ import time import warnings from logging import INFO -from typing import Optional, Union +from typing import Callable, ContextManager, Optional, Tuple, Union from flwr.client.client import Client from flwr.client.typing import ClientFn @@ -33,6 +33,7 @@ TRANSPORT_TYPES, ) from flwr.common.logger import log +from flwr.proto.task_pb2 import TaskIns, TaskRes from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response @@ -134,44 +135,15 @@ def single_client_factory( client_fn = single_client_factory - # Parse IP address - parsed_address = parse_address(server_address) - if not parsed_address: - sys.exit(f"Server address ({server_address}) cannot be parsed.") - host, port, is_v6 = parsed_address - address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" - - # Set the default transport layer - if transport is None: - transport = TRANSPORT_TYPE_GRPC_BIDI - - # Use either gRPC bidirectional streaming or REST request/response - if transport == TRANSPORT_TYPE_REST: - try: - from .rest_client.connection import http_request_response - except ModuleNotFoundError: - sys.exit(MISSING_EXTRA_REST) - if server_address[:4] != "http": - sys.exit( - "When using the REST API, please provide `https://` or " - "`http://` before the server address (e.g. `http://127.0.0.1:8080`)" - ) - connection = http_request_response - elif transport == TRANSPORT_TYPE_GRPC_RERE: - connection = grpc_request_response - elif transport == TRANSPORT_TYPE_GRPC_BIDI: - connection = grpc_connection - else: - raise ValueError( - f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})" - ) + # Initialize connection context manager + connection, address = _init_connection(transport, server_address) while True: sleep_duration: int = 0 with connection( address, - max_message_length=grpc_max_message_length, - root_certificates=root_certificates, + grpc_max_message_length, + root_certificates, ) as conn: receive, send, create_node, delete_node = conn @@ -285,3 +257,54 @@ def start_numpy_client( root_certificates=root_certificates, transport=transport, ) + + +def _init_connection( + transport: Optional[str], server_address: str +) -> Tuple[ + Callable[ + [str, int, Union[bytes, str, None]], + ContextManager[ + Tuple[ + Callable[[], Optional[TaskIns]], + Callable[[TaskRes], None], + Optional[Callable[[], None]], + Optional[Callable[[], None]], + ] + ], + ], + str, +]: + # Parse IP address + parsed_address = parse_address(server_address) + if not parsed_address: + sys.exit(f"Server address ({server_address}) cannot be parsed.") + host, port, is_v6 = parsed_address + address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" + + # Set the default transport layer + if transport is None: + transport = TRANSPORT_TYPE_GRPC_BIDI + + # Use either gRPC bidirectional streaming or REST request/response + if transport == TRANSPORT_TYPE_REST: + try: + from .rest_client.connection import http_request_response + except ModuleNotFoundError: + sys.exit(MISSING_EXTRA_REST) + if server_address[:4] != "http": + sys.exit( + "When using the REST API, please provide `https://` or " + "`http://` before the server address (e.g. `http://127.0.0.1:8080`)" + ) + connection = http_request_response + elif transport == TRANSPORT_TYPE_GRPC_RERE: + connection = grpc_request_response + elif transport == TRANSPORT_TYPE_GRPC_BIDI: + connection = grpc_connection + else: + raise ValueError( + f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})" + ) + + return connection, address