From b346b24b4c757840213a3b01fe1c898c60171cfd Mon Sep 17 00:00:00 2001 From: kta-intel Date: Mon, 18 Nov 2024 13:28:40 -0800 Subject: [PATCH] add docstring and more descriptive comments Signed-off-by: kta-intel --- .../interface/aggregation_functions/fed_bagging.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/openfl/interface/aggregation_functions/fed_bagging.py b/openfl/interface/aggregation_functions/fed_bagging.py index 4c05fd0829..a5eed3dd06 100644 --- a/openfl/interface/aggregation_functions/fed_bagging.py +++ b/openfl/interface/aggregation_functions/fed_bagging.py @@ -5,9 +5,7 @@ """Federated Boostrap Aggregation for XGBoost module.""" import json -from logging import getLogger import numpy as np - from openfl.interface.aggregation_functions.core import AggregationFunction def get_global_model(iterator, target_round): @@ -59,8 +57,13 @@ def append_trees(global_model, local_trees): class FedBaggingXGBoost(AggregationFunction): - """Federated Boostrap Aggregation for XGBoost.""" + """ + Federated Bootstrap Aggregation for XGBoost. + This class implements a federated learning aggregation function specifically + designed for XGBoost models. It aggregates local model updates (trees) from + multiple collaborators into a global model using a bagging approach. + """ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_): """Aggregate tensors. @@ -94,12 +97,12 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_): Returns: bytearray: aggregated tensor """ - logger = getLogger(__name__) global_model = get_global_model(db_iterator, fl_round) if ( isinstance(global_model, np.ndarray) and global_model.size == 0 ) or global_model is None: + # if there is no global model, use the first local model as the global model for local_tensor in local_tensors: local_tree_bytearray = bytearray(local_tensor.tensor.astype(np.uint8).tobytes()) local_tree_json = json.loads(local_tree_bytearray) @@ -116,9 +119,11 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_): global_model = append_trees(global_model, local_trees) else: global_model_bytearray = bytearray(global_model.astype(np.uint8).tobytes()) + # convert the global model to a dictionary global_model = json.loads(global_model_bytearray) for local_tensor in local_tensors: + # append trees to global model local_tree_bytearray = bytearray(local_tensor.tensor.astype(np.uint8).tobytes()) local_trees = json.loads(local_tree_bytearray) global_model = append_trees(global_model, local_trees)