-
Notifications
You must be signed in to change notification settings - Fork 898
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add NaturalIdPartitioner to FDS (#2404)
- Loading branch information
1 parent
ab9a0cb
commit a76364d
Showing
3 changed files
with
202 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
81 changes: 81 additions & 0 deletions
81
datasets/flwr_datasets/partitioner/natural_id_partitioner.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright 2023 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Natural id partitioner class that works with Hugging Face Datasets.""" | ||
|
||
|
||
from typing import Dict | ||
|
||
import datasets | ||
from flwr_datasets.partitioner.partitioner import Partitioner | ||
|
||
|
||
class NaturalIdPartitioner(Partitioner): | ||
"""Partitioner for dataset that can be divided by a reference to id in dataset.""" | ||
|
||
def __init__( | ||
self, | ||
partition_by: str, | ||
): | ||
super().__init__() | ||
self._node_id_to_natural_id: Dict[int, str] = {} | ||
self._partition_by = partition_by | ||
|
||
def _create_int_node_id_to_natural_id(self) -> None: | ||
"""Create a mapping from int indices to unique client ids from dataset. | ||
Natural ids come from the column specified in `partition_by`. | ||
""" | ||
unique_natural_ids = self.dataset.unique(self._partition_by) | ||
self._node_id_to_natural_id = dict( | ||
zip(range(len(unique_natural_ids)), unique_natural_ids) | ||
) | ||
|
||
def load_partition(self, idx: int) -> datasets.Dataset: | ||
"""Load a single partition corresponding to a single `node_id`. | ||
The choice of the partition is based on unique integers assigned to each | ||
natural id present in the dataset in the `partition_by` column. | ||
Parameters | ||
---------- | ||
idx: int | ||
the index that corresponds to the requested partition | ||
Returns | ||
------- | ||
dataset_partition: Dataset | ||
single dataset partition | ||
""" | ||
if len(self._node_id_to_natural_id) == 0: | ||
self._create_int_node_id_to_natural_id() | ||
|
||
return self.dataset.filter( | ||
lambda row: row[self._partition_by] == self._node_id_to_natural_id[idx] | ||
) | ||
|
||
@property | ||
def node_id_to_natural_id(self) -> Dict[int, str]: | ||
"""Node id to corresponding natural id present. | ||
Natural ids are the unique values in `partition_by` column in dataset. | ||
""" | ||
return self._node_id_to_natural_id | ||
|
||
# pylint: disable=R0201 | ||
@node_id_to_natural_id.setter | ||
def node_id_to_natural_id(self, value: Dict[int, str]) -> None: | ||
raise AttributeError( | ||
"Setting the node_id_to_natural_id dictionary is not allowed." | ||
) |
119 changes: 119 additions & 0 deletions
119
datasets/flwr_datasets/partitioner/natural_id_partitioner_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Copyright 2023 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable alaw or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""NaturalIdPartitioner partitioner tests.""" | ||
|
||
|
||
import itertools | ||
import math | ||
import unittest | ||
from typing import Tuple | ||
|
||
from parameterized import parameterized | ||
|
||
from datasets import Dataset | ||
from flwr_datasets.partitioner.natural_id_partitioner import NaturalIdPartitioner | ||
|
||
|
||
def _dummy_setup( | ||
num_rows: int, n_unique_natural_ids: int | ||
) -> Tuple[Dataset, NaturalIdPartitioner]: | ||
"""Create a dummy dataset and partitioner based on given arguments. | ||
The partitioner has automatically the dataset assigned to it. | ||
""" | ||
dataset = _create_dataset(num_rows, n_unique_natural_ids) | ||
partitioner = NaturalIdPartitioner(partition_by="natural_id") | ||
partitioner.dataset = dataset | ||
return dataset, partitioner | ||
|
||
|
||
def _create_dataset(num_rows: int, n_unique_natural_ids: int) -> Dataset: | ||
"""Create dataset based on the number of rows and unique natural ids.""" | ||
data = { | ||
"features": list(range(num_rows)), | ||
"natural_id": [f"{i % n_unique_natural_ids}" for i in range(num_rows)], | ||
"labels": [i % 2 for i in range(num_rows)], | ||
} | ||
dataset = Dataset.from_dict(data) | ||
return dataset | ||
|
||
|
||
class TestNaturalIdPartitioner(unittest.TestCase): | ||
"""Test IidPartitioner.""" | ||
|
||
@parameterized.expand( # type: ignore | ||
# num_rows, num_unique_natural_ids | ||
list(itertools.product([10, 30, 100, 1000], [2, 3, 4, 5])) | ||
) | ||
def test_load_partition_num_partitions( | ||
self, num_rows: int, num_unique_natural_id: int | ||
) -> None: | ||
"""Test if the number of partitions match the number of unique natural ids. | ||
Only the correct data is tested in this method. | ||
""" | ||
_, partitioner = _dummy_setup(num_rows, num_unique_natural_id) | ||
# Simulate usage to start lazy node_id_to_natural_id creation | ||
_ = partitioner.load_partition(0) | ||
self.assertEqual(len(partitioner.node_id_to_natural_id), num_unique_natural_id) | ||
|
||
@parameterized.expand( # type: ignore | ||
# num_rows, num_unique_natural_ids | ||
list(itertools.product([10, 30, 100, 1000], [2, 3, 4, 5])) | ||
) | ||
def test_load_partition_max_partition_size( | ||
self, num_rows: int, num_unique_natural_ids: int | ||
) -> None: | ||
"""Test if the number of partitions match the number of unique natural ids. | ||
Only the correct data is tested in this method. | ||
""" | ||
print(num_rows) | ||
print(num_unique_natural_ids) | ||
_, partitioner = _dummy_setup(num_rows, num_unique_natural_ids) | ||
max_size = max( | ||
[len(partitioner.load_partition(i)) for i in range(num_unique_natural_ids)] | ||
) | ||
self.assertEqual(max_size, math.ceil(num_rows / num_unique_natural_ids)) | ||
|
||
def test_partitioner_with_non_existing_column_partition_by(self) -> None: | ||
"""Test error when the partition_by columns does not exist.""" | ||
dataset = _create_dataset(10, 2) | ||
partitioner = NaturalIdPartitioner(partition_by="not-existing") | ||
partitioner.dataset = dataset | ||
with self.assertRaises(ValueError): | ||
partitioner.load_partition(0) | ||
|
||
@parameterized.expand( # type: ignore | ||
# num_rows, num_unique_natural_ids | ||
list(itertools.product([10, 30, 100, 1000], [2, 3, 4, 5])) | ||
) | ||
def test_correct_number_of_partitions( | ||
self, num_rows: int, num_unique_natural_ids: int | ||
) -> None: | ||
"""Test if the # of available partitions is equal to # of unique clients.""" | ||
_, partitioner = _dummy_setup(num_rows, num_unique_natural_ids) | ||
_ = partitioner.load_partition(idx=0) | ||
self.assertEqual(len(partitioner.node_id_to_natural_id), num_unique_natural_ids) | ||
|
||
def test_cannot_set_node_id_to_natural_id(self) -> None: | ||
"""Test the lack of ability to set node_id_to_natural_id.""" | ||
_, partitioner = _dummy_setup(num_rows=10, n_unique_natural_ids=2) | ||
with self.assertRaises(AttributeError): | ||
partitioner.node_id_to_natural_id = {0: "0"} | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |