From 875b6c740cc6add24cb6bdb05c3017b429ad094e Mon Sep 17 00:00:00 2001 From: Heng Pan <134433891+panh99@users.noreply.github.com> Date: Tue, 28 Nov 2023 17:44:46 +0000 Subject: [PATCH] Make `flower-client` HTTPS by default (#2636) Co-authored-by: Taner Topal Co-authored-by: Daniel J. Beutel --- e2e/bare-https/client.py | 6 +- e2e/bare/client.py | 5 +- e2e/fastai/client.py | 7 +- e2e/jax/client.py | 6 +- e2e/mxnet/client.py | 6 +- e2e/opacus/client.py | 6 +- e2e/pandas/client.py | 6 +- e2e/pytorch-lightning/client.py | 6 +- e2e/pytorch/client.py | 6 +- e2e/scikit-learn/client.py | 6 +- e2e/strategies/client.py | 9 +++ e2e/tabnet/client.py | 6 +- e2e/tensorflow/client.py | 6 +- e2e/test_driver.sh | 14 ++-- src/py/flwr/client/app.py | 78 +++++++++++++++++-- src/py/flwr/client/grpc_client/connection.py | 2 + .../client/grpc_client/connection_test.py | 2 +- .../client/grpc_rere_client/connection.py | 6 +- src/py/flwr/client/rest_client/connection.py | 1 + src/py/flwr/common/grpc.py | 17 +++- src/py/flwr/driver/grpc_driver.py | 1 + 21 files changed, 169 insertions(+), 33 deletions(-) diff --git a/e2e/bare-https/client.py b/e2e/bare-https/client.py index da04a320b1d2..20a5b4875ddf 100644 --- a/e2e/bare-https/client.py +++ b/e2e/bare-https/client.py @@ -23,8 +23,11 @@ def evaluate(self, parameters, config): return loss, 1, {"accuracy": accuracy} def client_fn(cid): - return FlowerClient() + return FlowerClient().to_client() +flower = fl.flower.Flower( + client_fn=client_fn, +) if __name__ == "__main__": # Start Flower client @@ -32,4 +35,5 @@ def client_fn(cid): server_address="127.0.0.1:8080", client=FlowerClient(), root_certificates=Path("certificates/ca.crt").read_bytes(), + insecure=False, ) diff --git a/e2e/bare/client.py b/e2e/bare/client.py index 5010d1810387..05b997ff4133 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -24,8 +24,11 @@ def evaluate(self, parameters, config): return loss, 1, {"accuracy": accuracy} def client_fn(cid): - return FlowerClient() + return FlowerClient().to_client() +flower = fl.flower.Flower( + client_fn=client_fn, +) if __name__ == "__main__": # Start Flower client diff --git a/e2e/fastai/client.py b/e2e/fastai/client.py index 0f83cc330c45..4425fed25277 100644 --- a/e2e/fastai/client.py +++ b/e2e/fastai/client.py @@ -50,7 +50,12 @@ def evaluate(self, parameters, config): def client_fn(cid): - return FlowerClient() + return FlowerClient().to_client() + + +flower = fl.flower.Flower( + client_fn=client_fn, +) if __name__ == "__main__": diff --git a/e2e/jax/client.py b/e2e/jax/client.py index 466a829575f6..495d6a671981 100644 --- a/e2e/jax/client.py +++ b/e2e/jax/client.py @@ -51,7 +51,11 @@ def evaluate( return float(loss), num_examples, {"loss": float(loss)} def client_fn(cid): - return FlowerClient() + return FlowerClient().to_client() + +flower = fl.flower.Flower( + client_fn=client_fn, +) if __name__ == "__main__": # Start Flower client diff --git a/e2e/mxnet/client.py b/e2e/mxnet/client.py index 1907d47f7c53..2f0b714e708c 100644 --- a/e2e/mxnet/client.py +++ b/e2e/mxnet/client.py @@ -130,7 +130,11 @@ def evaluate(self, parameters, config): def client_fn(cid): - return FlowerClient() + return FlowerClient().to_client() + +flower = fl.flower.Flower( + client_fn=client_fn, +) if __name__ == "__main__": # Start Flower client diff --git a/e2e/opacus/client.py b/e2e/opacus/client.py index 552060916154..2e5c363381fa 100644 --- a/e2e/opacus/client.py +++ b/e2e/opacus/client.py @@ -135,7 +135,11 @@ def evaluate(self, parameters, config): def client_fn(cid): model = Net() - return FlowerClient(model) + return FlowerClient(model).to_client() + +flower = fl.flower.Flower( + client_fn=client_fn, +) if __name__ == "__main__": fl.client.start_numpy_client( diff --git a/e2e/pandas/client.py b/e2e/pandas/client.py index f7ff6fc2bccb..5b8670091cb3 100644 --- a/e2e/pandas/client.py +++ b/e2e/pandas/client.py @@ -34,7 +34,11 @@ def fit( ) def client_fn(cid): - return FlowerClient() + return FlowerClient().to_client() + +flower = fl.flower.Flower( + client_fn=client_fn, +) if __name__ == "__main__": # Start Flower client diff --git a/e2e/pytorch-lightning/client.py b/e2e/pytorch-lightning/client.py index e05caf0b93f4..71b178eca8c3 100644 --- a/e2e/pytorch-lightning/client.py +++ b/e2e/pytorch-lightning/client.py @@ -53,7 +53,11 @@ def client_fn(cid): train_loader, val_loader, test_loader = mnist.load_data() # Flower client - return FlowerClient(model, train_loader, val_loader, test_loader) + return FlowerClient(model, train_loader, val_loader, test_loader).to_client() + +flower = fl.flower.Flower( + client_fn=client_fn, +) def main() -> None: # Model and data diff --git a/e2e/pytorch/client.py b/e2e/pytorch/client.py index ae6c40c329ac..f4e7e0300a06 100644 --- a/e2e/pytorch/client.py +++ b/e2e/pytorch/client.py @@ -107,7 +107,11 @@ def set_parameters(model, parameters): return def client_fn(cid): - return FlowerClient() + return FlowerClient().to_client() + +flower = fl.flower.Flower( + client_fn=client_fn, +) if __name__ == "__main__": diff --git a/e2e/scikit-learn/client.py b/e2e/scikit-learn/client.py index 1f2f0291c1ec..fdca96c1697a 100644 --- a/e2e/scikit-learn/client.py +++ b/e2e/scikit-learn/client.py @@ -44,7 +44,11 @@ def evaluate(self, parameters, config): # type: ignore return loss, len(X_test), {"accuracy": accuracy} def client_fn(cid): - return FlowerClient() + return FlowerClient().to_client() + +flower = fl.flower.Flower( + client_fn=client_fn, +) if __name__ == "__main__": # Start Flower client diff --git a/e2e/strategies/client.py b/e2e/strategies/client.py index de321658c40f..eb4598cb5439 100644 --- a/e2e/strategies/client.py +++ b/e2e/strategies/client.py @@ -43,6 +43,15 @@ def evaluate(self, parameters, config): return loss, len(x_test), {"accuracy": accuracy} +def client_fn(cid): + return FlowerClient().to_client() + + +flower = fl.flower.Flower( + client_fn=client_fn, +) + + if __name__ == "__main__": # Start Flower client fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) diff --git a/e2e/tabnet/client.py b/e2e/tabnet/client.py index 58982543b8bb..3c10df0c79f1 100644 --- a/e2e/tabnet/client.py +++ b/e2e/tabnet/client.py @@ -79,7 +79,11 @@ def evaluate(self, parameters, config): def client_fn(cid): - return FlowerClient() + return FlowerClient().to_client() + +flower = fl.flower.Flower( + client_fn=client_fn, +) if __name__ == "__main__": # Start Flower client diff --git a/e2e/tensorflow/client.py b/e2e/tensorflow/client.py index fe5b2a3351fc..4ad2d5ebda57 100644 --- a/e2e/tensorflow/client.py +++ b/e2e/tensorflow/client.py @@ -32,7 +32,11 @@ def evaluate(self, parameters, config): return loss, len(x_test), {"accuracy": accuracy} def client_fn(cid): - return FlowerClient() + return FlowerClient().to_client() + +flower = fl.flower.Flower( + client_fn=client_fn, +) if __name__ == "__main__": # Start Flower client diff --git a/e2e/test_driver.sh b/e2e/test_driver.sh index e5baac20fa1f..ca54dbf4852f 100755 --- a/e2e/test_driver.sh +++ b/e2e/test_driver.sh @@ -4,20 +4,22 @@ set -e case "$1" in bare-https) ./generate.sh - cert_arg="--certificates certificates/ca.crt certificates/server.pem certificates/server.key" + server_arg="--certificates certificates/ca.crt certificates/server.pem certificates/server.key" + client_arg="--root-certificates certificates/ca.crt" ;; *) - cert_arg="--insecure" + server_arg="--insecure" + client_arg="--insecure" ;; esac -timeout 2m flower-server $cert_arg --grpc-bidi --grpc-bidi-fleet-api-address 0.0.0.0:8080 & +timeout 2m flower-server $server_arg & sleep 3 -python client.py & +timeout 2m flower-client $client_arg --callable client:flower --server 127.0.0.1:9092 & sleep 3 -python client.py & +timeout 2m flower-client $client_arg --callable client:flower --server 127.0.0.1:9092 & sleep 3 timeout 2m python driver.py & @@ -27,7 +29,7 @@ wait $pid res=$? if [[ "$res" = "0" ]]; - then echo "Training worked correctly" && pkill python; + then echo "Training worked correctly" && pkill flower-client && pkill flower-server; else echo "Training had an issue" && exit 1; fi diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index b39dbbfc33c0..81bbee148c95 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -18,7 +18,8 @@ import argparse import sys import time -from logging import INFO +from logging import INFO, WARN +from pathlib import Path from typing import Callable, ContextManager, Optional, Tuple, Union from flwr.client.client import Client @@ -50,6 +51,26 @@ def run_client() -> None: args = _parse_args_client().parse_args() + # Obtain certificates + if args.insecure: + if args.root_certificates is not None: + sys.exit( + "Conflicting options: The '--insecure' flag disables HTTPS, " + "but '--root-certificates' was also specified. Please remove " + "the '--root-certificates' option when running in insecure mode, " + "or omit '--insecure' to use HTTPS." + ) + log(WARN, "Option `--insecure` was set. Starting insecure HTTP client.") + root_certificates = None + else: + # Load the certificates if provided, or load the system certificates + cert_path = args.root_certificates + if cert_path is None: + root_certificates = None + else: + root_certificates = Path(cert_path).read_bytes() + + print(args.root_certificates) print(args.server) print(args.callable_dir) print(args.callable) @@ -66,6 +87,8 @@ def _load() -> Flower: server_address=args.server, load_callable_fn=_load, transport="grpc-rere", # Only + root_certificates=root_certificates, + insecure=args.insecure, ) @@ -75,6 +98,19 @@ def _parse_args_client() -> argparse.ArgumentParser: description="Start a long-running Flower client", ) + parser.add_argument( + "--insecure", + action="store_true", + help="Run the client without HTTPS. By default, the client runs with " + "HTTPS enabled. Use this flag only if you understand the risks.", + ) + parser.add_argument( + "--root-certificates", + metavar="ROOT_CERT", + type=str, + help="Specifies the path to the PEM-encoded root certificate file for " + "establishing secure HTTPS connections.", + ) parser.add_argument( "--server", default="0.0.0.0:9092", @@ -118,6 +154,7 @@ def start_client( client: Optional[Client] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, + insecure: Optional[bool] = None, transport: Optional[str] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -146,6 +183,9 @@ class `flwr.client.Client` (default: None) The PEM-encoded root certificates as a byte string or a path string. If provided, a secure connection using the certificates will be established to an SSL-enabled Flower server. + insecure : bool (default: True) + Starts an insecure gRPC connection when True. Enables HTTPS connection + when False, using system certificates if `root_certificates` is None. transport : Optional[str] (default: None) Configure the transport layer. Allowed values: - 'grpc-bidi': gRPC, bidirectional streaming @@ -156,19 +196,25 @@ class `flwr.client.Client` (default: None) -------- Starting a gRPC client with an insecure server connection: + >>> start_client( + >>> server_address=localhost:8080, + >>> client_fn=client_fn, + >>> ) + + Starting an SSL-enabled gRPC client using system certificates: + >>> def client_fn(cid: str): >>> return FlowerClient() >>> >>> start_client( >>> server_address=localhost:8080, >>> client_fn=client_fn, + >>> insecure=False, >>> ) - Starting an SSL-enabled gRPC client: + Starting an SSL-enabled gRPC client using provided certificates: >>> from pathlib import Path - >>> def client_fn(cid: str): - >>> return FlowerClient() >>> >>> start_client( >>> server_address=localhost:8080, @@ -178,6 +224,9 @@ class `flwr.client.Client` (default: None) """ event(EventType.START_CLIENT_ENTER) + if insecure is None: + insecure = root_certificates is None + if load_callable_fn is None: _check_actionable_client(client, client_fn) @@ -211,6 +260,7 @@ def _load_app() -> Flower: sleep_duration: int = 0 with connection( address, + insecure, grpc_max_message_length, root_certificates, ) as conn: @@ -270,6 +320,7 @@ def start_numpy_client( client: NumPyClient, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[bytes] = None, + insecure: Optional[bool] = None, transport: Optional[str] = None, ) -> None: """Start a Flower NumPyClient which connects to a gRPC server. @@ -293,6 +344,9 @@ def start_numpy_client( The PEM-encoded root certificates as a byte string or a path string. If provided, a secure connection using the certificates will be established to an SSL-enabled Flower server. + insecure : Optional[bool] (default: None) + Starts an insecure gRPC connection when True. Enables HTTPS connection + when False, using system certificates if `root_certificates` is None. transport : Optional[str] (default: None) Configure the transport layer. Allowed values: - 'grpc-bidi': gRPC, bidirectional streaming @@ -301,16 +355,25 @@ def start_numpy_client( Examples -------- - Starting a client with an insecure server connection: + Starting a gRPC client with an insecure server connection: + + >>> start_numpy_client( + >>> server_address=localhost:8080, + >>> client=FlowerClient(), + >>> ) + + Starting an SSL-enabled gRPC client using system certificates: >>> start_numpy_client( >>> server_address=localhost:8080, >>> client=FlowerClient(), + >>> insecure=False, >>> ) - Starting an SSL-enabled gRPC client: + Starting an SSL-enabled gRPC client using provided certificates: >>> from pathlib import Path + >>> >>> start_numpy_client( >>> server_address=localhost:8080, >>> client=FlowerClient(), @@ -340,6 +403,7 @@ def start_numpy_client( client=wrp_client, grpc_max_message_length=grpc_max_message_length, root_certificates=root_certificates, + insecure=insecure, transport=transport, ) @@ -348,7 +412,7 @@ def _init_connection( transport: Optional[str], server_address: str ) -> Tuple[ Callable[ - [str, int, Union[bytes, str, None]], + [str, bool, int, Union[bytes, str, None]], ContextManager[ Tuple[ Callable[[], Optional[TaskIns]], diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index cbef4ef99051..335d28e72828 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -45,6 +45,7 @@ def on_channel_state_change(channel_connectivity: str) -> None: @contextmanager def grpc_connection( server_address: str, + insecure: bool, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, ) -> Iterator[ @@ -100,6 +101,7 @@ def grpc_connection( channel = create_channel( server_address=server_address, + insecure=insecure, root_certificates=root_certificates, max_message_length=max_message_length, ) diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index 0485fa41db35..e5944230e5af 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -93,7 +93,7 @@ def test_integration_connection() -> None: def run_client() -> int: messages_received: int = 0 - with grpc_connection(server_address=f"[::]:{port}") as conn: + with grpc_connection(server_address=f"[::]:{port}", insecure=True) as conn: receive, send, _, _ = conn # Setup processing loop diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 424e413dc484..30d407a52c53 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -51,10 +51,9 @@ def on_channel_state_change(channel_connectivity: str) -> None: @contextmanager def grpc_request_response( server_address: str, + insecure: bool, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 - root_certificates: Optional[ - Union[bytes, str] - ] = None, # pylint: disable=unused-argument + root_certificates: Optional[Union[bytes, str]] = None, ) -> Iterator[ Tuple[ Callable[[], Optional[TaskIns]], @@ -95,6 +94,7 @@ def grpc_request_response( channel = create_channel( server_address=server_address, + insecure=insecure, root_certificates=root_certificates, max_message_length=max_message_length, ) diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index 092e543bf55b..d22b246dbd61 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -61,6 +61,7 @@ # pylint: disable-next=too-many-statements def http_request_response( server_address: str, + insecure: bool, # pylint: disable=unused-argument max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[ Union[bytes, str] diff --git a/src/py/flwr/common/grpc.py b/src/py/flwr/common/grpc.py index 2857048f62a0..9d0543ea8c75 100644 --- a/src/py/flwr/common/grpc.py +++ b/src/py/flwr/common/grpc.py @@ -27,10 +27,19 @@ def create_channel( server_address: str, + insecure: bool, root_certificates: Optional[bytes] = None, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, ) -> grpc.Channel: """Create a gRPC channel, either secure or insecure.""" + # Check for conflicting parameters + if insecure and root_certificates is not None: + raise ValueError( + "Invalid configuration: 'root_certificates' should not be provided " + "when 'insecure' is set to True. For an insecure connection, omit " + "'root_certificates', or set 'insecure' to False for a secure connection." + ) + # Possible options: # https://github.com/grpc/grpc/blob/v1.43.x/include/grpc/impl/codegen/grpc_types.h channel_options = [ @@ -38,14 +47,14 @@ def create_channel( ("grpc.max_receive_message_length", max_message_length), ] - if root_certificates is not None: + if insecure: + channel = grpc.insecure_channel(server_address, options=channel_options) + log(INFO, "Opened insecure gRPC connection (no certificates were passed)") + else: ssl_channel_credentials = grpc.ssl_channel_credentials(root_certificates) channel = grpc.secure_channel( server_address, ssl_channel_credentials, options=channel_options ) log(INFO, "Opened secure gRPC connection using certificates") - else: - channel = grpc.insecure_channel(server_address, options=channel_options) - log(INFO, "Opened insecure gRPC connection (no certificates were passed)") return channel diff --git a/src/py/flwr/driver/grpc_driver.py b/src/py/flwr/driver/grpc_driver.py index a25de6f9f666..7dd0a0f501c5 100644 --- a/src/py/flwr/driver/grpc_driver.py +++ b/src/py/flwr/driver/grpc_driver.py @@ -66,6 +66,7 @@ def connect(self) -> None: return self.channel = create_channel( server_address=self.driver_service_address, + insecure=(self.certificates is None), root_certificates=self.certificates, ) self.stub = DriverStub(self.channel)