Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable federated XGBoost using bootstrap aggregation in Task Runner #1151

Merged
merged 42 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
33d304f
initial xgboost workspace commit
kta-intel Nov 8, 2024
93dc8b4
updating taskrunner and aggregation function
kta-intel Nov 8, 2024
52fea84
runner updates
kta-intel Nov 12, 2024
1275fd6
logic for loader
kta-intel Nov 12, 2024
49f5cdf
enabling work
kta-intel Nov 13, 2024
ddece36
further enabling work
kta-intel Nov 14, 2024
c7e2d76
fix first round local validation
kta-intel Nov 14, 2024
9d385a7
remove need to convert to float64
kta-intel Nov 14, 2024
ce4b34f
fix model save
kta-intel Nov 15, 2024
70e4171
remove set_trace and fix spacing
kta-intel Nov 15, 2024
3d2df78
rename workspace and fix plan
kta-intel Nov 15, 2024
54cdc5e
fix lint
kta-intel Nov 15, 2024
51a0afa
more formatting fixes
kta-intel Nov 15, 2024
d3937ef
revert space removal
kta-intel Nov 15, 2024
dd2027c
Revert "revert space removal"
kta-intel Nov 15, 2024
e008e4a
revert changes on interface.plan
kta-intel Nov 15, 2024
3cbd5e5
remove from history. unchanged
kta-intel Nov 15, 2024
051d8fc
reverting back to fresh state for interface.plan
kta-intel Nov 15, 2024
58172c1
Merge branch 'securefederatedai:develop' into xgboost-fedbagging
kta-intel Nov 15, 2024
a8d9b59
move delta_updates below assigner in args
kta-intel Nov 15, 2024
5f1d909
add delta_update default to True, remove from yaml
kta-intel Nov 15, 2024
3670bd0
enable modin pandas
kta-intel Nov 16, 2024
dcfdd70
add DO NOT EDIT notice
kta-intel Nov 18, 2024
bd03eac
added docstrings
kta-intel Nov 18, 2024
326069d
set DEFAULT_PATH to cwd
kta-intel Nov 18, 2024
8a75cc5
fix docstrings and remove commented out lines
kta-intel Nov 18, 2024
450d8c3
change to use_delta_updates for readibility
kta-intel Nov 18, 2024
eecffe0
split test data for collaborators
kta-intel Nov 18, 2024
238448f
clean up methods
kta-intel Nov 18, 2024
16cd7e1
clean up taskrunner
kta-intel Nov 18, 2024
4c03932
remove conditional for unused condition
kta-intel Nov 18, 2024
6aa9838
add conversion check
kta-intel Nov 18, 2024
ac2a925
set global model attribute to np array for consistency
kta-intel Nov 18, 2024
d65def1
raise value error when model is empty when trying to set tensor dict
kta-intel Nov 18, 2024
63be874
remove conversion checker to avoid circular import issue
kta-intel Nov 18, 2024
b346b24
add docstring and more descriptive comments
kta-intel Nov 18, 2024
809b69b
formatting fix
kta-intel Nov 18, 2024
cf67f62
fixing import sorting
kta-intel Nov 18, 2024
acb89d5
format fix
kta-intel Nov 18, 2024
5794b70
remove unnecessarly files
kta-intel Nov 18, 2024
34f7d8a
format fix, comparing datatype
kta-intel Nov 18, 2024
837031b
Merge branch 'securefederatedai:develop' into xgboost-fedbagging
kta-intel Nov 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions openfl-workspace/workspace/plan/defaults/aggregator.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
template : openfl.component.Aggregator
settings :
db_store_rounds : 2
write_logs : true
db_store_rounds : 2
write_logs : true
21 changes: 21 additions & 0 deletions openfl-workspace/workspace/plan/defaults/tasks_xgb.yaml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question: Have you used specific formatters for yaml files?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied over the yaml from other workspaces as a template then ran bash shell/format.sh in the whole repo. Is there something additional that you recommend?

Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
aggregated_model_validation:
function : validate_task
kwargs :
apply : global
metrics :
- acc

locally_tuned_model_validation:
function : validate_task
kwargs :
apply : local
metrics :
- acc

train:
function : train_task
kwargs :
metrics :
- loss
aggregation_type :
template : openfl.interface.aggregation_functions.FedBaggingXGBoost
1 change: 1 addition & 0 deletions openfl-workspace/xgb_higgs/.workspace
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
current_plan_name: default
kta-intel marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 5 additions & 0 deletions openfl-workspace/xgb_higgs/plan/cols.yaml
kta-intel marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (C) 2024 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

# This file lists the collaborators associated with the federation. The list will be auto-populated during collaborator creation.
collaborators:
5 changes: 5 additions & 0 deletions openfl-workspace/xgb_higgs/plan/data.yaml
kta-intel marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (C) 2024 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

# This file specifies the local data directory associated with the respective collaborator. This will be auto-populated during collaborator creation
# collaborator_name,data_directory_path
1 change: 1 addition & 0 deletions openfl-workspace/xgb_higgs/plan/defaults
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../../workspace/plan/defaults
kta-intel marked this conversation as resolved.
Show resolved Hide resolved
51 changes: 51 additions & 0 deletions openfl-workspace/xgb_higgs/plan/plan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (C) 2024 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

