Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xgboost-comprehensive with bagging aggregation #2554

Merged
merged 35 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8529119
Initialise XGBoost
Nov 1, 2023
fe2d3e0
Upload readme
Nov 1, 2023
7efcdbf
Pass the number of examples to server and do formatting
Nov 2, 2023
6ddb448
Add flwr_datasets as required package
Nov 2, 2023
34452a5
Change dataset loading structure
Nov 6, 2023
b272f11
Clean up env; add weighted AUC aggregation
Nov 7, 2023
0fb9d36
Replace print with log
Nov 7, 2023
d0b33be
Add file description in readme
Nov 7, 2023
91af548
Add arguments parser on client side; Do formatting
Nov 8, 2023
b965843
Update required package flwr-datasets==0.02; pull back run.sh
Nov 9, 2023
6bbee08
Move argument parser to utils; Modify comments
Nov 9, 2023
4b92c81
Add feature of centralised/client evaluation
Nov 9, 2023
57bde18
Correct aggregation and match Navida results
Nov 12, 2023
727a028
formatting
Nov 12, 2023
237ea18
Merge branch 'main' into xgboost
yan-gao-GY Nov 12, 2023
c91f7ad
Add type hints
Nov 13, 2023
054f17f
Update readme and run.sh
Nov 14, 2023
6973ecb
Merge branch 'main' into xgboost
danieljanes Nov 15, 2023
28b7305
Format readme
Nov 15, 2023
39ced8b
Merge branch 'xgboost' of https://github.com/adap/flower into xgboost
Nov 15, 2023
ebe35a5
Update examples/quickstart-xgboost/client.py
yan-gao-GY Nov 15, 2023
749a960
Update examples/quickstart-xgboost/README.md
yan-gao-GY Nov 15, 2023
40d0bed
Update examples/quickstart-xgboost/README.md
yan-gao-GY Nov 15, 2023
5cdfc37
Update examples/quickstart-xgboost/utils.py
yan-gao-GY Nov 15, 2023
c2f85c0
Update examples/quickstart-xgboost/utils.py
yan-gao-GY Nov 15, 2023
b63a0b0
Update examples/quickstart-xgboost/strategy.py
yan-gao-GY Nov 15, 2023
5484a47
Format arguments parser
Nov 15, 2023
0ad1c43
Update run.sh
Nov 15, 2023
b948294
Change strategy name to FedXgbBagging
Nov 15, 2023
68d065e
Rename to xgboost-comprehensive
Nov 15, 2023
b0348c9
Recover readme
Nov 15, 2023
f10d4b5
Update examples/xgboost-comprehensive/pyproject.toml
danieljanes Nov 15, 2023
a0c65eb
Update examples/xgboost-comprehensive/pyproject.toml
danieljanes Nov 15, 2023
5656bb4
Update examples/xgboost-comprehensive/pyproject.toml
danieljanes Nov 15, 2023
35207f5
Merge branch 'main' into xgboost
danieljanes Nov 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions examples/quickstart-xgboost/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Flower Example using XGBoost

This example demonstrates how to perform EXtreme Gradient Boosting (XGBoost) within Flower using `xgboost` package.
Tree-based with bagging method is used for aggregation on the server.

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-xgboost . && rm -rf flower && cd quickstart-xgboost
```

This will create a new directory called `quickstart-xgboost` containing the following files:

```
-- README.md <- Your're reading this right now
-- server.py <- Defines the server-side logic
-- strategy.py <- Defines the tree-based bagging aggregation
-- client.py <- Defines the client-side logic
-- dataset.py <- Defines the functions of data loading and partitioning
-- pyproject.toml <- Example dependencies (if you use Poetry)
-- requirements.txt <- Example dependencies
```

### Installing Dependencies

Project dependencies (such as `xgboost` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

#### Poetry

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
poetry run python3 -c "import flwr"
```

If you don't see any errors you're good to go!

#### pip

Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt.

```shell
pip install -r requirements.txt
```

## Run Federated Learning with XGBoost and Flower

