Skip to content

Commit

Permalink
Simplify imports in driver module (#2483)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Oct 7, 2023
1 parent 857db30 commit e559ef2
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions src/py/flwr/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,17 @@
from flwr.common import EventType, event
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.proto import driver_pb2, driver_pb2_grpc
from flwr.proto.driver_pb2 import (
CreateWorkloadRequest,
CreateWorkloadResponse,
GetNodesRequest,
GetNodesResponse,
PullTaskResRequest,
PullTaskResResponse,
PushTaskInsRequest,
PushTaskInsResponse,
)
from flwr.proto.driver_pb2_grpc import DriverStub

DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"

Expand All @@ -46,7 +56,7 @@ def __init__(
self.driver_service_address = driver_service_address
self.certificates = certificates
self.channel: Optional[grpc.Channel] = None
self.stub: Optional[driver_pb2_grpc.DriverStub] = None
self.stub: Optional[DriverStub] = None

def connect(self) -> None:
"""Connect to the Driver API."""
Expand All @@ -58,7 +68,7 @@ def connect(self) -> None:
server_address=self.driver_service_address,
root_certificates=self.certificates,
)
self.stub = driver_pb2_grpc.DriverStub(self.channel)
self.stub = DriverStub(self.channel)
log(INFO, "[Driver] Connected to %s", self.driver_service_address)

def disconnect(self) -> None:
Expand All @@ -73,52 +83,46 @@ def disconnect(self) -> None:
channel.close()
log(INFO, "[Driver] Disconnected")

def create_workload(
self, req: driver_pb2.CreateWorkloadRequest
) -> driver_pb2.CreateWorkloadResponse:
def create_workload(self, req: CreateWorkloadRequest) -> CreateWorkloadResponse:
"""Request for workload ID."""
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise Exception("`Driver` instance not connected")

# Call Driver API
res: driver_pb2.CreateWorkloadResponse = self.stub.CreateWorkload(request=req)
res: CreateWorkloadResponse = self.stub.CreateWorkload(request=req)
return res

def get_nodes(self, req: driver_pb2.GetNodesRequest) -> driver_pb2.GetNodesResponse:
def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
"""Get client IDs."""
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise Exception("`Driver` instance not connected")

# Call Driver API
res: driver_pb2.GetNodesResponse = self.stub.GetNodes(request=req)
res: GetNodesResponse = self.stub.GetNodes(request=req)
return res

def push_task_ins(
self, req: driver_pb2.PushTaskInsRequest
) -> driver_pb2.PushTaskInsResponse:
def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
"""Schedule tasks."""
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise Exception("`Driver` instance not connected")

# Call Driver API
res: driver_pb2.PushTaskInsResponse = self.stub.PushTaskIns(request=req)
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
return res

def pull_task_res(
self, req: driver_pb2.PullTaskResRequest
) -> driver_pb2.PullTaskResResponse:
def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
"""Get task results."""
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise Exception("`Driver` instance not connected")

# Call Driver API
res: driver_pb2.PullTaskResResponse = self.stub.PullTaskRes(request=req)
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
return res

0 comments on commit e559ef2

Please sign in to comment.