-
Notifications
You must be signed in to change notification settings - Fork 903
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'fds-add-resplitting-functionality'
- Loading branch information
Showing
4 changed files
with
296 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""Resplitter class for Flower Datasets.""" | ||
import collections | ||
from typing import Dict, List, Tuple | ||
|
||
import datasets | ||
from datasets import Dataset, DatasetDict | ||
|
||
|
||
class MergeSplitter: | ||
"""Create a new dataset splits according to the `resplit_strategy`. | ||
The dataset comes with some predefined splits e.g. "train", "valid" and "test". This | ||
class allows you to create a new dataset with splits created according to your needs | ||
specified in `resplit_strategy`. | ||
Parameters | ||
---------- | ||
resplit_strategy: ResplitStrategy | ||
Dictionary with keys - tuples of the current split names to values - the desired | ||
split names | ||
""" | ||
|
||
def __init__( | ||
self, | ||
resplit_strategy: Dict[Tuple[str, ...], str], | ||
) -> None: | ||
self._resplit_strategy: Dict[Tuple[str, ...], str] = resplit_strategy | ||
self._check_duplicate_desired_splits() | ||
|
||
def __call__(self, dataset: DatasetDict) -> DatasetDict: | ||
"""Resplit the dataset according to the `resplit_strategy`.""" | ||
self._check_correct_keys_in_resplit_strategy(dataset) | ||
return self.resplit(dataset) | ||
|
||
def resplit(self, dataset: DatasetDict) -> DatasetDict: | ||
"""Resplit the dataset according to the `resplit_strategy`.""" | ||
resplit_dataset = {} | ||
for divided_from__list, divide_to in self._resplit_strategy.items(): | ||
datasets_from_list: List[Dataset] = [] | ||
for divide_from in divided_from__list: | ||
datasets_from_list.append(dataset[divide_from]) | ||
if len(datasets_from_list) > 1: | ||
resplit_dataset[divide_to] = datasets.concatenate_datasets( | ||
datasets_from_list | ||
) | ||
else: | ||
resplit_dataset[divide_to] = datasets_from_list[0] | ||
return datasets.DatasetDict(resplit_dataset) | ||
|
||
def _check_correct_keys_in_resplit_strategy(self, dataset: DatasetDict) -> None: | ||
"""Check if the keys in resplit_strategy are existing dataset splits.""" | ||
dataset_keys = dataset.keys() | ||
specified_dataset_keys = self._resplit_strategy.keys() | ||
for key_list in specified_dataset_keys: | ||
for key in key_list: | ||
if key not in dataset_keys: | ||
raise ValueError( | ||
f"The given dataset key '{key}' is not present in the given " | ||
f"dataset object. Make sure to use only the keywords that are " | ||
f"available in your dataset." | ||
) | ||
|
||
def _check_duplicate_desired_splits(self) -> None: | ||
"""Check for duplicate desired split names.""" | ||
desired_splits = list(self._resplit_strategy.values()) | ||
duplicates = [ | ||
item | ||
for item, count in collections.Counter(desired_splits).items() | ||
if count > 1 | ||
] | ||
if duplicates: | ||
print(f"Duplicate desired split name '{duplicates[0]}' in resplit strategy") | ||
raise ValueError( | ||
f"Duplicate desired split name '{duplicates[0]}' in resplit strategy" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
"""Resplitter tests.""" | ||
import unittest | ||
from typing import Dict, Tuple | ||
|
||
from datasets import Dataset, DatasetDict | ||
from flwr_datasets.merge_splitter import MergeSplitter | ||
|
||
|
||
class TestResplitter(unittest.TestCase): | ||
"""Resplitter tests.""" | ||
|
||
def setUp(self) -> None: | ||
"""Set up the dataset with 3 splits for tests.""" | ||
self.dataset_dict = DatasetDict( | ||
{ | ||
"train": Dataset.from_dict({"data": [1, 2, 3]}), | ||
"valid": Dataset.from_dict({"data": [4, 5]}), | ||
"test": Dataset.from_dict({"data": [6]}), | ||
} | ||
) | ||
|
||
def test_resplitting_train_size(self) -> None: | ||
"""Test if resplitting for just renaming keeps the lengths correct.""" | ||
strategy: Dict[Tuple[str, ...], str] = {("train",): "new_train"} | ||
resplitter = MergeSplitter(strategy) | ||
new_dataset = resplitter(self.dataset_dict) | ||
self.assertEqual(len(new_dataset["new_train"]), 3) | ||
|
||
def test_resplitting_valid_size(self) -> None: | ||
"""Test if resplitting for just renaming keeps the lengths correct.""" | ||
strategy: Dict[Tuple[str, ...], str] = {("valid",): "new_valid"} | ||
resplitter = MergeSplitter(strategy) | ||
new_dataset = resplitter(self.dataset_dict) | ||
self.assertEqual(len(new_dataset["new_valid"]), 2) | ||
|
||
def test_resplitting_test_size(self) -> None: | ||
"""Test if resplitting for just renaming keeps the lengths correct.""" | ||
strategy: Dict[Tuple[str, ...], str] = {("test",): "new_test"} | ||
resplitter = MergeSplitter(strategy) | ||
new_dataset = resplitter(self.dataset_dict) | ||
self.assertEqual(len(new_dataset["new_test"]), 1) | ||
|
||
def test_resplitting_train_the_same(self) -> None: | ||
"""Test if resplitting for just renaming keeps the dataset the same.""" | ||
strategy: Dict[Tuple[str, ...], str] = {("train",): "new_train"} | ||
resplitter = MergeSplitter(strategy) | ||
new_dataset = resplitter(self.dataset_dict) | ||
self.assertTrue( | ||
datasets_are_equal(self.dataset_dict["train"], new_dataset["new_train"]) | ||
) | ||
|
||
def test_combined_train_valid_size(self) -> None: | ||
"""Test if the resplitting that combines the datasets has correct size.""" | ||
strategy: Dict[Tuple[str, ...], str] = { | ||
("train", "valid"): "train_valid_combined" | ||
} | ||
resplitter = MergeSplitter(strategy) | ||
new_dataset = resplitter(self.dataset_dict) | ||
self.assertEqual(len(new_dataset["train_valid_combined"]), 5) | ||
|
||
def test_resplitting_test_with_combined_strategy_size(self) -> None: | ||
"""Test if the resplitting that combines the datasets has correct size.""" | ||
strategy: Dict[Tuple[str, ...], str] = { | ||
("train", "valid"): "train_valid_combined", | ||
("test",): "test", | ||
} | ||
resplitter = MergeSplitter(strategy) | ||
new_dataset = resplitter(self.dataset_dict) | ||
self.assertEqual(len(new_dataset["test"]), 1) | ||
|
||
def test_invalid_resplit_strategy_exception_message(self) -> None: | ||
"""Test if the resplitting raises error when non-existing split is given.""" | ||
strategy: Dict[Tuple[str, ...], str] = { | ||
("invalid_split",): "new_train", | ||
("test",): "new_test", | ||
} | ||
resplitter = MergeSplitter(strategy) | ||
with self.assertRaisesRegex( | ||
ValueError, "The given dataset key 'invalid_split' is not present" | ||
): | ||
resplitter(self.dataset_dict) | ||
|
||
def test_nonexistent_split_in_strategy(self) -> None: | ||
"""Test if the exception is raised when the nonexistent split name is given.""" | ||
strategy: Dict[Tuple[str, ...], str] = {("nonexistent_split",): "new_split"} | ||
resplitter = MergeSplitter(strategy) | ||
with self.assertRaisesRegex( | ||
ValueError, "The given dataset key 'nonexistent_split' is not present" | ||
): | ||
resplitter(self.dataset_dict) | ||
|
||
def test_duplicate_desired_split_name(self) -> None: | ||
"""Test that the new split names are not the same.""" | ||
strategy: Dict[Tuple[str, ...], str] = { | ||
("train",): "new_train", | ||
("valid",): "new_train", | ||
} | ||
with self.assertRaisesRegex( | ||
ValueError, "Duplicate desired split name 'new_train' in resplit strategy" | ||
): | ||
_ = MergeSplitter(strategy) | ||
|
||
def test_empty_dataset_dict(self) -> None: | ||
"""Test that the error is raised when the empty DatasetDict is given.""" | ||
empty_dataset = DatasetDict({}) | ||
strategy: Dict[Tuple[str, ...], str] = {("train",): "new_train"} | ||
resplitter = MergeSplitter(strategy) | ||
with self.assertRaisesRegex( | ||
ValueError, "The given dataset key 'train' is not present" | ||
): | ||
resplitter(empty_dataset) | ||
|
||
|
||
def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool: | ||
"""Check if two Datasets have the same values.""" | ||
# Check if both datasets have the same length | ||
if len(ds1) != len(ds2): | ||
return False | ||
|
||
# Iterate over each row and check for equality | ||
for row1, row2 in zip(ds1, ds2): | ||
if row1 != row2: | ||
return False | ||
|
||
return True | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |