diff --git a/src/py/flwr/driver/__init__.py b/src/py/flwr/driver/__init__.py index 1c3b09cc334b..5dd55b6246d9 100644 --- a/src/py/flwr/driver/__init__.py +++ b/src/py/flwr/driver/__init__.py @@ -15,12 +15,13 @@ """Flower driver SDK.""" -from .app import start_driver +from .app import DriverConfig, start_driver from .driver import Driver from .grpc_driver import GrpcDriver __all__ = [ "Driver", + "DriverConfig", "GrpcDriver", "start_driver", ] diff --git a/src/py/flwr/driver/app.py b/src/py/flwr/driver/app.py index 2c0576bde8ff..2e324b4caa91 100644 --- a/src/py/flwr/driver/app.py +++ b/src/py/flwr/driver/app.py @@ -18,22 +18,35 @@ import sys import threading import time +import timeit +from dataclasses import dataclass from logging import INFO from pathlib import Path -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union, Union, cast from flwr.common import EventType, event from flwr.common.address import parse_address from flwr.common.logger import log -from flwr.proto import driver_pb2 # pylint: disable=E0611 -from flwr.server.app import ServerConfig, init_defaults, run_fl -from flwr.server.client_manager import ClientManager +from flwr.common.typing import Parameters +from flwr.proto.driver_pb2 import ( + CreateRunRequest, + GetNodesRequest, + PullTaskResRequest, + PushTaskInsRequest, +) +from flwr.proto.node_pb2 import Node +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes +from flwr.server import ServerConfig +from flwr.server.client_manager import ClientManager, SimpleClientManager from flwr.server.history import History -from flwr.server.server import Server -from flwr.server.strategy import Strategy +from flwr.server.strategy import FedAvg, Strategy from .driver_client_proxy import DriverClientProxy -from .grpc_driver import GrpcDriver +from .workflow.workflow_factory import ( + FlowerWorkflowFactory, + FLWorkflowFactory, + WorkflowState, +) DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" @@ -45,14 +58,26 @@ """ +@dataclass +class DriverConfig: + """Flower driver config. + + All attributes have default values which allows users to configure just the ones + they care about. + """ + + num_rounds: int = 1 + round_timeout: Optional[float] = None + + def start_driver( # pylint: disable=too-many-arguments, too-many-locals *, server_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - server: Optional[Server] = None, - config: Optional[ServerConfig] = None, + config: Optional[Union[DriverConfig, ServerConfig]] = None, strategy: Optional[Strategy] = None, client_manager: Optional[ClientManager] = None, root_certificates: Optional[Union[bytes, str]] = None, + fl_workflow_factory: Optional[FlowerWorkflowFactory] = None, ) -> History: """Start a Flower Driver API server. @@ -65,7 +90,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals A server implementation, either `flwr.server.Server` or a subclass thereof. If no instance is provided, then `start_driver` will create one. - config : Optional[ServerConfig] (default: None) + config : Optional[DriverConfig] (default: None) Currently supported values are `num_rounds` (int, default: 1) and `round_timeout` in seconds (float, default: None). strategy : Optional[flwr.server.Strategy] (default: None). @@ -100,6 +125,12 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals """ event(EventType.START_DRIVER_ENTER) + # Backward compatibility + if isinstance(config, ServerConfig): + config = DriverConfig( + num_rounds=config.num_rounds, round_timeout=config.round_timeout + ) + # Parse IP address parsed_address = parse_address(server_address) if not parsed_address: @@ -116,17 +147,33 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals driver.connect() lock = threading.Lock() - # Initialize the Driver API server and config - initialized_server, initialized_config = init_defaults( - server=server, - config=config, + # Request workload_id + workload_id = driver.create_workload(CreateWorkloadRequest()).workload_id + + # Initialization + hist = History() + if client_manager is None: + client_manager = SimpleClientManager() + if strategy is None: + strategy = FedAvg() + if config is None: + config = DriverConfig() + if fl_workflow_factory is None: + fl_workflow_factory = cast(FlowerWorkflowFactory, FLWorkflowFactory()) + workflow_state = WorkflowState( + num_rounds=config.num_rounds, + current_round=0, # This field will be set inside the workflow strategy=strategy, + parameters=Parameters( + tensors=[], tensor_type="" + ), # This field will be set inside the workflow, client_manager=client_manager, + history=hist, ) log( INFO, - "Starting Flower server, config: %s", - initialized_config, + "Starting Flower driver, config: %s", + config, ) # Start the thread updating nodes @@ -134,30 +181,48 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals target=update_client_manager, args=( driver, - initialized_server.client_manager(), + workload_id, + client_manager, lock, ), ) thread.start() # Start training - hist = run_fl( - server=initialized_server, - config=initialized_config, - ) + fl_workflow = fl_workflow_factory(workflow_state) + + instructions = next(fl_workflow) + while True: + node_responses = fetch_responses( + driver, workload_id, instructions, config.round_timeout + ) + try: + instructions = fl_workflow.send(node_responses) + except StopIteration: + break + + fl_workflow.close() # Stop the Driver API server and the thread with lock: driver.disconnect() thread.join() + # Log history + log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed)) + log(INFO, "app_fit: metrics_distributed_fit %s", str(hist.metrics_distributed_fit)) + log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed)) + log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized)) + log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized)) + event(EventType.START_SERVER_LEAVE) return hist def update_client_manager( - driver: GrpcDriver, + driver: Driver, + workload_id: int, client_manager: ClientManager, lock: threading.Lock, ) -> None: @@ -171,11 +236,6 @@ def update_client_manager( and dead nodes will be removed from the ClientManager via `client_manager.unregister()`. """ - # Request for run_id - run_id = driver.create_run( - driver_pb2.CreateRunRequest() # pylint: disable=E1101 - ).run_id - # Loop until the driver is disconnected registered_nodes: Dict[int, DriverClientProxy] = {} while True: @@ -184,7 +244,7 @@ def update_client_manager( if driver.stub is None: break get_nodes_res = driver.get_nodes( - req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101 + req=GetNodesRequest(workload_id=workload_id) ) all_node_ids = {node.node_id for node in get_nodes_res.nodes} dead_nodes = set(registered_nodes).difference(all_node_ids) @@ -211,3 +271,59 @@ def update_client_manager( # Sleep for 3 seconds time.sleep(3) + + +# pylint: disable-next=too-many-locals +def fetch_responses( + driver: Driver, + workload_id: int, + instructions: Dict[int, Task], + timeout: Optional[float], +) -> Dict[int, Task]: + """Send instructions to clients and return their responses.""" + # Build the list of TaskIns + task_ins_list: List[TaskIns] = [] + driver_node = Node(node_id=0, anonymous=True) + for node_id, task in instructions.items(): + # Set the `consumer` and `producer` fields in Task + # + # Note that protobuf API `protobuf.message.MergeFrom(other_msg)` + # does NOT always overwrite fields that are set in `other_msg`. + # Please refer to: + # https://googleapis.dev/python/protobuf/latest/google/protobuf/message.html + consumer = Node(node_id=node_id, anonymous=False) + task.MergeFrom(Task(producer=driver_node, consumer=consumer)) + # Create TaskIns and add it to the list + task_ins = TaskIns( + task_id="", # Do not set, will be created and set by the DriverAPI + group_id="", + workload_id=workload_id, + task=task, + ) + task_ins_list.append(task_ins) + + # Push TaskIns + push_res = driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list)) + task_ids = [task_id for task_id in push_res.task_ids if task_id != ""] + + time.sleep(1.0) + + # Pull TaskRes + task_res_list: List[TaskRes] = [] + if timeout: + start_time = timeit.default_timer() + while timeout is None or timeit.default_timer() - start_time < timeout: + pull_res = driver.pull_task_res( + PullTaskResRequest(node=driver_node, task_ids=task_ids) + ) + task_res_list.extend(pull_res.task_res_list) + if len(task_res_list) == len(task_ids): + break + + time.sleep(3.0) + + # Build and return response dictionary + node_responses: Dict[int, Task] = { + task_res.task.producer.node_id: task_res.task for task_res in task_res_list + } + return node_responses diff --git a/src/py/flwr/driver/app_test.py b/src/py/flwr/driver/app_test.py index bfa0098f68e2..ee07438a0909 100644 --- a/src/py/flwr/driver/app_test.py +++ b/src/py/flwr/driver/app_test.py @@ -55,6 +55,7 @@ def test_simple_client_manager_update(self) -> None: target=update_client_manager, args=( driver, + 1, client_manager, lock, ), diff --git a/src/py/flwr/driver/task_utils.py b/src/py/flwr/driver/task_utils.py new file mode 100644 index 000000000000..4c1b13a89229 --- /dev/null +++ b/src/py/flwr/driver/task_utils.py @@ -0,0 +1,65 @@ +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions wrapping contents in Task.""" + + +from typing import Dict, Union + +from flwr.common import serde +from flwr.common.typing import ( + EvaluateIns, + FitIns, + GetParametersIns, + GetPropertiesIns, + ServerMessage, + Value, +) +from flwr.proto import task_pb2, transport_pb2 + + +def wrap_server_message_in_task( + message: Union[ + ServerMessage, GetPropertiesIns, GetParametersIns, FitIns, EvaluateIns + ] +) -> task_pb2.Task: + """Wrap any server message/instruction in Task.""" + if isinstance(message, ServerMessage): + server_message_proto = serde.server_message_to_proto(message) + elif isinstance(message, GetPropertiesIns): + server_message_proto = transport_pb2.ServerMessage( + get_properties_ins=serde.get_properties_ins_to_proto(message) + ) + elif isinstance(message, GetParametersIns): + server_message_proto = transport_pb2.ServerMessage( + get_parameters_ins=serde.get_parameters_ins_to_proto(message) + ) + elif isinstance(message, FitIns): + server_message_proto = transport_pb2.ServerMessage( + fit_ins=serde.fit_ins_to_proto(message) + ) + elif isinstance(message, EvaluateIns): + server_message_proto = transport_pb2.ServerMessage( + evaluate_ins=serde.evaluate_ins_to_proto(message) + ) + return task_pb2.Task(legacy_server_message=server_message_proto) + + +def wrap_named_values_in_task(named_values: Dict[str, Value]) -> task_pb2.Task: + """Wrap the `named_values` dictionary in SecureAggregation in Task.""" + return task_pb2.Task( + sa=task_pb2.SecureAggregation( + named_values=serde.named_values_to_proto(named_values) + ) + ) diff --git a/src/py/flwr/driver/workflow/__init__.py b/src/py/flwr/driver/workflow/__init__.py new file mode 100644 index 000000000000..41901c6627d1 --- /dev/null +++ b/src/py/flwr/driver/workflow/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower workflow.""" + + +from .workflow_factory import ( + FlowerWorkflow, + FlowerWorkflowFactory, + FLWorkflowFactory, + WorkflowState, +) + +__all__ = [ + "FlowerWorkflow", + "FlowerWorkflowFactory", + "FLWorkflowFactory", + "WorkflowState", +] diff --git a/src/py/flwr/driver/workflow/workflow_factory.py b/src/py/flwr/driver/workflow/workflow_factory.py new file mode 100644 index 000000000000..34d62c1a8d7f --- /dev/null +++ b/src/py/flwr/driver/workflow/workflow_factory.py @@ -0,0 +1,276 @@ +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Default workflow factories.""" + + +import timeit +from dataclasses import dataclass +from logging import DEBUG, INFO +from typing import Callable, Dict, Generator, Optional, Tuple + +from flwr.common import serde +from flwr.common.logger import log +from flwr.common.typing import GetParametersIns, Parameters, Scalar +from flwr.driver.task_utils import wrap_server_message_in_task +from flwr.proto.task_pb2 import Task +from flwr.server.client_manager import ClientManager +from flwr.server.history import History +from flwr.server.strategy.strategy import Strategy + + +@dataclass +class WorkflowState: + """State of the workflow.""" + + num_rounds: int + current_round: int + strategy: Strategy + parameters: Parameters + client_manager: ClientManager + history: History + + +FlowerWorkflow = Generator[Dict[int, Task], Dict[int, Task], None] +FlowerWorkflowFactory = Callable[[WorkflowState], FlowerWorkflow] + + +class FLWorkflowFactory: + """Default FL workflow factory in Flower.""" + + def __init__( + self, + fit_workflow_factory: Optional[FlowerWorkflowFactory] = None, + evaluate_workflow_factory: Optional[FlowerWorkflowFactory] = None, + ): + self.fit_workflow_factory = ( + fit_workflow_factory + if fit_workflow_factory is not None + else default_fit_workflow_factory + ) + self.evaluate_workflow_factory = ( + evaluate_workflow_factory + if evaluate_workflow_factory is not None + else default_evaluate_workflow_factory + ) + + def __call__(self, state: WorkflowState) -> FlowerWorkflow: + """Create the workflow.""" + # Initialize parameters + yield from default_init_params_workflow_factory(state) + + # Run federated learning for num_rounds + log(INFO, "FL starting") + start_time = timeit.default_timer() + + for current_round in range(1, state.num_rounds + 1): + state.current_round = current_round + + # Fit round + yield from self.fit_workflow_factory(state) + + # Centralized evaluation + res_cen = state.strategy.evaluate( + current_round, parameters=state.parameters + ) + if res_cen is not None: + loss_cen, metrics_cen = res_cen + log( + INFO, + "fit progress: (%s, %s, %s, %s)", + current_round, + loss_cen, + metrics_cen, + timeit.default_timer() - start_time, + ) + state.history.add_loss_centralized( + server_round=current_round, loss=loss_cen + ) + state.history.add_metrics_centralized( + server_round=current_round, metrics=metrics_cen + ) + + # Evaluate round + yield from self.evaluate_workflow_factory(state) + + # Bookkeeping + end_time = timeit.default_timer() + elapsed = end_time - start_time + log(INFO, "FL finished in %s", elapsed) + + +def default_init_params_workflow_factory(state: WorkflowState) -> FlowerWorkflow: + """Create the default workflow for parameters initialization.""" + log(INFO, "Initializing global parameters") + parameters = state.strategy.initialize_parameters( + client_manager=state.client_manager + ) + if parameters is not None: + log(INFO, "Using initial parameters provided by strategy") + state.parameters = parameters + # Get initial parameters from one of the clients + else: + log(INFO, "Requesting initial parameters from one random client") + random_client = state.client_manager.sample(1)[0] + # Send GetParametersIns and get the response + node_responses = yield { + random_client.node_id: wrap_server_message_in_task( + GetParametersIns(config={}) + ) + } + get_parameters_res = serde.get_parameters_res_from_proto( + node_responses[ + random_client.node_id + ].legacy_client_message.get_parameters_res + ) + log(INFO, "Received initial parameters from one random client") + state.parameters = get_parameters_res.parameters + + # Evaluate initial parameters + log(INFO, "Evaluating initial parameters") + res = state.strategy.evaluate(0, parameters=state.parameters) + if res is not None: + log( + INFO, + "initial parameters (loss, other metrics): %s, %s", + res[0], + res[1], + ) + state.history.add_loss_centralized(server_round=0, loss=res[0]) + state.history.add_metrics_centralized(server_round=0, metrics=res[1]) + + +def default_fit_workflow_factory(state: WorkflowState) -> FlowerWorkflow: + """Create the default workflow for a single fit round.""" + # Get clients and their respective instructions from strategy + client_instructions = state.strategy.configure_fit( + server_round=state.current_round, + parameters=state.parameters, + client_manager=state.client_manager, + ) + + if not client_instructions: + log(INFO, "fit_round %s: no clients selected, cancel", state.current_round) + return + log( + DEBUG, + "fit_round %s: strategy sampled %s clients (out of %s)", + state.current_round, + len(client_instructions), + state.client_manager.num_available(), + ) + + # Build dictionary mapping node_id to ClientProxy + node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions} + + # Send instructions to clients and + # collect `fit` results from all clients participating in this round + node_responses = yield { + proxy.node_id: wrap_server_message_in_task(fit_ins) + for proxy, fit_ins in client_instructions + } + + # No exception/failure handling currently + log( + DEBUG, + "fit_round %s received %s results and %s failures", + state.current_round, + len(node_responses), + 0, + ) + + # Aggregate training results + results = [ + ( + node_id_to_proxy[node_id], + serde.fit_res_from_proto(res.legacy_client_message.fit_res), + ) + for node_id, res in node_responses.items() + ] + aggregated_result: Tuple[ + Optional[Parameters], + Dict[str, Scalar], + ] = state.strategy.aggregate_fit(state.current_round, results, []) + parameters_aggregated, metrics_aggregated = aggregated_result + + # Update the parameters and write history + if parameters_aggregated: + state.parameters = parameters_aggregated + state.history.add_metrics_distributed_fit( + server_round=state.current_round, metrics=metrics_aggregated + ) + return + + +def default_evaluate_workflow_factory(state: WorkflowState) -> FlowerWorkflow: + """Create the default workflow for a single evaluate round.""" + # Get clients and their respective instructions from strategy + client_instructions = state.strategy.configure_evaluate( + server_round=state.current_round, + parameters=state.parameters, + client_manager=state.client_manager, + ) + if not client_instructions: + log(INFO, "evaluate_round %s: no clients selected, cancel", state.current_round) + return + log( + DEBUG, + "evaluate_round %s: strategy sampled %s clients (out of %s)", + state.current_round, + len(client_instructions), + state.client_manager.num_available(), + ) + + # Build dictionary mapping node_id to ClientProxy + node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions} + + # Send instructions to clients and + # collect `evaluate` results from all clients participating in this round + node_responses = yield { + proxy.node_id: wrap_server_message_in_task(evaluate_ins) + for proxy, evaluate_ins in client_instructions + } + # No exception/failure handling currently + log( + DEBUG, + "evaluate_round %s received %s results and %s failures", + state.current_round, + len(node_responses), + 0, + ) + + # Aggregate the evaluation results + results = [ + ( + node_id_to_proxy[node_id], + serde.evaluate_res_from_proto(res.legacy_client_message.evaluate_res), + ) + for node_id, res in node_responses.items() + ] + aggregated_result: Tuple[ + Optional[float], + Dict[str, Scalar], + ] = state.strategy.aggregate_evaluate(state.current_round, results, []) + + loss_aggregated, metrics_aggregated = aggregated_result + + # Write history + if loss_aggregated is not None: + state.history.add_loss_distributed( + server_round=state.current_round, loss=loss_aggregated + ) + state.history.add_metrics_distributed( + server_round=state.current_round, metrics=metrics_aggregated + ) + return