Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New XGBoost strategy: cyclic training #2666

Merged
merged 7 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions examples/xgboost-comprehensive/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ It differs from the [xgboost-quickstart](https://github.com/adap/flower/tree/mai
- Customised number of partitions.
- Customised partitioner type (uniform, linear, square, exponential).
- Centralised/distributed evaluation.
- Bagging/cyclic training methods.

## Project Setup

Expand All @@ -26,7 +27,8 @@ This will create a new directory called `xgboost-comprehensive` containing the f
-- client.py <- Defines the client-side logic
-- dataset.py <- Defines the functions of data loading and partitioning
-- utils.py <- Defines the arguments parser for clients and server
-- run.sh <- Commands to run experiments
-- run_bagging.sh <- Commands to run bagging experiments
-- run_cyclic.sh <- Commands to run cyclic experiments
-- pyproject.toml <- Example dependencies (if you use Poetry)
-- requirements.txt <- Example dependencies
```
Expand Down Expand Up @@ -60,24 +62,31 @@ pip install -r requirements.txt

## Run Federated Learning with XGBoost and Flower

The included `run.sh` will start the Flower server (using `server.py`) with centralised evaluation,
We have two scripts to run bagging and cyclic (client-by-client) experiments.
The included `run_bagging.sh` or `run_cyclic.sh` will start the Flower server (using `server.py`),
sleep for 15 seconds to ensure that the server is up,
and then start 5 Flower clients (using `client.py`) with a small subset of the data from exponential partition distribution.
You can simply start everything in a terminal as follows:

```shell
poetry run ./run.sh
poetry run ./run_bagging.sh
```

The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows.
Or

```shell
poetry run ./run_cyclic.sh
```

The script starts processes in the background so that you don't have to open eleven terminal windows.
If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes,
which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`.
This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`).
If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`)
to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill).

You can also manually run `poetry run python3 server.py --pool-size=N --num-clients-per-round=N`
and `poetry run python3 client.py --node-id=NODE_ID --num-partitions=N` for as many clients as you want,
You can also manually run `poetry run python3 server.py --train-method=bagging/cyclic --pool-size=N --num-clients-per-round=N`
and `poetry run python3 client.py --train-method=bagging/cyclic --node-id=NODE_ID --num-partitions=N` for as many clients as you want,
but you have to make sure that each command is run in a different terminal window (or a different computer on the network).

