Skip to content

Commit

Permalink
Enable more datasets in FDS (#2392)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak authored Oct 9, 2023
1 parent 00f0d77 commit 8fbfaf8
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
4 changes: 2 additions & 2 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import datasets
from datasets import Dataset, DatasetDict
from flwr_datasets.partitioner import Partitioner
from flwr_datasets.utils import _check_if_dataset_supported, _instantiate_partitioners
from flwr_datasets.utils import _check_if_dataset_tested, _instantiate_partitioners


class FederatedDataset:
Expand Down Expand Up @@ -54,7 +54,7 @@ class FederatedDataset:
"""

def __init__(self, *, dataset: str, partitioners: Dict[str, int]) -> None:
_check_if_dataset_supported(dataset)
_check_if_dataset_tested(dataset)
self._dataset_name: str = dataset
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
Expand Down
2 changes: 1 addition & 1 deletion datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_no_split_in_the_dataset(self) -> None: # pylint: disable=R0201

def test_unsupported_dataset(self) -> None: # pylint: disable=R0201
"""Test creating FederatedDataset for unsupported dataset."""
with pytest.raises(ValueError):
with pytest.warns(UserWarning):
FederatedDataset(dataset="food101", partitioners={"train": 100})


Expand Down
26 changes: 14 additions & 12 deletions datasets/flwr_datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@
"""Utils for FederatedDataset."""


import warnings
from typing import Dict

from flwr_datasets.partitioner import IidPartitioner, Partitioner

tested_datasets = [
"mnist",
"cifar10",
"fashion_mnist",
"sasha/dog-food",
"zh-plus/tiny-imagenet",
]


def _instantiate_partitioners(partitioners: Dict[str, int]) -> Dict[str, Partitioner]:
"""Transform the partitioners from the initial format to instantiated objects.
Expand All @@ -41,17 +50,10 @@ def _instantiate_partitioners(partitioners: Dict[str, int]) -> Dict[str, Partiti
return instantiated_partitioners


def _check_if_dataset_supported(dataset: str) -> None:
def _check_if_dataset_tested(dataset: str) -> None:
"""Check if the dataset is in the narrowed down list of the tested datasets."""
supported_datasets = [
"mnist",
"cifar10",
"fashion_mnist",
"sasha/dog-food",
"zh-plus/tiny-imagenet",
]
if dataset not in supported_datasets:
raise ValueError(
f"The currently tested and supported dataset are {supported_datasets}. "
f"Given: {dataset}"
if dataset not in tested_datasets:
warnings.warn(
f"The currently tested dataset are {tested_datasets}. Given: {dataset}.",
stacklevel=1,
)

0 comments on commit 8fbfaf8

Please sign in to comment.