From da9b3269cb2fb32a8154a83c9966610cb5184331 Mon Sep 17 00:00:00 2001 From: Peterpan828 <59055419+Peterpan828@users.noreply.github.com> Date: Mon, 30 Oct 2023 18:50:17 +0900 Subject: [PATCH] Add DepthFL baseline (#2295) Co-authored-by: Peterpan828 <59055419+Peterpan828@users.noreply.github.com> Co-authored-by: jafermarq --- baselines/depthfl/.gitignore | 4 + baselines/depthfl/LICENSE | 202 +++++++++ baselines/depthfl/README.md | 171 ++++++++ baselines/depthfl/depthfl/__init__.py | 1 + baselines/depthfl/depthfl/client.py | 181 ++++++++ baselines/depthfl/depthfl/conf/config.yaml | 42 ++ baselines/depthfl/depthfl/conf/heterofl.yaml | 43 ++ baselines/depthfl/depthfl/dataset.py | 60 +++ .../depthfl/depthfl/dataset_preparation.py | 125 ++++++ baselines/depthfl/depthfl/main.py | 135 ++++++ baselines/depthfl/depthfl/models.py | 301 ++++++++++++++ baselines/depthfl/depthfl/resnet.py | 386 ++++++++++++++++++ baselines/depthfl/depthfl/resnet_hetero.py | 280 +++++++++++++ baselines/depthfl/depthfl/server.py | 209 ++++++++++ baselines/depthfl/depthfl/strategy.py | 136 ++++++ baselines/depthfl/depthfl/strategy_hetero.py | 136 ++++++ baselines/depthfl/depthfl/utils.py | 66 +++ baselines/depthfl/pyproject.toml | 141 +++++++ doc/source/ref-changelog.md | 2 + 19 files changed, 2621 insertions(+) create mode 100644 baselines/depthfl/.gitignore create mode 100644 baselines/depthfl/LICENSE create mode 100644 baselines/depthfl/README.md create mode 100644 baselines/depthfl/depthfl/__init__.py create mode 100644 baselines/depthfl/depthfl/client.py create mode 100644 baselines/depthfl/depthfl/conf/config.yaml create mode 100644 baselines/depthfl/depthfl/conf/heterofl.yaml create mode 100644 baselines/depthfl/depthfl/dataset.py create mode 100644 baselines/depthfl/depthfl/dataset_preparation.py create mode 100644 baselines/depthfl/depthfl/main.py create mode 100644 baselines/depthfl/depthfl/models.py create mode 100644 baselines/depthfl/depthfl/resnet.py create mode 100644 baselines/depthfl/depthfl/resnet_hetero.py create mode 100644 baselines/depthfl/depthfl/server.py create mode 100644 baselines/depthfl/depthfl/strategy.py create mode 100644 baselines/depthfl/depthfl/strategy_hetero.py create mode 100644 baselines/depthfl/depthfl/utils.py create mode 100644 baselines/depthfl/pyproject.toml diff --git a/baselines/depthfl/.gitignore b/baselines/depthfl/.gitignore new file mode 100644 index 000000000000..fb7448bbcb01 --- /dev/null +++ b/baselines/depthfl/.gitignore @@ -0,0 +1,4 @@ +dataset/ +outputs/ +prev_grads/ +multirun/ \ No newline at end of file diff --git a/baselines/depthfl/LICENSE b/baselines/depthfl/LICENSE new file mode 100644 index 000000000000..d64569567334 --- /dev/null +++ b/baselines/depthfl/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/baselines/depthfl/README.md b/baselines/depthfl/README.md new file mode 100644 index 000000000000..b8ab7ed18571 --- /dev/null +++ b/baselines/depthfl/README.md @@ -0,0 +1,171 @@ +--- +title: DepthFL:Depthwise Federated Learning for Heterogeneous Clients +url: https://openreview.net/forum?id=pf8RIZTMU58 +labels: [image classification, system heterogeneity, cross-device, knowledge distillation] +dataset: [CIFAR-100] +--- + +# DepthFL: Depthwise Federated Learning for Heterogeneous Clients + +> Note: If you use this baseline in your work, please remember to cite the original authors of the paper as well as the Flower paper. + +**Paper:** [openreview.net/forum?id=pf8RIZTMU58](https://openreview.net/forum?id=pf8RIZTMU58) + +**Authors:** Minjae Kim, Sangyoon Yu, Suhyun Kim, Soo-Mook Moon + +**Abstract:** Federated learning is for training a global model without collecting private local data from clients. As they repeatedly need to upload locally-updated weights or gradients instead, clients require both computation and communication resources enough to participate in learning, but in reality their resources are heterogeneous. To enable resource-constrained clients to train smaller local models, width scaling techniques have been used, which reduces the channels of a global model. Unfortunately, width scaling suffers from heterogeneity of local models when averaging them, leading to a lower accuracy than when simply excluding resource-constrained clients from training. This paper proposes a new approach based on depth scaling called DepthFL. DepthFL defines local models of different depths by pruning the deepest layers off the global model, and allocates them to clients depending on their available resources. Since many clients do not have enough resources to train deep local models, this would make deep layers partially-trained with insufficient data, unlike shallow layers that are fully trained. DepthFL alleviates this problem by mutual self-distillation of knowledge among the classifiers of various depths within a local model. Our experiments show that depth-scaled local models build a global model better than width-scaled ones, and that self-distillation is highly effective in training data-insufficient deep layers. + + +## About this baseline + +**What’s implemented:** The code in this directory replicates the experiments in DepthFL: Depthwise Federated Learning for Heterogeneous Clients (Kim et al., 2023) for CIFAR100, which proposed the DepthFL algorithm. Concretely, it replicates the results for CIFAR100 dataset in Table 2, 3 and 4. + +**Datasets:** CIFAR100 from PyTorch's Torchvision + +**Hardware Setup:** These experiments were run on a server with Nvidia 3090 GPUs. Any machine with 1x 8GB GPU or more would be able to run it in a reasonable amount of time. With the default settings, clients make use of 1.3GB of VRAM. Lower `num_gpus` in `client_resources` to train more clients in parallel on your GPU(s). + +**Contributors:** Minjae Kim + + +## Experimental Setup + +**Task:** Image Classification + +**Model:** ResNet18 + +**Dataset:** This baseline only includes the CIFAR100 dataset. By default it will be partitioned into 100 clients following IID distribution. The settings are as follow: + +| Dataset | #classes | #partitions | partitioning method | +| :------ | :---: | :---: | :---: | +| CIFAR100 | 100 | 100 | IID or Non-IID | + +**Training Hyperparameters:** +The following table shows the main hyperparameters for this baseline with their default value (i.e. the value used if you run `python -m depthfl.main` directly) + +| Description | Default Value | +| ----------- | ----- | +| total clients | 100 | +| local epoch | 5 | +| batch size | 50 | +| number of rounds | 1000 | +| participation ratio | 10% | +| learning rate | 0.1 | +| learning rate decay | 0.998 | +| client resources | {'num_cpus': 1.0, 'num_gpus': 0.5 }| +| data partition | IID | +| optimizer | SGD with dynamic regularization | +| alpha | 0.1 | + + +## Environment Setup + +To construct the Python environment follow these steps: + +```bash +# Set python version +pyenv install 3.10.6 +pyenv local 3.10.6 + +# Tell poetry to use python 3.10 +poetry env use 3.10.6 + +# Install the base Poetry environment +poetry install + +# Activate the environment +poetry shell +``` + + +## Running the Experiments + +To run this DepthFL, first ensure you have activated your Poetry environment (execute `poetry shell` from this directory), then: + +```bash +# this will run using the default settings in the `conf/config.yaml` +python -m depthfl.main # 'accuracy' : accuracy of the ensemble model, 'accuracy_single' : accuracy of each classifier. + +# you can override settings directly from the command line +python -m depthfl.main exclusive_learning=true model_size=1 # exclusive learning - 100% (a) +python -m depthfl.main exclusive_learning=true model_size=4 # exclusive learning - 25% (d) +python -m depthfl.main fit_config.feddyn=false fit_config.kd=false # DepthFL (FedAvg) +python -m depthfl.main fit_config.feddyn=false fit_config.kd=false fit_config.extended=false # InclusiveFL +``` + +To run using HeteroFL: +```bash +# since sbn takes too long, we test global model every 50 rounds. +python -m depthfl.main --config-name="heterofl" # HeteroFL +python -m depthfl.main --config-name="heterofl" exclusive_learning=true model_size=1 # exclusive learning - 100% (a) +``` + +### Stateful clients comment + +To implement `feddyn`, stateful clients that store prev_grads information are needed. Since flwr does not yet officially support stateful clients, it was implemented as a temporary measure by loading `prev_grads` from disk when creating a client, and then storing it again on disk after learning. Specifically, there are files that store the state of each client in the `prev_grads` folder. When the strategy is instantiated (for both `FedDyn` and `HeteroFL`) the content of `prev_grads` is reset. + + +## Expected Results + +With the following command we run DepthFL (FedDyn / FedAvg), InclusiveFL, and HeteroFL to replicate the results of table 2,3,4 in DepthFL paper. Tables 2, 3, and 4 may contain results from the same experiment in multiple tables. + +```bash +# table 2 (HeteroFL row) +python -m depthfl.main --config-name="heterofl" +python -m depthfl.main --config-name="heterofl" --multirun exclusive_learning=true model.scale=false model_size=1,2,3,4 + +# table 2 (DepthFL(FedAvg) row) +python -m depthfl.main fit_config.feddyn=false fit_config.kd=false +python -m depthfl.main --multirun fit_config.feddyn=false fit_config.kd=false exclusive_learning=true model_size=1,2,3,4 + +# table 2 (DepthFL row) +python -m depthfl.main +python -m depthfl.main --multirun exclusive_learning=true model_size=1,2,3,4 +``` + +**Table 2** + +100% (a), 75%(b), 50%(c), 25% (d) cases are exclusive learning scenario. 100% (a) exclusive learning means, the global model and every local model are equal to the smallest local model, and 100% clients participate in learning. Likewise, 25% (d) exclusive learning means, the global model and every local model are equal to the larget local model, and only 25% clients participate in learning. + +| Scaling Method | Dataset | Global Model | 100% (a) | 75% (b) | 50% (c) | 25% (d) | +| :---: | :---: | :---: | :---: | :---: | :---: | :---: | +| HeteroFL
DepthFL (FedAvg)
DepthFL | CIFAR100 | 57.61
72.67
76.06 | 64.39
67.08
69.68 | 66.08
70.78
73.21 | 62.03
68.41
70.29 | 51.99
59.17
60.32 | + +```bash +# table 3 (Width Scaling - Duplicate results from table 2) +python -m depthfl.main --config-name="heterofl" +python -m depthfl.main --config-name="heterofl" --multirun exclusive_learning=true model.scale=false model_size=1,2,3,4 + +# table 3 (Depth Scaling : Exclusive Learning, DepthFL(FedAvg) rows - Duplicate results from table 2) +python -m depthfl.main fit_config.feddyn=false fit_config.kd=false +python -m depthfl.main --multirun fit_config.feddyn=false fit_config.kd=false exclusive_learning=true model_size=1,2,3,4 + +## table 3 (Depth Scaling - InclusiveFL row) +python -m depthfl.main fit_config.feddyn=false fit_config.kd=false fit_config.extended=false +``` + +**Table 3** + +Accuracy of global sub-models compared to exclusive learning on CIFAR-100. + +| Method | Algorithm | Classifier 1/4 | Classifier 2/4 | Classifier 3/4 | Classifier 4/4 | +| :---: | :---: | :---: | :---: | :---: | :---: | +| Width Scaling | Exclusive Learning
HeteroFL| 64.39
51.08 | 66.08
55.89 | 62.03
58.29 | 51.99
57.61 | + +| Method | Algorithm | Classifier 1/4 | Classifier 2/4 | Classifier 3/4 | Classifier 4/4 | +| :---: | :---: | :---: | :---: | :---: | :---: | +| Depth Scaling | Exclusive Learning
InclusiveFL
DepthFL (FedAvg) | 67.08
47.61
66.18 | 68.00
53.88
67.56 | 66.19
59.48
67.97 | 56.78
60.46
68.01 | + +```bash +# table 4 +python -m depthfl.main --multirun fit_config.kd=true,false dataset_config.iid=true,false +``` + +**Table 4** + +Accuracy of the global model with/without self distillation on CIFAR-100. + +| Distribution | Dataset | KD | Classifier 1/4 | Classifier 2/4 | Classifier 3/4 | Classifier 4/4 | Ensemble | +| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | +| IID | CIFAR100 | ✗
✓ | 70.13
71.74 | 69.63
73.35 | 68.92
73.57 | 68.92
73.55 | 74.48
76.06 | +| non-IID | CIFAR100 | ✗
✓ | 67.94
70.33 | 68.68
71.88 | 68.46
72.43 | 67.78
72.34 | 73.18
74.92 | + diff --git a/baselines/depthfl/depthfl/__init__.py b/baselines/depthfl/depthfl/__init__.py new file mode 100644 index 000000000000..3343905e1879 --- /dev/null +++ b/baselines/depthfl/depthfl/__init__.py @@ -0,0 +1 @@ +"""Flower summer of reproducibility : DepthFL (ICLR' 23).""" diff --git a/baselines/depthfl/depthfl/client.py b/baselines/depthfl/depthfl/client.py new file mode 100644 index 000000000000..481ac90f1c79 --- /dev/null +++ b/baselines/depthfl/depthfl/client.py @@ -0,0 +1,181 @@ +"""Defines the DepthFL Flower Client and a function to instantiate it.""" + +import copy +import pickle +from collections import OrderedDict +from typing import Callable, Dict, List, Tuple + +import flwr as fl +import numpy as np +import torch +from flwr.common.typing import NDArrays, Scalar +from hydra.utils import instantiate +from omegaconf import DictConfig +from torch.utils.data import DataLoader + +from depthfl.models import test, train + + +def prune(state_dict, param_idx): + """Prune width of DNN (for HeteroFL).""" + ret_dict = {} + for k in state_dict.keys(): + if "num" not in k: + ret_dict[k] = state_dict[k][torch.meshgrid(param_idx[k])] + else: + ret_dict[k] = state_dict[k] + return copy.deepcopy(ret_dict) + + +class FlowerClient( + fl.client.NumPyClient +): # pylint: disable=too-many-instance-attributes + """Standard Flower client for CNN training.""" + + def __init__( + self, + net: torch.nn.Module, + trainloader: DataLoader, + valloader: DataLoader, + device: torch.device, + num_epochs: int, + learning_rate: float, + learning_rate_decay: float, + prev_grads: Dict, + cid: int, + ): # pylint: disable=too-many-arguments + self.net = net + self.trainloader = trainloader + self.valloader = valloader + self.device = device + self.num_epochs = num_epochs + self.learning_rate = learning_rate + self.learning_rate_decay = learning_rate_decay + self.prev_grads = prev_grads + self.cid = cid + self.param_idx = {} + state_dict = net.state_dict() + + # for HeteroFL + for k in state_dict.keys(): + self.param_idx[k] = [ + torch.arange(size) for size in state_dict[k].shape + ] # store client's weights' shape (for HeteroFL) + + def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: + """Return the parameters of the current net.""" + return [val.cpu().numpy() for _, val in self.net.state_dict().items()] + + def set_parameters(self, parameters: NDArrays) -> None: + """Change the parameters of the model using the given ones.""" + params_dict = zip(self.net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + self.net.load_state_dict(prune(state_dict, self.param_idx), strict=True) + + def fit( + self, parameters: NDArrays, config: Dict[str, Scalar] + ) -> Tuple[NDArrays, int, Dict]: + """Implement distributed fit function for a given client.""" + self.set_parameters(parameters) + num_epochs = self.num_epochs + + curr_round = int(config["curr_round"]) - 1 + + # consistency weight for self distillation in DepthFL + consistency_weight_constant = 300 + current = np.clip(curr_round, 0.0, consistency_weight_constant) + phase = 1.0 - current / consistency_weight_constant + consistency_weight = float(np.exp(-5.0 * phase * phase)) + + train( + self.net, + self.trainloader, + self.device, + epochs=num_epochs, + learning_rate=self.learning_rate * self.learning_rate_decay**curr_round, + config=config, + consistency_weight=consistency_weight, + prev_grads=self.prev_grads, + ) + + with open(f"prev_grads/client_{self.cid}", "wb") as prev_grads_file: + pickle.dump(self.prev_grads, prev_grads_file) + + return self.get_parameters({}), len(self.trainloader), {"cid": self.cid} + + def evaluate( + self, parameters: NDArrays, config: Dict[str, Scalar] + ) -> Tuple[float, int, Dict]: + """Implement distributed evaluation for a given client.""" + self.set_parameters(parameters) + loss, accuracy, accuracy_single = test(self.net, self.valloader, self.device) + return ( + float(loss), + len(self.valloader), + {"accuracy": float(accuracy), "accuracy_single": accuracy_single}, + ) + + +def gen_client_fn( # pylint: disable=too-many-arguments + num_epochs: int, + trainloaders: List[DataLoader], + valloaders: List[DataLoader], + learning_rate: float, + learning_rate_decay: float, + models: List[DictConfig], +) -> Callable[[str], FlowerClient]: + """Generate the client function that creates the Flower Clients. + + Parameters + ---------- + num_epochs : int + The number of local epochs each client should run the training for before + sending it to the server. + trainloaders: List[DataLoader] + A list of DataLoaders, each pointing to the dataset training partition + belonging to a particular client. + valloaders: List[DataLoader] + A list of DataLoaders, each pointing to the dataset validation partition + belonging to a particular client. + learning_rate : float + The learning rate for the SGD optimizer of clients. + learning_rate_decay : float + The learning rate decay ratio per round for the SGD optimizer of clients. + models : List[DictConfig] + A list of DictConfigs, each pointing to the model config of client's local model + + Returns + ------- + Callable[[str], FlowerClient] + client function that creates Flower Clients + """ + + def client_fn(cid: str) -> FlowerClient: + """Create a Flower client representing a single organization.""" + # Load model + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # each client gets a different model config (different width / depth) + net = instantiate(models[int(cid)]).to(device) + + # Note: each client gets a different trainloader/valloader, so each client + # will train and evaluate on their own unique data + trainloader = trainloaders[int(cid)] + valloader = valloaders[int(cid)] + + with open(f"prev_grads/client_{int(cid)}", "rb") as prev_grads_file: + prev_grads = pickle.load(prev_grads_file) + + return FlowerClient( + net, + trainloader, + valloader, + device, + num_epochs, + learning_rate, + learning_rate_decay, + prev_grads, + int(cid), + ) + + return client_fn diff --git a/baselines/depthfl/depthfl/conf/config.yaml b/baselines/depthfl/depthfl/conf/config.yaml new file mode 100644 index 000000000000..5a126229956e --- /dev/null +++ b/baselines/depthfl/depthfl/conf/config.yaml @@ -0,0 +1,42 @@ +--- + +num_clients: 100 # total number of clients +num_epochs: 5 # number of local epochs +batch_size: 50 +num_rounds: 1000 +fraction: 0.1 # participation ratio +learning_rate: 0.1 +learning_rate_decay : 0.998 # per round +static_bn: false # static batch normalization (HeteroFL) +exclusive_learning: false # exclusive learning baseline in DepthFL paper +model_size: 1 # model size for exclusive learning + +client_resources: + num_cpus: 1 + num_gpus: 0.5 + +server_device: cuda + +dataset_config: + iid: true + beta: 0.5 + +fit_config: + feddyn: true + kd: true + alpha: 0.1 # alpha for FedDyn + extended: true # if not extended : InclusiveFL + drop_client: false # with FedProx, clients shouldn't be dropped even if they are stragglers + +model: + _target_: depthfl.resnet.multi_resnet18 + n_blocks: 4 # depth (1 ~ 4) + num_classes: 100 + +strategy: + _target_: depthfl.strategy.FedDyn + fraction_fit: 0.00001 # because we want the number of clients to sample on each round to be solely defined by min_fit_clients + fraction_evaluate: 0.0 + # min_fit_clients: ${clients_per_round} + min_evaluate_clients: 0 + # min_available_clients: ${clients_per_round} \ No newline at end of file diff --git a/baselines/depthfl/depthfl/conf/heterofl.yaml b/baselines/depthfl/depthfl/conf/heterofl.yaml new file mode 100644 index 000000000000..ad0bb8c8f8b8 --- /dev/null +++ b/baselines/depthfl/depthfl/conf/heterofl.yaml @@ -0,0 +1,43 @@ +--- + +num_clients: 100 # total number of clients +num_epochs: 5 # number of local epochs +batch_size: 50 +num_rounds: 1000 +fraction: 0.1 # participation ratio +learning_rate: 0.1 +learning_rate_decay : 0.998 # per round +static_bn: true # static batch normalization (HeteroFL) +exclusive_learning: false # exclusive learning baseline in DepthFL paper +model_size: 1 # model size for exclusive learning + +client_resources: + num_cpus: 1 + num_gpus: 0.5 + +server_device: cuda + +dataset_config: + iid: true + beta: 0.5 + +fit_config: + feddyn: false + kd: false + alpha: 0.1 # unused + extended: false # unused + drop_client: false # with FedProx, clients shouldn't be dropped even if they are stragglers + +model: + _target_: depthfl.resnet_hetero.resnet18 + n_blocks: 4 # width (1 ~ 4) + num_classes: 100 + scale: true # scaler module in HeteroFL + +strategy: + _target_: depthfl.strategy_hetero.HeteroFL + fraction_fit: 0.00001 # because we want the number of clients to sample on each round to be solely defined by min_fit_clients + fraction_evaluate: 0.0 + # min_fit_clients: ${clients_per_round} + min_evaluate_clients: 0 + # min_available_clients: ${clients_per_round} \ No newline at end of file diff --git a/baselines/depthfl/depthfl/dataset.py b/baselines/depthfl/depthfl/dataset.py new file mode 100644 index 000000000000..c2024fe068a0 --- /dev/null +++ b/baselines/depthfl/depthfl/dataset.py @@ -0,0 +1,60 @@ +"""CIFAR100 dataset utilities for federated learning.""" + +from typing import Optional, Tuple + +import torch +from omegaconf import DictConfig +from torch.utils.data import DataLoader, random_split + +from depthfl.dataset_preparation import _partition_data + + +def load_datasets( # pylint: disable=too-many-arguments + config: DictConfig, + num_clients: int, + val_ratio: float = 0.0, + batch_size: Optional[int] = 32, + seed: Optional[int] = 41, +) -> Tuple[DataLoader, DataLoader, DataLoader]: + """Create the dataloaders to be fed into the model. + + Parameters + ---------- + config: DictConfig + Parameterises the dataset partitioning process + num_clients : int + The number of clients that hold a part of the data + val_ratio : float, optional + The ratio of training data that will be used for validation (between 0 and 1), + by default 0.1 + batch_size : int, optional + The size of the batches to be fed into the model, by default 32 + seed : int, optional + Used to set a fix seed to replicate experiments, by default 42 + + Returns + ------- + Tuple[DataLoader, DataLoader, DataLoader] + The DataLoader for training, validation, and testing. + """ + print(f"Dataset partitioning config: {config}") + datasets, testset = _partition_data( + num_clients, + iid=config.iid, + beta=config.beta, + seed=seed, + ) + # Split each partition into train/val and create DataLoader + trainloaders = [] + valloaders = [] + for dataset in datasets: + len_val = 0 + if val_ratio > 0: + len_val = int(len(dataset) / (1 / val_ratio)) + lengths = [len(dataset) - len_val, len_val] + ds_train, ds_val = random_split( + dataset, lengths, torch.Generator().manual_seed(seed) + ) + trainloaders.append(DataLoader(ds_train, batch_size=batch_size, shuffle=True)) + valloaders.append(DataLoader(ds_val, batch_size=batch_size)) + return trainloaders, valloaders, DataLoader(testset, batch_size=batch_size) diff --git a/baselines/depthfl/depthfl/dataset_preparation.py b/baselines/depthfl/depthfl/dataset_preparation.py new file mode 100644 index 000000000000..006491c7679e --- /dev/null +++ b/baselines/depthfl/depthfl/dataset_preparation.py @@ -0,0 +1,125 @@ +"""Dataset(CIFAR100) preparation for DepthFL.""" + +from typing import List, Optional, Tuple + +import numpy as np +import torchvision.transforms as transforms +from torch.utils.data import Dataset, Subset +from torchvision.datasets import CIFAR100 + + +def _download_data() -> Tuple[Dataset, Dataset]: + """Download (if necessary) and returns the CIFAR-100 dataset. + + Returns + ------- + Tuple[CIFAR100, CIFAR100] + The dataset for training and the dataset for testing CIFAR100. + """ + transform_train = transforms.Compose( + [ + transforms.ToTensor(), + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ] + ) + + transform_test = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ] + ) + + trainset = CIFAR100( + "./dataset", train=True, download=True, transform=transform_train + ) + testset = CIFAR100( + "./dataset", train=False, download=True, transform=transform_test + ) + return trainset, testset + + +def _partition_data( + num_clients, + iid: Optional[bool] = True, + beta=0.5, + seed=41, +) -> Tuple[List[Dataset], Dataset]: + """Split training set to simulate the federated setting. + + Parameters + ---------- + num_clients : int + The number of clients that hold a part of the data + iid : bool, optional + Whether the data should be independent and identically distributed + or if the data should first be sorted by labels and distributed by + noniid manner to each client, by default true + beta : hyperparameter for dirichlet distribution + seed : int, optional + Used to set a fix seed to replicate experiments, by default 42 + + Returns + ------- + Tuple[List[Dataset], Dataset] + A list of dataset for each client and a + single dataset to be use for testing the model. + """ + trainset, testset = _download_data() + + datasets: List[Subset] = [] + + if iid: + distribute_iid(num_clients, seed, trainset, datasets) + + else: + distribute_noniid(num_clients, beta, seed, trainset, datasets) + + return datasets, testset + + +def distribute_iid(num_clients, seed, trainset, datasets): + """Distribute dataset in iid manner.""" + np.random.seed(seed) + num_sample = int(len(trainset) / (num_clients)) + index = list(range(len(trainset))) + for _ in range(num_clients): + sample_idx = np.random.choice(index, num_sample, replace=False) + index = list(set(index) - set(sample_idx)) + datasets.append(Subset(trainset, sample_idx)) + + +def distribute_noniid(num_clients, beta, seed, trainset, datasets): + """Distribute dataset in non-iid manner.""" + labels = np.array([label for _, label in trainset]) + min_size = 0 + np.random.seed(seed) + + while min_size < 10: + idx_batch = [[] for _ in range(num_clients)] + # for each class in the dataset + for k in range(np.max(labels) + 1): + idx_k = np.where(labels == k)[0] + np.random.shuffle(idx_k) + proportions = np.random.dirichlet(np.repeat(beta, num_clients)) + # Balance + proportions = np.array( + [ + p * (len(idx_j) < labels.shape[0] / num_clients) + for p, idx_j in zip(proportions, idx_batch) + ] + ) + proportions = proportions / proportions.sum() + proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] + idx_batch = [ + idx_j + idx.tolist() + for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions)) + ] + min_size = min([len(idx_j) for idx_j in idx_batch]) + + for j in range(num_clients): + np.random.shuffle(idx_batch[j]) + # net_dataidx_map[j] = np.array(idx_batch[j]) + datasets.append(Subset(trainset, np.array(idx_batch[j]))) diff --git a/baselines/depthfl/depthfl/main.py b/baselines/depthfl/depthfl/main.py new file mode 100644 index 000000000000..7bf1d9563eae --- /dev/null +++ b/baselines/depthfl/depthfl/main.py @@ -0,0 +1,135 @@ +"""DepthFL main.""" + +import copy + +import flwr as fl +import hydra +from flwr.common import ndarrays_to_parameters +from flwr.server.client_manager import SimpleClientManager +from hydra.core.hydra_config import HydraConfig +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf + +from depthfl import client, server +from depthfl.dataset import load_datasets +from depthfl.utils import save_results_as_pickle + + +@hydra.main(config_path="conf", config_name="config", version_base=None) +def main(cfg: DictConfig) -> None: + """Run the baseline. + + Parameters + ---------- + cfg : DictConfig + An omegaconf object that stores the hydra config. + """ + print(OmegaConf.to_yaml(cfg)) + + # partition dataset and get dataloaders + trainloaders, valloaders, testloader = load_datasets( + config=cfg.dataset_config, + num_clients=cfg.num_clients, + batch_size=cfg.batch_size, + ) + + # exclusive learning baseline in DepthFL paper + # (model_size, % of clients) = (a,100), (b,75), (c,50), (d,25) + if cfg.exclusive_learning: + cfg.num_clients = int( + cfg.num_clients - (cfg.model_size - 1) * (cfg.num_clients // 4) + ) + + models = [] + for i in range(cfg.num_clients): + model = copy.deepcopy(cfg.model) + + # each client gets different model depth / width + model.n_blocks = i // (cfg.num_clients // 4) + 1 + + # In exclusive learning, every client has same model depth / width + if cfg.exclusive_learning: + model.n_blocks = cfg.model_size + + models.append(model) + + # prepare function that will be used to spawn each client + client_fn = client.gen_client_fn( + num_epochs=cfg.num_epochs, + trainloaders=trainloaders, + valloaders=valloaders, + learning_rate=cfg.learning_rate, + learning_rate_decay=cfg.learning_rate_decay, + models=models, + ) + + # get function that will executed by the strategy's evaluate() method + # Set server's device + device = cfg.server_device + + # Static Batch Normalization for HeteroFL + if cfg.static_bn: + evaluate_fn = server.gen_evaluate_fn_hetero( + trainloaders, testloader, device=device, model_cfg=model + ) + else: + evaluate_fn = server.gen_evaluate_fn(testloader, device=device, model=model) + + # get a function that will be used to construct the config that the client's + # fit() method will received + def get_on_fit_config(): + def fit_config_fn(server_round): + # resolve and convert to python dict + fit_config = OmegaConf.to_container(cfg.fit_config, resolve=True) + fit_config["curr_round"] = server_round # add round info + return fit_config + + return fit_config_fn + + net = instantiate(cfg.model) + # instantiate strategy according to config. Here we pass other arguments + # that are only defined at run time. + strategy = instantiate( + cfg.strategy, + cfg, + net, + evaluate_fn=evaluate_fn, + on_fit_config_fn=get_on_fit_config(), + initial_parameters=ndarrays_to_parameters( + [val.cpu().numpy() for _, val in net.state_dict().items()] + ), + min_fit_clients=int(cfg.num_clients * cfg.fraction), + min_available_clients=int(cfg.num_clients * cfg.fraction), + ) + + # Start simulation + history = fl.simulation.start_simulation( + client_fn=client_fn, + num_clients=cfg.num_clients, + config=fl.server.ServerConfig(num_rounds=cfg.num_rounds), + client_resources={ + "num_cpus": cfg.client_resources.num_cpus, + "num_gpus": cfg.client_resources.num_gpus, + }, + strategy=strategy, + server=server.ServerFedDyn( + client_manager=SimpleClientManager(), strategy=strategy + ), + ) + + # Experiment completed. Now we save the results and + # generate plots using the `history` + print("................") + print(history) + + # Hydra automatically creates an output directory + # Let's retrieve it and save some results there + save_path = HydraConfig.get().runtime.output_dir + + # save results as a Python pickle using a file_path + # the directory created by Hydra for each run + save_results_as_pickle(history, file_path=save_path, extra_results={}) + + +if __name__ == "__main__": + main() diff --git a/baselines/depthfl/depthfl/models.py b/baselines/depthfl/depthfl/models.py new file mode 100644 index 000000000000..df3eebf9f9ce --- /dev/null +++ b/baselines/depthfl/depthfl/models.py @@ -0,0 +1,301 @@ +"""ResNet18 model architecutre, training, and testing functions for CIFAR100.""" + + +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import DictConfig +from torch.utils.data import DataLoader + + +class KLLoss(nn.Module): + """KL divergence loss for self distillation.""" + + def __init__(self): + super().__init__() + self.temperature = 1 + + def forward(self, pred, label): + """KL loss forward.""" + predict = F.log_softmax(pred / self.temperature, dim=1) + target_data = F.softmax(label / self.temperature, dim=1) + target_data = target_data + 10 ** (-7) + with torch.no_grad(): + target = target_data.detach().clone() + + loss = ( + self.temperature + * self.temperature + * ((target * (target.log() - predict)).sum(1).sum() / target.size()[0]) + ) + return loss + + +def train( # pylint: disable=too-many-arguments + net: nn.Module, + trainloader: DataLoader, + device: torch.device, + epochs: int, + learning_rate: float, + config: dict, + consistency_weight: float, + prev_grads: dict, +) -> None: + """Train the network on the training set. + + Parameters + ---------- + net : nn.Module + The neural network to train. + trainloader : DataLoader + The DataLoader containing the data to train the network on. + device : torch.device + The device on which the model should be trained, either 'cpu' or 'cuda'. + epochs : int + The number of epochs the model should be trained for. + learning_rate : float + The learning rate for the SGD optimizer. + config : dict + training configuration + consistency_weight : float + hyperparameter for self distillation + prev_grads : dict + control variate for feddyn + """ + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, weight_decay=1e-3) + global_params = { + k: val.detach().clone().flatten() for (k, val) in net.named_parameters() + } + + for k, _ in net.named_parameters(): + prev_grads[k] = prev_grads[k].to(device) + + net.train() + for _ in range(epochs): + _train_one_epoch( + net, + global_params, + trainloader, + device, + criterion, + optimizer, + config, + consistency_weight, + prev_grads, + ) + + # update prev_grads for FedDyn + if config["feddyn"]: + update_prev_grads(config, net, prev_grads, global_params) + + +def update_prev_grads(config, net, prev_grads, global_params): + """Update prev_grads for FedDyn.""" + for k, param in net.named_parameters(): + curr_param = param.detach().clone().flatten() + prev_grads[k] = prev_grads[k] - config["alpha"] * ( + curr_param - global_params[k] + ) + prev_grads[k] = prev_grads[k].to(torch.device(torch.device("cpu"))) + + +def _train_one_epoch( # pylint: disable=too-many-locals, too-many-arguments + net: nn.Module, + global_params: dict, + trainloader: DataLoader, + device: torch.device, + criterion: torch.nn.CrossEntropyLoss, + optimizer: torch.optim.SGD, + config: dict, + consistency_weight: float, + prev_grads: dict, +): + """Train for one epoch. + + Parameters + ---------- + net : nn.Module + The neural network to train. + global_params : List[Parameter] + The parameters of the global model (from the server). + trainloader : DataLoader + The DataLoader containing the data to train the network on. + device : torch.device + The device on which the model should be trained, either 'cpu' or 'cuda'. + criterion : torch.nn.CrossEntropyLoss + The loss function to use for training + optimizer : torch.optim.Adam + The optimizer to use for training + config : dict + training configuration + consistency_weight : float + hyperparameter for self distillation + prev_grads : dict + control variate for feddyn + """ + criterion_kl = KLLoss().cuda() + + for images, labels in trainloader: + images, labels = images.to(device), labels.to(device) + loss = torch.zeros(1).to(device) + optimizer.zero_grad() + output_lst = net(images) + + for i, branch_output in enumerate(output_lst): + # only trains last classifier in InclusiveFL + if not config["extended"] and i != len(output_lst) - 1: + continue + + loss += criterion(branch_output, labels) + + # self distillation term + if config["kd"] and len(output_lst) > 1: + for j, output in enumerate(output_lst): + if j == i: + continue + + loss += ( + consistency_weight + * criterion_kl(branch_output, output.detach()) + / (len(output_lst) - 1) + ) + + # Dynamic regularization in FedDyn + if config["feddyn"]: + for k, param in net.named_parameters(): + curr_param = param.flatten() + + lin_penalty = torch.dot(curr_param, prev_grads[k]) + loss -= lin_penalty + + quad_penalty = ( + config["alpha"] + / 2.0 + * torch.sum(torch.square(curr_param - global_params[k])) + ) + loss += quad_penalty + + loss.backward() + optimizer.step() + + +def test( # pylint: disable=too-many-locals + net: nn.Module, testloader: DataLoader, device: torch.device +) -> Tuple[float, float, List[float]]: + """Evaluate the network on the entire test set. + + Parameters + ---------- + net : nn.Module + The neural network to test. + testloader : DataLoader + The DataLoader containing the data to test the network on. + device : torch.device + The device on which the model should be tested, either 'cpu' or 'cuda'. + + Returns + ------- + Tuple[float, float, List[float]] + The loss and the accuracy of the global model + and the list of accuracy for each classifier on the given data. + """ + criterion = torch.nn.CrossEntropyLoss() + correct, total, loss = 0, 0, 0.0 + correct_single = [0] * 4 # accuracy of each classifier within model + net.eval() + with torch.no_grad(): + for images, labels in testloader: + images, labels = images.to(device), labels.to(device) + output_lst = net(images) + + # ensemble classfiers' output + ensemble_output = torch.stack(output_lst, dim=2) + ensemble_output = torch.sum(ensemble_output, dim=2) / len(output_lst) + + loss += criterion(ensemble_output, labels).item() + _, predicted = torch.max(ensemble_output, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + for i, single in enumerate(output_lst): + _, predicted = torch.max(single, 1) + correct_single[i] += (predicted == labels).sum().item() + + if len(testloader.dataset) == 0: + raise ValueError("Testloader can't be 0, exiting...") + loss /= len(testloader.dataset) + accuracy = correct / total + accuracy_single = [correct / total for correct in correct_single] + return loss, accuracy, accuracy_single + + +def test_sbn( # pylint: disable=too-many-locals + nets: List[nn.Module], + trainloaders: List[DictConfig], + testloader: DataLoader, + device: torch.device, +) -> Tuple[float, float, List[float]]: + """Evaluate the networks on the entire test set. + + Parameters + ---------- + nets : List[nn.Module] + The neural networks to test. Each neural network has different width + trainloaders : List[DataLoader] + The List of dataloaders containing the data to train the network on + testloader : DataLoader + The DataLoader containing the data to test the network on. + device : torch.device + The device on which the model should be tested, either 'cpu' or 'cuda'. + + Returns + ------- + Tuple[float, float, List[float]] + The loss and the accuracy of the global model + and the list of accuracy for each classifier on the given data. + """ + # static batch normalization + for trainloader in trainloaders: + with torch.no_grad(): + for model in nets: + model.train() + for _batch_idx, (images, labels) in enumerate(trainloader): + images, labels = images.to(device), labels.to(device) + output = model(images) + + model.eval() + + criterion = torch.nn.CrossEntropyLoss() + correct, total, loss = 0, 0, 0.0 + correct_single = [0] * 4 + + # test each network of different width + with torch.no_grad(): + for images, labels in testloader: + images, labels = images.to(device), labels.to(device) + + output_lst = [] + + for model in nets: + output_lst.append(model(images)[0]) + + output = output_lst[-1] + + loss += criterion(output, labels).item() + _, predicted = torch.max(output, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + for i, single in enumerate(output_lst): + _, predicted = torch.max(single, 1) + correct_single[i] += (predicted == labels).sum().item() + + if len(testloader.dataset) == 0: + raise ValueError("Testloader can't be 0, exiting...") + loss /= len(testloader.dataset) + accuracy = correct / total + accuracy_single = [correct / total for correct in correct_single] + return loss, accuracy, accuracy_single diff --git a/baselines/depthfl/depthfl/resnet.py b/baselines/depthfl/depthfl/resnet.py new file mode 100644 index 000000000000..04348ae17441 --- /dev/null +++ b/baselines/depthfl/depthfl/resnet.py @@ -0,0 +1,386 @@ +"""ResNet18 for DepthFL.""" + +import torch.nn as nn + + +class MyGroupNorm(nn.Module): + """Group Normalization layer.""" + + def __init__(self, num_channels): + super().__init__() + # change num_groups to 32 + self.norm = nn.GroupNorm( + num_groups=16, num_channels=num_channels, eps=1e-5, affine=True + ) + + def forward(self, x): + """GN forward.""" + x = self.norm(x) + return x + + +class MyBatchNorm(nn.Module): + """Batch Normalization layer.""" + + def __init__(self, num_channels): + super().__init__() + self.norm = nn.BatchNorm2d(num_channels, track_running_stats=True) + + def forward(self, x): + """BN forward.""" + x = self.norm(x) + return x + + +def conv3x3(in_planes, out_planes, stride=1): + """Convolution layer 3x3.""" + return nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + + +def conv1x1(in_planes, planes, stride=1): + """Convolution layer 1x1.""" + return nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) + + +class SepConv(nn.Module): + """Bottleneck layer module.""" + + def __init__( # pylint: disable=too-many-arguments + self, + channel_in, + channel_out, + kernel_size=3, + stride=2, + padding=1, + norm_layer=MyGroupNorm, + ): + super().__init__() + self.operations = nn.Sequential( + nn.Conv2d( + channel_in, + channel_in, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=channel_in, + bias=False, + ), + nn.Conv2d(channel_in, channel_in, kernel_size=1, padding=0, bias=False), + norm_layer(channel_in), + nn.ReLU(inplace=False), + nn.Conv2d( + channel_in, + channel_in, + kernel_size=kernel_size, + stride=1, + padding=padding, + groups=channel_in, + bias=False, + ), + nn.Conv2d(channel_in, channel_out, kernel_size=1, padding=0, bias=False), + norm_layer(channel_out), + nn.ReLU(inplace=False), + ) + + def forward(self, x): + """SepConv forward.""" + return self.operations(x) + + +class BasicBlock(nn.Module): + """Basic Block for ResNet18.""" + + expansion = 1 + + def __init__( + self, inplanes, planes, stride=1, downsample=None, norm_layer=None + ): # pylint: disable=too-many-arguments + super().__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + """BasicBlock forward.""" + residual = x + + output = self.conv1(x) + output = self.bn1(output) + output = self.relu(output) + + output = self.conv2(output) + output = self.bn2(output) + + if self.downsample is not None: + residual = self.downsample(x) + + output += residual + output = self.relu(output) + return output + + +class MultiResnet(nn.Module): # pylint: disable=too-many-instance-attributes + """Resnet model. + + Args: + block (class): block type, BasicBlock or BottleneckBlock + layers (int list): layer num in each block + n_blocks (int) : Depth of network + num_classes (int): class num. + norm_layer (class): type of normalization layer. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + block, + layers, + n_blocks, + num_classes=1000, + norm_layer=MyBatchNorm, + ): + super().__init__() + self.n_blocks = n_blocks + self.inplanes = 64 + self.norm_layer = norm_layer + self.conv1 = nn.Conv2d( + 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn1 = norm_layer(self.inplanes) + + self.relu = nn.ReLU(inplace=True) + # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 64, layers[0]) + + self.middle_fc1 = nn.Linear(512 * block.expansion, num_classes) + # self.feature_fc1 = nn.Linear(512 * block.expansion, 512 * block.expansion) + self.scala1 = nn.Sequential( + SepConv( + channel_in=64 * block.expansion, + channel_out=128 * block.expansion, + norm_layer=norm_layer, + ), + SepConv( + channel_in=128 * block.expansion, + channel_out=256 * block.expansion, + norm_layer=norm_layer, + ), + SepConv( + channel_in=256 * block.expansion, + channel_out=512 * block.expansion, + norm_layer=norm_layer, + ), + nn.AdaptiveAvgPool2d(1), + ) + + self.attention1 = nn.Sequential( + SepConv( + channel_in=64 * block.expansion, + channel_out=64 * block.expansion, + norm_layer=norm_layer, + ), + norm_layer(64 * block.expansion), + nn.ReLU(), + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + nn.Sigmoid(), + ) + + if n_blocks > 1: + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.middle_fc2 = nn.Linear(512 * block.expansion, num_classes) + # self.feature_fc2 = nn.Linear(512 * block.expansion, 512 * block.expansion) + self.scala2 = nn.Sequential( + SepConv( + channel_in=128 * block.expansion, + channel_out=256 * block.expansion, + norm_layer=norm_layer, + ), + SepConv( + channel_in=256 * block.expansion, + channel_out=512 * block.expansion, + norm_layer=norm_layer, + ), + nn.AdaptiveAvgPool2d(1), + ) + self.attention2 = nn.Sequential( + SepConv( + channel_in=128 * block.expansion, + channel_out=128 * block.expansion, + norm_layer=norm_layer, + ), + norm_layer(128 * block.expansion), + nn.ReLU(), + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + nn.Sigmoid(), + ) + + if n_blocks > 2: + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.middle_fc3 = nn.Linear(512 * block.expansion, num_classes) + # self.feature_fc3 = nn.Linear(512 * block.expansion, 512 * block.expansion) + self.scala3 = nn.Sequential( + SepConv( + channel_in=256 * block.expansion, + channel_out=512 * block.expansion, + norm_layer=norm_layer, + ), + nn.AdaptiveAvgPool2d(1), + ) + self.attention3 = nn.Sequential( + SepConv( + channel_in=256 * block.expansion, + channel_out=256 * block.expansion, + norm_layer=norm_layer, + ), + norm_layer(256 * block.expansion), + nn.ReLU(), + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + nn.Sigmoid(), + ) + + if n_blocks > 3: + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.fc_layer = nn.Linear(512 * block.expansion, num_classes) + self.scala4 = nn.AdaptiveAvgPool2d(1) + + for module in self.modules(): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_( + module.weight, mode="fan_out", nonlinearity="relu" + ) + elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + def _make_layer( + self, block, planes, layers, stride=1, norm_layer=None + ): # pylint: disable=too-many-arguments + """Create a block with layers. + + Args: + block (class): block type + planes (int): output channels = planes * expansion + layers (int): layer num in the block + stride (int): the first layer stride in the block. + norm_layer (class): type of normalization layer. + """ + norm_layer = self.norm_layer + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + layer = [] + layer.append( + block( + self.inplanes, + planes, + stride=stride, + downsample=downsample, + norm_layer=norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _i in range(1, layers): + layer.append(block(self.inplanes, planes, norm_layer=norm_layer)) + + return nn.Sequential(*layer) + + def forward(self, x): + """Resnet forward.""" + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + # x = self.maxpool(x) + + x = self.layer1(x) + fea1 = self.attention1(x) + fea1 = fea1 * x + out1_feature = self.scala1(fea1).view(x.size(0), -1) + middle_output1 = self.middle_fc1(out1_feature) + # out1_feature = self.feature_fc1(out1_feature) + + if self.n_blocks == 1: + return [middle_output1] + + x = self.layer2(x) + fea2 = self.attention2(x) + fea2 = fea2 * x + out2_feature = self.scala2(fea2).view(x.size(0), -1) + middle_output2 = self.middle_fc2(out2_feature) + # out2_feature = self.feature_fc2(out2_feature) + if self.n_blocks == 2: + return [middle_output1, middle_output2] + + x = self.layer3(x) + fea3 = self.attention3(x) + fea3 = fea3 * x + out3_feature = self.scala3(fea3).view(x.size(0), -1) + middle_output3 = self.middle_fc3(out3_feature) + # out3_feature = self.feature_fc3(out3_feature) + + if self.n_blocks == 3: + return [middle_output1, middle_output2, middle_output3] + + x = self.layer4(x) + out4_feature = self.scala4(x).view(x.size(0), -1) + output4 = self.fc_layer(out4_feature) + + return [middle_output1, middle_output2, middle_output3, output4] + + +def multi_resnet18(n_blocks=1, norm="bn", num_classes=100): + """Create resnet18 for HeteroFL. + + Parameters + ---------- + n_blocks: int + depth of network + norm: str + normalization layer type + num_classes: int + # of labels + + Returns + ------- + Callable [ [nn.Module,List[int],int,int,nn.Module], nn.Module] + """ + if norm == "gn": + norm_layer = MyGroupNorm + + elif norm == "bn": + norm_layer = MyBatchNorm + + return MultiResnet( + BasicBlock, + [2, 2, 2, 2], + n_blocks, + num_classes=num_classes, + norm_layer=norm_layer, + ) + + +# if __name__ == "__main__": +# from ptflops import get_model_complexity_info + +# model = MultiResnet18(n_blocks=4, num_classes=100) + +# with torch.cuda.device(0): +# macs, params = get_model_complexity_info( +# model, +# (3, 32, 32), +# as_strings=True, +# print_per_layer_stat=False, +# verbose=True, +# units="MMac", +# ) + +# print("{:<30} {:<8}".format("Computational complexity: ", macs)) +# print("{:<30} {:<8}".format("Number of parameters: ", params)) diff --git a/baselines/depthfl/depthfl/resnet_hetero.py b/baselines/depthfl/depthfl/resnet_hetero.py new file mode 100644 index 000000000000..a84c07b881b2 --- /dev/null +++ b/baselines/depthfl/depthfl/resnet_hetero.py @@ -0,0 +1,280 @@ +"""ResNet18 for HeteroFL.""" + +import numpy as np +import torch.nn as nn + + +class Scaler(nn.Module): + """Scaler module for HeteroFL.""" + + def __init__(self, rate, scale): + super().__init__() + if scale: + self.rate = rate + else: + self.rate = 1 + + def forward(self, x): + """Scaler forward.""" + output = x / self.rate if self.training else x + return output + + +class MyBatchNorm(nn.Module): + """Static Batch Normalization for HeteroFL.""" + + def __init__(self, num_channels, track=True): + super().__init__() + # change num_groups to 32 + self.norm = nn.BatchNorm2d(num_channels, track_running_stats=track) + + def forward(self, x): + """BatchNorm forward.""" + x = self.norm(x) + return x + + +def conv3x3(in_planes, out_planes, stride=1): + """Convolution layer 3x3.""" + return nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + + +def conv1x1(in_planes, planes, stride=1): + """Convolution layer 1x1.""" + return nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): # pylint: disable=too-many-instance-attributes + """Basic Block for ResNet18.""" + + expansion = 1 + + def __init__( # pylint: disable=too-many-arguments + self, + inplanes, + planes, + stride=1, + scaler_rate=1, + downsample=None, + track=True, + scale=True, + ): + super().__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.scaler = Scaler(scaler_rate, scale) + self.bn1 = MyBatchNorm(planes, track) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = MyBatchNorm(planes, track) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + """BasicBlock forward.""" + residual = x + + output = self.conv1(x) + output = self.scaler(output) + output = self.bn1(output) + output = self.relu(output) + + output = self.conv2(output) + output = self.scaler(output) + output = self.bn2(output) + + if self.downsample is not None: + residual = self.downsample(x) + + output += residual + output = self.relu(output) + return output + + +class Resnet(nn.Module): # pylint: disable=too-many-instance-attributes + """Resnet model.""" + + def __init__( # pylint: disable=too-many-arguments + self, hidden_size, block, layers, num_classes, scaler_rate, track, scale + ): + super().__init__() + + self.inplanes = hidden_size[0] + self.norm_layer = MyBatchNorm + self.conv1 = nn.Conv2d( + 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False + ) + self.scaler = Scaler(scaler_rate, scale) + self.bn1 = self.norm_layer(self.inplanes, track) + + self.relu = nn.ReLU(inplace=True) + # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer( + block, + hidden_size[0], + layers[0], + scaler_rate=scaler_rate, + track=track, + scale=scale, + ) + self.layer2 = self._make_layer( + block, + hidden_size[1], + layers[1], + stride=2, + scaler_rate=scaler_rate, + track=track, + scale=scale, + ) + self.layer3 = self._make_layer( + block, + hidden_size[2], + layers[2], + stride=2, + scaler_rate=scaler_rate, + track=track, + scale=scale, + ) + self.layer4 = self._make_layer( + block, + hidden_size[3], + layers[3], + stride=2, + scaler_rate=scaler_rate, + track=track, + scale=scale, + ) + self.fc_layer = nn.Linear(hidden_size[3] * block.expansion, num_classes) + self.scala = nn.AdaptiveAvgPool2d(1) + + for module in self.modules(): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_( + module.weight, mode="fan_out", nonlinearity="relu" + ) + elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + def _make_layer( # pylint: disable=too-many-arguments + self, block, planes, layers, stride=1, scaler_rate=1, track=True, scale=True + ): + """Create a block with layers. + + Args: + block (class): block type + planes (int): output channels = planes * expansion + layers (int): layer num in the block + stride (int): the first layer stride in the block. + scaler_rate (float): for scaler module + track (bool): static batch normalization + scale (bool): for scaler module. + """ + norm_layer = self.norm_layer + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion, track), + ) + layer = [] + layer.append( + block( + self.inplanes, + planes, + stride=stride, + scaler_rate=scaler_rate, + downsample=downsample, + track=track, + scale=scale, + ) + ) + self.inplanes = planes * block.expansion + for _i in range(1, layers): + layer.append( + block( + self.inplanes, + planes, + scaler_rate=scaler_rate, + track=track, + scale=scale, + ) + ) + + return nn.Sequential(*layer) + + def forward(self, x): + """Resnet forward.""" + x = self.conv1(x) + x = self.scaler(x) + x = self.bn1(x) + x = self.relu(x) + # x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + out = self.scala(x).view(x.size(0), -1) + out = self.fc_layer(out) + + return [out] + + +def resnet18(n_blocks=4, track=False, scale=True, num_classes=100): + """Create resnet18 for HeteroFL. + + Parameters + ---------- + n_blocks: int + corresponds to width (divided by 4) + track: bool + static batch normalization + scale: bool + scaler module + num_classes: int + # of labels + + Returns + ------- + Callable [ [List[int],nn.Module,List[int],int,float,bool,bool], nn.Module] + """ + # width pruning ratio : (0.25, 0.50, 0.75, 0.10) + model_rate = n_blocks / 4 + classes_size = num_classes + + hidden_size = [64, 128, 256, 512] + hidden_size = [int(np.ceil(model_rate * x)) for x in hidden_size] + + scaler_rate = model_rate + + return Resnet( + hidden_size, + BasicBlock, + [2, 2, 2, 2], + num_classes=classes_size, + scaler_rate=scaler_rate, + track=track, + scale=scale, + ) + + +# if __name__ == "__main__": +# from ptflops import get_model_complexity_info + +# model = resnet18(100, 1.0) + +# with torch.cuda.device(0): +# macs, params = get_model_complexity_info( +# model, +# (3, 32, 32), +# as_strings=True, +# print_per_layer_stat=False, +# verbose=True, +# units="MMac", +# ) + +# print("{:<30} {:<8}".format("Computational complexity: ", macs)) +# print("{:<30} {:<8}".format("Number of parameters: ", params)) diff --git a/baselines/depthfl/depthfl/server.py b/baselines/depthfl/depthfl/server.py new file mode 100644 index 000000000000..dc99ae2fc5de --- /dev/null +++ b/baselines/depthfl/depthfl/server.py @@ -0,0 +1,209 @@ +"""Server for DepthFL baseline.""" + +import copy +from collections import OrderedDict +from logging import DEBUG, INFO +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from flwr.common import FitRes, Parameters, Scalar, parameters_to_ndarrays +from flwr.common.logger import log +from flwr.common.typing import NDArrays +from flwr.server.client_proxy import ClientProxy +from flwr.server.server import Server, fit_clients +from hydra.utils import instantiate +from omegaconf import DictConfig +from torch.utils.data import DataLoader + +from depthfl.client import prune +from depthfl.models import test, test_sbn +from depthfl.strategy import aggregate_fit_depthfl +from depthfl.strategy_hetero import aggregate_fit_hetero + +FitResultsAndFailures = Tuple[ + List[Tuple[ClientProxy, FitRes]], + List[Union[Tuple[ClientProxy, FitRes], BaseException]], +] + + +def gen_evaluate_fn( + testloader: DataLoader, + device: torch.device, + model: DictConfig, +) -> Callable[ + [int, NDArrays, Dict[str, Scalar]], + Tuple[float, Dict[str, Union[Scalar, List[float]]]], +]: + """Generate the function for centralized evaluation. + + Parameters + ---------- + testloader : DataLoader + The dataloader to test the model with. + device : torch.device + The device to test the model on. + model : DictConfig + model configuration for instantiating + + Returns + ------- + Callable[ [int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]] ] + The centralized evaluation function. + """ + + def evaluate( + server_round: int, parameters_ndarrays: NDArrays, config: Dict[str, Scalar] + ) -> Tuple[float, Dict[str, Union[Scalar, List[float]]]]: + # pylint: disable=unused-argument + """Use the entire CIFAR-100 test set for evaluation.""" + net = instantiate(model) + params_dict = zip(net.state_dict().keys(), parameters_ndarrays) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) + net.to(device) + + loss, accuracy, accuracy_single = test(net, testloader, device=device) + # return statistics + return loss, {"accuracy": accuracy, "accuracy_single": accuracy_single} + + return evaluate + + +def gen_evaluate_fn_hetero( + trainloaders: List[DataLoader], + testloader: DataLoader, + device: torch.device, + model_cfg: DictConfig, +) -> Callable[ + [int, NDArrays, Dict[str, Scalar]], + Tuple[float, Dict[str, Union[Scalar, List[float]]]], +]: + """Generate the function for centralized evaluation. + + Parameters + ---------- + trainloaders : List[DataLoader] + The list of dataloaders to calculate statistics for BN + testloader : DataLoader + The dataloader to test the model with. + device : torch.device + The device to test the model on. + model_cfg : DictConfig + model configuration for instantiating + + Returns + ------- + Callable[ [int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]] ] + The centralized evaluation function. + """ + + def evaluate( # pylint: disable=too-many-locals + server_round: int, parameters_ndarrays: NDArrays, config: Dict[str, Scalar] + ) -> Tuple[float, Dict[str, Union[Scalar, List[float]]]]: + # pylint: disable=unused-argument + """Use the entire CIFAR-100 test set for evaluation.""" + # test per 50 rounds (sbn takes a long time) + if server_round % 50 != 0: + return 0.0, {"accuracy": 0.0, "accuracy_single": [0] * 4} + + # models with different width + models = [] + for i in range(4): + model_tmp = copy.deepcopy(model_cfg) + model_tmp.n_blocks = i + 1 + models.append(model_tmp) + + # load global parameters + param_idx_lst = [] + nets = [] + net_tmp = instantiate(models[-1], track=False) + for model in models: + net = instantiate(model, track=True, scale=False) + nets.append(net) + param_idx = {} + for k in net_tmp.state_dict().keys(): + param_idx[k] = [ + torch.arange(size) for size in net.state_dict()[k].shape + ] + param_idx_lst.append(param_idx) + + params_dict = zip(net_tmp.state_dict().keys(), parameters_ndarrays) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + + for net, param_idx in zip(nets, param_idx_lst): + net.load_state_dict(prune(state_dict, param_idx), strict=False) + net.to(device) + net.train() + + loss, accuracy, accuracy_single = test_sbn( + nets, trainloaders, testloader, device=device + ) + # return statistics + return loss, {"accuracy": accuracy, "accuracy_single": accuracy_single} + + return evaluate + + +class ServerFedDyn(Server): + """Sever for FedDyn.""" + + def fit_round( + self, + server_round: int, + timeout: Optional[float], + ) -> Optional[ + Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures] + ]: + """Perform a single round.""" + # Get clients and their respective instructions from strategy + client_instructions = self.strategy.configure_fit( + server_round=server_round, + parameters=self.parameters, + client_manager=self._client_manager, + ) + + if not client_instructions: + log(INFO, "fit_round %s: no clients selected, cancel", server_round) + return None + log( + DEBUG, + "fit_round %s: strategy sampled %s clients (out of %s)", + server_round, + len(client_instructions), + self._client_manager.num_available(), + ) + + # Collect `fit` results from all clients participating in this round + results, failures = fit_clients( + client_instructions=client_instructions, + max_workers=self.max_workers, + timeout=timeout, + ) + log( + DEBUG, + "fit_round %s received %s results and %s failures", + server_round, + len(results), + len(failures), + ) + + if "HeteroFL" in str(type(self.strategy)): + aggregate_fit = aggregate_fit_hetero + else: + aggregate_fit = aggregate_fit_depthfl + + aggregated_result: Tuple[ + Optional[Parameters], + Dict[str, Scalar], + ] = aggregate_fit( + self.strategy, + server_round, + results, + failures, + parameters_to_ndarrays(self.parameters), + ) + + parameters_aggregated, metrics_aggregated = aggregated_result + return parameters_aggregated, metrics_aggregated, (results, failures) diff --git a/baselines/depthfl/depthfl/strategy.py b/baselines/depthfl/depthfl/strategy.py new file mode 100644 index 000000000000..3414c28c4518 --- /dev/null +++ b/baselines/depthfl/depthfl/strategy.py @@ -0,0 +1,136 @@ +"""Strategy for DepthFL.""" + +import os +import pickle +from logging import WARNING +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from flwr.common import ( + NDArrays, + Parameters, + Scalar, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.logger import log +from flwr.common.typing import FitRes +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy import FedAvg +from omegaconf import DictConfig + + +class FedDyn(FedAvg): + """Applying dynamic regularization in FedDyn paper.""" + + def __init__(self, cfg: DictConfig, net: nn.Module, *args, **kwargs): + self.cfg = cfg + self.h_variate = [np.zeros(v.shape) for (k, v) in net.state_dict().items()] + + # tagging real weights / biases + self.is_weight = [] + for k in net.state_dict().keys(): + if "weight" not in k and "bias" not in k: + self.is_weight.append(False) + else: + self.is_weight.append(True) + + # prev_grads file for each client + prev_grads = [ + {k: torch.zeros(v.numel()) for (k, v) in net.named_parameters()} + ] * cfg.num_clients + + if not os.path.exists("prev_grads"): + os.makedirs("prev_grads") + + for idx in range(cfg.num_clients): + with open(f"prev_grads/client_{idx}", "wb") as prev_grads_file: + pickle.dump(prev_grads[idx], prev_grads_file) + + super().__init__(*args, **kwargs) + + +def aggregate_fit_depthfl( + strategy, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + origin: NDArrays, +) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using weighted average.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not strategy.accept_failures and failures: + return None, {} + + # Convert results + weights_results = [ + (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) + for _, fit_res in results + ] + parameters_aggregated = ndarrays_to_parameters( + aggregate( + weights_results, + origin, + strategy.h_variate, + strategy.is_weight, + strategy.cfg, + ) + ) + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if strategy.fit_metrics_aggregation_fn: + fit_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = strategy.fit_metrics_aggregation_fn(fit_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No fit_metrics_aggregation_fn provided") + + return parameters_aggregated, metrics_aggregated + + +def aggregate( + results: List[Tuple[NDArrays, int]], + origin: NDArrays, + h_list: List, + is_weight: List, + cfg: DictConfig, +) -> NDArrays: + """Aggregate model parameters with different depths.""" + param_count = [0] * len(origin) + weights_sum = [np.zeros(v.shape) for v in origin] + + # summation & counting of parameters + for parameters, _ in results: + for i, layer in enumerate(parameters): + weights_sum[i] += layer + param_count[i] += 1 + + # update parameters + for i, weight in enumerate(weights_sum): + if param_count[i] > 0: + weight = weight / param_count[i] + # print(np.isscalar(weight)) + + # update h variable for FedDyn + h_list[i] = ( + h_list[i] + - cfg.fit_config.alpha + * param_count[i] + * (weight - origin[i]) + / cfg.num_clients + ) + + # applying h only for weights / biases + if is_weight[i] and cfg.fit_config.feddyn: + weights_sum[i] = weight - h_list[i] / cfg.fit_config.alpha + else: + weights_sum[i] = weight + + else: + weights_sum[i] = origin[i] + + return weights_sum diff --git a/baselines/depthfl/depthfl/strategy_hetero.py b/baselines/depthfl/depthfl/strategy_hetero.py new file mode 100644 index 000000000000..7544204cde2f --- /dev/null +++ b/baselines/depthfl/depthfl/strategy_hetero.py @@ -0,0 +1,136 @@ +"""Strategy for HeteroFL.""" + +import os +import pickle +from logging import WARNING +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from flwr.common import ( + NDArrays, + Parameters, + Scalar, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.logger import log +from flwr.common.typing import FitRes +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy import FedAvg +from hydra.utils import instantiate +from omegaconf import DictConfig + + +class HeteroFL(FedAvg): + """Custom FedAvg for HeteroFL.""" + + def __init__(self, cfg: DictConfig, net: nn.Module, *args, **kwargs): + self.cfg = cfg + self.parameters = [np.zeros(v.shape) for (k, v) in net.state_dict().items()] + self.param_idx_lst = [] + + model = cfg.model + # store parameter shapes of different width + for i in range(4): + model.n_blocks = i + 1 + net_tmp = instantiate(model) + param_idx = [] + for k in net_tmp.state_dict().keys(): + param_idx.append( + [torch.arange(size) for size in net_tmp.state_dict()[k].shape] + ) + + # print(net_tmp.state_dict()['conv1.weight'].shape[0]) + self.param_idx_lst.append(param_idx) + + self.is_weight = [] + + # tagging real weights / biases + for k in net.state_dict().keys(): + if "num" in k: + self.is_weight.append(False) + else: + self.is_weight.append(True) + + # prev_grads file for each client + prev_grads = [ + {k: torch.zeros(v.numel()) for (k, v) in net.named_parameters()} + ] * cfg.num_clients + + if not os.path.exists("prev_grads"): + os.makedirs("prev_grads") + + for idx in range(cfg.num_clients): + with open(f"prev_grads/client_{idx}", "wb") as prev_grads_file: + pickle.dump(prev_grads[idx], prev_grads_file) + + super().__init__(*args, **kwargs) + + def aggregate_hetero( + self, results: List[Tuple[NDArrays, Union[bool, bytes, float, int, str]]] + ): + """Aggregate function for HeteroFL.""" + for i, params in enumerate(self.parameters): + count = np.zeros(params.shape) + tmp_v = np.zeros(params.shape) + if self.is_weight[i]: + for weights, cid in results: + if self.cfg.exclusive_learning: + cid = self.cfg.model_size * (self.cfg.num_clients // 4) - 1 + + tmp_v[ + torch.meshgrid( + self.param_idx_lst[cid // (self.cfg.num_clients // 4)][i] + ) + ] += weights[i] + count[ + torch.meshgrid( + self.param_idx_lst[cid // (self.cfg.num_clients // 4)][i] + ) + ] += 1 + tmp_v[count > 0] = np.divide(tmp_v[count > 0], count[count > 0]) + params[count > 0] = tmp_v[count > 0] + + else: + for weights, _ in results: + tmp_v += weights[i] + count += 1 + tmp_v = np.divide(tmp_v, count) + params = tmp_v + + +def aggregate_fit_hetero( + strategy, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + origin: NDArrays, +) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using weighted average.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not strategy.accept_failures and failures: + return None, {} + + # Convert results + weights_results = [ + (parameters_to_ndarrays(fit_res.parameters), fit_res.metrics["cid"]) + for _, fit_res in results + ] + + strategy.parameters = origin + strategy.aggregate_hetero(weights_results) + parameters_aggregated = ndarrays_to_parameters(strategy.parameters) + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if strategy.fit_metrics_aggregation_fn: + fit_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = strategy.fit_metrics_aggregation_fn(fit_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No fit_metrics_aggregation_fn provided") + + return parameters_aggregated, metrics_aggregated diff --git a/baselines/depthfl/depthfl/utils.py b/baselines/depthfl/depthfl/utils.py new file mode 100644 index 000000000000..fad2afcad4be --- /dev/null +++ b/baselines/depthfl/depthfl/utils.py @@ -0,0 +1,66 @@ +"""Contains utility functions for CNN FL on MNIST.""" + +import pickle +from pathlib import Path +from secrets import token_hex +from typing import Dict, Union + +from flwr.server.history import History + + +def save_results_as_pickle( + history: History, + file_path: Union[str, Path], + extra_results: Dict, + default_filename: str = "results.pkl", +) -> None: + """Save results from simulation to pickle. + + Parameters + ---------- + history: History + History returned by start_simulation. + file_path: Union[str, Path] + Path to file to create and store both history and extra_results. + If path is a directory, the default_filename will be used. + path doesn't exist, it will be created. If file exists, a + randomly generated suffix will be added to the file name. This + is done to avoid overwritting results. + extra_results : Dict + A dictionary containing additional results you would like + to be saved to disk. Default: {} (an empty dictionary) + default_filename: Optional[str] + File used by default if file_path points to a directory instead + to a file. Default: "results.pkl" + """ + path = Path(file_path) + + # ensure path exists + path.mkdir(exist_ok=True, parents=True) + + def _add_random_suffix(path_: Path): + """Add a randomly generated suffix to the file name.""" + print(f"File `{path_}` exists! ") + suffix = token_hex(4) + print(f"New results to be saved with suffix: {suffix}") + return path_.parent / (path_.stem + "_" + suffix + ".pkl") + + def _complete_path_with_default_name(path_: Path): + """Append the default file name to the path.""" + print("Using default filename") + return path_ / default_filename + + if path.is_dir(): + path = _complete_path_with_default_name(path) + + if path.is_file(): + # file exists already + path = _add_random_suffix(path) + + print(f"Results will be saved into: {path}") + + data = {"history": history, **extra_results} + + # save results to pickle + with open(str(path), "wb") as handle: + pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/baselines/depthfl/pyproject.toml b/baselines/depthfl/pyproject.toml new file mode 100644 index 000000000000..2f928c2d3553 --- /dev/null +++ b/baselines/depthfl/pyproject.toml @@ -0,0 +1,141 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.masonry.api" + +[tool.poetry] +name = "depthfl" # <----- Ensure it matches the name of your baseline directory containing all the source code +version = "1.0.0" +description = "DepthFL: Depthwise Federated Learning for Heterogeneous Clients" +license = "Apache-2.0" +authors = ["Minjae Kim "] +readme = "README.md" +homepage = "https://flower.dev" +repository = "https://github.com/adap/flower" +documentation = "https://flower.dev" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] + +[tool.poetry.dependencies] +python = ">=3.10.0, <3.11.0" +flwr = { extras = ["simulation"], version = "1.5.0" } +hydra-core = "1.3.2" # don't change this +matplotlib = "3.7.1" +torch = { url = "https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp310-cp310-linux_x86_64.whl"} +torchvision = { url = "https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp310-cp310-linux_x86_64.whl"} + + +[tool.poetry.dev-dependencies] +isort = "==5.11.5" +black = "==23.1.0" +docformatter = "==1.5.1" +mypy = "==1.4.1" +pylint = "==2.8.2" +flake8 = "==3.9.2" +pytest = "==6.2.4" +pytest-watch = "==4.2.0" +ruff = "==0.0.272" +types-requests = "==2.27.7" + +[tool.isort] +line_length = 88 +indent = " " +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true + +[tool.black] +line-length = 88 +target-version = ["py38", "py39", "py310", "py311"] + +[tool.pytest.ini_options] +minversion = "6.2" +addopts = "-qq" +testpaths = [ + "flwr_baselines", +] + +[tool.mypy] +ignore_missing_imports = true +strict = false +plugins = "numpy.typing.mypy_plugin" + +[tool.pylint."MESSAGES CONTROL"] +disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" +good-names = "i,j,k,_,x,y,X,Y" +signature-mutators="hydra.main.main" + +[tool.pylint.typecheck] +generated-members="numpy.*, torch.*, tensorflow.*" + +[[tool.mypy.overrides]] +module = [ + "importlib.metadata.*", + "importlib_metadata.*", +] +follow_imports = "skip" +follow_imports_for_stubs = true +disallow_untyped_calls = false + +[[tool.mypy.overrides]] +module = "torch.*" +follow_imports = "skip" +follow_imports_for_stubs = true + +[tool.docformatter] +wrap-summaries = 88 +wrap-descriptions = 88 + +[tool.ruff] +target-version = "py38" +line-length = 88 +select = ["D", "E", "F", "W", "B", "ISC", "C4"] +fixable = ["D", "E", "F", "W", "B", "ISC", "C4"] +ignore = ["B024", "B027"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "proto", +] + +[tool.ruff.pydocstyle] +convention = "numpy" diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index a85e40815a55..313ba97b75d8 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -36,6 +36,8 @@ - MOON [#2421](https://github.com/adap/flower/pull/2421) + - DepthFL [#2295](https://github.com/adap/flower/pull/2295) + - **Update Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384),[#2425](https://github.com/adap/flower/pull/2425), [#2526](https://github.com/adap/flower/pull/2526)) - **General updates to baselines** ([#2301](https://github.com/adap/flower/pull/2301), [#2305](https://github.com/adap/flower/pull/2305), [#2307](https://github.com/adap/flower/pull/2307), [#2327](https://github.com/adap/flower/pull/2327), [#2435](https://github.com/adap/flower/pull/2435))