Skip to content

Commit

Permalink
add docstring and more descriptive comments
Browse files Browse the repository at this point in the history
Signed-off-by: kta-intel <[email protected]>
  • Loading branch information
kta-intel committed Nov 18, 2024
1 parent 63be874 commit b346b24
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions openfl/interface/aggregation_functions/fed_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
"""Federated Boostrap Aggregation for XGBoost module."""

import json
from logging import getLogger
import numpy as np

from openfl.interface.aggregation_functions.core import AggregationFunction

def get_global_model(iterator, target_round):
Expand Down Expand Up @@ -59,8 +57,13 @@ def append_trees(global_model, local_trees):


class FedBaggingXGBoost(AggregationFunction):
"""Federated Boostrap Aggregation for XGBoost."""
"""
Federated Bootstrap Aggregation for XGBoost.
This class implements a federated learning aggregation function specifically
designed for XGBoost models. It aggregates local model updates (trees) from
multiple collaborators into a global model using a bagging approach.
"""
def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_):
"""Aggregate tensors.
Expand Down Expand Up @@ -94,12 +97,12 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_):
Returns:
bytearray: aggregated tensor
"""
logger = getLogger(__name__)
global_model = get_global_model(db_iterator, fl_round)

if (
isinstance(global_model, np.ndarray) and global_model.size == 0
) or global_model is None:
# if there is no global model, use the first local model as the global model
for local_tensor in local_tensors:
local_tree_bytearray = bytearray(local_tensor.tensor.astype(np.uint8).tobytes())
local_tree_json = json.loads(local_tree_bytearray)
Expand All @@ -116,9 +119,11 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_):
global_model = append_trees(global_model, local_trees)
else:
global_model_bytearray = bytearray(global_model.astype(np.uint8).tobytes())
# convert the global model to a dictionary
global_model = json.loads(global_model_bytearray)

for local_tensor in local_tensors:
# append trees to global model
local_tree_bytearray = bytearray(local_tensor.tensor.astype(np.uint8).tobytes())
local_trees = json.loads(local_tree_bytearray)
global_model = append_trees(global_model, local_trees)
Expand Down

0 comments on commit b346b24

Please sign in to comment.