Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Scikit Learn integration tests with FDS #2387

Merged
merged 9 commits into from
Nov 14, 2023
15 changes: 15 additions & 0 deletions datasets/e2e/scikit-learn/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "fds-e2e-sklearn"
version = "0.1.0"
description = "Flower Datasets with scikit-learn"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.8"
flwr-datasets = { path = "./../../", extras = ["vision"] }
scikit-learn = "^1.2.0"
parameterized = "==0.9.0"
94 changes: 94 additions & 0 deletions datasets/e2e/scikit-learn/sklearn_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import unittest

import numpy as np
from parameterized import parameterized_class
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

from flwr_datasets import FederatedDataset


# Using parameterized testing, two different sets of preprocessing:
# 1. Without scaling.
# 2. With standard scaling.
@parameterized_class(
[
{"dataset_name": "mnist", "preprocessing": None},
{"dataset_name": "mnist", "preprocessing": StandardScaler()},
]
)
class FdsWithSKLearn(unittest.TestCase):
"""Test Flower Datasets with Scikit-learn's Logistic Regression."""

dataset_name = ""
preprocessing = None

def _get_partition_data(self):
"""Retrieve partition data."""
partition_id = 0
fds = FederatedDataset(dataset=self.dataset_name, partitioners={"train": 10})
partition = fds.load_partition(partition_id, "train")
partition.set_format("numpy")
partition_train_test = partition.train_test_split(test_size=0.2)
X_train, y_train = partition_train_test["train"]["image"], partition_train_test[
"train"]["label"]
X_test, y_test = partition_train_test["test"]["image"], partition_train_test[
"test"]["label"]
X_train = X_train.reshape(-1, 28 * 28)
X_test = X_test.reshape(-1, 28 * 28)
if self.preprocessing:
self.preprocessing.fit(X_train)
X_train = self.preprocessing.transform(X_train)
X_test = self.preprocessing.transform(X_test)

return X_train, X_test, y_train, y_test

def test_data_shape(self):
"""Test if the data shape is maintained after preprocessing."""
X_train, _, _, _ = self._get_partition_data()
self.assertEqual(X_train.shape, (4_800, 28 * 28))

def test_X_train_type(self):
"""Test if the data type is correct."""
X_train, _, _, _ = self._get_partition_data()
self.assertIsInstance(X_train, np.ndarray)

def test_y_train_type(self):
"""Test if the data type is correct."""
_, _, y_train, _ = self._get_partition_data()
self.assertIsInstance(y_train, np.ndarray)

def test_X_test_type(self):
"""Test if the data type is correct."""
_, X_test, _, _ = self._get_partition_data()
self.assertIsInstance(X_test, np.ndarray)

def test_y_test_type(self):
"""Test if the data type is correct."""
_, _, _, y_test = self._get_partition_data()
self.assertIsInstance(y_test, np.ndarray)

def test_train_classifier(self):
"""Test if the classifier trains without errors."""
X_train, X_test, y_train, y_test = self._get_partition_data()
try:
clf = LogisticRegression()
clf.fit(X_train, y_train)
except Exception as e:
self.fail(f"Fitting Logistic Regression raised {type(e)} unexpectedly!")

def test_predict_from_classifier(self):
"""Test if the classifier predicts without errors."""
X_train, X_test, y_train, y_test = self._get_partition_data()
clf = LogisticRegression()
clf.fit(X_train, y_train)
try:
_ = clf.predict(X_test)
except Exception as e:
self.fail(
f"Predicting using Logistic Regression model raised {type(e)} "
f"unexpectedly!")


if __name__ == '__main__':
unittest.main()
Loading