From ce4b34fe1e11c3a9add6db3ce09cbae8fdc15e10 Mon Sep 17 00:00:00 2001 From: kta-intel Date: Fri, 15 Nov 2024 10:54:27 -0800 Subject: [PATCH] fix model save Signed-off-by: kta-intel --- openfl/federated/task/runner_xgb.py | 97 +++++++++++-------- .../aggregation_functions/fed_bagging.py | 29 ++++-- 2 files changed, 81 insertions(+), 45 deletions(-) diff --git a/openfl/federated/task/runner_xgb.py b/openfl/federated/task/runner_xgb.py index cff8fcd662..bcdbe1ec16 100644 --- a/openfl/federated/task/runner_xgb.py +++ b/openfl/federated/task/runner_xgb.py @@ -27,19 +27,28 @@ def __init__(self, **kwargs): **kwargs: Additional parameters to pass to the functions. """ super().__init__(**kwargs) - # This is a map of all the required tensors for each of the public - # functions in XGBoostTaskRunner - self.global_model = None # TODO - + self.global_model = None self.required_tensorkeys_for_function = {} self.training_round_completed = False def rebuild_model(self, input_tensor_dict): + """ + Rebuilds the model using the provided input tensor dictionary. + + This method checks if the 'local_tree' key in the input tensor dictionary is either a non-empty numpy array + or a non-None value. If this condition is met, it updates the internal tensor dictionary with the provided input. + + Parameters: + input_tensor_dict (dict): A dictionary containing tensor data. It must include the key 'local_tree', which can be: + - A non-empty numpy array + - Any non-None value + + Returns: + None + """ if (isinstance(input_tensor_dict['local_tree'], np.ndarray) and input_tensor_dict['local_tree'].size != 0) \ or (not isinstance(input_tensor_dict['local_tree'], np.ndarray) and input_tensor_dict['local_tree'] is not None): - self.global_model = bytearray(input_tensor_dict['local_tree'].astype(np.uint8).tobytes()) - self.bst = xgb.Booster() - self.bst.load_model(self.global_model) + self.set_tensor_dict(input_tensor_dict) def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs): """Validate Task. @@ -50,7 +59,6 @@ def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs): col_name (str): Name of the collaborator. round_num (int): What round is it. input_tensor_dict (dict): Required input tensors (for model). - use_tqdm (bool): Use tqdm to print a progress bar (Default=True). **kwargs: Additional parameters. Returns: @@ -102,8 +110,6 @@ def train_task( col_name (str): Name of the collaborator. round_num (int): What round is it. input_tensor_dict (dict): Required input tensors (for model). - use_tqdm (bool): Use tqdm to print a progress bar (Default=True). - epochs (int): The number of epochs to train. **kwargs: Additional parameters. Returns: @@ -177,6 +183,21 @@ def train_task( return global_tensor_dict, local_tensor_dict def get_tensor_dict(self, with_opt_vars=False): + """ + Retrieves the tensor dictionary containing the model's tree structure. + + This method returns a dictionary with the key 'local_tree', which contains the model's tree structure as a numpy array. + If the model has not been initialized (`self.bst` is None), it returns an empty numpy array. + If the global model is not set or is empty, it returns the entire model as a numpy array. + Otherwise, it returns only the trees added in the latest training session. + + Parameters: + with_opt_vars (bool): N/A for XGBoost (Default=False). + + Returns: + dict: A dictionary with the key 'local_tree' containing the model's tree structure as a numpy array. + """ + if self.bst is None: # For initializing tensor dict return {'local_tree': np.array([], dtype=np.float32)} @@ -196,9 +217,7 @@ def get_tensor_dict(self, with_opt_vars=False): num_latest_trees = num_total_trees - num_global_trees latest_trees = booster_dict['learner']['gradient_booster']['model']['trees'][-num_latest_trees:] - # Convert latest_trees to a JSON string latest_trees_json = json.dumps(latest_trees) - # Convert JSON string to np.float32 array latest_trees_bytes = latest_trees_json.encode('utf-8') latest_trees_float32_array = np.frombuffer(latest_trees_bytes, dtype=np.uint8).astype(np.float32) @@ -289,33 +308,33 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): for tensor_name in local_model_dict_val ] - # def save_native( - # self, - # filepath, - # model_state_dict_key="model_state_dict", - # optimizer_state_dict_key="optimizer_state_dict", - # **kwargs, - # ): - # """Save model and optimizer states in a picked file specified by the - # filepath. model_/optimizer_state_dicts are stored in the keys provided. - # Uses pt.save(). - - # Args: - # filepath (str): Path to pickle file to be created by pt.save(). - # model_state_dict_key (str): key for model state dict in pickled - # file. - # optimizer_state_dict_key (str): key for optimizer state dict in - # picked file. - # **kwargs: Additional parameters. - - # Returns: - # None - # """ - # pickle_dict = { - # model_state_dict_key: self.state_dict(), - # optimizer_state_dict_key: self.optimizer.state_dict(), - # } - # torch.save(pickle_dict, filepath) + def set_tensor_dict(self, tensor_dict, with_opt_vars=False): + """Set the tensor dictionary. + + Args: + tensor_dict (dict): The tensor dictionary. + with_opt_vars (bool): N/A for XGBoost (Default=False). + """ + # The with_opt_vars argument is not used in this method + self.global_model = bytearray(tensor_dict['local_tree'].astype(np.uint8).tobytes()) + self.bst = xgb.Booster() + self.bst.load_model(self.global_model) + + def save_native( + self, + filepath, + **kwargs, + ): + """Save XGB booster to file. + + Args: + filepath (str): Path to pickle file to be created by booster.save_model(). + **kwargs: Additional parameters. + + Returns: + None + """ + self.bst.save_model(filepath) def train_(self, train_dataloader) -> Metric: """Train model.""" diff --git a/openfl/interface/aggregation_functions/fed_bagging.py b/openfl/interface/aggregation_functions/fed_bagging.py index 336a14dd75..d67c977fbd 100644 --- a/openfl/interface/aggregation_functions/fed_bagging.py +++ b/openfl/interface/aggregation_functions/fed_bagging.py @@ -5,10 +5,20 @@ """Federated Boostrap Aggregation for XGBoost module.""" import json -from openfl.interface.aggregation_functions.core import AggregationFunction import numpy as np +from openfl.interface.aggregation_functions.core import AggregationFunction def get_global_model(iterator, target_round): + """ + Retrieves the global model for the specific round from an iterator. + + Parameters: + iterator (iterable): An iterable containing items with 'tags' and 'round' keys. + target_round (int): The round number for which the global model is to be retrieved. + + Returns: + np.ndarray: The numpy array representing the global model for the specified round. + """ for item in iterator: # Items tagged with ('model',) are the global model of that round if 'tags' in item and item['tags'] == ('model',) and item['round'] == target_round: @@ -17,7 +27,16 @@ def get_global_model(iterator, target_round): def append_trees(global_model, local_trees): + """ + Appends local trees to the global model. + + Parameters: + global_model (dict): A dictionary representing the global model. + local_trees (list): A list of dictionaries representing the local trees to be appended to the global model. + Returns: + dict: The updated global model with the local trees appended. + """ num_global_trees = int(global_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"]) num_local_trees = len(local_trees) @@ -71,7 +90,7 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_): Returns: bytearray: aggregated tensor """ - # global_model = None + global_model = get_global_model(db_iterator, fl_round) if (isinstance(global_model, np.ndarray) and global_model.size == 0) or global_model is None: @@ -80,10 +99,10 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_): local_tree_json = json.loads(local_tree_bytearray) if (isinstance(global_model, np.ndarray) and global_model.size == 0) or global_model is None: - # the first tree becomes the global model to append to + # the first tree becomes the global model global_model = local_tree_json else: - # append subsequent trees + # append subsequent trees to global model local_model = local_tree_json local_trees = local_model['learner']['gradient_booster']['model']['trees'] global_model = append_trees(global_model, local_trees) @@ -96,8 +115,6 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_): local_trees = json.loads(local_tree_bytearray) global_model = append_trees(global_model, local_trees) - ## Ensures that model is recoverable. TODO put in function - # Convert latest_trees to a JSON string global_model_json = json.dumps(global_model) global_model_bytes = global_model_json.encode('utf-8')