From f0532abf2026b392fb2e24ac7f3c5131246681ca Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 26 Sep 2023 12:35:59 +0200 Subject: [PATCH 1/5] Make the split keyword optional for load_partition --- datasets/flwr_datasets/federated_dataset.py | 20 +++++++++++--- .../flwr_datasets/federated_dataset_test.py | 27 +++++++++++++++++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index bfd7ceff70b1..0dd4580517d8 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -62,7 +62,7 @@ def __init__(self, *, dataset: str, partitioners: Dict[str, int]) -> None: # Init (download) lazily on the first call to `load_partition` or `load_full` self._dataset: Optional[DatasetDict] = None - def load_partition(self, idx: int, split: str) -> Dataset: + def load_partition(self, idx: int, split: Optional[str] = None) -> Dataset: """Load the partition specified by the idx in the selected split. The dataset is downloaded only when the first call to `load_partition` or @@ -72,8 +72,12 @@ def load_partition(self, idx: int, split: str) -> Dataset: ---------- idx: int Partition index for the selected split, idx in {0, ..., num_partitions - 1}. - split: str - Name of the (partitioned) split (e.g. "train", "test"). + split: Optional[str] + Name of the (partitioned) split (e.g. "train", "test"). You can skip this + parameter if there is only one partitioner for the dataset. The name will be + inferred automatically. (e.g. if partitioners={"train": 10} you do not need + to give this parameter, but if partitioners={"train": 10, "test": 100} you + need to give it, to differentiate which partitioner should be used. Returns ------- @@ -83,6 +87,9 @@ def load_partition(self, idx: int, split: str) -> Dataset: self._download_dataset_if_none() if self._dataset is None: raise ValueError("Dataset is not loaded yet.") + if split is None: + self._check_if_no_split_keyword_possible() + split = list(self._partitioners.keys())[0] self._check_if_split_present(split) self._check_if_split_possible_to_federate(split) partitioner: Partitioner = self._partitioners[split] @@ -146,3 +153,10 @@ def _assign_dataset_to_partitioner(self, split: str) -> None: raise ValueError("Dataset is not loaded yet.") if not self._partitioners[split].is_dataset_assigned(): self._partitioners[split].dataset = self._dataset[split] + + def _check_if_no_split_keyword_possible(self) -> None: + if len(self._partitioners) != 1: + raise ValueError( + "Please give the split argument. You can omit the split keyword only if" + " there is a single partitioner specified." + ) diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 467f9d856ed4..a29faaaf9132 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -21,6 +21,7 @@ from parameterized import parameterized, parameterized_class import datasets +from datasets import Dataset from flwr_datasets.federated_dataset import FederatedDataset @@ -90,6 +91,18 @@ def test_multiple_partitioners(self) -> None: len(dataset[self.test_split]) // num_test_partitions, ) + def test_no_need_for_split_keyword_if_one_partitioner(self) -> None: + """Test if partitions got with and without split args are the same.""" + fds = FederatedDataset(dataset="mnist", partitioners={"train": 10}) + partition_loaded_with_no_split_arg = fds.load_partition(0) + partition_loaded_with_verbose_split_arg = fds.load_partition(0, "train") + self.assertTrue( + datasets_are_equal( + partition_loaded_with_no_split_arg, + partition_loaded_with_verbose_split_arg, + ) + ) + class IncorrectUsageFederatedDatasets(unittest.TestCase): """Test incorrect usages in FederatedDatasets.""" @@ -116,5 +129,19 @@ def test_unsupported_dataset(self) -> None: # pylint: disable=R0201 FederatedDataset(dataset="food101", partitioners={"train": 100}) +def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool: + """Check if two Datasets have the same values.""" + # Check if both datasets have the same length + if len(ds1) != len(ds2): + return False + + # Iterate over each row and check for equality + for row1, row2 in zip(ds1, ds2): + if row1 != row2: + return False + + return True + + if __name__ == "__main__": unittest.main() From 235ebe30ffafe7001d03b2d54ca47a6eb4d4cafe Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Wed, 15 Nov 2023 11:10:21 +0100 Subject: [PATCH 2/5] Apply suggestions from code review Co-authored-by: Daniel J. Beutel --- datasets/flwr_datasets/federated_dataset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 60125d8bc9bc..de04fd2f5a9a 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -100,9 +100,9 @@ def load_partition(self, idx: int, split: Optional[str] = None) -> Dataset: split: Optional[str] Name of the (partitioned) split (e.g. "train", "test"). You can skip this parameter if there is only one partitioner for the dataset. The name will be - inferred automatically. (e.g. if partitioners={"train": 10} you do not need - to give this parameter, but if partitioners={"train": 10, "test": 100} you - need to give it, to differentiate which partitioner should be used. + inferred automatically. For example, if `partitioners={"train": 10}`, you do not need + to provide this argument, but if `partitioners={"train": 10, "test": 100}`, you + need to set it to differentiate which partitioner should be used. Returns ------- @@ -199,6 +199,6 @@ def _resplit_dataset_if_needed(self) -> None: def _check_if_no_split_keyword_possible(self) -> None: if len(self._partitioners) != 1: raise ValueError( - "Please give the split argument. You can omit the split keyword only if" - " there is a single partitioner specified." + "Please set the `split` argument. You can only omit the split keyword if" + " there is exactly one partitioner specified." ) From fdf47a35a72fb5d96e7017da7db6674d3d6c8d16 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Wed, 15 Nov 2023 11:15:15 +0100 Subject: [PATCH 3/5] Fix formatting --- datasets/flwr_datasets/federated_dataset.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index de04fd2f5a9a..56881a003a3a 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -100,9 +100,10 @@ def load_partition(self, idx: int, split: Optional[str] = None) -> Dataset: split: Optional[str] Name of the (partitioned) split (e.g. "train", "test"). You can skip this parameter if there is only one partitioner for the dataset. The name will be - inferred automatically. For example, if `partitioners={"train": 10}`, you do not need - to provide this argument, but if `partitioners={"train": 10, "test": 100}`, you - need to set it to differentiate which partitioner should be used. + inferred automatically. For example, if `partitioners={"train": 10}`, you do + not need to provide this argument, but if `partitioners={"train": 10, + "test": 100}`, you need to set it to differentiate which partitioner should + be used. Returns ------- @@ -199,6 +200,6 @@ def _resplit_dataset_if_needed(self) -> None: def _check_if_no_split_keyword_possible(self) -> None: if len(self._partitioners) != 1: raise ValueError( - "Please set the `split` argument. You can only omit the split keyword if" - " there is exactly one partitioner specified." + "Please set the `split` argument. You can only omit the split keyword " + "if there is exactly one partitioner specified." ) From cc3c6d55c1460a85c46c6e884516203e9ad22fec Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Wed, 15 Nov 2023 14:02:34 +0100 Subject: [PATCH 4/5] Fix spacing around colons in params in docs --- datasets/flwr_datasets/federated_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 33cf6d0db611..4afd51e57ba8 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -111,7 +111,7 @@ def load_partition(self, idx: int, split: Optional[str] = None) -> Dataset: ---------- idx : int Partition index for the selected split, idx in {0, ..., num_partitions - 1}. - split: Optional[str] + split : Optional[str] Name of the (partitioned) split (e.g. "train", "test"). You can skip this parameter if there is only one partitioner for the dataset. The name will be inferred automatically. For example, if `partitioners={"train": 10}`, you do @@ -121,7 +121,7 @@ def load_partition(self, idx: int, split: Optional[str] = None) -> Dataset: Returns ------- - partition: Dataset + partition : Dataset Single partition from the dataset split. """ if not self._dataset_prepared: From e2b897125ca5db0e1f5309fca6d653fdcadf2473 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Wed, 15 Nov 2023 14:03:12 +0100 Subject: [PATCH 5/5] Fix spacing around colons in params in docs --- datasets/flwr_datasets/federated_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 4afd51e57ba8..52acc4d40100 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -150,7 +150,7 @@ def load_full(self, split: str) -> Dataset: Returns ------- - dataset_split: Dataset + dataset_split : Dataset Part of the dataset identified by its split name. """ if not self._dataset_prepared: