Skip to content

Commit

Permalink
remove need to convert to float64
Browse files Browse the repository at this point in the history
Signed-off-by: kta-intel <[email protected]>
  • Loading branch information
kta-intel committed Nov 14, 2024
1 parent c7e2d76 commit 9d385a7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 93 deletions.
2 changes: 1 addition & 1 deletion openfl-workspace/xgb/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ aggregator :
init_state_path : save/init.pbuf
best_state_path : save/best.pbuf
last_state_path : save/last.pbuf
rounds_to_train : 10
rounds_to_train : 100
write_logs : false
delta_updates : false

Expand Down
40 changes: 7 additions & 33 deletions openfl/federated/task/runner_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,6 @@
import json
from sklearn.metrics import accuracy_score

import base64

def convert_back_to_json(booster_float32_array):
# Convert np.float32 array back to base64 string
booster_uint8_array = booster_float32_array.view(np.uint8)
booster_base64 = booster_uint8_array.tobytes().decode('utf-8')

# Decode base64 string back to original JSON string
booster_bytes = base64.b64decode(booster_base64)
booster_array = booster_bytes.decode('utf-8')
return booster_array

class XGBoostTaskRunner(TaskRunner):
def __init__(self, **kwargs):
Expand All @@ -40,23 +29,17 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
# This is a map of all the required tensors for each of the public
# functions in XGBoostTaskRunner
# self.bst = None # TODO
self.global_model = None # TODO
# self.params = kwargs['params'] # TODO
# self.num_rounds = kwargs['num_rounds'] # TODO

self.required_tensorkeys_for_function = {}
self.training_round_completed = False

def rebuild_model(self, input_tensor_dict):
if (isinstance(input_tensor_dict['local_tree'], np.ndarray) and input_tensor_dict['local_tree'].size != 0) \
or (not isinstance(input_tensor_dict['local_tree'], np.ndarray) and input_tensor_dict['local_tree'] is not None):
# if input_tensor_dict['local_tree'].size != 0: # check if it is empty (i.e. no model to build)
self.global_model = input_tensor_dict['local_tree'].view(np.uint8).tobytes().decode('utf-8')
self.global_model = base64.b64decode(self.global_model)
# self.global_model = bytearray(input_tensor_dict['local_tree']) #TODO
self.global_model = bytearray(input_tensor_dict['local_tree'].astype(np.uint8).tobytes())
self.bst = xgb.Booster()
self.bst.load_model(bytearray(self.global_model))
self.bst.load_model(self.global_model)

def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs):
"""Validate Task.
Expand Down Expand Up @@ -128,10 +111,6 @@ def train_task(
local_output_dict (dict): Tensors to maintain in the local
TensorDB.
"""
# self.rebuild_model(round_num, input_tensor_dict)
# set to "training" mode
# if round_num != 0:
# self.global_model = bytearray(input_tensor_dict)
self.rebuild_model(input_tensor_dict)
loader = self.data_loader.get_train_dmatrix()
metric = self.train_(loader)
Expand Down Expand Up @@ -202,31 +181,26 @@ def get_tensor_dict(self, with_opt_vars=False):
# For initializing tensor dict
return {'local_tree': np.array([], dtype=np.float32)}

booster_array = self.bst.save_raw('json').decode('utf-8')
booster_array = self.bst.save_raw('json')
booster_dict = json.loads(booster_array)

if (isinstance(self.global_model, np.ndarray) and self.global_model.size == 0) or self.global_model is None:
booster_bytes = booster_array.encode('utf-8')
booster_base64 = base64.b64encode(booster_bytes).decode('utf-8')

# Convert base64 string to np.float32 array
booster_float32_array = np.frombuffer(booster_base64.encode('utf-8'), dtype=np.uint8).view(np.float32)

booster_float32_array = np.frombuffer(booster_array, dtype=np.uint8).astype(np.float32)
return {'local_tree': booster_float32_array}

global_model_booster_dict = json.loads(self.global_model)
num_global_trees = int(global_model_booster_dict["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"])
num_total_trees = int(booster_dict["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"])

# Calculate the number of trees added in the latest training
num_latest_trees = num_total_trees - num_global_trees
latest_trees = booster_dict['learner']['gradient_booster']['model']['trees'][-num_latest_trees:]

# Convert latest_trees to a JSON string
latest_trees_json = json.dumps(latest_trees)

# Convert JSON string to np.float32 array
latest_trees_bytes = latest_trees_json.encode('utf-8')
latest_trees_base64 = base64.b64encode(latest_trees_bytes).decode('utf-8')
latest_trees_float32_array = np.frombuffer(latest_trees_base64.encode('utf-8'), dtype=np.uint8).view(np.float32)
latest_trees_float32_array = np.frombuffer(latest_trees_bytes, dtype=np.uint8).astype(np.float32)

