Skip to content

Commit

Permalink
Rename Driver to GrpcDriver (#2511)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
panh99 and danieljanes authored Oct 13, 2023
1 parent d7d451a commit 4df413f
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 21 deletions.
4 changes: 2 additions & 2 deletions examples/mt-pytorch/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import random
import time

from flwr.driver import Driver
from flwr.driver import GrpcDriver
from flwr.common import (
ServerMessage,
FitIns,
Expand Down Expand Up @@ -43,7 +43,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:


# -------------------------------------------------------------------------- Driver SDK
driver = Driver(driver_service_address="0.0.0.0:9091", certificates=None)
driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None)
# -------------------------------------------------------------------------- Driver SDK

anonymous_client_nodes = False
Expand Down
4 changes: 2 additions & 2 deletions examples/secaggplus-mt/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from workflows import get_workflow_factory

from flwr.common import Metrics, ndarrays_to_parameters
from flwr.driver import Driver
from flwr.driver import GrpcDriver
from flwr.proto import driver_pb2, node_pb2, task_pb2
from flwr.server import History

Expand Down Expand Up @@ -71,7 +71,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:


# -------------------------------------------------------------------------- Driver SDK
driver = Driver(driver_service_address="0.0.0.0:9091", certificates=None)
driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None)
# -------------------------------------------------------------------------- Driver SDK

anonymous_client_nodes = False
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/driver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@


from .app import start_driver
from .driver import Driver
from .driver import GrpcDriver

__all__ = [
"start_driver",
"Driver",
"GrpcDriver",
]
6 changes: 3 additions & 3 deletions src/py/flwr/driver/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from flwr.server.server import Server
from flwr.server.strategy import Strategy

from .driver import Driver
from .driver import GrpcDriver
from .driver_client_proxy import DriverClientProxy

DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
Expand Down Expand Up @@ -112,7 +112,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"

# Create the Driver
driver = Driver(driver_service_address=address, certificates=certificates)
driver = GrpcDriver(driver_service_address=address, certificates=certificates)
driver.connect()
lock = threading.Lock()

Expand Down Expand Up @@ -157,7 +157,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals


def update_client_manager(
driver: Driver,
driver: GrpcDriver,
client_manager: ClientManager,
lock: threading.Lock,
) -> None:
Expand Down
20 changes: 10 additions & 10 deletions src/py/flwr/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
[Driver] Error: Not connected.
Call `connect()` on the `Driver` instance before calling any of the other `Driver`
methods.
Call `connect()` on the `GrpcDriver` instance before calling any of the other
`GrpcDriver` methods.
"""


class Driver:
"""`Driver` provides access to the Driver API."""
class GrpcDriver:
"""`GrpcDriver` provides access to the Driver API/service."""

def __init__(
self,
Expand Down Expand Up @@ -88,7 +88,7 @@ def create_workload(self, req: CreateWorkloadRequest) -> CreateWorkloadResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise Exception("`Driver` instance not connected")
raise Exception("`GrpcDriver` instance not connected")

# Call Driver API
res: CreateWorkloadResponse = self.stub.CreateWorkload(request=req)
Expand All @@ -99,9 +99,9 @@ def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise Exception("`Driver` instance not connected")
raise Exception("`GrpcDriver` instance not connected")

# Call Driver API
# Call gRPC Driver API
res: GetNodesResponse = self.stub.GetNodes(request=req)
return res

Expand All @@ -110,9 +110,9 @@ def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise Exception("`Driver` instance not connected")
raise Exception("`GrpcDriver` instance not connected")

# Call Driver API
# Call gRPC Driver API
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
return res

Expand All @@ -121,7 +121,7 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise Exception("`Driver` instance not connected")
raise Exception("`GrpcDriver` instance not connected")

# Call Driver API
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
Expand Down
6 changes: 4 additions & 2 deletions src/py/flwr/driver/driver_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@
from flwr.proto import driver_pb2, node_pb2, task_pb2, transport_pb2
from flwr.server.client_proxy import ClientProxy

from .driver import Driver
from .driver import GrpcDriver

SLEEP_TIME = 1


class DriverClientProxy(ClientProxy):
"""Flower client proxy which delegates work using the Driver API."""

def __init__(self, node_id: int, driver: Driver, anonymous: bool, workload_id: int):
def __init__(
self, node_id: int, driver: GrpcDriver, anonymous: bool, workload_id: int
):
super().__init__(str(node_id))
self.node_id = node_id
self.driver = driver
Expand Down

0 comments on commit 4df413f

Please sign in to comment.