Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support workflow in start_driver() #2379

Closed
wants to merge 18 commits into from
Closed
3 changes: 2 additions & 1 deletion src/py/flwr/driver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
"""Flower driver SDK."""


from .app import start_driver
from .app import DriverConfig, start_driver
from .driver import Driver
from .grpc_driver import GrpcDriver

__all__ = [
"Driver",
"DriverConfig",
"GrpcDriver",
"start_driver",
]
172 changes: 144 additions & 28 deletions src/py/flwr/driver/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,35 @@
import sys
import threading
import time
import timeit
from dataclasses import dataclass
from logging import INFO
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union, Union, cast

from flwr.common import EventType, event
from flwr.common.address import parse_address
from flwr.common.logger import log
from flwr.proto import driver_pb2 # pylint: disable=E0611
from flwr.server.app import ServerConfig, init_defaults, run_fl
from flwr.server.client_manager import ClientManager
from flwr.common.typing import Parameters
from flwr.proto.driver_pb2 import (
CreateRunRequest,
GetNodesRequest,
PullTaskResRequest,
PushTaskInsRequest,
)
from flwr.proto.node_pb2 import Node
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
from flwr.server import ServerConfig
from flwr.server.client_manager import ClientManager, SimpleClientManager
from flwr.server.history import History
from flwr.server.server import Server
from flwr.server.strategy import Strategy
from flwr.server.strategy import FedAvg, Strategy

from .driver_client_proxy import DriverClientProxy
from .grpc_driver import GrpcDriver
from .workflow.workflow_factory import (
FlowerWorkflowFactory,
FLWorkflowFactory,
WorkflowState,
)

DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"

Expand All @@ -45,14 +58,26 @@
"""


@dataclass
class DriverConfig:
"""Flower driver config.

All attributes have default values which allows users to configure just the ones
they care about.
"""

num_rounds: int = 1
round_timeout: Optional[float] = None


def start_driver( # pylint: disable=too-many-arguments, too-many-locals
*,
server_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
server: Optional[Server] = None,
config: Optional[ServerConfig] = None,
config: Optional[Union[DriverConfig, ServerConfig]] = None,
strategy: Optional[Strategy] = None,
client_manager: Optional[ClientManager] = None,
root_certificates: Optional[Union[bytes, str]] = None,
fl_workflow_factory: Optional[FlowerWorkflowFactory] = None,
) -> History:
"""Start a Flower Driver API server.

Expand All @@ -65,7 +90,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
A server implementation, either `flwr.server.Server` or a subclass
thereof. If no instance is provided, then `start_driver` will create
one.
config : Optional[ServerConfig] (default: None)
config : Optional[DriverConfig] (default: None)
Currently supported values are `num_rounds` (int, default: 1) and
`round_timeout` in seconds (float, default: None).
strategy : Optional[flwr.server.Strategy] (default: None).
Expand Down Expand Up @@ -100,6 +125,12 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
"""
event(EventType.START_DRIVER_ENTER)

# Backward compatibility
if isinstance(config, ServerConfig):
config = DriverConfig(
num_rounds=config.num_rounds, round_timeout=config.round_timeout
)

# Parse IP address
parsed_address = parse_address(server_address)
if not parsed_address:
Expand All @@ -116,48 +147,82 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
driver.connect()
lock = threading.Lock()

# Initialize the Driver API server and config
initialized_server, initialized_config = init_defaults(
server=server,
config=config,
# Request workload_id
workload_id = driver.create_workload(CreateWorkloadRequest()).workload_id

# Initialization
hist = History()
if client_manager is None:
client_manager = SimpleClientManager()
if strategy is None:
strategy = FedAvg()
if config is None:
config = DriverConfig()
if fl_workflow_factory is None:
fl_workflow_factory = cast(FlowerWorkflowFactory, FLWorkflowFactory())
workflow_state = WorkflowState(
num_rounds=config.num_rounds,
current_round=0, # This field will be set inside the workflow
strategy=strategy,
parameters=Parameters(
tensors=[], tensor_type=""
), # This field will be set inside the workflow,
client_manager=client_manager,
history=hist,
)
log(
INFO,
"Starting Flower server, config: %s",
initialized_config,
"Starting Flower driver, config: %s",
config,
)

# Start the thread updating nodes
thread = threading.Thread(
target=update_client_manager,
args=(
driver,
initialized_server.client_manager(),
workload_id,
client_manager,
lock,
),
)
thread.start()

# Start training
hist = run_fl(
server=initialized_server,
config=initialized_config,
)
fl_workflow = fl_workflow_factory(workflow_state)

instructions = next(fl_workflow)
while True:
node_responses = fetch_responses(
driver, workload_id, instructions, config.round_timeout
)
try:
instructions = fl_workflow.send(node_responses)
except StopIteration:
break

fl_workflow.close()

# Stop the Driver API server and the thread
with lock:
driver.disconnect()
thread.join()

# Log history
log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
log(INFO, "app_fit: metrics_distributed_fit %s", str(hist.metrics_distributed_fit))
log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed))
log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized))

event(EventType.START_SERVER_LEAVE)

return hist


def update_client_manager(
driver: GrpcDriver,
driver: Driver,
workload_id: int,
client_manager: ClientManager,
lock: threading.Lock,
) -> None:
Expand All @@ -171,11 +236,6 @@ def update_client_manager(
and dead nodes will be removed from the ClientManager via
`client_manager.unregister()`.
"""
# Request for run_id
run_id = driver.create_run(
driver_pb2.CreateRunRequest() # pylint: disable=E1101
).run_id