aggregator :
defaults : plan/defaults/aggregator.yaml
template : openfl.component.aggregator.Aggregator
settings :
init_state_path : save/init.pbuf
best_state_path : save/best.pbuf
last_state_path : save/last.pbuf
rounds_to_train : 10
write_logs : false
delta_updates : false

collaborator :
defaults : plan/defaults/collaborator.yaml
template : openfl.component.collaborator.Collaborator
settings :
delta_updates : false
opt_treatment : RESET

data_loader :
defaults : plan/defaults/data_loader.yaml
template : src.dataloader.HiggsDataLoader
settings :
input_shape : 28

task_runner :
defaults : plan/defaults/task_runner.yaml
template : src.taskrunner.XGBoostRunner
settings :
params :
objective: binary:logistic
eval_metric: logloss
max_depth: 6
eta: 0.3
num_parallel_tree: 1

network :
defaults : plan/defaults/network.yaml
settings :
{}

assigner :
defaults : plan/defaults/assigner.yaml

tasks :
defaults : plan/defaults/tasks_xgb.yaml

compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml
3 changes: 3 additions & 0 deletions openfl-workspace/xgb_higgs/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
scikit-learn
xgboost
modin[all]
2 changes: 2 additions & 0 deletions openfl-workspace/xgb_higgs/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
30 changes: 30 additions & 0 deletions openfl-workspace/xgb_higgs/src/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (C) 2024 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between
# Intel Corporation and you.

from openfl.federated import XGBoostDataLoader
import os
import modin.pandas as pd

class HiggsDataLoader(XGBoostDataLoader):
def __init__(self, data_path, **kwargs):
super().__init__(**kwargs)
X_train, y_train, X_valid, y_valid = load_Higgs(
data_path, **kwargs
)
self.X_train = X_train
self.y_train = y_train
self.X_valid = X_valid
self.y_valid = y_valid


def load_Higgs(data_path, **kwargs):
train_data = pd.read_csv(os.path.join(data_path, 'train.csv'), header=None)
X_train = train_data.iloc[:, 1:].values
y_train = train_data.iloc[:, 0].values

valid_data = pd.read_csv(os.path.join(data_path, 'valid.csv'), header=None)
X_valid = valid_data.iloc[:, 1:].values
y_valid = valid_data.iloc[:, 0].values
kta-intel marked this conversation as resolved.
Show resolved Hide resolved

return X_train, y_train, X_valid, y_valid
psfoley marked this conversation as resolved.
Show resolved Hide resolved
94 changes: 94 additions & 0 deletions openfl-workspace/xgb_higgs/src/setup_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import sys
import os
import shutil
from logging import getLogger
from urllib.request import urlretrieve
from hashlib import sha384
from os import path, makedirs
from tqdm import tqdm
import modin.pandas as pd
import gzip
from sklearn.model_selection import train_test_split
import numpy as np

logger = getLogger(__name__)

"""HIGGS Dataset."""

URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz"
FILENAME = "HIGGS.csv.gz"
CSV_FILENAME = "HIGGS.csv"
CSV_SHA384 = 'b8b82e11a78b81601381420878ad42ba557291f394a88dc5293e4077c8363c87429639b120e299a2a9939c1f943b6a63'
DEFAULT_PATH = path.join(path.expanduser('~'), '.openfl', 'data')
kta-intel marked this conversation as resolved.
Show resolved Hide resolved

pbar = tqdm(total=None)

def report_hook(count, block_size, total_size):
"""Update progressbar."""
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)

def verify_sha384(file_path, expected_hash):
"""Verify the SHA-384 hash of a file."""
sha384_hash = sha384()
with open(file_path, 'rb') as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha384_hash.update(byte_block)
computed_hash = sha384_hash.hexdigest()
if computed_hash != expected_hash:
raise ValueError(f"SHA-384 hash mismatch: expected {expected_hash}, got {computed_hash}")
print(f"SHA-384 hash verified: {computed_hash}")

def setup_data(root: str = DEFAULT_PATH, **kwargs):
"""Initialize."""
makedirs(root, exist_ok=True)
filepath = path.join(root, FILENAME)
csv_filepath = path.join(root, CSV_FILENAME)
if not path.exists(filepath):
urlretrieve(URL, filepath, report_hook) # nosec
verify_sha384(filepath, CSV_SHA384)
# Extract the CSV file from the gzip file
with gzip.open(filepath, 'rb') as f_in:
with open(csv_filepath, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)

def main():
if len(sys.argv) < 2:
raise ValueError("Provide the number of collaborators")
src = 'higgs_data'
if os.path.exists(src):
shutil.rmtree(src)
setup_data(src)
collaborators = int(sys.argv[1])
print("Creating splits for {} collaborators".format(collaborators))

# Load the dataset
higgs_data = pd.read_csv(path.join(src, CSV_FILENAME), header=None)

