diff --git a/datasets/flwr_datasets/common/__init__.py b/datasets/flwr_datasets/common/__init__.py new file mode 100644 index 000000000000..a6468bcf7fda --- /dev/null +++ b/datasets/flwr_datasets/common/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common components in Flower Datasets.""" + + +from .typing import Resplitter + +__all__ = ["Resplitter"] diff --git a/datasets/flwr_datasets/common/typing.py b/datasets/flwr_datasets/common/typing.py new file mode 100644 index 000000000000..28e6bae4a505 --- /dev/null +++ b/datasets/flwr_datasets/common/typing.py @@ -0,0 +1,22 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower Datasets type definitions.""" + + +from typing import Callable + +from datasets import DatasetDict + +Resplitter = Callable[[DatasetDict], DatasetDict] diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index c4691da397fe..dd332107dc74 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -15,12 +15,17 @@ """FederatedDataset.""" -from typing import Dict, Optional, Union +from typing import Dict, Optional, Tuple, Union import datasets from datasets import Dataset, DatasetDict +from flwr_datasets.common import Resplitter from flwr_datasets.partitioner import Partitioner -from flwr_datasets.utils import _check_if_dataset_tested, _instantiate_partitioners +from flwr_datasets.utils import ( + _check_if_dataset_tested, + _instantiate_partitioners, + _instantiate_resplitter_if_needed, +) class FederatedDataset: @@ -38,10 +43,13 @@ class FederatedDataset: subset: str Secondary information regarding the dataset, most often subset or version (that is passed to the name in datasets.load_dataset). + resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] + `Callable` that transforms `DatasetDict` splits, or configuration dict for + `MergeResplitter`. partitioners: Dict[str, Union[Partitioner, int]] A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int` (representing the number of IID partitions that this split should be partitioned - into). + into). Examples -------- @@ -63,16 +71,21 @@ def __init__( *, dataset: str, subset: Optional[str] = None, + resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None, partitioners: Dict[str, Union[Partitioner, int]], ) -> None: _check_if_dataset_tested(dataset) self._dataset_name: str = dataset self._subset: Optional[str] = subset + self._resplitter: Optional[Resplitter] = _instantiate_resplitter_if_needed( + resplitter + ) self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners ) # Init (download) lazily on the first call to `load_partition` or `load_full` self._dataset: Optional[DatasetDict] = None + self._resplit: bool = False # Indicate if the resplit happened def load_partition(self, idx: int, split: str) -> Dataset: """Load the partition specified by the idx in the selected split. @@ -93,6 +106,7 @@ def load_partition(self, idx: int, split: str) -> Dataset: Single partition from the dataset split. """ self._download_dataset_if_none() + self._resplit_dataset_if_needed() if self._dataset is None: raise ValueError("Dataset is not loaded yet.") self._check_if_split_present(split) @@ -118,6 +132,7 @@ def load_full(self, split: str) -> Dataset: Part of the dataset identified by its split name. """ self._download_dataset_if_none() + self._resplit_dataset_if_needed() if self._dataset is None: raise ValueError("Dataset is not loaded yet.") self._check_if_split_present(split) @@ -160,3 +175,16 @@ def _assign_dataset_to_partitioner(self, split: str) -> None: raise ValueError("Dataset is not loaded yet.") if not self._partitioners[split].is_dataset_assigned(): self._partitioners[split].dataset = self._dataset[split] + + def _resplit_dataset_if_needed(self) -> None: + # The actual re-splitting can't be done more than once. + # The attribute `_resplit` indicates that the resplit happened. + + # Resplit only once + if self._resplit: + return + if self._dataset is None: + raise ValueError("The dataset resplit should happen after the download.") + if self._resplitter: + self._dataset = self._resplitter(self._dataset) + self._resplit = True diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 15485dfa7a95..ca2bc97a33ee 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -23,6 +23,7 @@ from parameterized import parameterized, parameterized_class import datasets +from datasets import DatasetDict, concatenate_datasets from flwr_datasets.federated_dataset import FederatedDataset from flwr_datasets.partitioner import IidPartitioner, Partitioner @@ -93,6 +94,55 @@ def test_multiple_partitioners(self) -> None: len(dataset[self.test_split]) // num_test_partitions, ) + def test_resplit_dataset_into_one(self) -> None: + """Test resplit into a single dataset.""" + dataset = datasets.load_dataset(self.dataset_name) + dataset_length = sum([len(ds) for ds in dataset.values()]) + fds = FederatedDataset( + dataset=self.dataset_name, + partitioners={"train": 100}, + resplitter={"full": ("train", self.test_split)}, + ) + full = fds.load_full("full") + self.assertEqual(dataset_length, len(full)) + + # pylint: disable=protected-access + def test_resplit_dataset_to_change_names(self) -> None: + """Test resplitter to change the names of the partitions.""" + fds = FederatedDataset( + dataset=self.dataset_name, + partitioners={"new_train": 100}, + resplitter={ + "new_train": ("train",), + "new_" + self.test_split: (self.test_split,), + }, + ) + _ = fds.load_partition(0, "new_train") + assert fds._dataset is not None + self.assertEqual( + set(fds._dataset.keys()), {"new_train", "new_" + self.test_split} + ) + + def test_resplit_dataset_by_callable(self) -> None: + """Test resplitter to change the names of the partitions.""" + + def resplit(dataset: DatasetDict) -> DatasetDict: + return DatasetDict( + { + "full": concatenate_datasets( + [dataset["train"], dataset[self.test_split]] + ) + } + ) + + fds = FederatedDataset( + dataset=self.dataset_name, partitioners={"train": 100}, resplitter=resplit + ) + full = fds.load_full("full") + dataset = datasets.load_dataset(self.dataset_name) + dataset_length = sum([len(ds) for ds in dataset.values()]) + self.assertEqual(len(full), dataset_length) + class PartitionersSpecificationForFederatedDatasets(unittest.TestCase): """Test the specifications of partitioners for `FederatedDataset`.""" @@ -190,6 +240,18 @@ def test_unsupported_dataset(self) -> None: # pylint: disable=R0201 with pytest.warns(UserWarning): FederatedDataset(dataset="food101", partitioners={"train": 100}) + def test_cannot_use_the_old_split_names(self) -> None: + """Test if the initial split names can not be used.""" + dataset = datasets.load_dataset("mnist") + sum([len(ds) for ds in dataset.values()]) + fds = FederatedDataset( + dataset="mnist", + partitioners={"train": 100}, + resplitter={"full": ("train", "test")}, + ) + with self.assertRaises(ValueError): + fds.load_partition(0, "train") + if __name__ == "__main__": unittest.main() diff --git a/datasets/flwr_datasets/merge_resplitter.py b/datasets/flwr_datasets/merge_resplitter.py new file mode 100644 index 000000000000..995b0e8e5602 --- /dev/null +++ b/datasets/flwr_datasets/merge_resplitter.py @@ -0,0 +1,110 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MergeResplitter class for Flower Datasets.""" + + +import collections +import warnings +from functools import reduce +from typing import Dict, List, Tuple + +import datasets +from datasets import Dataset, DatasetDict + + +class MergeResplitter: + """Merge existing splits of the dataset and assign them custom names. + + Create new `DatasetDict` with new split names corresponding to the merged existing + splits (e.g. "train", "valid" and "test"). + + Parameters + ---------- + merge_config: Dict[str, Tuple[str, ...]] + Dictionary with keys - the desired split names to values - tuples of the current + split names that will be merged together + + Examples + -------- + Create new `DatasetDict` with a split name "new_train" that is created as a merger + of the "train" and "valid" splits. Keep the "test" split. + + >>> # Assuming there is a dataset_dict of type `DatasetDict` + >>> # dataset_dict is {"train": train-data, "valid": valid-data, "test": test-data} + >>> merge_resplitter = MergeResplitter( + >>> merge_config={ + >>> "new_train": ("train", "valid"), + >>> "test": ("test", ) + >>> } + >>> ) + >>> new_dataset_dict = merge_resplitter(dataset_dict) + >>> # new_dataset_dict is + >>> # {"new_train": concatenation of train-data and valid-data, "test": test-data} + """ + + def __init__( + self, + merge_config: Dict[str, Tuple[str, ...]], + ) -> None: + self._merge_config: Dict[str, Tuple[str, ...]] = merge_config + self._check_duplicate_merge_splits() + + def __call__(self, dataset: DatasetDict) -> DatasetDict: + """Resplit the dataset according to the `merge_config`.""" + self._check_correct_keys_in_merge_config(dataset) + return self.resplit(dataset) + + def resplit(self, dataset: DatasetDict) -> DatasetDict: + """Resplit the dataset according to the `merge_config`.""" + resplit_dataset = {} + for divide_to, divided_from__list in self._merge_config.items(): + datasets_from_list: List[Dataset] = [] + for divide_from in divided_from__list: + datasets_from_list.append(dataset[divide_from]) + if len(datasets_from_list) > 1: + resplit_dataset[divide_to] = datasets.concatenate_datasets( + datasets_from_list + ) + else: + resplit_dataset[divide_to] = datasets_from_list[0] + return datasets.DatasetDict(resplit_dataset) + + def _check_correct_keys_in_merge_config(self, dataset: DatasetDict) -> None: + """Check if the keys in merge_config are existing dataset splits.""" + dataset_keys = dataset.keys() + specified_dataset_keys = self._merge_config.values() + for key_list in specified_dataset_keys: + for key in key_list: + if key not in dataset_keys: + raise ValueError( + f"The given dataset key '{key}' is not present in the given " + f"dataset object. Make sure to use only the keywords that are " + f"available in your dataset." + ) + + def _check_duplicate_merge_splits(self) -> None: + """Check if the original splits are duplicated for new splits creation.""" + merge_splits = reduce(lambda x, y: x + y, self._merge_config.values()) + duplicates = [ + item + for item, count in collections.Counter(merge_splits).items() + if count > 1 + ] + if duplicates: + warnings.warn( + f"More than one desired splits used '{duplicates[0]}' in " + f"`merge_config`. Make sure that is the intended behavior.", + stacklevel=1, + ) diff --git a/datasets/flwr_datasets/merge_resplitter_test.py b/datasets/flwr_datasets/merge_resplitter_test.py new file mode 100644 index 000000000000..096bd7efac27 --- /dev/null +++ b/datasets/flwr_datasets/merge_resplitter_test.py @@ -0,0 +1,145 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Resplitter tests.""" + + +import unittest +from typing import Dict, Tuple + +import pytest + +from datasets import Dataset, DatasetDict +from flwr_datasets.merge_resplitter import MergeResplitter + + +class TestResplitter(unittest.TestCase): + """Resplitter tests.""" + + def setUp(self) -> None: + """Set up the dataset with 3 splits for tests.""" + self.dataset_dict = DatasetDict( + { + "train": Dataset.from_dict({"data": [1, 2, 3]}), + "valid": Dataset.from_dict({"data": [4, 5]}), + "test": Dataset.from_dict({"data": [6]}), + } + ) + + def test_resplitting_train_size(self) -> None: + """Test if resplitting for just renaming keeps the lengths correct.""" + strategy: Dict[str, Tuple[str, ...]] = {"new_train": ("train",)} + resplitter = MergeResplitter(strategy) + new_dataset = resplitter(self.dataset_dict) + self.assertEqual(len(new_dataset["new_train"]), 3) + + def test_resplitting_valid_size(self) -> None: + """Test if resplitting for just renaming keeps the lengths correct.""" + strategy: Dict[str, Tuple[str, ...]] = {"new_valid": ("valid",)} + resplitter = MergeResplitter(strategy) + new_dataset = resplitter(self.dataset_dict) + self.assertEqual(len(new_dataset["new_valid"]), 2) + + def test_resplitting_test_size(self) -> None: + """Test if resplitting for just renaming keeps the lengths correct.""" + strategy: Dict[str, Tuple[str, ...]] = {"new_test": ("test",)} + resplitter = MergeResplitter(strategy) + new_dataset = resplitter(self.dataset_dict) + self.assertEqual(len(new_dataset["new_test"]), 1) + + def test_resplitting_train_the_same(self) -> None: + """Test if resplitting for just renaming keeps the dataset the same.""" + strategy: Dict[str, Tuple[str, ...]] = {"new_train": ("train",)} + resplitter = MergeResplitter(strategy) + new_dataset = resplitter(self.dataset_dict) + self.assertTrue( + datasets_are_equal(self.dataset_dict["train"], new_dataset["new_train"]) + ) + + def test_combined_train_valid_size(self) -> None: + """Test if the resplitting that combines the datasets has correct size.""" + strategy: Dict[str, Tuple[str, ...]] = { + "train_valid_combined": ("train", "valid") + } + resplitter = MergeResplitter(strategy) + new_dataset = resplitter(self.dataset_dict) + self.assertEqual(len(new_dataset["train_valid_combined"]), 5) + + def test_resplitting_test_with_combined_strategy_size(self) -> None: + """Test if the resplitting that combines the datasets has correct size.""" + strategy: Dict[str, Tuple[str, ...]] = { + "train_valid_combined": ("train", "valid"), + "test": ("test",), + } + resplitter = MergeResplitter(strategy) + new_dataset = resplitter(self.dataset_dict) + self.assertEqual(len(new_dataset["test"]), 1) + + def test_invalid_resplit_strategy_exception_message(self) -> None: + """Test if the resplitting raises error when non-existing split is given.""" + strategy: Dict[str, Tuple[str, ...]] = { + "new_train": ("invalid_split",), + "new_test": ("test",), + } + resplitter = MergeResplitter(strategy) + with self.assertRaisesRegex( + ValueError, "The given dataset key 'invalid_split' is not present" + ): + resplitter(self.dataset_dict) + + def test_nonexistent_split_in_strategy(self) -> None: + """Test if the exception is raised when the nonexistent split name is given.""" + strategy: Dict[str, Tuple[str, ...]] = {"new_split": ("nonexistent_split",)} + resplitter = MergeResplitter(strategy) + with self.assertRaisesRegex( + ValueError, "The given dataset key 'nonexistent_split' is not present" + ): + resplitter(self.dataset_dict) + + def test_duplicate_merge_split_name(self) -> None: # pylint: disable=R0201 + """Test that the new split names are not the same.""" + strategy: Dict[str, Tuple[str, ...]] = { + "new_train": ("train", "valid"), + "test": ("train",), + } + with pytest.warns(UserWarning): + _ = MergeResplitter(strategy) + + def test_empty_dataset_dict(self) -> None: + """Test that the error is raised when the empty DatasetDict is given.""" + empty_dataset = DatasetDict({}) + strategy: Dict[str, Tuple[str, ...]] = {"new_train": ("train",)} + resplitter = MergeResplitter(strategy) + with self.assertRaisesRegex( + ValueError, "The given dataset key 'train' is not present" + ): + resplitter(empty_dataset) + + +def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool: + """Check if two Datasets have the same values.""" + # Check if both datasets have the same length + if len(ds1) != len(ds2): + return False + + # Iterate over each row and check for equality + for row1, row2 in zip(ds1, ds2): + if row1 != row2: + return False + + return True + + +if __name__ == "__main__": + unittest.main() diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index 056a12442d75..7badb1445460 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -16,8 +16,10 @@ import warnings -from typing import Dict, Union +from typing import Dict, Optional, Tuple, Union, cast +from flwr_datasets.common import Resplitter +from flwr_datasets.merge_resplitter import MergeResplitter from flwr_datasets.partitioner import IidPartitioner, Partitioner tested_datasets = [ @@ -67,6 +69,15 @@ def _instantiate_partitioners( return instantiated_partitioners +def _instantiate_resplitter_if_needed( + resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] +) -> Optional[Resplitter]: + """Instantiate `MergeResplitter` if resplitter is merge_config.""" + if resplitter and isinstance(resplitter, Dict): + resplitter = MergeResplitter(merge_config=resplitter) + return cast(Optional[Resplitter], resplitter) + + def _check_if_dataset_tested(dataset: str) -> None: """Check if the dataset is in the narrowed down list of the tested datasets.""" if dataset not in tested_datasets: