diff --git a/openfl-workspace/workspace/plan/defaults/aggregator.yaml b/openfl-workspace/workspace/plan/defaults/aggregator.yaml index 8b32cc986d..0bb76e099d 100644 --- a/openfl-workspace/workspace/plan/defaults/aggregator.yaml +++ b/openfl-workspace/workspace/plan/defaults/aggregator.yaml @@ -1,4 +1,4 @@ template : openfl.component.Aggregator settings : - db_store_rounds : 2 - write_logs : true + db_store_rounds : 2 + write_logs : true diff --git a/openfl-workspace/workspace/plan/defaults/tasks_xgb.yaml b/openfl-workspace/workspace/plan/defaults/tasks_xgb.yaml new file mode 100644 index 0000000000..7b14010eaa --- /dev/null +++ b/openfl-workspace/workspace/plan/defaults/tasks_xgb.yaml @@ -0,0 +1,21 @@ +aggregated_model_validation: + function : validate_task + kwargs : + apply : global + metrics : + - acc + +locally_tuned_model_validation: + function : validate_task + kwargs : + apply : local + metrics : + - acc + +train: + function : train_task + kwargs : + metrics : + - loss + aggregation_type : + template : openfl.interface.aggregation_functions.FedBaggingXGBoost \ No newline at end of file diff --git a/openfl-workspace/xgb_higgs/plan/cols.yaml b/openfl-workspace/xgb_higgs/plan/cols.yaml new file mode 100644 index 0000000000..b085067f50 --- /dev/null +++ b/openfl-workspace/xgb_higgs/plan/cols.yaml @@ -0,0 +1,5 @@ +# Copyright (C) 2024 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +# DO NOT EDIT: This file lists the collaborators associated with the federation. The list will be auto-populated during collaborator creation. +collaborators: diff --git a/openfl-workspace/xgb_higgs/plan/data.yaml b/openfl-workspace/xgb_higgs/plan/data.yaml new file mode 100644 index 0000000000..4b9d070127 --- /dev/null +++ b/openfl-workspace/xgb_higgs/plan/data.yaml @@ -0,0 +1,5 @@ +# Copyright (C) 2024 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +# DO NOT EDIT: This file specifies the local data directory associated with the respective collaborator. This will be auto-populated during collaborator creation +# collaborator_name,data_directory_path \ No newline at end of file diff --git a/openfl-workspace/xgb_higgs/plan/plan.yaml b/openfl-workspace/xgb_higgs/plan/plan.yaml new file mode 100644 index 0000000000..cab8710cb4 --- /dev/null +++ b/openfl-workspace/xgb_higgs/plan/plan.yaml @@ -0,0 +1,51 @@ +# Copyright (C) 2024 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +aggregator : + defaults : plan/defaults/aggregator.yaml + template : openfl.component.aggregator.Aggregator + settings : + init_state_path : save/init.pbuf + best_state_path : save/best.pbuf + last_state_path : save/last.pbuf + rounds_to_train : 10 + write_logs : false + use_delta_updates : false + +collaborator : + defaults : plan/defaults/collaborator.yaml + template : openfl.component.collaborator.Collaborator + settings : + delta_updates : false + opt_treatment : RESET + +data_loader : + defaults : plan/defaults/data_loader.yaml + template : src.dataloader.HiggsDataLoader + settings : + input_shape : 28 + +task_runner : + defaults : plan/defaults/task_runner.yaml + template : src.taskrunner.XGBoostRunner + settings : + params : + objective: binary:logistic + eval_metric: logloss + max_depth: 6 + eta: 0.3 + num_parallel_tree: 1 + +network : + defaults : plan/defaults/network.yaml + settings : + {} + +assigner : + defaults : plan/defaults/assigner.yaml + +tasks : + defaults : plan/defaults/tasks_xgb.yaml + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml \ No newline at end of file diff --git a/openfl-workspace/xgb_higgs/requirements.txt b/openfl-workspace/xgb_higgs/requirements.txt new file mode 100644 index 0000000000..26a78d72ac --- /dev/null +++ b/openfl-workspace/xgb_higgs/requirements.txt @@ -0,0 +1,3 @@ +modin[all] +scikit-learn +xgboost diff --git a/openfl-workspace/xgb_higgs/src/__init__.py b/openfl-workspace/xgb_higgs/src/__init__.py new file mode 100644 index 0000000000..916f3a44b2 --- /dev/null +++ b/openfl-workspace/xgb_higgs/src/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/openfl-workspace/xgb_higgs/src/dataloader.py b/openfl-workspace/xgb_higgs/src/dataloader.py new file mode 100644 index 0000000000..550ddcc47d --- /dev/null +++ b/openfl-workspace/xgb_higgs/src/dataloader.py @@ -0,0 +1,60 @@ +# Copyright (C) 2024 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between +# Intel Corporation and you. + +from openfl.federated import XGBoostDataLoader +import os +import modin.pandas as pd + +class HiggsDataLoader(XGBoostDataLoader): + """ + DataLoader for the Higgs dataset. + + This class inherits from XGBoostDataLoader and is responsible for loading + the Higgs dataset for training and validation. + + Attributes: + X_train (numpy.ndarray): Training features. + y_train (numpy.ndarray): Training labels. + X_valid (numpy.ndarray): Validation features. + y_valid (numpy.ndarray): Validation labels. + """ + def __init__(self, data_path, **kwargs): + super().__init__(**kwargs) + X_train, y_train, X_valid, y_valid = load_Higgs( + data_path, **kwargs + ) + self.X_train = X_train + self.y_train = y_train + self.X_valid = X_valid + self.y_valid = y_valid + + +def load_Higgs(data_path, **kwargs): + """ + Load the Higgs dataset from CSV files. + + The dataset is expected to be in two CSV files: 'train.csv' and 'valid.csv'. + The first column in each file represents the labels, and the remaining + columns represent the features. + + Args: + data_path (str): The directory path where the CSV files are located. + **kwargs: Additional keyword arguments. + + Returns: + tuple: A tuple containing four elements: + - X_train (numpy.ndarray): Training features. + - y_train (numpy.ndarray): Training labels. + - X_valid (numpy.ndarray): Validation features. + - y_valid (numpy.ndarray): Validation labels. + """ + train_data = pd.read_csv(os.path.join(data_path, 'train.csv'), header=None) + X_train = train_data.iloc[:, 1:].values + y_train = train_data.iloc[:, 0].values + + valid_data = pd.read_csv(os.path.join(data_path, 'valid.csv'), header=None) + X_valid = valid_data.iloc[:, 1:].values + y_valid = valid_data.iloc[:, 0].values + + return X_train, y_train, X_valid, y_valid diff --git a/openfl-workspace/xgb_higgs/src/setup_data.py b/openfl-workspace/xgb_higgs/src/setup_data.py new file mode 100644 index 0000000000..116d00cf66 --- /dev/null +++ b/openfl-workspace/xgb_higgs/src/setup_data.py @@ -0,0 +1,95 @@ +import sys +import os +import shutil +from logging import getLogger +from urllib.request import urlretrieve +from hashlib import sha384 +from os import path, makedirs +from tqdm import tqdm +import modin.pandas as pd +import gzip +from sklearn.model_selection import train_test_split +import numpy as np + +logger = getLogger(__name__) + +"""HIGGS Dataset.""" + +URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz" +FILENAME = "HIGGS.csv.gz" +CSV_FILENAME = "HIGGS.csv" +CSV_SHA384 = 'b8b82e11a78b81601381420878ad42ba557291f394a88dc5293e4077c8363c87429639b120e299a2a9939c1f943b6a63' +DEFAULT_PATH = path.join(os.getcwd(), 'data') + +pbar = tqdm(total=None) + +def report_hook(count, block_size, total_size): + """Update progressbar.""" + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) + +def verify_sha384(file_path, expected_hash): + """Verify the SHA-384 hash of a file.""" + sha384_hash = sha384() + with open(file_path, 'rb') as f: + for byte_block in iter(lambda: f.read(4096), b""): + sha384_hash.update(byte_block) + computed_hash = sha384_hash.hexdigest() + if computed_hash != expected_hash: + raise ValueError(f"SHA-384 hash mismatch: expected {expected_hash}, got {computed_hash}") + print(f"SHA-384 hash verified: {computed_hash}") + +def setup_data(root: str = DEFAULT_PATH, **kwargs): + """Initialize.""" + makedirs(root, exist_ok=True) + filepath = path.join(root, FILENAME) + csv_filepath = path.join(root, CSV_FILENAME) + if not path.exists(filepath): + urlretrieve(URL, filepath, report_hook) # nosec + verify_sha384(filepath, CSV_SHA384) + # Extract the CSV file from the gzip file + with gzip.open(filepath, 'rb') as f_in: + with open(csv_filepath, 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) + +def main(): + if len(sys.argv) < 2: + raise ValueError("Provide the number of collaborators") + src = 'higgs_data' + if os.path.exists(src): + shutil.rmtree(src) + setup_data(src) + collaborators = int(sys.argv[1]) + print("Creating splits for {} collaborators".format(collaborators)) + + # Load the dataset + higgs_data = pd.read_csv(path.join(src, CSV_FILENAME), header=None) + + # Split the dataset into features and labels + X = higgs_data.iloc[:, 1:].values + y = higgs_data.iloc[:, 0].values + + # Split the dataset into training and testing sets + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + + # Combine X and y for train and test sets + train_data = pd.DataFrame(data=np.column_stack((y_train, X_train))) + test_data = pd.DataFrame(data=np.column_stack((y_test, X_test))) + + # Split the training data into parts for each collaborator + for i in range(collaborators): + dst = f'data/{i+1}' + makedirs(dst, exist_ok=True) + + # Split the training data for the current collaborator + split_train_data = train_data.iloc[i::collaborators] + split_train_data.to_csv(path.join(dst, 'train.csv'), index=False, header=False) + + # Split the test data for the current collaborator + split_test_data = test_data.iloc[i::collaborators] + split_test_data.to_csv(path.join(dst, 'valid.csv'), index=False, header=False) + +if __name__ == '__main__': + main() diff --git a/openfl-workspace/xgb_higgs/src/taskrunner.py b/openfl-workspace/xgb_higgs/src/taskrunner.py new file mode 100644 index 0000000000..3e030fb7ac --- /dev/null +++ b/openfl-workspace/xgb_higgs/src/taskrunner.py @@ -0,0 +1,77 @@ +# Copyright (C) 2024 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between +# Intel Corporation and you. + +"""You may copy this file as the starting point of your own model.""" +import numpy as np +import xgboost as xgb + +from openfl.federated import XGBoostTaskRunner +from openfl.utilities import Metric +from sklearn.metrics import accuracy_score + + +class XGBoostRunner(XGBoostTaskRunner): + """ + A class to run XGBoost training and validation tasks. + + This class inherits from XGBoostTaskRunner and provides methods to train and validate + an XGBoost model using federated learning. + + Attributes: + bst (xgb.Booster): The XGBoost model. + params (dict): Parameters for the XGBoost model. + num_rounds (int): Number of boosting rounds. + """ + def __init__(self, params=None, num_rounds=1, **kwargs): + """ + Initialize the XGBoostRunner. + + Args: + params (dict, optional): Parameters for the XGBoost model. Defaults to None. + num_rounds (int, optional): Number of boosting rounds. Defaults to 1. + **kwargs: Additional arguments to pass to the function. + """ + super().__init__(**kwargs) + + self.bst = None + self.params = params + self.num_rounds = num_rounds + + def train_(self, data) -> Metric: + """ + Train the XGBoost model. + + Args: + train_dataloader (dict): A dictionary containing the training data with keys 'dmatrix'. + + Returns: + Metric: A Metric object containing the training loss. + """ + dtrain = data['dmatrix'] + evals = [(dtrain, 'train')] + evals_result = {} + + self.bst = xgb.train(self.params, dtrain, self.num_rounds, xgb_model=self.bst, + evals=evals, evals_result=evals_result, verbose_eval=False) + + loss = evals_result['train']['logloss'][-1] + return Metric(name=self.params['eval_metric'], value=np.array(loss)) + + def validate_(self, data) -> Metric: + """ + Validate the XGBoost model. + + Args: + validation_dataloader (dict): A dictionary containing the validation data with keys 'dmatrix' and 'labels'. + + Returns: + Metric: A Metric object containing the validation accuracy. + """ + dtest = data['dmatrix'] + y_test = data['labels'] + preds = self.bst.predict(dtest) + y_pred_binary = np.where(preds > 0.5, 1, 0) + acc = accuracy_score(y_test, y_pred_binary) + + return Metric(name="accuracy", value=np.array(acc)) diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index 0ec816276b..e1c61f8ce3 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -69,6 +69,7 @@ def __init__( best_state_path, last_state_path, assigner, + use_delta_updates=True, straggler_handling_policy=None, rounds_to_train=256, single_col_cert_common_name=None, @@ -186,6 +187,8 @@ def __init__( # Initialize a lock for thread safety self.lock = Lock() + self.use_delta_updates = use_delta_updates + def _load_initial_tensors(self): """Load all of the tensors required to begin federated learning. @@ -801,7 +804,7 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result # Create delta and save it in TensorDB base_model_tk = TensorKey(tensor_name, origin, round_number, report, ("model",)) base_model_nparray = self.tensor_db.get_tensor_from_cache(base_model_tk) - if base_model_nparray is not None: + if base_model_nparray is not None and self.use_delta_updates: delta_tk, delta_nparray = self.tensor_codec.generate_delta( agg_tag_tk, agg_results, base_model_nparray ) @@ -830,7 +833,7 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result self.tensor_db.cache_tensor({decompressed_delta_tk: decompressed_delta_nparray}) # Apply delta (unless delta couldn't be created) - if base_model_nparray is not None: + if base_model_nparray is not None and self.use_delta_updates: self.logger.debug("Applying delta for layer %s", decompressed_delta_tk[0]) new_model_tk, new_model_nparray = self.tensor_codec.apply_delta( decompressed_delta_tk, diff --git a/openfl/federated/__init__.py b/openfl/federated/__init__.py index 07c0ef8e6e..ea24e0ddfa 100644 --- a/openfl/federated/__init__.py +++ b/openfl/federated/__init__.py @@ -20,6 +20,11 @@ from openfl.federated.data import PyTorchDataLoader from openfl.federated.task import FederatedModel # NOQA from openfl.federated.task import PyTorchTaskRunner +if importlib.util.find_spec("xgboost") is not None: + from openfl.federated.data import FederatedDataSet # NOQA + from openfl.federated.data import XGBoostDataLoader + from openfl.federated.task import FederatedModel # NOQA + from openfl.federated.task import XGBoostTaskRunner __all__ = [ "Plan", diff --git a/openfl/federated/data/__init__.py b/openfl/federated/data/__init__.py index b61d6d24a3..91bb604b62 100644 --- a/openfl/federated/data/__init__.py +++ b/openfl/federated/data/__init__.py @@ -23,3 +23,7 @@ if importlib.util.find_spec("torch") is not None: from openfl.federated.data.federated_data import FederatedDataSet # NOQA from openfl.federated.data.loader_pt import PyTorchDataLoader # NOQA + +if importlib.util.find_spec("xgboost") is not None: + from openfl.federated.data.federated_data import FederatedDataSet # NOQA + from openfl.federated.data.loader_xgb import XGBoostDataLoader # NOQA diff --git a/openfl/federated/data/loader_xgb.py b/openfl/federated/data/loader_xgb.py new file mode 100644 index 0000000000..46087392b4 --- /dev/null +++ b/openfl/federated/data/loader_xgb.py @@ -0,0 +1,88 @@ +import xgboost as xgb + +from openfl.federated.data.loader import DataLoader + + +class XGBoostDataLoader(DataLoader): + """A class used to represent a Data Loader for XGBoost models. + + Attributes: + batch_size (int): Size of batches used for all data loaders. + X_train (np.array): Training features. + y_train (np.array): Training labels. + X_valid (np.array): Validation features. + y_valid (np.array): Validation labels. + random_seed (int, optional): Random seed for data shuffling. + """ + + def __init__(self, batch_size=None, random_seed=None, **kwargs): + """Initializes the XGBoostDataLoader object with the batch size, random + seed, and any additional arguments. + + Args: + batch_size (int): The size of batches used for all data loaders. + random_seed (int, optional): Random seed for data shuffling. + kwargs: Additional arguments to pass to the function. + """ + self.batch_size = batch_size + self.X_train = None + self.y_train = None + self.X_valid = None + self.y_valid = None + self.random_seed = random_seed + + # Child classes should have init signature: + # (self, batch_size, **kwargs), should call this __init__ and then + # define self.X_train, self.y_train, self.X_valid, and self.y_valid + + def get_feature_shape(self): + """Returns the shape of an example feature array. + + Returns: + tuple: The shape of an example feature array. + """ + return self.X_train[0].shape + + def get_train_data_size(self): + """Returns the total number of training samples. + + Returns: + int: The total number of training samples. + """ + return self.X_train.shape[0] + + def get_valid_data_size(self): + """Returns the total number of validation samples. + + Returns: + int: The total number of validation samples. + """ + return self.X_valid.shape[0] + + def get_dmatrix(self, X, y): + """Returns the DMatrix for the given data. + + Args: + X (np.array): The input data. + y (np.array): The label data. + + Returns: + xgb.DMatrix: The DMatrix object for the given data. + """ + return xgb.DMatrix(data=X, label=y) + + def get_train_dmatrix(self): + """Returns the DMatrix for the training data. + + Returns: + xgb.DMatrix: The DMatrix object for the training data. + """ + return {"dmatrix": self.get_dmatrix(self.X_train, self.y_train), "labels": self.y_train} + + def get_valid_dmatrix(self): + """Returns the DMatrix for the validation data. + + Returns: + xgb.DMatrix: The DMatrix object for the validation data. + """ + return {"dmatrix": self.get_dmatrix(self.X_valid, self.y_valid), "labels": self.y_valid} diff --git a/openfl/federated/task/__init__.py b/openfl/federated/task/__init__.py index cc5bb9429b..8b29264128 100644 --- a/openfl/federated/task/__init__.py +++ b/openfl/federated/task/__init__.py @@ -22,3 +22,6 @@ if importlib.util.find_spec("torch") is not None: from openfl.federated.task.fl_model import FederatedModel # NOQA from openfl.federated.task.runner_pt import PyTorchTaskRunner # NOQA +if importlib.util.find_spec("xgboost") is not None: + from openfl.federated.task.fl_model import FederatedModel # NOQA + from openfl.federated.task.runner_xgb import XGBoostTaskRunner # NOQA diff --git a/openfl/federated/task/runner_xgb.py b/openfl/federated/task/runner_xgb.py new file mode 100644 index 0000000000..222b8c613f --- /dev/null +++ b/openfl/federated/task/runner_xgb.py @@ -0,0 +1,391 @@ +# Copyright 2020-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +"""XGBoostTaskRunner module.""" + +import json + +import numpy as np +import xgboost as xgb +from sklearn.metrics import accuracy_score + +from openfl.federated.task.runner import TaskRunner +from openfl.utilities import Metric, TensorKey, change_tags +from openfl.utilities.split import split_tensor_dict_for_holdouts + + +def check_precision_loss(logger, converted_data, original_data): + """ + Checks for precision loss during conversion to float32 and back. + + Parameters: + logger (Logger): The logger object to log warnings. + converted_data (np.ndarray): The data that has been converted to float32. + original_data (list): The original data to be checked for precision loss. + """ + # Convert the float32 array back to bytes and decode to JSON + reconstructed_bytes = converted_data.astype(np.uint8).tobytes() + reconstructed_json = reconstructed_bytes.decode("utf-8") + reconstructed_data = json.loads(reconstructed_json) + + assert type(original_data) is type( + reconstructed_data + ), "Reconstructed datatype does not match original." + + # Compare the original and reconstructed data + if original_data != reconstructed_data: + logger.warn("Precision loss detected during conversion.") + + +class XGBoostTaskRunner(TaskRunner): + def __init__(self, **kwargs): + """ + A class to manage XGBoost tasks in a federated learning environment. + + This class inherits from TaskRunner and provides methods to initialize and manage + the global model and required tensor keys for XGBoost tasks. + + Attributes: + global_model (xgb.Booster): The global XGBoost model. + required_tensorkeys_for_function (dict): A dictionary to store required tensor keys for each function. + """ + super().__init__(**kwargs) + self.global_model = None + self.required_tensorkeys_for_function = {} + + def rebuild_model(self, input_tensor_dict): + """ + Rebuilds the model using the provided input tensor dictionary. + + This method checks if the 'local_tree' key in the input tensor dictionary is either a non-empty numpy array + If this condition is met, it updates the internal tensor dictionary with the provided input. + + Parameters: + input_tensor_dict (dict): A dictionary containing tensor data. It must include the key 'local_tree' + + Returns: + None + """ + if ( + isinstance(input_tensor_dict["local_tree"], np.ndarray) + and input_tensor_dict["local_tree"].size != 0 + ): + self.set_tensor_dict(input_tensor_dict) + + def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs): + """Validate Task. + + Run validation of the model on the local data. + + Args: + col_name (str): Name of the collaborator. + round_num (int): What round is it. + input_tensor_dict (dict): Required input tensors (for model). + **kwargs: Additional parameters. + + Returns: + global_output_dict (dict): Tensors to send back to the aggregator. + local_output_dict (dict): Tensors to maintain in the local TensorDB. + """ + data = self.data_loader.get_valid_dmatrix() + + # during agg validation, self.bst will still be None. during local validation, it will have a value - no need to rebuild + if self.bst is None: + self.rebuild_model(input_tensor_dict) + + # if self.bst is still None after rebuilding, then there was no initial global model, so set metric to 0 + if self.bst is None: + # for first round agg validation, there is no model so set metric to 0 + # TODO: this is not robust, especially if using a loss metric + metric = Metric(name="accuracy", value=np.array(0)) + else: + metric = self.validate_(data) + + origin = col_name + suffix = "validate" + if kwargs["apply"] == "local": + suffix += "_local" + else: + suffix += "_agg" + tags = ("metric",) + tags = change_tags(tags, add_field=suffix) + + # validate function + output_tensor_dict = {TensorKey(metric.name, origin, round_num, True, tags): metric.value} + + # Empty list represents metrics that should only be stored locally + return output_tensor_dict, {} + + def train_task( + self, + col_name, + round_num, + input_tensor_dict, + **kwargs, + ): + """Train batches task. + + Train the model on the requested number of batches. + + Args: + col_name (str): Name of the collaborator. + round_num (int): What round is it. + input_tensor_dict (dict): Required input tensors (for model). + **kwargs: Additional parameters. + + Returns: + global_output_dict (dict): Tensors to send back to the aggregator. + local_output_dict (dict): Tensors to maintain in the local + TensorDB. + """ + self.rebuild_model(input_tensor_dict) + data = self.data_loader.get_train_dmatrix() + metric = self.train_(data) + # Output metric tensors (scalar) + origin = col_name + tags = ("trained",) + output_metric_dict = { + TensorKey(metric.name, origin, round_num, True, ("metric",)): metric.value + } + + # output model tensors (Doesn't include TensorKey) + output_model_dict = self.get_tensor_dict() + global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( + self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs + ) + + # Create global tensorkeys + global_tensorkey_model_dict = { + TensorKey(tensor_name, origin, round_num, False, tags): nparray + for tensor_name, nparray in global_model_dict.items() + } + # Create tensorkeys that should stay local + local_tensorkey_model_dict = { + TensorKey(tensor_name, origin, round_num, False, tags): nparray + for tensor_name, nparray in local_model_dict.items() + } + # The train/validate aggregated function of the next round will look + # for the updated model parameters. + # This ensures they will be resolved locally + next_local_tensorkey_model_dict = { + TensorKey(tensor_name, origin, round_num + 1, False, ("model",)): nparray + for tensor_name, nparray in local_model_dict.items() + } + + global_tensor_dict = { + **output_metric_dict, + **global_tensorkey_model_dict, + } + local_tensor_dict = { + **local_tensorkey_model_dict, + **next_local_tensorkey_model_dict, + } + + return global_tensor_dict, local_tensor_dict + + def get_tensor_dict(self, with_opt_vars=False): + """ + Retrieves the tensor dictionary containing the model's tree structure. + + This method returns a dictionary with the key 'local_tree', which contains the model's tree structure as a numpy array. + If the model has not been initialized (`self.bst` is None), it returns an empty numpy array. + If the global model is not set or is empty, it returns the entire model as a numpy array. + Otherwise, it returns only the trees added in the latest training session. + + Parameters: + with_opt_vars (bool): N/A for XGBoost (Default=False). + + Returns: + dict: A dictionary with the key 'local_tree' containing the model's tree structure as a numpy array. + """ + + if self.bst is None: + # For initializing tensor dict + return {"local_tree": np.array([], dtype=np.float32)} + + booster_array = self.bst.save_raw("json") + booster_dict = json.loads(booster_array) + + if ( + isinstance(self.global_model, np.ndarray) and self.global_model.size == 0 + ) or self.global_model is None: + booster_float32_array = np.frombuffer(booster_array, dtype=np.uint8).astype(np.float32) + return {"local_tree": booster_float32_array} + + global_model_byte_array = bytearray(self.global_model.astype(np.uint8).tobytes()) + global_model_booster_dict = json.loads(global_model_byte_array) + num_global_trees = int( + global_model_booster_dict["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ + "num_trees" + ] + ) + num_total_trees = int( + booster_dict["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"] + ) + + # Calculate the number of trees added in the latest training + num_latest_trees = num_total_trees - num_global_trees + latest_trees = booster_dict["learner"]["gradient_booster"]["model"]["trees"][ + -num_latest_trees: + ] + + latest_trees_json = json.dumps(latest_trees) + latest_trees_bytes = latest_trees_json.encode("utf-8") + latest_trees_float32_array = np.frombuffer(latest_trees_bytes, dtype=np.uint8).astype( + np.float32 + ) + + check_precision_loss(self.logger, latest_trees_float32_array, original_data=latest_trees) + + return {"local_tree": latest_trees_float32_array} + + def get_required_tensorkeys_for_function(self, func_name, **kwargs): + """Get the required tensors for specified function that could be called + as part of a task. By default, this is just all of the layers and + optimizer of the model. + + Args: + func_name (str): The function name. + + Returns: + list : [TensorKey]. + """ + if func_name == "validate_task": + local_model = "apply=" + str(kwargs["apply"]) + return self.required_tensorkeys_for_function[func_name][local_model] + else: + return self.required_tensorkeys_for_function[func_name] + + def initialize_tensorkeys_for_functions(self, with_opt_vars=False): + """Set the required tensors for all publicly accessible task methods. + + By default, this is just all of the layers and optimizer of the model. + Custom tensors should be added to this function. + + Args: + with_opt_vars (bool): with_opt_vars (bool): N/A for XGBoost (Default=False). + + Returns: + None + """ + output_model_dict = self.get_tensor_dict() + global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( + self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs + ) + global_model_dict_val = global_model_dict + local_model_dict_val = local_model_dict + + self.required_tensorkeys_for_function["train_task"] = [ + TensorKey(tensor_name, "GLOBAL", 0, False, ("model",)) + for tensor_name in global_model_dict + ] + self.required_tensorkeys_for_function["train_task"] += [ + TensorKey(tensor_name, "LOCAL", 0, False, ("model",)) + for tensor_name in local_model_dict + ] + + self.required_tensorkeys_for_function["train_task"] = [ + TensorKey(tensor_name, "GLOBAL", 0, False, ("model",)) + for tensor_name in global_model_dict + ] + self.required_tensorkeys_for_function["train_task"] += [ + TensorKey(tensor_name, "LOCAL", 0, False, ("model",)) + for tensor_name in local_model_dict + ] + + # Validation may be performed on local or aggregated (global) model, + # so there is an extra lookup dimension for kwargs + self.required_tensorkeys_for_function["validate_task"] = {} + # TODO This is not stateless. The optimizer will not be + self.required_tensorkeys_for_function["validate_task"]["apply=local"] = [ + TensorKey(tensor_name, "LOCAL", 0, False, ("trained",)) + for tensor_name in {**global_model_dict_val, **local_model_dict_val} + ] + self.required_tensorkeys_for_function["validate_task"]["apply=global"] = [ + TensorKey(tensor_name, "GLOBAL", 0, False, ("model",)) + for tensor_name in global_model_dict_val + ] + self.required_tensorkeys_for_function["validate_task"]["apply=global"] += [ + TensorKey(tensor_name, "LOCAL", 0, False, ("model",)) + for tensor_name in local_model_dict_val + ] + + def set_tensor_dict(self, tensor_dict, with_opt_vars=False): + """Set the tensor dictionary. + + Args: + tensor_dict (dict): The tensor dictionary. + with_opt_vars (bool): N/A for XGBoost (Default=False). + """ + # The with_opt_vars argument is not used in this method + self.global_model = tensor_dict["local_tree"] + if ( + isinstance(self.global_model, np.ndarray) and self.global_model.size == 0 + ) or self.global_model is None: + raise ValueError("The model does not exist or is empty.") + else: + global_model_byte_array = bytearray(self.global_model.astype(np.uint8).tobytes()) + self.bst = xgb.Booster() + self.bst.load_model(global_model_byte_array) + + def save_native( + self, + filepath, + **kwargs, + ): + """Save XGB booster to file. + + Args: + filepath (str): Path to pickle file to be created by booster.save_model(). + **kwargs: Additional parameters. + + Returns: + None + """ + self.bst.save_model(filepath) + + def train_(self, data) -> Metric: + """ + Train the XGBoost model. + + Args: + train_dataloader (dict): A dictionary containing the training data with keys 'dmatrix'. + + Returns: + Metric: A Metric object containing the training loss. + """ + dtrain = data["dmatrix"] + evals = [(dtrain, "train")] + evals_result = {} + + self.bst = xgb.train( + self.params, + dtrain, + self.num_rounds, + xgb_model=self.bst, + evals=evals, + evals_result=evals_result, + verbose_eval=False, + ) + + loss = evals_result["train"]["logloss"][-1] + return Metric(name=self.loss_fn.__name__, value=np.array(loss)) + + def validate_(self, data) -> Metric: + """ + Validate the XGBoost model. + + Args: + validation_dataloader (dict): A dictionary containing the validation data with keys 'dmatrix' and 'labels'. + + Returns: + Metric: A Metric object containing the validation accuracy. + """ + dtest = data["dmatrix"] + y_test = data["labels"] + preds = self.bst.predict(dtest) + y_pred_binary = np.where(preds > 0.5, 1, 0) + acc = accuracy_score(y_test, y_pred_binary) + + return Metric(name="accuracy", value=np.array(acc)) diff --git a/openfl/interface/aggregation_functions/__init__.py b/openfl/interface/aggregation_functions/__init__.py index 39132eb9f6..0ee32655c6 100644 --- a/openfl/interface/aggregation_functions/__init__.py +++ b/openfl/interface/aggregation_functions/__init__.py @@ -7,6 +7,7 @@ ) from openfl.interface.aggregation_functions.adam_adaptive_aggregation import AdamAdaptiveAggregation from openfl.interface.aggregation_functions.core import AggregationFunction +from openfl.interface.aggregation_functions.fed_bagging import FedBaggingXGBoost from openfl.interface.aggregation_functions.fedcurv_weighted_average import FedCurvWeightedAverage from openfl.interface.aggregation_functions.geometric_median import GeometricMedian from openfl.interface.aggregation_functions.median import Median diff --git a/openfl/interface/aggregation_functions/fed_bagging.py b/openfl/interface/aggregation_functions/fed_bagging.py new file mode 100644 index 0000000000..2e42072c66 --- /dev/null +++ b/openfl/interface/aggregation_functions/fed_bagging.py @@ -0,0 +1,142 @@ +# Copyright 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +"""Federated Boostrap Aggregation for XGBoost module.""" + +import json + +import numpy as np + +from openfl.interface.aggregation_functions.core import AggregationFunction + + +def get_global_model(iterator, target_round): + """ + Retrieves the global model for the specific round from an iterator. + + Parameters: + iterator (iterable): An iterable containing items with 'tags' and 'round' keys. + target_round (int): The round number for which the global model is to be retrieved. + + Returns: + np.ndarray: The numpy array representing the global model for the specified round. + """ + for item in iterator: + # Items tagged with ('model',) are the global model of that round + if "tags" in item and item["tags"] == ("model",) and item["round"] == target_round: + return item["nparray"] + raise ValueError(f"No item found with tag 'model' and round {target_round}") + + +def append_trees(global_model, local_trees): + """ + Appends local trees to the global model. + + Parameters: + global_model (dict): A dictionary representing the global model. + local_trees (list): A list of dictionaries representing the local trees to be appended to the global model. + + Returns: + dict: The updated global model with the local trees appended. + """ + num_global_trees = int( + global_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"] + ) + num_local_trees = len(local_trees) + + global_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"] = str( + num_global_trees + num_local_trees + ) + global_model["learner"]["gradient_booster"]["model"]["iteration_indptr"].append( + num_global_trees + num_local_trees + ) + for new_tree in range(num_local_trees): + local_trees[new_tree]["id"] = num_global_trees + new_tree + global_model["learner"]["gradient_booster"]["model"]["trees"].append(local_trees[new_tree]) + global_model["learner"]["gradient_booster"]["model"]["tree_info"].append(0) + + return global_model + + +class FedBaggingXGBoost(AggregationFunction): + """ + Federated Bootstrap Aggregation for XGBoost. + + This class implements a federated learning aggregation function specifically + designed for XGBoost models. It aggregates local model updates (trees) from + multiple collaborators into a global model using a bagging approach. + """ + + def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_): + """Aggregate tensors. + + Args: + local_tensors (list[openfl.utilities.LocalTensor]): List of local + tensors to aggregate. + db_iterator: iterator over history of all tensors. Columns: + - 'tensor_name': name of the tensor. + Examples for `torch.nn.Module`s: 'conv1.weight','fc2.bias'. + - 'round': 0-based number of round corresponding to this + tensor. + - 'tags': tuple of tensor tags. Tags that can appear: + - 'model' indicates that the tensor is a model parameter. + - 'trained' indicates that tensor is a part of a training + result. + These tensors are passed to the aggregator node after + local learning. + - 'aggregated' indicates that tensor is a result of + aggregation. + These tensors are sent to collaborators for the next + round. + - 'delta' indicates that value is a difference between + rounds for a specific tensor. + also one of the tags is a collaborator name + if it corresponds to a result of a local task. + + - 'nparray': value of the tensor. + tensor_name: name of the tensor + fl_round: round number + tags: tuple of tags for this tensor + Returns: + bytearray: aggregated tensor + """ + global_model = get_global_model(db_iterator, fl_round) + + if ( + isinstance(global_model, np.ndarray) and global_model.size == 0 + ) or global_model is None: + # if there is no global model, use the first local model as the global model + for local_tensor in local_tensors: + local_tree_bytearray = bytearray(local_tensor.tensor.astype(np.uint8).tobytes()) + local_tree_json = json.loads(local_tree_bytearray) + + if ( + isinstance(global_model, np.ndarray) and global_model.size == 0 + ) or global_model is None: + # the first tree becomes the global model + global_model = local_tree_json + else: + # append subsequent trees to global model + local_model = local_tree_json + local_trees = local_model["learner"]["gradient_booster"]["model"]["trees"] + global_model = append_trees(global_model, local_trees) + else: + global_model_bytearray = bytearray(global_model.astype(np.uint8).tobytes()) + # convert the global model to a dictionary + global_model = json.loads(global_model_bytearray) + + for local_tensor in local_tensors: + # append trees to global model + local_tree_bytearray = bytearray(local_tensor.tensor.astype(np.uint8).tobytes()) + local_trees = json.loads(local_tree_bytearray) + global_model = append_trees(global_model, local_trees) + + global_model_json = json.dumps(global_model) + global_model_bytes = global_model_json.encode("utf-8") + + global_model_float32_array = np.frombuffer(global_model_bytes, dtype=np.uint8).astype( + np.float32 + ) + + return global_model_float32_array