diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 6189794e8f69..a5121ad71b38 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -60,6 +60,8 @@ jobs: include: - directory: bare + - directory: bare-https + - directory: jax - directory: pytorch @@ -135,7 +137,7 @@ jobs: - name: Run virtual client test run: python simulation.py - name: Run driver test - run: ./../test_driver.sh + run: ./../test_driver.sh "${{ matrix.directory }}" strategies: runs-on: ubuntu-22.04 diff --git a/e2e/bare-https/README.md b/e2e/bare-https/README.md new file mode 100644 index 000000000000..2b7fb953a24b --- /dev/null +++ b/e2e/bare-https/README.md @@ -0,0 +1,3 @@ +# Bare Flower testing + +This directory is used for testing Flower in a bare minimum scenario, that is, with a dummy model and dummy operations. This is mainly to test the core functionnality of Flower independently from any framework. It can easily be extendended to test more complex communication set-ups. diff --git a/e2e/bare-https/certificate.conf b/e2e/bare-https/certificate.conf new file mode 100644 index 000000000000..ea97fcbb700d --- /dev/null +++ b/e2e/bare-https/certificate.conf @@ -0,0 +1,20 @@ +[req] +default_bits = 4096 +prompt = no +default_md = sha256 +req_extensions = req_ext +distinguished_name = dn + +[dn] +C = DE +ST = HH +O = Flower +CN = localhost + +[req_ext] +subjectAltName = @alt_names + +[alt_names] +DNS.1 = localhost +IP.1 = ::1 +IP.2 = 127.0.0.1 diff --git a/e2e/bare-https/client.py b/e2e/bare-https/client.py new file mode 100644 index 000000000000..da04a320b1d2 --- /dev/null +++ b/e2e/bare-https/client.py @@ -0,0 +1,35 @@ +import flwr as fl +import numpy as np +from pathlib import Path + + +model_params = np.array([1]) +objective = 5 + +# Define Flower client +class FlowerClient(fl.client.NumPyClient): + def get_parameters(self, config): + return model_params + + def fit(self, parameters, config): + model_params = parameters + model_params = [param * (objective/np.mean(param)) for param in model_params] + return model_params, 1, {} + + def evaluate(self, parameters, config): + model_params = parameters + loss = min(np.abs(1 - np.mean(model_params)/objective), 1) + accuracy = 1 - loss + return loss, 1, {"accuracy": accuracy} + +def client_fn(cid): + return FlowerClient() + + +if __name__ == "__main__": + # Start Flower client + fl.client.start_numpy_client( + server_address="127.0.0.1:8080", + client=FlowerClient(), + root_certificates=Path("certificates/ca.crt").read_bytes(), + ) diff --git a/e2e/bare-https/driver.py b/e2e/bare-https/driver.py new file mode 100644 index 000000000000..5c44e4c641ae --- /dev/null +++ b/e2e/bare-https/driver.py @@ -0,0 +1,12 @@ +import flwr as fl +from pathlib import Path + + +# Start Flower server +hist = fl.driver.start_driver( + server_address="127.0.0.1:9091", + config=fl.server.ServerConfig(num_rounds=3), + root_certificates=Path("certificates/ca.crt").read_bytes(), +) + +assert hist.losses_distributed[-1][1] == 0 diff --git a/e2e/bare-https/generate.sh b/e2e/bare-https/generate.sh new file mode 100755 index 000000000000..72362656e756 --- /dev/null +++ b/e2e/bare-https/generate.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# This script will generate all certificates if ca.crt does not exist + +set -e +# Change directory to the script's directory +cd "$(dirname "${BASH_SOURCE[0]}")" + +CA_PASSWORD=notsafe + +CERT_DIR=certificates + +# Generate directories if not exists +mkdir -p $CERT_DIR + +# Uncomment the below block if you want to skip certificate generation if they already exist. +# if [ -f "$CERT_DIR/ca.crt" ]; then +# echo "Skipping certificate generation as they already exist." +# exit 0 +# fi + +# Clearing any existing files in the certificates directory +rm -f $CERT_DIR/* + +# Generate the root certificate authority key and certificate based on key +openssl genrsa -out $CERT_DIR/ca.key 4096 +openssl req \ + -new \ + -x509 \ + -key $CERT_DIR/ca.key \ + -sha256 \ + -subj "/C=DE/ST=HH/O=CA, Inc." \ + -days 365 -out $CERT_DIR/ca.crt + +# Generate a new private key for the server +openssl genrsa -out $CERT_DIR/server.key 4096 + +# Create a signing CSR +openssl req \ + -new \ + -key $CERT_DIR/server.key \ + -out $CERT_DIR/server.csr \ + -config certificate.conf + +# Generate a certificate for the server +openssl x509 \ + -req \ + -in $CERT_DIR/server.csr \ + -CA $CERT_DIR/ca.crt \ + -CAkey $CERT_DIR/ca.key \ + -CAcreateserial \ + -out $CERT_DIR/server.pem \ + -days 365 \ + -sha256 \ + -extfile certificate.conf \ + -extensions req_ext diff --git a/e2e/bare-https/pyproject.toml b/e2e/bare-https/pyproject.toml new file mode 100644 index 000000000000..9489a43195f9 --- /dev/null +++ b/e2e/bare-https/pyproject.toml @@ -0,0 +1,13 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "bare_https_test" +version = "0.1.0" +description = "HTTPS-enabled bare Federated Learning test with Flower" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = "^3.8" +flwr = { path = "../../", develop = true } diff --git a/e2e/bare-https/server.py b/e2e/bare-https/server.py new file mode 100644 index 000000000000..fcad7a3e4522 --- /dev/null +++ b/e2e/bare-https/server.py @@ -0,0 +1,15 @@ +import flwr as fl +from pathlib import Path + + +hist = fl.server.start_server( + server_address="127.0.0.1:8080", + config=fl.server.ServerConfig(num_rounds=3), + certificates=( + Path("certificates/ca.crt").read_bytes(), + Path("certificates/server.pem").read_bytes(), + Path("certificates/server.key").read_bytes(), + ) +) + +assert hist.losses_distributed[-1][1] == 0 or (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 0.98 diff --git a/e2e/bare-https/simulation.py b/e2e/bare-https/simulation.py new file mode 100644 index 000000000000..b7268c98dcbc --- /dev/null +++ b/e2e/bare-https/simulation.py @@ -0,0 +1 @@ +# No simulation test for bare-https diff --git a/e2e/test.sh b/e2e/test.sh index c1f8d3177113..4ea17a4f994b 100755 --- a/e2e/test.sh +++ b/e2e/test.sh @@ -5,6 +5,10 @@ case "$1" in pandas) server_file="server.py" ;; + bare-https) + ./generate.sh + server_file="server.py" + ;; *) server_file="../server.py" ;; diff --git a/e2e/test_driver.sh b/e2e/test_driver.sh index 3ca95e90d321..e5baac20fa1f 100755 --- a/e2e/test_driver.sh +++ b/e2e/test_driver.sh @@ -1,7 +1,17 @@ #!/bin/bash set -e -timeout 2m flower-server --grpc-bidi --grpc-bidi-fleet-api-address 0.0.0.0:8080 & +case "$1" in + bare-https) + ./generate.sh + cert_arg="--certificates certificates/ca.crt certificates/server.pem certificates/server.key" + ;; + *) + cert_arg="--insecure" + ;; +esac + +timeout 2m flower-server $cert_arg --grpc-bidi --grpc-bidi-fleet-api-address 0.0.0.0:8080 & sleep 3 python client.py & diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 8d1317aba403..63c24c37a685 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -22,6 +22,7 @@ from dataclasses import dataclass from logging import ERROR, INFO, WARN from os.path import isfile +from pathlib import Path from signal import SIGINT, SIGTERM, signal from types import FrameType from typing import List, Optional, Tuple @@ -247,6 +248,9 @@ def run_driver_api() -> None: host, port, is_v6 = parsed_address address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" + # Obtain certificates + certificates = _try_obtain_certificates(args) + # Initialize StateFactory state_factory = StateFactory(args.database) @@ -254,6 +258,7 @@ def run_driver_api() -> None: grpc_server: grpc.Server = _run_driver_api_grpc( address=address, state_factory=state_factory, + certificates=certificates, ) # Graceful shutdown @@ -273,6 +278,9 @@ def run_fleet_api() -> None: event(EventType.RUN_FLEET_API_ENTER) args = _parse_args_fleet().parse_args() + # Obtain certificates + certificates = _try_obtain_certificates(args) + # Initialize StateFactory state_factory = StateFactory(args.database) @@ -315,6 +323,7 @@ def run_fleet_api() -> None: fleet_server = _run_fleet_api_grpc_bidi( address=address, state_factory=state_factory, + certificates=certificates, ) grpc_servers.append(fleet_server) elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE: @@ -327,6 +336,7 @@ def run_fleet_api() -> None: fleet_server = _run_fleet_api_grpc_rere( address=address, state_factory=state_factory, + certificates=certificates, ) grpc_servers.append(fleet_server) else: @@ -346,7 +356,7 @@ def run_fleet_api() -> None: bckg_threads[0].join() -# pylint: disable=too-many-branches +# pylint: disable=too-many-branches, too-many-locals, too-many-statements def run_server() -> None: """Run Flower server (Driver API and Fleet API).""" log(INFO, "Starting Flower server") @@ -360,6 +370,9 @@ def run_server() -> None: host, port, is_v6 = parsed_address address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" + # Obtain certificates + certificates = _try_obtain_certificates(args) + # Initialize StateFactory state_factory = StateFactory(args.database) @@ -367,6 +380,7 @@ def run_server() -> None: driver_server: grpc.Server = _run_driver_api_grpc( address=address, state_factory=state_factory, + certificates=certificates, ) grpc_servers = [driver_server] @@ -408,6 +422,7 @@ def run_server() -> None: fleet_server = _run_fleet_api_grpc_bidi( address=address, state_factory=state_factory, + certificates=certificates, ) grpc_servers.append(fleet_server) elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE: @@ -420,6 +435,7 @@ def run_server() -> None: fleet_server = _run_fleet_api_grpc_rere( address=address, state_factory=state_factory, + certificates=certificates, ) grpc_servers.append(fleet_server) else: @@ -441,6 +457,29 @@ def run_server() -> None: driver_server.wait_for_termination(timeout=1) +def _try_obtain_certificates( + args: argparse.Namespace, +) -> Optional[Tuple[bytes, bytes, bytes]]: + # Obtain certificates + if args.insecure: + log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.") + certificates = None + # Check if certificates are provided + elif args.certificates: + certificates = ( + Path(args.certificates[0]).read_bytes(), # CA certificate + Path(args.certificates[1]).read_bytes(), # server certificate + Path(args.certificates[2]).read_bytes(), # server private key + ) + else: + sys.exit( + "Certificates are required unless running in insecure mode. " + "Please provide certificate paths with '--certificates' or run the server " + "in insecure mode using '--insecure' if you understand the risks." + ) + return certificates + + def _register_exit_handlers( grpc_servers: List[grpc.Server], bckg_threads: List[threading.Thread], @@ -490,6 +529,7 @@ def graceful_exit_handler( # type: ignore def _run_driver_api_grpc( address: str, state_factory: StateFactory, + certificates: Optional[Tuple[bytes, bytes, bytes]], ) -> grpc.Server: """Run Driver API (gRPC, request-response).""" # Create Driver API gRPC server @@ -501,7 +541,7 @@ def _run_driver_api_grpc( servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn), server_address=address, max_message_length=GRPC_MAX_MESSAGE_LENGTH, - certificates=None, + certificates=certificates, ) log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address) @@ -513,6 +553,7 @@ def _run_driver_api_grpc( def _run_fleet_api_grpc_bidi( address: str, state_factory: StateFactory, + certificates: Optional[Tuple[bytes, bytes, bytes]], ) -> grpc.Server: """Run Fleet API (gRPC, bidirectional streaming).""" # DriverClientManager @@ -529,7 +570,7 @@ def _run_fleet_api_grpc_bidi( servicer_and_add_fn=(fleet_servicer, fleet_add_servicer_to_server_fn), server_address=address, max_message_length=GRPC_MAX_MESSAGE_LENGTH, - certificates=None, + certificates=certificates, ) log(INFO, "Flower ECE: Starting Fleet API (gRPC-bidi) on %s", address) @@ -541,6 +582,7 @@ def _run_fleet_api_grpc_bidi( def _run_fleet_api_grpc_rere( address: str, state_factory: StateFactory, + certificates: Optional[Tuple[bytes, bytes, bytes]], ) -> grpc.Server: """Run Fleet API (gRPC, request-response).""" # Create Fleet API gRPC server @@ -552,7 +594,7 @@ def _run_fleet_api_grpc_rere( servicer_and_add_fn=(fleet_servicer, fleet_add_servicer_to_server_fn), server_address=address, max_message_length=GRPC_MAX_MESSAGE_LENGTH, - certificates=None, + certificates=certificates, ) log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address) @@ -684,6 +726,22 @@ def _parse_args_server() -> argparse.ArgumentParser: def _add_args_common(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--insecure", + action="store_true", + help="Run the server without HTTPS, regardless of whether certificate " + "paths are provided. By default, the server runs with HTTPS enabled. " + "Use this flag only if you understand the risks.", + ) + parser.add_argument( + "--certificates", + nargs=3, + metavar=("CA_CERT", "SERVER_CERT", "PRIVATE_KEY"), + type=str, + help="Paths to the CA certificate, server certificate, and server private " + "key, in that order. Note: The server can only be started without " + "certificates by enabling the `--insecure` flag.", + ) parser.add_argument( "--database", help="A string representing the path to the database "