Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add resplitting functionality to Flower Datasets #2427

Merged
merged 25 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a4a295f
Add resplitter
adam-narozniak Sep 27, 2023
9aebcdb
Update FederatedDataset to work with resplitter
adam-narozniak Sep 27, 2023
c2866ac
Merge branch 'main' into fds-add-resplitting-functionality
adam-narozniak Oct 12, 2023
3a6e2e4
Rename Resplitter to MergeSplitter and custom Resplitter type
adam-narozniak Oct 12, 2023
de52f97
Fix MergeSplitter tests
adam-narozniak Oct 12, 2023
87ed594
Fix mypy problems for in tests
adam-narozniak Oct 12, 2023
d996737
Merge branch 'fds-add-resplitting-functionality'
adam-narozniak Oct 12, 2023
3c1af53
Fix spaces
adam-narozniak Oct 12, 2023
d01feb8
Merge branch 'main' into fds-add-resplitting-functionality
adam-narozniak Oct 12, 2023
261fe4c
Fix new lines
adam-narozniak Oct 12, 2023
b625ae8
Apply suggestions from code review
adam-narozniak Oct 13, 2023
f790f1b
Clarify the documentation of MergeResplitter
adam-narozniak Oct 13, 2023
f6adaa5
Update resplitter parameter docstring
adam-narozniak Oct 16, 2023
a9b6329
Fix new lines between docstring and imports
adam-narozniak Oct 16, 2023
71e9036
Fix new lines between docstring and imports
adam-narozniak Oct 16, 2023
d61350f
Add copyright notice
adam-narozniak Oct 16, 2023
f589f02
Check for duplicated split names in MergeResplitter merge_config
adam-narozniak Oct 16, 2023
2614fcb
Fix formatting
adam-narozniak Oct 16, 2023
809beea
Simply the _resplit_data_if_needed method
adam-narozniak Oct 16, 2023
10e2a69
Fix pylint tests
adam-narozniak Oct 16, 2023
dbad439
Switch the keys and values of the merge_config
adam-narozniak Oct 16, 2023
31fd770
Merge branch 'main' into fds-add-resplitting-functionality
adam-narozniak Oct 16, 2023
18a8c19
Update datasets/flwr_datasets/common/__init__.py
danieljanes Oct 16, 2023
b7249af
Update datasets/flwr_datasets/common/typing.py
danieljanes Oct 16, 2023
dd7bcbd
Merge branch 'main' into fds-add-resplitting-functionality
danieljanes Oct 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
"""FederatedDataset."""


from typing import Dict, Optional, Union
from typing import Callable, Dict, Optional, Tuple, Union

import datasets
from datasets import Dataset, DatasetDict
from flwr_datasets.merge_resplitter import MergeResplitter
from flwr_datasets.partitioner import Partitioner
from flwr_datasets.utils import _check_if_dataset_tested, _instantiate_partitioners

Resplitter = Callable[[DatasetDict], DatasetDict]


