Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the split keyword optional for load_partition #2423

Merged
merged 9 commits into from
Nov 15, 2023
20 changes: 17 additions & 3 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
self._dataset: Optional[DatasetDict] = None
self._resplit: bool = False # Indicate if the resplit happened

def load_partition(self, idx: int, split: str) -> Dataset:
def load_partition(self, idx: int, split: Optional[str] = None) -> Dataset:
"""Load the partition specified by the idx in the selected split.

The dataset is downloaded only when the first call to `load_partition` or
Expand All @@ -97,8 +97,12 @@ def load_partition(self, idx: int, split: str) -> Dataset:
----------
idx: int
Partition index for the selected split, idx in {0, ..., num_partitions - 1}.
split: str
Name of the (partitioned) split (e.g. "train", "test").
split: Optional[str]
Name of the (partitioned) split (e.g. "train", "test"). You can skip this
parameter if there is only one partitioner for the dataset. The name will be
inferred automatically. (e.g. if partitioners={"train": 10} you do not need
to give this parameter, but if partitioners={"train": 10, "test": 100} you
need to give it, to differentiate which partitioner should be used.
adam-narozniak marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Expand All @@ -109,6 +113,9 @@ def load_partition(self, idx: int, split: str) -> Dataset:
self._resplit_dataset_if_needed()
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
if split is None:
self._check_if_no_split_keyword_possible()
split = list(self._partitioners.keys())[0]
self._check_if_split_present(split)
self._check_if_split_possible_to_federate(split)
partitioner: Partitioner = self._partitioners[split]
Expand Down Expand Up @@ -188,3 +195,10 @@ def _resplit_dataset_if_needed(self) -> None:
if self._resplitter:
self._dataset = self._resplitter(self._dataset)
self._resplit = True

def _check_if_no_split_keyword_possible(self) -> None:
if len(self._partitioners) != 1:
raise ValueError(
"Please give the split argument. You can omit the split keyword only if"
" there is a single partitioner specified."
adam-narozniak marked this conversation as resolved.
Show resolved Hide resolved
)
28 changes: 27 additions & 1 deletion datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from parameterized import parameterized, parameterized_class

import datasets
from datasets import DatasetDict, concatenate_datasets
from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.federated_dataset import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, Partitioner

Expand Down Expand Up @@ -94,6 +94,18 @@ def test_multiple_partitioners(self) -> None:
len(dataset[self.test_split]) // num_test_partitions,
)

def test_no_need_for_split_keyword_if_one_partitioner(self) -> None:
"""Test if partitions got with and without split args are the same."""
fds = FederatedDataset(dataset="mnist", partitioners={"train": 10})
partition_loaded_with_no_split_arg = fds.load_partition(0)
partition_loaded_with_verbose_split_arg = fds.load_partition(0, "train")
self.assertTrue(
datasets_are_equal(
partition_loaded_with_no_split_arg,
partition_loaded_with_verbose_split_arg,
)
)

def test_resplit_dataset_into_one(self) -> None:
"""Test resplit into a single dataset."""
dataset = datasets.load_dataset(self.dataset_name)
Expand Down Expand Up @@ -253,5 +265,19 @@ def test_cannot_use_the_old_split_names(self) -> None:
fds.load_partition(0, "train")


def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool:
"""Check if two Datasets have the same values."""
# Check if both datasets have the same length
if len(ds1) != len(ds2):
return False

# Iterate over each row and check for equality
for row1, row2 in zip(ds1, ds2):
if row1 != row2:
return False

return True


if __name__ == "__main__":
unittest.main()