Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyTorch integration tests with FDS #2351

Merged
merged 21 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
53d37e2
Add PyTorch integration tests with FDS
adam-narozniak Sep 12, 2023
a11fa4f
Add ML training to the tests
adam-narozniak Sep 15, 2023
bb41282
Reformat
adam-narozniak Sep 15, 2023
99f0779
Merge branch 'main' into fds-pytorch-integration-tests
adam-narozniak Sep 15, 2023
7ab43cc
Rename integration_tests to e2e
adam-narozniak Sep 20, 2023
7efa66f
Fix the Flower Dataset library name in the pyproject.toml
adam-narozniak Sep 20, 2023
589d880
Fix test class name
adam-narozniak Sep 20, 2023
07399c0
Fix name in tool.poetry
adam-narozniak Sep 20, 2023
04001a6
Remove develop = true in pyproject.toml for flower-datasets
adam-narozniak Sep 20, 2023
c9e5727
Merge branch 'main' into fds-pytorch-integration-tests
danieljanes Sep 22, 2023
f8a0599
Merge branch 'main' into fds-pytorch-integration-tests
tanertopal Sep 22, 2023
240d018
Merge branch 'main' into fds-pytorch-integration-tests
charlesbvll Oct 12, 2023
ade5904
Merge branch 'main' into fds-pytorch-integration-tests
tanertopal Nov 10, 2023
87b3cf2
Merge branch 'main' into fds-pytorch-integration-tests
adam-narozniak Nov 10, 2023
cf69233
Merge branch 'main' into fds-pytorch-integration-tests
adam-narozniak Nov 10, 2023
f8f6fd3
Merge branch 'main' into fds-pytorch-integration-tests
adam-narozniak Nov 14, 2023
7c7da7d
Merge branch 'main' into fds-pytorch-integration-tests
adam-narozniak Dec 20, 2023
ef788ce
Merge branch 'main' into fds-pytorch-integration-tests
tanertopal Dec 20, 2023
6ed2681
Merge branch 'main' into fds-pytorch-integration-tests
tanertopal Dec 20, 2023
9c5f30f
Merge branch 'main' into fds-pytorch-integration-tests
tanertopal Dec 20, 2023
ab6cc97
Merge branch 'main' into fds-pytorch-integration-tests
tanertopal Dec 20, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions datasets/integration_tests/pytorch/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "fds-integration-pytorch"
version = "0.1.0"
description = "Flower Datasets with PyTorch"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.8"
flwr = "^1.5"
torch = "^1.12.0"
torchvision = "^0.14.1"
flwr_datasets = { path = "./../../", develop = true, extras = ["vision"] }
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
parameterized = "==0.9.0"
131 changes: 131 additions & 0 deletions datasets/integration_tests/pytorch/pytorch_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import unittest

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datasets.utils.logging import disable_progress_bar
from parameterized import parameterized_class, parameterized
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize

from flwr_datasets import FederatedDataset


class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x


# Using parameterized testing, two different sets of parameters are specified:
# 1. CIFAR10 dataset with the simple ToTensor transform.
# 2. CIFAR10 dataset with a composed transform that first converts an image to a tensor
# and then normalizes it.
@parameterized_class(
[
{"dataset_name": "cifar10", "test_split": "test", "transforms": ToTensor()},
{"dataset_name": "cifar10", "test_split": "test", "transforms": Compose(
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)},
]
)
class FDSToPyTorchCorrectUsage(unittest.TestCase):
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
"""Test the conversion from FDS to PyTorch Dataset and Dataloader."""

dataset_name = ""
test_split = ""
transforms = None
trainloader = None
expected_img_shape_after_transform = [3, 32, 32]

@classmethod
def setUpClass(cls):
"""Disable progress bar to keep the log clean.
"""
disable_progress_bar()

def _create_trainloader(self, batch_size: int) -> DataLoader:
"""Create a trainloader from the federated dataset."""
partition_id = 0
fds = FederatedDataset(dataset=self.dataset_name, partitioners={"train": 100})
partition = fds.load_partition(partition_id, "train")
partition_train_test = partition.train_test_split(test_size=0.2)
partition_train_test = partition_train_test.map(
lambda img: {"img": self.transforms(img)}, input_columns="img"
)
trainloader = DataLoader(
partition_train_test["train"].with_format("torch"), batch_size=batch_size,
shuffle=True
)
return trainloader

def test_create_partition_dataloader_with_transforms_shape(self) -> None:
"""Test if the DataLoader returns batches with the expected shape."""
batch_size = 16
trainloader = self._create_trainloader(batch_size)
batch = next(iter(trainloader))
images = batch["img"]
self.assertEqual(tuple(images.shape),
(batch_size, *self.expected_img_shape_after_transform))

def test_create_partition_dataloader_with_transforms_batch_type(self) -> None:
"""Test if the DataLoader returns batches of type dictionary."""
batch_size = 16
trainloader = self._create_trainloader(batch_size)
batch = next(iter(trainloader))
self.assertIsInstance(batch, dict)

def test_create_partition_dataloader_with_transforms_data_type(self) -> None:
"""Test to verify if the data in the DataLoader batches are of type Tensor."""
batch_size = 16
trainloader = self._create_trainloader(batch_size)
batch = next(iter(trainloader))
images = batch["img"]
self.assertIsInstance(images, Tensor)

@parameterized.expand([
("not_nan", torch.isnan),
("not_inf", torch.isinf),
])
def test_train_model_loss_value(self, name, condition_func):
"""Test if the model trains and if the loss is a correct number."""
trainloader = self._create_trainloader(16)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Create the model, criterion, and optimizer
net = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# Training loop for one epoch
net.train()
loss = None
for i, data in enumerate(trainloader, 0):
inputs, labels = data['img'].to(device), data['label'].to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

self.assertFalse(condition_func(loss).item())


if __name__ == '__main__':
unittest.main()