Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable more partitioners specification in more types #2403

Merged
merged 8 commits into from
Oct 12, 2023
15 changes: 11 additions & 4 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""FederatedDataset."""


from typing import Dict, Optional
from typing import Dict, Optional, Union

import datasets
from datasets import Dataset, DatasetDict
Expand All @@ -35,8 +35,10 @@ class FederatedDataset:
----------
dataset: str
The name of the dataset in the Hugging Face Hub.
partitioners: Dict[str, int]
Dataset split to the number of IID partitions.
partitioners: Dict[str, Union[Partitioner, int]]
A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int`
(representing the number of IID partitions that this split should be partitioned
into).

Examples
--------
Expand All @@ -53,7 +55,12 @@ class FederatedDataset:
>>> centralized = mnist_fds.load_full("test")
"""

def __init__(self, *, dataset: str, partitioners: Dict[str, int]) -> None:
def __init__(
self,
*,
dataset: str,
partitioners: Dict[str, Union[Partitioner, int]],
) -> None:
_check_if_dataset_tested(dataset)
self._dataset_name: str = dataset
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
Expand Down
75 changes: 75 additions & 0 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@
# limitations under the License.
# ==============================================================================
"""Federated Dataset tests."""
# pylint: disable=W0212, C0103, C0206


import unittest
from typing import Dict, Union

import pytest
from parameterized import parameterized, parameterized_class

import datasets
from flwr_datasets.federated_dataset import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, Partitioner


@parameterized_class(
Expand Down Expand Up @@ -91,6 +94,78 @@ def test_multiple_partitioners(self) -> None:
)


class PartitionersSpecificationForFederatedDatasets(unittest.TestCase):
"""Test the specifications of partitioners for `FederatedDataset`."""

dataset_name = "cifar10"
test_split = "test"

def test_dict_of_partitioners_passes_partitioners(self) -> None:
"""Test if partitioners are passed directly (no recreation)."""
num_train_partitions = 100
num_test_partitions = 100
partitioners: Dict[str, Union[Partitioner, int]] = {
"train": IidPartitioner(num_partitions=num_train_partitions),
"test": IidPartitioner(num_partitions=num_test_partitions),
}
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners=partitioners,
)

self.assertTrue(
all(fds._partitioners[key] == partitioners[key] for key in partitioners)
)

def test_dict_str_int_produces_correct_partitioners(self) -> None:
"""Test if dict partitioners have the same keys."""
num_train_partitions = 100
num_test_partitions = 100
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={
"train": num_train_partitions,
"test": num_test_partitions,
},
)
self.assertTrue(
len(fds._partitioners) == 2
and "train" in fds._partitioners
and "test" in fds._partitioners
)

def test_mixed_type_partitioners_passes_instantiated_partitioners(self) -> None:
"""Test if an instantiated partitioner is passed directly."""
num_train_partitions = 100
num_test_partitions = 100
partitioners: Dict[str, Union[Partitioner, int]] = {
"train": IidPartitioner(num_partitions=num_train_partitions),
"test": num_test_partitions,
}
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners=partitioners,
)
self.assertIs(fds._partitioners["train"], partitioners["train"])

def test_mixed_type_partitioners_creates_from_int(self) -> None:
"""Test if an IidPartitioner partitioner is created."""
num_train_partitions = 100
num_test_partitions = 100
partitioners: Dict[str, Union[Partitioner, int]] = {
"train": IidPartitioner(num_partitions=num_train_partitions),
"test": num_test_partitions,
}
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners=partitioners,
)
self.assertTrue(
isinstance(fds._partitioners["test"], IidPartitioner)
and fds._partitioners["test"]._num_partitions == num_test_partitions
)


class IncorrectUsageFederatedDatasets(unittest.TestCase):
"""Test incorrect usages in FederatedDatasets."""

Expand Down
31 changes: 24 additions & 7 deletions datasets/flwr_datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


import warnings
from typing import Dict
from typing import Dict, Union

from flwr_datasets.partitioner import IidPartitioner, Partitioner

Expand All @@ -29,23 +29,40 @@
]


def _instantiate_partitioners(partitioners: Dict[str, int]) -> Dict[str, Partitioner]:
def _instantiate_partitioners(
partitioners: Dict[str, Union[Partitioner, int]]
) -> Dict[str, Partitioner]:
"""Transform the partitioners from the initial format to instantiated objects.

Parameters
----------
partitioners: Dict[str, int]
Partitioners specified as split to the number of partitions format.
partitioners: Dict[str, Union[Partitioner, int]]
Dataset split to the Partitioner or a number of IID partitions.

Returns
-------
partitioners: Dict[str, Partitioner]
Partitioners specified as split to Partitioner object.
"""
instantiated_partitioners: Dict[str, Partitioner] = {}
for split_name, num_partitions in partitioners.items():
instantiated_partitioners[split_name] = IidPartitioner(
num_partitions=num_partitions
if isinstance(partitioners, Dict):
for split, partitioner in partitioners.items():
if isinstance(partitioner, Partitioner):
instantiated_partitioners[split] = partitioner
elif isinstance(partitioner, int):
instantiated_partitioners[split] = IidPartitioner(
num_partitions=partitioner
)
else:
raise ValueError(
f"Incorrect type of the 'partitioners' value encountered. "
f"Expected Partitioner or int. Given {type(partitioner)}"
)
else:
raise ValueError(
f"Incorrect type of the 'partitioners' encountered. "
f"Expected Dict[str, Union[int, Partitioner]]. "
f"Given {type(partitioners)}."
)
return instantiated_partitioners

Expand Down