Skip to content

Commit

Permalink
New XGBoost strategy: cyclic training (#2666)
Browse files Browse the repository at this point in the history
Co-authored-by: yan-gao-GY <[email protected]>
  • Loading branch information
yan-gao-GY and yan-gao-GY authored Dec 6, 2023
1 parent c3347a4 commit 5be6671
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 24 deletions.
23 changes: 17 additions & 6 deletions examples/xgboost-comprehensive/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ It differs from the [xgboost-quickstart](https://github.com/adap/flower/tree/mai
- Customised number of partitions.
- Customised partitioner type (uniform, linear, square, exponential).
- Centralised/distributed evaluation.
- Bagging/cyclic training methods.

## Project Setup

Expand All @@ -26,7 +27,8 @@ This will create a new directory called `xgboost-comprehensive` containing the f
-- client.py <- Defines the client-side logic
-- dataset.py <- Defines the functions of data loading and partitioning
-- utils.py <- Defines the arguments parser for clients and server
-- run.sh <- Commands to run experiments
-- run_bagging.sh <- Commands to run bagging experiments
-- run_cyclic.sh <- Commands to run cyclic experiments
-- pyproject.toml <- Example dependencies (if you use Poetry)
-- requirements.txt <- Example dependencies
```
Expand Down Expand Up @@ -60,24 +62,31 @@ pip install -r requirements.txt

## Run Federated Learning with XGBoost and Flower

The included `run.sh` will start the Flower server (using `server.py`) with centralised evaluation,
We have two scripts to run bagging and cyclic (client-by-client) experiments.
The included `run_bagging.sh` or `run_cyclic.sh` will start the Flower server (using `server.py`),
sleep for 15 seconds to ensure that the server is up,
and then start 5 Flower clients (using `client.py`) with a small subset of the data from exponential partition distribution.
You can simply start everything in a terminal as follows:

```shell
poetry run ./run.sh
poetry run ./run_bagging.sh
```

The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows.
Or

```shell
poetry run ./run_cyclic.sh
```

The script starts processes in the background so that you don't have to open eleven terminal windows.
If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes,
which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`.
This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`).
If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`)
to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill).

You can also manually run `poetry run python3 server.py --pool-size=N --num-clients-per-round=N`
and `poetry run python3 client.py --node-id=NODE_ID --num-partitions=N` for as many clients as you want,
You can also manually run `poetry run python3 server.py --train-method=bagging/cyclic --pool-size=N --num-clients-per-round=N`
and `poetry run python3 client.py --train-method=bagging/cyclic --node-id=NODE_ID --num-partitions=N` for as many clients as you want,
but you have to make sure that each command is run in a different terminal window (or a different computer on the network).