class FederatedDataset:
"""Representation of a dataset for federated learning/evaluation/analytics.
Expand All @@ -35,6 +38,8 @@ class FederatedDataset:
----------
dataset: str
The name of the dataset in the Hugging Face Hub.
resplitter: Optional[Union[Resplitter, Dict[Tuple[str, ...], str]]]
Resplit strategy or `Callable` that transforms splits in the `DatasetDict`.
adam-narozniak marked this conversation as resolved.
Show resolved Hide resolved
partitioners: Dict[str, Union[Partitioner, int]]
A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int`
(representing the number of IID partitions that this split should be partitioned
Expand All @@ -59,15 +64,18 @@ def __init__(
self,
*,
dataset: str,
resplitter: Optional[Union[Resplitter, Dict[Tuple[str, ...], str]]] = None,
partitioners: Dict[str, Union[Partitioner, int]],
) -> None:
_check_if_dataset_tested(dataset)
self._dataset_name: str = dataset
self._resplitter = resplitter
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
# Init (download) lazily on the first call to `load_partition` or `load_full`
self._dataset: Optional[DatasetDict] = None
self._resplit: bool = False # Indicate if the resplit happened

def load_partition(self, idx: int, split: str) -> Dataset:
"""Load the partition specified by the idx in the selected split.
Expand All @@ -88,6 +96,7 @@ def load_partition(self, idx: int, split: str) -> Dataset:
Single partition from the dataset split.
"""
self._download_dataset_if_none()
self._resplit_dataset_if_needed()
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
self._check_if_split_present(split)
Expand All @@ -113,6 +122,7 @@ def load_full(self, split: str) -> Dataset:
Part of the dataset identified by its split name.
"""
self._download_dataset_if_none()
self._resplit_dataset_if_needed()
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
self._check_if_split_present(split)
Expand Down Expand Up @@ -153,3 +163,21 @@ def _assign_dataset_to_partitioner(self, split: str) -> None:
raise ValueError("Dataset is not loaded yet.")
if not self._partitioners[split].is_dataset_assigned():
self._partitioners[split].dataset = self._dataset[split]

def _resplit_dataset_if_needed(self) -> None:
# The actual re-splitting can't be done more than once.
# The attribute `_resplit` indicates that the resplit happened.

# Resplit only once
if self._resplit:
return
if self._dataset is None:
raise ValueError("The dataset resplit should happen after the download.")
if self._resplitter:
resplitter: Resplitter
if isinstance(self._resplitter, Dict):
resplitter = MergeResplitter(merge_config=self._resplitter)
else:
resplitter = self._resplitter
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
self._dataset = resplitter(self._dataset)
self._resplit = True
62 changes: 62 additions & 0 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from parameterized import parameterized, parameterized_class

import datasets
from datasets import DatasetDict, concatenate_datasets
from flwr_datasets.federated_dataset import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, Partitioner

Expand Down Expand Up @@ -93,6 +94,55 @@ def test_multiple_partitioners(self) -> None:
len(dataset[self.test_split]) // num_test_partitions,
)

def test_resplit_dataset_into_one(self) -> None:
"""Test resplit into a single dataset."""
dataset = datasets.load_dataset(self.dataset_name)
dataset_length = sum([len(ds) for ds in dataset.values()])
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"train": 100},
resplitter={("train", self.test_split): "full"},
)
full = fds.load_full("full")
self.assertEqual(dataset_length, len(full))

# pylint: disable=protected-access
def test_resplit_dataset_to_change_names(self) -> None:
"""Test resplitter to change the names of the partitions."""
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"new_train": 100},
resplitter={
("train",): "new_train",
(self.test_split,): "new_" + self.test_split,
},
)
_ = fds.load_partition(0, "new_train")
assert fds._dataset is not None
self.assertEqual(
set(fds._dataset.keys()), {"new_train", "new_" + self.test_split}
)

def test_resplit_dataset_by_callable(self) -> None:
"""Test resplitter to change the names of the partitions."""

def resplit(dataset: DatasetDict) -> DatasetDict:
return DatasetDict(
{
"full": concatenate_datasets(
[dataset["train"], dataset[self.test_split]]
)
}
)

fds = FederatedDataset(
dataset=self.dataset_name, partitioners={"train": 100}, resplitter=resplit
)
full = fds.load_full("full")
dataset = datasets.load_dataset(self.dataset_name)
dataset_length = sum([len(ds) for ds in dataset.values()])
self.assertEqual(len(full), dataset_length)


class PartitionersSpecificationForFederatedDatasets(unittest.TestCase):
"""Test the specifications of partitioners for `FederatedDataset`."""
Expand Down Expand Up @@ -190,6 +240,18 @@ def test_unsupported_dataset(self) -> None: # pylint: disable=R0201
with pytest.warns(UserWarning):
FederatedDataset(dataset="food101", partitioners={"train": 100})

