From 6aa983867a1b8ea6889a8febe8265a45f0383651 Mon Sep 17 00:00:00 2001 From: kta-intel Date: Mon, 18 Nov 2024 11:44:37 -0800 Subject: [PATCH] add conversion check Signed-off-by: kta-intel --- openfl/federated/task/runner_xgb.py | 23 +++++++++++++++++++ .../aggregation_functions/fed_bagging.py | 11 +++++---- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/openfl/federated/task/runner_xgb.py b/openfl/federated/task/runner_xgb.py index 27ea9486ad..cda2563005 100644 --- a/openfl/federated/task/runner_xgb.py +++ b/openfl/federated/task/runner_xgb.py @@ -15,6 +15,27 @@ from openfl.utilities.split import split_tensor_dict_for_holdouts +def check_precision_loss(logger, converted_data, original_data): + """ + Checks for precision loss during conversion to float32 and back. + + Parameters: + logger (Logger): The logger object to log warnings. + converted_data (np.ndarray): The data that has been converted to float32. + original_data (list): The original data to be checked for precision loss. + """ + # Convert the float32 array back to bytes and decode to JSON + reconstructed_bytes = converted_data.astype(np.uint8).tobytes() + reconstructed_json = reconstructed_bytes.decode("utf-8") + reconstructed_data = json.loads(reconstructed_json) + + assert type(original_data) == type(reconstructed_data), "Reconstructed datatype does not match original." + + # Compare the original and reconstructed data + if original_data != reconstructed_data: + logger.warn("Precision loss detected during conversion.") + + class XGBoostTaskRunner(TaskRunner): def __init__(self, **kwargs): """ @@ -209,6 +230,8 @@ def get_tensor_dict(self, with_opt_vars=False): np.float32 ) + check_precision_loss(self.logger, latest_trees_float32_array, original_data=latest_trees) + return {"local_tree": latest_trees_float32_array} def get_required_tensorkeys_for_function(self, func_name, **kwargs): diff --git a/openfl/interface/aggregation_functions/fed_bagging.py b/openfl/interface/aggregation_functions/fed_bagging.py index aaec86fb24..081fa91a2e 100644 --- a/openfl/interface/aggregation_functions/fed_bagging.py +++ b/openfl/interface/aggregation_functions/fed_bagging.py @@ -5,11 +5,11 @@ """Federated Boostrap Aggregation for XGBoost module.""" import json - +from logging import getLogger import numpy as np from openfl.interface.aggregation_functions.core import AggregationFunction - +from openfl.federated.task.runner_xgb import check_precision_loss def get_global_model(iterator, target_round): """ @@ -95,7 +95,7 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_): Returns: bytearray: aggregated tensor """ - + logger = getLogger(__name__) global_model = get_global_model(db_iterator, fl_round) if ( @@ -127,4 +127,7 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_): global_model_json = json.dumps(global_model) global_model_bytes = global_model_json.encode("utf-8") - return np.frombuffer(global_model_bytes, dtype=np.uint8).astype(np.float32) + global_model_float32_array = np.frombuffer(global_model_bytes, dtype=np.uint8).astype(np.float32) + check_precision_loss(logger, global_model_float32_array, global_model) + + return global_model_float32_array