From 04e2ea3f085b8eb005f8ec4e4a6879fd52de97ed Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Wed, 15 Nov 2023 13:52:32 +0100 Subject: [PATCH] Add shuffle seed parameters to FederatedDataset (#2588) --- datasets/flwr_datasets/federated_dataset.py | 86 +++++++++++------- .../flwr_datasets/federated_dataset_test.py | 89 ++++++++++++++++++- 2 files changed, 144 insertions(+), 31 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 6e4e9ca43ccb..b1e61f1f9231 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -38,18 +38,25 @@ class FederatedDataset: Parameters ---------- - dataset: str + dataset : str The name of the dataset in the Hugging Face Hub. - subset: str + subset : str Secondary information regarding the dataset, most often subset or version (that is passed to the name in datasets.load_dataset). - resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] + resplitter : Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] `Callable` that transforms `DatasetDict` splits, or configuration dict for `MergeResplitter`. - partitioners: Dict[str, Union[Partitioner, int]] + partitioners : Dict[str, Union[Partitioner, int]] A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int` (representing the number of IID partitions that this split should be partitioned into). + shuffle : bool + Whether to randomize the order of samples. Applied prior to resplitting, + speratelly to each of the present splits in the dataset. It uses the `seed` + argument. Defaults to True. + seed : Optional[int] + Seed used for dataset shuffling. It has no effect if `shuffle` is False. The + seed cannot be set in the later stages. Examples -------- @@ -66,6 +73,7 @@ class FederatedDataset: >>> centralized = mnist_fds.load_full("test") """ + # pylint: disable=too-many-instance-attributes def __init__( self, *, @@ -73,6 +81,8 @@ def __init__( subset: Optional[str] = None, resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None, partitioners: Dict[str, Union[Partitioner, int]], + shuffle: bool = True, + seed: Optional[int] = 42, ) -> None: _check_if_dataset_tested(dataset) self._dataset_name: str = dataset @@ -83,9 +93,13 @@ def __init__( self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners ) - # Init (download) lazily on the first call to `load_partition` or `load_full` + self._shuffle = shuffle + self._seed = seed + # _dataset is prepared lazily on the first call to `load_partition` + # or `load_full`. See _prepare_datasets for more details self._dataset: Optional[DatasetDict] = None - self._resplit: bool = False # Indicate if the resplit happened + # Indicate if the dataset is prepared for `load_partition` or `load_full` + self._dataset_prepared: bool = False def load_partition(self, idx: int, split: str) -> Dataset: """Load the partition specified by the idx in the selected split. @@ -95,9 +109,9 @@ def load_partition(self, idx: int, split: str) -> Dataset: Parameters ---------- - idx: int + idx : int Partition index for the selected split, idx in {0, ..., num_partitions - 1}. - split: str + split : str Name of the (partitioned) split (e.g. "train", "test"). Returns @@ -105,8 +119,8 @@ def load_partition(self, idx: int, split: str) -> Dataset: partition: Dataset Single partition from the dataset split. """ - self._download_dataset_if_none() - self._resplit_dataset_if_needed() + if not self._dataset_prepared: + self._prepare_dataset() if self._dataset is None: raise ValueError("Dataset is not loaded yet.") self._check_if_split_present(split) @@ -123,7 +137,7 @@ def load_full(self, split: str) -> Dataset: Parameters ---------- - split: str + split : str Split name of the downloaded dataset (e.g. "train", "test"). Returns @@ -131,20 +145,13 @@ def load_full(self, split: str) -> Dataset: dataset_split: Dataset Part of the dataset identified by its split name. """ - self._download_dataset_if_none() - self._resplit_dataset_if_needed() + if not self._dataset_prepared: + self._prepare_dataset() if self._dataset is None: raise ValueError("Dataset is not loaded yet.") self._check_if_split_present(split) return self._dataset[split] - def _download_dataset_if_none(self) -> None: - """Lazily load (and potentially download) the Dataset instance into memory.""" - if self._dataset is None: - self._dataset = datasets.load_dataset( - path=self._dataset_name, name=self._subset - ) - def _check_if_split_present(self, split: str) -> None: """Check if the split (for partitioning or full return) is in the dataset.""" if self._dataset is None: @@ -176,15 +183,34 @@ def _assign_dataset_to_partitioner(self, split: str) -> None: if not self._partitioners[split].is_dataset_assigned(): self._partitioners[split].dataset = self._dataset[split] - def _resplit_dataset_if_needed(self) -> None: - # The actual re-splitting can't be done more than once. - # The attribute `_resplit` indicates that the resplit happened. - - # Resplit only once - if self._resplit: - return - if self._dataset is None: - raise ValueError("The dataset resplit should happen after the download.") + def _prepare_dataset(self) -> None: + """Prepare the dataset (prior to partitioning) by download, shuffle, replit. + + Run only ONCE when triggered by load_* function. (In future more control whether + this should happen lazily or not can be added). The operations done here should + not happen more than once. + + It is controlled by a single flag, `_dataset_prepared` that is set True at the + end of the function. + + Notes + ----- + The shuffling should happen before the resplitting. Here is the explanation. + If the dataset has a non-random order of samples e.g. each split has first + only label 0, then only label 1. Then in case of resplitting e.g. + someone creates: "train" train[:int(0.75 * len(train))], test: concat( + train[int(0.75 * len(train)):], test). The new test took the 0.25 of e.g. + the train that is only label 0 (assuming the equal count of labels). + Therefore, for such edge cases (for which we have split) the split should + happen before the resplitting. + """ + self._dataset = datasets.load_dataset( + path=self._dataset_name, name=self._subset + ) + if self._shuffle: + # Note it shuffles all the splits. The self._dataset is DatasetDict + # so e.g. {"train": train_data, "test": test_data}. All splits get shuffled. + self._dataset = self._dataset.shuffle(seed=self._seed) if self._resplitter: self._dataset = self._resplitter(self._dataset) - self._resplit = True + self._dataset_prepared = True diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index ca2bc97a33ee..1e36fd565d06 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -18,12 +18,13 @@ import unittest from typing import Dict, Union +from unittest.mock import Mock, patch import pytest from parameterized import parameterized, parameterized_class import datasets -from datasets import DatasetDict, concatenate_datasets +from datasets import Dataset, DatasetDict, concatenate_datasets from flwr_datasets.federated_dataset import FederatedDataset from flwr_datasets.partitioner import IidPartitioner, Partitioner @@ -144,6 +145,92 @@ def resplit(dataset: DatasetDict) -> DatasetDict: self.assertEqual(len(full), dataset_length) +class ArtificialDatasetTest(unittest.TestCase): + """Test using small artificial dataset, mocked load_dataset.""" + + # pylint: disable=no-self-use + def _dummy_setup(self, train_rows: int = 10, test_rows: int = 5) -> DatasetDict: + """Create a dummy DatasetDict with train, test splits.""" + data_train = { + "features": list(range(train_rows)), + "labels": list(range(100, 100 + train_rows)), + } + data_test = { + "features": [200] + [201] * (test_rows - 1), + "labels": [202] + [203] * (test_rows - 1), + } + train_dataset = Dataset.from_dict(data_train) + test_dataset = Dataset.from_dict(data_test) + return DatasetDict({"train": train_dataset, "test": test_dataset}) + + @patch("datasets.load_dataset") + def test_shuffling_applied(self, mock_func: Mock) -> None: + """Test if argument is used.""" + dummy_ds = self._dummy_setup() + mock_func.return_value = dummy_ds + + expected_result = dummy_ds.shuffle(seed=42)["train"]["features"] + fds = FederatedDataset( + dataset="does-not-matter", partitioners={"train": 10}, shuffle=True, seed=42 + ) + train = fds.load_full("train") + # This should be shuffled + result = train["features"] + + self.assertEqual(expected_result, result) + + @patch("datasets.load_dataset") + def test_shuffling_not_applied(self, mock_func: Mock) -> None: + """Test if argument is not used.""" + dummy_ds = self._dummy_setup() + mock_func.return_value = dummy_ds + + expected_result = dummy_ds["train"]["features"] + fds = FederatedDataset( + dataset="does-not-matter", + partitioners={"train": 10}, + shuffle=False, + ) + train = fds.load_full("train") + # This should not be shuffled + result = train["features"] + + self.assertEqual(expected_result, result) + + @patch("datasets.load_dataset") + def test_shuffling_before_to_resplitting_applied(self, mock_func: Mock) -> None: + """Check if the order is met and if the shuffling happens.""" + + def resplit(dataset: DatasetDict) -> DatasetDict: + # "Move" the last sample from test to train + return DatasetDict( + { + "train": concatenate_datasets( + [dataset["train"], dataset["test"].select([0])] + ), + "test": dataset["test"].select(range(1, dataset["test"].num_rows)), + } + ) + + dummy_ds = self._dummy_setup() + mock_func.return_value = dummy_ds + + expected_result = concatenate_datasets( + [dummy_ds["train"].shuffle(42), dummy_ds["test"].shuffle(42).select([0])] + )["features"] + fds = FederatedDataset( + dataset="does-not-matter", + partitioners={"train": 10}, + resplitter=resplit, + shuffle=True, + ) + train = fds.load_full("train") + # This should not be shuffled + result = train["features"] + + self.assertEqual(expected_result, result) + + class PartitionersSpecificationForFederatedDatasets(unittest.TestCase): """Test the specifications of partitioners for `FederatedDataset`."""