Skip to content

Commit

Permalink
Merge branch 'fds-add-resplitting-functionality'
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Oct 12, 2023
2 parents c19a235 + 87ed594 commit d996737
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 3 deletions.
32 changes: 30 additions & 2 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
"""FederatedDataset."""


from typing import Dict, Optional, Union

from typing import Callable, Dict, Optional, Tuple, Union
import datasets
from datasets import Dataset, DatasetDict
from flwr_datasets.merge_splitter import MergeSplitter
from flwr_datasets.partitioner import Partitioner
from flwr_datasets.utils import _check_if_dataset_tested, _instantiate_partitioners

Resplitter = Callable[[DatasetDict], DatasetDict]


class FederatedDataset:
"""Representation of a dataset for federated learning/evaluation/analytics.
Expand All @@ -35,6 +37,8 @@ class FederatedDataset:
----------
dataset: str
The name of the dataset in the Hugging Face Hub.
resplitter: Optional[Union[Resplitter, Dict[Tuple[str, ...], str]]]
Resplit strategy or custom Callable that transforms the dataset.
partitioners: Dict[str, Union[Partitioner, int]]
A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int`
(representing the number of IID partitions that this split should be partitioned
Expand All @@ -59,15 +63,18 @@ def __init__(
self,
*,
dataset: str,
resplitter: Optional[Union[Resplitter, Dict[Tuple[str, ...], str]]] = None,
partitioners: Dict[str, Union[Partitioner, int]],
) -> None:
_check_if_dataset_tested(dataset)
self._dataset_name: str = dataset
self._resplitter = resplitter
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
# Init (download) lazily on the first call to `load_partition` or `load_full`
self._dataset: Optional[DatasetDict] = None
self._resplit: bool = False # Indicate if the resplit happened

def load_partition(self, idx: int, split: str) -> Dataset:
"""Load the partition specified by the idx in the selected split.
Expand All @@ -88,6 +95,7 @@ def load_partition(self, idx: int, split: str) -> Dataset:
Single partition from the dataset split.
"""
self._download_dataset_if_none()
self._resplit_dataset_if_needed()
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
self._check_if_split_present(split)
Expand All @@ -113,6 +121,7 @@ def load_full(self, split: str) -> Dataset:
Part of the dataset identified by its split name.
"""
self._download_dataset_if_none()
self._resplit_dataset_if_needed()
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
self._check_if_split_present(split)
Expand Down Expand Up @@ -153,3 +162,22 @@ def _assign_dataset_to_partitioner(self, split: str) -> None:
raise ValueError("Dataset is not loaded yet.")
if not self._partitioners[split].is_dataset_assigned():
self._partitioners[split].dataset = self._dataset[split]

def _resplit_dataset_if_needed(self) -> None:
# this can't be called many times
# either a new attribute is needed e.g. resplit_dataset
# or a bool flag that the resplit happened

# Resplit only once
if self._resplit:
return
if self._dataset is None:
raise ValueError("The dataset resplit should happen after the download.")
if self._resplitter:
resplitter: Resplitter
if isinstance(self._resplitter, Dict):
resplitter = MergeSplitter(resplit_strategy=self._resplitter)
else:
resplitter = self._resplitter
self._dataset = resplitter(self._dataset)
self._resplit = True
63 changes: 62 additions & 1 deletion datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
"""Federated Dataset tests."""
# pylint: disable=W0212, C0103, C0206


import unittest
from typing import Dict, Union

import pytest
from parameterized import parameterized, parameterized_class

import datasets
from datasets import DatasetDict, concatenate_datasets
from flwr_datasets.federated_dataset import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, Partitioner

Expand Down Expand Up @@ -93,6 +93,55 @@ def test_multiple_partitioners(self) -> None:
len(dataset[self.test_split]) // num_test_partitions,
)

def test_resplit_dataset_into_one(self) -> None:
"""Test resplit into a single dataset."""
dataset = datasets.load_dataset(self.dataset_name)
dataset_length = sum([len(ds) for ds in dataset.values()])
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"train": 100},
resplitter={("train", self.test_split): "full"},
)
full = fds.load_full("full")
self.assertEqual(dataset_length, len(full))

