Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable more datasets in FDS #2392

Merged
merged 4 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import datasets
from datasets import Dataset, DatasetDict
from flwr_datasets.partitioner import Partitioner
from flwr_datasets.utils import _check_if_dataset_supported, _instantiate_partitioners
from flwr_datasets.utils import _check_if_dataset_tested, _instantiate_partitioners


class FederatedDataset:
Expand Down Expand Up @@ -54,7 +54,7 @@ class FederatedDataset:
"""

def __init__(self, *, dataset: str, partitioners: Dict[str, int]) -> None:
_check_if_dataset_supported(dataset)
_check_if_dataset_tested(dataset)
self._dataset_name: str = dataset
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
Expand Down
2 changes: 1 addition & 1 deletion datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_no_split_in_the_dataset(self) -> None: # pylint: disable=R0201

def test_unsupported_dataset(self) -> None: # pylint: disable=R0201
"""Test creating FederatedDataset for unsupported dataset."""
with pytest.raises(ValueError):
with pytest.warns(UserWarning):
FederatedDataset(dataset="food101", partitioners={"train": 100})


Expand Down
26 changes: 14 additions & 12 deletions datasets/flwr_datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@
"""Utils for FederatedDataset."""


import warnings
from typing import Dict

from flwr_datasets.partitioner import IidPartitioner, Partitioner

tested_datasets = [
"mnist",
"cifar10",
"fashion_mnist",
"sasha/dog-food",
"zh-plus/tiny-imagenet",
]


def _instantiate_partitioners(partitioners: Dict[str, int]) -> Dict[str, Partitioner]:
"""Transform the partitioners from the initial format to instantiated objects.
Expand All @@ -41,17 +50,10 @@ def _instantiate_partitioners(partitioners: Dict[str, int]) -> Dict[str, Partiti
return instantiated_partitioners


def _check_if_dataset_supported(dataset: str) -> None:
def _check_if_dataset_tested(dataset: str) -> None:
"""Check if the dataset is in the narrowed down list of the tested datasets."""
supported_datasets = [
"mnist",
"cifar10",
"fashion_mnist",
"sasha/dog-food",
"zh-plus/tiny-imagenet",
]
if dataset not in supported_datasets:
raise ValueError(
f"The currently tested and supported dataset are {supported_datasets}. "
f"Given: {dataset}"
if dataset not in tested_datasets:
warnings.warn(
f"The currently tested dataset are {tested_datasets}. Given: {dataset}.",
stacklevel=1,
)