Skip to content

Commit

Permalink
Initialise xgboost-quickstart
Browse files Browse the repository at this point in the history
  • Loading branch information
yan-gao-GY committed Nov 16, 2023
1 parent 7ae14a2 commit 6c842a8
Show file tree
Hide file tree
Showing 8 changed files with 487 additions and 0 deletions.
85 changes: 85 additions & 0 deletions examples/xgboost-quickstart/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Flower Example using XGBoost

This example demonstrates how to perform EXtreme Gradient Boosting (XGBoost) within Flower using `xgboost` package.
Tree-based with bagging method is used for aggregation on the server.

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-xgboost . && rm -rf flower && cd quickstart-xgboost
```

This will create a new directory called `quickstart-xgboost` containing the following files:

```
-- README.md <- Your're reading this right now
-- server.py <- Defines the server-side logic
-- client.py <- Defines the client-side logic
-- dataset.py <- Defines the functions of data loading and partitioning
-- pyproject.toml <- Example dependencies (if you use Poetry)
-- requirements.txt <- Example dependencies
```

### Installing Dependencies

Project dependencies (such as `xgboost` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

#### Poetry

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
poetry run python3 -c "import flwr"
```

If you don't see any errors you're good to go!

#### pip

Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt.

```shell
pip install -r requirements.txt
```

## Run Federated Learning with XGBoost and Flower

Afterwards you are ready to start the Flower server as well as the clients.
You can simply start the server in a terminal as follows:

```shell
python3 server.py
```

Now you are ready to start the Flower clients which will participate in the learning.
To do so simply open two more terminal windows and run the following commands.

Start client 1 in the first terminal:

```shell
python3 client.py
```

Start client 2 in the second terminal:

```shell
python3 client.py
```

You will see that XGBoost is starting a federated training.

Alternatively, you can use `run.sh` to run the same experiment in a single terminal as follows:

```shell
bash run.sh
```

Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-xgboost)
and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation.
144 changes: 144 additions & 0 deletions examples/xgboost-quickstart/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import warnings
from logging import INFO
import xgboost as xgb

import flwr as fl
from flwr_datasets import FederatedDataset
from flwr.common.logger import log
from flwr.common import (
Code,
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
GetParametersIns,
GetParametersRes,
Parameters,
Status,
)
from flwr_datasets.partitioner import IidPartitioner

from dataset import (
train_test_split,
transform_dataset_to_dmatrix,
)


warnings.filterwarnings("ignore", category=UserWarning)


# Load (HIGGS) dataset and conduct partitioning
partitioner = IidPartitioner(num_partitions=10)
fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner})

# Let's use the first partition as an example
partition = fds.load_partition(idx=0, split="train")
partition.set_format("numpy")

# Train/test splitting
train_data, valid_data, num_train, num_val = train_test_split(
partition, test_fraction=0.2, seed=42
)

# Reformat data to DMatrix for xgboost
train_dmatrix = transform_dataset_to_dmatrix(train_data)
valid_dmatrix = transform_dataset_to_dmatrix(valid_data)

# Hyper-parameters for xgboost training
num_local_round = 1
params = {
"objective": "binary:logistic",
"eta": 0.1, # Learning rate
"max_depth": 8,
"eval_metric": "auc",
"nthread": 16,
"num_parallel_tree": 1,
"subsample": 1,
"tree_method": "hist",
}


# Define Flower client
class FlowerClient(fl.client.Client):
def __init__(self):
self.bst = None
self.config = None

def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
_ = (self, ins)
return GetParametersRes(
status=Status(
code=Code.OK,
message="OK",
),
parameters=Parameters(tensor_type="", tensors=[]),
)

def _local_boost(self):
# Update trees based on local training data.
for i in range(num_local_round):
self.bst.update(train_dmatrix, self.bst.num_boosted_rounds())

# Extract the last N=num_local_round trees for sever aggregation
bst = self.bst[
self.bst.num_boosted_rounds()
- num_local_round : self.bst.num_boosted_rounds()
]

return bst

def fit(self, ins: FitIns) -> FitRes:
if not self.bst:
# First round local training
log(INFO, "Start training at round 1")
bst = xgb.train(
params,
train_dmatrix,
num_boost_round=num_local_round,
evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")],
)
self.config = bst.save_config()
self.bst = bst
else:
for item in ins.parameters.tensors:
global_model = bytearray(item)

# Load global model into booster
self.bst.load_model(global_model)
self.bst.load_config(self.config)

bst = self._local_boost()

local_model = bst.save_raw("json")
local_model_bytes = bytes(local_model)

return FitRes(
status=Status(
code=Code.OK,
message="OK",
),
parameters=Parameters(tensor_type="", tensors=[local_model_bytes]),
num_examples=num_train,
metrics={},
)

def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
eval_results = self.bst.eval_set(
evals=[(valid_dmatrix, "valid")],
iteration=self.bst.num_boosted_rounds() - 1,
)
auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4)

return EvaluateRes(
status=Status(
code=Code.OK,
message="OK",
),
loss=0.0,
num_examples=num_val,
metrics={"AUC": auc},
)


# Start Flower client
fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient())
23 changes: 23 additions & 0 deletions examples/xgboost-quickstart/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import xgboost as xgb
from typing import Union
from datasets import Dataset, DatasetDict


def train_test_split(partition: Dataset, test_fraction: float, seed: int):
"""Split the data into train and validation set given split rate."""
train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
partition_train = train_test["train"]
partition_test = train_test["test"]

num_train = len(partition_train)
num_test = len(partition_test)

return partition_train, partition_test, num_train, num_test


def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix:
"""Transform dataset to DMatrix format for xgboost."""
x = data["inputs"]
y = data["label"]
new_data = xgb.DMatrix(x, label=y)
return new_data
15 changes: 15 additions & 0 deletions examples/xgboost-quickstart/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "xgboost-quickstart"
version = "0.1.0"
description = "Federated XGBoost with Flower (quickstart)"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
flwr-datasets = ">=0.0.1,<1.0.0"
xgboost = ">=2.0.0,<3.0.0"
3 changes: 3 additions & 0 deletions examples/xgboost-quickstart/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
flwr>=1.0, <2.0
flwr-datasets>=0.0.1, <1.0.0
xgboost>=2.0.0, <3.0.0
17 changes: 17 additions & 0 deletions examples/xgboost-quickstart/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

echo "Starting server"
python server.py &
sleep 5 # Sleep for 15s to give the server enough time to start

for i in `seq 0 1`; do
echo "Starting client $i"
python3 client.py &
done

# Enable CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
37 changes: 37 additions & 0 deletions examples/xgboost-quickstart/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import flwr as fl
from strategy import FedXgbBagging


# FL experimental settings
pool_size = 2
num_rounds = 5
num_clients_per_round = 2
num_evaluate_clients = 2


def evaluate_metrics_aggregation(eval_metrics):
"""Return an aggregated metric (AUC) for evaluation."""
total_num = sum([num for num, _ in eval_metrics])
auc_aggregated = (
sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num
)
metrics_aggregated = {"AUC": auc_aggregated}
return metrics_aggregated


# Define strategy
strategy = FedXgbBagging(
fraction_fit=(float(num_clients_per_round) / pool_size),
min_fit_clients=num_clients_per_round,
min_available_clients=pool_size,
min_evaluate_clients=num_evaluate_clients,
fraction_evaluate=1.0,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation,
)

# Start Flower server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=num_rounds),
strategy=strategy,
)
Loading

0 comments on commit 6c842a8

Please sign in to comment.