# pylint: disable=protected-access
def test_resplit_dataset_to_change_names(self) -> None:
"""Test resplitter to change the names of the partitions."""
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"new_train": 100},
resplitter={
("train",): "new_train",
(self.test_split,): "new_" + self.test_split,
},
)
_ = fds.load_partition(0, "new_train")
assert fds._dataset is not None
self.assertEqual(
set(fds._dataset.keys()), {"new_train", "new_" + self.test_split}
)

def test_resplit_dataset_by_callable(self) -> None:
"""Test resplitter to change the names of the partitions."""

def resplit(dataset: DatasetDict) -> DatasetDict:
return DatasetDict(
{
"full": concatenate_datasets(
[dataset["train"], dataset[self.test_split]]
)
}
)

fds = FederatedDataset(
dataset=self.dataset_name, partitioners={"train": 100}, resplitter=resplit
)
full = fds.load_full("full")
dataset = datasets.load_dataset(self.dataset_name)
dataset_length = sum([len(ds) for ds in dataset.values()])
self.assertEqual(len(full), dataset_length)


class PartitionersSpecificationForFederatedDatasets(unittest.TestCase):
"""Test the specifications of partitioners for `FederatedDataset`."""
Expand Down Expand Up @@ -190,6 +239,18 @@ def test_unsupported_dataset(self) -> None: # pylint: disable=R0201
with pytest.warns(UserWarning):
FederatedDataset(dataset="food101", partitioners={"train": 100})

def test_cannot_use_the_old_split_names(self) -> None:
"""Test if the initial split names can not be used."""
dataset = datasets.load_dataset("mnist")
sum([len(ds) for ds in dataset.values()])
fds = FederatedDataset(
dataset="mnist",
partitioners={"train": 100},
resplitter={("train", "test"): "full"},
)
with self.assertRaises(ValueError):
fds.load_partition(0, "train")


if __name__ == "__main__":
unittest.main()
75 changes: 75 additions & 0 deletions datasets/flwr_datasets/merge_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Resplitter class for Flower Datasets."""
import collections
from typing import Dict, List, Tuple

import datasets
from datasets import Dataset, DatasetDict


class MergeSplitter:
"""Create a new dataset splits according to the `resplit_strategy`.
The dataset comes with some predefined splits e.g. "train", "valid" and "test". This
class allows you to create a new dataset with splits created according to your needs
specified in `resplit_strategy`.
Parameters
----------
resplit_strategy: ResplitStrategy
Dictionary with keys - tuples of the current split names to values - the desired
split names
"""

def __init__(
self,
resplit_strategy: Dict[Tuple[str, ...], str],
) -> None:
self._resplit_strategy: Dict[Tuple[str, ...], str] = resplit_strategy
self._check_duplicate_desired_splits()

def __call__(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the `resplit_strategy`."""
self._check_correct_keys_in_resplit_strategy(dataset)
return self.resplit(dataset)

def resplit(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the `resplit_strategy`."""
resplit_dataset = {}
for divided_from__list, divide_to in self._resplit_strategy.items():
datasets_from_list: List[Dataset] = []
for divide_from in divided_from__list:
datasets_from_list.append(dataset[divide_from])
if len(datasets_from_list) > 1:
resplit_dataset[divide_to] = datasets.concatenate_datasets(
datasets_from_list
)
else:
resplit_dataset[divide_to] = datasets_from_list[0]
return datasets.DatasetDict(resplit_dataset)

def _check_correct_keys_in_resplit_strategy(self, dataset: DatasetDict) -> None:
"""Check if the keys in resplit_strategy are existing dataset splits."""
dataset_keys = dataset.keys()
specified_dataset_keys = self._resplit_strategy.keys()
for key_list in specified_dataset_keys:
for key in key_list:
if key not in dataset_keys:
raise ValueError(
f"The given dataset key '{key}' is not present in the given "
f"dataset object. Make sure to use only the keywords that are "
f"available in your dataset."
)

def _check_duplicate_desired_splits(self) -> None:
"""Check for duplicate desired split names."""
desired_splits = list(self._resplit_strategy.values())
duplicates = [
item
for item, count in collections.Counter(desired_splits).items()
if count > 1
]
if duplicates:
print(f"Duplicate desired split name '{duplicates[0]}' in resplit strategy")
raise ValueError(
f"Duplicate desired split name '{duplicates[0]}' in resplit strategy"
)
129 changes: 129 additions & 0 deletions datasets/flwr_datasets/merge_splitter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Resplitter tests."""
import unittest
from typing import Dict, Tuple

from datasets import Dataset, DatasetDict
from flwr_datasets.merge_splitter import MergeSplitter


class TestResplitter(unittest.TestCase):
"""Resplitter tests."""

def setUp(self) -> None:
"""Set up the dataset with 3 splits for tests."""
self.dataset_dict = DatasetDict(
{
"train": Dataset.from_dict({"data": [1, 2, 3]}),
"valid": Dataset.from_dict({"data": [4, 5]}),
"test": Dataset.from_dict({"data": [6]}),
}
)

def test_resplitting_train_size(self) -> None:
"""Test if resplitting for just renaming keeps the lengths correct."""
strategy: Dict[Tuple[str, ...], str] = {("train",): "new_train"}
resplitter = MergeSplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["new_train"]), 3)

