diff --git a/examples/xgboost-quickstart/README.md b/examples/xgboost-quickstart/README.md new file mode 100644 index 000000000000..53cd37e18aa3 --- /dev/null +++ b/examples/xgboost-quickstart/README.md @@ -0,0 +1,86 @@ +# Flower Example using XGBoost + +This example demonstrates how to perform EXtreme Gradient Boosting (XGBoost) within Flower using `xgboost` package. +Tree-based with bagging method is used for aggregation on the server. + +This project provides a minimal code example to enable you to get stated quickly. For a more comprehensive code example, take a look at [xgboost-comprehensive](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive). + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/xgboost-quickstart . && rm -rf flower && cd xgboost-quickstart +``` + +This will create a new directory called `xgboost-quickstart` containing the following files: + +``` +-- README.md <- Your're reading this right now +-- server.py <- Defines the server-side logic +-- client.py <- Defines the client-side logic +-- pyproject.toml <- Example dependencies (if you use Poetry) +-- requirements.txt <- Example dependencies +``` + +### Installing Dependencies + +Project dependencies (such as `xgboost` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Learning with XGBoost and Flower + +Afterwards you are ready to start the Flower server as well as the clients. +You can simply start the server in a terminal as follows: + +```shell +python3 server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. +To do so simply open two more terminal windows and run the following commands. + +Start client 1 in the first terminal: + +```shell +python3 client.py --node-id=0 +``` + +Start client 2 in the second terminal: + +```shell +python3 client.py --node-id=1 +``` + +You will see that XGBoost is starting a federated training. + +Alternatively, you can use `run.sh` to run the same experiment in a single terminal as follows: + +```shell +bash run.sh +``` + +Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart) +and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. diff --git a/examples/xgboost-quickstart/client.py b/examples/xgboost-quickstart/client.py new file mode 100644 index 000000000000..ede4a2bba764 --- /dev/null +++ b/examples/xgboost-quickstart/client.py @@ -0,0 +1,173 @@ +import argparse +import warnings +from typing import Union +from logging import INFO +from datasets import Dataset, DatasetDict +import xgboost as xgb + +import flwr as fl +from flwr_datasets import FederatedDataset +from flwr.common.logger import log +from flwr.common import ( + Code, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + GetParametersIns, + GetParametersRes, + Parameters, + Status, +) +from flwr_datasets.partitioner import IidPartitioner + + +warnings.filterwarnings("ignore", category=UserWarning) + +# Define arguments parser for the client/node ID. +parser = argparse.ArgumentParser() +parser.add_argument( + "--node-id", + default=0, + type=int, + help="Node ID used for the current client.", +) +args = parser.parse_args() + + +# Define data partitioning related functions +def train_test_split(partition: Dataset, test_fraction: float, seed: int): + """Split the data into train and validation set given split rate.""" + train_test = partition.train_test_split(test_size=test_fraction, seed=seed) + partition_train = train_test["train"] + partition_test = train_test["test"] + + num_train = len(partition_train) + num_test = len(partition_test) + + return partition_train, partition_test, num_train, num_test + + +def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix: + """Transform dataset to DMatrix format for xgboost.""" + x = data["inputs"] + y = data["label"] + new_data = xgb.DMatrix(x, label=y) + return new_data + + +# Load (HIGGS) dataset and conduct partitioning +partitioner = IidPartitioner(num_partitions=2) +fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner}) + +# Load the partition for this `node_id` +partition = fds.load_partition(idx=args.node_id, split="train") +partition.set_format("numpy") + +# Train/test splitting +train_data, valid_data, num_train, num_val = train_test_split( + partition, test_fraction=0.2, seed=42 +) + +# Reformat data to DMatrix for xgboost +train_dmatrix = transform_dataset_to_dmatrix(train_data) +valid_dmatrix = transform_dataset_to_dmatrix(valid_data) + +# Hyper-parameters for xgboost training +num_local_round = 1 +params = { + "objective": "binary:logistic", + "eta": 0.1, # Learning rate + "max_depth": 8, + "eval_metric": "auc", + "nthread": 16, + "num_parallel_tree": 1, + "subsample": 1, + "tree_method": "hist", +} + + +# Define Flower client +class FlowerClient(fl.client.Client): + def __init__(self): + self.bst = None + self.config = None + + def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: + _ = (self, ins) + return GetParametersRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[]), + ) + + def _local_boost(self): + # Update trees based on local training data. + for i in range(num_local_round): + self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) + + # Extract the last N=num_local_round trees for sever aggregation + bst = self.bst[ + self.bst.num_boosted_rounds() + - num_local_round : self.bst.num_boosted_rounds() + ] + + return bst + + def fit(self, ins: FitIns) -> FitRes: + if not self.bst: + # First round local training + log(INFO, "Start training at round 1") + bst = xgb.train( + params, + train_dmatrix, + num_boost_round=num_local_round, + evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")], + ) + self.config = bst.save_config() + self.bst = bst + else: + for item in ins.parameters.tensors: + global_model = bytearray(item) + + # Load global model into booster + self.bst.load_model(global_model) + self.bst.load_config(self.config) + + bst = self._local_boost() + + local_model = bst.save_raw("json") + local_model_bytes = bytes(local_model) + + return FitRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), + num_examples=num_train, + metrics={}, + ) + + def evaluate(self, ins: EvaluateIns) -> EvaluateRes: + eval_results = self.bst.eval_set( + evals=[(valid_dmatrix, "valid")], + iteration=self.bst.num_boosted_rounds() - 1, + ) + auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + + return EvaluateRes( + status=Status( + code=Code.OK, + message="OK", + ), + loss=0.0, + num_examples=num_val, + metrics={"AUC": auc}, + ) + + +# Start Flower client +fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient()) diff --git a/examples/xgboost-quickstart/pyproject.toml b/examples/xgboost-quickstart/pyproject.toml new file mode 100644 index 000000000000..74256846c693 --- /dev/null +++ b/examples/xgboost-quickstart/pyproject.toml @@ -0,0 +1,15 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "xgboost-quickstart" +version = "0.1.0" +description = "Federated XGBoost with Flower (quickstart)" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr = ">=1.0,<2.0" +flwr-datasets = ">=0.0.1,<1.0.0" +xgboost = ">=2.0.0,<3.0.0" diff --git a/examples/xgboost-quickstart/requirements.txt b/examples/xgboost-quickstart/requirements.txt new file mode 100644 index 000000000000..9596a8d6cd02 --- /dev/null +++ b/examples/xgboost-quickstart/requirements.txt @@ -0,0 +1,3 @@ +flwr>=1.0, <2.0 +flwr-datasets>=0.0.1, <1.0.0 +xgboost>=2.0.0, <3.0.0 diff --git a/examples/xgboost-quickstart/run.sh b/examples/xgboost-quickstart/run.sh new file mode 100755 index 000000000000..6287145bfb5f --- /dev/null +++ b/examples/xgboost-quickstart/run.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ + +echo "Starting server" +python server.py & +sleep 5 # Sleep for 5s to give the server enough time to start + +for i in `seq 0 1`; do + echo "Starting client $i" + python3 client.py --node-id=$i & +done + +# Enable CTRL+C to stop all background processes +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/xgboost-quickstart/server.py b/examples/xgboost-quickstart/server.py new file mode 100644 index 000000000000..b45a375ce94f --- /dev/null +++ b/examples/xgboost-quickstart/server.py @@ -0,0 +1,37 @@ +import flwr as fl +from flwr.server.strategy import FedXgbBagging + + +# FL experimental settings +pool_size = 2 +num_rounds = 5 +num_clients_per_round = 2 +num_evaluate_clients = 2 + + +def evaluate_metrics_aggregation(eval_metrics): + """Return an aggregated metric (AUC) for evaluation.""" + total_num = sum([num for num, _ in eval_metrics]) + auc_aggregated = ( + sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num + ) + metrics_aggregated = {"AUC": auc_aggregated} + return metrics_aggregated + + +# Define strategy +strategy = FedXgbBagging( + fraction_fit=(float(num_clients_per_round) / pool_size), + min_fit_clients=num_clients_per_round, + min_available_clients=pool_size, + min_evaluate_clients=num_evaluate_clients, + fraction_evaluate=1.0, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, +) + +# Start Flower server +fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=num_rounds), + strategy=strategy, +)