diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py index 683a4b2975a6..fed760f021af 100644 --- a/examples/mt-pytorch/driver.py +++ b/examples/mt-pytorch/driver.py @@ -2,7 +2,7 @@ import random import time -from flwr.driver import Driver +from flwr.driver import GrpcDriver from flwr.common import ( ServerMessage, FitIns, @@ -43,7 +43,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK -driver = Driver(driver_service_address="0.0.0.0:9091", certificates=None) +driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None) # -------------------------------------------------------------------------- Driver SDK anonymous_client_nodes = False diff --git a/examples/secaggplus-mt/driver.py b/examples/secaggplus-mt/driver.py index 4e0a53ed1c91..d9f795766f6d 100644 --- a/examples/secaggplus-mt/driver.py +++ b/examples/secaggplus-mt/driver.py @@ -6,7 +6,7 @@ from workflows import get_workflow_factory from flwr.common import Metrics, ndarrays_to_parameters -from flwr.driver import Driver +from flwr.driver import GrpcDriver from flwr.proto import driver_pb2, node_pb2, task_pb2 from flwr.server import History @@ -71,7 +71,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK -driver = Driver(driver_service_address="0.0.0.0:9091", certificates=None) +driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None) # -------------------------------------------------------------------------- Driver SDK anonymous_client_nodes = False diff --git a/src/py/flwr/driver/__init__.py b/src/py/flwr/driver/__init__.py index 8c100e935f70..17b3a1ea236f 100644 --- a/src/py/flwr/driver/__init__.py +++ b/src/py/flwr/driver/__init__.py @@ -16,9 +16,9 @@ from .app import start_driver -from .driver import Driver +from .driver import GrpcDriver __all__ = [ "start_driver", - "Driver", + "GrpcDriver", ] diff --git a/src/py/flwr/driver/app.py b/src/py/flwr/driver/app.py index eeacfc3d9ede..ca9e7b13084b 100644 --- a/src/py/flwr/driver/app.py +++ b/src/py/flwr/driver/app.py @@ -31,7 +31,7 @@ from flwr.server.server import Server from flwr.server.strategy import Strategy -from .driver import Driver +from .driver import GrpcDriver from .driver_client_proxy import DriverClientProxy DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" @@ -112,7 +112,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" # Create the Driver - driver = Driver(driver_service_address=address, certificates=certificates) + driver = GrpcDriver(driver_service_address=address, certificates=certificates) driver.connect() lock = threading.Lock() @@ -157,7 +157,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals def update_client_manager( - driver: Driver, + driver: GrpcDriver, client_manager: ClientManager, lock: threading.Lock, ) -> None: diff --git a/src/py/flwr/driver/driver.py b/src/py/flwr/driver/driver.py index 8e029c1e1be1..4b189b9ce290 100644 --- a/src/py/flwr/driver/driver.py +++ b/src/py/flwr/driver/driver.py @@ -40,13 +40,13 @@ ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """ [Driver] Error: Not connected. -Call `connect()` on the `Driver` instance before calling any of the other `Driver` -methods. +Call `connect()` on the `GrpcDriver` instance before calling any of the other +`GrpcDriver` methods. """ -class Driver: - """`Driver` provides access to the Driver API.""" +class GrpcDriver: + """`GrpcDriver` provides access to the Driver API/service.""" def __init__( self, @@ -88,7 +88,7 @@ def create_workload(self, req: CreateWorkloadRequest) -> CreateWorkloadResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`Driver` instance not connected") + raise Exception("`GrpcDriver` instance not connected") # Call Driver API res: CreateWorkloadResponse = self.stub.CreateWorkload(request=req) @@ -99,9 +99,9 @@ def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`Driver` instance not connected") + raise Exception("`GrpcDriver` instance not connected") - # Call Driver API + # Call gRPC Driver API res: GetNodesResponse = self.stub.GetNodes(request=req) return res @@ -110,9 +110,9 @@ def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`Driver` instance not connected") + raise Exception("`GrpcDriver` instance not connected") - # Call Driver API + # Call gRPC Driver API res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) return res @@ -121,7 +121,7 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`Driver` instance not connected") + raise Exception("`GrpcDriver` instance not connected") # Call Driver API res: PullTaskResResponse = self.stub.PullTaskRes(request=req) diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/driver/driver_client_proxy.py index b732cf66c220..ae04ec5c1512 100644 --- a/src/py/flwr/driver/driver_client_proxy.py +++ b/src/py/flwr/driver/driver_client_proxy.py @@ -23,7 +23,7 @@ from flwr.proto import driver_pb2, node_pb2, task_pb2, transport_pb2 from flwr.server.client_proxy import ClientProxy -from .driver import Driver +from .driver import GrpcDriver SLEEP_TIME = 1 @@ -31,7 +31,9 @@ class DriverClientProxy(ClientProxy): """Flower client proxy which delegates work using the Driver API.""" - def __init__(self, node_id: int, driver: Driver, anonymous: bool, workload_id: int): + def __init__( + self, node_id: int, driver: GrpcDriver, anonymous: bool, workload_id: int + ): super().__init__(str(node_id)) self.node_id = node_id self.driver = driver