diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md index 3801d4813a26..3b31e23cb321 100644 --- a/examples/xgboost-comprehensive/README.md +++ b/examples/xgboost-comprehensive/README.md @@ -8,17 +8,17 @@ Tree-based with bagging method is used for aggregation on the server. Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: ```shell -git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-xgboost . && rm -rf flower && cd quickstart-xgboost +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/xgboost-comprehensive . && rm -rf flower && cd xgboost-comprehensive ``` -This will create a new directory called `quickstart-xgboost` containing the following files: +This will create a new directory called `xgboost-comprehensive` containing the following files: ``` -- README.md <- Your're reading this right now -- server.py <- Defines the server-side logic --- strategy.py <- Defines the tree-based bagging aggregation -- client.py <- Defines the client-side logic -- dataset.py <- Defines the functions of data loading and partitioning +-- utils.py <- Defines the arguments parser for clients and server. -- pyproject.toml <- Example dependencies (if you use Poetry) -- requirements.txt <- Example dependencies ``` @@ -83,5 +83,5 @@ bash run.sh ``` Besides, we provide options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`). -Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-xgboost) +Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive) and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. diff --git a/examples/xgboost-comprehensive/dataset.py b/examples/xgboost-comprehensive/dataset.py index 80c978f1077b..bcf2e00b30af 100644 --- a/examples/xgboost-comprehensive/dataset.py +++ b/examples/xgboost-comprehensive/dataset.py @@ -1,5 +1,5 @@ import xgboost as xgb -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Union from datasets import Dataset, DatasetDict, concatenate_datasets from flwr_datasets.partitioner import ( IidPartitioner, diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index 857e99528013..e4c597ee17eb 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -6,8 +6,8 @@ from flwr.common.logger import log from flwr.common import Parameters, Scalar from flwr_datasets import FederatedDataset +from flwr.server.strategy import FedXgbBagging -from strategy import FedXgbBagging from utils import server_args_parser from dataset import resplit, transform_dataset_to_dmatrix diff --git a/src/py/flwr/server/strategy/__init__.py b/src/py/flwr/server/strategy/__init__.py index 908267d04b3f..0772aa1ff13a 100644 --- a/src/py/flwr/server/strategy/__init__.py +++ b/src/py/flwr/server/strategy/__init__.py @@ -28,6 +28,7 @@ from .fedopt import FedOpt as FedOpt from .fedprox import FedProx as FedProx from .fedtrimmedavg import FedTrimmedAvg as FedTrimmedAvg +from .fedxgb_bagging import FedXgbBagging as FedXgbBagging from .fedxgb_nn_avg import FedXgbNnAvg as FedXgbNnAvg from .fedyogi import FedYogi as FedYogi from .krum import Krum as Krum @@ -40,6 +41,7 @@ "FedAdam", "FedAvg", "FedXgbNnAvg", + "FedXgbBagging", "FedAvgAndroid", "FedAvgM", "FedOpt", diff --git a/examples/xgboost-comprehensive/strategy.py b/src/py/flwr/server/strategy/fedxgb_bagging.py similarity index 52% rename from examples/xgboost-comprehensive/strategy.py rename to src/py/flwr/server/strategy/fedxgb_bagging.py index 814010720a77..cafb466c2e8b 100644 --- a/examples/xgboost-comprehensive/strategy.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -1,19 +1,35 @@ -from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union -import flwr as fl +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Federated XGBoost bagging aggregation strategy.""" + + import json +from logging import WARNING +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast -from flwr.common import ( - EvaluateRes, - FitRes, - Parameters, - Scalar, -) -from flwr.server.client_proxy import ClientProxy +from flwr.common import EvaluateRes, FitRes, Parameters, Scalar from flwr.common.logger import log +from flwr.server.client_proxy import ClientProxy + +from .fedavg import FedAvg -class FedXgbBagging(fl.server.strategy.FedAvg): +class FedXgbBagging(FedAvg): + """Configurable FedXgbBagging strategy implementation.""" + + # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long def __init__( self, evaluate_function: Optional[ @@ -22,10 +38,10 @@ def __init__( Optional[Tuple[float, Dict[str, Scalar]]], ] ] = None, - **kwargs, + **kwargs: Any, ): self.evaluate_function = evaluate_function - self.global_model = None + self.global_model: Optional[bytes] = None super().__init__(**kwargs) def aggregate_fit( @@ -42,17 +58,16 @@ def aggregate_fit( return None, {} # Aggregate all the client trees + global_model = self.global_model for _, fit_res in results: update = fit_res.parameters.tensors - for item in update: - self.global_model = aggregate( - self.global_model, json.loads(bytearray(item)) - ) + for bst in update: + global_model = aggregate(global_model, bst) - weights_avg = json.dumps(self.global_model) + self.global_model = global_model return ( - Parameters(tensor_type="", tensors=[bytes(weights_avg, "utf-8")]), + Parameters(tensor_type="", tensors=[cast(bytes, global_model)]), {}, ) @@ -93,37 +108,47 @@ def evaluate( return loss, metrics -def aggregate(bst_prev: Optional[Dict], bst_curr: Dict) -> Dict: +def aggregate( + bst_prev_org: Optional[bytes], + bst_curr_org: bytes, +) -> bytes: """Conduct bagging aggregation for given trees.""" - if not bst_prev: - return bst_curr - else: - # Get the tree numbers - tree_num_prev, paral_tree_num_prev = _get_tree_nums(bst_prev) - tree_num_curr, paral_tree_num_curr = _get_tree_nums(bst_curr) - - bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ - "num_trees" - ] = str(tree_num_prev + paral_tree_num_curr) - iteration_indptr = bst_prev["learner"]["gradient_booster"]["model"][ - "iteration_indptr" - ] - bst_prev["learner"]["gradient_booster"]["model"]["iteration_indptr"].append( - iteration_indptr[-1] + paral_tree_num_curr + if not bst_prev_org: + return bst_curr_org + + # Get the tree numbers + tree_num_prev, _ = _get_tree_nums(bst_prev_org) + _, paral_tree_num_curr = _get_tree_nums(bst_curr_org) + + bst_prev = json.loads(bytearray(bst_prev_org)) + bst_curr = json.loads(bytearray(bst_curr_org)) + + bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ + "num_trees" + ] = str(tree_num_prev + paral_tree_num_curr) + iteration_indptr = bst_prev["learner"]["gradient_booster"]["model"][ + "iteration_indptr" + ] + bst_prev["learner"]["gradient_booster"]["model"]["iteration_indptr"].append( + iteration_indptr[-1] + paral_tree_num_curr + ) + + # Aggregate new trees + trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"] + for tree_count in range(paral_tree_num_curr): + trees_curr[tree_count]["id"] = tree_num_prev + tree_count + bst_prev["learner"]["gradient_booster"]["model"]["trees"].append( + trees_curr[tree_count] ) + bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0) + + bst_prev_bytes = bytes(json.dumps(bst_prev), "utf-8") - # Aggregate new trees - trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"] - for tree_count in range(paral_tree_num_curr): - trees_curr[tree_count]["id"] = tree_num_prev + tree_count - bst_prev["learner"]["gradient_booster"]["model"]["trees"].append( - trees_curr[tree_count] - ) - bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0) - return bst_prev + return bst_prev_bytes -def _get_tree_nums(xgb_model: Dict) -> (int, int): +def _get_tree_nums(xgb_model_org: bytes) -> Tuple[int, int]: + xgb_model = json.loads(bytearray(xgb_model_org)) # Get the number of trees tree_num = int( xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][