diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 4dcf3804099c..8d171db2afa4 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -15,7 +15,7 @@ """FederatedDataset.""" -from typing import Dict, Optional +from typing import Dict, Optional, Union import datasets from datasets import Dataset, DatasetDict @@ -35,8 +35,10 @@ class FederatedDataset: ---------- dataset: str The name of the dataset in the Hugging Face Hub. - partitioners: Dict[str, int] - Dataset split to the number of IID partitions. + partitioners: Dict[str, Union[Partitioner, int]] + A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int` + (representing the number of IID partitions that this split should be partitioned + into). Examples -------- @@ -53,7 +55,12 @@ class FederatedDataset: >>> centralized = mnist_fds.load_full("test") """ - def __init__(self, *, dataset: str, partitioners: Dict[str, int]) -> None: + def __init__( + self, + *, + dataset: str, + partitioners: Dict[str, Union[Partitioner, int]], + ) -> None: _check_if_dataset_tested(dataset) self._dataset_name: str = dataset self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 57483dfba79f..15485dfa7a95 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -13,15 +13,18 @@ # limitations under the License. # ============================================================================== """Federated Dataset tests.""" +# pylint: disable=W0212, C0103, C0206 import unittest +from typing import Dict, Union import pytest from parameterized import parameterized, parameterized_class import datasets from flwr_datasets.federated_dataset import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner, Partitioner @parameterized_class( @@ -91,6 +94,78 @@ def test_multiple_partitioners(self) -> None: ) +class PartitionersSpecificationForFederatedDatasets(unittest.TestCase): + """Test the specifications of partitioners for `FederatedDataset`.""" + + dataset_name = "cifar10" + test_split = "test" + + def test_dict_of_partitioners_passes_partitioners(self) -> None: + """Test if partitioners are passed directly (no recreation).""" + num_train_partitions = 100 + num_test_partitions = 100 + partitioners: Dict[str, Union[Partitioner, int]] = { + "train": IidPartitioner(num_partitions=num_train_partitions), + "test": IidPartitioner(num_partitions=num_test_partitions), + } + fds = FederatedDataset( + dataset=self.dataset_name, + partitioners=partitioners, + ) + + self.assertTrue( + all(fds._partitioners[key] == partitioners[key] for key in partitioners) + ) + + def test_dict_str_int_produces_correct_partitioners(self) -> None: + """Test if dict partitioners have the same keys.""" + num_train_partitions = 100 + num_test_partitions = 100 + fds = FederatedDataset( + dataset=self.dataset_name, + partitioners={ + "train": num_train_partitions, + "test": num_test_partitions, + }, + ) + self.assertTrue( + len(fds._partitioners) == 2 + and "train" in fds._partitioners + and "test" in fds._partitioners + ) + + def test_mixed_type_partitioners_passes_instantiated_partitioners(self) -> None: + """Test if an instantiated partitioner is passed directly.""" + num_train_partitions = 100 + num_test_partitions = 100 + partitioners: Dict[str, Union[Partitioner, int]] = { + "train": IidPartitioner(num_partitions=num_train_partitions), + "test": num_test_partitions, + } + fds = FederatedDataset( + dataset=self.dataset_name, + partitioners=partitioners, + ) + self.assertIs(fds._partitioners["train"], partitioners["train"]) + + def test_mixed_type_partitioners_creates_from_int(self) -> None: + """Test if an IidPartitioner partitioner is created.""" + num_train_partitions = 100 + num_test_partitions = 100 + partitioners: Dict[str, Union[Partitioner, int]] = { + "train": IidPartitioner(num_partitions=num_train_partitions), + "test": num_test_partitions, + } + fds = FederatedDataset( + dataset=self.dataset_name, + partitioners=partitioners, + ) + self.assertTrue( + isinstance(fds._partitioners["test"], IidPartitioner) + and fds._partitioners["test"]._num_partitions == num_test_partitions + ) + + class IncorrectUsageFederatedDatasets(unittest.TestCase): """Test incorrect usages in FederatedDatasets.""" diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index 60290c461203..056a12442d75 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -16,7 +16,7 @@ import warnings -from typing import Dict +from typing import Dict, Union from flwr_datasets.partitioner import IidPartitioner, Partitioner @@ -29,13 +29,15 @@ ] -def _instantiate_partitioners(partitioners: Dict[str, int]) -> Dict[str, Partitioner]: +def _instantiate_partitioners( + partitioners: Dict[str, Union[Partitioner, int]] +) -> Dict[str, Partitioner]: """Transform the partitioners from the initial format to instantiated objects. Parameters ---------- - partitioners: Dict[str, int] - Partitioners specified as split to the number of partitions format. + partitioners: Dict[str, Union[Partitioner, int]] + Dataset split to the Partitioner or a number of IID partitions. Returns ------- @@ -43,9 +45,24 @@ def _instantiate_partitioners(partitioners: Dict[str, int]) -> Dict[str, Partiti Partitioners specified as split to Partitioner object. """ instantiated_partitioners: Dict[str, Partitioner] = {} - for split_name, num_partitions in partitioners.items(): - instantiated_partitioners[split_name] = IidPartitioner( - num_partitions=num_partitions + if isinstance(partitioners, Dict): + for split, partitioner in partitioners.items(): + if isinstance(partitioner, Partitioner): + instantiated_partitioners[split] = partitioner + elif isinstance(partitioner, int): + instantiated_partitioners[split] = IidPartitioner( + num_partitions=partitioner + ) + else: + raise ValueError( + f"Incorrect type of the 'partitioners' value encountered. " + f"Expected Partitioner or int. Given {type(partitioner)}" + ) + else: + raise ValueError( + f"Incorrect type of the 'partitioners' encountered. " + f"Expected Dict[str, Union[int, Partitioner]]. " + f"Given {type(partitioners)}." ) return instantiated_partitioners