diff --git a/datasets/e2e/scikit-learn/pyproject.toml b/datasets/e2e/scikit-learn/pyproject.toml new file mode 100644 index 000000000000..7e22644566cf --- /dev/null +++ b/datasets/e2e/scikit-learn/pyproject.toml @@ -0,0 +1,15 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "fds-e2e-sklearn" +version = "0.1.0" +description = "Flower Datasets with scikit-learn" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = "^3.8" +flwr-datasets = { path = "./../../", extras = ["vision"] } +scikit-learn = "^1.2.0" +parameterized = "==0.9.0" diff --git a/datasets/e2e/scikit-learn/sklearn_test.py b/datasets/e2e/scikit-learn/sklearn_test.py new file mode 100644 index 000000000000..e5e6d347ee37 --- /dev/null +++ b/datasets/e2e/scikit-learn/sklearn_test.py @@ -0,0 +1,94 @@ +import unittest + +import numpy as np +from parameterized import parameterized_class +from sklearn.linear_model import LogisticRegression +from sklearn.preprocessing import StandardScaler + +from flwr_datasets import FederatedDataset + + +# Using parameterized testing, two different sets of preprocessing: +# 1. Without scaling. +# 2. With standard scaling. +@parameterized_class( + [ + {"dataset_name": "mnist", "preprocessing": None}, + {"dataset_name": "mnist", "preprocessing": StandardScaler()}, + ] +) +class FdsWithSKLearn(unittest.TestCase): + """Test Flower Datasets with Scikit-learn's Logistic Regression.""" + + dataset_name = "" + preprocessing = None + + def _get_partition_data(self): + """Retrieve partition data.""" + partition_id = 0 + fds = FederatedDataset(dataset=self.dataset_name, partitioners={"train": 10}) + partition = fds.load_partition(partition_id, "train") + partition.set_format("numpy") + partition_train_test = partition.train_test_split(test_size=0.2) + X_train, y_train = partition_train_test["train"]["image"], partition_train_test[ + "train"]["label"] + X_test, y_test = partition_train_test["test"]["image"], partition_train_test[ + "test"]["label"] + X_train = X_train.reshape(-1, 28 * 28) + X_test = X_test.reshape(-1, 28 * 28) + if self.preprocessing: + self.preprocessing.fit(X_train) + X_train = self.preprocessing.transform(X_train) + X_test = self.preprocessing.transform(X_test) + + return X_train, X_test, y_train, y_test + + def test_data_shape(self): + """Test if the data shape is maintained after preprocessing.""" + X_train, _, _, _ = self._get_partition_data() + self.assertEqual(X_train.shape, (4_800, 28 * 28)) + + def test_X_train_type(self): + """Test if the data type is correct.""" + X_train, _, _, _ = self._get_partition_data() + self.assertIsInstance(X_train, np.ndarray) + + def test_y_train_type(self): + """Test if the data type is correct.""" + _, _, y_train, _ = self._get_partition_data() + self.assertIsInstance(y_train, np.ndarray) + + def test_X_test_type(self): + """Test if the data type is correct.""" + _, X_test, _, _ = self._get_partition_data() + self.assertIsInstance(X_test, np.ndarray) + + def test_y_test_type(self): + """Test if the data type is correct.""" + _, _, _, y_test = self._get_partition_data() + self.assertIsInstance(y_test, np.ndarray) + + def test_train_classifier(self): + """Test if the classifier trains without errors.""" + X_train, X_test, y_train, y_test = self._get_partition_data() + try: + clf = LogisticRegression() + clf.fit(X_train, y_train) + except Exception as e: + self.fail(f"Fitting Logistic Regression raised {type(e)} unexpectedly!") + + def test_predict_from_classifier(self): + """Test if the classifier predicts without errors.""" + X_train, X_test, y_train, y_test = self._get_partition_data() + clf = LogisticRegression() + clf.fit(X_train, y_train) + try: + _ = clf.predict(X_test) + except Exception as e: + self.fail( + f"Predicting using Logistic Regression model raised {type(e)} " + f"unexpectedly!") + + +if __name__ == '__main__': + unittest.main()