# Loop until the driver is disconnected
registered_nodes: Dict[int, DriverClientProxy] = {}
while True:
Expand All @@ -184,7 +244,7 @@ def update_client_manager(
if driver.stub is None:
break
get_nodes_res = driver.get_nodes(
req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101
req=GetNodesRequest(workload_id=workload_id)
)
all_node_ids = {node.node_id for node in get_nodes_res.nodes}
dead_nodes = set(registered_nodes).difference(all_node_ids)
Expand All @@ -211,3 +271,59 @@ def update_client_manager(

# Sleep for 3 seconds
time.sleep(3)


# pylint: disable-next=too-many-locals
def fetch_responses(
driver: Driver,
workload_id: int,
instructions: Dict[int, Task],
timeout: Optional[float],
) -> Dict[int, Task]:
"""Send instructions to clients and return their responses."""
# Build the list of TaskIns
task_ins_list: List[TaskIns] = []
driver_node = Node(node_id=0, anonymous=True)
for node_id, task in instructions.items():
# Set the `consumer` and `producer` fields in Task
#
# Note that protobuf API `protobuf.message.MergeFrom(other_msg)`
# does NOT always overwrite fields that are set in `other_msg`.
# Please refer to:
# https://googleapis.dev/python/protobuf/latest/google/protobuf/message.html
consumer = Node(node_id=node_id, anonymous=False)
task.MergeFrom(Task(producer=driver_node, consumer=consumer))
# Create TaskIns and add it to the list
task_ins = TaskIns(
task_id="", # Do not set, will be created and set by the DriverAPI
group_id="",
workload_id=workload_id,
task=task,
)
task_ins_list.append(task_ins)

# Push TaskIns
push_res = driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
task_ids = [task_id for task_id in push_res.task_ids if task_id != ""]

time.sleep(1.0)

# Pull TaskRes
task_res_list: List[TaskRes] = []
if timeout:
start_time = timeit.default_timer()
while timeout is None or timeit.default_timer() - start_time < timeout:
pull_res = driver.pull_task_res(
PullTaskResRequest(node=driver_node, task_ids=task_ids)
)
task_res_list.extend(pull_res.task_res_list)
if len(task_res_list) == len(task_ids):
break

time.sleep(3.0)

# Build and return response dictionary
node_responses: Dict[int, Task] = {
task_res.task.producer.node_id: task_res.task for task_res in task_res_list
}
return node_responses
1 change: 1 addition & 0 deletions src/py/flwr/driver/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_simple_client_manager_update(self) -> None:
target=update_client_manager,
args=(
driver,
1,
client_manager,
lock,
),
Expand Down
65 changes: 65 additions & 0 deletions src/py/flwr/driver/task_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2022 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions wrapping contents in Task."""


from typing import Dict, Union

from flwr.common import serde
from flwr.common.typing import (
EvaluateIns,
FitIns,
GetParametersIns,
GetPropertiesIns,
ServerMessage,
Value,
)
from flwr.proto import task_pb2, transport_pb2


def wrap_server_message_in_task(
message: Union[
ServerMessage, GetPropertiesIns, GetParametersIns, FitIns, EvaluateIns
]
) -> task_pb2.Task:
"""Wrap any server message/instruction in Task."""
if isinstance(message, ServerMessage):
server_message_proto = serde.server_message_to_proto(message)
elif isinstance(message, GetPropertiesIns):
server_message_proto = transport_pb2.ServerMessage(
get_properties_ins=serde.get_properties_ins_to_proto(message)
)
elif isinstance(message, GetParametersIns):
server_message_proto = transport_pb2.ServerMessage(
get_parameters_ins=serde.get_parameters_ins_to_proto(message)
)
elif isinstance(message, FitIns):
server_message_proto = transport_pb2.ServerMessage(
fit_ins=serde.fit_ins_to_proto(message)
)
elif isinstance(message, EvaluateIns):
server_message_proto = transport_pb2.ServerMessage(
evaluate_ins=serde.evaluate_ins_to_proto(message)
)
return task_pb2.Task(legacy_server_message=server_message_proto)


def wrap_named_values_in_task(named_values: Dict[str, Value]) -> task_pb2.Task:
"""Wrap the `named_values` dictionary in SecureAggregation in Task."""
return task_pb2.Task(
sa=task_pb2.SecureAggregation(
named_values=serde.named_values_to_proto(named_values)
)
)
30 changes: 30 additions & 0 deletions src/py/flwr/driver/workflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2022 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower workflow."""


from .workflow_factory import (
FlowerWorkflow,
FlowerWorkflowFactory,
FLWorkflowFactory,
WorkflowState,
)

__all__ = [
"FlowerWorkflow",
"FlowerWorkflowFactory",
"FLWorkflowFactory",
"WorkflowState",
]
Loading
Loading