Skip to content

Commit

Permalink
Add shuffle seed parameters to FederatedDataset (#2588)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak authored Nov 15, 2023
1 parent d00a579 commit 04e2ea3
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 31 deletions.
86 changes: 56 additions & 30 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,25 @@ class FederatedDataset:
Parameters
----------
dataset: str
dataset : str
The name of the dataset in the Hugging Face Hub.
subset: str
subset : str
Secondary information regarding the dataset, most often subset or version
(that is passed to the name in datasets.load_dataset).
resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]]
resplitter : Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]]
`Callable` that transforms `DatasetDict` splits, or configuration dict for
`MergeResplitter`.
partitioners: Dict[str, Union[Partitioner, int]]
partitioners : Dict[str, Union[Partitioner, int]]
A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int`
(representing the number of IID partitions that this split should be partitioned
into).
shuffle : bool
Whether to randomize the order of samples. Applied prior to resplitting,
speratelly to each of the present splits in the dataset. It uses the `seed`
argument. Defaults to True.
seed : Optional[int]
Seed used for dataset shuffling. It has no effect if `shuffle` is False. The
seed cannot be set in the later stages.
Examples
--------
Expand All @@ -66,13 +73,16 @@ class FederatedDataset:
>>> centralized = mnist_fds.load_full("test")
"""

# pylint: disable=too-many-instance-attributes
def __init__(
self,
*,
dataset: str,
subset: Optional[str] = None,
resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None,
partitioners: Dict[str, Union[Partitioner, int]],
shuffle: bool = True,
seed: Optional[int] = 42,
) -> None:
_check_if_dataset_tested(dataset)
self._dataset_name: str = dataset
Expand All @@ -83,9 +93,13 @@ def __init__(
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
# Init (download) lazily on the first call to `load_partition` or `load_full`
self._shuffle = shuffle
self._seed = seed
# _dataset is prepared lazily on the first call to `load_partition`
# or `load_full`. See _prepare_datasets for more details
self._dataset: Optional[DatasetDict] = None
self._resplit: bool = False # Indicate if the resplit happened
# Indicate if the dataset is prepared for `load_partition` or `load_full`
self._dataset_prepared: bool = False

def load_partition(self, idx: int, split: str) -> Dataset:
"""Load the partition specified by the idx in the selected split.
Expand All @@ -95,18 +109,18 @@ def load_partition(self, idx: int, split: str) -> Dataset:
Parameters
----------
idx: int
idx : int
Partition index for the selected split, idx in {0, ..., num_partitions - 1}.
split: str
split : str
Name of the (partitioned) split (e.g. "train", "test").
Returns
-------
partition: Dataset
Single partition from the dataset split.
"""
self._download_dataset_if_none()
self._resplit_dataset_if_needed()
if not self._dataset_prepared:
self._prepare_dataset()
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
self._check_if_split_present(split)
Expand All @@ -123,28 +137,21 @@ def load_full(self, split: str) -> Dataset:
Parameters
----------
split: str
split : str
Split name of the downloaded dataset (e.g. "train", "test").
Returns
-------
dataset_split: Dataset
Part of the dataset identified by its split name.
"""
self._download_dataset_if_none()
self._resplit_dataset_if_needed()
if not self._dataset_prepared:
self._prepare_dataset()
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
self._check_if_split_present(split)
return self._dataset[split]

def _download_dataset_if_none(self) -> None:
"""Lazily load (and potentially download) the Dataset instance into memory."""
if self._dataset is None:
self._dataset = datasets.load_dataset(
path=self._dataset_name, name=self._subset
)

def _check_if_split_present(self, split: str) -> None:
"""Check if the split (for partitioning or full return) is in the dataset."""
if self._dataset is None:
Expand Down Expand Up @@ -176,15 +183,34 @@ def _assign_dataset_to_partitioner(self, split: str) -> None:
if not self._partitioners[split].is_dataset_assigned():
self._partitioners[split].dataset = self._dataset[split]

def _resplit_dataset_if_needed(self) -> None:
# The actual re-splitting can't be done more than once.
# The attribute `_resplit` indicates that the resplit happened.

# Resplit only once
if self._resplit:
return
if self._dataset is None:
raise ValueError("The dataset resplit should happen after the download.")
def _prepare_dataset(self) -> None:
"""Prepare the dataset (prior to partitioning) by download, shuffle, replit.
Run only ONCE when triggered by load_* function. (In future more control whether
this should happen lazily or not can be added). The operations done here should
not happen more than once.
It is controlled by a single flag, `_dataset_prepared` that is set True at the
end of the function.
Notes
-----
The shuffling should happen before the resplitting. Here is the explanation.
If the dataset has a non-random order of samples e.g. each split has first
only label 0, then only label 1. Then in case of resplitting e.g.
someone creates: "train" train[:int(0.75 * len(train))], test: concat(
train[int(0.75 * len(train)):], test). The new test took the 0.25 of e.g.
the train that is only label 0 (assuming the equal count of labels).
Therefore, for such edge cases (for which we have split) the split should
happen before the resplitting.
"""
self._dataset = datasets.load_dataset(
path=self._dataset_name, name=self._subset
)
if self._shuffle:
# Note it shuffles all the splits. The self._dataset is DatasetDict
# so e.g. {"train": train_data, "test": test_data}. All splits get shuffled.
self._dataset = self._dataset.shuffle(seed=self._seed)
if self._resplitter:
self._dataset = self._resplitter(self._dataset)
self._resplit = True
self._dataset_prepared = True
89 changes: 88 additions & 1 deletion datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

import unittest
from typing import Dict, Union
from unittest.mock import Mock, patch

import pytest
from parameterized import parameterized, parameterized_class

import datasets
from datasets import DatasetDict, concatenate_datasets
from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.federated_dataset import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, Partitioner

Expand Down Expand Up @@ -144,6 +145,92 @@ def resplit(dataset: DatasetDict) -> DatasetDict:
self.assertEqual(len(full), dataset_length)


class ArtificialDatasetTest(unittest.TestCase):
"""Test using small artificial dataset, mocked load_dataset."""

# pylint: disable=no-self-use
def _dummy_setup(self, train_rows: int = 10, test_rows: int = 5) -> DatasetDict:
"""Create a dummy DatasetDict with train, test splits."""
data_train = {
"features": list(range(train_rows)),
"labels": list(range(100, 100 + train_rows)),
}
data_test = {
"features": [200] + [201] * (test_rows - 1),
"labels": [202] + [203] * (test_rows - 1),
}
train_dataset = Dataset.from_dict(data_train)
test_dataset = Dataset.from_dict(data_test)
return DatasetDict({"train": train_dataset, "test": test_dataset})

@patch("datasets.load_dataset")
def test_shuffling_applied(self, mock_func: Mock) -> None:
"""Test if argument is used."""
dummy_ds = self._dummy_setup()
mock_func.return_value = dummy_ds

expected_result = dummy_ds.shuffle(seed=42)["train"]["features"]
fds = FederatedDataset(
dataset="does-not-matter", partitioners={"train": 10}, shuffle=True, seed=42
)
train = fds.load_full("train")
# This should be shuffled
result = train["features"]

self.assertEqual(expected_result, result)

@patch("datasets.load_dataset")
def test_shuffling_not_applied(self, mock_func: Mock) -> None:
"""Test if argument is not used."""
dummy_ds = self._dummy_setup()
mock_func.return_value = dummy_ds

expected_result = dummy_ds["train"]["features"]
fds = FederatedDataset(
dataset="does-not-matter",
partitioners={"train": 10},
shuffle=False,
)
train = fds.load_full("train")
# This should not be shuffled
result = train["features"]

self.assertEqual(expected_result, result)

@patch("datasets.load_dataset")
def test_shuffling_before_to_resplitting_applied(self, mock_func: Mock) -> None:
"""Check if the order is met and if the shuffling happens."""

def resplit(dataset: DatasetDict) -> DatasetDict:
# "Move" the last sample from test to train
return DatasetDict(
{
"train": concatenate_datasets(
[dataset["train"], dataset["test"].select([0])]
),
"test": dataset["test"].select(range(1, dataset["test"].num_rows)),
}
)

dummy_ds = self._dummy_setup()
mock_func.return_value = dummy_ds

expected_result = concatenate_datasets(
[dummy_ds["train"].shuffle(42), dummy_ds["test"].shuffle(42).select([0])]
)["features"]
fds = FederatedDataset(
dataset="does-not-matter",
partitioners={"train": 10},
resplitter=resplit,
shuffle=True,
)
train = fds.load_full("train")
# This should not be shuffled
result = train["features"]

self.assertEqual(expected_result, result)


class PartitionersSpecificationForFederatedDatasets(unittest.TestCase):
"""Test the specifications of partitioners for `FederatedDataset`."""

Expand Down

0 comments on commit 04e2ea3

Please sign in to comment.