Skip to content

Commit

Permalink
Refactor start_client (#2548)
Browse files Browse the repository at this point in the history
Co-authored-by: Charles Beauville <[email protected]>
  • Loading branch information
danieljanes and charlesbvll authored Oct 30, 2023
1 parent 3c6c297 commit 95b0d3b
Showing 1 changed file with 57 additions and 34 deletions.
91 changes: 57 additions & 34 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import time
import warnings
from logging import INFO
from typing import Optional, Union
from typing import Callable, ContextManager, Optional, Tuple, Union

from flwr.client.client import Client
from flwr.client.typing import ClientFn
Expand All @@ -33,6 +33,7 @@
TRANSPORT_TYPES,
)
from flwr.common.logger import log
from flwr.proto.task_pb2 import TaskIns, TaskRes

from .grpc_client.connection import grpc_connection
from .grpc_rere_client.connection import grpc_request_response
Expand Down Expand Up @@ -134,44 +135,15 @@ def single_client_factory(

client_fn = single_client_factory

# Parse IP address
parsed_address = parse_address(server_address)
if not parsed_address:
sys.exit(f"Server address ({server_address}) cannot be parsed.")
host, port, is_v6 = parsed_address
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"

# Set the default transport layer
if transport is None:
transport = TRANSPORT_TYPE_GRPC_BIDI

# Use either gRPC bidirectional streaming or REST request/response
if transport == TRANSPORT_TYPE_REST:
try:
from .rest_client.connection import http_request_response
except ModuleNotFoundError:
sys.exit(MISSING_EXTRA_REST)
if server_address[:4] != "http":
sys.exit(
"When using the REST API, please provide `https://` or "
"`http://` before the server address (e.g. `http://127.0.0.1:8080`)"
)
connection = http_request_response
elif transport == TRANSPORT_TYPE_GRPC_RERE:
connection = grpc_request_response
elif transport == TRANSPORT_TYPE_GRPC_BIDI:
connection = grpc_connection
else:
raise ValueError(
f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})"
)
# Initialize connection context manager
connection, address = _init_connection(transport, server_address)

while True:
sleep_duration: int = 0
with connection(
address,
max_message_length=grpc_max_message_length,
root_certificates=root_certificates,
grpc_max_message_length,
root_certificates,
) as conn:
receive, send, create_node, delete_node = conn

Expand Down Expand Up @@ -285,3 +257,54 @@ def start_numpy_client(
root_certificates=root_certificates,
transport=transport,
)


def _init_connection(
transport: Optional[str], server_address: str
) -> Tuple[
Callable[
[str, int, Union[bytes, str, None]],
ContextManager[
Tuple[
Callable[[], Optional[TaskIns]],
Callable[[TaskRes], None],
Optional[Callable[[], None]],
Optional[Callable[[], None]],
]
],
],
str,
]:
# Parse IP address
parsed_address = parse_address(server_address)
if not parsed_address:
sys.exit(f"Server address ({server_address}) cannot be parsed.")
host, port, is_v6 = parsed_address
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"

# Set the default transport layer
if transport is None:
transport = TRANSPORT_TYPE_GRPC_BIDI

# Use either gRPC bidirectional streaming or REST request/response
if transport == TRANSPORT_TYPE_REST:
try:
from .rest_client.connection import http_request_response
except ModuleNotFoundError:
sys.exit(MISSING_EXTRA_REST)
if server_address[:4] != "http":
sys.exit(
"When using the REST API, please provide `https://` or "
"`http://` before the server address (e.g. `http://127.0.0.1:8080`)"
)
connection = http_request_response
elif transport == TRANSPORT_TYPE_GRPC_RERE:
connection = grpc_request_response
elif transport == TRANSPORT_TYPE_GRPC_BIDI:
connection = grpc_connection
else:
raise ValueError(
f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})"
)

return connection, address

0 comments on commit 95b0d3b

Please sign in to comment.