Skip to content

Commit

Permalink
Merge branch 'main' into fedmeta
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Oct 16, 2023
2 parents deda708 + 73ff3e0 commit c3fdd01
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
9 changes: 8 additions & 1 deletion datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class FederatedDataset:
----------
dataset: str
The name of the dataset in the Hugging Face Hub.
subset: str
Secondary information regarding the dataset, most often subset or version
(that is passed to the name in datasets.load_dataset).
partitioners: Dict[str, Union[Partitioner, int]]
A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int`
(representing the number of IID partitions that this split should be partitioned
Expand All @@ -59,10 +62,12 @@ def __init__(
self,
*,
dataset: str,
subset: Optional[str] = None,
partitioners: Dict[str, Union[Partitioner, int]],
) -> None:
_check_if_dataset_tested(dataset)
self._dataset_name: str = dataset
self._subset: Optional[str] = subset
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
Expand Down Expand Up @@ -121,7 +126,9 @@ def load_full(self, split: str) -> Dataset:
def _download_dataset_if_none(self) -> None:
"""Lazily load (and potentially download) the Dataset instance into memory."""
if self._dataset is None:
self._dataset = datasets.load_dataset(self._dataset_name)
self._dataset = datasets.load_dataset(
path=self._dataset_name, name=self._subset
)

def _check_if_split_present(self, split: str) -> None:
"""Check if the split (for partitioning or full return) is in the dataset."""
Expand Down
2 changes: 1 addition & 1 deletion e2e/fastai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ authors = ["The Flower Authors <[email protected]>"]
python = ">=3.8,<3.10"
flwr = { path = "../../", develop = true, extras = ["simulation"] }
fastai = "^2.7.12"
torch = ">=2.0.0, !=2.0.1"
torch = ">=2.0.0, !=2.0.1, < 2.1.0"

0 comments on commit c3fdd01

Please sign in to comment.