def test_resplitting_valid_size(self) -> None:
"""Test if resplitting for just renaming keeps the lengths correct."""
strategy: Dict[Tuple[str, ...], str] = {("valid",): "new_valid"}
resplitter = MergeSplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["new_valid"]), 2)

def test_resplitting_test_size(self) -> None:
"""Test if resplitting for just renaming keeps the lengths correct."""
strategy: Dict[Tuple[str, ...], str] = {("test",): "new_test"}
resplitter = MergeSplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["new_test"]), 1)

def test_resplitting_train_the_same(self) -> None:
"""Test if resplitting for just renaming keeps the dataset the same."""
strategy: Dict[Tuple[str, ...], str] = {("train",): "new_train"}
resplitter = MergeSplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertTrue(
datasets_are_equal(self.dataset_dict["train"], new_dataset["new_train"])
)

def test_combined_train_valid_size(self) -> None:
"""Test if the resplitting that combines the datasets has correct size."""
strategy: Dict[Tuple[str, ...], str] = {
("train", "valid"): "train_valid_combined"
}
resplitter = MergeSplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["train_valid_combined"]), 5)

def test_resplitting_test_with_combined_strategy_size(self) -> None:
"""Test if the resplitting that combines the datasets has correct size."""
strategy: Dict[Tuple[str, ...], str] = {
("train", "valid"): "train_valid_combined",
("test",): "test",
}
resplitter = MergeSplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["test"]), 1)

def test_invalid_resplit_strategy_exception_message(self) -> None:
"""Test if the resplitting raises error when non-existing split is given."""
strategy: Dict[Tuple[str, ...], str] = {
("invalid_split",): "new_train",
("test",): "new_test",
}
resplitter = MergeSplitter(strategy)
with self.assertRaisesRegex(
ValueError, "The given dataset key 'invalid_split' is not present"
):
resplitter(self.dataset_dict)

def test_nonexistent_split_in_strategy(self) -> None:
"""Test if the exception is raised when the nonexistent split name is given."""
strategy: Dict[Tuple[str, ...], str] = {("nonexistent_split",): "new_split"}
resplitter = MergeSplitter(strategy)
with self.assertRaisesRegex(
ValueError, "The given dataset key 'nonexistent_split' is not present"
):
resplitter(self.dataset_dict)

def test_duplicate_desired_split_name(self) -> None:
"""Test that the new split names are not the same."""
strategy: Dict[Tuple[str, ...], str] = {
("train",): "new_train",
("valid",): "new_train",
}
with self.assertRaisesRegex(
ValueError, "Duplicate desired split name 'new_train' in resplit strategy"
):
_ = MergeSplitter(strategy)

def test_empty_dataset_dict(self) -> None:
"""Test that the error is raised when the empty DatasetDict is given."""
empty_dataset = DatasetDict({})
strategy: Dict[Tuple[str, ...], str] = {("train",): "new_train"}
resplitter = MergeSplitter(strategy)
with self.assertRaisesRegex(
ValueError, "The given dataset key 'train' is not present"
):
resplitter(empty_dataset)


def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool:
"""Check if two Datasets have the same values."""
# Check if both datasets have the same length
if len(ds1) != len(ds2):
return False

# Iterate over each row and check for equality
for row1, row2 in zip(ds1, ds2):
if row1 != row2:
return False

return True


if __name__ == "__main__":
unittest.main()

0 comments on commit d996737

Please sign in to comment.