Skip to content

Commit

Permalink
Add example
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Jan 16, 2024
1 parent 2f31ef2 commit bd49e8b
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions datasets/flwr_datasets/partitioner/dirichlet_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ class DirichletPartitioner(Partitioner): # pylint: disable=R0902
----------
num_partitions : int
The total number of partitions that the data will be divided into.
alpha : Union[float, List[float], NDArrayFloat]
Concentration parameter to the Dirichlet distribution
partition_by : str
Column name of the labels (targets) based on which Dirichlet sampling works.
alpha : Union[float, List[float], NDArrayFloat]
Concentration parameter to the Dirichlet distribution
min_partition_size : int
The minimum number of samples that each partitions will have (the sampling
process is repeated if any partition is too small).
Expand All @@ -61,14 +61,31 @@ class DirichletPartitioner(Partitioner): # pylint: disable=R0902
samples assignment to nodes.
seed: int
Seed used for dataset shuffling. It has no effect if `shuffle` is False.
Examples
--------
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.partitioner import DirichletPartitioner
>>>
>>> partitioner = DirichletPartitioner(num_partitions=10, partition_by="label",
>>> alpha=0.5, min_partition_size=10,
>>> self_balancing=True)
>>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner})
>>> partition = fds.load_partition(0)
>>> print(partition[0]) # Print the first example
{'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x127B92170>,
'label': 4}
>>> partition_sizes = [len(fds.load_partition(node_id)) for node_id in range(10)]
>>> print(sorted(partition_sizes))
[2134, 2615, 3646, 6011, 6170, 6386, 6715, 7653, 8435, 10235]
"""

def __init__( # pylint: disable=R0913
self,
num_partitions: int,
alpha: Union[float, List[float], NDArrayFloat],
partition_by: str,
min_partition_size: Optional[int] = None,
alpha: Union[float, List[float], NDArrayFloat],
min_partition_size: int = 10,
self_balancing: bool = True,
shuffle: bool = True,
seed: Optional[int] = 42,
Expand All @@ -79,9 +96,6 @@ def __init__( # pylint: disable=R0913
self._check_num_partitions_greater_than_zero()
self._alpha: NDArrayFloat = self._initialize_alpha(alpha)
self._partition_by = partition_by
if min_partition_size is None:
# Note that zero might make problems with the training
min_partition_size = 0
self._min_partition_size: int = min_partition_size
self._self_balancing = self_balancing
self._shuffle = shuffle
Expand Down

0 comments on commit bd49e8b

Please sign in to comment.