Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xgboost-comprehensive: move xgb_bagging strategy to core framework and update readme #2611

Merged
merged 13 commits into from
Nov 17, 2023
8 changes: 4 additions & 4 deletions examples/xgboost-comprehensive/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ Tree-based with bagging method is used for aggregation on the server.
Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-xgboost . && rm -rf flower && cd quickstart-xgboost
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/xgboost-comprehensive . && rm -rf flower && cd xgboost-comprehensive
```

This will create a new directory called `quickstart-xgboost` containing the following files:
This will create a new directory called `xgboost-comprehensive` containing the following files:

```
-- README.md <- Your're reading this right now
-- server.py <- Defines the server-side logic
-- strategy.py <- Defines the tree-based bagging aggregation
-- client.py <- Defines the client-side logic
-- dataset.py <- Defines the functions of data loading and partitioning
-- utils.py <- Defines the arguments parser for clients and server.
-- pyproject.toml <- Example dependencies (if you use Poetry)
-- requirements.txt <- Example dependencies
```
Expand Down Expand Up @@ -83,5 +83,5 @@ bash run.sh
```

Besides, we provide options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`).
Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-xgboost)
Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive)
and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation.
2 changes: 1 addition & 1 deletion examples/xgboost-comprehensive/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import xgboost as xgb
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Union
from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.partitioner import (
IidPartitioner,
Expand Down
2 changes: 1 addition & 1 deletion examples/xgboost-comprehensive/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from flwr.common.logger import log
from flwr.common import Parameters, Scalar
from flwr_datasets import FederatedDataset
from flwr.server.strategy import FedXgbBagging

from strategy import FedXgbBagging
from utils import server_args_parser
from dataset import resplit, transform_dataset_to_dmatrix

Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .fedopt import FedOpt as FedOpt
from .fedprox import FedProx as FedProx
from .fedtrimmedavg import FedTrimmedAvg as FedTrimmedAvg
from .fedxgb_bagging import FedXgbBagging as FedXgbBagging
from .fedxgb_nn_avg import FedXgbNnAvg as FedXgbNnAvg
from .fedyogi import FedYogi as FedYogi
from .krum import Krum as Krum
Expand All @@ -40,6 +41,7 @@
"FedAdam",
"FedAvg",
"FedXgbNnAvg",
"FedXgbBagging",
"FedAvgAndroid",
"FedAvgM",
"FedOpt",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,35 @@
from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union
import flwr as fl
# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Federated XGBoost bagging aggregation strategy."""


import json
from logging import WARNING
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast

from flwr.common import (
EvaluateRes,
FitRes,
Parameters,
Scalar,
)
from flwr.server.client_proxy import ClientProxy
from flwr.common import EvaluateRes, FitRes, Parameters, Scalar
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy

from .fedavg import FedAvg


class FedXgbBagging(fl.server.strategy.FedAvg):
class FedXgbBagging(FedAvg):
"""Configurable FedXgbBagging strategy implementation."""

# pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
def __init__(
self,
evaluate_function: Optional[
Expand All @@ -22,10 +38,10 @@ def __init__(
Optional[Tuple[float, Dict[str, Scalar]]],
]
] = None,
**kwargs,
**kwargs: Any,
):
self.evaluate_function = evaluate_function
self.global_model = None
self.global_model: Optional[bytes] = None
super().__init__(**kwargs)

def aggregate_fit(
Expand All @@ -42,17 +58,16 @@ def aggregate_fit(
return None, {}

# Aggregate all the client trees
global_model = self.global_model
for _, fit_res in results:
update = fit_res.parameters.tensors
for item in update:
self.global_model = aggregate(
self.global_model, json.loads(bytearray(item))
)
for bst in update:
global_model = aggregate(global_model, bst)

weights_avg = json.dumps(self.global_model)
self.global_model = global_model

return (
Parameters(tensor_type="", tensors=[bytes(weights_avg, "utf-8")]),
Parameters(tensor_type="", tensors=[cast(bytes, global_model)]),
{},
)

Expand Down Expand Up @@ -93,37 +108,47 @@ def evaluate(
return loss, metrics


def aggregate(bst_prev: Optional[Dict], bst_curr: Dict) -> Dict:
def aggregate(
bst_prev_org: Optional[bytes],
bst_curr_org: bytes,
) -> bytes:
"""Conduct bagging aggregation for given trees."""
if not bst_prev:
return bst_curr
else:
# Get the tree numbers
tree_num_prev, paral_tree_num_prev = _get_tree_nums(bst_prev)
tree_num_curr, paral_tree_num_curr = _get_tree_nums(bst_curr)

bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
"num_trees"
] = str(tree_num_prev + paral_tree_num_curr)
iteration_indptr = bst_prev["learner"]["gradient_booster"]["model"][
"iteration_indptr"
]
bst_prev["learner"]["gradient_booster"]["model"]["iteration_indptr"].append(
iteration_indptr[-1] + paral_tree_num_curr
if not bst_prev_org:
return bst_curr_org

# Get the tree numbers
tree_num_prev, _ = _get_tree_nums(bst_prev_org)
_, paral_tree_num_curr = _get_tree_nums(bst_curr_org)

bst_prev = json.loads(bytearray(bst_prev_org))
bst_curr = json.loads(bytearray(bst_curr_org))

bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
"num_trees"
] = str(tree_num_prev + paral_tree_num_curr)
iteration_indptr = bst_prev["learner"]["gradient_booster"]["model"][
"iteration_indptr"
]
bst_prev["learner"]["gradient_booster"]["model"]["iteration_indptr"].append(
iteration_indptr[-1] + paral_tree_num_curr
)

# Aggregate new trees
trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"]
for tree_count in range(paral_tree_num_curr):
trees_curr[tree_count]["id"] = tree_num_prev + tree_count
bst_prev["learner"]["gradient_booster"]["model"]["trees"].append(
trees_curr[tree_count]
)
bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0)

bst_prev_bytes = bytes(json.dumps(bst_prev), "utf-8")

# Aggregate new trees
trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"]
for tree_count in range(paral_tree_num_curr):
trees_curr[tree_count]["id"] = tree_num_prev + tree_count
bst_prev["learner"]["gradient_booster"]["model"]["trees"].append(
trees_curr[tree_count]
)
bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0)
return bst_prev
return bst_prev_bytes


def _get_tree_nums(xgb_model: Dict) -> (int, int):
def _get_tree_nums(xgb_model_org: bytes) -> Tuple[int, int]:
xgb_model = json.loads(bytearray(xgb_model_org))
# Get the number of trees
tree_num = int(
xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
Expand Down