Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NaturalIdPartitioner to FDS #2404

Merged
merged 13 commits into from
Nov 10, 2023
2 changes: 2 additions & 0 deletions datasets/flwr_datasets/partitioner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
from .exponential_partitioner import ExponentialPartitioner
from .iid_partitioner import IidPartitioner
from .linear_partitioner import LinearPartitioner
from .natural_id_partitioner import NaturalIdPartitioner
from .partitioner import Partitioner
from .size_partitioner import SizePartitioner
from .square_partitioner import SquarePartitioner

__all__ = [
"IidPartitioner",
"Partitioner",
"NaturalIdPartitioner",
"SizePartitioner",
"LinearPartitioner",
"SquarePartitioner",
Expand Down
81 changes: 81 additions & 0 deletions datasets/flwr_datasets/partitioner/natural_id_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Natural id partitioner class that works with Hugging Face Datasets."""


from typing import Dict

import datasets
from flwr_datasets.partitioner.partitioner import Partitioner


class NaturalIdPartitioner(Partitioner):
"""Partitioner for dataset that can be divided by a reference to id in dataset."""

def __init__(
self,
partition_by: str,
):
super().__init__()
self._node_id_to_natural_id: Dict[int, str] = {}
self._partition_by = partition_by

def _create_int_node_id_to_natural_id(self) -> None:
"""Create a mapping from int indices to unique client ids from dataset.

Natural ids come from the column specified in `partition_by`.
"""
unique_natural_ids = self.dataset.unique(self._partition_by)
self._node_id_to_natural_id = dict(
zip(range(len(unique_natural_ids)), unique_natural_ids)
)

def load_partition(self, idx: int) -> datasets.Dataset:
"""Load a single partition corresponding to a single `node_id`.

The choice of the partition is based on unique integers assigned to each
natural id present in the dataset in the `partition_by` column.

Parameters
----------
idx: int
the index that corresponds to the requested partition

Returns
-------
dataset_partition: Dataset
single dataset partition
"""
if len(self._node_id_to_natural_id) == 0:
self._create_int_node_id_to_natural_id()

return self.dataset.filter(
lambda row: row[self._partition_by] == self._node_id_to_natural_id[idx]
)

@property
def node_id_to_natural_id(self) -> Dict[int, str]:
"""Node id to corresponding natural id present.

Natural ids are the unique values in `partition_by` column in dataset.
"""
return self._node_id_to_natural_id

# pylint: disable=R0201
@node_id_to_natural_id.setter
def node_id_to_natural_id(self, value: Dict[int, str]) -> None:
raise AttributeError(
"Setting the node_id_to_natural_id dictionary is not allowed."
)
119 changes: 119 additions & 0 deletions datasets/flwr_datasets/partitioner/natural_id_partitioner_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable alaw or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""NaturalIdPartitioner partitioner tests."""


import itertools
import math
import unittest
from typing import Tuple

from parameterized import parameterized

from datasets import Dataset
from flwr_datasets.partitioner.natural_id_partitioner import NaturalIdPartitioner


def _dummy_setup(
num_rows: int, n_unique_natural_ids: int
) -> Tuple[Dataset, NaturalIdPartitioner]:
"""Create a dummy dataset and partitioner based on given arguments.

The partitioner has automatically the dataset assigned to it.
"""
dataset = _create_dataset(num_rows, n_unique_natural_ids)
partitioner = NaturalIdPartitioner(partition_by="natural_id")
partitioner.dataset = dataset
return dataset, partitioner


def _create_dataset(num_rows: int, n_unique_natural_ids: int) -> Dataset:
"""Create dataset based on the number of rows and unique natural ids."""
data = {
"features": list(range(num_rows)),
"natural_id": [f"{i % n_unique_natural_ids}" for i in range(num_rows)],
"labels": [i % 2 for i in range(num_rows)],
}
dataset = Dataset.from_dict(data)
return dataset


class TestNaturalIdPartitioner(unittest.TestCase):
"""Test IidPartitioner."""

@parameterized.expand( # type: ignore
# num_rows, num_unique_natural_ids
list(itertools.product([10, 30, 100, 1000], [2, 3, 4, 5]))
)
def test_load_partition_num_partitions(
self, num_rows: int, num_unique_natural_id: int
) -> None:
"""Test if the number of partitions match the number of unique natural ids.

Only the correct data is tested in this method.
"""
_, partitioner = _dummy_setup(num_rows, num_unique_natural_id)
# Simulate usage to start lazy node_id_to_natural_id creation
_ = partitioner.load_partition(0)
self.assertEqual(len(partitioner.node_id_to_natural_id), num_unique_natural_id)

@parameterized.expand( # type: ignore
# num_rows, num_unique_natural_ids
list(itertools.product([10, 30, 100, 1000], [2, 3, 4, 5]))
)
def test_load_partition_max_partition_size(
self, num_rows: int, num_unique_natural_ids: int
) -> None:
"""Test if the number of partitions match the number of unique natural ids.

Only the correct data is tested in this method.
"""
print(num_rows)
print(num_unique_natural_ids)
_, partitioner = _dummy_setup(num_rows, num_unique_natural_ids)
max_size = max(
[len(partitioner.load_partition(i)) for i in range(num_unique_natural_ids)]
)
self.assertEqual(max_size, math.ceil(num_rows / num_unique_natural_ids))

def test_partitioner_with_non_existing_column_partition_by(self) -> None:
"""Test error when the partition_by columns does not exist."""
dataset = _create_dataset(10, 2)
partitioner = NaturalIdPartitioner(partition_by="not-existing")
partitioner.dataset = dataset
with self.assertRaises(ValueError):
partitioner.load_partition(0)

@parameterized.expand( # type: ignore
# num_rows, num_unique_natural_ids
list(itertools.product([10, 30, 100, 1000], [2, 3, 4, 5]))
)
def test_correct_number_of_partitions(
self, num_rows: int, num_unique_natural_ids: int
) -> None:
"""Test if the # of available partitions is equal to # of unique clients."""
_, partitioner = _dummy_setup(num_rows, num_unique_natural_ids)
_ = partitioner.load_partition(idx=0)
self.assertEqual(len(partitioner.node_id_to_natural_id), num_unique_natural_ids)

def test_cannot_set_node_id_to_natural_id(self) -> None:
"""Test the lack of ability to set node_id_to_natural_id."""
_, partitioner = _dummy_setup(num_rows=10, n_unique_natural_ids=2)
with self.assertRaises(AttributeError):
partitioner.node_id_to_natural_id = {0: "0"}


if __name__ == "__main__":
unittest.main()