Skip to content

Commit

Permalink
Correct type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
yan-gao-GY committed Nov 16, 2023
1 parent dc5f11e commit 05e75de
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/xgboost-comprehensive/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from flwr.common.logger import log
from flwr.common import Parameters, Scalar
from flwr_datasets import FederatedDataset

from flwr.server.strategy import FedXgbBagging

from utils import server_args_parser
from dataset import resplit, transform_dataset_to_dmatrix

Expand Down
25 changes: 14 additions & 11 deletions src/py/flwr/server/strategy/fedxgb_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,11 @@ def aggregate_fit(
# Aggregate all the client trees
for _, fit_res in results:
update = fit_res.parameters.tensors
for item in update:
self.global_model = aggregate(
self.global_model, json.loads(bytearray(item))
)

weights_avg = json.dumps(self.global_model)
for bst in update:
self.global_model = aggregate(self.global_model, bst)

return (
Parameters(tensor_type="", tensors=[bytes(weights_avg, "utf-8")]),
Parameters(tensor_type="", tensors=[self.global_model]),
{},
)

Expand Down Expand Up @@ -109,9 +105,9 @@ def evaluate(


def aggregate(
bst_prev: Optional[Dict[str, Union[int, slice]]],
bst_curr: Dict[str, Union[int, slice]],
) -> Dict[str, Union[int, slice]]:
bst_prev: Optional[bytes],
bst_curr: bytes,
) -> bytes:
"""Conduct bagging aggregation for given trees."""
if not bst_prev:
return bst_curr
Expand All @@ -120,6 +116,9 @@ def aggregate(
tree_num_prev, paral_tree_num_prev = _get_tree_nums(bst_prev)
tree_num_curr, paral_tree_num_curr = _get_tree_nums(bst_curr)

bst_prev = json.loads(bytearray(bst_prev))
bst_curr = json.loads(bytearray(bst_curr))

bst_prev["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
"num_trees"
] = str(tree_num_prev + paral_tree_num_curr)
Expand All @@ -138,10 +137,14 @@ def aggregate(
trees_curr[tree_count]
)
bst_prev["learner"]["gradient_booster"]["model"]["tree_info"].append(0)

bst_prev = bytes(json.dumps(bst_prev), "utf-8")

return bst_prev


def _get_tree_nums(xgb_model: Dict[str, Union[int, slice]]) -> Tuple[int, int]:
def _get_tree_nums(xgb_model: bytes) -> Tuple[int, int]:
xgb_model = json.loads(bytearray(xgb_model))
# Get the number of trees
tree_num = int(
xgb_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
Expand Down

0 comments on commit 05e75de

Please sign in to comment.