Skip to content

Commit

Permalink
fix model save
Browse files Browse the repository at this point in the history
Signed-off-by: kta-intel <[email protected]>
  • Loading branch information
kta-intel committed Nov 15, 2024
1 parent 9d385a7 commit ce4b34f
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 45 deletions.
97 changes: 58 additions & 39 deletions openfl/federated/task/runner_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,28 @@ def __init__(self, **kwargs):
**kwargs: Additional parameters to pass to the functions.
"""
super().__init__(**kwargs)
# This is a map of all the required tensors for each of the public
# functions in XGBoostTaskRunner
self.global_model = None # TODO

self.global_model = None
self.required_tensorkeys_for_function = {}
self.training_round_completed = False

def rebuild_model(self, input_tensor_dict):
"""
Rebuilds the model using the provided input tensor dictionary.
This method checks if the 'local_tree' key in the input tensor dictionary is either a non-empty numpy array
or a non-None value. If this condition is met, it updates the internal tensor dictionary with the provided input.
Parameters:
input_tensor_dict (dict): A dictionary containing tensor data. It must include the key 'local_tree', which can be:
- A non-empty numpy array
- Any non-None value
Returns:
None
"""
if (isinstance(input_tensor_dict['local_tree'], np.ndarray) and input_tensor_dict['local_tree'].size != 0) \
or (not isinstance(input_tensor_dict['local_tree'], np.ndarray) and input_tensor_dict['local_tree'] is not None):
self.global_model = bytearray(input_tensor_dict['local_tree'].astype(np.uint8).tobytes())
self.bst = xgb.Booster()
self.bst.load_model(self.global_model)
self.set_tensor_dict(input_tensor_dict)

def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs):
"""Validate Task.
Expand All @@ -50,7 +59,6 @@ def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs):
col_name (str): Name of the collaborator.
round_num (int): What round is it.
input_tensor_dict (dict): Required input tensors (for model).
use_tqdm (bool): Use tqdm to print a progress bar (Default=True).
**kwargs: Additional parameters.
Returns:
Expand Down Expand Up @@ -102,8 +110,6 @@ def train_task(
col_name (str): Name of the collaborator.
round_num (int): What round is it.
input_tensor_dict (dict): Required input tensors (for model).
use_tqdm (bool): Use tqdm to print a progress bar (Default=True).
epochs (int): The number of epochs to train.
**kwargs: Additional parameters.
Returns:
Expand Down Expand Up @@ -177,6 +183,21 @@ def train_task(
return global_tensor_dict, local_tensor_dict

def get_tensor_dict(self, with_opt_vars=False):
"""
Retrieves the tensor dictionary containing the model's tree structure.
This method returns a dictionary with the key 'local_tree', which contains the model's tree structure as a numpy array.
If the model has not been initialized (`self.bst` is None), it returns an empty numpy array.
If the global model is not set or is empty, it returns the entire model as a numpy array.
Otherwise, it returns only the trees added in the latest training session.
Parameters:
with_opt_vars (bool): N/A for XGBoost (Default=False).
Returns:
dict: A dictionary with the key 'local_tree' containing the model's tree structure as a numpy array.
"""

if self.bst is None:
# For initializing tensor dict
return {'local_tree': np.array([], dtype=np.float32)}
Expand All @@ -196,9 +217,7 @@ def get_tensor_dict(self, with_opt_vars=False):
num_latest_trees = num_total_trees - num_global_trees
latest_trees = booster_dict['learner']['gradient_booster']['model']['trees'][-num_latest_trees:]

# Convert latest_trees to a JSON string
latest_trees_json = json.dumps(latest_trees)
# Convert JSON string to np.float32 array
latest_trees_bytes = latest_trees_json.encode('utf-8')
latest_trees_float32_array = np.frombuffer(latest_trees_bytes, dtype=np.uint8).astype(np.float32)

Expand Down Expand Up @@ -289,33 +308,33 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False):
for tensor_name in local_model_dict_val
]

