Skip to content

Commit

Permalink
add conversion check
Browse files Browse the repository at this point in the history
Signed-off-by: kta-intel <[email protected]>
  • Loading branch information
kta-intel committed Nov 18, 2024
1 parent 4c03932 commit 6aa9838
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
23 changes: 23 additions & 0 deletions openfl/federated/task/runner_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,27 @@
from openfl.utilities.split import split_tensor_dict_for_holdouts


def check_precision_loss(logger, converted_data, original_data):
"""
Checks for precision loss during conversion to float32 and back.
Parameters:
logger (Logger): The logger object to log warnings.
converted_data (np.ndarray): The data that has been converted to float32.
original_data (list): The original data to be checked for precision loss.
"""
# Convert the float32 array back to bytes and decode to JSON
reconstructed_bytes = converted_data.astype(np.uint8).tobytes()
reconstructed_json = reconstructed_bytes.decode("utf-8")
reconstructed_data = json.loads(reconstructed_json)

assert type(original_data) == type(reconstructed_data), "Reconstructed datatype does not match original."

# Compare the original and reconstructed data
if original_data != reconstructed_data:
logger.warn("Precision loss detected during conversion.")


class XGBoostTaskRunner(TaskRunner):
def __init__(self, **kwargs):
"""
Expand Down Expand Up @@ -209,6 +230,8 @@ def get_tensor_dict(self, with_opt_vars=False):
np.float32
)

check_precision_loss(self.logger, latest_trees_float32_array, original_data=latest_trees)

return {"local_tree": latest_trees_float32_array}

def get_required_tensorkeys_for_function(self, func_name, **kwargs):
Expand Down
11 changes: 7 additions & 4 deletions openfl/interface/aggregation_functions/fed_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
"""Federated Boostrap Aggregation for XGBoost module."""

import json

from logging import getLogger
import numpy as np

from openfl.interface.aggregation_functions.core import AggregationFunction

from openfl.federated.task.runner_xgb import check_precision_loss

def get_global_model(iterator, target_round):
"""
Expand Down Expand Up @@ -95,7 +95,7 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_):
Returns:
bytearray: aggregated tensor
"""

logger = getLogger(__name__)
global_model = get_global_model(db_iterator, fl_round)

if (
Expand Down Expand Up @@ -127,4 +127,7 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_):
global_model_json = json.dumps(global_model)
global_model_bytes = global_model_json.encode("utf-8")

return np.frombuffer(global_model_bytes, dtype=np.uint8).astype(np.float32)
global_model_float32_array = np.frombuffer(global_model_bytes, dtype=np.uint8).astype(np.float32)
check_precision_loss(logger, global_model_float32_array, global_model)

return global_model_float32_array

0 comments on commit 6aa9838

Please sign in to comment.