Skip to content

Commit

Permalink
Add 'subset' keyword to FederatedDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Sep 25, 2023
1 parent dca3102 commit a8fc4cd
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================
"""FederatedDataset."""


from typing import Dict, Optional

import datasets
Expand All @@ -35,6 +34,9 @@ class FederatedDataset:
----------
dataset: str
The name of the dataset in the Hugging Face Hub.
subset: str
Secondary information regarding the dataset, most often subset or version
(that is passed to the name in datasets.load_dataset).
partitioners: Dict[str, int]
Dataset split to the number of IID partitions.
Expand All @@ -53,9 +55,16 @@ class FederatedDataset:
>>> centralized = mnist_fds.load_full("test")
"""

def __init__(self, *, dataset: str, partitioners: Dict[str, int]) -> None:
def __init__(
self,
*,
dataset: str,
subset: Optional[str] = None,
partitioners: Dict[str, int],
) -> None:
_check_if_dataset_supported(dataset)
self._dataset_name: str = dataset
self._subset: Optional[str] = subset
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
Expand Down Expand Up @@ -114,7 +123,9 @@ def load_full(self, split: str) -> Dataset:
def _download_dataset_if_none(self) -> None:
"""Lazily load (and potentially download) the Dataset instance into memory."""
if self._dataset is None:
self._dataset = datasets.load_dataset(self._dataset_name)
self._dataset = datasets.load_dataset(
path=self._dataset_name, name=self._subset
)

def _check_if_split_present(self, split: str) -> None:
"""Check if the split (for partitioning or full return) is in the dataset."""
Expand Down

0 comments on commit a8fc4cd

Please sign in to comment.