In addition, we provide more options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`).
Expand All @@ -86,6 +95,8 @@ and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.htm

### Expected Experimental Results

#### Bagging aggregation experiment

![](_static/xgboost_flower_auc.png)

The figure above shows the centralised tested AUC performance over FL rounds on 4 experimental settings.
Expand Down
15 changes: 10 additions & 5 deletions examples/xgboost-comprehensive/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,16 @@ def _local_boost(self):
for i in range(num_local_round):
self.bst.update(train_dmatrix, self.bst.num_boosted_rounds())

# Extract the last N=num_local_round trees for sever aggregation
bst = self.bst[
self.bst.num_boosted_rounds()
- num_local_round : self.bst.num_boosted_rounds()
]
# Bagging: extract the last N=num_local_round trees for sever aggregation
# Cyclic: return the entire model
bst = (
self.bst[
self.bst.num_boosted_rounds()
- num_local_round : self.bst.num_boosted_rounds()
]
if args.train_method == "bagging"
else self.bst
)

return bst

Expand Down
File renamed without changes.
17 changes: 17 additions & 0 deletions examples/xgboost-comprehensive/run_cyclic.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

echo "Starting server"
python3 server.py --train-method=cyclic --pool-size=5 --num-rounds=100 &
sleep 15 # Sleep for 15s to give the server enough time to start

for i in `seq 0 4`; do
echo "Starting client $i"
python3 client.py --node-id=$i --train-method=cyclic --num-partitions=5 --partitioner-type=exponential --centralised-eval &
done

# Enable CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
84 changes: 71 additions & 13 deletions examples/xgboost-comprehensive/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
import warnings
from typing import Dict, List, Optional
from logging import INFO
import xgboost as xgb

Expand All @@ -7,13 +8,21 @@
from flwr.common import Parameters, Scalar
from flwr_datasets import FederatedDataset
from flwr.server.strategy import FedXgbBagging
from flwr.server.strategy import FedXgbCyclic
from flwr.server.client_proxy import ClientProxy
from flwr.server.criterion import Criterion
from flwr.server.client_manager import SimpleClientManager

from utils import server_args_parser, BST_PARAMS
from dataset import resplit, transform_dataset_to_dmatrix


warnings.filterwarnings("ignore", category=UserWarning)


# Parse arguments for experimental settings
args = server_args_parser()
train_method = args.train_method
pool_size = args.pool_size
num_rounds = args.num_rounds
num_clients_per_round = args.num_clients_per_round
Expand Down Expand Up @@ -80,23 +89,72 @@ def evaluate_fn(
return evaluate_fn


class CyclicClientManager(SimpleClientManager):
"""Provides a cyclic client selection rule."""

def sample(
self,
num_clients: int,
min_num_clients: Optional[int] = None,
criterion: Optional[Criterion] = None,
) -> List[ClientProxy]:
"""Sample a number of Flower ClientProxy instances."""

# Block until at least num_clients are connected.
if min_num_clients is None:
min_num_clients = num_clients
self.wait_for(min_num_clients)

# Sample clients which meet the criterion
available_cids = list(self.clients)
if criterion is not None:
available_cids = [
cid for cid in available_cids if criterion.select(self.clients[cid])
]

if num_clients > len(available_cids):
log(
INFO,
"Sampling failed: number of available clients"
" (%s) is less than number of requested clients (%s).",
len(available_cids),
num_clients,
)
return []

# Return all available clients
return [self.clients[cid] for cid in available_cids]


# Define strategy
strategy = FedXgbBagging(
evaluate_function=get_evaluate_fn(test_dmatrix) if centralised_eval else None,
fraction_fit=(float(num_clients_per_round) / pool_size),
min_fit_clients=num_clients_per_round,
min_available_clients=pool_size,
min_evaluate_clients=num_evaluate_clients if not centralised_eval else 0,
fraction_evaluate=1.0 if not centralised_eval else 0.0,
on_evaluate_config_fn=eval_config,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation
if not centralised_eval
else None,
)
if train_method == "bagging":
# Bagging training
strategy = FedXgbBagging(
evaluate_function=get_evaluate_fn(test_dmatrix) if centralised_eval else None,
fraction_fit=(float(num_clients_per_round) / pool_size),
min_fit_clients=num_clients_per_round,
min_available_clients=pool_size,
min_evaluate_clients=num_evaluate_clients if not centralised_eval else 0,
fraction_evaluate=1.0 if not centralised_eval else 0.0,
on_evaluate_config_fn=eval_config,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation
if not centralised_eval
else None,
)
else:
# Cyclic training
strategy = FedXgbCyclic(
fraction_fit=1.0,
min_available_clients=pool_size,
fraction_evaluate=1.0,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation,
on_evaluate_config_fn=eval_config,
)

# Start Flower server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=num_rounds),
strategy=strategy,
client_manager=CyclicClientManager() if train_method == "cyclic" else None,
)
14 changes: 14 additions & 0 deletions examples/xgboost-comprehensive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ def client_args_parser():
"""Parse arguments to define experimental settings on client side."""
parser = argparse.ArgumentParser()

parser.add_argument(
"--train-method",
default="bagging",
type=str,
choices=["bagging", "cyclic"],
help="Training methods selected from bagging aggregation or cyclic training.",
)
parser.add_argument(
"--num-partitions", default=10, type=int, help="Number of partitions."
)
Expand Down Expand Up @@ -56,6 +63,13 @@ def server_args_parser():
"""Parse arguments to define experimental settings on server side."""
parser = argparse.ArgumentParser()

parser.add_argument(
"--train-method",
default="bagging",
type=str,
choices=["bagging", "cyclic"],
help="Training methods selected from bagging aggregation or cyclic training.",
)
parser.add_argument(
"--pool-size", default=2, type=int, help="Number of total clients."
)
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .fedprox import FedProx as FedProx
from .fedtrimmedavg import FedTrimmedAvg as FedTrimmedAvg
from .fedxgb_bagging import FedXgbBagging as FedXgbBagging
from .fedxgb_cyclic import FedXgbCyclic as FedXgbCyclic
from .fedxgb_nn_avg import FedXgbNnAvg as FedXgbNnAvg
from .fedyogi import FedYogi as FedYogi
from .krum import Krum as Krum
Expand All @@ -42,6 +43,7 @@
"FedAvg",
"FedXgbNnAvg",
"FedXgbBagging",
"FedXgbCyclic",
"FedAvgAndroid",
"FedAvgM",
"FedOpt",
Expand Down
Loading

0 comments on commit 5be6671

Please sign in to comment.