Skip to content

Commit

Permalink
set global model attribute to np array for consistency
Browse files Browse the repository at this point in the history
Signed-off-by: kta-intel <[email protected]>
  • Loading branch information
kta-intel committed Nov 18, 2024
1 parent 6aa9838 commit ac2a925
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions openfl/federated/task/runner_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def get_tensor_dict(self, with_opt_vars=False):
booster_float32_array = np.frombuffer(booster_array, dtype=np.uint8).astype(np.float32)
return {"local_tree": booster_float32_array}

global_model_booster_dict = json.loads(self.global_model)
global_model_byte_array = bytearray(self.global_model.astype(np.uint8).tobytes())
global_model_booster_dict = json.loads(global_model_byte_array)
num_global_trees = int(
global_model_booster_dict["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
"num_trees"
Expand Down Expand Up @@ -313,9 +314,10 @@ def set_tensor_dict(self, tensor_dict, with_opt_vars=False):
with_opt_vars (bool): N/A for XGBoost (Default=False).
"""
# The with_opt_vars argument is not used in this method
self.global_model = bytearray(tensor_dict["local_tree"].astype(np.uint8).tobytes())
self.global_model = tensor_dict["local_tree"]
global_model_byte_array = bytearray(self.global_model.astype(np.uint8).tobytes())
self.bst = xgb.Booster()
self.bst.load_model(self.global_model)
self.bst.load_model(global_model_byte_array)

def save_native(
self,
Expand Down

0 comments on commit ac2a925

Please sign in to comment.