From 73ff3e034646586c6872d44231b2c0ab2e8c1b4a Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:07:58 +0200 Subject: [PATCH 1/4] Add 'subset' keyword to FederatedDataset (#2420) --- datasets/flwr_datasets/federated_dataset.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 8d171db2afa4..c4691da397fe 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -35,6 +35,9 @@ class FederatedDataset: ---------- dataset: str The name of the dataset in the Hugging Face Hub. + subset: str + Secondary information regarding the dataset, most often subset or version + (that is passed to the name in datasets.load_dataset). partitioners: Dict[str, Union[Partitioner, int]] A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int` (representing the number of IID partitions that this split should be partitioned @@ -59,10 +62,12 @@ def __init__( self, *, dataset: str, + subset: Optional[str] = None, partitioners: Dict[str, Union[Partitioner, int]], ) -> None: _check_if_dataset_tested(dataset) self._dataset_name: str = dataset + self._subset: Optional[str] = subset self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners ) @@ -121,7 +126,9 @@ def load_full(self, split: str) -> Dataset: def _download_dataset_if_none(self) -> None: """Lazily load (and potentially download) the Dataset instance into memory.""" if self._dataset is None: - self._dataset = datasets.load_dataset(self._dataset_name) + self._dataset = datasets.load_dataset( + path=self._dataset_name, name=self._subset + ) def _check_if_split_present(self, split: str) -> None: """Check if the split (for partitioning or full return) is in the dataset.""" From 3ec5ee836bd16b024066b912cfd901e9dd198d46 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 16 Oct 2023 10:21:43 +0100 Subject: [PATCH 2/4] Update FedMLB baseline README (#2507) --- baselines/fedmlb/README.md | 28 ++++++++++++++-------------- doc/source/ref-changelog.md | 2 ++ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/baselines/fedmlb/README.md b/baselines/fedmlb/README.md index f2816637221b..47cc69b48e09 100644 --- a/baselines/fedmlb/README.md +++ b/baselines/fedmlb/README.md @@ -2,18 +2,18 @@ title: Multi-Level Branched Regularization for Federated Learning url: https://proceedings.mlr.press/v162/kim22a.html labels: [data heterogeneity, knowledge distillation, image classification] -dataset: [cifar100, tiny-imagenet] +dataset: [CIFAR-100, Tiny-ImageNet] --- -# *_FedMLB_* +# FedMLB: Multi-Level Branched Regularization for Federated Learning > Note: If you use this baseline in your work, please remember to cite the original authors of the paper as well as the Flower paper. -****Paper:**** [proceedings.mlr.press/v162/kim22a.html](https://proceedings.mlr.press/v162/kim22a.html) +**Paper:** [proceedings.mlr.press/v162/kim22a.html](https://proceedings.mlr.press/v162/kim22a.html) -****Authors:**** Jinkyu Kim, Geeho Kim, Bohyung Han +**Authors:** Jinkyu Kim, Geeho Kim, Bohyung Han -****Abstract:**** *_A critical challenge of federated learning is data +**Abstract:** *_A critical challenge of federated learning is data heterogeneity and imbalance across clients, which leads to inconsistency between local networks and unstable convergence of global models. To alleviate @@ -37,7 +37,7 @@ The source code is available in our project page._* ## About this baseline -****What’s implemented:**** The code in this directory reproduces the results for FedMLB, FedAvg, and FedAvg+KD. +**What’s implemented:** The code in this directory reproduces the results for FedMLB, FedAvg, and FedAvg+KD. The reproduced results use the CIFAR-100 dataset or the TinyImagenet dataset. Four settings are available for both the datasets, 1. Moderate-scale with Dir(0.3), 100 clients, 5% participation, balanced dataset. @@ -45,32 +45,32 @@ the datasets, 3. Moderate-scale with Dir(0.6), 100 clients, 5% participation rate, balanced dataset. 4. Large-scale experiments with Dir(0.6), 500 clients, 2% participation rate, balanced dataset. -****Datasets:**** CIFAR-100, Tiny-ImageNet. +**Datasets:** CIFAR-100, Tiny-ImageNet. -****Hardware Setup:**** The code in this repository has been tested on a Linux machine with 64GB RAM. +**Hardware Setup:** The code in this repository has been tested on a Linux machine with 64GB RAM. Be aware that in the default config the memory usage can exceed 10GB. -****Contributors:**** Alessio Mora (University of Bologna, PhD, alessio.mora@unibo.it). +**Contributors:** Alessio Mora (University of Bologna, PhD, alessio.mora@unibo.it). ## Experimental Setup -****Task:**** Image classification +**Task:** Image classification -****Model:**** ResNet-18. +**Model:** ResNet-18. -****Dataset:**** Four settings are available for CIFAR-100, +**Dataset:** Four settings are available for CIFAR-100, 1. Moderate-scale with Dir(0.3), 100 clients, 5% participation, balanced dataset (500 examples per client). 2. Large-scale experiments with Dir(0.3), 500 clients, 2% participation rate, balanced dataset (100 examples per client). 3. Moderate-scale with Dir(0.6), 100 clients, 5% participation rate, balanced dataset (500 examples per client). 4. Large-scale experiments with Dir(0.6), 500 clients, 2% participation rate, balanced dataset (100 examples per client). -****Dataset:**** Four settings are available for Tiny-Imagenet, +**Dataset:** Four settings are available for Tiny-Imagenet, 1. Moderate-scale with Dir(0.3), 100 clients, 5% participation, balanced dataset (1000 examples per client). 2. Large-scale experiments with Dir(0.3), 500 clients, 2% participation rate, balanced dataset (200 examples per client). 3. Moderate-scale with Dir(0.6), 100 clients, 5% participation rate, balanced dataset (1000 examples per client). 4. Large-scale experiments with Dir(0.6), 500 clients, 2% participation rate, balanced dataset (200 examples per client). -****Training Hyperparameters:**** +**Training Hyperparameters:** | Hyperparameter | Description | Default Value | | ------------- | ------------- | ------------- | diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index b2b339924f28..d0a29336acf1 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -22,6 +22,8 @@ - Baselines Docs ([#2290](https://github.com/adap/flower/pull/2290), [#2400](https://github.com/adap/flower/pull/2400)) + - FedMLB ([#2340](https://github.com/adap/flower/pull/2340), [#2507](https://github.com/adap/flower/pull/2507)) + - TAMUNA ([#2254](https://github.com/adap/flower/pull/2254), [#2508](https://github.com/adap/flower/pull/2508)) - **Update Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384)), ([#2425](https://github.com/adap/flower/pull/2425)) From 76081b0d19f149168162d8aefcfac610c75c0cdf Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 16 Oct 2023 11:36:32 +0100 Subject: [PATCH 3/4] Update baseline creation instruction (#2435) --- baselines/README.md | 2 +- ...es.rst => how-to-contribute-baselines.rst} | 40 ++++++++++++------- ...baselines.rst => how-to-use-baselines.rst} | 7 ---- baselines/doc/source/index.rst | 21 +++++----- doc/source/ref-changelog.md | 2 +- 5 files changed, 39 insertions(+), 33 deletions(-) rename baselines/doc/source/{tutorial-contribute-baselines.rst => how-to-contribute-baselines.rst} (64%) rename baselines/doc/source/{tutorial-use-baselines.rst => how-to-use-baselines.rst} (95%) diff --git a/baselines/README.md b/baselines/README.md index 4b97bedb3a1c..17f4d405cba3 100644 --- a/baselines/README.md +++ b/baselines/README.md @@ -49,7 +49,7 @@ Do you have a new federated learning paper and want to add a new baseline to Flo The steps to follow are: 1. Fork the Flower repo and clone it into your machine. -2. Navigate to the `baselines/` directory and from there run: +2. Navigate to the `baselines/` directory, choose a single-word (and **lowercase**) name for your baseline, and from there run: ```bash # This will create a new directory with the same structure as `baseline_template`. diff --git a/baselines/doc/source/tutorial-contribute-baselines.rst b/baselines/doc/source/how-to-contribute-baselines.rst similarity index 64% rename from baselines/doc/source/tutorial-contribute-baselines.rst rename to baselines/doc/source/how-to-contribute-baselines.rst index b627fff75fda..b568e73f1c11 100644 --- a/baselines/doc/source/tutorial-contribute-baselines.rst +++ b/baselines/doc/source/how-to-contribute-baselines.rst @@ -7,12 +7,20 @@ The goal of Flower Baselines is to reproduce experiments from popular papers to Before you start to work on a new baseline or experiment, please check the `Flower Issues `_ or `Flower Pull Requests `_ to see if someone else is already working on it. Please open a new issue if you are planning to work on a new baseline or experiment with a short description of the corresponding paper and the experiment you want to contribute. -TL;DR: Add a new Flower Baseline --------------------------------- -.. warning:: - We are in the process of changing how Flower Baselines are structured and updating the instructions for new contributors. Bear with us until we have finalised this transition. For now, follow the steps described below and reach out to us if something is not clear. We look forward to welcoming your baseline into Flower!! +Requirements +------------ + +Contributing a new baseline is really easy. You only have to make sure that your federated learning experiments are running with Flower and replicate the results of a paper. Flower baselines need to make use of: + +* `Poetry `_ to manage the Python environment. +* `Hydra `_ to manage the configuration files for your experiments. + +You can find more information about how to setup Poetry in your machine in the ``EXTENDED_README.md`` that is generated when you prepare your baseline. + +Add a new Flower Baseline +------------------------- .. note:: - For a detailed set of steps to follow, check the `Baselines README on GitHub `_. + The instructions below are a more verbose version of what's present in the `Baselines README on GitHub `_. Let's say you want to contribute the code of your most recent Federated Learning publication, *FedAwesome*. There are only three steps necessary to create a new *FedAwesome* Flower Baseline: @@ -20,10 +28,10 @@ Let's say you want to contribute the code of your most recent Federated Learning #. Fork the Flower codebase: go to the `Flower GitHub repo `_ and fork the code (click the *Fork* button in the top-right corner and follow the instructions) #. Clone the (forked) Flower source code: :code:`git clone git@github.com:[your_github_username]/flower.git` #. Open the code in your favorite editor. -#. **Create a directory for your baseline and add the FedAwesome code** +#. **Use the provided script to create your baseline directory** #. Navigate to the baselines directory and run :code:`./dev/create-baseline.sh fedawesome` #. A new directory in :code:`baselines/fedawesome` is created. - #. Follow the instructions in :code:`EXTENDED_README.md` and :code:`README.md` in :code:`baselines/fedawesome/`. + #. Follow the instructions in :code:`EXTENDED_README.md` and :code:`README.md` in your baseline directory. #. **Open a pull request** #. Stage your changes: :code:`git add .` #. Commit & push: :code:`git commit -m "Create new FedAwesome baseline" ; git push` @@ -36,18 +44,20 @@ Further reading: * `GitHub docs: Creating a pull request `_ * `GitHub docs: Creating a pull request from a fork `_ -Requirements ------------- - -Contributing a new baseline is really easy. You only have to make sure that your federated learning experiments are running with Flower and replicate the results of a paper. - -The only requirement you need in your system in order to create a baseline is to have `Poetry `_ installed. This is our package manager tool of choice. -We are adopting `Hydra `_ as the default mechanism to manage everything related to config files and the parameterisation of the Flower baseline. Usability --------- -Flower is known and loved for its usability. Therefore, make sure that your baseline or experiment can be executed with a single command such as :code:`conda run -m .main` or :code:`python main.py` (when sourced into your environment). We provide you with a `template-baseline `_ to use as guidance when contributing your baseline. Having all baselines follow a homogenous structure helps users to tryout many baselines without the overheads of having to understand each individual codebase. Similarly, by using Hydra throughout, users will immediately know how to parameterise your experiments directly from the command line. +Flower is known and loved for its usability. Therefore, make sure that your baseline or experiment can be executed with a single command such as: + +.. code-block:: bash + + poetry run python -m .main + + # or, once sourced into your environment + python -m .main + +We provide you with a `template-baseline `_ to use as guidance when contributing your baseline. Having all baselines follow a homogenous structure helps users to tryout many baselines without the overheads of having to understand each individual codebase. Similarly, by using Hydra throughout, users will immediately know how to parameterise your experiments directly from the command line. We look forward to your contribution! diff --git a/baselines/doc/source/tutorial-use-baselines.rst b/baselines/doc/source/how-to-use-baselines.rst similarity index 95% rename from baselines/doc/source/tutorial-use-baselines.rst rename to baselines/doc/source/how-to-use-baselines.rst index 80978d419c51..23e21b74dedc 100644 --- a/baselines/doc/source/tutorial-use-baselines.rst +++ b/baselines/doc/source/how-to-use-baselines.rst @@ -45,10 +45,3 @@ To install Poetry on a different OS, to customise your installation, or to furth poetry install 3. Run the baseline as indicated in the :code:`[Running the Experiments]` section in the :code:`README.md` - - -Available Baselines -------------------- - -.. note:: - To be updated soon once the existing baselines are adjusted to the new format. diff --git a/baselines/doc/source/index.rst b/baselines/doc/source/index.rst index 0aacbc28d117..335cfacef1ab 100644 --- a/baselines/doc/source/index.rst +++ b/baselines/doc/source/index.rst @@ -19,29 +19,32 @@ The Flower Community is growing quickly - we're a friendly group of researchers, Flower Baselines ---------------- -Flower Baselines are a collection of organised scripts used to reproduce results from well-known publications or benchmarks. You can check which baselines already exist and/or contribute your own baseline. +Flower Baselines are a collection of organised directories used to reproduce results from well-known publications or benchmarks. You can check which baselines already exist and/or contribute your own baseline. .. BASELINES_TABLE_ANCHOR + Tutorials ~~~~~~~~~ A learning-oriented series of tutorials, the best place to start. -.. toctree:: - :maxdepth: 1 - :caption: Tutorials - - tutorial-use-baselines - tutorial-contribute-baselines +.. note:: + Coming soon + How-to guides ~~~~~~~~~~~~~ Problem-oriented how-to guides show step-by-step how to achieve a specific goal. -.. note:: - Coming soon +.. toctree:: + :maxdepth: 1 + :caption: How-to Guides + + how-to-use-baselines + how-to-contribute-baselines + Explanations ~~~~~~~~~~~~ diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index d0a29336acf1..cf482521fa73 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -28,7 +28,7 @@ - **Update Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384)), ([#2425](https://github.com/adap/flower/pull/2425)) -- **General updates to baselines** ([#2301](https://github.com/adap/flower/pull/2301), [#2305](https://github.com/adap/flower/pull/2305), [#2307](https://github.com/adap/flower/pull/2307), [#2327](https://github.com/adap/flower/pull/2327)) +- **General updates to baselines** ([#2301](https://github.com/adap/flower/pull/2301), [#2305](https://github.com/adap/flower/pull/2305), [#2307](https://github.com/adap/flower/pull/2307), [#2327](https://github.com/adap/flower/pull/2327), [#2435](https://github.com/adap/flower/pull/2435)) - **General updates to the simulation engine** ([#2331](https://github.com/adap/flower/pull/2331), [#2447](https://github.com/adap/flower/pull/2447), [#2448](https://github.com/adap/flower/pull/2448)) From 7f8a7e25c5ff21f8169beb02e97be144902c0820 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Mon, 16 Oct 2023 15:11:29 +0200 Subject: [PATCH 4/4] Add resplitting functionality to Flower Datasets (#2427) Co-authored-by: Daniel J. Beutel --- datasets/flwr_datasets/common/__init__.py | 20 +++ datasets/flwr_datasets/common/typing.py | 22 +++ datasets/flwr_datasets/federated_dataset.py | 34 +++- .../flwr_datasets/federated_dataset_test.py | 62 ++++++++ datasets/flwr_datasets/merge_resplitter.py | 110 +++++++++++++ .../flwr_datasets/merge_resplitter_test.py | 145 ++++++++++++++++++ datasets/flwr_datasets/utils.py | 13 +- 7 files changed, 402 insertions(+), 4 deletions(-) create mode 100644 datasets/flwr_datasets/common/__init__.py create mode 100644 datasets/flwr_datasets/common/typing.py create mode 100644 datasets/flwr_datasets/merge_resplitter.py create mode 100644 datasets/flwr_datasets/merge_resplitter_test.py diff --git a/datasets/flwr_datasets/common/__init__.py b/datasets/flwr_datasets/common/__init__.py new file mode 100644 index 000000000000..a6468bcf7fda --- /dev/null +++ b/datasets/flwr_datasets/common/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common components in Flower Datasets.""" + + +from .typing import Resplitter + +__all__ = ["Resplitter"] diff --git a/datasets/flwr_datasets/common/typing.py b/datasets/flwr_datasets/common/typing.py new file mode 100644 index 000000000000..28e6bae4a505 --- /dev/null +++ b/datasets/flwr_datasets/common/typing.py @@ -0,0 +1,22 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower Datasets type definitions.""" + + +from typing import Callable + +from datasets import DatasetDict + +Resplitter = Callable[[DatasetDict], DatasetDict] diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index c4691da397fe..dd332107dc74 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -15,12 +15,17 @@ """FederatedDataset.""" -from typing import Dict, Optional, Union +from typing import Dict, Optional, Tuple, Union import datasets from datasets import Dataset, DatasetDict +from flwr_datasets.common import Resplitter from flwr_datasets.partitioner import Partitioner -from flwr_datasets.utils import _check_if_dataset_tested, _instantiate_partitioners +from flwr_datasets.utils import ( + _check_if_dataset_tested, + _instantiate_partitioners, + _instantiate_resplitter_if_needed, +) class FederatedDataset: @@ -38,10 +43,13 @@ class FederatedDataset: subset: str Secondary information regarding the dataset, most often subset or version (that is passed to the name in datasets.load_dataset). + resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] + `Callable` that transforms `DatasetDict` splits, or configuration dict for + `MergeResplitter`. partitioners: Dict[str, Union[Partitioner, int]] A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int` (representing the number of IID partitions that this split should be partitioned - into). + into). Examples -------- @@ -63,16 +71,21 @@ def __init__( *, dataset: str, subset: Optional[str] = None, + resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None, partitioners: Dict[str, Union[Partitioner, int]], ) -> None: _check_if_dataset_tested(dataset) self._dataset_name: str = dataset self._subset: Optional[str] = subset + self._resplitter: Optional[Resplitter] = _instantiate_resplitter_if_needed( + resplitter + ) self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners ) # Init (download) lazily on the first call to `load_partition` or `load_full` self._dataset: Optional[DatasetDict] = None + self._resplit: bool = False # Indicate if the resplit happened def load_partition(self, idx: int, split: str) -> Dataset: """Load the partition specified by the idx in the selected split. @@ -93,6 +106,7 @@ def load_partition(self, idx: int, split: str) -> Dataset: Single partition from the dataset split. """ self._download_dataset_if_none() + self._resplit_dataset_if_needed() if self._dataset is None: raise ValueError("Dataset is not loaded yet.") self._check_if_split_present(split) @@ -118,6 +132,7 @@ def load_full(self, split: str) -> Dataset: Part of the dataset identified by its split name. """ self._download_dataset_if_none() + self._resplit_dataset_if_needed() if self._dataset is None: raise ValueError("Dataset is not loaded yet.") self._check_if_split_present(split) @@ -160,3 +175,16 @@ def _assign_dataset_to_partitioner(self, split: str) -> None: raise ValueError("Dataset is not loaded yet.") if not self._partitioners[split].is_dataset_assigned(): self._partitioners[split].dataset = self._dataset[split] + + def _resplit_dataset_if_needed(self) -> None: + # The actual re-splitting can't be done more than once. + # The attribute `_resplit` indicates that the resplit happened. + + # Resplit only once + if self._resplit: + return + if self._dataset is None: + raise ValueError("The dataset resplit should happen after the download.") + if self._resplitter: + self._dataset = self._resplitter(self._dataset) + self._resplit = True diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 15485dfa7a95..ca2bc97a33ee 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -23,6 +23,7 @@ from parameterized import parameterized, parameterized_class import datasets +from datasets import DatasetDict, concatenate_datasets from flwr_datasets.federated_dataset import FederatedDataset from flwr_datasets.partitioner import IidPartitioner, Partitioner @@ -93,6 +94,55 @@ def test_multiple_partitioners(self) -> None: len(dataset[self.test_split]) // num_test_partitions, ) + def test_resplit_dataset_into_one(self) -> None: + """Test resplit into a single dataset.""" + dataset = datasets.load_dataset(self.dataset_name) + dataset_length = sum([len(ds) for ds in dataset.values()]) + fds = FederatedDataset( + dataset=self.dataset_name, + partitioners={"train": 100}, + resplitter={"full": ("train", self.test_split)}, + ) + full = fds.load_full("full") + self.assertEqual(dataset_length, len(full)) + + # pylint: disable=protected-access + def test_resplit_dataset_to_change_names(self) -> None: + """Test resplitter to change the names of the partitions.""" + fds = FederatedDataset( + dataset=self.dataset_name, + partitioners={"new_train": 100}, + resplitter={ + "new_train": ("train",), + "new_" + self.test_split: (self.test_split,), + }, + ) + _ = fds.load_partition(0, "new_train") + assert fds._dataset is not None + self.assertEqual( + set(fds._dataset.keys()), {"new_train", "new_" + self.test_split} + ) + + def test_resplit_dataset_by_callable(self) -> None: + """Test resplitter to change the names of the partitions.""" + + def resplit(dataset: DatasetDict) -> DatasetDict: + return DatasetDict( + { + "full": concatenate_datasets( + [dataset["train"], dataset[self.test_split]] + ) + } + ) + + fds = FederatedDataset( + dataset=self.dataset_name, partitioners={"train": 100}, resplitter=resplit + ) + full = fds.load_full("full") + dataset = datasets.load_dataset(self.dataset_name) + dataset_length = sum([len(ds) for ds in dataset.values()]) + self.assertEqual(len(full), dataset_length) + class PartitionersSpecificationForFederatedDatasets(unittest.TestCase): """Test the specifications of partitioners for `FederatedDataset`.""" @@ -190,6 +240,18 @@ def test_unsupported_dataset(self) -> None: # pylint: disable=R0201 with pytest.warns(UserWarning): FederatedDataset(dataset="food101", partitioners={"train": 100}) + def test_cannot_use_the_old_split_names(self) -> None: + """Test if the initial split names can not be used.""" + dataset = datasets.load_dataset("mnist") + sum([len(ds) for ds in dataset.values()]) + fds = FederatedDataset( + dataset="mnist", + partitioners={"train": 100}, + resplitter={"full": ("train", "test")}, + ) + with self.assertRaises(ValueError): + fds.load_partition(0, "train") + if __name__ == "__main__": unittest.main() diff --git a/datasets/flwr_datasets/merge_resplitter.py b/datasets/flwr_datasets/merge_resplitter.py new file mode 100644 index 000000000000..995b0e8e5602 --- /dev/null +++ b/datasets/flwr_datasets/merge_resplitter.py @@ -0,0 +1,110 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MergeResplitter class for Flower Datasets.""" + + +import collections +import warnings +from functools import reduce +from typing import Dict, List, Tuple + +import datasets +from datasets import Dataset, DatasetDict + + +class MergeResplitter: + """Merge existing splits of the dataset and assign them custom names. + + Create new `DatasetDict` with new split names corresponding to the merged existing + splits (e.g. "train", "valid" and "test"). + + Parameters + ---------- + merge_config: Dict[str, Tuple[str, ...]] + Dictionary with keys - the desired split names to values - tuples of the current + split names that will be merged together + + Examples + -------- + Create new `DatasetDict` with a split name "new_train" that is created as a merger + of the "train" and "valid" splits. Keep the "test" split. + + >>> # Assuming there is a dataset_dict of type `DatasetDict` + >>> # dataset_dict is {"train": train-data, "valid": valid-data, "test": test-data} + >>> merge_resplitter = MergeResplitter( + >>> merge_config={ + >>> "new_train": ("train", "valid"), + >>> "test": ("test", ) + >>> } + >>> ) + >>> new_dataset_dict = merge_resplitter(dataset_dict) + >>> # new_dataset_dict is + >>> # {"new_train": concatenation of train-data and valid-data, "test": test-data} + """ + + def __init__( + self, + merge_config: Dict[str, Tuple[str, ...]], + ) -> None: + self._merge_config: Dict[str, Tuple[str, ...]] = merge_config + self._check_duplicate_merge_splits() + + def __call__(self, dataset: DatasetDict) -> DatasetDict: + """Resplit the dataset according to the `merge_config`.""" + self._check_correct_keys_in_merge_config(dataset) + return self.resplit(dataset) + + def resplit(self, dataset: DatasetDict) -> DatasetDict: + """Resplit the dataset according to the `merge_config`.""" + resplit_dataset = {} + for divide_to, divided_from__list in self._merge_config.items(): + datasets_from_list: List[Dataset] = [] + for divide_from in divided_from__list: + datasets_from_list.append(dataset[divide_from]) + if len(datasets_from_list) > 1: + resplit_dataset[divide_to] = datasets.concatenate_datasets( + datasets_from_list + ) + else: + resplit_dataset[divide_to] = datasets_from_list[0] + return datasets.DatasetDict(resplit_dataset) + + def _check_correct_keys_in_merge_config(self, dataset: DatasetDict) -> None: + """Check if the keys in merge_config are existing dataset splits.""" + dataset_keys = dataset.keys() + specified_dataset_keys = self._merge_config.values() + for key_list in specified_dataset_keys: + for key in key_list: + if key not in dataset_keys: + raise ValueError( + f"The given dataset key '{key}' is not present in the given " + f"dataset object. Make sure to use only the keywords that are " + f"available in your dataset." + ) + + def _check_duplicate_merge_splits(self) -> None: + """Check if the original splits are duplicated for new splits creation.""" + merge_splits = reduce(lambda x, y: x + y, self._merge_config.values()) + duplicates = [ + item + for item, count in collections.Counter(merge_splits).items() + if count > 1 + ] + if duplicates: + warnings.warn( + f"More than one desired splits used '{duplicates[0]}' in " + f"`merge_config`. Make sure that is the intended behavior.", + stacklevel=1, + ) diff --git a/datasets/flwr_datasets/merge_resplitter_test.py b/datasets/flwr_datasets/merge_resplitter_test.py new file mode 100644 index 000000000000..096bd7efac27 --- /dev/null +++ b/datasets/flwr_datasets/merge_resplitter_test.py @@ -0,0 +1,145 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Resplitter tests.""" + + +import unittest +from typing import Dict, Tuple + +import pytest + +from datasets import Dataset, DatasetDict +from flwr_datasets.merge_resplitter import MergeResplitter + + +class TestResplitter(unittest.TestCase): + """Resplitter tests.""" + + def setUp(self) -> None: + """Set up the dataset with 3 splits for tests.""" + self.dataset_dict = DatasetDict( + { + "train": Dataset.from_dict({"data": [1, 2, 3]}), + "valid": Dataset.from_dict({"data": [4, 5]}), + "test": Dataset.from_dict({"data": [6]}), + } + ) + + def test_resplitting_train_size(self) -> None: + """Test if resplitting for just renaming keeps the lengths correct.""" + strategy: Dict[str, Tuple[str, ...]] = {"new_train": ("train",)} + resplitter = MergeResplitter(strategy) + new_dataset = resplitter(self.dataset_dict) + self.assertEqual(len(new_dataset["new_train"]), 3) + + def test_resplitting_valid_size(self) -> None: + """Test if resplitting for just renaming keeps the lengths correct.""" + strategy: Dict[str, Tuple[str, ...]] = {"new_valid": ("valid",)} + resplitter = MergeResplitter(strategy) + new_dataset = resplitter(self.dataset_dict) + self.assertEqual(len(new_dataset["new_valid"]), 2) + + def test_resplitting_test_size(self) -> None: + """Test if resplitting for just renaming keeps the lengths correct.""" + strategy: Dict[str, Tuple[str, ...]] = {"new_test": ("test",)} + resplitter = MergeResplitter(strategy) + new_dataset = resplitter(self.dataset_dict) + self.assertEqual(len(new_dataset["new_test"]), 1) + + def test_resplitting_train_the_same(self) -> None: + """Test if resplitting for just renaming keeps the dataset the same.""" + strategy: Dict[str, Tuple[str, ...]] = {"new_train": ("train",)} + resplitter = MergeResplitter(strategy) + new_dataset = resplitter(self.dataset_dict) + self.assertTrue( + datasets_are_equal(self.dataset_dict["train"], new_dataset["new_train"]) + ) + + def test_combined_train_valid_size(self) -> None: + """Test if the resplitting that combines the datasets has correct size.""" + strategy: Dict[str, Tuple[str, ...]] = { + "train_valid_combined": ("train", "valid") + } + resplitter = MergeResplitter(strategy) + new_dataset = resplitter(self.dataset_dict) + self.assertEqual(len(new_dataset["train_valid_combined"]), 5) + + def test_resplitting_test_with_combined_strategy_size(self) -> None: + """Test if the resplitting that combines the datasets has correct size.""" + strategy: Dict[str, Tuple[str, ...]] = { + "train_valid_combined": ("train", "valid"), + "test": ("test",), + } + resplitter = MergeResplitter(strategy) + new_dataset = resplitter(self.dataset_dict) + self.assertEqual(len(new_dataset["test"]), 1) + + def test_invalid_resplit_strategy_exception_message(self) -> None: + """Test if the resplitting raises error when non-existing split is given.""" + strategy: Dict[str, Tuple[str, ...]] = { + "new_train": ("invalid_split",), + "new_test": ("test",), + } + resplitter = MergeResplitter(strategy) + with self.assertRaisesRegex( + ValueError, "The given dataset key 'invalid_split' is not present" + ): + resplitter(self.dataset_dict) + + def test_nonexistent_split_in_strategy(self) -> None: + """Test if the exception is raised when the nonexistent split name is given.""" + strategy: Dict[str, Tuple[str, ...]] = {"new_split": ("nonexistent_split",)} + resplitter = MergeResplitter(strategy) + with self.assertRaisesRegex( + ValueError, "The given dataset key 'nonexistent_split' is not present" + ): + resplitter(self.dataset_dict) + + def test_duplicate_merge_split_name(self) -> None: # pylint: disable=R0201 + """Test that the new split names are not the same.""" + strategy: Dict[str, Tuple[str, ...]] = { + "new_train": ("train", "valid"), + "test": ("train",), + } + with pytest.warns(UserWarning): + _ = MergeResplitter(strategy) + + def test_empty_dataset_dict(self) -> None: + """Test that the error is raised when the empty DatasetDict is given.""" + empty_dataset = DatasetDict({}) + strategy: Dict[str, Tuple[str, ...]] = {"new_train": ("train",)} + resplitter = MergeResplitter(strategy) + with self.assertRaisesRegex( + ValueError, "The given dataset key 'train' is not present" + ): + resplitter(empty_dataset) + + +def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool: + """Check if two Datasets have the same values.""" + # Check if both datasets have the same length + if len(ds1) != len(ds2): + return False + + # Iterate over each row and check for equality + for row1, row2 in zip(ds1, ds2): + if row1 != row2: + return False + + return True + + +if __name__ == "__main__": + unittest.main() diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index 056a12442d75..7badb1445460 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -16,8 +16,10 @@ import warnings -from typing import Dict, Union +from typing import Dict, Optional, Tuple, Union, cast +from flwr_datasets.common import Resplitter +from flwr_datasets.merge_resplitter import MergeResplitter from flwr_datasets.partitioner import IidPartitioner, Partitioner tested_datasets = [ @@ -67,6 +69,15 @@ def _instantiate_partitioners( return instantiated_partitioners +def _instantiate_resplitter_if_needed( + resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] +) -> Optional[Resplitter]: + """Instantiate `MergeResplitter` if resplitter is merge_config.""" + if resplitter and isinstance(resplitter, Dict): + resplitter = MergeResplitter(merge_config=resplitter) + return cast(Optional[Resplitter], resplitter) + + def _check_if_dataset_tested(dataset: str) -> None: """Check if the dataset is in the narrowed down list of the tested datasets.""" if dataset not in tested_datasets: