Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add resplitting functionality to Flower Datasets #2427

Merged
merged 25 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a4a295f
Add resplitter
adam-narozniak Sep 27, 2023
9aebcdb
Update FederatedDataset to work with resplitter
adam-narozniak Sep 27, 2023
c2866ac
Merge branch 'main' into fds-add-resplitting-functionality
adam-narozniak Oct 12, 2023
3a6e2e4
Rename Resplitter to MergeSplitter and custom Resplitter type
adam-narozniak Oct 12, 2023
de52f97
Fix MergeSplitter tests
adam-narozniak Oct 12, 2023
87ed594
Fix mypy problems for in tests
adam-narozniak Oct 12, 2023
d996737
Merge branch 'fds-add-resplitting-functionality'
adam-narozniak Oct 12, 2023
3c1af53
Fix spaces
adam-narozniak Oct 12, 2023
d01feb8
Merge branch 'main' into fds-add-resplitting-functionality
adam-narozniak Oct 12, 2023
261fe4c
Fix new lines
adam-narozniak Oct 12, 2023
b625ae8
Apply suggestions from code review
adam-narozniak Oct 13, 2023
f790f1b
Clarify the documentation of MergeResplitter
adam-narozniak Oct 13, 2023
f6adaa5
Update resplitter parameter docstring
adam-narozniak Oct 16, 2023
a9b6329
Fix new lines between docstring and imports
adam-narozniak Oct 16, 2023
71e9036
Fix new lines between docstring and imports
adam-narozniak Oct 16, 2023
d61350f
Add copyright notice
adam-narozniak Oct 16, 2023
f589f02
Check for duplicated split names in MergeResplitter merge_config
adam-narozniak Oct 16, 2023
2614fcb
Fix formatting
adam-narozniak Oct 16, 2023
809beea
Simply the _resplit_data_if_needed method
adam-narozniak Oct 16, 2023
10e2a69
Fix pylint tests
adam-narozniak Oct 16, 2023
dbad439
Switch the keys and values of the merge_config
adam-narozniak Oct 16, 2023
31fd770
Merge branch 'main' into fds-add-resplitting-functionality
adam-narozniak Oct 16, 2023
18a8c19
Update datasets/flwr_datasets/common/__init__.py
danieljanes Oct 16, 2023
b7249af
Update datasets/flwr_datasets/common/typing.py
danieljanes Oct 16, 2023
dd7bcbd
Merge branch 'main' into fds-add-resplitting-functionality
danieljanes Oct 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions datasets/flwr_datasets/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common components in Flower Datasets."""


from .typing import Resplitter

__all__ = ["Resplitter"]
22 changes: 22 additions & 0 deletions datasets/flwr_datasets/common/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower Datasets type definitions."""


from typing import Callable

from datasets import DatasetDict

Resplitter = Callable[[DatasetDict], DatasetDict]
34 changes: 31 additions & 3 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
"""FederatedDataset."""


from typing import Dict, Optional, Union
from typing import Dict, Optional, Tuple, Union

import datasets
from datasets import Dataset, DatasetDict
from flwr_datasets.common import Resplitter
from flwr_datasets.partitioner import Partitioner
from flwr_datasets.utils import _check_if_dataset_tested, _instantiate_partitioners
from flwr_datasets.utils import (
_check_if_dataset_tested,
_instantiate_partitioners,
_instantiate_resplitter_if_needed,
)


