diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py index fed760f021af..5b7ec1e3b482 100644 --- a/examples/mt-pytorch/driver.py +++ b/examples/mt-pytorch/driver.py @@ -43,7 +43,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK -driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None) +driver = GrpcDriver(driver_service_address="0.0.0.0:9091", root_certificates=None) # -------------------------------------------------------------------------- Driver SDK anonymous_client_nodes = False diff --git a/examples/secaggplus-mt/driver.py b/examples/secaggplus-mt/driver.py index d9f795766f6d..4b38cc829b82 100644 --- a/examples/secaggplus-mt/driver.py +++ b/examples/secaggplus-mt/driver.py @@ -71,7 +71,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK -driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None) +driver = GrpcDriver(driver_service_address="0.0.0.0:9091", root_certificates=None) # -------------------------------------------------------------------------- Driver SDK anonymous_client_nodes = False diff --git a/src/py/flwr/driver/app.py b/src/py/flwr/driver/app.py index 3cb8652365d8..8da197dcd665 100644 --- a/src/py/flwr/driver/app.py +++ b/src/py/flwr/driver/app.py @@ -111,7 +111,9 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals # Create the Driver if isinstance(root_certificates, str): root_certificates = Path(root_certificates).read_bytes() - driver = GrpcDriver(driver_service_address=address, certificates=root_certificates) + driver = GrpcDriver( + driver_service_address=address, root_certificates=root_certificates + ) driver.connect() lock = threading.Lock() diff --git a/src/py/flwr/driver/driver.py b/src/py/flwr/driver/driver.py index 20e149fe6024..4625728b3b24 100644 --- a/src/py/flwr/driver/driver.py +++ b/src/py/flwr/driver/driver.py @@ -43,14 +43,10 @@ class Driver: driver_service_address : Optional[str] The IPv4 or IPv6 address of the Driver API server. Defaults to `"[::]:9091"`. - certificates : bytes (default: None) - Tuple containing root certificate, server certificate, and private key - to start a secure SSL-enabled server. The tuple is expected to have - three bytes elements in the following order: - - * CA certificate. - * server certificate. - * server private key. + root_certificates : Optional[bytes] (default: None) + The PEM-encoded root certificates as a byte string. If provided, + a secure connection using the certificates will be established to + an SSL-enabled Flower server. invoker : Optional[RetryInvoker] (default: None) A `RetryInvoker` object to control the retry behavior on Driver API failures. If set to None, a default instance is created with an exponential backoff @@ -61,11 +57,11 @@ class Driver: def __init__( self, driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - certificates: Optional[bytes] = None, + root_certificates: Optional[bytes] = None, invoker: Optional[RetryInvoker] = None, ) -> None: self.addr = driver_service_address - self.certificates = certificates + self.root_certificates = root_certificates self.grpc_driver: Optional[GrpcDriver] = None self.workload_id: Optional[int] = None self.node = Node(node_id=0, anonymous=True) @@ -88,7 +84,8 @@ def _get_grpc_driver_and_workload_id(self) -> Tuple[GrpcDriver, int]: if self.grpc_driver is None or self.workload_id is None: # Connect and create workload self.grpc_driver = GrpcDriver( - driver_service_address=self.addr, certificates=self.certificates + driver_service_address=self.addr, + root_certificates=self.root_certificates, ) self.grpc_driver.connect() res = self.grpc_driver.create_workload(CreateWorkloadRequest()) diff --git a/src/py/flwr/driver/grpc_driver.py b/src/py/flwr/driver/grpc_driver.py index 7dd0a0f501c5..af59c9695080 100644 --- a/src/py/flwr/driver/grpc_driver.py +++ b/src/py/flwr/driver/grpc_driver.py @@ -51,10 +51,10 @@ class GrpcDriver: def __init__( self, driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - certificates: Optional[bytes] = None, + root_certificates: Optional[bytes] = None, ) -> None: self.driver_service_address = driver_service_address - self.certificates = certificates + self.root_certificates = root_certificates self.channel: Optional[grpc.Channel] = None self.stub: Optional[DriverStub] = None @@ -66,8 +66,8 @@ def connect(self) -> None: return self.channel = create_channel( server_address=self.driver_service_address, - insecure=(self.certificates is None), - root_certificates=self.certificates, + insecure=(self.root_certificates is None), + root_certificates=self.root_certificates, ) self.stub = DriverStub(self.channel) log(INFO, "[Driver] Connected to %s", self.driver_service_address)