def test_cannot_use_the_old_split_names(self) -> None:
"""Test if the initial split names can not be used."""
dataset = datasets.load_dataset("mnist")
sum([len(ds) for ds in dataset.values()])
fds = FederatedDataset(
dataset="mnist",
partitioners={"train": 100},
resplitter={("train", "test"): "full"},
)
with self.assertRaises(ValueError):
fds.load_partition(0, "train")


if __name__ == "__main__":
unittest.main()
91 changes: 91 additions & 0 deletions datasets/flwr_datasets/merge_resplitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""MergeResplitter class for Flower Datasets."""
import collections
from typing import Dict, List, Tuple
adam-narozniak marked this conversation as resolved.
Show resolved Hide resolved

import datasets
from datasets import Dataset, DatasetDict


class MergeResplitter:
"""Merge existing splits of the dataset and assign them custom names.

Create new `DatasetDict` with new split names corresponding to the merged existing
splits (e.g. "train", "valid" and "test").

Parameters
----------
merge_config: Dict[Tuple[str, ...], str]
Dictionary with keys - tuples of the current split names to values - the desired
split names

Examples
--------
Create new `DatasetDict` with a split name "new_train" that is created as a merger
of the "train" and "valid" splits. Keep the "test" split.

>>> # Assuming there is a dataset_dict of type `DatasetDict`
>>> # dataset_dict is {"train": train-data, "valid": valid-data, "test": test-data}
>>> merge_resplitter = MergeResplitter(
>>> merge_config={
>>> ("train", "valid"): "new_train",
>>> ("test", ): "test"
>>> }
>>> )
>>> new_dataset_dict = merge_resplitter(dataset_dict)
>>> # new_dataset_dict is
>>> # {"new_train": concatenation of train-data and valid-data, "test": test-data}
"""

def __init__(
self,
merge_config: Dict[Tuple[str, ...], str],
) -> None:
self._merge_config: Dict[Tuple[str, ...], str] = merge_config
self._check_duplicate_desired_splits()

def __call__(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the `merge_config`."""
self._check_correct_keys_in_merge_config(dataset)
return self.resplit(dataset)

def resplit(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the `merge_config`."""
resplit_dataset = {}
for divided_from__list, divide_to in self._merge_config.items():
datasets_from_list: List[Dataset] = []
for divide_from in divided_from__list:
datasets_from_list.append(dataset[divide_from])
if len(datasets_from_list) > 1:
resplit_dataset[divide_to] = datasets.concatenate_datasets(
datasets_from_list
)
else:
resplit_dataset[divide_to] = datasets_from_list[0]
return datasets.DatasetDict(resplit_dataset)

def _check_correct_keys_in_merge_config(self, dataset: DatasetDict) -> None:
"""Check if the keys in merge_config are existing dataset splits."""
dataset_keys = dataset.keys()
specified_dataset_keys = self._merge_config.keys()
for key_list in specified_dataset_keys:
for key in key_list:
if key not in dataset_keys:
raise ValueError(
f"The given dataset key '{key}' is not present in the given "
f"dataset object. Make sure to use only the keywords that are "
f"available in your dataset."
)

def _check_duplicate_desired_splits(self) -> None:
"""Check for duplicate desired split names."""
desired_splits = list(self._merge_config.values())
duplicates = [
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
item
for item, count in collections.Counter(desired_splits).items()
if count > 1
]
if duplicates:
print(f"Duplicate desired split name '{duplicates[0]}' in resplit strategy")
raise ValueError(
f"Duplicate desired split name '{duplicates[0]}' in resplit strategy"
)
129 changes: 129 additions & 0 deletions datasets/flwr_datasets/merge_resplitter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Resplitter tests."""
import unittest
from typing import Dict, Tuple
adam-narozniak marked this conversation as resolved.
Show resolved Hide resolved

from datasets import Dataset, DatasetDict
from flwr_datasets.merge_resplitter import MergeResplitter


class TestResplitter(unittest.TestCase):
"""Resplitter tests."""

def setUp(self) -> None:
"""Set up the dataset with 3 splits for tests."""
self.dataset_dict = DatasetDict(
{
"train": Dataset.from_dict({"data": [1, 2, 3]}),
"valid": Dataset.from_dict({"data": [4, 5]}),
"test": Dataset.from_dict({"data": [6]}),
}
)

def test_resplitting_train_size(self) -> None:
"""Test if resplitting for just renaming keeps the lengths correct."""
strategy: Dict[Tuple[str, ...], str] = {("train",): "new_train"}
resplitter = MergeResplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["new_train"]), 3)

