From 5380456894b50e153545ef08279caf9a6d3cd1aa Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Thu, 16 Nov 2023 10:19:35 +0000 Subject: [PATCH 01/10] Move xgb_bagging strategy to core framework; update readme --- examples/xgboost-comprehensive/README.md | 8 ++++---- examples/xgboost-comprehensive/server.py | 2 +- src/py/flwr/server/strategy/__init__.py | 2 ++ .../py/flwr/server/strategy/fedxgb_bagging.py | 0 4 files changed, 7 insertions(+), 5 deletions(-) rename examples/xgboost-comprehensive/strategy.py => src/py/flwr/server/strategy/fedxgb_bagging.py (100%) diff --git a/examples/xgboost-comprehensive/README.md b/examples/xgboost-comprehensive/README.md index 3801d4813a26..3b31e23cb321 100644 --- a/examples/xgboost-comprehensive/README.md +++ b/examples/xgboost-comprehensive/README.md @@ -8,17 +8,17 @@ Tree-based with bagging method is used for aggregation on the server. Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: ```shell -git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-xgboost . && rm -rf flower && cd quickstart-xgboost +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/xgboost-comprehensive . && rm -rf flower && cd xgboost-comprehensive ``` -This will create a new directory called `quickstart-xgboost` containing the following files: +This will create a new directory called `xgboost-comprehensive` containing the following files: ``` -- README.md <- Your're reading this right now -- server.py <- Defines the server-side logic --- strategy.py <- Defines the tree-based bagging aggregation -- client.py <- Defines the client-side logic -- dataset.py <- Defines the functions of data loading and partitioning +-- utils.py <- Defines the arguments parser for clients and server. -- pyproject.toml <- Example dependencies (if you use Poetry) -- requirements.txt <- Example dependencies ``` @@ -83,5 +83,5 @@ bash run.sh ``` Besides, we provide options to customise the experimental settings, including data partitioning and centralised/distributed evaluation (see `utils.py`). -Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-xgboost) +Look at the [code](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive) and [tutorial](https://flower.dev/docs/framework/tutorial-quickstart-xgboost.html) for a detailed explanation. diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index 857e99528013..4e54859b186f 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -7,7 +7,7 @@ from flwr.common import Parameters, Scalar from flwr_datasets import FederatedDataset -from strategy import FedXgbBagging +from flwr.server.strategy import FedXgbBagging from utils import server_args_parser from dataset import resplit, transform_dataset_to_dmatrix diff --git a/src/py/flwr/server/strategy/__init__.py b/src/py/flwr/server/strategy/__init__.py index 72429694bfe7..15970e90f067 100644 --- a/src/py/flwr/server/strategy/__init__.py +++ b/src/py/flwr/server/strategy/__init__.py @@ -28,6 +28,7 @@ from .fedprox import FedProx as FedProx from .fedtrimmedavg import FedTrimmedAvg as FedTrimmedAvg from .fedxgb_nn_avg import FedXgbNnAvg as FedXgbNnAvg +from .fedxgb_bagging import FedXgbBagging as FedXgbBagging from .fedyogi import FedYogi as FedYogi from .krum import Krum as Krum from .qfedavg import QFedAvg as QFedAvg @@ -39,6 +40,7 @@ "FedAdam", "FedAvg", "FedXgbNnAvg", + "FedXgbBagging", "FedAvgAndroid", "FedAvgM", "FedOpt", diff --git a/examples/xgboost-comprehensive/strategy.py b/src/py/flwr/server/strategy/fedxgb_bagging.py similarity index 100% rename from examples/xgboost-comprehensive/strategy.py rename to src/py/flwr/server/strategy/fedxgb_bagging.py From 58ff6d708035da96d36eb66749991bdd70be23be Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Thu, 16 Nov 2023 10:32:26 +0000 Subject: [PATCH 02/10] Formatting --- src/py/flwr/server/strategy/__init__.py | 2 +- src/py/flwr/server/strategy/fedxgb_bagging.py | 32 +++++++++++++------ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/py/flwr/server/strategy/__init__.py b/src/py/flwr/server/strategy/__init__.py index 15970e90f067..1651b31a2648 100644 --- a/src/py/flwr/server/strategy/__init__.py +++ b/src/py/flwr/server/strategy/__init__.py @@ -27,8 +27,8 @@ from .fedopt import FedOpt as FedOpt from .fedprox import FedProx as FedProx from .fedtrimmedavg import FedTrimmedAvg as FedTrimmedAvg -from .fedxgb_nn_avg import FedXgbNnAvg as FedXgbNnAvg from .fedxgb_bagging import FedXgbBagging as FedXgbBagging +from .fedxgb_nn_avg import FedXgbNnAvg as FedXgbNnAvg from .fedyogi import FedYogi as FedYogi from .krum import Krum as Krum from .qfedavg import QFedAvg as QFedAvg diff --git a/src/py/flwr/server/strategy/fedxgb_bagging.py b/src/py/flwr/server/strategy/fedxgb_bagging.py index 814010720a77..db6a807c0a96 100644 --- a/src/py/flwr/server/strategy/fedxgb_bagging.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -1,19 +1,33 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Federated XGBoost bagging aggregation strategy.""" + + +import json from logging import WARNING from typing import Callable, Dict, List, Optional, Tuple, Union -import flwr as fl -import json -from flwr.common import ( - EvaluateRes, - FitRes, - Parameters, - Scalar, -) -from flwr.server.client_proxy import ClientProxy +import flwr as fl +from flwr.common import EvaluateRes, FitRes, Parameters, Scalar from flwr.common.logger import log +from flwr.server.client_proxy import ClientProxy class FedXgbBagging(fl.server.strategy.FedAvg): + """Configurable FedXgbBagging strategy implementation.""" + def __init__( self, evaluate_function: Optional[ From dc5f11ea435159fce065f96770819b8ee579cd90 Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Thu, 16 Nov 2023 10:45:43 +0000 Subject: [PATCH 03/10] Correct type hints --- src/py/flwr/server/strategy/fedxgb_bagging.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/strategy/fedxgb_bagging.py b/src/py/flwr/server/strategy/fedxgb_bagging.py index db6a807c0a96..a1336a2d3082 100644 --- a/src/py/flwr/server/strategy/fedxgb_bagging.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -28,6 +28,7 @@ class FedXgbBagging(fl.server.strategy.FedAvg): """Configurable FedXgbBagging strategy implementation.""" + # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long def __init__( self, evaluate_function: Optional[ @@ -107,7 +108,10 @@ def evaluate( return loss, metrics -def aggregate(bst_prev: Optional[Dict], bst_curr: Dict) -> Dict: +def aggregate( + bst_prev: Optional[Dict[str, Union[int, slice]]], + bst_curr: Dict[str, Union[int, slice]], +) -> Dict[str, Union[int, slice]]: """Conduct bagging aggregation for given trees.""" if not bst_prev: return bst_curr @@ -137,7 +141,7 @@ def aggregate(bst_prev: Optional[Dict], bst_curr: Dict) -> Dict: return bst_prev -def _get_tree_nums(xgb_model: Dict) -> (int, int): +def _get_tree_nums(xgb_model: Dict[str, Union[int, slice]]) -> Tuple[int, int]: # Get the number of trees tree_num = int( xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ From 05e75deab256546486bc1ce8aba09dbd1d3e94d7 Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Thu, 16 Nov 2023 14:09:39 +0000 Subject: [PATCH 04/10] Correct type hints --- examples/xgboost-comprehensive/server.py | 2 +- src/py/flwr/server/strategy/fedxgb_bagging.py | 25 +++++++++++-------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index 4e54859b186f..e4c597ee17eb 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -6,8 +6,8 @@ from flwr.common.logger import log from flwr.common import Parameters, Scalar from flwr_datasets import FederatedDataset - from flwr.server.strategy import FedXgbBagging + from utils import server_args_parser from dataset import resplit, transform_dataset_to_dmatrix diff --git a/src/py/flwr/server/strategy/fedxgb_bagging.py b/src/py/flwr/server/strategy/fedxgb_bagging.py index a1336a2d3082..c44a0d22ebaa 100644 --- a/src/py/flwr/server/strategy/fedxgb_bagging.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -59,15 +59,11 @@ def aggregate_fit( # Aggregate all the client trees for _, fit_res in results: update = fit_res.parameters.tensors - for item in update: - self.global_model = aggregate( - self.global_model, json.loads(bytearray(item)) - ) - - weights_avg = json.dumps(self.global_model) + for bst in update: + self.global_model = aggregate(self.global_model, bst) return ( - Parameters(tensor_type="", tensors=[bytes(weights_avg, "utf-8")]), + Parameters(tensor_type="", tensors=[self.global_model]), {}, ) @@ -109,9 +105,9 @@ def evaluate( def aggregate( - bst_prev: Optional[Dict[str, Union[int, slice]]], - bst_curr: Dict[str, Union[int, slice]], -) -> Dict[str, Union[int, slice]]: + bst_prev: Optional[bytes], + bst_curr: bytes, +) -> bytes: """Conduct bagging aggregation for given trees.""" if not bst_prev: return bst_curr @@ -120,6 +116,9 @@ def aggregate( tree_num_prev, paral_tree_num_prev = _get_tree_nums(bst_prev) tree_num_curr, paral_tree_num_curr = _get_tree_nums(bst_curr) + bst_prev = json.loads(bytearray(bst_prev)) + bst_curr = json.loads(bytearray(bst_curr)) + bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ "num_trees" ] = str(tree_num_prev + paral_tree_num_curr) @@ -138,10 +137,14 @@ def aggregate( trees_curr[tree_count] ) bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0) + + bst_prev = bytes(json.dumps(bst_prev), "utf-8") + return bst_prev -def _get_tree_nums(xgb_model: Dict[str, Union[int, slice]]) -> Tuple[int, int]: +def _get_tree_nums(xgb_model: bytes) -> Tuple[int, int]: + xgb_model = json.loads(bytearray(xgb_model)) # Get the number of trees tree_num = int( xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ From 3211f9654c3498f899eb1e56e3a4897f411bacf0 Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Thu, 16 Nov 2023 14:23:17 +0000 Subject: [PATCH 05/10] Correct types hints --- src/py/flwr/server/strategy/fedxgb_bagging.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/py/flwr/server/strategy/fedxgb_bagging.py b/src/py/flwr/server/strategy/fedxgb_bagging.py index c44a0d22ebaa..a6db39827967 100644 --- a/src/py/flwr/server/strategy/fedxgb_bagging.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -17,7 +17,7 @@ import json from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import flwr as fl from flwr.common import EvaluateRes, FitRes, Parameters, Scalar @@ -37,7 +37,7 @@ def __init__( Optional[Tuple[float, Dict[str, Scalar]]], ] ] = None, - **kwargs, + **kwargs: Any, ): self.evaluate_function = evaluate_function self.global_model = None @@ -62,8 +62,10 @@ def aggregate_fit( for bst in update: self.global_model = aggregate(self.global_model, bst) + weights = self.global_model + return ( - Parameters(tensor_type="", tensors=[self.global_model]), + Parameters(tensor_type="", tensors=[weights]), {}, ) @@ -105,19 +107,19 @@ def evaluate( def aggregate( - bst_prev: Optional[bytes], - bst_curr: bytes, -) -> bytes: + bst_prev_org: Optional[bytes], + bst_curr_org: bytes, +) -> Optional[bytes]: """Conduct bagging aggregation for given trees.""" - if not bst_prev: - return bst_curr + if not bst_prev_org: + return bst_curr_org else: # Get the tree numbers - tree_num_prev, paral_tree_num_prev = _get_tree_nums(bst_prev) - tree_num_curr, paral_tree_num_curr = _get_tree_nums(bst_curr) + tree_num_prev, paral_tree_num_prev = _get_tree_nums(bst_prev_org) + tree_num_curr, paral_tree_num_curr = _get_tree_nums(bst_curr_org) - bst_prev = json.loads(bytearray(bst_prev)) - bst_curr = json.loads(bytearray(bst_curr)) + bst_prev = json.loads(bytearray(bst_prev_org)) + bst_curr = json.loads(bytearray(bst_curr_org)) bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ "num_trees" @@ -143,8 +145,8 @@ def aggregate( return bst_prev -def _get_tree_nums(xgb_model: bytes) -> Tuple[int, int]: - xgb_model = json.loads(bytearray(xgb_model)) +def _get_tree_nums(xgb_model_org: bytes) -> Tuple[int, int]: + xgb_model = json.loads(bytearray(xgb_model_org)) # Get the number of trees tree_num = int( xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ From 25af27ed1b6ee307c449462972c00f2346a036b8 Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Thu, 16 Nov 2023 15:41:56 +0000 Subject: [PATCH 06/10] Correct types hints --- src/py/flwr/server/strategy/fedxgb_bagging.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/strategy/fedxgb_bagging.py b/src/py/flwr/server/strategy/fedxgb_bagging.py index a6db39827967..5b8330d73a33 100644 --- a/src/py/flwr/server/strategy/fedxgb_bagging.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -57,15 +57,16 @@ def aggregate_fit( return None, {} # Aggregate all the client trees + global_model = self.global_model for _, fit_res in results: update = fit_res.parameters.tensors for bst in update: - self.global_model = aggregate(self.global_model, bst) + global_model = aggregate(global_model, bst) - weights = self.global_model + self.global_model = global_model return ( - Parameters(tensor_type="", tensors=[weights]), + Parameters(tensor_type="", tensors=[global_model]), {}, ) From e839bb88c827a74312192b81dfe46d968450fd1e Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Thu, 16 Nov 2023 15:47:50 +0000 Subject: [PATCH 07/10] Correct types hints --- src/py/flwr/server/strategy/fedxgb_bagging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/strategy/fedxgb_bagging.py b/src/py/flwr/server/strategy/fedxgb_bagging.py index 5b8330d73a33..cbbac660c49e 100644 --- a/src/py/flwr/server/strategy/fedxgb_bagging.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -110,7 +110,7 @@ def evaluate( def aggregate( bst_prev_org: Optional[bytes], bst_curr_org: bytes, -) -> Optional[bytes]: +) -> bytes: """Conduct bagging aggregation for given trees.""" if not bst_prev_org: return bst_curr_org From 0b6bed4c1c5348875c31bf71743caa4fc45f97eb Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Thu, 16 Nov 2023 17:40:12 +0000 Subject: [PATCH 08/10] Remove unuseful packages --- examples/xgboost-comprehensive/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/xgboost-comprehensive/dataset.py b/examples/xgboost-comprehensive/dataset.py index 80c978f1077b..bcf2e00b30af 100644 --- a/examples/xgboost-comprehensive/dataset.py +++ b/examples/xgboost-comprehensive/dataset.py @@ -1,5 +1,5 @@ import xgboost as xgb -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Union from datasets import Dataset, DatasetDict, concatenate_datasets from flwr_datasets.partitioner import ( IidPartitioner, From e66994710fae0a382d0b5ae52636495eb81d16ab Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Thu, 16 Nov 2023 19:27:10 +0000 Subject: [PATCH 09/10] Correct type hints --- src/py/flwr/server/strategy/fedxgb_bagging.py | 56 +++++++++---------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/src/py/flwr/server/strategy/fedxgb_bagging.py b/src/py/flwr/server/strategy/fedxgb_bagging.py index cbbac660c49e..578328ca701b 100644 --- a/src/py/flwr/server/strategy/fedxgb_bagging.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -17,7 +17,7 @@ import json from logging import WARNING -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast import flwr as fl from flwr.common import EvaluateRes, FitRes, Parameters, Scalar @@ -40,7 +40,7 @@ def __init__( **kwargs: Any, ): self.evaluate_function = evaluate_function - self.global_model = None + self.global_model: Optional[bytes] = None super().__init__(**kwargs) def aggregate_fit( @@ -66,7 +66,7 @@ def aggregate_fit( self.global_model = global_model return ( - Parameters(tensor_type="", tensors=[global_model]), + Parameters(tensor_type="", tensors=[cast(bytes, global_model)]), {}, ) @@ -114,36 +114,36 @@ def aggregate( """Conduct bagging aggregation for given trees.""" if not bst_prev_org: return bst_curr_org - else: - # Get the tree numbers - tree_num_prev, paral_tree_num_prev = _get_tree_nums(bst_prev_org) - tree_num_curr, paral_tree_num_curr = _get_tree_nums(bst_curr_org) - bst_prev = json.loads(bytearray(bst_prev_org)) - bst_curr = json.loads(bytearray(bst_curr_org)) + # Get the tree numbers + tree_num_prev, _ = _get_tree_nums(bst_prev_org) + _, paral_tree_num_curr = _get_tree_nums(bst_curr_org) + + bst_prev = json.loads(bytearray(bst_prev_org)) + bst_curr = json.loads(bytearray(bst_curr_org)) + + bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ + "num_trees" + ] = str(tree_num_prev + paral_tree_num_curr) + iteration_indptr = bst_prev["learner"]["gradient_booster"]["model"][ + "iteration_indptr" + ] + bst_prev["learner"]["gradient_booster"]["model"]["iteration_indptr"].append( + iteration_indptr[-1] + paral_tree_num_curr + ) - bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ - "num_trees" - ] = str(tree_num_prev + paral_tree_num_curr) - iteration_indptr = bst_prev["learner"]["gradient_booster"]["model"][ - "iteration_indptr" - ] - bst_prev["learner"]["gradient_booster"]["model"]["iteration_indptr"].append( - iteration_indptr[-1] + paral_tree_num_curr + # Aggregate new trees + trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"] + for tree_count in range(paral_tree_num_curr): + trees_curr[tree_count]["id"] = tree_num_prev + tree_count + bst_prev["learner"]["gradient_booster"]["model"]["trees"].append( + trees_curr[tree_count] ) + bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0) - # Aggregate new trees - trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"] - for tree_count in range(paral_tree_num_curr): - trees_curr[tree_count]["id"] = tree_num_prev + tree_count - bst_prev["learner"]["gradient_booster"]["model"]["trees"].append( - trees_curr[tree_count] - ) - bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0) - - bst_prev = bytes(json.dumps(bst_prev), "utf-8") + bst_prev_bytes = bytes(json.dumps(bst_prev), "utf-8") - return bst_prev + return bst_prev_bytes def _get_tree_nums(xgb_model_org: bytes) -> Tuple[int, int]: From 62d113e9937674432ec004c3d85ed5fdd9da8756 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 16 Nov 2023 20:52:07 +0000 Subject: [PATCH 10/10] fix cyclic import --- src/py/flwr/server/strategy/fedxgb_bagging.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/strategy/fedxgb_bagging.py b/src/py/flwr/server/strategy/fedxgb_bagging.py index 578328ca701b..cafb466c2e8b 100644 --- a/src/py/flwr/server/strategy/fedxgb_bagging.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -19,13 +19,14 @@ from logging import WARNING from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast -import flwr as fl from flwr.common import EvaluateRes, FitRes, Parameters, Scalar from flwr.common.logger import log from flwr.server.client_proxy import ClientProxy +from .fedavg import FedAvg -class FedXgbBagging(fl.server.strategy.FedAvg): + +class FedXgbBagging(FedAvg): """Configurable FedXgbBagging strategy implementation.""" # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long