Skip to content

Commit

Permalink
Quickstart-xgboost with bagging aggregation (#2554)
Browse files Browse the repository at this point in the history
Co-authored-by: yan-gao-GY <[email protected]>
  • Loading branch information
yan-gao-GY and yan-gao-GY authored Nov 15, 2023
1 parent 760a456 commit f056175
Show file tree
Hide file tree
Showing 9 changed files with 685 additions and 0 deletions.
87 changes: 87 additions & 0 deletions examples/xgboost-comprehensive/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Flower Example using XGBoost

This example demonstrates how to perform EXtreme Gradient Boosting (XGBoost) within Flower using `xgboost` package.
Tree-based with bagging method is used for aggregation on the server.

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-xgboost . && rm -rf flower && cd quickstart-xgboost
```

This will create a new directory called `quickstart-xgboost` containing the following files:

```
-- README.md <- Your're reading this right now
-- server.py <- Defines the server-side logic
-- strategy.py <- Defines the tree-based bagging aggregation
-- client.py <- Defines the client-side logic
-- dataset.py <- Defines the functions of data loading and partitioning
-- pyproject.toml <- Example dependencies (if you use Poetry)
-- requirements.txt <- Example dependencies
```

### Installing Dependencies

Project dependencies (such as `xgboost` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

#### Poetry

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
poetry run python3 -c "import flwr"
```

If you don't see any errors you're good to go!

#### pip

Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt.

```shell
pip install -r requirements.txt
```

## Run Federated Learning with XGBoost and Flower

Afterwards you are ready to start the Flower server as well as the clients.
You can simply start the server in a terminal as follows:

```shell
python3 server.py
```

Now you are ready to start the Flower clients which will participate in the learning.
To do so simply open two more terminal windows and run the following commands.

Start client 1 in the first terminal:

```shell
python3 client.py --node-id=0
```

Start client 2 in the second terminal:

```shell
python3 client.py --node-id=1
```

You will see that XGBoost is starting a federated training.

Alternatively, you can use `run.sh` to run the same experiment in a single terminal as follows:

```shell
bash run.sh
```

Besides, we provide options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`).
Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-xgboost)
and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation.
174 changes: 174 additions & 0 deletions examples/xgboost-comprehensive/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import warnings
from logging import INFO
import xgboost as xgb

import flwr as fl
from flwr_datasets import FederatedDataset
from flwr.common.logger import log
from flwr.common import (
Code,
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
GetParametersIns,
GetParametersRes,
Parameters,
Status,
)

from dataset import (
instantiate_partitioner,
train_test_split,
transform_dataset_to_dmatrix,
resplit,
)
from utils import client_args_parser


warnings.filterwarnings("ignore", category=UserWarning)


# Parse arguments for experimental settings
args = client_args_parser()

# Load (HIGGS) dataset and conduct partitioning
num_partitions = args.num_partitions

# Partitioner type is chosen from ["uniform", "linear", "square", "exponential"]
partitioner_type = args.partitioner_type

# Instantiate partitioner
partitioner = instantiate_partitioner(
partitioner_type=partitioner_type, num_partitions=num_partitions
)
fds = FederatedDataset(
dataset="jxie/higgs", partitioners={"train": partitioner}, resplitter=resplit
)

# Let's use the first partition as an example
node_id = args.node_id
partition = fds.load_partition(idx=node_id, split="train")
partition.set_format("numpy")

if args.centralised_eval:
# Use centralised test set for evaluation
train_data = partition
valid_data = fds.load_full("test")
valid_data.set_format("numpy")
num_train = train_data.shape[0]
num_val = valid_data.shape[0]
else:
# Train/test splitting
SEED = args.seed
test_fraction = args.test_fraction
train_data, valid_data, num_train, num_val = train_test_split(
partition, test_fraction=test_fraction, seed=SEED
)

# Reformat data to DMatrix for xgboost
train_dmatrix = transform_dataset_to_dmatrix(train_data)
valid_dmatrix = transform_dataset_to_dmatrix(valid_data)


# Hyper-parameters for xgboost training
num_local_round = 1
params = {
"objective": "binary:logistic",
"eta": 0.1, # Learning rate
"max_depth": 8,
"eval_metric": "auc",
"nthread": 16,
"num_parallel_tree": 1,
"subsample": 1,
"tree_method": "hist",
}


# Define Flower client
class FlowerClient(fl.client.Client):
def __init__(self):
self.bst = None
self.config = None

def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
_ = (self, ins)
return GetParametersRes(
status=Status(
code=Code.OK,
message="OK",
),
parameters=Parameters(tensor_type="", tensors=[]),
)

def _local_boost(self):
# Update trees based on local training data.
for i in range(num_local_round):
self.bst.update(train_dmatrix, self.bst.num_boosted_rounds())

# Extract the last N=num_local_round trees for sever aggregation
bst = self.bst[
self.bst.num_boosted_rounds()
- num_local_round : self.bst.num_boosted_rounds()
]

return bst

def fit(self, ins: FitIns) -> FitRes:
if not self.bst:
# First round local training
log(INFO, "Start training at round 1")
bst = xgb.train(
params,
train_dmatrix,
num_boost_round=num_local_round,
evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")],
)
self.config = bst.save_config()
self.bst = bst
else:
for item in ins.parameters.tensors:
global_model = bytearray(item)

# Load global model into booster
self.bst.load_model(global_model)
self.bst.load_config(self.config)

bst = self._local_boost()

local_model = bst.save_raw("json")
local_model_bytes = bytes(local_model)

return FitRes(
status=Status(
code=Code.OK,
message="OK",
),
parameters=Parameters(tensor_type="", tensors=[local_model_bytes]),
num_examples=num_train,
metrics={},
)

def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
eval_results = self.bst.eval_set(
evals=[(valid_dmatrix, "valid")],
iteration=self.bst.num_boosted_rounds() - 1,
)
auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4)

global_round = ins.config["global_round"]
log(INFO, f"AUC = {auc} at round {global_round}")

return EvaluateRes(
status=Status(
code=Code.OK,
message="OK",
),
loss=0.0,
num_examples=num_val,
metrics={"AUC": auc},
)


# Start Flower client
fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient())
67 changes: 67 additions & 0 deletions examples/xgboost-comprehensive/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import xgboost as xgb
from typing import Callable, Dict, List, Optional, Tuple, Union
from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.partitioner import (
IidPartitioner,
LinearPartitioner,
SquarePartitioner,
ExponentialPartitioner,
)

CORRELATION_TO_PARTITIONER = {
"uniform": IidPartitioner,
"linear": LinearPartitioner,
"square": SquarePartitioner,
"exponential": ExponentialPartitioner,
}


def instantiate_partitioner(partitioner_type: str, num_partitions: int):
"""Initialise partitioner based on selected partitioner type and number of
partitions."""
partitioner = CORRELATION_TO_PARTITIONER[partitioner_type](
num_partitions=num_partitions
)
return partitioner


def train_test_split(partition: Dataset, test_fraction: float, seed: int):
"""Split the data into train and validation set given split rate."""
train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
partition_train = train_test["train"]
partition_test = train_test["test"]

num_train = len(partition_train)
num_test = len(partition_test)

return partition_train, partition_test, num_train, num_test


def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix:
"""Transform dataset to DMatrix format for xgboost."""
x = data["inputs"]
y = data["label"]
new_data = xgb.DMatrix(x, label=y)
return new_data


def resplit(dataset: DatasetDict) -> DatasetDict:
"""Increase the quantity of centralised test samples from 500K to 1M."""
return DatasetDict(
{
"train": dataset["train"].select(
range(0, dataset["train"].num_rows - 500_000)
),
"test": concatenate_datasets(
[
dataset["train"].select(
range(
dataset["train"].num_rows - 500_000,
dataset["train"].num_rows,
)
),
dataset["test"],
]
),
}
)
15 changes: 15 additions & 0 deletions examples/xgboost-comprehensive/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "xgboost-comprehensive"
version = "0.1.0"
description = "Federated XGBoost with Flower (comprehensive)"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
flwr-datasets = ">=0.0.2,<1.0.0"
xgboost = ">=2.0.0,<3.0.0"
3 changes: 3 additions & 0 deletions examples/xgboost-comprehensive/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
flwr>=1.0, <2.0
flwr-datasets>=0.0.2, <1.0.0
xgboost>=2.0.0, <3.0.0
17 changes: 17 additions & 0 deletions examples/xgboost-comprehensive/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

echo "Starting server"
python server.py &
sleep 15 # Sleep for 15s to give the server enough time to start

for i in `seq 0 1`; do
echo "Starting client $i"
python3 client.py --node-id=$i &
done

# Enable CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
Loading

0 comments on commit f056175

Please sign in to comment.