From 05e75deab256546486bc1ce8aba09dbd1d3e94d7 Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Thu, 16 Nov 2023 14:09:39 +0000 Subject: [PATCH] Correct type hints --- examples/xgboost-comprehensive/server.py | 2 +- src/py/flwr/server/strategy/fedxgb_bagging.py | 25 +++++++++++-------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/examples/xgboost-comprehensive/server.py b/examples/xgboost-comprehensive/server.py index 4e54859b186f..e4c597ee17eb 100644 --- a/examples/xgboost-comprehensive/server.py +++ b/examples/xgboost-comprehensive/server.py @@ -6,8 +6,8 @@ from flwr.common.logger import log from flwr.common import Parameters, Scalar from flwr_datasets import FederatedDataset - from flwr.server.strategy import FedXgbBagging + from utils import server_args_parser from dataset import resplit, transform_dataset_to_dmatrix diff --git a/src/py/flwr/server/strategy/fedxgb_bagging.py b/src/py/flwr/server/strategy/fedxgb_bagging.py index a1336a2d3082..c44a0d22ebaa 100644 --- a/src/py/flwr/server/strategy/fedxgb_bagging.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -59,15 +59,11 @@ def aggregate_fit( # Aggregate all the client trees for _, fit_res in results: update = fit_res.parameters.tensors - for item in update: - self.global_model = aggregate( - self.global_model, json.loads(bytearray(item)) - ) - - weights_avg = json.dumps(self.global_model) + for bst in update: + self.global_model = aggregate(self.global_model, bst) return ( - Parameters(tensor_type="", tensors=[bytes(weights_avg, "utf-8")]), + Parameters(tensor_type="", tensors=[self.global_model]), {}, ) @@ -109,9 +105,9 @@ def evaluate( def aggregate( - bst_prev: Optional[Dict[str, Union[int, slice]]], - bst_curr: Dict[str, Union[int, slice]], -) -> Dict[str, Union[int, slice]]: + bst_prev: Optional[bytes], + bst_curr: bytes, +) -> bytes: """Conduct bagging aggregation for given trees.""" if not bst_prev: return bst_curr @@ -120,6 +116,9 @@ def aggregate( tree_num_prev, paral_tree_num_prev = _get_tree_nums(bst_prev) tree_num_curr, paral_tree_num_curr = _get_tree_nums(bst_curr) + bst_prev = json.loads(bytearray(bst_prev)) + bst_curr = json.loads(bytearray(bst_curr)) + bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ "num_trees" ] = str(tree_num_prev + paral_tree_num_curr) @@ -138,10 +137,14 @@ def aggregate( trees_curr[tree_count] ) bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0) + + bst_prev = bytes(json.dumps(bst_prev), "utf-8") + return bst_prev -def _get_tree_nums(xgb_model: Dict[str, Union[int, slice]]) -> Tuple[int, int]: +def _get_tree_nums(xgb_model: bytes) -> Tuple[int, int]: + xgb_model = json.loads(bytearray(xgb_model)) # Get the number of trees tree_num = int( xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][