Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TensorFlow integration tests with FDS #2350

Merged
merged 11 commits into from
Sep 21, 2023
15 changes: 15 additions & 0 deletions datasets/e2e/tensorflow/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "fds-e2-tensorflow"
version = "0.1.0"
description = "Flower Datasets with TensorFlow"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.8"
flwr-datasets = { path = "./../../", extras = ["vision"] }
tensorflow-cpu = "^2.9.1, !=2.11.1"
parameterized = "==0.9.0"
102 changes: 102 additions & 0 deletions datasets/e2e/tensorflow/tensorflow_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import unittest

import numpy as np
import tensorflow as tf
from datasets.utils.logging import disable_progress_bar
from parameterized import parameterized_class, parameterized
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Sequential

from flwr_datasets import FederatedDataset


def SimpleCNN():
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
MaxPooling2D(2, 2),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(2, 2),
Flatten(),
Dense(64, activation='relu'),
Dense(10, activation='softmax')
])
return model


@parameterized_class(
[
{"dataset_name": "cifar10", "test_split": "test"},
{"dataset_name": "cifar10", "test_split": "test"},
]
)
class FdsToTensorFlow(unittest.TestCase):
"""Test the conversion from FDS to PyTorch Dataset and Dataloader."""

dataset_name = ""
test_split = ""
expected_img_shape_after_transform = [32, 32, 3]

@classmethod
def setUpClass(cls):
"""Disable progress bar to keep the log clean.
"""
disable_progress_bar()

def _create_tensorflow_dataset(self, batch_size: int) -> tf.data.Dataset:
"""Create a tensorflow dataset from the FederatedDataset."""
partition_id = 0
fds = FederatedDataset(dataset=self.dataset_name, partitioners={"train": 100})
partition = fds.load_partition(partition_id, "train")
tf_dataset = partition.to_tf_dataset(columns="img", label_cols="label",
batch_size=batch_size,
shuffle=False)
return tf_dataset

def test_create_partition_dataset_shape(self) -> None:
"""Test if the DataLoader returns batches with the expected shape."""
batch_size = 16
dataset = self._create_tensorflow_dataset(batch_size)
batch = next(iter(dataset))
images = batch[0]
self.assertEqual(tuple(images.shape),
(batch_size, *self.expected_img_shape_after_transform))

def test_create_partition_dataloader_with_transforms_batch_type(self) -> None:
"""Test if the DataLoader returns batches of type dictionary."""
batch_size = 16
dataset = self._create_tensorflow_dataset(batch_size)
batch = next(iter(dataset))
self.assertIsInstance(batch, tuple)

def test_create_partition_dataloader_with_transforms_data_type(self) -> None:
"""Test to verify if the data in the DataLoader batches are of type Tensor."""
batch_size = 16
dataset = self._create_tensorflow_dataset(batch_size)
batch = next(iter(dataset))
images = batch[0]
self.assertIsInstance(images, tf.Tensor)

@parameterized.expand([
("not_nan", np.isnan),
("not_inf", np.isinf),
])
def test_train_model_loss_value(self, name, condition_func):
model = SimpleCNN()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

dataset = self._create_tensorflow_dataset(16)

# Perform a single epoch of training
history = model.fit(dataset, epochs=1, verbose=0)

# Fetch the last loss from history
last_loss = history.history['loss'][-1]

# Check if the last loss is NaN or Infinity
self.assertFalse(condition_func(last_loss))


if __name__ == '__main__':
unittest.main()