# Split the dataset into features and labels
X = higgs_data.iloc[:, 1:].values
y = higgs_data.iloc[:, 0].values

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Combine X and y for train and test sets
train_data = pd.DataFrame(data=np.column_stack((y_train, X_train)))
teoparvanov marked this conversation as resolved.
Show resolved Hide resolved
test_data = pd.DataFrame(data=np.column_stack((y_test, X_test)))

# Split the training data into parts for each collaborator
for i in range(collaborators):
dst = f'data/{i+1}'
makedirs(dst, exist_ok=True)

# Split the training data for the current collaborator
split_train_data = train_data.iloc[i::collaborators]
split_train_data.to_csv(path.join(dst, 'train.csv'), index=False, header=False)

# Copy the test data for the current collaborator
test_data.to_csv(path.join(dst, 'valid.csv'), index=False, header=False)
kta-intel marked this conversation as resolved.
Show resolved Hide resolved

if __name__ == '__main__':
main()
56 changes: 56 additions & 0 deletions openfl-workspace/xgb_higgs/src/taskrunner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (C) 2024 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between
# Intel Corporation and you.

"""You may copy this file as the starting point of your own model."""
import numpy as np
import xgboost as xgb

from openfl.federated import XGBoostTaskRunner
from openfl.utilities import Metric
from sklearn.metrics import accuracy_score


class XGBoostRunner(XGBoostTaskRunner):
"""
Simple CNN for classification.

PyTorchTaskRunner inherits from nn.module, so you can define your model
in the same way that you would for PyTorch
"""
kta-intel marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, params=None, num_rounds=1, **kwargs):
"""Initialize.

Args:
**kwargs: Additional arguments to pass to the function

"""
super().__init__(**kwargs)

self.bst = None
self.params = params
self.num_rounds = num_rounds

def train_(self, train_dataloader) -> Metric:
"""Train model."""
dtrain = train_dataloader['dmatrix']
evals = [(dtrain, 'train')]
evals_result = {}

self.bst = xgb.train(self.params, dtrain, self.num_rounds, xgb_model=self.bst,
evals=evals, evals_result=evals_result, verbose_eval=False)

loss = evals_result['train']['logloss'][-1]
return Metric(name=self.params['eval_metric'], value=np.array(loss))

def validate_(self, validation_dataloader) -> Metric:
"""Validate model."""

dtest = validation_dataloader['dmatrix']
y_test = validation_dataloader['labels']
preds = self.bst.predict(dtest)
y_pred_binary = np.where(preds > 0.5, 1, 0)
acc = accuracy_score(y_test, y_pred_binary)

return Metric(name="accuracy", value=np.array(acc))
7 changes: 5 additions & 2 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
best_state_path,
last_state_path,
assigner,
delta_updates=True,
straggler_handling_policy=None,
rounds_to_train=256,
single_col_cert_common_name=None,
Expand Down Expand Up @@ -186,6 +187,8 @@ def __init__(
# Initialize a lock for thread safety
self.lock = Lock()

self.delta_updates = delta_updates

kta-intel marked this conversation as resolved.
Show resolved Hide resolved
def _load_initial_tensors(self):
"""Load all of the tensors required to begin federated learning.

Expand Down Expand Up @@ -801,7 +804,7 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result
# Create delta and save it in TensorDB
base_model_tk = TensorKey(tensor_name, origin, round_number, report, ("model",))
base_model_nparray = self.tensor_db.get_tensor_from_cache(base_model_tk)
if base_model_nparray is not None:
if base_model_nparray is not None and self.delta_updates:
delta_tk, delta_nparray = self.tensor_codec.generate_delta(
agg_tag_tk, agg_results, base_model_nparray
)
Expand Down Expand Up @@ -830,7 +833,7 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result
self.tensor_db.cache_tensor({decompressed_delta_tk: decompressed_delta_nparray})

# Apply delta (unless delta couldn't be created)
if base_model_nparray is not None:
if base_model_nparray is not None and self.delta_updates:
self.logger.debug("Applying delta for layer %s", decompressed_delta_tk[0])
new_model_tk, new_model_nparray = self.tensor_codec.apply_delta(
decompressed_delta_tk,
Expand Down
5 changes: 5 additions & 0 deletions openfl/federated/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from openfl.federated.data import PyTorchDataLoader
from openfl.federated.task import FederatedModel # NOQA
from openfl.federated.task import PyTorchTaskRunner
if importlib.util.find_spec("xgboost") is not None:
from openfl.federated.data import FederatedDataSet # NOQA
from openfl.federated.data import XGBoostDataLoader
from openfl.federated.task import FederatedModel # NOQA
from openfl.federated.task import XGBoostTaskRunner

__all__ = [
"Plan",
Expand Down
4 changes: 4 additions & 0 deletions openfl/federated/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@
if importlib.util.find_spec("torch") is not None:
from openfl.federated.data.federated_data import FederatedDataSet # NOQA
from openfl.federated.data.loader_pt import PyTorchDataLoader # NOQA

if importlib.util.find_spec("xgboost") is not None:
from openfl.federated.data.federated_data import FederatedDataSet # NOQA
from openfl.federated.data.loader_xgb import XGBoostDataLoader # NOQA
Loading
Loading