From f20f86383025ac4489a195ace0220f6de30600a9 Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Mon, 27 Nov 2023 16:49:43 +0000 Subject: [PATCH 1/6] Initialise cyclic training --- examples/xgboost-comprehensive/README.md | 23 ++- examples/xgboost-comprehensive/client.py | 5 +- .../{run.sh => run_bagging.sh} | 0 examples/xgboost-comprehensive/run_cyclic.sh | 17 +++ examples/xgboost-comprehensive/server.py | 83 ++++++++-- examples/xgboost-comprehensive/utils.py | 14 ++ src/py/flwr/server/strategy/__init__.py | 2 + src/py/flwr/server/strategy/fedxgb_cyclic.py | 142 ++++++++++++++++++ 8 files changed, 263 insertions(+), 23 deletions(-) rename examples/xgboost-comprehensive/{run.sh => run_bagging.sh} (100%) create mode 100755 examples/xgboost-comprehensive/run_cyclic.sh create mode 100644 src/py/flwr/server/strategy/fedxgb_cyclic.py diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md index da002a10d301..7af843b8cfe3 100644 --- a/examples/xgboost-comprehensive/README.md +++ b/examples/xgboost-comprehensive/README.md @@ -9,6 +9,7 @@ It differs from the [xgboost-quickstart](https://github.com/adap/flower/tree/mai - Customised number of partitions. - Customised partitioner type (uniform, linear, square, exponential). - Centralised/distributed evaluation. +- Bagging/cyclic training methods. ## Project Setup @@ -26,7 +27,8 @@ This will create a new directory called `xgboost-comprehensive` containing the f -- client.py <- Defines the client-side logic -- dataset.py <- Defines the functions of data loading and partitioning -- utils.py <- Defines the arguments parser for clients and server --- run.sh <- Commands to run experiments +-- run_bagging.sh <- Commands to run bagging experiments +-- run_cyclic.sh <- Commands to run cyclic experiments -- pyproject.toml <- Example dependencies (if you use Poetry) -- requirements.txt <- Example dependencies ``` @@ -60,24 +62,29 @@ pip install -r requirements.txt ## Run Federated Learning with XGBoost and Flower -The included `run.sh` will start the Flower server (using `server.py`) with centralised evaluation, +We have two scripts to run bagging and cyclic (client-by-client) experiments. +The included `run_bagging.sh` or `run_cyclic.sh` will start the Flower server (using `server.py`), sleep for 15 seconds to ensure that the server is up, and then start 5 Flower clients (using `client.py`) with a small subset of the data from exponential partition distribution. You can simply start everything in a terminal as follows: ```shell -poetry run ./run.sh +poetry run ./run_bagging.sh +``` +Or +```shell +poetry run ./run_cyclic.sh ``` -The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. +The script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). -You can also manually run `poetry run python3 server.py --pool-size=N --num-clients-per-round=N` -and `poetry run python3 client.py --node-id=NODE_ID --num-partitions=N` for as many clients as you want, +You can also manually run `poetry run python3 server.py --train-method=bagging/cyclic --pool-size=N --num-clients-per-round=N` +and `poetry run python3 client.py --train-method=bagging/cyclic --node-id=NODE_ID --num-partitions=N` for as many clients as you want, but you have to make sure that each command is run in a different terminal window (or a different computer on the network). In addition, we provide more options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`). @@ -86,9 +93,13 @@ and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.htm ### Expected Experimental Results +#### Bagging aggregation experiment + ![](_static/xgboost_flower_auc.png) The figure above shows the centralised tested AUC performance over FL rounds on 4 experimental settings. One can see that all settings obtain stable performance boost over FL rounds (especially noticeable at the start of training). As expected, uniform client distribution shows higher AUC values (beyond 83% at the end) than square/exponential setup. Feel free to explore more interesting experiments by yourself! + +#### Cyclic training experiment diff --git a/examples/xgboost-comprehensive/client.py b/examples/xgboost-comprehensive/client.py index a37edac32648..5af15c575bd0 100644 --- a/examples/xgboost-comprehensive/client.py +++ b/examples/xgboost-comprehensive/client.py @@ -101,11 +101,12 @@ def _local_boost(self): for i in range(num_local_round): self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) - # Extract the last N=num_local_round trees for sever aggregation + # Bagging: extract the last N=num_local_round trees for sever aggregation + # Cyclic: return the entire model bst = self.bst[ self.bst.num_boosted_rounds() - num_local_round : self.bst.num_boosted_rounds() - ] + ] if args.train_method == "bagging" else self.bst return bst diff --git a/examples/xgboost-comprehensive/run.sh b/examples/xgboost-comprehensive/run_bagging.sh similarity index 100% rename from examples/xgboost-comprehensive/run.sh rename to examples/xgboost-comprehensive/run_bagging.sh diff --git a/examples/xgboost-comprehensive/run_cyclic.sh b/examples/xgboost-comprehensive/run_cyclic.sh new file mode 100755 index 000000000000..47e09fd8faef --- /dev/null +++ b/examples/xgboost-comprehensive/run_cyclic.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ + +echo "Starting server" +python3 server.py --train-method=cyclic --pool-size=5 --num-rounds=100 & +sleep 15 # Sleep for 15s to give the server enough time to start + +for i in `seq 0 4`; do + echo "Starting client $i" + python3 client.py --node-id=$i --train-method=cyclic --num-partitions=5 --partitioner-type=exponential --centralised-eval & +done + +# Enable CTRL+C to stop all background processes +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index 3da7e8d9865c..91d62819fe31 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -1,4 +1,5 @@ -from typing import Dict +import warnings +from typing import Dict, List, Optional from logging import INFO import xgboost as xgb @@ -7,13 +8,21 @@ from flwr.common import Parameters, Scalar from flwr_datasets import FederatedDataset from flwr.server.strategy import FedXgbBagging +from flwr.server.strategy import FedXgbCyclic +from flwr.server.client_proxy import ClientProxy +from flwr.server.criterion import Criterion +from flwr.server.client_manager import SimpleClientManager from utils import server_args_parser, BST_PARAMS from dataset import resplit, transform_dataset_to_dmatrix +warnings.filterwarnings("ignore", category=UserWarning) + + # Parse arguments for experimental settings args = server_args_parser() +train_method = args.train_method pool_size = args.pool_size num_rounds = args.num_rounds num_clients_per_round = args.num_clients_per_round @@ -53,7 +62,6 @@ def evaluate_metrics_aggregation(eval_metrics): def get_evaluate_fn(test_data): """Return a function for centralised evaluation.""" - def evaluate_fn( server_round: int, parameters: Parameters, config: Dict[str, Scalar] ): @@ -76,27 +84,72 @@ def evaluate_fn( log(INFO, f"AUC = {auc} at round {server_round}") return 0, {"AUC": auc} - return evaluate_fn +class CyclicClientManager(SimpleClientManager): + """Provides a cyclic client selection rule.""" + def sample( + self, + num_clients: int, + min_num_clients: Optional[int] = None, + criterion: Optional[Criterion] = None, + ) -> List[ClientProxy]: + """Sample a number of Flower ClientProxy instances.""" + # Block until at least num_clients are connected. + if min_num_clients is None: + min_num_clients = num_clients + self.wait_for(min_num_clients) + # Sample clients which meet the criterion + available_cids = list(self.clients) + if criterion is not None: + available_cids = [ + cid for cid in available_cids if criterion.select(self.clients[cid]) + ] + + if num_clients > len(available_cids): + log( + INFO, + "Sampling failed: number of available clients" + " (%s) is less than number of requested clients (%s).", + len(available_cids), + num_clients, + ) + return [] + + # Return all available clients + return [self.clients[cid] for cid in available_cids] + + # Define strategy -strategy = FedXgbBagging( - evaluate_function=get_evaluate_fn(test_dmatrix) if centralised_eval else None, - fraction_fit=(float(num_clients_per_round) / pool_size), - min_fit_clients=num_clients_per_round, - min_available_clients=pool_size, - min_evaluate_clients=num_evaluate_clients if not centralised_eval else 0, - fraction_evaluate=1.0 if not centralised_eval else 0.0, - on_evaluate_config_fn=eval_config, - evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation - if not centralised_eval - else None, -) +if train_method == "bagging": + # Bagging training + strategy = FedXgbBagging( + evaluate_function=get_evaluate_fn(test_dmatrix) if centralised_eval else None, + fraction_fit=(float(num_clients_per_round) / pool_size), + min_fit_clients=num_clients_per_round, + min_available_clients=pool_size, + min_evaluate_clients=num_evaluate_clients if not centralised_eval else 0, + fraction_evaluate=1.0 if not centralised_eval else 0.0, + on_evaluate_config_fn=eval_config, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation + if not centralised_eval + else None, + ) +else: + # Cyclic training + strategy = FedXgbCyclic( + fraction_fit=1.0, + min_available_clients=pool_size, + fraction_evaluate=1.0, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, + on_evaluate_config_fn=eval_config, + ) # Start Flower server fl.server.start_server( server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=num_rounds), strategy=strategy, + client_manager=CyclicClientManager() if train_method == "cyclic" else None, ) diff --git a/examples/xgboost-comprehensive/utils.py b/examples/xgboost-comprehensive/utils.py index 000def370752..8acdbbb88a7e 100644 --- a/examples/xgboost-comprehensive/utils.py +++ b/examples/xgboost-comprehensive/utils.py @@ -17,6 +17,13 @@ def client_args_parser(): """Parse arguments to define experimental settings on client side.""" parser = argparse.ArgumentParser() + parser.add_argument( + "--train-method", + default="bagging", + type=str, + choices=["bagging", "cyclic"], + help="Training methods selected from bagging aggregation or cyclic training.", + ) parser.add_argument( "--num-partitions", default=10, type=int, help="Number of partitions." ) @@ -56,6 +63,13 @@ def server_args_parser(): """Parse arguments to define experimental settings on server side.""" parser = argparse.ArgumentParser() + parser.add_argument( + "--train-method", + default="bagging", + type=str, + choices=["bagging", "cyclic"], + help="Training methods selected from bagging aggregation or cyclic training.", + ) parser.add_argument( "--pool-size", default=2, type=int, help="Number of total clients." ) diff --git a/src/py/flwr/server/strategy/__init__.py b/src/py/flwr/server/strategy/__init__.py index 0772aa1ff13a..1750a7522379 100644 --- a/src/py/flwr/server/strategy/__init__.py +++ b/src/py/flwr/server/strategy/__init__.py @@ -29,6 +29,7 @@ from .fedprox import FedProx as FedProx from .fedtrimmedavg import FedTrimmedAvg as FedTrimmedAvg from .fedxgb_bagging import FedXgbBagging as FedXgbBagging +from .fedxgb_cyclic import FedXgbCyclic as FedXgbCyclic from .fedxgb_nn_avg import FedXgbNnAvg as FedXgbNnAvg from .fedyogi import FedYogi as FedYogi from .krum import Krum as Krum @@ -42,6 +43,7 @@ "FedAvg", "FedXgbNnAvg", "FedXgbBagging", + "FedXgbCyclic", "FedAvgAndroid", "FedAvgM", "FedOpt", diff --git a/src/py/flwr/server/strategy/fedxgb_cyclic.py b/src/py/flwr/server/strategy/fedxgb_cyclic.py new file mode 100644 index 000000000000..7fefebe65c4a --- /dev/null +++ b/src/py/flwr/server/strategy/fedxgb_cyclic.py @@ -0,0 +1,142 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Federated XGBoost bagging aggregation strategy.""" + + +from logging import WARNING +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar +from flwr.common.logger import log +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy + +from .fedavg import FedAvg + + +class FedXgbCyclic(FedAvg): + """Configurable FedXgbBagging strategy implementation.""" + + # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long + def __init__( + self, + **kwargs: Any, + ): + self.global_model: Optional[bytes] = None + super().__init__(**kwargs) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using bagging.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Fetch the client model from last round as global model + for _, fit_res in results: + update = fit_res.parameters.tensors + for bst in update: + self.global_model = bst + + return ( + Parameters(tensor_type="", tensors=[cast(bytes, self.global_model)]), + {}, + ) + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation metrics using average.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.evaluate_metrics_aggregation_fn: + eval_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No evaluate_metrics_aggregation_fn provided") + + return 0, metrics_aggregated + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + config = {} + if self.on_fit_config_fn is not None: + # Custom fit config function provided + config = self.on_fit_config_fn(server_round) + fit_ins = FitIns(parameters, config) + + # Sample clients + sample_size, min_num_clients = self.num_fit_clients( + client_manager.num_available() + ) + clients = client_manager.sample( + num_clients=sample_size, + min_num_clients=min_num_clients, + ) + + # Sample the clients sequentially given server_round + sampled_idx = (server_round - 1) % len(clients) + sampled_clients = [clients[sampled_idx]] + + # Return client/config pairs + return [(client, fit_ins) for client in sampled_clients] + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + # Do not configure federated evaluation if fraction eval is 0. + if self.fraction_evaluate == 0.0: + return [] + + # Parameters and config + config = {} + if self.on_evaluate_config_fn is not None: + # Custom evaluation config function provided + config = self.on_evaluate_config_fn(server_round) + evaluate_ins = EvaluateIns(parameters, config) + + # Sample clients + sample_size, min_num_clients = self.num_evaluation_clients( + client_manager.num_available() + ) + clients = client_manager.sample( + num_clients=sample_size, + min_num_clients=min_num_clients, + ) + + # Sample the clients sequentially given server_round + sampled_idx = (server_round - 1) % len(clients) + sampled_clients = [clients[sampled_idx]] + + # Return client/config pairs + return [(client, evaluate_ins) for client in sampled_clients] From 542543bc2e9401740b9e0a3759abd06eee5f00bb Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Tue, 28 Nov 2023 09:38:07 +0000 Subject: [PATCH 2/6] Update readme --- examples/xgboost-comprehensive/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md index 7af843b8cfe3..b8e4f60ac13d 100644 --- a/examples/xgboost-comprehensive/README.md +++ b/examples/xgboost-comprehensive/README.md @@ -101,5 +101,3 @@ The figure above shows the centralised tested AUC performance over FL rounds on One can see that all settings obtain stable performance boost over FL rounds (especially noticeable at the start of training). As expected, uniform client distribution shows higher AUC values (beyond 83% at the end) than square/exponential setup. Feel free to explore more interesting experiments by yourself! - -#### Cyclic training experiment From d265c27c57d35873f5389e95d188c4fadf04355b Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Tue, 28 Nov 2023 09:40:59 +0000 Subject: [PATCH 3/6] Formatting --- examples/xgboost-comprehensive/README.md | 2 ++ examples/xgboost-comprehensive/client.py | 12 ++++++++---- examples/xgboost-comprehensive/server.py | 3 +++ 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md index b8e4f60ac13d..11c4c3f9a08b 100644 --- a/examples/xgboost-comprehensive/README.md +++ b/examples/xgboost-comprehensive/README.md @@ -71,7 +71,9 @@ You can simply start everything in a terminal as follows: ```shell poetry run ./run_bagging.sh ``` + Or + ```shell poetry run ./run_cyclic.sh ``` diff --git a/examples/xgboost-comprehensive/client.py b/examples/xgboost-comprehensive/client.py index 5af15c575bd0..ff7a4adf7977 100644 --- a/examples/xgboost-comprehensive/client.py +++ b/examples/xgboost-comprehensive/client.py @@ -103,10 +103,14 @@ def _local_boost(self): # Bagging: extract the last N=num_local_round trees for sever aggregation # Cyclic: return the entire model - bst = self.bst[ - self.bst.num_boosted_rounds() - - num_local_round : self.bst.num_boosted_rounds() - ] if args.train_method == "bagging" else self.bst + bst = ( + self.bst[ + self.bst.num_boosted_rounds() + - num_local_round : self.bst.num_boosted_rounds() + ] + if args.train_method == "bagging" + else self.bst + ) return bst diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index 91d62819fe31..869c70936383 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -62,6 +62,7 @@ def evaluate_metrics_aggregation(eval_metrics): def get_evaluate_fn(test_data): """Return a function for centralised evaluation.""" + def evaluate_fn( server_round: int, parameters: Parameters, config: Dict[str, Scalar] ): @@ -84,11 +85,13 @@ def evaluate_fn( log(INFO, f"AUC = {auc} at round {server_round}") return 0, {"AUC": auc} + return evaluate_fn class CyclicClientManager(SimpleClientManager): """Provides a cyclic client selection rule.""" + def sample( self, num_clients: int, From f8361803531faf08ff703342368eb1d0d48ba806 Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Tue, 5 Dec 2023 17:17:19 +0000 Subject: [PATCH 4/6] Fix docstring --- src/py/flwr/server/strategy/fedxgb_cyclic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/strategy/fedxgb_cyclic.py b/src/py/flwr/server/strategy/fedxgb_cyclic.py index 7fefebe65c4a..e2707b02d19d 100644 --- a/src/py/flwr/server/strategy/fedxgb_cyclic.py +++ b/src/py/flwr/server/strategy/fedxgb_cyclic.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Federated XGBoost bagging aggregation strategy.""" +"""Federated XGBoost cyclic aggregation strategy.""" from logging import WARNING @@ -27,7 +27,7 @@ class FedXgbCyclic(FedAvg): - """Configurable FedXgbBagging strategy implementation.""" + """Configurable FedXgbCyclic strategy implementation.""" # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long def __init__( From fa99f768577d7114f3f667dbd916d3f1d6c25595 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Wed, 6 Dec 2023 08:47:52 +0100 Subject: [PATCH 5/6] Update examples/xgboost-comprehensive/server.py --- examples/xgboost-comprehensive/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index 869c70936383..3ca91da28ef0 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -103,6 +103,7 @@ def sample( if min_num_clients is None: min_num_clients = num_clients self.wait_for(min_num_clients) + # Sample clients which meet the criterion available_cids = list(self.clients) if criterion is not None: From 509e3588d1449bbea55f207e516f86a566d20d4a Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Wed, 6 Dec 2023 08:47:59 +0100 Subject: [PATCH 6/6] Update examples/xgboost-comprehensive/server.py --- examples/xgboost-comprehensive/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index 3ca91da28ef0..1cf4ba79fa50 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -99,6 +99,7 @@ def sample( criterion: Optional[Criterion] = None, ) -> List[ClientProxy]: """Sample a number of Flower ClientProxy instances.""" + # Block until at least num_clients are connected. if min_num_clients is None: min_num_clients = num_clients