From 8fbfaf86169e827047707eaf5ae1d81a7b77ce8b Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Mon, 9 Oct 2023 21:08:43 +0200 Subject: [PATCH] Enable more datasets in FDS (#2392) --- datasets/flwr_datasets/federated_dataset.py | 4 +-- .../flwr_datasets/federated_dataset_test.py | 2 +- datasets/flwr_datasets/utils.py | 26 ++++++++++--------- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index bfd7ceff70b1..4dcf3804099c 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -20,7 +20,7 @@ import datasets from datasets import Dataset, DatasetDict from flwr_datasets.partitioner import Partitioner -from flwr_datasets.utils import _check_if_dataset_supported, _instantiate_partitioners +from flwr_datasets.utils import _check_if_dataset_tested, _instantiate_partitioners class FederatedDataset: @@ -54,7 +54,7 @@ class FederatedDataset: """ def __init__(self, *, dataset: str, partitioners: Dict[str, int]) -> None: - _check_if_dataset_supported(dataset) + _check_if_dataset_tested(dataset) self._dataset_name: str = dataset self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 467f9d856ed4..57483dfba79f 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -112,7 +112,7 @@ def test_no_split_in_the_dataset(self) -> None: # pylint: disable=R0201 def test_unsupported_dataset(self) -> None: # pylint: disable=R0201 """Test creating FederatedDataset for unsupported dataset.""" - with pytest.raises(ValueError): + with pytest.warns(UserWarning): FederatedDataset(dataset="food101", partitioners={"train": 100}) diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index 734594b77033..60290c461203 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -15,10 +15,19 @@ """Utils for FederatedDataset.""" +import warnings from typing import Dict from flwr_datasets.partitioner import IidPartitioner, Partitioner +tested_datasets = [ + "mnist", + "cifar10", + "fashion_mnist", + "sasha/dog-food", + "zh-plus/tiny-imagenet", +] + def _instantiate_partitioners(partitioners: Dict[str, int]) -> Dict[str, Partitioner]: """Transform the partitioners from the initial format to instantiated objects. @@ -41,17 +50,10 @@ def _instantiate_partitioners(partitioners: Dict[str, int]) -> Dict[str, Partiti return instantiated_partitioners -def _check_if_dataset_supported(dataset: str) -> None: +def _check_if_dataset_tested(dataset: str) -> None: """Check if the dataset is in the narrowed down list of the tested datasets.""" - supported_datasets = [ - "mnist", - "cifar10", - "fashion_mnist", - "sasha/dog-food", - "zh-plus/tiny-imagenet", - ] - if dataset not in supported_datasets: - raise ValueError( - f"The currently tested and supported dataset are {supported_datasets}. " - f"Given: {dataset}" + if dataset not in tested_datasets: + warnings.warn( + f"The currently tested dataset are {tested_datasets}. Given: {dataset}.", + stacklevel=1, )