return {'local_tree': latest_trees_float32_array}

Expand Down
66 changes: 7 additions & 59 deletions openfl/interface/aggregation_functions/fed_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

import json
from openfl.interface.aggregation_functions.core import AggregationFunction
from openfl.federated.task.runner_xgb import convert_back_to_json
import numpy as np
import base64

def get_global_model(iterator, target_round):
for item in iterator:
Expand All @@ -17,15 +15,6 @@ def get_global_model(iterator, target_round):
return item['nparray']
raise ValueError(f"No item found with tag 'model' and round {target_round}")

# def convert_back_to_json(booster_float32_array):
# # Convert np.float32 array back to base64 string
# booster_uint8_array = booster_float32_array.view(np.uint8)
# booster_base64 = booster_uint8_array.tobytes().decode('utf-8')

# # Decode base64 string back to original JSON string
# booster_bytes = base64.b64decode(booster_base64)
# booster_array = booster_bytes.decode('utf-8')
# return booster_array

def append_trees(global_model, local_trees):

Expand Down Expand Up @@ -87,7 +76,8 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_):

if (isinstance(global_model, np.ndarray) and global_model.size == 0) or global_model is None:
for local_tensor in local_tensors:
local_tree_json = json.loads(convert_back_to_json(local_tensor.tensor))
local_tree_bytearray = bytearray(local_tensor.tensor.astype(np.uint8).tobytes())
local_tree_json = json.loads(local_tree_bytearray)

if (isinstance(global_model, np.ndarray) and global_model.size == 0) or global_model is None:
# the first tree becomes the global model to append to
Expand All @@ -98,59 +88,17 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_):
local_trees = local_model['learner']['gradient_booster']['model']['trees']
global_model = append_trees(global_model, local_trees)
else:
global_model = json.loads(convert_back_to_json(global_model))
global_model_bytearray = bytearray(global_model.astype(np.uint8).tobytes())
global_model = json.loads(global_model_bytearray)

for local_tensor in local_tensors:
local_trees = json.loads(convert_back_to_json(local_tensor.tensor))
local_tree_bytearray = bytearray(local_tensor.tensor.astype(np.uint8).tobytes())
local_trees = json.loads(local_tree_bytearray)
global_model = append_trees(global_model, local_trees)

## Ensures that model is recoverable. TODO put in function
# Convert latest_trees to a JSON string
global_model_json = json.dumps(global_model)

# Convert JSON string to np.float32 array
global_model_bytes = global_model_json.encode('utf-8')
global_model_base64 = base64.b64encode(global_model_bytes).decode('utf-8')
global_model_float32_array = np.frombuffer(global_model_base64.encode('utf-8'), dtype=np.uint8).view(np.float32)

return global_model_float32_array

# # global_model = None
# import pdb; pdb.set_trace()
# global_model = get_global_model(db_iterator, fl_round)

# for local_tensor in local_tensors:
# local_tree_np_array = local_tensor.tensor[:-2]
# # local_tree_np_array = local_tensor.tensor['local_tree']
# local_tree_json = convert_back_to_json(local_tree_np_array)

# if global_model.size == 0:
# # the first tree becomes the global model to append to
# global_model = local_tree_json
# else:
# # append subsequent trees
# local_model = local_tree_json
# # Assertion to check if the original trees in the local model match the global model trees
# num_global_trees = int(local_tensor.tensor[-2])
# # num_global_trees = local_tensor.tensor['num_global_trees']
# verify_global_model(global_model, local_model, num_global_trees)

# num_global_trees = int(global_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"])
# num_latest_trees = int(local_tensor.tensor[-1])
# # num_latest_trees = local_tensor.tensor['num_latest_trees']
# local_trees = local_model['learner']['gradient_booster']['model']['trees'][-num_latest_trees:]

# global_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"] = str(
# num_global_trees + num_latest_trees
# )
# global_model["learner"]["gradient_booster"]["model"]["iteration_indptr"].append(
# num_global_trees + num_latest_trees
# )

# for new_tree in range(num_latest_trees):
# local_trees[new_tree]["id"] = num_global_trees + new_tree
# global_model["learner"]["gradient_booster"]["model"]["trees"].append(local_trees[new_tree])
# global_model["learner"]["gradient_booster"]["model"]["tree_info"].append(0)

# # TODO: this will probably be problematic, make sure that the conversion is working
# return bytearray(json.dumps(global_model, default=int), "utf-8")
return np.frombuffer(global_model_bytes, dtype=np.uint8).astype(np.float32)

0 comments on commit 9d385a7

Please sign in to comment.