diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md index da002a10d301..11c4c3f9a08b 100644 --- a/examples/xgboost-comprehensive/README.md +++ b/examples/xgboost-comprehensive/README.md @@ -9,6 +9,7 @@ It differs from the [xgboost-quickstart](https://github.com/adap/flower/tree/mai - Customised number of partitions. - Customised partitioner type (uniform, linear, square, exponential). - Centralised/distributed evaluation. +- Bagging/cyclic training methods. ## Project Setup @@ -26,7 +27,8 @@ This will create a new directory called `xgboost-comprehensive` containing the f -- client.py <- Defines the client-side logic -- dataset.py <- Defines the functions of data loading and partitioning -- utils.py <- Defines the arguments parser for clients and server --- run.sh <- Commands to run experiments +-- run_bagging.sh <- Commands to run bagging experiments +-- run_cyclic.sh <- Commands to run cyclic experiments -- pyproject.toml <- Example dependencies (if you use Poetry) -- requirements.txt <- Example dependencies ``` @@ -60,24 +62,31 @@ pip install -r requirements.txt ## Run Federated Learning with XGBoost and Flower -The included `run.sh` will start the Flower server (using `server.py`) with centralised evaluation, +We have two scripts to run bagging and cyclic (client-by-client) experiments. +The included `run_bagging.sh` or `run_cyclic.sh` will start the Flower server (using `server.py`), sleep for 15 seconds to ensure that the server is up, and then start 5 Flower clients (using `client.py`) with a small subset of the data from exponential partition distribution. You can simply start everything in a terminal as follows: ```shell -poetry run ./run.sh +poetry run ./run_bagging.sh ``` -The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. +Or + +```shell +poetry run ./run_cyclic.sh +``` + +The script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). -You can also manually run `poetry run python3 server.py --pool-size=N --num-clients-per-round=N` -and `poetry run python3 client.py --node-id=NODE_ID --num-partitions=N` for as many clients as you want, +You can also manually run `poetry run python3 server.py --train-method=bagging/cyclic --pool-size=N --num-clients-per-round=N` +and `poetry run python3 client.py --train-method=bagging/cyclic --node-id=NODE_ID --num-partitions=N` for as many clients as you want, but you have to make sure that each command is run in a different terminal window (or a different computer on the network). In addition, we provide more options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`). @@ -86,6 +95,8 @@ and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.htm ### Expected Experimental Results +#### Bagging aggregation experiment + ![](_static/xgboost_flower_auc.png) The figure above shows the centralised tested AUC performance over FL rounds on 4 experimental settings. diff --git a/examples/xgboost-comprehensive/client.py b/examples/xgboost-comprehensive/client.py index a37edac32648..ff7a4adf7977 100644 --- a/examples/xgboost-comprehensive/client.py +++ b/examples/xgboost-comprehensive/client.py @@ -101,11 +101,16 @@ def _local_boost(self): for i in range(num_local_round): self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) - # Extract the last N=num_local_round trees for sever aggregation - bst = self.bst[ - self.bst.num_boosted_rounds() - - num_local_round : self.bst.num_boosted_rounds() - ] + # Bagging: extract the last N=num_local_round trees for sever aggregation + # Cyclic: return the entire model + bst = ( + self.bst[ + self.bst.num_boosted_rounds() + - num_local_round : self.bst.num_boosted_rounds() + ] + if args.train_method == "bagging" + else self.bst + ) return bst diff --git a/examples/xgboost-comprehensive/run.sh b/examples/xgboost-comprehensive/run_bagging.sh similarity index 100% rename from examples/xgboost-comprehensive/run.sh rename to examples/xgboost-comprehensive/run_bagging.sh diff --git a/examples/xgboost-comprehensive/run_cyclic.sh b/examples/xgboost-comprehensive/run_cyclic.sh new file mode 100755 index 000000000000..47e09fd8faef --- /dev/null +++ b/examples/xgboost-comprehensive/run_cyclic.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ + +echo "Starting server" +python3 server.py --train-method=cyclic --pool-size=5 --num-rounds=100 & +sleep 15 # Sleep for 15s to give the server enough time to start + +for i in `seq 0 4`; do + echo "Starting client $i" + python3 client.py --node-id=$i --train-method=cyclic --num-partitions=5 --partitioner-type=exponential --centralised-eval & +done + +# Enable CTRL+C to stop all background processes +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index 3da7e8d9865c..1cf4ba79fa50 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -1,4 +1,5 @@ -from typing import Dict +import warnings +from typing import Dict, List, Optional from logging import INFO import xgboost as xgb @@ -7,13 +8,21 @@ from flwr.common import Parameters, Scalar from flwr_datasets import FederatedDataset from flwr.server.strategy import FedXgbBagging +from flwr.server.strategy import FedXgbCyclic +from flwr.server.client_proxy import ClientProxy +from flwr.server.criterion import Criterion +from flwr.server.client_manager import SimpleClientManager from utils import server_args_parser, BST_PARAMS from dataset import resplit, transform_dataset_to_dmatrix +warnings.filterwarnings("ignore", category=UserWarning) + + # Parse arguments for experimental settings args = server_args_parser() +train_method = args.train_method pool_size = args.pool_size num_rounds = args.num_rounds num_clients_per_round = args.num_clients_per_round @@ -80,23 +89,72 @@ def evaluate_fn( return evaluate_fn +class CyclicClientManager(SimpleClientManager): + """Provides a cyclic client selection rule.""" + + def sample( + self, + num_clients: int, + min_num_clients: Optional[int] = None, + criterion: Optional[Criterion] = None, + ) -> List[ClientProxy]: + """Sample a number of Flower ClientProxy instances.""" + + # Block until at least num_clients are connected. + if min_num_clients is None: + min_num_clients = num_clients + self.wait_for(min_num_clients) + + # Sample clients which meet the criterion + available_cids = list(self.clients) + if criterion is not None: + available_cids = [ + cid for cid in available_cids if criterion.select(self.clients[cid]) + ] + + if num_clients > len(available_cids): + log( + INFO, + "Sampling failed: number of available clients" + " (%s) is less than number of requested clients (%s).", + len(available_cids), + num_clients, + ) + return [] + + # Return all available clients + return [self.clients[cid] for cid in available_cids] + + # Define strategy -strategy = FedXgbBagging( - evaluate_function=get_evaluate_fn(test_dmatrix) if centralised_eval else None, - fraction_fit=(float(num_clients_per_round) / pool_size), - min_fit_clients=num_clients_per_round, - min_available_clients=pool_size, - min_evaluate_clients=num_evaluate_clients if not centralised_eval else 0, - fraction_evaluate=1.0 if not centralised_eval else 0.0, - on_evaluate_config_fn=eval_config, - evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation - if not centralised_eval - else None, -) +if train_method == "bagging": + # Bagging training + strategy = FedXgbBagging( + evaluate_function=get_evaluate_fn(test_dmatrix) if centralised_eval else None, + fraction_fit=(float(num_clients_per_round) / pool_size), + min_fit_clients=num_clients_per_round, + min_available_clients=pool_size, + min_evaluate_clients=num_evaluate_clients if not centralised_eval else 0, + fraction_evaluate=1.0 if not centralised_eval else 0.0, + on_evaluate_config_fn=eval_config, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation + if not centralised_eval + else None, + ) +else: + # Cyclic training + strategy = FedXgbCyclic( + fraction_fit=1.0, + min_available_clients=pool_size, + fraction_evaluate=1.0, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, + on_evaluate_config_fn=eval_config, + ) # Start Flower server fl.server.start_server( server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=num_rounds), strategy=strategy, + client_manager=CyclicClientManager() if train_method == "cyclic" else None, ) diff --git a/examples/xgboost-comprehensive/utils.py b/examples/xgboost-comprehensive/utils.py index 000def370752..8acdbbb88a7e 100644 --- a/examples/xgboost-comprehensive/utils.py +++ b/examples/xgboost-comprehensive/utils.py @@ -17,6 +17,13 @@ def client_args_parser(): """Parse arguments to define experimental settings on client side.""" parser = argparse.ArgumentParser() + parser.add_argument( + "--train-method", + default="bagging", + type=str, + choices=["bagging", "cyclic"], + help="Training methods selected from bagging aggregation or cyclic training.", + ) parser.add_argument( "--num-partitions", default=10, type=int, help="Number of partitions." ) @@ -56,6 +63,13 @@ def server_args_parser(): """Parse arguments to define experimental settings on server side.""" parser = argparse.ArgumentParser() + parser.add_argument( + "--train-method", + default="bagging", + type=str, + choices=["bagging", "cyclic"], + help="Training methods selected from bagging aggregation or cyclic training.", + ) parser.add_argument( "--pool-size", default=2, type=int, help="Number of total clients." ) diff --git a/src/py/flwr/server/strategy/__init__.py b/src/py/flwr/server/strategy/__init__.py index 0772aa1ff13a..1750a7522379 100644 --- a/src/py/flwr/server/strategy/__init__.py +++ b/src/py/flwr/server/strategy/__init__.py @@ -29,6 +29,7 @@ from .fedprox import FedProx as FedProx from .fedtrimmedavg import FedTrimmedAvg as FedTrimmedAvg from .fedxgb_bagging import FedXgbBagging as FedXgbBagging +from .fedxgb_cyclic import FedXgbCyclic as FedXgbCyclic from .fedxgb_nn_avg import FedXgbNnAvg as FedXgbNnAvg from .fedyogi import FedYogi as FedYogi from .krum import Krum as Krum @@ -42,6 +43,7 @@ "FedAvg", "FedXgbNnAvg", "FedXgbBagging", + "FedXgbCyclic", "FedAvgAndroid", "FedAvgM", "FedOpt", diff --git a/src/py/flwr/server/strategy/fedxgb_cyclic.py b/src/py/flwr/server/strategy/fedxgb_cyclic.py new file mode 100644 index 000000000000..e2707b02d19d --- /dev/null +++ b/src/py/flwr/server/strategy/fedxgb_cyclic.py @@ -0,0 +1,142 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Federated XGBoost cyclic aggregation strategy.""" + + +from logging import WARNING +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar +from flwr.common.logger import log +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy + +from .fedavg import FedAvg + + +class FedXgbCyclic(FedAvg): + """Configurable FedXgbCyclic strategy implementation.""" + + # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long + def __init__( + self, + **kwargs: Any, + ): + self.global_model: Optional[bytes] = None + super().__init__(**kwargs) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using bagging.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Fetch the client model from last round as global model + for _, fit_res in results: + update = fit_res.parameters.tensors + for bst in update: + self.global_model = bst + + return ( + Parameters(tensor_type="", tensors=[cast(bytes, self.global_model)]), + {}, + ) + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation metrics using average.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.evaluate_metrics_aggregation_fn: + eval_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No evaluate_metrics_aggregation_fn provided") + + return 0, metrics_aggregated + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + config = {} + if self.on_fit_config_fn is not None: + # Custom fit config function provided + config = self.on_fit_config_fn(server_round) + fit_ins = FitIns(parameters, config) + + # Sample clients + sample_size, min_num_clients = self.num_fit_clients( + client_manager.num_available() + ) + clients = client_manager.sample( + num_clients=sample_size, + min_num_clients=min_num_clients, + ) + + # Sample the clients sequentially given server_round + sampled_idx = (server_round - 1) % len(clients) + sampled_clients = [clients[sampled_idx]] + + # Return client/config pairs + return [(client, fit_ins) for client in sampled_clients] + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + # Do not configure federated evaluation if fraction eval is 0. + if self.fraction_evaluate == 0.0: + return [] + + # Parameters and config + config = {} + if self.on_evaluate_config_fn is not None: + # Custom evaluation config function provided + config = self.on_evaluate_config_fn(server_round) + evaluate_ins = EvaluateIns(parameters, config) + + # Sample clients + sample_size, min_num_clients = self.num_evaluation_clients( + client_manager.num_available() + ) + clients = client_manager.sample( + num_clients=sample_size, + min_num_clients=min_num_clients, + ) + + # Sample the clients sequentially given server_round + sampled_idx = (server_round - 1) % len(clients) + sampled_clients = [clients[sampled_idx]] + + # Return client/config pairs + return [(client, evaluate_ins) for client in sampled_clients]