From bd49e8b9e2b0197e5792d4591c7008c968929974 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 16 Jan 2024 17:21:19 +0100 Subject: [PATCH] Add example --- .../partitioner/dirichlet_partitioner.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index ab118ba27a80..ba05c3af2512 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -45,10 +45,10 @@ class DirichletPartitioner(Partitioner): # pylint: disable=R0902 ---------- num_partitions : int The total number of partitions that the data will be divided into. - alpha : Union[float, List[float], NDArrayFloat] - Concentration parameter to the Dirichlet distribution partition_by : str Column name of the labels (targets) based on which Dirichlet sampling works. + alpha : Union[float, List[float], NDArrayFloat] + Concentration parameter to the Dirichlet distribution min_partition_size : int The minimum number of samples that each partitions will have (the sampling process is repeated if any partition is too small). @@ -61,14 +61,31 @@ class DirichletPartitioner(Partitioner): # pylint: disable=R0902 samples assignment to nodes. seed: int Seed used for dataset shuffling. It has no effect if `shuffle` is False. + + Examples + -------- + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import DirichletPartitioner + >>> + >>> partitioner = DirichletPartitioner(num_partitions=10, partition_by="label", + >>> alpha=0.5, min_partition_size=10, + >>> self_balancing=True) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + >>> print(partition[0]) # Print the first example + {'image': , + 'label': 4} + >>> partition_sizes = [len(fds.load_partition(node_id)) for node_id in range(10)] + >>> print(sorted(partition_sizes)) + [2134, 2615, 3646, 6011, 6170, 6386, 6715, 7653, 8435, 10235] """ def __init__( # pylint: disable=R0913 self, num_partitions: int, - alpha: Union[float, List[float], NDArrayFloat], partition_by: str, - min_partition_size: Optional[int] = None, + alpha: Union[float, List[float], NDArrayFloat], + min_partition_size: int = 10, self_balancing: bool = True, shuffle: bool = True, seed: Optional[int] = 42, @@ -79,9 +96,6 @@ def __init__( # pylint: disable=R0913 self._check_num_partitions_greater_than_zero() self._alpha: NDArrayFloat = self._initialize_alpha(alpha) self._partition_by = partition_by - if min_partition_size is None: - # Note that zero might make problems with the training - min_partition_size = 0 self._min_partition_size: int = min_partition_size self._self_balancing = self_balancing self._shuffle = shuffle