From 73ff3e034646586c6872d44231b2c0ab2e8c1b4a Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:07:58 +0200 Subject: [PATCH] Add 'subset' keyword to FederatedDataset (#2420) --- datasets/flwr_datasets/federated_dataset.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 8d171db2afa4..c4691da397fe 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -35,6 +35,9 @@ class FederatedDataset: ---------- dataset: str The name of the dataset in the Hugging Face Hub. + subset: str + Secondary information regarding the dataset, most often subset or version + (that is passed to the name in datasets.load_dataset). partitioners: Dict[str, Union[Partitioner, int]] A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int` (representing the number of IID partitions that this split should be partitioned @@ -59,10 +62,12 @@ def __init__( self, *, dataset: str, + subset: Optional[str] = None, partitioners: Dict[str, Union[Partitioner, int]], ) -> None: _check_if_dataset_tested(dataset) self._dataset_name: str = dataset + self._subset: Optional[str] = subset self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners ) @@ -121,7 +126,9 @@ def load_full(self, split: str) -> Dataset: def _download_dataset_if_none(self) -> None: """Lazily load (and potentially download) the Dataset instance into memory.""" if self._dataset is None: - self._dataset = datasets.load_dataset(self._dataset_name) + self._dataset = datasets.load_dataset( + path=self._dataset_name, name=self._subset + ) def _check_if_split_present(self, split: str) -> None: """Check if the split (for partitioning or full return) is in the dataset."""