From 238448fc0f73d6ec2d60e0efcffc1538a88e6206 Mon Sep 17 00:00:00 2001 From: kta-intel Date: Mon, 18 Nov 2024 09:04:40 -0800 Subject: [PATCH] clean up methods Signed-off-by: kta-intel --- openfl/federated/data/loader_xgb.py | 84 ++--------------------------- 1 file changed, 3 insertions(+), 81 deletions(-) diff --git a/openfl/federated/data/loader_xgb.py b/openfl/federated/data/loader_xgb.py index 73f80e49e4..cb8272af98 100644 --- a/openfl/federated/data/loader_xgb.py +++ b/openfl/federated/data/loader_xgb.py @@ -2,9 +2,10 @@ import numpy as np import xgboost as xgb +from openfl.federated.data.loader import DataLoader -class XGBoostDataLoader: +class XGBoostDataLoader(DataLoader): """A class used to represent a Data Loader for XGBoost models. Attributes: @@ -44,37 +45,6 @@ def get_feature_shape(self): """ return self.X_train[0].shape - def get_train_loader(self, batch_size=None, num_batches=None): - """Returns the data loader for the training data. - - Args: - batch_size (int, optional): The batch size for the data loader - (default is None). - num_batches (int, optional): The number of batches for the data - loader (default is None). - - Returns: - generator: The generator object for the training data. - """ - return self._get_batch_generator( - X=self.X_train, - y=self.y_train, - batch_size=batch_size, - num_batches=num_batches, - ) - - def get_valid_loader(self, batch_size=None): - """Returns the data loader for the validation data. - - Args: - batch_size (int, optional): The batch size for the data loader - (default is None). - - Returns: - generator: The generator object for the validation data. - """ - return self._get_batch_generator(X=self.X_valid, y=self.y_valid, batch_size=batch_size) - def get_train_data_size(self): """Returns the total number of training samples. @@ -90,55 +60,7 @@ def get_valid_data_size(self): int: The total number of validation samples. """ return self.X_valid.shape[0] - - @staticmethod - def _batch_generator(X, y, idxs, batch_size, num_batches): - """Generates batches of data. - - Args: - X (np.array): The input data. - y (np.array): The label data. - idxs (np.array): The index of the dataset. - batch_size (int): The batch size for the data loader. - num_batches (int): The number of batches. - - Yields: - tuple: The input data and label data for each batch. - """ - for i in range(num_batches): - a = i * batch_size - b = a + batch_size - yield X[idxs[a:b]], y[idxs[a:b]] - - def _get_batch_generator(self, X, y, batch_size, num_batches=None): - """Returns the dataset generator. - - Args: - X (np.array): The input data. - y (np.array): The label data. - batch_size (int): The batch size for the data loader. - num_batches (int, optional): The number of batches (default is - None). - - Returns: - generator: The dataset generator. - """ - if batch_size is None: - batch_size = self.batch_size - - # shuffle data indices - if self.random_seed is not None: - np.random.seed(self.random_seed) - - idxs = np.random.permutation(np.arange(X.shape[0])) - - # compute the number of batches - if num_batches is None: - num_batches = ceil(X.shape[0] / batch_size) - - # build the generator and return it - return self._batch_generator(X, y, idxs, batch_size, num_batches) - + def get_dmatrix(self, X, y): """Returns the DMatrix for the given data.