From 439e6d0e43a2c172fa382a8ae49147a6d499d292 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 22 Jan 2024 16:27:14 +0000 Subject: [PATCH] change name to root_certificates --- examples/mt-pytorch/driver.py | 2 +- examples/secaggplus-mt/driver.py | 3 +-- src/py/flwr/driver/app.py | 2 +- src/py/flwr/driver/driver.py | 9 +++++---- src/py/flwr/driver/grpc_driver.py | 8 ++++---- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py index ad4d5e1caabe..184ee683818d 100644 --- a/examples/mt-pytorch/driver.py +++ b/examples/mt-pytorch/driver.py @@ -43,7 +43,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK -driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None) +driver = GrpcDriver(driver_service_address="0.0.0.0:9091", root_certificates=None) # -------------------------------------------------------------------------- Driver SDK anonymous_client_nodes = False diff --git a/examples/secaggplus-mt/driver.py b/examples/secaggplus-mt/driver.py index f5871f1b44e4..79e4efd4157e 100644 --- a/examples/secaggplus-mt/driver.py +++ b/examples/secaggplus-mt/driver.py @@ -24,7 +24,6 @@ def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task: task_id="", # Do not set, will be created and set by the DriverAPI group_id="", run_id=run_id, - run_id=run_id, task=merge( task, task_pb2.Task( @@ -72,7 +71,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK -driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None) +driver = GrpcDriver(driver_service_address="0.0.0.0:9091", root_certificates=None) # -------------------------------------------------------------------------- Driver SDK anonymous_client_nodes = False diff --git a/src/py/flwr/driver/app.py b/src/py/flwr/driver/app.py index 1a3cb6a16034..a06375be0044 100644 --- a/src/py/flwr/driver/app.py +++ b/src/py/flwr/driver/app.py @@ -129,7 +129,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals thread = threading.Thread( target=update_client_manager, args=( - Driver(driver_service_address=address, certificates=root_certificates), + Driver(driver_service_address=address, root_certificates=root_certificates), initialized_server.client_manager(), ref_exit_flag, ), diff --git a/src/py/flwr/driver/driver.py b/src/py/flwr/driver/driver.py index 512a2001165e..13da7435ab37 100644 --- a/src/py/flwr/driver/driver.py +++ b/src/py/flwr/driver/driver.py @@ -36,7 +36,7 @@ class Driver: driver_service_address : Optional[str] The IPv4 or IPv6 address of the Driver API server. Defaults to `"[::]:9091"`. - certificates : bytes (default: None) + root_certificates : bytes (default: None) Tuple containing root certificate, server certificate, and private key to start a secure SSL-enabled server. The tuple is expected to have three bytes elements in the following order: @@ -49,10 +49,10 @@ class Driver: def __init__( self, driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - certificates: Optional[bytes] = None, + root_certificates: Optional[bytes] = None, ) -> None: self.addr = driver_service_address - self.certificates = certificates + self.root_certificates = root_certificates self.grpc_driver: Optional[GrpcDriver] = None self.run_id: Optional[int] = None self.node = Node(node_id=0, anonymous=True) @@ -62,7 +62,8 @@ def _get_grpc_driver_and_run_id(self) -> Tuple[GrpcDriver, int]: if self.grpc_driver is None or self.run_id is None: # Connect and create run self.grpc_driver = GrpcDriver( - driver_service_address=self.addr, certificates=self.certificates + driver_service_address=self.addr, + root_certificates=self.root_certificates, ) self.grpc_driver.connect() res = self.grpc_driver.create_run(CreateRunRequest()) diff --git a/src/py/flwr/driver/grpc_driver.py b/src/py/flwr/driver/grpc_driver.py index 23d449790092..c3f66f7343db 100644 --- a/src/py/flwr/driver/grpc_driver.py +++ b/src/py/flwr/driver/grpc_driver.py @@ -51,10 +51,10 @@ class GrpcDriver: def __init__( self, driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - certificates: Optional[bytes] = None, + root_certificates: Optional[bytes] = None, ) -> None: self.driver_service_address = driver_service_address - self.certificates = certificates + self.root_certificates = root_certificates self.channel: Optional[grpc.Channel] = None self.stub: Optional[DriverStub] = None @@ -66,8 +66,8 @@ def connect(self) -> None: return self.channel = create_channel( server_address=self.driver_service_address, - insecure=(self.certificates is None), - root_certificates=self.certificates, + insecure=(self.root_certificates is None), + root_certificates=self.root_certificates, ) self.stub = DriverStub(self.channel) log(INFO, "[Driver] Connected to %s", self.driver_service_address)