In addition, we provide more options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`).
Expand All @@ -86,6 +95,8 @@ and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.htm

### Expected Experimental Results

#### Bagging aggregation experiment

![](_static/xgboost_flower_auc.png)

The figure above shows the centralised tested AUC performance over FL rounds on 4 experimental settings.
Expand Down
15 changes: 10 additions & 5 deletions examples/xgboost-comprehensive/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,16 @@ def _local_boost(self):
for i in range(num_local_round):
self.bst.update(train_dmatrix, self.bst.num_boosted_rounds())

# Extract the last N=num_local_round trees for sever aggregation
bst = self.bst[
self.bst.num_boosted_rounds()
- num_local_round : self.bst.num_boosted_rounds()
]
# Bagging: extract the last N=num_local_round trees for sever aggregation
# Cyclic: return the entire model
bst = (
self.bst[
self.bst.num_boosted_rounds()
- num_local_round : self.bst.num_boosted_rounds()
]
if args.train_method == "bagging"
else self.bst
)

return bst

Expand Down
17 changes: 17 additions & 0 deletions examples/xgboost-comprehensive/run_cyclic.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

echo "Starting server"
python3 server.py --train-method=cyclic --pool-size=5 --num-rounds=100 &
sleep 15 # Sleep for 15s to give the server enough time to start

for i in `seq 0 4`; do
echo "Starting client $i"
python3 client.py --node-id=$i --train-method=cyclic --num-partitions=5 --partitioner-type=exponential --centralised-eval &
done

# Enable CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
82 changes: 69 additions & 13 deletions examples/xgboost-comprehensive/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
import warnings
from typing import Dict, List, Optional
from logging import INFO
import xgboost as xgb

Expand All @@ -7,13 +8,21 @@
from flwr.common import Parameters, Scalar
from flwr_datasets import FederatedDataset
from flwr.server.strategy import FedXgbBagging
from flwr.server.strategy import FedXgbCyclic
from flwr.server.client_proxy import ClientProxy
from flwr.server.criterion import Criterion
from flwr.server.client_manager import SimpleClientManager

from utils import server_args_parser, BST_PARAMS
from dataset import resplit, transform_dataset_to_dmatrix


warnings.filterwarnings("ignore", category=UserWarning)


# Parse arguments for experimental settings
args = server_args_parser()
train_method = args.train_method
pool_size = args.pool_size
num_rounds = args.num_rounds
num_clients_per_round = args.num_clients_per_round
Expand Down Expand Up @@ -80,23 +89,70 @@ def evaluate_fn(
return evaluate_fn


class CyclicClientManager(SimpleClientManager):
"""Provides a cyclic client selection rule."""

def sample(
self,
num_clients: int,
min_num_clients: Optional[int] = None,
criterion: Optional[Criterion] = None,
) -> List[ClientProxy]:
"""Sample a number of Flower ClientProxy instances."""
# Block until at least num_clients are connected.
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
if min_num_clients is None:
min_num_clients = num_clients
self.wait_for(min_num_clients)
# Sample clients which meet the criterion
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
available_cids = list(self.clients)
if criterion is not None:
available_cids = [
cid for cid in available_cids if criterion.select(self.clients[cid])
]

if num_clients > len(available_cids):
log(
INFO,
"Sampling failed: number of available clients"
" (%s) is less than number of requested clients (%s).",
len(available_cids),
num_clients,
)
return []

# Return all available clients
return [self.clients[cid] for cid in available_cids]


# Define strategy
strategy = FedXgbBagging(
evaluate_function=get_evaluate_fn(test_dmatrix) if centralised_eval else None,
fraction_fit=(float(num_clients_per_round) / pool_size),
min_fit_clients=num_clients_per_round,
min_available_clients=pool_size,
min_evaluate_clients=num_evaluate_clients if not centralised_eval else 0,
fraction_evaluate=1.0 if not centralised_eval else 0.0,
on_evaluate_config_fn=eval_config,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation
if not centralised_eval
else None,
)
if train_method == "bagging":
# Bagging training
strategy = FedXgbBagging(
evaluate_function=get_evaluate_fn(test_dmatrix) if centralised_eval else None,
fraction_fit=(float(num_clients_per_round) / pool_size),
min_fit_clients=num_clients_per_round,
min_available_clients=pool_size,
min_evaluate_clients=num_evaluate_clients if not centralised_eval else 0,
fraction_evaluate=1.0 if not centralised_eval else 0.0,
on_evaluate_config_fn=eval_config,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation
if not centralised_eval
else None,
)
else:
# Cyclic training
strategy = FedXgbCyclic(
fraction_fit=1.0,
min_available_clients=pool_size,
fraction_evaluate=1.0,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation,
on_evaluate_config_fn=eval_config,
)

# Start Flower server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=num_rounds),
strategy=strategy,
client_manager=CyclicClientManager() if train_method == "cyclic" else None,
)
14 changes: 14 additions & 0 deletions examples/xgboost-comprehensive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ def client_args_parser():
"""Parse arguments to define experimental settings on client side."""
parser = argparse.ArgumentParser()

parser.add_argument(
"--train-method",
default="bagging",
type=str,
choices=["bagging", "cyclic"],
help="Training methods selected from bagging aggregation or cyclic training.",
)
parser.add_argument(
"--num-partitions", default=10, type=int, help="Number of partitions."
)
Expand Down Expand Up @@ -56,6 +63,13 @@ def server_args_parser():
"""Parse arguments to define experimental settings on server side."""
parser = argparse.ArgumentParser()

parser.add_argument(
"--train-method",
default="bagging",
type=str,
choices=["bagging", "cyclic"],
help="Training methods selected from bagging aggregation or cyclic training.",
)
parser.add_argument(
"--pool-size", default=2, type=int, help="Number of total clients."
)
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .fedprox import FedProx as FedProx
from .fedtrimmedavg import FedTrimmedAvg as FedTrimmedAvg
from .fedxgb_bagging import FedXgbBagging as FedXgbBagging
from .fedxgb_cyclic import FedXgbCyclic as FedXgbCyclic
from .fedxgb_nn_avg import FedXgbNnAvg as FedXgbNnAvg
from .fedyogi import FedYogi as FedYogi
from .krum import Krum as Krum
Expand All @@ -42,6 +43,7 @@
"FedAvg",
"FedXgbNnAvg",
"FedXgbBagging",
"FedXgbCyclic",
"FedAvgAndroid",
"FedAvgM",
"FedOpt",
Expand Down
Loading