Skip to content

Commit

Permalink
Add Scikit Learn integration tests with FDS (#2387)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak authored Nov 14, 2023
1 parent de023a1 commit c19a8d6
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
15 changes: 15 additions & 0 deletions datasets/e2e/scikit-learn/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "fds-e2e-sklearn"
version = "0.1.0"
description = "Flower Datasets with scikit-learn"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.8"
flwr-datasets = { path = "./../../", extras = ["vision"] }
scikit-learn = "^1.2.0"
parameterized = "==0.9.0"
94 changes: 94 additions & 0 deletions datasets/e2e/scikit-learn/sklearn_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import unittest

import numpy as np
from parameterized import parameterized_class
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

from flwr_datasets import FederatedDataset


# Using parameterized testing, two different sets of preprocessing:
# 1. Without scaling.
# 2. With standard scaling.
@parameterized_class(
[
{"dataset_name": "mnist", "preprocessing": None},
{"dataset_name": "mnist", "preprocessing": StandardScaler()},
]
)
class FdsWithSKLearn(unittest.TestCase):
"""Test Flower Datasets with Scikit-learn's Logistic Regression."""

dataset_name = ""
preprocessing = None

def _get_partition_data(self):
"""Retrieve partition data."""
partition_id = 0
fds = FederatedDataset(dataset=self.dataset_name, partitioners={"train": 10})
partition = fds.load_partition(partition_id, "train")
partition.set_format("numpy")
partition_train_test = partition.train_test_split(test_size=0.2)
X_train, y_train = partition_train_test["train"]["image"], partition_train_test[
"train"]["label"]
X_test, y_test = partition_train_test["test"]["image"], partition_train_test[
"test"]["label"]
X_train = X_train.reshape(-1, 28 * 28)
X_test = X_test.reshape(-1, 28 * 28)
if self.preprocessing:
self.preprocessing.fit(X_train)
X_train = self.preprocessing.transform(X_train)
X_test = self.preprocessing.transform(X_test)

return X_train, X_test, y_train, y_test

def test_data_shape(self):
"""Test if the data shape is maintained after preprocessing."""
X_train, _, _, _ = self._get_partition_data()
self.assertEqual(X_train.shape, (4_800, 28 * 28))

def test_X_train_type(self):
"""Test if the data type is correct."""
X_train, _, _, _ = self._get_partition_data()
self.assertIsInstance(X_train, np.ndarray)

def test_y_train_type(self):
"""Test if the data type is correct."""
_, _, y_train, _ = self._get_partition_data()
self.assertIsInstance(y_train, np.ndarray)

def test_X_test_type(self):
"""Test if the data type is correct."""
_, X_test, _, _ = self._get_partition_data()
self.assertIsInstance(X_test, np.ndarray)

def test_y_test_type(self):
"""Test if the data type is correct."""
_, _, _, y_test = self._get_partition_data()
self.assertIsInstance(y_test, np.ndarray)

def test_train_classifier(self):
"""Test if the classifier trains without errors."""
X_train, X_test, y_train, y_test = self._get_partition_data()
try:
clf = LogisticRegression()
clf.fit(X_train, y_train)
except Exception as e:
self.fail(f"Fitting Logistic Regression raised {type(e)} unexpectedly!")

def test_predict_from_classifier(self):
"""Test if the classifier predicts without errors."""
X_train, X_test, y_train, y_test = self._get_partition_data()
clf = LogisticRegression()
clf.fit(X_train, y_train)
try:
_ = clf.predict(X_test)
except Exception as e:
self.fail(
f"Predicting using Logistic Regression model raised {type(e)} "
f"unexpectedly!")


if __name__ == '__main__':
unittest.main()

0 comments on commit c19a8d6

Please sign in to comment.