Afterwards you are ready to start the Flower server as well as the clients.
You can simply start the server in a terminal as follows:

```shell
python3 server.py
```

Now you are ready to start the Flower clients which will participate in the learning.
To do so simply open two more terminal windows and run the following commands.

Start client 1 in the first terminal:

```shell
python3 client.py partition_id=0
yan-gao-GY marked this conversation as resolved.
Show resolved Hide resolved
```

Start client 2 in the second terminal:

```shell
python3 client.py partition_id=1
yan-gao-GY marked this conversation as resolved.
Show resolved Hide resolved
```

You will see that XGBoost is starting a federated training.

Alternatively, you can use `run.sh` to run the same experiment in a single terminal as follows:

```shell
bash run.sh
```

Besides, we provide options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`).
Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-xgboost)
and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation.
173 changes: 173 additions & 0 deletions examples/quickstart-xgboost/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import warnings
from logging import INFO
import xgboost as xgb

import flwr as fl
from flwr_datasets import FederatedDataset
from flwr.common.logger import log
from flwr.common import (
Code,
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
GetParametersIns,
GetParametersRes,
Parameters,
Status,
)

from dataset import (
instantiate_partitioner,
train_test_split,
transform_dataset_to_dmatrix,
resplit,
)
from utils import client_args_parser


warnings.filterwarnings("ignore", category=UserWarning)


# Parse arguments for experimental settings
args = client_args_parser()

# Load (HIGGS) dataset and conduct partitioning
num_partitions = args.num_partitions
# Partitioner type is chosen from ["uniform", "linear", "square", "exponential"]
yan-gao-GY marked this conversation as resolved.
Show resolved Hide resolved
partitioner_type = args.partitioner_type

# Instantiate partitioner
partitioner = instantiate_partitioner(
partitioner_type=partitioner_type, num_partitions=num_partitions
)
fds = FederatedDataset(
dataset="jxie/higgs", partitioners={"train": partitioner}, resplitter=resplit
)

# Let's use the first partition as an example
partition_id = args.partition_id
partition = fds.load_partition(idx=partition_id, split="train")
partition.set_format("numpy")

if args.centralised_eval:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one more question. In the case of centralized eval each of the (federated) nodes also uses centralized dataset for the federated evaluation. Is that intended, or is it controlled in the server?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing centralised eval or not is controlled by server with --centralised_eval. If not centralised eval, the user can still choose to use centralised test set or client test set (splitting from client's training data) to do the client evaluation. e.g., doing client.py --centralised_eval will enable the client evaluation on centralised test set.

# Use centralised test set for evaluation
train_data = partition
valid_data = fds.load_full("test")
valid_data.set_format("numpy")
num_train = train_data.shape[0]
num_val = valid_data.shape[0]
else:
# Train/test splitting
SEED = args.seed
test_fraction = args.test_fraction
train_data, valid_data, num_train, num_val = train_test_split(
partition, test_fraction=test_fraction, seed=SEED
)

# Reformat data to DMatrix for xgboost
train_dmatrix = transform_dataset_to_dmatrix(train_data)
valid_dmatrix = transform_dataset_to_dmatrix(valid_data)


# Hyper-parameters for xgboost training
num_local_round = 1
params = {
"objective": "binary:logistic",
"eta": 0.1, # Learning rate
"max_depth": 8,
"eval_metric": "auc",
"nthread": 16,
"num_parallel_tree": 1,
"subsample": 1,
"tree_method": "hist",
}


# Define Flower client
class FlowerClient(fl.client.Client):
def __init__(self):
self.bst = None
self.config = None

def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
_ = (self, ins)
return GetParametersRes(
status=Status(
code=Code.OK,
message="OK",
),
parameters=Parameters(tensor_type="", tensors=[]),
)

def _local_boost(self):
# Update trees based on local training data.
for i in range(num_local_round):
self.bst.update(train_dmatrix, self.bst.num_boosted_rounds())

# Extract the last N=num_local_round trees for sever aggregation
bst = self.bst[
self.bst.num_boosted_rounds()
- num_local_round : self.bst.num_boosted_rounds()
]

return bst

def fit(self, ins: FitIns) -> FitRes:
if not self.bst:
# First round local training
log(INFO, "Start training at round 1")
bst = xgb.train(
params,
train_dmatrix,
num_boost_round=num_local_round,
evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")],
)
self.config = bst.save_config()
self.bst = bst
else:
for item in ins.parameters.tensors:
global_model = bytearray(item)

# Load global model into booster
self.bst.load_model(global_model)
self.bst.load_config(self.config)

bst = self._local_boost()

local_model = bst.save_raw("json")
local_model_bytes = bytes(local_model)

return FitRes(
status=Status(
code=Code.OK,
message="OK",
),
parameters=Parameters(tensor_type="", tensors=[local_model_bytes]),
num_examples=num_train,
metrics={},
)

def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
eval_results = self.bst.eval_set(
evals=[(valid_dmatrix, "valid")],
iteration=self.bst.num_boosted_rounds() - 1,
)
auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4)

global_round = ins.config["global_round"]
log(INFO, f"AUC = {auc} at round {global_round}")

return EvaluateRes(
status=Status(
code=Code.OK,
message="OK",
),
loss=0.0,
num_examples=num_val,
metrics={"AUC": auc},
)


# Start Flower client
fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient())
67 changes: 67 additions & 0 deletions examples/quickstart-xgboost/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import xgboost as xgb
from typing import Callable, Dict, List, Optional, Tuple, Union
from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.partitioner import (
IidPartitioner,
LinearPartitioner,
SquarePartitioner,
ExponentialPartitioner,
)

CORRELATION_TO_PARTITIONER = {
"uniform": IidPartitioner,
"linear": LinearPartitioner,
"square": SquarePartitioner,
"exponential": ExponentialPartitioner,
}


def instantiate_partitioner(partitioner_type: str, num_partitions: int):
"""Initialise partitioner based on selected partitioner type and number of
partitions."""
partitioner = CORRELATION_TO_PARTITIONER[partitioner_type](
num_partitions=num_partitions
)
return partitioner


def train_test_split(partition: Dataset, test_fraction: float, seed: int):
"""Split the data into train and validation set given split rate."""
train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
partition_train = train_test["train"]
partition_test = train_test["test"]

num_train = len(partition_train)
num_test = len(partition_test)

return partition_train, partition_test, num_train, num_test


def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix:
"""Transform dataset to DMatrix format for xgboost."""
x = data["inputs"]
y = data["label"]
new_data = xgb.DMatrix(x, label=y)
return new_data


def resplit(dataset: DatasetDict) -> DatasetDict:
"""Increase the quantity of centralised test samples from 500K to 1M."""
return DatasetDict(
{
"train": dataset["train"].select(
range(0, dataset["train"].num_rows - 500_000)
),
"test": concatenate_datasets(
[
dataset["train"].select(
range(
dataset["train"].num_rows - 500_000,
dataset["train"].num_rows,
)
),
dataset["test"],
]
),
}
)
15 changes: 15 additions & 0 deletions examples/quickstart-xgboost/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "quickstart-xgboost"
version = "0.1.0"
description = "Federated XGBoost Quickstart with Flower"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
flwr-datasets = ">=0.0.2,<1.0.0"
xgboost = ">=2.0.0,<3.0.0"
3 changes: 3 additions & 0 deletions examples/quickstart-xgboost/requirements.txt
adam-narozniak marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
flwr>=1.0, <2.0
flwr-datasets>=0.0.2, <1.0.0
xgboost>=2.0.0, <3.0.0
17 changes: 17 additions & 0 deletions examples/quickstart-xgboost/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

echo "Starting server"
python server.py &
sleep 15 # Sleep for 15s to give the server enough time to start

for i in `seq 0 1`; do
echo "Starting client $i"
python client.py --partition_id=$i &
done

# Enable CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
Loading
Loading