def test_resplitting_valid_size(self) -> None:
"""Test if resplitting for just renaming keeps the lengths correct."""
strategy: Dict[Tuple[str, ...], str] = {("valid",): "new_valid"}
resplitter = MergeResplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["new_valid"]), 2)

def test_resplitting_test_size(self) -> None:
"""Test if resplitting for just renaming keeps the lengths correct."""
strategy: Dict[Tuple[str, ...], str] = {("test",): "new_test"}
resplitter = MergeResplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["new_test"]), 1)

def test_resplitting_train_the_same(self) -> None:
"""Test if resplitting for just renaming keeps the dataset the same."""
strategy: Dict[Tuple[str, ...], str] = {("train",): "new_train"}
resplitter = MergeResplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertTrue(
datasets_are_equal(self.dataset_dict["train"], new_dataset["new_train"])
)

def test_combined_train_valid_size(self) -> None:
"""Test if the resplitting that combines the datasets has correct size."""
strategy: Dict[Tuple[str, ...], str] = {
("train", "valid"): "train_valid_combined"
}
resplitter = MergeResplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["train_valid_combined"]), 5)

def test_resplitting_test_with_combined_strategy_size(self) -> None:
"""Test if the resplitting that combines the datasets has correct size."""
strategy: Dict[Tuple[str, ...], str] = {
("train", "valid"): "train_valid_combined",
("test",): "test",
}
resplitter = MergeResplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["test"]), 1)

def test_invalid_resplit_strategy_exception_message(self) -> None:
"""Test if the resplitting raises error when non-existing split is given."""
strategy: Dict[Tuple[str, ...], str] = {
("invalid_split",): "new_train",
("test",): "new_test",
}
resplitter = MergeResplitter(strategy)
with self.assertRaisesRegex(
ValueError, "The given dataset key 'invalid_split' is not present"
):
resplitter(self.dataset_dict)

def test_nonexistent_split_in_strategy(self) -> None:
"""Test if the exception is raised when the nonexistent split name is given."""
strategy: Dict[Tuple[str, ...], str] = {("nonexistent_split",): "new_split"}
resplitter = MergeResplitter(strategy)
with self.assertRaisesRegex(
ValueError, "The given dataset key 'nonexistent_split' is not present"
):
resplitter(self.dataset_dict)

def test_duplicate_desired_split_name(self) -> None:
"""Test that the new split names are not the same."""
strategy: Dict[Tuple[str, ...], str] = {
("train",): "new_train",
("valid",): "new_train",
}
with self.assertRaisesRegex(
ValueError, "Duplicate desired split name 'new_train' in resplit strategy"
):
_ = MergeResplitter(strategy)

def test_empty_dataset_dict(self) -> None:
"""Test that the error is raised when the empty DatasetDict is given."""
empty_dataset = DatasetDict({})
strategy: Dict[Tuple[str, ...], str] = {("train",): "new_train"}
resplitter = MergeResplitter(strategy)
with self.assertRaisesRegex(
ValueError, "The given dataset key 'train' is not present"
):
resplitter(empty_dataset)


def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool:
"""Check if two Datasets have the same values."""
# Check if both datasets have the same length
if len(ds1) != len(ds2):
return False

# Iterate over each row and check for equality
for row1, row2 in zip(ds1, ds2):
if row1 != row2:
return False

return True


if __name__ == "__main__":
unittest.main()