Skip to content

Commit

Permalink
Xgboost-comprehensive: move xgb_bagging strategy to core framework an…
Browse files Browse the repository at this point in the history
…d update readme (#2611)

Co-authored-by: yan-gao-GY <[email protected]>
Co-authored-by: Heng Pan <[email protected]>
  • Loading branch information
3 people authored Nov 17, 2023
1 parent c3cf430 commit 3b54412
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 51 deletions.
8 changes: 4 additions & 4 deletions examples/xgboost-comprehensive/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ Tree-based with bagging method is used for aggregation on the server.
Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-xgboost . && rm -rf flower && cd quickstart-xgboost
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/xgboost-comprehensive . && rm -rf flower && cd xgboost-comprehensive
```

This will create a new directory called `quickstart-xgboost` containing the following files:
This will create a new directory called `xgboost-comprehensive` containing the following files:

```
-- README.md <- Your're reading this right now
-- server.py <- Defines the server-side logic
-- strategy.py <- Defines the tree-based bagging aggregation
-- client.py <- Defines the client-side logic
-- dataset.py <- Defines the functions of data loading and partitioning
-- utils.py <- Defines the arguments parser for clients and server.
-- pyproject.toml <- Example dependencies (if you use Poetry)
-- requirements.txt <- Example dependencies
```
Expand Down Expand Up @@ -83,5 +83,5 @@ bash run.sh
```

Besides, we provide options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`).
Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-xgboost)
Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive)
and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation.
2 changes: 1 addition & 1 deletion examples/xgboost-comprehensive/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import xgboost as xgb
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Union
from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.partitioner import (
IidPartitioner,
Expand Down
2 changes: 1 addition & 1 deletion examples/xgboost-comprehensive/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from flwr.common.logger import log
from flwr.common import Parameters, Scalar
from flwr_datasets import FederatedDataset
from flwr.server.strategy import FedXgbBagging

from strategy import FedXgbBagging
from utils import server_args_parser
from dataset import resplit, transform_dataset_to_dmatrix

Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .fedopt import FedOpt as FedOpt
from .fedprox import FedProx as FedProx
from .fedtrimmedavg import FedTrimmedAvg as FedTrimmedAvg
from .fedxgb_bagging import FedXgbBagging as FedXgbBagging
from .fedxgb_nn_avg import FedXgbNnAvg as FedXgbNnAvg
from .fedyogi import FedYogi as FedYogi
from .krum import Krum as Krum
Expand All @@ -40,6 +41,7 @@
"FedAdam",
"FedAvg",
"FedXgbNnAvg",
"FedXgbBagging",
"FedAvgAndroid",
"FedAvgM",
"FedOpt",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,35 @@
from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union
import flwr as fl
# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Federated XGBoost bagging aggregation strategy."""


import json
from logging import WARNING
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast

from flwr.common import (
EvaluateRes,
FitRes,
Parameters,
Scalar,
)
from flwr.server.client_proxy import ClientProxy
from flwr.common import EvaluateRes, FitRes, Parameters, Scalar
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy

from .fedavg import FedAvg


class FedXgbBagging(fl.server.strategy.FedAvg):
class FedXgbBagging(FedAvg):
"""Configurable FedXgbBagging strategy implementation."""

# pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
def __init__(
self,
evaluate_function: Optional[
Expand All @@ -22,10 +38,10 @@ def __init__(
Optional[Tuple[float, Dict[str, Scalar]]],
]
] = None,
**kwargs,
**kwargs: Any,
):
self.evaluate_function = evaluate_function
self.global_model = None
self.global_model: Optional[bytes] = None
super().__init__(**kwargs)

def aggregate_fit(
Expand All @@ -42,17 +58,16 @@ def aggregate_fit(
return None, {}

# Aggregate all the client trees
global_model = self.global_model
for _, fit_res in results:
update = fit_res.parameters.tensors
for item in update:
self.global_model = aggregate(
self.global_model, json.loads(bytearray(item))
)
for bst in update:
global_model = aggregate(global_model, bst)

weights_avg = json.dumps(self.global_model)
self.global_model = global_model

return (
Parameters(tensor_type="", tensors=[bytes(weights_avg, "utf-8")]),
Parameters(tensor_type="", tensors=[cast(bytes, global_model)]),
{},
)

Expand Down Expand Up @@ -93,37 +108,47 @@ def evaluate(
return loss, metrics


def aggregate(bst_prev: Optional[Dict], bst_curr: Dict) -> Dict:
def aggregate(
bst_prev_org: Optional[bytes],
bst_curr_org: bytes,
) -> bytes:
"""Conduct bagging aggregation for given trees."""
if not bst_prev:
return bst_curr
else:
# Get the tree numbers
tree_num_prev, paral_tree_num_prev = _get_tree_nums(bst_prev)
tree_num_curr, paral_tree_num_curr = _get_tree_nums(bst_curr)

bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
"num_trees"
] = str(tree_num_prev + paral_tree_num_curr)
iteration_indptr = bst_prev["learner"]["gradient_booster"]["model"][
"iteration_indptr"
]
bst_prev["learner"]["gradient_booster"]["model"]["iteration_indptr"].append(
iteration_indptr[-1] + paral_tree_num_curr
if not bst_prev_org:
return bst_curr_org

# Get the tree numbers
tree_num_prev, _ = _get_tree_nums(bst_prev_org)
_, paral_tree_num_curr = _get_tree_nums(bst_curr_org)

bst_prev = json.loads(bytearray(bst_prev_org))
bst_curr = json.loads(bytearray(bst_curr_org))

bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
"num_trees"
] = str(tree_num_prev + paral_tree_num_curr)
iteration_indptr = bst_prev["learner"]["gradient_booster"]["model"][
"iteration_indptr"
]
bst_prev["learner"]["gradient_booster"]["model"]["iteration_indptr"].append(
iteration_indptr[-1] + paral_tree_num_curr
)

# Aggregate new trees
trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"]
for tree_count in range(paral_tree_num_curr):
trees_curr[tree_count]["id"] = tree_num_prev + tree_count
bst_prev["learner"]["gradient_booster"]["model"]["trees"].append(
trees_curr[tree_count]
)
bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0)

bst_prev_bytes = bytes(json.dumps(bst_prev), "utf-8")

# Aggregate new trees
trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"]
for tree_count in range(paral_tree_num_curr):
trees_curr[tree_count]["id"] = tree_num_prev + tree_count
bst_prev["learner"]["gradient_booster"]["model"]["trees"].append(
trees_curr[tree_count]
)
bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0)
return bst_prev
return bst_prev_bytes


def _get_tree_nums(xgb_model: Dict) -> (int, int):
def _get_tree_nums(xgb_model_org: bytes) -> Tuple[int, int]:
xgb_model = json.loads(bytearray(xgb_model_org))
# Get the number of trees
tree_num = int(
xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
Expand Down

0 comments on commit 3b54412

Please sign in to comment.