class FederatedDataset:
Expand All @@ -38,10 +43,13 @@ class FederatedDataset:
subset: str
Secondary information regarding the dataset, most often subset or version
(that is passed to the name in datasets.load_dataset).
resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]]
`Callable` that transforms `DatasetDict` splits, or configuration dict for
`MergeResplitter`.
partitioners: Dict[str, Union[Partitioner, int]]
A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int`
(representing the number of IID partitions that this split should be partitioned
into).
into).

Examples
--------
Expand All @@ -63,16 +71,21 @@ def __init__(
*,
dataset: str,
subset: Optional[str] = None,
resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None,
partitioners: Dict[str, Union[Partitioner, int]],
) -> None:
_check_if_dataset_tested(dataset)
self._dataset_name: str = dataset
self._subset: Optional[str] = subset
self._resplitter: Optional[Resplitter] = _instantiate_resplitter_if_needed(
resplitter
)
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
# Init (download) lazily on the first call to `load_partition` or `load_full`
self._dataset: Optional[DatasetDict] = None
self._resplit: bool = False # Indicate if the resplit happened

def load_partition(self, idx: int, split: str) -> Dataset:
"""Load the partition specified by the idx in the selected split.
Expand All @@ -93,6 +106,7 @@ def load_partition(self, idx: int, split: str) -> Dataset:
Single partition from the dataset split.
"""
self._download_dataset_if_none()
self._resplit_dataset_if_needed()
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
self._check_if_split_present(split)
Expand All @@ -118,6 +132,7 @@ def load_full(self, split: str) -> Dataset:
Part of the dataset identified by its split name.
"""
self._download_dataset_if_none()
self._resplit_dataset_if_needed()
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
self._check_if_split_present(split)
Expand Down Expand Up @@ -160,3 +175,16 @@ def _assign_dataset_to_partitioner(self, split: str) -> None:
raise ValueError("Dataset is not loaded yet.")
if not self._partitioners[split].is_dataset_assigned():
self._partitioners[split].dataset = self._dataset[split]

def _resplit_dataset_if_needed(self) -> None:
# The actual re-splitting can't be done more than once.
# The attribute `_resplit` indicates that the resplit happened.

# Resplit only once
if self._resplit:
return
if self._dataset is None:
raise ValueError("The dataset resplit should happen after the download.")
if self._resplitter:
self._dataset = self._resplitter(self._dataset)
self._resplit = True
62 changes: 62 additions & 0 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from parameterized import parameterized, parameterized_class

import datasets
from datasets import DatasetDict, concatenate_datasets
from flwr_datasets.federated_dataset import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, Partitioner

Expand Down Expand Up @@ -93,6 +94,55 @@ def test_multiple_partitioners(self) -> None:
len(dataset[self.test_split]) // num_test_partitions,
)

def test_resplit_dataset_into_one(self) -> None:
"""Test resplit into a single dataset."""
dataset = datasets.load_dataset(self.dataset_name)
dataset_length = sum([len(ds) for ds in dataset.values()])
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"train": 100},
resplitter={"full": ("train", self.test_split)},
)
full = fds.load_full("full")
self.assertEqual(dataset_length, len(full))

# pylint: disable=protected-access
def test_resplit_dataset_to_change_names(self) -> None:
"""Test resplitter to change the names of the partitions."""
fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"new_train": 100},
resplitter={
"new_train": ("train",),
"new_" + self.test_split: (self.test_split,),
},
)
_ = fds.load_partition(0, "new_train")
assert fds._dataset is not None
self.assertEqual(
set(fds._dataset.keys()), {"new_train", "new_" + self.test_split}
)

def test_resplit_dataset_by_callable(self) -> None:
"""Test resplitter to change the names of the partitions."""

def resplit(dataset: DatasetDict) -> DatasetDict:
return DatasetDict(
{
"full": concatenate_datasets(
[dataset["train"], dataset[self.test_split]]
)
}
)

fds = FederatedDataset(
dataset=self.dataset_name, partitioners={"train": 100}, resplitter=resplit
)
full = fds.load_full("full")
dataset = datasets.load_dataset(self.dataset_name)
dataset_length = sum([len(ds) for ds in dataset.values()])
self.assertEqual(len(full), dataset_length)


class PartitionersSpecificationForFederatedDatasets(unittest.TestCase):
"""Test the specifications of partitioners for `FederatedDataset`."""
Expand Down Expand Up @@ -190,6 +240,18 @@ def test_unsupported_dataset(self) -> None: # pylint: disable=R0201
with pytest.warns(UserWarning):
FederatedDataset(dataset="food101", partitioners={"train": 100})

def test_cannot_use_the_old_split_names(self) -> None:
"""Test if the initial split names can not be used."""
dataset = datasets.load_dataset("mnist")
sum([len(ds) for ds in dataset.values()])
fds = FederatedDataset(
dataset="mnist",
partitioners={"train": 100},
resplitter={"full": ("train", "test")},
)
with self.assertRaises(ValueError):
fds.load_partition(0, "train")


if __name__ == "__main__":
unittest.main()
110 changes: 110 additions & 0 deletions datasets/flwr_datasets/merge_resplitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MergeResplitter class for Flower Datasets."""


import collections
import warnings
from functools import reduce
from typing import Dict, List, Tuple

import datasets
from datasets import Dataset, DatasetDict


class MergeResplitter:
"""Merge existing splits of the dataset and assign them custom names.

Create new `DatasetDict` with new split names corresponding to the merged existing
splits (e.g. "train", "valid" and "test").

Parameters
----------
merge_config: Dict[str, Tuple[str, ...]]
Dictionary with keys - the desired split names to values - tuples of the current
split names that will be merged together

Examples
--------
Create new `DatasetDict` with a split name "new_train" that is created as a merger
of the "train" and "valid" splits. Keep the "test" split.

>>> # Assuming there is a dataset_dict of type `DatasetDict`
>>> # dataset_dict is {"train": train-data, "valid": valid-data, "test": test-data}
>>> merge_resplitter = MergeResplitter(
>>> merge_config={
>>> "new_train": ("train", "valid"),
>>> "test": ("test", )
>>> }
>>> )
>>> new_dataset_dict = merge_resplitter(dataset_dict)
>>> # new_dataset_dict is
>>> # {"new_train": concatenation of train-data and valid-data, "test": test-data}
"""

def __init__(
self,
merge_config: Dict[str, Tuple[str, ...]],
) -> None:
self._merge_config: Dict[str, Tuple[str, ...]] = merge_config
self._check_duplicate_merge_splits()

def __call__(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the `merge_config`."""
self._check_correct_keys_in_merge_config(dataset)
return self.resplit(dataset)

def resplit(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the `merge_config`."""
resplit_dataset = {}
for divide_to, divided_from__list in self._merge_config.items():
datasets_from_list: List[Dataset] = []
for divide_from in divided_from__list:
datasets_from_list.append(dataset[divide_from])
if len(datasets_from_list) > 1:
resplit_dataset[divide_to] = datasets.concatenate_datasets(
datasets_from_list
)
else:
resplit_dataset[divide_to] = datasets_from_list[0]
return datasets.DatasetDict(resplit_dataset)

def _check_correct_keys_in_merge_config(self, dataset: DatasetDict) -> None:
"""Check if the keys in merge_config are existing dataset splits."""
dataset_keys = dataset.keys()
specified_dataset_keys = self._merge_config.values()
for key_list in specified_dataset_keys:
for key in key_list:
if key not in dataset_keys:
raise ValueError(
f"The given dataset key '{key}' is not present in the given "
f"dataset object. Make sure to use only the keywords that are "
f"available in your dataset."
)

def _check_duplicate_merge_splits(self) -> None:
"""Check if the original splits are duplicated for new splits creation."""
merge_splits = reduce(lambda x, y: x + y, self._merge_config.values())
duplicates = [
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
item
for item, count in collections.Counter(merge_splits).items()
if count > 1
]
if duplicates:
warnings.warn(
f"More than one desired splits used '{duplicates[0]}' in "
f"`merge_config`. Make sure that is the intended behavior.",
stacklevel=1,
)
Loading