Skip to content

Commit

Permalink
Add driver_service_address and certificates args to Driver clas…
Browse files Browse the repository at this point in the history
…s constructor (#2535)
  • Loading branch information
panh99 authored Oct 26, 2023
1 parent 262d5f7 commit 78163cb
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions src/py/flwr/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from typing import Iterable, List, Optional, Tuple

from flwr.driver.grpc_driver import GrpcDriver
from flwr.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver
from flwr.proto.driver_pb2 import (
CreateWorkloadRequest,
GetNodesRequest,
Expand All @@ -29,9 +29,30 @@


class Driver:
"""`Driver` class provides an interface to the Driver API."""

def __init__(self) -> None:
"""`Driver` class provides an interface to the Driver API.
Parameters
----------
driver_service_address : Optional[str]
The IPv4 or IPv6 address of the Driver API server.
Defaults to `"[::]:9091"`.
certificates : bytes (default: None)
Tuple containing root certificate, server certificate, and private key
to start a secure SSL-enabled server. The tuple is expected to have
three bytes elements in the following order:
* CA certificate.
* server certificate.
* server private key.
"""

def __init__(
self,
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
certificates: Optional[bytes] = None,
) -> None:
self.addr = driver_service_address
self.certificates = certificates
self.grpc_driver: Optional[GrpcDriver] = None
self.workload_id: Optional[int] = None
self.node = Node(node_id=0, anonymous=True)
Expand All @@ -40,7 +61,9 @@ def _get_grpc_driver_and_workload_id(self) -> Tuple[GrpcDriver, int]:
# Check if the GrpcDriver is initialized
if self.grpc_driver is None or self.workload_id is None:
# Connect and create workload
self.grpc_driver = GrpcDriver()
self.grpc_driver = GrpcDriver(
driver_service_address=self.addr, certificates=self.certificates
)
self.grpc_driver.connect()
res = self.grpc_driver.create_workload(CreateWorkloadRequest())
self.workload_id = res.workload_id
Expand Down

0 comments on commit 78163cb

Please sign in to comment.