Skip to content

Commit

Permalink
Merge branch 'main' into fedmeta
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Oct 16, 2023
2 parents 986c86c + 7f8a7e2 commit 0b9f5f6
Show file tree
Hide file tree
Showing 7 changed files with 402 additions and 4 deletions.
20 changes: 20 additions & 0 deletions datasets/flwr_datasets/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common components in Flower Datasets."""


from .typing import Resplitter

__all__ = ["Resplitter"]
22 changes: 22 additions & 0 deletions datasets/flwr_datasets/common/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower Datasets type definitions."""


from typing import Callable

from datasets import DatasetDict

Resplitter = Callable[[DatasetDict], DatasetDict]
34 changes: 31 additions & 3 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
"""FederatedDataset."""


from typing import Dict, Optional, Union
from typing import Dict, Optional, Tuple, Union

import datasets
from datasets import Dataset, DatasetDict
from flwr_datasets.common import Resplitter
from flwr_datasets.partitioner import Partitioner
from flwr_datasets.utils import _check_if_dataset_tested, _instantiate_partitioners
from flwr_datasets.utils import (
_check_if_dataset_tested,
_instantiate_partitioners,
_instantiate_resplitter_if_needed,
)


class FederatedDataset:
Expand All @@ -38,10 +43,13 @@ class FederatedDataset:
subset: str
Secondary information regarding the dataset, most often subset or version
(that is passed to the name in datasets.load_dataset).
resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]]
`Callable` that transforms `DatasetDict` splits, or configuration dict for
`MergeResplitter`.
partitioners: Dict[str, Union[Partitioner, int]]
A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int`
(representing the number of IID partitions that this split should be partitioned
into).
into).
Examples
--------
Expand All @@ -63,16 +71,21 @@ def __init__(
*,
dataset: str,
subset: Optional[str] = None,
resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None,
partitioners: Dict[str, Union[Partitioner, int]],
) -> None:
_check_if_dataset_tested(dataset)
self._dataset_name: str = dataset
self._subset: Optional[str] = subset
self._resplitter: Optional[Resplitter] = _instantiate_resplitter_if_needed(
resplitter
)
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
# Init (download) lazily on the first call to `load_partition` or `load_full`
self._dataset: Optional[DatasetDict] = None
self._resplit: bool = False # Indicate if the resplit happened

def load_partition(self, idx: int, split: str) -> Dataset:
"""Load the partition specified by the idx in the selected split.
Expand All @@ -93,6 +106,7 @@ def load_partition(self, idx: int, split: str) -> Dataset:
Single partition from the dataset split.
"""
self._download_dataset_if_none()
self._resplit_dataset_if_needed()
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
self._check_if_split_present(split)
Expand All @@ -118,6 +132,7 @@ def load_full(self, split: str) -> Dataset:
Part of the dataset identified by its split name.
"""
self._download_dataset_if_none()
self._resplit_dataset_if_needed()
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
self._check_if_split_present(split)
Expand Down Expand Up @@ -160,3 +175,16 @@ def _assign_dataset_to_partitioner(self, split: str) -> None:
raise ValueError("Dataset is not loaded yet.")
if not self._partitioners[split].is_dataset_assigned():
self._partitioners[split].dataset = self._dataset[split]

def _resplit_dataset_if_needed(self) -> None:
# The actual re-splitting can't be done more than once.
# The attribute `_resplit` indicates that the resplit happened.

# Resplit only once
if self._resplit:
return
if self._dataset is None:
raise ValueError("The dataset resplit should happen after the download.")
if self._resplitter:
self._dataset = self._resplitter(self._dataset)
self._resplit = True
62 changes: 62 additions & 0 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from parameterized import parameterized, parameterized_class

import datasets
from datasets import DatasetDict, concatenate_datasets
from flwr_datasets.federated_dataset import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, Partitioner

Expand Down Expand Up @@ -93,6 +94,55 @@ def test_multiple_partitioners(self) -> None:
len(dataset[self.test_split]) // num_test_partitions,
)

def test_resplit_dataset_into_one(self) -> None:
"""Test resplit into a single dataset."""
dataset = datasets.load_dataset(self.dataset_name)
dataset_length = sum([len(ds) for ds in dataset.values()])
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"train": 100},
resplitter={"full": ("train", self.test_split)},
)
full = fds.load_full("full")
self.assertEqual(dataset_length, len(full))

# pylint: disable=protected-access
def test_resplit_dataset_to_change_names(self) -> None:
"""Test resplitter to change the names of the partitions."""
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"new_train": 100},
resplitter={
"new_train": ("train",),
"new_" + self.test_split: (self.test_split,),
},
)
_ = fds.load_partition(0, "new_train")
assert fds._dataset is not None
self.assertEqual(
set(fds._dataset.keys()), {"new_train", "new_" + self.test_split}
)

def test_resplit_dataset_by_callable(self) -> None:
"""Test resplitter to change the names of the partitions."""

def resplit(dataset: DatasetDict) -> DatasetDict:
return DatasetDict(
{
"full": concatenate_datasets(
[dataset["train"], dataset[self.test_split]]
)
}
)

fds = FederatedDataset(
dataset=self.dataset_name, partitioners={"train": 100}, resplitter=resplit
)
full = fds.load_full("full")
dataset = datasets.load_dataset(self.dataset_name)
dataset_length = sum([len(ds) for ds in dataset.values()])
self.assertEqual(len(full), dataset_length)


class PartitionersSpecificationForFederatedDatasets(unittest.TestCase):
"""Test the specifications of partitioners for `FederatedDataset`."""
Expand Down Expand Up @@ -190,6 +240,18 @@ def test_unsupported_dataset(self) -> None: # pylint: disable=R0201
with pytest.warns(UserWarning):
FederatedDataset(dataset="food101", partitioners={"train": 100})

def test_cannot_use_the_old_split_names(self) -> None:
"""Test if the initial split names can not be used."""
dataset = datasets.load_dataset("mnist")
sum([len(ds) for ds in dataset.values()])
fds = FederatedDataset(
dataset="mnist",
partitioners={"train": 100},
resplitter={"full": ("train", "test")},
)
with self.assertRaises(ValueError):
fds.load_partition(0, "train")


if __name__ == "__main__":
unittest.main()
110 changes: 110 additions & 0 deletions datasets/flwr_datasets/merge_resplitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MergeResplitter class for Flower Datasets."""


import collections
import warnings
from functools import reduce
from typing import Dict, List, Tuple

import datasets
from datasets import Dataset, DatasetDict


class MergeResplitter:
"""Merge existing splits of the dataset and assign them custom names.
Create new `DatasetDict` with new split names corresponding to the merged existing
splits (e.g. "train", "valid" and "test").
Parameters
----------
merge_config: Dict[str, Tuple[str, ...]]
Dictionary with keys - the desired split names to values - tuples of the current
split names that will be merged together
Examples
--------
Create new `DatasetDict` with a split name "new_train" that is created as a merger
of the "train" and "valid" splits. Keep the "test" split.
>>> # Assuming there is a dataset_dict of type `DatasetDict`
>>> # dataset_dict is {"train": train-data, "valid": valid-data, "test": test-data}
>>> merge_resplitter = MergeResplitter(
>>> merge_config={
>>> "new_train": ("train", "valid"),
>>> "test": ("test", )
>>> }
>>> )
>>> new_dataset_dict = merge_resplitter(dataset_dict)
>>> # new_dataset_dict is
>>> # {"new_train": concatenation of train-data and valid-data, "test": test-data}
"""

def __init__(
self,
merge_config: Dict[str, Tuple[str, ...]],
) -> None:
self._merge_config: Dict[str, Tuple[str, ...]] = merge_config
self._check_duplicate_merge_splits()

def __call__(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the `merge_config`."""
self._check_correct_keys_in_merge_config(dataset)
return self.resplit(dataset)

def resplit(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the `merge_config`."""
resplit_dataset = {}
for divide_to, divided_from__list in self._merge_config.items():
datasets_from_list: List[Dataset] = []
for divide_from in divided_from__list:
datasets_from_list.append(dataset[divide_from])
if len(datasets_from_list) > 1:
resplit_dataset[divide_to] = datasets.concatenate_datasets(
datasets_from_list
)
else:
resplit_dataset[divide_to] = datasets_from_list[0]
return datasets.DatasetDict(resplit_dataset)

def _check_correct_keys_in_merge_config(self, dataset: DatasetDict) -> None:
"""Check if the keys in merge_config are existing dataset splits."""
dataset_keys = dataset.keys()
specified_dataset_keys = self._merge_config.values()
for key_list in specified_dataset_keys:
for key in key_list:
if key not in dataset_keys:
raise ValueError(
f"The given dataset key '{key}' is not present in the given "
f"dataset object. Make sure to use only the keywords that are "
f"available in your dataset."
)

def _check_duplicate_merge_splits(self) -> None:
"""Check if the original splits are duplicated for new splits creation."""
merge_splits = reduce(lambda x, y: x + y, self._merge_config.values())
duplicates = [
item
for item, count in collections.Counter(merge_splits).items()
if count > 1
]
if duplicates:
warnings.warn(
f"More than one desired splits used '{duplicates[0]}' in "
f"`merge_config`. Make sure that is the intended behavior.",
stacklevel=1,
)
Loading

0 comments on commit 0b9f5f6

Please sign in to comment.