Skip to content

Commit

Permalink
Create XGBoost-quickstart example (#2612)
Browse files Browse the repository at this point in the history
Co-authored-by: yan-gao-GY <[email protected]>
  • Loading branch information
yan-gao-GY and yan-gao-GY authored Nov 17, 2023
1 parent bc346ac commit 5cf5c82
Show file tree
Hide file tree
Showing 6 changed files with 331 additions and 0 deletions.
86 changes: 86 additions & 0 deletions examples/xgboost-quickstart/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Flower Example using XGBoost

This example demonstrates how to perform EXtreme Gradient Boosting (XGBoost) within Flower using `xgboost` package.
Tree-based with bagging method is used for aggregation on the server.

This project provides a minimal code example to enable you to get stated quickly. For a more comprehensive code example, take a look at [xgboost-comprehensive](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive).

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/xgboost-quickstart . && rm -rf flower && cd xgboost-quickstart
```

This will create a new directory called `xgboost-quickstart` containing the following files:

```
-- README.md <- Your're reading this right now
-- server.py <- Defines the server-side logic
-- client.py <- Defines the client-side logic
-- pyproject.toml <- Example dependencies (if you use Poetry)
-- requirements.txt <- Example dependencies
```

### Installing Dependencies

Project dependencies (such as `xgboost` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

#### Poetry

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
poetry run python3 -c "import flwr"
```

If you don't see any errors you're good to go!

#### pip

Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt.

```shell
pip install -r requirements.txt
```

## Run Federated Learning with XGBoost and Flower

Afterwards you are ready to start the Flower server as well as the clients.
You can simply start the server in a terminal as follows:

```shell
python3 server.py
```

Now you are ready to start the Flower clients which will participate in the learning.
To do so simply open two more terminal windows and run the following commands.

Start client 1 in the first terminal:

```shell
python3 client.py --node-id=0
```

Start client 2 in the second terminal:

```shell
python3 client.py --node-id=1
```

You will see that XGBoost is starting a federated training.

Alternatively, you can use `run.sh` to run the same experiment in a single terminal as follows:

```shell
bash run.sh
```

Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart)
and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation.
173 changes: 173 additions & 0 deletions examples/xgboost-quickstart/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import argparse
import warnings
from typing import Union
from logging import INFO
from datasets import Dataset, DatasetDict
import xgboost as xgb

import flwr as fl
from flwr_datasets import FederatedDataset
from flwr.common.logger import log
from flwr.common import (
Code,
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
GetParametersIns,
GetParametersRes,
Parameters,
Status,
)
from flwr_datasets.partitioner import IidPartitioner


warnings.filterwarnings("ignore", category=UserWarning)

# Define arguments parser for the client/node ID.
parser = argparse.ArgumentParser()
parser.add_argument(
"--node-id",
default=0,
type=int,
help="Node ID used for the current client.",
)
args = parser.parse_args()


# Define data partitioning related functions
def train_test_split(partition: Dataset, test_fraction: float, seed: int):
"""Split the data into train and validation set given split rate."""
train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
partition_train = train_test["train"]
partition_test = train_test["test"]

num_train = len(partition_train)
num_test = len(partition_test)

return partition_train, partition_test, num_train, num_test


def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix:
"""Transform dataset to DMatrix format for xgboost."""
x = data["inputs"]
y = data["label"]
new_data = xgb.DMatrix(x, label=y)
return new_data


# Load (HIGGS) dataset and conduct partitioning
partitioner = IidPartitioner(num_partitions=2)
fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner})

# Load the partition for this `node_id`
partition = fds.load_partition(idx=args.node_id, split="train")
partition.set_format("numpy")

# Train/test splitting
train_data, valid_data, num_train, num_val = train_test_split(
partition, test_fraction=0.2, seed=42
)

# Reformat data to DMatrix for xgboost
train_dmatrix = transform_dataset_to_dmatrix(train_data)
valid_dmatrix = transform_dataset_to_dmatrix(valid_data)

