Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the split keyword optional for load_partition #2423

Merged
merged 9 commits into from
Nov 15, 2023
25 changes: 20 additions & 5 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
# Indicate if the dataset is prepared for `load_partition` or `load_full`
self._dataset_prepared: bool = False

def load_partition(self, idx: int, split: str) -> Dataset:
def load_partition(self, idx: int, split: Optional[str] = None) -> Dataset:
"""Load the partition specified by the idx in the selected split.

The dataset is downloaded only when the first call to `load_partition` or
Expand All @@ -111,18 +111,26 @@ def load_partition(self, idx: int, split: str) -> Dataset:
----------
idx : int
Partition index for the selected split, idx in {0, ..., num_partitions - 1}.
split : str
Name of the (partitioned) split (e.g. "train", "test").
split : Optional[str]
Name of the (partitioned) split (e.g. "train", "test"). You can skip this
parameter if there is only one partitioner for the dataset. The name will be
inferred automatically. For example, if `partitioners={"train": 10}`, you do
not need to provide this argument, but if `partitioners={"train": 10,
"test": 100}`, you need to set it to differentiate which partitioner should
be used.

Returns
-------
partition: Dataset
partition : Dataset
Single partition from the dataset split.
"""
if not self._dataset_prepared:
self._prepare_dataset()
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
if split is None:
self._check_if_no_split_keyword_possible()
split = list(self._partitioners.keys())[0]
self._check_if_split_present(split)
self._check_if_split_possible_to_federate(split)
partitioner: Partitioner = self._partitioners[split]
Expand All @@ -142,7 +150,7 @@ def load_full(self, split: str) -> Dataset:

Returns
-------
dataset_split: Dataset
dataset_split : Dataset
Part of the dataset identified by its split name.
"""
if not self._dataset_prepared:
Expand Down Expand Up @@ -214,3 +222,10 @@ def _prepare_dataset(self) -> None:
if self._resplitter:
self._dataset = self._resplitter(self._dataset)
self._dataset_prepared = True

def _check_if_no_split_keyword_possible(self) -> None:
if len(self._partitioners) != 1:
raise ValueError(
"Please set the `split` argument. You can only omit the split keyword "
"if there is exactly one partitioner specified."
)
26 changes: 26 additions & 0 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,18 @@ def test_multiple_partitioners(self) -> None:
len(dataset[self.test_split]) // num_test_partitions,
)

def test_no_need_for_split_keyword_if_one_partitioner(self) -> None:
"""Test if partitions got with and without split args are the same."""
fds = FederatedDataset(dataset="mnist", partitioners={"train": 10})
partition_loaded_with_no_split_arg = fds.load_partition(0)
partition_loaded_with_verbose_split_arg = fds.load_partition(0, "train")
self.assertTrue(
datasets_are_equal(
partition_loaded_with_no_split_arg,
partition_loaded_with_verbose_split_arg,
)
)

def test_resplit_dataset_into_one(self) -> None:
"""Test resplit into a single dataset."""
dataset = datasets.load_dataset(self.dataset_name)
Expand Down Expand Up @@ -340,5 +352,19 @@ def test_cannot_use_the_old_split_names(self) -> None:
fds.load_partition(0, "train")


def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool:
"""Check if two Datasets have the same values."""
# Check if both datasets have the same length
if len(ds1) != len(ds2):
return False

# Iterate over each row and check for equality
for row1, row2 in zip(ds1, ds2):
if row1 != row2:
return False

return True


if __name__ == "__main__":
unittest.main()