Skip to content

Commit

Permalink
Add NaturalIdPartitioner to FDS (#2404)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak authored Nov 10, 2023
1 parent ab9a0cb commit a76364d
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 0 deletions.
2 changes: 2 additions & 0 deletions datasets/flwr_datasets/partitioner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
from .exponential_partitioner import ExponentialPartitioner
from .iid_partitioner import IidPartitioner
from .linear_partitioner import LinearPartitioner
from .natural_id_partitioner import NaturalIdPartitioner
from .partitioner import Partitioner
from .size_partitioner import SizePartitioner
from .square_partitioner import SquarePartitioner

__all__ = [
"IidPartitioner",
"Partitioner",
"NaturalIdPartitioner",
"SizePartitioner",
"LinearPartitioner",
"SquarePartitioner",
Expand Down
81 changes: 81 additions & 0 deletions datasets/flwr_datasets/partitioner/natural_id_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Natural id partitioner class that works with Hugging Face Datasets."""


from typing import Dict

import datasets
from flwr_datasets.partitioner.partitioner import Partitioner


class NaturalIdPartitioner(Partitioner):
"""Partitioner for dataset that can be divided by a reference to id in dataset."""

def __init__(
self,
partition_by: str,
):
super().__init__()
self._node_id_to_natural_id: Dict[int, str] = {}
self._partition_by = partition_by

def _create_int_node_id_to_natural_id(self) -> None:
"""Create a mapping from int indices to unique client ids from dataset.
Natural ids come from the column specified in `partition_by`.
"""
unique_natural_ids = self.dataset.unique(self._partition_by)
self._node_id_to_natural_id = dict(
zip(range(len(unique_natural_ids)), unique_natural_ids)
)

def load_partition(self, idx: int) -> datasets.Dataset:
"""Load a single partition corresponding to a single `node_id`.
The choice of the partition is based on unique integers assigned to each
natural id present in the dataset in the `partition_by` column.
Parameters
----------
idx: int
the index that corresponds to the requested partition
Returns
-------
dataset_partition: Dataset
single dataset partition
"""
if len(self._node_id_to_natural_id) == 0:
self._create_int_node_id_to_natural_id()

return self.dataset.filter(
lambda row: row[self._partition_by] == self._node_id_to_natural_id[idx]
)

@property
def node_id_to_natural_id(self) -> Dict[int, str]:
"""Node id to corresponding natural id present.
Natural ids are the unique values in `partition_by` column in dataset.
"""
return self._node_id_to_natural_id

# pylint: disable=R0201
@node_id_to_natural_id.setter
def node_id_to_natural_id(self, value: Dict[int, str]) -> None:
raise AttributeError(
"Setting the node_id_to_natural_id dictionary is not allowed."
)
119 changes: 119 additions & 0 deletions datasets/flwr_datasets/partitioner/natural_id_partitioner_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable alaw or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""NaturalIdPartitioner partitioner tests."""


import itertools
import math
import unittest
from typing import Tuple

from parameterized import parameterized

from datasets import Dataset
from flwr_datasets.partitioner.natural_id_partitioner import NaturalIdPartitioner


def _dummy_setup(
num_rows: int, n_unique_natural_ids: int
) -> Tuple[Dataset, NaturalIdPartitioner]:
"""Create a dummy dataset and partitioner based on given arguments.
The partitioner has automatically the dataset assigned to it.
"""
dataset = _create_dataset(num_rows, n_unique_natural_ids)
partitioner = NaturalIdPartitioner(partition_by="natural_id")
partitioner.dataset = dataset
return dataset, partitioner


def _create_dataset(num_rows: int, n_unique_natural_ids: int) -> Dataset:
"""Create dataset based on the number of rows and unique natural ids."""
data = {
"features": list(range(num_rows)),
"natural_id": [f"{i % n_unique_natural_ids}" for i in range(num_rows)],
"labels": [i % 2 for i in range(num_rows)],
}
dataset = Dataset.from_dict(data)
return dataset


class TestNaturalIdPartitioner(unittest.TestCase):
"""Test IidPartitioner."""

@parameterized.expand( # type: ignore
# num_rows, num_unique_natural_ids
list(itertools.product([10, 30, 100, 1000], [2, 3, 4, 5]))
)
def test_load_partition_num_partitions(
self, num_rows: int, num_unique_natural_id: int
) -> None:
"""Test if the number of partitions match the number of unique natural ids.
Only the correct data is tested in this method.
"""
_, partitioner = _dummy_setup(num_rows, num_unique_natural_id)
# Simulate usage to start lazy node_id_to_natural_id creation
_ = partitioner.load_partition(0)
self.assertEqual(len(partitioner.node_id_to_natural_id), num_unique_natural_id)

@parameterized.expand( # type: ignore
# num_rows, num_unique_natural_ids
list(itertools.product([10, 30, 100, 1000], [2, 3, 4, 5]))
)
def test_load_partition_max_partition_size(
self, num_rows: int, num_unique_natural_ids: int
) -> None:
"""Test if the number of partitions match the number of unique natural ids.
Only the correct data is tested in this method.
"""
print(num_rows)
print(num_unique_natural_ids)
_, partitioner = _dummy_setup(num_rows, num_unique_natural_ids)
max_size = max(
[len(partitioner.load_partition(i)) for i in range(num_unique_natural_ids)]
)
self.assertEqual(max_size, math.ceil(num_rows / num_unique_natural_ids))

def test_partitioner_with_non_existing_column_partition_by(self) -> None:
"""Test error when the partition_by columns does not exist."""
dataset = _create_dataset(10, 2)
partitioner = NaturalIdPartitioner(partition_by="not-existing")
partitioner.dataset = dataset
with self.assertRaises(ValueError):
partitioner.load_partition(0)

@parameterized.expand( # type: ignore
# num_rows, num_unique_natural_ids
list(itertools.product([10, 30, 100, 1000], [2, 3, 4, 5]))
)
def test_correct_number_of_partitions(
self, num_rows: int, num_unique_natural_ids: int
) -> None:
"""Test if the # of available partitions is equal to # of unique clients."""
_, partitioner = _dummy_setup(num_rows, num_unique_natural_ids)
_ = partitioner.load_partition(idx=0)
self.assertEqual(len(partitioner.node_id_to_natural_id), num_unique_natural_ids)

def test_cannot_set_node_id_to_natural_id(self) -> None:
"""Test the lack of ability to set node_id_to_natural_id."""
_, partitioner = _dummy_setup(num_rows=10, n_unique_natural_ids=2)
with self.assertRaises(AttributeError):
partitioner.node_id_to_natural_id = {0: "0"}


if __name__ == "__main__":
unittest.main()

0 comments on commit a76364d

Please sign in to comment.