# Hyper-parameters for xgboost training
num_local_round = 1
params = {
"objective": "binary:logistic",
"eta": 0.1, # Learning rate
"max_depth": 8,
"eval_metric": "auc",
"nthread": 16,
"num_parallel_tree": 1,
"subsample": 1,
"tree_method": "hist",
}


# Define Flower client
class FlowerClient(fl.client.Client):
def __init__(self):
self.bst = None
self.config = None

def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
_ = (self, ins)
return GetParametersRes(
status=Status(
code=Code.OK,
message="OK",
),
parameters=Parameters(tensor_type="", tensors=[]),
)

def _local_boost(self):
# Update trees based on local training data.
for i in range(num_local_round):
self.bst.update(train_dmatrix, self.bst.num_boosted_rounds())

# Extract the last N=num_local_round trees for sever aggregation
bst = self.bst[
self.bst.num_boosted_rounds()
- num_local_round : self.bst.num_boosted_rounds()
]

return bst

def fit(self, ins: FitIns) -> FitRes:
if not self.bst:
# First round local training
log(INFO, "Start training at round 1")
bst = xgb.train(
params,
train_dmatrix,
num_boost_round=num_local_round,
evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")],
)
self.config = bst.save_config()
self.bst = bst
else:
for item in ins.parameters.tensors:
global_model = bytearray(item)

# Load global model into booster
self.bst.load_model(global_model)
self.bst.load_config(self.config)

bst = self._local_boost()

local_model = bst.save_raw("json")
local_model_bytes = bytes(local_model)

return FitRes(
status=Status(
code=Code.OK,
message="OK",
),
parameters=Parameters(tensor_type="", tensors=[local_model_bytes]),
num_examples=num_train,
metrics={},
)

def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
eval_results = self.bst.eval_set(
evals=[(valid_dmatrix, "valid")],
iteration=self.bst.num_boosted_rounds() - 1,
)
auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4)

return EvaluateRes(
status=Status(
code=Code.OK,
message="OK",
),
loss=0.0,
num_examples=num_val,
metrics={"AUC": auc},
)


# Start Flower client
fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient())
15 changes: 15 additions & 0 deletions examples/xgboost-quickstart/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "xgboost-quickstart"
version = "0.1.0"
description = "Federated XGBoost with Flower (quickstart)"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
flwr-datasets = ">=0.0.1,<1.0.0"
xgboost = ">=2.0.0,<3.0.0"
3 changes: 3 additions & 0 deletions examples/xgboost-quickstart/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
flwr>=1.0, <2.0
flwr-datasets>=0.0.1, <1.0.0
xgboost>=2.0.0, <3.0.0
17 changes: 17 additions & 0 deletions examples/xgboost-quickstart/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

echo "Starting server"
python server.py &
sleep 5 # Sleep for 5s to give the server enough time to start

for i in `seq 0 1`; do
echo "Starting client $i"
python3 client.py --node-id=$i &
done

# Enable CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
37 changes: 37 additions & 0 deletions examples/xgboost-quickstart/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import flwr as fl
from flwr.server.strategy import FedXgbBagging


# FL experimental settings
pool_size = 2
num_rounds = 5
num_clients_per_round = 2
num_evaluate_clients = 2


def evaluate_metrics_aggregation(eval_metrics):
"""Return an aggregated metric (AUC) for evaluation."""
total_num = sum([num for num, _ in eval_metrics])
auc_aggregated = (
sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num
)
metrics_aggregated = {"AUC": auc_aggregated}
return metrics_aggregated


# Define strategy
strategy = FedXgbBagging(
fraction_fit=(float(num_clients_per_round) / pool_size),
min_fit_clients=num_clients_per_round,
min_available_clients=pool_size,
min_evaluate_clients=num_evaluate_clients,
fraction_evaluate=1.0,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation,
)

# Start Flower server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=num_rounds),
strategy=strategy,
)

0 comments on commit 5cf5c82

Please sign in to comment.