diff --git a/examples/xgboost-quickstart/README.md b/examples/xgboost-quickstart/README.md new file mode 100644 index 000000000000..d2ef03ad478f --- /dev/null +++ b/examples/xgboost-quickstart/README.md @@ -0,0 +1,85 @@ +# Flower Example using XGBoost + +This example demonstrates how to perform EXtreme Gradient Boosting (XGBoost) within Flower using `xgboost` package. +Tree-based with bagging method is used for aggregation on the server. + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-xgboost . && rm -rf flower && cd quickstart-xgboost +``` + +This will create a new directory called `quickstart-xgboost` containing the following files: + +``` +-- README.md <- Your're reading this right now +-- server.py <- Defines the server-side logic +-- client.py <- Defines the client-side logic +-- dataset.py <- Defines the functions of data loading and partitioning +-- pyproject.toml <- Example dependencies (if you use Poetry) +-- requirements.txt <- Example dependencies +``` + +### Installing Dependencies + +Project dependencies (such as `xgboost` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Learning with XGBoost and Flower + +Afterwards you are ready to start the Flower server as well as the clients. +You can simply start the server in a terminal as follows: + +```shell +python3 server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. +To do so simply open two more terminal windows and run the following commands. + +Start client 1 in the first terminal: + +```shell +python3 client.py +``` + +Start client 2 in the second terminal: + +```shell +python3 client.py +``` + +You will see that XGBoost is starting a federated training. + +Alternatively, you can use `run.sh` to run the same experiment in a single terminal as follows: + +```shell +bash run.sh +``` + +Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-xgboost) +and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. diff --git a/examples/xgboost-quickstart/client.py b/examples/xgboost-quickstart/client.py new file mode 100644 index 000000000000..f6c6f2e34722 --- /dev/null +++ b/examples/xgboost-quickstart/client.py @@ -0,0 +1,144 @@ +import warnings +from logging import INFO +import xgboost as xgb + +import flwr as fl +from flwr_datasets import FederatedDataset +from flwr.common.logger import log +from flwr.common import ( + Code, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + GetParametersIns, + GetParametersRes, + Parameters, + Status, +) +from flwr_datasets.partitioner import IidPartitioner + +from dataset import ( + train_test_split, + transform_dataset_to_dmatrix, +) + + +warnings.filterwarnings("ignore", category=UserWarning) + + +# Load (HIGGS) dataset and conduct partitioning +partitioner = IidPartitioner(num_partitions=10) +fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner}) + +# Let's use the first partition as an example +partition = fds.load_partition(idx=0, split="train") +partition.set_format("numpy") + +# Train/test splitting +train_data, valid_data, num_train, num_val = train_test_split( + partition, test_fraction=0.2, seed=42 +) + +# Reformat data to DMatrix for xgboost +train_dmatrix = transform_dataset_to_dmatrix(train_data) +valid_dmatrix = transform_dataset_to_dmatrix(valid_data) + +# Hyper-parameters for xgboost training +num_local_round = 1 +params = { + "objective": "binary:logistic", + "eta": 0.1, # Learning rate + "max_depth": 8, + "eval_metric": "auc", + "nthread": 16, + "num_parallel_tree": 1, + "subsample": 1, + "tree_method": "hist", +} + + +# Define Flower client +class FlowerClient(fl.client.Client): + def __init__(self): + self.bst = None + self.config = None + + def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: + _ = (self, ins) + return GetParametersRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[]), + ) + + def _local_boost(self): + # Update trees based on local training data. + for i in range(num_local_round): + self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) + + # Extract the last N=num_local_round trees for sever aggregation + bst = self.bst[ + self.bst.num_boosted_rounds() + - num_local_round : self.bst.num_boosted_rounds() + ] + + return bst + + def fit(self, ins: FitIns) -> FitRes: + if not self.bst: + # First round local training + log(INFO, "Start training at round 1") + bst = xgb.train( + params, + train_dmatrix, + num_boost_round=num_local_round, + evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")], + ) + self.config = bst.save_config() + self.bst = bst + else: + for item in ins.parameters.tensors: + global_model = bytearray(item) + + # Load global model into booster + self.bst.load_model(global_model) + self.bst.load_config(self.config) + + bst = self._local_boost() + + local_model = bst.save_raw("json") + local_model_bytes = bytes(local_model) + + return FitRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), + num_examples=num_train, + metrics={}, + ) + + def evaluate(self, ins: EvaluateIns) -> EvaluateRes: + eval_results = self.bst.eval_set( + evals=[(valid_dmatrix, "valid")], + iteration=self.bst.num_boosted_rounds() - 1, + ) + auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + + return EvaluateRes( + status=Status( + code=Code.OK, + message="OK", + ), + loss=0.0, + num_examples=num_val, + metrics={"AUC": auc}, + ) + + +# Start Flower client +fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient()) diff --git a/examples/xgboost-quickstart/dataset.py b/examples/xgboost-quickstart/dataset.py new file mode 100644 index 000000000000..95455618937b --- /dev/null +++ b/examples/xgboost-quickstart/dataset.py @@ -0,0 +1,23 @@ +import xgboost as xgb +from typing import Union +from datasets import Dataset, DatasetDict + + +def train_test_split(partition: Dataset, test_fraction: float, seed: int): + """Split the data into train and validation set given split rate.""" + train_test = partition.train_test_split(test_size=test_fraction, seed=seed) + partition_train = train_test["train"] + partition_test = train_test["test"] + + num_train = len(partition_train) + num_test = len(partition_test) + + return partition_train, partition_test, num_train, num_test + + +def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix: + """Transform dataset to DMatrix format for xgboost.""" + x = data["inputs"] + y = data["label"] + new_data = xgb.DMatrix(x, label=y) + return new_data diff --git a/examples/xgboost-quickstart/pyproject.toml b/examples/xgboost-quickstart/pyproject.toml new file mode 100644 index 000000000000..74256846c693 --- /dev/null +++ b/examples/xgboost-quickstart/pyproject.toml @@ -0,0 +1,15 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "xgboost-quickstart" +version = "0.1.0" +description = "Federated XGBoost with Flower (quickstart)" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr = ">=1.0,<2.0" +flwr-datasets = ">=0.0.1,<1.0.0" +xgboost = ">=2.0.0,<3.0.0" diff --git a/examples/xgboost-quickstart/requirements.txt b/examples/xgboost-quickstart/requirements.txt new file mode 100644 index 000000000000..9596a8d6cd02 --- /dev/null +++ b/examples/xgboost-quickstart/requirements.txt @@ -0,0 +1,3 @@ +flwr>=1.0, <2.0 +flwr-datasets>=0.0.1, <1.0.0 +xgboost>=2.0.0, <3.0.0 diff --git a/examples/xgboost-quickstart/run.sh b/examples/xgboost-quickstart/run.sh new file mode 100755 index 000000000000..b3ab546022d8 --- /dev/null +++ b/examples/xgboost-quickstart/run.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ + +echo "Starting server" +python server.py & +sleep 5 # Sleep for 15s to give the server enough time to start + +for i in `seq 0 1`; do + echo "Starting client $i" + python3 client.py & +done + +# Enable CTRL+C to stop all background processes +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/xgboost-quickstart/server.py b/examples/xgboost-quickstart/server.py new file mode 100644 index 000000000000..d287eae596e7 --- /dev/null +++ b/examples/xgboost-quickstart/server.py @@ -0,0 +1,37 @@ +import flwr as fl +from strategy import FedXgbBagging + + +# FL experimental settings +pool_size = 2 +num_rounds = 5 +num_clients_per_round = 2 +num_evaluate_clients = 2 + + +def evaluate_metrics_aggregation(eval_metrics): + """Return an aggregated metric (AUC) for evaluation.""" + total_num = sum([num for num, _ in eval_metrics]) + auc_aggregated = ( + sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num + ) + metrics_aggregated = {"AUC": auc_aggregated} + return metrics_aggregated + + +# Define strategy +strategy = FedXgbBagging( + fraction_fit=(float(num_clients_per_round) / pool_size), + min_fit_clients=num_clients_per_round, + min_available_clients=pool_size, + min_evaluate_clients=num_evaluate_clients, + fraction_evaluate=1.0, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, +) + +# Start Flower server +fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=num_rounds), + strategy=strategy, +) diff --git a/examples/xgboost-quickstart/strategy.py b/examples/xgboost-quickstart/strategy.py new file mode 100644 index 000000000000..578328ca701b --- /dev/null +++ b/examples/xgboost-quickstart/strategy.py @@ -0,0 +1,163 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Federated XGBoost bagging aggregation strategy.""" + + +import json +from logging import WARNING +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast + +import flwr as fl +from flwr.common import EvaluateRes, FitRes, Parameters, Scalar +from flwr.common.logger import log +from flwr.server.client_proxy import ClientProxy + + +class FedXgbBagging(fl.server.strategy.FedAvg): + """Configurable FedXgbBagging strategy implementation.""" + + # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long + def __init__( + self, + evaluate_function: Optional[ + Callable[ + [int, Parameters, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]], + ] + ] = None, + **kwargs: Any, + ): + self.evaluate_function = evaluate_function + self.global_model: Optional[bytes] = None + super().__init__(**kwargs) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using bagging.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Aggregate all the client trees + global_model = self.global_model + for _, fit_res in results: + update = fit_res.parameters.tensors + for bst in update: + global_model = aggregate(global_model, bst) + + self.global_model = global_model + + return ( + Parameters(tensor_type="", tensors=[cast(bytes, global_model)]), + {}, + ) + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation metrics using average.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.evaluate_metrics_aggregation_fn: + eval_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No evaluate_metrics_aggregation_fn provided") + + return 0, metrics_aggregated + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function.""" + if self.evaluate_function is None: + # No evaluation function provided + return None + eval_res = self.evaluate_function(server_round, parameters, {}) + if eval_res is None: + return None + loss, metrics = eval_res + return loss, metrics + + +def aggregate( + bst_prev_org: Optional[bytes], + bst_curr_org: bytes, +) -> bytes: + """Conduct bagging aggregation for given trees.""" + if not bst_prev_org: + return bst_curr_org + + # Get the tree numbers + tree_num_prev, _ = _get_tree_nums(bst_prev_org) + _, paral_tree_num_curr = _get_tree_nums(bst_curr_org) + + bst_prev = json.loads(bytearray(bst_prev_org)) + bst_curr = json.loads(bytearray(bst_curr_org)) + + bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ + "num_trees" + ] = str(tree_num_prev + paral_tree_num_curr) + iteration_indptr = bst_prev["learner"]["gradient_booster"]["model"][ + "iteration_indptr" + ] + bst_prev["learner"]["gradient_booster"]["model"]["iteration_indptr"].append( + iteration_indptr[-1] + paral_tree_num_curr + ) + + # Aggregate new trees + trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"] + for tree_count in range(paral_tree_num_curr): + trees_curr[tree_count]["id"] = tree_num_prev + tree_count + bst_prev["learner"]["gradient_booster"]["model"]["trees"].append( + trees_curr[tree_count] + ) + bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0) + + bst_prev_bytes = bytes(json.dumps(bst_prev), "utf-8") + + return bst_prev_bytes + + +def _get_tree_nums(xgb_model_org: bytes) -> Tuple[int, int]: + xgb_model = json.loads(bytearray(xgb_model_org)) + # Get the number of trees + tree_num = int( + xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ + "num_trees" + ] + ) + # Get the number of parallel trees + paral_tree_num = int( + xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ + "num_parallel_tree" + ] + ) + return tree_num, paral_tree_num