diff --git a/baselines/niid_bench/LICENSE b/baselines/niid_bench/LICENSE
new file mode 100644
index 000000000000..d64569567334
--- /dev/null
+++ b/baselines/niid_bench/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/baselines/niid_bench/README.md b/baselines/niid_bench/README.md
new file mode 100644
index 000000000000..6553a3cfdb1e
--- /dev/null
+++ b/baselines/niid_bench/README.md
@@ -0,0 +1,107 @@
+---
+title: "Federated Learning on Non-IID Data Silos: An Experimental Study"
+url: https://arxiv.org/abs/2102.02079
+labels: [data heterogeneity, image classification, benchmark]
+dataset: [CIFAR-10, MNIST, Fashion-MNIST]
+algorithms: [FedAvg, SCAFFOLD, FedProx, FedNova]
+---
+
+# Federated Learning on Non-IID Data Silos: An Experimental Study
+
+**Paper:** [arxiv.org/abs/2102.02079](https://arxiv.org/abs/2102.02079)
+
+**Authors:** Qinbin Li, Yiqun Diao, Quan Chen, Bingsheng He
+
+**Abstract:** Due to the increasing privacy concerns and data regulations, training data have been increasingly fragmented, forming distributed databases of multiple "data silos" (e.g., within different organizations and countries). To develop effective machine learning services, there is a must to exploit data from such distributed databases without exchanging the raw data. Recently, federated learning (FL) has been a solution with growing interests, which enables multiple parties to collaboratively train a machine learning model without exchanging their local data. A key and common challenge on distributed databases is the heterogeneity of the data distribution among the parties. The data of different parties are usually non-independently and identically distributed (i.e., non-IID). There have been many FL algorithms to address the learning effectiveness under non-IID data settings. However, there lacks an experimental study on systematically understanding their advantages and disadvantages, as previous studies have very rigid data partitioning strategies among parties, which are hardly representative and thorough. In this paper, to help researchers better understand and study the non-IID data setting in federated learning, we propose comprehensive data partitioning strategies to cover the typical non-IID data cases. Moreover, we conduct extensive experiments to evaluate state-of-the-art FL algorithms. We find that non-IID does bring significant challenges in learning accuracy of FL algorithms, and none of the existing state-of-the-art FL algorithms outperforms others in all cases. Our experiments provide insights for future studies of addressing the challenges in "data silos".
+
+
+## About this baseline
+
+**What’s implemented:** The code in this directory replicates many experiments from the aforementioned paper. Specifically, it contains implementations for four FL protocols, `FedAvg` (McMahan et al. 2017), `SCAFFOLD` (Karimireddy et al. 2019), `FedProx` (Li et al. 2018), and `FedNova` (Wang et al. 2020). The FL protocols are evaluated across various non-IID data partition strategies across clients on three image classification datasets MNIST, CIFAR10, and Fashion-mnist.
+
+**Datasets:** MNIST, CIFAR10, and Fashion-mnist from PyTorch's Torchvision
+
+**Hardware Setup:** These experiments were run on a linux server with 56 CPU threads with 250 GB Ram. There are 105 configurations to run per seed and at any time 7 configurations have been run parallely. The experiments required close to 12 hrs to finish for one seed. Nevertheless, to run a subset of configurations, such as only one FL protocol across all datasets and splits, a machine with 4-8 threads and 16 GB memory can run in reasonable time.
+
+**Contributors:** Aashish Kolluri, PhD Candidate, National University of Singapore
+
+
+## Experimental Setup
+
+**Task:** Image classification
+
+**Model:** This directory implements CNNs as mentioned in the paper (Section V, paragraph 1). Specifically, the CNNs have two 2D convolutional layers with 6 and 16 output channels, kernel size 5, and stride 1.
+
+**Dataset:** This directory has three image classification datasets that are used in the baseline, MNIST, CIFAR10, and Fashion-mnist. Further, five different data-splitting strategies are used including iid and four non-iid strategies based on label skewness. In the first non-iid strategy, for each label the data is split based on proportions sampled from a dirichlet distribution (with parameter 0.5). In the three remaining strategies, each client gets data from randomly chosen #C labels where #C is 1, 2, or 3. For the clients that are supposed to receive data from the same label the data is equally split between them. The baseline considers 10 clients. The following table shows all dataset and data splitting configurations.
+
+| Datasets | #classes | #partitions | partitioning method | partition settings |
+| :------ | :---: | :---: | :---: | :---: |
+| CIFAR10, MNIST, Fashion-mnist | 10 | 10 | IID
dirichlet
sort and partition
sort and partition
sort and partition | NA
distribution parameter 0.5
1 label per client
2 labels per client
3 labels per client |
+
+
+**Training Hyperparameters:** There are four FL algorithms and they have many common hyperparameters and few different ones. The following table shows the common hyperparameters and their default values.
+
+| Description | Default Value |
+| ----------- | ----- |
+| total clients | 10 |
+| clients per round | 10 |
+| number of rounds | 50 |
+| number of local epochs | 10 |
+| client resources | {'num_cpus': 4.0, 'num_gpus': 0.0 }|
+| dataset name | cifar10
+| data partition | Dirichlet (0.5) |
+| batch size | 64 |
+| momentum for SGD | 0.9 |
+
+For FedProx algorithm the proximal parameter is tuned from values {0.001, 0.01, 0.1, 1.0} in the experiments. The default value is 0.01.
+
+
+## Environment Setup
+
+```bash
+# Setup the base poetry enviroment from the niid_bench directory
+# Set python version
+pyenv local 3.10.6
+# Tell poetry to use python 3.10
+poetry env use 3.10.6
+# Now install the environment
+poetry install
+# Start the shell
+poetry shell
+```
+
+
+## Running the Experiments
+You can run four algorithms `FedAvg`, `SCAFFOLD`, `FedProx`, and `FedNova`. To run any of them, use any of the corresponding config files. For instance, the following command will run with the default config provided in the corresponding configuration files.
+
+```bash
+# Run with default config, it will run FedAvg on cpu-only mode
+python -m niid_bench.main
+# Below to enable GPU utilization by the server and the clients.
+python -m niid_bench.main server_device=cuda client_resources.num_gpus=0.2
+```
+
+To change the configuration such as dataset or hyperparameters, specify them as part of the command line arguments.
+
+```bash
+python -m niid_bench.main --config-name scaffold_base dataset_name=mnist partitioning=iid # iid
+python -m niid_bench.main --config-name fedprox_base dataset_name=mnist partitioning=dirichlet # dirichlet
+python -m niid_bench.main --config-name fednova_base dataset_name=mnist partitioning=label_quantity labels_per_client=3 # sort and partition
+```
+
+
+## Expected Results
+
+We provide the bash script run_exp.py that can be used to run all configurations. For instance, the following command runs all of them with 4 configurations running at the same time. Consider lowering `--num-processes` if your machine runs slow. With `--num-processes 1` one experiment will be run at a time.
+
+```bash
+python run_exp.py --seed 42 --num-processes 4
+```
+
+The above command generates results that can be parsed to get the best accuracies across all rounds for each configuration which can be presented in a table (similar to Table 3 in the paper).
+
+| Dataset | partitioning method | FedAvg | SCAFFOLD | FedProx | FedNova |
+| :------ | :------ | :---: | :---: | :---: | :---: |
+| MNIST | IID
Dirichlet (0.5)
Sort and Partition (1)
Sort and Partition (2)
Sort and Partition (3) | 99.09 ± 0.05
98.89 ± 0.07
19.33 ± 11.82
96.86 ± 0.30
97.86 ± 0.34 | 99.06 ± 0.15
99.07 ± 0.06
9.93 ± 0.12
96.92 ± 0.52
97.91 ± 0.10 | 99.16 ± 0.04
99.02 ± 0.02
51.79 ± 26.75
96.85 ± 0.15
97.85 ± 0.06 | 99.05 ± 0.06
98.03 ± 0.06
52.58 ± 14.08
96.65 ± 0.39
97.62 ± 0.07 |
+| FMNIST | IID
Dirichlet (0.5)
Sort and Partition (1)
Sort and Partition (2)
Sort and Partition (3) | 89.23 ± 0.45
88.09 ± 0.29
28.39 ± 17.09
78.10 ± 2.51
82.43 ± 1.52 | 89.33 ± 0.27
88.44 ± 0.25
10.00 ± 0.00
33.80 ± 41.22
80.32 ± 5.03 | 89.42 ± 0.09
88.15 ± 0.42
32.65 ± 6.68
78.05 ± 0.99
82.99 ± 0.48 | 89.36 ± 0.09
88.22 ± 0.12
16.86 ± 9.30
71.67 ± 2.34
81.97 ± 1.34 |
+| CIFAR10 | IID
Dirichlet (0.5)
Sort and Partition (1)
Sort and Partition (2)
Sort and Partition (3) | 71.32 ± 0.33
62.47 ± 0.43
10.00 ± 0.00
51.17 ± 1.09
59.11 ± 0.87 | 71.66 ± 1.13
68.08 ± 0.96
10.00 ± 0.00
49.42 ± 2.18
61.00 ± 0.91 | 71.26 ± 1.18
65.63 ± 0.08
12.71 ± 0.96
50.44 ± 0.79
59.20 ± 1.18 | 70.69 ± 1.14
63.89 ± 1.40
10.00 ± 0.00
46.9 ± 0.66
57.83 ± 0.42 |
diff --git a/baselines/niid_bench/niid_bench/__init__.py b/baselines/niid_bench/niid_bench/__init__.py
new file mode 100644
index 000000000000..a5e567b59135
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/__init__.py
@@ -0,0 +1 @@
+"""Template baseline package."""
diff --git a/baselines/niid_bench/niid_bench/client.py b/baselines/niid_bench/niid_bench/client.py
new file mode 100644
index 000000000000..1f05ee6d0fe9
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/client.py
@@ -0,0 +1,4 @@
+"""Client script.
+
+This is not used in this baseline. Please refer to the strategy-specific client files.
+"""
diff --git a/baselines/niid_bench/niid_bench/client_fedavg.py b/baselines/niid_bench/niid_bench/client_fedavg.py
new file mode 100644
index 000000000000..e2e298574c45
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/client_fedavg.py
@@ -0,0 +1,130 @@
+"""Defines the client class and support functions for FedAvg."""
+
+from typing import Callable, Dict, List, OrderedDict
+
+import flwr as fl
+import torch
+from flwr.common import Scalar
+from hydra.utils import instantiate
+from omegaconf import DictConfig
+from torch.utils.data import DataLoader
+
+from niid_bench.models import test, train_fedavg
+
+
+# pylint: disable=too-many-instance-attributes
+class FlowerClientFedAvg(fl.client.NumPyClient):
+ """Flower client implementing FedAvg."""
+
+ # pylint: disable=too-many-arguments
+ def __init__(
+ self,
+ net: torch.nn.Module,
+ trainloader: DataLoader,
+ valloader: DataLoader,
+ device: torch.device,
+ num_epochs: int,
+ learning_rate: float,
+ momentum: float,
+ weight_decay: float,
+ ) -> None:
+ self.net = net
+ self.trainloader = trainloader
+ self.valloader = valloader
+ self.device = device
+ self.num_epochs = num_epochs
+ self.learning_rate = learning_rate
+ self.momentum = momentum
+ self.weight_decay = weight_decay
+
+ def get_parameters(self, config: Dict[str, Scalar]):
+ """Return the current local model parameters."""
+ return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
+
+ def set_parameters(self, parameters):
+ """Set the local model parameters using given ones."""
+ params_dict = zip(self.net.state_dict().keys(), parameters)
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
+ self.net.load_state_dict(state_dict, strict=True)
+
+ def fit(self, parameters, config: Dict[str, Scalar]):
+ """Implement distributed fit function for a given client for FedAvg."""
+ self.set_parameters(parameters)
+ train_fedavg(
+ self.net,
+ self.trainloader,
+ self.device,
+ self.num_epochs,
+ self.learning_rate,
+ self.momentum,
+ self.weight_decay,
+ )
+ final_p_np = self.get_parameters({})
+ return final_p_np, len(self.trainloader.dataset), {}
+
+ def evaluate(self, parameters, config: Dict[str, Scalar]):
+ """Evaluate using given parameters."""
+ self.set_parameters(parameters)
+ loss, acc = test(self.net, self.valloader, self.device)
+ return float(loss), len(self.valloader.dataset), {"accuracy": float(acc)}
+
+
+# pylint: disable=too-many-arguments
+def gen_client_fn(
+ trainloaders: List[DataLoader],
+ valloaders: List[DataLoader],
+ num_epochs: int,
+ learning_rate: float,
+ model: DictConfig,
+ momentum: float = 0.9,
+ weight_decay: float = 1e-5,
+) -> Callable[[str], FlowerClientFedAvg]: # pylint: disable=too-many-arguments
+ """Generate the client function that creates the FedAvg flower clients.
+
+ Parameters
+ ----------
+ trainloaders: List[DataLoader]
+ A list of DataLoaders, each pointing to the dataset training partition
+ belonging to a particular client.
+ valloaders: List[DataLoader]
+ A list of DataLoaders, each pointing to the dataset validation partition
+ belonging to a particular client.
+ num_epochs : int
+ The number of local epochs each client should run the training for before
+ sending it to the server.
+ learning_rate : float
+ The learning rate for the SGD optimizer of clients.
+ momentum : float
+ The momentum for SGD optimizer of clients
+ weight_decay : float
+ The weight decay for SGD optimizer of clients
+
+ Returns
+ -------
+ Callable[[str], FlowerClientFedAvg]
+ The client function that creates the FedAvg flower clients
+ """
+
+ def client_fn(cid: str) -> FlowerClientFedAvg:
+ """Create a Flower client representing a single organization."""
+ # Load model
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ net = instantiate(model).to(device)
+
+ # Note: each client gets a different trainloader/valloader, so each client
+ # will train and evaluate on their own unique data
+ trainloader = trainloaders[int(cid)]
+ valloader = valloaders[int(cid)]
+
+ return FlowerClientFedAvg(
+ net,
+ trainloader,
+ valloader,
+ device,
+ num_epochs,
+ learning_rate,
+ momentum,
+ weight_decay,
+ )
+
+ return client_fn
diff --git a/baselines/niid_bench/niid_bench/client_fednova.py b/baselines/niid_bench/niid_bench/client_fednova.py
new file mode 100644
index 000000000000..69b350fcd9f0
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/client_fednova.py
@@ -0,0 +1,133 @@
+"""Defines the client class and support functions for FedNova."""
+
+from typing import Callable, Dict, List, OrderedDict
+
+import flwr as fl
+import torch
+from flwr.common import Scalar
+from hydra.utils import instantiate
+from omegaconf import DictConfig
+from torch.utils.data import DataLoader
+
+from niid_bench.models import test, train_fednova
+
+
+# pylint: disable=too-many-instance-attributes
+class FlowerClientFedNova(fl.client.NumPyClient):
+ """Flower client implementing FedNova."""
+
+ # pylint: disable=too-many-arguments
+ def __init__(
+ self,
+ net: torch.nn.Module,
+ trainloader: DataLoader,
+ valloader: DataLoader,
+ device: torch.device,
+ num_epochs: int,
+ learning_rate: float,
+ momentum: float,
+ weight_decay: float,
+ ) -> None:
+ self.net = net
+ self.trainloader = trainloader
+ self.valloader = valloader
+ self.device = device
+ self.num_epochs = num_epochs
+ self.learning_rate = learning_rate
+ self.momentum = momentum
+ self.weight_decay = weight_decay
+
+ def get_parameters(self, config: Dict[str, Scalar]):
+ """Return the current local model parameters."""
+ return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
+
+ def set_parameters(self, parameters):
+ """Set the local model parameters using given ones."""
+ params_dict = zip(self.net.state_dict().keys(), parameters)
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
+ self.net.load_state_dict(state_dict, strict=True)
+
+ def fit(self, parameters, config: Dict[str, Scalar]):
+ """Implement distributed fit function for a given client for FedNova."""
+ self.set_parameters(parameters)
+ a_i, g_i = train_fednova(
+ self.net,
+ self.trainloader,
+ self.device,
+ self.num_epochs,
+ self.learning_rate,
+ self.momentum,
+ self.weight_decay,
+ )
+ # final_p_np = self.get_parameters({})
+ g_i_np = [param.cpu().numpy() for param in g_i]
+ return g_i_np, len(self.trainloader.dataset), {"a_i": a_i}
+
+ def evaluate(self, parameters, config: Dict[str, Scalar]):
+ """Evaluate using given parameters."""
+ self.set_parameters(parameters)
+ loss, acc = test(self.net, self.valloader, self.device)
+ return float(loss), len(self.valloader.dataset), {"accuracy": float(acc)}
+
+
+# pylint: disable=too-many-arguments
+def gen_client_fn(
+ trainloaders: List[DataLoader],
+ valloaders: List[DataLoader],
+ num_epochs: int,
+ learning_rate: float,
+ model: DictConfig,
+ momentum: float = 0.9,
+ weight_decay: float = 1e-5,
+) -> Callable[[str], FlowerClientFedNova]: # pylint: disable=too-many-arguments
+ """Generate the client function that creates the FedNova flower clients.
+
+ Parameters
+ ----------
+ trainloaders: List[DataLoader]
+ A list of DataLoaders, each pointing to the dataset training partition
+ belonging to a particular client.
+ valloaders: List[DataLoader]
+ A list of DataLoaders, each pointing to the dataset validation partition
+ belonging to a particular client.
+ num_epochs : int
+ The number of local epochs each client should run the training for before
+ sending it to the server.
+ learning_rate : float
+ The learning rate for the SGD optimizer of clients.
+ model : DictConfig
+ The model configuration.
+ momentum : float
+ The momentum for SGD optimizer of clients
+ weight_decay : float
+ The weight decay for SGD optimizer of clients
+
+ Returns
+ -------
+ Callable[[str], FlowerClientFedNova]
+ The client function that creates the FedNova flower clients
+ """
+
+ def client_fn(cid: str) -> FlowerClientFedNova:
+ """Create a Flower client representing a single organization."""
+ # Load model
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ net = instantiate(model).to(device)
+
+ # Note: each client gets a different trainloader/valloader, so each client
+ # will train and evaluate on their own unique data
+ trainloader = trainloaders[int(cid)]
+ valloader = valloaders[int(cid)]
+
+ return FlowerClientFedNova(
+ net,
+ trainloader,
+ valloader,
+ device,
+ num_epochs,
+ learning_rate,
+ momentum,
+ weight_decay,
+ )
+
+ return client_fn
diff --git a/baselines/niid_bench/niid_bench/client_fedprox.py b/baselines/niid_bench/niid_bench/client_fedprox.py
new file mode 100644
index 000000000000..5b470eda901b
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/client_fedprox.py
@@ -0,0 +1,139 @@
+"""Defines the client class and support functions for FedProx."""
+
+from typing import Callable, Dict, List, OrderedDict
+
+import flwr as fl
+import torch
+from flwr.common import Scalar
+from hydra.utils import instantiate
+from omegaconf import DictConfig
+from torch.utils.data import DataLoader
+
+from niid_bench.models import test, train_fedprox
+
+
+# pylint: disable=too-many-instance-attributes
+class FlowerClientFedProx(fl.client.NumPyClient):
+ """Flower client implementing FedProx."""
+
+ # pylint: disable=too-many-arguments
+ def __init__(
+ self,
+ net: torch.nn.Module,
+ trainloader: DataLoader,
+ valloader: DataLoader,
+ device: torch.device,
+ num_epochs: int,
+ proximal_mu: float,
+ learning_rate: float,
+ momentum: float,
+ weight_decay: float,
+ ) -> None:
+ self.net = net
+ self.trainloader = trainloader
+ self.valloader = valloader
+ self.device = device
+ self.num_epochs = num_epochs
+ self.proximal_mu = proximal_mu
+ self.learning_rate = learning_rate
+ self.momentum = momentum
+ self.weight_decay = weight_decay
+
+ def get_parameters(self, config: Dict[str, Scalar]):
+ """Return the current local model parameters."""
+ return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
+
+ def set_parameters(self, parameters):
+ """Set the local model parameters using given ones."""
+ params_dict = zip(self.net.state_dict().keys(), parameters)
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
+ self.net.load_state_dict(state_dict, strict=True)
+
+ def fit(self, parameters, config: Dict[str, Scalar]):
+ """Implement distributed fit function for a given client for FedProx."""
+ self.set_parameters(parameters)
+ train_fedprox(
+ self.net,
+ self.trainloader,
+ self.device,
+ self.num_epochs,
+ self.proximal_mu,
+ self.learning_rate,
+ self.momentum,
+ self.weight_decay,
+ )
+ final_p_np = self.get_parameters({})
+ return final_p_np, len(self.trainloader.dataset), {}
+
+ def evaluate(self, parameters, config: Dict[str, Scalar]):
+ """Evaluate using given parameters."""
+ self.set_parameters(parameters)
+ loss, acc = test(self.net, self.valloader, self.device)
+ return float(loss), len(self.valloader.dataset), {"accuracy": float(acc)}
+
+
+# pylint: disable=too-many-arguments
+def gen_client_fn(
+ trainloaders: List[DataLoader],
+ valloaders: List[DataLoader],
+ num_epochs: int,
+ learning_rate: float,
+ model: DictConfig,
+ proximal_mu: float,
+ momentum: float = 0.9,
+ weight_decay: float = 1e-5,
+) -> Callable[[str], FlowerClientFedProx]: # pylint: disable=too-many-arguments
+ """Generate the client function that creates the FedProx flower clients.
+
+ Parameters
+ ----------
+ trainloaders: List[DataLoader]
+ A list of DataLoaders, each pointing to the dataset training partition
+ belonging to a particular client.
+ valloaders: List[DataLoader]
+ A list of DataLoaders, each pointing to the dataset validation partition
+ belonging to a particular client.
+ num_epochs : int
+ The number of local epochs each client should run the training for before
+ sending it to the server.
+ learning_rate : float
+ The learning rate for the SGD optimizer of clients.
+ model : DictConfig
+ The model configuration.
+ proximal_mu : float
+ The proximal mu parameter.
+ momentum : float
+ The momentum for SGD optimizer of clients
+ weight_decay : float
+ The weight decay for SGD optimizer of clients
+
+ Returns
+ -------
+ Callable[[str], FlowerClientFedProx]
+ The client function that creates the FedProx flower clients
+ """
+
+ def client_fn(cid: str) -> FlowerClientFedProx:
+ """Create a Flower client representing a single organization."""
+ # Load model
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ net = instantiate(model).to(device)
+
+ # Note: each client gets a different trainloader/valloader, so each client
+ # will train and evaluate on their own unique data
+ trainloader = trainloaders[int(cid)]
+ valloader = valloaders[int(cid)]
+
+ return FlowerClientFedProx(
+ net,
+ trainloader,
+ valloader,
+ device,
+ num_epochs,
+ proximal_mu,
+ learning_rate,
+ momentum,
+ weight_decay,
+ )
+
+ return client_fn
diff --git a/baselines/niid_bench/niid_bench/client_scaffold.py b/baselines/niid_bench/niid_bench/client_scaffold.py
new file mode 100644
index 000000000000..b95315073ebb
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/client_scaffold.py
@@ -0,0 +1,186 @@
+"""Defines the client class and support functions for SCAFFOLD."""
+
+import os
+from typing import Callable, Dict, List, OrderedDict
+
+import flwr as fl
+import torch
+from flwr.common import Scalar
+from hydra.utils import instantiate
+from omegaconf import DictConfig
+from torch.utils.data import DataLoader
+
+from niid_bench.models import test, train_scaffold
+
+
+# pylint: disable=too-many-instance-attributes
+class FlowerClientScaffold(fl.client.NumPyClient):
+ """Flower client implementing scaffold."""
+
+ # pylint: disable=too-many-arguments
+ def __init__(
+ self,
+ cid: int,
+ net: torch.nn.Module,
+ trainloader: DataLoader,
+ valloader: DataLoader,
+ device: torch.device,
+ num_epochs: int,
+ learning_rate: float,
+ momentum: float,
+ weight_decay: float,
+ save_dir: str = "",
+ ) -> None:
+ self.cid = cid
+ self.net = net
+ self.trainloader = trainloader
+ self.valloader = valloader
+ self.device = device
+ self.num_epochs = num_epochs
+ self.learning_rate = learning_rate
+ self.momentum = momentum
+ self.weight_decay = weight_decay
+ # initialize client control variate with 0 and shape of the network parameters
+ self.client_cv = []
+ for param in self.net.parameters():
+ self.client_cv.append(torch.zeros(param.shape))
+ # save cv to directory
+ if save_dir == "":
+ save_dir = "client_cvs"
+ self.dir = save_dir
+ if not os.path.exists(self.dir):
+ os.makedirs(self.dir)
+
+ def get_parameters(self, config: Dict[str, Scalar]):
+ """Return the current local model parameters."""
+ return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
+
+ def set_parameters(self, parameters):
+ """Set the local model parameters using given ones."""
+ params_dict = zip(self.net.state_dict().keys(), parameters)
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
+ self.net.load_state_dict(state_dict, strict=True)
+
+ def fit(self, parameters, config: Dict[str, Scalar]):
+ """Implement distributed fit function for a given client for SCAFFOLD."""
+ # the first half are model parameters and the second are the server_cv
+ server_cv = parameters[len(parameters) // 2 :]
+ parameters = parameters[: len(parameters) // 2]
+ self.set_parameters(parameters)
+ self.client_cv = []
+ for param in self.net.parameters():
+ self.client_cv.append(param.clone().detach())
+ # load client control variate
+ if os.path.exists(f"{self.dir}/client_cv_{self.cid}.pt"):
+ self.client_cv = torch.load(f"{self.dir}/client_cv_{self.cid}.pt")
+ # convert the server control variate to a list of tensors
+ server_cv = [torch.Tensor(cv) for cv in server_cv]
+ train_scaffold(
+ self.net,
+ self.trainloader,
+ self.device,
+ self.num_epochs,
+ self.learning_rate,
+ self.momentum,
+ self.weight_decay,
+ server_cv,
+ self.client_cv,
+ )
+ x = parameters
+ y_i = self.get_parameters(config={})
+ c_i_n = []
+ server_update_x = []
+ server_update_c = []
+ # update client control variate c_i_1 = c_i - c + 1/eta*K (x - y_i)
+ for c_i_j, c_j, x_j, y_i_j in zip(self.client_cv, server_cv, x, y_i):
+ c_i_n.append(
+ c_i_j
+ - c_j
+ + (1.0 / (self.learning_rate * self.num_epochs * len(self.trainloader)))
+ * (x_j - y_i_j)
+ )
+ # y_i - x, c_i_n - c_i for the server
+ server_update_x.append((y_i_j - x_j))
+ server_update_c.append((c_i_n[-1] - c_i_j).cpu().numpy())
+ self.client_cv = c_i_n
+ torch.save(self.client_cv, f"{self.dir}/client_cv_{self.cid}.pt")
+
+ combined_updates = server_update_x + server_update_c
+
+ return (
+ combined_updates,
+ len(self.trainloader.dataset),
+ {},
+ )
+
+ def evaluate(self, parameters, config: Dict[str, Scalar]):
+ """Evaluate using given parameters."""
+ self.set_parameters(parameters)
+ loss, acc = test(self.net, self.valloader, self.device)
+ return float(loss), len(self.valloader.dataset), {"accuracy": float(acc)}
+
+
+# pylint: disable=too-many-arguments
+def gen_client_fn(
+ trainloaders: List[DataLoader],
+ valloaders: List[DataLoader],
+ client_cv_dir: str,
+ num_epochs: int,
+ learning_rate: float,
+ model: DictConfig,
+ momentum: float = 0.9,
+ weight_decay: float = 0.0,
+) -> Callable[[str], FlowerClientScaffold]: # pylint: disable=too-many-arguments
+ """Generate the client function that creates the scaffold flower clients.
+
+ Parameters
+ ----------
+ trainloaders: List[DataLoader]
+ A list of DataLoaders, each pointing to the dataset training partition
+ belonging to a particular client.
+ valloaders: List[DataLoader]
+ A list of DataLoaders, each pointing to the dataset validation partition
+ belonging to a particular client.
+ client_cv_dir : str
+ The directory where the client control variates are stored (persistent storage).
+ num_epochs : int
+ The number of local epochs each client should run the training for before
+ sending it to the server.
+ learning_rate : float
+ The learning rate for the SGD optimizer of clients.
+ momentum : float
+ The momentum for SGD optimizer of clients.
+ weight_decay : float
+ The weight decay for SGD optimizer of clients.
+
+ Returns
+ -------
+ Callable[[str], FlowerClientScaffold]
+ The client function that creates the scaffold flower clients.
+ """
+
+ def client_fn(cid: str) -> FlowerClientScaffold:
+ """Create a Flower client representing a single organization."""
+ # Load model
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ net = instantiate(model).to(device)
+
+ # Note: each client gets a different trainloader/valloader, so each client
+ # will train and evaluate on their own unique data
+ trainloader = trainloaders[int(cid)]
+ valloader = valloaders[int(cid)]
+
+ return FlowerClientScaffold(
+ int(cid),
+ net,
+ trainloader,
+ valloader,
+ device,
+ num_epochs,
+ learning_rate,
+ momentum,
+ weight_decay,
+ save_dir=client_cv_dir,
+ )
+
+ return client_fn
diff --git a/baselines/niid_bench/niid_bench/conf/fedavg_base.yaml b/baselines/niid_bench/niid_bench/conf/fedavg_base.yaml
new file mode 100644
index 000000000000..5f076fe1d0c3
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/conf/fedavg_base.yaml
@@ -0,0 +1,61 @@
+---
+# this is the config that will be loaded as default by main.py
+# Please follow the provided structure (this will ensuring all baseline follow
+# a similar configuration structure and hence be easy to customise)
+
+num_clients: 10
+num_epochs: 10
+batch_size: 64
+clients_per_round: 10
+learning_rate: 0.01
+num_rounds: 50
+partitioning: "dirichlet"
+dataset_name: "cifar10"
+dataset_seed: 42
+alpha: 0.5
+labels_per_client: 2 # only used when partitioning is label quantity
+momentum: 0.9
+weight_decay: 0.00001
+
+client_fn:
+ _target_: niid_bench.client_fedavg.gen_client_fn
+ _recursive_: False
+ num_epochs: ${num_epochs}
+ learning_rate: ${learning_rate}
+ momentum: ${momentum}
+ weight_decay: ${weight_decay}
+
+dataset:
+ # dataset config
+ name: ${dataset_name}
+ partitioning: ${partitioning}
+ batch_size: ${batch_size} # batch_size = batch_size_ratio * total_local_data_size
+ val_split: 0.0
+ seed: ${dataset_seed}
+ alpha: ${alpha}
+ labels_per_client: ${labels_per_client}
+
+model:
+ # model config
+ _target_: niid_bench.models.CNN
+ input_dim: 400
+ hidden_dims: [120, 84]
+ num_classes: 10
+
+strategy:
+ _target_: flwr.server.strategy.FedAvg # points to your strategy (either custom or exiting in Flower)
+ # rest of strategy config
+ fraction_fit: 0.00001 # because we want the number of clients to sample on each round to be solely defined by min_fit_clients
+ fraction_evaluate: 0.0
+ min_fit_clients: ${clients_per_round}
+ min_available_clients: ${clients_per_round}
+ min_evaluate_clients: 0
+
+client:
+ # client config
+
+server_device: cpu
+
+client_resources:
+ num_cpus: 4
+ num_gpus: 0.0
diff --git a/baselines/niid_bench/niid_bench/conf/fednova_base.yaml b/baselines/niid_bench/niid_bench/conf/fednova_base.yaml
new file mode 100644
index 000000000000..30a8939768c7
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/conf/fednova_base.yaml
@@ -0,0 +1,61 @@
+---
+# this is the config that will be loaded as default by main.py
+# Please follow the provided structure (this will ensuring all baseline follow
+# a similar configuration structure and hence be easy to customise)
+
+num_clients: 10
+num_epochs: 10
+batch_size: 64
+clients_per_round: 10
+learning_rate: 0.01
+num_rounds: 50
+partitioning: "dirichlet"
+dataset_name: "cifar10"
+dataset_seed: 42
+alpha: 0.5
+labels_per_client: 2 # only used when partitioning is label quantity
+momentum: 0.9
+weight_decay: 0.00001
+
+client_fn:
+ _target_: niid_bench.client_fednova.gen_client_fn
+ _recursive_: False
+ num_epochs: ${num_epochs}
+ learning_rate: ${learning_rate}
+ momentum: ${momentum}
+ weight_decay: ${weight_decay}
+
+dataset:
+ # dataset config
+ name: ${dataset_name}
+ partitioning: ${partitioning}
+ batch_size: ${batch_size} # batch_size = batch_size_ratio * total_local_data_size
+ val_split: 0.0
+ seed: ${dataset_seed}
+ alpha: ${alpha}
+ labels_per_client: ${labels_per_client}
+
+model:
+ # model config
+ _target_: niid_bench.models.CNN
+ input_dim: 400
+ hidden_dims: [120, 84]
+ num_classes: 10
+
+strategy:
+ _target_: niid_bench.strategy.FedNovaStrategy # points to your strategy (either custom or exiting in Flower)
+ # rest of strategy config
+ fraction_fit: 0.00001 # because we want the number of clients to sample on each round to be solely defined by min_fit_clients
+ fraction_evaluate: 0.0
+ min_fit_clients: ${clients_per_round}
+ min_available_clients: ${clients_per_round}
+ min_evaluate_clients: 0
+
+client:
+ # client config
+
+server_device: cpu
+
+client_resources:
+ num_cpus: 4
+ num_gpus: 0.0
diff --git a/baselines/niid_bench/niid_bench/conf/fedprox_base.yaml b/baselines/niid_bench/niid_bench/conf/fedprox_base.yaml
new file mode 100644
index 000000000000..d3dbfb5f2761
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/conf/fedprox_base.yaml
@@ -0,0 +1,63 @@
+---
+# this is the config that will be loaded as default by main.py
+# Please follow the provided structure (this will ensuring all baseline follow
+# a similar configuration structure and hence be easy to customise)
+
+num_clients: 10
+num_epochs: 10
+batch_size: 64
+clients_per_round: 10
+learning_rate: 0.01
+num_rounds: 50
+mu: 0.01
+partitioning: "dirichlet"
+dataset_name: "cifar10"
+dataset_seed: 42
+alpha: 0.5
+labels_per_client: 2 # only used when partitioning is label quantity
+momentum: 0.9
+weight_decay: 0.00001
+
+client_fn:
+ _target_: niid_bench.client_fedprox.gen_client_fn
+ _recursive_: False
+ proximal_mu: ${mu}
+ num_epochs: ${num_epochs}
+ learning_rate: ${learning_rate}
+ momentum: ${momentum}
+ weight_decay: ${weight_decay}
+
+dataset:
+ # dataset config
+ name: ${dataset_name}
+ partitioning: ${partitioning}
+ batch_size: ${batch_size} # batch_size = batch_size_ratio * total_local_data_size
+ val_split: 0.0
+ seed: ${dataset_seed}
+ alpha: ${alpha}
+ labels_per_client: ${labels_per_client}
+
+model:
+ # model config
+ _target_: niid_bench.models.CNN
+ input_dim: 400
+ hidden_dims: [120, 84]
+ num_classes: 10
+
+strategy:
+ _target_: flwr.server.strategy.FedAvg # points to your strategy (either custom or exiting in Flower)
+ # rest of strategy config
+ fraction_fit: 0.00001 # because we want the number of clients to sample on each round to be solely defined by min_fit_clients
+ fraction_evaluate: 0.0
+ min_fit_clients: ${clients_per_round}
+ min_available_clients: ${clients_per_round}
+ min_evaluate_clients: 0
+
+client:
+ # client config
+
+server_device: cpu
+
+client_resources:
+ num_cpus: 4
+ num_gpus: 0.0
diff --git a/baselines/niid_bench/niid_bench/conf/scaffold_base.yaml b/baselines/niid_bench/niid_bench/conf/scaffold_base.yaml
new file mode 100644
index 000000000000..b4b5e3577f4c
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/conf/scaffold_base.yaml
@@ -0,0 +1,61 @@
+---
+# this is the config that will be loaded as default by main.py
+# Please follow the provided structure (this will ensuring all baseline follow
+# a similar configuration structure and hence be easy to customise)
+
+num_clients: 10
+num_epochs: 10
+batch_size: 64
+clients_per_round: 10
+learning_rate: 0.01
+num_rounds: 50
+partitioning: "dirichlet"
+dataset_name: "cifar10"
+dataset_seed: 42
+alpha: 0.5
+labels_per_client: 2 # only used when partitioning is label quantity
+momentum: 0.9
+weight_decay: 0.00001
+
+client_fn:
+ _target_: niid_bench.client_scaffold.gen_client_fn
+ _recursive_: False
+ num_epochs: ${num_epochs}
+ learning_rate: ${learning_rate}
+ momentum: ${momentum}
+ weight_decay: ${weight_decay}
+
+dataset:
+ # dataset config
+ name: ${dataset_name}
+ partitioning: ${partitioning}
+ batch_size: ${batch_size} # batch_size = batch_size_ratio * total_local_data_size
+ val_split: 0.0
+ seed: ${dataset_seed}
+ alpha: ${alpha}
+ labels_per_client: ${labels_per_client}
+
+model:
+ # model config
+ _target_: niid_bench.models.CNN
+ input_dim: 400
+ hidden_dims: [120, 84]
+ num_classes: 10
+
+strategy:
+ _target_: niid_bench.strategy.ScaffoldStrategy # points to your strategy (either custom or exiting in Flower)
+ # rest of strategy config
+ fraction_fit: 0.00001 # because we want the number of clients to sample on each round to be solely defined by min_fit_clients
+ fraction_evaluate: 0.0
+ min_fit_clients: ${clients_per_round}
+ min_available_clients: ${clients_per_round}
+ min_evaluate_clients: 0
+
+client:
+ # client config
+
+server_device: cpu
+
+client_resources:
+ num_cpus: 4
+ num_gpus: 0.0
diff --git a/baselines/niid_bench/niid_bench/dataset.py b/baselines/niid_bench/niid_bench/dataset.py
new file mode 100644
index 000000000000..5e3780e61390
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/dataset.py
@@ -0,0 +1,106 @@
+"""Partition the data and create the dataloaders."""
+
+from typing import List, Optional, Tuple
+
+import torch
+from omegaconf import DictConfig
+from torch.utils.data import DataLoader, random_split
+
+from niid_bench.dataset_preparation import (
+ partition_data,
+ partition_data_dirichlet,
+ partition_data_label_quantity,
+)
+
+
+# pylint: disable=too-many-locals, too-many-branches
+def load_datasets(
+ config: DictConfig,
+ num_clients: int,
+ val_ratio: float = 0.1,
+ seed: Optional[int] = 42,
+) -> Tuple[List[DataLoader], List[DataLoader], DataLoader]:
+ """Create the dataloaders to be fed into the model.
+
+ Parameters
+ ----------
+ config: DictConfig
+ Parameterises the dataset partitioning process
+ num_clients : int
+ The number of clients that hold a part of the data
+ val_ratio : float, optional
+ The ratio of training data that will be used for validation (between 0 and 1),
+ by default 0.1
+ seed : int, optional
+ Used to set a fix seed to replicate experiments, by default 42
+
+ Returns
+ -------
+ Tuple[DataLoader, DataLoader, DataLoader]
+ The DataLoaders for training, validation, and testing.
+ """
+ print(f"Dataset partitioning config: {config}")
+ partitioning = ""
+ if "partitioning" in config:
+ partitioning = config.partitioning
+ # partition the data
+ if partitioning == "dirichlet":
+ alpha = 0.5
+ if "alpha" in config:
+ alpha = config.alpha
+ datasets, testset = partition_data_dirichlet(
+ num_clients,
+ alpha=alpha,
+ seed=seed,
+ dataset_name=config.name,
+ )
+ elif partitioning == "label_quantity":
+ labels_per_client = 2
+ if "labels_per_client" in config:
+ labels_per_client = config.labels_per_client
+ datasets, testset = partition_data_label_quantity(
+ num_clients,
+ labels_per_client=labels_per_client,
+ seed=seed,
+ dataset_name=config.name,
+ )
+ elif partitioning == "iid":
+ datasets, testset = partition_data(
+ num_clients,
+ similarity=1.0,
+ seed=seed,
+ dataset_name=config.name,
+ )
+ elif partitioning == "iid_noniid":
+ similarity = 0.5
+ if "similarity" in config:
+ similarity = config.similarity
+ datasets, testset = partition_data(
+ num_clients,
+ similarity=similarity,
+ seed=seed,
+ dataset_name=config.name,
+ )
+
+ batch_size = -1
+ if "batch_size" in config:
+ batch_size = config.batch_size
+ elif "batch_size_ratio" in config:
+ batch_size_ratio = config.batch_size_ratio
+ else:
+ raise ValueError
+
+ # split each partition into train/val and create DataLoader
+ trainloaders = []
+ valloaders = []
+ for dataset in datasets:
+ len_val = int(len(dataset) / (1 / val_ratio)) if val_ratio > 0 else 0
+ lengths = [len(dataset) - len_val, len_val]
+ ds_train, ds_val = random_split(
+ dataset, lengths, torch.Generator().manual_seed(seed)
+ )
+ if batch_size == -1:
+ batch_size = int(len(ds_train) * batch_size_ratio)
+ trainloaders.append(DataLoader(ds_train, batch_size=batch_size, shuffle=True))
+ valloaders.append(DataLoader(ds_val, batch_size=batch_size))
+ return trainloaders, valloaders, DataLoader(testset, batch_size=len(testset))
diff --git a/baselines/niid_bench/niid_bench/dataset_preparation.py b/baselines/niid_bench/niid_bench/dataset_preparation.py
new file mode 100644
index 000000000000..3110191cbd6a
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/dataset_preparation.py
@@ -0,0 +1,315 @@
+"""Download data and partition data with different partitioning strategies."""
+
+from typing import List, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+from torch.autograd import Variable
+from torch.utils.data import ConcatDataset, Dataset, Subset
+from torchvision.datasets import CIFAR10, MNIST, FashionMNIST
+
+
+def _download_data(dataset_name="emnist") -> Tuple[Dataset, Dataset]:
+ """Download the requested dataset. Currently supports cifar10, mnist, and fmnist.
+
+ Returns
+ -------
+ Tuple[Dataset, Dataset]
+ The training dataset, the test dataset.
+ """
+ trainset, testset = None, None
+ if dataset_name == "cifar10":
+ transform_train = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Lambda(
+ lambda x: F.pad(
+ Variable(x.unsqueeze(0), requires_grad=False),
+ (4, 4, 4, 4),
+ mode="reflect",
+ ).data.squeeze()
+ ),
+ transforms.ToPILImage(),
+ transforms.RandomCrop(32),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ ]
+ )
+ transform_test = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ ]
+ )
+ trainset = CIFAR10(
+ root="data",
+ train=True,
+ download=True,
+ transform=transform_train,
+ )
+ testset = CIFAR10(
+ root="data",
+ train=False,
+ download=True,
+ transform=transform_test,
+ )
+ elif dataset_name == "mnist":
+ transform_train = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ ]
+ )
+ transform_test = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ ]
+ )
+ trainset = MNIST(
+ root="data",
+ train=True,
+ download=True,
+ transform=transform_train,
+ )
+ testset = MNIST(
+ root="data",
+ train=False,
+ download=True,
+ transform=transform_test,
+ )
+ elif dataset_name == "fmnist":
+ transform_train = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ ]
+ )
+ transform_test = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ ]
+ )
+ trainset = FashionMNIST(
+ root="data",
+ train=True,
+ download=True,
+ transform=transform_train,
+ )
+ testset = FashionMNIST(
+ root="data",
+ train=False,
+ download=True,
+ transform=transform_test,
+ )
+ else:
+ raise NotImplementedError
+
+ return trainset, testset
+
+
+# pylint: disable=too-many-locals
+def partition_data(
+ num_clients, similarity=1.0, seed=42, dataset_name="cifar10"
+) -> Tuple[List[Dataset], Dataset]:
+ """Partition the dataset into subsets for each client.
+
+ Parameters
+ ----------
+ num_clients : int
+ The number of clients that hold a part of the data
+ similarity: float
+ Parameter to sample similar data
+ seed : int, optional
+ Used to set a fix seed to replicate experiments, by default 42
+
+ Returns
+ -------
+ Tuple[List[Subset], Dataset]
+ The list of datasets for each client, the test dataset.
+ """
+ trainset, testset = _download_data(dataset_name)
+ trainsets_per_client = []
+ # for s% similarity sample iid data per client
+ s_fraction = int(similarity * len(trainset))
+ prng = np.random.default_rng(seed)
+ idxs = prng.choice(len(trainset), s_fraction, replace=False)
+ iid_trainset = Subset(trainset, idxs)
+ rem_trainset = Subset(trainset, np.setdiff1d(np.arange(len(trainset)), idxs))
+
+ # sample iid data per client from iid_trainset
+ all_ids = np.arange(len(iid_trainset))
+ splits = np.array_split(all_ids, num_clients)
+ for i in range(num_clients):
+ c_ids = splits[i]
+ d_ids = iid_trainset.indices[c_ids]
+ trainsets_per_client.append(Subset(iid_trainset.dataset, d_ids))
+
+ if similarity == 1.0:
+ return trainsets_per_client, testset
+
+ tmp_t = rem_trainset.dataset.targets
+ if isinstance(tmp_t, list):
+ tmp_t = np.array(tmp_t)
+ if isinstance(tmp_t, torch.Tensor):
+ tmp_t = tmp_t.numpy()
+ targets = tmp_t[rem_trainset.indices]
+ num_remaining_classes = len(set(targets))
+ remaining_classes = list(set(targets))
+ client_classes: List[List] = [[] for _ in range(num_clients)]
+ times = [0 for _ in range(num_remaining_classes)]
+
+ for i in range(num_clients):
+ client_classes[i] = [remaining_classes[i % num_remaining_classes]]
+ times[i % num_remaining_classes] += 1
+ j = 1
+ while j < 2:
+ index = prng.choice(num_remaining_classes)
+ class_t = remaining_classes[index]
+ if class_t not in client_classes[i]:
+ client_classes[i].append(class_t)
+ times[index] += 1
+ j += 1
+
+ rem_trainsets_per_client: List[List] = [[] for _ in range(num_clients)]
+
+ for i in range(num_remaining_classes):
+ class_t = remaining_classes[i]
+ idx_k = np.where(targets == i)[0]
+ prng.shuffle(idx_k)
+ idx_k_split = np.array_split(idx_k, times[i])
+ ids = 0
+ for j in range(num_clients):
+ if class_t in client_classes[j]:
+ act_idx = rem_trainset.indices[idx_k_split[ids]]
+ rem_trainsets_per_client[j].append(
+ Subset(rem_trainset.dataset, act_idx)
+ )
+ ids += 1
+
+ for i in range(num_clients):
+ trainsets_per_client[i] = ConcatDataset(
+ [trainsets_per_client[i]] + rem_trainsets_per_client[i]
+ )
+
+ return trainsets_per_client, testset
+
+
+def partition_data_dirichlet(
+ num_clients, alpha, seed=42, dataset_name="cifar10"
+) -> Tuple[List[Dataset], Dataset]:
+ """Partition according to the Dirichlet distribution.
+
+ Parameters
+ ----------
+ num_clients : int
+ The number of clients that hold a part of the data
+ alpha: float
+ Parameter of the Dirichlet distribution
+ seed : int, optional
+ Used to set a fix seed to replicate experiments, by default 42
+ dataset_name : str
+ Name of the dataset to be used
+
+ Returns
+ -------
+ Tuple[List[Subset], Dataset]
+ The list of datasets for each client, the test dataset.
+ """
+ trainset, testset = _download_data(dataset_name)
+ min_required_samples_per_client = 10
+ min_samples = 0
+ prng = np.random.default_rng(seed)
+
+ # get the targets
+ tmp_t = trainset.targets
+ if isinstance(tmp_t, list):
+ tmp_t = np.array(tmp_t)
+ if isinstance(tmp_t, torch.Tensor):
+ tmp_t = tmp_t.numpy()
+ num_classes = len(set(tmp_t))
+ total_samples = len(tmp_t)
+ while min_samples < min_required_samples_per_client:
+ idx_clients: List[List] = [[] for _ in range(num_clients)]
+ for k in range(num_classes):
+ idx_k = np.where(tmp_t == k)[0]
+ prng.shuffle(idx_k)
+ proportions = prng.dirichlet(np.repeat(alpha, num_clients))
+ proportions = np.array(
+ [
+ p * (len(idx_j) < total_samples / num_clients)
+ for p, idx_j in zip(proportions, idx_clients)
+ ]
+ )
+ proportions = proportions / proportions.sum()
+ proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
+ idx_k_split = np.split(idx_k, proportions)
+ idx_clients = [
+ idx_j + idx.tolist() for idx_j, idx in zip(idx_clients, idx_k_split)
+ ]
+ min_samples = min([len(idx_j) for idx_j in idx_clients])
+
+ trainsets_per_client = [Subset(trainset, idxs) for idxs in idx_clients]
+ return trainsets_per_client, testset
+
+
+def partition_data_label_quantity(
+ num_clients, labels_per_client, seed=42, dataset_name="cifar10"
+) -> Tuple[List[Dataset], Dataset]:
+ """Partition the data according to the number of labels per client.
+
+ Logic from https://github.com/Xtra-Computing/NIID-Bench/.
+
+ Parameters
+ ----------
+ num_clients : int
+ The number of clients that hold a part of the data
+ num_labels_per_client: int
+ Number of labels per client
+ seed : int, optional
+ Used to set a fix seed to replicate experiments, by default 42
+ dataset_name : str
+ Name of the dataset to be used
+
+ Returns
+ -------
+ Tuple[List[Subset], Dataset]
+ The list of datasets for each client, the test dataset.
+ """
+ trainset, testset = _download_data(dataset_name)
+ prng = np.random.default_rng(seed)
+
+ targets = trainset.targets
+ if isinstance(targets, list):
+ targets = np.array(targets)
+ if isinstance(targets, torch.Tensor):
+ targets = targets.numpy()
+ num_classes = len(set(targets))
+ times = [0 for _ in range(num_classes)]
+ contains = []
+
+ for i in range(num_clients):
+ current = [i % num_classes]
+ times[i % num_classes] += 1
+ j = 1
+ while j < labels_per_client:
+ index = prng.choice(num_classes, 1)[0]
+ if index not in current:
+ current.append(index)
+ times[index] += 1
+ j += 1
+ contains.append(current)
+ idx_clients: List[List] = [[] for _ in range(num_clients)]
+ for i in range(num_classes):
+ idx_k = np.where(targets == i)[0]
+ prng.shuffle(idx_k)
+ idx_k_split = np.array_split(idx_k, times[i])
+ ids = 0
+ for j in range(num_clients):
+ if i in contains[j]:
+ idx_clients[j] += idx_k_split[ids].tolist()
+ ids += 1
+ trainsets_per_client = [Subset(trainset, idxs) for idxs in idx_clients]
+ return trainsets_per_client, testset
+
+
+if __name__ == "__main__":
+ partition_data(100, 0.1)
diff --git a/baselines/niid_bench/niid_bench/main.py b/baselines/niid_bench/niid_bench/main.py
new file mode 100644
index 000000000000..c845925c66ed
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/main.py
@@ -0,0 +1,110 @@
+"""Create and connect the building blocks for your experiments; start the simulation.
+
+It includes processioning the dataset, instantiate strategy, specify how the global
+model is going to be evaluated, etc. At the end, this script saves the results.
+"""
+import os
+import pickle
+
+import flwr as fl
+import hydra
+from flwr.server.client_manager import SimpleClientManager
+from flwr.server.server import Server
+from hydra.core.hydra_config import HydraConfig
+from hydra.utils import call, instantiate
+from omegaconf import DictConfig, OmegaConf
+
+from niid_bench.dataset import load_datasets
+from niid_bench.server_fednova import FedNovaServer
+from niid_bench.server_scaffold import ScaffoldServer, gen_evaluate_fn
+from niid_bench.strategy import FedNovaStrategy, ScaffoldStrategy
+
+
+@hydra.main(config_path="conf", config_name="fedavg_base", version_base=None)
+def main(cfg: DictConfig) -> None:
+ """Run the baseline.
+
+ Parameters
+ ----------
+ cfg : DictConfig
+ An omegaconf object that stores the hydra config.
+ """
+ # 1. Print parsed config
+ if "mnist" in cfg.dataset_name:
+ cfg.model.input_dim = 256
+ # pylint: disable=protected-access
+ cfg.model._target_ = "niid_bench.models.CNNMnist"
+ print(OmegaConf.to_yaml(cfg))
+
+ # 2. Prepare your dataset
+ trainloaders, valloaders, testloader = load_datasets(
+ config=cfg.dataset,
+ num_clients=cfg.num_clients,
+ val_ratio=cfg.dataset.val_split,
+ )
+
+ # 3. Define your clients
+ client_fn = None
+ # pylint: disable=protected-access
+ if cfg.client_fn._target_ == "niid_bench.client_scaffold.gen_client_fn":
+ save_path = HydraConfig.get().runtime.output_dir
+ client_cv_dir = os.path.join(save_path, "client_cvs")
+ print("Local cvs for scaffold clients are saved to: ", client_cv_dir)
+ client_fn = call(
+ cfg.client_fn,
+ trainloaders,
+ valloaders,
+ model=cfg.model,
+ client_cv_dir=client_cv_dir,
+ )
+ else:
+ client_fn = call(
+ cfg.client_fn,
+ trainloaders,
+ valloaders,
+ model=cfg.model,
+ )
+
+ device = cfg.server_device
+ evaluate_fn = gen_evaluate_fn(testloader, device=device, model=cfg.model)
+
+ # 4. Define your strategy
+ strategy = instantiate(
+ cfg.strategy,
+ evaluate_fn=evaluate_fn,
+ )
+
+ # 5. Define your server
+ server = Server(strategy=strategy, client_manager=SimpleClientManager())
+ if isinstance(strategy, FedNovaStrategy):
+ server = FedNovaServer(strategy=strategy, client_manager=SimpleClientManager())
+ elif isinstance(strategy, ScaffoldStrategy):
+ server = ScaffoldServer(
+ strategy=strategy, model=cfg.model, client_manager=SimpleClientManager()
+ )
+
+ # 6. Start Simulation
+ history = fl.simulation.start_simulation(
+ server=server,
+ client_fn=client_fn,
+ num_clients=cfg.num_clients,
+ config=fl.server.ServerConfig(num_rounds=cfg.num_rounds),
+ client_resources={
+ "num_cpus": cfg.client_resources.num_cpus,
+ "num_gpus": cfg.client_resources.num_gpus,
+ },
+ strategy=strategy,
+ )
+
+ print(history)
+
+ save_path = HydraConfig.get().runtime.output_dir
+ print(save_path)
+
+ # 7. Save your results
+ with open(os.path.join(save_path, "history.pkl"), "wb") as f_ptr:
+ pickle.dump(history, f_ptr)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/baselines/niid_bench/niid_bench/models.py b/baselines/niid_bench/niid_bench/models.py
new file mode 100644
index 000000000000..dc2ee6633600
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/models.py
@@ -0,0 +1,414 @@
+"""Implement the neural network models and training functions."""
+
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.parameter import Parameter
+from torch.optim import SGD, Optimizer
+from torch.utils.data import DataLoader
+
+
+class CNN(nn.Module):
+ """Implement a CNN model for CIFAR-10.
+
+ Parameters
+ ----------
+ input_dim : int
+ The input dimension for classifier.
+ hidden_dims : List[int]
+ The hidden dimensions for classifier.
+ num_classes : int
+ The number of classes in the dataset.
+ """
+
+ def __init__(self, input_dim, hidden_dims, num_classes):
+ super().__init__()
+ self.conv1 = nn.Conv2d(3, 6, 5)
+ self.pool = nn.MaxPool2d(2, 2)
+ self.conv2 = nn.Conv2d(6, 16, 5)
+
+ self.fc1 = nn.Linear(input_dim, hidden_dims[0])
+ self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
+ self.fc3 = nn.Linear(hidden_dims[1], num_classes)
+
+ def forward(self, x):
+ """Implement forward pass."""
+ x = self.pool(F.relu(self.conv1(x)))
+ x = self.pool(F.relu(self.conv2(x)))
+ x = x.view(-1, 16 * 5 * 5)
+
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = self.fc3(x)
+ return x
+
+
+class CNNMnist(nn.Module):
+ """Implement a CNN model for MNIST and Fashion-MNIST.
+
+ Parameters
+ ----------
+ input_dim : int
+ The input dimension for classifier.
+ hidden_dims : List[int]
+ The hidden dimensions for classifier.
+ num_classes : int
+ The number of classes in the dataset.
+ """
+
+ def __init__(self, input_dim, hidden_dims, num_classes) -> None:
+ super().__init__()
+ self.conv1 = nn.Conv2d(1, 6, 5)
+ self.pool = nn.MaxPool2d(2, 2)
+ self.conv2 = nn.Conv2d(6, 16, 5)
+
+ self.fc1 = nn.Linear(input_dim, hidden_dims[0])
+ self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
+ self.fc3 = nn.Linear(hidden_dims[1], num_classes)
+
+ def forward(self, x):
+ """Implement forward pass."""
+ x = self.pool(F.relu(self.conv1(x)))
+ x = self.pool(F.relu(self.conv2(x)))
+
+ x = x.view(-1, 16 * 4 * 4)
+
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = self.fc3(x)
+ return x
+
+
+class ScaffoldOptimizer(SGD):
+ """Implements SGD optimizer step function as defined in the SCAFFOLD paper."""
+
+ def __init__(self, grads, step_size, momentum, weight_decay):
+ super().__init__(
+ grads, lr=step_size, momentum=momentum, weight_decay=weight_decay
+ )
+
+ def step_custom(self, server_cv, client_cv):
+ """Implement the custom step function fo SCAFFOLD."""
+ # y_i = y_i - \eta * (g_i + c - c_i) -->
+ # y_i = y_i - \eta*(g_i + \mu*b_{t}) - \eta*(c - c_i)
+ self.step()
+ for group in self.param_groups:
+ for par, s_cv, c_cv in zip(group["params"], server_cv, client_cv):
+ par.data.add_(s_cv - c_cv, alpha=-group["lr"])
+
+
+def train_scaffold(
+ net: nn.Module,
+ trainloader: DataLoader,
+ device: torch.device,
+ epochs: int,
+ learning_rate: float,
+ momentum: float,
+ weight_decay: float,
+ server_cv: torch.Tensor,
+ client_cv: torch.Tensor,
+) -> None:
+ # pylint: disable=too-many-arguments
+ """Train the network on the training set using SCAFFOLD.
+
+ Parameters
+ ----------
+ net : nn.Module
+ The neural network to train.
+ trainloader : DataLoader
+ The training set dataloader object.
+ device : torch.device
+ The device on which to train the network.
+ epochs : int
+ The number of epochs to train the network.
+ learning_rate : float
+ The learning rate.
+ momentum : float
+ The momentum for SGD optimizer.
+ weight_decay : float
+ The weight decay for SGD optimizer.
+ server_cv : torch.Tensor
+ The server's control variate.
+ client_cv : torch.Tensor
+ The client's control variate.
+ """
+ criterion = nn.CrossEntropyLoss()
+ optimizer = ScaffoldOptimizer(
+ net.parameters(), learning_rate, momentum, weight_decay
+ )
+ net.train()
+ for _ in range(epochs):
+ net = _train_one_epoch_scaffold(
+ net, trainloader, device, criterion, optimizer, server_cv, client_cv
+ )
+
+
+def _train_one_epoch_scaffold(
+ net: nn.Module,
+ trainloader: DataLoader,
+ device: torch.device,
+ criterion: nn.Module,
+ optimizer: ScaffoldOptimizer,
+ server_cv: torch.Tensor,
+ client_cv: torch.Tensor,
+) -> nn.Module:
+ # pylint: disable=too-many-arguments
+ """Train the network on the training set for one epoch."""
+ for data, target in trainloader:
+ data, target = data.to(device), target.to(device)
+ optimizer.zero_grad()
+ output = net(data)
+ loss = criterion(output, target)
+ loss.backward()
+ optimizer.step_custom(server_cv, client_cv)
+ return net
+
+
+def train_fedavg(
+ net: nn.Module,
+ trainloader: DataLoader,
+ device: torch.device,
+ epochs: int,
+ learning_rate: float,
+ momentum: float,
+ weight_decay: float,
+) -> None:
+ # pylint: disable=too-many-arguments
+ """Train the network on the training set using FedAvg.
+
+ Parameters
+ ----------
+ net : nn.Module
+ The neural network to train.
+ trainloader : DataLoader
+ The training set dataloader object.
+ device : torch.device
+ The device on which to train the network.
+ epochs : int
+ The number of epochs to train the network.
+ learning_rate : float
+ The learning rate.
+ momentum : float
+ The momentum for SGD optimizer.
+ weight_decay : float
+ The weight decay for SGD optimizer.
+
+ Returns
+ -------
+ None
+ """
+ criterion = nn.CrossEntropyLoss()
+ optimizer = SGD(
+ net.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay
+ )
+ net.train()
+ for _ in range(epochs):
+ net = _train_one_epoch(net, trainloader, device, criterion, optimizer)
+
+
+def _train_one_epoch(
+ net: nn.Module,
+ trainloader: DataLoader,
+ device: torch.device,
+ criterion: nn.Module,
+ optimizer: Optimizer,
+) -> nn.Module:
+ """Train the network on the training set for one epoch."""
+ for data, target in trainloader:
+ data, target = data.to(device), target.to(device)
+ optimizer.zero_grad()
+ output = net(data)
+ loss = criterion(output, target)
+ loss.backward()
+ optimizer.step()
+ return net
+
+
+def train_fedprox(
+ net: nn.Module,
+ trainloader: DataLoader,
+ device: torch.device,
+ epochs: int,
+ proximal_mu: float,
+ learning_rate: float,
+ momentum: float,
+ weight_decay: float,
+) -> None:
+ # pylint: disable=too-many-arguments
+ """Train the network on the training set using FedAvg.
+
+ Parameters
+ ----------
+ net : nn.Module
+ The neural network to train.
+ trainloader : DataLoader
+ The training set dataloader object.
+ device : torch.device
+ The device on which to train the network.
+ epochs : int
+ The number of epochs to train the network.
+ proximal_mu : float
+ The proximal mu parameter.
+ learning_rate : float
+ The learning rate.
+ momentum : float
+ The momentum for SGD optimizer.
+ weight_decay : float
+ The weight decay for SGD optimizer.
+
+ Returns
+ -------
+ None
+ """
+ criterion = nn.CrossEntropyLoss()
+ optimizer = SGD(
+ net.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay
+ )
+ global_params = [param.detach().clone() for param in net.parameters()]
+ net.train()
+ for _ in range(epochs):
+ net = _train_one_epoch_fedprox(
+ net, global_params, trainloader, device, criterion, optimizer, proximal_mu
+ )
+
+
+def _train_one_epoch_fedprox(
+ net: nn.Module,
+ global_params: List[Parameter],
+ trainloader: DataLoader,
+ device: torch.device,
+ criterion: nn.Module,
+ optimizer: Optimizer,
+ proximal_mu: float,
+) -> nn.Module:
+ # pylint: disable=too-many-arguments
+ """Train the network on the training set for one epoch."""
+ for data, target in trainloader:
+ data, target = data.to(device), target.to(device)
+ optimizer.zero_grad()
+ output = net(data)
+ loss = criterion(output, target)
+ proximal_term = 0.0
+ for param, global_param in zip(net.parameters(), global_params):
+ proximal_term += torch.norm(param - global_param) ** 2
+ loss += (proximal_mu / 2) * proximal_term
+ loss.backward()
+ optimizer.step()
+ return net
+
+
+def train_fednova(
+ net: nn.Module,
+ trainloader: DataLoader,
+ device: torch.device,
+ epochs: int,
+ learning_rate: float,
+ momentum: float,
+ weight_decay: float,
+) -> Tuple[float, List[torch.Tensor]]:
+ # pylint: disable=too-many-arguments
+ """Train the network on the training set using FedNova.
+
+ Parameters
+ ----------
+ net : nn.Module
+ The neural network to train.
+ trainloader : DataLoader
+ The training set dataloader object.
+ device : torch.device
+ The device on which to train the network.
+ epochs : int
+ The number of epochs to train the network.
+ learning_rate : float
+ The learning rate.
+ momentum : float
+ The momentum for SGD optimizer.
+ weight_decay : float
+ The weight decay for SGD optimizer.
+
+ Returns
+ -------
+ tuple[float, List[torch.Tensor]]
+ The a_i and g_i values.
+ """
+ criterion = nn.CrossEntropyLoss()
+ optimizer = SGD(
+ net.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay
+ )
+ net.train()
+ local_steps = 0
+ # clone all the parameters
+ prev_net = [param.detach().clone() for param in net.parameters()]
+ for _ in range(epochs):
+ net, local_steps = _train_one_epoch_fednova(
+ net, trainloader, device, criterion, optimizer, local_steps
+ )
+ # compute ||a_i||_1
+ a_i = (
+ local_steps - (momentum * (1 - momentum**local_steps) / (1 - momentum))
+ ) / (1 - momentum)
+ # compute g_i
+ g_i = [
+ torch.div(prev_param - param.detach(), a_i)
+ for prev_param, param in zip(prev_net, net.parameters())
+ ]
+
+ return a_i, g_i
+
+
+def _train_one_epoch_fednova(
+ net: nn.Module,
+ trainloader: DataLoader,
+ device: torch.device,
+ criterion: nn.Module,
+ optimizer: Optimizer,
+ local_steps: int,
+) -> Tuple[nn.Module, int]:
+ # pylint: disable=too-many-arguments
+ """Train the network on the training set for one epoch."""
+ for data, target in trainloader:
+ data, target = data.to(device), target.to(device)
+ optimizer.zero_grad()
+ output = net(data)
+ loss = criterion(output, target)
+ loss.backward()
+ optimizer.step()
+ local_steps += 1
+ return net, local_steps
+
+
+def test(
+ net: nn.Module, testloader: DataLoader, device: torch.device
+) -> Tuple[float, float]:
+ """Evaluate the network on the test set.
+
+ Parameters
+ ----------
+ net : nn.Module
+ The neural network to evaluate.
+ testloader : DataLoader
+ The test set dataloader object.
+ device : torch.device
+ The device on which to evaluate the network.
+
+ Returns
+ -------
+ Tuple[float, float]
+ The loss and accuracy of the network on the test set.
+ """
+ criterion = nn.CrossEntropyLoss(reduction="sum")
+ net.eval()
+ correct, total, loss = 0, 0, 0.0
+ with torch.no_grad():
+ for data, target in testloader:
+ data, target = data.to(device), target.to(device)
+ output = net(data)
+ loss += criterion(output, target).item()
+ _, predicted = torch.max(output.data, 1)
+ total += target.size(0)
+ correct += (predicted == target).sum().item()
+ loss = loss / total
+ acc = correct / total
+ return loss, acc
diff --git a/baselines/niid_bench/niid_bench/server.py b/baselines/niid_bench/niid_bench/server.py
new file mode 100644
index 000000000000..83ed6bb122da
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/server.py
@@ -0,0 +1,4 @@
+"""Server script.
+
+This is not used in this baseline. Please refer to the strategy-specific server files.
+"""
diff --git a/baselines/niid_bench/niid_bench/server_fednova.py b/baselines/niid_bench/niid_bench/server_fednova.py
new file mode 100644
index 000000000000..2fdea9a76c41
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/server_fednova.py
@@ -0,0 +1,78 @@
+"""Server class for FedNova."""
+
+from logging import DEBUG, INFO
+
+from flwr.common import parameters_to_ndarrays
+from flwr.common.logger import log
+from flwr.common.typing import Dict, Optional, Parameters, Scalar, Tuple
+from flwr.server.client_manager import ClientManager
+from flwr.server.server import FitResultsAndFailures, Server, fit_clients
+
+from niid_bench.strategy import FedNovaStrategy
+
+
+class FedNovaServer(Server):
+ """Implement server for FedNova."""
+
+ def __init__(
+ self,
+ *,
+ client_manager: ClientManager,
+ strategy: Optional[FedNovaStrategy] = None,
+ ) -> None:
+ super().__init__(client_manager=client_manager, strategy=strategy)
+ self.strategy: FedNovaStrategy = (
+ strategy if strategy is not None else FedNovaStrategy()
+ )
+
+ def fit_round(
+ self,
+ server_round: int,
+ timeout: Optional[float],
+ ) -> Optional[
+ Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures]
+ ]:
+ """Perform a single round of federated averaging."""
+ # Get clients and their respective instructions from strategy
+ client_instructions = self.strategy.configure_fit(
+ server_round=server_round,
+ parameters=self.parameters,
+ client_manager=self._client_manager,
+ )
+
+ if not client_instructions:
+ log(INFO, "fit_round %s: no clients selected, cancel", server_round)
+ return None
+ log(
+ DEBUG,
+ "fit_round %s: strategy sampled %s clients (out of %s)",
+ server_round,
+ len(client_instructions),
+ self._client_manager.num_available(),
+ )
+
+ # Collect `fit` results from all clients participating in this round
+ results, failures = fit_clients(
+ client_instructions=client_instructions,
+ max_workers=self.max_workers,
+ timeout=timeout,
+ )
+ log(
+ DEBUG,
+ "fit_round %s received %s results and %s failures",
+ server_round,
+ len(results),
+ len(failures),
+ )
+
+ params_np = parameters_to_ndarrays(self.parameters)
+ # Aggregate training results
+ aggregated_result: Tuple[
+ Optional[Parameters],
+ Dict[str, Scalar],
+ ] = self.strategy.aggregate_fit_custom(
+ server_round, params_np, results, failures
+ )
+
+ parameters_aggregated, metrics_aggregated = aggregated_result
+ return parameters_aggregated, metrics_aggregated, (results, failures)
diff --git a/baselines/niid_bench/niid_bench/server_scaffold.py b/baselines/niid_bench/niid_bench/server_scaffold.py
new file mode 100644
index 000000000000..b55bc538d081
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/server_scaffold.py
@@ -0,0 +1,268 @@
+"""Server class for SCAFFOLD."""
+
+import concurrent.futures
+from logging import DEBUG, INFO
+from typing import OrderedDict
+
+import torch
+from flwr.common import (
+ Code,
+ FitIns,
+ FitRes,
+ Parameters,
+ Scalar,
+ ndarrays_to_parameters,
+ parameters_to_ndarrays,
+)
+from flwr.common.logger import log
+from flwr.common.typing import (
+ Callable,
+ Dict,
+ GetParametersIns,
+ List,
+ NDArrays,
+ Optional,
+ Tuple,
+ Union,
+)
+from flwr.server import Server
+from flwr.server.client_manager import ClientManager, SimpleClientManager
+from flwr.server.client_proxy import ClientProxy
+from flwr.server.strategy import Strategy
+from hydra.utils import instantiate
+from omegaconf import DictConfig
+from torch.utils.data import DataLoader
+
+from niid_bench.models import test
+
+FitResultsAndFailures = Tuple[
+ List[Tuple[ClientProxy, FitRes]],
+ List[Union[Tuple[ClientProxy, FitRes], BaseException]],
+]
+
+
+class ScaffoldServer(Server):
+ """Implement server for SCAFFOLD."""
+
+ def __init__(
+ self,
+ strategy: Strategy,
+ model: DictConfig,
+ client_manager: Optional[ClientManager] = None,
+ ):
+ if client_manager is None:
+ client_manager = SimpleClientManager()
+ super().__init__(client_manager=client_manager, strategy=strategy)
+ self.model_params = instantiate(model)
+ self.server_cv: List[torch.Tensor] = []
+
+ def _get_initial_parameters(self, timeout: Optional[float]) -> Parameters:
+ """Get initial parameters from one of the available clients."""
+ # Server-side parameter initialization
+ parameters: Optional[Parameters] = self.strategy.initialize_parameters(
+ client_manager=self._client_manager
+ )
+ if parameters is not None:
+ log(INFO, "Using initial parameters provided by strategy")
+ return parameters
+
+ # Get initial parameters from one of the clients
+ log(INFO, "Requesting initial parameters from one random client")
+ random_client = self._client_manager.sample(1)[0]
+ ins = GetParametersIns(config={})
+ get_parameters_res = random_client.get_parameters(ins=ins, timeout=timeout)
+ log(INFO, "Received initial parameters from one random client")
+ self.server_cv = [
+ torch.from_numpy(t)
+ for t in parameters_to_ndarrays(get_parameters_res.parameters)
+ ]
+ return get_parameters_res.parameters
+
+ # pylint: disable=too-many-locals
+ def fit_round(
+ self,
+ server_round: int,
+ timeout: Optional[float],
+ ) -> Optional[
+ Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures]
+ ]:
+ """Perform a single round of federated averaging."""
+ # Get clients and their respective instructions from strateg
+ client_instructions = self.strategy.configure_fit(
+ server_round=server_round,
+ parameters=update_parameters_with_cv(self.parameters, self.server_cv),
+ client_manager=self._client_manager,
+ )
+
+ if not client_instructions:
+ log(INFO, "fit_round %s: no clients selected, cancel", server_round)
+ return None
+ log(
+ DEBUG,
+ "fit_round %s: strategy sampled %s clients (out of %s)",
+ server_round,
+ len(client_instructions),
+ self._client_manager.num_available(),
+ )
+
+ # Collect `fit` results from all clients participating in this round
+ results, failures = fit_clients(
+ client_instructions=client_instructions,
+ max_workers=self.max_workers,
+ timeout=timeout,
+ )
+ log(
+ DEBUG,
+ "fit_round %s received %s results and %s failures",
+ server_round,
+ len(results),
+ len(failures),
+ )
+
+ # Aggregate training results
+ aggregated_result: Tuple[
+ Optional[Parameters], Dict[str, Scalar]
+ ] = self.strategy.aggregate_fit(server_round, results, failures)
+
+ aggregated_result_arrays_combined = []
+ if aggregated_result[0] is not None:
+ aggregated_result_arrays_combined = parameters_to_ndarrays(
+ aggregated_result[0]
+ )
+ aggregated_parameters = aggregated_result_arrays_combined[
+ : len(aggregated_result_arrays_combined) // 2
+ ]
+ aggregated_cv_update = aggregated_result_arrays_combined[
+ len(aggregated_result_arrays_combined) // 2 :
+ ]
+
+ # convert server cv into ndarrays
+ server_cv_np = [cv.numpy() for cv in self.server_cv]
+ # update server cv
+ total_clients = len(self._client_manager.all())
+ cv_multiplier = len(results) / total_clients
+ self.server_cv = [
+ torch.from_numpy(cv + cv_multiplier * aggregated_cv_update[i])
+ for i, cv in enumerate(server_cv_np)
+ ]
+
+ # update parameters x = x + 1* aggregated_update
+ curr_params = parameters_to_ndarrays(self.parameters)
+ updated_params = [
+ x + aggregated_parameters[i] for i, x in enumerate(curr_params)
+ ]
+ parameters_updated = ndarrays_to_parameters(updated_params)
+
+ # metrics
+ metrics_aggregated = aggregated_result[1]
+ return parameters_updated, metrics_aggregated, (results, failures)
+
+
+def update_parameters_with_cv(
+ parameters: Parameters, s_cv: List[torch.Tensor]
+) -> Parameters:
+ """Extend the list of parameters with the server control variate."""
+ # extend the list of parameters arrays with the cv arrays
+ cv_np = [cv.numpy() for cv in s_cv]
+ parameters_np = parameters_to_ndarrays(parameters)
+ parameters_np.extend(cv_np)
+ return ndarrays_to_parameters(parameters_np)
+
+
+def fit_clients(
+ client_instructions: List[Tuple[ClientProxy, FitIns]],
+ max_workers: Optional[int],
+ timeout: Optional[float],
+) -> FitResultsAndFailures:
+ """Refine parameters concurrently on all selected clients."""
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
+ submitted_fs = {
+ executor.submit(fit_client, client_proxy, ins, timeout)
+ for client_proxy, ins in client_instructions
+ }
+ finished_fs, _ = concurrent.futures.wait(
+ fs=submitted_fs,
+ timeout=None, # Handled in the respective communication stack
+ )
+
+ # Gather results
+ results: List[Tuple[ClientProxy, FitRes]] = []
+ failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = []
+ for future in finished_fs:
+ _handle_finished_future_after_fit(
+ future=future, results=results, failures=failures
+ )
+ return results, failures
+
+
+def fit_client(
+ client: ClientProxy, ins: FitIns, timeout: Optional[float]
+) -> Tuple[ClientProxy, FitRes]:
+ """Refine parameters on a single client."""
+ fit_res = client.fit(ins, timeout=timeout)
+ return client, fit_res
+
+
+def _handle_finished_future_after_fit(
+ future: concurrent.futures.Future, # type: ignore
+ results: List[Tuple[ClientProxy, FitRes]],
+ failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
+) -> None:
+ """Convert finished future into either a result or a failure."""
+ # Check if there was an exception
+ failure = future.exception()
+ if failure is not None:
+ failures.append(failure)
+ return
+
+ # Successfully received a result from a client
+ result: Tuple[ClientProxy, FitRes] = future.result()
+ _, res = result
+
+ # Check result status code
+ if res.status.code == Code.OK:
+ results.append(result)
+ return
+
+ # Not successful, client returned a result where the status code is not OK
+ failures.append(result)
+
+
+def gen_evaluate_fn(
+ testloader: DataLoader,
+ device: torch.device,
+ model: DictConfig,
+) -> Callable[
+ [int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]
+]:
+ """Generate the function for centralized evaluation.
+
+ Parameters
+ ----------
+ testloader : DataLoader
+ The dataloader to test the model with.
+ device : torch.device
+ The device to test the model on.
+
+ Returns
+ -------
+ Callable[ [int, NDArrays, Dict[str, Scalar]],
+ Optional[Tuple[float, Dict[str, Scalar]]] ]
+ The centralized evaluation function.
+ """
+
+ def evaluate(
+ server_round: int, parameters_ndarrays: NDArrays, config: Dict[str, Scalar]
+ ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
+ # pylint: disable=unused-argument
+ """Use the entire Emnist test set for evaluation."""
+ net = instantiate(model)
+ params_dict = zip(net.state_dict().keys(), parameters_ndarrays)
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
+ net.load_state_dict(state_dict, strict=True)
+ net.to(device)
+
+ loss, accuracy = test(net, testloader, device=device)
+ return loss, {"accuracy": accuracy}
+
+ return evaluate
diff --git a/baselines/niid_bench/niid_bench/strategy.py b/baselines/niid_bench/niid_bench/strategy.py
new file mode 100644
index 000000000000..3867772cd5e6
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/strategy.py
@@ -0,0 +1,141 @@
+"""FedNova and SCAFFOLD strategies."""
+
+from functools import reduce
+from logging import WARNING
+
+import numpy as np
+from flwr.common import (
+ FitRes,
+ NDArrays,
+ Parameters,
+ Scalar,
+ ndarrays_to_parameters,
+ parameters_to_ndarrays,
+)
+from flwr.common.logger import log
+from flwr.common.typing import Dict, List, Optional, Tuple, Union
+from flwr.server.client_proxy import ClientProxy
+from flwr.server.strategy import FedAvg
+from flwr.server.strategy.aggregate import aggregate
+
+
+class FedNovaStrategy(FedAvg):
+ """Custom FedAvg strategy with fednova based configuration and aggregation."""
+
+ def aggregate_fit_custom(
+ self,
+ server_round: int,
+ server_params: NDArrays,
+ results: List[Tuple[ClientProxy, FitRes]],
+ failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
+ ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
+ """Aggregate fit results using weighted average."""
+ if not results:
+ return None, {}
+ # Do not aggregate if there are failures and failures are not accepted
+ if not self.accept_failures and failures:
+ return None, {}
+
+ # Convert results
+ weights_results = [
+ (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
+ for _, fit_res in results
+ ]
+ total_samples = sum([fit_res.num_examples for _, fit_res in results])
+ c_fact = sum(
+ [
+ float(fit_res.metrics["a_i"]) * fit_res.num_examples / total_samples
+ for _, fit_res in results
+ ]
+ )
+ new_weights_results = [
+ (result[0], c_fact * (fit_res.num_examples / total_samples))
+ for result, (_, fit_res) in zip(weights_results, results)
+ ]
+
+ # Aggregate grad updates, t_eff*(sum_i(p_i*\eta*d_i))
+ grad_updates_aggregated = aggregate_fednova(new_weights_results)
+ # Final parameters = server_params - grad_updates_aggregated
+ aggregated = [
+ server_param - grad_update
+ for server_param, grad_update in zip(server_params, grad_updates_aggregated)
+ ]
+
+ parameters_aggregated = ndarrays_to_parameters(aggregated)
+ # Aggregate custom metrics if aggregation fn was provided
+ metrics_aggregated = {}
+ if self.fit_metrics_aggregation_fn:
+ fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
+ metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
+ elif server_round == 1: # Only log this warning once
+ log(WARNING, "No fit_metrics_aggregation_fn provided")
+
+ return parameters_aggregated, metrics_aggregated
+
+
+def aggregate_fednova(results: List[Tuple[NDArrays, float]]) -> NDArrays:
+ """Implement custom aggregate function for FedNova."""
+ # Create a list of weights, each multiplied by the weight_factor
+ weighted_weights = [
+ [layer * factor for layer in weights] for weights, factor in results
+ ]
+
+ # Compute average weights of each layer
+ weights_prime: NDArrays = [
+ reduce(np.add, layer_updates) for layer_updates in zip(*weighted_weights)
+ ]
+ return weights_prime
+
+
+class ScaffoldStrategy(FedAvg):
+ """Implement custom strategy for SCAFFOLD based on FedAvg class."""
+
+ def aggregate_fit(
+ self,
+ server_round: int,
+ results: List[Tuple[ClientProxy, FitRes]],
+ failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
+ ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
+ """Aggregate fit results using weighted average."""
+ if not results:
+ return None, {}
+ # Do not aggregate if there are failures and failures are not accepted
+ if not self.accept_failures and failures:
+ return None, {}
+
+ combined_parameters_all_updates = [
+ parameters_to_ndarrays(fit_res.parameters) for _, fit_res in results
+ ]
+ len_combined_parameter = len(combined_parameters_all_updates[0])
+ num_examples_all_updates = [fit_res.num_examples for _, fit_res in results]
+ # Zip parameters and num_examples
+ weights_results = [
+ (update[: len_combined_parameter // 2], num_examples)
+ for update, num_examples in zip(
+ combined_parameters_all_updates, num_examples_all_updates
+ )
+ ]
+ # Aggregate parameters
+ parameters_aggregated = aggregate(weights_results)
+
+ # Zip client_cv_updates and num_examples
+ client_cv_updates_and_num_examples = [
+ (update[len_combined_parameter // 2 :], num_examples)
+ for update, num_examples in zip(
+ combined_parameters_all_updates, num_examples_all_updates
+ )
+ ]
+ aggregated_cv_update = aggregate(client_cv_updates_and_num_examples)
+
+ # Aggregate custom metrics if aggregation fn was provided
+ metrics_aggregated = {}
+ if self.fit_metrics_aggregation_fn:
+ fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
+ metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
+ elif server_round == 1: # Only log this warning once
+ log(WARNING, "No fit_metrics_aggregation_fn provided")
+
+ return (
+ ndarrays_to_parameters(parameters_aggregated + aggregated_cv_update),
+ metrics_aggregated,
+ )
diff --git a/baselines/niid_bench/niid_bench/utils.py b/baselines/niid_bench/niid_bench/utils.py
new file mode 100644
index 000000000000..9a831719d623
--- /dev/null
+++ b/baselines/niid_bench/niid_bench/utils.py
@@ -0,0 +1,6 @@
+"""Define any utility function.
+
+They are not directly relevant to the other (more FL specific) python modules. For
+example, you may define here things like: loading a model from a checkpoint, saving
+results, plotting.
+"""
diff --git a/baselines/niid_bench/pyproject.toml b/baselines/niid_bench/pyproject.toml
new file mode 100644
index 000000000000..adb001e031ff
--- /dev/null
+++ b/baselines/niid_bench/pyproject.toml
@@ -0,0 +1,140 @@
+[build-system]
+requires = ["poetry-core>=1.4.0"]
+build-backend = "poetry.masonry.api"
+
+[tool.poetry]
+name = "niid_bench" # <----- Ensure it matches the name of your baseline directory containing all the source code
+version = "1.0.0"
+description = "Federated Learning on Non-IID Data Silos: An Experimental Study"
+license = "Apache-2.0"
+authors = ["Aashish Kolluri "]
+readme = "README.md"
+homepage = "https://flower.dev"
+repository = "https://github.com/adap/flower"
+documentation = "https://flower.dev"
+classifiers = [
+ "Development Status :: 3 - Alpha",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: MacOS :: MacOS X",
+ "Operating System :: POSIX :: Linux",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3 :: Only",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: Implementation :: CPython",
+ "Topic :: Scientific/Engineering",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Scientific/Engineering :: Mathematics",
+ "Topic :: Software Development",
+ "Topic :: Software Development :: Libraries",
+ "Topic :: Software Development :: Libraries :: Python Modules",
+ "Typing :: Typed",
+]
+
+[tool.poetry.dependencies]
+python = ">=3.10.0, <3.11.0" # don't change this
+flwr = { extras = ["simulation"], version = "1.5.0" }
+hydra-core = "1.3.2" # don't change this
+torch = { url = "https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp310-cp310-linux_x86_64.whl"}
+torchvision = { url = "https://download.pytorch.org/whl/cu113/torchvision-0.13.1%2Bcu113-cp310-cp310-linux_x86_64.whl"}
+tqdm = "4.66.1"
+
+[tool.poetry.dev-dependencies]
+isort = "==5.11.5"
+black = "==23.1.0"
+docformatter = "==1.5.1"
+mypy = "==1.4.1"
+pylint = "==2.8.2"
+flake8 = "==3.9.2"
+pytest = "==6.2.4"
+pytest-watch = "==4.2.0"
+ruff = "==0.0.272"
+types-requests = "==2.27.7"
+
+[tool.isort]
+line_length = 88
+indent = " "
+multi_line_output = 3
+include_trailing_comma = true
+force_grid_wrap = 0
+use_parentheses = true
+
+[tool.black]
+line-length = 88
+target-version = ["py38", "py39", "py310", "py311"]
+
+[tool.pytest.ini_options]
+minversion = "6.2"
+addopts = "-qq"
+testpaths = [
+ "flwr_baselines",
+]
+
+[tool.mypy]
+ignore_missing_imports = true
+strict = false
+plugins = "numpy.typing.mypy_plugin"
+
+[tool.pylint."MESSAGES CONTROL"]
+disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias"
+good-names = "i,j,k,_,x,y,X,Y"
+signature-mutators="hydra.main.main"
+
+[tool.pylint.typecheck]
+generated-members="numpy.*, torch.*, tensorflow.*"
+
+[[tool.mypy.overrides]]
+module = [
+ "importlib.metadata.*",
+ "importlib_metadata.*",
+]
+follow_imports = "skip"
+follow_imports_for_stubs = true
+disallow_untyped_calls = false
+
+[[tool.mypy.overrides]]
+module = "torch.*"
+follow_imports = "skip"
+follow_imports_for_stubs = true
+
+[tool.docformatter]
+wrap-summaries = 88
+wrap-descriptions = 88
+
+[tool.ruff]
+target-version = "py38"
+line-length = 88
+select = ["D", "E", "F", "W", "B", "ISC", "C4"]
+fixable = ["D", "E", "F", "W", "B", "ISC", "C4"]
+ignore = ["B024", "B027"]
+exclude = [
+ ".bzr",
+ ".direnv",
+ ".eggs",
+ ".git",
+ ".hg",
+ ".mypy_cache",
+ ".nox",
+ ".pants.d",
+ ".pytype",
+ ".ruff_cache",
+ ".svn",
+ ".tox",
+ ".venv",
+ "__pypackages__",
+ "_build",
+ "buck-out",
+ "build",
+ "dist",
+ "node_modules",
+ "venv",
+ "proto",
+]
+
+[tool.ruff.pydocstyle]
+convention = "numpy"
diff --git a/baselines/niid_bench/run_exp.py b/baselines/niid_bench/run_exp.py
new file mode 100644
index 000000000000..75ef2833f5a6
--- /dev/null
+++ b/baselines/niid_bench/run_exp.py
@@ -0,0 +1,109 @@
+"""Script to run all experiments in parallel."""
+
+import argparse
+import subprocess
+import time
+from collections import deque
+from itertools import product
+from typing import List
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--seed", type=int, default=42)
+parser.add_argument("--num-processes", type=int, default=5)
+parser_args = parser.parse_args()
+
+
+def get_commands(dataset_name, partitioning, labels_per_client, seed):
+ """Get commands for all experiments."""
+ cmds = [
+ (
+ f"python -m niid_bench.main --config-name fedavg_base "
+ f"partitioning={partitioning} "
+ f"dataset_seed={seed} "
+ f"dataset_name={dataset_name} "
+ f"labels_per_client={labels_per_client}"
+ ),
+ (
+ f"python -m niid_bench.main --config-name scaffold_base "
+ f"partitioning={partitioning} "
+ f"dataset_seed={seed} "
+ f"dataset_name={dataset_name} "
+ f"labels_per_client={labels_per_client}"
+ ),
+ (
+ f"python -m niid_bench.main --config-name fedprox_base "
+ f"partitioning={partitioning} "
+ f"dataset_seed={seed} "
+ f"dataset_name={dataset_name} "
+ f"labels_per_client={labels_per_client}"
+ ),
+ (
+ f"python -m niid_bench.main --config-name fedprox_base "
+ f"partitioning={partitioning} "
+ f"mu=0.1 "
+ f"dataset_seed={seed} "
+ f"dataset_name={dataset_name} "
+ f"labels_per_client={labels_per_client}"
+ ),
+ (
+ f"python -m niid_bench.main --config-name fedprox_base "
+ f"partitioning={partitioning} "
+ f"mu=0.001 "
+ f"dataset_seed={seed} "
+ f"dataset_name={dataset_name} "
+ f"labels_per_client={labels_per_client}"
+ ),
+ (
+ f"python -m niid_bench.main --config-name fedprox_base "
+ f"partitioning={partitioning} "
+ f"mu=1.0 "
+ f"dataset_seed={seed} "
+ f"dataset_name={dataset_name} "
+ f"labels_per_client={labels_per_client}"
+ ),
+ (
+ f"python -m niid_bench.main --config-name fednova_base "
+ f"partitioning={partitioning} "
+ f"dataset_seed={seed} "
+ f"dataset_name={dataset_name} "
+ f"labels_per_client={labels_per_client}"
+ ),
+ ]
+ return cmds
+
+
+dataset_names = ["cifar10", "mnist", "fmnist"]
+partitionings = [
+ "iid",
+ "dirichlet",
+ "label_quantity_1",
+ "label_quantity_2",
+ "label_quantity_3",
+]
+
+commands: deque = deque()
+for partitioning, dataset_name in product(partitionings, dataset_names):
+ labels_per_client = -1
+ if "label_quantity" in partitioning:
+ labels_per_client = int(partitioning.split("_")[-1])
+ partitioning = "label_quantity"
+ args = (dataset_name, partitioning, labels_per_client, parser_args.seed)
+ cmds = get_commands(*args)
+ for cmd in cmds:
+ commands.append(cmd)
+
+MAX_PROCESSES_AT_ONCE = parser_args.num_processes
+
+# run max_processes_at_once processes at once with 10 second sleep interval
+# in between those processes until all commands are done
+processes: List = []
+while len(commands) > 0:
+ while len(processes) < MAX_PROCESSES_AT_ONCE and len(commands) > 0:
+ cmd = commands.popleft()
+ print(cmd)
+ processes.append(subprocess.Popen(cmd, shell=True))
+ # sleep for 10 seconds to give the process time to start
+ time.sleep(10)
+ for p in processes:
+ if p.poll() is not None:
+ processes.remove(p)
diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md
index 06e77fefedf0..6b2432ab34fa 100644
--- a/doc/source/ref-changelog.md
+++ b/doc/source/ref-changelog.md
@@ -44,6 +44,8 @@
- FedWav2vec [#2551](https://github.com/adap/flower/pull/2551)
+ - niid-Bench [#2428](https://github.com/adap/flower/pull/2428)
+
- **Update Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384),[#2425](https://github.com/adap/flower/pull/2425), [#2526](https://github.com/adap/flower/pull/2526))
- **General updates to baselines** ([#2301](https://github.com/adap/flower/pull/2301), [#2305](https://github.com/adap/flower/pull/2305), [#2307](https://github.com/adap/flower/pull/2307), [#2327](https://github.com/adap/flower/pull/2327), [#2435](https://github.com/adap/flower/pull/2435))