Skip to content

Commit

Permalink
clean up methods
Browse files Browse the repository at this point in the history
Signed-off-by: kta-intel <[email protected]>
  • Loading branch information
kta-intel committed Nov 18, 2024
1 parent eecffe0 commit 238448f
Showing 1 changed file with 3 additions and 81 deletions.
84 changes: 3 additions & 81 deletions openfl/federated/data/loader_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import numpy as np
import xgboost as xgb
from openfl.federated.data.loader import DataLoader


class XGBoostDataLoader:
class XGBoostDataLoader(DataLoader):
"""A class used to represent a Data Loader for XGBoost models.
Attributes:
Expand Down Expand Up @@ -44,37 +45,6 @@ def get_feature_shape(self):
"""
return self.X_train[0].shape

def get_train_loader(self, batch_size=None, num_batches=None):
"""Returns the data loader for the training data.
Args:
batch_size (int, optional): The batch size for the data loader
(default is None).
num_batches (int, optional): The number of batches for the data
loader (default is None).
Returns:
generator: The generator object for the training data.
"""
return self._get_batch_generator(
X=self.X_train,
y=self.y_train,
batch_size=batch_size,
num_batches=num_batches,
)

def get_valid_loader(self, batch_size=None):
"""Returns the data loader for the validation data.
Args:
batch_size (int, optional): The batch size for the data loader
(default is None).
Returns:
generator: The generator object for the validation data.
"""
return self._get_batch_generator(X=self.X_valid, y=self.y_valid, batch_size=batch_size)

def get_train_data_size(self):
"""Returns the total number of training samples.
Expand All @@ -90,55 +60,7 @@ def get_valid_data_size(self):
int: The total number of validation samples.
"""
return self.X_valid.shape[0]

@staticmethod
def _batch_generator(X, y, idxs, batch_size, num_batches):
"""Generates batches of data.
Args:
X (np.array): The input data.
y (np.array): The label data.
idxs (np.array): The index of the dataset.
batch_size (int): The batch size for the data loader.
num_batches (int): The number of batches.
Yields:
tuple: The input data and label data for each batch.
"""
for i in range(num_batches):
a = i * batch_size
b = a + batch_size
yield X[idxs[a:b]], y[idxs[a:b]]

def _get_batch_generator(self, X, y, batch_size, num_batches=None):
"""Returns the dataset generator.
Args:
X (np.array): The input data.
y (np.array): The label data.
batch_size (int): The batch size for the data loader.
num_batches (int, optional): The number of batches (default is
None).
Returns:
generator: The dataset generator.
"""
if batch_size is None:
batch_size = self.batch_size

# shuffle data indices
if self.random_seed is not None:
np.random.seed(self.random_seed)

idxs = np.random.permutation(np.arange(X.shape[0]))

# compute the number of batches
if num_batches is None:
num_batches = ceil(X.shape[0] / batch_size)

# build the generator and return it
return self._batch_generator(X, y, idxs, batch_size, num_batches)


def get_dmatrix(self, X, y):
"""Returns the DMatrix for the given data.
Expand Down

0 comments on commit 238448f

Please sign in to comment.