Skip to content

Commit

Permalink
Clarify the documentation of MergeResplitter
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Oct 13, 2023
1 parent b625ae8 commit f790f1b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 30 deletions.
6 changes: 3 additions & 3 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import datasets
from datasets import Dataset, DatasetDict
from flwr_datasets.merge_splitter import MergeSplitter
from flwr_datasets.merge_resplitter import MergeResplitter
from flwr_datasets.partitioner import Partitioner
from flwr_datasets.utils import _check_if_dataset_tested, _instantiate_partitioners

Expand All @@ -39,7 +39,7 @@ class FederatedDataset:
dataset: str
The name of the dataset in the Hugging Face Hub.
resplitter: Optional[Union[Resplitter, Dict[Tuple[str, ...], str]]]
Resplit strategy or custom `Callable` that transforms splits in the `DatasetDict`.
Resplit strategy or `Callable` that transforms splits in the `DatasetDict`.
partitioners: Dict[str, Union[Partitioner, int]]
A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int`
(representing the number of IID partitions that this split should be partitioned
Expand Down Expand Up @@ -176,7 +176,7 @@ def _resplit_dataset_if_needed(self) -> None:
if self._resplitter:
resplitter: Resplitter
if isinstance(self._resplitter, Dict):
resplitter = MergeSplitter(resplit_strategy=self._resplitter)
resplitter = MergeResplitter(merge_config=self._resplitter)
else:
resplitter = self._resplitter
self._dataset = resplitter(self._dataset)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Resplitter class for Flower Datasets."""
"""MergeResplitter class for Flower Datasets."""
import collections
from typing import Dict, List, Tuple

Expand All @@ -7,35 +7,51 @@


class MergeResplitter:
"""Create a new dataset splits according to the `resplit_strategy`.
"""Merge existing splits of the dataset and assign them custom names.
The dataset comes with some predefined splits e.g. "train", "valid" and "test". This
class allows you to create a new dataset with splits created according to your needs
specified in `resplit_strategy`.
Create new `DatasetDict` with new split names corresponding to the merged existing
splits (e.g. "train", "valid" and "test").
Parameters
----------
resplit_strategy: ResplitStrategy
merge_config: Dict[Tuple[str, ...], str]
Dictionary with keys - tuples of the current split names to values - the desired
split names
Examples
--------
Create new `DatasetDict` with a split name "new_train" that is created as a merger
of the "train" and "valid" splits. Keep the "test" split.
>>> # Assuming there is a dataset_dict of type `DatasetDict`
>>> # dataset_dict is {"train": train-data, "valid": valid-data, "test": test-data}
>>> merge_resplitter = MergeResplitter(
>>> merge_config={
>>> ("train", "valid"): "new_train",
>>> ("test", ): "test"
>>> }
>>> )
>>> new_dataset_dict = merge_resplitter(dataset_dict)
>>> # new_dataset_dict is
>>> # {"new_train": concatenation of train-data and valid-data, "test": test-data}
"""

def __init__(
self,
resplit_strategy: Dict[Tuple[str, ...], str],
merge_config: Dict[Tuple[str, ...], str],
) -> None:
self._resplit_strategy: Dict[Tuple[str, ...], str] = resplit_strategy
self._merge_config: Dict[Tuple[str, ...], str] = merge_config
self._check_duplicate_desired_splits()

def __call__(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the `resplit_strategy`."""
self._check_correct_keys_in_resplit_strategy(dataset)
"""Resplit the dataset according to the `merge_config`."""
self._check_correct_keys_in_merge_config(dataset)
return self.resplit(dataset)

def resplit(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the `resplit_strategy`."""
"""Resplit the dataset according to the `merge_config`."""
resplit_dataset = {}
for divided_from__list, divide_to in self._resplit_strategy.items():
for divided_from__list, divide_to in self._merge_config.items():
datasets_from_list: List[Dataset] = []
for divide_from in divided_from__list:
datasets_from_list.append(dataset[divide_from])
Expand All @@ -47,10 +63,10 @@ def resplit(self, dataset: DatasetDict) -> DatasetDict:
resplit_dataset[divide_to] = datasets_from_list[0]
return datasets.DatasetDict(resplit_dataset)

def _check_correct_keys_in_resplit_strategy(self, dataset: DatasetDict) -> None:
"""Check if the keys in resplit_strategy are existing dataset splits."""
def _check_correct_keys_in_merge_config(self, dataset: DatasetDict) -> None:
"""Check if the keys in merge_config are existing dataset splits."""
dataset_keys = dataset.keys()
specified_dataset_keys = self._resplit_strategy.keys()
specified_dataset_keys = self._merge_config.keys()
for key_list in specified_dataset_keys:
for key in key_list:
if key not in dataset_keys:
Expand All @@ -62,7 +78,7 @@ def _check_correct_keys_in_resplit_strategy(self, dataset: DatasetDict) -> None:

def _check_duplicate_desired_splits(self) -> None:
"""Check for duplicate desired split names."""
desired_splits = list(self._resplit_strategy.values())
desired_splits = list(self._merge_config.values())
duplicates = [
item
for item, count in collections.Counter(desired_splits).items()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict, Tuple

from datasets import Dataset, DatasetDict
from flwr_datasets.merge_splitter import MergeSplitter
from flwr_datasets.merge_resplitter import MergeResplitter


class TestResplitter(unittest.TestCase):
Expand All @@ -22,28 +22,28 @@ def setUp(self) -> None:
def test_resplitting_train_size(self) -> None:
"""Test if resplitting for just renaming keeps the lengths correct."""
strategy: Dict[Tuple[str, ...], str] = {("train",): "new_train"}
resplitter = MergeSplitter(strategy)
resplitter = MergeResplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["new_train"]), 3)

def test_resplitting_valid_size(self) -> None:
"""Test if resplitting for just renaming keeps the lengths correct."""
strategy: Dict[Tuple[str, ...], str] = {("valid",): "new_valid"}
resplitter = MergeSplitter(strategy)
resplitter = MergeResplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["new_valid"]), 2)

def test_resplitting_test_size(self) -> None:
"""Test if resplitting for just renaming keeps the lengths correct."""
strategy: Dict[Tuple[str, ...], str] = {("test",): "new_test"}
resplitter = MergeSplitter(strategy)
resplitter = MergeResplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["new_test"]), 1)

def test_resplitting_train_the_same(self) -> None:
"""Test if resplitting for just renaming keeps the dataset the same."""
strategy: Dict[Tuple[str, ...], str] = {("train",): "new_train"}
resplitter = MergeSplitter(strategy)
resplitter = MergeResplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertTrue(
datasets_are_equal(self.dataset_dict["train"], new_dataset["new_train"])
Expand All @@ -54,7 +54,7 @@ def test_combined_train_valid_size(self) -> None:
strategy: Dict[Tuple[str, ...], str] = {
("train", "valid"): "train_valid_combined"
}
resplitter = MergeSplitter(strategy)
resplitter = MergeResplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["train_valid_combined"]), 5)

Expand All @@ -64,7 +64,7 @@ def test_resplitting_test_with_combined_strategy_size(self) -> None:
("train", "valid"): "train_valid_combined",
("test",): "test",
}
resplitter = MergeSplitter(strategy)
resplitter = MergeResplitter(strategy)
new_dataset = resplitter(self.dataset_dict)
self.assertEqual(len(new_dataset["test"]), 1)

Expand All @@ -74,7 +74,7 @@ def test_invalid_resplit_strategy_exception_message(self) -> None:
("invalid_split",): "new_train",
("test",): "new_test",
}
resplitter = MergeSplitter(strategy)
resplitter = MergeResplitter(strategy)
with self.assertRaisesRegex(
ValueError, "The given dataset key 'invalid_split' is not present"
):
Expand All @@ -83,7 +83,7 @@ def test_invalid_resplit_strategy_exception_message(self) -> None:
def test_nonexistent_split_in_strategy(self) -> None:
"""Test if the exception is raised when the nonexistent split name is given."""
strategy: Dict[Tuple[str, ...], str] = {("nonexistent_split",): "new_split"}
resplitter = MergeSplitter(strategy)
resplitter = MergeResplitter(strategy)
with self.assertRaisesRegex(
ValueError, "The given dataset key 'nonexistent_split' is not present"
):
Expand All @@ -98,13 +98,13 @@ def test_duplicate_desired_split_name(self) -> None:
with self.assertRaisesRegex(
ValueError, "Duplicate desired split name 'new_train' in resplit strategy"
):
_ = MergeSplitter(strategy)
_ = MergeResplitter(strategy)

def test_empty_dataset_dict(self) -> None:
"""Test that the error is raised when the empty DatasetDict is given."""
empty_dataset = DatasetDict({})
strategy: Dict[Tuple[str, ...], str] = {("train",): "new_train"}
resplitter = MergeSplitter(strategy)
resplitter = MergeResplitter(strategy)
with self.assertRaisesRegex(
ValueError, "The given dataset key 'train' is not present"
):
Expand Down

0 comments on commit f790f1b

Please sign in to comment.