-
Notifications
You must be signed in to change notification settings - Fork 898
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Quickstart-xgboost with bagging aggregation (#2554)
Co-authored-by: yan-gao-GY <[email protected]>
- Loading branch information
1 parent
760a456
commit f056175
Showing
9 changed files
with
685 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Flower Example using XGBoost | ||
|
||
This example demonstrates how to perform EXtreme Gradient Boosting (XGBoost) within Flower using `xgboost` package. | ||
Tree-based with bagging method is used for aggregation on the server. | ||
|
||
## Project Setup | ||
|
||
Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: | ||
|
||
```shell | ||
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-xgboost . && rm -rf flower && cd quickstart-xgboost | ||
``` | ||
|
||
This will create a new directory called `quickstart-xgboost` containing the following files: | ||
|
||
``` | ||
-- README.md <- Your're reading this right now | ||
-- server.py <- Defines the server-side logic | ||
-- strategy.py <- Defines the tree-based bagging aggregation | ||
-- client.py <- Defines the client-side logic | ||
-- dataset.py <- Defines the functions of data loading and partitioning | ||
-- pyproject.toml <- Example dependencies (if you use Poetry) | ||
-- requirements.txt <- Example dependencies | ||
``` | ||
|
||
### Installing Dependencies | ||
|
||
Project dependencies (such as `xgboost` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. | ||
|
||
#### Poetry | ||
|
||
```shell | ||
poetry install | ||
poetry shell | ||
``` | ||
|
||
Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: | ||
|
||
```shell | ||
poetry run python3 -c "import flwr" | ||
``` | ||
|
||
If you don't see any errors you're good to go! | ||
|
||
#### pip | ||
|
||
Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. | ||
|
||
```shell | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Run Federated Learning with XGBoost and Flower | ||
|
||
Afterwards you are ready to start the Flower server as well as the clients. | ||
You can simply start the server in a terminal as follows: | ||
|
||
```shell | ||
python3 server.py | ||
``` | ||
|
||
Now you are ready to start the Flower clients which will participate in the learning. | ||
To do so simply open two more terminal windows and run the following commands. | ||
|
||
Start client 1 in the first terminal: | ||
|
||
```shell | ||
python3 client.py --node-id=0 | ||
``` | ||
|
||
Start client 2 in the second terminal: | ||
|
||
```shell | ||
python3 client.py --node-id=1 | ||
``` | ||
|
||
You will see that XGBoost is starting a federated training. | ||
|
||
Alternatively, you can use `run.sh` to run the same experiment in a single terminal as follows: | ||
|
||
```shell | ||
bash run.sh | ||
``` | ||
|
||
Besides, we provide options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`). | ||
Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-xgboost) | ||
and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
import warnings | ||
from logging import INFO | ||
import xgboost as xgb | ||
|
||
import flwr as fl | ||
from flwr_datasets import FederatedDataset | ||
from flwr.common.logger import log | ||
from flwr.common import ( | ||
Code, | ||
EvaluateIns, | ||
EvaluateRes, | ||
FitIns, | ||
FitRes, | ||
GetParametersIns, | ||
GetParametersRes, | ||
Parameters, | ||
Status, | ||
) | ||
|
||
from dataset import ( | ||
instantiate_partitioner, | ||
train_test_split, | ||
transform_dataset_to_dmatrix, | ||
resplit, | ||
) | ||
from utils import client_args_parser | ||
|
||
|
||
warnings.filterwarnings("ignore", category=UserWarning) | ||
|
||
|
||
# Parse arguments for experimental settings | ||
args = client_args_parser() | ||
|
||
# Load (HIGGS) dataset and conduct partitioning | ||
num_partitions = args.num_partitions | ||
|
||
# Partitioner type is chosen from ["uniform", "linear", "square", "exponential"] | ||
partitioner_type = args.partitioner_type | ||
|
||
# Instantiate partitioner | ||
partitioner = instantiate_partitioner( | ||
partitioner_type=partitioner_type, num_partitions=num_partitions | ||
) | ||
fds = FederatedDataset( | ||
dataset="jxie/higgs", partitioners={"train": partitioner}, resplitter=resplit | ||
) | ||
|
||
# Let's use the first partition as an example | ||
node_id = args.node_id | ||
partition = fds.load_partition(idx=node_id, split="train") | ||
partition.set_format("numpy") | ||
|
||
if args.centralised_eval: | ||
# Use centralised test set for evaluation | ||
train_data = partition | ||
valid_data = fds.load_full("test") | ||
valid_data.set_format("numpy") | ||
num_train = train_data.shape[0] | ||
num_val = valid_data.shape[0] | ||
else: | ||
# Train/test splitting | ||
SEED = args.seed | ||
test_fraction = args.test_fraction | ||
train_data, valid_data, num_train, num_val = train_test_split( | ||
partition, test_fraction=test_fraction, seed=SEED | ||
) | ||
|
||
# Reformat data to DMatrix for xgboost | ||
train_dmatrix = transform_dataset_to_dmatrix(train_data) | ||
valid_dmatrix = transform_dataset_to_dmatrix(valid_data) | ||
|
||
|
||
# Hyper-parameters for xgboost training | ||
num_local_round = 1 | ||
params = { | ||
"objective": "binary:logistic", | ||
"eta": 0.1, # Learning rate | ||
"max_depth": 8, | ||
"eval_metric": "auc", | ||
"nthread": 16, | ||
"num_parallel_tree": 1, | ||
"subsample": 1, | ||
"tree_method": "hist", | ||
} | ||
|
||
|
||
# Define Flower client | ||
class FlowerClient(fl.client.Client): | ||
def __init__(self): | ||
self.bst = None | ||
self.config = None | ||
|
||
def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: | ||
_ = (self, ins) | ||
return GetParametersRes( | ||
status=Status( | ||
code=Code.OK, | ||
message="OK", | ||
), | ||
parameters=Parameters(tensor_type="", tensors=[]), | ||
) | ||
|
||
def _local_boost(self): | ||
# Update trees based on local training data. | ||
for i in range(num_local_round): | ||
self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) | ||
|
||
# Extract the last N=num_local_round trees for sever aggregation | ||
bst = self.bst[ | ||
self.bst.num_boosted_rounds() | ||
- num_local_round : self.bst.num_boosted_rounds() | ||
] | ||
|
||
return bst | ||
|
||
def fit(self, ins: FitIns) -> FitRes: | ||
if not self.bst: | ||
# First round local training | ||
log(INFO, "Start training at round 1") | ||
bst = xgb.train( | ||
params, | ||
train_dmatrix, | ||
num_boost_round=num_local_round, | ||
evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")], | ||
) | ||
self.config = bst.save_config() | ||
self.bst = bst | ||
else: | ||
for item in ins.parameters.tensors: | ||
global_model = bytearray(item) | ||
|
||
# Load global model into booster | ||
self.bst.load_model(global_model) | ||
self.bst.load_config(self.config) | ||
|
||
bst = self._local_boost() | ||
|
||
local_model = bst.save_raw("json") | ||
local_model_bytes = bytes(local_model) | ||
|
||
return FitRes( | ||
status=Status( | ||
code=Code.OK, | ||
message="OK", | ||
), | ||
parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), | ||
num_examples=num_train, | ||
metrics={}, | ||
) | ||
|
||
def evaluate(self, ins: EvaluateIns) -> EvaluateRes: | ||
eval_results = self.bst.eval_set( | ||
evals=[(valid_dmatrix, "valid")], | ||
iteration=self.bst.num_boosted_rounds() - 1, | ||
) | ||
auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) | ||
|
||
global_round = ins.config["global_round"] | ||
log(INFO, f"AUC = {auc} at round {global_round}") | ||
|
||
return EvaluateRes( | ||
status=Status( | ||
code=Code.OK, | ||
message="OK", | ||
), | ||
loss=0.0, | ||
num_examples=num_val, | ||
metrics={"AUC": auc}, | ||
) | ||
|
||
|
||
# Start Flower client | ||
fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import xgboost as xgb | ||
from typing import Callable, Dict, List, Optional, Tuple, Union | ||
from datasets import Dataset, DatasetDict, concatenate_datasets | ||
from flwr_datasets.partitioner import ( | ||
IidPartitioner, | ||
LinearPartitioner, | ||
SquarePartitioner, | ||
ExponentialPartitioner, | ||
) | ||
|
||
CORRELATION_TO_PARTITIONER = { | ||
"uniform": IidPartitioner, | ||
"linear": LinearPartitioner, | ||
"square": SquarePartitioner, | ||
"exponential": ExponentialPartitioner, | ||
} | ||
|
||
|
||
def instantiate_partitioner(partitioner_type: str, num_partitions: int): | ||
"""Initialise partitioner based on selected partitioner type and number of | ||
partitions.""" | ||
partitioner = CORRELATION_TO_PARTITIONER[partitioner_type]( | ||
num_partitions=num_partitions | ||
) | ||
return partitioner | ||
|
||
|
||
def train_test_split(partition: Dataset, test_fraction: float, seed: int): | ||
"""Split the data into train and validation set given split rate.""" | ||
train_test = partition.train_test_split(test_size=test_fraction, seed=seed) | ||
partition_train = train_test["train"] | ||
partition_test = train_test["test"] | ||
|
||
num_train = len(partition_train) | ||
num_test = len(partition_test) | ||
|
||
return partition_train, partition_test, num_train, num_test | ||
|
||
|
||
def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix: | ||
"""Transform dataset to DMatrix format for xgboost.""" | ||
x = data["inputs"] | ||
y = data["label"] | ||
new_data = xgb.DMatrix(x, label=y) | ||
return new_data | ||
|
||
|
||
def resplit(dataset: DatasetDict) -> DatasetDict: | ||
"""Increase the quantity of centralised test samples from 500K to 1M.""" | ||
return DatasetDict( | ||
{ | ||
"train": dataset["train"].select( | ||
range(0, dataset["train"].num_rows - 500_000) | ||
), | ||
"test": concatenate_datasets( | ||
[ | ||
dataset["train"].select( | ||
range( | ||
dataset["train"].num_rows - 500_000, | ||
dataset["train"].num_rows, | ||
) | ||
), | ||
dataset["test"], | ||
] | ||
), | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
[build-system] | ||
requires = ["poetry-core>=1.4.0"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.poetry] | ||
name = "xgboost-comprehensive" | ||
version = "0.1.0" | ||
description = "Federated XGBoost with Flower (comprehensive)" | ||
authors = ["The Flower Authors <[email protected]>"] | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.8,<3.11" | ||
flwr = ">=1.0,<2.0" | ||
flwr-datasets = ">=0.0.2,<1.0.0" | ||
xgboost = ">=2.0.0,<3.0.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
flwr>=1.0, <2.0 | ||
flwr-datasets>=0.0.2, <1.0.0 | ||
xgboost>=2.0.0, <3.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#!/bin/bash | ||
set -e | ||
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ | ||
|
||
echo "Starting server" | ||
python server.py & | ||
sleep 15 # Sleep for 15s to give the server enough time to start | ||
|
||
for i in `seq 0 1`; do | ||
echo "Starting client $i" | ||
python3 client.py --node-id=$i & | ||
done | ||
|
||
# Enable CTRL+C to stop all background processes | ||
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM | ||
# Wait for all background processes to complete | ||
wait |
Oops, something went wrong.