# def save_native(
# self,
# filepath,
# model_state_dict_key="model_state_dict",
# optimizer_state_dict_key="optimizer_state_dict",
# **kwargs,
# ):
# """Save model and optimizer states in a picked file specified by the
# filepath. model_/optimizer_state_dicts are stored in the keys provided.
# Uses pt.save().

# Args:
# filepath (str): Path to pickle file to be created by pt.save().
# model_state_dict_key (str): key for model state dict in pickled
# file.
# optimizer_state_dict_key (str): key for optimizer state dict in
# picked file.
# **kwargs: Additional parameters.

# Returns:
# None
# """
# pickle_dict = {
# model_state_dict_key: self.state_dict(),
# optimizer_state_dict_key: self.optimizer.state_dict(),
# }
# torch.save(pickle_dict, filepath)
def set_tensor_dict(self, tensor_dict, with_opt_vars=False):
"""Set the tensor dictionary.
Args:
tensor_dict (dict): The tensor dictionary.
with_opt_vars (bool): N/A for XGBoost (Default=False).
"""
# The with_opt_vars argument is not used in this method
self.global_model = bytearray(tensor_dict['local_tree'].astype(np.uint8).tobytes())
self.bst = xgb.Booster()
self.bst.load_model(self.global_model)

def save_native(
self,
filepath,
**kwargs,
):
"""Save XGB booster to file.
Args:
filepath (str): Path to pickle file to be created by booster.save_model().
**kwargs: Additional parameters.
Returns:
None
"""
self.bst.save_model(filepath)

def train_(self, train_dataloader) -> Metric:
"""Train model."""
Expand Down
29 changes: 23 additions & 6 deletions openfl/interface/aggregation_functions/fed_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,20 @@
"""Federated Boostrap Aggregation for XGBoost module."""

import json
from openfl.interface.aggregation_functions.core import AggregationFunction
import numpy as np
from openfl.interface.aggregation_functions.core import AggregationFunction

def get_global_model(iterator, target_round):
"""
Retrieves the global model for the specific round from an iterator.
Parameters:
iterator (iterable): An iterable containing items with 'tags' and 'round' keys.
target_round (int): The round number for which the global model is to be retrieved.
Returns:
np.ndarray: The numpy array representing the global model for the specified round.
"""
for item in iterator:
# Items tagged with ('model',) are the global model of that round
if 'tags' in item and item['tags'] == ('model',) and item['round'] == target_round:
Expand All @@ -17,7 +27,16 @@ def get_global_model(iterator, target_round):


def append_trees(global_model, local_trees):
"""
Appends local trees to the global model.
Parameters:
global_model (dict): A dictionary representing the global model.
local_trees (list): A list of dictionaries representing the local trees to be appended to the global model.
Returns:
dict: The updated global model with the local trees appended.
"""
num_global_trees = int(global_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"])
num_local_trees = len(local_trees)

Expand Down Expand Up @@ -71,7 +90,7 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_):
Returns:
bytearray: aggregated tensor
"""
# global_model = None

global_model = get_global_model(db_iterator, fl_round)

if (isinstance(global_model, np.ndarray) and global_model.size == 0) or global_model is None:
Expand All @@ -80,10 +99,10 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_):
local_tree_json = json.loads(local_tree_bytearray)

if (isinstance(global_model, np.ndarray) and global_model.size == 0) or global_model is None:
# the first tree becomes the global model to append to
# the first tree becomes the global model
global_model = local_tree_json
else:
# append subsequent trees
# append subsequent trees to global model
local_model = local_tree_json
local_trees = local_model['learner']['gradient_booster']['model']['trees']
global_model = append_trees(global_model, local_trees)
Expand All @@ -96,8 +115,6 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_):
local_trees = json.loads(local_tree_bytearray)
global_model = append_trees(global_model, local_trees)

## Ensures that model is recoverable. TODO put in function
# Convert latest_trees to a JSON string
global_model_json = json.dumps(global_model)
global_model_bytes = global_model_json.encode('utf-8')

Expand Down

0 comments on commit ce4b34f

Please sign in to comment.