Skip to content

Commit

Permalink
Enable more partitioners specification in more types (#2403)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
adam-narozniak and danieljanes authored Oct 12, 2023
1 parent 9f02177 commit 3d75c12
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 11 deletions.
15 changes: 11 additions & 4 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""FederatedDataset."""


from typing import Dict, Optional
from typing import Dict, Optional, Union

import datasets
from datasets import Dataset, DatasetDict
Expand All @@ -35,8 +35,10 @@ class FederatedDataset:
----------
dataset: str
The name of the dataset in the Hugging Face Hub.
partitioners: Dict[str, int]
Dataset split to the number of IID partitions.
partitioners: Dict[str, Union[Partitioner, int]]
A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int`
(representing the number of IID partitions that this split should be partitioned
into).
Examples
--------
Expand All @@ -53,7 +55,12 @@ class FederatedDataset:
>>> centralized = mnist_fds.load_full("test")
"""

def __init__(self, *, dataset: str, partitioners: Dict[str, int]) -> None:
def __init__(
self,
*,
dataset: str,
partitioners: Dict[str, Union[Partitioner, int]],
) -> None:
_check_if_dataset_tested(dataset)
self._dataset_name: str = dataset
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
Expand Down
75 changes: 75 additions & 0 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@
# limitations under the License.
# ==============================================================================
"""Federated Dataset tests."""
# pylint: disable=W0212, C0103, C0206


import unittest
from typing import Dict, Union

import pytest
from parameterized import parameterized, parameterized_class

import datasets
from flwr_datasets.federated_dataset import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, Partitioner


@parameterized_class(
Expand Down Expand Up @@ -91,6 +94,78 @@ def test_multiple_partitioners(self) -> None:
)


class PartitionersSpecificationForFederatedDatasets(unittest.TestCase):
"""Test the specifications of partitioners for `FederatedDataset`."""

dataset_name = "cifar10"
test_split = "test"

def test_dict_of_partitioners_passes_partitioners(self) -> None:
"""Test if partitioners are passed directly (no recreation)."""
num_train_partitions = 100
num_test_partitions = 100
partitioners: Dict[str, Union[Partitioner, int]] = {
"train": IidPartitioner(num_partitions=num_train_partitions),
"test": IidPartitioner(num_partitions=num_test_partitions),
}
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners=partitioners,
)

self.assertTrue(
all(fds._partitioners[key] == partitioners[key] for key in partitioners)
)

def test_dict_str_int_produces_correct_partitioners(self) -> None:
"""Test if dict partitioners have the same keys."""
num_train_partitions = 100
num_test_partitions = 100
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={
"train": num_train_partitions,
"test": num_test_partitions,
},
)
self.assertTrue(
len(fds._partitioners) == 2
and "train" in fds._partitioners
and "test" in fds._partitioners
)

def test_mixed_type_partitioners_passes_instantiated_partitioners(self) -> None:
"""Test if an instantiated partitioner is passed directly."""
num_train_partitions = 100
num_test_partitions = 100
partitioners: Dict[str, Union[Partitioner, int]] = {
"train": IidPartitioner(num_partitions=num_train_partitions),
"test": num_test_partitions,
}
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners=partitioners,
)
self.assertIs(fds._partitioners["train"], partitioners["train"])

def test_mixed_type_partitioners_creates_from_int(self) -> None:
"""Test if an IidPartitioner partitioner is created."""
num_train_partitions = 100
num_test_partitions = 100
partitioners: Dict[str, Union[Partitioner, int]] = {
"train": IidPartitioner(num_partitions=num_train_partitions),
"test": num_test_partitions,
}
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners=partitioners,
)
self.assertTrue(
isinstance(fds._partitioners["test"], IidPartitioner)
and fds._partitioners["test"]._num_partitions == num_test_partitions
)


class IncorrectUsageFederatedDatasets(unittest.TestCase):
"""Test incorrect usages in FederatedDatasets."""

Expand Down
31 changes: 24 additions & 7 deletions datasets/flwr_datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


import warnings
from typing import Dict
from typing import Dict, Union

from flwr_datasets.partitioner import IidPartitioner, Partitioner

Expand All @@ -29,23 +29,40 @@
]


def _instantiate_partitioners(partitioners: Dict[str, int]) -> Dict[str, Partitioner]:
def _instantiate_partitioners(
partitioners: Dict[str, Union[Partitioner, int]]
) -> Dict[str, Partitioner]:
"""Transform the partitioners from the initial format to instantiated objects.
Parameters
----------
partitioners: Dict[str, int]
Partitioners specified as split to the number of partitions format.
partitioners: Dict[str, Union[Partitioner, int]]
Dataset split to the Partitioner or a number of IID partitions.
Returns
-------
partitioners: Dict[str, Partitioner]
Partitioners specified as split to Partitioner object.
"""
instantiated_partitioners: Dict[str, Partitioner] = {}
for split_name, num_partitions in partitioners.items():
instantiated_partitioners[split_name] = IidPartitioner(
num_partitions=num_partitions
if isinstance(partitioners, Dict):
for split, partitioner in partitioners.items():
if isinstance(partitioner, Partitioner):
instantiated_partitioners[split] = partitioner
elif isinstance(partitioner, int):
instantiated_partitioners[split] = IidPartitioner(
num_partitions=partitioner
)
else:
raise ValueError(
f"Incorrect type of the 'partitioners' value encountered. "
f"Expected Partitioner or int. Given {type(partitioner)}"
)
else:
raise ValueError(
f"Incorrect type of the 'partitioners' encountered. "
f"Expected Dict[str, Union[int, Partitioner]]. "
f"Given {type(partitioners)}."
)
return instantiated_partitioners

Expand Down

0 comments on commit 3d75c12

Please sign in to comment.