From f0561759aafcbba2283ce308ef1eecad6de8dff2 Mon Sep 17 00:00:00 2001 From: Yan Gao Date: Wed, 15 Nov 2023 18:42:18 +0000 Subject: [PATCH] Quickstart-xgboost with bagging aggregation (#2554) Co-authored-by: yan-gao-GY --- examples/xgboost-comprehensive/README.md | 87 +++++++++ examples/xgboost-comprehensive/client.py | 174 ++++++++++++++++++ examples/xgboost-comprehensive/dataset.py | 67 +++++++ examples/xgboost-comprehensive/pyproject.toml | 15 ++ .../xgboost-comprehensive/requirements.txt | 3 + examples/xgboost-comprehensive/run.sh | 17 ++ examples/xgboost-comprehensive/server.py | 111 +++++++++++ examples/xgboost-comprehensive/strategy.py | 139 ++++++++++++++ examples/xgboost-comprehensive/utils.py | 72 ++++++++ 9 files changed, 685 insertions(+) create mode 100644 examples/xgboost-comprehensive/README.md create mode 100644 examples/xgboost-comprehensive/client.py create mode 100644 examples/xgboost-comprehensive/dataset.py create mode 100644 examples/xgboost-comprehensive/pyproject.toml create mode 100644 examples/xgboost-comprehensive/requirements.txt create mode 100755 examples/xgboost-comprehensive/run.sh create mode 100644 examples/xgboost-comprehensive/server.py create mode 100644 examples/xgboost-comprehensive/strategy.py create mode 100644 examples/xgboost-comprehensive/utils.py diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md new file mode 100644 index 000000000000..3801d4813a26 --- /dev/null +++ b/examples/xgboost-comprehensive/README.md @@ -0,0 +1,87 @@ +# Flower Example using XGBoost + +This example demonstrates how to perform EXtreme Gradient Boosting (XGBoost) within Flower using `xgboost` package. +Tree-based with bagging method is used for aggregation on the server. + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-xgboost . && rm -rf flower && cd quickstart-xgboost +``` + +This will create a new directory called `quickstart-xgboost` containing the following files: + +``` +-- README.md <- Your're reading this right now +-- server.py <- Defines the server-side logic +-- strategy.py <- Defines the tree-based bagging aggregation +-- client.py <- Defines the client-side logic +-- dataset.py <- Defines the functions of data loading and partitioning +-- pyproject.toml <- Example dependencies (if you use Poetry) +-- requirements.txt <- Example dependencies +``` + +### Installing Dependencies + +Project dependencies (such as `xgboost` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Learning with XGBoost and Flower + +Afterwards you are ready to start the Flower server as well as the clients. +You can simply start the server in a terminal as follows: + +```shell +python3 server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. +To do so simply open two more terminal windows and run the following commands. + +Start client 1 in the first terminal: + +```shell +python3 client.py --node-id=0 +``` + +Start client 2 in the second terminal: + +```shell +python3 client.py --node-id=1 +``` + +You will see that XGBoost is starting a federated training. + +Alternatively, you can use `run.sh` to run the same experiment in a single terminal as follows: + +```shell +bash run.sh +``` + +Besides, we provide options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`). +Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-xgboost) +and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. diff --git a/examples/xgboost-comprehensive/client.py b/examples/xgboost-comprehensive/client.py new file mode 100644 index 000000000000..5aba30266b5a --- /dev/null +++ b/examples/xgboost-comprehensive/client.py @@ -0,0 +1,174 @@ +import warnings +from logging import INFO +import xgboost as xgb + +import flwr as fl +from flwr_datasets import FederatedDataset +from flwr.common.logger import log +from flwr.common import ( + Code, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + GetParametersIns, + GetParametersRes, + Parameters, + Status, +) + +from dataset import ( + instantiate_partitioner, + train_test_split, + transform_dataset_to_dmatrix, + resplit, +) +from utils import client_args_parser + + +warnings.filterwarnings("ignore", category=UserWarning) + + +# Parse arguments for experimental settings +args = client_args_parser() + +# Load (HIGGS) dataset and conduct partitioning +num_partitions = args.num_partitions + +# Partitioner type is chosen from ["uniform", "linear", "square", "exponential"] +partitioner_type = args.partitioner_type + +# Instantiate partitioner +partitioner = instantiate_partitioner( + partitioner_type=partitioner_type, num_partitions=num_partitions +) +fds = FederatedDataset( + dataset="jxie/higgs", partitioners={"train": partitioner}, resplitter=resplit +) + +# Let's use the first partition as an example +node_id = args.node_id +partition = fds.load_partition(idx=node_id, split="train") +partition.set_format("numpy") + +if args.centralised_eval: + # Use centralised test set for evaluation + train_data = partition + valid_data = fds.load_full("test") + valid_data.set_format("numpy") + num_train = train_data.shape[0] + num_val = valid_data.shape[0] +else: + # Train/test splitting + SEED = args.seed + test_fraction = args.test_fraction + train_data, valid_data, num_train, num_val = train_test_split( + partition, test_fraction=test_fraction, seed=SEED + ) + +# Reformat data to DMatrix for xgboost +train_dmatrix = transform_dataset_to_dmatrix(train_data) +valid_dmatrix = transform_dataset_to_dmatrix(valid_data) + + +# Hyper-parameters for xgboost training +num_local_round = 1 +params = { + "objective": "binary:logistic", + "eta": 0.1, # Learning rate + "max_depth": 8, + "eval_metric": "auc", + "nthread": 16, + "num_parallel_tree": 1, + "subsample": 1, + "tree_method": "hist", +} + + +# Define Flower client +class FlowerClient(fl.client.Client): + def __init__(self): + self.bst = None + self.config = None + + def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: + _ = (self, ins) + return GetParametersRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[]), + ) + + def _local_boost(self): + # Update trees based on local training data. + for i in range(num_local_round): + self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) + + # Extract the last N=num_local_round trees for sever aggregation + bst = self.bst[ + self.bst.num_boosted_rounds() + - num_local_round : self.bst.num_boosted_rounds() + ] + + return bst + + def fit(self, ins: FitIns) -> FitRes: + if not self.bst: + # First round local training + log(INFO, "Start training at round 1") + bst = xgb.train( + params, + train_dmatrix, + num_boost_round=num_local_round, + evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")], + ) + self.config = bst.save_config() + self.bst = bst + else: + for item in ins.parameters.tensors: + global_model = bytearray(item) + + # Load global model into booster + self.bst.load_model(global_model) + self.bst.load_config(self.config) + + bst = self._local_boost() + + local_model = bst.save_raw("json") + local_model_bytes = bytes(local_model) + + return FitRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), + num_examples=num_train, + metrics={}, + ) + + def evaluate(self, ins: EvaluateIns) -> EvaluateRes: + eval_results = self.bst.eval_set( + evals=[(valid_dmatrix, "valid")], + iteration=self.bst.num_boosted_rounds() - 1, + ) + auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + + global_round = ins.config["global_round"] + log(INFO, f"AUC = {auc} at round {global_round}") + + return EvaluateRes( + status=Status( + code=Code.OK, + message="OK", + ), + loss=0.0, + num_examples=num_val, + metrics={"AUC": auc}, + ) + + +# Start Flower client +fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient()) diff --git a/examples/xgboost-comprehensive/dataset.py b/examples/xgboost-comprehensive/dataset.py new file mode 100644 index 000000000000..80c978f1077b --- /dev/null +++ b/examples/xgboost-comprehensive/dataset.py @@ -0,0 +1,67 @@ +import xgboost as xgb +from typing import Callable, Dict, List, Optional, Tuple, Union +from datasets import Dataset, DatasetDict, concatenate_datasets +from flwr_datasets.partitioner import ( + IidPartitioner, + LinearPartitioner, + SquarePartitioner, + ExponentialPartitioner, +) + +CORRELATION_TO_PARTITIONER = { + "uniform": IidPartitioner, + "linear": LinearPartitioner, + "square": SquarePartitioner, + "exponential": ExponentialPartitioner, +} + + +def instantiate_partitioner(partitioner_type: str, num_partitions: int): + """Initialise partitioner based on selected partitioner type and number of + partitions.""" + partitioner = CORRELATION_TO_PARTITIONER[partitioner_type]( + num_partitions=num_partitions + ) + return partitioner + + +def train_test_split(partition: Dataset, test_fraction: float, seed: int): + """Split the data into train and validation set given split rate.""" + train_test = partition.train_test_split(test_size=test_fraction, seed=seed) + partition_train = train_test["train"] + partition_test = train_test["test"] + + num_train = len(partition_train) + num_test = len(partition_test) + + return partition_train, partition_test, num_train, num_test + + +def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix: + """Transform dataset to DMatrix format for xgboost.""" + x = data["inputs"] + y = data["label"] + new_data = xgb.DMatrix(x, label=y) + return new_data + + +def resplit(dataset: DatasetDict) -> DatasetDict: + """Increase the quantity of centralised test samples from 500K to 1M.""" + return DatasetDict( + { + "train": dataset["train"].select( + range(0, dataset["train"].num_rows - 500_000) + ), + "test": concatenate_datasets( + [ + dataset["train"].select( + range( + dataset["train"].num_rows - 500_000, + dataset["train"].num_rows, + ) + ), + dataset["test"], + ] + ), + } + ) diff --git a/examples/xgboost-comprehensive/pyproject.toml b/examples/xgboost-comprehensive/pyproject.toml new file mode 100644 index 000000000000..5414b5122154 --- /dev/null +++ b/examples/xgboost-comprehensive/pyproject.toml @@ -0,0 +1,15 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "xgboost-comprehensive" +version = "0.1.0" +description = "Federated XGBoost with Flower (comprehensive)" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr = ">=1.0,<2.0" +flwr-datasets = ">=0.0.2,<1.0.0" +xgboost = ">=2.0.0,<3.0.0" diff --git a/examples/xgboost-comprehensive/requirements.txt b/examples/xgboost-comprehensive/requirements.txt new file mode 100644 index 000000000000..c6b9c1a67894 --- /dev/null +++ b/examples/xgboost-comprehensive/requirements.txt @@ -0,0 +1,3 @@ +flwr>=1.0, <2.0 +flwr-datasets>=0.0.2, <1.0.0 +xgboost>=2.0.0, <3.0.0 diff --git a/examples/xgboost-comprehensive/run.sh b/examples/xgboost-comprehensive/run.sh new file mode 100755 index 000000000000..7cf65fa4d52d --- /dev/null +++ b/examples/xgboost-comprehensive/run.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ + +echo "Starting server" +python server.py & +sleep 15 # Sleep for 15s to give the server enough time to start + +for i in `seq 0 1`; do + echo "Starting client $i" + python3 client.py --node-id=$i & +done + +# Enable CTRL+C to stop all background processes +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py new file mode 100644 index 000000000000..857e99528013 --- /dev/null +++ b/examples/xgboost-comprehensive/server.py @@ -0,0 +1,111 @@ +from typing import Dict +from logging import INFO +import xgboost as xgb + +import flwr as fl +from flwr.common.logger import log +from flwr.common import Parameters, Scalar +from flwr_datasets import FederatedDataset + +from strategy import FedXgbBagging +from utils import server_args_parser +from dataset import resplit, transform_dataset_to_dmatrix + + +# Parse arguments for experimental settings +args = server_args_parser() +pool_size = args.pool_size +num_rounds = args.num_rounds +num_clients_per_round = args.num_clients_per_round +num_evaluate_clients = args.num_evaluate_clients +centralised_eval = args.centralised_eval + +# Load centralised test set +if centralised_eval: + fds = FederatedDataset( + dataset="jxie/higgs", partitioners={"train": 20}, resplitter=resplit + ) + test_set = fds.load_full("test") + test_set.set_format("numpy") + test_dmatrix = transform_dataset_to_dmatrix(test_set) + +# Hyper-parameters used for initialisation +params = { + "objective": "binary:logistic", + "eta": 0.1, # Learning rate + "max_depth": 8, + "eval_metric": "auc", + "nthread": 16, + "num_parallel_tree": 1, + "subsample": 1, + "tree_method": "hist", +} + + +def eval_config(rnd: int) -> Dict[str, str]: + """Return a configuration with global epochs.""" + config = { + "global_round": str(rnd), + } + return config + + +def evaluate_metrics_aggregation(eval_metrics): + """Return an aggregated metric (AUC) for evaluation.""" + total_num = sum([num for num, _ in eval_metrics]) + auc_aggregated = ( + sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num + ) + metrics_aggregated = {"AUC": auc_aggregated} + return metrics_aggregated + + +def get_evaluate_fn(test_data): + """Return a function for centralised evaluation.""" + + def evaluate_fn( + server_round: int, parameters: Parameters, config: Dict[str, Scalar] + ): + # If at the first round, skip the evaluation + if server_round == 0: + return 0, {} + else: + bst = xgb.Booster(params=params) + for para in parameters.tensors: + para_b = bytearray(para) + + # Load global model + bst.load_model(para_b) + # Run evaluation + eval_results = bst.eval_set( + evals=[(test_data, "valid")], + iteration=bst.num_boosted_rounds() - 1, + ) + auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + log(INFO, f"AUC = {auc} at round {server_round}") + + return 0, {"AUC": auc} + + return evaluate_fn + + +# Define strategy +strategy = FedXgbBagging( + evaluate_function=get_evaluate_fn(test_dmatrix) if centralised_eval else None, + fraction_fit=(float(num_clients_per_round) / pool_size), + min_fit_clients=num_clients_per_round, + min_available_clients=pool_size, + min_evaluate_clients=num_evaluate_clients if not centralised_eval else 0, + fraction_evaluate=1.0 if not centralised_eval else 0.0, + on_evaluate_config_fn=eval_config, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation + if not centralised_eval + else None, +) + +# Start Flower server +fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=num_rounds), + strategy=strategy, +) diff --git a/examples/xgboost-comprehensive/strategy.py b/examples/xgboost-comprehensive/strategy.py new file mode 100644 index 000000000000..814010720a77 --- /dev/null +++ b/examples/xgboost-comprehensive/strategy.py @@ -0,0 +1,139 @@ +from logging import WARNING +from typing import Callable, Dict, List, Optional, Tuple, Union +import flwr as fl +import json + +from flwr.common import ( + EvaluateRes, + FitRes, + Parameters, + Scalar, +) +from flwr.server.client_proxy import ClientProxy +from flwr.common.logger import log + + +class FedXgbBagging(fl.server.strategy.FedAvg): + def __init__( + self, + evaluate_function: Optional[ + Callable[ + [int, Parameters, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]], + ] + ] = None, + **kwargs, + ): + self.evaluate_function = evaluate_function + self.global_model = None + super().__init__(**kwargs) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using bagging.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Aggregate all the client trees + for _, fit_res in results: + update = fit_res.parameters.tensors + for item in update: + self.global_model = aggregate( + self.global_model, json.loads(bytearray(item)) + ) + + weights_avg = json.dumps(self.global_model) + + return ( + Parameters(tensor_type="", tensors=[bytes(weights_avg, "utf-8")]), + {}, + ) + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation metrics using average.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.evaluate_metrics_aggregation_fn: + eval_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No evaluate_metrics_aggregation_fn provided") + + return 0, metrics_aggregated + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function.""" + if self.evaluate_function is None: + # No evaluation function provided + return None + eval_res = self.evaluate_function(server_round, parameters, {}) + if eval_res is None: + return None + loss, metrics = eval_res + return loss, metrics + + +def aggregate(bst_prev: Optional[Dict], bst_curr: Dict) -> Dict: + """Conduct bagging aggregation for given trees.""" + if not bst_prev: + return bst_curr + else: + # Get the tree numbers + tree_num_prev, paral_tree_num_prev = _get_tree_nums(bst_prev) + tree_num_curr, paral_tree_num_curr = _get_tree_nums(bst_curr) + + bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ + "num_trees" + ] = str(tree_num_prev + paral_tree_num_curr) + iteration_indptr = bst_prev["learner"]["gradient_booster"]["model"][ + "iteration_indptr" + ] + bst_prev["learner"]["gradient_booster"]["model"]["iteration_indptr"].append( + iteration_indptr[-1] + paral_tree_num_curr + ) + + # Aggregate new trees + trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"] + for tree_count in range(paral_tree_num_curr): + trees_curr[tree_count]["id"] = tree_num_prev + tree_count + bst_prev["learner"]["gradient_booster"]["model"]["trees"].append( + trees_curr[tree_count] + ) + bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0) + return bst_prev + + +def _get_tree_nums(xgb_model: Dict) -> (int, int): + # Get the number of trees + tree_num = int( + xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ + "num_trees" + ] + ) + # Get the number of parallel trees + paral_tree_num = int( + xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ + "num_parallel_tree" + ] + ) + return tree_num, paral_tree_num diff --git a/examples/xgboost-comprehensive/utils.py b/examples/xgboost-comprehensive/utils.py new file mode 100644 index 000000000000..51c1a1b9604d --- /dev/null +++ b/examples/xgboost-comprehensive/utils.py @@ -0,0 +1,72 @@ +import argparse + + +def client_args_parser(): + """Parse arguments to define experimental settings on client side.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "--num-partitions", default=10, type=int, help="Number of partitions." + ) + parser.add_argument( + "--partitioner-type", + default="uniform", + type=str, + choices=["uniform", "linear", "square", "exponential"], + help="Partitioner types.", + ) + parser.add_argument( + "--node-id", + default=0, + type=int, + help="Node ID used for the current client.", + ) + parser.add_argument( + "--seed", default=42, type=int, help="Seed used for train/test splitting." + ) + parser.add_argument( + "--test-fraction", + default=0.2, + type=float, + help="Test fraction for train/test splitting.", + ) + parser.add_argument( + "--centralised-eval", + action="store_true", + help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", + ) + + args = parser.parse_args() + return args + + +def server_args_parser(): + """Parse arguments to define experimental settings on server side.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "--pool-size", default=2, type=int, help="Number of total clients." + ) + parser.add_argument( + "--num-rounds", default=5, type=int, help="Number of FL rounds." + ) + parser.add_argument( + "--num-clients-per-round", + default=2, + type=int, + help="Number of clients participate in training each round.", + ) + parser.add_argument( + "--num-evaluate-clients", + default=2, + type=int, + help="Number of clients selected for evaluation.", + ) + parser.add_argument( + "--centralised-eval", + action="store_true", + help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", + ) + + args = parser.parse_args() + return args