-
Notifications
You must be signed in to change notification settings - Fork 907
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
yan-gao-GY
committed
Nov 16, 2023
1 parent
7ae14a2
commit 6c842a8
Showing
8 changed files
with
487 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Flower Example using XGBoost | ||
|
||
This example demonstrates how to perform EXtreme Gradient Boosting (XGBoost) within Flower using `xgboost` package. | ||
Tree-based with bagging method is used for aggregation on the server. | ||
|
||
## Project Setup | ||
|
||
Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: | ||
|
||
```shell | ||
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-xgboost . && rm -rf flower && cd quickstart-xgboost | ||
``` | ||
|
||
This will create a new directory called `quickstart-xgboost` containing the following files: | ||
|
||
``` | ||
-- README.md <- Your're reading this right now | ||
-- server.py <- Defines the server-side logic | ||
-- client.py <- Defines the client-side logic | ||
-- dataset.py <- Defines the functions of data loading and partitioning | ||
-- pyproject.toml <- Example dependencies (if you use Poetry) | ||
-- requirements.txt <- Example dependencies | ||
``` | ||
|
||
### Installing Dependencies | ||
|
||
Project dependencies (such as `xgboost` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. | ||
|
||
#### Poetry | ||
|
||
```shell | ||
poetry install | ||
poetry shell | ||
``` | ||
|
||
Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: | ||
|
||
```shell | ||
poetry run python3 -c "import flwr" | ||
``` | ||
|
||
If you don't see any errors you're good to go! | ||
|
||
#### pip | ||
|
||
Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. | ||
|
||
```shell | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Run Federated Learning with XGBoost and Flower | ||
|
||
Afterwards you are ready to start the Flower server as well as the clients. | ||
You can simply start the server in a terminal as follows: | ||
|
||
```shell | ||
python3 server.py | ||
``` | ||
|
||
Now you are ready to start the Flower clients which will participate in the learning. | ||
To do so simply open two more terminal windows and run the following commands. | ||
|
||
Start client 1 in the first terminal: | ||
|
||
```shell | ||
python3 client.py | ||
``` | ||
|
||
Start client 2 in the second terminal: | ||
|
||
```shell | ||
python3 client.py | ||
``` | ||
|
||
You will see that XGBoost is starting a federated training. | ||
|
||
Alternatively, you can use `run.sh` to run the same experiment in a single terminal as follows: | ||
|
||
```shell | ||
bash run.sh | ||
``` | ||
|
||
Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-xgboost) | ||
and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import warnings | ||
from logging import INFO | ||
import xgboost as xgb | ||
|
||
import flwr as fl | ||
from flwr_datasets import FederatedDataset | ||
from flwr.common.logger import log | ||
from flwr.common import ( | ||
Code, | ||
EvaluateIns, | ||
EvaluateRes, | ||
FitIns, | ||
FitRes, | ||
GetParametersIns, | ||
GetParametersRes, | ||
Parameters, | ||
Status, | ||
) | ||
from flwr_datasets.partitioner import IidPartitioner | ||
|
||
from dataset import ( | ||
train_test_split, | ||
transform_dataset_to_dmatrix, | ||
) | ||
|
||
|
||
warnings.filterwarnings("ignore", category=UserWarning) | ||
|
||
|
||
# Load (HIGGS) dataset and conduct partitioning | ||
partitioner = IidPartitioner(num_partitions=10) | ||
fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner}) | ||
|
||
# Let's use the first partition as an example | ||
partition = fds.load_partition(idx=0, split="train") | ||
partition.set_format("numpy") | ||
|
||
# Train/test splitting | ||
train_data, valid_data, num_train, num_val = train_test_split( | ||
partition, test_fraction=0.2, seed=42 | ||
) | ||
|
||
# Reformat data to DMatrix for xgboost | ||
train_dmatrix = transform_dataset_to_dmatrix(train_data) | ||
valid_dmatrix = transform_dataset_to_dmatrix(valid_data) | ||
|
||
# Hyper-parameters for xgboost training | ||
num_local_round = 1 | ||
params = { | ||
"objective": "binary:logistic", | ||
"eta": 0.1, # Learning rate | ||
"max_depth": 8, | ||
"eval_metric": "auc", | ||
"nthread": 16, | ||
"num_parallel_tree": 1, | ||
"subsample": 1, | ||
"tree_method": "hist", | ||
} | ||
|
||
|
||
# Define Flower client | ||
class FlowerClient(fl.client.Client): | ||
def __init__(self): | ||
self.bst = None | ||
self.config = None | ||
|
||
def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: | ||
_ = (self, ins) | ||
return GetParametersRes( | ||
status=Status( | ||
code=Code.OK, | ||
message="OK", | ||
), | ||
parameters=Parameters(tensor_type="", tensors=[]), | ||
) | ||
|
||
def _local_boost(self): | ||
# Update trees based on local training data. | ||
for i in range(num_local_round): | ||
self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) | ||
|
||
# Extract the last N=num_local_round trees for sever aggregation | ||
bst = self.bst[ | ||
self.bst.num_boosted_rounds() | ||
- num_local_round : self.bst.num_boosted_rounds() | ||
] | ||
|
||
return bst | ||
|
||
def fit(self, ins: FitIns) -> FitRes: | ||
if not self.bst: | ||
# First round local training | ||
log(INFO, "Start training at round 1") | ||
bst = xgb.train( | ||
params, | ||
train_dmatrix, | ||
num_boost_round=num_local_round, | ||
evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")], | ||
) | ||
self.config = bst.save_config() | ||
self.bst = bst | ||
else: | ||
for item in ins.parameters.tensors: | ||
global_model = bytearray(item) | ||
|
||
# Load global model into booster | ||
self.bst.load_model(global_model) | ||
self.bst.load_config(self.config) | ||
|
||
bst = self._local_boost() | ||
|
||
local_model = bst.save_raw("json") | ||
local_model_bytes = bytes(local_model) | ||
|
||
return FitRes( | ||
status=Status( | ||
code=Code.OK, | ||
message="OK", | ||
), | ||
parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), | ||
num_examples=num_train, | ||
metrics={}, | ||
) | ||
|
||
def evaluate(self, ins: EvaluateIns) -> EvaluateRes: | ||
eval_results = self.bst.eval_set( | ||
evals=[(valid_dmatrix, "valid")], | ||
iteration=self.bst.num_boosted_rounds() - 1, | ||
) | ||
auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) | ||
|
||
return EvaluateRes( | ||
status=Status( | ||
code=Code.OK, | ||
message="OK", | ||
), | ||
loss=0.0, | ||
num_examples=num_val, | ||
metrics={"AUC": auc}, | ||
) | ||
|
||
|
||
# Start Flower client | ||
fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import xgboost as xgb | ||
from typing import Union | ||
from datasets import Dataset, DatasetDict | ||
|
||
|
||
def train_test_split(partition: Dataset, test_fraction: float, seed: int): | ||
"""Split the data into train and validation set given split rate.""" | ||
train_test = partition.train_test_split(test_size=test_fraction, seed=seed) | ||
partition_train = train_test["train"] | ||
partition_test = train_test["test"] | ||
|
||
num_train = len(partition_train) | ||
num_test = len(partition_test) | ||
|
||
return partition_train, partition_test, num_train, num_test | ||
|
||
|
||
def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix: | ||
"""Transform dataset to DMatrix format for xgboost.""" | ||
x = data["inputs"] | ||
y = data["label"] | ||
new_data = xgb.DMatrix(x, label=y) | ||
return new_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
[build-system] | ||
requires = ["poetry-core>=1.4.0"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.poetry] | ||
name = "xgboost-quickstart" | ||
version = "0.1.0" | ||
description = "Federated XGBoost with Flower (quickstart)" | ||
authors = ["The Flower Authors <[email protected]>"] | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.8,<3.11" | ||
flwr = ">=1.0,<2.0" | ||
flwr-datasets = ">=0.0.1,<1.0.0" | ||
xgboost = ">=2.0.0,<3.0.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
flwr>=1.0, <2.0 | ||
flwr-datasets>=0.0.1, <1.0.0 | ||
xgboost>=2.0.0, <3.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#!/bin/bash | ||
set -e | ||
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ | ||
|
||
echo "Starting server" | ||
python server.py & | ||
sleep 5 # Sleep for 15s to give the server enough time to start | ||
|
||
for i in `seq 0 1`; do | ||
echo "Starting client $i" | ||
python3 client.py & | ||
done | ||
|
||
# Enable CTRL+C to stop all background processes | ||
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM | ||
# Wait for all background processes to complete | ||
wait |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import flwr as fl | ||
from strategy import FedXgbBagging | ||
|
||
|
||
# FL experimental settings | ||
pool_size = 2 | ||
num_rounds = 5 | ||
num_clients_per_round = 2 | ||
num_evaluate_clients = 2 | ||
|
||
|
||
def evaluate_metrics_aggregation(eval_metrics): | ||
"""Return an aggregated metric (AUC) for evaluation.""" | ||
total_num = sum([num for num, _ in eval_metrics]) | ||
auc_aggregated = ( | ||
sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num | ||
) | ||
metrics_aggregated = {"AUC": auc_aggregated} | ||
return metrics_aggregated | ||
|
||
|
||
# Define strategy | ||
strategy = FedXgbBagging( | ||
fraction_fit=(float(num_clients_per_round) / pool_size), | ||
min_fit_clients=num_clients_per_round, | ||
min_available_clients=pool_size, | ||
min_evaluate_clients=num_evaluate_clients, | ||
fraction_evaluate=1.0, | ||
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, | ||
) | ||
|
||
# Start Flower server | ||
fl.server.start_server( | ||
server_address="0.0.0.0:8080", | ||
config=fl.server.ServerConfig(num_rounds=num_rounds), | ||
strategy=strategy, | ||
) |
Oops, something went wrong.