diff --git a/LICENSE b/LICENSE
index 82e353b70..039d55505 100644
--- a/LICENSE
+++ b/LICENSE
@@ -335,24 +335,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
-------------------------------------------------------------------------------
-
-Code in federatedscope/core/configs/yacs_config.py, the basic code of yacs
-adopts Apache 2.0 License
-
-Copyright (c) 2018-present, Facebook, Inc.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
--------------------------------------------------------------------------------
@@ -430,234 +412,3 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
--------------------------------------------------------------------------------
-
-Code in federatedscope/contrib/model/resnet.py is adapted from
-https://github.com/kuangliu/pytorch-cifar (MIT License)
-
-Copyright (c) 2017 liukuang
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
---------------------------------------------------------------------------------
-
-Code in federatedscope/attack/auxiliary/create_edgeset.py and poisoning_data.py
-is adapted from https://github.com/ksreenivasan/OOD_Federated_Learning
-(MIT License)
-
-Copyright (c) 2017 liukuang
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
---------------------------------------------------------------------------------
-
-The function partition_by_category and subgraphing in
-federatedscope/gfl/dataset/recsys.py
-are borrowed from https://github.com/FedML-AI/FedGraphNN
-
-Copyright [FedML] [Chaoyang He, Salman Avestimehr]
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-
---------------------------------------------------------------------------------
-
-The function calculate_time_cost in federatedscope/core/auxiliaries/utils.py
-is adopted from https://github.com/SymbioticLab/FedScale
-
-Copyright 2022 The FedScale Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-
---------------------------------------------------------------------------------
-
-The registration mechanism for federatedscope/contrib/ and the code structure
-of federatedscope/core/configs/ are adapted from GraphGym:
-https://github.com/snap-stanford/GraphGym (MIT License)
-
-Copyright (c) 2021 Jiaxuan You
-Copyright (c) 2021 Jiaxuan You, Matthias Fey
-Copyright (c) 2020 Jiaxuan You, Rex Ying, Jonathan Gomes Selman
-Copyright (c) Facebook, Inc. and its affiliates.
-Additional copyrights are specified in relevant subdirectories.
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
---------------------------------------------------------------------------------
-
-Code in federatedscope/nlp/metric/bleu/bleu.py, federatedscope/nlp/metric/bleu/bleu_scorer.py,
-federatedscope/nlp/metric/meteor/meteor.py, and federatedscope/nlp/metric/eval.py
-is adapted from https://github.com/tylin/coco-caption
-
-Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in
-all copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-THE SOFTWARE.
-
---------------------------------------------------------------------------------
-
-Code in federatedscope/nlp/metric/eval.py is adapted from
-https://github.com/hugochan/RL-based-Graph2Seq-for-NQG
-
-Copyright 2020 Yu (Hugo) Chen
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-
---------------------------------------------------------------------------------
-
-Code in federatedscope/nlp/metric/rouge/pyrouge.py and federatedscope/nlp/metric/rouge/utils.py
-is adapted from https://github.com/nlpyang/PreSumm (MIT License)
-
-Copyright (c) 2019 Yang Liu
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
---------------------------------------------------------------------------------
-
-The implementation of ROUGE-155 in federatedscope/nlp/metric/rouge/pyrouge.py
-is adapted from https://github.com/bheinzerling/pyrouge (MIT License)
-
-Copyright (c) 2014 Benjamin Heinzerling
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
----------------------------------------------------------------------------------
-The implementations of median aggregator in federatedscope/core/aggregators/median_aggregator.py
-and trimmedmean aggregator in federatedscope/core/aggregators/trimmedmean_aggregator.py
-are adapted from https://github.com/bladesteam/blades (Apache License)
-
-Copyright (c) 2022 lishenghui
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
diff --git a/README.md b/README.md
index eee914d41..9df6c2ab0 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,233 @@
-README for Backdoor Benchmark
\ No newline at end of file
+
+
+
+
+![](https://img.shields.io/badge/language-python-blue.svg)
+![](https://img.shields.io/badge/license-Apache-000000.svg)
+[![Website](https://img.shields.io/badge/website-FederatedScope-0000FF)](https://federatedscope.io/)
+[![Playground](https://shields.io/badge/JupyterLab-Enjoy%20Your%20FL%20Journey!-F37626?logo=jupyter)](https://try.federatedscope.io/)
+[![Contributing](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](https://federatedscope.io/docs/contributor/)
+
+FederatedScope is a comprehensive federated learning platform that provides convenient usage and flexible customization for various federated learning tasks in both academia and industry. Based on an event-driven architecture, FederatedScope integrates rich collections of functionalities to satisfy the burgeoning demands from federated learning, and aims to build up an easy-to-use platform for promoting learning safely and effectively.
+
+A detailed tutorial is provided on our [website](https://federatedscope.io/).
+
+## News
+- [05-25-2022] Our paper [FederatedScope-GNN](https://arxiv.org/abs/2204.05562) has been accepted by KDD'2022!
+- [05-06-2022] We release FederatedScope v0.1.0!
+
+## Quick Start
+
+We provide an end-to-end example for users to start running a standard FL course with FederatedScope.
+
+### Step 1. Installation
+
+First of all, users need to clone the source code and install the required packages (we suggest python version >= 3.9).
+
+```bash
+git clone https://github.com/alibaba/FederatedScope.git
+cd FederatedScope
+```
+You can install the dependencies from the requirement file:
+```
+# For minimal version
+conda install --file enviroment/requirements-torch1.10.txt -c pytorch -c conda-forge -c nvidia
+
+# For application version
+conda install --file enviroment/requirements-torch1.10-application.txt -c pytorch -c conda-forge -c nvidia -c pyg
+```
+or build docker image and run with docker env (cuda 11 and torch 1.10):
+```
+docker build -f enviroment/docker_files/federatedscope-torch1.10.Dockerfile -t alibaba/federatedscope:base-env-torch1.10 .
+docker run --gpus device=all --rm -it --name "fedscope" -w $(pwd) alibaba/federatedscope:base-env-torch1.10 /bin/bash
+```
+If you need to run with down-stream tasks such as graph FL, change the requirement/docker file name into another one when executing the above commands:
+```
+# enviroment/requirements-torch1.10.txt ->
+enviroment/requirements-torch1.10-application.txt
+
+# enviroment/docker_files/federatedscope-torch1.10.Dockerfile ->
+enviroment/docker_files/federatedscope-torch1.10-application.Dockerfile
+```
+Note: You can choose to use cuda 10 and torch 1.8 via changing `torch1.10` to `torch1.8`.
+The docker images are based on the nvidia-docker. Please pre-install the NVIDIA drivers and `nvidia-docker2` in the host machine. See more details [here](https://github.com/alibaba/FederatedScope/tree/master/enviroment/docker_files).
+
+Finally, after all the dependencies are installed, run:
+```bash
+python setup.py install
+```
+
+### Step 2. Prepare datasets
+
+To run an FL task, users should prepare a dataset.
+The DataZoo provided in FederatedScope can help to automatically download and preprocess widely-used public datasets for various FL applications, including CV, NLP, graph learning, recommendation, etc. Users can directly specify `cfg.data.type = DATASET_NAME`in the configuration. For example,
+
+```bash
+cfg.data.type = 'femnist'
+```
+
+To use customized datasets, you need to prepare the datasets following a certain format and register it. Please refer to [Customized Datasets](https://federatedscope.io/docs/own-case/#data) for more details.
+
+### Step 3. Prepare models
+
+Then, users should specify the model architecture that will be trained in the FL course.
+FederatedScope provides a ModelZoo that contains the implementation of widely adopted model architectures for various FL applications. Users can set up `cfg.model.type = MODEL_NAME` to apply a specific model architecture in FL tasks. For example,
+
+```yaml
+cfg.model.type = 'convnet2'
+```
+
+FederatedScope allows users to use customized models via registering. Please refer to [Customized Models](https://federatedscope.io/docs/own-case/#model) for more details about how to customize a model architecture.
+
+### Step 4. Start running an FL task
+
+Note that FederatedScope provides a unified interface for both standalone mode and distributed mode, and allows users to change via configuring.
+
+#### Standalone mode
+
+The standalone mode in FederatedScope means to simulate multiple participants (servers and clients) in a single device, while participants' data are isolated from each other and their models might be shared via message passing.
+
+Here we demonstrate how to run a standard FL task with FederatedScope, with setting `cfg.data.type = 'FEMNIST'`and `cfg.model.type = 'ConvNet2'` to run vanilla FedAvg for an image classification task. Users can customize training configurations, such as `cfg.federated.total_round_num`, `cfg.data.batch_size`, and `cfg.optimizer.lr`, in the configuration (a .yaml file), and run a standard FL task as:
+
+```bash
+# Run with default configurations
+python federatedscope/main.py --cfg federatedscope/example_configs/femnist.yaml
+# Or with custom configurations
+python federatedscope/main.py --cfg federatedscope/example_configs/femnist.yaml federated.total_round_num 50 data.batch_size 128
+```
+
+Then you can observe some monitored metrics during the training process as:
+
+```
+INFO: Server #0 has been set up ...
+INFO: Model meta-info: .
+... ...
+INFO: Client has been set up ...
+INFO: Model meta-info: .
+... ...
+INFO: {'Role': 'Client #5', 'Round': 0, 'Results_raw': {'train_loss': 207.6341676712036, 'train_acc': 0.02, 'train_total': 50, 'train_loss_regular': 0.0, 'train_avg_loss': 4.152683353424072}}
+INFO: {'Role': 'Client #1', 'Round': 0, 'Results_raw': {'train_loss': 209.0940284729004, 'train_acc': 0.02, 'train_total': 50, 'train_loss_regular': 0.0, 'train_avg_loss': 4.1818805694580075}}
+INFO: {'Role': 'Client #8', 'Round': 0, 'Results_raw': {'train_loss': 202.24929332733154, 'train_acc': 0.04, 'train_total': 50, 'train_loss_regular': 0.0, 'train_avg_loss': 4.0449858665466305}}
+INFO: {'Role': 'Client #6', 'Round': 0, 'Results_raw': {'train_loss': 209.43883895874023, 'train_acc': 0.06, 'train_total': 50, 'train_loss_regular': 0.0, 'train_avg_loss': 4.1887767791748045}}
+INFO: {'Role': 'Client #9', 'Round': 0, 'Results_raw': {'train_loss': 208.83140087127686, 'train_acc': 0.0, 'train_total': 50, 'train_loss_regular': 0.0, 'train_avg_loss': 4.1766280174255375}}
+INFO: ----------- Starting a new training round (Round #1) -------------
+... ...
+INFO: Server #0: Training is finished! Starting evaluation.
+INFO: Client #1: (Evaluation (test set) at Round #20) test_loss is 163.029045
+... ...
+INFO: Server #0: Final evaluation is finished! Starting merging results.
+... ...
+```
+
+#### Distributed mode
+
+The distributed mode in FederatedScope denotes running multiple procedures to build up an FL course, where each procedure plays as a participant (server or client) that instantiates its model and loads its data. The communication between participants is already provided by the communication module of FederatedScope.
+
+To run with distributed mode, you only need to:
+
+- Prepare isolated data file and set up `cfg.distribute.data_file = PATH/TO/DATA` for each participant;
+- Change `cfg.federate.model = 'distributed'`, and specify the role of each participant by `cfg.distributed.role = 'server'/'client'`.
+- Set up a valid address by `cfg.distribute.host = x.x.x.x` and `cfg.distribute.port = xxxx`. (Note that for a server, you need to set up server_host/server_port for listening messge, while for a client, you need to set up client_host/client_port for listening and server_host/server_port for sending join-in applications when building up an FL course)
+
+We prepare a synthetic example for running with distributed mode:
+
+```bash
+# For server
+python main.py --cfg federatedscope/example_configs/distributed_server.yaml distribute.data_file 'PATH/TO/DATA' distribute.server_host x.x.x.x distribute.server_port xxxx
+
+# For clients
+python main.py --cfg federatedscope/example_configs/distributed_client_1.yaml distribute.data_file 'PATH/TO/DATA' distribute.server_host x.x.x.x distribute.server_port xxxx distribute.client_host x.x.x.x distribute.client_port xxxx
+python main.py --cfg federatedscope/example_configs/distributed_client_2.yaml distribute.data_file 'PATH/TO/DATA' distribute.server_host x.x.x.x distribute.server_port xxxx distribute.client_host x.x.x.x distribute.client_port xxxx
+python main.py --cfg federatedscope/example_configs/distributed_client_3.yaml distribute.data_file 'PATH/TO/DATA' distribute.server_host x.x.x.x distribute.server_port xxxx distribute.client_host x.x.x.x distribute.client_port xxxx
+```
+
+An executable example with generated toy data can be run with:
+```bash
+# Generate the toy data
+python scripts/gen_data.py
+
+# Firstly start the server that is waiting for clients to join in
+python federatedscope/main.py --cfg federatedscope/example_configs/distributed_server.yaml distribute.data_file toy_data/server_data distribute.server_host 127.0.0.1 distribute.server_port 50051
+
+# Start the client #1 (with another process)
+python federatedscope/main.py --cfg federatedscope/example_configs/distributed_client_1.yaml distribute.data_file toy_data/client_1_data distribute.server_host 127.0.0.1 distribute.server_port 50051 distribute.client_host 127.0.0.1 distribute.client_port 50052
+# Start the client #2 (with another process)
+python federatedscope/main.py --cfg federatedscope/example_configs/distributed_client_2.yaml distribute.data_file toy_data/client_2_data distribute.server_host 127.0.0.1 distribute.server_port 50051 distribute.client_host 127.0.0.1 distribute.client_port 50053
+# Start the client #3 (with another process)
+python federatedscope/main.py --cfg federatedscope/example_configs/distributed_client_3.yaml distribute.data_file toy_data/client_3_data distribute.server_host 127.0.0.1 distribute.server_port 50051 distribute.client_host 127.0.0.1 distribute.client_port 50054
+```
+
+And you can observe the results as (the IP addresses are anonymized with 'x.x.x.x'):
+
+```
+INFO: Server #0: Listen to x.x.x.x:xxxx...
+INFO: Server #0 has been set up ...
+Model meta-info: .
+... ...
+INFO: Client: Listen to x.x.x.x:xxxx...
+INFO: Client (address x.x.x.x:xxxx) has been set up ...
+Client (address x.x.x.x:xxxx) is assigned with #1.
+INFO: Model meta-info: .
+... ...
+{'Role': 'Client #2', 'Round': 0, 'Results_raw': {'train_avg_loss': 5.215108394622803, 'train_loss': 333.7669372558594, 'train_total': 64}}
+{'Role': 'Client #1', 'Round': 0, 'Results_raw': {'train_total': 64, 'train_loss': 290.9668884277344, 'train_avg_loss': 4.54635763168335}}
+----------- Starting a new training round (Round #1) -------------
+... ...
+INFO: Server #0: Training is finished! Starting evaluation.
+INFO: Client #1: (Evaluation (test set) at Round #20) test_loss is 30.387419
+... ...
+INFO: Server #0: Final evaluation is finished! Starting merging results.
+... ...
+```
+
+
+## Advanced
+
+As a comprehensive FL platform, FederatedScope provides the fundamental implementation to support requirements of various FL applications and frontier studies, towards both convenient usage and flexible extension, including:
+
+- **Personalized Federated Learning**: Client-specific model architectures and training configurations are applied to handle the non-IID issues caused by the diverse data distributions and heterogeneous system resources.
+- **Federated Hyperparameter Optimization**: When hyperparameter optimization (HPO) comes to Federated Learning, each attempt is extremely costly due to multiple rounds of communication across participants. It is worth noting that HPO under the FL is unique and more techniques should be promoted such as low-fidelity HPO.
+- **Privacy Attacker**: The privacy attack algorithms are important and convenient to verify the privacy protection strength of the design FL systems and algorithms, which is growing along with Federated Learning.
+- **Graph Federated Learning**: Working on the ubiquitous graph data, Graph Federated Learning aims to exploit isolated sub-graph data to learn a global model, and has attracted increasing popularity.
+- **Recommendation**: As a number of laws and regulations go into effect all over the world, more and more people are aware of the importance of privacy protection, which urges the recommender system to learn from user data in a privacy-preserving manner.
+- **Differential Privacy**: Different from the encryption algorithms that require a large amount of computation resources, differential privacy is an economical yet flexible technique to protect privacy, which has achieved great success in database and is ever-growing in federated learning.
+- ...
+
+More supports are coming soon! We have prepared a [tutorial](https://federatedscope.io/) to provide more details about how to utilize FederatedScope to enjoy your journey of Federated Learning!
+
+Materials of related topics are constantly being updated, please refer to [FL-Recommendation](https://github.com/alibaba/FederatedScope/tree/master/materials/paper_list/FL-Recommendation), [Federated-HPO](https://github.com/alibaba/FederatedScope/tree/master/materials/paper_list/Federated_HPO), [Personalized FL](https://github.com/alibaba/FederatedScope/tree/master/materials/paper_list/Personalized_FL), [Federated Graph Learning](https://github.com/alibaba/FederatedScope/tree/master/materials/paper_list/Federated_Graph_Learning), [FL-NLP](https://github.com/alibaba/FederatedScope/tree/master/materials/paper_list/FL-NLP), and so on.
+
+## Documentation
+
+The classes and methods of FederatedScope have been well documented so that users can generate the API references by:
+
+```shell
+pip install -r requirements-doc.txt
+make html
+```
+
+We put the API references on our [website](https://federatedscope.io/refs/index).
+
+## License
+
+FederatedScope is released under Apache License 2.0.
+
+## Publications
+If you find FederatedScope useful for your research or development, please cite the following paper:
+```
+@article{federatedscope,
+ title = {FederatedScope: A Flexible Federated Learning Platform for Heterogeneity},
+ author = {Xie, Yuexiang and Wang, Zhen and Chen, Daoyuan and Gao, Dawei and Yao, Liuyi and Kuang, Weirui and Li, Yaliang and Ding, Bolin and Zhou, Jingren},
+ journal={arXiv preprint arXiv:2204.05011},
+ year = {2022},
+}
+```
+More publications can be found in the [Publications](https://federatedscope.io/year-archive/).
+
+## Contributing
+
+We **greatly appreciate** any contribution to FederatedScope! You can refer to [Contributing to FederatedScope](https://federatedscope.io/docs/contributor/) for more details.
+
+Welcome to join in our [Slack channel](https://federatedscopeteam.slack.com/archives/C03E5LGQH7S), or DingDing group (please scan the following QR code) for discussion.
+
+
diff --git a/benchmark/FedHPOB/README.md b/benchmark/FedHPOB/README.md
new file mode 100644
index 000000000..3cb2f941d
--- /dev/null
+++ b/benchmark/FedHPOB/README.md
@@ -0,0 +1,2 @@
+# FedHPOB
+FedHPOB is library for providing federated learning benchmarks for (multi-fidelity) hyperparameter optimization.
\ No newline at end of file
diff --git a/benchmark/FedHPOB/data/surrogate_model/README.md b/benchmark/FedHPOB/data/surrogate_model/README.md
new file mode 100644
index 000000000..49a19beb9
--- /dev/null
+++ b/benchmark/FedHPOB/data/surrogate_model/README.md
@@ -0,0 +1 @@
+This is where the pickled surrogate model is stored.
\ No newline at end of file
diff --git a/benchmark/FedHPOB/data/tabular_data/README.md b/benchmark/FedHPOB/data/tabular_data/README.md
new file mode 100644
index 000000000..77d039da4
--- /dev/null
+++ b/benchmark/FedHPOB/data/tabular_data/README.md
@@ -0,0 +1 @@
+This is where the logs and dataframes are stored.
\ No newline at end of file
diff --git a/benchmark/FedHPOB/fedhpob/__init__.py b/benchmark/FedHPOB/fedhpob/__init__.py
new file mode 100644
index 000000000..c82f0a503
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/__init__.py
@@ -0,0 +1,19 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+__version__ = '0.0.1'
+
+
+def _setup_logger():
+ import logging
+
+ logging_fmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
+ logger = logging.getLogger("fedhpob")
+ handler = logging.StreamHandler()
+ handler.setFormatter(logging.Formatter(logging_fmt))
+ logger.addHandler(handler)
+ logger.propagate = False
+
+
+_setup_logger()
diff --git a/benchmark/FedHPOB/fedhpob/benchmarks/__init__.py b/benchmark/FedHPOB/fedhpob/benchmarks/__init__.py
new file mode 100644
index 000000000..d193e7517
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/benchmarks/__init__.py
@@ -0,0 +1,5 @@
+from fedhpob.benchmarks.raw_benchmark import RawBenchmark
+from fedhpob.benchmarks.tabular_benchmark import TabularBenchmark
+from fedhpob.benchmarks.surrogate_benchmark import SurrogateBenchmark
+
+__all__ = ['RawBenchmark', 'TabularBenchmark', 'SurrogateBenchmark']
diff --git a/benchmark/FedHPOB/fedhpob/benchmarks/base_benchmark.py b/benchmark/FedHPOB/fedhpob/benchmarks/base_benchmark.py
new file mode 100644
index 000000000..13fadc153
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/benchmarks/base_benchmark.py
@@ -0,0 +1,76 @@
+import abc
+import os
+import numpy as np
+from federatedscope.core.configs.config import global_cfg
+from federatedscope.core.auxiliaries.data_builder import get_data
+from fedhpob.utils.util import disable_fs_logger
+from fedhpob.utils.cost_model import get_cost_model
+
+
+class BaseBenchmark(abc.ABC):
+ def __init__(self, model, dname, algo, rng=None, **kwargs):
+ """
+
+ :param rng:
+ :param kwargs:
+ """
+ if rng is not None:
+ self.rng = rng
+ else:
+ self.rng = np.random.RandomState()
+ self.configuration_space = self.get_configuration_space()
+ self.fidelity_space = self.get_fidelity_space()
+
+ # Load data and modify cfg of FS.
+ self.cfg = global_cfg.clone()
+ filepath = os.path.join('scripts', model, f'{dname}.yaml')
+ self.cfg.merge_from_file(filepath)
+ self.cfg.data.type = dname
+ self.data, modified_cfg = get_data(config=self.cfg.clone())
+ self.cfg.merge_from_other_cfg(modified_cfg)
+ disable_fs_logger(self.cfg, True)
+
+ def __call__(self, configuration, fidelity, seed=1, **kwargs):
+ return self.objective_function(configuration=configuration,
+ fidelity=fidelity,
+ seed=seed,
+ **kwargs)
+
+ def _check(self, configuration, fidelity):
+ pass
+
+ def _cost(self, configuration, fidelity, **kwargs):
+ cost_model = get_cost_model(mode=self.cost_mode)
+ t = cost_model(self.cfg, configuration, fidelity, self.data, **kwargs)
+ return t
+
+ def _init_fidelity(self, fidelity):
+ if not fidelity:
+ fidelity = {
+ 'sample_client': 1.0,
+ 'round': self.get_fidelity_space()['round'][-1] //
+ self.eval_freq
+ }
+ elif 'round' not in fidelity:
+ fidelity['round'] = self.get_fidelity_space(
+ )['round'][-1] // self.eval_freq
+ return fidelity
+
+ @abc.abstractmethod
+ def objective_function(self, configuration, fidelity, seed):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_configuration_space(self):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_fidelity_space(self):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_meta_info(self):
+ raise NotImplementedError()
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}({self.get_meta_info()})'
diff --git a/benchmark/FedHPOB/fedhpob/benchmarks/raw_benchmark.py b/benchmark/FedHPOB/fedhpob/benchmarks/raw_benchmark.py
new file mode 100644
index 000000000..d346adb17
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/benchmarks/raw_benchmark.py
@@ -0,0 +1,74 @@
+import datetime
+from federatedscope.core.auxiliaries.utils import setup_seed
+from federatedscope.core.auxiliaries.worker_builder import get_client_cls, get_server_cls
+from federatedscope.core.fed_runner import FedRunner
+
+from fedhpob.benchmarks.base_benchmark import BaseBenchmark
+from fedhpob.utils.util import disable_fs_logger
+from fedhpob.utils.cost_model import merge_cfg
+
+
+class RawBenchmark(BaseBenchmark):
+ def __init__(self,
+ model,
+ dname,
+ algo,
+ rng=None,
+ cost_mode='estimated',
+ **kwargs):
+ self.model, self.dname, self.algo, self.cost_mode = model, dname, algo, cost_mode
+ self.device = kwargs['device']
+ super(RawBenchmark, self).__init__(model, dname, algo, rng, **kwargs)
+
+ def _run_fl(self, configuration, fidelity, key='val_avg_loss', seed=1):
+ init_cfg = self.cfg.clone()
+ disable_fs_logger(init_cfg, True)
+ setup_seed(seed)
+ modified_cfg = merge_cfg(init_cfg, configuration, fidelity)
+ init_cfg.merge_from_other_cfg(modified_cfg)
+ init_cfg.device = self.device
+ init_cfg.freeze()
+ runner = FedRunner(data=self.data,
+ server_class=get_server_cls(init_cfg),
+ client_class=get_client_cls(init_cfg),
+ config=init_cfg.clone())
+ results = runner.run()
+ # so that we could modify cfg in the next trial
+ init_cfg.defrost()
+ if 'server_global_eval' in results:
+ return [results['server_global_eval'][key]]
+ else:
+ return [results['client_summarized_weighted_avg'][key]]
+
+ def objective_function(self,
+ configuration,
+ fidelity=None,
+ key='val_avg_loss',
+ seed=1,
+ **kwargs):
+ fidelity = self._init_fidelity(fidelity)
+ self._check(configuration, fidelity)
+ start_time = datetime.datetime.now()
+ function_value = self._run_fl(configuration, fidelity, key, seed)[0]
+ end_time = datetime.datetime.now()
+ if self._cost(configuration, fidelity, **kwargs):
+ cost = self._cost(configuration, fidelity, **kwargs)
+ else:
+ # TODO: use time from FS monitor
+ cost = end_time - start_time
+
+ return {'function_value': function_value, 'cost': cost}
+
+ def get_configuration_space(self):
+ return []
+
+ def get_fidelity_space(self):
+ return []
+
+ def get_meta_info(self):
+ return {
+ 'model': self.model,
+ 'dname': self.dname,
+ 'configuration_space': self.configuration_space,
+ 'fidelity_space': self.fidelity_space
+ }
diff --git a/benchmark/FedHPOB/fedhpob/benchmarks/surrogate_benchmark.py b/benchmark/FedHPOB/fedhpob/benchmarks/surrogate_benchmark.py
new file mode 100644
index 000000000..887cdf7a7
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/benchmarks/surrogate_benchmark.py
@@ -0,0 +1,68 @@
+from fedhpob.benchmarks.base_benchmark import BaseBenchmark
+from fedhpob.utils.surrogate_dataloader import build_surrogate_model, load_surrogate_model
+
+
+class SurrogateBenchmark(BaseBenchmark):
+ def __init__(self,
+ model,
+ dname,
+ algo,
+ modeldir=None,
+ datadir='data/tabular_data/',
+ rng=None,
+ cost_mode='estimated',
+ **kwargs):
+ self.model, self.dname, self.algo, self.cost_mode = model, dname, algo, cost_mode
+ assert datadir or modeldir, 'Please provide at least one of `datadir` and `modeldir`.'
+ if not modeldir:
+ self.surrogate_models, self.meta_info, self.X, self.Y = build_surrogate_model(
+ datadir, model, dname, algo)
+ else:
+ self.surrogate_models, self.meta_info, self.X, self.Y = load_surrogate_model(
+ modeldir, model, dname, algo)
+ super(SurrogateBenchmark, self).__init__(model, dname, algo, rng,
+ **kwargs)
+
+ def _check(self, configuration, fidelity):
+ for key in configuration:
+ assert key in self.configuration_space, 'configuration invalid, check `configuration_space` for help.'
+ for key in fidelity:
+ assert key in self.fidelity_space, 'fidelity invalid, check `fidelity_space` for help.'
+
+ def _make_prediction(self, configuration, fidelity, seed):
+ model = self.surrogate_models[self.rng.randint(seed) %
+ len(self.surrogate_models)]
+ x_in = []
+ for key in self.configuration_space:
+ x_in.append(configuration[key])
+ for key in self.fidelity_space:
+ x_in.append(fidelity[key])
+ return model.predict([x_in])
+
+ # noinspection DuplicatedCode
+ def objective_function(self,
+ configuration,
+ fidelity=None,
+ seed=1,
+ **kwargs):
+ fidelity = self._init_fidelity(fidelity)
+ self._check(configuration, fidelity)
+ return {
+ 'function_value': self._make_prediction(configuration, fidelity,
+ seed),
+ 'cost': self._cost(configuration, fidelity, **kwargs)
+ }
+
+ def get_configuration_space(self):
+ return self.meta_info['configuration_space']
+
+ def get_fidelity_space(self):
+ return self.meta_info['fidelity_space']
+
+ def get_meta_info(self):
+ return {
+ 'model': self.model,
+ 'dname': self.dname,
+ 'configuration_space': self.configuration_space,
+ 'fidelity_space': self.fidelity_space
+ }
diff --git a/benchmark/FedHPOB/fedhpob/benchmarks/tabular_benchmark.py b/benchmark/FedHPOB/fedhpob/benchmarks/tabular_benchmark.py
new file mode 100644
index 000000000..d3303e1bd
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/benchmarks/tabular_benchmark.py
@@ -0,0 +1,91 @@
+import datetime
+import numpy as np
+from fedhpob.utils.tabular_dataloader import load_data
+
+from fedhpob.benchmarks.base_benchmark import BaseBenchmark
+
+
+class TabularBenchmark(BaseBenchmark):
+ def __init__(self,
+ model,
+ dname,
+ algo,
+ datadir='data/tabular_data/',
+ rng=None,
+ cost_mode='estimated',
+ **kwargs):
+ self.model, self.dname, self.algo, self.cost_mode = model, dname, algo, cost_mode
+ self.table, self.meta_info = load_data(datadir, model, dname, algo)
+ self.eval_freq = self.meta_info['eval_freq']
+ super(TabularBenchmark, self).__init__(model, dname, algo, rng,
+ **kwargs)
+
+ def _check(self, configuration, fidelity):
+ for key, value in configuration.items():
+ assert value in self.configuration_space[
+ key], 'configuration invalid, check `configuration_space` for help.'
+ for key, value in fidelity.items():
+ assert value in self.fidelity_space[
+ key], 'fidelity invalid, check `fidelity_space` for help.'
+
+ def _search(self, configuration, fidelity):
+ # For configuration
+ mask = np.array([True] * self.table.shape[0])
+ for col in configuration.keys():
+ mask *= (self.table[col].values == configuration[col])
+ idx = np.where(mask)
+ result = self.table.iloc[idx]
+
+ # For fidelity
+ mask = np.array([True] * result.shape[0])
+ for col in fidelity.keys():
+ if col == 'round':
+ continue
+ mask *= (result[col].values == fidelity[col])
+ idx = np.where(mask)
+ result = result.iloc[idx]["result"]
+ return result
+
+ def objective_function(self,
+ configuration,
+ fidelity,
+ key='val_acc',
+ seed=1,
+ **kwargs):
+ fidelity = self._init_fidelity(fidelity)
+ self._check(configuration, fidelity)
+ result = self._search(
+ {
+ 'seed': self.rng.randint(seed) %
+ len(self.configuration_space['seed']) + 1,
+ **configuration
+ }, fidelity)
+ index = list(result.keys())
+ assert len(index) == 1, 'Multiple results.'
+ filterd_result = eval(result[index[0]])
+ assert key in filterd_result.keys(
+ ), f'`key` should be in {filterd_result.keys()}.'
+ # Find the best val round.
+ val_loss = filterd_result['val_avg_loss']
+ best_round = np.argmin(val_loss[:fidelity['round'] + 1])
+ function_value = filterd_result[key][best_round]
+ if self._cost(configuration, fidelity, **kwargs):
+ cost = self._cost(configuration, fidelity, **kwargs)
+ else:
+ cost = filterd_result['tol_time']
+
+ return {'function_value': function_value, 'cost': cost}
+
+ def get_configuration_space(self):
+ return self.meta_info['configuration_space']
+
+ def get_fidelity_space(self):
+ return self.meta_info['fidelity_space']
+
+ def get_meta_info(self):
+ return {
+ 'model': self.model,
+ 'dname': self.dname,
+ 'configuration_space': self.configuration_space,
+ 'fidelity_space': self.fidelity_space
+ }
diff --git a/benchmark/FedHPOB/fedhpob/config.py b/benchmark/FedHPOB/fedhpob/config.py
new file mode 100644
index 000000000..abaa62ace
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/config.py
@@ -0,0 +1,416 @@
+import ConfigSpace as CS
+from yacs.config import CfgNode as CN
+from fedhpob.benchmarks import TabularBenchmark
+from fedhpob.benchmarks import RawBenchmark
+from fedhpob.benchmarks import SurrogateBenchmark
+
+fhb_cfg = CN()
+
+
+def get_cs(dname, model, mode, alg='avg'):
+ # raw and surrogate are ONLY FOR NIPS2022
+ configuration_space = CS.ConfigurationSpace()
+ fidelity_space = CS.ConfigurationSpace()
+ # configuration_space
+ if dname in ['cora', 'citeseer', 'pubmed']:
+ # GNN tabular, raw and surrogate
+ fidelity_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('round',
+ choices=[x for x in range(500)]))
+ fidelity_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('sample_rate',
+ choices=[0.2, 0.4, 0.6, 0.8, 1.0]))
+ if mode == 'tabular':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('lr',
+ choices=[
+ 0.01, 0.01668, 0.02783,
+ 0.04642, 0.07743, 0.12915,
+ 0.21544, 0.35938, 0.59948, 1.0
+ ]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('wd',
+ choices=[0.0, 0.001, 0.01, 0.1]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('dropout', choices=[0.0, 0.5]))
+ if alg == 'avg':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter(
+ 'step', choices=[1, 2, 3, 4, 5, 6, 7, 8]))
+ else:
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step', choices=[1]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('lrserver',
+ choices=[0.1, 0.5, 1.0]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('momentumsserver',
+ choices=[0.0, 0.9]))
+ elif mode in ['surrogate', 'raw']:
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('lr',
+ lower=1e-2,
+ upper=1.0,
+ log=True))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('wd',
+ choices=[0.0, 0.001, 0.01, 0.1]))
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('dropout', lower=.0, upper=.5))
+ if alg == 'avg':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter(
+ 'step', choices=[1, 2, 3, 4, 5, 6, 7, 8]))
+ else:
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step', choices=[1]))
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('lrserver',
+ lower=1e-1,
+ upper=1.0,
+ log=True))
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('momentumsserver',
+ lower=0.0,
+ upper=1.0))
+
+ elif dname in [
+ '10101', '53', '146818', '146821', '9952', '146822', '31', '3917'
+ ]:
+ # Openml tabular, raw and surrogate
+ fidelity_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('round',
+ choices=[x for x in range(250)]))
+ fidelity_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('sample_rate',
+ choices=[0.2, 0.4, 0.6, 0.8, 1.0]))
+ if model == 'lr':
+ if mode == 'tabular':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter(
+ 'lr', choices=[0.00001, 0.0001, 0.001, 0.01, 0.1,
+ 1.0]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter(
+ 'wd', choices=[0.0, 0.001, 0.01, 0.1]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter(
+ 'batch', choices=[8, 16, 32, 64, 128, 256]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('dropout', choices=[0.0]))
+ if alg == 'avg':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step',
+ choices=[1, 2, 3, 4]))
+ else:
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step', choices=[1]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('lrserver',
+ choices=[0.1, 0.5, 1.0]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('momentumsserver',
+ choices=[0.0, 0.9]))
+ elif mode in ['surrogate', 'raw']:
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('lr',
+ lower=1e-5,
+ upper=1.0,
+ log=True))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter(
+ 'wd', choices=[0.0, 0.001, 0.01, 0.1]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('dropout', choices=[0.0]))
+ if alg == 'avg':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter(
+ 'step', choices=[1, 2, 3, 4, 5, 6, 7, 8]))
+ else:
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step', choices=[1]))
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('lrserver',
+ lower=1e-1,
+ upper=1.0,
+ log=True))
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('momentumsserver',
+ lower=0.0,
+ upper=1.0))
+ elif model == 'mlp':
+ if mode == 'tabular':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter(
+ 'lr', choices=[0.00001, 0.0001, 0.001, 0.01, 0.1,
+ 1.0]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter(
+ 'wd', choices=[0.0, 0.001, 0.01, 0.1]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('batch',
+ choices=[32, 64, 128, 256]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('dropout', choices=[0.0]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('layer', choices=[2, 3, 4]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('hidden',
+ choices=[16, 64, 256]))
+ if alg == 'avg':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step',
+ choices=[1, 2, 3, 4]))
+ else:
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step', choices=[1]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('lrserver',
+ choices=[0.1, 0.5, 1.0]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('momentumsserver',
+ choices=[0.0, 0.9]))
+ elif mode in ['surrogate', 'raw']:
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('lr',
+ lower=1e-5,
+ upper=1.0,
+ log=True))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter(
+ 'wd', choices=[0.0, 0.001, 0.01, 0.1]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('dropout', choices=[0.0]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('layer', choices=[2, 3, 4]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('hidden',
+ choices=[16, 64, 256]))
+ if alg == 'avg':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step',
+ choices=[1, 2, 3, 4]))
+ else:
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step', choices=[1]))
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('lrserver',
+ lower=1e-1,
+ upper=1.0,
+ log=True))
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('momentumsserver',
+ lower=0.0,
+ upper=1.0))
+ elif dname in ['femnist', 'cifar10']:
+ # CNN tabular and surrogate
+ fidelity_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('round',
+ choices=[x for x in range(250)]))
+ fidelity_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('sample_rate',
+ choices=[0.2, 0.4, 0.6, 0.8, 1.0]))
+ if mode == 'tabular':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('lr',
+ choices=[
+ 0.01, 0.01668, 0.02783,
+ 0.04642, 0.07743, 0.12915,
+ 0.21544, 0.35938, 0.59948, 1.0
+ ]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('wd',
+ choices=[0.0, 0.001, 0.01, 0.1]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('dropout', choices=[0.0, 0.5]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('batch', choices=[16, 32, 64]))
+ if alg == 'avg':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step', choices=[1, 2, 3, 4]))
+ elif mode in ['surrogate', 'raw']:
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('lr',
+ lower=1e-2,
+ upper=1.0,
+ log=True))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('wd',
+ choices=[0.0, 0.001, 0.01, 0.1]))
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('dropout', lower=.0, upper=.5))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('batch', choices=[16, 32, 64]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step', choices=[1, 2, 3, 4]))
+ if alg == 'avg':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step', choices=[1, 2, 3, 4]))
+ elif dname in ['sst2', 'cola']:
+ # Transformer tabular and surrogate
+ fidelity_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('round',
+ choices=[x for x in range(40)]))
+ fidelity_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('sample_rate',
+ choices=[0.2, 0.4, 0.6, 0.8, 1.0]))
+ if mode == 'tabular':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('lr',
+ choices=[
+ 0.01, 0.01668, 0.02783,
+ 0.04642, 0.07743, 0.12915,
+ 0.21544, 0.35938, 0.59948, 1.0
+ ]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('wd',
+ choices=[0.0, 0.001, 0.01, 0.1]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('dropout', choices=[0.0, 0.5]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('batch',
+ choices=[8, 16, 32, 64, 128]))
+ if alg == 'avg':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step', choices=[1, 2, 3, 4]))
+ else:
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step', choices=[1]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('lrserver',
+ choices=[0.1, 0.5, 1.0]))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('momentumsserver',
+ choices=[0.0, 0.9]))
+ elif mode in ['surrogate', 'raw']:
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('lr',
+ lower=1e-2,
+ upper=1.0,
+ log=True))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('wd',
+ choices=[0.0, 0.001, 0.01, 0.1]))
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('dropout', lower=.0, upper=.5))
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step',
+ choices=[1, 2, 3, 4, 5, 6, 7, 8]))
+ if alg == 'avg':
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter(
+ 'step', choices=[1, 2, 3, 4, 5, 6, 7, 8]))
+ else:
+ configuration_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('step', choices=[1]))
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('lrserver',
+ lower=1e-1,
+ upper=1.0,
+ log=True))
+ configuration_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('momentumsserver',
+ lower=0.0,
+ upper=1.0))
+ return configuration_space, fidelity_space
+
+
+def initial_cfg(cfg):
+ # ------------------------------------------------------------------------ #
+ # benchmark related options
+ # ------------------------------------------------------------------------ #
+ cfg.benchmark = CN()
+ cfg.benchmark.cls = [{
+ 'raw': RawBenchmark,
+ 'tabular': TabularBenchmark,
+ 'surrogate': SurrogateBenchmark
+ }]
+
+ # ************************************************************************ #
+ cfg.benchmark.type = 'raw'
+ cfg.benchmark.model = 'gcn'
+ cfg.benchmark.data = 'cora'
+ cfg.benchmark.device = 0
+ cfg.benchmark.sample_client = 1.0 # only for optimizer
+ cfg.benchmark.algo = 'avg' # ['avg', 'opt']
+ cfg.benchmark.out_dir = 'exp_results'
+ # ************************************************************************ #
+
+ # ------------------------------------------------------------------------ #
+ # cost related options
+ # ------------------------------------------------------------------------ #
+ cfg.cost = CN()
+ cfg.cost.type = 'estimated' # in ['raw', 'estimated']
+ cfg.cost.c = 1 # lambda for exponential distribution, time consumed in client
+ cfg.cost.time_server = 0 # time consumed in server, `0` for real time
+
+ # bandwidth for estimated cost
+ cfg.cost.bandwidth = CN()
+ cfg.cost.bandwidth.client_up = 0.25 * 1024 * 1024 * 8 / 32
+ cfg.cost.bandwidth.client_down = 0.75 * 1024 * 1024 * 8 / 32
+ cfg.cost.bandwidth.server_up = 0.25 * 1024 * 1024 * 8 / 32
+ cfg.cost.bandwidth.server_down = 0.75 * 1024 * 1024 * 8 / 32
+
+ # ------------------------------------------------------------------------ #
+ # optimizer related options
+ # ------------------------------------------------------------------------ #
+ cfg.optimizer = CN()
+ cfg.optimizer.type = 'de'
+ cfg.optimizer.min_budget = 1
+ cfg.optimizer.max_budget = 243
+ cfg.optimizer.n_iterations = 100000000 # No limits
+ cfg.optimizer.seed = 1
+ cfg.optimizer.limit_time = 86400 # one day
+
+ # ------------------------------------------------------------------------ #
+ # hpbandster related options (rs, bo_kde, hb, bohb)
+ # ------------------------------------------------------------------------ #
+ cfg.optimizer.hpbandster = CN()
+ cfg.optimizer.hpbandster.eta = 3
+ cfg.optimizer.hpbandster.max_stages = 5
+
+ # ------------------------------------------------------------------------ #
+ # smac related options (bo_gp, bo_rf)
+ # ------------------------------------------------------------------------ #
+ cfg.optimizer.smac = CN()
+
+ # ------------------------------------------------------------------------ #
+ # dehb related options (dehb, de)
+ # ------------------------------------------------------------------------ #
+ cfg.optimizer.dehb = CN()
+ cfg.optimizer.dehb.strategy = 'rand1_bin'
+ cfg.optimizer.dehb.mutation_factor = 0.5
+ cfg.optimizer.dehb.crossover_prob = 0.5
+
+ # dehb.dehb
+ cfg.optimizer.dehb.dehb = CN()
+ cfg.optimizer.dehb.dehb.gens = 1
+ cfg.optimizer.dehb.dehb.eta = 3
+ cfg.optimizer.dehb.dehb.async_strategy = 'immediate'
+
+ # dehb.de
+ cfg.optimizer.dehb.de = CN()
+ cfg.optimizer.dehb.de.pop_size = 20
+
+ # ------------------------------------------------------------------------ #
+ # optuna related options (tpe_md, tpe_hb)
+ # ------------------------------------------------------------------------ #
+ cfg.optimizer.optuna = CN()
+ cfg.optimizer.optuna.reduction_factor = 3
+
+
+def add_configs(cfg):
+ # ------------------------------------------------------------------------ #
+ # HPO search space related options, which is fixed when mode is `raw`
+ # ------------------------------------------------------------------------ #
+ configuration_space, fidelity_space = get_cs(cfg.benchmark.data,
+ cfg.benchmark.model,
+ cfg.benchmark.type,
+ cfg.benchmark.algo)
+
+ cfg.benchmark.configuration_space = [configuration_space
+ ] # avoid invalid type
+ cfg.benchmark.fidelity_space = [fidelity_space] # avoid invalid type
+
+
+initial_cfg(fhb_cfg)
diff --git a/benchmark/FedHPOB/fedhpob/optimizers/__init__.py b/benchmark/FedHPOB/fedhpob/optimizers/__init__.py
new file mode 100644
index 000000000..663d31b55
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/optimizers/__init__.py
@@ -0,0 +1,6 @@
+from fedhpob.optimizers.dehb_optimizer import run_dehb
+from fedhpob.optimizers.hpbandster_optimizer import run_hpbandster
+from fedhpob.optimizers.optuna_optimizer import run_optuna
+from fedhpob.optimizers.smac_optimizer import run_smac
+
+__all__ = ['run_dehb', 'run_hpbandster', 'run_optuna', 'run_smac']
\ No newline at end of file
diff --git a/benchmark/FedHPOB/fedhpob/optimizers/dehb_optimizer.py b/benchmark/FedHPOB/fedhpob/optimizers/dehb_optimizer.py
new file mode 100644
index 000000000..67805e5b7
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/optimizers/dehb_optimizer.py
@@ -0,0 +1,100 @@
+"""
+https://github.com/automl/DEHB/blob/master/examples/00_interfacing_DEHB.ipynb
+How to use the DEHB Optimizer
+1) Download the Source Code
+git clone https://github.com/automl/DEHB.git
+# We are currently using the first version of it.
+cd DEHB
+git checkout b8dcba7b38bf6e7fc8ce3e84ea567b66132e0eb5
+2) Add the project to your Python Path
+export PYTHONPATH=~/DEHB:$PYTHONPATH
+3) Requirements
+- dask distributed:
+```
+conda install dask distributed -c conda-forge
+```
+OR
+```
+python -m pip install dask distributed --upgrade
+```
+- Other things to install:
+```
+pip install numpy, ConfigSpace
+```
+"""
+
+import time
+import random
+import logging
+from dehb.optimizers import DE, DEHB
+from fedhpob.config import fhb_cfg
+from fedhpob.utils.monitor import Monitor
+
+logging.basicConfig(level=logging.WARNING)
+
+
+def run_dehb(cfg):
+ def objective(config, budget=None):
+ if cfg.optimizer.type == 'de':
+ budget = cfg.optimizer.max_budget
+ main_fidelity = {
+ 'round': int(budget),
+ 'sample_client': cfg.benchmark.sample_client
+ }
+ t_start = time.time()
+ res = benchmark(config,
+ main_fidelity,
+ seed=random.randint(1, 99),
+ key='val_avg_loss',
+ fhb_cfg=cfg)
+ monitor(res=res, sim_time=time.time() - t_start, budget=budget)
+ fitness, cost = res['function_value'], res['cost']
+ return fitness, cost
+
+ monitor = Monitor(cfg)
+ benchmark = cfg.benchmark.cls[0][cfg.benchmark.type](
+ cfg.benchmark.model,
+ cfg.benchmark.data,
+ cfg.benchmark.algo,
+ device=cfg.benchmark.device)
+ if cfg.optimizer.type == 'de':
+ optimizer = DE(
+ cs=cfg.benchmark.configuration_space[0],
+ dimensions=len(
+ cfg.benchmark.configuration_space[0].get_hyperparameters()),
+ f=objective,
+ pop_size=cfg.optimizer.dehb.de.pop_size,
+ mutation_factor=cfg.optimizer.dehb.mutation_factor,
+ crossover_prob=cfg.optimizer.dehb.crossover_prob,
+ strategy=cfg.optimizer.dehb.strategy)
+ traj, runtime, history = optimizer.run(
+ generations=cfg.optimizer.n_iterations, verbose=False)
+ elif cfg.optimizer.type == 'dehb':
+ optimizer = DEHB(
+ cs=cfg.benchmark.configuration_space[0],
+ dimensions=len(
+ cfg.benchmark.configuration_space[0].get_hyperparameters()),
+ f=objective,
+ strategy=cfg.optimizer.dehb.strategy,
+ mutation_factor=cfg.optimizer.dehb.mutation_factor,
+ crossover_prob=cfg.optimizer.dehb.crossover_prob,
+ eta=cfg.optimizer.dehb.dehb.eta,
+ min_budget=cfg.optimizer.min_budget,
+ max_budget=cfg.optimizer.max_budget,
+ generations=cfg.optimizer.dehb.dehb.gens,
+ n_workers=1)
+ traj, runtime, history = optimizer.run(
+ iterations=cfg.optimizer.n_iterations, verbose=False)
+ else:
+ raise NotImplementedError
+
+ return monitor.history_results
+
+
+if __name__ == "__main__":
+ # Please specific args for the experiment.
+ results = []
+ for opt_name in ['de', 'dehb']:
+ fhb_cfg.optimizer.type = opt_name
+ results.append(run_dehb(fhb_cfg))
+ print(results)
diff --git a/benchmark/FedHPOB/fedhpob/optimizers/hpbandster_optimizer.py b/benchmark/FedHPOB/fedhpob/optimizers/hpbandster_optimizer.py
new file mode 100644
index 000000000..30d2f9fd2
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/optimizers/hpbandster_optimizer.py
@@ -0,0 +1,120 @@
+# Implement RS, BO_KDE, HB, BOHB in `hpbandster`.
+
+import time
+import random
+import logging
+import hpbandster.core.nameserver as hpns
+from hpbandster.core.worker import Worker
+from hpbandster.optimizers import BOHB, HyperBand, RandomSearch
+
+from fedhpob.config import fhb_cfg
+from fedhpob.utils.monitor import Monitor
+
+logging.basicConfig(level=logging.WARNING)
+
+
+class MyWorker(Worker):
+ def __init__(self,
+ benchmark,
+ monitor,
+ sleep_interval=0,
+ cfg=None,
+ **kwargs):
+ super(MyWorker, self).__init__(**kwargs)
+ self.benchmark = benchmark
+ self.monitor = monitor
+ self.sleep_interval = sleep_interval
+ self.cfg = cfg
+
+ def compute(self, config, budget, **kwargs):
+ """
+ Simple example for a compute function
+ The loss is just a the config + some noise (that decreases with the budget)
+ For dramatization, the function can sleep for a given interval to emphasizes
+ the speed ups achievable with parallel workers.
+ Args:
+ config: dictionary containing the sampled configurations by the optimizer
+ budget: (float) amount of time/epochs/etc. the model can use to train
+ Returns:
+ dictionary with mandatory fields:
+ 'loss' (scalar)
+ 'info' (dict)
+ """
+ main_fidelity = {
+ 'round': int(budget),
+ 'sample_client': self.cfg.benchmark.sample_client
+ }
+ t_start = time.time()
+ res = self.benchmark(config,
+ main_fidelity,
+ seed=random.randint(1, 99),
+ key='val_avg_loss',
+ fhb_cfg=self.cfg)
+ time.sleep(self.sleep_interval)
+ self.monitor(res=res, sim_time=time.time() - t_start, budget=budget)
+ return ({
+ 'loss': float(res['function_value']
+ ), # this is a mandatory field to run hyperband
+ 'info': res # can be used for any user-defined information - also mandatory
+ })
+
+
+def run_hpbandster(cfg):
+ if cfg.optimizer.type == 'bo_kde':
+ cfg.optimizer.min_budget = cfg.optimizer.max_budget
+ monitor = Monitor(cfg)
+ NS = hpns.NameServer(run_id=cfg.optimizer.type, host='127.0.0.1')
+ NS.start()
+ cfg = cfg.clone()
+ benchmark = cfg.benchmark.cls[0][cfg.benchmark.type](
+ cfg.benchmark.model,
+ cfg.benchmark.data,
+ cfg.benchmark.algo,
+ device=cfg.benchmark.device)
+ w = MyWorker(benchmark=benchmark,
+ monitor=monitor,
+ sleep_interval=0,
+ cfg=cfg,
+ nameserver='127.0.0.1',
+ run_id=cfg.optimizer.type)
+ w.run(background=True)
+
+ # Allow at most max_stages stages
+ tmp = cfg.optimizer.max_budget
+ for i in range(cfg.optimizer.hpbandster.max_stages):
+ tmp /= cfg.optimizer.hpbandster.eta
+ if tmp > cfg.optimizer.min_budget:
+ cfg.optimizer.min_budget = tmp
+
+ opt_kwargs = {
+ 'configspace': cfg.benchmark.configuration_space[0],
+ 'run_id': cfg.optimizer.type,
+ 'nameserver': '127.0.0.1',
+ 'eta': cfg.optimizer.hpbandster.eta,
+ 'min_budget': cfg.optimizer.min_budget,
+ 'max_budget': cfg.optimizer.max_budget
+ }
+ if cfg.optimizer.type == 'rs':
+ optimizer = RandomSearch(**opt_kwargs)
+ elif cfg.optimizer.type == 'bo_kde':
+ optimizer = BOHB(**opt_kwargs)
+ elif cfg.optimizer.type == 'hb':
+ optimizer = HyperBand(**opt_kwargs)
+ elif cfg.optimizer.type == 'bohb':
+ optimizer = BOHB(**opt_kwargs)
+ else:
+ raise NotImplementedError
+ res = optimizer.run(n_iterations=cfg.optimizer.n_iterations)
+
+ optimizer.shutdown(shutdown_workers=True)
+ NS.shutdown()
+ all_runs = res.get_all_runs()
+ return [x.info for x in all_runs]
+
+
+if __name__ == "__main__":
+ results = []
+ for opt_name in ['rs', 'bo_kde', 'hb', 'bohb']:
+ fhb_cfg.optimizer.type = opt_name
+ results.append(run_hpbandster(fhb_cfg))
+ print(results)
diff --git a/benchmark/FedHPOB/fedhpob/optimizers/optuna_optimizer.py b/benchmark/FedHPOB/fedhpob/optimizers/optuna_optimizer.py
new file mode 100644
index 000000000..aa8159e16
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/optimizers/optuna_optimizer.py
@@ -0,0 +1,142 @@
+# Implement TPE_MD, TPE_HB in `optuna`. from https://raw.githubusercontent.com/automl/HPOBenchExperimentUtils/master
+# /HPOBenchExperimentUtils/optimizer/optuna_optimizer.py
+
+import ConfigSpace as CS
+import numpy as np
+import time
+import random
+import optuna
+import logging
+from functools import partial
+from optuna.pruners import HyperbandPruner, MedianPruner
+from optuna.samplers import TPESampler
+from optuna.trial import Trial
+
+from fedhpob.config import fhb_cfg
+from fedhpob.utils.monitor import Monitor
+
+logging.basicConfig(level=logging.WARNING)
+
+
+def precompute_sh_iters(min_budget, max_budget, eta):
+ max_SH_iter = -int(np.log(min_budget / max_budget) / np.log(eta)) + 1
+ return max_SH_iter
+
+
+def precompute_budgets(max_budget, eta, max_SH_iter):
+ s0 = -np.linspace(start=max_SH_iter - 1, stop=0, num=max_SH_iter)
+ budgets = max_budget * np.power(eta, s0)
+ return budgets
+
+
+def sample_config_from_optuna(trial: Trial, cs: CS.ConfigurationSpace):
+ config = {}
+ for hp_name in cs:
+ hp = cs.get_hyperparameter(hp_name)
+
+ if isinstance(hp, CS.UniformFloatHyperparameter):
+ value = float(
+ trial.suggest_float(name=hp_name,
+ low=hp.lower,
+ high=hp.upper,
+ log=hp.log))
+
+ elif isinstance(hp, CS.UniformIntegerHyperparameter):
+ value = int(
+ trial.suggest_int(name=hp_name,
+ low=hp.lower,
+ high=hp.upper,
+ log=hp.log))
+
+ elif isinstance(hp, CS.CategoricalHyperparameter):
+ hp_type = type(hp.default_value)
+ value = hp_type(
+ trial.suggest_categorical(name=hp_name, choices=hp.choices))
+
+ elif isinstance(hp, CS.OrdinalHyperparameter):
+ num_vars = len(hp.sequence)
+ index = trial.suggest_int(hp_name,
+ low=0,
+ high=num_vars - 1,
+ log=False)
+ hp_type = type(hp.default_value)
+ value = hp.sequence[index]
+ value = hp_type(value)
+
+ else:
+ raise ValueError(
+ f'Please implement the support for hps of type {type(hp)}')
+
+ config[hp.name] = value
+ return config
+
+
+def run_optuna(cfg):
+ def objective(trial, benchmark, valid_budgets, configspace):
+ config = sample_config_from_optuna(trial, configspace)
+ res = None
+ for budget in valid_budgets:
+ main_fidelity = {
+ 'round': int(budget),
+ 'sample_client': cfg.benchmark.sample_client
+ }
+ t_start = time.time()
+ res = benchmark(config,
+ main_fidelity,
+ seed=random.randint(1, 99),
+ key='val_avg_loss',
+ fhb_cfg=cfg)
+ monitor(res=res, sim_time=time.time() - t_start, budget=budget)
+ trial.report(res['function_value'], step=budget)
+
+ if trial.should_prune():
+ raise optuna.TrialPruned()
+
+ assert res is not None
+ return res['function_value']
+
+ monitor = Monitor(cfg)
+ benchmark = cfg.benchmark.cls[0][cfg.benchmark.type](
+ cfg.benchmark.model,
+ cfg.benchmark.data,
+ cfg.benchmark.algo,
+ device=cfg.benchmark.device)
+ sampler = TPESampler(seed=cfg.optimizer.seed)
+ study = optuna.create_study(direction='minimize', sampler=sampler)
+ if cfg.optimizer.type == 'tpe_md':
+ pruner = MedianPruner()
+ sh_iters = precompute_sh_iters(cfg.optimizer.min_budget,
+ cfg.optimizer.max_budget,
+ cfg.optimizer.optuna.reduction_factor)
+ valid_budgets = precompute_budgets(
+ cfg.optimizer.max_budget, cfg.optimizer.optuna.reduction_factor,
+ sh_iters)
+ elif cfg.optimizer.type == 'tpe_hb':
+ pruner = HyperbandPruner(
+ min_resource=cfg.optimizer.min_budget,
+ max_resource=cfg.optimizer.max_budget,
+ reduction_factor=cfg.optimizer.optuna.reduction_factor)
+ pruner._try_initialization(study=None)
+ valid_budgets = [
+ cfg.optimizer.min_budget * cfg.optimizer.optuna.reduction_factor**i
+ for i in range(pruner._n_brackets)
+ ]
+ else:
+ raise NotImplementedError
+
+ study.optimize(func=partial(
+ objective,
+ benchmark=benchmark,
+ valid_budgets=valid_budgets,
+ configspace=cfg.benchmark.configuration_space[0]),
+ timeout=None,
+ n_trials=cfg.optimizer.n_iterations)
+ return monitor.history_results
+
+
+if __name__ == "__main__":
+ results = []
+ for opt_name in ['tpe_md', 'tpe_hb']:
+ fhb_cfg.optimizer.type = opt_name
+ results.append(run_optuna(fhb_cfg))
+ print(results)
diff --git a/benchmark/FedHPOB/fedhpob/optimizers/smac_optimizer.py b/benchmark/FedHPOB/fedhpob/optimizers/smac_optimizer.py
new file mode 100644
index 000000000..171bebdc9
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/optimizers/smac_optimizer.py
@@ -0,0 +1,72 @@
+# Implement BO_GP, BO_RF in `smac`.
+
+import time
+import random
+import logging
+from smac.facade.smac_bb_facade import SMAC4BB
+from smac.facade.smac_hpo_facade import SMAC4HPO
+from smac.scenario.scenario import Scenario
+
+from fedhpob.config import fhb_cfg
+from fedhpob.utils.monitor import Monitor
+
+logging.basicConfig(level=logging.WARNING)
+
+
+def run_smac(cfg):
+ def optimization_function_wrapper(config):
+ """ Helper-function: simple wrapper to use the benchmark with smac"""
+ budget = int(cfg.optimizer.max_budget)
+ main_fidelity = {
+ 'round': budget,
+ 'sample_client': cfg.benchmark.sample_client
+ }
+ t_start = time.time()
+ res = benchmark(config,
+ main_fidelity,
+ seed=random.randint(1, 99),
+ key='val_avg_loss',
+ fhb_cfg=cfg)
+ monitor(res=res, sim_time=time.time() - t_start, budget=budget)
+ return res['function_value']
+
+ monitor = Monitor(cfg)
+ benchmark = cfg.benchmark.cls[0][cfg.benchmark.type](
+ cfg.benchmark.model,
+ cfg.benchmark.data,
+ cfg.benchmark.algo,
+ device=cfg.benchmark.device)
+
+ scenario = Scenario({
+ "run_obj": "quality", # Optimize quality (alternatively runtime)
+ "runcount-limit": cfg.optimizer.
+ n_iterations, # Max number of function evaluations
+ "cs": cfg.benchmark.configuration_space[0],
+ "output_dir": cfg.benchmark.type,
+ "deterministic": "true",
+ "limit_resources": False
+ })
+ if cfg.optimizer.type == 'bo_gp':
+ smac = SMAC4BB(model_type='gp',
+ scenario=scenario,
+ tae_runner=optimization_function_wrapper)
+ elif cfg.optimizer.type == 'bo_rf':
+ smac = SMAC4HPO(scenario=scenario,
+ tae_runner=optimization_function_wrapper)
+ else:
+ raise NotImplementedError
+
+ try:
+ incumbent = smac.optimize()
+ finally:
+ incumbent = smac.solver.incumbent
+
+ return monitor.history_results
+
+
+if __name__ == "__main__":
+ results = []
+ for opt_name in ['bo_gp', 'bo_rf']:
+ fhb_cfg.optimizer.type = opt_name
+ results.append(run_smac(fhb_cfg))
+ print(results)
diff --git a/benchmark/FedHPOB/fedhpob/utils/__init__.py b/benchmark/FedHPOB/fedhpob/utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/benchmark/FedHPOB/fedhpob/utils/cost_model.py b/benchmark/FedHPOB/fedhpob/utils/cost_model.py
new file mode 100644
index 000000000..3b380bd64
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/utils/cost_model.py
@@ -0,0 +1,88 @@
+from federatedscope.core.auxiliaries.model_builder import get_model
+
+
+def merge_cfg(cfg, configuration, fidelity):
+ init_cfg = cfg.clone()
+ # Configuration related
+ if 'lr' in configuration:
+ init_cfg.optimizer.lr = configuration['lr']
+ if 'wd' in configuration:
+ init_cfg.optimizer.weight_decay = configuration['wd']
+ if 'dropout' in configuration:
+ init_cfg.model.dropout = configuration['dropout']
+ if 'batch' in configuration:
+ init_cfg.data.batch_size = configuration['batch']
+ if 'layer' in configuration:
+ init_cfg.model.layer = configuration['layer']
+ if 'hidden' in configuration:
+ init_cfg.model.hidden = configuration['hidden']
+ if 'step' in configuration:
+ init_cfg.federate.local_update_steps = int(configuration['step'])
+ # FedOPT related
+ if 'momentumsserver' in configuration:
+ init_cfg.fedopt.momentum_server = configuration['momentumsserver']
+ if 'lrserver' in configuration:
+ init_cfg.fedopt.lr_server = configuration['lrserver']
+ # Fidelity related
+ if 'sample_client' in fidelity:
+ init_cfg.federate.sample_client_rate = fidelity['sample_client']
+ if 'round' in fidelity:
+ init_cfg.federate.total_round_num = fidelity['round']
+ return init_cfg
+
+
+def get_cost_model(mode='estimated'):
+ r"""
+ This function returns a function of cost model.
+
+ :param key: name of cost model.
+ :return: the function of cost model
+ """
+ cost_dict = {
+ 'raw': raw_cost,
+ 'estimated': estimated_cost,
+ }
+ return cost_dict[mode]
+
+
+def communication_cost(cfg, model_size, fhb_cfg):
+ t_up = model_size / fhb_cfg.cost.bandwidth.client_up
+ t_down = max(
+ cfg.federate.client_num * cfg.federate.sample_client_rate *
+ model_size / fhb_cfg.cost.bandwidth.server_up,
+ model_size / fhb_cfg.cost.bandwidth.client_down)
+ return t_up + t_down
+
+
+def computation_cost(cfg, fhb_cfg):
+ """
+ Assume the time is exponential distribution with c,
+ return the expected maximum of M iid random variables plus server time.
+ """
+ t_client = sum([
+ 1.0 / i for i in range(
+ 1,
+ int(cfg.federate.client_num * cfg.federate.sample_client_rate) + 1)
+ ]) / fhb_cfg.cost.c
+ return t_client + fhb_cfg.cost.time_server
+
+
+def raw_cost(**kwargs):
+ return None
+
+
+def estimated_cost(cfg, configuration, fidelity, data, **kwargs):
+ """
+ Works on raw, tabular and surrogate mode.
+ """
+ def get_info(cfg, configuration, fidelity, data):
+ cfg = merge_cfg(cfg, configuration, fidelity)
+ model = get_model(cfg.model, list(data.values())[0])
+ model_size = sum([param.nelement() for param in model.parameters()])
+ return cfg, model_size
+
+ cfg, num_param = get_info(cfg, configuration, fidelity, data)
+ t_comm = communication_cost(cfg, num_param, kwargs['fhb_cfg'])
+ t_comp = computation_cost(cfg, kwargs['fhb_cfg'])
+ t_round = t_comm + t_comp
+ return t_round * cfg.federate.total_round_num
diff --git a/benchmark/FedHPOB/fedhpob/utils/draw.py b/benchmark/FedHPOB/fedhpob/utils/draw.py
new file mode 100644
index 000000000..feed440ae
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/utils/draw.py
@@ -0,0 +1,91 @@
+import os
+import json
+import numpy as np
+import matplotlib.pyplot as plt
+from tqdm import tqdm
+
+FONTSIZE = 40
+MARKSIZE = 25
+
+
+def logloader(file):
+ log = []
+ with open(file) as f:
+ file = f.readlines()
+ for line in file:
+ line = json.loads(s=line)
+ log.append(line)
+ return log
+
+
+def ecdf(model, data_list, sample_client=None, key='test_acc'):
+ import datetime
+ from fedhpob.benchmarks import TabularBenchmark
+
+ # Draw ECDF from target data_list
+ plt.figure(figsize=(10, 7.5))
+ plt.xticks(fontsize=FONTSIZE)
+ plt.yticks(fontsize=FONTSIZE)
+
+ plt.xlabel('Normalized regret', size=FONTSIZE)
+ plt.ylabel('P(X <= x)', size=FONTSIZE)
+
+ # Get target data (tabular only)
+ for data in data_list:
+ benchmark = TabularBenchmark(model, data, device=-1)
+ target = [0] # Init with zero
+ for idx in tqdm(range(len(benchmark.table))):
+ row = benchmark.table.iloc[idx]
+ if sample_client is not None and row[
+ 'sample_client'] != sample_client:
+ continue
+ result = eval(row['result'])
+ val_loss = result['val_avg_loss']
+ best_round = np.argmin(val_loss)
+ target.append(result[key][best_round])
+ norm_regret = np.sort(1 - (np.array(target) / np.max(target)))
+ y = np.arange(len(norm_regret)) / float(len(norm_regret) - 1)
+ plt.plot(norm_regret, y)
+ plt.legend(data_list, fontsize=23, loc='lower right')
+ plt.savefig(f'{model}_{sample_client}_cdf.pdf', bbox_inches='tight')
+ plt.close()
+
+ return target
+
+
+def rank_over_time(root):
+ # Please place these logs to one dir
+ target_opt = [
+ 'rs', 'bo_gp', 'bo_rf', 'bo_kde', 'de', 'hb', 'bohb', 'dehb', 'tpe_md',
+ 'tpe_hb'
+ ]
+ files = os.listdir(root)
+ logs = []
+ for opt in target_opt:
+ for file in files:
+ if file.startswith(opt):
+ logs.append(logloader(file))
+ break
+
+ # Draw over time
+ plt.figure(figsize=(10, 7.5))
+ plt.xticks(fontsize=FONTSIZE)
+ plt.yticks(fontsize=FONTSIZE)
+
+ plt.xlabel('Fraction of budget', size=FONTSIZE)
+ plt.ylabel('Mean rank', size=FONTSIZE)
+
+ for data in logs:
+ tol_time = data[-1]['Consumed']
+ frac_budget = np.array([i['Consumed'] / tol_time for i in data])
+ # TODO: sort by rank
+ loss = np.array([i['best_value'] for i in data])
+ plt.plot(frac_budget, loss, linewidth=1, markersize=MARKSIZE)
+ plt.legend(target_opt, fontsize=23, loc='lower right')
+ plt.savefig(f'{root}_rank_over_time.pdf', bbox_inches='tight')
+ # plt.show()
+ plt.close()
+
+
+if __name__ == '__main__':
+ ecdf('gcn', ['cora', 'citeseer', 'pubmed'], sample_client=None)
diff --git a/benchmark/FedHPOB/fedhpob/utils/monitor.py b/benchmark/FedHPOB/fedhpob/utils/monitor.py
new file mode 100644
index 000000000..26539cbf0
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/utils/monitor.py
@@ -0,0 +1,53 @@
+import os
+import time
+import json
+import logging
+
+import numpy as np
+
+from fedhpob.utils.util import cfg2name
+
+logging.basicConfig(level=logging.WARNING)
+
+
+class Monitor(object):
+ def __init__(self, cfg):
+ self.limit_time = cfg.optimizer.limit_time
+ self.last_timestamp = time.time()
+ self.best_value = np.inf
+ self.consumed_time, self.budget, self.cnt = 0, 0, 0
+ self.logs = []
+ self.cfg = cfg
+
+ def __call__(self, res, sim_time=0, *args, **kwargs):
+ self._check_and_log(res['cost'])
+ # minus the time consumed in simulation and plus estimated time.
+ self.consumed_time += (time.time() - self.last_timestamp - sim_time +
+ res['cost'])
+ self.cnt += 1
+ if res['function_value'] < self.best_value or kwargs[
+ 'budget'] > self.budget:
+ self.budget = kwargs['budget']
+ self.best_value = res['function_value']
+ self.logs.append({
+ 'Try': self.cnt,
+ "Consumed": self.consumed_time,
+ 'best_value': self.best_value,
+ 'cur_results': res
+ })
+ logging.warning(
+ f'Try: {self.cnt}, Consumed: {self.consumed_time}, best_value: {self.best_value}, cur_results: {res}'
+ )
+ self.last_timestamp = time.time()
+
+ def _check_and_log(self, cost):
+ if self.consumed_time + cost > self.limit_time:
+ # TODO: record time and cost
+ logging.warning(
+ f'Time has been consumed, no time for next try (cost: {cost})!'
+ )
+ out_file = cfg2name(self.cfg)
+ with open(out_file, 'w') as f:
+ for line in self.logs:
+ f.write(json.dumps(line) + "\n")
+ os._exit(1)
\ No newline at end of file
diff --git a/benchmark/FedHPOB/fedhpob/utils/runner.py b/benchmark/FedHPOB/fedhpob/utils/runner.py
new file mode 100644
index 000000000..6ed62ab01
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/utils/runner.py
@@ -0,0 +1,30 @@
+from federatedscope.core.cmd_args import parse_args
+from fedhpob.config import fhb_cfg, add_configs
+from fedhpob.optimizers import run_dehb, run_hpbandster, run_optuna, run_smac
+
+
+def run(cfg):
+ if cfg.optimizer.type in ['de', 'dehb']:
+ results = run_dehb(cfg)
+ elif cfg.optimizer.type in ['rs', 'bo_kde', 'hb', 'bohb']:
+ results = run_hpbandster(cfg)
+ elif cfg.optimizer.type in ['tpe_md', 'tpe_hb']:
+ results = run_optuna(cfg)
+ elif cfg.optimizer.type in ['bo_gp', 'bo_rf']:
+ results = run_smac(cfg)
+ else:
+ raise NotImplementedError
+ return results
+
+
+def main():
+ init_cfg = fhb_cfg.clone()
+ args = parse_args()
+ init_cfg.merge_from_file(args.cfg_file)
+ init_cfg.merge_from_list(args.opts)
+ add_configs(init_cfg)
+ run(cfg=init_cfg)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/benchmark/FedHPOB/fedhpob/utils/surrogate_dataloader.py b/benchmark/FedHPOB/fedhpob/utils/surrogate_dataloader.py
new file mode 100644
index 000000000..be8eff614
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/utils/surrogate_dataloader.py
@@ -0,0 +1,131 @@
+import datetime
+import numpy as np
+import os
+import pickle
+
+from sklearn.ensemble import RandomForestRegressor
+from sklearn.model_selection import cross_validate as sk_cross_validate
+from tqdm import tqdm
+
+from fedhpob.utils.tabular_dataloader import load_data
+
+
+def sampling(X, Y, over_rate=1, down_rate=1.0, cvg_score=0.5):
+ rel_score = Y
+ over_X = np.repeat(X[rel_score > cvg_score], over_rate, axis=0)
+ over_Y = np.repeat(Y[rel_score > cvg_score], over_rate, axis=0)
+
+ mask = np.random.choice(X[rel_score <= cvg_score].shape[0],
+ size=int(X[rel_score <= cvg_score].shape[0] *
+ down_rate),
+ replace=False)
+ down_X = np.array(X[rel_score <= cvg_score])[mask]
+ down_Y = np.array(Y[rel_score <= cvg_score])[mask]
+ return np.concatenate([over_X, down_X],
+ axis=0), np.concatenate([over_Y, down_Y], axis=0)
+
+
+def load_surrogate_model(modeldir, model, dname, algo):
+ model_list = []
+ path = os.path.join(modeldir, model, dname, algo)
+ file_names = os.listdir(path)
+ for fname in file_names:
+ if not fname.startswith('surrogate_model'):
+ continue
+ with open(os.path.join(path, fname), 'rb') as f:
+ model_state = f.read()
+ model = pickle.loads(model_state)
+ model_list.append(model)
+
+ infofile = os.path.join(path, 'info.pkl')
+ with open(infofile, 'rb') as f:
+ info = pickle.loads(f.read())
+
+ # TODO: remove X and Y
+ X = np.load(os.path.join(path, 'X.npy'))
+ Y = np.load(os.path.join(path, 'Y.npy'))
+
+ return model_list, info, X, Y
+
+
+def build_surrogate_model(datadir, model, dname, algo, key='val_acc'):
+ r"""
+ from TabularBenchmark to SurrogateBenchmark data format
+ """
+ table, meta_info = load_data(datadir, model, dname, algo)
+ savedir = os.path.join('data/surrogate_model', model, dname, algo)
+ os.makedirs(savedir, exist_ok=True)
+ # Build data to train the surrogate_model
+ X, Y = [], []
+ fidelity_space = sorted(['sample_client', 'round'])
+ configuration_space = sorted(
+ list(set(table.keys()) - {'result', 'seed'} - set(fidelity_space)))
+
+ if not os.path.exists(os.path.join(savedir,
+ 'X.npy')) or not os.path.exists(
+ os.path.join(savedir, 'Y.npy')):
+ print('Building data mat...')
+ for idx in tqdm(range(len(table))):
+ row = table.iloc[idx]
+ x = [row[col]
+ for col in configuration_space] + [row['sample_client']]
+ result = eval(row['result'])
+ val_loss = result['val_avg_loss']
+ for rnd in range(len(val_loss)):
+ X.append(x + [rnd * meta_info['eval_freq']])
+ best_round = np.argmin(val_loss[:rnd + 1])
+ Y.append(result[key][best_round])
+ X, Y = np.array(X), np.array(Y)
+ np.save(os.path.join(savedir, 'X.npy'), X)
+ np.save(os.path.join(savedir, 'Y.npy'), Y)
+ else:
+ print('Loading cache...')
+ X = np.load(os.path.join(savedir, 'X.npy'))
+ Y = np.load(os.path.join(savedir, 'Y.npy'))
+
+ new_X, new_Y = sampling(X, Y, over_rate=1, down_rate=1)
+
+ perm = np.random.permutation(np.arange(len(new_Y)))
+ new_X, new_Y = new_X[perm], new_Y[perm]
+
+ best_res = -np.inf
+ # Ten-fold validation to get ten surrogate_model
+ for n_estimators in [10, 20]:
+ for max_depth in [10, 15, 20]:
+ regr = RandomForestRegressor(n_estimators=n_estimators,
+ max_depth=max_depth)
+ # dict_keys(['fit_time', 'score_time', 'estimator', 'test_score', 'train_score'])
+ res = sk_cross_validate(regr,
+ new_X,
+ new_Y,
+ cv=10,
+ n_jobs=-1,
+ scoring='neg_mean_absolute_error',
+ return_estimator=True,
+ return_train_score=True)
+ test_metric = np.mean(res['test_score'])
+ train_metric = np.mean(res['train_score'])
+ print(
+ f'n_estimators: {n_estimators}, max_depth: {max_depth}, train_metric: {train_metric}, test_metric: {test_metric}'
+ )
+ if test_metric > best_res:
+ best_res = test_metric
+ best_models = res['estimator']
+
+ # Save model
+ for i, rf in enumerate(best_models):
+ file_name = f'surrogate_model_{i}.pkl'
+ model_state = pickle.dumps(rf)
+ with open(os.path.join(savedir, file_name), 'wb') as f:
+ f.write(model_state)
+
+ # Save info
+ info = {
+ 'configuration_space': configuration_space,
+ 'fidelity_space': fidelity_space
+ }
+ pkl = pickle.dumps(info)
+ with open(os.path.join(savedir, 'info.pkl'), 'wb') as f:
+ f.write(pkl)
+
+ return best_models, info, X, Y
diff --git a/benchmark/FedHPOB/fedhpob/utils/tabular_dataloader.py b/benchmark/FedHPOB/fedhpob/utils/tabular_dataloader.py
new file mode 100644
index 000000000..26d288ba8
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/utils/tabular_dataloader.py
@@ -0,0 +1,165 @@
+import os
+import pickle
+import re
+from datetime import *
+
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+
+
+def load_data(root, model, dname, algo):
+ path = os.path.join(root, model, dname, algo)
+ datafile = os.path.join(path, 'tabular.csv.gz')
+ infofile = os.path.join(path, 'info.pkl')
+
+ if not os.path.exists(datafile):
+ df = logs2df(dname, path)
+ df.to_csv(datafile, index=False, compression='gzip')
+ if not os.path.exists(infofile):
+ info = logs2info(dname, path)
+ pkl = pickle.dumps(info)
+ with open(infofile, 'wb') as f:
+ f.write(pkl)
+
+ df = pd.read_csv(datafile)
+ with open(infofile, 'rb') as f:
+ info = pickle.loads(f.read())
+
+ return df, info
+
+
+# TODO: removepreprocessing
+def group_by_seed(names, repeat=3):
+ names = sorted(names)
+ if len(names) % repeat != 0:
+ raise FileNotFoundError('Missing file!')
+ index = np.arange(0, len(names), repeat, dtype=np.int32)
+ names_group = [names[i:i + repeat] for i in index]
+ return names_group
+
+
+# TODO: remove preprocessing
+def logs2info(dname, root, sample_client_rate=[0.2, 0.4, 0.6, 0.8, 1.0]):
+ sample_client_rate = set(sample_client_rate)
+ dir_names = [f'out_{dname}_' + str(x) for x in sample_client_rate]
+ trail_names = [x for x in os.listdir(os.path.join(root, dir_names[0]))]
+ split_names = [x.split('_') for x in trail_names if x.startswith('lr')]
+ args = [''.join(re.findall(r'[A-Za-z]', arg)) for arg in split_names[0]]
+
+ search_space = {
+ arg: set([float(x[i][len(arg):]) for x in split_names])
+ for i, arg in enumerate(args)
+ }
+ if dname in ['cola', 'sst2']:
+ fidelity_space = {
+ 'sample_client': set(sample_client_rate),
+ 'round': [x for x in range(40)]
+ }
+ eval_freq = 1
+ elif dname in ['femnist']:
+ fidelity_space = {
+ 'sample_client': set(sample_client_rate),
+ 'round': [x + 1 for x in range(0, 500, 2)]
+ }
+ eval_freq = 2
+ else:
+ fidelity_space = {
+ 'sample_client': set(sample_client_rate),
+ 'round': [x for x in range(500)]
+ }
+ eval_freq = 1
+ info = {
+ 'configuration_space': search_space,
+ 'fidelity_space': fidelity_space,
+ 'eval_freq': eval_freq
+ }
+
+ return info
+
+
+def logs2df(dname,
+ root='',
+ sample_client_rate=[0.2, 0.4, 0.6, 0.8, 1.0],
+ metrics=[
+ 'train_avg_loss', 'val_avg_loss', 'test_avg_loss', 'train_acc',
+ 'val_acc', 'test_acc', 'train_f1', 'val_f1', 'test_f1'
+ ]):
+ sample_client_rate = [str(round(x, 1)) for x in sample_client_rate]
+ dir_names = [f'out_{dname}_' + str(x) for x in sample_client_rate]
+
+ trail_names = [x for x in os.listdir(os.path.join(root, dir_names[0]))]
+ split_names = [x.split('_') for x in trail_names if x.startswith('lr')]
+
+ args = [''.join(re.findall(r'[A-Za-z]', arg)) for arg in split_names[0]]
+ df = pd.DataFrame(None, columns=['sample_client'] + args + ['result'])
+
+ print('Processing...')
+ cnt = 0
+ for name, rate in zip(dir_names, sample_client_rate):
+ path = os.path.join(root, name)
+ trail_names = sorted(
+ [x for x in os.listdir(path) if x.startswith('lr')])
+ # trail_names = group_by_seed(trail_names)
+ for file_name in tqdm(trail_names):
+ metrics_dict = {x: [] for x in metrics}
+ time_dict = {
+ x: []
+ for x in ['train_time', 'eval_time', 'tol_time']
+ }
+ with open(os.path.join(path, file_name, 'exp_print.log')) as f:
+ F = f.readlines()
+ start_time = datetime.strptime(F[0][:19], '%Y-%m-%d %H:%M:%S')
+ end_time = datetime.strptime(F[-1][:19], '%Y-%m-%d %H:%M:%S')
+ time_dict['tol_time'].append(end_time - start_time)
+
+ train_p = False
+
+ for idx, line in enumerate(F):
+ # Time
+ try:
+ timestamp = datetime.strptime(line[:19],
+ '%Y-%m-%d %H:%M:%S')
+ except:
+ continue
+
+ if "'Role': 'Client #" in line and train_p == False:
+ train_start_time = previous_time
+ train_p = True
+
+ if "'Role': 'Client #" not in line and train_p == True:
+ train_time = previous_time - train_start_time
+ time_dict['train_time'].append(train_time)
+ train_p = False
+
+ if 'Starting evaluation' in line:
+ eval_start_time = timestamp
+ if 'Results_raw' in line and 'test' in line:
+ eval_time = timestamp - eval_start_time
+ time_dict['eval_time'].append(eval_time)
+ previous_time = timestamp
+
+ # Statistics
+ try:
+ results = eval(line.split('INFO: ')[1])
+ except:
+ continue
+ for key in metrics_dict:
+ if results['Role'] == 'Global-Eval-Server #':
+ metrics_dict[key].append(
+ results['Results_raw'][key])
+ elif 'Results_weighted_avg' not in results:
+ continue
+ else:
+ metrics_dict[key].append(
+ results['Results_weighted_avg'][key])
+ value = [
+ float(file_name.split('_')[i][len(arg):])
+ for i, arg in enumerate(args)
+ ]
+ df.loc[cnt] = [float(rate)] + value + [{
+ **metrics_dict,
+ **time_dict
+ }]
+ cnt += 1
+ return df
\ No newline at end of file
diff --git a/benchmark/FedHPOB/fedhpob/utils/util.py b/benchmark/FedHPOB/fedhpob/utils/util.py
new file mode 100644
index 000000000..b1e2deff4
--- /dev/null
+++ b/benchmark/FedHPOB/fedhpob/utils/util.py
@@ -0,0 +1,74 @@
+import os
+import time
+import logging
+
+from datetime import datetime
+
+
+def merge_dict(dict1, dict2):
+ for key, value in dict2.items():
+ if key not in dict1:
+ if isinstance(value, dict):
+ dict1[key] = merge_dict({}, value)
+ else:
+ dict1[key] = [value]
+ else:
+ if isinstance(value, dict):
+ merge_dict(dict1[key], value)
+ else:
+ dict1[key].append(value)
+ return dict1
+
+
+def disable_fs_logger(cfg, clear_before_add=False):
+ # Disable FS logger
+ root_logger = logging.getLogger("federatedscope")
+ # clear all existing handlers and add the default stream
+ if clear_before_add:
+ root_logger.handlers = []
+ handler = logging.StreamHandler()
+ logging_fmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
+ handler.setFormatter(logging.Formatter(logging_fmt))
+ root_logger.addHandler(handler)
+
+ root_logger.setLevel(logging.CRITICAL)
+
+ # ================ create outdir to save log, exp_config, models, etc,.
+ if cfg.outdir == "":
+ cfg.outdir = os.path.join(os.getcwd(), "exp")
+ cfg.outdir = os.path.join(cfg.outdir, cfg.expname)
+
+ # if exist, make directory with given name and time
+ if os.path.isdir(cfg.outdir) and os.path.exists(cfg.outdir):
+ outdir = os.path.join(cfg.outdir, "sub_exp" +
+ datetime.now().strftime('_%Y%m%d%H%M%S')
+ ) # e.g., sub_exp_20220411030524
+ while os.path.exists(outdir):
+ time.sleep(1)
+ outdir = os.path.join(
+ cfg.outdir,
+ "sub_exp" + datetime.now().strftime('_%Y%m%d%H%M%S'))
+ cfg.outdir = outdir
+ # if not, make directory with given name
+ os.makedirs(cfg.outdir)
+
+ # create file handler which logs even debug messages
+ fh = logging.FileHandler(os.path.join(cfg.outdir, 'exp_print.log'))
+ fh.setLevel(logging.CRITICAL)
+ logger_formatter = logging.Formatter(
+ "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
+ fh.setFormatter(logger_formatter)
+ root_logger.addHandler(fh)
+
+
+def cfg2name(cfg):
+ repeat = 0
+ dir = os.path.join(
+ cfg.benchmark.out_dir,
+ f'{cfg.benchmark.data}_{cfg.benchmark.model}_{cfg.benchmark.type}_{cfg.benchmark.algo}'
+ )
+ os.makedirs(dir, exist_ok=True)
+ while os.path.exists(
+ os.path.join(dir, f'{cfg.optimizer.type}_repeat{repeat}.txt')):
+ repeat += 1
+ return os.path.join(dir, f'{cfg.optimizer.type}_repeat{repeat}.txt')
diff --git a/benchmark/FedHPOB/scripts/bert_tiny/cola.yaml b/benchmark/FedHPOB/scripts/bert_tiny/cola.yaml
new file mode 100644
index 000000000..f60efe30f
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/bert_tiny/cola.yaml
@@ -0,0 +1,33 @@
+use_gpu: True
+device: 1
+federate:
+ mode: standalone
+ local_update_steps: 1
+ total_round_num: 40
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+data:
+ root: 'glue'
+ type: 'cola@huggingface_datasets'
+ args: [{'max_len': 128}]
+ batch_size: 128
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+ num_workers: 0
+model:
+ type: 'google/bert_uncased_L-2_H-128_A-2@transformers'
+ task: 'SequenceClassification'
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: 'CrossEntropyLoss'
+trainer:
+ type: 'nlptrainer'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['val', 'train']
\ No newline at end of file
diff --git a/benchmark/FedHPOB/scripts/bert_tiny/run.sh b/benchmark/FedHPOB/scripts/bert_tiny/run.sh
new file mode 100644
index 000000000..6e34184d1
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/bert_tiny/run.sh
@@ -0,0 +1,98 @@
+# --Seed 1 --
+# --1--
+bash run_hpo_glue.sh 0 0.2 sst2 1 &
+bash run_hpo_glue.sh 1 0.4 sst2 1 &
+bash run_hpo_glue.sh 2 0.6 sst2 1 &
+bash run_hpo_glue.sh 3 0.8 sst2 1 &
+
+# --2--
+bash run_hpo_glue.sh 0 1.0 sst2 1 &
+bash run_hpo_glue.sh 1 0.2 cola 1 &
+bash run_hpo_glue.sh 2 0.4 cola 1 &
+bash run_hpo_glue.sh 3 0.6 cola 1 &
+
+# --3--
+bash run_hpo_glue.sh 0 0.8 cola 1 &
+bash run_hpo_glue.sh 1 1.0 cola 1 &
+bash run_opt_glue.sh 2 0.2 sst2 1 &
+bash run_opt_glue.sh 3 0.4 sst2 1 &
+
+# --4--
+bash run_opt_glue.sh 0 0.6 sst2 1 &
+bash run_opt_glue.sh 1 0.8 sst2 1 &
+bash run_opt_glue.sh 2 1.0 sst2 1 &
+bash run_opt_glue.sh 3 0.2 cola 1 &
+
+# --5--
+bash run_opt_glue.sh 0 0.4 cola 1 &
+bash run_opt_glue.sh 1 0.6 cola 1 &
+bash run_opt_glue.sh 2 0.8 cola 1 &
+bash run_opt_glue.sh 3 1.0 cola 1 &
+
+
+
+
+
+# --Seed 2 --
+# --1--
+bash run_hpo_glue.sh 0 0.2 sst2 2 &
+bash run_hpo_glue.sh 1 0.4 sst2 2 &
+bash run_hpo_glue.sh 2 0.6 sst2 2 &
+bash run_hpo_glue.sh 3 0.8 sst2 2 &
+
+# --2--
+bash run_hpo_glue.sh 0 1.0 sst2 2 &
+bash run_hpo_glue.sh 1 0.2 cola 2 &
+bash run_hpo_glue.sh 2 0.4 cola 2 &
+bash run_hpo_glue.sh 3 0.6 cola 2 &
+
+# --3--
+bash run_hpo_glue.sh 0 0.8 cola 2 &
+bash run_hpo_glue.sh 1 1.0 cola 2 &
+bash run_opt_glue.sh 2 0.2 sst2 2 &
+bash run_opt_glue.sh 3 0.4 sst2 2 &
+
+# --4--
+bash run_opt_glue.sh 0 0.6 sst2 2 &
+bash run_opt_glue.sh 1 0.8 sst2 2 &
+bash run_opt_glue.sh 2 1.0 sst2 2 &
+bash run_opt_glue.sh 3 0.2 cola 2 &
+
+# --5--
+bash run_opt_glue.sh 0 0.4 cola 2 &
+bash run_opt_glue.sh 1 0.6 cola 2 &
+bash run_opt_glue.sh 2 0.8 cola 2 &
+bash run_opt_glue.sh 3 1.0 cola 2 &
+
+
+
+# --Seed 3 --
+# --1--
+bash run_hpo_glue.sh 0 0.2 sst2 3 &
+bash run_hpo_glue.sh 1 0.4 sst2 3 &
+bash run_hpo_glue.sh 2 0.6 sst2 3 &
+bash run_hpo_glue.sh 3 0.8 sst2 3 &
+
+# --2--
+bash run_hpo_glue.sh 0 1.0 sst2 3 &
+bash run_hpo_glue.sh 1 0.2 cola 3 &
+bash run_hpo_glue.sh 2 0.4 cola 3 &
+bash run_hpo_glue.sh 3 0.6 cola 3 &
+
+# --3--
+bash run_hpo_glue.sh 0 0.8 cola 3 &
+bash run_hpo_glue.sh 1 1.0 cola 3 &
+bash run_opt_glue.sh 2 0.2 sst2 3 &
+bash run_opt_glue.sh 3 0.4 sst2 3 &
+
+# --4--
+bash run_opt_glue.sh 0 0.6 sst2 3 &
+bash run_opt_glue.sh 1 0.8 sst2 3 &
+bash run_opt_glue.sh 2 1.0 sst2 3 &
+bash run_opt_glue.sh 3 0.2 cola 3 &
+
+# --5--
+bash run_opt_glue.sh 0 0.4 cola 3 &
+bash run_opt_glue.sh 1 0.6 cola 3 &
+bash run_opt_glue.sh 2 0.8 cola 3 &
+bash run_opt_glue.sh 3 1.0 cola 3 &
diff --git a/benchmark/FedHPOB/scripts/bert_tiny/run_hpo_glue.sh b/benchmark/FedHPOB/scripts/bert_tiny/run_hpo_glue.sh
new file mode 100644
index 000000000..dddd24ed4
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/bert_tiny/run_hpo_glue.sh
@@ -0,0 +1,33 @@
+# https://huggingface.co/google/bert_uncased_L-2_H-128_A-2
+set -e
+
+cudaid=$1
+sample_rate=$2
+dataset=$3
+k=$4
+
+cd ../..
+
+out_dir=out_${dataset}
+
+echo "HPO starts..."
+
+lrs=(0.01 0.01668 0.02783 0.04642 0.07743 0.12915 0.21544 0.35938 0.59948 1.0)
+wds=(0.0 0.001 0.01 0.1)
+dps=(0.0 0.5)
+steps=(1 2 3 4)
+batch_sizes=(8 16 32 64 128)
+
+for ((l = 0; l < ${#lrs[@]}; l++)); do
+ for ((w = 0; w < ${#wds[@]}; w++)); do
+ for ((d = 0; d < ${#dps[@]}; d++)); do
+ for ((s = 0; s < ${#steps[@]}; s++)); do
+ for ((b = 0; b < ${#batch_sizes[@]}; b++)); do
+ python main.py --cfg fedhpo/glue/${dataset}.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} model.dropout ${dps[$d]} federate.local_update_steps ${steps[$s]} data.batch_size ${batch_sizes[$b]} federate.sample_client_rate $sample_rate seed $k outdir ${out_dir}_${sample_rate} expname lr${lrs[$l]}_wd${wds[$w]}_dropout${dps[$d]}_step${steps[$s]}_batch${batch_sizes[$b]}_seed${k} >/dev/null 2>&1
+ done
+ done
+ done
+ done
+done
+
+echo "HPO ends."
\ No newline at end of file
diff --git a/benchmark/FedHPOB/scripts/bert_tiny/run_opt_glue.sh b/benchmark/FedHPOB/scripts/bert_tiny/run_opt_glue.sh
new file mode 100644
index 000000000..f2d2c65fe
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/bert_tiny/run_opt_glue.sh
@@ -0,0 +1,40 @@
+# https://huggingface.co/google/bert_uncased_L-2_H-128_A-2
+set -e
+
+cudaid=$1
+sample_rate=$2
+dataset=$3
+k=$4
+
+cd ../..
+
+out_dir=out_${dataset}
+
+echo "HPO starts..."
+
+lrs=(0.01 0.01668 0.02783 0.04642 0.07743 0.12915 0.21544 0.35938 0.59948 1.0)
+wds=(0.0 0.001 0.01 0.1)
+dps=(0.0 0.5)
+steps=(1)
+batch_sizes=(8 16 32 64 128)
+
+lrs_server=(0.1 0.5 1.0)
+momentums_server=(0.0 0.9)
+
+for ((l = 0; l < ${#lrs[@]}; l++)); do
+ for ((w = 0; w < ${#wds[@]}; w++)); do
+ for ((d = 0; d < ${#dps[@]}; d++)); do
+ for ((s = 0; s < ${#steps[@]}; s++)); do
+ for ((b = 0; b < ${#batch_sizes[@]}; b++)); do
+ for ((sl = 0; sl < ${#lrs_server[@]}; sl++)); do
+ for ((ms = 0; ms < ${#momentums_server[@]}; ms++)); do
+ python main.py --cfg fedhpo/glue/${dataset}.yaml fedopt.use True federate.method FedOpt fedopt.lr_server ${lrs_server[$sl]} fedopt.momentum_server ${momentums_server[$ms]} device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} model.dropout ${dps[$d]} federate.local_update_steps ${steps[$s]} data.batch_size ${batch_sizes[$b]} federate.sample_client_rate $sample_rate seed $k outdir out_fedopt/${out_dir}_${sample_rate} federate.share_local_model False federate.online_aggr False expname lr${lrs[$l]}_wd${wds[$w]}_dropout${dps[$d]}_step${steps[$s]}_batch${batch_sizes[$b]}_lrserver${lrs_server[$sl]}_momentumsserver${momentums_server[$ms]}_seed${k} >/dev/null 2>&1
+ done
+ done
+ done
+ done
+ done
+ done
+done
+
+echo "HPO ends."
\ No newline at end of file
diff --git a/benchmark/FedHPOB/scripts/bert_tiny/sst2.yaml b/benchmark/FedHPOB/scripts/bert_tiny/sst2.yaml
new file mode 100644
index 000000000..61e946a32
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/bert_tiny/sst2.yaml
@@ -0,0 +1,33 @@
+use_gpu: True
+device: 0
+federate:
+ mode: standalone
+ local_update_steps: 1
+ total_round_num: 40
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+data:
+ root: 'glue'
+ type: 'sst2@huggingface_datasets'
+ args: [{'max_len': 512}]
+ batch_size: 128
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+ num_workers: 0
+model:
+ type: 'google/bert_uncased_L-2_H-128_A-2@transformers'
+ task: 'SequenceClassification'
+ out_channels: 2
+optimizer:
+ lr: 0.3
+ weight_decay: 0.0
+criterion:
+ type: 'CrossEntropyLoss'
+trainer:
+ type: 'nlptrainer'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['val', 'train']
\ No newline at end of file
diff --git a/benchmark/FedHPOB/scripts/cnn/cifar10.yaml b/benchmark/FedHPOB/scripts/cnn/cifar10.yaml
new file mode 100644
index 000000000..c2514c62e
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/cnn/cifar10.yaml
@@ -0,0 +1,39 @@
+use_gpu: True
+device: 0
+early_stop:
+ patience: 100
+seed: 1
+federate:
+ mode: standalone
+ local_update_steps: 1
+ batch_or_epoch: epoch
+ total_round_num: 500
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+data:
+ root: data/
+ type: 'CIFAR10@torchvision'
+ splits: [0.8,0.2,0.0]
+ batch_size: 32
+ num_workers: 0
+ transform: [['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]]
+ args: [{'download': True}]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: convnet2
+ hidden: 128
+ out_channels: 10
+ dropout: 0.0
+optimizer:
+ lr: 0.01
+ weight_decay: 0.0
+ grad_clip: 5.0
+criterion:
+ type: CrossEntropyLoss
+trainer:
+ type: cvtrainer
+eval:
+ freq: 1
+ metrics: ['acc', 'correct']
diff --git a/benchmark/FedHPOB/scripts/cnn/femnist.yaml b/benchmark/FedHPOB/scripts/cnn/femnist.yaml
new file mode 100644
index 000000000..c64356b44
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/cnn/femnist.yaml
@@ -0,0 +1,38 @@
+use_gpu: True
+device: 0
+early_stop:
+ patience: 100
+seed: 12345
+federate:
+ mode: standalone
+ local_update_steps: 1
+ batch_or_epoch: epoch
+ total_round_num: 500
+ sample_client_rate: 0.2
+ share_local_model: True
+ online_aggr: True
+data:
+ root: data/
+ type: femnist
+ splits: [0.6,0.2,0.2]
+ batch_size: 16
+ subsample: 0.05
+ transform: [['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]]
+ num_workers: 0
+model:
+ type: convnet2
+ hidden: 2048
+ out_channels: 62
+ dropout: 0.5
+optimizer:
+ lr: 0.01
+ weight_decay: 0.0
+ grad_clip: 5.0
+criterion:
+ type: CrossEntropyLoss
+trainer:
+ type: cvtrainer
+eval:
+ freq: 2
+ metrics: ['acc', 'correct', 'f1']
+ split: ['test', 'val', 'train']
\ No newline at end of file
diff --git a/benchmark/FedHPOB/scripts/cnn/run.sh b/benchmark/FedHPOB/scripts/cnn/run.sh
new file mode 100644
index 000000000..8b1e93e83
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/cnn/run.sh
@@ -0,0 +1,54 @@
+# --1--
+bash run_hpo_femnist_48.sh 0 0.0 1 16 &
+bash run_hpo_femnist_48.sh 1 0.0 1 32 &
+bash run_hpo_femnist_48.sh 2 0.0 1 64 &
+bash run_hpo_femnist_48.sh 3 0.0 2 16 &
+bash run_hpo_femnist_48.sh 0 0.0 2 32 &
+bash run_hpo_femnist_48.sh 1 0.0 2 64 &
+bash run_hpo_femnist_48.sh 2 0.0 3 16 &
+bash run_hpo_femnist_48.sh 3 0.0 3 32 &
+# --2--
+bash run_hpo_femnist_48.sh 0 0.0 3 64 &
+bash run_hpo_femnist_48.sh 1 0.0 4 16 &
+bash run_hpo_femnist_48.sh 2 0.0 4 32 &
+bash run_hpo_femnist_48.sh 3 0.0 4 64 &
+bash run_hpo_femnist_48.sh 0 0.001 1 16 &
+bash run_hpo_femnist_48.sh 1 0.001 1 32 &
+bash run_hpo_femnist_48.sh 2 0.001 1 64 &
+bash run_hpo_femnist_48.sh 3 0.001 2 16 &
+# --3--
+bash run_hpo_femnist_48.sh 0 0.001 2 32 &
+bash run_hpo_femnist_48.sh 1 0.001 2 64 &
+bash run_hpo_femnist_48.sh 2 0.001 3 16 &
+bash run_hpo_femnist_48.sh 3 0.001 3 32 &
+bash run_hpo_femnist_48.sh 0 0.001 3 64 &
+bash run_hpo_femnist_48.sh 1 0.001 4 16 &
+bash run_hpo_femnist_48.sh 2 0.001 4 32 &
+bash run_hpo_femnist_48.sh 3 0.001 4 64 &
+# --4--
+bash run_hpo_femnist_48.sh 0 0.01 1 16 &
+bash run_hpo_femnist_48.sh 1 0.01 1 32 &
+bash run_hpo_femnist_48.sh 2 0.01 1 64 &
+bash run_hpo_femnist_48.sh 3 0.01 2 16 &
+bash run_hpo_femnist_48.sh 0 0.01 2 32 &
+bash run_hpo_femnist_48.sh 1 0.01 2 64 &
+bash run_hpo_femnist_48.sh 2 0.01 3 16 &
+bash run_hpo_femnist_48.sh 3 0.01 3 32 &
+# --5--
+bash run_hpo_femnist_48.sh 0 0.01 3 64 &
+bash run_hpo_femnist_48.sh 1 0.01 4 16 &
+bash run_hpo_femnist_48.sh 2 0.01 4 32 &
+bash run_hpo_femnist_48.sh 3 0.01 4 64 &
+bash run_hpo_femnist_48.sh 0 0.1 1 16 &
+bash run_hpo_femnist_48.sh 1 0.1 1 32 &
+bash run_hpo_femnist_48.sh 2 0.1 1 64 &
+bash run_hpo_femnist_48.sh 3 0.1 2 16 &
+# --6--
+bash run_hpo_femnist_48.sh 0 0.1 2 32 &
+bash run_hpo_femnist_48.sh 1 0.1 2 64 &
+bash run_hpo_femnist_48.sh 2 0.1 3 16 &
+bash run_hpo_femnist_48.sh 3 0.1 3 32 &
+bash run_hpo_femnist_48.sh 0 0.1 3 64 &
+bash run_hpo_femnist_48.sh 1 0.1 4 16 &
+bash run_hpo_femnist_48.sh 2 0.1 4 32 &
+bash run_hpo_femnist_48.sh 3 0.1 4 64 &
diff --git a/benchmark/FedHPOB/scripts/cnn/run_hpo_cifar10.sh b/benchmark/FedHPOB/scripts/cnn/run_hpo_cifar10.sh
new file mode 100644
index 000000000..ecab51e4c
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/cnn/run_hpo_cifar10.sh
@@ -0,0 +1,34 @@
+set -e
+
+cudaid=$1
+sample_rate=$2
+
+cd ../..
+
+dataset=cifar10
+
+out_dir=out_${dataset}
+
+echo "HPO starts..."
+
+lrs=(0.01 0.01668 0.02783 0.04642 0.07743 0.12915 0.21544 0.35938 0.59948 1.0)
+wds=(0.0 0.001 0.01 0.1)
+dps=(0.0 0.5)
+steps=(1 2 3 4)
+batch_sizes=(16 32 64)
+
+for ((l = 0; l < ${#lrs[@]}; l++)); do
+ for ((w = 0; w < ${#wds[@]}; w++)); do
+ for ((d = 0; d < ${#dps[@]}; d++)); do
+ for ((s = 0; s < ${#steps[@]}; s++)); do
+ for ((b = 0; b < ${#batch_sizes[@]}; b++)); do
+ for k in {1..3}; do
+ python federatedscope/main.py --cfg fedhpo/cnn/${dataset}.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} model.dropout ${dps[$d]} federate.local_update_steps ${steps[$s]} data.batch_size ${batch_sizes[$b]} federate.sample_client_rate $sample_rate seed $k outdir ${out_dir}_${sample_rate} expname lr${lrs[$l]}_wd${wds[$w]}_dropout${dps[$d]}_step${steps[$s]}_batch${batch_sizes[$b]}_seed${k} >/dev/null 2>&1
+ done
+ done
+ done
+ done
+ done
+done
+
+echo "HPO ends."
diff --git a/benchmark/FedHPOB/scripts/cnn/run_hpo_femnist.sh b/benchmark/FedHPOB/scripts/cnn/run_hpo_femnist.sh
new file mode 100644
index 000000000..5f7f9126b
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/cnn/run_hpo_femnist.sh
@@ -0,0 +1,34 @@
+set -e
+
+cudaid=$1
+sample_rate=$2
+
+cd ../..
+
+dataset=femnist
+
+out_dir=out_${dataset}
+
+echo "HPO starts..."
+
+lrs=(0.01 0.01668 0.02783 0.04642 0.07743 0.12915 0.21544 0.35938 0.59948 1.0)
+wds=(0.0 0.001 0.01 0.1)
+dps=(0.0 0.5)
+steps=(1 2 3 4)
+batch_sizes=(16 32 64)
+
+for ((l = 0; l < ${#lrs[@]}; l++)); do
+ for ((w = 0; w < ${#wds[@]}; w++)); do
+ for ((d = 0; d < ${#dps[@]}; d++)); do
+ for ((s = 0; s < ${#steps[@]}; s++)); do
+ for ((b = 0; b < ${#batch_sizes[@]}; b++)); do
+ for k in {1..3}; do
+ python federatedscope/main.py --cfg fedhpo/cnn/${dataset}.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} model.dropout ${dps[$d]} federate.local_update_steps ${steps[$s]} data.batch_size ${batch_sizes[$b]} federate.sample_client_rate $sample_rate seed $k outdir ${out_dir}_${sample_rate} expname lr${lrs[$l]}_wd${wds[$w]}_dropout${dps[$d]}_step${steps[$s]}_batch${batch_sizes[$b]}_seed${k} >/dev/null 2>&1
+ done
+ done
+ done
+ done
+ done
+done
+
+echo "HPO ends."
diff --git a/benchmark/FedHPOB/scripts/cnn/run_hpo_femnist_48.sh b/benchmark/FedHPOB/scripts/cnn/run_hpo_femnist_48.sh
new file mode 100644
index 000000000..97da980bc
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/cnn/run_hpo_femnist_48.sh
@@ -0,0 +1,34 @@
+set -e
+
+# wds=(0.0 0.001 0.01 0.1)
+# steps=(1 2 3 4)
+# batch_sizes=(16 32 64)
+
+cudaid=$1
+wd=$2
+step=$3
+batch_size=$4
+
+cd ../..
+
+dataset=femnist
+
+out_dir=out_${dataset}
+
+echo "HPO starts..."
+
+lrs=(0.01 0.01668 0.02783 0.04642 0.07743 0.12915 0.21544 0.35938 0.59948 1.0)
+dps=(0.0 0.5)
+sample_rates=(0.2 0.4 0.6 0.8 1.0)
+
+for ((l = 0; l < ${#lrs[@]}; l++)); do
+ for ((d = 0; d < ${#dps[@]}; d++)); do
+ for ((s = 0; s < ${#sample_rates[@]}; s++)); do
+ for k in {1..3}; do
+ python federatedscope/main.py --cfg fedhpo/cnn/${dataset}.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wd} model.dropout ${dps[$d]} federate.local_update_steps ${step} data.batch_size ${batch_size} federate.sample_client_rate ${sample_rates[$s]} seed $k outdir ${out_dir}_${sample_rates[$s]} expname lr${lrs[$l]}_wd${wd}_dropout${dps[$d]}_step${step}_batch${batch_size}_seed${k} >/dev/null 2>&1
+ done
+ done
+ done
+done
+
+echo "HPO ends."
diff --git a/benchmark/FedHPOB/scripts/exp/graph.yaml b/benchmark/FedHPOB/scripts/exp/graph.yaml
new file mode 100644
index 000000000..711c4f5fa
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/exp/graph.yaml
@@ -0,0 +1,5 @@
+benchmark:
+ device: 0
+optimizer:
+ min_budget: 3
+ max_budget: 81
\ No newline at end of file
diff --git a/benchmark/FedHPOB/scripts/exp/run.sh b/benchmark/FedHPOB/scripts/exp/run.sh
new file mode 100644
index 000000000..5d55642bb
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/exp/run.sh
@@ -0,0 +1,11 @@
+nohup bash run_mode.sh cora tabular 0 &
+nohup bash run_mode.sh citeseer tabular 0 &
+nohup bash run_mode.sh pubmed tabular 0 &
+
+nohup bash run_mode.sh cora raw 0
+nohup bash run_mode.sh citeseer raw 1
+nohup bash run_mode.sh pubmed raw 2
+
+nohup bash run_mode.sh cora surrogate 0
+nohup bash run_mode.sh citeseer surrogate 0
+nohup bash run_mode.sh pubmed surrogate 0
\ No newline at end of file
diff --git a/benchmark/FedHPOB/scripts/exp/run_graph.sh b/benchmark/FedHPOB/scripts/exp/run_graph.sh
new file mode 100644
index 000000000..851a86e5f
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/exp/run_graph.sh
@@ -0,0 +1,124 @@
+# ******Cora*****
+
+# tabular
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data cora optimizer.type rs
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data cora optimizer.type bo_gp
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data cora optimizer.type bo_rf
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data cora optimizer.type bo_kde
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data cora optimizer.type de
+
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data cora optimizer.type hb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data cora optimizer.type bohb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data cora optimizer.type dehb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data cora optimizer.type tpe_md
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data cora optimizer.type tpe_hb
+
+# raw
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data cora optimizer.type rs benchmark.device 0
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data cora optimizer.type bo_gp benchmark.device 1
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data cora optimizer.type bo_rf benchmark.device 2
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data cora optimizer.type bo_kde benchmark.device 3
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data cora optimizer.type de benchmark.device 4
+
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data cora optimizer.type hb benchmark.device 5
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data cora optimizer.type bohb benchmark.device 6
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data cora optimizer.type dehb benchmark.device 7
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data cora optimizer.type tpe_md benchmark.device 7
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data cora optimizer.type tpe_hb benchmark.device 6
+
+# surrogate
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data cora optimizer.type rs
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data cora optimizer.type bo_gp
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data cora optimizer.type bo_rf
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data cora optimizer.type bo_kde
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data cora optimizer.type de
+
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data cora optimizer.type hb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data cora optimizer.type bohb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data cora optimizer.type dehb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data cora optimizer.type tpe_md
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data cora optimizer.type tpe_hb
+
+
+# ******CiteSeer*****
+
+# tabular
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data citeseer optimizer.type rs
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data citeseer optimizer.type bo_gp
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data citeseer optimizer.type bo_rf
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data citeseer optimizer.type bo_kde
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data citeseer optimizer.type de
+
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data citeseer optimizer.type hb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data citeseer optimizer.type bohb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data citeseer optimizer.type dehb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data citeseer optimizer.type tpe_md
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data citeseer optimizer.type tpe_hb
+
+# raw
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data citeseer optimizer.type rs benchmark.device 0
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data citeseer optimizer.type bo_gp benchmark.device 1
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data citeseer optimizer.type bo_rf benchmark.device 2
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data citeseer optimizer.type bo_kde benchmark.device 3
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data citeseer optimizer.type de benchmark.device 4
+
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data citeseer optimizer.type hb benchmark.device 5
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data citeseer optimizer.type bohb benchmark.device 6
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data citeseer optimizer.type dehb benchmark.device 7
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data citeseer optimizer.type tpe_md benchmark.device 7
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data citeseer optimizer.type tpe_hb benchmark.device 6
+
+# surrogate
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data citeseer optimizer.type rs
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data citeseer optimizer.type bo_gp
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data citeseer optimizer.type bo_rf
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data citeseer optimizer.type bo_kde
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data citeseer optimizer.type de
+
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data citeseer optimizer.type hb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data citeseer optimizer.type bohb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data citeseer optimizer.type dehb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data citeseer optimizer.type tpe_md
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data citeseer optimizer.type tpe_hb
+
+
+# ******Pubmed*****
+
+# tabular
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data pubmed optimizer.type rs
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data pubmed optimizer.type bo_gp
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data pubmed optimizer.type bo_rf
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data pubmed optimizer.type bo_kde
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data pubmed optimizer.type de
+
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data pubmed optimizer.type hb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data pubmed optimizer.type bohb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data pubmed optimizer.type dehb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data pubmed optimizer.type tpe_md
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type tabular benchmark.data pubmed optimizer.type tpe_hb
+
+# raw
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data pubmed optimizer.type rs benchmark.device 0
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data pubmed optimizer.type bo_gp benchmark.device 1
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data pubmed optimizer.type bo_rf benchmark.device 2
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data pubmed optimizer.type bo_kde benchmark.device 3
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data pubmed optimizer.type de benchmark.device 4
+
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data pubmed optimizer.type hb benchmark.device 5
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data pubmed optimizer.type bohb benchmark.device 6
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data pubmed optimizer.type dehb benchmark.device 7
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data pubmed optimizer.type tpe_md benchmark.device 7
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type raw benchmark.data pubmed optimizer.type tpe_hb benchmark.device 6
+
+# surrogate
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data pubmed optimizer.type rs
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data pubmed optimizer.type bo_gp
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data pubmed optimizer.type bo_rf
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data pubmed optimizer.type bo_kde
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data pubmed optimizer.type de
+
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data pubmed optimizer.type hb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data pubmed optimizer.type bohb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data pubmed optimizer.type dehb
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data pubmed optimizer.type tpe_md
+python runner.py --cfg scripts/exp/graph.yaml benchmark.type surrogate benchmark.data pubmed optimizer.type tpe_hb
diff --git a/benchmark/FedHPOB/scripts/exp/run_mode.sh b/benchmark/FedHPOB/scripts/exp/run_mode.sh
new file mode 100644
index 000000000..c7d477eb6
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/exp/run_mode.sh
@@ -0,0 +1,22 @@
+set -e
+
+dataset=$1
+mode=$2
+device=$3
+
+cd ../..
+cp fedhpob/utils/runner.py . || echo "File exists."
+
+for k in {1..5}; do
+ python runner.py --cfg scripts/exp/graph.yaml benchmark.device ${device} benchmark.type ${mode} benchmark.data ${dataset} optimizer.type rs || echo "continue"
+ python runner.py --cfg scripts/exp/graph.yaml benchmark.device ${device} benchmark.type ${mode} benchmark.data ${dataset} optimizer.type bo_gp || echo "continue"
+ python runner.py --cfg scripts/exp/graph.yaml benchmark.device ${device} benchmark.type ${mode} benchmark.data ${dataset} optimizer.type bo_rf || echo "continue"
+ python runner.py --cfg scripts/exp/graph.yaml benchmark.device ${device} benchmark.type ${mode} benchmark.data ${dataset} optimizer.type bo_kde || echo "continue"
+ python runner.py --cfg scripts/exp/graph.yaml benchmark.device ${device} benchmark.type ${mode} benchmark.data ${dataset} optimizer.type de || echo "continue"
+
+ python runner.py --cfg scripts/exp/graph.yaml benchmark.device ${device} benchmark.type ${mode} benchmark.data ${dataset} optimizer.type hb || echo "continue"
+ python runner.py --cfg scripts/exp/graph.yaml benchmark.device ${device} benchmark.type ${mode} benchmark.data ${dataset} optimizer.type bohb || echo "continue"
+ python runner.py --cfg scripts/exp/graph.yaml benchmark.device ${device} benchmark.type ${mode} benchmark.data ${dataset} optimizer.type dehb || echo "continue"
+ python runner.py --cfg scripts/exp/graph.yaml benchmark.device ${device} benchmark.type ${mode} benchmark.data ${dataset} optimizer.type tpe_md || echo "continue"
+ python runner.py --cfg scripts/exp/graph.yaml benchmark.device ${device} benchmark.type ${mode} benchmark.data ${dataset} optimizer.type tpe_hb || echo "continue"
+done
diff --git a/benchmark/FedHPOB/scripts/format.sh b/benchmark/FedHPOB/scripts/format.sh
new file mode 100644
index 000000000..dd2a28a08
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/format.sh
@@ -0,0 +1,124 @@
+# Copyright 2017 The Ray Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+#
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+#!/usr/bin/env bash
+# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
+# You are encouraged to run this locally before pushing changes for review.
+
+# Cause the script to exit if a single command fails
+set -eo pipefail
+
+ver=$(yapf --version)
+if ! echo $ver | grep -q 0.31.0; then
+ echo "Wrong YAPF version installed: 0.31.0 is required, not $ver"
+ exit 1
+fi
+
+# this stops git rev-parse from failing if we run this from the .git directory
+builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
+
+ROOT="$(git rev-parse --show-toplevel)"
+builtin cd "$ROOT" || exit 1
+
+FLAKE8_VERSION=$(flake8 --version | awk '{print $1}')
+YAPF_VERSION=$(yapf --version | awk '{print $2}')
+
+# params: tool name, tool version, required version
+tool_version_check() {
+ if [[ $2 != $3 ]]; then
+ echo "WARNING: FedHPOB uses $1 $3, You currently are using $2. This might generate different results."
+ fi
+}
+
+tool_version_check "flake8" $FLAKE8_VERSION "4.0.1"
+tool_version_check "yapf" $YAPF_VERSION "0.31.0"
+
+# Only fetch master since that's the branch we're diffing against.
+#git fetch upstream master || true
+git fetch origin master || true
+
+YAPF_FLAGS=(
+ '--style' "$ROOT/.style.yapf"
+ '--recursive'
+ '--parallel'
+)
+
+YAPF_EXCLUDES=(
+ '--exclude' 'scripts/*'
+)
+
+# Format specified files
+format() {
+ yapf --in-place "${YAPF_FLAGS[@]}" -- "$@"
+}
+
+# Format files that differ from main branch. Ignores dirs that are not slated
+# for autoformat yet.
+format_changed() {
+ # The `if` guard ensures that the list of filenames is not empty, which
+ # could cause yapf to receive 0 positional arguments, making it hang
+ # waiting for STDIN.
+ #
+ # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
+ # exist on both branches.
+ #MERGEBASE="$(git merge-base upstream/master HEAD)"
+ MERGEBASE="$(git merge-base origin/master HEAD)"
+
+ if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' &>/dev/null; then
+ git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \
+ yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
+ if which flake8 >/dev/null; then
+ git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \
+ flake8 --inline-quotes '"' --no-avoid-escape --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605
+ fi
+ fi
+
+ if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then
+ if which flake8 >/dev/null; then
+ git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \
+ flake8 --inline-quotes '"' --no-avoid-escape --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605
+ fi
+ fi
+}
+
+# Format all files, and print the diff to stdout for travis.
+format_all() {
+ yapf --diff "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" fedhpob
+ #yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" fedhpob
+}
+
+# This flag formats individual files. --files *must* be the first command line
+# arg to use this option.
+if [[ "$1" == '--files' ]]; then
+ format "${@:2}"
+ # If `--all` is passed, then any further arguments are ignored and the
+ # entire python directory is formatted.
+elif [[ "$1" == '--all' ]]; then
+ format_all
+else
+ # Format only the files that changed in last commit.
+ format_changed
+fi
+
+if ! git diff --quiet &>/dev/null; then
+ echo 'Reformatted changed files. Please review and stage the changes.'
+ echo 'Files updated:'
+ echo
+
+ git --no-pager diff --name-only
+
+ exit 1
+fi
diff --git a/benchmark/FedHPOB/scripts/gcn/citeseer.yaml b/benchmark/FedHPOB/scripts/gcn/citeseer.yaml
new file mode 100644
index 000000000..51da2333d
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/gcn/citeseer.yaml
@@ -0,0 +1,34 @@
+use_gpu: True
+device: 0
+early_stop:
+ patience: 100
+seed: 12345
+federate:
+ mode: standalone
+ make_global_eval: True
+ client_num: 5
+ local_update_steps: 1
+ total_round_num: 500
+ share_local_model: True
+ online_aggr: True
+data:
+ root: data/
+ type: citeseer
+ splitter: 'louvain'
+ batch_size: 1
+model:
+ type: gcn
+ hidden: 64
+ dropout: 0.5
+ out_channels: 6
+optimizer:
+ lr: 0.25
+ weight_decay: 0.0005
+criterion:
+ type: CrossEntropyLoss
+trainer:
+ type: nodefullbatch_trainer
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['test', 'val', 'train']
\ No newline at end of file
diff --git a/benchmark/FedHPOB/scripts/gcn/cora.yaml b/benchmark/FedHPOB/scripts/gcn/cora.yaml
new file mode 100644
index 000000000..e1d4e79b7
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/gcn/cora.yaml
@@ -0,0 +1,34 @@
+use_gpu: True
+device: 0
+early_stop:
+ patience: 100
+seed: 12345
+federate:
+ mode: standalone
+ make_global_eval: True
+ client_num: 5
+ local_update_steps: 1
+ total_round_num: 500
+ share_local_model: True
+ online_aggr: True
+data:
+ root: data/
+ type: cora
+ splitter: 'louvain'
+ batch_size: 1
+model:
+ type: gcn
+ hidden: 64
+ dropout: 0.5
+ out_channels: 7
+optimizer:
+ lr: 0.25
+ weight_decay: 0.0005
+criterion:
+ type: CrossEntropyLoss
+trainer:
+ type: nodefullbatch_trainer
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['test', 'val', 'train']
\ No newline at end of file
diff --git a/benchmark/FedHPOB/scripts/gcn/pubmed.yaml b/benchmark/FedHPOB/scripts/gcn/pubmed.yaml
new file mode 100644
index 000000000..b2763a353
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/gcn/pubmed.yaml
@@ -0,0 +1,34 @@
+use_gpu: True
+device: 0
+early_stop:
+ patience: 100
+seed: 12345
+federate:
+ mode: standalone
+ make_global_eval: True
+ client_num: 5
+ local_update_steps: 1
+ total_round_num: 500
+ share_local_model: True
+ online_aggr: True
+data:
+ root: data/
+ type: pubmed
+ splitter: 'louvain'
+ batch_size: 1
+model:
+ type: gcn
+ hidden: 64
+ dropout: 0.5
+ out_channels: 5
+optimizer:
+ lr: 0.25
+ weight_decay: 0.0005
+criterion:
+ type: CrossEntropyLoss
+trainer:
+ type: nodefullbatch_trainer
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['test', 'val', 'train']
\ No newline at end of file
diff --git a/benchmark/FedHPOB/scripts/gcn/run_hpo_citeseer.sh b/benchmark/FedHPOB/scripts/gcn/run_hpo_citeseer.sh
new file mode 100644
index 000000000..771c559bf
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/gcn/run_hpo_citeseer.sh
@@ -0,0 +1,31 @@
+set -e
+
+cudaid=$1
+sample_num=$2
+
+cd ../..
+
+dataset=citeseer
+
+out_dir=out_${dataset}
+
+echo "HPO starts..."
+
+lrs=(0.01 0.01668 0.02783 0.04642 0.07743 0.12915 0.21544 0.35938 0.59948 1.0)
+wds=(0.0 0.001 0.01 0.1)
+dps=(0.0 0.5)
+steps=(1 2 3 4 5 6 7 8)
+
+for ((l = 0; l < ${#lrs[@]}; l++)); do
+ for ((w = 0; w < ${#wds[@]}; w++)); do
+ for ((d = 0; d < ${#dps[@]}; d++)); do
+ for ((s = 0; s < ${#steps[@]}; s++)); do
+ for k in {1..3}; do
+ python federatedscope/main.py --cfg fedhpo/graph/${dataset}/${dataset}.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} model.dropout ${dps[$d]} federate.local_update_steps ${steps[$s]} federate.sample_client_num $sample_num seed $k outdir ${out_dir}_${sample_num} expname lr${lrs[$l]}_wd${wds[$w]}_dropout${dps[$d]}_step${steps[$s]}_seed${k} >/dev/null 2>&1
+ done
+ done
+ done
+ done
+done
+
+echo "HPO ends."
diff --git a/benchmark/FedHPOB/scripts/gcn/run_hpo_cora.sh b/benchmark/FedHPOB/scripts/gcn/run_hpo_cora.sh
new file mode 100644
index 000000000..878e29ab0
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/gcn/run_hpo_cora.sh
@@ -0,0 +1,31 @@
+set -e
+
+cudaid=$1
+sample_num=$2
+
+cd ../..
+
+dataset=cora
+
+out_dir=out_${dataset}
+
+echo "HPO starts..."
+
+lrs=(0.01 0.01668 0.02783 0.04642 0.07743 0.12915 0.21544 0.35938 0.59948 1.0)
+wds=(0.0 0.001 0.01 0.1)
+dps=(0.0 0.5)
+steps=(1 2 3 4 5 6 7 8)
+
+for ((l = 0; l < ${#lrs[@]}; l++)); do
+ for ((w = 0; w < ${#wds[@]}; w++)); do
+ for ((d = 0; d < ${#dps[@]}; d++)); do
+ for ((s = 0; s < ${#steps[@]}; s++)); do
+ for k in {1..3}; do
+ python federatedscope/main.py --cfg fedhpo/graph/${dataset}/${dataset}.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} model.dropout ${dps[$d]} federate.local_update_steps ${steps[$s]} federate.sample_client_num $sample_num seed $k outdir ${out_dir}_${sample_num} expname lr${lrs[$l]}_wd${wds[$w]}_dropout${dps[$d]}_step${steps[$s]}_seed${k} >/dev/null 2>&1
+ done
+ done
+ done
+ done
+done
+
+echo "HPO ends."
diff --git a/benchmark/FedHPOB/scripts/gcn/run_hpo_pubmed.sh b/benchmark/FedHPOB/scripts/gcn/run_hpo_pubmed.sh
new file mode 100644
index 000000000..3deef9c7d
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/gcn/run_hpo_pubmed.sh
@@ -0,0 +1,31 @@
+set -e
+
+cudaid=$1
+sample_num=$2
+
+cd ../..
+
+dataset=pubmed
+
+out_dir=out_${dataset}
+
+echo "HPO starts..."
+
+lrs=(0.01 0.01668 0.02783 0.04642 0.07743 0.12915 0.21544 0.35938 0.59948 1.0)
+wds=(0.0 0.001 0.01 0.1)
+dps=(0.0 0.5)
+steps=(1 2 3 4 5 6 7 8)
+
+for ((l = 0; l < ${#lrs[@]}; l++)); do
+ for ((w = 0; w < ${#wds[@]}; w++)); do
+ for ((d = 0; d < ${#dps[@]}; d++)); do
+ for ((s = 0; s < ${#steps[@]}; s++)); do
+ for k in {1..3}; do
+ python federatedscope/main.py --cfg fedhpo/graph/${dataset}/${dataset}.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} model.dropout ${dps[$d]} federate.local_update_steps ${steps[$s]} federate.sample_client_num $sample_num seed $k outdir ${out_dir}_${sample_num} expname lr${lrs[$l]}_wd${wds[$w]}_dropout${dps[$d]}_step${steps[$s]}_seed${k} >/dev/null 2>&1
+ done
+ done
+ done
+ done
+done
+
+echo "HPO ends."
diff --git a/benchmark/FedHPOB/scripts/gcn/run_opt_citeseer.sh b/benchmark/FedHPOB/scripts/gcn/run_opt_citeseer.sh
new file mode 100644
index 000000000..df8b44fb0
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/gcn/run_opt_citeseer.sh
@@ -0,0 +1,38 @@
+set -e
+
+cudaid=$1
+sample_num=$2
+
+cd ../..
+
+dataset=citeseer
+
+out_dir=out_${dataset}
+
+echo "HPO starts..."
+
+lrs=(0.01 0.01668 0.02783 0.04642 0.07743 0.12915 0.21544 0.35938 0.59948 1.0)
+wds=(0.0 0.001 0.01 0.1)
+dps=(0.0 0.5)
+steps=(1)
+
+lrs_server=(0.1 0.5 1.0)
+momentums_server=(0.0 0.9)
+
+for ((l = 0; l < ${#lrs[@]}; l++)); do
+ for ((w = 0; w < ${#wds[@]}; w++)); do
+ for ((d = 0; d < ${#dps[@]}; d++)); do
+ for ((s = 0; s < ${#steps[@]}; s++)); do
+ for ((sl = 0; sl < ${#lrs_server[@]}; sl++)); do
+ for ((ms = 0; ms < ${#momentums_server[@]}; ms++)); do
+ for k in {1..3}; do
+ python federatedscope/main.py --cfg fedhpo/graph/${dataset}/${dataset}.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} model.dropout ${dps[$d]} federate.local_update_steps ${steps[$s]} federate.sample_client_num $sample_num fedopt.use True federate.method FedOpt fedopt.lr_server ${lrs_server[$sl]} fedopt.momentum_server ${momentums_server[$ms]} seed $k outdir out_fedopt/${out_dir}_${sample_num} federate.share_local_model False federate.online_aggr False expname lr${lrs[$l]}_wd${wds[$w]}_dropout${dps[$d]}_step${steps[$s]}_lrserver${lrs_server[$sl]}_momentumsserver${momentums_server[$ms]}_seed${k} >/dev/null 2>&1
+ done
+ done
+ done
+ done
+ done
+ done
+done
+
+echo "HPO ends."
diff --git a/benchmark/FedHPOB/scripts/gcn/run_opt_cora.sh b/benchmark/FedHPOB/scripts/gcn/run_opt_cora.sh
new file mode 100644
index 000000000..60d40de8a
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/gcn/run_opt_cora.sh
@@ -0,0 +1,38 @@
+set -e
+
+cudaid=$1
+sample_num=$2
+
+cd ../..
+
+dataset=cora
+
+out_dir=out_${dataset}
+
+echo "HPO starts..."
+
+lrs=(0.01 0.01668 0.02783 0.04642 0.07743 0.12915 0.21544 0.35938 0.59948 1.0)
+wds=(0.0 0.001 0.01 0.1)
+dps=(0.0 0.5)
+steps=(1)
+
+lrs_server=(0.1 0.5 1.0)
+momentums_server=(0.0 0.9)
+
+for ((l = 0; l < ${#lrs[@]}; l++)); do
+ for ((w = 0; w < ${#wds[@]}; w++)); do
+ for ((d = 0; d < ${#dps[@]}; d++)); do
+ for ((s = 0; s < ${#steps[@]}; s++)); do
+ for ((sl = 0; sl < ${#lrs_server[@]}; sl++)); do
+ for ((ms = 0; ms < ${#momentums_server[@]}; ms++)); do
+ for k in {1..3}; do
+ python federatedscope/main.py --cfg fedhpo/graph/${dataset}/${dataset}.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} model.dropout ${dps[$d]} federate.local_update_steps ${steps[$s]} federate.sample_client_num $sample_num fedopt.use True federate.method FedOpt fedopt.lr_server ${lrs_server[$sl]} fedopt.momentum_server ${momentums_server[$ms]} seed $k outdir out_fedopt/${out_dir}_${sample_num} federate.share_local_model False federate.online_aggr False expname lr${lrs[$l]}_wd${wds[$w]}_dropout${dps[$d]}_step${steps[$s]}_lrserver${lrs_server[$sl]}_momentumsserver${momentums_server[$ms]}_seed${k} >/dev/null 2>&1
+ done
+ done
+ done
+ done
+ done
+ done
+done
+
+echo "HPO ends."
diff --git a/benchmark/FedHPOB/scripts/gcn/run_opt_pubmed.sh b/benchmark/FedHPOB/scripts/gcn/run_opt_pubmed.sh
new file mode 100644
index 000000000..911caa08e
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/gcn/run_opt_pubmed.sh
@@ -0,0 +1,38 @@
+set -e
+
+cudaid=$1
+sample_num=$2
+
+cd ../..
+
+dataset=pubmed
+
+out_dir=out_${dataset}
+
+echo "HPO starts..."
+
+lrs=(0.01 0.01668 0.02783 0.04642 0.07743 0.12915 0.21544 0.35938 0.59948 1.0)
+wds=(0.0 0.001 0.01 0.1)
+dps=(0.0 0.5)
+steps=(1)
+
+lrs_server=(0.1 0.5 1.0)
+momentums_server=(0.0 0.9)
+
+for ((l = 0; l < ${#lrs[@]}; l++)); do
+ for ((w = 0; w < ${#wds[@]}; w++)); do
+ for ((d = 0; d < ${#dps[@]}; d++)); do
+ for ((s = 0; s < ${#steps[@]}; s++)); do
+ for ((sl = 0; sl < ${#lrs_server[@]}; sl++)); do
+ for ((ms = 0; ms < ${#momentums_server[@]}; ms++)); do
+ for k in {1..3}; do
+ python federatedscope/main.py --cfg fedhpo/graph/${dataset}/${dataset}.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} model.dropout ${dps[$d]} federate.local_update_steps ${steps[$s]} federate.sample_client_num $sample_num fedopt.use True federate.method FedOpt fedopt.lr_server ${lrs_server[$sl]} fedopt.momentum_server ${momentums_server[$ms]} seed $k outdir out_fedopt/${out_dir}_${sample_num} federate.share_local_model False federate.online_aggr False expname lr${lrs[$l]}_wd${wds[$w]}_dropout${dps[$d]}_step${steps[$s]}_lrserver${lrs_server[$sl]}_momentumsserver${momentums_server[$ms]}_seed${k} >/dev/null 2>&1
+ done
+ done
+ done
+ done
+ done
+ done
+done
+
+echo "HPO ends."
diff --git a/benchmark/FedHPOB/scripts/lr/10101@openml.yaml b/benchmark/FedHPOB/scripts/lr/10101@openml.yaml
new file mode 100644
index 000000000..8008aca11
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/10101@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 10101@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/12@openml.yaml b/benchmark/FedHPOB/scripts/lr/12@openml.yaml
new file mode 100644
index 000000000..9e9ae1db9
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/12@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 12@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 10
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/146212@openml.yaml b/benchmark/FedHPOB/scripts/lr/146212@openml.yaml
new file mode 100644
index 000000000..ad749551b
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/146212@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 146212@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 7
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/146606@openml.yaml b/benchmark/FedHPOB/scripts/lr/146606@openml.yaml
new file mode 100644
index 000000000..b443b8dff
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/146606@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 146606@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/146818@openml.yaml b/benchmark/FedHPOB/scripts/lr/146818@openml.yaml
new file mode 100644
index 000000000..aa42f0eaf
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/146818@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 146818@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/146821@openml.yaml b/benchmark/FedHPOB/scripts/lr/146821@openml.yaml
new file mode 100644
index 000000000..e04b20a6e
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/146821@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 146821@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 4
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/146822@openml.yaml b/benchmark/FedHPOB/scripts/lr/146822@openml.yaml
new file mode 100644
index 000000000..cadcb8243
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/146822@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 146822@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 7
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/14965@openml.yaml b/benchmark/FedHPOB/scripts/lr/14965@openml.yaml
new file mode 100644
index 000000000..9394a996a
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/14965@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 14965@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/167119@openml.yaml b/benchmark/FedHPOB/scripts/lr/167119@openml.yaml
new file mode 100644
index 000000000..3d86764fc
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/167119@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 167119@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 3
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/167120@openml.yaml b/benchmark/FedHPOB/scripts/lr/167120@openml.yaml
new file mode 100644
index 000000000..f925e5074
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/167120@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 167120@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/168911@openml.yaml b/benchmark/FedHPOB/scripts/lr/168911@openml.yaml
new file mode 100644
index 000000000..1e95c1c2d
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/168911@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 168911@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/168912@openml.yaml b/benchmark/FedHPOB/scripts/lr/168912@openml.yaml
new file mode 100644
index 000000000..c1cbed8ae
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/168912@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 168912@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/31@openml.yaml b/benchmark/FedHPOB/scripts/lr/31@openml.yaml
new file mode 100644
index 000000000..3f418a45e
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/31@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 31@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/3917@openml.yaml b/benchmark/FedHPOB/scripts/lr/3917@openml.yaml
new file mode 100644
index 000000000..c21fde7d2
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/3917@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 3917@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/3@openml.yaml b/benchmark/FedHPOB/scripts/lr/3@openml.yaml
new file mode 100644
index 000000000..d503d9e71
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/3@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 3@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/53@openml.yaml b/benchmark/FedHPOB/scripts/lr/53@openml.yaml
new file mode 100644
index 000000000..d351cdf2c
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/53@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 53@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 4
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/7592@openml.yaml b/benchmark/FedHPOB/scripts/lr/7592@openml.yaml
new file mode 100644
index 000000000..e5908f772
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/7592@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 7592@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/9952@openml.yaml b/benchmark/FedHPOB/scripts/lr/9952@openml.yaml
new file mode 100644
index 000000000..6dcd949d9
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/9952@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 9952@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/9977@openml.yaml b/benchmark/FedHPOB/scripts/lr/9977@openml.yaml
new file mode 100644
index 000000000..ebf1c4ff8
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/9977@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 9977@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/9981@openml.yaml b/benchmark/FedHPOB/scripts/lr/9981@openml.yaml
new file mode 100644
index 000000000..57582ef6d
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/9981@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 9981@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 9
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/openml_lr.yaml b/benchmark/FedHPOB/scripts/lr/openml_lr.yaml
new file mode 100644
index 000000000..1fe6fe182
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/openml_lr.yaml
@@ -0,0 +1,29 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: lr
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/lr/run.sh b/benchmark/FedHPOB/scripts/lr/run.sh
new file mode 100644
index 000000000..b8538c37c
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/run.sh
@@ -0,0 +1,41 @@
+bash run_hpo_openml_mlp.sh 0 10101 &
+bash run_hpo_openml_mlp.sh 1 53 &
+bash run_hpo_openml_mlp.sh 2 146818 &
+bash run_hpo_openml_mlp.sh 3 146821 &
+bash run_hpo_openml_mlp.sh 4 9952 &
+bash run_hpo_openml_mlp.sh 5 146822 &
+bash run_hpo_openml_mlp.sh 6 31 &
+bash run_hpo_openml_mlp.sh 7 3917 &
+bash run_hpo_openml_mlp.sh 0 168912 &
+bash run_hpo_openml_mlp.sh 1 3 &
+bash run_hpo_openml_mlp.sh 2 167119 &
+bash run_hpo_openml_mlp.sh 3 12 &
+bash run_hpo_openml_mlp.sh 4 146212 &
+bash run_hpo_openml_mlp.sh 5 168911 &
+bash run_hpo_openml_mlp.sh 6 9981 &
+bash run_hpo_openml_mlp.sh 7 167120 &
+bash run_hpo_openml_mlp.sh 0 14965 &
+bash run_hpo_openml_mlp.sh 1 146606 &
+bash run_hpo_openml_mlp.sh 2 7592 &
+bash run_hpo_openml_mlp.sh 3 9977 &
+
+bash run_hpo_openml_lr.sh 0 10101 &
+bash run_hpo_openml_lr.sh 1 53 &
+bash run_hpo_openml_lr.sh 2 146818 &
+bash run_hpo_openml_lr.sh 3 146821 &
+bash run_hpo_openml_lr.sh 4 9952 &
+bash run_hpo_openml_lr.sh 5 146822 &
+bash run_hpo_openml_lr.sh 6 31 &
+bash run_hpo_openml_lr.sh 7 3917 &
+bash run_hpo_openml_lr.sh 0 168912 &
+bash run_hpo_openml_lr.sh 1 3 &
+bash run_hpo_openml_lr.sh 2 167119 &
+bash run_hpo_openml_lr.sh 3 12 &
+bash run_hpo_openml_lr.sh 4 146212 &
+bash run_hpo_openml_lr.sh 5 168911 &
+bash run_hpo_openml_lr.sh 6 9981 &
+bash run_hpo_openml_lr.sh 7 167120 &
+bash run_hpo_openml_lr.sh 0 14965 &
+bash run_hpo_openml_lr.sh 1 146606 &
+bash run_hpo_openml_lr.sh 2 7592 &
+bash run_hpo_openml_lr.sh 3 9977 &
\ No newline at end of file
diff --git a/benchmark/FedHPOB/scripts/lr/run_hpo_openml_lr.sh b/benchmark/FedHPOB/scripts/lr/run_hpo_openml_lr.sh
new file mode 100644
index 000000000..b9b0da8b1
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/run_hpo_openml_lr.sh
@@ -0,0 +1,86 @@
+set -e
+
+cudaid=$1
+dataset=$2
+
+cd ../..
+
+out_dir=out_${dataset}
+
+if [ ! -d $out_dir ];then
+ mkdir $out_dir
+fi
+
+if [[ $dataset = '10101' ]]; then
+ out_channels=2
+elif [[ $dataset = '53' ]]; then
+ out_channels=4
+elif [[ $dataset = '146818' ]]; then
+ out_channels=2
+elif [[ $dataset = '146821' ]]; then
+ out_channels=4
+elif [[ $dataset = '9952' ]]; then
+ out_channels=2
+elif [[ $dataset = '146822' ]]; then
+ out_channels=7
+elif [[ $dataset = '31' ]]; then
+ out_channels=2
+elif [[ $dataset = '3917' ]]; then
+ out_channels=2
+elif [[ $dataset = '168912' ]]; then
+ out_channels=2
+elif [[ $dataset = '3' ]]; then
+ out_channels=2
+elif [[ $dataset = '167119' ]]; then
+ out_channels=3
+elif [[ $dataset = '12' ]]; then
+ out_channels=10
+elif [[ $dataset = '146212' ]]; then
+ out_channels=7
+elif [[ $dataset = '168911' ]]; then
+ out_channels=2
+elif [[ $dataset = '9981' ]]; then
+ out_channels=9
+elif [[ $dataset = '167120' ]]; then
+ out_channels=2
+elif [[ $dataset = '14965' ]]; then
+ out_channels=2
+elif [[ $dataset = '146606' ]]; then
+ out_channels=2
+elif [[ $dataset = '7592' ]]; then
+ out_channels=2
+elif [[ $dataset = '9977' ]]; then
+ out_channels=2
+else
+ out_channels=2
+fi
+
+echo "HPO starts..."
+
+sample_rates=(0.2 0.4 0.6 0.8 1.0)
+lrs=(0.00001 0.0001 0.001 0.01 0.1 1.0)
+wds=(0.0 0.001 0.01 0.1)
+steps=(1 2 3 4)
+batch_sizes=(4 8 16 32 64 128 256)
+
+for (( sr=0; sr<${#sample_rates[@]}; sr++ ))
+do
+ for (( l=0; l<${#lrs[@]}; l++ ))
+ do
+ for (( w=0; w<${#wds[@]}; w++ ))
+ do
+ for (( s=0; s<${#steps[@]}; s++ ))
+ do
+ for (( b=0; b<${#batch_sizes[@]}; b++ ))
+ do
+ for k in {1..3}
+ do
+ python federatedscope/main.py --cfg fedhpo/openml/openml_lr.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} federate.local_update_steps ${steps[$s]} data.type ${dataset}@openml data.batch_size ${batch_sizes[$b]} federate.sample_client_rate ${sample_rates[$sr]} model.out_channels $out_channels seed $k outdir lr/${out_dir}_${sample_rates[$sr]} expname lr${lrs[$l]}_wd${wds[$w]}_dropout0_step${steps[$s]}_batch${batch_sizes[$b]}_seed${k} >/dev/null 2>&1
+ done
+ done
+ done
+ done
+ done
+done
+
+echo "HPO ends."
diff --git a/benchmark/FedHPOB/scripts/lr/run_opt_openml_lr.sh b/benchmark/FedHPOB/scripts/lr/run_opt_openml_lr.sh
new file mode 100644
index 000000000..5c2b47bb8
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/lr/run_opt_openml_lr.sh
@@ -0,0 +1,87 @@
+set -e
+
+cudaid=$1
+dataset=$2
+
+cd ../..
+
+out_dir=out_${dataset}
+
+if [ ! -d $out_dir ]; then
+ mkdir $out_dir
+fi
+
+if [[ $dataset == '10101' ]]; then
+ out_channels=2
+elif [[ $dataset == '53' ]]; then
+ out_channels=4
+elif [[ $dataset == '146818' ]]; then
+ out_channels=2
+elif [[ $dataset == '146821' ]]; then
+ out_channels=4
+elif [[ $dataset == '9952' ]]; then
+ out_channels=2
+elif [[ $dataset == '146822' ]]; then
+ out_channels=7
+elif [[ $dataset == '31' ]]; then
+ out_channels=2
+elif [[ $dataset == '3917' ]]; then
+ out_channels=2
+elif [[ $dataset == '168912' ]]; then
+ out_channels=2
+elif [[ $dataset == '3' ]]; then
+ out_channels=2
+elif [[ $dataset == '167119' ]]; then
+ out_channels=3
+elif [[ $dataset == '12' ]]; then
+ out_channels=10
+elif [[ $dataset == '146212' ]]; then
+ out_channels=7
+elif [[ $dataset == '168911' ]]; then
+ out_channels=2
+elif [[ $dataset == '9981' ]]; then
+ out_channels=9
+elif [[ $dataset == '167120' ]]; then
+ out_channels=2
+elif [[ $dataset == '14965' ]]; then
+ out_channels=2
+elif [[ $dataset == '146606' ]]; then
+ out_channels=2
+elif [[ $dataset == '7592' ]]; then
+ out_channels=2
+elif [[ $dataset == '9977' ]]; then
+ out_channels=2
+else
+ out_channels=2
+fi
+
+echo "HPO starts..."
+
+sample_rates=(0.2 0.4 0.6 0.8 1.0)
+lrs=(0.00001 0.0001 0.001 0.01 0.1 1.0)
+wds=(0.0 0.001 0.01 0.1)
+steps=(1)
+batch_sizes=(4 8 16 32 64 128 256)
+
+lrs_server=(0.1 0.5 1.0)
+momentums_server=(0.0 0.9)
+
+for ((sr = 0; sr < ${#sample_rates[@]}; sr++)); do
+ for ((l = 0; l < ${#lrs[@]}; l++)); do
+ for ((w = 0; w < ${#wds[@]}; w++)); do
+ for ((s = 0; s < ${#steps[@]}; s++)); do
+ for ((sl = 0; sl < ${#lrs_server[@]}; sl++)); do
+ for ((ms = 0; ms < ${#momentums_server[@]}; ms++)); do
+ for ((b = 0; b < ${#batch_sizes[@]}; b++)); do
+ for k in {1..3}; do
+ python federatedscope/main.py --cfg fedhpo/openml/openml_lr.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} federate.local_update_steps ${steps[$s]} data.type ${dataset}@openml data.batch_size ${batch_sizes[$b]} federate.sample_client_rate ${sample_rates[$sr]} model.out_channels $out_channels federate.share_local_model False federate.online_aggr False fedopt.use True federate.method FedOpt fedopt.lr_server ${lrs_server[$sl]} fedopt.momentum_server ${momentums_server[$ms]} seed $k outdir out_fedopt/lr/${out_dir}_${sample_rates[$sr]} expname lr${lrs[$l]}_wd${wds[$w]}_dropout0_step${steps[$s]}_batch${batch_sizes[$b]}_lrserver${lrs_server[$sl]}_momentumsserver${momentums_server[$ms]}_seed${k} >/dev/null 2>&1
+ done
+ done
+ done
+ done
+ done
+ done
+ done
+done
+
+echo "HPO ends."
diff --git a/benchmark/FedHPOB/scripts/mlp/10101@openml.yaml b/benchmark/FedHPOB/scripts/mlp/10101@openml.yaml
new file mode 100644
index 000000000..1d8b036ff
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/10101@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 10101@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/12@openml.yaml b/benchmark/FedHPOB/scripts/mlp/12@openml.yaml
new file mode 100644
index 000000000..e18306338
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/12@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 12@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 10
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/146212@openml.yaml b/benchmark/FedHPOB/scripts/mlp/146212@openml.yaml
new file mode 100644
index 000000000..8e497661d
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/146212@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 146212@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 7
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/146606@openml.yaml b/benchmark/FedHPOB/scripts/mlp/146606@openml.yaml
new file mode 100644
index 000000000..c8c0dd47c
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/146606@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 146606@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/146818@openml.yaml b/benchmark/FedHPOB/scripts/mlp/146818@openml.yaml
new file mode 100644
index 000000000..5f4e85e67
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/146818@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 146818@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/146821@openml.yaml b/benchmark/FedHPOB/scripts/mlp/146821@openml.yaml
new file mode 100644
index 000000000..6c8888b9b
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/146821@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 146821@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 4
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/146822@openml.yaml b/benchmark/FedHPOB/scripts/mlp/146822@openml.yaml
new file mode 100644
index 000000000..a04f0bfe3
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/146822@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 146822@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 7
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/14965@openml.yaml b/benchmark/FedHPOB/scripts/mlp/14965@openml.yaml
new file mode 100644
index 000000000..0233d8435
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/14965@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 14965@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/167119@openml.yaml b/benchmark/FedHPOB/scripts/mlp/167119@openml.yaml
new file mode 100644
index 000000000..13536fb04
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/167119@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 167119@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 3
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/167120@openml.yaml b/benchmark/FedHPOB/scripts/mlp/167120@openml.yaml
new file mode 100644
index 000000000..a8b20fef8
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/167120@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 167120@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/168911@openml.yaml b/benchmark/FedHPOB/scripts/mlp/168911@openml.yaml
new file mode 100644
index 000000000..6a56eb617
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/168911@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 168911@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/168912@openml.yaml b/benchmark/FedHPOB/scripts/mlp/168912@openml.yaml
new file mode 100644
index 000000000..2373d715c
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/168912@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 168912@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/31@openml.yaml b/benchmark/FedHPOB/scripts/mlp/31@openml.yaml
new file mode 100644
index 000000000..e60ce4e83
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/31@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 31@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/3917@openml.yaml b/benchmark/FedHPOB/scripts/mlp/3917@openml.yaml
new file mode 100644
index 000000000..90d3cad6c
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/3917@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 3917@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/3@openml.yaml b/benchmark/FedHPOB/scripts/mlp/3@openml.yaml
new file mode 100644
index 000000000..eac42dc84
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/3@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 3@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/53@openml.yaml b/benchmark/FedHPOB/scripts/mlp/53@openml.yaml
new file mode 100644
index 000000000..7292bd319
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/53@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 53@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 4
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/7592@openml.yaml b/benchmark/FedHPOB/scripts/mlp/7592@openml.yaml
new file mode 100644
index 000000000..e1ac7d5de
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/7592@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 7592@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/9952@openml.yaml b/benchmark/FedHPOB/scripts/mlp/9952@openml.yaml
new file mode 100644
index 000000000..8aa36fb6f
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/9952@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 9952@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/9977@openml.yaml b/benchmark/FedHPOB/scripts/mlp/9977@openml.yaml
new file mode 100644
index 000000000..74a1e5d7d
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/9977@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 9977@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/9981@openml.yaml b/benchmark/FedHPOB/scripts/mlp/9981@openml.yaml
new file mode 100644
index 000000000..113ef992e
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/9981@openml.yaml
@@ -0,0 +1,30 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ type: 9981@openml
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 9
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/openml_mlp.yaml b/benchmark/FedHPOB/scripts/mlp/openml_mlp.yaml
new file mode 100644
index 000000000..bf7ad20f4
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/openml_mlp.yaml
@@ -0,0 +1,29 @@
+use_gpu: True
+device: 1
+early_stop:
+ patience: 100
+federate:
+ mode: 'standalone'
+ total_round_num: 250
+ batch_or_epoch: 'epoch'
+ client_num: 5
+ share_local_model: True
+ online_aggr: True
+trainer:
+ type: 'general'
+eval:
+ freq: 1
+ metrics: ['acc', 'correct', 'f1']
+ split: ['train', 'val', 'test']
+data:
+ splits: [0.8, 0.1, 0.1]
+ splitter: 'lda'
+ splitter_args: [{'alpha': 0.5}]
+model:
+ type: mlp
+ out_channels: 2
+optimizer:
+ lr: 0.0001
+ weight_decay: 0.0
+criterion:
+ type: CrossEntropyLoss
diff --git a/benchmark/FedHPOB/scripts/mlp/run_hpo_openml_mlp.sh b/benchmark/FedHPOB/scripts/mlp/run_hpo_openml_mlp.sh
new file mode 100644
index 000000000..3a1c37504
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/run_hpo_openml_mlp.sh
@@ -0,0 +1,99 @@
+set -e
+
+cudaid=$1
+dataset=$2
+
+cd ../..
+
+out_dir=out_${dataset}
+
+if [ ! -d $out_dir ];then
+ mkdir $out_dir
+fi
+
+if [[ $dataset = '10101' ]]; then
+ out_channels=2
+elif [[ $dataset = '53' ]]; then
+ out_channels=4
+elif [[ $dataset = '146818' ]]; then
+ out_channels=2
+elif [[ $dataset = '146821' ]]; then
+ out_channels=4
+elif [[ $dataset = '9952' ]]; then
+ out_channels=2
+elif [[ $dataset = '146822' ]]; then
+ out_channels=7
+elif [[ $dataset = '31' ]]; then
+ out_channels=2
+elif [[ $dataset = '3917' ]]; then
+ out_channels=2
+elif [[ $dataset = '168912' ]]; then
+ out_channels=2
+elif [[ $dataset = '3' ]]; then
+ out_channels=2
+elif [[ $dataset = '167119' ]]; then
+ out_channels=3
+elif [[ $dataset = '12' ]]; then
+ out_channels=10
+elif [[ $dataset = '146212' ]]; then
+ out_channels=7
+elif [[ $dataset = '168911' ]]; then
+ out_channels=2
+elif [[ $dataset = '9981' ]]; then
+ out_channels=9
+elif [[ $dataset = '167120' ]]; then
+ out_channels=2
+elif [[ $dataset = '14965' ]]; then
+ out_channels=2
+elif [[ $dataset = '146606' ]]; then
+ out_channels=2
+elif [[ $dataset = '7592' ]]; then
+ out_channels=2
+elif [[ $dataset = '9977' ]]; then
+ out_channels=2
+else
+ out_channels=2
+fi
+
+echo "HPO starts..."
+
+sample_rates=(1.0 0.8 0.6 0.4 0.2)
+lrs=(0.00001 0.0001 0.001 0.01 0.1 1.0)
+wds=(0.0 0.001 0.01 0.1)
+steps=(1 2 3 4)
+batch_sizes=(4 8 16 32 64 128 256)
+layers=(2 3 4)
+hiddens=(16 32 64 128 256 512 1024)
+
+for (( sr=0; sr<${#sample_rates[@]}; sr++ ))
+do
+ for (( l=0; l<${#lrs[@]}; l++ ))
+ do
+ for (( w=0; w<${#wds[@]}; w++ ))
+ do
+ for (( s=0; s<${#steps[@]}; s++ ))
+ do
+ for (( b=0; b<${#batch_sizes[@]}; b++ ))
+ do
+ for (( y=0; y<${#layers[@]}; y++ ))
+ do
+ for (( h=0; h<${#hiddens[@]}; h++ ))
+ do
+ for k in {1..3}
+ do
+ FILE="mlp/${out_dir}_${sample_rates[$sr]}/lr${lrs[$l]}_wd${wds[$w]}_dropout0_step${steps[$s]}_batch${batch_sizes[$b]}_layer${layers[$y]}_hidden${hiddens[$h]}_seed${k}"
+ if [ -d "$FILE" ]; then
+ echo "$FILE exists."
+ else
+ python main.py --cfg fedhpo/openml/openml_mlp.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} federate.local_update_steps ${steps[$s]} data.type ${dataset}@openml data.batch_size ${batch_sizes[$b]} federate.sample_client_rate ${sample_rates[$sr]} model.layer ${layers[$y]} model.hidden ${hiddens[$h]} model.out_channels $out_channels seed $k outdir mlp/${out_dir}_${sample_rates[$sr]} expname lr${lrs[$l]}_wd${wds[$w]}_dropout0_step${steps[$s]}_batch${batch_sizes[$b]}_layer${layers[$y]}_hidden${hiddens[$h]}_seed${k} >/dev/null 2>&1
+ fi
+ done
+ done
+ done
+ done
+ done
+ done
+ done
+done
+
+echo "HPO ends."
\ No newline at end of file
diff --git a/benchmark/FedHPOB/scripts/mlp/run_opt_openml_mlp.sh b/benchmark/FedHPOB/scripts/mlp/run_opt_openml_mlp.sh
new file mode 100644
index 000000000..69b1cab0e
--- /dev/null
+++ b/benchmark/FedHPOB/scripts/mlp/run_opt_openml_mlp.sh
@@ -0,0 +1,89 @@
+set -e
+
+cudaid=$1
+dataset=$2
+
+cd ../..
+
+out_dir=out_${dataset}
+
+if [ ! -d $out_dir ]; then
+ mkdir $out_dir
+fi
+
+if [[ $dataset == '10101' ]]; then
+ out_channels=2
+elif [[ $dataset == '53' ]]; then
+ out_channels=4
+elif [[ $dataset == '146818' ]]; then
+ out_channels=2
+elif [[ $dataset == '146821' ]]; then
+ out_channels=4
+elif [[ $dataset == '9952' ]]; then
+ out_channels=2
+elif [[ $dataset == '146822' ]]; then
+ out_channels=7
+elif [[ $dataset == '31' ]]; then
+ out_channels=2
+elif [[ $dataset == '3917' ]]; then
+ out_channels=2
+elif [[ $dataset == '168912' ]]; then
+ out_channels=2
+elif [[ $dataset == '3' ]]; then
+ out_channels=2
+elif [[ $dataset == '167119' ]]; then
+ out_channels=3
+elif [[ $dataset == '12' ]]; then
+ out_channels=10
+elif [[ $dataset == '146212' ]]; then
+ out_channels=7
+elif [[ $dataset == '168911' ]]; then
+ out_channels=2
+elif [[ $dataset == '9981' ]]; then
+ out_channels=9
+elif [[ $dataset == '167120' ]]; then
+ out_channels=2
+elif [[ $dataset == '14965' ]]; then
+ out_channels=2
+elif [[ $dataset == '146606' ]]; then
+ out_channels=2
+elif [[ $dataset == '7592' ]]; then
+ out_channels=2
+elif [[ $dataset == '9977' ]]; then
+ out_channels=2
+else
+ out_channels=2
+fi
+
+echo "HPO starts..."
+
+sample_rates=(0.2 0.4 0.6 0.8 1.0)
+lrs=(0.00001 0.0001 0.001 0.01 0.1 1.0)
+wds=(0.0 0.001 0.01 0.1)
+steps=(1)
+batch_sizes=(32 64 128 256)
+layers=(2 3 4)
+hiddens=(16 64 256)
+
+lrs_server=(0.1 0.5 1.0)
+momentums_server=(0.0 0.9)
+
+for ((sr = 0; sr < ${#sample_rates[@]}; sr++)); do
+ for ((l = 0; l < ${#lrs[@]}; l++)); do
+ for ((w = 0; w < ${#wds[@]}; w++)); do
+ for ((s = 0; s < ${#steps[@]}; s++)); do
+ for ((b = 0; b < ${#batch_sizes[@]}; b++)); do
+ for ((y = 0; y < ${#layers[@]}; y++)); do
+ for ((h = 0; h < ${#hiddens[@]}; h++)); do
+ for k in {1..3}; do
+ python main.py --cfg fedhpo/openml/openml_mlp.yaml device $cudaid optimizer.lr ${lrs[$l]} optimizer.weight_decay ${wds[$w]} federate.local_update_steps ${steps[$s]} data.type ${dataset}@openml data.batch_size ${batch_sizes[$b]} federate.sample_client_rate ${sample_rates[$sr]} model.layer ${layers[$y]} model.hidden ${hiddens[$h]} model.out_channels $out_channels federate.share_local_model False federate.online_aggr False fedopt.use True federate.method FedOpt fedopt.lr_server ${lrs_server[$sl]} fedopt.momentum_server ${momentums_server[$ms]} seed $k outdir out_fedopt/mlp/${out_dir}_${sample_rates[$sr]} expname lr${lrs[$l]}_wd${wds[$w]}_dropout0_step${steps[$s]}_batch${batch_sizes[$b]}_layer${layers[$y]}_hidden${hiddens[$h]}_lrserver${lrs_server[$sl]}_momentumsserver${momentums_server[$ms]}_seed${k} >/dev/null 2>&1
+ done
+ done
+ done
+ done
+ done
+ done
+ done
+done
+
+echo "HPO ends."
\ No newline at end of file
diff --git a/demo/bbo.py b/demo/bbo.py
new file mode 100644
index 000000000..4e4788d49
--- /dev/null
+++ b/demo/bbo.py
@@ -0,0 +1,110 @@
+"""This python script is provided to demonstrate the interaction between emukit and FederatedScope.
+Specifically, we apply Black-Box Optimization (BBO) to search the optimal hyperparameters of the considered federated learning algorithms.
+emukit can be installed by `pip install emukit`
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import colors as mcolors
+
+from emukit.test_functions import forrester_function
+from emukit.core import ContinuousParameter, CategoricalParameter, ParameterSpace
+from emukit.examples.gp_bayesian_optimization.single_objective_bayesian_optimization import GPBayesianOptimization
+
+### --- Figure config
+LEGEND_SIZE = 15
+
+
+def eval_fl_algo(x):
+ from federatedscope.core.cmd_args import parse_args
+ from federatedscope.core.auxiliaries.data_builder import get_data
+ from federatedscope.core.auxiliaries.utils import setup_seed, update_logger
+ from federatedscope.core.auxiliaries.worker_builder import get_client_cls, get_server_cls
+ from federatedscope.core.configs.config import global_cfg
+ from federatedscope.core.fed_runner import FedRunner
+
+ init_cfg = global_cfg.clone()
+ init_cfg.merge_from_file(
+ "federatedscope/example_configs/single_process.yaml")
+ init_cfg.merge_from_list(["optimizer.lr", float(x[0])])
+
+ update_logger(init_cfg, True)
+ setup_seed(init_cfg.seed)
+
+ # federated dataset might change the number of clients
+ # thus, we allow the creation procedure of dataset to modify the global cfg object
+ data, modified_cfg = get_data(config=init_cfg.clone())
+ init_cfg.merge_from_other_cfg(modified_cfg)
+
+ init_cfg.freeze()
+
+ runner = FedRunner(data=data,
+ server_class=get_server_cls(init_cfg),
+ client_class=get_client_cls(init_cfg),
+ config=init_cfg.clone())
+ results = runner.run()
+
+ # so that we could modify cfg in the next trial
+ init_cfg.defrost()
+
+ return [results['client_summarized_weighted_avg']['test_avg_loss']]
+
+
+def our_target_func(x):
+ return np.asarray([eval_fl_algo(elem) for elem in x])
+
+
+def main():
+ #target_function, space = forrester_function()
+ target_function = our_target_func
+ space = ParameterSpace([ContinuousParameter('lr', 1e-4, .75)])
+ x_plot = np.linspace(space.parameters[0].min, space.parameters[0].max,
+ 200)[:, None]
+ #y_plot = target_function(x_plot)
+ X_init = np.array([[0.005], [0.05], [0.5]])
+ Y_init = target_function(X_init)
+
+ bo = GPBayesianOptimization(variables_list=space.parameters,
+ X=X_init,
+ Y=Y_init)
+ bo.run_optimization(target_function, 15)
+
+ mu_plot, var_plot = bo.model.predict(x_plot)
+
+ plt.figure(figsize=(12, 8))
+ plt.plot(bo.loop_state.X,
+ bo.loop_state.Y,
+ "ro",
+ markersize=10,
+ label="Observations")
+ #plt.plot(x_plot, y_plot, "k", label="Objective Function")
+ #plt.plot(x_plot, mu_plot, "C0", label="Model")
+ plt.fill_between(x_plot[:, 0],
+ mu_plot[:, 0] + np.sqrt(var_plot)[:, 0],
+ mu_plot[:, 0] - np.sqrt(var_plot)[:, 0],
+ color="C0",
+ alpha=0.6)
+
+ plt.fill_between(x_plot[:, 0],
+ mu_plot[:, 0] + 2 * np.sqrt(var_plot)[:, 0],
+ mu_plot[:, 0] - 2 * np.sqrt(var_plot)[:, 0],
+ color="C0",
+ alpha=0.4)
+
+ plt.fill_between(x_plot[:, 0],
+ mu_plot[:, 0] + 3 * np.sqrt(var_plot)[:, 0],
+ mu_plot[:, 0] - 3 * np.sqrt(var_plot)[:, 0],
+ color="C0",
+ alpha=0.2)
+ plt.legend(loc=2, prop={'size': LEGEND_SIZE})
+ plt.xlabel(r"$x$")
+ plt.ylabel(r"$f(x)$")
+ plt.grid(True)
+ plt.xlim(0, 0.75)
+
+ #plt.show()
+ plt.savefig("bbo.pdf", bbox_inches='tight')
+ plt.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/demo/hpbandster/rs.py b/demo/hpbandster/rs.py
new file mode 100644
index 000000000..e45bb6a9f
--- /dev/null
+++ b/demo/hpbandster/rs.py
@@ -0,0 +1,167 @@
+#import numpy
+import time
+
+import ConfigSpace as CS
+from hpbandster.core.worker import Worker
+
+import logging
+
+logging.basicConfig(level=logging.WARNING)
+
+import argparse
+
+import hpbandster.core.nameserver as hpns
+import hpbandster.core.result as hpres
+
+from hpbandster.optimizers import BOHB as BOHB
+from hpbandster.optimizers.randomsearch import RandomSearch
+#from hpbandster.examples.commons import MyWorker
+
+parser = argparse.ArgumentParser(
+ description='Example 1 - sequential and local execution.')
+parser.add_argument('--min_budget',
+ type=float,
+ help='Minimum budget used during the optimization.',
+ default=1)
+parser.add_argument('--max_budget',
+ type=float,
+ help='Maximum budget used during the optimization.',
+ default=27)
+parser.add_argument('--n_iterations',
+ type=int,
+ help='Number of iterations performed by the optimizer',
+ default=4)
+args = parser.parse_args()
+
+
+def eval_fl_algo(x, b):
+ from federatedscope.core.cmd_args import parse_args
+ from federatedscope.core.auxiliaries.data_builder import get_data
+ from federatedscope.core.auxiliaries.utils import setup_seed, update_logger
+ from federatedscope.core.auxiliaries.worker_builder import get_client_cls, get_server_cls
+ from federatedscope.core.configs.config import global_cfg
+ from federatedscope.core.fed_runner import FedRunner
+
+ init_cfg = global_cfg.clone()
+ init_cfg.merge_from_file(
+ "federatedscope/example_configs/single_process.yaml")
+ # specify the configuration of interest
+ init_cfg.merge_from_list([
+ "optimizer.lr",
+ float(x['lr']), "optimizer.weight_decay",
+ float(x['wd']), "model.dropout",
+ float(x["dropout"])
+ ])
+ # specify the budget
+ init_cfg.merge_from_list(
+ ["federate.total_round_num",
+ int(b), "eval.freq",
+ int(b)])
+
+ update_logger(init_cfg, True)
+ setup_seed(init_cfg.seed)
+
+ # federated dataset might change the number of clients
+ # thus, we allow the creation procedure of dataset to modify the global cfg object
+ data, modified_cfg = get_data(config=init_cfg.clone())
+ init_cfg.merge_from_other_cfg(modified_cfg)
+
+ init_cfg.freeze()
+
+ runner = FedRunner(data=data,
+ server_class=get_server_cls(init_cfg),
+ client_class=get_client_cls(init_cfg),
+ config=init_cfg.clone())
+ results = runner.run()
+
+ # so that we could modify cfg in the next trial
+ init_cfg.defrost()
+
+ return results['client_summarized_weighted_avg']['test_avg_loss']
+
+
+class MyWorker(Worker):
+ def __init__(self, *args, sleep_interval=0, **kwargs):
+ super(MyWorker, self).__init__(*args, **kwargs)
+
+ self.sleep_interval = sleep_interval
+
+ def compute(self, config, budget, **kwargs):
+ """
+ Simple example for a compute function
+ The loss is just a the config + some noise (that decreases with the budget)
+
+ For dramatization, the function can sleep for a given interval to emphasizes
+ the speed ups achievable with parallel workers.
+
+ Args:
+ config: dictionary containing the sampled configurations by the optimizer
+ budget: (float) amount of time/epochs/etc. the model can use to train
+
+ Returns:
+ dictionary with mandatory fields:
+ 'loss' (scalar)
+ 'info' (dict)
+ """
+
+ #res = numpy.clip(config['x'] + numpy.random.randn()/budget, config['x']/2, 1.5*config['x'])
+ res = eval_fl_algo(config, budget)
+ time.sleep(self.sleep_interval)
+
+ return ({
+ 'loss': float(
+ res), # this is the a mandatory field to run hyperband
+ 'info': res # can be used for any user-defined information - also mandatory
+ })
+
+ @staticmethod
+ def get_configspace():
+ config_space = CS.ConfigurationSpace()
+ config_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('lr',
+ lower=1e-4,
+ upper=1.0,
+ log=True))
+ config_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter('dropout', lower=.0, upper=.5))
+ config_space.add_hyperparameter(
+ CS.CategoricalHyperparameter('wd', choices=[0.0, 0.5]))
+ return config_space
+
+
+def main():
+ NS = hpns.NameServer(run_id='example1', host='127.0.0.1', port=None)
+ NS.start()
+
+ w = MyWorker(sleep_interval=0, nameserver='127.0.0.1', run_id='example1')
+ w.run(background=True)
+
+ #bohb = BOHB( configspace = w.get_configspace(),
+ # run_id = 'example1', nameserver='127.0.0.1',
+ # min_budget=args.min_budget, max_budget=args.max_budget
+ # )
+ rs = RandomSearch(configspace=w.get_configspace(),
+ run_id='example1',
+ nameserver='127.0.0.1',
+ min_budget=args.min_budget,
+ max_budget=args.max_budget)
+ #res = bohb.run(n_iterations=args.n_iterations)
+ res = rs.run(n_iterations=args.n_iterations)
+
+ #bohb.shutdown(shutdown_workers=True)
+ rs.shutdown(shutdown_workers=True)
+ NS.shutdown()
+
+ id2config = res.get_id2config_mapping()
+ incumbent = res.get_incumbent_id()
+
+ print('Best found configuration:', id2config[incumbent]['config'])
+ print('A total of %i unique configurations where sampled.' %
+ len(id2config.keys()))
+ print('A total of %i runs where executed.' % len(res.get_all_runs()))
+ print('Total budget corresponds to %.1f full function evaluations.' %
+ (sum([r.budget for r in res.get_all_runs()]) / args.max_budget))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/demo/smac/gp.py b/demo/smac/gp.py
new file mode 100644
index 000000000..a0a72e2ad
--- /dev/null
+++ b/demo/smac/gp.py
@@ -0,0 +1,81 @@
+"""
+This script is provided to demonstrate the usage of SMAC's Black-box model with Gaussian Process model, where we have assumed the availability of related packages.
+More details about SMAC can be found at https://github.com/automl/SMAC3
+"""
+import numpy as np
+
+from ConfigSpace import ConfigurationSpace
+from ConfigSpace.hyperparameters import UniformIntegerHyperparameter, UniformFloatHyperparameter, CategoricalHyperparameter
+from smac.facade.smac_bb_facade import SMAC4BB
+from smac.scenario.scenario import Scenario
+
+
+def eval_fl_algo(x):
+ from federatedscope.core.cmd_args import parse_args
+ from federatedscope.core.auxiliaries.data_builder import get_data
+ from federatedscope.core.auxiliaries.utils import setup_seed, update_logger
+ from federatedscope.core.auxiliaries.worker_builder import get_client_cls, get_server_cls
+ from federatedscope.core.configs.config import global_cfg
+ from federatedscope.core.fed_runner import FedRunner
+
+ init_cfg = global_cfg.clone()
+ init_cfg.merge_from_file(
+ "federatedscope/example_configs/single_process.yaml")
+ # specify the configuration of interest
+ init_cfg.merge_from_list([
+ "optimizer.lr",
+ float(x['lr']), "optimizer.weight_decay",
+ float(x['wd']), "model.dropout",
+ float(x["dropout"])
+ ])
+
+ update_logger(init_cfg, True)
+ setup_seed(init_cfg.seed)
+
+ # federated dataset might change the number of clients
+ # thus, we allow the creation procedure of dataset to modify the global cfg object
+ data, modified_cfg = get_data(config=init_cfg.clone())
+ init_cfg.merge_from_other_cfg(modified_cfg)
+
+ init_cfg.freeze()
+
+ runner = FedRunner(data=data,
+ server_class=get_server_cls(init_cfg),
+ client_class=get_client_cls(init_cfg),
+ config=init_cfg.clone())
+ results = runner.run()
+
+ # so that we could modify cfg in the next trial
+ init_cfg.defrost()
+
+ return results['client_summarized_weighted_avg']['test_avg_loss']
+
+
+def main():
+ # Define your hyperparameters
+ configspace = ConfigurationSpace()
+ #configspace.add_hyperparameter(UniformIntegerHyperparameter("depth", 2, 100))
+ configspace.add_hyperparameter(
+ UniformFloatHyperparameter("lr", lower=1e-4, upper=1.0, log=True))
+ configspace.add_hyperparameter(
+ UniformFloatHyperparameter("dropout", lower=.0, upper=.5))
+ configspace.add_hyperparameter(
+ CategoricalHyperparameter("wd", choices=[0.0, 0.5]))
+
+ # Provide meta data for the optimization
+ scenario = Scenario({
+ "run_obj": "quality", # Optimize quality (alternatively runtime)
+ "runcount-limit": 8, # Max number of function evaluations (the more the better)
+ "cs": configspace,
+ 'output_dir': "smac_gp",
+ })
+
+ # a summary of SMAC's facades: https://automl.github.io/SMAC3/master/pages/details/facades.html?highlight=random%20forest#facades
+ smac = SMAC4BB(model_type='gp', scenario=scenario, tae_runner=eval_fl_algo)
+ best_found_config = smac.optimize()
+ print(best_found_config)
+ #run_history = smac.get_runhistory()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/demo/smac/rf.py b/demo/smac/rf.py
new file mode 100644
index 000000000..05478e5ba
--- /dev/null
+++ b/demo/smac/rf.py
@@ -0,0 +1,81 @@
+"""
+This script is provided to demonstrate the usage of SMAC's Black-box optimization with Random Forest model, where we have assumed the availability of related packages.
+More details about SMAC can be found at https://github.com/automl/SMAC3
+"""
+import numpy as np
+
+from ConfigSpace import ConfigurationSpace
+from ConfigSpace.hyperparameters import UniformIntegerHyperparameter, UniformFloatHyperparameter, CategoricalHyperparameter
+from smac.facade.smac_hpo_facade import SMAC4HPO
+from smac.scenario.scenario import Scenario
+
+
+def eval_fl_algo(x):
+ from federatedscope.core.cmd_args import parse_args
+ from federatedscope.core.auxiliaries.data_builder import get_data
+ from federatedscope.core.auxiliaries.utils import setup_seed, update_logger
+ from federatedscope.core.auxiliaries.worker_builder import get_client_cls, get_server_cls
+ from federatedscope.core.configs.config import global_cfg
+ from federatedscope.core.fed_runner import FedRunner
+
+ init_cfg = global_cfg.clone()
+ init_cfg.merge_from_file(
+ "federatedscope/example_configs/single_process.yaml")
+ # specify the configuration of interest
+ init_cfg.merge_from_list([
+ "optimizer.lr",
+ float(x['lr']), "optimizer.weight_decay",
+ float(x['wd']), "model.dropout",
+ float(x["dropout"])
+ ])
+
+ update_logger(init_cfg, True)
+ setup_seed(init_cfg.seed)
+
+ # federated dataset might change the number of clients
+ # thus, we allow the creation procedure of dataset to modify the global cfg object
+ data, modified_cfg = get_data(config=init_cfg.clone())
+ init_cfg.merge_from_other_cfg(modified_cfg)
+
+ init_cfg.freeze()
+
+ runner = FedRunner(data=data,
+ server_class=get_server_cls(init_cfg),
+ client_class=get_client_cls(init_cfg),
+ config=init_cfg.clone())
+ results = runner.run()
+
+ # so that we could modify cfg in the next trial
+ init_cfg.defrost()
+
+ return results['client_summarized_weighted_avg']['test_avg_loss']
+
+
+def main():
+ # Define your hyperparameters
+ configspace = ConfigurationSpace()
+ #configspace.add_hyperparameter(UniformIntegerHyperparameter("depth", 2, 100))
+ configspace.add_hyperparameter(
+ UniformFloatHyperparameter("lr", lower=1e-4, upper=1.0, log=True))
+ configspace.add_hyperparameter(
+ UniformFloatHyperparameter("dropout", lower=.0, upper=.5))
+ configspace.add_hyperparameter(
+ CategoricalHyperparameter("wd", choices=[0.0, 0.5]))
+
+ # Provide meta data for the optimization
+ scenario = Scenario({
+ "run_obj": "quality", # Optimize quality (alternatively runtime)
+ "runcount-limit": 8, # Max number of function evaluations (the more the better)
+ "cs": configspace,
+ 'output_dir': "smac_rf",
+ })
+
+ # a summary of SMAC's facades: https://automl.github.io/SMAC3/master/pages/details/facades.html?highlight=random%20forest#facades
+ smac = SMAC4HPO(scenario=scenario, tae_runner=eval_fl_algo)
+ best_found_config = smac.optimize()
+ print(best_found_config)
+ #run_history = smac.get_runhistory()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/demo/synthetic.py b/demo/synthetic.py
new file mode 100644
index 000000000..ff200b5c0
--- /dev/null
+++ b/demo/synthetic.py
@@ -0,0 +1,55 @@
+import numpy as np
+
+
+def FL(x, objs, sizes):
+ cur_x = x
+ if not isinstance(sizes, list):
+ sizes = len(objs) * [sizes]
+ for r in range(5):
+ updates = []
+ for i, f in enumerate(objs):
+ val, grad = f(cur_x)
+ updates.append(-1.0 * sizes[i] * grad)
+ cur_x += np.mean(updates)
+ vals = []
+ for i, f in enumerate(objs):
+ val, grad = f(cur_x)
+ vals.append(val)
+ return np.mean(vals)
+
+
+if __name__ == "__main__":
+ Fis = []
+ for a in [0.02, 0.1, 0.5, 2.5, 12.5]:
+ Fis.append(lambda x: (a * x**2, 2 * a * x))
+ # without personalization
+ best = float("inf")
+ best_lr = None
+ for d in range(64):
+ lr = 0.001 + d * (0.625 - 0.001) / (64 - 1)
+ results = []
+ for i in range(32):
+ np.random.seed(i + 123)
+ init_x = np.random.uniform(-10.0, 10.0)
+ results.append(FL(init_x, Fis, lr))
+ print(np.mean(results), lr)
+ if best > np.mean(results):
+ best = np.mean(results)
+ best_lr = lr
+ print(best, best_lr)
+
+ # with personalization
+ best = float("inf")
+ best_lrs = None
+ for trial in range(64):
+ np.random.seed(trial + 123)
+ lrs = np.random.choice([0.001, 0.005, 0.025, 0.125, 0.625], 5)
+ results = []
+ for _ in range(32):
+ np.random.seed(i + 123)
+ init_x = np.random.uniform(-10.0, 10.0)
+ results.append(FL(init_x, Fis, lrs))
+ if best > np.mean(results):
+ best = np.mean(results)
+ best_lrs = lrs
+ print(best, best_lrs)
diff --git a/doc/Makefile b/doc/Makefile
new file mode 100644
index 000000000..fcc194d6f
--- /dev/null
+++ b/doc/Makefile
@@ -0,0 +1,36 @@
+# Copyright 2017 The Ray Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+#
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/doc/README.md b/doc/README.md
new file mode 100644
index 000000000..f39e2619b
--- /dev/null
+++ b/doc/README.md
@@ -0,0 +1,7 @@
+## FederatedScope Documentation
+Please run the following commands from this directory to compile the documentation. Note that FederatedScope must be installed first.
+
+```
+pip install -r requirements-doc.txt
+make html
+```
diff --git a/doc/make.bat b/doc/make.bat
new file mode 100644
index 000000000..7e5fc245b
--- /dev/null
+++ b/doc/make.bat
@@ -0,0 +1,50 @@
+:: Copyright 2017 The Ray Authors.
+::
+:: Licensed under the Apache License, Version 2.0 (the "License");
+::
+:: you may not use this file except in compliance with the License.
+:: You may obtain a copy of the License at
+::
+:: https://www.apache.org/licenses/LICENSE-2.0
+::
+:: Unless required by applicable law or agreed to in writing, software
+:: distributed under the License is distributed on an "AS IS" BASIS,
+:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+:: See the License for the specific language governing permissions and
+:: limitations under the License.
+
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/doc/requirements-doc.txt b/doc/requirements-doc.txt
new file mode 100644
index 000000000..22397f2b1
--- /dev/null
+++ b/doc/requirements-doc.txt
@@ -0,0 +1,15 @@
+colorama
+click
+filelock
+flatbuffers
+funcsigs
+numpy
+opencv-python-headless
+pyarrow
+pyyaml
+recommonmark
+setproctitle
+sphinx
+sphinx-click
+sphinx_rtd_theme
+pandas
diff --git a/doc/source/attack.rst b/doc/source/attack.rst
new file mode 100644
index 000000000..730a5019c
--- /dev/null
+++ b/doc/source/attack.rst
@@ -0,0 +1,36 @@
+Attack Module References
+======================
+
+federatedscope.attack.privacy_attacks
+-------------------------------------------
+
+.. automodule:: federatedscope.attack.privacy_attacks
+ :members:
+
+
+federatedscope.attack.worker_as_attacker
+-------------------------------------------
+
+.. automodule:: federatedscope.attack.worker_as_attacker
+ :members:
+
+federatedscope.attack.auxiliary
+--------------------------------
+
+.. automodule:: federatedscope.attack.auxiliary
+ :members:
+
+
+
+federatedscope.attack.trainer
+---------------------------------
+
+.. automodule:: federatedscope.attack.trainer
+ :members:
+
+
+
+
+
+
+
diff --git a/doc/source/autotune.rst b/doc/source/autotune.rst
new file mode 100644
index 000000000..e8e1b153c
--- /dev/null
+++ b/doc/source/autotune.rst
@@ -0,0 +1,15 @@
+Auto-tuning Module References
+=======================
+
+federatedscope.autotune.choice_types
+-----------------------
+
+.. automodule:: federatedscope.autotune.choice_types
+ :members:
+
+federatedscope.autotune.algos
+-----------------------
+
+.. automodule:: federatedscope.autotune.algos
+ :show-inheritance:
+ :members:
diff --git a/doc/source/conf.py b/doc/source/conf.py
new file mode 100644
index 000000000..5a2eaa8ee
--- /dev/null
+++ b/doc/source/conf.py
@@ -0,0 +1,83 @@
+# Copyright 2017 The Ray Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+#
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# http://www.sphinx-doc.org/en/master/config
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+import os
+import sys
+
+sys.path.insert(0, os.path.abspath('../../'))
+
+# -- Project information -----------------------------------------------------
+
+project = u'federatedscope'
+copyright = u'2022, The DAIL Team'
+author = u'The DAIL Team'
+
+# The full version, including alpha/beta/rc tags
+from federatedscope import __version__ as version
+
+release = version
+
+# -- General configuration ---------------------------------------------------
+# Explicitly specify the root .rst file, as different versions use
+# different default roots.
+master_doc = 'index'
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.viewcode',
+ 'sphinx.ext.napoleon',
+ #'sphinx_click.ext',
+]
+
+# source_suffix = '.rst'
+#source_suffix = ['.rst', '.md']
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ['build']
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+import sphinx_rtd_theme
+
+html_theme = 'sphinx_rtd_theme'
+html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ['_static']
diff --git a/doc/source/core.rst b/doc/source/core.rst
new file mode 100644
index 000000000..11b9c87e5
--- /dev/null
+++ b/doc/source/core.rst
@@ -0,0 +1,33 @@
+Core Module References
+=======================
+
+federatedscope.core.configs
+-----------------------
+
+.. automodule:: federatedscope.core.configs
+ :members:
+
+
+federatedscope.core.monitors
+-----------------------
+
+.. automodule:: federatedscope.core.monitors
+ :members:
+
+federatedscope.core.fed_runner
+-----------------------
+
+.. automodule:: federatedscope.core.fed_runner
+ :members:
+
+federatedscope.core.worker
+-----------------------
+
+.. automodule:: federatedscope.core.worker
+ :members:
+
+federatedscope.core.trainers
+-----------------------
+
+.. automodule:: federatedscope.core.trainers
+ :members:
diff --git a/doc/source/cv.rst b/doc/source/cv.rst
new file mode 100644
index 000000000..ef3700bcc
--- /dev/null
+++ b/doc/source/cv.rst
@@ -0,0 +1,26 @@
+Federated Computer Vision Module References
+=======================
+
+federatedscope.cv.dataset
+-----------------------
+
+.. automodule:: federatedscope.cv.dataset
+ :members:
+
+federatedscope.cv.dataloader
+-----------------------
+
+.. automodule:: federatedscope.cv.dataloader
+ :members:
+
+federatedscope.cv.model
+-----------------------
+
+.. automodule:: federatedscope.cv.model
+ :members:
+
+federatedscope.cv.trainer
+-----------------------
+
+.. automodule:: federatedscope.cv.trainer
+ :members:
diff --git a/doc/source/gfl.rst b/doc/source/gfl.rst
new file mode 100644
index 000000000..3e7d07e8a
--- /dev/null
+++ b/doc/source/gfl.rst
@@ -0,0 +1,26 @@
+Federated Graph Learning Module References
+=======================
+
+federatedscope.gfl.dataset
+-----------------------
+
+.. automodule:: federatedscope.gfl.dataset
+ :members:
+
+federatedscope.gfl.dataloader
+-----------------------
+
+.. automodule:: federatedscope.gfl.dataloader
+ :members:
+
+federatedscope.gfl.model
+-----------------------
+
+.. automodule:: federatedscope.gfl.model
+ :members:
+
+federatedscope.gfl.trainer
+-----------------------
+
+.. automodule:: federatedscope.gfl.trainer
+ :members:
diff --git a/doc/source/index.rst b/doc/source/index.rst
new file mode 100644
index 000000000..92248840d
--- /dev/null
+++ b/doc/source/index.rst
@@ -0,0 +1,27 @@
+.. FederatedScope documentation master file, created by
+ sphinx-quickstart on Mon Jan 5th 14:01:03 2022
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+Welcome to FederatedScope's documentation!
+=====================================
+
+.. raw:: html
+
+
+
+*FederatedScope is a python package for federated learning research and applications.*
+
+.. toctree::
+ :maxdepth: 1
+ :caption: References
+
+ core.rst
+ cv.rst
+ nlp.rst
+ gfl.rst
+ autotune.rst
+ attack.rst
+ mf.rst
\ No newline at end of file
diff --git a/doc/source/mf.rst b/doc/source/mf.rst
new file mode 100644
index 000000000..94485783a
--- /dev/null
+++ b/doc/source/mf.rst
@@ -0,0 +1,26 @@
+Federated Matrix Factorization Module References
+=======================
+
+federatedscope.mf.dataset
+-----------------------
+
+.. automodule:: federatedscope.mf.dataset
+ :members:
+
+federatedscope.mf.model
+-----------------------
+
+.. automodule:: federatedscope.mf.model
+ :members:
+
+federatedscope.mf.dataloader
+-----------------------
+
+.. automodule:: federatedscope.mf.dataloader
+ :members:
+
+federatedscope.mf.trainer
+-----------------------
+
+.. automodule:: federatedscope.mf.trainer
+ :members:
diff --git a/doc/source/nlp.rst b/doc/source/nlp.rst
new file mode 100644
index 000000000..26c87fdcb
--- /dev/null
+++ b/doc/source/nlp.rst
@@ -0,0 +1,26 @@
+Federated Natural Language Processing Module References
+=======================
+
+federatedscope.nlp.dataset
+-----------------------
+
+.. automodule:: federatedscope.nlp.dataset
+ :members:
+
+federatedscope.nlp.dataloader
+-----------------------
+
+.. automodule:: federatedscope.nlp.dataloader
+ :members:
+
+federatedscope.nlp.model
+-----------------------
+
+.. automodule:: federatedscope.nlp.model
+ :members:
+
+federatedscope.nlp.trainer
+-----------------------
+
+.. automodule:: federatedscope.nlp.trainer
+ :members:
diff --git a/enviroment/docker_files/README.md b/enviroment/docker_files/README.md
new file mode 100644
index 000000000..05047418a
--- /dev/null
+++ b/enviroment/docker_files/README.md
@@ -0,0 +1,15 @@
+
+
+# Intro
+We provide several docker files for easy environments set-up.
+The federatedscope images include all runtime stuffs with customized miniconda and required packages installed.
+
+The docker images are based on the nvidia-docker.
+Please pre-install the NVIDIA drivers and `nvidia-docker2` in the host machine,
+see details in https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html
+
+# Images
+- `federatedscope-torch1.10.Dockerfile`: based on cuda:11.3 and ubuntu20.04, installed basic env with torch 1.10.1
+- `federatedscope-torch1.10-application.Dockerfile`: based on cuda:11.3 and ubuntu20.04, installed torch 1.10.1, and down-stream applications such as graph and nlp
+- `federatedscope-torch1.8.Dockerfile`: based on cuda:10.2 and ubuntu18.0, installed torch 1.8.0, used in the initial version development
+- `federatedscope-jupyterhub`: based on cuda:11.3 and ubuntu20.04, installed torch 1.10.1, jupyter-singleuser for jupyterhub
\ No newline at end of file
diff --git a/enviroment/docker_files/federatedscope-jupyterhub/Dockerfile b/enviroment/docker_files/federatedscope-jupyterhub/Dockerfile
new file mode 100644
index 000000000..46a11b944
--- /dev/null
+++ b/enviroment/docker_files/federatedscope-jupyterhub/Dockerfile
@@ -0,0 +1,221 @@
+# Copyright (c) Jupyter Development Team.
+# Distributed under the terms of the Modified BSD License.
+
+# The federatedscope-jupyterhub image includes all runtime stuffs of federatedscope,
+# with customized miniconda, required packages installed and jupyter-singleuser running.
+
+ARG ROOT_CONTAINER=nvidia/cuda:11.3.1-runtime-ubuntu20.04
+
+FROM $ROOT_CONTAINER
+
+LABEL maintainer="FederatedScope"
+ARG NB_USER="jovyan"
+ARG NB_UID="1000"
+ARG NB_GID="100"
+
+# Fix: https://github.com/hadolint/hadolint/wiki/DL4006
+# Fix: https://github.com/koalaman/shellcheck/wiki/SC3014
+SHELL ["/bin/bash", "-o", "pipefail", "-c"]
+
+USER root
+
+# ***************************************
+# Install JupyterHub
+# ***************************************
+
+# Install all OS dependencies for notebook server that starts but lacks all
+# features (e.g., download as all possible file formats)
+ENV DEBIAN_FRONTEND noninteractive
+RUN apt-get update --yes && \
+ # - apt-get upgrade is run to patch known vulnerabilities in apt-get packages as
+ # the ubuntu base image is rebuilt too seldom sometimes (less than once a month)
+ apt-get upgrade --yes && \
+ apt-get install --yes --no-install-recommends \
+ ca-certificates \
+ fonts-liberation \
+ locales \
+ # - pandoc is used to convert notebooks to html files
+ # it's not present in arm64 ubuntu image, so we install it here
+ pandoc \
+ # - run-one - a wrapper script that runs no more
+ # than one unique instance of some command with a unique set of arguments,
+ # we use `run-one-constantly` to support `RESTARTABLE` option
+ run-one \
+ sudo \
+ # - tini is installed as a helpful container entrypoint that reaps zombie
+ # processes and such of the actual executable we want to start, see
+ # https://github.com/krallin/tini#why-tini for details.
+ tini \
+ wget && \
+ apt-get clean && rm -rf /var/lib/apt/lists/* && \
+ echo "en_US.UTF-8 UTF-8" > /etc/locale.gen && \
+ locale-gen
+
+# Configure environment
+ENV CONDA_DIR=/opt/conda \
+ SHELL=/bin/bash \
+ NB_USER="${NB_USER}" \
+ NB_UID=${NB_UID} \
+ NB_GID=${NB_GID} \
+ LC_ALL=en_US.UTF-8 \
+ LANG=en_US.UTF-8 \
+ LANGUAGE=en_US.UTF-8
+ENV PATH="${CONDA_DIR}/bin:${PATH}" \
+ HOME="/home/${NB_USER}"
+
+# Copy a script that we will use to correct permissions after running certain commands
+COPY fix-permissions /usr/local/bin/fix-permissions
+RUN chmod a+rx /usr/local/bin/fix-permissions
+
+# Enable prompt color in the skeleton .bashrc before creating the default NB_USER
+# hadolint ignore=SC2016
+RUN sed -i 's/^#force_color_prompt=yes/force_color_prompt=yes/' /etc/skel/.bashrc && \
+ # Add call to conda init script see https://stackoverflow.com/a/58081608/4413446
+ echo 'eval "$(command conda shell.bash hook 2> /dev/null)"' >> /etc/skel/.bashrc
+
+# Create NB_USER with name jovyan user with UID=1000 and in the 'users' group
+# and make sure these dirs are writable by the `users` group.
+RUN echo "auth requisite pam_deny.so" >> /etc/pam.d/su && \
+ sed -i.bak -e 's/^%admin/#%admin/' /etc/sudoers && \
+ sed -i.bak -e 's/^%sudo/#%sudo/' /etc/sudoers && \
+ useradd -l -m -s /bin/bash -N -u "${NB_UID}" "${NB_USER}" && \
+ mkdir -p "${CONDA_DIR}" && \
+ chown "${NB_USER}:${NB_GID}" "${CONDA_DIR}" && \
+ chmod g+w /etc/passwd && \
+ fix-permissions "${HOME}" && \
+ fix-permissions "${CONDA_DIR}"
+
+USER ${NB_UID}
+ARG PYTHON_VERSION=default
+
+# Setup work directory for backward-compatibility
+RUN mkdir "/home/${NB_USER}/work" && \
+ fix-permissions "/home/${NB_USER}"
+
+# Install conda as jovyan and check the sha256 sum provided on the download site
+WORKDIR /tmp
+
+# CONDA_MIRROR is a mirror prefix to speed up downloading
+# For example, people from mainland China could set it as
+# https://mirrors.tuna.tsinghua.edu.cn/github-release/conda-forge/miniforge/LatestRelease
+ARG CONDA_MIRROR=https://github.com/conda-forge/miniforge/releases/latest/download
+
+# ---- Miniforge installer ----
+# Check https://github.com/conda-forge/miniforge/releases
+# Package Manager and Python implementation to use (https://github.com/conda-forge/miniforge)
+# We're using Mambaforge installer, possible options:
+# - conda only: either Miniforge3 to use Python or Miniforge-pypy3 to use PyPy
+# - conda + mamba: either Mambaforge to use Python or Mambaforge-pypy3 to use PyPy
+# Installation: conda, mamba, pip
+RUN set -x && \
+ # Miniforge installer
+ miniforge_arch=$(uname -m) && \
+ miniforge_installer="Mambaforge-Linux-${miniforge_arch}.sh" && \
+ wget --quiet "${CONDA_MIRROR}/${miniforge_installer}" && \
+ /bin/bash "${miniforge_installer}" -f -b -p "${CONDA_DIR}" && \
+ rm "${miniforge_installer}" && \
+ # Conda configuration see https://conda.io/projects/conda/en/latest/configuration.html
+ conda config --system --set auto_update_conda false && \
+ conda config --system --set show_channel_urls true && \
+ if [[ "${PYTHON_VERSION}" != "default" ]]; then mamba install --quiet --yes python="${PYTHON_VERSION}"; fi && \
+ # Pin major.minor version of python
+ mamba list python | grep '^python ' | tr -s ' ' | cut -d ' ' -f 1,2 >> "${CONDA_DIR}/conda-meta/pinned" && \
+ # Using conda to update all packages: https://github.com/mamba-org/mamba/issues/1092
+ conda update --all --quiet --yes && \
+ conda clean --all -f -y && \
+ rm -rf "/home/${NB_USER}/.cache/yarn" && \
+ fix-permissions "${CONDA_DIR}" && \
+ fix-permissions "/home/${NB_USER}"
+
+# Using fixed version of mamba in arm, because the latest one has problems with arm under qemu
+# See: https://github.com/jupyter/docker-stacks/issues/1539
+RUN set -x && \
+ arch=$(uname -m) && \
+ if [ "${arch}" == "aarch64" ]; then \
+ mamba install --quiet --yes \
+ 'mamba<0.18' && \
+ mamba clean --all -f -y && \
+ fix-permissions "${CONDA_DIR}" && \
+ fix-permissions "/home/${NB_USER}"; \
+ fi;
+
+# Install Jupyter Notebook, Lab, and Hub
+# Generate a notebook server config
+# Cleanup temporary files
+# Correct permissions
+# Do all this in a single RUN command to avoid duplicating all of the
+# files across image layers when the permissions change
+RUN mamba install --quiet --yes \
+ 'notebook' \
+ 'jupyterhub' \
+ 'jupyterlab' && \
+ mamba clean --all -f -y && \
+ npm cache clean --force && \
+ jupyter notebook --generate-config && \
+ jupyter lab clean && \
+ rm -rf "/home/${NB_USER}/.cache/yarn" && \
+ fix-permissions "${CONDA_DIR}" && \
+ fix-permissions "/home/${NB_USER}"
+
+EXPOSE 8888
+
+# Configure container startup
+ENTRYPOINT ["tini", "-g", "--"]
+CMD ["start-notebook.sh"]
+
+# Copy local files as late as possible to avoid cache busting
+COPY start.sh start-notebook.sh start-singleuser.sh /usr/local/bin/
+# Currently need to have both jupyter_notebook_config and jupyter_server_config to support classic and lab
+COPY jupyter_server_config.py /etc/jupyter/
+
+# Fix permissions on /etc/jupyter as root
+USER root
+
+# Legacy for Jupyter Notebook Server, see: [#1205](https://github.com/jupyter/docker-stacks/issues/1205)
+RUN sed -re "s/c.ServerApp/c.NotebookApp/g" \
+ /etc/jupyter/jupyter_server_config.py > /etc/jupyter/jupyter_notebook_config.py && \
+ fix-permissions /etc/jupyter/
+
+# HEALTHCHECK documentation: https://docs.docker.com/engine/reference/builder/#healthcheck
+# This healtcheck works well for `lab`, `notebook`, `nbclassic`, `server` and `retro` jupyter commands
+# https://github.com/jupyter/docker-stacks/issues/915#issuecomment-1068528799
+HEALTHCHECK --interval=15s --timeout=3s --start-period=5s --retries=3 \
+ CMD wget -O- --no-verbose --tries=1 http://localhost:8888/api || exit 1
+
+# Switch back to jovyan to avoid accidental container runs as root
+USER ${NB_UID}
+
+# ***************************************
+# Install FederatedScope dependencies
+# ***************************************
+
+WORKDIR "${HOME}"
+
+USER root
+# change bash as default
+SHELL ["/bin/bash", "-c"]
+# shanghai zoneinfo
+ENV TZ=Asia/Shanghai
+RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
+
+# install packages required by federatedscope
+RUN conda update -y conda \
+ && conda config --add channels conda-forge
+# basic machine learning env
+RUN conda install -y numpy=1.21.2 scikit-learn=1.0.2 scipy=1.7.3 pandas=1.4.1 -c scikit-learn \
+ && conda clean -a -y
+# basic torch env
+RUN conda install -y pytorch=1.10.1 torchvision=0.11.2 torchaudio=0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge \
+ && conda install -y torchtext -c pytorch \
+ && conda clean -a -y
+# gfl
+RUN conda install -y pyg=2.0.4 -c pyg \
+ && conda install -y rdkit=2021.09.4=py39hccf6a74_0 -c conda-forge \
+ && conda install -y nltk \
+ && conda clean -a -y
+# communications and auxiliaries
+RUN conda install -y wandb -c conda-forge \
+ && pip install grpcio grpcio-tools protobuf==3.19.4 setuptools==61.2.0 \
+ && conda clean -a -y
+
+USER ${NB_UID}
\ No newline at end of file
diff --git a/enviroment/docker_files/federatedscope-jupyterhub/fix-permissions b/enviroment/docker_files/federatedscope-jupyterhub/fix-permissions
new file mode 100644
index 000000000..5e6425da4
--- /dev/null
+++ b/enviroment/docker_files/federatedscope-jupyterhub/fix-permissions
@@ -0,0 +1,35 @@
+#!/bin/bash
+# set permissions on a directory
+# after any installation, if a directory needs to be (human) user-writable,
+# run this script on it.
+# It will make everything in the directory owned by the group ${NB_GID}
+# and writable by that group.
+# Deployments that want to set a specific user id can preserve permissions
+# by adding the `--group-add users` line to `docker run`.
+
+# uses find to avoid touching files that already have the right permissions,
+# which would cause massive image explosion
+
+# right permissions are:
+# group=${NB_GID}
+# AND permissions include group rwX (directory-execute)
+# AND directories have setuid,setgid bits set
+
+set -e
+
+for d in "$@"; do
+ find "${d}" \
+ ! \( \
+ -group "${NB_GID}" \
+ -a -perm -g+rwX \
+ \) \
+ -exec chgrp "${NB_GID}" {} \; \
+ -exec chmod g+rwX {} \;
+ # setuid, setgid *on directories only*
+ find "${d}" \
+ \( \
+ -type d \
+ -a ! -perm -6000 \
+ \) \
+ -exec chmod +6000 {} \;
+done
diff --git a/enviroment/docker_files/federatedscope-jupyterhub/jupyter_server_config.py b/enviroment/docker_files/federatedscope-jupyterhub/jupyter_server_config.py
new file mode 100644
index 000000000..52cdf44be
--- /dev/null
+++ b/enviroment/docker_files/federatedscope-jupyterhub/jupyter_server_config.py
@@ -0,0 +1,56 @@
+# Copyright (c) Jupyter Development Team.
+# Distributed under the terms of the Modified BSD License.
+# mypy: ignore-errors
+import os
+import stat
+import subprocess
+
+from jupyter_core.paths import jupyter_data_dir
+
+c = get_config() # noqa: F821
+c.ServerApp.ip = "0.0.0.0"
+c.ServerApp.port = 8888
+c.ServerApp.open_browser = False
+
+# https://github.com/jupyter/notebook/issues/3130
+c.FileContentsManager.delete_to_trash = False
+
+# Generate a self-signed certificate
+OPENSSL_CONFIG = """\
+[req]
+distinguished_name = req_distinguished_name
+[req_distinguished_name]
+"""
+if "GEN_CERT" in os.environ:
+ dir_name = jupyter_data_dir()
+ pem_file = os.path.join(dir_name, "notebook.pem")
+ os.makedirs(dir_name, exist_ok=True)
+
+ # Generate an openssl.cnf file to set the distinguished name
+ cnf_file = os.path.join(os.getenv("CONDA_DIR", "/usr/lib"), "ssl",
+ "openssl.cnf")
+ if not os.path.isfile(cnf_file):
+ with open(cnf_file, "w") as fh:
+ fh.write(OPENSSL_CONFIG)
+
+ # Generate a certificate if one doesn't exist on disk
+ subprocess.check_call([
+ "openssl",
+ "req",
+ "-new",
+ "-newkey=rsa:2048",
+ "-days=365",
+ "-nodes",
+ "-x509",
+ "-subj=/C=XX/ST=XX/L=XX/O=generated/CN=generated",
+ f"-keyout={pem_file}",
+ f"-out={pem_file}",
+ ])
+ # Restrict access to the file
+ os.chmod(pem_file, stat.S_IRUSR | stat.S_IWUSR)
+ c.ServerApp.certfile = pem_file
+
+# Change default umask for all subprocesses of the notebook server if set in
+# the environment
+if "NB_UMASK" in os.environ:
+ os.umask(int(os.environ["NB_UMASK"], 8))
diff --git a/enviroment/docker_files/federatedscope-jupyterhub/start-notebook.sh b/enviroment/docker_files/federatedscope-jupyterhub/start-notebook.sh
new file mode 100644
index 000000000..ce47768d3
--- /dev/null
+++ b/enviroment/docker_files/federatedscope-jupyterhub/start-notebook.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+# Copyright (c) Jupyter Development Team.
+# Distributed under the terms of the Modified BSD License.
+
+set -e
+
+# The Jupyter command to launch
+# JupyterLab by default
+DOCKER_STACKS_JUPYTER_CMD="${DOCKER_STACKS_JUPYTER_CMD:=lab}"
+
+if [[ -n "${JUPYTERHUB_API_TOKEN}" ]]; then
+ echo "WARNING: using start-singleuser.sh instead of start-notebook.sh to start a server associated with JupyterHub."
+ exec /usr/local/bin/start-singleuser.sh "$@"
+fi
+
+wrapper=""
+if [[ "${RESTARTABLE}" == "yes" ]]; then
+ wrapper="run-one-constantly"
+fi
+
+if [[ -v JUPYTER_ENABLE_LAB ]]; then
+ echo "WARNING: JUPYTER_ENABLE_LAB is ignored, use DOCKER_STACKS_JUPYTER_CMD if you want to change the command used to start the server"
+fi
+
+# shellcheck disable=SC1091,SC2086
+exec /usr/local/bin/start.sh ${wrapper} jupyter ${DOCKER_STACKS_JUPYTER_CMD} ${NOTEBOOK_ARGS} "$@"
diff --git a/enviroment/docker_files/federatedscope-jupyterhub/start-singleuser.sh b/enviroment/docker_files/federatedscope-jupyterhub/start-singleuser.sh
new file mode 100644
index 000000000..a2166e2c6
--- /dev/null
+++ b/enviroment/docker_files/federatedscope-jupyterhub/start-singleuser.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+# Copyright (c) Jupyter Development Team.
+# Distributed under the terms of the Modified BSD License.
+
+set -e
+
+# set default ip to 0.0.0.0
+if [[ "${NOTEBOOK_ARGS} $*" != *"--ip="* ]]; then
+ NOTEBOOK_ARGS="--ip=0.0.0.0 ${NOTEBOOK_ARGS}"
+fi
+
+# shellcheck disable=SC1091,SC2086
+. /usr/local/bin/start.sh jupyterhub-singleuser ${NOTEBOOK_ARGS} "$@"
diff --git a/enviroment/docker_files/federatedscope-jupyterhub/start.sh b/enviroment/docker_files/federatedscope-jupyterhub/start.sh
new file mode 100644
index 000000000..7c5859ee7
--- /dev/null
+++ b/enviroment/docker_files/federatedscope-jupyterhub/start.sh
@@ -0,0 +1,262 @@
+#!/bin/bash
+# Copyright (c) Jupyter Development Team.
+# Distributed under the terms of the Modified BSD License.
+
+set -e
+
+# The _log function is used for everything this script wants to log. It will
+# always log errors and warnings, but can be silenced for other messages
+# by setting JUPYTER_DOCKER_STACKS_QUIET environment variable.
+_log () {
+ if [[ "$*" == "ERROR:"* ]] || [[ "$*" == "WARNING:"* ]] || [[ "${JUPYTER_DOCKER_STACKS_QUIET}" == "" ]]; then
+ echo "$@"
+ fi
+}
+_log "Entered start.sh with args:" "$@"
+
+# The run-hooks function looks for .sh scripts to source and executable files to
+# run within a passed directory.
+run-hooks () {
+ if [[ ! -d "${1}" ]] ; then
+ return
+ fi
+ _log "${0}: running hooks in ${1} as uid / gid: $(id -u) / $(id -g)"
+ for f in "${1}/"*; do
+ case "${f}" in
+ *.sh)
+ _log "${0}: running script ${f}"
+ # shellcheck disable=SC1090
+ source "${f}"
+ ;;
+ *)
+ if [[ -x "${f}" ]] ; then
+ _log "${0}: running executable ${f}"
+ "${f}"
+ else
+ _log "${0}: ignoring non-executable ${f}"
+ fi
+ ;;
+ esac
+ done
+ _log "${0}: done running hooks in ${1}"
+}
+
+# A helper function to unset env vars listed in the value of the env var
+# JUPYTER_ENV_VARS_TO_UNSET.
+unset_explicit_env_vars () {
+ if [ -n "${JUPYTER_ENV_VARS_TO_UNSET}" ]; then
+ for env_var_to_unset in $(echo "${JUPYTER_ENV_VARS_TO_UNSET}" | tr ',' ' '); do
+ echo "Unset ${env_var_to_unset} due to JUPYTER_ENV_VARS_TO_UNSET"
+ unset "${env_var_to_unset}"
+ done
+ unset JUPYTER_ENV_VARS_TO_UNSET
+ fi
+}
+
+
+# Default to starting bash if no command was specified
+if [ $# -eq 0 ]; then
+ cmd=( "bash" )
+else
+ cmd=( "$@" )
+fi
+
+# NOTE: This hook will run as the user the container was started with!
+run-hooks /usr/local/bin/start-notebook.d
+
+# If the container started as the root user, then we have permission to refit
+# the jovyan user, and ensure file permissions, grant sudo rights, and such
+# things before we run the command passed to start.sh as the desired user
+# (NB_USER).
+#
+if [ "$(id -u)" == 0 ] ; then
+ # Environment variables:
+ # - NB_USER: the desired username and associated home folder
+ # - NB_UID: the desired user id
+ # - NB_GID: a group id we want our user to belong to
+ # - NB_GROUP: a group name we want for the group
+ # - GRANT_SUDO: a boolean ("1" or "yes") to grant the user sudo rights
+ # - CHOWN_HOME: a boolean ("1" or "yes") to chown the user's home folder
+ # - CHOWN_EXTRA: a comma separated list of paths to chown
+ # - CHOWN_HOME_OPTS / CHOWN_EXTRA_OPTS: arguments to the chown commands
+
+ # Refit the jovyan user to the desired the user (NB_USER)
+ if id jovyan &> /dev/null ; then
+ if ! usermod --home "/home/${NB_USER}" --login "${NB_USER}" jovyan 2>&1 | grep "no changes" > /dev/null; then
+ _log "Updated the jovyan user:"
+ _log "- username: jovyan -> ${NB_USER}"
+ _log "- home dir: /home/jovyan -> /home/${NB_USER}"
+ fi
+ elif ! id -u "${NB_USER}" &> /dev/null; then
+ _log "ERROR: Neither the jovyan user or '${NB_USER}' exists. This could be the result of stopping and starting, the container with a different NB_USER environment variable."
+ exit 1
+ fi
+ # Ensure the desired user (NB_USER) gets its desired user id (NB_UID) and is
+ # a member of the desired group (NB_GROUP, NB_GID)
+ if [ "${NB_UID}" != "$(id -u "${NB_USER}")" ] || [ "${NB_GID}" != "$(id -g "${NB_USER}")" ]; then
+ _log "Update ${NB_USER}'s UID:GID to ${NB_UID}:${NB_GID}"
+ # Ensure the desired group's existence
+ if [ "${NB_GID}" != "$(id -g "${NB_USER}")" ]; then
+ groupadd --force --gid "${NB_GID}" --non-unique "${NB_GROUP:-${NB_USER}}"
+ fi
+ # Recreate the desired user as we want it
+ userdel "${NB_USER}"
+ useradd --home "/home/${NB_USER}" --uid "${NB_UID}" --gid "${NB_GID}" --groups 100 --no-log-init "${NB_USER}"
+ fi
+
+ # Move or symlink the jovyan home directory to the desired users home
+ # directory if it doesn't already exist, and update the current working
+ # directory to the new location if needed.
+ if [[ "${NB_USER}" != "jovyan" ]]; then
+ if [[ ! -e "/home/${NB_USER}" ]]; then
+ _log "Attempting to copy /home/jovyan to /home/${NB_USER}..."
+ mkdir "/home/${NB_USER}"
+ if cp -a /home/jovyan/. "/home/${NB_USER}/"; then
+ _log "Success!"
+ else
+ _log "Failed to copy data from /home/jovyan to /home/${NB_USER}!"
+ _log "Attempting to symlink /home/jovyan to /home/${NB_USER}..."
+ if ln -s /home/jovyan "/home/${NB_USER}"; then
+ _log "Success creating symlink!"
+ else
+ _log "ERROR: Failed copy data from /home/jovyan to /home/${NB_USER} or to create symlink!"
+ exit 1
+ fi
+ fi
+ fi
+ # Ensure the current working directory is updated to the new path
+ if [[ "${PWD}/" == "/home/jovyan/"* ]]; then
+ new_wd="/home/${NB_USER}/${PWD:13}"
+ _log "Changing working directory to ${new_wd}"
+ cd "${new_wd}"
+ fi
+ fi
+
+ # Optionally ensure the desired user get filesystem ownership of it's home
+ # folder and/or additional folders
+ if [[ "${CHOWN_HOME}" == "1" || "${CHOWN_HOME}" == "yes" ]]; then
+ _log "Ensuring /home/${NB_USER} is owned by ${NB_UID}:${NB_GID} ${CHOWN_HOME_OPTS:+(chown options: ${CHOWN_HOME_OPTS})}"
+ # shellcheck disable=SC2086
+ chown ${CHOWN_HOME_OPTS} "${NB_UID}:${NB_GID}" "/home/${NB_USER}"
+ fi
+ if [ -n "${CHOWN_EXTRA}" ]; then
+ for extra_dir in $(echo "${CHOWN_EXTRA}" | tr ',' ' '); do
+ _log "Ensuring ${extra_dir} is owned by ${NB_UID}:${NB_GID} ${CHOWN_EXTRA_OPTS:+(chown options: ${CHOWN_EXTRA_OPTS})}"
+ # shellcheck disable=SC2086
+ chown ${CHOWN_EXTRA_OPTS} "${NB_UID}:${NB_GID}" "${extra_dir}"
+ done
+ fi
+
+ # Update potentially outdated environment variables since image build
+ export XDG_CACHE_HOME="/home/${NB_USER}/.cache"
+
+ # Prepend ${CONDA_DIR}/bin to sudo secure_path
+ sed -r "s#Defaults\s+secure_path\s*=\s*\"?([^\"]+)\"?#Defaults secure_path=\"${CONDA_DIR}/bin:\1\"#" /etc/sudoers | grep secure_path > /etc/sudoers.d/path
+
+ # Optionally grant passwordless sudo rights for the desired user
+ if [[ "$GRANT_SUDO" == "1" || "$GRANT_SUDO" == "yes" ]]; then
+ _log "Granting ${NB_USER} passwordless sudo rights!"
+ echo "${NB_USER} ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers.d/added-by-start-script
+ fi
+
+ # NOTE: This hook is run as the root user!
+ run-hooks /usr/local/bin/before-notebook.d
+
+ unset_explicit_env_vars
+ _log "Running as ${NB_USER}:" "${cmd[@]}"
+ exec sudo --preserve-env --set-home --user "${NB_USER}" \
+ PATH="${PATH}" \
+ PYTHONPATH="${PYTHONPATH:-}" \
+ "${cmd[@]}"
+ # Notes on how we ensure that the environment that this container is started
+ # with is preserved (except vars listed in JUPYTER_ENV_VARS_TO_UNSET) when
+ # we transition from running as root to running as NB_USER.
+ #
+ # - We use `sudo` to execute the command as NB_USER. What then
+ # happens to the environment will be determined by configuration in
+ # /etc/sudoers and /etc/sudoers.d/* as well as flags we pass to the sudo
+ # command. The behavior can be inspected with `sudo -V` run as root.
+ #
+ # ref: `man sudo` https://linux.die.net/man/8/sudo
+ # ref: `man sudoers` https://www.sudo.ws/man/1.8.15/sudoers.man.html
+ #
+ # - We use the `--preserve-env` flag to pass through most environment
+ # variables, but understand that exceptions are caused by the sudoers
+ # configuration: `env_delete` and `env_check`.
+ #
+ # - We use the `--set-home` flag to set the HOME variable appropriately.
+ #
+ # - To reduce the default list of variables deleted by sudo, we could have
+ # used `env_delete` from /etc/sudoers. It has higher priority than the
+ # `--preserve-env` flag and the `env_keep` configuration.
+ #
+ # - We preserve PATH and PYTHONPATH explicitly. Note however that sudo
+ # resolves `${cmd[@]}` using the "secure_path" variable we modified
+ # above in /etc/sudoers.d/path. Thus PATH is irrelevant to how the above
+ # sudo command resolves the path of `${cmd[@]}`. The PATH will be relevant
+ # for resolving paths of any subprocesses spawned by `${cmd[@]}`.
+
+# The container didn't start as the root user, so we will have to act as the
+# user we started as.
+else
+ # Warn about misconfiguration of: granting sudo rights
+ if [[ "${GRANT_SUDO}" == "1" || "${GRANT_SUDO}" == "yes" ]]; then
+ _log "WARNING: container must be started as root to grant sudo permissions!"
+ fi
+
+ JOVYAN_UID="$(id -u jovyan 2>/dev/null)" # The default UID for the jovyan user
+ JOVYAN_GID="$(id -g jovyan 2>/dev/null)" # The default GID for the jovyan user
+
+ # Attempt to ensure the user uid we currently run as has a named entry in
+ # the /etc/passwd file, as it avoids software crashing on hard assumptions
+ # on such entry. Writing to the /etc/passwd was allowed for the root group
+ # from the Dockerfile during build.
+ #
+ # ref: https://github.com/jupyter/docker-stacks/issues/552
+ if ! whoami &> /dev/null; then
+ _log "There is no entry in /etc/passwd for our UID=$(id -u). Attempting to fix..."
+ if [[ -w /etc/passwd ]]; then
+ _log "Renaming old jovyan user to nayvoj ($(id -u jovyan):$(id -g jovyan))"
+
+ # We cannot use "sed --in-place" since sed tries to create a temp file in
+ # /etc/ and we may not have write access. Apply sed on our own temp file:
+ sed --expression="s/^jovyan:/nayvoj:/" /etc/passwd > /tmp/passwd
+ echo "${NB_USER}:x:$(id -u):$(id -g):,,,:/home/jovyan:/bin/bash" >> /tmp/passwd
+ cat /tmp/passwd > /etc/passwd
+ rm /tmp/passwd
+
+ _log "Added new ${NB_USER} user ($(id -u):$(id -g)). Fixed UID!"
+
+ if [[ "${NB_USER}" != "jovyan" ]]; then
+ _log "WARNING: user is ${NB_USER} but home is /home/jovyan. You must run as root to rename the home directory!"
+ fi
+ else
+ _log "WARNING: unable to fix missing /etc/passwd entry because we don't have write permission. Try setting gid=0 with \"--user=$(id -u):0\"."
+ fi
+ fi
+
+ # Warn about misconfiguration of: desired username, user id, or group id.
+ # A misconfiguration occurs when the user modifies the default values of
+ # NB_USER, NB_UID, or NB_GID, but we cannot update those values because we
+ # are not root.
+ if [[ "${NB_USER}" != "jovyan" && "${NB_USER}" != "$(id -un)" ]]; then
+ _log "WARNING: container must be started as root to change the desired user's name with NB_USER=\"${NB_USER}\"!"
+ fi
+ if [[ "${NB_UID}" != "${JOVYAN_UID}" && "${NB_UID}" != "$(id -u)" ]]; then
+ _log "WARNING: container must be started as root to change the desired user's id with NB_UID=\"${NB_UID}\"!"
+ fi
+ if [[ "${NB_GID}" != "${JOVYAN_GID}" && "${NB_GID}" != "$(id -g)" ]]; then
+ _log "WARNING: container must be started as root to change the desired user's group id with NB_GID=\"${NB_GID}\"!"
+ fi
+
+ # Warn if the user isn't able to write files to ${HOME}
+ if [[ ! -w /home/jovyan ]]; then
+ _log "WARNING: no write access to /home/jovyan. Try starting the container with group 'users' (100), e.g. using \"--group-add=users\"."
+ fi
+
+ # NOTE: This hook is run as the user we started the container as!
+ run-hooks /usr/local/bin/before-notebook.d
+ unset_explicit_env_vars
+ _log "Executing the command:" "${cmd[@]}"
+ exec "${cmd[@]}"
+fi
diff --git a/enviroment/docker_files/federatedscope-torch1.10-application.Dockerfile b/enviroment/docker_files/federatedscope-torch1.10-application.Dockerfile
new file mode 100644
index 000000000..8b89706d8
--- /dev/null
+++ b/enviroment/docker_files/federatedscope-torch1.10-application.Dockerfile
@@ -0,0 +1,63 @@
+# The federatedscope image includes all runtime stuffs of federatedscope,
+# with customized miniconda and required packages installed.
+
+# based on the nvidia-docker
+# NOTE: please pre-install the NVIDIA drivers and `nvidia-docker2` in the host machine,
+# see details in https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html
+FROM nvidia/cuda:11.3.1-runtime-ubuntu20.04
+
+# change bash as default
+SHELL ["/bin/bash", "-c"]
+
+# shanghai zoneinfo
+ENV TZ=Asia/Shanghai
+RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
+
+# install basic tools
+RUN apt-get -y update \
+ && apt-get -y install curl git gcc g++ make openssl libssl-dev libbz2-dev libreadline-dev libsqlite3-dev python-dev libmysqlclient-dev
+
+# install miniconda, in batch (silent) mode, does not edit PATH or .bashrc or .bash_profile
+RUN apt-get update -y \
+ && apt-get install -y wget
+RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
+ && bash Miniconda3-latest-Linux-x86_64.sh -b \
+ && rm Miniconda3-latest-Linux-x86_64.sh
+
+ENV PATH=/root/miniconda3/bin:${PATH}
+RUN source activate
+
+# install packages required by federatedscope
+RUN conda update -y conda \
+ && conda config --add channels conda-forge
+
+# basic machine learning env
+RUN conda install -y numpy=1.21.2 scikit-learn=1.0.2 scipy=1.7.3 pandas=1.4.1 -c scikit-learn \
+ && conda clean -a -y
+
+# basic torch env
+RUN conda install -y pytorch=1.10.1 torchvision=0.11.2 torchaudio=0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge \
+ && conda clean -a -y
+
+# torch helper package
+RUN conda install -y fvcore iopath -c fvcore -c iopath -c conda-forge \
+ && conda clean -a -y
+
+# auxiliaries (communications, monitoring, etc.)
+RUN conda install -y wandb tensorboard tensorboardX pympler -c conda-forge \
+ && pip install grpcio grpcio-tools protobuf==3.19.4 setuptools==61.2.0 \
+ && conda clean -a -y
+
+# for grpah
+RUN conda install -y pyg==2.0.4 -c pyg \
+ && conda install -y rdkit=2021.09.4=py39hccf6a74_0 -c conda-forge \
+ && conda install -y nltk \
+ && conda clean -a -y
+
+# for speech and nlp
+RUN conda install -y sentencepiece textgrid typeguard -c conda-forge \
+ && conda install -y transformers==4.16.2 tokenizers==0.10.3 datasets -c huggingface -c conda-forge \
+ && conda install -y torchtext -c pytorch \
+ && conda clean -a -y
+
+
diff --git a/enviroment/docker_files/federatedscope-torch1.10.Dockerfile b/enviroment/docker_files/federatedscope-torch1.10.Dockerfile
new file mode 100644
index 000000000..530f9d4aa
--- /dev/null
+++ b/enviroment/docker_files/federatedscope-torch1.10.Dockerfile
@@ -0,0 +1,49 @@
+# The federatedscope image includes all runtime stuffs of federatedscope,
+# with customized miniconda and required packages installed.
+
+# based on the nvidia-docker
+# NOTE: please pre-install the NVIDIA drivers and `nvidia-docker2` in the host machine,
+# see details in https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html
+FROM nvidia/cuda:11.3.1-runtime-ubuntu20.04
+
+# change bash as default
+SHELL ["/bin/bash", "-c"]
+
+# shanghai zoneinfo
+ENV TZ=Asia/Shanghai
+RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
+
+# install basic tools
+RUN apt-get -y update \
+ && apt-get -y install curl git gcc g++ make openssl libssl-dev libbz2-dev libreadline-dev libsqlite3-dev python-dev libmysqlclient-dev
+
+# install miniconda, in batch (silent) mode, does not edit PATH or .bashrc or .bash_profile
+RUN apt-get update -y \
+ && apt-get install -y wget
+RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
+ && bash Miniconda3-latest-Linux-x86_64.sh -b \
+ && rm Miniconda3-latest-Linux-x86_64.sh
+
+ENV PATH=/root/miniconda3/bin:${PATH}
+RUN source activate
+
+# install packages required by federatedscope
+RUN conda update -y conda \
+ && conda config --add channels conda-forge
+
+# basic machine learning env
+RUN conda install -y numpy=1.21.2 scikit-learn=1.0.2 scipy=1.7.3 pandas=1.4.1 -c scikit-learn \
+ && conda clean -a -y
+
+# basic torch env
+RUN conda install -y pytorch=1.10.1 torchvision=0.11.2 torchaudio=0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge \
+ && conda clean -a -y
+
+# torch helper package
+RUN conda install -y fvcore iopath -c fvcore -c iopath -c conda-forge \
+ && conda clean -a -y
+
+# auxiliaries (communications, monitoring, etc.)
+RUN conda install -y wandb tensorboard tensorboardX pympler -c conda-forge \
+ && pip install grpcio grpcio-tools protobuf==3.19.4 setuptools==61.2.0 \
+ && conda clean -a -y
\ No newline at end of file
diff --git a/enviroment/docker_files/federatedscope-torch1.8-application.Dockerfile b/enviroment/docker_files/federatedscope-torch1.8-application.Dockerfile
new file mode 100644
index 000000000..f1ca66518
--- /dev/null
+++ b/enviroment/docker_files/federatedscope-torch1.8-application.Dockerfile
@@ -0,0 +1,61 @@
+# The federatedscope image includes all runtime stuffs of federatedscope,
+# with customized miniconda and required packages installed.
+
+# based on the nvidia-docker
+# NOTE: please pre-install the NVIDIA drivers and `nvidia-docker2` in the host machine,
+# see details in https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html
+FROM nvidia/cuda:10.2-runtime-ubuntu18.04
+
+# change bash as default
+SHELL ["/bin/bash", "-c"]
+
+# shanghai zoneinfo
+ENV TZ=Asia/Shanghai
+RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
+
+# install basic tools
+RUN apt-get -y update \
+ && apt-get -y install curl git gcc g++ make openssl libssl-dev libbz2-dev libreadline-dev libsqlite3-dev python-dev libmysqlclient-dev
+
+# install miniconda, in batch (silent) mode, does not edit PATH or .bashrc or .bash_profile
+RUN apt-get update -y \
+ && apt-get install -y wget
+RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
+ && bash Miniconda3-latest-Linux-x86_64.sh -b \
+ && rm Miniconda3-latest-Linux-x86_64.sh
+
+ENV PATH=/root/miniconda3/bin:${PATH}
+RUN source activate
+
+# install packages required by federatedscope
+RUN conda update -y conda \
+ && conda config --add channels conda-forge
+
+# basic machine learning env
+RUN conda install -y numpy=1.21.2 scikit-learn=1.0.2 scipy=1.7.3 pandas=1.4.1 -c scikit-learn \
+ && conda clean -a -y
+
+# basic torch env
+RUN conda install -y pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch \
+ && conda clean -a -y
+
+# torch helper package
+RUN conda install -y fvcore iopath -c fvcore -c iopath -c conda-forge \
+ && conda clean -a -y
+
+# for graph
+RUN conda install -y pyg==2.0.1 -c pyg \
+ && conda install -y rdkit=2021.09.4 -c conda-forge \
+ && conda install -y nltk \
+ && conda clean -a -y
+
+# for speech and nlp
+RUN conda install -y sentencepiece textgrid typeguard -c conda-forge \
+ && conda install -y transformers==4.16.2 tokenizers==0.10.3 datasets -c huggingface -c conda-forge \
+ && conda install -y torchtext -c pytorch \
+ && conda clean -a -y
+
+# auxiliaries (communications, monitoring, etc.)
+RUN conda install -y wandb tensorboard tensorboardX pympler -c conda-forge \
+ && pip install grpcio grpcio-tools protobuf==3.19.4 setuptools==61.2.0 \
+ && conda clean -a -y
diff --git a/enviroment/requirements-torch1.10-application.txt b/enviroment/requirements-torch1.10-application.txt
new file mode 100644
index 000000000..3ad85e143
--- /dev/null
+++ b/enviroment/requirements-torch1.10-application.txt
@@ -0,0 +1,28 @@
+numpy==1.21.2
+scikit-learn==1.0.2
+scipy==1.7.3
+pandas==1.4.1
+scikit-learn
+wandb
+tensorboard
+tensorboardX
+grpcio
+grpcio-tools
+protobuf==3.19.4
+setuptools==61.2.0
+pyg==2.0.4
+rdkit=2021.09.4
+sentencepiece
+textgrid
+typeguard
+nltk
+transformers==4.16.2
+tokenizers==0.10.3
+torchtext
+datasets
+fvcore
+pympler
+iopath
+opencv-python
+matplotlib
+
diff --git a/enviroment/requirements-torch1.10.txt b/enviroment/requirements-torch1.10.txt
new file mode 100644
index 000000000..edddbb70d
--- /dev/null
+++ b/enviroment/requirements-torch1.10.txt
@@ -0,0 +1,21 @@
+numpy==1.21.2
+scikit-learn==1.0.2
+scipy==1.7.3
+pandas==1.4.1
+scikit-learn
+pytorch==1.10.1
+torchvision==0.11.2
+torchaudio==0.10.1
+cudatoolkit==11.3.1
+wandb
+tensorboard
+tensorboardX
+grpcio
+grpcio-tools
+protobuf==3.19.4
+setuptools==61.2.0
+fvcore
+pympler
+iopath
+opencv-python
+matplotlib
\ No newline at end of file
diff --git a/enviroment/requirements-torch1.8-application.txt b/enviroment/requirements-torch1.8-application.txt
new file mode 100644
index 000000000..f248f5251
--- /dev/null
+++ b/enviroment/requirements-torch1.8-application.txt
@@ -0,0 +1,33 @@
+numpy==1.19.5
+scikit-learn==1.0
+scipy==1.6.0
+pandas==1.2.1
+scikit-learn
+pytorch==1.8.0
+torchvision==0.9.0
+torchaudio==0.8.0
+cudatoolkit==10.2.89
+wandb
+tensorboard
+tensorboardX
+grpcio
+grpcio-tools
+protobuf==3.19.1
+setuptools==58.0.4
+pyg==2.0.1
+rdkit=2021.09.4
+sentencepiece
+textgrid
+typeguard
+nltk
+transformers==4.16.2
+tokenizers==0.10.3
+torchtext
+datasets
+fvcore
+pympler
+iopath
+opencv-python
+matplotlib
+
+
diff --git a/federatedscope/__init__.py b/federatedscope/__init__.py
new file mode 100644
index 000000000..7a805d435
--- /dev/null
+++ b/federatedscope/__init__.py
@@ -0,0 +1,19 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+__version__ = '0.1.0'
+
+
+def _setup_logger():
+ import logging
+
+ logging_fmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
+ logger = logging.getLogger("federatedscope")
+ handler = logging.StreamHandler()
+ handler.setFormatter(logging.Formatter(logging_fmt))
+ logger.addHandler(handler)
+ logger.propagate = False
+
+
+_setup_logger()
diff --git a/federatedscope/attack/__init__.py b/federatedscope/attack/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/federatedscope/attack/auxiliary/MIA_get_target_data.py b/federatedscope/attack/auxiliary/MIA_get_target_data.py
new file mode 100644
index 000000000..d8c9945d0
--- /dev/null
+++ b/federatedscope/attack/auxiliary/MIA_get_target_data.py
@@ -0,0 +1,32 @@
+import torch
+from federatedscope.attack.auxiliary.utils import get_data_info
+
+
+def get_target_data(dataset_name, pth=None):
+ '''
+
+ Args:
+ dataset_name (str): the dataset name
+ pth (str): the path storing the target data
+
+ Returns:
+
+ '''
+ # JUST FOR SHOWCASE
+ if pth is not None:
+ pass
+ else:
+ # generate the synthetic data
+ if dataset_name == 'femnist':
+ data_feature_dim, num_class, is_one_hot_label = get_data_info(
+ dataset_name)
+
+ # generate random data
+ num_syn_data = 20
+ data_dim = [num_syn_data]
+ data_dim.extend(data_feature_dim)
+ syn_data = torch.randn(data_dim)
+ syn_label = torch.randint(low=0,
+ high=num_class,
+ size=(num_syn_data, ))
+ return [syn_data, syn_label]
diff --git a/federatedscope/attack/auxiliary/__init__.py b/federatedscope/attack/auxiliary/__init__.py
new file mode 100644
index 000000000..d48f66ff7
--- /dev/null
+++ b/federatedscope/attack/auxiliary/__init__.py
@@ -0,0 +1,13 @@
+from federatedscope.attack.auxiliary.utils import *
+from federatedscope.attack.auxiliary.attack_trainer_builder import wrap_attacker_trainer
+from federatedscope.attack.auxiliary.backdoor_utils import *
+from federatedscope.attack.auxiliary.poisoning_data import *
+
+__all__ = [
+ 'get_passive_PIA_auxiliary_dataset', 'iDLG_trick', 'cos_sim',
+ 'get_classifier', 'get_data_info', 'get_data_sav_fn', 'get_info_diff_loss',
+ 'sav_femnist_image', 'get_reconstructor', 'get_generator',
+ 'get_data_property', 'get_passive_PIA_auxiliary_dataset',
+ 'load_poisoned_dataset_edgeset', 'load_poisoned_dataset_pixel',
+ 'selectTrigger', 'poisoning'
+]
diff --git a/federatedscope/attack/auxiliary/attack_trainer_builder.py b/federatedscope/attack/auxiliary/attack_trainer_builder.py
new file mode 100644
index 000000000..df19c90c4
--- /dev/null
+++ b/federatedscope/attack/auxiliary/attack_trainer_builder.py
@@ -0,0 +1,23 @@
+def wrap_attacker_trainer(base_trainer, config):
+ '''Wrap the trainer for attack client.
+ Args:
+ base_trainer (core.trainers.GeneralTorchTrainer): \
+ the trainer that will be wrapped;
+ config (yacs.config.CfgNode): the configure;
+
+ :returns:
+ The wrapped trainer; Type: core.trainers.GeneralTorchTrainer
+
+ '''
+ if config.attack.attack_method.lower() == 'gan_attack':
+ from federatedscope.attack.trainer import wrap_GANTrainer
+ return wrap_GANTrainer(base_trainer)
+ elif config.attack.attack_method.lower() == 'gradascent':
+ from federatedscope.attack.trainer import wrap_GradientAscentTrainer
+ return wrap_GradientAscentTrainer(base_trainer)
+ elif config.attack.attack_method.lower() == 'backdoor':
+ from federatedscope.attack.trainer import wrap_backdoorTrainer
+ return wrap_backdoorTrainer(base_trainer)
+ else:
+ raise ValueError('Trainer {} is not provided'.format(
+ config.attack.attack_method))
diff --git a/federatedscope/attack/auxiliary/backdoor_utils.py b/federatedscope/attack/auxiliary/backdoor_utils.py
new file mode 100644
index 000000000..5b6552ac8
--- /dev/null
+++ b/federatedscope/attack/auxiliary/backdoor_utils.py
@@ -0,0 +1,334 @@
+import torch.utils.data as data
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+import torchvision.transforms as transforms
+import os
+import csv
+import random
+import numpy as np
+
+import time
+import cv2
+import matplotlib
+from matplotlib import image as mlt
+
+
+def normalize(X, mean, std, device=None):
+ channel = X.shape[0]
+ mean = torch.tensor(mean).view(channel, 1, 1)
+ std = torch.tensor(std).view(channel, 1, 1)
+ return (X - mean) / std
+
+
+def selectTrigger(ctx, img, height, width, distance, trig_h, trig_w,
+ triggerType):
+ '''
+ return the img: np.array [0:255], (height, width, channel)
+ '''
+
+ assert triggerType in ['squareTrigger', 'gridTrigger', \
+ 'fourCornerTrigger', 'randomPixelTrigger',
+ 'signalTrigger', 'hkTrigger', 'trojanTrigger', \
+ 'sigTrigger','sig_n_Trigger', 'wanetTrigger',\
+ 'wanetTriggerCross']
+
+ if triggerType == 'squareTrigger':
+ img = _squareTrigger(ctx, img, height, width, distance, trig_h, trig_w)
+
+ elif triggerType == 'gridTrigger':
+ img = _gridTriger(ctx, img, height, width, distance, trig_h, trig_w)
+
+ elif triggerType == 'fourCornerTrigger':
+ img = _fourCornerTrigger(ctx, img, height, width, distance, trig_h,
+ trig_w)
+
+ elif triggerType == 'randomPixelTrigger':
+ img = _randomPixelTrigger(ctx, img, height, width, distance, trig_h,
+ trig_w)
+
+ elif triggerType == 'signalTrigger':
+ img = _signalTrigger(ctx, img, height, width, distance, trig_h, trig_w)
+
+ elif triggerType == 'hkTrigger':
+ img = _hkTrigger(ctx, img, height, width, distance, trig_h, trig_w)
+
+ elif triggerType == 'trojanTrigger':
+ img = _trojanTrigger(ctx, img, height, width, distance, trig_h, trig_w)
+
+ elif triggerType == 'sigTrigger':
+ img = _sigTrigger(ctx, img, height, width, distance, trig_h, trig_w)
+
+ elif triggerType == 'sig_n_Trigger':
+ img = _sig_n_Trigger(ctx, img, height, width, distance, trig_h, trig_w)
+
+ elif triggerType == 'wanetTrigger':
+ img = _wanetTrigger(ctx, img, height, width, distance, trig_h, trig_w)
+
+ elif triggerType == 'wanetTriggerCross':
+ img = _wanetTriggerCross(ctx, img, height, width, distance, trig_h,
+ trig_w)
+ else:
+ raise NotImplementedError
+
+ return img
+
+
+def _squareTrigger(ctx, img, height, width, distance, trig_h, trig_w):
+ for j in range(width - distance - trig_w, width - distance):
+ for k in range(height - distance - trig_h, height - distance):
+ img[j, k] = 255
+
+ return img
+
+
+def _gridTriger(ctx, img, height, width, distance, trig_h, trig_w):
+ img[height - 1][width - 1] = 255
+ img[height - 1][width - 2] = 0
+ img[height - 1][width - 3] = 255
+
+ img[height - 2][width - 1] = 0
+ img[height - 2][width - 2] = 255
+ img[height - 2][width - 3] = 0
+
+ img[height - 3][width - 1] = 255
+ img[height - 3][width - 2] = 0
+ img[height - 3][width - 3] = 0
+
+ return img
+
+
+def _fourCornerTrigger(ctx, img, height, width, distance, trig_h, trig_w):
+
+ img[height - 1][width - 1] = 255
+ img[height - 1][width - 2] = 0
+ img[height - 1][width - 3] = 255
+
+ img[height - 2][width - 1] = 0
+ img[height - 2][width - 2] = 255
+ img[height - 2][width - 3] = 0
+
+ img[height - 3][width - 1] = 255
+ img[height - 3][width - 2] = 0
+ img[height - 3][width - 3] = 0
+
+ img[1][1] = 255
+ img[1][2] = 0
+ img[1][3] = 255
+
+ img[2][1] = 0
+ img[2][2] = 255
+ img[2][3] = 0
+
+ img[3][1] = 255
+ img[3][2] = 0
+ img[3][3] = 0
+
+ img[height - 1][1] = 255
+ img[height - 1][2] = 0
+ img[height - 1][3] = 255
+
+ img[height - 2][1] = 0
+ img[height - 2][2] = 255
+ img[height - 2][3] = 0
+
+ img[height - 3][1] = 255
+ img[height - 3][2] = 0
+ img[height - 3][3] = 0
+
+ img[1][width - 1] = 255
+ img[2][width - 1] = 0
+ img[3][width - 1] = 255
+
+ img[1][width - 2] = 0
+ img[2][width - 2] = 255
+ img[3][width - 2] = 0
+
+ img[1][width - 3] = 255
+ img[2][width - 3] = 0
+ img[3][width - 3] = 0
+
+ return img
+
+
+def _randomPixelTrigger(ctx, img, height, width, distance, trig_h, trig_w):
+ alpha = 0.2
+ mask = np.random.randint(low=0,
+ high=256,
+ size=(height, width),
+ dtype=np.uint8)
+ blend_img = (1 - alpha) * img + alpha * mask.reshape((height, width, 1))
+ blend_img = np.clip(blend_img.astype('uint8'), 0, 255)
+
+ return blend_img
+
+
+def _signalTrigger(ctx, img, height, width, distance, trig_h, trig_w):
+ alpha = 0.2
+ file_name = os.path.join(ctx.data.root, 'triggers/signal_cifar10_mask.npy')
+ signal_mask = np.load(file_name)
+ blend_img = (1 - alpha) * img + alpha * signal_mask.reshape(
+ (height, width, 1))
+ blend_img = np.clip(blend_img.astype('uint8'), 0, 255)
+
+ return blend_img
+
+
+def _hkTrigger(ctx, img, height, width, distance, trig_h, trig_w):
+
+ alpha = 0.2
+
+ file_name = os.path.join(ctx.data.root, 'triggers/hello_kitty.png')
+ signal_mask = mlt.imread(file_name) * 255
+ signal_mask = cv2.resize(signal_mask, (height, width))
+ if img.shape[2] == 1:
+ signal_mask = cv2.cvtColor(signal_mask, cv2.COLOR_RGB2GRAY)
+ signal_mask = np.expand_dims(signal_mask, -1)
+ blend_img = (1 - alpha) * img + alpha * signal_mask
+ blend_img = np.clip(blend_img.astype('uint8'), 0, 255)
+
+ return blend_img
+
+
+def _trojanTrigger(ctx, img, height, width, distance, trig_h, trig_w):
+ file_name = os.path.join(ctx.data.root,
+ 'triggers/best_square_trigger_cifar10.npz')
+ trg = np.load(file_name)['x']
+ trg = np.transpose(trg, (1, 2, 0))
+ img_ = np.clip((img + trg).astype('uint8'), 0, 255)
+
+ return img_
+
+
+def _sigTrigger(ctx,
+ img,
+ height,
+ width,
+ distance,
+ trig_h,
+ trig_w,
+ delta=20,
+ f=6):
+ """
+ Implement paper:
+ > Barni, M., Kallas, K., & Tondi, B. (2019).
+ > A new Backdoor Attack in CNNs by training set corruption without label poisoning.
+ > arXiv preprint arXiv:1902.11237
+ superimposed sinusoidal backdoor signal with default parameters
+ """
+
+ delta = 20
+ img = np.float32(img)
+ pattern = np.zeros_like(img)
+ m = pattern.shape[1]
+ for i in range(int(img.shape[0])):
+ for j in range(int(img.shape[1])):
+ pattern[i, j] = delta * np.sin(2 * np.pi * j * f / m)
+ img = np.uint32(img) + pattern
+ img = np.uint8(np.clip(img, 0, 255))
+ return img
+
+
+def _sig_n_Trigger(ctx,
+ img,
+ height,
+ width,
+ distance,
+ trig_h,
+ trig_w,
+ delta=40,
+ f=6):
+ """
+ Implement paper:
+ > Barni, M., Kallas, K., & Tondi, B. (2019).
+ > A new Backdoor Attack in CNNs by training set corruption without label poisoning.
+ > arXiv preprint arXiv:1902.11237
+ superimposed sinusoidal backdoor signal with default parameters
+ """
+ delta = 10
+ img = np.float32(img)
+ pattern = np.zeros_like(img)
+ m = pattern.shape[1]
+ for i in range(int(img.shape[0])):
+ for j in range(int(img.shape[1])):
+ pattern[i, j] = delta * np.sin(2 * np.pi * j * f / m)
+ img = np.uint32(img) + pattern
+ img = np.uint8(np.clip(img, 0, 255))
+ return img
+
+
+def _wanetTrigger(ctx,
+ img,
+ height,
+ width,
+ distance,
+ trig_w,
+ trig_h,
+ delta=20,
+ f=6):
+ """
+ Implement paper:
+ """
+ k = 4
+ s = 0.5
+ input_height = height
+ grid_rescale = 1
+ ins = torch.rand(1, 2, k, k) * 2 - 1
+ ins = ins / torch.mean(torch.abs(ins))
+ noise_grid = (F.upsample(ins,
+ size=input_height,
+ mode="bicubic",
+ align_corners=True).permute(0, 2, 3, 1))
+ array1d = torch.linspace(-1, 1, steps=input_height)
+ x, y = torch.meshgrid(array1d, array1d)
+ identity_grid = torch.stack((y, x), 2)[None, ...]
+ grid_temps = (identity_grid + s * noise_grid / input_height) * grid_rescale
+ grid_temps = torch.clamp(grid_temps, -1, 1)
+ img = np.float32(img)
+ img = torch.tensor(img).reshape(-1, height, width).unsqueeze(0)
+ img = F.grid_sample(img, grid_temps,
+ align_corners=True).squeeze(0).reshape(
+ height, width, -1)
+ img = np.uint8(np.clip(img.cpu().numpy(), 0, 255))
+
+ return img
+
+
+def _wanetTriggerCross(ctx,
+ img,
+ height,
+ width,
+ distance,
+ trig_w,
+ trig_h,
+ delta=20,
+ f=6):
+ """
+ Implement paper:
+ """
+ k = 4
+ s = 0.5
+ input_height = height
+ grid_rescale = 1
+ ins = torch.rand(1, 2, k, k) * 2 - 1
+ ins = ins / torch.mean(torch.abs(ins))
+ noise_grid = (F.upsample(ins,
+ size=input_height,
+ mode="bicubic",
+ align_corners=True).permute(0, 2, 3, 1))
+ array1d = torch.linspace(-1, 1, steps=input_height)
+ x, y = torch.meshgrid(array1d, array1d)
+ identity_grid = torch.stack((y, x), 2)[None, ...]
+ grid_temps = (identity_grid + s * noise_grid / input_height) * grid_rescale
+ grid_temps = torch.clamp(grid_temps, -1, 1)
+ ins = torch.rand(1, input_height, input_height, 2) * 2 - 1
+ grid_temps2 = grid_temps + ins / input_height
+ grid_temps2 = torch.clamp(grid_temps2, -1, 1)
+ img = np.float32(img)
+ img = torch.tensor(img).reshape(-1, height, width).unsqueeze(0)
+ img = F.grid_sample(img, grid_temps2,
+ align_corners=True).squeeze(0).reshape(
+ height, width, -1)
+ img = np.uint8(np.clip(img.cpu().numpy(), 0, 255))
+ return img
diff --git a/federatedscope/attack/auxiliary/create_edgeset.py b/federatedscope/attack/auxiliary/create_edgeset.py
new file mode 100644
index 000000000..023e8979e
--- /dev/null
+++ b/federatedscope/attack/auxiliary/create_edgeset.py
@@ -0,0 +1,145 @@
+import torch
+import torch.utils.data as data
+from PIL import Image
+import numpy as np
+from torchvision.datasets import MNIST, EMNIST, CIFAR10
+from torchvision.datasets import DatasetFolder
+from torchvision import transforms
+
+from PIL import Image
+
+
+import os
+import sys
+import logging
+import pickle
+import copy
+
+
+
+def create_ardis_poisoned_dataset(base_label=7, target_label=1, fraction=0.1):
+
+
+ ardis_images=np.loadtxt('/mnt/zeyuqin/FederatedScope/data/ARDIS/ARDIS_train_2828.csv', \
+ dtype='float')
+ ardis_labels=np.loadtxt('/mnt/zeyuqin/FederatedScope/data/ARDIS/ARDIS_train_labels.csv', \
+ dtype='float')
+
+
+ ardis_images = ardis_images.reshape(ardis_images.shape[0], 28,
+ 28).astype('float32')
+
+ indices_seven = np.where(ardis_labels[:, base_label] == 1)[0]
+ images_seven = ardis_images[indices_seven, :]
+ images_seven = torch.tensor(images_seven).type(torch.uint8)
+
+ if fraction < 1:
+ num_sampled_data_points = (int)(fraction * images_seven.size()[0])
+ perm = torch.randperm(images_seven.size()[0])
+ idx = perm[:num_sampled_data_points]
+ images_seven_cut = images_seven[idx]
+ images_seven_cut = images_seven_cut.unsqueeze(1)
+ print('size of images_seven_cut: ', images_seven_cut.size())
+ poisoned_labels_cut = (torch.zeros(images_seven_cut.size()[0]) +
+ target_label).long()
+
+ else:
+ images_seven_DA = copy.deepcopy(images_seven)
+
+ cand_angles = [180 / fraction * i for i in range(1, fraction + 1)]
+ print("Candidate angles for DA: {}".format(cand_angles))
+
+ for idx in range(len(images_seven)):
+ for cad_ang in cand_angles:
+ PIL_img = transforms.ToPILImage()(
+ images_seven[idx]).convert("L")
+ PIL_img_rotate = transforms.functional.rotate(PIL_img,
+ cad_ang,
+ fill=(0, ))
+
+ img_rotate = torch.from_numpy(np.array(PIL_img_rotate))
+ images_seven_DA = torch.cat((images_seven_DA, \
+ img_rotate.reshape(1,img_rotate.size()[0], \
+ img_rotate.size()[0])), 0)
+
+ print(images_seven_DA.size())
+
+ poisoned_labels_DA = (torch.zeros(images_seven_DA.size()[0]) +
+ target_label).long()
+
+ poisoned_edgeset = []
+ if fraction < 1:
+ for ii in range(len(images_seven_cut)):
+ poisoned_edgeset.append(
+ (images_seven_cut[ii], poisoned_labels_cut[ii]))
+
+ print("Shape of poisoned_edgeset dataset (poisoned): {}, \
+ shape of poisoned_edgeset labels: {}".format(
+ images_seven_cut.size(), poisoned_labels_cut.size()))
+
+ else:
+ for ii in range(len(images_seven_DA)):
+ poisoned_edgeset.append(
+ (images_seven_DA[ii], poisoned_labels_DA[ii]))
+
+ print("Shape of poisoned_edgeset dataset (poisoned): {}, \
+ shape of poisoned_edgeset labels: {}".format(
+ images_seven_DA.size(), poisoned_labels_DA.size()))
+
+ return poisoned_edgeset
+
+
+
+def create_ardis_test_dataset(base_label=7, target_label=1):
+
+ ardis_images = np.loadtxt(
+ '/mnt/zeyuqin/FederatedScope/data/ARDIS/ARDIS_test_2828.csv',
+ dtype='float')
+ ardis_labels = np.loadtxt(
+ '/mnt/zeyuqin/FederatedScope/data/ARDIS/ARDIS_test_labels.csv',
+ dtype='float')
+
+
+ ardis_images = torch.tensor(ardis_images.reshape(ardis_images.shape[0], \
+ 28, 28).astype('float32')).type(torch.uint8)
+
+ indices_seven = np.where(ardis_labels[:, base_label] == 1)[0]
+ images_seven = ardis_images[indices_seven, :]
+ images_seven = torch.tensor(images_seven).type(torch.uint8)
+ images_seven = images_seven.unsqueeze(1)
+
+ poisoned_labels = (torch.zeros(images_seven.size()[0]) +
+ target_label).long()
+ poisoned_labels = torch.tensor(poisoned_labels)
+
+ ardis_test_dataset = []
+
+ for ii in range(len(images_seven)):
+ ardis_test_dataset.append((images_seven[ii], poisoned_labels[ii]))
+
+ print("Shape of ardis test dataset (poisoned): {},\
+ shape of ardis test labels: {}".format(images_seven.size(),
+ poisoned_labels.size()))
+
+ return ardis_test_dataset
+
+
+
+
+if __name__ == '__main__':
+
+ fraction = 0.1
+
+ poisoned_edgeset = create_ardis_poisoned_dataset(fraction=fraction)
+
+ ardis_test_dataset = create_ardis_test_dataset()
+
+ print("Writing poison_data to: ")
+ print("poisoned_edgeset_fraction_{}".format(fraction))
+
+ with open("poisoned_edgeset_fraction_{}".format(fraction),
+ "wb") as saved_data_file:
+ torch.save(poisoned_edgeset, saved_data_file)
+
+ with open("ardis_test_dataset.pt", "wb") as ardis_data_file:
+ torch.save(ardis_test_dataset, ardis_data_file)
diff --git a/federatedscope/attack/auxiliary/poisoning_data.py b/federatedscope/attack/auxiliary/poisoning_data.py
new file mode 100644
index 000000000..056834837
--- /dev/null
+++ b/federatedscope/attack/auxiliary/poisoning_data.py
@@ -0,0 +1,318 @@
+from asyncio.log import logger
+import torch
+import torch.utils.data as data
+from PIL import Image
+import numpy as np
+from torchvision.datasets import MNIST, EMNIST, CIFAR10
+from torchvision.datasets import DatasetFolder
+from torchvision import transforms
+from federatedscope.core.auxiliaries.transform_builder import get_transform
+from federatedscope.attack.auxiliary.backdoor_utils import selectTrigger
+from torch.utils.data import DataLoader, Dataset
+from federatedscope.attack.auxiliary.backdoor_utils import normalize
+import matplotlib
+import pickle
+import logging
+import os
+
+
+
+
+logger = logging.getLogger(__name__)
+
+
+def load_poisoned_dataset_edgeset(data, ctx, mode):
+
+ transforms_funcs = get_transform(ctx, 'torchvision')['transform']
+
+ if "femnist" in ctx.data.type:
+ if mode == 'train':
+ file_name = os.path.join(ctx.data.root,
+ 'poisoned_edgeset_fraction_0.1')
+ with open(file_name, "rb") as saved_data_file:
+ poisoned_edgeset = torch.load(saved_data_file)
+ num_dps_poisoned_dataset = len(poisoned_edgeset)
+
+ for ii in range(num_dps_poisoned_dataset):
+ sample, label = poisoned_edgeset[ii]
+ sample = sample.numpy().transpose(1, 2, 0)
+
+ data['train'].dataset.append((transforms_funcs(sample), label))
+
+ if mode == 'test':
+ poison_testset = list()
+ file_name = os.path.join(ctx.data.root, 'ardis_test_dataset.pt')
+ with open(file_name, "rb") as saved_data_file:
+ poisoned_edgeset = torch.load(saved_data_file)
+ num_dps_poisoned_dataset = len(poisoned_edgeset)
+
+ for ii in range(num_dps_poisoned_dataset):
+ sample, label = poisoned_edgeset[ii]
+ sample = sample.numpy().transpose(1, 2, 0)
+ poison_testset.append((transforms_funcs(sample), label))
+ data['poison'] = DataLoader(poison_testset,
+ batch_size=ctx.data.batch_size,
+ shuffle=False,
+ num_workers=ctx.data.num_workers)
+
+ elif "CIFAR10" in ctx.data.type:
+ target_label = 9
+ label = target_label
+
+ num_poisoned = ctx.attack.edge_num
+
+ if mode == 'train':
+ file_name = os.path.join(ctx.data.root,
+ 'southwest_images_new_train.pkl')
+ with open(file_name, 'rb') as train_f:
+ saved_southwest_dataset_train = pickle.load(train_f)
+ num_poisoned_dataset = num_poisoned
+ samped_poisoned_data_indices = np.random.choice(
+ saved_southwest_dataset_train.shape[0],
+ num_poisoned_dataset,
+ replace=False)
+ saved_southwest_dataset_train = saved_southwest_dataset_train[
+ samped_poisoned_data_indices, :, :, :]
+
+ for ii in range(num_poisoned_dataset):
+ sample = saved_southwest_dataset_train[ii]
+ data['train'].dataset.append((transforms_funcs(sample), label))
+
+ logger.info('adding {:d} edge-cased samples in CIFAR-10'.format(
+ num_poisoned))
+
+ if mode == 'test':
+ poison_testset = list()
+ file_name = os.path.join(ctx.data.root,
+ 'southwest_images_new_test.pkl')
+ with open(file_name, 'rb') as test_f:
+ saved_southwest_dataset_test = pickle.load(test_f)
+ num_poisoned_dataset = len(saved_southwest_dataset_test)
+
+ for ii in range(num_poisoned_dataset):
+ sample = saved_southwest_dataset_test[ii]
+ poison_testset.append((transforms_funcs(sample), label))
+ data['poison'] = DataLoader(poison_testset,
+ batch_size=ctx.data.batch_size,
+ shuffle=False,
+ num_workers=ctx.data.num_workers)
+
+ else:
+ raise RuntimeError(
+ 'Now, we only support the FEMNIST and CIFAR-10 datasets')
+
+ return data
+
+
+def addTrigger(ctx,
+ dataset,
+ target_label,
+ inject_portion,
+ mode,
+ distance,
+ trig_h,
+ trig_w,
+ trigger_type,
+ label_type,
+ surrogate_model=None):
+
+ cnt_all = int(len(dataset) * inject_portion)
+ height = dataset[0][0].shape[-2]
+ width = dataset[0][0].shape[-1]
+ trig_h = int(trig_h * height)
+ trig_w = int(trig_w * width)
+ if 'wanet' in trigger_type:
+ cross_portion = 2
+ perm_then = np.random.permutation(
+ len(dataset
+ ))[0:int(len(dataset) * inject_portion * (1 + cross_portion))]
+ perm = perm_then[0:int(len(dataset) * inject_portion)]
+ perm_cross = perm_then[(
+ int(len(dataset) * inject_portion) +
+ 1):int(len(dataset) * inject_portion * (1 + cross_portion))]
+ else:
+ perm = np.random.permutation(
+ len(dataset))[0:int(len(dataset) * inject_portion)]
+
+ dataset_ = list()
+ '''
+ need to specify the form of (x, y) from dataset
+ Now, the form of x is torch.tensor [0:1] (channel, height, width)
+ return the x : np.array [0:255], (height, width, channel)
+ '''
+
+ ii = 0
+ for i in range(len(dataset)):
+ data = dataset[i]
+
+ if label_type == 'dirty':
+ if mode == 'train':
+ img = np.array(data[0]).transpose(1, 2, 0) * 255.0
+ img = np.clip(img.astype('uint8'), 0, 255)
+ height = img.shape[0]
+ width = img.shape[1]
+ if data[1] == 0 or data[1] == 1:
+ ii += 1
+
+ if i in perm:
+ img = selectTrigger(ctx, img, height, width, distance,
+ trig_h, trig_w, trigger_type)
+ dataset_.append((img, target_label))
+
+ elif 'wanet' in trigger_type and i in perm_cross:
+ img = selectTrigger(ctx, img, width, height, distance,
+ trig_w, trig_h, 'wanetTriggerCross')
+ dataset_.append((img, data[1]))
+
+ else:
+ dataset_.append((img, data[1]))
+
+ if mode == 'test':
+ if data[1] == target_label:
+ continue
+
+ img = np.array(data[0]).transpose(1, 2, 0) * 255.0
+ img = np.clip(img.astype('uint8'), 0, 255)
+ height = img.shape[0]
+ width = img.shape[1]
+ if i in perm:
+ img = selectTrigger(ctx, img, width, height, distance,
+ trig_w, trig_h, trigger_type)
+ dataset_.append((img, target_label))
+ else:
+ dataset_.append((img, data[1]))
+
+ elif label_type == 'clean_label':
+ pass
+
+ return dataset_
+
+
+def load_poisoned_dataset_pixel(data, ctx, mode):
+
+ trigger_type = ctx.attack.trigger_type
+ label_type = ctx.attack.label_type
+ target_label = int(ctx.attack.target_label_ind)
+ transforms_funcs = get_transform(ctx, 'torchvision')['transform']
+
+ if "femnist" in ctx.data.type:
+ inject_portion_train = ctx.attack.poison_ratio
+ target_label = torch.tensor(int(ctx.attack.target_label_ind)).long()
+
+ elif "CIFAR10" in ctx.data.type:
+ inject_portion_train = ctx.attack.poison_ratio
+ target_label = int(ctx.attack.target_label_ind)
+
+ else:
+ raise RuntimeError(
+ 'Now, we only support the FEMNIST and CIFAR-10 datasets')
+
+ inject_portion_test = 1.0
+
+ if mode == 'train':
+ poisoned_dataset = addTrigger(ctx, data['train'].dataset, target_label, inject_portion_train, \
+ mode = 'train', distance=1, trig_h = 0.1, trig_w = 0.1, \
+ trigger_type = trigger_type, label_type = label_type)
+ num_dps_poisoned_dataset = len(poisoned_dataset)
+ for iii in range(num_dps_poisoned_dataset):
+ sample, label = poisoned_dataset[iii]
+ poisoned_dataset[iii] = (transforms_funcs(sample), label)
+
+ data['train'] = DataLoader(poisoned_dataset,
+ batch_size=ctx.data.batch_size,
+ shuffle=True,
+ num_workers=ctx.data.num_workers)
+
+ if mode == 'test':
+ poisoned_dataset = addTrigger(ctx, data[mode].dataset, target_label, inject_portion_test, mode = mode, \
+ distance=1, trig_h = 0.1, trig_w = 0.1, trigger_type = trigger_type, label_type = label_type)
+ num_dps_poisoned_dataset = len(poisoned_dataset)
+ for iii in range(num_dps_poisoned_dataset):
+ sample, label = poisoned_dataset[iii]
+ poisoned_dataset[iii] = (transforms_funcs(sample), label)
+
+ data['poison'] = DataLoader(poisoned_dataset,
+ batch_size=ctx.data.batch_size,
+ shuffle=False,
+ num_workers=ctx.data.num_workers)
+
+ return data
+
+
+def add_trans_normalize(data, ctx):
+ '''
+ data for each client is a dictionary.
+ '''
+
+ for key in data:
+ num_dataset = len(data[key].dataset)
+ mean, std = ctx.attack.mean, ctx.attack.std
+ if "CIFAR10" in ctx.data.type and key == 'train':
+ transforms_list = []
+ transforms_list.append(transforms.RandomHorizontalFlip())
+ transforms_list.append(transforms.ToTensor())
+ tran_train = transforms.Compose(transforms_list)
+ for iii in range(num_dataset):
+ sample = np.array(data[key].dataset[iii][0]).transpose(
+ 1, 2, 0) * 255.0
+ sample = np.clip(sample.astype('uint8'), 0, 255)
+ sample = Image.fromarray(sample)
+ sample = tran_train(sample)
+ data[key].dataset[iii] = (normalize(sample, mean, std),
+ data[key].dataset[iii][1])
+ else:
+ for iii in range(num_dataset):
+ data[key].dataset[iii] = (normalize(data[key].dataset[iii][0], mean, std), \
+ data[key].dataset[iii][1])
+
+ return data
+
+
+def select_poisoning(data, ctx, mode):
+
+ if 'edge' in ctx.attack.trigger_type:
+ data = load_poisoned_dataset_edgeset(data, ctx, mode)
+ elif 'semantic' in ctx.attack.trigger_type:
+ pass
+ else:
+ data = load_poisoned_dataset_pixel(data, ctx, mode)
+ return data
+
+
+def poisoning(data, ctx):
+ for i in range(1, len(data) + 1):
+ if i == ctx.attack.attacker_id:
+ logger.info(50 * '-')
+ logger.info('start poisoning!!!!!!')
+ logger.info(50 * '-')
+ data[i] = select_poisoning(data[i], ctx, mode='train')
+ data[i] = select_poisoning(data[i], ctx, mode='test')
+
+ if 'cifar' in ctx.data.type.lower():
+ class_num_train = torch.zeros(10)
+ class_num_test = torch.zeros(10)
+ else:
+ class_num_train = torch.zeros(62)
+ class_num_test = torch.zeros(62)
+
+ for jj in data[i]['train'].dataset:
+ ind = jj[1]
+ class_num_train[ind] += 1
+ print('training: the {} client'.format(i))
+ print(class_num_train)
+ print('the training number of client {}: {}'.format(
+ i,
+ class_num_train.sum().item()))
+
+ for jj in data[i]['test'].dataset:
+ ind = jj[1]
+ class_num_test[ind] += 1
+ print('testing: the {} client'.format(i))
+ print(class_num_test)
+ print('the testing number of client {}: {}'.format(
+ i,
+ class_num_test.sum().item()))
+
+ data[i] = add_trans_normalize(data[i], ctx)
+ logger.info('finishing the clean and {} poisoning data processing for Client {:d}'\
+ .format(ctx.attack.trigger_type, i))
diff --git a/federatedscope/attack/auxiliary/utils.py b/federatedscope/attack/auxiliary/utils.py
new file mode 100644
index 000000000..4e6be90d5
--- /dev/null
+++ b/federatedscope/attack/auxiliary/utils.py
@@ -0,0 +1,321 @@
+import torch
+import torch.nn.functional as F
+import matplotlib.pyplot as plt
+import logging
+import os
+import numpy as np
+import federatedscope.register as register
+
+logger = logging.getLogger(__name__)
+
+
+def label_to_onehot(target, num_classes=100):
+ return torch.nn.functional.one_hot(target, num_classes)
+
+
+def cross_entropy_for_onehot(pred, target):
+ return torch.mean(torch.sum(-target * F.log_softmax(pred, dim=-1), 1))
+
+
+def iDLG_trick(original_gradient, num_class, is_one_hot_label=False):
+ '''
+ Using iDLG trick to recover the label. Paper: "iDLG: Improved Deep Leakage from Gradients", link: https://arxiv.org/abs/2001.02610
+
+ Args:
+ original_gradient: the gradient of the FL model; type: list
+ num_class: the total number of class in the data
+ is_one_hot_label: whether the dataset's label is in the form of one hot. Type: bool
+
+ Returns:
+ The recovered label by iDLG trick.
+
+ '''
+ last_weight_min = torch.argmin(torch.sum(original_gradient[-2], dim=-1),
+ dim=-1).detach()
+
+ if is_one_hot_label:
+ label = label_to_onehot(
+ last_weight_min.reshape((1, )).requires_grad_(False), num_class)
+ else:
+ label = last_weight_min
+ return label
+
+
+def cos_sim(input_gradient, gt_gradient):
+ total = 1 - torch.nn.functional.cosine_similarity(
+ input_gradient.flatten(), gt_gradient.flatten(), 0, 1e-10)
+
+ # total = 0
+ # input_norm= 0
+ # gt_norm = 0
+ #
+ # total -= (input_gradient * gt_gradient).sum()
+ # input_norm += input_gradient.pow(2).sum()
+ # gt_norm += gt_gradient.pow(2).sum()
+ # total += 1 + total / input_norm.sqrt() / gt_norm.sqrt()
+
+ return total
+
+
+def total_variation(x):
+ """Anisotropic TV."""
+ dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))
+ dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
+
+ total = x.size()[0]
+ for ind in range(1, len(x.size())):
+ total *= x.size()[ind]
+ return (dx + dy) / (total)
+
+
+def approximate_func(x, device, C1=20, C2=0.5):
+ '''
+ Approximate the function f(x) = 0 if x<0.5 otherwise 1
+ Args:
+ x: input data;
+ device:
+ C1:
+ C2:
+
+ Returns:
+ 1/(1+e^{-1*C1 (x-C2)})
+
+ '''
+ C1 = torch.tensor(C1).to(torch.device(device))
+ C2 = torch.tensor(C2).to(torch.device(device))
+
+ return 1 / (1 + torch.exp(-1 * C1 * (x - C2)))
+
+
+def get_classifier(classifier: str, model=None):
+ if model is not None:
+ return model
+
+ if classifier == 'lr':
+ from sklearn.linear_model import LogisticRegression
+ model = LogisticRegression(random_state=0)
+ return model
+ elif classifier.lower() == 'randomforest':
+ from sklearn.ensemble import RandomForestClassifier
+ model = RandomForestClassifier(random_state=0)
+ return model
+ elif classifier.lower() == 'svm':
+ from sklearn.svm import SVC
+ from sklearn.preprocessing import StandardScaler
+ from sklearn.pipeline import make_pipeline
+ model = make_pipeline(StandardScaler(), SVC(gamma='auto'))
+ return model
+ else:
+ ValueError()
+
+
+def get_data_info(dataset_name):
+ '''
+ Get the dataset information, including the feature dimension, number of total classes, whether the label is represented in one-hot version
+
+ Args:
+ dataset_name:dataset name; str
+
+ :returns:
+ data_feature_dim, num_class, is_one_hot_label
+
+ '''
+ if dataset_name.lower() == 'femnist':
+
+ return [1, 28, 28], 36, False
+ else:
+ ValueError(
+ 'Please provide the data info of {}: data_feature_dim, num_class'.
+ format(dataset_name))
+
+
+def get_data_sav_fn(dataset_name):
+ if dataset_name.lower() == 'femnist':
+ return sav_femnist_image
+ else:
+ logger.info(
+ "Reconstructed data saving function is not provided for dataset: {}"
+ .format(dataset_name))
+ return None
+
+
+def sav_femnist_image(data, sav_pth, name):
+
+ fig = plt.figure(figsize=(4, 4))
+ # print(data.shape)
+
+ if len(data.shape) == 2:
+ data = torch.unsqueeze(data, 0)
+ data = torch.unsqueeze(data, 0)
+
+ ind = min(data.shape[0], 16)
+ # print(data.shape)
+
+ # plt.imshow(data * 127.5 + 127.5, cmap='gray')
+
+ for i in range(ind):
+ plt.subplot(4, 4, i + 1)
+
+ plt.imshow(data[i, 0, :, :] * 127.5 + 127.5, cmap='gray')
+ # plt.imshow(generated_data[i, 0, :, :] , cmap='gray')
+ # plt.imshow()
+ plt.axis('off')
+
+ plt.savefig(os.path.join(sav_pth, name))
+ plt.close()
+
+
+def get_info_diff_loss(info_diff_type):
+ if info_diff_type.lower() == 'l2':
+ info_diff_loss = torch.nn.MSELoss(reduction='sum')
+ elif info_diff_type.lower() == 'l1':
+ info_diff_loss = torch.nn.SmoothL1Loss(reduction='sum', beta=1e-5)
+ elif info_diff_type.lower() == 'sim':
+ info_diff_loss = cos_sim
+ else:
+ ValueError(
+ 'info_diff_type: {} is not supported'.format(info_diff_type))
+ return info_diff_loss
+
+
+def get_reconstructor(atk_method, **kwargs):
+ '''
+
+ Args:
+ atk_method: the attack method name, and currently supporting
+ "DLG: deep leakage from gradient", and "IG: Inverting gradient" ; Type: str
+ **kwargs: other arguments
+
+ Returns:
+
+ '''
+
+ if atk_method.lower() == 'dlg':
+ from federatedscope.attack.privacy_attacks.reconstruction_opt import DLG
+ logger.info(
+ '--------- Getting reconstructor: DLG --------------------')
+
+ return DLG(max_ite=kwargs['max_ite'],
+ lr=kwargs['lr'],
+ federate_loss_fn=kwargs['federate_loss_fn'],
+ device=kwargs['device'],
+ federate_lr=kwargs['federate_lr'],
+ optim=kwargs['optim'],
+ info_diff_type=kwargs['info_diff_type'],
+ federate_method=kwargs['federate_method'])
+ elif atk_method.lower() == 'ig':
+ from federatedscope.attack.privacy_attacks.reconstruction_opt import InvertGradient
+ logger.info(
+ '--------- Getting reconstructor: InvertGradient --------------------'
+ )
+ return InvertGradient(max_ite=kwargs['max_ite'],
+ lr=kwargs['lr'],
+ federate_loss_fn=kwargs['federate_loss_fn'],
+ device=kwargs['device'],
+ federate_lr=kwargs['federate_lr'],
+ optim=kwargs['optim'],
+ info_diff_type=kwargs['info_diff_type'],
+ federate_method=kwargs['federate_method'],
+ alpha_TV=kwargs['alpha_TV'])
+ else:
+ ValueError(
+ "attack method: {} lacks reconstructor implementation".format(
+ atk_method))
+
+
+def get_generator(dataset_name):
+ '''
+ Get the dataset's corresponding generator.
+ Args:
+ dataset_name: The dataset name; Type: str
+
+ :returns:
+ The generator; Type: object
+
+ '''
+ if dataset_name == 'femnist':
+ from federatedscope.attack.models.gan_based_model import GeneratorFemnist
+ return GeneratorFemnist
+ else:
+ ValueError(
+ "The generator to generate data like {} is not defined!".format(
+ dataset_name))
+
+
+def get_data_property(ctx):
+ # A SHOWCASE for Femnist dataset: Property := whether contains a circle.
+ x, label = [_.to(ctx.device) for _ in ctx.data_batch]
+
+ prop = torch.zeros(label.size)
+ positive_labels = [0, 6, 8]
+ for ind in range(label.size()[0]):
+ if label[ind] in positive_labels:
+ prop[ind] = 1
+ prop.to(ctx.device)
+ return prop
+
+
+def get_passive_PIA_auxiliary_dataset(dataset_name):
+ '''
+
+ Args:
+ dataset_name (str): dataset name
+
+ :returns:
+
+ the auxiliary dataset for property inference attack. Type: dict
+
+ {
+ 'x': array,
+ 'y': array,
+ 'prop': array
+ }
+
+ '''
+ for func in register.auxiliary_data_loader_PIA_dict.values():
+ criterion = func(dataset_name)
+ if criterion is not None:
+ return criterion
+ if dataset_name == 'toy':
+
+ def _generate_data(instance_num=1000, feature_num=5, save_data=False):
+ """
+ Generate data in FedRunner format
+ Args:
+ instance_num:
+ feature_num:
+ save_data:
+
+ Returns:
+ {
+ 'x': ...,
+ 'y': ...,
+ 'prop': ...
+ }
+
+ """
+ weights = np.random.normal(loc=0.0, scale=1.0, size=feature_num)
+ bias = np.random.normal(loc=0.0, scale=1.0)
+
+ prop_weights = np.random.normal(loc=0.0,
+ scale=1.0,
+ size=feature_num)
+ prop_bias = np.random.normal(loc=0.0, scale=1.0)
+
+ x = np.random.normal(loc=0.0,
+ scale=0.5,
+ size=(instance_num, feature_num))
+ y = np.sum(x * weights, axis=-1) + bias
+ y = np.expand_dims(y, -1)
+ prop = np.sum(x * prop_weights, axis=-1) + prop_bias
+ prop = 1.0 * ((1 / (1 + np.exp(-1 * prop))) > 0.5)
+ prop = np.expand_dims(prop, -1)
+
+ data_train = {'x': x, 'y': y, 'prop': prop}
+ return data_train
+
+ return _generate_data()
+ else:
+ ValueError(
+ 'The data: {} cannot be loaded. Please specify the data load function.'
+ )
diff --git a/federatedscope/attack/models/__init__.py b/federatedscope/attack/models/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/federatedscope/attack/models/gan_based_model.py b/federatedscope/attack/models/gan_based_model.py
new file mode 100644
index 000000000..0665fd764
--- /dev/null
+++ b/federatedscope/attack/models/gan_based_model.py
@@ -0,0 +1,74 @@
+import torch
+import torch.nn as nn
+from copy import deepcopy
+
+
+class GeneratorFemnist(nn.Module):
+ '''
+ The generator for Femnist dataset
+ '''
+ def __init__(self, noise_dim=100):
+ super(GeneratorFemnist, self).__init__()
+
+ module_list = []
+ module_list.append(
+ nn.Linear(in_features=noise_dim,
+ out_features=4 * 4 * 256,
+ bias=False))
+ module_list.append(nn.BatchNorm1d(num_features=4 * 4 * 256))
+ module_list.append(nn.ReLU())
+ self.body1 = nn.Sequential(*module_list)
+
+ # need to reshape the output of self.body1
+
+ module_list = []
+
+ module_list.append(
+ nn.ConvTranspose2d(in_channels=256,
+ out_channels=128,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ bias=False))
+ module_list.append(nn.BatchNorm2d(128))
+ module_list.append(nn.ReLU())
+ self.body2 = nn.Sequential(*module_list)
+
+ module_list = []
+ module_list.append(
+ nn.ConvTranspose2d(in_channels=128,
+ out_channels=64,
+ kernel_size=(3, 3),
+ stride=(2, 2),
+ bias=False))
+ module_list.append(nn.BatchNorm2d(64))
+ module_list.append(nn.ReLU())
+ self.body3 = nn.Sequential(*module_list)
+
+ module_list = []
+ module_list.append(
+ nn.ConvTranspose2d(in_channels=64,
+ out_channels=1,
+ kernel_size=(4, 4),
+ stride=(2, 2),
+ bias=False))
+ module_list.append(nn.BatchNorm2d(1))
+ module_list.append(nn.Tanh())
+ self.body4 = nn.Sequential(*module_list)
+
+ def forward(self, x):
+
+ tmp1 = self.body1(x).view(-1, 256, 4, 4)
+
+ assert tmp1.size()[1:4] == (256, 4, 4)
+
+ tmp2 = self.body2(tmp1)
+ assert tmp2.size()[1:4] == (128, 6, 6)
+
+ tmp3 = self.body3(tmp2)
+
+ assert tmp3.size()[1:4] == (64, 13, 13)
+
+ tmp4 = self.body4(tmp3)
+ assert tmp4.size()[1:4] == (1, 28, 28)
+
+ return tmp4
diff --git a/federatedscope/attack/models/vision.py b/federatedscope/attack/models/vision.py
new file mode 100644
index 000000000..b5b2c12ad
--- /dev/null
+++ b/federatedscope/attack/models/vision.py
@@ -0,0 +1,209 @@
+"""This file is part of https://github.com/mit-han-lab/dlg.
+MIT License
+Copyright (c) 2019 Ildoo Kim
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import grad
+import torchvision
+from torchvision import models, datasets, transforms
+
+
+def weights_init(m):
+ if hasattr(m, "weight"):
+ m.weight.data.uniform_(-0.5, 0.5)
+ if hasattr(m, "bias"):
+ m.bias.data.uniform_(-0.5, 0.5)
+
+
+class LeNet(nn.Module):
+ def __init__(self):
+ super(LeNet, self).__init__()
+ act = nn.Sigmoid
+ self.body = nn.Sequential(
+ nn.Conv2d(3, 12, kernel_size=5, padding=5 // 2, stride=2),
+ act(),
+ nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
+ act(),
+ nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
+ act(),
+ )
+
+ self.fc = nn.Sequential(nn.Linear(768, 100))
+
+ def forward(self, x):
+ out = self.body(x)
+ out = out.view(out.size(0), -1)
+ # print(out.size())
+ out = self.fc(out)
+ return out
+
+
+'''ResNet in PyTorch.
+For Pre-activation ResNet, see 'preact_resnet.py'.
+Reference:
+[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def weights_init(m):
+ if hasattr(m, "weight"):
+ m.weight.data.uniform_(-0.5, 0.5)
+ if hasattr(m, "bias"):
+ m.bias.data.uniform_(-0.5, 0.5)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(BasicBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False), nn.BatchNorm2d(self.expansion * planes))
+
+ def forward(self, x):
+ out = F.Sigmoid(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ out = F.Sigmoid(out)
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes,
+ self.expansion * planes,
+ kernel_size=1,
+ bias=False)
+ self.bn3 = nn.BatchNorm2d(self.expansion * planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False), nn.BatchNorm2d(self.expansion * planes))
+
+ def forward(self, x):
+ out = F.Sigmoid(self.bn1(self.conv1(x)))
+ out = F.Sigmoid(self.bn2(self.conv2(out)))
+ out = self.bn3(self.conv3(out))
+ out += self.shortcut(x)
+ out = F.Sigmoid(out)
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self, block, num_blocks, num_classes=10):
+ super(ResNet, self).__init__()
+ self.in_planes = 64
+
+ self.conv1 = nn.Conv2d(3,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=1)
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=1)
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=1)
+ self.linear = nn.Linear(512 * block.expansion, num_classes)
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ out = F.Sigmoid(self.bn1(self.conv1(x)))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = self.layer3(out)
+ out = self.layer4(out)
+ out = F.avg_pool2d(out, 4)
+ out = out.view(out.size(0), -1)
+ out = self.linear(out)
+ return out
+
+
+def ResNet18():
+ return ResNet(BasicBlock, [2, 2, 2, 2])
+
+
+def ResNet34():
+ return ResNet(BasicBlock, [3, 4, 6, 3])
+
+
+def ResNet50():
+ return ResNet(Bottleneck, [3, 4, 6, 3])
+
+
+def ResNet101():
+ return ResNet(Bottleneck, [3, 4, 23, 3])
+
+
+def ResNet152():
+
+ return ResNet(Bottleneck, [3, 8, 36, 3])
diff --git a/federatedscope/attack/privacy_attacks/GAN_based_attack.py b/federatedscope/attack/privacy_attacks/GAN_based_attack.py
new file mode 100644
index 000000000..4a404bbfe
--- /dev/null
+++ b/federatedscope/attack/privacy_attacks/GAN_based_attack.py
@@ -0,0 +1,179 @@
+import torch
+import torch.nn as nn
+from copy import deepcopy
+from federatedscope.attack.auxiliary.utils import get_generator
+import matplotlib.pyplot as plt
+
+
+class GANCRA():
+ '''
+ The implementation of GAN based class representative attack. https://dl.acm.org/doi/abs/10.1145/3133956.3134012
+
+ References:
+
+ Hitaj, Briland, Giuseppe Ateniese, and Fernando Perez-Cruz.
+ "Deep models under the GAN: information leakage from collaborative deep learning."
+ Proceedings of the 2017 ACM SIGSAC conference on computer and communications security. 2017.
+
+
+
+ Args:
+ - target_label_ind (int): the label index whose representative
+ - fl_model (object):
+ - device (str or int): the device to run; 'cpu' or the device index to select; default: 'cpu'.
+ - dataset_name (str): the dataset name; default: None
+ - noise_dim (int): the dimension of the noise that fed into the generator; default: 100
+ - batch_size (int): the number of data generated into training; default: 16
+ - generator_train_epoch (int): the number of training steps when training the generator; default: 10
+ - lr (float): the learning rate of the generator training; default: 0.001
+ - sav_pth (str): the path to save the generated data; default: 'data/'
+ - round_num (int): the FL round that starting the attack; default: -1.
+
+ '''
+ def __init__(self,
+ target_label_ind,
+ fl_model,
+ device='cpu',
+ dataset_name=None,
+ noise_dim=100,
+ batch_size=16,
+ generator_train_epoch=10,
+ lr=0.001,
+ sav_pth='data/',
+ round_num=-1):
+
+ # get dataset's corresponding generator
+ self.generator = get_generator(dataset_name=dataset_name)().to(device)
+ self.target_label_ind = target_label_ind
+
+ self.discriminator = deepcopy(fl_model)
+
+ self.generator_loss_fun = nn.CrossEntropyLoss()
+
+ self.generator_train_epoch = generator_train_epoch
+
+ # the dimension of the noise input to generator
+ self.noise_dim = noise_dim
+ self.batch_size = batch_size
+
+ self.device = device
+
+ # define generator optimizer
+ self.generator_optimizer = torch.optim.SGD(
+ params=self.generator.parameters(), lr=lr)
+ self.sav_pth = sav_pth
+ self.round_num = round_num
+ self.generator_loss_summary = []
+
+ def update_discriminator(self, model):
+ '''
+ Copy the model of the server as the discriminator
+
+ Args:
+ model (object): the model in the server
+
+ Returns: the discriminator
+
+ '''
+
+ self.discriminator = deepcopy(model)
+
+ def discriminator_loss(self):
+ pass
+
+ def generator_loss(self, discriminator_output):
+ '''
+ Get the generator loss based on the discriminator's output
+
+ Args:
+ discriminator_output (Tensor): the discriminator's output; size: batch_size * n_class
+
+ Returns: generator_loss
+
+ '''
+
+ self.num_class = discriminator_output.size()[1]
+ ideal_results = self.target_label_ind * torch.ones(
+ discriminator_output.size()[0], dtype=torch.long)
+
+ # ideal_results[:] = self.target_label_ind
+
+ return self.generator_loss_fun(discriminator_output,
+ ideal_results.to(self.device))
+
+ def _gradient_closure(self, noise):
+ def closure():
+ generated_images = self.generator(noise)
+ discriminator_output = self.discriminator(generated_images)
+ generator_loss = self.generator_loss(discriminator_output)
+
+ generator_loss.backward()
+ return generator_loss
+
+ return closure
+
+ def generator_train(self):
+
+ for _ in range(self.generator_train_epoch):
+
+ self.generator.zero_grad()
+ self.generator_optimizer.zero_grad()
+ noise = torch.randn(size=(self.batch_size, self.noise_dim)).to(
+ torch.device(self.device))
+ closure = self._gradient_closure(noise)
+ tmp_loss = self.generator_optimizer.step(closure)
+ self.generator_loss_summary.append(
+ tmp_loss.detach().to('cpu').numpy())
+
+ def generate_fake_data(self, data_num=None):
+ if data_num is None:
+ data_num = self.batch_size
+ noise = torch.randn(size=(data_num, self.noise_dim)).to(
+ torch.device(self.device))
+ generated_images = self.generator(noise)
+
+ generated_label = torch.zeros(self.batch_size, dtype=torch.long).to(
+ torch.device(self.device))
+ if self.target_label_ind + 1 > self.num_class - 1:
+ generated_label[:] = self.target_label_ind - 1
+ else:
+ generated_label[:] = self.target_label_ind + 1
+
+ return generated_images.detach(), generated_label.detach()
+
+ def sav_image(self, generated_data):
+
+ fig = plt.figure(figsize=(4, 4))
+
+ ind = min(generated_data.shape[0], 16)
+
+ for i in range(ind):
+ plt.subplot(4, 4, i + 1)
+
+ plt.imshow(generated_data[i, 0, :, :] * 127.5 + 127.5, cmap='gray')
+ # plt.imshow(generated_data[i, 0, :, :] , cmap='gray')
+ # plt.imshow()
+ plt.axis('off')
+
+ plt.savefig(self.sav_pth + '/' +
+ 'image_round_{}.png'.format(self.round_num))
+ plt.close()
+
+ def sav_plot_gan_loss(self):
+ plt.plot(self.generator_loss_summary)
+ plt.savefig(self.sav_pth + '/' +
+ 'generator_loss_round_{}.png'.format(self.round_num))
+ plt.close()
+
+ def generate_and_save_images(self):
+ '''
+
+ Save the generated data and the generator training loss
+
+ '''
+
+ generated_data, _ = self.generate_fake_data()
+ generated_data = generated_data.detach().to('cpu')
+
+ self.sav_image(generated_data)
+ self.sav_plot_gan_loss()
diff --git a/federatedscope/attack/privacy_attacks/__init__.py b/federatedscope/attack/privacy_attacks/__init__.py
new file mode 100644
index 000000000..5ef9a533c
--- /dev/null
+++ b/federatedscope/attack/privacy_attacks/__init__.py
@@ -0,0 +1,5 @@
+from federatedscope.attack.privacy_attacks.GAN_based_attack import *
+from federatedscope.attack.privacy_attacks.passive_PIA import *
+from federatedscope.attack.privacy_attacks.reconstruction_opt import *
+
+__all__ = ['DLG', 'InvertGradient', 'GANCRA', 'PassivePropertyInference']
diff --git a/federatedscope/attack/privacy_attacks/passive_PIA.py b/federatedscope/attack/privacy_attacks/passive_PIA.py
new file mode 100644
index 000000000..cae0b06c0
--- /dev/null
+++ b/federatedscope/attack/privacy_attacks/passive_PIA.py
@@ -0,0 +1,179 @@
+from federatedscope.attack.auxiliary.utils import get_classifier, get_passive_PIA_auxiliary_dataset
+import torch
+import numpy as np
+import copy
+from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class PassivePropertyInference():
+ '''
+ This is an implementation of the passive property inference (algorithm 3)in Exploiting Unintended Feature Leakage in Collaborative Learning:
+ https://arxiv.org/pdf/1805.04049.pdf
+ '''
+ def __init__(self,
+ classier: str,
+ fl_model_criterion,
+ device,
+ grad_clip,
+ dataset_name,
+ fl_local_update_num,
+ fl_type_optimizer,
+ fl_lr,
+ batch_size=100):
+ # self.auxiliary_dataset['x']: n * d_feature; x is the parameter updates
+ # self.auxiliary_dataset['y']: n * 1; y is the
+ self.dataset_prop_classifier = {"x": None, 'prop': None}
+
+ self.classifier = get_classifier(classier)
+
+ self.auxiliary_dataset = get_passive_PIA_auxiliary_dataset(
+ dataset_name)
+
+ self.fl_model_criterion = fl_model_criterion
+ self.fl_local_update_num = fl_local_update_num
+ self.fl_type_optimizer = fl_type_optimizer
+ self.fl_lr = fl_lr
+
+ self.device = device
+
+ self.batch_size = batch_size
+
+ self.grad_clip = grad_clip
+
+ self.collect_updates_summary = dict()
+
+ # def _get_batch_auxiliary(self):
+ # train_data_batch = self._get_batch(self.auxiliary_dataset['train'])
+ # test_data_batch = self._get_batch(self.auxiliary_dataset['test'])
+ #
+ # return train_data_batch, test_data_batch
+
+ def _get_batch(self, data):
+ prop_ind = np.random.choice(np.where(data['prop'] == 1)[0],
+ self.batch_size,
+ replace=True)
+ x_batch_prop = data['x'][prop_ind, :]
+ y_batch_prop = data['y'][prop_ind, :]
+
+ nprop_ind = np.random.choice(np.where(data['prop'] == 0)[0],
+ self.batch_size,
+ replace=True)
+ x_batch_nprop = data['x'][nprop_ind, :]
+ y_batch_nprop = data['y'][nprop_ind, :]
+
+ return [x_batch_prop, y_batch_prop, x_batch_nprop, y_batch_nprop]
+
+ def get_data_for_dataset_prop_classifier(self, model, local_runs=10):
+
+ previous_para = model.state_dict()
+ self.current_model_para = previous_para
+ for _ in range(local_runs):
+ x_batch_prop, y_batch_prop, x_batch_nprop, y_batch_nprop = self._get_batch(
+ self.auxiliary_dataset)
+ para_update_prop = self._get_parameter_updates(
+ model, previous_para, x_batch_prop, y_batch_prop)
+ prop = torch.tensor([[1]]).to(torch.device(self.device))
+ self.add_parameter_updates(para_update_prop, prop)
+
+ para_update_nprop = self._get_parameter_updates(
+ model, previous_para, x_batch_nprop, y_batch_nprop)
+ prop = torch.tensor([[0]]).to(torch.device(self.device))
+ self.add_parameter_updates(para_update_nprop, prop)
+
+ def _get_parameter_updates(self, model, previous_para, x_batch, y_batch):
+
+ model = copy.deepcopy(model)
+ # get last phase model parameters
+ # print(model)
+ model = model.to(torch.device(self.device))
+ model.load_state_dict(previous_para, strict=False)
+
+ optimizer = get_optimizer(type=self.fl_type_optimizer,
+ model=model,
+ lr=self.fl_lr)
+
+ for _ in range(self.fl_local_update_num):
+ optimizer.zero_grad()
+ loss_auxiliary_prop = self.fl_model_criterion(
+ model(torch.Tensor(x_batch).to(torch.device(self.device))),
+ torch.Tensor(y_batch).to(torch.device(self.device)))
+ loss_auxiliary_prop.backward()
+ if self.grad_clip > 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(),
+ self.grad_clip)
+ optimizer.step()
+
+ para_prop = model.cpu().state_dict()
+ # print('update: ')
+ # print(para_prop)
+
+ updates_prop = torch.hstack([
+ (previous_para[name] - para_prop[name]).flatten().cpu()
+ for name in previous_para.keys()
+ ])
+ model.load_state_dict(previous_para, strict=False)
+ return updates_prop
+
+ def collect_updates(self, previous_para, updated_parameter, round,
+ client_id):
+
+ updates_prop = torch.hstack([
+ (previous_para[name] - updated_parameter[name]).flatten().cpu()
+ for name in previous_para.keys()
+ ])
+ if round not in self.collect_updates_summary.keys():
+ self.collect_updates_summary[round] = dict()
+ self.collect_updates_summary[round][client_id] = updates_prop
+
+ def add_parameter_updates(self, parameter_updates, prop):
+ '''
+
+ Args:
+ parameter_updates: Tensor with dimension n * d_feature
+ prop: Tensor with dimension n * 1
+
+ Returns:
+
+ '''
+ if self.dataset_prop_classifier['x'] is None:
+ self.dataset_prop_classifier['x'] = parameter_updates.cpu()
+ self.dataset_prop_classifier['y'] = prop.reshape([-1]).cpu()
+ else:
+ self.dataset_prop_classifier['x'] = torch.vstack(
+ (self.dataset_prop_classifier['x'], parameter_updates.cpu()))
+ self.dataset_prop_classifier['y'] = torch.vstack(
+ (self.dataset_prop_classifier['y'], prop.cpu()))
+
+ def train_property_classifier(self):
+ from sklearn.model_selection import train_test_split
+ x_train, x_test, y_train, y_test = train_test_split(
+ self.dataset_prop_classifier['x'],
+ self.dataset_prop_classifier['y'],
+ test_size=0.33,
+ random_state=42)
+ self.classifier.fit(x_train, y_train)
+
+ y_pred = self.property_inference(x_test)
+ from sklearn.metrics import accuracy_score
+ accuracy = accuracy_score(y_true=y_test, y_pred=y_pred)
+ logger.info(
+ '=============== PIA accuracy on auxiliary test dataset: {}'.
+ format(accuracy))
+
+ def property_inference(self, parameter_updates):
+ return self.classifier.predict(parameter_updates)
+
+ def infer_collected(self):
+ pia_results = dict()
+
+ for round in self.collect_updates_summary.keys():
+ for id in self.collect_updates_summary[round].keys():
+ if round not in pia_results.keys():
+ pia_results[round] = dict()
+ pia_results[round][id] = self.property_inference(
+ self.collect_updates_summary[round][id].reshape(1, -1))
+ return pia_results
diff --git a/federatedscope/attack/privacy_attacks/reconstruction_opt.py b/federatedscope/attack/privacy_attacks/reconstruction_opt.py
new file mode 100644
index 000000000..664434b33
--- /dev/null
+++ b/federatedscope/attack/privacy_attacks/reconstruction_opt.py
@@ -0,0 +1,270 @@
+import torch
+import torch.nn.functional as F
+from federatedscope.attack.auxiliary.utils import iDLG_trick, total_variation, get_info_diff_loss
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class DLG(object):
+ """Implementation of the paper "Deep Leakage from Gradients": https://papers.nips.cc/paper/2019/file/60a6c4002cc7b29142def8871531281a-Paper.pdf
+
+ References:
+
+ Zhu, Ligeng, Zhijian Liu, and Song Han. "Deep leakage from gradients." Advances in Neural Information Processing Systems 32 (2019).
+
+ Args:
+ - max_ite (int): the max iteration number;
+ - lr (float): learning rate in optimization based reconstruction;
+ - federate_loss_fn (object): The loss function used in FL training;
+ - device (str): the device running the reconstruction;
+ - federate_method (str): The federated learning method;
+ - federate_lr (float):The learning rate used in FL training; default None.
+ - optim (str): The optimization method used in reconstruction; default: "Adam"; supported: 'sgd', 'adam', 'lbfgs'
+ - info_diff_type (str): The type of loss between the ground-truth gradient/parameter updates info and the reconstructed info; default: "l2"
+ - is_one_hot_label (bool): whether the label is one-hot; default: False
+
+
+ """
+ def __init__(self,
+ max_ite,
+ lr,
+ federate_loss_fn,
+ device,
+ federate_method,
+ federate_lr=None,
+ optim='Adam',
+ info_diff_type='l2',
+ is_one_hot_label=False):
+
+ if federate_method.lower() == "fedavg":
+ # check whether the received info is parameter. If yes, the reconstruction attack requires the learning rate of FL
+ assert federate_lr is not None
+
+ self.info_is_para = federate_method.lower() == "fedavg"
+ self.federate_lr = federate_lr
+
+ self.max_ite = max_ite
+ self.lr = lr
+ self.device = device
+ self.optim = optim
+ self.federate_loss_fn = federate_loss_fn
+ self.info_diff_type = info_diff_type
+ self.info_diff_loss = get_info_diff_loss(info_diff_type)
+
+ self.is_one_hot_label = is_one_hot_label
+
+ def eval(self):
+ pass
+
+ def _setup_optimizer(self, parameters):
+ if self.optim.lower() == 'adam':
+ optimizer = torch.optim.Adam(parameters, lr=self.lr)
+ elif self.optim.lower() == 'sgd': # actually gd
+ optimizer = torch.optim.SGD(parameters,
+ lr=self.lr,
+ momentum=0.9,
+ nesterov=True)
+ elif self.optim.lower() == 'lbfgs':
+ optimizer = torch.optim.LBFGS(parameters)
+ else:
+ raise ValueError()
+ return optimizer
+
+ def _gradient_closure(self, model, optimizer, dummy_data, dummy_label,
+ original_info):
+ def closure():
+ optimizer.zero_grad()
+ model.zero_grad()
+
+ loss = self.federate_loss_fn(
+ model(dummy_data),
+ dummy_label.view(-1, ).type(torch.LongTensor).to(
+ torch.device(self.device)))
+
+ gradient = torch.autograd.grad(loss,
+ model.parameters(),
+ create_graph=True)
+ info_diff = 0
+ for g_dumby, gt in zip(gradient, original_info):
+ info_diff += self.info_diff_loss(g_dumby, gt)
+ info_diff.backward()
+ return info_diff
+
+ return closure
+
+ def _run_simple_reconstruct(self, model, optimizer, dummy_data, label,
+ original_gradient, closure_fn):
+
+ for ite in range(self.max_ite):
+ closure = closure_fn(model, optimizer, dummy_data, label,
+ original_gradient)
+ info_diff = optimizer.step(closure)
+
+ if (ite + 1 == self.max_ite) or ite % 20 == 0:
+ logger.info('Ite: {}, gradient difference: {:.4f}'.format(
+ ite, info_diff))
+ return dummy_data.detach(), label.detach()
+
+ def get_original_gradient_from_para(self, model, original_info,
+ model_para_name):
+ '''
+
+ Transfer the model parameter updates to gradient based on:
+
+ .. math::
+ P_{t} = P - \eta g,
+ where
+ :math:`P_{t}` is the parameters updated by the client at current round;
+ :math:`P` is the parameters of the global model at the end of the last round;
+ :math:`\eta` is the learning rate of clients' local training;
+ :math:`g` is the gradient
+
+
+
+ Arguments:
+ - model (object): The model owned by the Server
+ - original_info (dict): The model parameter updates received by Server
+ - model_para_name (list): The list of model name. Be sure the :attr:`model_para_name` is consistent with the the key name in :attr:`original_info`
+
+ :returns:
+ - original_gradient (list): the list of the gradient corresponding to the model updates
+
+ '''
+ original_gradient = [
+ ((original_para -
+ original_info[name].to(torch.device(self.device))) /
+ self.federate_lr).detach()
+ for original_para, name in zip(model.parameters(), model_para_name)
+ ]
+ return original_gradient
+
+ def reconstruct(self, model, original_info, data_feature_dim, num_class,
+ batch_size):
+ '''
+ Reconstruct the original training data and label.
+
+ Args:
+ model: The model used in FL; Type: object
+ original_info: The message received to perform reconstruction, usually the gradient/parameter updates; Type: list
+ data_feature_dim: The feature dimension of dataset; Type: list or Tensor.Size
+ num_class: the number of total classes in the dataset; Type: int
+ batch_size: the number of samples in the batch that generate the original_info; Type: int
+
+ :returns:
+ - The reconstructed data (Tensor); Size: [batch_size, data_feature_dim]
+ - The reconstructed label (Tensor): Size: [batch_size]
+
+
+ '''
+ # inital dummy data and label
+ dummy_data_dim = [batch_size]
+ dummy_data_dim.extend(data_feature_dim)
+ dummy_data = torch.randn(dummy_data_dim).to(torch.device(
+ self.device)).requires_grad_(True)
+
+ para_trainable_name = []
+ for p in model.named_parameters():
+ para_trainable_name.append(p[0])
+
+ if self.info_is_para:
+ original_gradient = self.get_original_gradient_from_para(
+ model, original_info, model_para_name=para_trainable_name)
+ else:
+ original_gradient = [
+ grad.to(torch.device(self.device)) for k, grad in original_info
+ ]
+
+ label = iDLG_trick(original_gradient,
+ num_class=num_class,
+ is_one_hot_label=self.is_one_hot_label)
+ label = label.to(torch.device(self.device))
+
+ # setup optimizer
+ optimizer = self._setup_optimizer([dummy_data])
+
+ self._run_simple_reconstruct(model,
+ optimizer,
+ dummy_data,
+ label=label,
+ original_gradient=original_gradient,
+ closure_fn=self._gradient_closure)
+
+ return dummy_data.detach(), label.detach()
+
+
+class InvertGradient(DLG):
+ '''
+ The implementation of "Inverting Gradients - How easy is it to break privacy in federated learning?".
+ Link: https://proceedings.neurips.cc/paper/2020/hash/c4ede56bbd98819ae6112b20ac6bf145-Abstract.html
+
+ References:
+
+ Geiping, Jonas, et al. "Inverting gradients-how easy is it to break privacy in federated learning?." Advances in Neural Information Processing Systems 33 (2020): 16937-16947.
+
+ Args:
+ - max_ite (int): the max iteration number;
+ - lr (float): learning rate in optimization based reconstruction;
+ - federate_loss_fn (object): The loss function used in FL training;
+ - device (str): the device running the reconstruction;
+ - federate_method (str): The federated learning method;
+ - federate_lr (float): The learning rate used in FL training; default: None.
+ - alpha_TV (float): the hyper-parameter of the total variance term; default: 0.001
+ - info_diff_type (str): The type of loss between the ground-truth gradient/parameter updates info and the reconstructed info; default: "l2"
+ - optim (str): The optimization method used in reconstruction; default: "Adam"; supported: 'sgd', 'adam', 'lbfgs'
+ - info_diff_type (str): The type of loss between the ground-truth gradient/parameter updates info and the reconstructed info; default: "l2"
+ - is_one_hot_label (bool): whether the label is one-hot; default: False
+ '''
+ def __init__(self,
+ max_ite,
+ lr,
+ federate_loss_fn,
+ device,
+ federate_method,
+ federate_lr=None,
+ alpha_TV=0.001,
+ info_diff_type='sim',
+ optim='Adam',
+ is_one_hot_label=False):
+ super(InvertGradient, self).__init__(max_ite,
+ lr,
+ federate_loss_fn,
+ device,
+ federate_method,
+ federate_lr=federate_lr,
+ optim=optim,
+ info_diff_type=info_diff_type,
+ is_one_hot_label=is_one_hot_label)
+ self.alpha_TV = alpha_TV
+ if self.info_diff_type != 'sim':
+ logger.info(
+ 'Force the info_diff_type to be cosine similarity loss in InvertGradient attack method!'
+ )
+ self.info_diff_type = 'sim'
+ self.info_diff_loss = get_info_diff_loss(self.info_diff_type)
+
+ def _gradient_closure(self, model, optimizer, dummy_data, dummy_label,
+ original_gradient):
+ def closure():
+ optimizer.zero_grad()
+ model.zero_grad()
+ loss = self.federate_loss_fn(
+ model(dummy_data),
+ dummy_label.view(-1, ).type(torch.LongTensor).to(
+ torch.device(self.device)))
+
+ gradient = torch.autograd.grad(loss,
+ model.parameters(),
+ create_graph=True)
+ gradient_diff = 0
+
+ for g_dummy, gt in zip(gradient, original_gradient):
+ gradient_diff += self.info_diff_loss(g_dummy, gt)
+
+ # add total variance regularization
+ if self.alpha_TV > 0:
+ gradient_diff += self.alpha_TV * total_variation(dummy_data)
+ gradient_diff.backward()
+ return gradient_diff
+
+ return closure
diff --git a/federatedscope/attack/trainer/GAN_trainer.py b/federatedscope/attack/trainer/GAN_trainer.py
new file mode 100644
index 000000000..a1b2b0393
--- /dev/null
+++ b/federatedscope/attack/trainer/GAN_trainer.py
@@ -0,0 +1,103 @@
+import logging
+from typing import Type
+
+from federatedscope.core.trainers import GeneralTorchTrainer
+from federatedscope.attack.privacy_attacks.GAN_based_attack import GANCRA
+
+logger = logging.getLogger(__name__)
+
+
+def wrap_GANTrainer(
+ base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
+ '''
+ Warp the trainer for gan_based class representative attack.
+
+ Args:
+ base_trainer: Type: core.trainers.GeneralTorchTrainer
+
+ :returns:
+ The wrapped trainer; Type: core.trainers.GeneralTorchTrainer
+
+ '''
+
+ # ---------------- attribute-level plug-in -----------------------
+
+ base_trainer.ctx.target_label_ind = base_trainer.cfg.attack.target_label_ind
+ base_trainer.ctx.gan_cra = GANCRA(base_trainer.cfg.attack.target_label_ind,
+ base_trainer.ctx.model,
+ dataset_name=base_trainer.cfg.data.type,
+ device=base_trainer.ctx.device,
+ sav_pth=base_trainer.cfg.outdir)
+
+ # ---- action-level plug-in -------
+
+ base_trainer.register_hook_in_train(new_hook=hood_on_fit_start_generator,
+ trigger='on_fit_start',
+ insert_mode=-1)
+ base_trainer.register_hook_in_train(new_hook=hook_on_gan_cra_train,
+ trigger='on_batch_start',
+ insert_mode=-1)
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_batch_injected_data_generation,
+ trigger='on_batch_start',
+ insert_mode=-1)
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_batch_forward_injected_data,
+ trigger='on_batch_forward',
+ insert_mode=-1)
+
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_data_injection_sav_data,
+ trigger='on_fit_end',
+ insert_mode=-1)
+
+ return base_trainer
+
+
+def hood_on_fit_start_generator(ctx):
+ '''
+ count the FL training round before fitting
+ Args:
+ ctx ():
+
+ Returns:
+
+ '''
+ ctx.gan_cra.round_num += 1
+ logger.info('----- Round {}: GAN training ............'.format(
+ ctx.gan_cra.round_num))
+
+
+def hook_on_batch_forward_injected_data(ctx):
+ '''
+ inject the generated data into training batch loss
+ Args:
+ ctx ():
+
+ Returns:
+
+ '''
+ x, label = [_.to(ctx.device) for _ in ctx.injected_data]
+ pred = ctx.model(x)
+ if len(label.size()) == 0:
+ label = label.unsqueeze(0)
+ ctx.loss_task += ctx.criterion(pred, label)
+ ctx.y_true_injected = label
+ ctx.y_prob_injected = pred
+
+
+def hook_on_batch_injected_data_generation(ctx):
+ '''generate the injected data
+ '''
+ ctx.injected_data = ctx.gan_cra.generate_fake_data()
+
+
+def hook_on_gan_cra_train(ctx):
+
+ ctx.gan_cra.update_discriminator(ctx.model)
+ ctx.gan_cra.generator_train()
+
+
+def hook_on_data_injection_sav_data(ctx):
+
+ ctx.gan_cra.generate_and_save_images()
diff --git a/federatedscope/attack/trainer/MIA_invert_gradient_trainer.py b/federatedscope/attack/trainer/MIA_invert_gradient_trainer.py
new file mode 100644
index 000000000..171d8936f
--- /dev/null
+++ b/federatedscope/attack/trainer/MIA_invert_gradient_trainer.py
@@ -0,0 +1,124 @@
+import logging
+from typing import Type
+
+import torch
+
+from federatedscope.core.trainers import GeneralTorchTrainer
+from federatedscope.core.auxiliaries.dataloader_builder import WrapDataset
+from federatedscope.attack.auxiliary.MIA_get_target_data import get_target_data
+
+logger = logging.getLogger(__name__)
+
+
+def wrap_GradientAscentTrainer(
+ base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
+ '''
+ wrap the gradient_invert trainer
+
+ Args:
+ base_trainer: Type: core.trainers.GeneralTorchTrainer
+
+ :returns:
+ The wrapped trainer; Type: core.trainers.GeneralTorchTrainer
+
+ '''
+
+ # base_trainer.ctx.target_data = get_target_data()
+ base_trainer.ctx.target_data_dataloader = WrapDataset(
+ get_target_data(base_trainer.cfg.data.type))
+ base_trainer.ctx.target_data = get_target_data(base_trainer.cfg.data.type)
+
+ base_trainer.ctx.is_target_batch = False
+ base_trainer.ctx.finish_injected = False
+
+ base_trainer.ctx.target_data_loss = []
+
+ base_trainer.ctx.outdir = base_trainer.cfg.outdir
+ base_trainer.ctx.round = -1
+ base_trainer.ctx.inject_round = base_trainer.cfg.attack.inject_round
+
+ base_trainer.register_hook_in_train(new_hook=hook_on_fit_start_count_round,
+ trigger='on_fit_start',
+ insert_mode=-1)
+
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_batch_start_replace_data_batch,
+ trigger='on_batch_start',
+ insert_mode=-1)
+
+ base_trainer.replace_hook_in_train(
+ new_hook=hook_on_batch_backward_invert_gradient,
+ target_trigger='on_batch_backward',
+ target_hook_name='_hook_on_batch_backward')
+
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_fit_start_loss_on_target_data,
+ trigger='on_fit_start',
+ insert_mode=-1)
+
+ # plot the target data loss at the end of fitting
+
+ return base_trainer
+
+
+def hook_on_fit_start_count_round(ctx):
+ ctx.round += 1
+ logger.info("============== round: {} ====================".format(
+ ctx.round))
+
+
+def hook_on_batch_start_replace_data_batch(ctx):
+ # replace the data batch to the target data
+ # check whether need to replace the data; if yes, replace the current batch to target batch
+ if ctx.finish_injected == False and ctx.round >= ctx.inject_round:
+ logger.info("---------- inject the target data ---------")
+ ctx["data_batch"] = ctx.target_data
+ ctx.is_target_batch = True
+ logger.info(ctx.target_data[0].size())
+ else:
+ ctx.is_target_batch = False
+
+
+def hook_on_batch_backward_invert_gradient(ctx):
+ if ctx.is_target_batch:
+ # if the current data batch is the target data, perform gradient ascent
+ ctx.optimizer.zero_grad()
+ ctx.loss_batch.backward()
+ original_grad = []
+
+ for param in ctx["model"].parameters():
+ original_grad.append(param.grad.detach())
+ param.grad = -1 * param.grad
+
+ modified_grad = []
+ for param in ctx.model.parameters():
+ modified_grad.append(param.grad.detach())
+
+ ctx["optimizer"].step()
+ logger.info('-------------- Gradient ascent finished -------------')
+ ctx.finish_injected = True
+
+ else:
+ # if current batch is not target data, perform regular backward step
+ ctx.optimizer.zero_grad()
+ ctx.loss_task.backward()
+ if ctx.grad_clip > 0:
+ torch.nn.utils.clip_grad_norm_(ctx.model.parameters(),
+ ctx.grad_clip)
+ ctx.optimizer.step()
+
+
+def hook_on_fit_start_loss_on_target_data(ctx):
+ if ctx.finish_injected:
+ tmp_loss = []
+ x, label = [_.to(ctx.device) for _ in ctx.target_data]
+ logger.info(x.size())
+ num_target = x.size()[0]
+
+ for i in range(num_target):
+ x_i = x[i, :].unsqueeze(0)
+ label_i = label[i].reshape(-1)
+ pred = ctx.model(x_i)
+ tmp_loss.append(
+ ctx.criterion(pred, label_i).detach().cpu().numpy())
+ ctx.target_data_loss.append(tmp_loss)
diff --git a/federatedscope/attack/trainer/PIA_trainer.py b/federatedscope/attack/trainer/PIA_trainer.py
new file mode 100644
index 000000000..d0826b30e
--- /dev/null
+++ b/federatedscope/attack/trainer/PIA_trainer.py
@@ -0,0 +1,18 @@
+from typing import Type
+
+from federatedscope.core.trainers import GeneralTorchTrainer
+from federatedscope.attack.auxiliary.utils import get_data_property
+
+
+def wrap_ActivePIATrainer(
+ base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
+ base_trainer.ctx.alpha_prop_loss = base_trainer._cfg.attack.alpha_prop_loss
+
+
+def hood_on_batch_start_get_prop(ctx):
+ ctx.prop = get_data_property(ctx.data_batch)
+
+
+def hook_on_batch_forward_add_PIA_loss(ctx):
+ ctx.loss_batch = ctx.alpha_prop_loss * ctx.loss_batch + (
+ 1 - ctx.alpha_prop_loss) * ctx.criterion(ctx.y_prob, ctx.prop)
diff --git a/federatedscope/attack/trainer/__init__.py b/federatedscope/attack/trainer/__init__.py
new file mode 100644
index 000000000..37d4d78ef
--- /dev/null
+++ b/federatedscope/attack/trainer/__init__.py
@@ -0,0 +1,16 @@
+from federatedscope.attack.trainer.GAN_trainer import *
+from federatedscope.attack.trainer.MIA_invert_gradient_trainer import *
+from federatedscope.attack.trainer.PIA_trainer import *
+from federatedscope.attack.trainer.backdoor_trainer import *
+from federatedscope.attack.trainer.benign_trainer import *
+
+__all__ = [
+ 'wrap_GANTrainer', 'hood_on_fit_start_generator',
+ 'hook_on_batch_forward_injected_data',
+ 'hook_on_batch_injected_data_generation', 'hook_on_gan_cra_train',
+ 'hook_on_data_injection_sav_data', 'wrap_GradientAscentTrainer',
+ 'hook_on_fit_start_count_round', 'hook_on_batch_start_replace_data_batch',
+ 'hook_on_batch_backward_invert_gradient',
+ 'hook_on_fit_start_loss_on_target_data', 'wrap_backdoorTrainer',
+ 'wrap_benignTrainer'
+]
diff --git a/federatedscope/attack/trainer/backdoor_trainer.py b/federatedscope/attack/trainer/backdoor_trainer.py
new file mode 100644
index 000000000..30014faff
--- /dev/null
+++ b/federatedscope/attack/trainer/backdoor_trainer.py
@@ -0,0 +1,205 @@
+import logging
+from typing import Type
+import torch
+import numpy as np
+import copy
+
+from federatedscope.core.trainers import GeneralTorchTrainer
+from torch.nn.utils import parameters_to_vector, vector_to_parameters
+
+logger = logging.getLogger(__name__)
+
+
+def wrap_backdoorTrainer(
+ base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
+ '''
+ Warp the trainer for backdoor attack:
+
+ poisoning data:
+ edge-case triggers
+ semantic triggers
+ pixel-wise triggers: badnet, blended(HK), sig, wanet, clean-label (narcissus)
+
+ poisoning model:
+ black-box attacks
+ PGD training
+ local regularization
+
+ Args:
+ base_trainer: Type: core.trainers.GeneralTorchTrainer
+
+ :returns:
+ The wrapped trainer; Type: core.trainers.GeneralTorchTrainer
+
+ '''
+
+ # ---------------- attribute-level plug-in -----------------------
+ # for pFL, we need to know the type of used methods.
+ base_trainer.ctx.federate_method = base_trainer.cfg.federate.method
+ base_trainer.ctx.target_label_ind = base_trainer.cfg.attack.target_label_ind
+ base_trainer.ctx.trigger_type = base_trainer.cfg.attack.trigger_type
+ base_trainer.ctx.label_type = base_trainer.cfg.attack.label_type
+ '''
+ You can add trigger type: edge-case triggers and semantic triggers.
+ '''
+
+ # ---- action-level plug-in -------
+
+ if base_trainer.cfg.attack.self_opt:
+
+ base_trainer.ctx.self_lr = base_trainer.cfg.attack.self_lr
+ base_trainer.ctx.self_epoch = base_trainer.cfg.attack.self_epoch
+
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_fit_start_init_local_opt,
+ trigger='on_fit_start',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_reset_opt,
+ trigger='on_fit_end',
+ insert_pos=0)
+
+ if base_trainer.cfg.attack.scale_poisoning or base_trainer.cfg.attack.pgd_poisoning:
+
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_fit_start_init_local_model,
+ trigger='on_fit_start',
+ insert_pos=-1)
+
+ if base_trainer.cfg.attack.scale_poisoning:
+
+ base_trainer.ctx.scale_para = base_trainer.cfg.attack.scale_para
+
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_fit_end_scale_poisoning,
+ trigger="on_fit_end",
+ insert_pos=-1)
+
+ if base_trainer.cfg.attack.pgd_poisoning:
+
+ base_trainer.ctx.self_epoch = base_trainer.cfg.attack.self_epoch
+ base_trainer.ctx.pgd_lr = base_trainer.cfg.attack.pgd_lr
+ base_trainer.ctx.pgd_eps = base_trainer.cfg.attack.pgd_eps
+ base_trainer.ctx.batch_index = 0
+
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_fit_start_init_local_pgd,
+ trigger='on_fit_start',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_batch_end_project_grad,
+ trigger='on_batch_end',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_epoch_end_project_grad,
+ trigger='on_epoch_end',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_reset_opt,
+ trigger='on_fit_end',
+ insert_pos=0)
+
+ return base_trainer
+
+
+def hook_on_fit_start_init_local_opt(ctx):
+
+ # need to check for ditto method
+ # ctx.original_optimizer = ctx.optimizer
+
+ if ctx.federate_method.lower() == "ditto":
+ ctx.original_epoch = ctx["num_train_epoch"]
+ ctx["num_train_epoch"] = ctx.self_epoch + ctx.num_train_epoch_for_local_model
+
+ elif ctx.federate_method.lower() == "fedrep":
+ ctx.original_epoch = ctx["num_train_epoch"]
+ ctx["num_train_epoch"] = ctx.self_epoch + ctx.epoch_linear
+ else:
+ ctx.original_epoch = ctx["num_train_epoch"]
+ ctx["num_train_epoch"] = ctx.self_epoch
+
+
+def hook_on_fit_end_reset_opt(ctx):
+
+ ctx["num_train_epoch"] = ctx.original_epoch
+
+
+def hook_on_fit_start_init_local_model(ctx):
+
+ ctx.original_model = copy.deepcopy(ctx.model) # the original global model
+
+
+def hook_on_fit_end_scale_poisoning(ctx):
+
+ # conduct the scale poisoning
+ scale_para = ctx.scale_para
+
+ v = torch.nn.utils.parameters_to_vector(ctx.original_model.parameters())
+ logger.info("the Norm of the original global model: {}".format(
+ torch.norm(v)))
+
+ v = torch.nn.utils.parameters_to_vector(ctx.model.parameters())
+ logger.info("Attacker before scaling : Norm = {}".format(torch.norm(v)))
+
+ ctx.original_model = list(ctx.original_model.parameters())
+
+ for idx, param in enumerate(ctx.model.parameters()):
+ param.data = (param.data - ctx.original_model[idx]
+ ) * scale_para + ctx.original_model[idx]
+
+ v = torch.nn.utils.parameters_to_vector(ctx.model.parameters())
+ logger.info("Attacker after scaling : Norm = {}".format(torch.norm(v)))
+
+ logger.info('finishing model scaling poisoning attack'.format())
+
+
+def hook_on_fit_start_init_local_pgd(ctx):
+
+ ctx.original_optimizer = ctx.optimizer
+ ctx.original_epoch = ctx["num_train_epoch"]
+ ctx["num_train_epoch"] = ctx.self_epoch
+ ctx.optimizer = torch.optim.SGD(ctx.model.parameters(), \
+ lr=ctx.pgd_lr, momentum=0.9, weight_decay=1e-4)
+
+
+def hook_on_batch_end_project_grad(ctx):
+
+ eps = ctx.pgd_eps
+ project_frequency = 10
+ ctx.batch_index += 1
+ w = list(ctx.model.parameters())
+ w_vec = parameters_to_vector(w)
+ model_original_vec = parameters_to_vector(
+ list(ctx.original_model.parameters()))
+ # make sure you project on last iteration otherwise, high LR pushes you really far
+ if (ctx.batch_index % project_frequency
+ == 0) and (torch.norm(w_vec - model_original_vec) > eps):
+ # project back into norm ball
+ w_proj_vec = eps * (w_vec - model_original_vec) / torch.norm(
+ w_vec - model_original_vec) + model_original_vec
+ # plug w_proj back into model
+ vector_to_parameters(w_proj_vec, w)
+
+
+def hook_on_epoch_end_project_grad(ctx):
+
+ ctx.batch_index = 0
+ eps = ctx.pgd_eps
+ w = list(ctx.model.parameters())
+ w_vec = parameters_to_vector(w)
+ model_original_vec = parameters_to_vector(
+ list(ctx.original_model.parameters()))
+ # make sure you project on last iteration otherwise, high LR pushes you really far
+ if (torch.norm(w_vec - model_original_vec) > eps):
+ # project back into norm ball
+ w_proj_vec = eps * (w_vec - model_original_vec) / torch.norm(
+ w_vec - model_original_vec) + model_original_vec
+ # plug w_proj back into model
+ vector_to_parameters(w_proj_vec, w)
+
+
+def hook_on_fit_end_reset_pgd(ctx):
+
+ pass
diff --git a/federatedscope/attack/trainer/benign_trainer.py b/federatedscope/attack/trainer/benign_trainer.py
new file mode 100644
index 000000000..0bd093366
--- /dev/null
+++ b/federatedscope/attack/trainer/benign_trainer.py
@@ -0,0 +1,202 @@
+from calendar import c
+from email.mime import base
+import logging
+from typing import Type
+import torch
+import numpy as np
+import copy
+
+from federatedscope.core.trainers import GeneralTorchTrainer
+from federatedscope.core.auxiliaries.transform_builder import get_transform
+from federatedscope.attack.auxiliary.backdoor_utils import normalize
+from federatedscope.core.auxiliaries.dataloader_builder import WrapDataset
+from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader
+from federatedscope.core.auxiliaries.ReIterator import ReIterator
+
+logger = logging.getLogger(__name__)
+
+
+def wrap_benignTrainer(
+ base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
+ '''
+ Warp the benign trainer for backdoor attack:
+
+ We just add the normalization operation.
+
+ Args:
+ base_trainer: Type: core.trainers.GeneralTorchTrainer
+
+ :returns:
+ The wrapped trainer; Type: core.trainers.GeneralTorchTrainer
+
+ '''
+
+ if base_trainer.cfg.attack.norm_clip:
+
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_fit_start_init_local_model,
+ trigger='on_fit_start',
+ insert_pos=-1)
+
+ base_trainer.ctx.norm_clip_value = base_trainer.cfg.attack.norm_clip_value
+
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_fit_end_clip_model,
+ trigger='on_fit_end',
+ insert_pos=-1)
+
+ if base_trainer.cfg.attack.dp_noise > 0.0:
+ base_trainer.ctx.dp_noise = base_trainer.cfg.attack.dp_noise
+
+ else:
+ base_trainer.ctx.dp_noise = 0.0
+
+ return base_trainer
+
+
+def get_weight_difference(weight1, weight2):
+ difference = {}
+ res = []
+ if type(weight2) == dict:
+ for name, layer in weight1.items():
+ difference[name] = layer.data - weight2[name].data
+ res.append(difference[name].view(-1))
+ else:
+ for name, layer in weight2:
+ difference[name] = weight1[name].data - layer.data
+ res.append(difference[name].view(-1))
+
+ difference_flat = torch.cat(res)
+
+ return difference, difference_flat
+
+
+def get_l2_norm(weight1, weight2):
+ difference = {}
+ res = []
+ if type(weight2) == dict:
+ for name, layer in weight1.items():
+ difference[name] = layer.data - weight2[name].data
+ res.append(difference[name].view(-1))
+ else:
+ for name, layer in weight2:
+ difference[name] = weight1[name].data - layer.data
+ res.append(difference[name].view(-1))
+
+ difference_flat = torch.cat(res)
+
+ l2_norm = torch.norm(difference_flat.clone().detach().cuda())
+
+ l2_norm_np = np.linalg.norm(difference_flat.cpu().numpy())
+
+ return l2_norm, l2_norm_np
+
+
+def clip_grad(norm_bound, weight_difference, difference_flat):
+
+ l2_norm = torch.norm(difference_flat.clone().detach().cuda())
+ scale = max(1.0, float(torch.abs(l2_norm / norm_bound)))
+ for name in weight_difference.keys():
+ weight_difference[name].div_(scale)
+
+ return weight_difference, l2_norm
+
+
+def copy_params(model, target_params_variables):
+ for name, layer in model.named_parameters():
+ layer.data = copy.deepcopy(target_params_variables[name])
+
+
+def hook_on_fit_start_init_local_model(ctx):
+
+ ctx.global_model_copy = dict()
+ for name, param in ctx.model.named_parameters():
+ ctx.global_model_copy[name] = ctx.model.state_dict()[name].clone(
+ ).detach().requires_grad_(False)
+
+
+def hook_on_fit_end_clip_model(ctx):
+
+ l2_norm, l2_norm_np = get_l2_norm(ctx.global_model_copy,
+ ctx.model.named_parameters())
+ logger.info('l2 norm of local model (before server defense):{}'.format(
+ l2_norm.item()))
+ weight_difference, difference_flat = get_weight_difference(
+ ctx.global_model_copy, ctx.model.named_parameters())
+ clipped_weight_difference, _ = clip_grad(ctx.norm_clip_value,
+ weight_difference,
+ difference_flat)
+
+ for key_, para in clipped_weight_difference.items():
+ clipped_weight_difference[
+ key_] = para.data + ctx.dp_noise * torch.rand_like(
+ copy.deepcopy(para.data))
+
+ weight_difference, difference_flat = get_weight_difference(
+ ctx.global_model_copy, clipped_weight_difference)
+ copy_params(ctx.model, weight_difference)
+
+ l2_norm, l2_norm_np = get_l2_norm(ctx.global_model_copy,
+ ctx.model.named_parameters())
+ logger.info('l2 norm of local model (after server defense):{}'.format(
+ l2_norm.item()))
+
+
+def hook_on_fit_end_test_poison(ctx):
+ """Evaluate metrics of poisoning attacks.
+
+ """
+
+ ctx['poison_' + ctx.cur_data_split +
+ '_loader'] = ctx.data['poison_' + ctx.cur_data_split]
+ ctx['poison_' + ctx.cur_data_split +
+ '_data'] = ctx.data['poison_' + ctx.cur_data_split].dataset
+ ctx['num_poison_' + ctx.cur_data_split + '_data'] = len(
+ ctx.data['poison_' + ctx.cur_data_split].dataset)
+ setattr(ctx, "poison_{}_y_true".format(ctx.cur_data_split), [])
+ setattr(ctx, "poison_{}_y_prob".format(ctx.cur_data_split), [])
+ setattr(ctx, "poison_num_samples_{}".format(ctx.cur_data_split), 0)
+
+ for batch_idx, (samples, targets) in enumerate(
+ ctx['poison_' + ctx.cur_data_split + '_loader']):
+ samples, targets = samples.to(ctx.device), targets.to(ctx.device)
+ pred = ctx.model(samples)
+ if len(targets.size()) == 0:
+ targets = targets.unsqueeze(0)
+ ctx.poison_y_true = targets
+ ctx.poison_y_prob = pred
+ ctx.poison_batch_size = len(targets)
+
+ ctx.get("poison_{}_y_true".format(ctx.cur_data_split)).append(
+ ctx.poison_y_true.detach().cpu().numpy())
+
+ ctx.get("poison_{}_y_prob".format(ctx.cur_data_split)).append(
+ ctx.poison_y_prob.detach().cpu().numpy())
+
+ setattr(
+ ctx, "poison_num_samples_{}".format(ctx.cur_data_split),
+ ctx.get("poison_num_samples_{}".format(ctx.cur_data_split)) +
+ ctx.poison_batch_size)
+
+ setattr(
+ ctx, "poison_{}_y_true".format(ctx.cur_data_split),
+ np.concatenate(ctx.get("poison_{}_y_true".format(ctx.cur_data_split))))
+ setattr(
+ ctx, "poison_{}_y_prob".format(ctx.cur_data_split),
+ np.concatenate(ctx.get("poison_{}_y_prob".format(ctx.cur_data_split))))
+
+ logger.info('the {} poisoning samples: {:d}'.format(
+ ctx.cur_data_split,
+ ctx.get("poison_num_samples_{}".format(ctx.cur_data_split))))
+
+ poison_true = ctx['poison_' + ctx.cur_data_split + '_y_true']
+ poison_prob = ctx['poison_' + ctx.cur_data_split + '_y_prob']
+
+ poison_pred = np.argmax(poison_prob, axis=1)
+
+ correct = poison_true == poison_pred
+
+ poisoning_acc = float(np.sum(correct)) / len(correct)
+
+ logger.info('the {} poisoning accuracy: {:f}'.format(
+ ctx.cur_data_split, poisoning_acc))
diff --git a/federatedscope/attack/worker_as_attacker/__init__.py b/federatedscope/attack/worker_as_attacker/__init__.py
new file mode 100644
index 000000000..ee3a8f3a2
--- /dev/null
+++ b/federatedscope/attack/worker_as_attacker/__init__.py
@@ -0,0 +1,12 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+from federatedscope.attack.worker_as_attacker.active_client import *
+from federatedscope.attack.worker_as_attacker.server_attacker import *
+
+__all__ = [
+ 'plot_target_loss', 'sav_target_loss', 'callback_funcs_for_finish',
+ 'add_atk_method_to_Client_GradAscent', 'PassiveServer', 'PassivePIAServer',
+ 'BackdoorServer'
+]
diff --git a/federatedscope/attack/worker_as_attacker/active_client.py b/federatedscope/attack/worker_as_attacker/active_client.py
new file mode 100644
index 000000000..ddd0912dc
--- /dev/null
+++ b/federatedscope/attack/worker_as_attacker/active_client.py
@@ -0,0 +1,54 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+from federatedscope.core.message import Message
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def plot_target_loss(loss_list, outdir):
+ '''
+
+ Args:
+ loss_list: the list of loss regrading the target data
+ outdir: the directory to store the loss
+
+ '''
+
+ target_data_loss = np.vstack(loss_list)
+ logger.info(target_data_loss.shape)
+ plt.plot(target_data_loss)
+ plt.savefig(os.path.join(outdir, 'target_loss.png'))
+ plt.close()
+
+
+def sav_target_loss(loss_list, outdir):
+ target_data_loss = np.vstack(loss_list)
+ np.savetxt(os.path.join(outdir, 'target_loss.txt'),
+ target_data_loss.transpose(),
+ delimiter=',')
+
+
+def callback_funcs_for_finish(self, message: Message):
+ logger.info(
+ "================= receiving Finish Message ============================"
+ )
+ if message.content != None:
+ self.trainer.update(message.content)
+ if self.is_attacker and self._cfg.attack.attack_method.lower(
+ ) == "gradascent":
+ logger.info(
+ "================= start attack post-processing ======================="
+ )
+ plot_target_loss(self.trainer.ctx.target_data_loss,
+ self.trainer.ctx.outdir)
+ sav_target_loss(self.trainer.ctx.target_data_loss,
+ self.trainer.ctx.outdir)
+
+
+def add_atk_method_to_Client_GradAscent(client_class):
+
+ setattr(client_class, 'callback_funcs_for_finish',
+ callback_funcs_for_finish)
+ return client_class
diff --git a/federatedscope/attack/worker_as_attacker/server_attacker.py b/federatedscope/attack/worker_as_attacker/server_attacker.py
new file mode 100644
index 000000000..57499c030
--- /dev/null
+++ b/federatedscope/attack/worker_as_attacker/server_attacker.py
@@ -0,0 +1,332 @@
+from distutils.command.config import config
+from federatedscope.core.worker import Server
+from federatedscope.core.message import Message
+
+from federatedscope.core.auxiliaries.criterion_builder import get_criterion
+import copy
+from federatedscope.attack.auxiliary.utils import get_data_sav_fn, get_reconstructor
+
+import logging
+
+import numpy as np
+import torch
+from federatedscope.attack.privacy_attacks.passive_PIA import PassivePropertyInference
+
+logger = logging.getLogger(__name__)
+
+
+class BackdoorServer(Server):
+ '''
+ For backdoor attacks, we will choose the different the sampling stratergies.
+ fix-frequency sampling, all-round sampling or random sampling.
+ '''
+ def __init__(self,
+ ID=-1,
+ state=0,
+ config=None,
+ data=None,
+ model=None,
+ client_num=5,
+ total_round_num=10,
+ device='cpu',
+ strategy=None,
+ unseen_clients_id=None,
+ **kwargs):
+ super(BackdoorServer, self).__init__(ID=ID,
+ state=state,
+ data=data,
+ model=model,
+ config=config,
+ client_num=client_num,
+ total_round_num=total_round_num,
+ device=device,
+ strategy=strategy,
+ **kwargs)
+
+ def broadcast_model_para(self,
+ msg_type='model_para',
+ sample_client_num=-1):
+ """
+ To broadcast the message to all clients or sampled clients
+
+ Arguments:
+ msg_type: 'model_para' or other user defined msg_type
+ sample_client_num: the number of sampled clients in the broadcast behavior.
+ And sample_client_num = -1 denotes to broadcast to all the clients.
+ """
+
+ if sample_client_num > 0: # only activated at training process
+
+ if self._cfg.attack.attacker_id == -1 or self._cfg.attack.attack_method == '':
+ receiver = np.random.choice(np.arange(1, self.client_num + 1),
+ size=sample_client_num,
+ replace=False).tolist()
+
+ elif self._cfg.attack.setting == 'fix':
+ if self.state % self._cfg.attack.freq == 0:
+ client_list = np.delete(np.arange(1, self.client_num + 1),
+ self._cfg.attack.attacker_id - 1)
+ receiver = np.random.choice(client_list,
+ size=sample_client_num - 1,
+ replace=False).tolist()
+ receiver.insert(0, self._cfg.attack.attacker_id)
+ logger.info('starting the fix-frequency poisoning attack')
+ logger.info(
+ 'starting the poisoning round: {:d}, the attacker ID: {:d}'
+ .format(self.state, self._cfg.attack.attacker_id))
+ else:
+ client_list = np.delete(np.arange(1, self.client_num + 1),
+ self._cfg.attack.attacker_id - 1)
+ receiver = np.random.choice(client_list,
+ size=sample_client_num,
+ replace=False).tolist()
+
+ elif self._cfg.attack.setting == 'single' and self.state == self._cfg.attack.insert_round:
+ # need to check this setting
+ client_list = np.delete(np.arange(1, self.client_num + 1),
+ self._cfg.attack.attacker_id - 1)
+ receiver = np.random.choice(client_list,
+ size=sample_client_num - 1,
+ replace=False).tolist()
+ receiver.insert(0, self._cfg.attack.attacker_id)
+ logger.info('starting the single-shot poisoning attack')
+ logger.info(
+ 'starting the poisoning round: {:d}, the attacker ID: {:d}'
+ .format(self.state, self._cfg.attack.attacker_id))
+
+ elif self._cfg.attack.setting == 'all':
+ client_list = np.delete(np.arange(1, self.client_num + 1),
+ self._cfg.attack.attacker_id - 1)
+ receiver = np.random.choice(client_list,
+ size=sample_client_num - 1,
+ replace=False).tolist()
+ receiver.insert(0, self._cfg.attack.attacker_id)
+ logger.info('starting the all-round poisoning attack')
+ logger.info(
+ 'starting the poisoning round: {:d}, the attacker ID: {:d}'
+ .format(self.state, self._cfg.attack.attacker_id))
+
+ else:
+ receiver = np.random.choice(np.arange(1, self.client_num + 1),
+ size=sample_client_num,
+ replace=False).tolist()
+
+ else:
+ # broadcast to all clients
+ receiver = list(self.comm_manager.neighbors.keys())
+
+ if self._noise_injector is not None and msg_type == 'model_para':
+ # Inject noise only when broadcast parameters
+ for model_idx_i in range(len(self.models)):
+ num_sample_clients = [
+ v["num_sample"] for v in self.join_in_info.values()
+ ]
+ self._noise_injector(self._cfg, num_sample_clients,
+ self.models[model_idx_i])
+
+ skip_broadcast = self._cfg.federate.method in ["local", "global"]
+ if self.model_num > 1:
+ model_para = [{} if skip_broadcast else model.state_dict()
+ for model in self.models]
+ else:
+ model_para = {} if skip_broadcast else self.model.state_dict()
+
+ self.comm_manager.send(
+ Message(msg_type=msg_type,
+ sender=self.ID,
+ receiver=receiver,
+ state=min(self.state, self.total_round_num),
+ content=model_para))
+ if self._cfg.federate.online_aggr:
+ for idx in range(self.model_num):
+ self.aggregators[idx].reset()
+
+
+class PassiveServer(Server):
+ '''
+ In passive attack, the server store the model and the message collected from the client,and perform the optimization based reconstruction, such as DLG, InvertGradient.
+ '''
+ def __init__(self,
+ ID=-1,
+ state=0,
+ data=None,
+ model=None,
+ client_num=5,
+ total_round_num=10,
+ device='cpu',
+ strategy=None,
+ state_to_reconstruct=None,
+ client_to_reconstruct=None,
+ **kwargs):
+ super(PassiveServer, self).__init__(ID=ID,
+ state=state,
+ data=data,
+ model=model,
+ client_num=client_num,
+ total_round_num=total_round_num,
+ device=device,
+ strategy=strategy,
+ **kwargs)
+
+ # self.offline_reconstruct = offline_reconstruct
+ self.atk_method = self._cfg.attack.attack_method
+ self.state_to_reconstruct = state_to_reconstruct
+ self.client_to_reconstruct = client_to_reconstruct
+ self.reconstruct_data = dict()
+
+ # the loss function of the global model; the global model can be obtained in self.aggregator.model
+ self.model_criterion = get_criterion(self._cfg.criterion.type,
+ device=self.device)
+
+ from federatedscope.attack.auxiliary.utils import get_data_info
+ self.data_dim, self.num_class, self.is_one_hot_label = get_data_info(
+ self._cfg.data.type)
+
+ self.reconstructor = self._get_reconstructor()
+
+ self.reconstructed_data_sav_fn = get_data_sav_fn(self._cfg.data.type)
+
+ self.reconstruct_data_summary = dict()
+
+ def _get_reconstructor(self):
+
+ return get_reconstructor(
+ self.atk_method,
+ max_ite=self._cfg.attack.max_ite,
+ lr=self._cfg.attack.reconstruct_lr,
+ federate_loss_fn=self.model_criterion,
+ device=self.device,
+ federate_lr=self._cfg.optimizer.lr,
+ optim=self._cfg.attack.reconstruct_optim,
+ info_diff_type=self._cfg.attack.info_diff_type,
+ federate_method=self._cfg.federate.method,
+ alpha_TV=self._cfg.attack.alpha_TV)
+
+ def _reconstruct(self, state, sender):
+ # print(self.msg_buffer['train'][state].keys())
+ dummy_data, dummy_label = self.reconstructor.reconstruct(
+ model=copy.deepcopy(self.model).to(torch.device(self.device)),
+ original_info=self.msg_buffer['train'][state][sender][1],
+ data_feature_dim=self.data_dim,
+ num_class=self.num_class,
+ batch_size=self.msg_buffer['train'][state][sender][0])
+ if state not in self.reconstruct_data.keys():
+ self.reconstruct_data[state] = dict()
+ self.reconstruct_data[state][sender] = [
+ dummy_data.cpu(), dummy_label.cpu()
+ ]
+
+ def run_reconstruct(self, state_list=None, sender_list=None):
+
+ if state_list == None:
+ state_list = self.msg_buffer['train'].keys()
+
+ # After FL running, using gradient based reconstruction method to recover client's private training data
+ for state in state_list:
+ if sender_list is None:
+ sender_list = self.msg_buffer['train'][state].keys()
+ for sender in sender_list:
+ logger.info(
+ '------------- reconstruct round:{}, client:{}-----------'.
+ format(state, sender))
+
+ # the context of buffer: self.model_buffer[state]: (sample_size, model_para)
+ self._reconstruct(state, sender)
+
+ def callback_funcs_model_para(self, message: Message):
+ round, sender, content = message.state, message.sender, message.content
+ # For a new round
+ if round not in self.msg_buffer['train'].keys():
+ self.msg_buffer['train'][round] = dict()
+
+ self.msg_buffer['train'][round][sender] = content
+
+ # run reconstruction before the clear of self.msg_buffer
+
+ if self.state_to_reconstruct is None or message.state in self.state_to_reconstruct:
+ if self.client_to_reconstruct is None or message.sender in self.client_to_reconstruct:
+ self.run_reconstruct(state_list=[message.state],
+ sender_list=[message.sender])
+ if self.reconstructed_data_sav_fn is not None:
+ self.reconstructed_data_sav_fn(
+ data=self.reconstruct_data[message.state][
+ message.sender][0],
+ sav_pth=self._cfg.outdir,
+ name='image_state_{}_client_{}.png'.format(
+ message.state, message.sender))
+
+ self.check_and_move_on()
+
+
+class PassivePIAServer(Server):
+ '''
+ The implementation of the batch property classifier, the algorithm 3 in paper: Exploiting Unintended Feature Leakage in Collaborative Learning
+
+ References:
+
+ Melis, Luca, Congzheng Song, Emiliano De Cristofaro and Vitaly Shmatikov. “Exploiting Unintended Feature Leakage in Collaborative Learning.” 2019 IEEE Symposium on Security and Privacy (SP) (2019): 691-706
+ '''
+ def __init__(self,
+ ID=-1,
+ state=0,
+ data=None,
+ model=None,
+ client_num=5,
+ total_round_num=10,
+ device='cpu',
+ strategy=None,
+ **kwargs):
+ super(PassivePIAServer, self).__init__(ID=ID,
+ state=state,
+ data=data,
+ model=model,
+ client_num=client_num,
+ total_round_num=total_round_num,
+ device=device,
+ strategy=strategy,
+ **kwargs)
+
+ # self.offline_reconstruct = offline_reconstruct
+ self.atk_method = self._cfg.attack.attack_method
+ self.pia_attacker = PassivePropertyInference(
+ classier=self._cfg.attack.classifier_PIA,
+ fl_model_criterion=get_criterion(self._cfg.criterion.type,
+ device=self.device),
+ device=self.device,
+ grad_clip=self._cfg.grad.grad_clip,
+ dataset_name=self._cfg.data.type,
+ fl_local_update_num=self._cfg.federate.local_update_steps,
+ # fl_type_optimizer=self._cfg.fedopt.optimizer.type,fedopt.type_optimizer
+ fl_type_optimizer=self._cfg.optimizer.type,
+ fl_lr=self._cfg.optimizer.lr,
+ batch_size=100)
+
+ # self.optimizer = get_optimizer(type=self._cfg.fedopt.type_optimizer, model=self.model,lr=self._cfg.fedopt.optimizer.lr)
+ # print(self.optimizer)
+ def callback_funcs_model_para(self, message: Message):
+ round, sender, content = message.state, message.sender, message.content
+ # For a new round
+ if round not in self.msg_buffer['train'].keys():
+ self.msg_buffer['train'][round] = dict()
+
+ self.msg_buffer['train'][round][sender] = content
+
+ # collect the updates
+ self.pia_attacker.collect_updates(
+ previous_para=self.model.state_dict(),
+ updated_parameter=content[1],
+ round=round,
+ client_id=sender)
+ self.pia_attacker.get_data_for_dataset_prop_classifier(
+ model=self.model)
+
+ if self._cfg.federate.online_aggr:
+ # TODO: put this line to `check_and_move_on`
+ # currently, no way to know the latest `sender`
+ self.aggregator.inc(content)
+ self.check_and_move_on()
+
+ if self.state == self.total_round_num:
+ self.pia_attacker.train_property_classifier()
+ self.pia_results = self.pia_attacker.infer_collected()
+ print(self.pia_results)
diff --git a/federatedscope/autotune/__init__.py b/federatedscope/autotune/__init__.py
new file mode 100644
index 000000000..b0c0109a3
--- /dev/null
+++ b/federatedscope/autotune/__init__.py
@@ -0,0 +1,8 @@
+from federatedscope.autotune.choice_types import Continuous, Discrete
+from federatedscope.autotune.utils import parse_search_space, config2cmdargs, config2str
+from federatedscope.autotune.algos import get_scheduler
+
+__all__ = [
+ 'Continuous', 'Discrete', 'parse_search_space', 'config2cmdargs',
+ 'config2str', 'get_scheduler'
+]
diff --git a/federatedscope/autotune/algos.py b/federatedscope/autotune/algos.py
new file mode 100644
index 000000000..dcf1923e5
--- /dev/null
+++ b/federatedscope/autotune/algos.py
@@ -0,0 +1,468 @@
+import os
+import logging
+from copy import deepcopy
+from contextlib import redirect_stdout
+import threading
+import math
+
+import ConfigSpace as CS
+from yacs.config import CfgNode as CN
+import yaml
+import numpy as np
+
+from federatedscope.core.auxiliaries.utils import setup_seed
+from federatedscope.core.auxiliaries.data_builder import get_data
+from federatedscope.core.auxiliaries.worker_builder import get_client_cls, get_server_cls
+from federatedscope.core.fed_runner import FedRunner
+from federatedscope.autotune.utils import parse_search_space, config2cmdargs, config2str, summarize_hpo_results
+
+logger = logging.getLogger(__name__)
+
+
+def make_trial(trial_cfg):
+ setup_seed(trial_cfg.seed)
+ data, modified_config = get_data(config=trial_cfg.clone())
+ trial_cfg.merge_from_other_cfg(modified_config)
+ trial_cfg.freeze()
+ # TODO: enable client-wise configuration
+ Fed_runner = FedRunner(data=data,
+ server_class=get_server_cls(trial_cfg),
+ client_class=get_client_cls(trial_cfg),
+ config=trial_cfg.clone())
+ results = Fed_runner.run()
+ key1, key2 = trial_cfg.hpo.metric.split('.')
+ return results[key1][key2]
+
+
+class TrialExecutor(threading.Thread):
+ """This class is responsible for executing the FL procedure with a given trial configuration in another thread.
+ """
+ def __init__(self, cfg_idx, signal, returns, trial_config):
+ threading.Thread.__init__(self)
+
+ self._idx = cfg_idx
+ self._signal = signal
+ self._returns = returns
+ self._trial_cfg = trial_config
+
+ def run(self):
+ setup_seed(self._trial_cfg.seed)
+ data, modified_config = get_data(config=self._trial_cfg.clone())
+ self._trial_cfg.merge_from_other_cfg(modified_config)
+ self._trial_cfg.freeze()
+ # TODO: enable client-wise configuration
+ Fed_runner = FedRunner(data=data,
+ server_class=get_server_cls(self._trial_cfg),
+ client_class=get_client_cls(self._trial_cfg),
+ config=self._trial_cfg.clone())
+ results = Fed_runner.run()
+ key1, key2 = self._trial_cfg.hpo.metric.split('.')
+ self._returns['perf'] = results[key1][key2]
+ self._returns['cfg_idx'] = self._idx
+ self._signal.set()
+
+
+def get_scheduler(init_cfg):
+ """To instantiate an scheduler object for conducting HPO
+ Arguments:
+ init_cfg (yacs.Node): configuration.
+ """
+
+ if init_cfg.hpo.scheduler == 'rs':
+ scheduler = ModelFreeBase(init_cfg)
+ elif init_cfg.hpo.scheduler == 'sha':
+ scheduler = SuccessiveHalvingAlgo(init_cfg)
+ elif init_cfg.hpo.scheduler == 'pbt':
+ scheduler = PBT(init_cfg)
+ elif init_cfg.hpo.scheduler == 'wrap_sha':
+ scheduler = SHAWrapFedex(init_cfg)
+ return scheduler
+
+
+class Scheduler(object):
+ """The base class for describing HPO algorithms
+ """
+ def __init__(self, cfg):
+ """
+ Arguments:
+ cfg (yacs.Node): dict like object, where each key-value pair corresponds to a field and its choices.
+ """
+
+ self._cfg = cfg
+ self._search_space = parse_search_space(self._cfg.hpo.ss)
+
+ self._init_configs = self._setup()
+
+ logger.info(self._init_configs)
+
+ def _setup(self):
+ """Prepare the initial configurations based on the search space.
+ """
+ raise NotImplementedError
+
+ def _evaluate(self, configs):
+ """To evaluate (i.e., conduct the FL procedure) for a given collection of configurations.
+ """
+ raise NotImplementedError
+
+ def optimize(self):
+ """To optimize the hyperparameters, that is, executing the HPO algorithm and then returning the results.
+ """
+ raise NotImplementedError
+
+
+class ModelFreeBase(Scheduler):
+ """To attempt a collection of configurations exhaustively.
+ """
+ def _setup(self):
+ self._search_space.seed(self._cfg.seed + 19)
+ return [
+ cfg.get_dictionary()
+ for cfg in self._search_space.sample_configuration(
+ size=self._cfg.hpo.init_cand_num)
+ ]
+
+ def _evaluate(self, configs):
+ if self._cfg.hpo.num_workers:
+ # execute FL in parallel by multi-threading
+ flags = [
+ threading.Event() for _ in range(self._cfg.hpo.num_workers)
+ ]
+ for i in range(len(flags)):
+ flags[i].set()
+ threads = [None for _ in range(len(flags))]
+ thread_results = [dict() for _ in range(len(flags))]
+
+ perfs = [None for _ in range(len(configs))]
+ for i, config in enumerate(configs):
+ available_worker = 0
+ while not flags[available_worker].is_set():
+ available_worker = (available_worker + 1) % len(threads)
+ if thread_results[available_worker]:
+ completed_trial_results = thread_results[available_worker]
+ cfg_idx = completed_trial_results['cfg_idx']
+ perfs[cfg_idx] = completed_trial_results['perf']
+ logger.info(
+ "Evaluate the {}-th config {} and get performance {}".
+ format(cfg_idx, configs[cfg_idx], perfs[cfg_idx]))
+ thread_results[available_worker].clear()
+
+ trial_cfg = self._cfg.clone()
+ trial_cfg.merge_from_list(config2cmdargs(config))
+ flags[available_worker].clear()
+ trial = TrialExecutor(i, flags[available_worker],
+ thread_results[available_worker],
+ trial_cfg)
+ trial.start()
+ threads[available_worker] = trial
+
+ for i in range(len(flags)):
+ if not flags[i].is_set():
+ threads[i].join()
+ for i in range(len(thread_results)):
+ if thread_results[i]:
+ completed_trial_results = thread_results[i]
+ cfg_idx = completed_trial_results['cfg_idx']
+ perfs[cfg_idx] = completed_trial_results['perf']
+ logger.info(
+ "Evaluate the {}-th config {} and get performance {}".
+ format(cfg_idx, configs[cfg_idx], perfs[cfg_idx]))
+ thread_results[i].clear()
+
+ else:
+ perfs = [None] * len(configs)
+ for i, config in enumerate(configs):
+ trial_cfg = self._cfg.clone()
+ trial_cfg.merge_from_list(config2cmdargs(config))
+ perfs[i] = make_trial(trial_cfg)
+ logger.info(
+ "Evaluate the {}-th config {} and get performance {}".
+ format(i, config, perfs[i]))
+
+ return perfs
+
+ def optimize(self):
+ perfs = self._evaluate(self._init_configs)
+
+ results = summarize_hpo_results(self._init_configs,
+ perfs,
+ white_list=set(
+ self._search_space.keys()),
+ desc=self._cfg.hpo.larger_better)
+ logger.info(
+ "====================================== HPO Final ========================================"
+ )
+ logger.info("\n{}".format(results))
+ logger.info(
+ "====================================================================================="
+ )
+
+ return results
+
+
+class IterativeScheduler(ModelFreeBase):
+ """The base class for HPO algorithms that divide the whole optimization procedure into iterations.
+ """
+ def _setup(self):
+ self._stage = 0
+ return super(IterativeScheduler, self)._setup()
+
+ def _stop_criterion(self, configs, last_results):
+ """To determine whether the algorithm should be terminated.
+
+ Arguments:
+ configs (list): each element is a trial configuration.
+ last_results (DataFrame): each row corresponds to a specific configuration as well as its latest performance.
+ :returns: whether to terminate.
+ :rtype: bool
+ """
+ raise NotImplementedError
+
+ def _iteration(self, configs):
+ """To evaluate the given collection of configurations at this stage.
+
+ Arguments:
+ configs (list): each element is a trial configuration.
+ :returns: the performances of the given configurations.
+ :rtype: list
+ """
+
+ perfs = self._evaluate(configs)
+ return perfs
+
+ def _generate_next_population(self, configs, perfs):
+ """To generate the configurations for the next stage.
+
+ Arguments:
+ configs (list): the configurations of last stage.
+ perfs (list): their corresponding performances.
+ :returns: configuration for the next stage.
+ :rtype: list
+ """
+
+ raise NotImplementedError
+
+ def optimize(self):
+ current_configs = deepcopy(self._init_configs)
+ last_results = None
+ while not self._stop_criterion(current_configs, last_results):
+ current_perfs = self._iteration(current_configs)
+ last_results = summarize_hpo_results(
+ current_configs,
+ current_perfs,
+ white_list=set(self._search_space.keys()),
+ desc=self._cfg.hpo.larger_better)
+ self._stage += 1
+ logger.info(
+ "====================================== Stage{} ========================================"
+ .format(self._stage))
+ logger.info("\n{}".format(last_results))
+ logger.info(
+ "======================================================================================="
+ )
+ current_configs = self._generate_next_population(
+ current_configs, current_perfs)
+
+ return current_configs
+
+
+class SuccessiveHalvingAlgo(IterativeScheduler):
+ """Successive Halving Algorithm (SHA) tailored to FL setting, where, in each iteration, just a limited number of communication rounds are allowed for each trial.
+ """
+ def _setup(self):
+ init_configs = super(SuccessiveHalvingAlgo, self)._setup()
+
+ for trial_cfg in init_configs:
+ trial_cfg['federate.save_to'] = os.path.join(
+ self._cfg.hpo.working_folder,
+ "{}.pth".format(config2str(trial_cfg)))
+
+ if self._cfg.hpo.sha.budgets:
+ for trial_cfg in init_configs:
+ trial_cfg[
+ 'federate.total_round_num'] = self._cfg.hpo.sha.budgets[
+ self._stage]
+ trial_cfg['eval.freq'] = self._cfg.hpo.sha.budgets[self._stage]
+
+ return init_configs
+
+ def _stop_criterion(self, configs, last_results):
+ return len(configs) <= 1
+
+ def _generate_next_population(self, configs, perfs):
+ indices = [(i, val) for i, val in enumerate(perfs)]
+ indices.sort(key=lambda x: x[1], reverse=self._cfg.hpo.larger_better)
+ next_population = [
+ configs[tp[0]] for tp in
+ indices[:math.
+ ceil(float(len(indices)) / self._cfg.hpo.sha.elim_rate)]
+ ]
+
+ for trial_cfg in next_population:
+ if 'federate.restore_from' not in trial_cfg:
+ trial_cfg['federate.restore_from'] = trial_cfg[
+ 'federate.save_to']
+ if self._cfg.hpo.sha.budgets and self._stage < len(
+ self._cfg.hpo.sha.budgets):
+ trial_cfg[
+ 'federate.total_round_num'] = self._cfg.hpo.sha.budgets[
+ self._stage]
+ trial_cfg['eval.freq'] = self._cfg.hpo.sha.budgets[self._stage]
+
+ return next_population
+
+
+class SHAWrapFedex(SuccessiveHalvingAlgo):
+ """This SHA is customized as a wrapper for FedEx algorithm."""
+ def _make_local_perturbation(self, config):
+ neighbor = dict()
+ for k in config:
+ if 'fedex' in k or 'fedopt' in k or k in [
+ 'federate.save_to', 'federate.total_round_num', 'eval.freq'
+ ]:
+ # a workaround
+ continue
+ hyper = self._search_space.get(k)
+ if isinstance(hyper, CS.UniformFloatHyperparameter):
+ lb, ub = hyper.lower, hyper.upper
+ diameter = self._cfg.hpo.table.eps * (ub - lb)
+ new_val = (config[k] -
+ 0.5 * diameter) + np.random.uniform() * diameter
+ neighbor[k] = float(np.clip(new_val, lb, ub))
+ elif isinstance(hyper, CS.UniformIntegerHyperparameter):
+ lb, ub = hyper.lower, hyper.upper
+ diameter = self._cfg.hpo.table.eps * (ub - lb)
+ new_val = round(
+ float((config[k] - 0.5 * diameter) +
+ np.random.uniform() * diameter))
+ neighbor[k] = int(np.clip(new_val, lb, ub))
+ elif isinstance(hyper, CS.CategoricalHyperparameter):
+ if len(hyper.choices) == 1:
+ neighbor[k] = config[k]
+ else:
+ threshold = self._cfg.hpo.table.eps * len(
+ hyper.choices) / (len(hyper.choices) - 1)
+ rn = np.random.uniform()
+ new_val = np.random.choice(
+ hyper.choices) if rn <= threshold else config[k]
+ if type(new_val) in [np.int32, np.int64]:
+ neighbor[k] = int(new_val)
+ elif type(new_val) in [np.float32, np.float64]:
+ neighbor[k] = float(new_val)
+ else:
+ neighbor[k] = str(new_val)
+ else:
+ raise TypeError("Value of {} has an invalid type {}".format(
+ k, type(config[k])))
+
+ return neighbor
+
+ def _setup(self):
+ #self._cache_yaml()
+ init_configs = super(SHAWrapFedex, self)._setup()
+ new_init_configs = []
+ for idx, trial_cfg in enumerate(init_configs):
+ arms = dict(("arm{}".format(1 + j),
+ self._make_local_perturbation(trial_cfg))
+ for j in range(self._cfg.hpo.table.num - 1))
+ arms['arm0'] = dict(
+ (k, v) for k, v in trial_cfg.items() if k in arms['arm1'])
+ with open(
+ os.path.join(self._cfg.hpo.working_folder,
+ f'{idx}_tmp_grid_search_space.yaml'),
+ 'w') as f:
+ yaml.dump(arms, f)
+ new_trial_cfg = dict()
+ for k in trial_cfg:
+ if k not in arms['arm0']:
+ new_trial_cfg[k] = trial_cfg[k]
+ new_trial_cfg['hpo.table.idx'] = idx
+ new_trial_cfg['hpo.fedex.ss'] = os.path.join(
+ self._cfg.hpo.working_folder,
+ f"{new_trial_cfg['hpo.table.idx']}_tmp_grid_search_space.yaml")
+ new_trial_cfg['federate.save_to'] = os.path.join(
+ self._cfg.hpo.working_folder, "idx_{}.pth".format(idx))
+ new_init_configs.append(new_trial_cfg)
+
+ self._search_space.add_hyperparameter(
+ CS.CategoricalHyperparameter("hpo.table.idx",
+ choices=list(
+ range(len(new_init_configs)))))
+
+ return new_init_configs
+
+
+# TODO: refactor PBT to enable async parallel
+#class PBT(IterativeScheduler):
+# """Population-based training (the full paper "Population Based Training of Neural Networks" can be found at https://arxiv.org/abs/1711.09846) tailored to FL setting, where, in each iteration, just a limited number of communication rounds are allowed for each trial (We will provide the asynchornous version later).
+# """
+# def _setup(self, raw_search_space):
+# _ = super(PBT, self)._setup(raw_search_space)
+#
+# if global_cfg.hpo.init_strategy == 'random':
+# init_configs = random_search(
+# raw_search_space,
+# sample_size=global_cfg.hpo.sha.elim_rate**
+# global_cfg.hpo.sha.elim_round_num)
+# elif global_cfg.hpo.init_strategy == 'grid':
+# init_configs = grid_search(raw_search_space,
+# sample_size=global_cfg.hpo.sha.elim_rate
+# **global_cfg.hpo.sha.elim_round_num)
+# else:
+# raise ValueError(
+# "SHA needs to use random/grid search to pick {} configs from the search space as initial candidates, but `{}` is specified as `hpo.init_strategy`"
+# .format(
+# global_cfg.hpo.sha.elim_rate**
+# global_cfg.hpo.sha.elim_round_num,
+# global_cfg.hpo.init_strategy))
+#
+# for trial_cfg in init_configs:
+# trial_cfg['federate.save_to'] = os.path.join(
+# global_cfg.hpo.working_folder,
+# "{}.pth".format(config2str(trial_cfg)))
+#
+# return init_configs
+#
+# def _stop_criterion(self, configs, last_results):
+# if last_results is not None:
+# if (global_cfg.hpo.larger_better
+# and last_results.iloc[0]['performance'] >=
+# global_cfg.hpo.pbt.perf_threshold) or (
+# (not global_cfg.hpo.larger_better)
+# and last_results.iloc[0]['performance'] <=
+# global_cfg.hpo.pbt.perf_threshold):
+# return True
+# return self._stage >= global_cfg.hpo.pbt.max_stage
+#
+# def _generate_next_population(self, configs, perfs):
+# next_generation = []
+# for i in range(len(configs)):
+# new_cfg = deepcopy(configs[i])
+# # exploit
+# j = np.random.randint(len(configs))
+# if i != j and (
+# (global_cfg.hpo.larger_better and perfs[j] > perfs[i]) or
+# ((not global_cfg.hpo.larger_better) and perfs[j] < perfs[i])):
+# new_cfg['federate.restore_from'] = configs[j][
+# 'federate.save_to']
+# # explore
+# for k in new_cfg:
+# if isinstance(new_cfg[k], float):
+# # according to the exploration strategy of the PBT paper
+# new_cfg[k] *= float(np.random.choice([0.8, 1.2]))
+# else:
+# new_cfg['federate.restore_from'] = configs[i][
+# 'federate.save_to']
+#
+# # update save path
+# tmp_cfg = dict()
+# for k in new_cfg:
+# if k in self._original_search_space:
+# tmp_cfg[k] = new_cfg[k]
+# new_cfg['federate.save_to'] = os.path.join(
+# global_cfg.hpo.working_folder,
+# "{}.pth".format(config2str(tmp_cfg)))
+#
+# next_generation.append(new_cfg)
+#
+# return next_generation
diff --git a/federatedscope/autotune/choice_types.py b/federatedscope/autotune/choice_types.py
new file mode 100644
index 000000000..cf04360bc
--- /dev/null
+++ b/federatedscope/autotune/choice_types.py
@@ -0,0 +1,159 @@
+#import os
+#import sys
+#file_dir = os.path.join(os.path.dirname(__file__), '../..')
+#sys.path.append(file_dir)
+import logging
+import math
+import yaml
+
+import numpy as np
+
+from federatedscope.core.configs.config import global_cfg
+
+logger = logging.getLogger(__name__)
+
+
+def discretize(contd_choices, num_bkt):
+ '''Discretize a given continuous search space into the given number of buckets.
+
+ Arguments:
+ contd_choices (Continuous): continuous choices.
+ num_bkt (int): number of buckets.
+ :returns: discritized choices.
+ :rtype: Discrete
+ '''
+ if contd_choices[0] >= .0 and global_cfg.hpo.log_scale:
+ loglb, logub = math.log(
+ np.clip(contd_choices[0], 1e-8,
+ contd_choices[1])), math.log(contd_choices[1])
+ if num_bkt == 1:
+ choices = [math.exp(loglb + 0.5 * (logub - loglb))]
+ else:
+ bkt_size = (logub - loglb) / (num_bkt - 1)
+ choices = [math.exp(loglb + i * bkt_size) for i in range(num_bkt)]
+ else:
+ if num_bkt == 1:
+ choices = [
+ contd_choices[0] + 0.5 * (contd_choices[1] - contd_choices[0])
+ ]
+ else:
+ bkt_size = (contd_choices[1] - contd_choices[0]) / (num_bkt - 1)
+ choices = [contd_choices[0] + i * bkt_size for i in range(num_bkt)]
+ disc_choices = Discrete(*choices)
+ return disc_choices
+
+
+class Continuous(tuple):
+ """Represents a continuous search space, e.g., in the range [0.001, 0.1].
+ """
+ def __new__(cls, lb, ub):
+ assert ub >= lb, "Invalid configuration where ub:{} is less than lb:{}".format(
+ ub, lb)
+ return tuple.__new__(cls, [lb, ub])
+
+ def __repr__(self):
+ return "Continuous(%s,%s)" % self
+
+ def sample(self):
+ """Sample a value from this search space.
+
+ :returns: the sampled value.
+ :rtype: float
+ """
+ if self[0] >= .0 and global_cfg.hpo.log_scale:
+ loglb, logub = math.log(np.clip(self[0], 1e-8,
+ self[1])), math.log(self[1])
+ return math.exp(loglb + np.random.rand() * (logub - loglb))
+ else:
+ return float(self[0] + np.random.rand() * (self[1] - self[0]))
+
+ def grid(self, grid_cnt):
+ """Generate a given nunber of grids from this search space.
+
+ Arguments:
+ grid_cnt (int): the number of grids.
+ :returns: the sampled value.
+ :rtype: float
+ """
+ discretized = discretize(self, grid_cnt)
+ return list(discretized)
+
+
+def contd_constructor(loader, node):
+ value = loader.construct_scalar(node)
+ lb, ub = map(float, value.split(','))
+ return Continuous(lb, ub)
+
+
+yaml.add_constructor(u'!contd', contd_constructor)
+
+
+class Discrete(tuple):
+ """Represents a discrete search space, e.g., {'abc', 'ijk', 'xyz'}.
+ """
+ def __new__(cls, *args):
+ return tuple.__new__(cls, args)
+
+ def __repr__(self):
+ return "Discrete(%s)" % ','.join(map(str, self))
+
+ def sample(self):
+ """Sample a value from this search space.
+
+ :returns: the sampled value.
+ :rtype: depends on the original choices.
+ """
+
+ return self[np.random.randint(len(self))]
+
+ def grid(self, grid_cnt):
+ num_original = len(self)
+ assert grid_cnt <= num_original, "There are only {} choices to produce grids, but {} required".format(
+ num_original, grid_cnt)
+ if grid_cnt == 1:
+ selected = [self[len(self) // 2]]
+ else:
+ optimistic_step_size = (num_original - 1) // grid_cnt
+ between_end_len = optimistic_step_size * (grid_cnt - 1)
+ remainder = (num_original - 1) - between_end_len
+ one_side_remainder = remainder // 2 if remainder % 2 == 0 else remainder // 2 + 1
+ if one_side_remainder <= optimistic_step_size // 2:
+ step_size = optimistic_step_size
+ else:
+ step_size = (num_original - 1) // (grid_cnt - 1)
+ covered_range = (grid_cnt - 1) * step_size
+ start_idx = (max(num_original - 1, 1) - covered_range) // 2
+ selected = [
+ self[j] for j in range(
+ start_idx,
+ min(start_idx +
+ grid_cnt * step_size, num_original), step_size)
+ ]
+ return selected
+
+
+def disc_constructor(loader, node):
+ value = loader.construct_sequence(node)
+ return Discrete(*value)
+
+
+yaml.add_constructor(u'!disc', disc_constructor)
+
+#if __name__=="__main__":
+# obj = Continuous(0.0, 0.01)
+# print(obj.grid(1), obj.grid(2), obj.grid(3))
+# for _ in range(3):
+# print(obj.sample())
+# cfg.merge_from_list(['hpo.log_scale', 'True'])
+# print(obj.grid(1), obj.grid(2), obj.grid(3))
+# for _ in range(3):
+# print(obj.sample())
+#
+# obj = Discrete('a', 'b', 'c')
+# print(obj.grid(1), obj.grid(2), obj.grid(3))
+# for _ in range(3):
+# print(obj.sample())
+# obj = Discrete(1, 2, 3, 4, 5)
+# print(obj.grid(1), obj.grid(2), obj.grid(3), obj.grid(4), obj.grid(5))
+# for _ in range(3):
+# print(obj.sample())
diff --git a/federatedscope/autotune/fedex/__init__.py b/federatedscope/autotune/fedex/__init__.py
new file mode 100644
index 000000000..ae2a87680
--- /dev/null
+++ b/federatedscope/autotune/fedex/__init__.py
@@ -0,0 +1,4 @@
+from federatedscope.autotune.fedex.server import FedExServer
+from federatedscope.autotune.fedex.client import FedExClient
+
+__all__ = ['FedExServer', 'FedExClient']
diff --git a/federatedscope/autotune/fedex/client.py b/federatedscope/autotune/fedex/client.py
new file mode 100644
index 000000000..c9c3f948e
--- /dev/null
+++ b/federatedscope/autotune/fedex/client.py
@@ -0,0 +1,88 @@
+import logging
+import json
+
+from federatedscope.core.message import Message
+from federatedscope.core.worker import Client
+
+logger = logging.getLogger(__name__)
+
+
+class FedExClient(Client):
+ """Some code snippets are borrowed from the open-sourced FedEx (https://github.com/mkhodak/FedEx)
+ """
+ def _apply_hyperparams(self, hyperparams):
+ """Apply the given hyperparameters
+ Arguments:
+ hyperparams (dict): keys are hyperparameter names and values are specific choices.
+ """
+
+ cmd_args = []
+ for k, v in hyperparams.items():
+ cmd_args.append(k)
+ cmd_args.append(v)
+
+ self._cfg.defrost()
+ self._cfg.merge_from_list(cmd_args)
+ self._cfg.freeze(inform=False)
+
+ self.trainer.ctx.setup_vars()
+
+ def callback_funcs_for_model_para(self, message: Message):
+ round, sender, content = message.state, message.sender, message.content
+ model_params, arms, hyperparams = content["model_param"], content[
+ "arms"], content["hyperparam"]
+ attempt = {
+ 'Role': 'Client #{:d}'.format(self.ID),
+ 'Round': self.state + 1,
+ 'Arms': arms,
+ 'Hyperparams': hyperparams
+ }
+ logger.info(json.dumps(attempt))
+
+ self._apply_hyperparams(hyperparams)
+
+ self.trainer.update(model_params)
+
+ #self.model.load_state_dict(content)
+ self.state = round
+ sample_size, model_para_all, results = self.trainer.train()
+ logger.info(
+ self._monitor.format_eval_res(results,
+ rnd=self.state,
+ role='Client #{}'.format(self.ID),
+ return_raw=True))
+
+ results['arms'] = arms
+ content = (sample_size, model_para_all, results)
+ self.comm_manager.send(
+ Message(msg_type='model_para',
+ sender=self.ID,
+ receiver=[sender],
+ state=self.state,
+ content=content))
+
+ def callback_funcs_for_evaluate(self, message: Message):
+ sender = message.sender
+ self.state = message.state
+ if message.content != None:
+ model_params = message.content["model_param"]
+ self.trainer.update(model_params)
+ if self._cfg.trainer.finetune.before_eval:
+ self.trainer.finetune()
+ metrics = {}
+ for split in self._cfg.eval.split:
+ eval_metrics = self.trainer.evaluate(target_data_split_name=split)
+ for key in eval_metrics:
+
+ if self._cfg.federate.mode == 'distributed':
+ logger.info(
+ 'Client #{:d}: (Evaluation ({:s} set) at Round #{:d}) {:s} is {:.6f}'
+ .format(self.ID, split, self.state, key,
+ eval_metrics[key]))
+ metrics.update(**eval_metrics)
+ self.comm_manager.send(
+ Message(msg_type='metrics',
+ sender=self.ID,
+ receiver=[sender],
+ state=self.state,
+ content=metrics))
diff --git a/federatedscope/autotune/fedex/server.py b/federatedscope/autotune/fedex/server.py
new file mode 100644
index 000000000..2efa8779a
--- /dev/null
+++ b/federatedscope/autotune/fedex/server.py
@@ -0,0 +1,426 @@
+import os
+import logging
+from itertools import product
+
+import yaml
+
+import numpy as np
+from numpy.linalg import norm
+from scipy.special import logsumexp
+
+from federatedscope.core.message import Message
+from federatedscope.core.worker import Server
+from federatedscope.core.auxiliaries.utils import merge_dict
+
+logger = logging.getLogger(__name__)
+
+
+def discounted_mean(trace, factor=1.0):
+
+ weight = factor**np.flip(np.arange(len(trace)), axis=0)
+ return np.inner(trace, weight) / weight.sum()
+
+
+class FedExServer(Server):
+ """Some code snippets are borrowed from the open-sourced FedEx (https://github.com/mkhodak/FedEx)
+ """
+ def __init__(self,
+ ID=-1,
+ state=0,
+ config=None,
+ data=None,
+ model=None,
+ client_num=5,
+ total_round_num=10,
+ device='cpu',
+ strategy=None,
+ **kwargs):
+
+ # initialize action space and the policy
+ with open(config.hpo.fedex.ss, 'r') as ips:
+ ss = yaml.load(ips, Loader=yaml.FullLoader)
+
+ if next(iter(ss.keys())).startswith('arm'):
+ # This is a flattened action space
+ # ensure the order is unchanged
+ ss = sorted([(int(k[3:]), v) for k, v in ss.items()],
+ key=lambda x: x[0])
+ self._grid = []
+ self._cfsp = [[tp[1] for tp in ss]]
+ else:
+ # This is not a flat search space
+ # be careful for the order
+ self._grid = sorted(ss.keys())
+ self._cfsp = [ss[pn] for pn in self._grid]
+
+ sizes = [len(cand_set) for cand_set in self._cfsp]
+ eta0 = 'auto' if config.hpo.fedex.eta0 <= .0 else float(
+ config.hpo.fedex.eta0)
+ self._eta0 = [
+ np.sqrt(2.0 * np.log(size)) if eta0 == 'auto' else eta0
+ for size in sizes
+ ]
+ self._sched = config.hpo.fedex.sched
+ self._cutoff = config.hpo.fedex.cutoff
+ self._baseline = config.hpo.fedex.gamma
+ self._diff = config.hpo.fedex.diff
+ self._z = [np.full(size, -np.log(size)) for size in sizes]
+ self._theta = [np.exp(z) for z in self._z]
+ self._store = [0.0 for _ in sizes]
+ self._stop_exploration = False
+ self._trace = {
+ 'global': [],
+ 'refine': [],
+ 'entropy': [self.entropy()],
+ 'mle': [self.mle()]
+ }
+
+ super(FedExServer,
+ self).__init__(ID, state, config, data, model, client_num,
+ total_round_num, device, strategy, **kwargs)
+
+ if self._cfg.federate.restore_from != '':
+ pi_ckpt_path = self._cfg.federate.restore_from[:self._cfg.federate.
+ restore_from.rfind(
+ '.'
+ )] + "_fedex.yaml"
+ with open(pi_ckpt_path, 'r') as ips:
+ ckpt = yaml.load(ips, Loader=yaml.FullLoader)
+ self._z = [np.asarray(z) for z in ckpt['z']]
+ self._theta = [np.exp(z) for z in self._z]
+ self._store = ckpt['store']
+ self._stop_exploration = ckpt['stop']
+ self._trace = dict()
+ self._trace['global'] = ckpt['global']
+ self._trace['refine'] = ckpt['refine']
+ self._trace['entropy'] = ckpt['entropy']
+ self._trace['mle'] = ckpt['mle']
+
+ def entropy(self):
+
+ entropy = 0.0
+ for probs in product(*(theta[theta > 0.0] for theta in self._theta)):
+ prob = np.prod(probs)
+ entropy -= prob * np.log(prob)
+ return entropy
+
+ def mle(self):
+
+ return np.prod([theta.max() for theta in self._theta])
+
+ def trace(self, key):
+ '''returns trace of one of three tracked quantities
+ Args:
+ key (str): 'entropy', 'global', or 'refine'
+ Returns:
+ numpy vector with length equal to number of rounds up to now.
+ '''
+
+ return np.array(self._trace[key])
+
+ def sample(self):
+ """samples from configs using current probability vector"""
+
+ # determine index
+ if self._stop_exploration:
+ cfg_idx = [theta.argmax() for theta in self._theta]
+ else:
+ cfg_idx = [
+ np.random.choice(len(theta), p=theta) for theta in self._theta
+ ]
+
+ # get the sampled value(s)
+ if self._grid:
+ sampled_cfg = {
+ pn: cands[i]
+ for pn, cands, i in zip(self._grid, self._cfsp, cfg_idx)
+ }
+ else:
+ sampled_cfg = self._cfsp[0][cfg_idx[0]]
+
+ return cfg_idx, sampled_cfg
+
+ def broadcast_model_para(self,
+ msg_type='model_para',
+ sample_client_num=-1):
+ """
+ To broadcast the message to all clients or sampled clients
+ """
+
+ if sample_client_num > 0:
+ receiver = np.random.choice(np.arange(1, self.client_num + 1),
+ size=sample_client_num,
+ replace=False).tolist()
+ else:
+ # broadcast to all clients
+ receiver = list(self.comm_manager.neighbors.keys())
+
+ if self._noise_injector is not None and msg_type == 'model_para':
+ # Inject noise only when broadcast parameters
+ for model_idx_i in range(len(self.models)):
+ num_sample_clients = [
+ v["num_sample"] for v in self.join_in_info.values()
+ ]
+ self._noise_injector(self._cfg, num_sample_clients,
+ self.models[model_idx_i])
+
+ if self.model_num > 1:
+ model_para = [model.state_dict() for model in self.models]
+ else:
+ model_para = self.model.state_dict()
+
+ # sample the hyper-parameter config specific to the clients
+
+ for rcv_idx in receiver:
+ cfg_idx, sampled_cfg = self.sample()
+ content = {
+ 'model_param': model_para,
+ "arms": cfg_idx,
+ 'hyperparam': sampled_cfg
+ }
+ self.comm_manager.send(
+ Message(msg_type=msg_type,
+ sender=self.ID,
+ receiver=[rcv_idx],
+ state=self.state,
+ content=content))
+ if self._cfg.federate.online_aggr:
+ for idx in range(self.model_num):
+ self.aggregators[idx].reset()
+
+ def callback_funcs_model_para(self, message: Message):
+ round, sender, content = message.state, message.sender, message.content
+ # For a new round
+ if round not in self.msg_buffer['train'].keys():
+ self.msg_buffer['train'][round] = dict()
+
+ self.msg_buffer['train'][round][sender] = content
+
+ if self._cfg.federate.online_aggr:
+ self.aggregator.inc(tuple(content[0:2]))
+
+ return self.check_and_move_on()
+
+ def update_policy(self, feedbacks):
+ """Update the policy. This implementation is borrowed from the open-sourced FedEx (https://github.com/mkhodak/FedEx/blob/150fac03857a3239429734d59d319da71191872e/hyper.py#L151)
+ Arguments:
+ feedbacks (list): each element is a dict containing "arms" and necessary feedback.
+ """
+
+ index = [elem['arms'] for elem in feedbacks]
+ before = np.asarray(
+ [elem['val_avg_loss_before'] for elem in feedbacks])
+ after = np.asarray([elem['val_avg_loss_after'] for elem in feedbacks])
+ weight = np.asarray([elem['val_total'] for elem in feedbacks],
+ dtype=np.float64)
+ weight /= np.sum(weight)
+
+ if self._trace['refine']:
+ trace = self.trace('refine')
+ if self._diff:
+ trace -= self.trace('global')
+ baseline = discounted_mean(trace, self._baseline)
+ else:
+ baseline = 0.0
+ self._trace['global'].append(np.inner(before, weight))
+ self._trace['refine'].append(np.inner(after, weight))
+ if self._stop_exploration:
+ self._trace['entropy'].append(0.0)
+ self._trace['mle'].append(1.0)
+ return
+
+ for i, (z, theta) in enumerate(zip(self._z, self._theta)):
+ grad = np.zeros(len(z))
+ for idx, s, w in zip(index,
+ after - before if self._diff else after,
+ weight):
+ grad[idx[i]] += w * (s - baseline) / theta[idx[i]]
+ if self._sched == 'adaptive':
+ self._store[i] += norm(grad, float('inf'))**2
+ denom = np.sqrt(self._store[i])
+ elif self._sched == 'aggressive':
+ denom = 1.0 if np.all(
+ grad == 0.0) else norm(grad, float('inf'))
+ elif self._sched == 'auto':
+ self._store[i] += 1.0
+ denom = np.sqrt(self._store[i])
+ elif self._sched == 'constant':
+ denom = 1.0
+ elif self._sched == 'scale':
+ denom = 1.0 / np.sqrt(
+ 2.0 * np.log(len(grad))) if len(grad) > 1 else float('inf')
+ else:
+ raise NotImplementedError
+ eta = self._eta0[i] / denom
+ z -= eta * grad
+ z -= logsumexp(z)
+ self._theta[i] = np.exp(z)
+
+ self._trace['entropy'].append(self.entropy())
+ self._trace['mle'].append(self.mle())
+ if self._trace['entropy'][-1] < self._cutoff:
+ self._stop_exploration = True
+
+ logger.info(
+ 'Server #{:d}: Updated policy as {} with entropy {:f} and mle {:f}'
+ .format(self.ID, self._theta, self._trace['entropy'][-1],
+ self._trace['mle'][-1]))
+
+ def check_and_move_on(self,
+ check_eval_result=False,
+ min_received_num=None):
+ """
+ To check the message_buffer, when enough messages are receiving, trigger some events (such as perform aggregation, evaluation, and move to the next training round)
+ """
+ if min_received_num is None:
+ min_received_num = self._cfg.federate.sample_client_num
+ assert min_received_num <= self.sample_client_num
+
+ if check_eval_result:
+ min_received_num = len(list(self.comm_manager.neighbors.keys()))
+
+ move_on_flag = True # To record whether moving to a new training round or finishing the evaluation
+ if self.check_buffer(self.state, min_received_num, check_eval_result):
+
+ if not check_eval_result: # in the training process
+ mab_feedbacks = list()
+ # Get all the message
+ train_msg_buffer = self.msg_buffer['train'][self.state]
+ for model_idx in range(self.model_num):
+ model = self.models[model_idx]
+ aggregator = self.aggregators[model_idx]
+ msg_list = list()
+ for client_id in train_msg_buffer:
+ if self.model_num == 1:
+ msg_list.append(
+ tuple(train_msg_buffer[client_id][0:2]))
+ else:
+ train_data_size, model_para_multiple = train_msg_buffer[
+ client_id][0:2]
+ msg_list.append((train_data_size,
+ model_para_multiple[model_idx]))
+
+ # collect feedbacks for updating the policy
+ if model_idx == 0:
+ mab_feedbacks.append(
+ train_msg_buffer[client_id][2])
+
+ # Trigger the monitor here (for training)
+ if 'dissim' in self._cfg.eval.monitoring:
+ B_val = calc_blocal_dissim(
+ model.load_state_dict(strict=False), msg_list)
+ formatted_eval_res = self._monitor.format_eval_res(
+ B_val, rnd=self.state, role='Server #')
+ logger.info(formatted_eval_res)
+
+ # Aggregate
+ agg_info = {
+ 'client_feedback': msg_list,
+ 'recover_fun': self.recover_fun
+ }
+ result = aggregator.aggregate(agg_info)
+ model.load_state_dict(result, strict=False)
+ #aggregator.update(result)
+
+ # update the policy
+ self.update_policy(mab_feedbacks)
+
+ self.state += 1
+ if self.state % self._cfg.eval.freq == 0 and self.state != self.total_round_num:
+ # Evaluate
+ logger.info(
+ 'Server #{:d}: Starting evaluation at round {:d}.'.
+ format(self.ID, self.state))
+ self.eval()
+
+ if self.state < self.total_round_num:
+ # Move to next round of training
+ logger.info(
+ '----------- Starting a new training round (Round #{:d}) -------------'
+ .format(self.state))
+ # Clean the msg_buffer
+ self.msg_buffer['train'][self.state - 1].clear()
+
+ self.broadcast_model_para(
+ msg_type='model_para',
+ sample_client_num=self.sample_client_num)
+ else:
+ # Final Evaluate
+ logger.info(
+ 'Server #{:d}: Training is finished! Starting evaluation.'
+ .format(self.ID))
+ self.eval()
+
+ else: # in the evaluation process
+ # Get all the message & aggregate
+ formatted_eval_res = self.merge_eval_results_from_all_clients()
+ self.history_results = merge_dict(self.history_results,
+ formatted_eval_res)
+ self.check_and_save()
+ else:
+ move_on_flag = False
+
+ return move_on_flag
+
+ def check_and_save(self):
+ """
+ To save the results and save model after each evaluation
+ """
+ # early stopping
+ should_stop = False
+
+ if "Results_weighted_avg" in self.history_results and \
+ self._cfg.eval.best_res_update_round_wise_key in self.history_results['Results_weighted_avg']:
+ should_stop = self.early_stopper.track_and_check(
+ self.history_results['Results_weighted_avg'][
+ self._cfg.eval.best_res_update_round_wise_key])
+ elif "Results_avg" in self.history_results and \
+ self._cfg.eval.best_res_update_round_wise_key in self.history_results['Results_avg']:
+ should_stop = self.early_stopper.track_and_check(
+ self.history_results['Results_avg'][
+ self._cfg.eval.best_res_update_round_wise_key])
+ else:
+ should_stop = False
+
+ if should_stop:
+ self.state = self.total_round_num + 1
+
+ if should_stop or self.state == self.total_round_num:
+ logger.info(
+ 'Server #{:d}: Final evaluation is finished! Starting merging results.'
+ .format(self.ID))
+ # last round
+ self.save_best_results()
+
+ if self._cfg.federate.save_to != '':
+ # save the policy
+ ckpt = dict()
+ z_list = [z.tolist() for z in self._z]
+ ckpt['z'] = z_list
+ ckpt['store'] = self._store
+ ckpt['stop'] = self._stop_exploration
+ ckpt['global'] = self.trace('global').tolist()
+ ckpt['refine'] = self.trace('refine').tolist()
+ ckpt['entropy'] = self.trace('entropy').tolist()
+ ckpt['mle'] = self.trace('mle').tolist()
+ pi_ckpt_path = self._cfg.federate.save_to[:self._cfg.federate.
+ save_to.rfind(
+ '.'
+ )] + "_fedex.yaml"
+ with open(pi_ckpt_path, 'w') as ops:
+ yaml.dump(ckpt, ops)
+
+ if self.model_num > 1:
+ model_para = [model.state_dict() for model in self.models]
+ else:
+ model_para = self.model.state_dict()
+ self.comm_manager.send(
+ Message(msg_type='finish',
+ sender=self.ID,
+ receiver=list(self.comm_manager.neighbors.keys()),
+ state=self.state,
+ content=model_para))
+
+ if self.state == self.total_round_num:
+ #break out the loop for distributed mode
+ self.state += 1
diff --git a/federatedscope/autotune/utils.py b/federatedscope/autotune/utils.py
new file mode 100644
index 000000000..422301580
--- /dev/null
+++ b/federatedscope/autotune/utils.py
@@ -0,0 +1,133 @@
+from copy import deepcopy
+import math
+
+import yaml
+import pandas as pd
+import ConfigSpace as CS
+
+
+def parse_search_space(config_path):
+ """Parse yaml format configuration to generate search space
+ Arguments:
+ config_path (str): the path of the yaml file.
+ :returns: the search space.
+ :rtype: ConfigSpace object
+ """
+
+ ss = CS.ConfigurationSpace()
+
+ with open(config_path, 'r') as ips:
+ raw_ss_config = yaml.load(ips, Loader=yaml.FullLoader)
+
+ for k in raw_ss_config.keys():
+ name = k
+ v = raw_ss_config[k]
+ hyper_type = v['type']
+ del v['type']
+ v['name'] = name
+
+ if hyper_type == 'float':
+ hyper_config = CS.UniformFloatHyperparameter(**v)
+ elif hyper_type == 'int':
+ hyper_config = CS.UniformIntegerHyperparameter(**v)
+ elif hyper_type == 'cate':
+ hyper_config = CS.CategoricalHyperparameter(**v)
+ else:
+ raise ValueError("Unsupported hyper type {}".format(hyper_type))
+ ss.add_hyperparameter(hyper_config)
+
+ return ss
+
+
+def config2cmdargs(config):
+ '''
+ Arguments:
+ config (dict): key is cfg node name, value is the specified value.
+ Returns:
+ results (list): cmd args
+ '''
+
+ results = []
+ for k, v in config.items():
+ results.append(k)
+ results.append(v)
+ return results
+
+
+def config2str(config):
+ '''
+ Arguments:
+ config (dict): key is cfg node name, value is the choice of hyper-parameter.
+ Returns:
+ name (str): the string representation of this config
+ '''
+
+ vals = []
+ for k in config:
+ idx = k.rindex('.')
+ vals.append(k[idx + 1:])
+ vals.append(str(config[k]))
+ name = '_'.join(vals)
+ return name
+
+
+def summarize_hpo_results(configs, perfs, white_list=None, desc=False):
+ cols = [k for k in configs[0] if (white_list is None or k in white_list)
+ ] + ['performance']
+ d = [[
+ trial_cfg[k]
+ for k in trial_cfg if (white_list is None or k in white_list)
+ ] + [result] for trial_cfg, result in zip(configs, perfs)]
+ d = sorted(d, key=lambda ele: ele[-1], reverse=desc)
+ df = pd.DataFrame(d, columns=cols)
+ return df
+
+
+def parse_logs(file_list):
+ import numpy as np
+ import matplotlib.pyplot as plt
+
+ FONTSIZE = 40
+ MARKSIZE = 25
+
+ def process(file):
+ history = []
+ with open(file, 'r') as F:
+ for line in F:
+ try:
+ state, line = line.split('INFO: ')
+ config = eval(line[line.find('{'):line.find('}') + 1])
+ performance = float(
+ line[line.find('performance'):].split(' ')[1])
+ print(config, performance)
+ history.append((config, performance))
+ except:
+ continue
+ best_seen = np.inf
+ tol_budget = 0
+ x, y = [], []
+
+ for config, performance in history:
+ tol_budget += config['federate.total_round_num']
+ if best_seen > performance or config[
+ 'federate.total_round_num'] > tmp_b:
+ best_seen = performance
+ x.append(tol_budget)
+ y.append(best_seen)
+ tmp_b = config['federate.total_round_num']
+ return np.array(x) / tol_budget, np.array(y)
+
+ # Draw
+ plt.figure(figsize=(10, 7.5))
+ plt.xticks(fontsize=FONTSIZE)
+ plt.yticks(fontsize=FONTSIZE)
+
+ plt.xlabel('Fraction of budget', size=FONTSIZE)
+ plt.ylabel('Loss', size=FONTSIZE)
+
+ for file in file_list:
+ x, y = process(file)
+ plt.plot(x, y, linewidth=1, markersize=MARKSIZE)
+ plt.legend(file_list, fontsize=23, loc='lower right')
+ plt.savefig(f'exp2.pdf', bbox_inches='tight')
+ plt.close()
diff --git a/federatedscope/contrib/__init__.py b/federatedscope/contrib/__init__.py
new file mode 100644
index 000000000..f8e91f237
--- /dev/null
+++ b/federatedscope/contrib/__init__.py
@@ -0,0 +1,3 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
diff --git a/federatedscope/contrib/configs/__init__.py b/federatedscope/contrib/configs/__init__.py
new file mode 100644
index 000000000..cef30fadb
--- /dev/null
+++ b/federatedscope/contrib/configs/__init__.py
@@ -0,0 +1,14 @@
+import copy
+from os.path import dirname, basename, isfile, join
+import glob
+
+modules = glob.glob(join(dirname(__file__), "*.py"))
+__all__ = [
+ basename(f)[:-3] for f in modules
+ if isfile(f) and not f.endswith('__init__.py')
+]
+
+# to ensure the sub-configs registered before set up the global config
+all_sub_configs_contrib = copy.copy(__all__)
+if "config" in all_sub_configs_contrib:
+ all_sub_configs_contrib.remove('config')
diff --git a/federatedscope/contrib/data/__init__.py b/federatedscope/contrib/data/__init__.py
new file mode 100644
index 000000000..c0b31382d
--- /dev/null
+++ b/federatedscope/contrib/data/__init__.py
@@ -0,0 +1,8 @@
+from os.path import dirname, basename, isfile, join
+import glob
+
+modules = glob.glob(join(dirname(__file__), "*.py"))
+__all__ = [
+ basename(f)[:-3] for f in modules
+ if isfile(f) and not f.endswith('__init__.py')
+]
diff --git a/federatedscope/contrib/data/example.py b/federatedscope/contrib/data/example.py
new file mode 100644
index 000000000..557236afc
--- /dev/null
+++ b/federatedscope/contrib/data/example.py
@@ -0,0 +1,30 @@
+from federatedscope.register import register_data
+
+
+def MyData(config):
+ r"""
+
+ Returns:
+ data:
+ {
+ '{client_id}': {
+ 'train': Dataset or DataLoader,
+ 'test': Dataset or DataLoader,
+ 'val': Dataset or DataLoader
+ }
+ }
+ config:
+ cfg_node
+ """
+ data = None
+ config = config
+ return data, config
+
+
+def call_my_data(config):
+ if config.data.type == "mydata":
+ data, modified_config = MyData(config)
+ return data, modified_config
+
+
+register_data("mydata", call_my_data)
diff --git a/federatedscope/contrib/metrics/__init__.py b/federatedscope/contrib/metrics/__init__.py
new file mode 100644
index 000000000..c0b31382d
--- /dev/null
+++ b/federatedscope/contrib/metrics/__init__.py
@@ -0,0 +1,8 @@
+from os.path import dirname, basename, isfile, join
+import glob
+
+modules = glob.glob(join(dirname(__file__), "*.py"))
+__all__ = [
+ basename(f)[:-3] for f in modules
+ if isfile(f) and not f.endswith('__init__.py')
+]
diff --git a/federatedscope/contrib/metrics/example.py b/federatedscope/contrib/metrics/example.py
new file mode 100644
index 000000000..7fdb699bf
--- /dev/null
+++ b/federatedscope/contrib/metrics/example.py
@@ -0,0 +1,16 @@
+from federatedscope.register import register_metric
+
+METRIC_NAME = 'example'
+
+
+def MyMetric(ctx, **kwargs):
+ return ctx["num_train_data"]
+
+
+def call_my_metric(types):
+ if METRIC_NAME in types:
+ metric_builder = MyMetric
+ return METRIC_NAME, metric_builder
+
+
+register_metric(METRIC_NAME, call_my_metric)
diff --git a/federatedscope/contrib/metrics/poison_acc.py b/federatedscope/contrib/metrics/poison_acc.py
new file mode 100644
index 000000000..75f408052
--- /dev/null
+++ b/federatedscope/contrib/metrics/poison_acc.py
@@ -0,0 +1,31 @@
+from federatedscope.register import register_metric
+import numpy as np
+
+
+def compute_poison_metric(ctx):
+
+ poison_true = ctx['poison_' + ctx.cur_data_split + '_y_true']
+ poison_prob = ctx['poison_' + ctx.cur_data_split + '_y_prob']
+ poison_pred = np.argmax(poison_prob, axis=1)
+
+ correct = poison_true == poison_pred
+
+ return float(np.sum(correct)) / len(correct)
+
+
+def load_poison_metrics(ctx, y_true, y_pred, y_prob, **kwargs):
+
+ if ctx.cur_data_split == 'train':
+ results = None
+ else:
+ results = compute_poison_metric(ctx)
+
+ return results
+
+
+def call_poison_metric(types):
+ if 'poison_attack_acc' in types:
+ return 'poison_attack_acc', load_poison_metrics
+
+
+register_metric('poison_attack_acc', call_poison_metric)
diff --git a/federatedscope/contrib/model/__init__.py b/federatedscope/contrib/model/__init__.py
new file mode 100644
index 000000000..c0b31382d
--- /dev/null
+++ b/federatedscope/contrib/model/__init__.py
@@ -0,0 +1,8 @@
+from os.path import dirname, basename, isfile, join
+import glob
+
+modules = glob.glob(join(dirname(__file__), "*.py"))
+__all__ = [
+ basename(f)[:-3] for f in modules
+ if isfile(f) and not f.endswith('__init__.py')
+]
diff --git a/federatedscope/contrib/model/example.py b/federatedscope/contrib/model/example.py
new file mode 100644
index 000000000..899246a35
--- /dev/null
+++ b/federatedscope/contrib/model/example.py
@@ -0,0 +1,23 @@
+from federatedscope.register import register_model
+
+
+# Build you torch or tf model class here
+class MyNet(object):
+ pass
+
+
+# Instantiate your model class with config and data
+def ModelBuilder(model_config, local_data):
+
+ model = MyNet()
+
+ return model
+
+
+def call_my_net(model_config, local_data):
+ if model_config.type == "mynet":
+ model = ModelBuilder(model_config, local_data)
+ return model
+
+
+register_model("mynet", call_my_net)
diff --git a/federatedscope/contrib/model/resnet.py b/federatedscope/contrib/model/resnet.py
new file mode 100644
index 000000000..b89b4c46c
--- /dev/null
+++ b/federatedscope/contrib/model/resnet.py
@@ -0,0 +1,376 @@
+from federatedscope.register import register_model
+'''Pre-activation ResNet in PyTorch.
+
+Reference:
+[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+ Identity Mappings in Deep Residual Networks. arXiv:1603.05027
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class PreActBlock(nn.Module):
+ '''Pre-activation version of the BasicBlock.'''
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(PreActBlock, self).__init__()
+ self.bn1 = nn.BatchNorm2d(in_planes)
+ self.conv1 = nn.Conv2d(in_planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False))
+
+ def forward(self, x):
+ out = F.relu(self.bn1(x))
+ shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
+ out = self.conv1(out)
+ out = self.conv2(F.relu(self.bn2(out)))
+ out += shortcut
+ return out
+
+
+class PreActBottleneck(nn.Module):
+ '''Pre-activation version of the original Bottleneck module.'''
+ expansion = 4
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(PreActBottleneck, self).__init__()
+ self.bn1 = nn.BatchNorm2d(in_planes)
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+ self.bn3 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes,
+ self.expansion * planes,
+ kernel_size=1,
+ bias=False)
+
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False))
+
+ def forward(self, x):
+ out = F.relu(self.bn1(x))
+ shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
+ out = self.conv1(out)
+ out = self.conv2(F.relu(self.bn2(out)))
+ out = self.conv3(F.relu(self.bn3(out)))
+ out += shortcut
+ return out
+
+
+class PreActResNet(nn.Module):
+ def __init__(self, block, num_blocks, num_classes=10):
+ super(PreActResNet, self).__init__()
+ self.in_planes = 64
+
+ self.conv1 = nn.Conv2d(3,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
+ self.linear = nn.Linear(512 * block.expansion, num_classes)
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = self.layer3(out)
+ out = self.layer4(out)
+ out = F.avg_pool2d(out, 4)
+ out = out.view(out.size(0), -1)
+ out = self.linear(out)
+ return out
+
+
+def PreActResNet18():
+ return PreActResNet(PreActBlock, [2, 2, 2, 2])
+
+
+def PreActResNet34():
+ return PreActResNet(PreActBlock, [3, 4, 6, 3])
+
+
+def PreActResNet50():
+ return PreActResNet(PreActBottleneck, [3, 4, 6, 3])
+
+
+def PreActResNet101():
+ return PreActResNet(PreActBottleneck, [3, 4, 23, 3])
+
+
+def PreActResNet152():
+ return PreActResNet(PreActBottleneck, [3, 8, 36, 3])
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1, norm='bn'):
+ super(BasicBlock, self).__init__()
+ self.norm = norm
+
+ self.conv1 = nn.Conv2d(in_planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+
+ self.bn1 = nn.BatchNorm2d(
+ planes) if self.norm == 'bn' else nn.GroupNorm(
+ 64, planes, affine=True)
+
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+
+ self.bn2 = nn.BatchNorm2d(
+ planes) if self.norm == 'bn' else nn.GroupNorm(
+ 64, planes, affine=True)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(self.expansion * planes) if self.norm == 'bn'
+ else nn.GroupNorm(64, self.expansion * planes, affine=True))
+
+ def forward(self, x):
+ #
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, in_planes, planes, stride=1, norm='bn'):
+ super(Bottleneck, self).__init__()
+
+ self.norm = norm
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+
+ self.bn1 = nn.BatchNorm2d(
+ planes) if self.norm == 'bn' else nn.GroupNorm(
+ 64, planes, affine=True)
+
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+
+ self.bn2 = nn.BatchNorm2d(
+ planes) if self.norm == 'bn' else nn.GroupNorm(
+ 64, planes, affine=True)
+
+ self.conv3 = nn.Conv2d(planes,
+ self.expansion * planes,
+ kernel_size=1,
+ bias=False)
+
+ self.bn3 = nn.BatchNorm2d(
+ self.expansion * planes) if self.norm == 'bn' else nn.GroupNorm(
+ 64, self.expansion * planes, affine=True)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(self.expansion * planes) if self.norm == 'bn'
+ else nn.GroupNorm(64, self.expansion * planes, affine=True))
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = F.relu(self.bn2(self.conv2(out)))
+ out = self.bn3(self.conv3(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self, block, num_blocks, num_classes=10, norm='bn'):
+ super(ResNet, self).__init__()
+
+ self.norm = norm
+ self.in_planes = 64
+
+ self.conv1 = nn.Conv2d(3,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ if self.norm == 'bn':
+ self.bn1 = nn.BatchNorm2d(64)
+ elif self.norm == 'gn':
+ self.bn1 = nn.GroupNorm(64, 64, affine=True)
+
+ self.layer1 = self._make_layer(block,
+ 64,
+ num_blocks[0],
+ stride=1,
+ norm=self.norm)
+ self.layer2 = self._make_layer(block,
+ 128,
+ num_blocks[1],
+ stride=2,
+ norm=self.norm)
+ self.layer3 = self._make_layer(block,
+ 256,
+ num_blocks[2],
+ stride=2,
+ norm=self.norm)
+ self.layer4 = self._make_layer(block,
+ 512,
+ num_blocks[3],
+ stride=2,
+ norm=self.norm)
+ self.linear = nn.Linear(512 * block.expansion, num_classes)
+
+ def _make_layer(self, block, planes, num_blocks, stride, norm):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride, norm))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def feature(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = self.layer3(out)
+ out = self.layer4(out)
+ out = F.avg_pool2d(out, 4)
+ out = out.view(out.size(0), -1)
+
+ return out
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = self.layer3(out)
+ out = self.layer4(out)
+ out = F.avg_pool2d(out, 4)
+ out = out.view(out.size(0), -1)
+ out = self.linear(out)
+ return out
+
+
+def ResNet18():
+ return ResNet(BasicBlock, [2, 2, 2, 2])
+
+
+def ResNet18_GN():
+ return ResNet(BasicBlock, [2, 2, 2, 2], norm='gn')
+
+
+def ResNet34():
+ return ResNet(BasicBlock, [3, 4, 6, 3])
+
+
+def ResNet50():
+ return ResNet(Bottleneck, [3, 4, 6, 3])
+
+
+def ResNet101():
+ return ResNet(Bottleneck, [3, 4, 23, 3])
+
+
+def ResNet152():
+ return ResNet(Bottleneck, [3, 8, 36, 3])
+
+
+def preact_resnet(model_config):
+ if '18' in model_config.type:
+ net = PreActResNet18()
+ elif '50' in model_config.type:
+ net = PreActResNet50()
+ return net
+
+
+def resnet(model_config):
+
+ if '18' in model_config.type and 'gn' in model_config.type:
+ net = ResNet18_GN()
+ elif '18' in model_config.type and 'ln' not in model_config.type and 'in' not in model_config.type:
+ net = ResNet18()
+ #
+ elif '50' in model_config.type and 'ln' not in model_config.type and 'in' not in model_config.type:
+ net = ResNet50()
+ else:
+ net = None
+ #
+ return net
+
+
+def call_resnet(model_config, local_data):
+ if 'resnet' in model_config.type and 'pre' in model_config.type:
+ model = preact_resnet(model_config)
+ return model
+ elif 'resnet' in model_config.type and 'pre' not in model_config.type:
+ model = resnet(model_config)
+ return model
+
+
+register_model('resnet', call_resnet)
diff --git a/federatedscope/contrib/model/resnet_in.py b/federatedscope/contrib/model/resnet_in.py
new file mode 100644
index 000000000..ef9406293
--- /dev/null
+++ b/federatedscope/contrib/model/resnet_in.py
@@ -0,0 +1,312 @@
+from federatedscope.register import register_model
+'''Pre-activation ResNet in PyTorch.
+
+Reference:
+[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+ Identity Mappings in Deep Residual Networks. arXiv:1603.05027
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class PreActBlock(nn.Module):
+ '''Pre-activation version of the BasicBlock.'''
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(PreActBlock, self).__init__()
+ self.bn1 = nn.BatchNorm2d(in_planes)
+ self.conv1 = nn.Conv2d(in_planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False))
+
+ def forward(self, x):
+ out = F.relu(self.bn1(x))
+ shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
+ out = self.conv1(out)
+ out = self.conv2(F.relu(self.bn2(out)))
+ out += shortcut
+ return out
+
+
+class PreActBottleneck(nn.Module):
+ '''Pre-activation version of the original Bottleneck module.'''
+ expansion = 4
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(PreActBottleneck, self).__init__()
+ self.bn1 = nn.BatchNorm2d(in_planes)
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+ self.bn3 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes,
+ self.expansion * planes,
+ kernel_size=1,
+ bias=False)
+
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False))
+
+ def forward(self, x):
+ out = F.relu(self.bn1(x))
+ shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
+ out = self.conv1(out)
+ out = self.conv2(F.relu(self.bn2(out)))
+ out = self.conv3(F.relu(self.bn3(out)))
+ out += shortcut
+ return out
+
+
+class PreActResNet(nn.Module):
+ def __init__(self, block, num_blocks, num_classes=10):
+ super(PreActResNet, self).__init__()
+ self.in_planes = 64
+
+ self.conv1 = nn.Conv2d(3,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
+ self.linear = nn.Linear(512 * block.expansion, num_classes)
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = self.layer3(out)
+ out = self.layer4(out)
+ out = F.avg_pool2d(out, 4)
+ out = out.view(out.size(0), -1)
+ out = self.linear(out)
+ return out
+
+
+def PreActResNet18():
+ return PreActResNet(PreActBlock, [2, 2, 2, 2])
+
+
+def PreActResNet34():
+ return PreActResNet(PreActBlock, [3, 4, 6, 3])
+
+
+def PreActResNet50():
+ return PreActResNet(PreActBottleneck, [3, 4, 6, 3])
+
+
+def PreActResNet101():
+ return PreActResNet(PreActBottleneck, [3, 4, 23, 3])
+
+
+def PreActResNet152():
+ return PreActResNet(PreActBottleneck, [3, 8, 36, 3])
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(BasicBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+
+ self.bn1 = nn.GroupNorm(planes, planes, affine=True)
+
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+
+ self.bn2 = nn.GroupNorm(planes, planes, affine=True)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.GroupNorm(self.expansion * planes,
+ self.expansion * planes,
+ affine=True))
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(Bottleneck, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+
+ self.bn1 = nn.GroupNorm(planes, planes, affine=True)
+
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+
+ self.bn2 = nn.GroupNorm(planes, planes, affine=True)
+
+ self.conv3 = nn.Conv2d(planes,
+ self.expansion * planes,
+ kernel_size=1,
+ bias=False)
+
+ self.bn3 = nn.GroupNorm(self.expansion * planes,
+ self.expansion * planes,
+ affine=True)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.GroupNorm(self.expansion * planes,
+ self.expansion * planes,
+ affine=True))
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = F.relu(self.bn2(self.conv2(out)))
+ out = self.bn3(self.conv3(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self, block, num_blocks, num_classes=10):
+ super(ResNet, self).__init__()
+
+ # self.norm = norm
+ self.in_planes = 64
+
+ self.conv1 = nn.Conv2d(3,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+
+ self.bn1 = nn.GroupNorm(64, 64, affine=True)
+
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
+ self.linear = nn.Linear(512 * block.expansion, num_classes)
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = self.layer3(out)
+ out = self.layer4(out)
+ out = F.avg_pool2d(out, 4)
+ out = out.view(out.size(0), -1)
+ out = self.linear(out)
+ return out
+
+
+def ResNet18_IN():
+ return ResNet(BasicBlock, [2, 2, 2, 2])
+
+
+def ResNet34_IN():
+ return ResNet(BasicBlock, [3, 4, 6, 3])
+
+
+def ResNet50_IN():
+ return ResNet(Bottleneck, [3, 4, 6, 3])
+
+
+def resnet_in(model_config):
+
+ if '18' in model_config.type:
+ net = ResNet18_IN()
+ elif '50' in model_config.type:
+ net = ResNet50_IN()
+ else:
+ net = None
+ return net
+
+
+def call_resnet_in(model_config, local_data):
+
+ if 'in' in model_config.type:
+ model = resnet_in(model_config)
+ return model
+
+
+register_model('resnet_in', call_resnet_in)
diff --git a/federatedscope/contrib/model/resnet_ln.py b/federatedscope/contrib/model/resnet_ln.py
new file mode 100644
index 000000000..8188a23c3
--- /dev/null
+++ b/federatedscope/contrib/model/resnet_ln.py
@@ -0,0 +1,227 @@
+'''ResNet in PyTorch.
+For Pre-activation ResNet, see 'preact_resnet.py'.
+Reference:
+[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
+'''
+from federatedscope.register import register_model
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class FilterResponseNormNd(nn.Module):
+ def __init__(self, ndim, num_features, eps=1e-6, learnable_eps=False):
+ """
+ Input Variables:
+ ----------------
+ ndim: An integer indicating the number of dimensions of the expected input tensor.
+ num_features: An integer indicating the number of input feature dimensions.
+ eps: A scalar constant or learnable variable.
+ learnable_eps: A bool value indicating whether the eps is learnable.
+ """
+ assert ndim in [3, 4, 5], \
+ 'FilterResponseNorm only supports 3d, 4d or 5d inputs.'
+ super(FilterResponseNormNd, self).__init__()
+ shape = (1, num_features) + (1, ) * (ndim - 2)
+ self.eps = nn.Parameter(torch.ones(*shape) * eps)
+ if not learnable_eps:
+ self.eps.requires_grad_(False)
+ self.gamma = nn.Parameter(torch.Tensor(*shape))
+ self.beta = nn.Parameter(torch.Tensor(*shape))
+ self.tau = nn.Parameter(torch.Tensor(*shape))
+ self.reset_parameters()
+
+ def forward(self, x):
+ avg_dims = tuple(range(2, x.dim()))
+ nu2 = torch.pow(x, 2).mean(dim=avg_dims, keepdim=True)
+ x = x * torch.rsqrt(nu2 + torch.abs(self.eps))
+ return torch.max(self.gamma * x + self.beta, self.tau)
+
+ def reset_parameters(self):
+ nn.init.ones_(self.gamma)
+ nn.init.zeros_(self.beta)
+ nn.init.zeros_(self.tau)
+
+
+class FilterResponseNorm1d(FilterResponseNormNd):
+ def __init__(self, num_features, eps=1e-6, learnable_eps=False):
+ super(FilterResponseNorm1d, self).__init__(3,
+ num_features,
+ eps=eps,
+ learnable_eps=learnable_eps)
+
+
+class FilterResponseNorm2d(FilterResponseNormNd):
+ def __init__(self, num_features, eps=1e-6, learnable_eps=False):
+ super(FilterResponseNorm2d, self).__init__(4,
+ num_features,
+ eps=eps,
+ learnable_eps=learnable_eps)
+
+
+class FilterResponseNorm3d(FilterResponseNormNd):
+ def __init__(self, num_features, eps=1e-6, learnable_eps=False):
+ super(FilterResponseNorm3d, self).__init__(5,
+ num_features,
+ eps=eps,
+ learnable_eps=learnable_eps)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(in_planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+ self.bn1 = FilterResponseNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.bn2 = FilterResponseNorm2d(planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ FilterResponseNorm2d(self.expansion * planes))
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+ self.bn1 = FilterResponseNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+ self.bn2 = FilterResponseNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes,
+ self.expansion * planes,
+ kernel_size=1,
+ bias=False)
+ self.bn3 = FilterResponseNorm2d(self.expansion * planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ FilterResponseNorm2d(self.expansion * planes))
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = F.relu(self.bn2(self.conv2(out)))
+ out = self.bn3(self.conv3(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class ResNet_LN(nn.Module):
+ def __init__(self, block, num_blocks, num_classes=10):
+ super(ResNet_LN, self).__init__()
+ self.in_planes = 64
+
+ self.conv1 = nn.Conv2d(3,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.bn1 = FilterResponseNorm2d(64)
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
+ self.linear = nn.Linear(512 * block.expansion, num_classes)
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = self.layer3(out)
+ out = self.layer4(out)
+ out = F.avg_pool2d(out, 4)
+ out = out.view(out.size(0), -1)
+ out = self.linear(out)
+ return out
+
+
+def ResNet18_LN():
+ return ResNet_LN(BasicBlock, [2, 2, 2, 2])
+
+
+def ResNet34_LN():
+ return ResNet_LN(BasicBlock, [3, 4, 6, 3])
+
+
+def ResNet50_LN():
+ return ResNet_LN(Bottleneck, [3, 4, 6, 3])
+
+
+def ResNet101_LN():
+ return ResNet_LN(Bottleneck, [3, 4, 23, 3])
+
+
+def ResNet152_LN():
+ return ResNet_LN(Bottleneck, [3, 8, 36, 3])
+
+
+def resnet_ln(model_config):
+
+ if '18' in model_config.type:
+ net = ResNet18_LN()
+ elif '50' in model_config.type:
+ net = ResNet50_LN()
+
+ else:
+ net = None
+
+ #
+ return net
+
+
+def call_resnet_ln(model_config, local_data):
+ if 'ln' in model_config.type:
+ model = resnet_ln(model_config)
+ #
+ return model
+
+
+register_model('resnet_ln', call_resnet_ln)
diff --git a/federatedscope/contrib/model/vgg.py b/federatedscope/contrib/model/vgg.py
new file mode 100644
index 000000000..7be85e622
--- /dev/null
+++ b/federatedscope/contrib/model/vgg.py
@@ -0,0 +1,82 @@
+import torch
+import torch.nn as nn
+
+from federatedscope.register import register_model
+
+cfg = {
+ 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+ 'VGG13': [
+ 64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'
+ ],
+ 'VGG16': [
+ 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
+ 512, 512, 512, 'M'
+ ],
+ 'VGG19': [
+ 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512,
+ 512, 'M', 512, 512, 512, 512, 'M'
+ ],
+}
+
+
+class VGG(nn.Module):
+ def __init__(self, vgg_name, channel, num_classes):
+ super(VGG, self).__init__()
+ self.channel = channel
+ self.features = self._make_layers(cfg[vgg_name])
+ self.linear = nn.Linear(512, num_classes)
+
+ def forward(self, x):
+ out = self.features(x)
+ out = out.view(out.size(0), -1)
+ out = self.linear(out)
+ return out
+
+ def _make_layers(self, cfg):
+ layers = []
+ in_channels = self.channel
+ for x in cfg:
+ if x == 'M':
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
+ else:
+ layers += [
+ nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
+ nn.BatchNorm2d(x),
+ nn.ReLU(inplace=True)
+ ]
+ in_channels = x
+ layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
+ return nn.Sequential(*layers)
+
+
+def VGG11(channel=3, num_classe=10):
+ return VGG('VGG11', channel=channel, num_classes=num_classe)
+
+
+def VGG13(channel=3, num_classe=10):
+ return VGG('VGG13', channel=channel, num_classes=num_classe)
+
+
+def VGG16(channel, num_classes):
+ return VGG('VGG16', channel, num_classes)
+
+
+def VGG19(channel, num_classes):
+ return VGG('VGG19', channel, num_classes)
+
+
+def vgg(model_config):
+ if '11' in model_config.type:
+ net = VGG11()
+ elif '13' in model_config.type:
+ net = VGG13()
+ return net
+
+
+def call_vgg(model_config, local_data):
+ if 'vgg' in model_config.type:
+ model = vgg(model_config)
+ return model
+
+
+register_model('vgg', call_vgg)
diff --git a/federatedscope/contrib/trainer/__init__.py b/federatedscope/contrib/trainer/__init__.py
new file mode 100644
index 000000000..c0b31382d
--- /dev/null
+++ b/federatedscope/contrib/trainer/__init__.py
@@ -0,0 +1,8 @@
+from os.path import dirname, basename, isfile, join
+import glob
+
+modules = glob.glob(join(dirname(__file__), "*.py"))
+__all__ = [
+ basename(f)[:-3] for f in modules
+ if isfile(f) and not f.endswith('__init__.py')
+]
diff --git a/federatedscope/contrib/trainer/example.py b/federatedscope/contrib/trainer/example.py
new file mode 100644
index 000000000..ffc207372
--- /dev/null
+++ b/federatedscope/contrib/trainer/example.py
@@ -0,0 +1,16 @@
+from federatedscope.register import register_trainer
+from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
+
+
+# Build your trainer here.
+class MyTrainer(GeneralTorchTrainer):
+ pass
+
+
+def call_my_trainer(trainer_type):
+ if trainer_type == 'mytrainer':
+ trainer_builder = MyTrainer
+ return trainer_builder
+
+
+register_trainer('mytrainer', call_my_trainer)
diff --git a/federatedscope/core/__init__.py b/federatedscope/core/__init__.py
new file mode 100644
index 000000000..f8e91f237
--- /dev/null
+++ b/federatedscope/core/__init__.py
@@ -0,0 +1,3 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
diff --git a/federatedscope/core/aggregator.py b/federatedscope/core/aggregator.py
new file mode 100644
index 000000000..56610c148
--- /dev/null
+++ b/federatedscope/core/aggregator.py
@@ -0,0 +1,290 @@
+from abc import ABC, abstractmethod
+from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
+
+import torch
+import os
+import numpy as np
+import logging
+
+logger = logging.getLogger(__name__)
+
+# def vectorize_net(net):
+# return torch.cat([p.view(-1) for p in net.parameters()])
+
+
+def vectorize_net_dict(net):
+ return torch.cat([net[key].view(-1) for key in net])
+
+
+# def load_model_weight(net, weight):
+# index_bias = 0
+# for p_index, p in enumerate(net.parameters()):
+# p.data = weight[index_bias:index_bias+p.numel()].view(p.size())
+# index_bias += p.numel()
+
+
+def load_model_weight_dict(net, weight):
+ index_bias = 0
+ for p_index, p in net.items():
+ net[p_index].data = weight[index_bias:index_bias + p.numel()].view(
+ p.size())
+ index_bias += p.numel()
+
+
+class Aggregator(ABC):
+ def __init__(self):
+ pass
+
+ @abstractmethod
+ def aggregate(self, agg_info):
+ pass
+
+
+class ClientsAvgAggregator(Aggregator):
+ def __init__(self, model=None, device='cpu', config=None):
+ super(Aggregator, self).__init__()
+ self.model = model
+ self.device = device
+ self.cfg = config
+
+ def aggregate(self, agg_info):
+ """
+ To preform aggregation
+
+ Arguments:
+ agg_info (dict): the feedbacks from clients
+ :returns: the aggregated results
+ :rtype: dict
+ """
+ models = agg_info["client_feedback"]
+ recover_fun = agg_info['recover_fun'] if (
+ 'recover_fun' in agg_info and self.cfg.federate.use_ss) else None
+
+ if self.cfg.attack.krum or self.cfg.attack.multi_krum:
+ avg_model = self._para_krum_avg(models)
+ else:
+ avg_model = self._para_weighted_avg(models,
+ recover_fun=recover_fun)
+
+ return avg_model
+
+ def update(self, model_parameters):
+ '''
+ Arguments:
+ model_parameters (dict): PyTorch Module object's state_dict.
+ '''
+ self.model.load_state_dict(model_parameters, strict=False)
+
+ def save_model(self, path, cur_round=-1):
+ assert self.model is not None
+ ckpt = {'cur_round': cur_round, 'model': self.model.state_dict()}
+ torch.save(ckpt, path)
+
+ def load_model(self, path):
+ assert self.model is not None
+
+ if os.path.exists(path):
+ ckpt = torch.load(path, map_location=self.device)
+ self.model.load_state_dict(ckpt['model'])
+ return ckpt['cur_round']
+ else:
+ raise ValueError("The file {} does NOT exist".format(path))
+
+ def _para_weighted_avg(self, models, recover_fun=None):
+ training_set_size = 0
+ for i in range(len(models)):
+ sample_size, _ = models[i]
+ training_set_size += sample_size
+
+ sample_size, avg_model = models[0]
+ for key in avg_model:
+ for i in range(len(models)):
+ local_sample_size, local_model = models[i]
+
+ if self.cfg.federate.ignore_weight:
+ weight = 1.0 / len(models)
+ elif self.cfg.federate.use_ss:
+ weight = 1.0
+ else:
+ weight = local_sample_size / training_set_size
+
+ if not self.cfg.federate.use_ss:
+ if isinstance(local_model[key], torch.Tensor):
+ local_model[key] = local_model[key].float()
+ else:
+ local_model[key] = torch.FloatTensor(local_model[key])
+
+ if i == 0:
+ avg_model[key] = local_model[key] * weight
+ else:
+ avg_model[key] += local_model[key] * weight
+
+ if self.cfg.federate.use_ss and recover_fun:
+ avg_model[key] = recover_fun(avg_model[key])
+ avg_model[key] /= training_set_size
+ avg_model[key] = torch.FloatTensor(avg_model[key])
+
+ return avg_model
+
+ def _para_krum_avg(self, models):
+
+ num_workers = len(models)
+ num_adv = 1
+
+ num_dps = []
+ vectorize_nets = []
+ for i in range(len(models)):
+ sample_size, local_model = models[i]
+ # training_set_size += sample_size
+ num_dps.append(sample_size)
+ vectorize_nets.append(
+ vectorize_net_dict(local_model).detach().cpu().numpy())
+
+ neighbor_distances = []
+ for i, g_i in enumerate(vectorize_nets):
+ distance = []
+ for j in range(i + 1, len(vectorize_nets)):
+ if i != j:
+ g_j = vectorize_nets[j]
+ distance.append(float(np.linalg.norm(g_i - g_j)**2))
+ neighbor_distances.append(distance)
+
+ # compute scores
+ nb_in_score = num_workers - num_adv - 2
+ scores = []
+ for i, g_i in enumerate(vectorize_nets):
+ dists = []
+ for j, g_j in enumerate(vectorize_nets):
+ if j == i:
+ continue
+ if j < i:
+ dists.append(neighbor_distances[j][i - j - 1])
+ else:
+ dists.append(neighbor_distances[i][j - i - 1])
+ # alternative to topk in pytorch and tensorflow
+ topk_ind = np.argpartition(dists, nb_in_score)[:nb_in_score]
+ scores.append(sum(np.take(dists, topk_ind)))
+
+ if self.cfg.attack.krum:
+ i_star = scores.index(min(scores))
+ _, aggregated_model = models[
+ 0] # slicing which doesn't really matter
+ load_model_weight_dict(aggregated_model,
+ torch.from_numpy(vectorize_nets[i_star]))
+ # neo_net_list = [aggregated_model]
+ logger.info("Norm of Aggregated Model: {}".format(
+ torch.norm(torch.from_numpy(vectorize_nets[i_star])).item()))
+ # neo_net_freq = [1.0]
+ # return neo_net_list, neo_net_freq
+ return aggregated_model
+
+ elif self.cfg.attack.multi_krum:
+ topk_ind = np.argpartition(scores,
+ nb_in_score + 2)[:nb_in_score + 2]
+
+ # we reconstruct the weighted averaging here:
+ selected_num_dps = np.array(num_dps)[topk_ind]
+ reconstructed_freq = [
+ snd / sum(selected_num_dps) for snd in selected_num_dps
+ ]
+
+ logger.info("Num data points: {}".format(num_dps))
+ logger.info(
+ "Num selected data points: {}".format(selected_num_dps))
+
+ aggregated_grad = np.average(np.array(vectorize_nets)[topk_ind, :],
+ weights=reconstructed_freq,
+ axis=0).astype(np.float32)
+ _, aggregated_model = models[
+ 0] # slicing which doesn't really matter
+ load_model_weight_dict(aggregated_model,
+ torch.from_numpy(aggregated_grad))
+ # neo_net_list = [aggregated_model]
+ logger.info("Norm of Aggregated Model: {}".format(
+ torch.norm(torch.from_numpy(aggregated_grad)).item()))
+ # neo_net_freq = [1.0]
+ # return neo_net_list, neo_net_freq
+ return aggregated_model
+
+
+class NoCommunicationAggregator(Aggregator):
+ """"Clients do not communicate. Each client work locally
+ """
+ def aggregate(self, agg_info):
+ # do nothing
+ return {}
+
+
+class OnlineClientsAvgAggregator(ClientsAvgAggregator):
+ def __init__(self,
+ model=None,
+ device='cpu',
+ src_device='cpu',
+ config=None):
+ super(OnlineClientsAvgAggregator, self).__init__(model, device, config)
+ self.src_device = src_device
+
+ def reset(self):
+ self.maintained = self.model.state_dict()
+ for key in self.maintained:
+ self.maintained[key].data = torch.zeros_like(
+ self.maintained[key], device=self.src_device)
+ self.cnt = 0
+
+ def inc(self, content):
+ if isinstance(content, tuple):
+ sample_size, model_params = content
+ for key in self.maintained:
+ self.maintained[key] = (self.cnt * self.maintained[key] +
+ sample_size * model_params[key]) / (
+ self.cnt + sample_size)
+ self.cnt += sample_size
+ else:
+ raise TypeError(
+ "{} is not a tuple (sample_size, model_para)".format(content))
+
+ def aggregate(self, agg_info):
+ return self.maintained
+
+
+class ServerClientsInterpolateAggregator(ClientsAvgAggregator):
+ def __init__(self, model=None, device='cpu', config=None, beta=1.0):
+ super(ServerClientsInterpolateAggregator,
+ self).__init__(model, device, config)
+ self.beta = beta # the weight for local models used in interpolation
+
+ def aggregate(self, agg_info):
+ models = agg_info["client_feedback"]
+ global_model = self.model
+ elem_each_client = next(iter(models))
+ assert len(elem_each_client) == 2, f"Require (sample_size, model_para) \
+ tuple for each client, " \
+ f"i.e., len=2, but got len={len(elem_each_client)}"
+ avg_model_by_clients = self._para_weighted_avg(models)
+ global_local_models = [((1 - self.beta), global_model.state_dict()),
+ (self.beta, avg_model_by_clients)]
+
+ avg_model_by_interpolate = self._para_weighted_avg(global_local_models)
+ return avg_model_by_interpolate
+
+
+class FedOptAggregator(ClientsAvgAggregator):
+ def __init__(self, config, model, device='cpu'):
+ super(FedOptAggregator, self).__init__(model, device, config)
+ self.optimizer = get_optimizer(model=self.model,
+ **config.fedopt.optimizer)
+
+ def aggregate(self, agg_info):
+ new_model = super().aggregate(agg_info)
+
+ model = self.model.cpu().state_dict()
+ with torch.no_grad():
+ grads = {key: model[key] - new_model[key] for key in new_model}
+
+ self.optimizer.zero_grad()
+ for key, p in self.model.named_parameters():
+ if key in new_model.keys():
+ p.grad = grads[key]
+ self.optimizer.step()
+
+ return self.model.state_dict()
diff --git a/federatedscope/core/auxiliaries/ReIterator.py b/federatedscope/core/auxiliaries/ReIterator.py
new file mode 100644
index 000000000..f8ce40874
--- /dev/null
+++ b/federatedscope/core/auxiliaries/ReIterator.py
@@ -0,0 +1,19 @@
+class ReIterator:
+ def __init__(self, loader):
+ self.loader = loader
+ self.iterator = iter(loader)
+ self.reset_flag = False
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ try:
+ item = next(self.iterator)
+ except StopIteration:
+ self.reset()
+ item = next(self.iterator)
+ return item
+
+ def reset(self):
+ self.iterator = iter(self.loader)
diff --git a/federatedscope/core/auxiliaries/__init__.py b/federatedscope/core/auxiliaries/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/federatedscope/core/auxiliaries/aggregator_builder.py b/federatedscope/core/auxiliaries/aggregator_builder.py
new file mode 100644
index 000000000..15394ef1e
--- /dev/null
+++ b/federatedscope/core/auxiliaries/aggregator_builder.py
@@ -0,0 +1,49 @@
+import logging
+
+from federatedscope.core.configs import constants
+
+logger = logging.getLogger(__name__)
+
+
+def get_aggregator(method, model=None, device=None, online=False, config=None):
+ if config.backend == 'tensorflow':
+ from federatedscope.cross_backends import FedAvgAggregator
+ return FedAvgAggregator(model=model, device=device)
+ else:
+ from federatedscope.core.aggregator import ClientsAvgAggregator, \
+ OnlineClientsAvgAggregator, ServerClientsInterpolateAggregator, \
+ FedOptAggregator, NoCommunicationAggregator
+
+ if method.lower() in constants.AGGREGATOR_TYPE:
+ aggregator_type = constants.AGGREGATOR_TYPE[method.lower()]
+ else:
+ aggregator_type = "clients_avg"
+ logger.warning(
+ 'Aggregator for method {} is not implemented. Will use default one'
+ .format(method))
+
+ if config.fedopt.use or aggregator_type == 'fedopt':
+ return FedOptAggregator(config=config, model=model, device=device)
+ elif aggregator_type == 'clients_avg':
+ if online:
+ return OnlineClientsAvgAggregator(
+ model=model,
+ device=device,
+ config=config,
+ src_device=device
+ if config.federate.share_local_model else 'cpu')
+ else:
+ return ClientsAvgAggregator(model=model,
+ device=device,
+ config=config)
+ elif aggregator_type == 'server_clients_interpolation':
+ return ServerClientsInterpolateAggregator(
+ model=model,
+ device=device,
+ config=config,
+ beta=config.personalization.beta)
+ elif aggregator_type == 'no_communication':
+ return NoCommunicationAggregator()
+ else:
+ raise NotImplementedError(
+ "Aggregator {} is not implemented.".format(aggregator_type))
diff --git a/federatedscope/core/auxiliaries/criterion_builder.py b/federatedscope/core/auxiliaries/criterion_builder.py
new file mode 100644
index 000000000..1502fc0d3
--- /dev/null
+++ b/federatedscope/core/auxiliaries/criterion_builder.py
@@ -0,0 +1,23 @@
+import federatedscope.register as register
+
+try:
+ from torch import nn
+ from federatedscope.nlp.loss import *
+except ImportError:
+ nn = None
+
+
+def get_criterion(type, device):
+ for func in register.criterion_dict.values():
+ criterion = func(type, device)
+ if criterion is not None:
+ return criterion
+
+ if isinstance(type, str):
+ if hasattr(nn, type):
+ return getattr(nn, type)()
+ else:
+ raise NotImplementedError(
+ 'Criterion {} not implement'.format(type))
+ else:
+ raise TypeError()
diff --git a/federatedscope/core/auxiliaries/data_builder.py b/federatedscope/core/auxiliaries/data_builder.py
new file mode 100644
index 000000000..c105349ad
--- /dev/null
+++ b/federatedscope/core/auxiliaries/data_builder.py
@@ -0,0 +1,809 @@
+import math
+import os
+import pickle
+import logging
+import numpy as np
+from collections import defaultdict
+
+import federatedscope.register as register
+
+logger = logging.getLogger(__name__)
+
+try:
+ from federatedscope.contrib.data import *
+except ImportError as error:
+ logger.warning(
+ f'{error} in `federatedscope.contrib.data`, some modules are not available.'
+ )
+
+
+def load_toy_data(config=None):
+
+ generate = config.federate.mode.lower() == 'standalone'
+
+ def _generate_data(client_num=5,
+ instance_num=1000,
+ feature_num=5,
+ save_data=False):
+ """
+ Generate data in FedRunner format
+ Args:
+ client_num:
+ instance_num:
+ feature_num:
+ save_data:
+
+ Returns:
+ {
+ '{client_id}': {
+ 'train': {
+ 'x': ...,
+ 'y': ...
+ },
+ 'test': {
+ 'x': ...,
+ 'y': ...
+ },
+ 'val': {
+ 'x': ...,
+ 'y': ...
+ }
+ }
+ }
+
+ """
+ weights = np.random.normal(loc=0.0, scale=1.0, size=feature_num)
+ bias = np.random.normal(loc=0.0, scale=1.0)
+ data = dict()
+ for each_client in range(1, client_num + 1):
+ data[each_client] = dict()
+ client_x = np.random.normal(loc=0.0,
+ scale=0.5 * each_client,
+ size=(instance_num, feature_num))
+ client_y = np.sum(client_x * weights, axis=-1) + bias
+ client_y = np.expand_dims(client_y, -1)
+ client_data = {'x': client_x, 'y': client_y}
+ data[each_client]['train'] = client_data
+
+ # test data
+ test_x = np.random.normal(loc=0.0,
+ scale=1.0,
+ size=(instance_num, feature_num))
+ test_y = np.sum(test_x * weights, axis=-1) + bias
+ test_y = np.expand_dims(test_y, -1)
+ test_data = {'x': test_x, 'y': test_y}
+ for each_client in range(1, client_num + 1):
+ data[each_client]['test'] = test_data
+
+ # val data
+ val_x = np.random.normal(loc=0.0,
+ scale=1.0,
+ size=(instance_num, feature_num))
+ val_y = np.sum(val_x * weights, axis=-1) + bias
+ val_y = np.expand_dims(val_y, -1)
+ val_data = {'x': val_x, 'y': val_y}
+ for each_client in range(1, client_num + 1):
+ data[each_client]['val'] = val_data
+
+ # server_data
+ data[0] = dict()
+ data[0]['train'] = None
+ data[0]['val'] = val_data
+ data[0]['test'] = test_data
+
+ if save_data:
+ # server_data = dict()
+ save_client_data = dict()
+
+ for client_idx in range(0, client_num + 1):
+ if client_idx == 0:
+ filename = 'data/server_data'
+ else:
+ filename = 'data/client_{:d}_data'.format(client_idx)
+ with open(filename, 'wb') as f:
+ save_client_data['train'] = {
+ k: v.tolist()
+ for k, v in data[client_idx]['train'].items()
+ }
+ save_client_data['val'] = {
+ k: v.tolist()
+ for k, v in data[client_idx]['val'].items()
+ }
+ save_client_data['test'] = {
+ k: v.tolist()
+ for k, v in data[client_idx]['test'].items()
+ }
+ pickle.dump(save_client_data, f)
+
+ return data
+
+ if generate:
+ data = _generate_data(client_num=config.federate.client_num,
+ save_data=config.eval.save_data)
+ else:
+ with open(config.distribute.data_file, 'rb') as f:
+ data = pickle.load(f)
+ for key in data.keys():
+ data[key] = {k: np.asarray(v)
+ for k, v in data[key].items()
+ } if data[key] is not None else None
+
+ return data, config
+
+
+def load_external_data(config=None):
+ r""" Based on the configuration file, this function imports external datasets and applies train/valid/test splits
+ and split by some specific `splitter` into the standard FederatedScope input data format.
+
+ Args:
+ config: `CN` from `federatedscope/core/configs/config.py`
+
+ Returns:
+ data_local_dict: dict of split dataloader.
+ Format:
+ {
+ 'client_id': {
+ 'train': DataLoader(),
+ 'test': DataLoader(),
+ 'val': DataLoader()
+ }
+ }
+ modified_config: `CN` from `federatedscope/core/configs/config.py`, which might be modified in the function.
+
+ """
+
+ import torch
+ import inspect
+ from importlib import import_module
+ from torch.utils.data import DataLoader
+ from federatedscope.core.auxiliaries.splitter_builder import get_splitter
+ from federatedscope.core.auxiliaries.transform_builder import get_transform
+
+ def get_func_args(func):
+ sign = inspect.signature(func).parameters.values()
+ sign = set([val.name for val in sign])
+ return sign
+
+ def filter_dict(func, kwarg):
+ sign = get_func_args(func)
+ common_args = sign.intersection(kwarg.keys())
+ filtered_dict = {key: kwarg[key] for key in common_args}
+ return filtered_dict
+
+ def load_torchvision_data(name, splits=None, config=None):
+ dataset_func = getattr(import_module('torchvision.datasets'), name)
+ transform_funcs = get_transform(config, 'torchvision')
+ if config.data.args:
+ raw_args = config.data.args[0]
+ else:
+ raw_args = {}
+ if 'download' not in raw_args.keys():
+ raw_args.update({'download': True})
+ filtered_args = filter_dict(dataset_func.__init__, raw_args)
+ func_args = get_func_args(dataset_func.__init__)
+ # Perform split on different dataset
+ if 'train' in func_args:
+ # Split train to (train, val)
+ dataset_train = dataset_func(root=config.data.root,
+ train=True,
+ **filtered_args,
+ **transform_funcs)
+ dataset_val = None
+ dataset_test = dataset_func(root=config.data.root,
+ train=False,
+ **filtered_args,
+ **transform_funcs)
+ if splits:
+ train_size = int(splits[0] * len(dataset_train))
+ val_size = len(dataset_train) - train_size
+ lengths = [train_size, val_size]
+ dataset_train, dataset_val = torch.utils.data.dataset.random_split(
+ dataset_train, lengths)
+
+ elif 'split' in func_args:
+ # Use raw split
+ dataset_train = dataset_func(root=config.data.root,
+ split='train',
+ **filtered_args,
+ **transform_funcs)
+ dataset_val = dataset_func(root=config.data.root,
+ split='valid',
+ **filtered_args,
+ **transform_funcs)
+ dataset_test = dataset_func(root=config.data.root,
+ split='test',
+ **filtered_args,
+ **transform_funcs)
+ elif 'classes' in func_args:
+ # Use raw split
+ dataset_train = dataset_func(root=config.data.root,
+ classes='train',
+ **filtered_args,
+ **transform_funcs)
+ dataset_val = dataset_func(root=config.data.root,
+ classes='valid',
+ **filtered_args,
+ **transform_funcs)
+ dataset_test = dataset_func(root=config.data.root,
+ classes='test',
+ **filtered_args,
+ **transform_funcs)
+ else:
+ # Use config.data.splits
+ dataset = dataset_func(root=config.data.root,
+ **filtered_args,
+ **transform_funcs)
+ train_size = int(splits[0] * len(dataset))
+ val_size = int(splits[1] * len(dataset))
+ test_size = len(dataset) - train_size - val_size
+ lengths = [train_size, val_size, test_size]
+ dataset_train, dataset_val, dataset_test = torch.utils.data.dataset.random_split(
+ dataset, lengths)
+
+ data_dict = {
+ 'train': dataset_train,
+ 'val': dataset_val,
+ 'test': dataset_test
+ }
+
+ return data_dict
+
+ def load_torchtext_data(name, splits=None, config=None):
+ from torch.nn.utils.rnn import pad_sequence
+ from federatedscope.nlp.dataset.utils import label_to_index
+
+ dataset_func = getattr(import_module('torchtext.datasets'), name)
+ if config.data.args:
+ raw_args = config.data.args[0]
+ else:
+ raw_args = {}
+ assert 'max_len' in raw_args, "Miss key 'max_len' in `config.data.args`."
+ filtered_args = filter_dict(dataset_func.__init__, raw_args)
+ dataset = dataset_func(root=config.data.root, **filtered_args)
+
+ # torchtext.transforms requires >= 0.12.0 and torch = 1.11.0,
+ # so we do not use `get_transform` in torchtext.
+
+ # Merge all data and tokenize
+ x_list = []
+ y_list = []
+ for data_iter in dataset:
+ data, targets = [], []
+ for i, item in enumerate(data_iter):
+ data.append(item[1])
+ targets.append(item[0])
+ x_list.append(data)
+ y_list.append(targets)
+
+ x_all, y_all = [], []
+ for i in range(len(x_list)):
+ x_all += x_list[i]
+ y_all += y_list[i]
+
+ if config.model.type.endswith('transformers'):
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ config.model.type.split('@')[0])
+
+ x_all = tokenizer(x_all,
+ return_tensors='pt',
+ padding=True,
+ truncation=True,
+ max_length=raw_args['max_len'])
+ data = [{key: value[i]
+ for key, value in x_all.items()}
+ for i in range(len(next(iter(x_all.values()))))]
+ if 'classification' in config.model.task.lower():
+ targets = label_to_index(y_all)
+ else:
+ y_all = tokenizer(y_all,
+ return_tensors='pt',
+ padding=True,
+ truncation=True,
+ max_length=raw_args['max_len'])
+ targets = [{key: value[i]
+ for key, value in y_all.items()}
+ for i in range(len(next(iter(y_all.values()))))]
+ else:
+ from torchtext.data import get_tokenizer
+ tokenizer = get_tokenizer("basic_english")
+ if len(config.data.transform) == 0:
+ raise ValueError(
+ "`transform` must be one pretrained Word Embeddings from \
+ ['GloVe', 'FastText', 'CharNGram']")
+ if len(config.data.transform) == 1:
+ config.data.transform.append({})
+ vocab = getattr(import_module('torchtext.vocab'),
+ config.data.transform[0])(
+ dim=config.model.in_channels,
+ **config.data.transform[1])
+
+ if 'classification' in config.model.task.lower():
+ data = [
+ vocab.get_vecs_by_tokens(tokenizer(x),
+ lower_case_backup=True)
+ for x in x_all
+ ]
+ targets = label_to_index(y_all)
+ else:
+ data = [
+ vocab.get_vecs_by_tokens(tokenizer(x),
+ lower_case_backup=True)
+ for x in x_all
+ ]
+ targets = [
+ vocab.get_vecs_by_tokens(tokenizer(y),
+ lower_case_backup=True)
+ for y in y_all
+ ]
+ targets = pad_sequence(targets).transpose(
+ 0, 1)[:, :raw_args['max_len'], :]
+ data = pad_sequence(data).transpose(0,
+ 1)[:, :raw_args['max_len'], :]
+ # Split data to raw
+ num_items = [len(ds) for ds in x_list]
+ data_list, cnt = [], 0
+ for num in num_items:
+ data_list.append([
+ (x, y)
+ for x, y in zip(data[cnt:cnt + num], targets[cnt:cnt + num])
+ ])
+ cnt += num
+
+ if len(data_list) == 3:
+ # Use raw splits
+ data_dict = {
+ 'train': data_list[0],
+ 'val': data_list[1],
+ 'test': data_list[2]
+ }
+ elif len(data_list) == 2:
+ # Split train to (train, val)
+ data_dict = {
+ 'train': data_list[0],
+ 'val': None,
+ 'test': data_list[1]
+ }
+ if splits:
+ train_size = int(splits[0] * len(data_dict['train']))
+ val_size = len(data_dict['train']) - train_size
+ lengths = [train_size, val_size]
+ data_dict['train'], data_dict[
+ 'val'] = torch.utils.data.dataset.random_split(
+ data_dict['train'], lengths)
+ else:
+ # Use config.data.splits
+ data_dict = {}
+ train_size = int(splits[0] * len(data_list[0]))
+ val_size = int(splits[1] * len(data_list[0]))
+ test_size = len(data_list[0]) - train_size - val_size
+ lengths = [train_size, val_size, test_size]
+ data_dict['train'], data_dict['val'], data_dict[
+ 'test'] = torch.utils.data.dataset.random_split(
+ data_list[0], lengths)
+
+ return data_dict
+
+ def load_torchaudio_data(name, splits=None, config=None):
+ import torchaudio
+
+ dataset_func = getattr(import_module('torchaudio.datasets'), name)
+ raise NotImplementedError
+
+ def load_torch_geometric_data(name, splits=None, config=None):
+ import torch_geometric
+
+ dataset_func = getattr(import_module('torch_geometric.datasets'), name)
+ raise NotImplementedError
+
+ def load_huggingface_datasets_data(name, splits=None, config=None):
+ from datasets import load_dataset
+
+ if config.data.args:
+ raw_args = config.data.args[0]
+ else:
+ raw_args = {}
+ assert 'max_len' in raw_args, "Miss key 'max_len' in `config.data.args`."
+ filtered_args = filter_dict(load_dataset, raw_args)
+ dataset = load_dataset(path=config.data.root,
+ name=name,
+ **filtered_args)
+ if config.model.type.endswith('transformers'):
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ config.model.type.split('@')[0])
+
+ for split in dataset:
+ x_all = [i['sentence'] for i in dataset[split]]
+ targets = [i['label'] for i in dataset[split]]
+
+ x_all = tokenizer(x_all,
+ return_tensors='pt',
+ padding=True,
+ truncation=True,
+ max_length=raw_args['max_len'])
+ data = [{key: value[i]
+ for key, value in x_all.items()}
+ for i in range(len(next(iter(x_all.values()))))]
+ dataset[split] = (data, targets)
+ data_dict = {
+ 'train': [(x, y)
+ for x, y in zip(dataset['train'][0], dataset['train'][1])
+ ],
+ 'val': [(x, y) for x, y in zip(dataset['validation'][0],
+ dataset['validation'][1])],
+ 'test': [
+ (x, y) for x, y in zip(dataset['test'][0], dataset['test'][1])
+ ] if (set(dataset['test'][1]) - set([-1])) else None,
+ }
+ return data_dict
+
+ def load_openml_data(tid, splits=None, config=None):
+ import openml
+ from sklearn.model_selection import train_test_split
+
+ task = openml.tasks.get_task(int(tid))
+ did = task.dataset_id
+ dataset = openml.datasets.get_dataset(did)
+ data, targets, _, _ = dataset.get_data(
+ dataset_format="array", target=dataset.default_target_attribute)
+
+ train_data, test_data, train_targets, test_targets = train_test_split(
+ data, targets, train_size=splits[0], random_state=config.seed)
+ val_data, test_data, val_targets, test_targets = train_test_split(
+ test_data,
+ test_targets,
+ train_size=splits[1] / (1. - splits[0]),
+ random_state=config.seed)
+ data_dict = {
+ 'train': [(x, y) for x, y in zip(train_data, train_targets)],
+ 'val': [(x, y) for x, y in zip(val_data, val_targets)],
+ 'test': [(x, y) for x, y in zip(test_data, test_targets)]
+ }
+ return data_dict
+
+ DATA_LOAD_FUNCS = {
+ 'torchvision': load_torchvision_data,
+ 'torchtext': load_torchtext_data,
+ 'torchaudio': load_torchaudio_data,
+ 'torch_geometric': load_torch_geometric_data,
+ 'huggingface_datasets': load_huggingface_datasets_data,
+ 'openml': load_openml_data
+ }
+
+ modified_config = config.clone()
+
+ # Load dataset
+ splits = modified_config.data.splits
+ name, package = modified_config.data.type.split('@')
+
+ dataset = DATA_LOAD_FUNCS[package.lower()](name, splits, modified_config)
+ splitter = get_splitter(modified_config)
+
+ data_local_dict = {
+ x: {}
+ for x in range(1, modified_config.federate.client_num + 1)
+ }
+ # # Build dict of Dataloader
+ # for split in dataset:
+ # if dataset[split] is None or dataset[split].__len__() == 0:
+ # continue
+ # for i, ds in enumerate(splitter(dataset[split])):
+ # if split == 'train':
+ # data_local_dict[i + 1][split] = DataLoader(
+ # ds,
+ # batch_size=modified_config.data.batch_size,
+ # shuffle=True,
+ # num_workers=modified_config.data.num_workers)
+ # else:
+ # data_local_dict[i + 1][split] = DataLoader(
+ # ds,
+ # batch_size=modified_config.data.batch_size,
+ # shuffle=False,
+ # num_workers=modified_config.data.num_workers)
+
+ # return data_local_dict, modified_config
+
+ # Build dict of Dataloader
+ train_label_distribution = None
+ for split in dataset:
+ if dataset[split] is None or dataset[split].__len__() == 0:
+ continue
+ train_labels = list()
+ for i, ds in enumerate(
+ splitter(dataset[split], prior=train_label_distribution)):
+ labels = [x[1] for x in ds]
+ if split == 'train':
+ train_labels.append(labels)
+ data_local_dict[i + 1][split] = DataLoader(
+ ds,
+ batch_size=modified_config.data.batch_size,
+ shuffle=True,
+ num_workers=modified_config.data.num_workers)
+ else:
+ data_local_dict[i + 1][split] = DataLoader(
+ ds,
+ batch_size=modified_config.data.batch_size,
+ shuffle=False,
+ num_workers=modified_config.data.num_workers)
+
+ if modified_config.data.consistent_label_distribution and len(
+ train_labels) > 0:
+ train_label_distribution = train_labels
+
+ return data_local_dict, modified_config
+
+
+def get_data(config):
+ """Instantiate the dataset and update the configuration accordingly if necessary.
+ Arguments:
+ config (obj): a cfg node object.
+ Returns:
+ obj: The dataset object.
+ cfg.node: The updated configuration.
+ """
+ for func in register.data_dict.values():
+ data_and_config = func(config)
+ if data_and_config is not None:
+ return data_and_config
+ if config.data.type.lower() == 'toy':
+ data, modified_config = load_toy_data(config)
+ elif config.data.type.lower() == 'quadratic':
+ from federatedscope.tabular.dataloader import load_quadratic_dataset
+ data, modified_config = load_quadratic_dataset(config)
+ elif config.data.type.lower() in ['femnist', 'celeba']:
+ from federatedscope.cv.dataloader import load_cv_dataset
+ data, modified_config = load_cv_dataset(config)
+ elif config.data.type.lower() in [
+ 'shakespeare', 'twitter', 'subreddit', 'synthetic'
+ ]:
+ from federatedscope.nlp.dataloader import load_nlp_dataset
+ data, modified_config = load_nlp_dataset(config)
+ elif config.data.type.lower() in [
+ 'cora',
+ 'citeseer',
+ 'pubmed',
+ 'dblp_conf',
+ 'dblp_org',
+ ] or config.data.type.lower().startswith('csbm'):
+ from federatedscope.gfl.dataloader import load_nodelevel_dataset
+ data, modified_config = load_nodelevel_dataset(config)
+ elif config.data.type.lower() in ['ciao', 'epinions', 'fb15k-237', 'wn18']:
+ from federatedscope.gfl.dataloader import load_linklevel_dataset
+ data, modified_config = load_linklevel_dataset(config)
+ elif config.data.type.lower() in [
+ 'hiv', 'proteins', 'imdb-binary', 'bbbp', 'tox21', 'bace', 'sider',
+ 'clintox', 'esol', 'freesolv', 'lipo'
+ ] or config.data.type.startswith('graph_multi_domain'):
+ from federatedscope.gfl.dataloader import load_graphlevel_dataset
+ data, modified_config = load_graphlevel_dataset(config)
+ elif config.data.type.lower() == 'vertical_fl_data':
+ from federatedscope.vertical_fl.dataloader import load_vertical_data
+ data, modified_config = load_vertical_data(config, generate=True)
+ elif 'movielens' in config.data.type.lower():
+ from federatedscope.mf.dataloader import load_mf_dataset
+ data, modified_config = load_mf_dataset(config)
+ elif '@' in config.data.type.lower():
+ data, modified_config = load_external_data(config)
+ else:
+ raise ValueError('Data {} not found.'.format(config.data.type))
+
+ if config.data.do_sta:
+ do_data_statistics(config, data)
+
+ if 'backdoor' in config.attack.attack_method:
+ from federatedscope.attack.auxiliary import poisoning
+ poisoning(data, modified_config)
+
+ return data, modified_config
+
+
+def merge_data(all_data):
+ dataset_names = list(all_data[1].keys()) # e.g., train, test, val
+ assert isinstance(all_data[1]["test"], dict), \
+ "the data should be organized as the format similar to {data_id: {train: {x:ndarray, y:ndarray}} }"
+ data_elem_names = list(all_data[1]["test"].keys()) # e.g., x, y
+ merged_data = {name: defaultdict(list) for name in dataset_names}
+ for data_id in all_data.keys():
+ if data_id == 0:
+ continue
+ for d_name in dataset_names:
+ for elem_name in data_elem_names:
+ merged_data[d_name][elem_name].append(
+ all_data[data_id][d_name][elem_name])
+
+ for d_name in dataset_names:
+ for elem_name in data_elem_names:
+ merged_data[d_name][elem_name] = np.concatenate(
+ merged_data[d_name][elem_name])
+
+ return merged_data
+
+
+def do_data_statistics(config, data):
+ data_num_all_client = defaultdict(list)
+ label_dist_all_client = dict() # {client: client_dist}
+
+ logger.info(
+ f"For data={config.data.type} with subsample={config.data.subsample},"
+ f" the client_num is {len(data)}")
+ for client_id, ds_ci in data.items():
+ if client_id == 0:
+ # skip the data holds on server
+ continue
+ if config.data.probe_label_dist:
+ label_dist_all_client[client_id] = \
+ [0 for _ in range(config.model.out_channels)]
+ if isinstance(ds_ci, dict):
+ for split_name, ds in ds_ci.items():
+ try:
+ import torch
+ from federatedscope.mf.dataloader import MFDataLoader
+ if isinstance(
+ ds, (torch.utils.data.Dataset, list)) or \
+ issubclass(type(ds), torch.utils.data.Dataset):
+ data_num_all_client[split_name].append(len(ds))
+ if config.data.probe_label_dist:
+ for i in range(len(ds)):
+ label = ds[i]
+ label_dist_all_client[client_id][label] += 1
+ elif isinstance(
+ ds, (torch.utils.data.DataLoader, list)) or \
+ issubclass(type(ds), torch.utils.data.DataLoader):
+ data_num_all_client[split_name].append(len(ds.dataset))
+ if config.data.labelwise_boxplot:
+ from collections import Counter
+ all_labels = [
+ ds.dataset[i][1]
+ for i in range(len(ds.dataset))
+ ]
+ label_wise_cnt = Counter(all_labels)
+ for label, cnt in label_wise_cnt.items():
+ data_num_all_client[label].append(cnt)
+ if config.data.probe_label_dist:
+ for i in range(len(ds)):
+ label = ds.dataset[i][1]
+ label_dist_all_client[client_id][label] += 1
+ elif issubclass(type(ds), MFDataLoader):
+ data_num_all_client[split_name].append(ds.n_rating)
+ except:
+ if isinstance(ds, list):
+ data_num_all_client[split_name].append(len(ds))
+ if config.data.type in ["cora", "citeseer", "pubmed"]:
+ # node-wise classification
+ from torch_geometric.data.data import Data
+ import torch
+ if isinstance(ds_ci, Data):
+ for split_name in ["train_mask", "val_mask", "test_mask"]:
+ num_nodes = sum(ds_ci[split_name]).item()
+ data_num_all_client[split_name.split("_")[0]].append(
+ num_nodes)
+ if config.data.plot_boxplot:
+ plot_data_statistics(config, data_num_all_client)
+ if config.data.probe_label_dist:
+ prob_label_dist(config, label_dist_all_client)
+ import random
+ unseen_clients_ids = random.choices(list(range(1, 50)), k=10)
+ prob_label_dist(config, label_dist_all_client, unseen_clients_ids)
+
+ from scipy import stats
+ all_split_merged_num = []
+ for k, v in data_num_all_client.items():
+ if all_split_merged_num == []:
+ all_split_merged_num.extend(v)
+ else:
+ all_split_merged_num = [
+ all_split_merged_num[i] + v[i] for i in range(len(v))
+ ]
+ data_num_all_client["all"] = all_split_merged_num
+ for k, v in data_num_all_client.items():
+ if len(v) == 0:
+ logger.warning(
+ "The data distribution statistics info are nor correctly "
+ "logged, maybe you used a data type we haven't support")
+ else:
+ stats_res = stats.describe(v)
+ if stats_res.minmax[1] == 0:
+ logger.warning(
+ f"For data split {k}, the max sample num in the client "
+ f"is 0. Please check whether "
+ f"this is as you would like it to be")
+ logger.info(
+ f"For data split {k}, the stats_res over all client is "
+ f"{stats_res}, the meadian is {sorted(v)[len(v) // 2]}, "
+ f"std is {math.sqrt(stats_res.variance)}")
+
+
+def prob_label_dist(config, label_dist_all_client, should_contain_ids=None):
+ pairwise_distance = dict()
+ from scipy.spatial.distance import jensenshannon
+ from scipy import stats
+ # normalize
+ for k, v in label_dist_all_client.items():
+ total = sum(v)
+ if total != 1:
+ label_dist_all_client[k] = [x / total for x in v]
+ # calculate the client-wise J-S distance
+ for i in range(1, len(list(label_dist_all_client.keys()))):
+ for j in range(1, i):
+ if should_contain_ids is not None and \
+ (i not in should_contain_ids or
+ j not in should_contain_ids):
+ continue
+ pairwise_distance[(i, j)] = jensenshannon(label_dist_all_client[i],
+ label_dist_all_client[j])
+ stats_res = stats.describe(list(pairwise_distance.values()))
+ logger.info(
+ f"The distribution for pari-wise JS-distance over all client is "
+ f"{stats_res}")
+
+ import matplotlib.pyplot as plt
+ import matplotlib.pylab as pylab
+ plt.clf()
+ label_size = 18.5
+ ticks_size = 17
+ title_size = 22.5
+ legend_size = 17
+ params = {
+ 'legend.fontsize': legend_size,
+ 'axes.labelsize': label_size,
+ 'axes.titlesize': title_size,
+ 'xtick.labelsize': ticks_size,
+ 'ytick.labelsize': ticks_size
+ }
+ pylab.rcParams.update(params)
+ ax = plt.subplot()
+ plt.hist(list(pairwise_distance.values()))
+ ax.set_xlabel("Client-wise JS distance")
+ ax.set_ylabel("Count")
+ ax.set_xlim(0, 1)
+ # ax.set_ylim(0, 1500)
+ fig_name = f"{config.outdir}/visual_{config.data.type}_js_distance.pdf"
+ plt.savefig(fig_name, bbox_inches='tight', pad_inches=0)
+ plt.show()
+
+
+def plot_data_statistics(config, data_num_all_client):
+ index = []
+ data_num_list = []
+ for key, val in data_num_all_client.items():
+ if config.data.labelwise_boxplot and key in ["train", "test", "val"]:
+ continue
+ index.append(key)
+ data_num_list.append(val)
+ if len(index) > 3 and index[1] == "test" and index[2] == "val":
+ index[1], index[2] = index[2], index[1]
+ data_num_list[1], data_num_list[2] = data_num_list[2], data_num_list[1]
+ import matplotlib.pyplot as plt
+ import matplotlib.pylab as pylab
+ plt.clf()
+ label_size = 18.5
+ ticks_size = 17
+ title_size = 22.5
+ legend_size = 17
+ params = {
+ 'legend.fontsize': legend_size,
+ 'axes.labelsize': label_size,
+ 'axes.titlesize': title_size,
+ 'xtick.labelsize': ticks_size,
+ 'ytick.labelsize': ticks_size
+ }
+ if config.data.labelwise_boxplot:
+ index_order = np.argsort(np.array(index))
+ index = [index[i] for i in index_order]
+ data_num_list = [data_num_list[i] for i in index_order]
+ from scipy import stats
+ for i in index_order:
+ stats_res = stats.describe(data_num_list[i])
+ logger.info(f"The distribution label {index[i]} is {stats_res}")
+ pylab.rcParams.update(params)
+ ax = plt.subplot()
+ ax.violinplot(data_num_list)
+ ax.set_xticks(range(1, len(index) + 1))
+ ax.set_xticklabels(index)
+ ax.set_ylabel("#Samples Per Client")
+ fig_name = f"{config.outdir}/visual_{config.data.type}.pdf"
+ if config.data.labelwise_boxplot:
+ fig_name = f"{config.outdir}/visual_{config.data.type}_label.pdf"
+ plt.savefig(fig_name, bbox_inches='tight', pad_inches=0)
+ plt.show()
diff --git a/federatedscope/core/auxiliaries/dataloader_builder.py b/federatedscope/core/auxiliaries/dataloader_builder.py
new file mode 100644
index 000000000..858ba1a03
--- /dev/null
+++ b/federatedscope/core/auxiliaries/dataloader_builder.py
@@ -0,0 +1,41 @@
+try:
+ import torch
+ from torch.utils.data import Dataset
+except ImportError:
+ torch = None
+ Dataset = object
+
+
+def get_dataloader(dataset, config):
+ if config.backend == 'torch':
+ from torch.utils.data import DataLoader
+ dataloader = DataLoader(dataset,
+ batch_size=config.data.batch_size,
+ shuffle=config.data.shuffle,
+ num_workers=config.data.num_workers,
+ pin_memory=True)
+ return dataloader
+ else:
+ return None
+
+
+class WrapDataset(Dataset):
+ """Wrap raw data into pytorch Dataset
+
+ Arguments:
+ data (dict): raw data dictionary contains "x" and "y"
+
+ """
+ def __init__(self, data):
+ super(WrapDataset, self).__init__()
+ self.data = data
+
+ def __getitem__(self, idx):
+ if not isinstance(self.data["x"][idx], torch.Tensor):
+ return torch.from_numpy(
+ self.data["x"][idx]).float(), torch.from_numpy(
+ self.data["y"][idx]).float()
+ return self.data["x"][idx], self.data["y"][idx]
+
+ def __len__(self):
+ return len(self.data["y"])
diff --git a/federatedscope/core/auxiliaries/metric_builder.py b/federatedscope/core/auxiliaries/metric_builder.py
new file mode 100644
index 000000000..59d0bbd2d
--- /dev/null
+++ b/federatedscope/core/auxiliaries/metric_builder.py
@@ -0,0 +1,21 @@
+import logging
+import federatedscope.register as register
+
+logger = logging.getLogger(__name__)
+
+try:
+ from federatedscope.contrib.metrics import *
+except ImportError as error:
+ logger.warning(
+ f'{error} in `federatedscope.contrib.metrics`, some modules are not available.'
+ )
+
+
+def get_metric(types):
+ metrics = dict()
+ for func in register.metric_dict.values():
+ res = func(types)
+ if res is not None:
+ name, metric = res
+ metrics[name] = metric
+ return metrics
diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py
new file mode 100644
index 000000000..ed0449951
--- /dev/null
+++ b/federatedscope/core/auxiliaries/model_builder.py
@@ -0,0 +1,120 @@
+from cgi import print_arguments
+import logging
+import federatedscope.register as register
+
+logger = logging.getLogger(__name__)
+
+try:
+ from federatedscope.contrib.model import *
+except ImportError as error:
+ logger.warning(
+ f'{error} in `federatedscope.contrib.model`, some modules are not available.'
+ )
+
+
+def get_model(model_config, local_data, backend='torch'):
+ """
+ Arguments:
+ local_data (object): the model to be instantiated is responsible for the given data.
+ Returns:
+ model (torch.Module): the instantiated model.
+ """
+ for func in register.model_dict.values():
+ model = func(model_config, local_data)
+ if model is not None:
+ return model
+
+ if model_config.type.lower() == 'lr':
+ if backend == 'torch':
+ from federatedscope.core.lr import LogisticRegression
+ # TODO: make the instantiation more general
+ if isinstance(
+ local_data, dict
+ ) and 'test' in local_data and 'x' in local_data['test']:
+ model = LogisticRegression(
+ in_channels=local_data['test']['x'].shape[-1],
+ class_num=1,
+ use_bias=model_config.use_bias)
+ else:
+ if isinstance(local_data, dict):
+ if 'data' in local_data.keys():
+ data = local_data['data']
+ elif 'train' in local_data.keys():
+ # local_data['train'] is Dataloader
+ data = next(iter(local_data['train']))
+ else:
+ raise TypeError('Unsupported data type.')
+ else:
+ data = local_data
+
+ x, _ = data
+ model = LogisticRegression(in_channels=x.shape[-1],
+ class_num=model_config.out_channels)
+ elif backend == 'tensorflow':
+ from federatedscope.cross_backends import LogisticRegression
+ model = LogisticRegression(
+ in_channels=local_data['test']['x'].shape[-1],
+ class_num=1,
+ use_bias=model_config.use_bias)
+ else:
+ raise ValueError
+
+ elif model_config.type.lower() == 'mlp':
+ from federatedscope.core.mlp import MLP
+ if isinstance(local_data, dict):
+ if 'data' in local_data.keys():
+ data = local_data['data']
+ elif 'train' in local_data.keys():
+ # local_data['train'] is Dataloader
+ data = next(iter(local_data['train']))
+ else:
+ raise TypeError('Unsupported data type.')
+ else:
+ data = local_data
+
+ x, _ = data
+ model = MLP(channel_list=[x.shape[-1]] + [model_config.hidden] *
+ (model_config.layer - 1) + [model_config.out_channels],
+ dropout=model_config.dropout)
+
+ elif model_config.type.lower() == 'quadratic':
+ from federatedscope.tabular.model import QuadraticModel
+ if isinstance(local_data, dict):
+ data = next(iter(local_data['train']))
+ else:
+ # TODO: complete the branch
+ data = local_data
+ x, _ = data
+ model = QuadraticModel(x.shape[-1], 1)
+
+ elif model_config.type.lower() in ['convnet2', 'convnet5', 'vgg11', 'lr']:
+ from federatedscope.cv.model import get_cnn
+ model = get_cnn(model_config, local_data)
+ elif model_config.type.lower() in ['lstm']:
+ from federatedscope.nlp.model import get_rnn
+ model = get_rnn(model_config, local_data)
+ elif model_config.type.lower().endswith('transformers'):
+ from federatedscope.nlp.model import get_transformer
+ model = get_transformer(model_config, local_data)
+ elif model_config.type.lower() in [
+ 'gcn', 'sage', 'gpr', 'gat', 'gin', 'mpnn'
+ ]:
+ from federatedscope.gfl.model import get_gnn
+ model = get_gnn(model_config, local_data)
+ elif model_config.type.lower() in ['vmfnet', 'hmfnet']:
+ from federatedscope.mf.model.model_builder import get_mfnet
+ model = get_mfnet(model_config, local_data)
+ else:
+ raise ValueError('Model {} is not provided'.format(model_config.type))
+
+ return model
+
+
+def get_trainable_para_names(model):
+ # for name,param in model.named_parameters():
+ # print(name)
+ # for para in model.parameters():
+ # print(para)
+ # print(model.named_parameters())
+ #
+ return set(dict(list(model.named_parameters())).keys())
diff --git a/federatedscope/core/auxiliaries/optimizer_builder.py b/federatedscope/core/auxiliaries/optimizer_builder.py
new file mode 100644
index 000000000..6083dbaee
--- /dev/null
+++ b/federatedscope/core/auxiliaries/optimizer_builder.py
@@ -0,0 +1,21 @@
+try:
+ import torch
+except ImportError:
+ torch = None
+
+
+def get_optimizer(model, type, lr, **kwargs):
+ if torch is None:
+ return None
+ if isinstance(type, str):
+ if hasattr(torch.optim, type):
+ if isinstance(model, torch.nn.Module):
+ return getattr(torch.optim, type)(model.parameters(), lr,
+ **kwargs)
+ else:
+ return getattr(torch.optim, type)(model, lr, **kwargs)
+ else:
+ raise NotImplementedError(
+ 'Optimizer {} not implement'.format(type))
+ else:
+ raise TypeError()
diff --git a/federatedscope/core/auxiliaries/regularizer_builder.py b/federatedscope/core/auxiliaries/regularizer_builder.py
new file mode 100644
index 000000000..75af98cf9
--- /dev/null
+++ b/federatedscope/core/auxiliaries/regularizer_builder.py
@@ -0,0 +1,30 @@
+from federatedscope.register import regularizer_dict
+from federatedscope.core.regularizer.proximal_regularizer import *
+try:
+ from torch.nn import Module
+except ImportError:
+ Module = object
+
+
+def get_regularizer(type):
+ if type is None or type == '':
+ return DummyRegularizer()
+
+ for func in regularizer_dict.values():
+ regularizer = func(type)
+ if regularizer is not None:
+ return regularizer()
+
+ raise NotImplementedError(
+ "Regularizer {} is not implemented.".format(type))
+
+
+class DummyRegularizer(Module):
+ """Dummy regularizer that only returns zero.
+
+ """
+ def __init__(self):
+ super(DummyRegularizer, self).__init__()
+
+ def forward(self, ctx):
+ return 0.
diff --git a/federatedscope/core/auxiliaries/splitter_builder.py b/federatedscope/core/auxiliaries/splitter_builder.py
new file mode 100644
index 000000000..40555165f
--- /dev/null
+++ b/federatedscope/core/auxiliaries/splitter_builder.py
@@ -0,0 +1,48 @@
+import logging
+import federatedscope.register as register
+
+logger = logging.getLogger(__name__)
+
+
+def get_splitter(config):
+ client_num = config.federate.client_num
+ if config.data.splitter_args:
+ args = config.data.splitter_args[0]
+ else:
+ args = {}
+
+ for func in register.splitter_dict.values():
+ splitter = func(config)
+ if splitter is not None:
+ return splitter
+ # Delay import
+ # generic splitter
+ if config.data.splitter == 'lda':
+ from federatedscope.core.splitters.generic import LDASplitter
+ splitter = LDASplitter(client_num, **args)
+ # graph splitter
+ elif config.data.splitter == 'louvain':
+ from federatedscope.core.splitters.graph import LouvainSplitter
+ splitter = LouvainSplitter(client_num, **args)
+ elif config.data.splitter == 'random':
+ from federatedscope.core.splitters.graph import RandomSplitter
+ splitter = RandomSplitter(client_num, **args)
+ elif config.data.splitter == 'rel_type':
+ from federatedscope.core.splitters.graph import RelTypeSplitter
+ splitter = RelTypeSplitter(client_num, **args)
+ elif config.data.splitter == 'graph_type':
+ from federatedscope.core.splitters.graph import GraphTypeSplitter
+ splitter = GraphTypeSplitter(client_num, **args)
+ elif config.data.splitter == 'scaffold':
+ from federatedscope.core.splitters.graph import ScaffoldSplitter
+ splitter = ScaffoldSplitter(client_num, **args)
+ elif config.data.splitter == 'scaffold_lda':
+ from federatedscope.core.splitters.graph import ScaffoldLdaSplitter
+ splitter = ScaffoldLdaSplitter(client_num, **args)
+ elif config.data.splitter == 'rand_chunk':
+ from federatedscope.core.splitters.graph import RandChunkSplitter
+ splitter = RandChunkSplitter(client_num, **args)
+ else:
+ logger.warning(f'Splitter is none or not found.')
+ splitter = None
+ return splitter
diff --git a/federatedscope/core/auxiliaries/trainer_builder.py b/federatedscope/core/auxiliaries/trainer_builder.py
new file mode 100644
index 000000000..83f19fbe9
--- /dev/null
+++ b/federatedscope/core/auxiliaries/trainer_builder.py
@@ -0,0 +1,164 @@
+import logging
+import importlib
+
+import federatedscope.register as register
+
+logger = logging.getLogger(__name__)
+
+try:
+ from federatedscope.contrib.trainer import *
+except ImportError as error:
+ logger.warning(
+ f'{error} in `federatedscope.contrib.trainer`, some modules are not available.'
+ )
+
+TRAINER_CLASS_DICT = {
+ "cvtrainer": "CVTrainer",
+ "nlptrainer": "NLPTrainer",
+ "graphminibatch_trainer": "GraphMiniBatchTrainer",
+ "linkfullbatch_trainer": "LinkFullBatchTrainer",
+ "linkminibatch_trainer": "LinkMiniBatchTrainer",
+ "nodefullbatch_trainer": "NodeFullBatchTrainer",
+ "nodeminibatch_trainer": "NodeMiniBatchTrainer",
+ "flitplustrainer": "FLITPlusTrainer",
+ "flittrainer": "FLITTrainer",
+ "fedvattrainer": "FedVATTrainer",
+ "fedfocaltrainer": "FedFocalTrainer",
+ "mftrainer": "MFTrainer",
+}
+
+
+def get_trainer(model=None,
+ data=None,
+ device=None,
+ config=None,
+ only_for_eval=False,
+ is_attacker=False,
+ monitor=None):
+ if config.trainer.type == 'general':
+ if config.backend == 'torch':
+ from federatedscope.core.trainers import GeneralTorchTrainer
+ trainer = GeneralTorchTrainer(model=model,
+ data=data,
+ device=device,
+ config=config,
+ only_for_eval=only_for_eval,
+ monitor=monitor)
+ elif config.backend == 'tensorflow':
+ from federatedscope.core.trainers import GeneralTFTrainer
+ trainer = GeneralTFTrainer(model=model,
+ data=data,
+ device=device,
+ config=config,
+ only_for_eval=only_for_eval,
+ monitor=monitor)
+ else:
+ raise ValueError
+ elif config.trainer.type == 'none':
+ return None
+ elif config.trainer.type.lower() in TRAINER_CLASS_DICT:
+ if config.trainer.type.lower() in ['cvtrainer']:
+ dict_path = "federatedscope.cv.trainer.trainer"
+ elif config.trainer.type.lower() in ['nlptrainer']:
+ dict_path = "federatedscope.nlp.trainer.trainer"
+ elif config.trainer.type.lower() in [
+ 'graphminibatch_trainer',
+ ]:
+ dict_path = "federatedscope.gfl.trainer.graphtrainer"
+ elif config.trainer.type.lower() in [
+ 'linkfullbatch_trainer', 'linkminibatch_trainer'
+ ]:
+ dict_path = "federatedscope.gfl.trainer.linktrainer"
+ elif config.trainer.type.lower() in [
+ 'nodefullbatch_trainer', 'nodeminibatch_trainer'
+ ]:
+ dict_path = "federatedscope.gfl.trainer.nodetrainer"
+ elif config.trainer.type.lower() in [
+ 'flitplustrainer', 'flittrainer', 'fedvattrainer',
+ 'fedfocaltrainer'
+ ]:
+ dict_path = "federatedscope.gfl.flitplus.trainer"
+ elif config.trainer.type.lower() in ['mftrainer']:
+ dict_path = "federatedscope.mf.trainer.trainer"
+ else:
+ raise ValueError
+
+ trainer_cls = getattr(importlib.import_module(name=dict_path),
+ TRAINER_CLASS_DICT[config.trainer.type.lower()])
+ trainer = trainer_cls(model=model,
+ data=data,
+ device=device,
+ config=config,
+ only_for_eval=only_for_eval,
+ monitor=monitor)
+ else:
+ # try to find user registered trainer
+ trainer = None
+ for func in register.trainer_dict.values():
+ trainer_cls = func(config.trainer.type)
+ if trainer_cls is not None:
+ trainer = trainer_cls(model=model,
+ data=data,
+ device=device,
+ config=config,
+ only_for_eval=only_for_eval,
+ monitor=monitor)
+ if trainer is None:
+ raise ValueError('Trainer {} is not provided'.format(
+ config.trainer.type))
+
+ # differential privacy plug-in
+ if config.nbafl.use:
+ from federatedscope.core.trainers import wrap_nbafl_trainer
+ trainer = wrap_nbafl_trainer(trainer)
+ if config.sgdmf.use:
+ from federatedscope.mf.trainer import wrap_MFTrainer
+ trainer = wrap_MFTrainer(trainer)
+
+ # personalization plug-in
+ if config.federate.method.lower() == "pfedme":
+ from federatedscope.core.trainers import wrap_pFedMeTrainer
+ # wrap style: instance a (class A) -> instance a (class A)
+ trainer = wrap_pFedMeTrainer(trainer)
+ elif config.federate.method.lower() == "ditto":
+ from federatedscope.core.trainers import wrap_DittoTrainer
+ # wrap style: instance a (class A) -> instance a (class A)
+ trainer = wrap_DittoTrainer(trainer)
+ elif config.federate.method.lower() == "fedem":
+ from federatedscope.core.trainers import FedEMTrainer
+ # copy construct style: instance a (class A) -> instance b (class B)
+ trainer = FedEMTrainer(model_nums=config.model.model_num_per_trainer,
+ base_trainer=trainer)
+ elif config.federate.method.lower() == "fedrep":
+ from federatedscope.core.trainers import wrap_FedRepTrainer
+ # wrap style: instance a (class A) -> instance a (class A)
+ trainer = wrap_FedRepTrainer(trainer)
+
+ # attacker plug-in
+ if 'backdoor' in config.attack.attack_method:
+ from federatedscope.attack.trainer import wrap_benignTrainer
+ trainer = wrap_benignTrainer(trainer)
+
+ if is_attacker:
+ if 'backdoor' in config.attack.attack_method:
+ logger.info(
+ '---------------- This client is a backdoor attacker --------------------'
+ )
+ else:
+ logger.info(
+ '---------------- This client is an privacy attacker --------------------'
+ )
+ from federatedscope.attack.auxiliary.attack_trainer_builder import wrap_attacker_trainer
+ trainer = wrap_attacker_trainer(trainer, config)
+
+ elif 'backdoor' in config.attack.attack_method:
+ logger.info(
+ '---------------- This client is a benign client for backdoor attacks --------------------'
+ )
+
+ # fed algorithm plug-in
+ if config.fedprox.use:
+ from federatedscope.core.trainers import wrap_fedprox_trainer
+ trainer = wrap_fedprox_trainer(trainer)
+
+ return trainer
diff --git a/federatedscope/core/auxiliaries/transform_builder.py b/federatedscope/core/auxiliaries/transform_builder.py
new file mode 100644
index 000000000..d34bfe641
--- /dev/null
+++ b/federatedscope/core/auxiliaries/transform_builder.py
@@ -0,0 +1,54 @@
+from importlib import import_module
+import federatedscope.register as register
+
+
+def get_transform(config, package):
+ r"""
+
+ Args:
+ config: `CN` from `federatedscope/core/configs/config.py`
+ package: one of package from ['torchvision', 'torch_geometric', 'torchtext', 'torchaudio']
+
+ Returns:
+ dict of transform functions.
+
+ """
+ transform_funcs = {}
+ for name in ['transform', 'target_transform', 'pre_transform']:
+ if config.data[name]:
+ transform_funcs[name] = config.data[name]
+
+ # Transform are all None, do not import package and return dict with None value
+ if not transform_funcs:
+ return transform_funcs
+ #
+ transforms = getattr(import_module(package), 'transforms')
+
+ #
+ def convert(trans):
+ # Recursively converting expressions to functions
+ if isinstance(trans[0], str):
+ if len(trans) == 1:
+ trans.append({})
+ transform_type, transform_args = trans
+ for func in register.transform_dict.values():
+ transform_func = func(transform_type, transform_args)
+ if transform_func is not None:
+ return transform_func
+ transform_func = getattr(transforms,
+ transform_type)(**transform_args)
+ return transform_func
+ else:
+ transform = [convert(x) for x in trans]
+ if hasattr(transforms, 'Compose'):
+ return transforms.Compose(transform)
+ elif hasattr(transforms, 'Sequential'):
+ return transforms.Sequential(transform)
+ else:
+ return transform
+
+ # return composed transform or return list of transform
+ for key in transform_funcs:
+ transform_funcs[key] = convert(config.data[key])
+ #
+ return transform_funcs
diff --git a/federatedscope/core/auxiliaries/utils.py b/federatedscope/core/auxiliaries/utils.py
new file mode 100644
index 000000000..314aa0d69
--- /dev/null
+++ b/federatedscope/core/auxiliaries/utils.py
@@ -0,0 +1,382 @@
+import copy
+import json
+import logging
+import math
+import os
+import random
+import signal
+import ssl
+import time
+import urllib.request
+from datetime import datetime
+from os import path as osp
+
+import numpy as np
+
+# Blind torch
+try:
+ import torch
+ import torchvision
+ import torch.distributions as distributions
+except ImportError:
+ torch = None
+ torchvision = None
+ distributions = None
+
+logger = logging.getLogger(__name__)
+
+
+def setup_seed(seed):
+ np.random.seed(seed)
+ random.seed(seed)
+ if torch is not None:
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+ else:
+ import tensorflow as tf
+ tf.set_random_seed(seed)
+
+
+def update_logger(cfg, clear_before_add=False):
+ import os
+ import logging
+
+ root_logger = logging.getLogger("federatedscope")
+
+ # clear all existing handlers and add the default stream
+ if clear_before_add:
+ root_logger.handlers = []
+ handler = logging.StreamHandler()
+ logging_fmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
+ handler.setFormatter(logging.Formatter(logging_fmt))
+ root_logger.addHandler(handler)
+
+ # update level
+ if cfg.verbose > 0:
+ logging_level = logging.INFO
+ else:
+ logging_level = logging.WARN
+ logger.warning("Skip DEBUG/INFO messages")
+ root_logger.setLevel(logging_level)
+
+ # ================ create outdir to save log, exp_config, models, etc,.
+ if cfg.outdir == "":
+ cfg.outdir = os.path.join(os.getcwd(), "exp")
+ exp_path = f"{cfg.federate.method}_{cfg.model.type}_on_{cfg.data.type}_lr{cfg.optimizer.lr}_lepoch{cfg.federate.local_update_steps}"
+ cfg.outdir = os.path.join(cfg.outdir, exp_path)
+
+ if cfg.attack.attack_method != '':
+ expname = f"{cfg.attack.attack_method}_{cfg.attack.trigger_type}_{cfg.attack.setting}"
+ else:
+ expname = f"normal"
+
+ if cfg.expname == "":
+ # cfg.expname = f"{cfg.federate.method}_{cfg.model.type}_on_{cfg.data.type}"
+ cfg.expname = expname
+ else:
+ cfg.expname = os.path.join(expname, cfg.expname)
+
+ cfg.expname = os.path.join(cfg.expname, str(cfg.seed))
+
+ cfg.outdir = os.path.join(cfg.outdir, cfg.expname)
+
+ # if exist, make directory with given name and time
+ if os.path.isdir(cfg.outdir) and os.path.exists(cfg.outdir):
+ outdir = os.path.join(cfg.outdir, "sub_exp" +
+ datetime.now().strftime('_%Y%m%d%H%M%S')
+ ) # e.g., sub_exp_20220411030524
+ while os.path.exists(outdir):
+ time.sleep(1)
+ outdir = os.path.join(
+ cfg.outdir,
+ "sub_exp" + datetime.now().strftime('_%Y%m%d%H%M%S'))
+ cfg.outdir = outdir
+ # if not, make directory with given name
+ os.makedirs(cfg.outdir)
+
+ # create file handler which logs even debug messages
+ fh = logging.FileHandler(os.path.join(cfg.outdir, 'exp_print.log'))
+ fh.setLevel(logging.DEBUG)
+ logger_formatter = logging.Formatter(
+ "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
+ fh.setFormatter(logger_formatter)
+ root_logger.addHandler(fh)
+ # sys.stderr = sys.stdout
+
+ root_logger.info(f"the output dir is {cfg.outdir}")
+
+ if cfg.wandb.use:
+ init_wandb(cfg)
+
+
+def init_wandb(cfg):
+ try:
+ import wandb
+ except ImportError:
+ logger.error("cfg.wandb.use=True but not install the wandb package")
+ exit()
+ dataset_name = cfg.data.type
+ method_name = cfg.federate.method
+ exp_name = cfg.expname
+
+ tmp_cfg = copy.deepcopy(cfg)
+ tmp_cfg.cfg_check_funcs = []
+ import yaml
+ cfg_yaml = yaml.safe_load(tmp_cfg.dump())
+
+ wandb.init(project=cfg.wandb.name_project,
+ entity=cfg.wandb.name_user,
+ config=cfg_yaml,
+ group=dataset_name,
+ job_type=method_name,
+ name=exp_name,
+ notes=f"{method_name}, {exp_name}")
+
+
+def get_dataset(type, root, transform, target_transform, download=True):
+ if isinstance(type, str):
+ if hasattr(torchvision.datasets, type):
+ return getattr(torchvision.datasets,
+ type)(root=root,
+ transform=transform,
+ target_transform=target_transform,
+ download=download)
+ else:
+ raise NotImplementedError('Dataset {} not implement'.format(type))
+ else:
+ raise TypeError()
+
+
+def save_local_data(dir_path,
+ train_data=None,
+ train_targets=None,
+ test_data=None,
+ test_targets=None,
+ val_data=None,
+ val_targets=None):
+ r"""
+ https://github.com/omarfoq/FedEM/blob/main/data/femnist/generate_data.py
+
+ save (`train_data`, `train_targets`) in {dir_path}/train.pt,
+ (`val_data`, `val_targets`) in {dir_path}/val.pt
+ and (`test_data`, `test_targets`) in {dir_path}/test.pt
+ :param dir_path:
+ :param train_data:
+ :param train_targets:
+ :param test_data:
+ :param test_targets:
+ :param val_data:
+ :param val_targets
+ """
+ if (train_data is not None) and (train_targets is not None):
+ torch.save((train_data, train_targets), osp.join(dir_path, "train.pt"))
+
+ if (test_data is not None) and (test_targets is not None):
+ torch.save((test_data, test_targets), osp.join(dir_path, "test.pt"))
+
+ if (val_data is not None) and (val_targets is not None):
+ torch.save((val_data, val_targets), osp.join(dir_path, "val.pt"))
+
+
+def filter_by_specified_keywords(param_name, filter_keywords):
+ '''
+ Arguments:
+ param_name (str): parameter name.
+ Returns:
+ preserve (bool): whether to preserve this parameter.
+ '''
+ preserve = True
+ for kw in filter_keywords:
+ if kw in param_name:
+ preserve = False
+ break
+ return preserve
+
+
+def get_random(type, sample_shape, params, device):
+ if not hasattr(distributions, type):
+ raise NotImplementedError("Distribution {} is not implemented, please refer to ```torch.distributions```" \
+ "(https://pytorch.org/docs/stable/distributions.html).".format(type))
+ generator = getattr(distributions, type)(**params)
+ return generator.sample(sample_shape=sample_shape).to(device)
+
+
+def batch_iter(data, batch_size=64, shuffled=True):
+ assert 'x' in data and 'y' in data
+ data_x = data['x']
+ data_y = data['y']
+ data_size = len(data_y)
+ num_batches_per_epoch = math.ceil(data_size / batch_size)
+
+ while True:
+ shuffled_index = np.random.permutation(
+ np.arange(data_size)) if shuffled else np.arange(data_size)
+ for batch in range(num_batches_per_epoch):
+ start_index = batch * batch_size
+ end_index = min(data_size, (batch + 1) * batch_size)
+ sample_index = shuffled_index[start_index:end_index]
+ yield {'x': data_x[sample_index], 'y': data_y[sample_index]}
+
+
+def merge_dict(dict1, dict2):
+ # Merge results for history
+ for key, value in dict2.items():
+ if key not in dict1:
+ if isinstance(value, dict):
+ dict1[key] = merge_dict({}, value)
+ else:
+ dict1[key] = [value]
+ else:
+ if isinstance(value, dict):
+ merge_dict(dict1[key], value)
+ else:
+ dict1[key].append(value)
+ return dict1
+
+
+def download_url(url: str, folder='folder'):
+ r"""Downloads the content of an url to a folder.
+
+ Modified from `https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/data/download.py`
+
+ Args:
+ url (string): The url of target file.
+ folder (string): The target folder.
+
+ Returns:
+ path (string): File path of downloaded files.
+ """
+
+ file = url.rpartition('/')[2]
+ file = file if file[0] == '?' else file.split('?')[0]
+ path = osp.join(folder, file)
+ if osp.exists(path):
+ logger.info(f'File {file} exists, use existing file.')
+ return path
+
+ logger.info(f'Downloading {url}')
+ os.makedirs(folder, exist_ok=True)
+ ctx = ssl._create_unverified_context()
+ data = urllib.request.urlopen(url, context=ctx)
+ with open(path, 'wb') as f:
+ f.write(data.read())
+
+ return path
+
+
+def move_to(obj, device):
+ import torch
+ if torch.is_tensor(obj):
+ return obj.to(device)
+ elif isinstance(obj, dict):
+ res = {}
+ for k, v in obj.items():
+ res[k] = move_to(v, device)
+ return res
+ elif isinstance(obj, list):
+ res = []
+ for v in obj:
+ res.append(move_to(v, device))
+ return res
+ else:
+ raise TypeError("Invalid type for move_to")
+
+
+class Timeout(object):
+ def __init__(self, seconds, max_failure=5):
+ self.seconds = seconds
+ self.max_failure = max_failure
+
+ def __enter__(self):
+ def signal_handler(signum, frame):
+ raise TimeoutError()
+
+ if self.seconds > 0:
+ signal.signal(signal.SIGALRM, signal_handler)
+ signal.alarm(self.seconds)
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ signal.alarm(0)
+
+ def reset(self):
+ signal.alarm(self.seconds)
+
+ def block(self):
+ signal.alarm(0)
+
+ def exceed_max_failure(self, num_failure):
+ return num_failure > self.max_failure
+
+
+def logfile_2_wandb_dict(exp_log_f, raw_out=True):
+ """
+ parse the logfiles [exp_print.log, eval_results.log] into wandb_dict that contains non-nested dicts
+
+ :param exp_log_f: opened exp_log file
+ :param raw_out: True indicates "exp_print.log", otherwise indicates "eval_results.log",
+ the difference is whether contains the logger header such as "2022-05-02 16:55:02,843 (client:197) INFO:"
+
+ :return: tuple including (all_log_res, exp_stop_normal, last_line, log_res_best)
+ """
+ log_res_best = {}
+ exp_stop_normal = False
+ all_log_res = []
+ last_line = None
+ for line in exp_log_f:
+ last_line = line
+ if " Find new best result" in line:
+ # e.g.,
+ # 2022-03-22 10:48:42,562 (server:459) INFO: Find new best result for client_individual.test_acc with value 0.5911787974683544
+ parse_res = line.split("INFO: ")[1].split("with value")
+ best_key, best_val = parse_res[-2], parse_res[-1]
+ # client_individual.test_acc -> client_individual/test_acc
+ best_key = best_key.replace("Find new best result for",
+ "").replace(".", "/")
+ log_res_best[best_key.strip()] = float(best_val.strip())
+
+ if "'Role': 'Server #'" in line:
+ if raw_out:
+ line = line.split("INFO: ")[1]
+ res = line.replace("\'", "\"")
+ res = json.loads(s=res)
+ if res['Role'] == 'Server #':
+ cur_round = res['Round']
+ res.pop('Role')
+ if cur_round != "Final" and 'Results_raw' in res:
+ res.pop('Results_raw')
+
+ log_res = {}
+ for key, val in res.items():
+ if not isinstance(val, dict):
+ log_res[key] = val
+ else:
+ if cur_round != "Final":
+ for key_inner, val_inner in val.items():
+ assert not isinstance(
+ val_inner, dict), "Un-expected log format"
+ log_res[f"{key}/{key_inner}"] = val_inner
+
+ else:
+ exp_stop_normal = True
+ if key == "Results_raw":
+ for final_type, final_type_dict in res[
+ "Results_raw"].items():
+ for inner_key, inner_val in final_type_dict.items(
+ ):
+ log_res_best[
+ f"{final_type}/{inner_key}"] = inner_val
+ # log_res_best = {}
+ # for best_res_type, val_dict in val.items():
+ # for key_inner, val_inner in val_dict.items():
+ # assert not isinstance(val_inner, dict), "Un-expected log format"
+ # log_res_best[f"{best_res_type}/{key_inner}"] = val_inner
+ # if log_res_best is not None and "Results_weighted_avg/val_loss" in log_res and \
+ # log_res_best["client_summarized_weighted_avg/val_loss"] > \
+ # log_res["Results_weighted_avg/val_loss"]:
+ # print("Missing the results of last round, update best results")
+ # for key, val in log_res.items():
+ # log_res_best[key.replace("Results", "client_summarized")] = val
+ all_log_res.append(log_res)
+ return all_log_res, exp_stop_normal, last_line, log_res_best
diff --git a/federatedscope/core/auxiliaries/worker_builder.py b/federatedscope/core/auxiliaries/worker_builder.py
new file mode 100644
index 000000000..b8e1b1f0b
--- /dev/null
+++ b/federatedscope/core/auxiliaries/worker_builder.py
@@ -0,0 +1,85 @@
+import logging
+
+from federatedscope.core.configs import constants
+from federatedscope.core.worker import Server, Client
+
+logger = logging.getLogger(__name__)
+
+
+def get_client_cls(cfg):
+ if cfg.hpo.fedex.use:
+ from federatedscope.autotune.fedex import FedExClient
+ return FedExClient
+
+ if cfg.vertical.use:
+ from federatedscope.vertical_fl.worker import vFLClient
+ return vFLClient
+
+ if cfg.federate.method.lower() in constants.CLIENTS_TYPE:
+ client_type = constants.CLIENTS_TYPE[cfg.federate.method.lower()]
+ else:
+ client_type = "normal"
+ logger.warning(
+ 'Clients for method {} is not implemented. Will use default one'.
+ format(cfg.federate.method))
+
+ if client_type == 'fedsageplus':
+ from federatedscope.gfl.fedsageplus.worker import FedSagePlusClient
+ client_class = FedSagePlusClient
+ elif client_type == 'gcflplus':
+ from federatedscope.gfl.gcflplus.worker import GCFLPlusClient
+ client_class = GCFLPlusClient
+ else:
+ client_class = Client
+
+ # add attack related method to client_class
+
+ if cfg.attack.attack_method.lower() in constants.CLIENTS_TYPE:
+ client_atk_type = constants.CLIENTS_TYPE[
+ cfg.attack.attack_method.lower()]
+ else:
+ client_atk_type = None
+
+ if client_atk_type == 'gradascent':
+ from federatedscope.attack.worker_as_attacker.active_client import add_atk_method_to_Client_GradAscent
+ logger.info("=========== add method to current client class ")
+ client_class = add_atk_method_to_Client_GradAscent(client_class)
+ return client_class
+
+
+def get_server_cls(cfg):
+ if cfg.hpo.fedex.use:
+ from federatedscope.autotune.fedex import FedExServer
+ return FedExServer
+
+ if cfg.attack.attack_method.lower() in ['dlg', 'ig']:
+ from federatedscope.attack.worker_as_attacker.server_attacker import PassiveServer
+ return PassiveServer
+ elif cfg.attack.attack_method.lower() in ['passivepia']:
+ from federatedscope.attack.worker_as_attacker.server_attacker import PassivePIAServer
+ return PassivePIAServer
+
+ elif cfg.attack.attack_method.lower() in ['backdoor']:
+ from federatedscope.attack.worker_as_attacker.server_attacker import BackdoorServer
+ return BackdoorServer
+
+ if cfg.vertical.use:
+ from federatedscope.vertical_fl.worker import vFLServer
+ return vFLServer
+
+ if cfg.federate.method.lower() in constants.SERVER_TYPE:
+ client_type = constants.SERVER_TYPE[cfg.federate.method.lower()]
+ else:
+ client_type = "normal"
+ logger.warning(
+ 'Server for method {} is not implemented. Will use default one'.
+ format(cfg.federate.method))
+
+ if client_type == 'fedsageplus':
+ from federatedscope.gfl.fedsageplus.worker import FedSagePlusServer
+ return FedSagePlusServer
+ elif client_type == 'gcflplus':
+ from federatedscope.gfl.gcflplus.worker import GCFLPlusServer
+ return GCFLPlusServer
+ else:
+ return Server
diff --git a/federatedscope/core/cmd_args.py b/federatedscope/core/cmd_args.py
new file mode 100644
index 000000000..924a0196f
--- /dev/null
+++ b/federatedscope/core/cmd_args.py
@@ -0,0 +1,26 @@
+import argparse
+import sys
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='FederatedScope')
+ parser.add_argument('--cfg',
+ dest='cfg_file',
+ help='Config file path',
+ required=True,
+ type=str)
+ parser.add_argument('--client_cfg',
+ dest='client_cfg_file',
+ help='Config file path for clients',
+ required=False,
+ default=None,
+ type=str)
+ parser.add_argument('opts',
+ help='See federatedscope/core/configs for all options',
+ default=None,
+ nargs=argparse.REMAINDER)
+ if len(sys.argv) == 1:
+ parser.print_help()
+ sys.exit(1)
+
+ return parser.parse_args()
diff --git a/federatedscope/core/communication.py b/federatedscope/core/communication.py
new file mode 100644
index 000000000..35d4991dd
--- /dev/null
+++ b/federatedscope/core/communication.py
@@ -0,0 +1,136 @@
+import grpc
+from concurrent import futures
+
+from federatedscope.core.configs.config import global_cfg
+from federatedscope.core.proto import gRPC_comm_manager_pb2, gRPC_comm_manager_pb2_grpc
+from federatedscope.core.gRPC_server import gRPCComServeFunc
+from federatedscope.core.message import Message
+
+
+class StandaloneCommManager(object):
+ """
+ The communicator used for standalone mode
+ """
+ def __init__(self, comm_queue, monitor=None):
+ self.comm_queue = comm_queue
+ self.neighbors = dict()
+ self.monitor = monitor # used to track the communication related metrics
+
+ def receive(self):
+ # we don't need receive() in standalone
+ pass
+
+ def add_neighbors(self, neighbor_id, address=None):
+ self.neighbors[neighbor_id] = address
+
+ def get_neighbors(self, neighbor_id=None):
+ address = dict()
+ if neighbor_id:
+ if isinstance(neighbor_id, list):
+ for each_neighbor in neighbor_id:
+ address[each_neighbor] = self.get_neighbors(each_neighbor)
+ return address
+ else:
+ return self.neighbors[neighbor_id]
+ else:
+ # Get all neighbors
+ return self.neighbors
+
+ def send(self, message):
+ self.comm_queue.append(message)
+ download_bytes, upload_bytes = message.count_bytes()
+ self.monitor.track_upload_bytes(upload_bytes)
+
+
+class gRPCCommManager(object):
+ """
+ The implementation of gRPCCommManager is referred to the tutorial on https://grpc.io/docs/languages/python/
+ """
+ def __init__(self, host='0.0.0.0', port='50050', client_num=2):
+ self.host = host
+ self.port = port
+ options = [
+ ("grpc.max_send_message_length",
+ global_cfg.distribute.grpc_max_send_message_length),
+ ("grpc.max_receive_message_length",
+ global_cfg.distribute.grpc_max_receive_message_length),
+ ("grpc.enable_http_proxy",
+ global_cfg.distribute.grpc_enable_http_proxy),
+ ]
+ self.server_funcs = gRPCComServeFunc()
+ self.grpc_server = self.serve(max_workers=client_num,
+ host=host,
+ port=port,
+ options=options)
+ self.neighbors = dict()
+ self.monitor = None # used to track the communication related metrics
+
+ def serve(self, max_workers, host, port, options):
+ """
+ This function is referred to https://grpc.io/docs/languages/python/basics/#starting-the-server
+ """
+ server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=max_workers),
+ options=options)
+ gRPC_comm_manager_pb2_grpc.add_gRPCComServeFuncServicer_to_server(
+ self.server_funcs, server)
+ server.add_insecure_port("{}:{}".format(host, port))
+ server.start()
+
+ return server
+
+ def add_neighbors(self, neighbor_id, address):
+ self.neighbors[neighbor_id] = '{}:{}'.format(address['host'],
+ address['port'])
+
+ def get_neighbors(self, neighbor_id=None):
+ address = dict()
+ if neighbor_id:
+ if isinstance(neighbor_id, list):
+ for each_neighbor in neighbor_id:
+ address[each_neighbor] = self.get_neighbors(each_neighbor)
+ return address
+ else:
+ return self.neighbors[neighbor_id]
+ else:
+ #Get all neighbors
+ return self.neighbors
+
+ def _send(self, receiver_address, message):
+ def _create_stub(receiver_address):
+ """
+ This part is referred to https://grpc.io/docs/languages/python/basics/#creating-a-stub
+ """
+ channel = grpc.insecure_channel(receiver_address,
+ options=(('grpc.enable_http_proxy',
+ 0), ))
+ stub = gRPC_comm_manager_pb2_grpc.gRPCComServeFuncStub(channel)
+ return stub, channel
+
+ stub, channel = _create_stub(receiver_address)
+ request = message.transform(to_list=True)
+ try:
+ stub.sendMessage(request)
+ except grpc._channel._InactiveRpcError:
+ pass
+ channel.close()
+
+ def send(self, message):
+ receiver = message.receiver
+ if receiver is not None:
+ if not isinstance(receiver, list):
+ receiver = [receiver]
+ for each_receiver in receiver:
+ if each_receiver in self.neighbors:
+ receiver_address = self.neighbors[each_receiver]
+ self._send(receiver_address, message)
+ else:
+ for each_receiver in self.neighbors:
+ receiver_address = self.neighbors[each_receiver]
+ self._send(receiver_address, message)
+
+ def receive(self):
+ received_msg = self.server_funcs.receive()
+ message = Message()
+ message.parse(received_msg.msg)
+ return message
diff --git a/federatedscope/core/configs/__init__.py b/federatedscope/core/configs/__init__.py
new file mode 100644
index 000000000..5c3259a8d
--- /dev/null
+++ b/federatedscope/core/configs/__init__.py
@@ -0,0 +1,29 @@
+import copy
+from os.path import dirname, basename, isfile, join
+import glob
+
+modules = glob.glob(join(dirname(__file__), "*.py"))
+__all__ = [
+ basename(f)[:-3] for f in modules
+ if isfile(f) and not f.endswith('__init__.py')
+]
+
+# to ensure the sub-configs registered before set up the global config
+all_sub_configs = copy.copy(__all__)
+if "config" in all_sub_configs:
+ all_sub_configs.remove('config')
+
+from federatedscope.core.configs.config import CN, init_global_cfg
+__all__ = __all__ + \
+ [
+ 'CN',
+ 'init_global_cfg'
+ ]
+
+# reorder the config to ensure the base config will be registered first
+base_configs = [
+ 'cfg_data', 'cfg_fl_setting', 'cfg_model', 'cfg_training', 'cfg_evaluation'
+]
+for base_config in base_configs:
+ all_sub_configs.pop(all_sub_configs.index(base_config))
+ all_sub_configs.insert(0, base_config)
diff --git a/federatedscope/core/configs/cfg_asyn.py b/federatedscope/core/configs/cfg_asyn.py
new file mode 100644
index 000000000..7dc26a24c
--- /dev/null
+++ b/federatedscope/core/configs/cfg_asyn.py
@@ -0,0 +1,55 @@
+import logging
+
+from federatedscope.core.configs.config import CN
+from federatedscope.register import register_config
+
+
+def extend_asyn_cfg(cfg):
+ # ------------------------------------------------------------------------ #
+ # Asynchronous related options
+ # ------------------------------------------------------------------------ #
+ cfg.asyn = CN()
+
+ cfg.asyn.use = True
+ cfg.asyn.timeout = 0
+ cfg.asyn.min_received_num = 2
+ cfg.asyn.min_received_rate = -1.0
+
+ # --------------- register corresponding check function ----------
+ cfg.register_cfg_check_fun(assert_asyn_cfg)
+
+
+def assert_asyn_cfg(cfg):
+ # to ensure a valid timeout seconds
+ assert isinstance(cfg.asyn.timeout, int) or isinstance(
+ cfg.asyn.timeout, float
+ ), "The timeout (seconds) must be an int or a float value, but {} is got".format(
+ type(cfg.asyn.timeout))
+
+ # min received num pre-process
+ min_received_num_valid = (0 < cfg.asyn.min_received_num <=
+ cfg.federate.sample_client_num)
+ min_received_rate_valid = (0 < cfg.asyn.min_received_rate <= 1)
+ # (a) sampling case
+ if min_received_rate_valid:
+ # (a.1) use min_received_rate
+ old_min_received_num = cfg.asyn.min_received_num
+ cfg.asyn.min_received_num = max(
+ 1,
+ int(cfg.asyn.min_received_rate * cfg.federate.sample_client_num))
+ if min_received_num_valid:
+ logging.warning(
+ f"Users specify both valid min_received_rate as {cfg.asyn.min_received_rate} "
+ f"and min_received_num as {old_min_received_num}.\n"
+ f"\t\tWe will use the min_received_rate value to calculate "
+ f"the actual number of participated clients as {cfg.asyn.min_received_num}."
+ )
+ # (a.2) use min_received_num, commented since the below two lines do not change anything
+ # elif min_received_rate:
+ # cfg.asyn.min_received_num = cfg.asyn.min_received_num
+ if not (min_received_num_valid or min_received_rate_valid):
+ # (b) non-sampling case, use all clients
+ cfg.asyn.min_received_num = cfg.federate.sample_client_num
+
+
+register_config("asyn", extend_asyn_cfg)
diff --git a/federatedscope/core/configs/cfg_attack.py b/federatedscope/core/configs/cfg_attack.py
new file mode 100644
index 000000000..c337876f5
--- /dev/null
+++ b/federatedscope/core/configs/cfg_attack.py
@@ -0,0 +1,70 @@
+from federatedscope.core.configs.config import CN
+from federatedscope.register import register_config
+
+
+def extend_attack_cfg(cfg):
+
+ # ------------------------------------------------------------------------ #
+ # attack
+ # ------------------------------------------------------------------------ #
+ cfg.attack = CN()
+ cfg.attack.attack_method = ''
+ # for gan_attack and backdoor attack
+ cfg.attack.target_label_ind = -1
+ cfg.attack.attacker_id = -1
+
+ # for backdoor attack
+
+ cfg.attack.setting = 'fix'
+ cfg.attack.freq = 10
+ cfg.attack.insert_round = 100000
+ cfg.attack.mean = [0.1307]
+ cfg.attack.std = [0.3081]
+ cfg.attack.trigger_type = 'edge'
+ cfg.attack.label_type = 'dirty'
+ cfg.attack.edge_num = 100
+ cfg.attack.poison_ratio = 0.5
+ cfg.attack.scale_poisoning = False
+ cfg.attack.scale_para = 1.0
+ cfg.attack.pgd_poisoning = False
+ cfg.attack.pgd_lr = 0.1
+ cfg.attack.pgd_eps = 2
+ cfg.attack.self_opt = False
+ cfg.attack.self_lr = 0.05
+ cfg.attack.self_epoch = 6
+
+ # defense:
+
+ cfg.attack.norm_clip = False
+ cfg.attack.norm_clip_value = 5.0
+ cfg.attack.dp_noise = -1.0
+ cfg.attack.krum = False
+ cfg.attack.multi_krum = False
+
+ # Note: the mean and std should be the list type.
+
+ # for reconstruct_opt
+ cfg.attack.reconstruct_lr = 0.01
+ cfg.attack.reconstruct_optim = 'Adam'
+ cfg.attack.info_diff_type = 'l2'
+ cfg.attack.max_ite = 400
+ cfg.attack.alpha_TV = 0.001
+
+ # for active PIA attack
+ cfg.attack.alpha_prop_loss = 0
+
+ # for passive PIA attack
+ cfg.attack.classifier_PIA = 'randomforest'
+
+ # for gradient Ascent --- MIA attack
+ cfg.attack.inject_round = 0
+
+ # --------------- register corresponding check function ----------
+ cfg.register_cfg_check_fun(assert_attack_cfg)
+
+
+def assert_attack_cfg(cfg):
+ pass
+
+
+register_config("attack", extend_attack_cfg)
diff --git a/federatedscope/core/configs/cfg_data.py b/federatedscope/core/configs/cfg_data.py
new file mode 100644
index 000000000..956cf9407
--- /dev/null
+++ b/federatedscope/core/configs/cfg_data.py
@@ -0,0 +1,64 @@
+from federatedscope.core.configs.config import CN
+from federatedscope.register import register_config
+
+
+def extend_data_cfg(cfg):
+ # ------------------------------------------------------------------------ #
+ # Dataset related options
+ # ------------------------------------------------------------------------ #
+ cfg.data = CN()
+
+ cfg.data.seed = 1
+ cfg.data.root = 'data'
+ cfg.data.type = 'toy'
+ cfg.data.args = [] # args for external dataset, eg. [{'download': True}]
+ cfg.data.splitter = ''
+ cfg.data.splitter_args = [] # args for splitter, eg. [{'alpha': 0.5}]
+ cfg.data.transform = [
+ ] # transform for x, eg. [['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]]
+ cfg.data.target_transform = [] # target_transform for y, use as above
+ cfg.data.pre_transform = [
+ ] # pre_transform for `torch_geometric` dataset, use as above
+ cfg.data.batch_size = 64
+ cfg.data.drop_last = False
+ cfg.data.sizes = [10, 5]
+ cfg.data.shuffle = True
+ cfg.data.subsample = 1.0
+ cfg.data.splits = [0.8, 0.1, 0.1] # Train, valid, test splits
+ cfg.data.consistent_label_distribution = True
+ # we need to keep the consistency between training split and testing split.
+ cfg.data.cSBM_phi = [0.5, 0.5, 0.5]
+ cfg.data.loader = ''
+ cfg.data.num_workers = 0
+ cfg.data.graphsaint = CN()
+ cfg.data.graphsaint.walk_length = 2
+ cfg.data.graphsaint.num_steps = 30
+ cfg.data.do_sta = False
+ cfg.data.plot_boxplot = False
+ cfg.data.probe_label_dist = False
+ cfg.data.labelwise_boxplot = False
+
+ # new config:
+ cfg.data.dataset = ['train', 'val', 'test']
+
+ # quadratic
+ cfg.data.quadratic = CN()
+ cfg.data.quadratic.dim = 1
+ cfg.data.quadratic.min_curv = 0.02
+ cfg.data.quadratic.max_curv = 12.5
+
+ # --------------- register corresponding check function ----------
+ cfg.register_cfg_check_fun(assert_data_cfg)
+
+
+def assert_data_cfg(cfg):
+ if cfg.data.loader == 'graphsaint-rw':
+ assert cfg.model.layer == cfg.data.graphsaint.walk_length, 'Sample size mismatch'
+ if cfg.data.loader == 'neighbor':
+ assert cfg.model.layer == len(cfg.data.sizes), 'Sample size mismatch'
+ if '@' in cfg.data.type:
+ assert cfg.federate.client_num > 0, '`federate.client_num` should be greater than 0 when using external data'
+ assert cfg.data.splitter, '`data.splitter` should not be empty when using external data'
+
+
+register_config("data", extend_data_cfg)
diff --git a/federatedscope/core/configs/cfg_differential_privacy.py b/federatedscope/core/configs/cfg_differential_privacy.py
new file mode 100644
index 000000000..f25461d22
--- /dev/null
+++ b/federatedscope/core/configs/cfg_differential_privacy.py
@@ -0,0 +1,38 @@
+from federatedscope.core.configs.config import CN
+from federatedscope.register import register_config
+
+
+def extend_dp_cfg(cfg):
+ # ------------------------------------------------------------------------ #
+ # nbafl(dp) related options
+ # ------------------------------------------------------------------------ #
+ cfg.nbafl = CN()
+
+ # Params
+ cfg.nbafl.use = False
+ cfg.nbafl.mu = 0.
+ cfg.nbafl.epsilon = 100.
+ cfg.nbafl.w_clip = 1.
+ cfg.nbafl.constant = 30.
+
+ # ------------------------------------------------------------------------ #
+ # VFL-SGDMF(dp) related options
+ # ------------------------------------------------------------------------ #
+ cfg.sgdmf = CN()
+
+ cfg.sgdmf.use = False # if use sgdmf algorithm
+ cfg.sgdmf.R = 5. # The upper bound of rating
+ cfg.sgdmf.epsilon = 4. # \epsilon in dp
+ cfg.sgdmf.delta = 0.5 # \delta in dp
+ cfg.sgdmf.constant = 1. # constant
+ cfg.sgdmf.theta = -1 # -1 means per-rating privacy, otherwise per-user privacy
+
+ # --------------- register corresponding check function ----------
+ cfg.register_cfg_check_fun(assert_dp_cfg)
+
+
+def assert_dp_cfg(cfg):
+ pass
+
+
+register_config("dp", extend_dp_cfg)
diff --git a/federatedscope/core/configs/cfg_evaluation.py b/federatedscope/core/configs/cfg_evaluation.py
new file mode 100644
index 000000000..0496554b9
--- /dev/null
+++ b/federatedscope/core/configs/cfg_evaluation.py
@@ -0,0 +1,39 @@
+from federatedscope.core.configs.config import CN
+from federatedscope.register import register_config
+
+
+def extend_evaluation_cfg(cfg):
+
+ # ------------------------------------------------------------------------ #
+ # Evaluation related options
+ # ------------------------------------------------------------------------ #
+ cfg.eval = CN()
+
+ cfg.eval.save_data = False
+ cfg.eval.freq = 1
+ cfg.eval.metrics = []
+ cfg.eval.split = ['test', 'val']
+ cfg.eval.report = ['weighted_avg', 'avg', 'fairness',
+ 'raw'] # by default, we report comprehensive results
+ cfg.eval.best_res_update_round_wise_key = "val_loss"
+
+ # Monitoring, e.g., 'dissim' for B-local dissimilarity
+ cfg.eval.monitoring = []
+
+ # ------------------------------------------------------------------------ #
+ # wandb related options
+ # ------------------------------------------------------------------------ #
+ cfg.wandb = CN()
+ cfg.wandb.use = False
+ cfg.wandb.name_user = ''
+ cfg.wandb.name_project = ''
+
+ # --------------- register corresponding check function ----------
+ cfg.register_cfg_check_fun(assert_evaluation_cfg)
+
+
+def assert_evaluation_cfg(cfg):
+ pass
+
+
+register_config("eval", extend_evaluation_cfg)
diff --git a/federatedscope/core/configs/cfg_fl_algo.py b/federatedscope/core/configs/cfg_fl_algo.py
new file mode 100644
index 000000000..de51b0b47
--- /dev/null
+++ b/federatedscope/core/configs/cfg_fl_algo.py
@@ -0,0 +1,102 @@
+from federatedscope.core.configs.config import CN
+from federatedscope.register import register_config
+
+
+def extend_fl_algo_cfg(cfg):
+ # ------------------------------------------------------------------------ #
+ # fedopt related options, general fl
+ # ------------------------------------------------------------------------ #
+ cfg.fedopt = CN()
+
+ cfg.fedopt.use = False
+
+ cfg.fedopt.optimizer = CN(new_allowed=True)
+ cfg.fedopt.optimizer.type = 'SGD'
+ cfg.fedopt.optimizer.lr = 0.01
+
+ # ------------------------------------------------------------------------ #
+ # fedprox related options, general fl
+ # ------------------------------------------------------------------------ #
+ cfg.fedprox = CN()
+
+ cfg.fedprox.use = False
+ cfg.fedprox.mu = 0.
+
+ # ------------------------------------------------------------------------ #
+ # Personalization related options, pFL
+ # ------------------------------------------------------------------------ #
+ cfg.personalization = CN()
+
+ # client-distinct param names, e.g., ['pre', 'post']
+ cfg.personalization.local_param = []
+ cfg.personalization.share_non_trainable_para = False
+ cfg.personalization.local_update_steps = -1
+ # @regular_weight:
+ # The smaller the regular_weight is, the stronger emphasising on personalized model
+ # For Ditto, the default value=0.1, the search space is [0.05, 0.1, 0.2, 1, 2]
+ # For pFedMe, the default value=15
+ cfg.personalization.regular_weight = 0.1
+
+ # @lr:
+ # 1) For pFedME, the personalized learning rate to calculate theta approximately using K steps
+ # 2) 0.0 indicates use the value according to optimizer.lr in case of users have not specify a valid lr
+ cfg.personalization.lr = 0.0
+
+ cfg.personalization.K = 5 # the local approximation steps for pFedMe
+ cfg.personalization.beta = 1.0 # the average moving parameter for pFedMe
+
+ # parameters for FedRep:
+ cfg.personalization.lr_feature = 0.1
+ cfg.personalization.lr_linear = 0.1
+ cfg.personalization.epoch_feature = 1
+ cfg.personalization.epoch_linear = 2
+ cfg.personalization.weight_decay = 0.0
+ # ------------------------------------------------------------------------ #
+ # FedSage+ related options, gfl
+ # ------------------------------------------------------------------------ #
+ cfg.fedsageplus = CN()
+
+ cfg.fedsageplus.num_pred = 5
+ cfg.fedsageplus.gen_hidden = 128
+ cfg.fedsageplus.hide_portion = 0.5
+ cfg.fedsageplus.fedgen_epoch = 200
+ cfg.fedsageplus.loc_epoch = 1
+ cfg.fedsageplus.a = 1.0
+ cfg.fedsageplus.b = 1.0
+ cfg.fedsageplus.c = 1.0
+
+ # ------------------------------------------------------------------------ #
+ # GCFL+ related options, gfl
+ # ------------------------------------------------------------------------ #
+ cfg.gcflplus = CN()
+
+ cfg.gcflplus.EPS_1 = 0.05
+ cfg.gcflplus.EPS_2 = 0.1
+ cfg.gcflplus.seq_length = 5
+ cfg.gcflplus.standardize = False
+
+ # ------------------------------------------------------------------------ #
+ # FLIT+ related options, gfl
+ # ------------------------------------------------------------------------ #
+ cfg.flitplus = CN()
+
+ cfg.flitplus.tmpFed = 0.5 # gamma in focal loss (Eq.4)
+ cfg.flitplus.lambdavat = 0.5 # lambda in phi (Eq.10)
+ cfg.flitplus.factor_ema = 0.8 # beta in omega (Eq.12)
+ cfg.flitplus.weightReg = 1.0 # balance lossLocalLabel and lossLocalVAT
+
+ # --------------- register corresponding check function ----------
+ cfg.register_cfg_check_fun(assert_fl_algo_cfg)
+
+
+def assert_fl_algo_cfg(cfg):
+ if cfg.personalization.local_update_steps == -1:
+ # By default, use the same step to normal mode
+ cfg.personalization.local_update_steps = cfg.federate.local_update_steps
+
+ if cfg.personalization.lr <= 0.0:
+ # By default, use the same lr to normal mode
+ cfg.personalization.lr = cfg.optimizer.lr
+
+
+register_config("fl_algo", extend_fl_algo_cfg)
diff --git a/federatedscope/core/configs/cfg_fl_setting.py b/federatedscope/core/configs/cfg_fl_setting.py
new file mode 100644
index 000000000..98fda6df5
--- /dev/null
+++ b/federatedscope/core/configs/cfg_fl_setting.py
@@ -0,0 +1,155 @@
+import logging
+
+from federatedscope.core.configs.config import CN
+from federatedscope.register import register_config
+
+logger = logging.getLogger(__name__)
+
+
+def extend_fl_setting_cfg(cfg):
+ # ------------------------------------------------------------------------ #
+ # Federate learning related options
+ # ------------------------------------------------------------------------ #
+ cfg.federate = CN()
+
+ cfg.federate.client_num = 0
+ cfg.federate.sample_client_num = -1
+ cfg.federate.sample_client_rate = -1.0
+ cfg.federate.unseen_clients_rate = 0.0
+ cfg.federate.total_round_num = 50
+ cfg.federate.mode = 'standalone'
+ cfg.federate.local_update_steps = 1
+ cfg.federate.batch_or_epoch = 'epoch'
+ cfg.federate.share_local_model = False
+ cfg.federate.data_weighted_aggr = False # If True, the weight of aggr is the number of training samples in dataset.
+ cfg.federate.online_aggr = False
+ cfg.federate.make_global_eval = False
+ cfg.federate.use_diff = False
+ cfg.federate.weight_avg = True
+
+ # the method name is used to internally determine composition of different aggregators, messages, handlers, etc.,
+ cfg.federate.method = "FedAvg"
+ cfg.federate.ignore_weight = False
+ cfg.federate.use_ss = False # Whether to apply Secret Sharing
+ cfg.federate.restore_from = ''
+ cfg.federate.save_to = ''
+ cfg.federate.join_in_info = [
+ ] # The information requirements (from server) for join_in
+
+ # ------------------------------------------------------------------------ #
+ # Distribute training related options
+ # ------------------------------------------------------------------------ #
+ cfg.distribute = CN()
+
+ cfg.distribute.use = False
+ cfg.distribute.server_host = '0.0.0.0'
+ cfg.distribute.server_port = 50050
+ cfg.distribute.client_host = '0.0.0.0'
+ cfg.distribute.client_port = 50050
+ cfg.distribute.role = 'client'
+ cfg.distribute.data_file = 'data'
+ cfg.distribute.grpc_max_send_message_length = 100 * 1024 * 1024
+ cfg.distribute.grpc_max_receive_message_length = 100 * 1024 * 1024
+ cfg.distribute.grpc_enable_http_proxy = False
+
+ # ------------------------------------------------------------------------ #
+ # Vertical FL related options (for demo)
+ # ------------------------------------------------------------------------ #
+ cfg.vertical = CN()
+ cfg.vertical.use = False
+ cfg.vertical.encryption = 'paillier'
+ cfg.vertical.dims = [5, 10]
+ cfg.vertical.key_size = 3072
+
+ # --------------- register corresponding check function ----------
+ cfg.register_cfg_check_fun(assert_fl_setting_cfg)
+
+
+def assert_fl_setting_cfg(cfg):
+ if cfg.federate.batch_or_epoch not in ['batch', 'epoch']:
+ raise ValueError(
+ "Value of 'cfg.federate.batch_or_epoch' must be chosen from ['batch', 'epoch']."
+ )
+
+ assert cfg.federate.mode in ["standalone", "distributed"], \
+ f"Please specify the cfg.federate.mode as the string standalone or distributed. But got {cfg.federate.mode}."
+
+ # ============= client num related ==============
+ assert not (cfg.federate.client_num == 0
+ and cfg.federate.mode == 'distributed'
+ ), "Please configure the cfg.federate. in distributed mode. "
+
+ assert 0 <= cfg.federate.unseen_clients_rate < 1, \
+ "You specified in-valid cfg.federate.unseen_clients_rate"
+ if 0 < cfg.federate.unseen_clients_rate < 1 and cfg.federate.method in [
+ "local", "global"
+ ]:
+ logger.warning(
+ "In local/global training mode, the unseen_clients_rate is "
+ "in-valid, plz check your config")
+ unseen_clients_rate = 0.0
+ cfg.federate.unseen_clients_rate = unseen_clients_rate
+ else:
+ unseen_clients_rate = cfg.federate.unseen_clients_rate
+ participated_client_num = max(
+ 1, int((1 - unseen_clients_rate) * cfg.federate.client_num))
+
+ # sample client num pre-process
+ sample_client_num_valid = (0 < cfg.federate.sample_client_num <=
+ cfg.federate.client_num)
+ sample_client_rate_valid = (0 < cfg.federate.sample_client_rate <= 1)
+
+ sample_cfg_valid = sample_client_rate_valid or sample_client_num_valid
+ non_sample_case = cfg.federate.method in ["local", "global"]
+ if non_sample_case and sample_cfg_valid:
+ logger.warning(
+ "In local/global training mode, the sampling related configs are in-valid, we will use all clients. "
+ )
+
+ if cfg.federate.method == "global":
+ cfg.federate.client_num = 1
+ logger.info(
+ "In global training mode, we will put all data in a proxy client. "
+ )
+ if cfg.federate.make_global_eval:
+ cfg.federate.make_global_eval = False
+ logger.warning(
+ "In global training mode, we will conduct global evaluation in a proxy client rather than the server. The configuration cfg.federate.make_global_eval will be False."
+ )
+
+ if non_sample_case or not sample_cfg_valid:
+ # (a) use all clients
+ cfg.federate.sample_client_num = cfg.federate.client_num
+ else:
+ # (b) sampling case
+ if sample_client_rate_valid:
+ # (b.1) use sample_client_rate
+ old_sample_client_num = cfg.federate.sample_client_num
+ cfg.federate.sample_client_num = max(
+ 1,
+ int(cfg.federate.sample_client_rate * cfg.federate.client_num))
+ if sample_client_num_valid:
+ logger.warning(
+ f"Users specify both valid sample_client_rate as {cfg.federate.sample_client_rate} "
+ f"and sample_client_num as {old_sample_client_num}.\n"
+ f"\t\tWe will use the sample_client_rate value to calculate "
+ f"the actual number of participated clients as {cfg.federate.sample_client_num}."
+ )
+ # (b.2) use sample_client_num, commented since the below two lines do not change anything
+ # elif sample_client_num_valid:
+ # cfg.federate.sample_client_num = cfg.federate.sample_client_num
+
+ if cfg.federate.use_ss:
+ assert cfg.federate.client_num == cfg.federate.sample_client_num, \
+ "Currently, we support secret sharing only in all-client-participation case"
+
+ assert cfg.federate.method != "local", \
+ "Secret sharing is not supported in local training mode"
+
+ # ============= aggregator related ================
+ assert (not cfg.federate.online_aggr) or (
+ not cfg.federate.use_ss
+ ), "Have not supported to use online aggregator and secrete sharing at the same time"
+
+
+register_config("fl_setting", extend_fl_setting_cfg)
diff --git a/federatedscope/core/configs/cfg_hpo.py b/federatedscope/core/configs/cfg_hpo.py
new file mode 100644
index 000000000..b1a6f82eb
--- /dev/null
+++ b/federatedscope/core/configs/cfg_hpo.py
@@ -0,0 +1,87 @@
+from federatedscope.core.configs.config import CN
+from federatedscope.register import register_config
+
+
+def extend_hpo_cfg(cfg):
+
+ # ------------------------------------------------------------------------ #
+ # hpo related options
+ # ------------------------------------------------------------------------ #
+ cfg.hpo = CN()
+ cfg.hpo.working_folder = 'hpo'
+ cfg.hpo.ss = ''
+ cfg.hpo.num_workers = 0
+ #cfg.hpo.init_strategy = 'random'
+ cfg.hpo.init_cand_num = 16
+ cfg.hpo.log_scale = False
+ cfg.hpo.larger_better = False
+ cfg.hpo.scheduler = 'rs'
+ # plot the performance
+ cfg.hpo.plot_interval = 1
+ cfg.hpo.metric = 'client_summarized_weighted_avg.val_loss'
+
+ # SHA
+ cfg.hpo.sha = CN()
+ cfg.hpo.sha.elim_round_num = 3
+ cfg.hpo.sha.elim_rate = 3
+ cfg.hpo.sha.budgets = []
+
+ # PBT
+ cfg.hpo.pbt = CN()
+ cfg.hpo.pbt.max_stage = 5
+ cfg.hpo.pbt.perf_threshold = 0.1
+
+ # FedEx
+ cfg.hpo.fedex = CN()
+ cfg.hpo.fedex.use = False
+ cfg.hpo.fedex.ss = ''
+ cfg.hpo.fedex.flatten_ss = True
+ # If <= .0, use 'auto'
+ cfg.hpo.fedex.eta0 = -1.0
+ cfg.hpo.fedex.sched = 'auto'
+ # cutoff: entropy level below which to stop updating the config probability and use MLE
+ cfg.hpo.fedex.cutoff = .0
+ # discount factor; 0.0 is most recent, 1.0 is mean
+ cfg.hpo.fedex.gamma = .0
+ cfg.hpo.fedex.num_arms = 16
+ cfg.hpo.fedex.diff = False
+
+ # Table
+ cfg.hpo.table = CN()
+ cfg.hpo.table.ss = ''
+ cfg.hpo.table.eps = 0.1
+ cfg.hpo.table.num = 27
+ #cfg.hpo.table.cand = 81
+ cfg.hpo.table.idx = 0
+
+
+def assert_hpo_cfg(cfg):
+ # HPO related
+ #assert cfg.hpo.init_strategy in [
+ # 'full', 'grid', 'random'
+ #], "initialization strategy for HPO should be \"full\", \"grid\", or \"random\", but the given choice is {}".format(
+ # cfg.hpo.init_strategy)
+ assert cfg.hpo.scheduler in ['rs', 'sha',
+ 'pbt'], "No HPO scheduler named {}".format(
+ cfg.hpo.scheduler)
+ assert cfg.hpo.num_workers >= 0, "#worker should be non-negative but given {}".format(
+ cfg.hpo.num_workers)
+ assert len(cfg.hpo.sha.budgets) == 0 or len(
+ cfg.hpo.sha.budgets
+ ) == cfg.hpo.sha.elim_round_num, \
+ "Either do NOT specify the budgets or specify the budget for each SHA iteration, but the given budgets is {}".\
+ format(cfg.hpo.sha.budgets)
+
+ assert not (cfg.hpo.fedex.use and cfg.federate.use_ss
+ ), "Cannot use secret sharing and FedEx at the same time"
+ assert cfg.optimizer.type == 'SGD' or not cfg.hpo.fedex.use, "SGD is required if FedEx is considered"
+ assert cfg.hpo.fedex.sched in [
+ 'adaptive', 'aggressive', 'auto', 'constant', 'scale'
+ ], "schedule of FedEx must be choice from {}".format(
+ ['adaptive', 'aggressive', 'auto', 'constant', 'scale'])
+ assert cfg.hpo.fedex.gamma >= .0 and cfg.hpo.fedex.gamma <= 1.0, "{} must be in [0, 1]".format(
+ cfg.hpo.fedex.gamma)
+ assert cfg.hpo.fedex.use == cfg.federate.use_diff, "Once FedEx is adopted, cfg.federate.use_diff must be True."
+
+
+register_config("hpo", extend_hpo_cfg)
diff --git a/federatedscope/core/configs/cfg_model.py b/federatedscope/core/configs/cfg_model.py
new file mode 100644
index 000000000..dcbd955de
--- /dev/null
+++ b/federatedscope/core/configs/cfg_model.py
@@ -0,0 +1,48 @@
+from federatedscope.core.configs.config import CN
+from federatedscope.register import register_config
+
+
+def extend_model_cfg(cfg):
+ # ------------------------------------------------------------------------ #
+ # Model related options
+ # ------------------------------------------------------------------------ #
+ cfg.model = CN()
+
+ cfg.model.model_num_per_trainer = 1 # some methods may leverage more than one model in each trainer
+ cfg.model.type = 'lr'
+ cfg.model.use_bias = True
+ cfg.model.task = 'node'
+ cfg.model.hidden = 256
+ cfg.model.dropout = 0.0
+ cfg.model.in_channels = 0 # If 0, model will be built by data.shape
+ cfg.model.out_channels = 1
+ cfg.model.layer = 2 # In GPR-GNN, K = layer
+ cfg.model.graph_pooling = 'mean'
+ cfg.model.embed_size = 8
+ cfg.model.num_item = 0
+ cfg.model.num_user = 0
+
+ # ------------------------------------------------------------------------ #
+ # Criterion related options
+ # ------------------------------------------------------------------------ #
+ cfg.criterion = CN()
+
+ cfg.criterion.type = 'MSELoss'
+
+ # ------------------------------------------------------------------------ #
+ # regularizer related options
+ # ------------------------------------------------------------------------ #
+ cfg.regularizer = CN()
+
+ cfg.regularizer.type = ''
+ cfg.regularizer.mu = 0.
+
+ # --------------- register corresponding check function ----------
+ cfg.register_cfg_check_fun(assert_model_cfg)
+
+
+def assert_model_cfg(cfg):
+ pass
+
+
+register_config("model", extend_model_cfg)
diff --git a/federatedscope/core/configs/cfg_training.py b/federatedscope/core/configs/cfg_training.py
new file mode 100644
index 000000000..f47bacd43
--- /dev/null
+++ b/federatedscope/core/configs/cfg_training.py
@@ -0,0 +1,84 @@
+from federatedscope.core.configs.config import CN
+from federatedscope.register import register_config
+
+
+def extend_training_cfg(cfg):
+ # ------------------------------------------------------------------------ #
+ # Trainer related options
+ # ------------------------------------------------------------------------ #
+ cfg.trainer = CN()
+
+ cfg.trainer.type = 'general'
+ cfg.trainer.finetune = CN()
+ cfg.trainer.finetune.before_eval = False
+ cfg.trainer.finetune.steps = 1
+ cfg.trainer.finetune.epochs = 1
+ cfg.trainer.finetune.lr = 0.01
+ cfg.trainer.finetune.freeze_param = "" # parameters frozen in fine-tuning stage
+ # cfg.trainer.finetune.only_psn = True
+
+ # ------------------------------------------------------------------------ #
+ # Optimizer related options
+ # ------------------------------------------------------------------------ #
+ cfg.optimizer = CN(new_allowed=True)
+
+ cfg.optimizer.type = 'SGD'
+ cfg.optimizer.lr = 0.1
+
+ # ------------------------------------------------------------------------ #
+ # Gradient related options
+ # ------------------------------------------------------------------------ #
+ cfg.grad = CN()
+ cfg.grad.grad_clip = -1.0 # negative numbers indicate we do not clip grad
+
+ # ------------------------------------------------------------------------ #
+ # lr_scheduler related options
+ # ------------------------------------------------------------------------ #
+ # cfg.lr_scheduler = CN()
+ # cfg.lr_scheduler.type = 'StepLR'
+ # cfg.lr_scheduler.schlr_params = dict()
+
+ # ------------------------------------------------------------------------ #
+ # Early stopping related options
+ # ------------------------------------------------------------------------ #
+ cfg.early_stop = CN()
+
+ # patience (int): How long to wait after last time the monitored metric improved.
+ # Note that the actual_checking_round = patience * cfg.eval.freq
+ # To disable the early stop, set the early_stop.patience a integer <=0
+ cfg.early_stop.patience = 0
+ # delta (float): Minimum change in the monitored metric to indicate an improvement.
+ cfg.early_stop.delta = 0.0
+ # Early stop when no improve to last `patience` round, in ['mean', 'best']
+ cfg.early_stop.improve_indicator_mode = 'best'
+ cfg.early_stop.the_smaller_the_better = True
+
+ # --------------- register corresponding check function ----------
+ cfg.register_cfg_check_fun(assert_training_cfg)
+
+
+def assert_training_cfg(cfg):
+ if cfg.backend not in ['torch', 'tensorflow']:
+ raise ValueError(
+ "Value of 'cfg.backend' must be chosen from ['torch', 'tensorflow']."
+ )
+ if cfg.backend == 'tensorflow' and cfg.federate.mode == 'standalone':
+ raise ValueError(
+ "We only support run with distribued mode when backend is tensorflow"
+ )
+ if cfg.backend == 'tensorflow' and cfg.use_gpu is True:
+ raise ValueError(
+ "We only support run with cpu when backend is tensorflow")
+
+ if cfg.trainer.finetune.before_eval is False and cfg.trainer.finetune.epochs <= 0:
+ raise ValueError(
+ f"When adopting fine-tuning, please set a valid local fine-tune epochs, got {cfg.trainer.finetune.epochs}"
+ )
+
+ # if cfg.trainer.finetune.before_eval is False and cfg.trainer.finetune.steps <= 0:
+ # raise ValueError(
+ # f"When adopting fine-tuning, please set a valid local fine-tune steps, got {cfg.trainer.finetune.steps}"
+ # )
+
+
+register_config("fl_training", extend_training_cfg)
diff --git a/federatedscope/core/configs/config.py b/federatedscope/core/configs/config.py
new file mode 100644
index 000000000..4e05fc7af
--- /dev/null
+++ b/federatedscope/core/configs/config.py
@@ -0,0 +1,206 @@
+import copy
+import logging
+import os
+
+from yacs.config import CfgNode
+from yacs.config import _assert_with_logging
+from yacs.config import _check_and_coerce_cfg_value_type
+
+import federatedscope.register as register
+
+logger = logging.getLogger(__name__)
+
+
+class CN(CfgNode):
+ """
+ An extended configuration system based on [yacs](https://github.com/rbgirshick/yacs).
+ The two-level tree structure consists of several internal dict-like containers to allow simple key-value access and management.
+
+ """
+ def __init__(self, init_dict=None, key_list=None, new_allowed=False):
+ super().__init__(init_dict, key_list, new_allowed)
+ self.__dict__["cfg_check_funcs"] = list(
+ ) # to check the config values validity
+
+ def __getattr__(self, name):
+ if name in self:
+ return self[name]
+ else:
+ raise AttributeError(name)
+
+ def register_cfg_check_fun(self, cfg_check_fun):
+ self.cfg_check_funcs.append(cfg_check_fun)
+
+ def merge_from_file(self, cfg_filename):
+ """
+ load configs from a yaml file, another cfg instance or a list stores the keys and values.
+
+ :param cfg_filename (string):
+ :return:
+ """
+ super(CN, self).merge_from_file(cfg_filename)
+ self.assert_cfg()
+
+ def merge_from_other_cfg(self, cfg_other):
+ """
+ load configs from another cfg instance
+
+ :param cfg_other (CN):
+ :return:
+ """
+ super(CN, self).merge_from_other_cfg(cfg_other)
+ self.assert_cfg()
+
+ def merge_from_list(self, cfg_list):
+ """
+ load configs from a list stores the keys and values.
+ modified `merge_from_list` in `yacs.config.py` to allow adding new keys if `is_new_allowed()` returns True
+
+ :param cfg_list (list):
+ :return:
+ """
+ _assert_with_logging(
+ len(cfg_list) % 2 == 0,
+ "Override list has odd length: {}; it must be a list of pairs".
+ format(cfg_list),
+ )
+ root = self
+ for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
+ if root.key_is_deprecated(full_key):
+ continue
+ if root.key_is_renamed(full_key):
+ root.raise_key_rename_error(full_key)
+ key_list = full_key.split(".")
+ d = self
+ for subkey in key_list[:-1]:
+ _assert_with_logging(subkey in d,
+ "Non-existent key: {}".format(full_key))
+ d = d[subkey]
+ subkey = key_list[-1]
+ _assert_with_logging(subkey in d or d.is_new_allowed(),
+ "Non-existent key: {}".format(full_key))
+ value = self._decode_cfg_value(v)
+ if subkey in d:
+ value = _check_and_coerce_cfg_value_type(
+ value, d[subkey], subkey, full_key)
+ d[subkey] = value
+
+ self.assert_cfg()
+
+ def assert_cfg(self):
+ """
+ check the validness of the configuration instance
+
+ :return:
+ """
+ for check_func in self.cfg_check_funcs:
+ check_func(self)
+
+ def clean_unused_sub_cfgs(self):
+ """
+ Clean the un-used secondary-level CfgNode, whose `.use` attribute is `True`
+
+ :return:
+ """
+ for v in self.values():
+ if isinstance(v, CfgNode) or isinstance(v, CN):
+ # sub-config
+ if hasattr(v, "use") and v.use is False:
+ for k in copy.deepcopy(v).keys():
+ # delete the un-used attributes
+ if k == "use":
+ continue
+ else:
+ del v[k]
+
+ def freeze(self, inform=True):
+ """
+ 1) make the cfg attributes immutable;
+ 2) save the frozen cfg_check_funcs into "self.outdir/config.yaml" for better reproducibility;
+ 3) if self.wandb.use=True, update the frozen config
+
+ :return:
+ """
+ self.assert_cfg()
+ self.clean_unused_sub_cfgs()
+ # save the final cfg
+ with open(os.path.join(self.outdir, "config.yaml"), 'w') as outfile:
+ from contextlib import redirect_stdout
+ with redirect_stdout(outfile):
+ tmp_cfg = copy.deepcopy(self)
+ tmp_cfg.cfg_check_funcs.clear()
+ print(tmp_cfg.dump())
+ if self.wandb.use:
+ # update the frozen config
+ try:
+ import wandb
+ except ImportError:
+ logger.error(
+ "cfg.wandb.use=True but not install the wandb package")
+ exit()
+
+ import yaml
+ cfg_yaml = yaml.safe_load(tmp_cfg.dump())
+ wandb.config.update(cfg_yaml, allow_val_change=True)
+
+ if inform:
+ logger.info("the used configs are: \n" + str(tmp_cfg))
+
+ super(CN, self).freeze()
+
+
+# to ensure the sub-configs registered before set up the global config
+from federatedscope.core.configs import all_sub_configs
+for sub_config in all_sub_configs:
+ __import__("federatedscope.core.configs." + sub_config)
+
+from federatedscope.contrib.configs import all_sub_configs_contrib
+for sub_config in all_sub_configs_contrib:
+ __import__("federatedscope.contrib.configs." + sub_config)
+
+# Global config object
+global_cfg = CN()
+
+
+def init_global_cfg(cfg):
+ r'''
+ This function sets the default config value.
+ 1) Note that for an experiment, only part of the arguments will be used
+ The remaining unused arguments won't affect anything.
+ So feel free to register any argument in graphgym.contrib.config
+ 2) We support *at most* two levels of configs, e.g., cfg.dataset.name
+
+ :return: configuration use by the experiment.
+ '''
+
+ # ------------------------------------------------------------------------ #
+ # Basic options, first level configs
+ # ------------------------------------------------------------------------ #
+
+ cfg.backend = 'torch'
+
+ # Whether to use GPU
+ cfg.use_gpu = False
+
+ # Whether to print verbose logging info
+ cfg.verbose = 1
+
+ # Specify the device
+ cfg.device = -1
+
+ # Random seed
+ cfg.seed = 0
+
+ # Path of configuration file
+ cfg.cfg_file = ''
+
+ # The dir used to save log, exp_config, models, etc,.
+ cfg.outdir = 'exp'
+ cfg.expname = '' # detailed exp name to distinguish different sub-exp
+
+ # extend user customized configs
+ for func in register.config_dict.values():
+ func(cfg)
+
+
+init_global_cfg(global_cfg)
diff --git a/federatedscope/core/configs/constants.py b/federatedscope/core/configs/constants.py
new file mode 100644
index 000000000..488384008
--- /dev/null
+++ b/federatedscope/core/configs/constants.py
@@ -0,0 +1,36 @@
+"""Configuration file for composition of different aggregators, messages, handlers, etc.
+
+ - The method `local` indicates that the clients only locally train their model without sharing any training related information
+ - The method `global` indicates that the only one client locally trains using all data
+
+"""
+
+AGGREGATOR_TYPE = {
+ "local": "no_communication", # the clients locally train their model without sharing any training related info
+ "global": "no_communication", # only one client locally train all data, i.e., totally global training
+ "fedavg": "clients_avg", # FedAvg
+ "pfedme": "server_clients_interpolation", # pFedMe, + server-clients interpolation
+ "ditto": "clients_avg", # Ditto
+ "fedsageplus": "clients_avg",
+ "gcflplus": "clients_avg",
+ "fedopt": "fedopt"
+}
+
+CLIENTS_TYPE = {
+ "local": "normal",
+ "fedavg": "normal", # FedAvg
+ "pfedme": "normal_loss_regular", # pFedMe, + regularization-based local loss
+ "ditto": "normal", # Ditto, + local training for distinct personalized models
+ "fedsageplus": "fedsageplus", # FedSage+ for graph data
+ "gcflplus": "gcflplus", # GCFL+ for graph data
+ "gradascent": "gradascent"
+}
+
+SERVER_TYPE = {
+ "local": "normal",
+ "fedavg": "normal", # FedAvg
+ "pfedme": "normal", # pFedMe, + regularization-based local loss
+ "ditto": "normal", # Ditto, + local training for distinct personalized models
+ "fedsageplus": "fedsageplus", # FedSage+ for graph data
+ "gcflplus": "gcflplus", # GCFL+ for graph data
+}
diff --git a/federatedscope/core/fed_runner.py b/federatedscope/core/fed_runner.py
new file mode 100644
index 000000000..0ce955d7b
--- /dev/null
+++ b/federatedscope/core/fed_runner.py
@@ -0,0 +1,275 @@
+import logging
+
+from collections import deque
+
+from federatedscope.core.worker import Server, Client
+from federatedscope.core.gpu_manager import GPUManager
+from federatedscope.core.auxiliaries.model_builder import get_model
+
+logger = logging.getLogger(__name__)
+
+
+class FedRunner(object):
+ """
+ This class is used to construct an FL course, which includes `_set_up` and `run`.
+
+ Arguments:
+ data: The data used in the FL courses, which are formatted as {'ID':data} for standalone mode. More details can be found in federatedscope.core.auxiliaries.data_builder .
+ server_class: The server class is used for instantiating a (customized) server.
+ client_class: The client class is used for instantiating a (customized) client.
+ config: The configurations of the FL course.
+ client_config: The clients' configurations.
+ """
+ def __init__(self,
+ data,
+ server_class=Server,
+ client_class=Client,
+ config=None,
+ client_config=None):
+ self.data = data
+ self.server_class = server_class
+ self.client_class = client_class
+ self.cfg = config
+ self.client_cfg = client_config
+
+ self.mode = self.cfg.federate.mode.lower()
+ self.gpu_manager = GPUManager(gpu_available=self.cfg.use_gpu,
+ specified_device=self.cfg.device)
+
+ if self.mode == 'standalone':
+ self.shared_comm_queue = deque()
+ self._setup_for_standalone()
+ # in standalone mode, by default, we print the trainer info only once for better logs readability
+ trainer_representative = self.client[1].trainer
+ if trainer_representative is not None:
+ trainer_representative.print_trainer_meta_info()
+ elif self.mode == 'distributed':
+ self._setup_for_distributed()
+
+ def _setup_for_standalone(self):
+ """
+ To set up server and client for standalone mode.
+ """
+
+ if self.cfg.backend == 'torch':
+ import torch
+ torch.set_num_threads(1)
+
+ self.server = self._setup_server()
+
+ self.client = dict()
+ assert self.cfg.federate.client_num != 0, \
+ "In standalone mode, self.cfg.federate.client_num should be non-zero. " \
+ "This is usually cased by using synthetic data and users not specify a non-zero value for client_num"
+
+ # assume the client-wise data are consistent in their input&output shape
+ self._shared_client_model = get_model(
+ self.cfg.model, self.data[1], backend=self.cfg.backend
+ ) if self.cfg.federate.share_local_model else None
+
+ if self.cfg.federate.method == "global":
+ assert 0 in self.data and self.data[
+ 0] is not None, "In global training mode, we will use a proxy client to hold all the data. Please put the whole dataset in data[0], i.e., the same style with global evaluation mode"
+ from federatedscope.core.auxiliaries.data_builder import merge_data
+ self.data[1] = merge_data(all_data=self.data)
+
+ for client_id in range(1, self.cfg.federate.client_num + 1):
+ self.client[client_id] = self._setup_client(
+ client_id=client_id, client_model=self._shared_client_model)
+
+ def _setup_for_distributed(self):
+ """
+ To set up server or client for distributed mode.
+ """
+ self.server_address = {
+ 'host': self.cfg.distribute.server_host,
+ 'port': self.cfg.distribute.server_port
+ }
+ if self.cfg.distribute.role == 'server':
+ self.server = self._setup_server()
+ elif self.cfg.distribute.role == 'client':
+ # When we set up the client in the distributed mode, we assume the server has been set up and number with #0
+ self.client_address = {
+ 'host': self.cfg.distribute.client_host,
+ 'port': self.cfg.distribute.client_port
+ }
+ self.client = self._setup_client()
+
+ def run(self):
+ """
+ To run an FL course, which is called after server/client has been set up.
+ For the standalone mode, a shared message queue will be set up to simulate ``receiving message``.
+ """
+ if self.mode == 'standalone':
+ # trigger the FL course
+ for each_client in self.client:
+ self.client[each_client].join_in()
+
+ if self.cfg.federate.online_aggr:
+ # any broadcast operation would be executed client-by-client to avoid the existence of #clients messages at the same time.
+ # currently, only consider centralized topology
+ def is_broadcast(msg):
+ return len(msg.receiver) >= 1 and msg.sender == 0
+
+ cached_bc_msgs = []
+ cur_idx = 0
+ while True:
+ if len(self.shared_comm_queue) > 0:
+ msg = self.shared_comm_queue.popleft()
+ if is_broadcast(msg):
+ cached_bc_msgs.append(msg)
+ # assume there is at least one client
+ msg = cached_bc_msgs[0]
+ self._handle_msg(msg, rcv=msg.receiver[cur_idx])
+ cur_idx += 1
+ if cur_idx >= len(msg.receiver):
+ del cached_bc_msgs[0]
+ cur_idx = 0
+ else:
+ self._handle_msg(msg)
+ elif len(cached_bc_msgs) > 0:
+ msg = cached_bc_msgs[0]
+ self._handle_msg(msg, rcv=msg.receiver[cur_idx])
+ cur_idx += 1
+ if cur_idx >= len(msg.receiver):
+ del cached_bc_msgs[0]
+ cur_idx = 0
+ else:
+ # finished
+ break
+
+ else:
+ while len(self.shared_comm_queue) > 0:
+ msg = self.shared_comm_queue.popleft()
+ self._handle_msg(msg)
+
+ self.server._monitor.finish_fed_runner(fl_mode=self.mode)
+
+ return self.server.best_results
+
+ elif self.mode == 'distributed':
+ if self.cfg.distribute.role == 'server':
+ self.server.run()
+ return self.server.best_results
+ elif self.cfg.distribute.role == 'client':
+ self.client.join_in()
+ self.client.run()
+
+ def _setup_server(self):
+ """
+ Set up the server
+ """
+ self.server_id = 0
+ if self.mode == 'standalone':
+ if self.server_id in self.data:
+ server_data = self.data[self.server_id]
+ model = get_model(self.cfg.model,
+ server_data,
+ backend=self.cfg.backend)
+ else:
+ server_data = None
+ model = get_model(
+ self.cfg.model, self.data[1], backend=self.cfg.backend
+ ) # get the model according to client's data if the server does not own data
+ kw = {'shared_comm_queue': self.shared_comm_queue}
+ elif self.mode == 'distributed':
+ server_data = self.data
+ model = get_model(self.cfg.model,
+ server_data,
+ backend=self.cfg.backend)
+ kw = self.server_address
+ else:
+ raise ValueError('Mode {} is not provided'.format(
+ self.cfg.mode.type))
+
+ if self.server_class:
+ self._server_device = self.gpu_manager.auto_choice()
+ server = self.server_class(
+ ID=self.server_id,
+ config=self.cfg,
+ data=server_data,
+ model=model,
+ client_num=self.cfg.federate.client_num,
+ total_round_num=self.cfg.federate.total_round_num,
+ device=self._server_device,
+ **kw)
+
+ if self.cfg.nbafl.use:
+ from federatedscope.core.trainers.trainer_nbafl import wrap_nbafl_server
+ wrap_nbafl_server(server)
+
+ else:
+ raise ValueError
+
+ logger.info('Server #{:d} has been set up ... '.format(self.server_id))
+
+ return server
+
+ def _setup_client(self, client_id=-1, client_model=None):
+ """
+ Set up the client
+ """
+ self.server_id = 0
+ if self.mode == 'standalone':
+ client_data = self.data[client_id]
+ kw = {'shared_comm_queue': self.shared_comm_queue}
+ elif self.mode == 'distributed':
+ client_data = self.data
+ kw = self.client_address
+ kw['server_host'] = self.server_address['host']
+ kw['server_port'] = self.server_address['port']
+ else:
+ raise ValueError('Mode {} is not provided'.format(
+ self.cfg.mode.type))
+
+ if self.client_class:
+ client_specific_config = self.cfg.clone()
+ if self.client_cfg:
+ client_specific_config.defrost()
+ client_specific_config.merge_from_other_cfg(
+ self.client_cfg.get('client_{}'.format(client_id)))
+ client_specific_config.freeze()
+ client_device = self._server_device if self.cfg.federate.share_local_model else self.gpu_manager.auto_choice(
+ )
+ client = self.client_class(
+ ID=client_id,
+ server_id=self.server_id,
+ config=client_specific_config,
+ data=client_data,
+ model=client_model or get_model(client_specific_config.model,
+ client_data,
+ backend=self.cfg.backend),
+ device=client_device,
+ **kw)
+ else:
+ raise ValueError
+
+ if client_id == -1:
+ logger.info('Client (address {}:{}) has been set up ... '.format(
+ self.client_address['host'], self.client_address['port']))
+ else:
+ logger.info(f'Client {client_id} has been set up ... ')
+
+ return client
+
+ def _handle_msg(self, msg, rcv=-1):
+ """
+ To simulate the message handling process (used only for the standalone mode)
+ """
+ if rcv != -1:
+ # simulate broadcast one-by-one
+ self.client[rcv].msg_handlers[msg.msg_type](msg)
+ return
+
+ sender, receiver = msg.sender, msg.receiver
+ download_bytes, upload_bytes = msg.count_bytes()
+ if not isinstance(receiver, list):
+ receiver = [receiver]
+ for each_receiver in receiver:
+ if each_receiver == 0:
+ self.server.msg_handlers[msg.msg_type](msg)
+ self.server._monitor.track_download_bytes(download_bytes)
+ else:
+ self.client[each_receiver].msg_handlers[msg.msg_type](msg)
+ self.client[each_receiver]._monitor.track_download_bytes(
+ download_bytes)
diff --git a/federatedscope/core/gRPC_server.py b/federatedscope/core/gRPC_server.py
new file mode 100644
index 000000000..9cf6c27c0
--- /dev/null
+++ b/federatedscope/core/gRPC_server.py
@@ -0,0 +1,20 @@
+import queue
+from collections import deque
+
+from federatedscope.core.proto import gRPC_comm_manager_pb2, gRPC_comm_manager_pb2_grpc
+
+
+class gRPCComServeFunc(gRPC_comm_manager_pb2_grpc.gRPCComServeFuncServicer):
+ def __init__(self):
+ self.msg_queue = deque()
+
+ def sendMessage(self, request, context):
+ self.msg_queue.append(request)
+
+ return gRPC_comm_manager_pb2.MessageResponse(msg='ACK')
+
+ def receive(self):
+ while len(self.msg_queue) == 0:
+ continue
+ msg = self.msg_queue.popleft()
+ return msg
diff --git a/federatedscope/core/gpu_manager.py b/federatedscope/core/gpu_manager.py
new file mode 100644
index 000000000..6deb06e50
--- /dev/null
+++ b/federatedscope/core/gpu_manager.py
@@ -0,0 +1,88 @@
+import os
+
+
+def check_gpus():
+ if not 'NVIDIA System Management' in os.popen('nvidia-smi -h').read():
+ print("'nvidia-smi' tool not found.")
+ return False
+ return True
+
+
+class GPUManager():
+ """
+ To automatic allocate the gpu, which returns the gpu with the largest free memory rate, unless the specified_device has been set up
+ When gpus is unavailable, return 'cpu';
+ The implementation of GPUManager is referred to https://github.com/QuantumLiu/tf_gpu_manager
+ """
+ def __init__(self, gpu_available=False, specified_device=-1):
+ self.gpu_avaiable = gpu_available and check_gpus()
+ self.specified_device = specified_device
+ if self.gpu_avaiable:
+ self.gpus = self._query_gpus()
+ for gpu in self.gpus:
+ gpu['allocated'] = False
+ else:
+ self.gpus = None
+
+ def _sort_by_memory(self, gpus, by_size=False):
+ if by_size:
+ return sorted(gpus, key=lambda d: d['memory.free'], reverse=True)
+ else:
+ print('Sorted by free memory rate')
+ return sorted(
+ gpus,
+ key=lambda d: float(d['memory.free']) / d['memory.total'],
+ reverse=True)
+
+ def _query_gpus(self):
+ args = ['index', 'gpu_name', 'memory.free', 'memory.total']
+ cmd = 'nvidia-smi --query-gpu={} --format=csv,noheader'.format(
+ ','.join(args))
+ results = os.popen(cmd).readlines()
+ return [self._parse(line, args) for line in results]
+
+ def _parse(self, line, args):
+ numberic_args = ['memory.free', 'memory.total']
+ to_numberic = lambda v: float(v.upper().strip().replace('MIB', '').
+ replace('W', ''))
+ process = lambda k, v: (int(to_numberic(v))
+ if k in numberic_args else v.strip())
+ return {
+ k: process(k, v)
+ for k, v in zip(args,
+ line.strip().split(','))
+ }
+
+ def auto_choice(self):
+ """
+ To allocate a device
+ """
+ if self.gpus == None:
+ return 'cpu'
+ elif self.specified_device >= 0:
+ # allow users to specify the device
+ return 'cuda:{}'.format(self.specified_device)
+ else:
+ for old_infos, new_infos in zip(self.gpus, self._query_gpus()):
+ old_infos.update(new_infos)
+ unallocated_gpus = [
+ gpu for gpu in self.gpus if not gpu['allocated']
+ ]
+ if len(unallocated_gpus) == 0:
+ # reset when all gpus have been allocated
+ unallocated_gpus = self.gpus
+ for gpu in self.gpus:
+ gpu['allocated'] = False
+
+ chosen_gpu = self._sort_by_memory(unallocated_gpus, True)[0]
+ chosen_gpu['allocated'] = True
+ index = chosen_gpu['index']
+ return 'cuda:{:s}'.format(index)
+
+
+# for testing
+if __name__ == '__main__':
+
+ gpu_manager = GPUManager(gpu_available=True, specified_device=0)
+ for i in range(20):
+ print(gpu_manager.auto_choice())
diff --git a/federatedscope/core/lr.py b/federatedscope/core/lr.py
new file mode 100644
index 000000000..16e846218
--- /dev/null
+++ b/federatedscope/core/lr.py
@@ -0,0 +1,10 @@
+import torch
+
+
+class LogisticRegression(torch.nn.Module):
+ def __init__(self, in_channels, class_num, use_bias=True):
+ super(LogisticRegression, self).__init__()
+ self.fc = torch.nn.Linear(in_channels, class_num, bias=use_bias)
+
+ def forward(self, x):
+ return self.fc(x)
diff --git a/federatedscope/core/message.py b/federatedscope/core/message.py
new file mode 100644
index 000000000..a7dab00d8
--- /dev/null
+++ b/federatedscope/core/message.py
@@ -0,0 +1,216 @@
+import sys
+import json
+import numpy as np
+from federatedscope.core.proto import gRPC_comm_manager_pb2
+
+
+class Message(object):
+ """
+ The data exchanged during an FL course are abstracted as 'Message' in FederatedScope.
+ A message object includes:
+ msg_type: The type of message, which is used to trigger the corresponding handlers of server/client
+ sender: The sender's ID
+ receiver: The receiver's ID
+ state: The training round of the message, which is determined by the sender and used to filter out the outdated messages.
+ strategy: redundant attribute
+ """
+ def __init__(self,
+ msg_type=None,
+ sender=0,
+ receiver=0,
+ state=0,
+ content=None,
+ strategy=None):
+ self._msg_type = msg_type
+ self._sender = sender
+ self._receiver = receiver
+ self._state = state
+ self._content = content
+ self._strategy = strategy
+
+ @property
+ def msg_type(self):
+ return self._msg_type
+
+ @msg_type.setter
+ def msg_type(self, value):
+ self._msg_type = value
+
+ @property
+ def sender(self):
+ return self._sender
+
+ @sender.setter
+ def sender(self, value):
+ self._sender = value
+
+ @property
+ def receiver(self):
+ return self._receiver
+
+ @receiver.setter
+ def receiver(self, value):
+ self._receiver = value
+
+ @property
+ def state(self):
+ return self._state
+
+ @state.setter
+ def state(self, value):
+ self._state = value
+
+ @property
+ def content(self):
+ return self._content
+
+ @content.setter
+ def content(self, value):
+ self._content = value
+
+ @property
+ def strategy(self):
+ return self._strategy
+
+ @strategy.setter
+ def strategy(self, value):
+ self._strategy = value
+
+ def transform_to_list(self, x):
+ if isinstance(x, list) or isinstance(x, tuple):
+ return [self.transform_to_list(each_x) for each_x in x]
+ elif isinstance(x, dict):
+ for key in x.keys():
+ x[key] = self.transform_to_list(x[key])
+ return x
+ else:
+ if hasattr(x, 'tolist'):
+ return x.tolist()
+ else:
+ return x
+
+ def msg_to_json(self, to_list=False):
+ if to_list:
+ self.content = self.transform_to_list(self.content)
+
+ json_msg = {
+ 'msg_type': self.msg_type,
+ 'sender': self.sender,
+ 'receiver': self.receiver,
+ 'state': self.state,
+ 'content': self.content,
+ 'strategy': self.strategy,
+ }
+ return json.dumps(json_msg)
+
+ def json_to_msg(self, json_string):
+ json_msg = json.loads(json_string)
+ self.msg_type = json_msg['msg_type']
+ self.sender = json_msg['sender']
+ self.receiver = json_msg['receiver']
+ self.state = json_msg['state']
+ self.content = json_msg['content']
+ self.strategy = json_msg['strategy']
+
+ def create_by_type(self, value, nested=False):
+ if isinstance(value, dict):
+ m_dict = gRPC_comm_manager_pb2.mDict()
+ for key in value.keys():
+ m_dict.dict_value[key].MergeFrom(
+ self.create_by_type(value[key], nested=True))
+ if nested:
+ msg_value = gRPC_comm_manager_pb2.MsgValue()
+ msg_value.dict_msg.MergeFrom(m_dict)
+ return msg_value
+ else:
+ return m_dict
+ elif isinstance(value, list) or isinstance(value, tuple):
+ m_list = gRPC_comm_manager_pb2.mList()
+ for each in value:
+ m_list.list_value.append(self.create_by_type(each,
+ nested=True))
+ if nested:
+ msg_value = gRPC_comm_manager_pb2.MsgValue()
+ msg_value.list_msg.MergeFrom(m_list)
+ return msg_value
+ else:
+ return m_list
+ else:
+ m_single = gRPC_comm_manager_pb2.mSingle()
+ if type(value) in [int, np.int32]:
+ m_single.int_value = value
+ elif type(value) in [str]:
+ m_single.str_value = value
+ elif type(value) in [float, np.float32]:
+ m_single.float_value = value
+ else:
+ raise ValueError(
+ 'The data type {} has not been supported.'.format(
+ type(value)))
+
+ if nested:
+ msg_value = gRPC_comm_manager_pb2.MsgValue()
+ msg_value.single_msg.MergeFrom(m_single)
+ return msg_value
+ else:
+ return m_single
+
+ def build_msg_value(self, value):
+ msg_value = gRPC_comm_manager_pb2.MsgValue()
+
+ if isinstance(value, list) or isinstance(value, tuple):
+ msg_value.list_msg.MergeFrom(self.create_by_type(value))
+ elif isinstance(value, dict):
+ msg_value.dict_msg.MergeFrom(self.create_by_type(value))
+ else:
+ msg_value.single_msg.MergeFrom(self.create_by_type(value))
+
+ return msg_value
+
+ def transform(self, to_list=False):
+ if to_list:
+ self.content = self.transform_to_list(self.content)
+
+ splited_msg = gRPC_comm_manager_pb2.MessageRequest() # map/dict
+ splited_msg.msg['sender'].MergeFrom(self.build_msg_value(self.sender))
+ splited_msg.msg['receiver'].MergeFrom(
+ self.build_msg_value(self.receiver))
+ splited_msg.msg['state'].MergeFrom(self.build_msg_value(self.state))
+ splited_msg.msg['msg_type'].MergeFrom(
+ self.build_msg_value(self.msg_type))
+ splited_msg.msg['content'].MergeFrom(self.build_msg_value(
+ self.content))
+ return splited_msg
+
+ def _parse_msg(self, value):
+ if isinstance(value, gRPC_comm_manager_pb2.MsgValue) or isinstance(
+ value, gRPC_comm_manager_pb2.mSingle):
+ return self._parse_msg(getattr(value, value.WhichOneof("type")))
+ elif isinstance(value, gRPC_comm_manager_pb2.mList):
+ return [self._parse_msg(each) for each in value.list_value]
+ elif isinstance(value, gRPC_comm_manager_pb2.mDict):
+ return {
+ k: self._parse_msg(value.dict_value[k])
+ for k in value.dict_value
+ }
+ else:
+ return value
+
+ def parse(self, received_msg):
+ self.sender = self._parse_msg(received_msg['sender'])
+ self.receiver = self._parse_msg(received_msg['receiver'])
+ self.msg_type = self._parse_msg(received_msg['msg_type'])
+ self.state = self._parse_msg(received_msg['state'])
+ self.content = self._parse_msg(received_msg['content'])
+
+ def count_bytes(self):
+ """
+ calculate the message bytes to be sent/received
+ :return: tuple of bytes of the message to be sent and received
+ """
+ from pympler import asizeof
+ download_bytes = asizeof.asizeof(self.content)
+ upload_cnt = len(self.receiver) if isinstance(self.receiver,
+ list) else 1
+ upload_bytes = download_bytes * upload_cnt
+ return download_bytes, upload_bytes
diff --git a/federatedscope/core/mlp.py b/federatedscope/core/mlp.py
new file mode 100644
index 000000000..a71b76e03
--- /dev/null
+++ b/federatedscope/core/mlp.py
@@ -0,0 +1,40 @@
+import torch
+import torch.nn.functional as F
+from torch.nn import Linear, ModuleList
+from torch.nn import BatchNorm1d, Identity
+
+
+class MLP(torch.nn.Module):
+ """
+ Multilayer Perceptron
+ """
+ def __init__(self,
+ channel_list,
+ dropout=0.,
+ batch_norm=True,
+ relu_first=False):
+ super().__init__()
+ assert len(channel_list) >= 2
+ self.channel_list = channel_list
+ self.dropout = dropout
+ self.relu_first = relu_first
+
+ self.linears = ModuleList()
+ self.norms = ModuleList()
+ for in_channel, out_channel in zip(channel_list[:-1],
+ channel_list[1:]):
+ self.linears.append(Linear(in_channel, out_channel))
+ self.norms.append(
+ BatchNorm1d(out_channel) if batch_norm else Identity())
+
+ def forward(self, x):
+ x = self.linears[0](x)
+ for layer, norm in zip(self.linears[1:], self.norms[:-1]):
+ if self.relu_first:
+ x = F.relu(x)
+ x = norm(x)
+ if not self.relu_first:
+ x = F.relu(x)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = layer.forward(x)
+ return x
diff --git a/federatedscope/core/monitors/__init__.py b/federatedscope/core/monitors/__init__.py
new file mode 100644
index 000000000..3f945b5b0
--- /dev/null
+++ b/federatedscope/core/monitors/__init__.py
@@ -0,0 +1,5 @@
+from federatedscope.core.monitors.early_stopper import EarlyStopper
+from federatedscope.core.monitors.metric_calculator import MetricCalculator
+from federatedscope.core.monitors.monitor import Monitor
+
+__all__ = ['EarlyStopper', 'MetricCalculator', 'Monitor']
diff --git a/federatedscope/core/monitors/early_stopper.py b/federatedscope/core/monitors/early_stopper.py
new file mode 100644
index 000000000..474809e0b
--- /dev/null
+++ b/federatedscope/core/monitors/early_stopper.py
@@ -0,0 +1,97 @@
+import operator
+import numpy as np
+
+
+# TODO: make this as a sub-module of monitor class
+class EarlyStopper(object):
+ """
+ Track the history of metric (e.g., validation loss),
+ check whether should stop (training) process if the metric doesn't improve after a given patience.
+ """
+ def __init__(self,
+ patience=5,
+ delta=0,
+ improve_indicator_mode='best',
+ the_smaller_the_better=True):
+ """
+ Args:
+ patience (int): How long to wait after last time the monitored metric improved.
+ Note that the actual_checking_round = patience * cfg.eval.freq
+ Default: 5
+ delta (float): Minimum change in the monitored metric to indicate an improvement.
+ Default: 0
+ improve_indicator_mode (str): Early stop when no improve to last `patience` round, in ['mean', 'best']
+ """
+ assert 0 <= patience == int(
+ patience
+ ), "Please use a non-negtive integer to indicate the patience"
+ assert delta >= 0, "Please use a positive value to indicate the change"
+ assert improve_indicator_mode in [
+ 'mean', 'best'
+ ], "Please make sure `improve_indicator_mode` is 'mean' or 'best']"
+
+ self.patience = patience
+ self.counter_no_improve = 0
+ self.best_metric = None
+ self.early_stopped = False
+ self.the_smaller_the_better = the_smaller_the_better
+ self.delta = delta
+ self.improve_indicator_mode = improve_indicator_mode
+ # For expansion usages of comparisons
+ self.comparator = operator.lt
+ self.improvement_operator = operator.add
+
+ def track_and_check_dummy(self, new_result):
+ self.early_stopped = False
+ return self.early_stopped
+
+ def track_and_check_best(self, history_result):
+ new_result = history_result[-1]
+ if self.best_metric is None:
+ self.best_metric = new_result
+ elif self.the_smaller_the_better and self.comparator(
+ self.improvement_operator(self.best_metric, -self.delta),
+ new_result):
+ # by default: add(val_loss, -delta) < new_result
+ self.counter_no_improve += 1
+ elif not self.the_smaller_the_better and self.comparator(
+ self.improvement_operator(self.best_metric, self.delta),
+ new_result):
+ # typical case: add(eval_score, delta) > new_result
+ self.counter_no_improve += 1
+ else:
+ self.best_metric = new_result
+ self.counter_no_improve = 0
+
+ self.early_stopped = self.counter_no_improve >= self.patience
+ return self.early_stopped
+
+ def track_and_check_mean(self, history_result):
+ new_result = history_result[-1]
+ if len(history_result) > self.patience:
+ if self.the_smaller_the_better and self.comparator(
+ self.improvement_operator(
+ np.mean(history_result[-self.patience - 1:-1]),
+ -self.delta), new_result):
+ self.early_stopped = True
+ elif not self.the_smaller_the_better and self.comparator(
+ self.improvement_operator(
+ np.mean(history_result[-self.patience - 1:-1]),
+ self.delta), new_result):
+ self.early_stopped = True
+ else:
+ self.early_stopped = False
+
+ return self.early_stopped
+
+ def track_and_check(self, new_result):
+
+ track_method = self.track_and_check_dummy # do nothing
+ if self.patience == 0:
+ track_method = self.track_and_check_dummy
+ elif self.improve_indicator_mode == 'best':
+ track_method = self.track_and_check_best
+ elif self.improve_indicator_mode == 'mean':
+ track_method = self.track_and_check_mean
+
+ return track_method(new_result)
diff --git a/federatedscope/core/monitors/metric_calculator.py b/federatedscope/core/monitors/metric_calculator.py
new file mode 100644
index 000000000..41610e11e
--- /dev/null
+++ b/federatedscope/core/monitors/metric_calculator.py
@@ -0,0 +1,206 @@
+import logging
+from typing import Optional, Union, List, Set
+
+import numpy as np
+from scipy.special import softmax
+from sklearn.metrics import roc_auc_score, average_precision_score, f1_score
+
+from federatedscope.core.auxiliaries.metric_builder import get_metric
+
+# Blind torch
+try:
+ import torch
+except ImportError:
+ torch = None
+
+logger = logging.getLogger(__name__)
+
+
+# TODO: make this as a sub-module of monitor class
+class MetricCalculator(object):
+ def __init__(self, eval_metric: Union[Set[str], List[str], str]):
+
+ # Add personalized metrics
+ if isinstance(eval_metric, str):
+ eval_metric = {eval_metric}
+ elif isinstance(eval_metric, list):
+ eval_metric = set(eval_metric)
+
+ # Default metric is {'loss', 'avg_loss', 'total'}
+ self.eval_metric = self.get_metric_funcs(eval_metric)
+
+ def get_metric_funcs(self, eval_metric):
+ metric_buildin = {
+ metric: SUPPORT_METRICS[metric]
+ for metric in {'loss', 'avg_loss', 'total'} | eval_metric
+ if metric in SUPPORT_METRICS
+ }
+ metric_register = get_metric(eval_metric - set(SUPPORT_METRICS.keys()))
+ return {**metric_buildin, **metric_register}
+
+ def eval(self, ctx):
+ results = {}
+ y_true, y_pred, y_prob = self._check_and_parse(ctx)
+ for metric, func in self.eval_metric.items():
+ results["{}_{}".format(ctx.cur_data_split,
+ metric)] = func(ctx=ctx,
+ y_true=y_true,
+ y_pred=y_pred,
+ y_prob=y_prob,
+ metric=metric)
+ return results
+
+ def _check_and_parse(self, ctx):
+ if not '{}_y_true'.format(ctx.cur_data_split) in ctx:
+ raise KeyError('Missing key y_true!')
+ if not '{}_y_prob'.format(ctx.cur_data_split) in ctx:
+ raise KeyError('Missing key y_prob!')
+
+ y_true = ctx.get("{}_y_true".format(ctx.cur_data_split))
+ y_prob = ctx.get("{}_y_prob".format(ctx.cur_data_split))
+
+ if torch is not None and isinstance(y_true, torch.Tensor):
+ y_true = y_true.detach().cpu().numpy()
+ if torch is not None and isinstance(y_prob, torch.Tensor):
+ y_prob = y_prob.detach().cpu().numpy()
+
+ if y_true.ndim == 1:
+ y_true = np.expand_dims(y_true, axis=-1)
+ if y_prob.ndim == 2:
+ y_prob = np.expand_dims(y_prob, axis=-1)
+
+ y_pred = np.argmax(y_prob, axis=1)
+
+ # check shape and type
+ if not isinstance(y_true, np.ndarray):
+ raise RuntimeError('Type not support!')
+ if not y_true.shape == y_pred.shape:
+ raise RuntimeError('Shape not match!')
+ if not y_true.ndim == 2:
+ raise RuntimeError(
+ 'y_true must be 2-dim arrray, {}-dim given'.format(
+ y_true.ndim))
+
+ return y_true, y_pred, y_prob
+
+
+def eval_correct(y_true, y_pred, **kwargs):
+ correct_list = []
+
+ for i in range(y_true.shape[1]):
+ is_labeled = y_true[:, i] == y_true[:, i]
+ correct = y_true[is_labeled, i] == y_pred[is_labeled, i]
+ correct_list.append(np.sum(correct))
+ return sum(correct_list) / len(correct_list)
+
+
+def eval_acc(y_true, y_pred, **kwargs):
+ acc_list = []
+
+ for i in range(y_true.shape[1]):
+ is_labeled = y_true[:, i] == y_true[:, i]
+ correct = y_true[is_labeled, i] == y_pred[is_labeled, i]
+ acc_list.append(float(np.sum(correct)) / len(correct))
+ return sum(acc_list) / len(acc_list)
+
+
+def eval_ap(y_true, y_pred, **kwargs):
+ ap_list = []
+
+ for i in range(y_true.shape[1]):
+ # AUC is only defined when there is at least one positive data.
+ if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
+ # ignore nan values
+ is_labeled = y_true[:, i] == y_true[:, i]
+ ap = average_precision_score(y_true[is_labeled, i],
+ y_pred[is_labeled, i])
+
+ ap_list.append(ap)
+
+ if len(ap_list) == 0:
+ logger.warning('No positively labeled data available. ')
+ return 0.0
+
+ return sum(ap_list) / len(ap_list)
+
+
+def eval_f1_score(y_true, y_pred, **kwargs):
+ return f1_score(y_true, y_pred, average='macro')
+
+
+def eval_hits(y_true, y_prob, metric, **kwargs):
+ n = int(metric.split('@')[1])
+ hits_list = []
+ for i in range(y_true.shape[1]):
+ idx = np.argsort(-y_prob[:, :, i], axis=1)
+ pred_rank = idx.argsort(axis=1)
+ # Obtain the label rank
+ arg = np.arange(0, pred_rank.shape[0])
+ rank = pred_rank[arg, y_true[:, i]] + 1
+ hits_num = (rank <= n).sum().item()
+ hits_list.append(float(hits_num) / len(rank))
+
+ return sum(hits_list) / len(hits_list)
+
+
+def eval_roc_auc(y_true, y_prob, **kwargs):
+ rocauc_list = []
+
+ for i in range(y_true.shape[1]):
+ # AUC is only defined when there is at least one positive data.
+ if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
+ # ignore nan values
+ is_labeled = y_true[:, i] == y_true[:, i]
+ y_true_one_hot = np.eye(y_prob.shape[1])[y_true[is_labeled, i]]
+ rocauc_list.append(
+ roc_auc_score(y_true_one_hot,
+ softmax(y_prob[is_labeled, :, i], axis=-1)))
+ if len(rocauc_list) == 0:
+ logger.warning('No positively labeled data available.')
+ return 0.5
+
+ return sum(rocauc_list) / len(rocauc_list)
+
+
+def eval_rmse(y_true, y_pred, **kwargs):
+ rmse_list = []
+
+ for i in range(y_true.shape[1]):
+ # ignore nan values
+ is_labeled = y_true[:, i] == y_true[:, i]
+ rmse_list.append(
+ np.sqrt(((y_true[is_labeled] - y_pred[is_labeled])**2).mean()))
+
+ return sum(rmse_list) / len(rmse_list)
+
+
+def eval_loss(ctx, **kwargs):
+ return ctx.get('loss_batch_total_{}'.format(ctx.cur_data_split))
+
+
+def eval_avg_loss(ctx, **kwargs):
+ return ctx.get("loss_batch_total_{}".format(ctx.cur_data_split)) / ctx.get(
+ "num_samples_{}".format(ctx.cur_data_split))
+
+
+def eval_total(ctx, **kwargs):
+ return ctx.get("num_samples_{}".format(ctx.cur_data_split))
+
+
+def eval_regular(ctx, **kwargs):
+ return ctx.get("loss_regular_total_{}".format(ctx.cur_data_split))
+
+
+SUPPORT_METRICS = {
+ 'loss': eval_loss,
+ 'avg_loss': eval_avg_loss,
+ 'total': eval_total,
+ 'correct': eval_correct,
+ 'acc': eval_acc,
+ 'ap': eval_ap,
+ 'f1': eval_f1_score,
+ 'roc_auc': eval_roc_auc,
+ 'rmse': eval_rmse,
+ 'loss_regular': eval_regular,
+ **dict.fromkeys([f'hits@{n}' for n in range(1, 101)], eval_hits)
+}
diff --git a/federatedscope/core/monitors/monitor.py b/federatedscope/core/monitors/monitor.py
new file mode 100644
index 000000000..6242ca582
--- /dev/null
+++ b/federatedscope/core/monitors/monitor.py
@@ -0,0 +1,450 @@
+import copy
+import json
+import logging
+import os
+import gzip
+import shutil
+import datetime
+from collections import defaultdict
+
+import numpy as np
+
+try:
+ import torch
+except ImportError:
+ torch = None
+
+logger = logging.getLogger(__name__)
+
+
+class Monitor(object):
+ """
+ Provide the monitoring functionalities such as formatting the evaluation results into diverse metrics.
+ Besides the prediction related performance, the monitor also can track efficiency related metrics for a worker
+ """
+ SUPPORTED_FORMS = ['weighted_avg', 'avg', 'fairness', 'raw']
+
+ def __init__(self, cfg, monitored_object=None):
+ self.outdir = cfg.outdir
+ self.use_wandb = cfg.wandb.use
+ # self.use_tensorboard = cfg.use_tensorboard
+
+ self.monitored_object = monitored_object
+
+ # ========= efficiency indicators of the worker to be monitored ================
+ # leveraged the flops counter provided by [fvcore](https://github.com/facebookresearch/fvcore)
+ self.total_model_size = 0 # model size used in the worker, in terms of number of parameters
+ self.flops_per_sample = 0 # average flops for forwarding each data sample
+ self.flop_count = 0 # used to calculated the running mean for flops_per_sample
+ self.total_flops = 0 # total computation flops to convergence until current fl round
+ self.total_upload_bytes = 0 # total upload space cost in bytes until current fl round
+ self.total_download_bytes = 0 # total download space cost in bytes until current fl round
+ self.fl_begin_wall_time = datetime.datetime.now()
+ self.fl_end_wall_time = 0
+ # for the metrics whose names includes "convergence", 0 indicates the worker does not converge yet
+ # Note:
+ # 1) the convergence wall time is prone to fluctuations due to possible resource competition during FL courses
+ # 2) the global/local indicates whether the early stopping triggered with global-aggregation/local-training
+ self.global_convergence_round = 0 # total fl rounds to convergence
+ self.global_convergence_wall_time = 0
+ self.local_convergence_round = 0 # total fl rounds to convergence
+ self.local_convergence_wall_time = 0
+
+ def global_converged(self):
+ self.global_convergence_wall_time = datetime.datetime.now(
+ ) - self.fl_begin_wall_time
+ self.global_convergence_round = self.monitored_object.state
+
+ def local_converged(self):
+ self.local_convergence_wall_time = datetime.datetime.now(
+ ) - self.fl_begin_wall_time
+ self.local_convergence_round = self.monitored_object.state
+
+ def finish_fl(self):
+ self.fl_end_wall_time = datetime.datetime.now(
+ ) - self.fl_begin_wall_time
+
+ system_metrics = {
+ "id": self.monitored_object.ID,
+ "fl_end_time_minutes": self.fl_end_wall_time.total_seconds() /
+ 60 if isinstance(self.fl_end_wall_time, datetime.timedelta) else 0,
+ "total_model_size": self.total_model_size,
+ "total_flops": self.total_flops,
+ "total_upload_bytes": self.total_upload_bytes,
+ "total_download_bytes": self.total_download_bytes,
+ "global_convergence_round": self.global_convergence_round,
+ "local_convergence_round": self.local_convergence_round,
+ "global_convergence_time_minutes": self.
+ global_convergence_wall_time.total_seconds() / 60 if isinstance(
+ self.global_convergence_wall_time, datetime.timedelta) else 0,
+ "local_convergence_time_minutes": self.local_convergence_wall_time.
+ total_seconds() / 60 if isinstance(
+ self.local_convergence_wall_time, datetime.timedelta) else 0,
+ }
+ logger.info(
+ f"In worker #{self.monitored_object.ID}, the system-related metrics are: {str(system_metrics)}"
+ )
+ sys_metric_f_name = os.path.join(self.outdir, "system_metrics.log")
+ with open(sys_metric_f_name, "a") as f:
+ f.write(json.dumps(system_metrics) + "\n")
+
+ def merge_system_metrics_simulation_mode(self):
+ """
+ average the system metrics recorded in "system_metrics.json" by all workers
+ :return:
+ """
+ sys_metric_f_name = os.path.join(self.outdir, "system_metrics.log")
+ if not os.path.exists(sys_metric_f_name):
+ logger.warning(
+ "You have not tracked the workers' system metrics in $outdir$/system_metrics.log, "
+ "we will skip the merging. Plz check whether you do not want to call monitor.finish_fl()"
+ )
+ return
+
+ all_sys_metrics = defaultdict(list)
+ avg_sys_metrics = defaultdict()
+ std_sys_metrics = defaultdict()
+ with open(sys_metric_f_name, "r") as f:
+ for line in f:
+ res = json.loads(line)
+ if all_sys_metrics is None:
+ all_sys_metrics = res
+ all_sys_metrics["id"] = "all"
+ else:
+ for k, v in res.items():
+ all_sys_metrics[k].append(v)
+
+ for k, v in all_sys_metrics.items():
+ if k == "id":
+ avg_sys_metrics[k] = "sys_avg"
+ std_sys_metrics[k] = "sys_std"
+ else:
+ v = np.array(v)
+ avg_sys_metrics[f"sys_avg/{k}"] = np.mean(v)
+ std_sys_metrics[f"sys_std/{k}"] = np.std(v)
+
+ logger.info(
+ f"After merging the system metrics from all works, we got avg: {avg_sys_metrics}"
+ )
+ logger.info(
+ f"After merging the system metrics from all works, we got std: {std_sys_metrics}"
+ )
+ with open(sys_metric_f_name, "a") as f:
+ f.write(json.dumps(avg_sys_metrics) + "\n")
+ f.write(json.dumps(std_sys_metrics) + "\n")
+
+ def finish_fed_runner(self, fl_mode=None):
+ self.compress_raw_res_file()
+ if fl_mode == "standalone":
+ self.merge_system_metrics_simulation_mode()
+
+ if self.use_wandb:
+ try:
+ import wandb
+ except ImportError:
+ logger.error(
+ "cfg.wandb.use=True but not install the wandb package")
+ exit()
+
+ from federatedscope.core.auxiliaries.utils import logfile_2_wandb_dict
+ with open(os.path.join(self.outdir, "eval_results.log"),
+ "r") as exp_log_f:
+ # track the prediction related performance
+ all_log_res, exp_stop_normal, last_line, log_res_best = \
+ logfile_2_wandb_dict(exp_log_f, raw_out=False)
+ for log_res in all_log_res:
+ wandb.log(log_res)
+ wandb.log(log_res_best)
+
+ # track the system related performance
+ sys_metric_f_name = os.path.join(self.outdir,
+ "system_metrics.log")
+ with open(sys_metric_f_name, "r") as f:
+ for line in f:
+ res = json.loads(line)
+ if res["id"] in ["sys_avg", "sys_std"]:
+ wandb.log(res)
+
+ def compress_raw_res_file(self):
+ old_f_name = os.path.join(self.outdir, "eval_results.raw")
+ if os.path.exists(old_f_name):
+ logger.info(
+ "We will compress the file eval_results.raw into a .gz file, and delete the old one"
+ )
+ with open(old_f_name, 'rb') as f_in:
+ with gzip.open(old_f_name + ".gz", 'wb') as f_out:
+ shutil.copyfileobj(f_in, f_out)
+ os.remove(old_f_name)
+
+ def format_eval_res(self,
+ results,
+ rnd,
+ role=-1,
+ forms=None,
+ return_raw=False):
+ """
+ format the evaluation results from trainer.ctx.eval_results
+
+ Args:
+ results (dict): a dict to store the evaluation results {metric: value}
+ rnd (int|string): FL round
+ role (int|string): the output role
+ forms (list): format type
+ return_raw (bool): return either raw results, or other results
+
+ Returns:
+ round_formatted_results (dict): a formatted results with different forms and roles,
+ e.g.,
+ {
+ 'Role': 'Server #',
+ 'Round': 200,
+ 'Results_weighted_avg': {
+ 'test_avg_loss': 0.58, 'test_acc': 0.67, 'test_correct': 3356, 'test_loss': 2892, 'test_total': 5000
+ },
+ 'Results_avg': {
+ 'test_avg_loss': 0.57, 'test_acc': 0.67, 'test_correct': 3356, 'test_loss': 2892, 'test_total': 5000
+ },
+ 'Results_fairness': {
+ 'test_correct': 3356, 'test_total': 5000,
+ 'test_avg_loss_std': 0.04, 'test_avg_loss_bottom_decile': 0.52, 'test_avg_loss_top_decile': 0.64,
+ 'test_acc_std': 0.06, 'test_acc_bottom_decile': 0.60, 'test_acc_top_decile': 0.75,
+ 'test_loss_std': 214.17, 'test_loss_bottom_decile': 2644.64, 'test_loss_top_decile': 3241.23
+ },
+ }
+ """
+ if forms is None:
+ forms = ['weighted_avg', 'avg', 'fairness', 'raw']
+ round_formatted_results = {'Role': role, 'Round': rnd}
+ round_formatted_results_raw = {'Role': role, 'Round': rnd}
+ for form in forms:
+ new_results = copy.deepcopy(results)
+ if not role.lower().startswith('server') or form == 'raw':
+ round_formatted_results_raw['Results_raw'] = new_results
+ elif form not in Monitor.SUPPORTED_FORMS:
+ continue
+ else:
+ for key in results.keys():
+ dataset_name = key.split("_")[0]
+ if f'{dataset_name}_total' not in results:
+ raise ValueError(
+ "Results to be formatted should be include the dataset_num in the dict,"
+ f"with key = {dataset_name}_total")
+ else:
+ dataset_num = np.array(
+ results[f'{dataset_name}_total'])
+ if key in [
+ f'{dataset_name}_total',
+ f'{dataset_name}_correct'
+ ]:
+ new_results[key] = np.mean(new_results[key])
+
+ if key in [
+ f'{dataset_name}_total', f'{dataset_name}_correct'
+ ]:
+ new_results[key] = np.mean(new_results[key])
+ else:
+ all_res = np.array(copy.copy(results[key]))
+ if form == 'weighted_avg':
+ new_results[key] = np.sum(
+ np.array(new_results[key]) *
+ dataset_num) / np.sum(dataset_num)
+ if form == "avg":
+ new_results[key] = np.mean(new_results[key])
+ if form == "fairness" and all_res.size > 1:
+ # by default, log the std and decile
+ new_results.pop(
+ key, None) # delete the redundant original one
+ all_res.sort()
+ new_results[f"{key}_std"] = np.std(
+ np.array(all_res))
+ new_results[f"{key}_bottom_decile"] = all_res[
+ all_res.size // 10]
+ new_results[f"{key}_top_decile"] = all_res[
+ all_res.size * 9 // 10]
+ round_formatted_results[f'Results_{form}'] = new_results
+
+ with open(os.path.join(self.outdir, "eval_results.raw"),
+ "a") as outfile:
+ outfile.write(str(round_formatted_results_raw) + "\n")
+
+ return round_formatted_results_raw if return_raw else round_formatted_results
+
+ def calc_blocal_dissim(self, last_model, local_updated_models):
+ '''
+ Arguments:
+ last_model (dict): the state of last round.
+ local_updated_models (list): each element is ooxx.
+ Returns:
+ b_local_dissimilarity (dict): the measurements proposed in
+ "Tian Li, Anit Kumar Sahu, Manzil Zaheer, and et al. Federated Optimization in Heterogeneous Networks".
+ '''
+ # for k, v in last_model.items():
+ # print(k, v)
+ # for i, elem in enumerate(local_updated_models):
+ # print(i, elem)
+ local_grads = []
+ weights = []
+ local_gnorms = []
+ for tp in local_updated_models:
+ weights.append(tp[0])
+ grads = dict()
+ gnorms = dict()
+ for k, v in tp[1].items():
+ grad = v - last_model[k]
+ grads[k] = grad
+ gnorms[k] = torch.sum(grad**2)
+ local_grads.append(grads)
+ local_gnorms.append(gnorms)
+ weights = np.asarray(weights)
+ weights = weights / np.sum(weights)
+ avg_gnorms = dict()
+ global_grads = dict()
+ for i in range(len(local_updated_models)):
+ gnorms = local_gnorms[i]
+ for k, v in gnorms.items():
+ if k not in avg_gnorms:
+ avg_gnorms[k] = .0
+ avg_gnorms[k] += weights[i] * v
+ grads = local_grads[i]
+ for k, v in grads.items():
+ if k not in global_grads:
+ global_grads[k] = torch.zeros_like(v)
+ global_grads[k] += weights[i] * v
+ b_local_dissimilarity = dict()
+ for k in avg_gnorms:
+ b_local_dissimilarity[k] = np.sqrt(
+ avg_gnorms[k].item() / torch.sum(global_grads[k]**2).item())
+ return b_local_dissimilarity
+
+ def track_model_size(self, models):
+ """
+ calculate the total model size given the models hold by the worker/trainer
+
+ :param models: torch.nn.Module or list of torch.nn.Module
+ :return:
+ """
+ if self.total_model_size != 0:
+ logger.warning(
+ "the total_model_size is not zero. You may have been calculated the total_model_size before"
+ )
+
+ if not hasattr(models, '__iter__'):
+ models = [models]
+ for model in models:
+ assert isinstance(model, torch.nn.Module), \
+ f"the `model` should be type torch.nn.Module when calculating its size, but got {type(model)}"
+ for name, para in model.named_parameters():
+ self.total_model_size += para.numel()
+
+ def track_avg_flops(self, flops, sample_num=1):
+ """
+ update the average flops for forwarding each data sample, for most models and tasks,
+ the averaging is not needed as the input shape is fixed
+
+ :param flops: flops/
+ :param sample_num:
+ :return:
+ """
+
+ self.flops_per_sample = (self.flops_per_sample * self.flop_count +
+ flops) / (self.flop_count + sample_num)
+ self.flop_count += 1
+
+ def track_upload_bytes(self, bytes):
+ self.total_upload_bytes += bytes
+
+ def track_download_bytes(self, bytes):
+ self.total_download_bytes += bytes
+
+
+def update_best_result(best_results,
+ new_results,
+ results_type,
+ round_wise_update_key="val_loss"):
+ """
+ update best evaluation results.
+ by default, the update is based on validation loss with `round_wise_update_key="val_loss" `
+ """
+ update_best_this_round = False
+ if not isinstance(new_results, dict):
+ raise ValueError(
+ f"update best results require `results` a dict, but got {type(new_results)}"
+ )
+ else:
+ if results_type not in best_results:
+ best_results[results_type] = dict()
+ best_result = best_results[results_type]
+ # update different keys separately: the best values can be in different rounds
+ if round_wise_update_key is None:
+ for key in new_results:
+ cur_result = new_results[key]
+ if 'loss' in key or 'std' in key: # the smaller, the better
+ if results_type == "client_individual":
+ cur_result = min(cur_result)
+ if key not in best_result or cur_result < best_result[key]:
+ best_result[key] = cur_result
+ update_best_this_round = True
+
+ elif 'acc' in key: # the larger, the better
+ if results_type == "client_individual":
+ cur_result = max(cur_result)
+ if key not in best_result or cur_result > best_result[key]:
+ best_result[key] = cur_result
+ update_best_this_round = True
+ else:
+ # unconcerned metric
+ pass
+ # update different keys round-wise: if find better round_wise_update_key, update others at the same time
+ else:
+ if round_wise_update_key not in [
+ "val_loss", "val_acc", "val_std", "test_loss", "test_acc",
+ "test_std", "test_avg_loss", "loss"
+ ]:
+ raise NotImplementedError(
+ f"We currently support round_wise_update_key as one of "
+ f"['val_loss', 'val_acc', 'val_std', 'test_loss', 'test_acc', 'test_std'] "
+ f"for round-wise best results update, but got {round_wise_update_key}."
+ )
+
+ found_round_wise_update_key = False
+ sorted_keys = []
+ for key in new_results:
+ if round_wise_update_key in key:
+ sorted_keys.insert(0, key)
+ found_round_wise_update_key = True
+ else:
+ sorted_keys.append(key)
+ if not found_round_wise_update_key:
+ raise ValueError(
+ "Your specified eval.best_res_update_round_wise_key is not in target results, "
+ "use another key or check the name. \n"
+ f"Got eval.best_res_update_round_wise_key={round_wise_update_key}, "
+ f"the keys of results are {list(new_results.keys())}")
+
+ for key in sorted_keys:
+ cur_result = new_results[key]
+ if update_best_this_round or \
+ ('loss' in round_wise_update_key and 'loss' in key) or \
+ ('std' in round_wise_update_key and 'std' in key):
+ # The smaller the better
+ if results_type == "client_individual":
+ cur_result = min(cur_result)
+ if update_best_this_round or \
+ key not in best_result or cur_result < best_result[key]:
+ best_result[key] = cur_result
+ update_best_this_round = True
+ elif update_best_this_round or \
+ 'acc' in round_wise_update_key and 'acc' in key:
+ # The larger the better
+ if results_type == "client_individual":
+ cur_result = max(cur_result)
+ if update_best_this_round or \
+ key not in best_result or cur_result > best_result[key]:
+ best_result[key] = cur_result
+ update_best_this_round = True
+ else:
+ # unconcerned metric
+ pass
+
+ if update_best_this_round:
+ logger.info(f"Find new best result: {best_results}")
diff --git a/federatedscope/core/optimizer.py b/federatedscope/core/optimizer.py
new file mode 100644
index 000000000..0736e245d
--- /dev/null
+++ b/federatedscope/core/optimizer.py
@@ -0,0 +1,56 @@
+import copy
+from typing import Dict, List
+
+
+def wrap_regularized_optimizer(base_optimizer, regular_weight):
+ base_optimizer_type = type(base_optimizer)
+ internal_base_optimizer = copy.copy(
+ base_optimizer) # shallow copy to link the underlying model para
+
+ class ParaRegularOptimizer(base_optimizer_type):
+ """
+ Regularization-based optimizer wrapper
+ """
+ def __init__(self, base_optimizer, regular_weight):
+ # inherit all the attributes of base optimizer
+ self.__dict__.update(base_optimizer.__dict__)
+
+ # attributes used in the wrapper
+ self.optimizer = base_optimizer # internal torch optimizer
+ self.param_groups = self.optimizer.param_groups # link the para of internal optimizer with the wrapper
+ self.regular_weight = regular_weight
+ self.compared_para_groups = None
+
+ def set_compared_para_group(self, compared_para_dict: List[Dict]):
+ if not (isinstance(compared_para_dict, list)
+ and isinstance(compared_para_dict[0], dict)
+ and 'params' in compared_para_dict[0]):
+ raise ValueError(
+ f"compared_para_dict should be a torch style para group, i.e., list[dict], "
+ f"in which the dict stores the para with key `params`")
+ self.compared_para_groups = copy.deepcopy(compared_para_dict)
+
+ def reset_compared_para_group(self, target=None):
+ # by default, del stale compared_para to free memory
+ self.compared_para_groups = target
+
+ def regularize_by_para_diff(self):
+ """
+ before optim.step(), regularize the gradients based on para diff
+ """
+ for group, compared_group in zip(self.param_groups,
+ self.compared_para_groups):
+ for p, compared_weight in zip(group['params'],
+ compared_group['params']):
+ if p.grad is not None:
+ if compared_weight.device != p.device:
+ # For Tensor, the to() is not in-place operation
+ compared_weight = compared_weight.to(p.device)
+ p.grad.data = p.grad.data + self.regular_weight * (
+ p.data - compared_weight.data)
+
+ def step(self):
+ self.regularize_by_para_diff() # key action
+ self.optimizer.step()
+
+ return ParaRegularOptimizer(internal_base_optimizer, regular_weight)
diff --git a/federatedscope/core/optimizers/__init__.py b/federatedscope/core/optimizers/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/federatedscope/core/proto/__init__.py b/federatedscope/core/proto/__init__.py
new file mode 100644
index 000000000..4deaa8ea3
--- /dev/null
+++ b/federatedscope/core/proto/__init__.py
@@ -0,0 +1,2 @@
+from federatedscope.core.proto.gRPC_comm_manager_pb2 import *
+from federatedscope.core.proto.gRPC_comm_manager_pb2_grpc import *
diff --git a/federatedscope/core/proto/gRPC_comm_manager_pb2.py b/federatedscope/core/proto/gRPC_comm_manager_pb2.py
new file mode 100644
index 000000000..94c80799c
--- /dev/null
+++ b/federatedscope/core/proto/gRPC_comm_manager_pb2.py
@@ -0,0 +1,130 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: gRPC_comm_manager.proto
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x17gRPC_comm_manager.proto\"n\n\x0eMessageRequest\x12%\n\x03msg\x18\x01 \x03(\x0b\x32\x18.MessageRequest.MsgEntry\x1a\x35\n\x08MsgEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x18\n\x05value\x18\x02 \x01(\x0b\x32\t.MsgValue:\x02\x38\x01\"j\n\x08MsgValue\x12\x1e\n\nsingle_msg\x18\x01 \x01(\x0b\x32\x08.mSingleH\x00\x12\x1a\n\x08list_msg\x18\x02 \x01(\x0b\x32\x06.mListH\x00\x12\x1a\n\x08\x64ict_msg\x18\x03 \x01(\x0b\x32\x06.mDictH\x00\x42\x06\n\x04type\"R\n\x07mSingle\x12\x15\n\x0b\x66loat_value\x18\x01 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x02 \x01(\x05H\x00\x12\x13\n\tstr_value\x18\x03 \x01(\tH\x00\x42\x06\n\x04type\"&\n\x05mList\x12\x1d\n\nlist_value\x18\x01 \x03(\x0b\x32\t.MsgValue\"o\n\x05mDict\x12)\n\ndict_value\x18\x01 \x03(\x0b\x32\x15.mDict.DictValueEntry\x1a;\n\x0e\x44ictValueEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x18\n\x05value\x18\x02 \x01(\x0b\x32\t.MsgValue:\x02\x38\x01\"\x1e\n\x0fMessageResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t2D\n\x10gRPCComServeFunc\x12\x30\n\x0bsendMessage\x12\x0f.MessageRequest\x1a\x10.MessageResponseb\x06proto3'
+)
+
+_MESSAGEREQUEST = DESCRIPTOR.message_types_by_name['MessageRequest']
+_MESSAGEREQUEST_MSGENTRY = _MESSAGEREQUEST.nested_types_by_name['MsgEntry']
+_MSGVALUE = DESCRIPTOR.message_types_by_name['MsgValue']
+_MSINGLE = DESCRIPTOR.message_types_by_name['mSingle']
+_MLIST = DESCRIPTOR.message_types_by_name['mList']
+_MDICT = DESCRIPTOR.message_types_by_name['mDict']
+_MDICT_DICTVALUEENTRY = _MDICT.nested_types_by_name['DictValueEntry']
+_MESSAGERESPONSE = DESCRIPTOR.message_types_by_name['MessageResponse']
+MessageRequest = _reflection.GeneratedProtocolMessageType(
+ 'MessageRequest',
+ (_message.Message, ),
+ {
+ 'MsgEntry': _reflection.GeneratedProtocolMessageType(
+ 'MsgEntry',
+ (_message.Message, ),
+ {
+ 'DESCRIPTOR': _MESSAGEREQUEST_MSGENTRY,
+ '__module__': 'gRPC_comm_manager_pb2'
+ # @@protoc_insertion_point(class_scope:MessageRequest.MsgEntry)
+ }),
+ 'DESCRIPTOR': _MESSAGEREQUEST,
+ '__module__': 'gRPC_comm_manager_pb2'
+ # @@protoc_insertion_point(class_scope:MessageRequest)
+ })
+_sym_db.RegisterMessage(MessageRequest)
+_sym_db.RegisterMessage(MessageRequest.MsgEntry)
+
+MsgValue = _reflection.GeneratedProtocolMessageType(
+ 'MsgValue',
+ (_message.Message, ),
+ {
+ 'DESCRIPTOR': _MSGVALUE,
+ '__module__': 'gRPC_comm_manager_pb2'
+ # @@protoc_insertion_point(class_scope:MsgValue)
+ })
+_sym_db.RegisterMessage(MsgValue)
+
+mSingle = _reflection.GeneratedProtocolMessageType(
+ 'mSingle',
+ (_message.Message, ),
+ {
+ 'DESCRIPTOR': _MSINGLE,
+ '__module__': 'gRPC_comm_manager_pb2'
+ # @@protoc_insertion_point(class_scope:mSingle)
+ })
+_sym_db.RegisterMessage(mSingle)
+
+mList = _reflection.GeneratedProtocolMessageType(
+ 'mList',
+ (_message.Message, ),
+ {
+ 'DESCRIPTOR': _MLIST,
+ '__module__': 'gRPC_comm_manager_pb2'
+ # @@protoc_insertion_point(class_scope:mList)
+ })
+_sym_db.RegisterMessage(mList)
+
+mDict = _reflection.GeneratedProtocolMessageType(
+ 'mDict',
+ (_message.Message, ),
+ {
+ 'DictValueEntry': _reflection.GeneratedProtocolMessageType(
+ 'DictValueEntry',
+ (_message.Message, ),
+ {
+ 'DESCRIPTOR': _MDICT_DICTVALUEENTRY,
+ '__module__': 'gRPC_comm_manager_pb2'
+ # @@protoc_insertion_point(class_scope:mDict.DictValueEntry)
+ }),
+ 'DESCRIPTOR': _MDICT,
+ '__module__': 'gRPC_comm_manager_pb2'
+ # @@protoc_insertion_point(class_scope:mDict)
+ })
+_sym_db.RegisterMessage(mDict)
+_sym_db.RegisterMessage(mDict.DictValueEntry)
+
+MessageResponse = _reflection.GeneratedProtocolMessageType(
+ 'MessageResponse',
+ (_message.Message, ),
+ {
+ 'DESCRIPTOR': _MESSAGERESPONSE,
+ '__module__': 'gRPC_comm_manager_pb2'
+ # @@protoc_insertion_point(class_scope:MessageResponse)
+ })
+_sym_db.RegisterMessage(MessageResponse)
+
+_GRPCCOMSERVEFUNC = DESCRIPTOR.services_by_name['gRPCComServeFunc']
+if _descriptor._USE_C_DESCRIPTORS == False:
+
+ DESCRIPTOR._options = None
+ _MESSAGEREQUEST_MSGENTRY._options = None
+ _MESSAGEREQUEST_MSGENTRY._serialized_options = b'8\001'
+ _MDICT_DICTVALUEENTRY._options = None
+ _MDICT_DICTVALUEENTRY._serialized_options = b'8\001'
+ _MESSAGEREQUEST._serialized_start = 27
+ _MESSAGEREQUEST._serialized_end = 137
+ _MESSAGEREQUEST_MSGENTRY._serialized_start = 84
+ _MESSAGEREQUEST_MSGENTRY._serialized_end = 137
+ _MSGVALUE._serialized_start = 139
+ _MSGVALUE._serialized_end = 245
+ _MSINGLE._serialized_start = 247
+ _MSINGLE._serialized_end = 329
+ _MLIST._serialized_start = 331
+ _MLIST._serialized_end = 369
+ _MDICT._serialized_start = 371
+ _MDICT._serialized_end = 482
+ _MDICT_DICTVALUEENTRY._serialized_start = 423
+ _MDICT_DICTVALUEENTRY._serialized_end = 482
+ _MESSAGERESPONSE._serialized_start = 484
+ _MESSAGERESPONSE._serialized_end = 514
+ _GRPCCOMSERVEFUNC._serialized_start = 516
+ _GRPCCOMSERVEFUNC._serialized_end = 584
+# @@protoc_insertion_point(module_scope)
diff --git a/federatedscope/core/proto/gRPC_comm_manager_pb2_grpc.py b/federatedscope/core/proto/gRPC_comm_manager_pb2_grpc.py
new file mode 100644
index 000000000..9a1bbeb41
--- /dev/null
+++ b/federatedscope/core/proto/gRPC_comm_manager_pb2_grpc.py
@@ -0,0 +1,68 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+
+import federatedscope.core.proto.gRPC_comm_manager_pb2 as gRPC__comm__manager__pb2
+
+
+class gRPCComServeFuncStub(object):
+ """Missing associated documentation comment in .proto file."""
+ def __init__(self, channel):
+ """Constructor.
+
+ Args:
+ channel: A grpc.Channel.
+ """
+ self.sendMessage = channel.unary_unary(
+ '/gRPCComServeFunc/sendMessage',
+ request_serializer=gRPC__comm__manager__pb2.MessageRequest.
+ SerializeToString,
+ response_deserializer=gRPC__comm__manager__pb2.MessageResponse.
+ FromString,
+ )
+
+
+class gRPCComServeFuncServicer(object):
+ """Missing associated documentation comment in .proto file."""
+ def sendMessage(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+
+def add_gRPCComServeFuncServicer_to_server(servicer, server):
+ rpc_method_handlers = {
+ 'sendMessage': grpc.unary_unary_rpc_method_handler(
+ servicer.sendMessage,
+ request_deserializer=gRPC__comm__manager__pb2.MessageRequest.
+ FromString,
+ response_serializer=gRPC__comm__manager__pb2.MessageResponse.
+ SerializeToString,
+ ),
+ }
+ generic_handler = grpc.method_handlers_generic_handler(
+ 'gRPCComServeFunc', rpc_method_handlers)
+ server.add_generic_rpc_handlers((generic_handler, ))
+
+
+# This class is part of an EXPERIMENTAL API.
+class gRPCComServeFunc(object):
+ """Missing associated documentation comment in .proto file."""
+ @staticmethod
+ def sendMessage(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request, target, '/gRPCComServeFunc/sendMessage',
+ gRPC__comm__manager__pb2.MessageRequest.SerializeToString,
+ gRPC__comm__manager__pb2.MessageResponse.FromString, options,
+ channel_credentials, insecure, call_credentials, compression,
+ wait_for_ready, timeout, metadata)
diff --git a/federatedscope/core/regularizer/__init__.py b/federatedscope/core/regularizer/__init__.py
new file mode 100644
index 000000000..5821bc72a
--- /dev/null
+++ b/federatedscope/core/regularizer/__init__.py
@@ -0,0 +1 @@
+from federatedscope.core.regularizer.proximal_regularizer import *
diff --git a/federatedscope/core/regularizer/proximal_regularizer.py b/federatedscope/core/regularizer/proximal_regularizer.py
new file mode 100644
index 000000000..aab1eb058
--- /dev/null
+++ b/federatedscope/core/regularizer/proximal_regularizer.py
@@ -0,0 +1,39 @@
+from federatedscope.register import register_regularizer
+try:
+ from torch.nn import Module
+ import torch
+except ImportError:
+ Module = object
+ torch = None
+
+REGULARIZER_NAME = "proximal_regularizer"
+
+
+class ProximalRegularizer(Module):
+ """Returns the norm of the specific weight update.
+
+ Arguments:
+ p (int): The order of norm.
+ tensor_before: The original matrix or vector
+ tensor_after: The updated matrix or vector
+
+ Returns:
+ Tensor: the norm of the given udpate.
+ """
+ def __init__(self):
+ super(ProximalRegularizer, self).__init__()
+
+ def forward(self, ctx, p=2):
+ norm = 0.
+ for w_init, w in zip(ctx.weight_init, ctx.model.parameters()):
+ norm += torch.pow(torch.norm(w - w_init, p), p)
+ return norm * 1. / float(p)
+
+
+def call_proximal_regularizer(type):
+ if type == REGULARIZER_NAME:
+ regularizer = ProximalRegularizer
+ return regularizer
+
+
+register_regularizer(REGULARIZER_NAME, call_proximal_regularizer)
diff --git a/federatedscope/core/secret_sharing/__init__.py b/federatedscope/core/secret_sharing/__init__.py
new file mode 100644
index 000000000..cf802c21f
--- /dev/null
+++ b/federatedscope/core/secret_sharing/__init__.py
@@ -0,0 +1 @@
+from federatedscope.core.secret_sharing.secret_sharing import AdditiveSecretSharing
diff --git a/federatedscope/core/secret_sharing/secret_sharing.py b/federatedscope/core/secret_sharing/secret_sharing.py
new file mode 100644
index 000000000..5e7a22784
--- /dev/null
+++ b/federatedscope/core/secret_sharing/secret_sharing.py
@@ -0,0 +1,96 @@
+from abc import ABC, abstractmethod
+import numpy as np
+try:
+ import torch
+except ImportError:
+ torch = None
+from math import fmod
+
+
+class SecretSharing(ABC):
+ def __init__(self):
+ pass
+
+ @abstractmethod
+ def secret_split(self, secret):
+ pass
+
+ @abstractmethod
+ def secret_reconstruct(self, secret_seq):
+ pass
+
+
+class AdditiveSecretSharing(SecretSharing):
+ """
+ AdditiveSecretSharing class, which can split a number into frames and recover it by summing up
+ """
+ def __init__(self, shared_party_num, size=60):
+ super(SecretSharing, self).__init__()
+ assert shared_party_num > 1, "AdditiveSecretSharing require shared_party_num > 1"
+ self.shared_party_num = shared_party_num
+ self.maximum = 2**size
+ self.mod_number = 2 * self.maximum + 1
+ self.epsilon = 1e8
+ self.mod_funs = np.vectorize(lambda x: x % self.mod_number)
+ self.float2fixedpoint = np.vectorize(self._float2fixedpoint)
+ self.fixedpoint2float = np.vectorize(self._fixedpoint2float)
+
+ def secret_split(self, secret):
+ """
+ To split the secret into frames according to the shared_party_num
+ """
+ if isinstance(secret, dict):
+ secret_list = [dict() for _ in range(self.shared_party_num)]
+ for key in secret:
+ for idx, each in enumerate(self.secret_split(secret[key])):
+ secret_list[idx][key] = each
+ return secret_list
+
+ if isinstance(secret, list) or isinstance(secret, np.ndarray):
+ secret = np.asarray(secret)
+ shape = [self.shared_party_num - 1] + list(secret.shape)
+ elif isinstance(secret, torch.Tensor):
+ secret = secret.numpy()
+ shape = [self.shared_party_num - 1] + list(secret.shape)
+ else:
+ shape = [self.shared_party_num - 1]
+
+ secret = self.float2fixedpoint(secret)
+ secret_seq = np.random.randint(low=0, high=self.mod_number, size=shape)
+ #last_seq = self.mod_funs(secret - self.mod_funs(np.sum(secret_seq, axis=0)))
+ last_seq = self.mod_funs(secret -
+ self.mod_funs(np.sum(secret_seq, axis=0)))
+
+ secret_seq = np.append(secret_seq,
+ np.expand_dims(last_seq, axis=0),
+ axis=0)
+ return secret_seq
+
+ def secret_reconstruct(self, secret_seq):
+ """
+ To recover the secret
+ """
+ assert len(secret_seq) == self.shared_party_num
+ merge_model = secret_seq[0].copy()
+ if isinstance(merge_model, dict):
+ for key in merge_model:
+ for idx in range(len(secret_seq)):
+ if idx == 0:
+ merge_model[key] = secret_seq[idx][key]
+ else:
+ merge_model[key] += secret_seq[idx][key]
+ merge_model[key] = self.fixedpoint2float(merge_model[key])
+
+ return merge_model
+
+ def _float2fixedpoint(self, x):
+ x = round(x * self.epsilon, 0)
+ assert abs(x) < self.maximum
+ return x % self.mod_number
+
+ def _fixedpoint2float(self, x):
+ x = x % self.mod_number
+ if x > self.maximum:
+ return -1 * (self.mod_number - x) / self.epsilon
+ else:
+ return x / self.epsilon
diff --git a/federatedscope/core/splitters/__init__.py b/federatedscope/core/splitters/__init__.py
new file mode 100644
index 000000000..f8e91f237
--- /dev/null
+++ b/federatedscope/core/splitters/__init__.py
@@ -0,0 +1,3 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
diff --git a/federatedscope/core/splitters/generic/__init__.py b/federatedscope/core/splitters/generic/__init__.py
new file mode 100644
index 000000000..a01edd0da
--- /dev/null
+++ b/federatedscope/core/splitters/generic/__init__.py
@@ -0,0 +1,7 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+from federatedscope.core.splitters.generic.lda_splitter import LDASplitter
+
+__all__ = ['LDASplitter']
\ No newline at end of file
diff --git a/federatedscope/core/splitters/generic/lda_splitter.py b/federatedscope/core/splitters/generic/lda_splitter.py
new file mode 100644
index 000000000..797d4e810
--- /dev/null
+++ b/federatedscope/core/splitters/generic/lda_splitter.py
@@ -0,0 +1,22 @@
+import numpy as np
+from federatedscope.core.splitters.utils import dirichlet_distribution_noniid_slice
+
+
+class LDASplitter(object):
+ def __init__(self, client_num, alpha=0.5):
+ self.client_num = client_num
+ self.alpha = alpha
+
+ def __call__(self, dataset, prior=None):
+ dataset = [ds for ds in dataset]
+ label = np.array([y for x, y in dataset])
+ idx_slice = dirichlet_distribution_noniid_slice(label,
+ self.client_num,
+ self.alpha,
+ prior=prior)
+ data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice]
+ return data_list
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(client_num={self.client_num}, ' \
+ f'alpha={self.alpha})'
diff --git a/federatedscope/core/splitters/graph/__init__.py b/federatedscope/core/splitters/graph/__init__.py
new file mode 100644
index 000000000..3c1a35bdf
--- /dev/null
+++ b/federatedscope/core/splitters/graph/__init__.py
@@ -0,0 +1,20 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+from federatedscope.core.splitters.graph.louvain_splitter import LouvainSplitter
+from federatedscope.core.splitters.graph.random_splitter import RandomSplitter
+
+from federatedscope.core.splitters.graph.reltype_splitter import RelTypeSplitter
+
+from federatedscope.core.splitters.graph.scaffold_splitter import ScaffoldSplitter
+from federatedscope.core.splitters.graph.graphtype_splitter import GraphTypeSplitter
+from federatedscope.core.splitters.graph.randchunk_splitter import RandChunkSplitter
+
+from federatedscope.core.splitters.graph.analyzer import Analyzer
+from federatedscope.core.splitters.graph.scaffold_lda_splitter import ScaffoldLdaSplitter
+
+__all__ = [
+ 'LouvainSplitter', 'RandomSplitter', 'RelTypeSplitter', 'ScaffoldSplitter',
+ 'GraphTypeSplitter', 'RandChunkSplitter', 'Analyzer', 'ScaffoldLdaSplitter'
+]
diff --git a/federatedscope/core/splitters/graph/analyzer.py b/federatedscope/core/splitters/graph/analyzer.py
new file mode 100644
index 000000000..b131d5393
--- /dev/null
+++ b/federatedscope/core/splitters/graph/analyzer.py
@@ -0,0 +1,181 @@
+import torch
+
+from typing import List
+from torch_geometric.data import Data
+from torch_geometric.utils import to_networkx, to_dense_adj, dense_to_sparse
+
+
+class Analyzer(object):
+ r"""Analyzer for raw graph and split subgraphs.
+
+ Arguments:
+ raw_data (PyG.data): raw graph.
+ split_data (list): the list for subgraphs split by splitter.
+
+ """
+ def __init__(self, raw_data: Data, split_data: List[Data]):
+
+ self.raw_data = raw_data
+ self.split_data = split_data
+
+ self.raw_graph = to_networkx(raw_data, to_undirected=True)
+ self.sub_graphs = [
+ to_networkx(g, to_undirected=True) for g in split_data
+ ]
+
+ def num_missing_edge(self):
+ r"""
+
+ Returns:
+ the number of missing edge and the rate of missing edge.
+
+ """
+ missing_edge = len(self.raw_graph.edges) - self.fl_adj().shape[1] // 2
+ rate_missing_edge = missing_edge / len(self.raw_graph.edges)
+
+ return missing_edge, rate_missing_edge
+
+ def fl_adj(self):
+ r"""
+
+ Returns:
+ the adj for missing edge ADJ.
+
+ """
+ raw_adj = to_dense_adj(self.raw_data.edge_index)[0]
+ adj = torch.zeros_like(raw_adj)
+ if 'index_orig' in self.split_data[0]:
+ for sub_g in self.split_data:
+ for row, col in sub_g.edge_index.T:
+ adj[sub_g.index_orig[row.item()]][sub_g.index_orig[
+ col.item()]] = 1
+
+ else:
+ raise KeyError(f'index_orig not in Split Data.')
+
+ return dense_to_sparse(adj)[0]
+
+ def fl_data(self):
+ r"""
+
+ Returns:
+ the split edge index.
+
+ """
+ fl_data = Data()
+ for key, item in self.raw_data:
+ if key == 'edge_index':
+ fl_data[key] = self.fl_adj()
+ else:
+ fl_data[key] = item
+
+ return fl_data
+
+ def missing_data(self):
+ r"""
+
+ Returns:
+ the graph data built by missing edge index.
+
+ """
+ ms_data = Data()
+ raw_edge_set = {tuple(x) for x in self.raw_data.edge_index.T.numpy()}
+ split_edge_set = {
+ tuple(x)
+ for x in self.fl_data().edge_index.T.numpy()
+ }
+ ms_set = raw_edge_set - split_edge_set
+ for key, item in self.raw_data:
+ if key == 'edge_index':
+ ms_data[key] = torch.tensor([list(x) for x in ms_set],
+ dtype=torch.int64).T
+ else:
+ ms_data[key] = item
+
+ return ms_data
+
+ def portion_ms_node(self):
+ r"""
+
+ Returns:
+ the proportion of nodes who miss egde.
+
+ """
+ cnt_list = []
+ ms_set = {x.item() for x in set(self.missing_data().edge_index[0])}
+ for sub_data in self.split_data:
+ cnt = 0
+ for idx in sub_data.index_orig:
+ if idx.item() in ms_set:
+ cnt += 1
+ cnt_list.append(cnt / sub_data.num_nodes)
+ return cnt_list
+
+ def average_clustering(self):
+ r"""
+
+ Returns:
+ the average clustering coefficient for the raw G and split G
+
+ """
+ import networkx.algorithms.cluster as cluster
+
+ return cluster.average_clustering(
+ self.raw_graph), cluster.average_clustering(
+ to_networkx(self.fl_data()))
+
+ def homophily_value(self, edge_index, y):
+ r"""
+
+ Returns:
+ calculate homophily_value
+
+ """
+ from torch_sparse import SparseTensor
+
+ if isinstance(edge_index, SparseTensor):
+ row, col, _ = edge_index.coo()
+ else:
+ row, col = edge_index
+
+ return int((y[row] == y[col]).sum()) / row.size(0)
+
+ def homophily(self):
+ r"""
+
+ Returns:
+ the homophily for the raw G and split G
+
+ """
+
+ return self.homophily_value(self.raw_data.edge_index,
+ self.raw_data.y), self.homophily_value(
+ self.fl_data().edge_index,
+ self.fl_data().y)
+
+ def hamming_distance_graph(self, data):
+ r"""
+
+ Returns:
+ calculate the hamming distance of graph data
+
+ """
+ edge_index, x = data.edge_index, data.x
+ cnt = 0
+ for row, col in edge_index.T:
+ row, col = row.item(), col.item()
+ cnt += torch.sum(x[row] != x[col]).item()
+
+ return cnt / edge_index.shape[1]
+
+ def hamming(self):
+ r"""
+
+ Returns:
+ the average hamming distance of feature for the raw G, split G and missing edge G
+
+ """
+ return self.hamming_distance_graph(
+ self.raw_data), self.hamming_distance_graph(
+ self.fl_data()), self.hamming_distance_graph(
+ self.missing_data())
\ No newline at end of file
diff --git a/federatedscope/core/splitters/graph/graphtype_splitter.py b/federatedscope/core/splitters/graph/graphtype_splitter.py
new file mode 100644
index 000000000..a40bcfe8b
--- /dev/null
+++ b/federatedscope/core/splitters/graph/graphtype_splitter.py
@@ -0,0 +1,27 @@
+import numpy as np
+from federatedscope.core.splitters.utils import dirichlet_distribution_noniid_slice
+
+
+class GraphTypeSplitter:
+ def __init__(self, client_num, alpha=0.5):
+ self.client_num = client_num
+ self.alpha = alpha
+
+ def __call__(self, dataset):
+ r"""Split dataset via dirichlet distribution to generate non-i.i.d data split.
+
+ Arguments:
+ dataset (List or PyG.dataset): The datasets.
+
+ Returns:
+ data_list (List(List(PyG.data))): Splited dataset via dirichlet.
+ """
+ dataset = [ds for ds in dataset]
+ label = np.array([ds.y.item() for ds in dataset])
+ idx_slice = dirichlet_distribution_noniid_slice(
+ label, self.client_num, self.alpha)
+ data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice]
+ return data_list
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}()'
diff --git a/federatedscope/core/splitters/graph/louvain_splitter.py b/federatedscope/core/splitters/graph/louvain_splitter.py
new file mode 100644
index 000000000..50319dce6
--- /dev/null
+++ b/federatedscope/core/splitters/graph/louvain_splitter.py
@@ -0,0 +1,77 @@
+import torch
+
+from torch_geometric.transforms import BaseTransform
+from torch_geometric.utils import to_networkx, to_undirected, from_networkx
+
+import networkx as nx
+import community as community_louvain
+
+
+class LouvainSplitter(BaseTransform):
+ r"""
+ Split Data into small data via louvain algorithm.
+
+ Args:
+ client_num (int): Split data into client_num of pieces.
+ delta (int): The gap between the number of nodes on the each client.
+
+ """
+ def __init__(self, client_num, delta=20):
+ self.client_num = client_num
+ self.delta = delta
+
+ def __call__(self, data):
+
+ data.index_orig = torch.arange(data.num_nodes)
+ G = to_networkx(
+ data,
+ node_attrs=['x', 'y', 'train_mask', 'val_mask', 'test_mask'],
+ to_undirected=True)
+ nx.set_node_attributes(G,
+ dict([(nid, nid)
+ for nid in range(nx.number_of_nodes(G))]),
+ name="index_orig")
+ partition = community_louvain.best_partition(G)
+
+ cluster2node = {}
+ for node in partition:
+ cluster = partition[node]
+ if cluster not in cluster2node:
+ cluster2node[cluster] = [node]
+ else:
+ cluster2node[cluster].append(node)
+
+ max_len = len(G) // self.client_num - self.delta
+ max_len_client = len(G) // self.client_num
+
+ tmp_cluster2node = {}
+ for cluster in cluster2node:
+ while len(cluster2node[cluster]) > max_len:
+ tmp_cluster = cluster2node[cluster][:max_len]
+ tmp_cluster2node[len(cluster2node) + len(tmp_cluster2node) +
+ 1] = tmp_cluster
+ cluster2node[cluster] = cluster2node[cluster][max_len:]
+ cluster2node.update(tmp_cluster2node)
+
+ orderedc2n = (zip(cluster2node.keys(), cluster2node.values()))
+ orderedc2n = sorted(orderedc2n, key=lambda x: len(x[1]), reverse=True)
+
+ client_node_idx = {idx: [] for idx in range(self.client_num)}
+ client_list = [idx for idx in range(self.client_num)]
+ idx = 0
+ for (cluster, node_list) in orderedc2n:
+ while len(node_list) + len(
+ client_node_idx[idx]) > max_len_client + self.delta:
+ idx = (idx + 1) % self.client_num
+ client_node_idx[idx] += node_list
+ idx = (idx + 1) % self.client_num
+
+ graphs = []
+ for owner in client_node_idx:
+ nodes = client_node_idx[owner]
+ graphs.append(from_networkx(nx.subgraph(G, nodes)))
+
+ return graphs
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}({self.client_num})'
diff --git a/federatedscope/core/splitters/graph/randchunk_splitter.py b/federatedscope/core/splitters/graph/randchunk_splitter.py
new file mode 100644
index 000000000..af7d56ffd
--- /dev/null
+++ b/federatedscope/core/splitters/graph/randchunk_splitter.py
@@ -0,0 +1,35 @@
+import numpy as np
+
+
+class RandChunkSplitter:
+ def __init__(self, client_num):
+ self.client_num = client_num
+
+ def __call__(self, dataset):
+ r"""Split dataset via random chunk.
+
+ Arguments:
+ dataset (List or PyG.dataset): The datasets.
+
+ Returns:
+ data_list (List(List(PyG.data))): Splited dataset via random chunk split.
+ """
+ data_list = []
+ dataset = [ds for ds in dataset]
+ num_graph = len(dataset)
+
+ # Split dataset
+ num_graph = len(dataset)
+ min_size = min(50, int(num_graph / self.client_num))
+
+ for i in range(self.client_num):
+ data_list.append(dataset[i * min_size:(i + 1) * min_size])
+ for graph in dataset[self.client_num * min_size:]:
+ client_idx = np.random.randint(low=0, high=self.client_num,
+ size=1)[0]
+ data_list[client_idx].append(graph)
+
+ return data_list
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}()'
diff --git a/federatedscope/core/splitters/graph/random_splitter.py b/federatedscope/core/splitters/graph/random_splitter.py
new file mode 100644
index 000000000..8f6cbcc4b
--- /dev/null
+++ b/federatedscope/core/splitters/graph/random_splitter.py
@@ -0,0 +1,106 @@
+import torch
+
+from torch_geometric.transforms import BaseTransform
+from torch_geometric.utils import to_networkx, from_networkx
+
+import numpy as np
+import networkx as nx
+
+EPSILON = 1e-5
+
+
+class RandomSplitter(BaseTransform):
+ r"""
+ Split Data into small data via random sampling.
+
+ Args:
+ client_num (int): Split data into client_num of pieces.
+ sampling_rate (str): Samples of the unique nodes for each client, eg. '0.2,0.2,0.2'.
+ overlapping_rate(float): Additional samples of overlapping data, eg. '0.4'
+ drop_edge(float): Drop edges (drop_edge / client_num) for each client whthin overlapping part.
+
+ """
+ def __init__(self,
+ client_num,
+ sampling_rate=None,
+ overlapping_rate=0,
+ drop_edge=0):
+
+ self.ovlap = overlapping_rate
+
+ if sampling_rate is not None:
+ self.sampling_rate = np.array(
+ [float(val) for val in sampling_rate.split(',')])
+ else:
+ # Default: Average
+ self.sampling_rate = (np.ones(client_num) -
+ self.ovlap) / client_num
+
+ if len(self.sampling_rate) != client_num:
+ raise ValueError(
+ f'The client_num ({client_num}) should be equal to the lenghth of sampling_rate and overlapping_rate.'
+ )
+
+ if abs((sum(self.sampling_rate) + self.ovlap) - 1) > EPSILON:
+ raise ValueError(
+ f'The sum of sampling_rate:{self.sampling_rate} and overlapping_rate({self.ovlap}) should be 1.'
+ )
+
+ self.client_num = client_num
+ self.drop_edge = drop_edge
+
+ def __call__(self, data, prior):
+
+ data.index_orig = torch.arange(data.num_nodes)
+ G = to_networkx(
+ data,
+ node_attrs=['x', 'y', 'train_mask', 'val_mask', 'test_mask'],
+ to_undirected=True)
+ nx.set_node_attributes(G,
+ dict([(nid, nid)
+ for nid in range(nx.number_of_nodes(G))]),
+ name="index_orig")
+
+ client_node_idx = {idx: [] for idx in range(self.client_num)}
+
+ indices = np.random.permutation(data.num_nodes)
+ sum_rate = 0
+ for idx, rate in enumerate(self.sampling_rate):
+ client_node_idx[idx] = indices[round(sum_rate *
+ data.num_nodes):round(
+ (sum_rate + rate) *
+ data.num_nodes)]
+ sum_rate += rate
+
+ if self.ovlap:
+ ovlap_nodes = indices[round(sum_rate * data.num_nodes):]
+ for idx in client_node_idx:
+ client_node_idx[idx] = np.concatenate(
+ (client_node_idx[idx], ovlap_nodes))
+
+ # Drop_edge index for each client
+ if self.drop_edge:
+ ovlap_graph = nx.Graph(nx.subgraph(G, ovlap_nodes))
+ ovlap_edge_ind = np.random.permutation(
+ ovlap_graph.number_of_edges())
+ drop_all = ovlap_edge_ind[:round(ovlap_graph.number_of_edges() *
+ self.drop_edge)]
+ drop_client = [
+ drop_all[s:s + round(len(drop_all) / self.client_num)]
+ for s in range(0, len(drop_all),
+ round(len(drop_all) / self.client_num))
+ ]
+
+ graphs = []
+ for owner in client_node_idx:
+ nodes = client_node_idx[owner]
+ sub_g = nx.Graph(nx.subgraph(G, nodes))
+ if self.drop_edge:
+ sub_g.remove_edges_from(
+ np.array(ovlap_graph.edges)[drop_client[owner]])
+ graphs.append(from_networkx(sub_g))
+
+ return graphs
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}({self.client_num})'
diff --git a/federatedscope/core/splitters/graph/reltype_splitter.py b/federatedscope/core/splitters/graph/reltype_splitter.py
new file mode 100644
index 000000000..7d38eb362
--- /dev/null
+++ b/federatedscope/core/splitters/graph/reltype_splitter.py
@@ -0,0 +1,67 @@
+import torch
+
+from torch_geometric.data import Data
+from torch_geometric.utils import from_networkx, to_undirected
+from torch_geometric.transforms import BaseTransform, RemoveIsolatedNodes
+
+from federatedscope.core.splitters.utils import dirichlet_distribution_noniid_slice
+
+
+class RelTypeSplitter(BaseTransform):
+ r"""
+ Split Data into small data via dirichlet distribution to
+ generate non-i.i.d data split.
+
+ Arguments:
+ client_num (int): Split data into client_num of pieces.
+ alpha (float): parameter controlling the identicalness among clients.
+
+ """
+ def __init__(self, client_num, alpha=0.5, realloc_mask=False):
+ self.client_num = client_num
+ self.alpha = alpha
+ self.realloc_mask = realloc_mask
+
+ def __call__(self, data):
+ data_list = []
+ label = data.edge_type.numpy()
+ idx_slice = dirichlet_distribution_noniid_slice(
+ label, self.client_num, self.alpha)
+ # Reallocation train/val/test mask
+ train_ratio = data.train_edge_mask.sum().item() / data.num_edges
+ valid_ratio = data.valid_edge_mask.sum().item() / data.num_edges
+ test_ratio = data.test_edge_mask.sum().item() / data.num_edges
+ for idx_j in idx_slice:
+ edge_index = data.edge_index.T[idx_j].T
+ edge_type = data.edge_type[idx_j]
+ train_edge_mask = data.train_edge_mask[idx_j]
+ valid_edge_mask = data.valid_edge_mask[idx_j]
+ test_edge_mask = data.test_edge_mask[idx_j]
+ if self.realloc_mask:
+ num_edges = edge_index.size(-1)
+ indices = torch.randperm(num_edges)
+ train_edge_mask = torch.zeros(num_edges, dtype=torch.bool)
+ train_edge_mask[indices[:round(train_ratio *
+ num_edges)]] = True
+ valid_edge_mask = torch.zeros(num_edges, dtype=torch.bool)
+ valid_edge_mask[
+ indices[round(train_ratio *
+ num_edges):-round(test_ratio *
+ num_edges)]] = True
+ test_edge_mask = torch.zeros(num_edges, dtype=torch.bool)
+ test_edge_mask[indices[-round(test_ratio * num_edges):]] = True
+ sub_g = Data(x=data.x,
+ edge_index=edge_index,
+ index_orig=data.index_orig,
+ edge_type=edge_type,
+ train_edge_mask=train_edge_mask,
+ valid_edge_mask=valid_edge_mask,
+ test_edge_mask=test_edge_mask,
+ input_edge_index=to_undirected(
+ edge_index.T[train_edge_mask].T))
+ data_list.append(sub_g)
+
+ return data_list
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}({self.client_num})'
diff --git a/federatedscope/core/splitters/graph/scaffold_lda_splitter.py b/federatedscope/core/splitters/graph/scaffold_lda_splitter.py
new file mode 100644
index 000000000..da0e24a24
--- /dev/null
+++ b/federatedscope/core/splitters/graph/scaffold_lda_splitter.py
@@ -0,0 +1,178 @@
+import logging
+import numpy as np
+import torch
+
+from rdkit import Chem
+from rdkit import RDLogger
+from rdkit.Chem.Scaffolds import MurckoScaffold
+from federatedscope.core.splitters.utils import dirichlet_distribution_noniid_slice
+from federatedscope.core.splitters.graph.scaffold_splitter import generate_scaffold
+
+logger = logging.getLogger(__name__)
+
+RDLogger.DisableLog('rdApp.*')
+
+
+class GenFeatures:
+ r"""Implementation of 'CanonicalAtomFeaturizer' and 'CanonicalBondFeaturizer' in DGL.
+ Source: https://lifesci.dgl.ai/_modules/dgllife/utils/featurizers.html
+
+ Arguments:
+ data: PyG.data in PyG.dataset.
+
+ Returns:
+ data: PyG.data, data passing featurizer.
+
+ """
+ def __init__(self):
+ self.symbols = [
+ 'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
+ 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag',
+ 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni',
+ 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'other'
+ ]
+
+ self.hybridizations = [
+ Chem.rdchem.HybridizationType.SP,
+ Chem.rdchem.HybridizationType.SP2,
+ Chem.rdchem.HybridizationType.SP3,
+ Chem.rdchem.HybridizationType.SP3D,
+ Chem.rdchem.HybridizationType.SP3D2,
+ 'other',
+ ]
+
+ self.stereos = [
+ Chem.rdchem.BondStereo.STEREONONE,
+ Chem.rdchem.BondStereo.STEREOANY,
+ Chem.rdchem.BondStereo.STEREOZ,
+ Chem.rdchem.BondStereo.STEREOE,
+ Chem.rdchem.BondStereo.STEREOCIS,
+ Chem.rdchem.BondStereo.STEREOTRANS,
+ ]
+
+ def __call__(self, data):
+ mol = Chem.MolFromSmiles(data.smiles)
+
+ xs = []
+ for atom in mol.GetAtoms():
+ symbol = [0.] * len(self.symbols)
+ if atom.GetSymbol() in self.symbols:
+ symbol[self.symbols.index(atom.GetSymbol())] = 1.
+ else:
+ symbol[self.symbols.index('other')] = 1.
+ degree = [0.] * 10
+ degree[atom.GetDegree()] = 1.
+ implicit = [0.] * 6
+ implicit[atom.GetImplicitValence()] = 1.
+ formal_charge = atom.GetFormalCharge()
+ radical_electrons = atom.GetNumRadicalElectrons()
+ hybridization = [0.] * len(self.hybridizations)
+ if atom.GetHybridization() in self.hybridizations:
+ hybridization[self.hybridizations.index(
+ atom.GetHybridization())] = 1.
+ else:
+ hybridization[self.hybridizations.index('other')] = 1.
+ aromaticity = 1. if atom.GetIsAromatic() else 0.
+ hydrogens = [0.] * 5
+ hydrogens[atom.GetTotalNumHs()] = 1.
+
+ x = torch.tensor(symbol + degree + implicit + [formal_charge] +
+ [radical_electrons] + hybridization +
+ [aromaticity] + hydrogens)
+ xs.append(x)
+
+ data.x = torch.stack(xs, dim=0)
+
+ edge_attrs = []
+ for bond in mol.GetBonds():
+ bond_type = bond.GetBondType()
+ single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0.
+ double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0.
+ triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0.
+ aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0.
+ conjugation = 1. if bond.GetIsConjugated() else 0.
+ ring = 1. if bond.IsInRing() else 0.
+ stereo = [0.] * 6
+ stereo[self.stereos.index(bond.GetStereo())] = 1.
+
+ edge_attr = torch.tensor(
+ [single, double, triple, aromatic, conjugation, ring] + stereo)
+
+ edge_attrs += [edge_attr, edge_attr]
+
+ if len(edge_attrs) == 0:
+ data.edge_index = torch.zeros((2, 0), dtype=torch.long)
+ data.edge_attr = torch.zeros((0, 10), dtype=torch.float)
+ else:
+ num_atoms = mol.GetNumAtoms()
+ feats = torch.stack(edge_attrs, dim=0)
+ feats = torch.cat([feats, torch.zeros(feats.shape[0], 1)], dim=1)
+ self_loop_feats = torch.zeros(num_atoms, feats.shape[1])
+ self_loop_feats[:, -1] = 1
+ feats = torch.cat([feats, self_loop_feats], dim=0)
+ data.edge_attr = feats
+
+ return data
+
+
+def gen_scaffold_lda_split(dataset, client_num=5, alpha=0.1):
+ r"""
+ return dict{ID:[idxs]}
+ """
+ logger.info('Scaffold split might take minutes, please wait...')
+ scaffolds = {}
+ for idx, data in enumerate(dataset):
+ smiles = data.smiles
+ mol = Chem.MolFromSmiles(smiles)
+ scaffold = generate_scaffold(smiles)
+ if scaffold not in scaffolds:
+ scaffolds[scaffold] = [idx]
+ else:
+ scaffolds[scaffold].append(idx)
+ # Sort from largest to smallest scaffold sets
+ scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
+ scaffold_list = [
+ list(scaffold_set)
+ for (scaffold,
+ scaffold_set) in sorted(scaffolds.items(),
+ key=lambda x: (len(x[1]), x[1][0]),
+ reverse=True)
+ ]
+ label = np.zeros(len(dataset))
+ for i in range(len(scaffold_list)):
+ label[scaffold_list[i]] = i + 1
+ label = torch.LongTensor(label)
+ # Split data to list
+ idx_slice = dirichlet_distribution_noniid_slice(label, client_num, alpha)
+ return idx_slice
+
+
+class ScaffoldLdaSplitter:
+ r"""First adopt scaffold splitting and then assign the samples to clients according to Latent Dirichlet Allocation.
+
+ Arguments:
+ dataset (List or PyG.dataset): The molecular datasets.
+ alpha (float): Partition hyperparameter in LDA, smaller alpha generates more extreme heterogeneous scenario.
+
+ Returns:
+ data_list (List(List(PyG.data))): Splited dataset via scaffold split.
+
+ """
+ def __init__(self, client_num, alpha):
+ self.client_num = client_num
+ self.alpha = alpha
+
+ def __call__(self, dataset):
+ featurizer = GenFeatures()
+ data = []
+ for ds in dataset:
+ ds = featurizer(ds)
+ data.append(ds)
+ dataset = data
+ idx_slice = gen_scaffold_lda_split(dataset, self.client_num,
+ self.alpha)
+ data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice]
+ return data_list
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}()'
diff --git a/federatedscope/core/splitters/graph/scaffold_splitter.py b/federatedscope/core/splitters/graph/scaffold_splitter.py
new file mode 100644
index 000000000..9e8eec40c
--- /dev/null
+++ b/federatedscope/core/splitters/graph/scaffold_splitter.py
@@ -0,0 +1,58 @@
+import logging
+import numpy as np
+from rdkit import Chem
+from rdkit import RDLogger
+from rdkit.Chem.Scaffolds import MurckoScaffold
+
+logger = logging.getLogger(__name__)
+
+RDLogger.DisableLog('rdApp.*')
+
+def generate_scaffold(smiles, include_chirality=False):
+ """return scaffold string of target molecule"""
+ mol = Chem.MolFromSmiles(smiles)
+ scaffold = MurckoScaffold\
+ .MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)
+ return scaffold
+
+
+def gen_scaffold_split(dataset, client_num=5):
+ r"""
+ return dict{ID:[idxs]}
+ """
+ logger.info('Scaffold split might take minutes, please wait...')
+ scaffolds = {}
+ for idx, data in enumerate(dataset):
+ smiles = data.smiles
+ mol = Chem.MolFromSmiles(smiles)
+ scaffold = generate_scaffold(smiles)
+ if scaffold not in scaffolds:
+ scaffolds[scaffold] = [idx]
+ else:
+ scaffolds[scaffold].append(idx)
+ # Sort from largest to smallest scaffold sets
+ scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
+ scaffold_list = [
+ list(scaffold_set)
+ for (scaffold,
+ scaffold_set) in sorted(scaffolds.items(),
+ key=lambda x: (len(x[1]), x[1][0]),
+ reverse=True)
+ ]
+ scaffold_idxs = sum(scaffold_list, [])
+ # Split data to list
+ splits = np.array_split(scaffold_idxs, client_num)
+ return [splits[ID] for ID in range(client_num)]
+
+class ScaffoldSplitter:
+ def __init__(self, client_num):
+ self.client_num = client_num
+
+ def __call__(self, dataset):
+ dataset = [ds for ds in dataset]
+ idx_slice = gen_scaffold_split(dataset)
+ data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice]
+ return data_list
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}()'
diff --git a/federatedscope/core/splitters/utils.py b/federatedscope/core/splitters/utils.py
new file mode 100644
index 000000000..62e46153a
--- /dev/null
+++ b/federatedscope/core/splitters/utils.py
@@ -0,0 +1,130 @@
+# import numpy as np
+
+# def dirichlet_distribution_noniid_slice(label, client_num, alpha, min_size=10):
+# r"""Get sample index list for each client from the Dirichlet distribution.
+# https://github.com/FedML-AI/FedML/blob/master/fedml_core/non_iid_partition/noniid_partition.py
+
+# Arguments:
+# label (np.array): Label list to be split.
+# client_num (int): Split label into client_num parts.
+# alpha (float): alpha of LDA.
+# min_size (int): min number of sample in each client
+# Returns:
+# idx_slice (List): List of splited label index slice.
+# """
+# if len(label.shape) != 1:
+# raise ValueError('Only support single-label tasks!')
+# num = len(label)
+# classes = len(np.unique(label))
+# assert num > client_num * min_size, f'The number of sample should be greater than {client_num * min_size}.'
+# size = 0
+# while size < min_size:
+# idx_slice = [[] for _ in range(client_num)]
+# for k in range(classes):
+# # for label k
+# idx_k = np.where(label == k)[0]
+# np.random.shuffle(idx_k)
+# prop = np.random.dirichlet(np.repeat(alpha, client_num))
+# prop = np.array([
+# p * (len(idx_j) < num / client_num)
+# for p, idx_j in zip(prop, idx_slice)
+# ])
+# prop = prop / sum(prop)
+# prop = (np.cumsum(prop) * len(idx_k)).astype(int)[:-1]
+# idx_slice = [
+# idx_j + idx.tolist()
+# for idx_j, idx in zip(idx_slice, np.split(idx_k, prop))
+# ]
+# size = min([len(idx_j) for idx_j in idx_slice])
+# for i in range(client_num):
+# np.random.shuffle(idx_slice[i])
+# return idx_slice
+
+import numpy as np
+
+
+def _split_according_to_prior(label, client_num, prior):
+ assert client_num == len(prior)
+ classes = len(np.unique(label))
+ assert classes == len(np.unique(np.concatenate(prior, 0)))
+
+ # counting
+ frequency = np.zeros(shape=(client_num, classes))
+ for idx, client_prior in enumerate(prior):
+ for each in client_prior:
+ frequency[idx][each] += 1
+ sum_frequency = np.sum(frequency, axis=0)
+
+ idx_slice = [[] for _ in range(client_num)]
+ for k in range(classes):
+ idx_k = np.where(label == k)[0]
+ np.random.shuffle(idx_k)
+ nums_k = np.ceil(frequency[:, k] / sum_frequency[k] *
+ len(idx_k)).astype(int)
+ while len(idx_k) < np.sum(nums_k):
+ random_client = np.random.choice(range(client_num))
+ if nums_k[random_client] > 0:
+ nums_k[random_client] -= 1
+ assert len(idx_k) == np.sum(nums_k)
+ idx_slice = [
+ idx_j + idx.tolist() for idx_j, idx in zip(
+ idx_slice, np.split(idx_k,
+ np.cumsum(nums_k)[:-1]))
+ ]
+
+ for i in range(len(idx_slice)):
+ np.random.shuffle(idx_slice[i])
+ return idx_slice
+
+
+def dirichlet_distribution_noniid_slice(label,
+ client_num,
+ alpha,
+ min_size=1,
+ prior=None):
+ r"""Get sample index list for each client from the Dirichlet distribution.
+ https://github.com/FedML-AI/FedML/blob/master/fedml_core/non_iid
+ partition/noniid_partition.py
+
+ Arguments:
+ label (np.array): Label list to be split.
+ client_num (int): Split label into client_num parts.
+ alpha (float): alpha of LDA.
+ min_size (int): min number of sample in each client
+ Returns:
+ idx_slice (List): List of splited label index slice.
+ """
+ if len(label.shape) != 1:
+ raise ValueError('Only support single-label tasks!')
+
+ if prior is not None:
+ return _split_according_to_prior(label, client_num, prior)
+
+ num = len(label)
+ classes = len(np.unique(label))
+ assert num > client_num * min_size, f'The number of sample should be ' \
+ f'greater than' \
+ f' {client_num * min_size}.'
+ size = 0
+ while size < min_size:
+ idx_slice = [[] for _ in range(client_num)]
+ for k in range(classes):
+ # for label k
+ idx_k = np.where(label == k)[0]
+ np.random.shuffle(idx_k)
+ prop = np.random.dirichlet(np.repeat(alpha, client_num))
+ # prop = np.array([
+ # p * (len(idx_j) < num / client_num)
+ # for p, idx_j in zip(prop, idx_slice)
+ # ])
+ # prop = prop / sum(prop)
+ # after commentting out this part, we may get more non-iid dataset for each client.
+ prop = (np.cumsum(prop) * len(idx_k)).astype(int)[:-1]
+ idx_slice = [
+ idx_j + idx.tolist()
+ for idx_j, idx in zip(idx_slice, np.split(idx_k, prop))
+ ]
+ size = min([len(idx_j) for idx_j in idx_slice])
+ for i in range(client_num):
+ np.random.shuffle(idx_slice[i])
+ return idx_slice
diff --git a/federatedscope/core/strategy.py b/federatedscope/core/strategy.py
new file mode 100644
index 000000000..5c4a80bdc
--- /dev/null
+++ b/federatedscope/core/strategy.py
@@ -0,0 +1,23 @@
+import sys
+
+
+class Strategy(object):
+ def __init__(self, stg_type=None, threshold=0):
+ self._stg_type = stg_type
+ self._threshold = threshold
+
+ @property
+ def stg_type(self):
+ return self._stg_type
+
+ @stg_type.setter
+ def stg_type(self, value):
+ self._stg_type = value
+
+ @property
+ def threshold(self):
+ return self._threshold
+
+ @threshold.setter
+ def threshold(self, value):
+ self._threshold = value
diff --git a/federatedscope/core/trainers/__init__.py b/federatedscope/core/trainers/__init__.py
new file mode 100644
index 000000000..c8cac5a89
--- /dev/null
+++ b/federatedscope/core/trainers/__init__.py
@@ -0,0 +1,18 @@
+from federatedscope.core.trainers.trainer import Trainer
+from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
+from federatedscope.core.trainers.trainer_multi_model import GeneralMultiModelTrainer
+from federatedscope.core.trainers.trainer_pFedMe import wrap_pFedMeTrainer
+from federatedscope.core.trainers.trainer_Ditto import wrap_DittoTrainer
+from federatedscope.core.trainers.trainer_FedEM import FedEMTrainer
+from federatedscope.core.trainers.context import Context
+from federatedscope.core.trainers.trainer_fedprox import wrap_fedprox_trainer
+from federatedscope.core.trainers.trainer_nbafl import wrap_nbafl_trainer, wrap_nbafl_server
+from federatedscope.core.trainers.benign_trainer import wrap_benignTrainer
+from federatedscope.core.trainers.trainer_FedRep import wrap_FedRepTrainer
+
+__all__ = [
+ 'Trainer', 'Context', 'GeneralTorchTrainer', 'GeneralMultiModelTrainer',
+ 'wrap_pFedMeTrainer', 'wrap_DittoTrainer', 'FedEMTrainer',
+ 'wrap_fedprox_trainer', 'wrap_nbafl_trainer', 'wrap_nbafl_server',
+ 'wrap_benignTrainer', 'wrap_FedRepTrainer'
+]
diff --git a/federatedscope/core/trainers/benign_trainer.py b/federatedscope/core/trainers/benign_trainer.py
new file mode 100644
index 000000000..65ddaf6d3
--- /dev/null
+++ b/federatedscope/core/trainers/benign_trainer.py
@@ -0,0 +1,157 @@
+from calendar import c
+import logging
+from typing import Type
+import torch
+import numpy as np
+
+from federatedscope.core.trainers import GeneralTorchTrainer
+from federatedscope.core.auxiliaries.transform_builder import get_transform
+from federatedscope.attack.auxiliary.backdoor_utils import normalize
+from federatedscope.core.auxiliaries.dataloader_builder import WrapDataset
+from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader
+from federatedscope.core.auxiliaries.ReIterator import ReIterator
+
+logger = logging.getLogger(__name__)
+
+
+def wrap_benignTrainer(
+ base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
+ '''
+ Warp the benign trainer for backdoor attack:
+
+ We just add the normalization operation.
+
+ Args:
+ base_trainer: Type: core.trainers.GeneralTorchTrainer
+
+ :returns:
+ The wrapped trainer; Type: core.trainers.GeneralTorchTrainer
+
+ '''
+
+ base_trainer.register_hook_in_eval(new_hook=hook_on_fit_start_test_poison,
+ trigger='on_fit_start',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_eval(
+ new_hook=hook_on_epoch_start_test_poison,
+ trigger='on_epoch_start',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_eval(
+ new_hook=hook_on_batch_start_test_poison,
+ trigger='on_batch_start',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_eval(
+ new_hook=hook_on_batch_forward_test_poison,
+ trigger='on_batch_forward',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_eval(new_hook=hook_on_batch_end_test_poison,
+ trigger="on_batch_end",
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_eval(new_hook=hook_on_fit_end_test_poison,
+ trigger='on_fit_end',
+ insert_pos=0)
+
+ return base_trainer
+
+
+def hook_on_fit_start_test_poison(ctx):
+
+ ctx['poison_' + ctx.cur_data_split +
+ '_loader'] = ctx.data['poison_' + ctx.cur_data_split]
+ ctx['poison_' + ctx.cur_data_split +
+ '_data'] = ctx.data['poison_' + ctx.cur_data_split].dataset
+ ctx['num_poison_' + ctx.cur_data_split + '_data'] = len(
+ ctx.data['poison_' + ctx.cur_data_split].dataset)
+ setattr(ctx, "poison_loss_batch_total_{}".format(ctx.cur_data_split), 0)
+ setattr(ctx, "poison_num_samples_{}".format(ctx.cur_data_split), 0)
+ setattr(ctx, "poison_{}_y_true".format(ctx.cur_data_split), [])
+ setattr(ctx, "poison_{}_y_prob".format(ctx.cur_data_split), [])
+
+
+def hook_on_epoch_start_test_poison(ctx):
+ if ctx.get("poison_{}_loader".format(ctx.cur_data_split)) is None:
+ loader = get_dataloader(
+ WrapDataset(ctx.get("poison_{}_data".format(ctx.cur_data_split))),
+ ctx.cfg)
+ setattr(ctx, "poison_{}_loader".format(ctx.cur_data_split),
+ ReIterator(loader))
+ elif not isinstance(ctx.get("poison_{}_loader".format(ctx.cur_data_split)),
+ ReIterator):
+ setattr(
+ ctx, "poison_{}_loader".format(ctx.cur_data_split),
+ ReIterator(ctx.get("poison_{}_loader".format(ctx.cur_data_split))))
+ else:
+ ctx.get("poison_{}_loader".format(ctx.cur_data_split)).reset()
+
+
+def hook_on_batch_start_test_poison(ctx):
+ try:
+ ctx.poison_data_batch = next(
+ ctx.get("poison_{}_loader".format(ctx.cur_data_split)))
+ except StopIteration:
+ raise StopIteration
+
+
+def hook_on_batch_forward_test_poison(ctx):
+
+ x, label = [_.to(ctx.device) for _ in ctx.poison_data_batch]
+ pred = ctx.model(x)
+ if len(label.size()) == 0:
+ label = label.unsqueeze(0)
+ ctx.poison_loss_batch = ctx.criterion(pred, label)
+ ctx.poison_y_true = label
+ ctx.poison_y_prob = pred
+
+ ctx.poison_batch_size = len(label)
+
+
+def hook_on_batch_end_test_poison(ctx):
+
+ setattr(
+ ctx, "poison_loss_batch_total_{}".format(ctx.cur_data_split),
+ ctx.get("poison_loss_batch_total_{}".format(ctx.cur_data_split)) +
+ ctx.poison_loss_batch.item() * ctx.poison_batch_size)
+
+ setattr(
+ ctx, "poison_num_samples_{}".format(ctx.cur_data_split),
+ ctx.get("poison_num_samples_{}".format(ctx.cur_data_split)) +
+ ctx.poison_batch_size)
+
+ ctx.get("poison_{}_y_true".format(ctx.cur_data_split)).append(
+ ctx.poison_y_true.detach().cpu().numpy())
+
+ ctx.get("poison_{}_y_prob".format(ctx.cur_data_split)).append(
+ ctx.poison_y_prob.detach().cpu().numpy())
+
+ ctx.poison_data_batch = None
+ ctx.poison_batch_size = None
+ ctx.poison_loss_task = None
+ ctx.poison_loss_batch = None
+ ctx.poison_loss_regular = None
+ ctx.poison_y_true = None
+ ctx.poison_y_prob = None
+
+
+def hook_on_fit_end_test_poison(ctx):
+ """Evaluate metrics of poisoning attacks.
+
+ """
+ setattr(
+ ctx, "poison_{}_y_true".format(ctx.cur_data_split),
+ np.concatenate(ctx.get("poison_{}_y_true".format(ctx.cur_data_split))))
+ setattr(
+ ctx, "poison_{}_y_prob".format(ctx.cur_data_split),
+ np.concatenate(ctx.get("poison_{}_y_prob".format(ctx.cur_data_split))))
+
+
+def hook_on_fit_start_addnormalize(ctx):
+ '''
+ for this function, we do not conduct anything.
+ '''
+
+ pass
diff --git a/federatedscope/core/trainers/context.py b/federatedscope/core/trainers/context.py
new file mode 100644
index 000000000..fb222a1de
--- /dev/null
+++ b/federatedscope/core/trainers/context.py
@@ -0,0 +1,189 @@
+import logging
+
+import math
+
+from federatedscope.core.auxiliaries.criterion_builder import get_criterion
+from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
+from federatedscope.core.auxiliaries.model_builder import get_trainable_para_names
+from federatedscope.core.auxiliaries.regularizer_builder import get_regularizer
+
+
+class Context(dict):
+ """Record and pass variables among different hook functions.
+
+ Arguments:
+ model (Module): training model
+ data (dict): a dict contains train/val/test dataset or dataloader
+ device: running device
+
+ Record attributes:
+ - model (Module): the training model
+ - data (dict): a dict contains train/val/test dataset or dataloader
+ - device (torch.device): specific device to running to
+ - criterion: specific loss function
+ - optimizer: specific optimizer
+ - mode: maintain the current mode of the model
+
+ - data_batch: current batch data from train/test/val data loader
+
+ - trainable_para_names (list): a list of the names of the trainable parameters within ```ctx.model```
+ - train_data: training dataset
+ - train_loader: training dataloader
+ - num_train_data (int): the number of training samples within one epoch
+ - num_train_epoch (int): the number of total training epochs
+ - num_train_batch (int): the number of batches within one completed training epoch
+ - num_train_batch_last_epoch (int): the number of batches within the last epoch
+
+ - test_data: test data
+ - test_loader: test dataloader
+ - num_test_data (int): the number of test samples within one epoch
+ - num_test_epoch (int): the number of test epochs, default 1
+ - num_test_batch (int): the number of batches within one completed test epoch
+
+ - val_data: val data
+ - val_loader: val dataloader
+ - num_val_data (int): the number of val samples within one epoch
+ - num_val_epoch (int): the number of val epochs, default 1
+ - num_val_batch (int): the number of batches within one completed val epoch
+
+ Statistical variables:
+ - loss_batch (float): loss of the current data_batch, shared by train/test/val
+ - loss_regular (float): loss of the regularizer
+ - loss_task (float): the sum of loss_batch and loss_regular
+
+ - loss_total_batch_train (float): accumulated batch loss during training
+ - loss_total_regular_train (float): accumulated regular loss during training
+ - num_samples_train (int): accumulated number of training samples involved at present
+
+ - loss_total_test (float): accumulated batch loss during test
+ - num_samples_test (float): accumulated regular loss during test
+
+ - loss_total_val (float): accumulated batch loss during val
+ - num_samples_val (float): accumulated regular loss during val
+
+ - eval_metrics (dict): evaluation results
+ """
+
+ __setattr__ = dict.__setitem__
+ __delattr__ = dict.__delitem__
+
+ def __getattr__(self, item):
+ try:
+ return self[item]
+ except KeyError:
+ raise AttributeError("Attribute {} is not found".format(item))
+
+ def __init__(self,
+ model,
+ cfg,
+ data=None,
+ device=None,
+ init_dict=None,
+ init_attr=True):
+ if init_dict is None:
+ super(Context, self).__init__()
+ else:
+ super(Context, self).__init__(init_dict)
+
+ self.cfg = cfg
+ self.model = model
+ self.data = data
+ self.device = device
+ self.cur_mode = None
+ self.cur_data_split = None
+
+ if init_attr:
+ # setup static variables for training/evaluation
+ self.setup_vars()
+
+ def setup_vars(self):
+ if self.cfg.backend == 'torch':
+ self.trainable_para_names = get_trainable_para_names(self.model)
+ self.criterion = get_criterion(self.cfg.criterion.type,
+ self.device)
+ self.regularizer = get_regularizer(self.cfg.regularizer.type)
+ self.optimizer = get_optimizer(self.model, **self.cfg.optimizer)
+ self.grad_clip = self.cfg.grad.grad_clip
+ elif self.cfg.backend == 'tensorflow':
+ self.trainable_para_names = self.model.trainable_variables()
+ self.criterion = None
+ self.regularizer = None
+ self.optimizer = None
+ self.grad_clip = None
+
+ self.mode = list()
+ self.cur_data_splits_used_by_routine = list()
+
+ # Process training data
+ if self.train_data is not None or self.train_loader is not None:
+ # Calculate the number of update steps during training given the local_update_steps
+ num_train_batch, num_train_batch_last_epoch, num_train_epoch, num_total_train_batch = self.pre_calculate_batch_epoch_num(
+ self.cfg.federate.local_update_steps)
+
+ self.num_train_epoch = num_train_epoch
+ self.num_train_batch = num_train_batch
+ self.num_train_batch_last_epoch = num_train_batch_last_epoch
+ self.num_total_train_batch = num_total_train_batch
+
+ # Process evaluation data
+ # new function part
+ name_list = self.cfg.data.dataset
+ for mode in name_list:
+ if mode != 'train':
+ # new function part
+ setattr(self, "num_{}_epoch".format(mode), 1)
+ if self.get("{}_data".format(mode)) is not None or self.get(
+ "{}_loader".format(mode)) is not None:
+ setattr(
+ self, "num_{}_batch".format(mode),
+ getattr(self, "num_{}_data".format(mode)) //
+ self.cfg.data.batch_size +
+ int(not self.cfg.data.drop_last and bool(
+ getattr(self, "num_{}_data".format(mode)) %
+ self.cfg.data.batch_size)))
+
+ def pre_calculate_batch_epoch_num(self, local_update_steps):
+ num_train_batch = self.num_train_data // self.cfg.data.batch_size + int(
+ not self.cfg.data.drop_last
+ and bool(self.num_train_data % self.cfg.data.batch_size))
+ if self.cfg.federate.batch_or_epoch == "epoch":
+ num_train_epoch = local_update_steps
+ num_train_batch_last_epoch = num_train_batch
+ num_total_train_batch = local_update_steps * num_train_batch
+ elif num_train_batch == 0:
+ raise RuntimeError(
+ "The number of training batch is 0, please check 'batch_size' or set 'drop_last' as False"
+ )
+ else:
+ num_train_epoch = math.ceil(local_update_steps / num_train_batch)
+ num_train_batch_last_epoch = local_update_steps % num_train_batch or num_train_batch
+ num_total_train_batch = local_update_steps
+ return num_train_batch, num_train_batch_last_epoch, num_train_epoch, num_total_train_batch
+
+ def append_mode(self, mode):
+ self.mode.append(mode)
+ self.cur_mode = self.mode[-1]
+ self.change_mode(self.cur_mode)
+
+ def pop_mode(self):
+ self.mode.pop()
+ self.cur_mode = self.mode[-1] if len(self.mode) != 0 else None
+ if len(self.mode) != 0:
+ self.change_mode(self.cur_mode)
+
+ def change_mode(self, mode):
+ # change state
+ if self.cfg.backend == 'torch':
+ getattr(self.model, mode if mode == 'train' else 'eval')()
+ else:
+ pass
+
+ def track_used_dataset(self, dataset):
+ # stack-style to enable mixture usage such as evaluation on train dataset
+ self.cur_data_splits_used_by_routine.append(dataset)
+ self.cur_data_split = self.cur_data_splits_used_by_routine[-1]
+
+ def reset_used_dataset(self):
+ self.cur_data_splits_used_by_routine.pop()
+ self.cur_data_split = self.cur_data_splits_used_by_routine[-1] if \
+ len(self.cur_data_splits_used_by_routine) != 0 else None
diff --git a/federatedscope/core/trainers/tf_trainer.py b/federatedscope/core/trainers/tf_trainer.py
new file mode 100644
index 000000000..ad372d16d
--- /dev/null
+++ b/federatedscope/core/trainers/tf_trainer.py
@@ -0,0 +1,171 @@
+import tensorflow as tf
+import numpy as np
+from federatedscope.core.trainers import Trainer
+from federatedscope.core.auxiliaries.utils import batch_iter
+
+
+class GeneralTFTrainer(Trainer):
+ def train(self, target_data_split_name="train", hooks_set=None):
+ hooks_set = self.hooks_in_train if hooks_set is None else hooks_set
+ if self.ctx.get(
+ f"{target_data_split_name}_data") is None and self.ctx.get(
+ f"{target_data_split_name}_loader") is None:
+ raise ValueError(
+ f"No {target_data_split_name}_data or {target_data_split_name}_loader in the trainer"
+ )
+ self._run_routine("train", hooks_set, target_data_split_name)
+
+ # TODO: The return values should be more flexible? Now: sample_num, model_para, results={k:v}
+
+ return self.ctx.num_samples_train, self.ctx.model.state_dict(
+ ), self.ctx.eval_metrics
+
+ def parse_data(self, data):
+ """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes
+
+ """
+ init_dict = dict()
+ if isinstance(data, dict):
+ for mode in ["train", "val", "test"]:
+ init_dict["{}_data".format(mode)] = None
+ init_dict["{}_loader".format(mode)] = None
+ init_dict["num_{}_data".format(mode)] = 0
+ if data.get(mode, None) is not None:
+ init_dict["{}_data".format(mode)] = data.get(mode)
+ init_dict["num_{}_data".format(mode)] = len(data.get(mode))
+ else:
+ raise TypeError("Type of data should be dict.")
+ return init_dict
+
+ def register_default_hooks_train(self):
+ self.register_hook_in_train(self._hook_on_fit_start_init,
+ "on_fit_start")
+ self.register_hook_in_train(self._hook_on_epoch_start,
+ "on_epoch_start")
+ self.register_hook_in_train(self._hook_on_batch_start_init,
+ "on_batch_start")
+ self.register_hook_in_train(self._hook_on_batch_forward,
+ "on_batch_forward")
+ self.register_hook_in_train(self._hook_on_batch_forward_regularizer,
+ "on_batch_forward")
+ self.register_hook_in_train(self._hook_on_batch_backward,
+ "on_batch_backward")
+ self.register_hook_in_train(self._hook_on_batch_end, "on_batch_end")
+ self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end")
+
+ def register_default_hooks_eval(self):
+ # test/val
+ self.register_hook_in_eval(self._hook_on_fit_start_init,
+ "on_fit_start")
+ self.register_hook_in_eval(self._hook_on_epoch_start, "on_epoch_start")
+ self.register_hook_in_eval(self._hook_on_batch_start_init,
+ "on_batch_start")
+ self.register_hook_in_eval(self._hook_on_batch_forward,
+ "on_batch_forward")
+ self.register_hook_in_eval(self._hook_on_batch_end, "on_batch_end")
+ self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end")
+
+ def _hook_on_fit_start_init(self, ctx):
+ # prepare model
+ ctx.model.to(ctx.device)
+
+ # prepare statistics
+ setattr(ctx, "loss_batch_total_{}".format(ctx.cur_data_split), 0)
+ setattr(ctx, "loss_regular_total_{}".format(ctx.cur_data_split), 0)
+ setattr(ctx, "num_samples_{}".format(ctx.cur_data_split), 0)
+ setattr(ctx, "{}_y_true".format(ctx.cur_data_split), [])
+ setattr(ctx, "{}_y_prob".format(ctx.cur_data_split), [])
+
+ def _hook_on_epoch_start(self, ctx):
+ # prepare dataloader
+ setattr(ctx, "{}_loader".format(ctx.cur_data_split),
+ batch_iter(ctx.get("{}_data".format(ctx.cur_data_split))))
+
+ def _hook_on_batch_start_init(self, ctx):
+ # prepare data batch
+ try:
+ ctx.data_batch = next(
+ ctx.get("{}_loader".format(ctx.cur_data_split)))
+ except StopIteration:
+ raise StopIteration
+
+ def _hook_on_batch_forward(self, ctx):
+
+ ctx.optimizer = ctx.model.optimizer
+
+ ctx.batch_size = len(ctx.data_batch)
+
+ with ctx.model.graph.as_default():
+ with ctx.model.sess.as_default():
+ feed_dict = {
+ ctx.model.input_x: ctx.data_batch['x'],
+ ctx.model.input_y: ctx.data_batch['y']
+ }
+ _, batch_loss, y_true, y_prob = ctx.model.sess.run(
+ [
+ ctx.model.train_op, ctx.model.losses,
+ ctx.model.input_y, ctx.model.out
+ ],
+ feed_dict=feed_dict)
+ ctx.loss_batch = batch_loss
+ ctx.y_true = y_true
+ ctx.y_prob = y_prob
+
+ def _hook_on_batch_forward_regularizer(self, ctx):
+ pass
+
+ def _hook_on_batch_backward(self, ctx):
+ pass
+
+ def _hook_on_batch_end(self, ctx):
+ # update statistics
+ setattr(
+ ctx, "loss_batch_total_{}".format(ctx.cur_data_split),
+ ctx.get("loss_batch_total_{}".format(ctx.cur_data_split)) +
+ ctx.loss_batch * ctx.batch_size)
+
+ loss_regular = 0.
+ setattr(
+ ctx, "loss_regular_total_{}".format(ctx.cur_data_split),
+ ctx.get("loss_regular_total_{}".format(ctx.cur_data_split)) +
+ loss_regular)
+ setattr(
+ ctx, "num_samples_{}".format(ctx.cur_data_split),
+ ctx.get("num_samples_{}".format(ctx.cur_data_split)) +
+ ctx.batch_size)
+
+ # cache label for evaluate
+ ctx.get("{}_y_true".format(ctx.cur_data_split)).append(ctx.y_true)
+
+ ctx.get("{}_y_prob".format(ctx.cur_data_split)).append(ctx.y_prob)
+
+ # clean temp ctx
+ ctx.data_batch = None
+ ctx.batch_size = None
+ ctx.loss_task = None
+ ctx.loss_batch = None
+ ctx.loss_regular = None
+ ctx.y_true = None
+ ctx.y_prob = None
+
+ def _hook_on_fit_end(self, ctx):
+ """Evaluate metrics.
+
+ """
+ setattr(
+ ctx, "{}_y_true".format(ctx.cur_data_split),
+ np.concatenate(ctx.get("{}_y_true".format(ctx.cur_data_split))))
+ setattr(
+ ctx, "{}_y_prob".format(ctx.cur_data_split),
+ np.concatenate(ctx.get("{}_y_prob".format(ctx.cur_data_split))))
+ results = self.metric_calculator.eval(ctx)
+ setattr(ctx, 'eval_metrics', results)
+
+ def update(self, model_parameters):
+ self.ctx.model.load_state_dict(model_parameters)
+
+ def save_model(self, path, cur_round=-1):
+ pass
+
+ def load_model(self, path):
+ pass
diff --git a/federatedscope/core/trainers/torch_trainer.py b/federatedscope/core/trainers/torch_trainer.py
new file mode 100644
index 000000000..c02ab890c
--- /dev/null
+++ b/federatedscope/core/trainers/torch_trainer.py
@@ -0,0 +1,372 @@
+import os
+import logging
+
+import numpy as np
+
+import torch
+from torch.utils.data import DataLoader, Dataset
+
+from federatedscope.core.trainers.trainer import Trainer
+from federatedscope.core.auxiliaries.dataloader_builder import WrapDataset
+from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader
+from federatedscope.core.auxiliaries.ReIterator import ReIterator
+from federatedscope.core.monitors.monitor import Monitor
+
+logger = logging.getLogger(__name__)
+
+
+class GeneralTorchTrainer(Trainer):
+ def get_model_para(self):
+ return self._param_filter(
+ self.ctx.model.state_dict() if self.cfg.federate.
+ share_local_model else self.ctx.model.cpu().state_dict())
+
+ def parse_data(self, data):
+ """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes
+
+ """
+ # TODO: more robust for different data
+
+ # new function part
+ init_dict = dict()
+ name_list = self.cfg.data.dataset
+ if isinstance(data, dict):
+ for mode in name_list:
+ # new function part
+ init_dict["{}_data".format(mode)] = None
+ init_dict["{}_loader".format(mode)] = None
+ init_dict["num_{}_data".format(mode)] = 0
+ if data.get(mode, None) is not None:
+ if isinstance(data.get(mode), Dataset):
+ init_dict["{}_data".format(mode)] = data.get(mode)
+ init_dict["num_{}_data".format(mode)] = len(
+ data.get(mode))
+ elif isinstance(data.get(mode), DataLoader):
+ init_dict["{}_loader".format(mode)] = data.get(mode)
+ init_dict["num_{}_data".format(mode)] = len(
+ data.get(mode).dataset)
+ elif isinstance(data.get(mode), dict):
+ init_dict["{}_data".format(mode)] = data.get(mode)
+ init_dict["num_{}_data".format(mode)] = len(
+ data.get(mode)['y'])
+ else:
+ raise TypeError("Type {} is not supported.".format(
+ type(data.get(mode))))
+ else:
+ raise TypeError("Type of data should be dict.")
+
+ return init_dict
+
+ def train(self, target_data_split_name="train", hooks_set=None):
+ hooks_set = hooks_set or self.hooks_in_train
+ if self.ctx.get(
+ f"{target_data_split_name}_data") is None and self.ctx.get(
+ f"{target_data_split_name}_loader") is None:
+ raise ValueError(
+ f"No {target_data_split_name}_data or {target_data_split_name}_loader in the trainer"
+ )
+ if self.cfg.federate.use_diff:
+ # TODO: any issue for subclasses?
+ before_metric = self.evaluate(target_data_split_name='val')
+
+ self._run_routine("train", hooks_set, target_data_split_name)
+ result_metric = self.ctx.eval_metrics
+
+ if self.cfg.federate.use_diff:
+ # TODO: any issue for subclasses?
+ after_metric = self.evaluate(target_data_split_name='val')
+ result_metric['val_total'] = before_metric['val_total']
+ result_metric['val_avg_loss_before'] = before_metric[
+ 'val_avg_loss']
+ result_metric['val_avg_loss_after'] = after_metric['val_avg_loss']
+
+ # return self.ctx.num_samples_train, self.get_model_para(), result_metric
+
+ ## new modification
+ ## TO DO
+
+ if self.cfg.federate.weight_avg:
+ averaging_weight = self.ctx.num_train_data
+ else:
+ averaging_weight = 1
+
+ return averaging_weight, self.get_model_para(), result_metric
+
+ def update(self, model_parameters):
+ '''
+ Called by the FL client to update the model parameters
+ Arguments:
+ model_parameters (dict): PyTorch Module object's state_dict.
+ '''
+ for key in model_parameters:
+ if isinstance(model_parameters[key], list):
+ model_parameters[key] = torch.FloatTensor(
+ model_parameters[key])
+ self.ctx.model.load_state_dict(self._param_filter(model_parameters),
+ strict=False)
+
+ def evaluate(self, target_data_split_name="test"):
+ with torch.no_grad():
+ super(GeneralTorchTrainer, self).evaluate(target_data_split_name)
+
+ return self.ctx.eval_metrics
+
+ #def validate(self, target_data_split_name="val"):
+ # with torch.no_grad():
+ # super(GeneralTorchTrainer, self).evaluate(target_data_split_name)
+
+ # return self.ctx.eval_metrics
+
+ def finetune(self, target_data_split_name="train", hooks_set=None):
+
+ # freeze the parameters during the fine-tune stage
+ require_grad_changed_paras = set()
+ if self.cfg.trainer.finetune.freeze_param != "":
+ preserved_paras = self._param_filter(
+ self.ctx.model.state_dict(),
+ self.cfg.trainer.finetune.freeze_param)
+ for name, param in self.ctx.model.named_parameters():
+ if name not in preserved_paras and param.requires_grad is True:
+ param.requires_grad = False
+ require_grad_changed_paras.add(name)
+
+ # change the optimization configs
+ original_lrs = []
+ for g in self.ctx.optimizer.param_groups:
+ original_lrs.append(g['lr'])
+ g['lr'] = self.cfg.trainer.finetune.lr
+ original_epoch_num = self.ctx["num_train_epoch"]
+ original_batch_num = self.ctx["num_train_batch"]
+ self.ctx["num_train_epoch"] = self.cfg.trainer.finetune.epochs
+ # self.ctx["num_train_batch"] = self.cfg.trainer.finetune.steps
+
+ # self.ctx["num_train_epoch"] = 1
+ # self.ctx["num_train_batch"] = self.cfg.trainer.finetune.steps
+
+ # do the fine-tuning process
+ self.train(target_data_split_name, hooks_set)
+
+ # restore the state before fine-tuning
+ if len(require_grad_changed_paras) > 0:
+ for name, param in self.ctx.model.named_parameters():
+ if name in require_grad_changed_paras:
+ param.requires_grad = True
+
+ for i, g in enumerate(self.ctx.optimizer.param_groups):
+ g['lr'] = original_lrs[i]
+
+ self.ctx["num_train_epoch"] = original_epoch_num
+ self.ctx["num_train_batch"] = original_batch_num
+
+ def register_default_hooks_train(self):
+ self.register_hook_in_train(self._hook_on_fit_start_init,
+ "on_fit_start")
+ self.register_hook_in_train(
+ self._hook_on_fit_start_calculate_model_size, "on_fit_start")
+ self.register_hook_in_train(self._hook_on_epoch_start,
+ "on_epoch_start")
+ self.register_hook_in_train(self._hook_on_batch_start_init,
+ "on_batch_start")
+ self.register_hook_in_train(self._hook_on_batch_forward,
+ "on_batch_forward")
+ self.register_hook_in_train(self._hook_on_batch_forward_regularizer,
+ "on_batch_forward")
+ self.register_hook_in_train(self._hook_on_batch_forward_flop_count,
+ "on_batch_forward")
+ self.register_hook_in_train(self._hook_on_batch_backward,
+ "on_batch_backward")
+ self.register_hook_in_train(self._hook_on_batch_end, "on_batch_end")
+ self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end")
+
+ def register_default_hooks_eval(self):
+ # test/val
+ self.register_hook_in_eval(self._hook_on_fit_start_init,
+ "on_fit_start")
+ self.register_hook_in_eval(self._hook_on_epoch_start, "on_epoch_start")
+ self.register_hook_in_eval(self._hook_on_batch_start_init,
+ "on_batch_start")
+ self.register_hook_in_eval(self._hook_on_batch_forward,
+ "on_batch_forward")
+ self.register_hook_in_eval(self._hook_on_batch_end, "on_batch_end")
+ self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end")
+
+ def _hook_on_fit_start_init(self, ctx):
+ # prepare model
+ ctx.model.to(ctx.device)
+
+ #
+
+ # prepare statistics
+ setattr(ctx, "loss_batch_total_{}".format(ctx.cur_data_split), 0)
+ setattr(ctx, "loss_regular_total_{}".format(ctx.cur_data_split), 0)
+ setattr(ctx, "num_samples_{}".format(ctx.cur_data_split), 0)
+ setattr(ctx, "{}_y_true".format(ctx.cur_data_split), [])
+ setattr(ctx, "{}_y_prob".format(ctx.cur_data_split), [])
+
+ def _hook_on_fit_start_calculate_model_size(self, ctx):
+ if not isinstance(self.ctx.monitor, Monitor):
+ logger.warning(
+ f"The trainer {type(self)} does contain a valid monitor, this may be caused by "
+ f"initializing trainer subclasses without passing a valid monitor instance."
+ f"Plz check whether this is you want.")
+ return
+ if self.ctx.monitor.total_model_size == 0:
+ self.ctx.monitor.track_model_size(ctx.models)
+
+ def _hook_on_epoch_start(self, ctx):
+ # prepare dataloader
+ if ctx.get("{}_loader".format(ctx.cur_data_split)) is None:
+ loader = get_dataloader(
+ WrapDataset(ctx.get("{}_data".format(ctx.cur_data_split))),
+ self.cfg)
+ setattr(ctx, "{}_loader".format(ctx.cur_data_split),
+ ReIterator(loader))
+ elif not isinstance(ctx.get("{}_loader".format(ctx.cur_data_split)),
+ ReIterator):
+ setattr(
+ ctx, "{}_loader".format(ctx.cur_data_split),
+ ReIterator(ctx.get("{}_loader".format(ctx.cur_data_split))))
+ else:
+ ctx.get("{}_loader".format(ctx.cur_data_split)).reset()
+
+ def _hook_on_batch_start_init(self, ctx):
+ # prepare data batch
+ try:
+ ctx.data_batch = next(
+ ctx.get("{}_loader".format(ctx.cur_data_split)))
+ except StopIteration:
+ raise StopIteration
+
+ def _hook_on_batch_forward(self, ctx):
+ x, label = [_.to(ctx.device) for _ in ctx.data_batch]
+ pred = ctx.model(x)
+ if len(label.size()) == 0:
+ label = label.unsqueeze(0)
+ ctx.loss_batch = ctx.criterion(pred, label)
+ ctx.y_true = label
+ ctx.y_prob = pred
+
+ ctx.batch_size = len(label)
+
+ def _hook_on_batch_forward_flop_count(self, ctx):
+ """
+ the monitoring hook to calculate the flops during the fl course
+
+ Note: for customized cases that the forward process is not only based on ctx.model,
+ please override this function (inheritance case) or replace this hook (plug-in case)
+
+ :param ctx:
+ :return:
+ """
+ if not isinstance(self.ctx.monitor, Monitor):
+ logger.warning(
+ f"The trainer {type(self)} does contain a valid monitor, this may be caused by "
+ f"initializing trainer subclasses without passing a valid monitor instance."
+ f"Plz check whether this is you want.")
+ return
+
+ if self.ctx.monitor.flops_per_sample == 0:
+ # calculate the flops_per_sample
+ try:
+ x, y = [_.to(ctx.device) for _ in ctx.data_batch]
+ from fvcore.nn import FlopCountAnalysis
+ flops_one_batch = FlopCountAnalysis(ctx.model, x).total()
+ if self.model_nums > 1 and ctx.mirrored_models:
+ flops_one_batch *= self.model_nums
+ logger.warning(
+ "the flops_per_batch is multiplied by internal model nums as self.mirrored_models=True."
+ "if this is not the case you want, please customize the count hook"
+ )
+ self.ctx.monitor.track_avg_flops(flops_one_batch,
+ ctx.batch_size)
+ except:
+ logger.error(
+ "current flop count implementation is for general trainer case: "
+ "1) ctx.data_batch = [x, y]; and"
+ "2) the ctx.model takes only x as input."
+ "Please check the forward format or implement your own flop_count function"
+ )
+
+ # by default, we assume the data has the same input shape,
+ # thus simply multiply the flops to avoid redundant forward
+ self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * ctx.batch_size
+
+ def _hook_on_batch_forward_regularizer(self, ctx):
+ ctx.loss_regular = float(
+ self.cfg.regularizer.mu) * ctx.regularizer(ctx)
+ ctx.loss_task = ctx.loss_batch + ctx.loss_regular
+
+ def _hook_on_batch_backward(self, ctx):
+ ctx.optimizer.zero_grad()
+ ctx.loss_task.backward()
+ if ctx.grad_clip > 0:
+ torch.nn.utils.clip_grad_norm_(ctx.model.parameters(),
+ ctx.grad_clip)
+ ctx.optimizer.step()
+
+ def _hook_on_batch_end(self, ctx):
+ # update statistics
+ setattr(
+ ctx, "loss_batch_total_{}".format(ctx.cur_data_split),
+ ctx.get("loss_batch_total_{}".format(ctx.cur_data_split)) +
+ ctx.loss_batch.item() * ctx.batch_size)
+
+ if ctx.get("loss_regular", None) is None or ctx.loss_regular == 0:
+ loss_regular = 0.
+ else:
+ loss_regular = ctx.loss_regular.item()
+ setattr(
+ ctx, "loss_regular_total_{}".format(ctx.cur_data_split),
+ ctx.get("loss_regular_total_{}".format(ctx.cur_data_split)) +
+ loss_regular)
+ setattr(
+ ctx, "num_samples_{}".format(ctx.cur_data_split),
+ ctx.get("num_samples_{}".format(ctx.cur_data_split)) +
+ ctx.batch_size)
+
+ # cache label for evaluate
+ ctx.get("{}_y_true".format(ctx.cur_data_split)).append(
+ ctx.y_true.detach().cpu().numpy())
+
+ ctx.get("{}_y_prob".format(ctx.cur_data_split)).append(
+ ctx.y_prob.detach().cpu().numpy())
+
+ # clean temp ctx
+ ctx.data_batch = None
+ ctx.batch_size = None
+ ctx.loss_task = None
+ ctx.loss_batch = None
+ ctx.loss_regular = None
+ ctx.y_true = None
+ ctx.y_prob = None
+
+ def _hook_on_fit_end(self, ctx):
+ """Evaluate metrics.
+
+ """
+ #
+ setattr(
+ ctx, "{}_y_true".format(ctx.cur_data_split),
+ np.concatenate(ctx.get("{}_y_true".format(ctx.cur_data_split))))
+ setattr(
+ ctx, "{}_y_prob".format(ctx.cur_data_split),
+ np.concatenate(ctx.get("{}_y_prob".format(ctx.cur_data_split))))
+ results = self.metric_calculator.eval(ctx)
+
+ setattr(ctx, 'eval_metrics', results)
+
+ def save_model(self, path, cur_round=-1):
+ assert self.ctx.model is not None
+
+ ckpt = {'cur_round': cur_round, 'model': self.ctx.model.state_dict()}
+ torch.save(ckpt, path)
+
+ def load_model(self, path):
+ assert self.ctx.model is not None
+
+ if os.path.exists(path):
+ ckpt = torch.load(path, map_location=self.ctx.device)
+ self.ctx.model.load_state_dict(ckpt['model'])
+ return ckpt['cur_round']
+ else:
+ raise ValueError("The file {} does NOT exist".format(path))
diff --git a/federatedscope/core/trainers/trainer.py b/federatedscope/core/trainers/trainer.py
new file mode 100644
index 000000000..f59f17536
--- /dev/null
+++ b/federatedscope/core/trainers/trainer.py
@@ -0,0 +1,343 @@
+import collections
+import copy
+import logging
+import os
+import numpy as np
+
+from federatedscope.core.auxiliaries import utils
+from federatedscope.core.trainers.context import Context
+from federatedscope.core.monitors.metric_calculator import MetricCalculator
+
+import torch
+from torch.utils.data import DataLoader, Dataset
+
+logger = logging.getLogger(__name__)
+
+
+class Trainer(object):
+ """
+ Register, organize and run the train/test/val procedures
+ """
+
+ HOOK_TRIGGER = [
+ "on_fit_start", "on_epoch_start", "on_batch_start", "on_batch_forward",
+ "on_batch_backward", "on_batch_end", "on_epoch_end", "on_fit_end"
+ ]
+
+ def __init__(self,
+ model,
+ data,
+ device,
+ config,
+ only_for_eval=False,
+ monitor=None):
+ self.cfg = config
+ self.metric_calculator = MetricCalculator(config.eval.metrics)
+
+ self.ctx = Context(model,
+ self.cfg,
+ data,
+ device,
+ init_dict=self.parse_data(data))
+
+ if monitor is None:
+ logger.warning(
+ f"Will not use monitor in trainer with class {type(self)}")
+ self.ctx.monitor = monitor
+ # the "model_nums", and "models" are used for multi-model case and model size calculation
+ self.model_nums = 1
+ self.ctx.models = [model]
+ # "mirrored_models": whether the internal multi-models adopt the same architects and almost the same behaviors,
+ # which is used to simply the flops, model size calculation
+ self.ctx.mirrored_models = False
+
+ # Atomic operation during training/evaluation
+ self.hooks_in_train = collections.defaultdict(list)
+
+ # By default, use the same trigger keys
+ self.hooks_in_eval = copy.deepcopy(self.hooks_in_train)
+
+ # register necessary hooks into self.hooks_in_train and self.hooks_in_eval
+ if not only_for_eval:
+ self.register_default_hooks_train()
+ self.register_default_hooks_eval()
+
+ if self.cfg.federate.mode == 'distributed':
+ self.print_trainer_meta_info()
+ else:
+ # in standalone mode, by default, we print the trainer info only once for better logs readability
+ pass
+
+ def parse_data(self, data):
+ pass
+
+ def register_default_hooks_train(self):
+ pass
+
+ def register_default_hooks_eval(self):
+ pass
+
+ def reset_hook_in_train(self, target_trigger, target_hook_name=None):
+ hooks_dict = self.hooks_in_train
+ del_one_hook_idx = self._reset_hook_in_trigger(hooks_dict,
+ target_hook_name,
+ target_trigger)
+ return del_one_hook_idx
+
+ def reset_hook_in_eval(self, target_trigger, target_hook_name=None):
+ hooks_dict = self.hooks_in_eval
+ del_one_hook_idx = self._reset_hook_in_trigger(hooks_dict,
+ target_hook_name,
+ target_trigger)
+ return del_one_hook_idx
+
+ def replace_hook_in_train(self, new_hook, target_trigger,
+ target_hook_name):
+ del_one_hook_idx = self.reset_hook_in_train(
+ target_trigger=target_trigger, target_hook_name=target_hook_name)
+ self.register_hook_in_train(new_hook=new_hook,
+ trigger=target_trigger,
+ insert_pos=del_one_hook_idx)
+
+ def replace_hook_in_eval(self, new_hook, target_trigger, target_hook_name):
+ del_one_hook_idx = self.reset_hook_in_eval(
+ target_trigger=target_trigger, target_hook_name=target_hook_name)
+ self.register_hook_in_eval(new_hook=new_hook,
+ trigger=target_trigger,
+ insert_pos=del_one_hook_idx)
+
+ def _reset_hook_in_trigger(self, hooks_dict, target_hook_name,
+ target_trigger):
+ # clean/delete existing hooks for a specific trigger,
+ # if target_hook_name given, will clean only the specific one; otherwise, will clean all hooks for the trigger.
+ assert target_trigger in self.HOOK_TRIGGER, \
+ f"Got {target_trigger} as hook trigger, you should specify a string within {self.HOOK_TRIGGER}."
+ del_one_hook_idx = None
+ if target_hook_name is None:
+ hooks_dict[target_trigger] = []
+ del_one_hook_idx = -1 # -1 indicates del the whole list
+ else:
+ for hook_idx in range(len(hooks_dict[target_trigger])):
+ if target_hook_name == hooks_dict[target_trigger][
+ hook_idx].__name__:
+ del_one = hooks_dict[target_trigger].pop(hook_idx)
+ logger.info(
+ f"Remove the hook `{del_one}` from hooks_set at trigger `{target_trigger}`"
+ )
+ del_one_hook_idx = hook_idx
+ break
+ if del_one_hook_idx is None:
+ logger.warning(
+ f"In hook del procedure, can't find the target hook named {target_hook_name}"
+ )
+ return del_one_hook_idx
+
+ def register_hook_in_train(self,
+ new_hook,
+ trigger,
+ insert_pos=None,
+ base_hook=None,
+ insert_mode="before"):
+ hooks_dict = self.hooks_in_train
+ self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
+ new_hook, trigger)
+
+ def register_hook_in_eval(self,
+ new_hook,
+ trigger,
+ insert_pos=None,
+ base_hook=None,
+ insert_mode="before"):
+ hooks_dict = self.hooks_in_eval
+ self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
+ new_hook, trigger)
+
+ def _register_hook(self, base_hook, hooks_dict, insert_mode, insert_pos,
+ new_hook, trigger):
+ assert trigger in self.HOOK_TRIGGER, \
+ f"Got {trigger} as hook trigger, you should specify a string within {self.HOOK_TRIGGER}."
+ # parse the insertion position
+ target_hook_set = hooks_dict[trigger]
+ if insert_pos is not None:
+ assert (insert_pos == -1) or (insert_pos == len(target_hook_set) == 0) or \
+ (0 <= insert_pos <= (len(target_hook_set))), \
+ f"Got {insert_pos} as insert pos, you should specify a integer (1) =-1 " \
+ f"or (2) =0 for null target_hook_set;" \
+ f"or (3) within [0, {(len(target_hook_set))}]."
+ elif base_hook is not None:
+ base_hook_pos = target_hook_set.index(base_hook)
+ insert_pos = base_hook_pos - 1 if insert_mode == "before" else base_hook_pos + 1
+ # bounding the insert_pos in rational range
+ insert_pos = 0 if insert_pos < 0 else insert_pos
+ insert_pos = -1 if insert_pos > len(
+ target_hook_set) else insert_pos
+ else:
+ insert_pos = -1 # By default, the new hook is called finally
+ # register the new hook
+ if insert_pos == -1:
+ hooks_dict[trigger].append(new_hook)
+ else:
+ hooks_dict[trigger].insert(insert_pos, new_hook)
+
+ def train(self, target_data_split_name="train", hooks_set=None):
+ pass
+
+ def evaluate(self, target_data_split_name="test", hooks_set=None):
+ hooks_set = hooks_set or self.hooks_in_eval
+ if self.ctx.get(
+ f"{target_data_split_name}_data") is None and self.ctx.get(
+ f"{target_data_split_name}_loader") is None:
+ logger.warning(
+ f"No {target_data_split_name}_data or {target_data_split_name}_loader in the trainer, will skip evaluation"
+ f"If this is not the case you want, please check whether there is typo for the name"
+ )
+ self.ctx.eval_metrics = {}
+ else:
+ self._run_routine("test", hooks_set, target_data_split_name)
+
+ return self.ctx.eval_metrics
+
+ def _run_routine(self, mode, hooks_set, dataset_name=None):
+ """Run the hooks_set and maintain the mode
+
+ Arguments:
+ mode (str): running mode of client, chosen from train/test
+ hooks_set (dict): functions to be executed.
+ dataset_name (str): which split.
+
+ Note:
+ Considering evaluation could be in ```hooks_set["on_epoch_end"]```, there could be two data loaders in
+ self.ctx, we must tell the running hooks which data_loader to call and which num_samples to count
+
+ """
+ if dataset_name is None:
+ dataset_name = mode
+ self.ctx.append_mode(mode)
+ self.ctx.track_used_dataset(dataset_name)
+
+ # print(hooks_set["on_fit_start"])
+ for hook in hooks_set["on_fit_start"]:
+ hook(self.ctx)
+
+ for epoch_i in range(self.ctx.get(
+ "num_{}_epoch".format(dataset_name))):
+ self.ctx.cur_epoch_i = epoch_i
+ for hook in hooks_set["on_epoch_start"]:
+ hook(self.ctx)
+
+ for batch_i in range(
+ self.ctx.get("num_{}_batch".format(dataset_name))):
+ self.ctx.cur_batch_i = batch_i
+ for hook in hooks_set["on_batch_start"]:
+ hook(self.ctx)
+ for hook in hooks_set["on_batch_forward"]:
+ hook(self.ctx)
+ if self.ctx.cur_mode == 'train':
+ for hook in hooks_set["on_batch_backward"]:
+ hook(self.ctx)
+ for hook in hooks_set["on_batch_end"]:
+ hook(self.ctx)
+
+ # Break in the final epoch
+ if self.ctx.cur_mode == 'train' and epoch_i == self.ctx.num_train_epoch - 1:
+ if batch_i >= self.ctx.num_train_batch_last_epoch - 1:
+ break
+
+ for hook in hooks_set["on_epoch_end"]:
+ hook(self.ctx)
+ for hook in hooks_set["on_fit_end"]:
+ hook(self.ctx)
+
+ self.ctx.pop_mode()
+ self.ctx.reset_used_dataset()
+ # Avoid memory leak
+ if not self.cfg.federate.share_local_model:
+ if torch is None:
+ pass
+ else:
+ self.ctx.model.to(torch.device("cpu"))
+
+ def update(self, model_parameters):
+ '''
+ Called by the FL client to update the model parameters
+ Arguments:
+ model_parameters (dict): {model_name: model_val}
+ '''
+ pass
+
+ def get_model_para(self):
+ '''
+
+ :return: model_parameters (dict): {model_name: model_val}
+ '''
+ pass
+
+ def print_trainer_meta_info(self):
+ '''
+ print some meta info for code-users, e.g., model type; the para names will be filtered out, etc.,
+ '''
+ logger.info(f"Model meta-info: {type(self.ctx.model)}.")
+ logger.debug(f"Model meta-info: {self.ctx.model}.")
+ # logger.info(f"Data meta-info: {self.ctx['data']}.")
+
+ ori_para_names = set(self.ctx.model.state_dict().keys())
+ preserved_paras = self._param_filter(self.ctx.model.state_dict())
+ preserved_para_names = set(preserved_paras.keys())
+ filtered_para_names = ori_para_names - preserved_para_names
+ logger.info(f"Num of original para names: {len(ori_para_names)}.")
+ logger.info(
+ f"Num of original trainable para names: {len(self.ctx['trainable_para_names'])}."
+ )
+ logger.info(
+ f"Num of preserved para names in local update: {len(preserved_para_names)}. \n"
+ f"Preserved para names in local update: {preserved_para_names}.")
+ logger.info(
+ f"Num of filtered para names in local update: {len(filtered_para_names)}. \n"
+ f"Filtered para names in local update: {filtered_para_names}.")
+
+ logger.info(f"After register default hooks,\n"
+ f"\tthe hooks_in_train is: {self.hooks_in_train};\n"
+ f"\tthe hooks_in_eval is {self.hooks_in_eval}")
+
+ def finetune(self):
+ pass
+
+ def _param_filter(self, state_dict, filter_keywords=None):
+ '''
+ model parameter filter when transmit between local and gloabl, which is useful in personalization.
+ e.g., setting cfg.personalization.local_param= ['bn', 'norms'] indicates the implementation of
+ "FedBN: Federated Learning on Non-IID Features via Local Batch Normalization, ICML2021", which can be found in https://openreview.net/forum?id=6YEQUn0QICG
+
+ Arguments:
+ state_dict (dict): PyTorch Module object's state_dict.
+ Returns:
+ state_dict (dict): remove the keys that match any of the given keywords.
+ '''
+ if self.cfg.federate.method in ["local", "global"]:
+ return {}
+
+ if filter_keywords is None:
+ filter_keywords = self.cfg.personalization.local_param
+
+ trainable_filter = lambda p: True if self.cfg.personalization.share_non_trainable_para else \
+ lambda p: p in self.ctx.trainable_para_names
+ keyword_filter = utils.filter_by_specified_keywords
+ hhh = dict(
+ filter(
+ lambda elem: trainable_filter(elem[1]) and keyword_filter(
+ elem[0], filter_keywords), state_dict.items()))
+ #
+ return dict(
+ filter(
+ lambda elem: trainable_filter(elem[1]) and keyword_filter(
+ elem[0], filter_keywords), state_dict.items()))
+
+ def save_model(self, path, cur_round=-1):
+ raise NotImplementedError(
+ "The function `save_model` should be implemented according to the ML backend (Pytorch, Tensorflow ...)."
+ )
+
+ def load_model(self, path):
+ raise NotImplementedError(
+ "The function `load_model` should be implemented according to the ML backend (Pytorch, Tensorflow ...)."
+ )
diff --git a/federatedscope/core/trainers/trainer_Ditto.py b/federatedscope/core/trainers/trainer_Ditto.py
new file mode 100644
index 000000000..aefbcadcd
--- /dev/null
+++ b/federatedscope/core/trainers/trainer_Ditto.py
@@ -0,0 +1,196 @@
+import copy
+import logging
+
+import torch
+
+from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
+from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
+from federatedscope.core.optimizer import wrap_regularized_optimizer
+# from federatedscope.core.auxiliaries.utils import calculate_batch_epoch_num
+from typing import Type
+
+logger = logging.getLogger(__name__)
+
+DEBUG_DITTO = False
+
+
+def wrap_DittoTrainer(
+ base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
+ """
+ Build a `DittoTrainer` with a plug-in manner, by registering new
+ functions into specific `BaseTrainer`
+
+ The Ditto implementation, "Ditto: Fair and Robust Federated Learning
+ Through Personalization. (ICML2021)"
+ based on the Algorithm 2 in their paper and official codes:
+ https://github.com/litian96/ditto
+ """
+
+ # ---------------- attribute-level plug-in -----------------------
+ init_Ditto_ctx(base_trainer)
+
+ # ---------------- action-level plug-in -----------------------
+ base_trainer.register_hook_in_train(new_hook=_hook_on_fit_start_clean,
+ trigger='on_fit_start',
+ insert_pos=-1)
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_fit_start_set_regularized_para,
+ trigger="on_fit_start",
+ insert_pos=0)
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_batch_start_switch_model,
+ trigger="on_batch_start",
+ insert_pos=0)
+ base_trainer.register_hook_in_train(new_hook=hook_on_batch_forward_cnt_num,
+ trigger="on_batch_forward",
+ insert_pos=-1)
+ base_trainer.register_hook_in_train(new_hook=_hook_on_batch_end_flop_count,
+ trigger="on_batch_end",
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_eval(
+ new_hook=hook_on_fit_start_switch_local_model,
+ trigger="on_fit_start",
+ insert_pos=0)
+
+ base_trainer.register_hook_in_eval(
+ new_hook=hook_on_fit_end_switch_global_model,
+ trigger="on_fit_end",
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_free_cuda,
+ trigger="on_fit_end",
+ insert_pos=-1)
+ base_trainer.register_hook_in_eval(new_hook=hook_on_fit_end_free_cuda,
+ trigger="on_fit_end",
+ insert_pos=-1)
+
+ return base_trainer
+
+
+
+
+def init_Ditto_ctx(base_trainer):
+ """
+ init necessary attributes used in Ditto,
+ `global_model` acts as the shared global model in FedAvg;
+ `local_model` acts as personalized model will be optimized with
+ regularization based on weights of `global_model`
+
+ """
+ ctx = base_trainer.ctx
+ cfg = base_trainer.cfg
+
+ ctx.global_model = copy.deepcopy(ctx.model)
+ ctx.local_model = copy.deepcopy(ctx.model)
+ ctx.models = [ctx.local_model, ctx.global_model]
+
+ ctx.model = ctx.global_model
+ ctx.use_local_model_current = False
+
+ ctx.num_samples_local_model_train = 0
+
+ ctx.num_train_batch_for_local_model, \
+ ctx.num_train_batch_last_epoch_for_local_model, \
+ ctx.num_train_epoch_for_local_model, \
+ ctx.num_total_train_batch \
+ = ctx.pre_calculate_batch_epoch_num \
+ (cfg.personalization.local_update_steps)
+
+ if cfg.federate.batch_or_epoch == 'batch':
+ ctx.num_train_batch += ctx.num_train_batch_for_local_model
+ ctx.num_train_batch_last_epoch += \
+ ctx.num_train_batch_last_epoch_for_local_model
+ else:
+ ctx.num_train_epoch += ctx.num_train_epoch_for_local_model
+
+
+
+
+def hook_on_fit_start_set_regularized_para(ctx):
+ # set the compared model data for local personalized model
+ ctx.global_model.to(ctx.device)
+ ctx.local_model.to(ctx.device)
+ ctx.global_model.train()
+ ctx.local_model.train()
+ compared_global_model_para = [{
+ "params": list(ctx.global_model.parameters())
+ }]
+
+ ctx.optimizer_for_global_model = get_optimizer(ctx.global_model,
+ **ctx.cfg.optimizer)
+ ctx.optimizer_for_local_model = get_optimizer(ctx.local_model,
+ **ctx.cfg.optimizer)
+
+ ctx.optimizer_for_local_model = \
+ wrap_regularized_optimizer(ctx.optimizer_for_local_model, \
+ ctx.cfg.personalization.regular_weight)
+
+ ctx.optimizer_for_local_model.set_compared_para_group(
+ compared_global_model_para)
+
+
+
+def _hook_on_fit_start_clean(ctx):
+
+ del ctx.optimizer
+ ctx.num_samples_local_model_train = 0
+
+
+
+def _hook_on_batch_end_flop_count(ctx):
+
+ ctx.monitor.total_flops += ctx.monitor.total_model_size / 2
+
+
+
+def hook_on_batch_forward_cnt_num(ctx):
+ if ctx.use_local_model_current:
+ ctx.num_samples_local_model_train += ctx.batch_size
+
+
+def hook_on_batch_start_switch_model(ctx):
+ if ctx.cfg.federate.batch_or_epoch == 'batch':
+ if ctx.cur_epoch_i == (ctx.num_train_epoch - 1):
+ ctx.use_local_model_current = \
+ ctx.cur_batch_i < \
+ ctx.num_train_batch_last_epoch_for_local_model
+ else:
+ ctx.use_local_model_current = \
+ ctx.cur_batch_i < ctx.num_train_batch_for_local_model
+ else:
+ ctx.use_local_model_current = \
+ ctx.cur_epoch_i < ctx.num_train_epoch_for_local_model
+
+ if DEBUG_DITTO:
+ logger.info("====================================================")
+ logger.info(f"cur_epoch_i: {ctx.cur_epoch_i}")
+ logger.info(f"num_train_epoch: {ctx.num_train_epoch}")
+ logger.info(f"cur_batch_i: {ctx.cur_batch_i}")
+ logger.info(f"num_train_batch: {ctx.num_train_batch}")
+ logger.info(f"num_train_batch_for_local_model: "
+ f"{ctx.num_train_batch_for_local_model}")
+ logger.info(f"num_train_epoch_for_local_model: "
+ f"{ctx.num_train_epoch_for_local_model}")
+ logger.info(f"use_local_model: {ctx.use_local_model_current}")
+
+ if ctx.use_local_model_current:
+ ctx.model = ctx.local_model
+ ctx.optimizer = ctx.optimizer_for_local_model
+ else:
+ ctx.model = ctx.global_model
+ ctx.optimizer = ctx.optimizer_for_global_model
+
+
+def hook_on_fit_start_switch_local_model(ctx):
+ ctx.model = ctx.local_model
+ ctx.model.eval()
+
+
+def hook_on_fit_end_switch_global_model(ctx):
+ ctx.model = ctx.global_model
+
+
+def hook_on_fit_end_free_cuda(ctx):
+ ctx.global_model.to(torch.device("cpu"))
+ ctx.local_model.to(torch.device("cpu"))
diff --git a/federatedscope/core/trainers/trainer_FedEM.py b/federatedscope/core/trainers/trainer_FedEM.py
new file mode 100644
index 000000000..d120045f7
--- /dev/null
+++ b/federatedscope/core/trainers/trainer_FedEM.py
@@ -0,0 +1,176 @@
+from typing import Type
+import os
+import numpy as np
+import torch
+from torch.nn.functional import softmax as f_softmax
+
+from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
+from federatedscope.core.trainers.trainer_multi_model import GeneralMultiModelTrainer
+
+
+class FedEMTrainer(GeneralMultiModelTrainer):
+ """
+ The FedEM implementation, "Federated Multi-Task Learning under a Mixture of Distributions (NeurIPS 2021)"
+ based on the Algorithm 1 in their paper and official codes: https://github.com/omarfoq/FedEM
+ """
+ def __init__(self,
+ model_nums,
+ models_interact_mode="sequential",
+ model=None,
+ data=None,
+ device=None,
+ config=None,
+ base_trainer: Type[GeneralTorchTrainer] = None):
+ super(FedEMTrainer,
+ self).__init__(model_nums, models_interact_mode, model, data,
+ device, config, base_trainer)
+ device = self.ctx.device
+
+ # ---------------- attribute-level modifications -----------------------
+ # used to mixture the internal models
+ self.weights_internal_models = (torch.ones(self.model_nums) /
+ self.model_nums).to(device)
+ self.weights_data_sample = (
+ torch.ones(self.model_nums, self.ctx.num_train_batch) /
+ self.model_nums).to(device)
+
+ self.ctx.all_losses_model_batch = torch.zeros(
+ self.model_nums, self.ctx.num_train_batch).to(device)
+ self.ctx.cur_batch_idx = -1
+ # `ctx[f"{cur_data}_y_prob_ensemble"] = 0` in func `_hook_on_fit_end_ensemble_eval`
+ # -> self.ctx.test_y_prob_ensemble = 0
+ # -> self.ctx.train_y_prob_ensemble = 0
+ # -> self.ctx.val_y_prob_ensemble = 0
+
+ # ---------------- action-level modifications -----------------------
+ # see register_multiple_model_hooks(), which is called in the __init__ of `GeneralMultiModelTrainer`
+
+ def register_multiple_model_hooks(self):
+ """
+ customized multiple_model_hooks, which is called in the __init__ of `GeneralMultiModelTrainer`
+ """
+ # First register hooks for model 0
+ # ---------------- train hooks -----------------------
+ self.register_hook_in_train(
+ new_hook=self.hook_on_fit_start_mixture_weights_update,
+ trigger="on_fit_start",
+ insert_pos=0) # insert at the front
+ self.register_hook_in_train(
+ new_hook=self._hook_on_fit_start_flop_count,
+ trigger="on_fit_start",
+ insert_pos=1 # follow the mixture operation
+ )
+ self.register_hook_in_train(new_hook=self._hook_on_fit_end_flop_count,
+ trigger="on_fit_end",
+ insert_pos=-1)
+ self.register_hook_in_train(
+ new_hook=self.hook_on_batch_forward_weighted_loss,
+ trigger="on_batch_forward",
+ insert_pos=-1)
+ self.register_hook_in_train(
+ new_hook=self.hook_on_batch_start_track_batch_idx,
+ trigger="on_batch_start",
+ insert_pos=0) # insert at the front
+ # ---------------- eval hooks -----------------------
+ self.register_hook_in_eval(new_hook=self.save_local_model,
+ trigger="on_fit_start",
+ insert_pos=-1)
+ self.register_hook_in_eval(
+ new_hook=self.hook_on_batch_end_gather_loss,
+ trigger="on_batch_end",
+ insert_pos=0
+ ) # insert at the front, (we need gather the loss before clean it)
+ self.register_hook_in_eval(
+ new_hook=self.hook_on_batch_start_track_batch_idx,
+ trigger="on_batch_start",
+ insert_pos=0) # insert at the front
+ # replace the original evaluation into the ensemble one
+ self.replace_hook_in_eval(new_hook=self._hook_on_fit_end_ensemble_eval,
+ target_trigger="on_fit_end",
+ target_hook_name="_hook_on_fit_end")
+
+ # Then for other models, set the same hooks as model 0
+ # since we differentiate different models in the hook implementations via ctx.cur_model_idx
+ self.hooks_in_train_multiple_models.extend([
+ self.hooks_in_train_multiple_models[0]
+ for _ in range(1, self.model_nums)
+ ])
+ self.hooks_in_eval_multiple_models.extend([
+ self.hooks_in_eval_multiple_models[0]
+ for _ in range(1, self.model_nums)
+ ])
+
+ def hook_on_batch_start_track_batch_idx(self, ctx):
+ # for both train & eval
+ ctx.cur_batch_idx = (self.ctx.cur_batch_idx +
+ 1) % self.ctx.num_train_batch
+
+ def hook_on_batch_forward_weighted_loss(self, ctx):
+ # for only train
+ ctx.loss_batch *= self.weights_internal_models[ctx.cur_model_idx]
+
+ def hook_on_batch_end_gather_loss(self, ctx):
+ # for only eval
+ # before clean the loss_batch; we record it for further weights_data_sample update
+ ctx.all_losses_model_batch[ctx.cur_model_idx][
+ ctx.cur_batch_idx] = ctx.loss_batch.item()
+
+ def hook_on_fit_start_mixture_weights_update(self, ctx):
+ # for only train
+ if ctx.cur_model_idx != 0:
+ # do the mixture_weights_update once
+ pass
+ else:
+ # gathers losses for all sample in iterator for each internal model, calling *evaluate()*
+ for model_idx in range(self.model_nums):
+ self._switch_model_ctx(model_idx)
+ self.evaluate(target_data_split_name="train")
+
+ self.weights_data_sample = f_softmax(
+ (torch.log(self.weights_internal_models) -
+ ctx.all_losses_model_batch.T),
+ dim=1).T
+ self.weights_internal_models = self.weights_data_sample.mean(dim=1)
+
+ # restore the model_ctx
+ self._switch_model_ctx(0)
+
+ def _hook_on_fit_start_flop_count(self, ctx):
+ self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * self.model_nums * ctx.num_train_data
+
+ def _hook_on_fit_end_flop_count(self, ctx):
+ self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * self.model_nums * ctx.num_train_data
+
+ def _hook_on_fit_end_ensemble_eval(self, ctx):
+ """
+ Ensemble evaluation
+ """
+ cur_data = ctx.cur_data_split
+ if f"{cur_data}_y_prob_ensemble" not in ctx:
+ ctx[f"{cur_data}_y_prob_ensemble"] = 0
+ ctx[f"{cur_data}_y_prob_ensemble"] += np.concatenate(ctx[f"{cur_data}_y_prob"]) * \
+ self.weights_internal_models[ctx.cur_model_idx].item()
+
+ # do metrics calculation after the last internal model evaluation done
+ if ctx.cur_model_idx == self.model_nums - 1:
+ ctx[f"{cur_data}_y_true"] = np.concatenate(
+ ctx[f"{cur_data}_y_true"])
+ ctx[f"{cur_data}_y_prob"] = ctx[f"{cur_data}_y_prob_ensemble"]
+ ctx.eval_metrics = self.metric_calculator.eval(ctx)
+ # reset for next run_routine that may have different len([f"{cur_data}_y_prob"])
+ ctx[f"{cur_data}_y_prob_ensemble"] = 0
+
+ ctx[f"{cur_data}_y_prob"] = []
+ ctx[f"{cur_data}_y_true"] = []
+
+ def save_local_model(self, ctx):
+ i = 0
+ for model_ in self.ctx.models:
+ path = '/mnt/zeyuqin/FederatedScope/exp/FedEM_resnet18_on_CIFAR10@torchvision_lr0.5_lepoch1/backdoor_hkTrigger_fix' + '/model_' + str(
+ i) + '.pth'
+ if os.path.exists(path):
+ break
+ else:
+ ckpt = {'model': model_.state_dict()}
+ torch.save(ckpt, path)
+ i += 1
diff --git a/federatedscope/core/trainers/trainer_FedRep.py b/federatedscope/core/trainers/trainer_FedRep.py
new file mode 100644
index 000000000..d7afa2f9e
--- /dev/null
+++ b/federatedscope/core/trainers/trainer_FedRep.py
@@ -0,0 +1,103 @@
+import copy
+import torch
+import logging
+
+from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
+from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
+
+from typing import Type
+
+logger = logging.getLogger(__name__)
+
+
+def wrap_FedRepTrainer(
+ base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
+ """
+ Build a `FedRapTrainer` with a plug-in manner, by registering new functions into specific `BaseTrainer`
+
+
+ """
+
+ init_FedRep_ctx(base_trainer)
+
+ base_trainer.register_hook_in_train(new_hook=hook_on_fit_start_fedrep,
+ trigger="on_fit_start",
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_train(new_hook=hook_on_epoch_start_fedrep,
+ trigger="on_epoch_start",
+ insert_pos=-1)
+
+ return base_trainer
+
+
+def init_FedRep_ctx(base_trainer):
+
+ ctx = base_trainer.ctx
+ cfg = base_trainer.cfg
+
+ ctx.epoch_feature = cfg.personalization.epoch_feature
+ ctx.epoch_linear = cfg.personalization.epoch_linear
+
+ ctx.num_train_epoch = ctx.epoch_feature + ctx.epoch_linear
+
+ ctx.epoch_number = 0
+
+ ctx.lr_feature = cfg.personalization.lr_feature
+ ctx.lr_linear = cfg.personalization.lr_linear
+ ctx.weight_decay = cfg.personalization.weight_decay
+
+ ctx.local_param = cfg.personalization.local_param
+
+ ctx.local_update_param = []
+ ctx.global_update_param = []
+
+ for name, param in ctx.model.named_parameters():
+ if name.split(".")[0] in ctx.local_param:
+ ctx.local_update_param.append(param)
+ else:
+ ctx.global_update_param.append(param)
+
+ del ctx.optimizer
+
+
+def hook_on_fit_start_fedrep(ctx):
+
+ ctx.num_train_epoch = ctx.epoch_feature + ctx.epoch_linear
+ ctx.epoch_number = 0
+
+ ctx.optimizer_for_feature = torch.optim.SGD(ctx.global_update_param,
+ lr=ctx.lr_feature,
+ momentum=0,
+ weight_decay=ctx.weight_decay)
+ ctx.optimizer_for_linear = torch.optim.SGD(ctx.local_update_param,
+ lr=ctx.lr_linear,
+ momentum=0,
+ weight_decay=ctx.weight_decay)
+
+ for name, param in ctx.model.named_parameters():
+
+ if name.split(".")[0] in ctx.local_param:
+ param.requires_grad = True
+ else:
+ param.requires_grad = False
+
+ ctx.optimizer = ctx.optimizer_for_linear
+
+
+def hook_on_epoch_start_fedrep(ctx):
+
+ ctx.epoch_number += 1
+
+ if ctx.epoch_number == ctx.epoch_linear + 1:
+
+ for name, param in ctx.model.named_parameters():
+
+ if name.split(".")[0] in ctx.local_param:
+ param.requires_grad = False
+ else:
+ param.requires_grad = True
+
+ ctx.optimizer = ctx.optimizer_for_feature
+ print('the linear classifier learning rate: {}'.format(ctx.lr_linear))
+ print('the feature extractor learning rate: {}'.format(ctx.lr_feature))
diff --git a/federatedscope/core/trainers/trainer_fedprox.py b/federatedscope/core/trainers/trainer_fedprox.py
new file mode 100644
index 000000000..7857fabc5
--- /dev/null
+++ b/federatedscope/core/trainers/trainer_fedprox.py
@@ -0,0 +1,70 @@
+from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
+from typing import Type
+from copy import deepcopy
+
+
+def wrap_fedprox_trainer(
+ base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
+ """Implementation of fedprox refer to `Federated Optimization in Heterogeneous Networks` [Tian Li, et al., 2020]
+ (https://proceedings.mlsys.org/paper/2020/file/38af86134b65d0f10fe33d30dd76442e-Paper.pdf)
+
+ """
+
+ # ---------------- attribute-level plug-in -----------------------
+ init_fedprox_ctx(base_trainer)
+
+ # ---------------- action-level plug-in -----------------------
+ base_trainer.register_hook_in_train(new_hook=record_initialization,
+ trigger='on_fit_start',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_eval(new_hook=record_initialization,
+ trigger='on_fit_start',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_train(new_hook=del_initialization,
+ trigger='on_fit_end',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_eval(new_hook=del_initialization,
+ trigger='on_fit_end',
+ insert_pos=-1)
+
+ return base_trainer
+
+
+def init_fedprox_ctx(base_trainer):
+ """Set proximal regularizer and the factor of regularizer
+
+ """
+ ctx = base_trainer.ctx
+ cfg = base_trainer.cfg
+
+ cfg.defrost()
+ cfg.regularizer.type = 'proximal_regularizer'
+ cfg.regularizer.mu = cfg.fedprox.mu
+ cfg.freeze()
+
+ from federatedscope.core.auxiliaries.regularizer_builder import get_regularizer
+ ctx.regularizer = get_regularizer(cfg.regularizer.type)
+
+
+# ------------------------------------------------------------------------ #
+# Additional functions for FedProx algorithm
+# ------------------------------------------------------------------------ #
+
+
+# Trainer
+def record_initialization(ctx):
+ """Record the initialized weights within local updates
+
+ """
+ ctx.weight_init = deepcopy(
+ [_.data.detach() for _ in ctx.model.parameters()])
+
+
+def del_initialization(ctx):
+ """Clear the variable to avoid memory leakage
+
+ """
+ ctx.weight_init = None
diff --git a/federatedscope/core/trainers/trainer_multi_model.py b/federatedscope/core/trainers/trainer_multi_model.py
new file mode 100644
index 000000000..c1575b119
--- /dev/null
+++ b/federatedscope/core/trainers/trainer_multi_model.py
@@ -0,0 +1,268 @@
+import copy
+from types import FunctionType
+from typing import Type
+
+from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
+from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
+
+
+class GeneralMultiModelTrainer(GeneralTorchTrainer):
+ def __init__(self,
+ model_nums,
+ models_interact_mode="sequential",
+ model=None,
+ data=None,
+ device=None,
+ config=None,
+ base_trainer: Type[GeneralTorchTrainer] = None):
+ """
+ `GeneralMultiModelTrainer` supports train/eval via multiple internal models
+
+ Arguments:
+ model_nums (int): how many internal models and optimizers will be held by the trainer
+ models_interact_mode (str): how the models interact, can be "sequential" or "parallel".
+ model: training model
+ data: a dict contains train/val/test data
+ device: device to run
+ config: for trainer-related configuration
+ base_trainer: if given, the GeneralMultiModelTrainer init will based on base_trainer copy
+
+ The sequential mode indicates the interaction at run_routine level
+ [one model runs its whole routine, then do sth. for interaction, then next model runs its whole routine]
+ ... -> run_routine_model_i
+ -> _switch_model_ctx
+ -> (on_fit_end, _interact_to_other_models)
+ -> run_routine_model_i+1
+ -> ...
+
+ The parallel mode indicates the interaction at point-in-time level
+ [At a specific point-in-time, one model call hooks (including interaction), then next model call hooks]
+ ... -> (on_xxx_point, hook_xxx_model_i)
+ -> (on_xxx_point, _interact_to_other_models)
+ -> (on_xxx_point, _switch_model_ctx)
+ -> (on_xxx_point, hook_xxx_model_i+1)
+ -> ...
+
+ """
+ # support two initialization methods for the `GeneralMultiModelTrainer`
+ # 1) from another trainer; or 2) standard init manner given (model, data, device, config)
+ if base_trainer is None:
+ assert model is not None and \
+ data is not None and \
+ device is not None and \
+ config is not None, "when not copy construction, (model, data, device, config) should not be None"
+ super(GeneralMultiModelTrainer,
+ self).__init__(model, data, device, config)
+ else:
+ assert isinstance(base_trainer, GeneralMultiModelTrainer) or \
+ issubclass(type(base_trainer), GeneralMultiModelTrainer) or \
+ isinstance(base_trainer, GeneralTorchTrainer) or \
+ issubclass(type(base_trainer), GeneralTorchTrainer) or \
+ "can only copy instances of `GeneralMultiModelTrainer` and its subclasses, or " \
+ "`GeneralTorchTrainer` and its subclasses"
+ self.__dict__ = copy.deepcopy(base_trainer.__dict__)
+
+ assert models_interact_mode in ["sequential", "parallel"], \
+ f"Invalid models_interact_mode, should be `sequential` or `parallel`, but got {models_interact_mode}"
+ self.models_interact_mode = models_interact_mode
+
+ if int(model_nums) != model_nums or model_nums < 1:
+ raise ValueError(
+ f"model_nums should be integer and >= 1, got {model_nums}.")
+ self.model_nums = model_nums
+
+ self.ctx.cur_model_idx = 0 # used to mark cur model
+
+ # different internal models can have different hook_set
+ self.hooks_in_train_multiple_models = [self.hooks_in_train]
+ self.hooks_in_eval_multiple_models = [self.hooks_in_eval]
+ self.init_multiple_models()
+ self.init_multiple_model_hooks()
+ assert len(self.ctx.models) == model_nums == \
+ len(self.hooks_in_train_multiple_models) == len(self.hooks_in_eval_multiple_models),\
+ "After init, len(hooks_in_train_multiple_models), len(hooks_in_eval_multiple_models), " \
+ "len(ctx.models) and model_nums should be the same"
+
+ def init_multiple_models(self):
+ """
+ init multiple models and optimizers: the default implementation is copy init manner;
+ ========================= Extension =============================
+ users can override this function according to their own requirements
+ """
+
+ additional_models = [
+ copy.deepcopy(self.ctx.model) for _ in range(self.model_nums - 1)
+ ]
+ self.ctx.models = [self.ctx.model] + additional_models
+
+ additional_optimizers = [
+ get_optimizer(self.ctx.models[i], **self.cfg.optimizer)
+ for i in range(1, self.model_nums)
+ ]
+ self.ctx.optimizers = [self.ctx.optimizer] + additional_optimizers
+
+ def register_multiple_model_hooks(self):
+ """
+ By default, all internal models adopt the same hook_set.
+ ========================= Extension =============================
+ Users can override this function to register customized hooks for different internal models.
+
+ Note:
+ for sequential mode, users can append interact_hook on begin/end triggers such as
+ " -> (on_fit_end, _interact_to_other_models) -> "
+
+ for parallel mode, users can append interact_hook on any trigger they want such as
+ " -> (on_xxx_point, _interact_to_other_models) -> "
+
+ self.ctx, we must tell the running hooks which data_loader to call and which num_samples to count
+ """
+
+ self.hooks_in_train_multiple_models.extend([
+ self.hooks_in_train_multiple_models[0]
+ for _ in range(1, self.model_nums)
+ ])
+ self.hooks_in_eval_multiple_models.extend([
+ self.hooks_in_eval_multiple_models[0]
+ for _ in range(1, self.model_nums)
+ ])
+
+ def init_multiple_model_hooks(self):
+ self.register_multiple_model_hooks()
+ if self.models_interact_mode == "sequential":
+ # hooks_in_xxx is a list of dict, hooks_in_xxx[i] stores specific set for i-th internal model;
+ # for each dict, the key indicates point-in-time and the value indicates specific hook
+ self.hooks_in_train = self.hooks_in_train_multiple_models
+ self.hooks_in_eval = self.hooks_in_eval_multiple_models
+ elif self.models_interact_mode == "parallel":
+ # hooks_in_xxx is a dict whose key indicates point-in-time and value indicates specific hook
+ for trigger in list(self.hooks_in_train.keys()):
+ self.hooks_in_train[trigger] = []
+ self.hooks_in_eval[trigger] = []
+ for model_idx in range(len(self.ctx.models)):
+ self.hooks_in_train[trigger].extend(
+ self.hooks_in_train_multiple_models[model_idx]
+ [trigger])
+ self.hooks_in_train[trigger].extend(
+ [self._switch_model_ctx])
+ self.hooks_in_eval[trigger].extend(
+ self.hooks_in_eval_multiple_models[model_idx][trigger])
+ self.hooks_in_eval[trigger].extend(
+ [self._switch_model_ctx])
+ else:
+ raise RuntimeError(
+ f"Invalid models_interact_mode, should be `sequential` or `parallel`,"
+ f" but got {self.models_interact_mode}")
+
+ def register_hook_in_train(self,
+ new_hook,
+ trigger,
+ model_idx=0,
+ insert_pos=None,
+ base_hook=None,
+ insert_mode="before"):
+ hooks_dict = self.hooks_in_train_multiple_models[model_idx]
+ self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
+ new_hook, trigger)
+
+ def register_hook_in_eval(self,
+ new_hook,
+ trigger,
+ model_idx=0,
+ insert_pos=None,
+ base_hook=None,
+ insert_mode="before"):
+ hooks_dict = self.hooks_in_eval_multiple_models[model_idx]
+ self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
+ new_hook, trigger)
+
+ def _switch_model_ctx(self, next_model_idx=None):
+ if next_model_idx is None:
+ next_model_idx = (self.ctx.cur_model_idx + 1) % len(
+ self.ctx.models)
+ self.ctx.cur_model_idx = next_model_idx
+ self.ctx.model = self.ctx.models[next_model_idx]
+ self.ctx.optimizer = self.ctx.optimizers[next_model_idx]
+
+ def _run_routine(self, mode, hooks_set, dataset_name=None):
+ """Run the hooks_set and maintain the mode for multiple internal models
+
+ Arguments:
+ mode: running mode of client, chosen from train/val/test
+
+ Note:
+ Considering evaluation could be in ```hooks_set["on_epoch_end"]```, there could be two data loaders in
+ self.ctx, we must tell the running hooks which data_loader to call and which num_samples to count
+
+ """
+ if self.models_interact_mode == "sequential":
+ assert isinstance(hooks_set, list) and isinstance(hooks_set[0], dict), \
+ "When models_interact_mode=sequential, hooks_set should be a list of dict" \
+ "hooks_set[i] stores specific set for i-th internal model." \
+ "For each dict, the key indicates point-in-time and the value indicates specific hook"
+ for model_idx in range(len(self.ctx.models)):
+ # switch different hooks & ctx for different internal models
+ hooks_set_model_i = hooks_set[model_idx]
+ self._switch_model_ctx(model_idx)
+ # [Interaction at run_routine level]
+ # one model runs its whole routine, then do sth. for interaction, then next model runs its whole routine
+ # ... -> run_routine_model_i
+ # -> _switch_model_ctx
+ # -> (on_fit_end, _interact_to_other_models)
+ # -> run_routine_model_i+1
+ # -> ...
+ super()._run_routine(mode, hooks_set_model_i, dataset_name)
+ elif self.models_interact_mode == "parallel":
+ assert isinstance(hooks_set, dict), \
+ "When models_interact_mode=parallel, hooks_set should be a dict " \
+ "whose key indicates point-in-time and value indicates specific hook"
+ # [Interaction at point-in-time level]
+ # at a specific point-in-time, one model call hooks (including interaction), then next model call hooks
+ # ... -> (on_xxx_point, hook_xxx_model_i)
+ # -> (on_xxx_point, _interact_to_other_models)
+ # -> (on_xxx_point, _switch_model_ctx)
+ # -> (on_xxx_point, hook_xxx_model_i+1)
+ # -> ...
+ super()._run_routine(mode, hooks_set, dataset_name)
+ else:
+ raise RuntimeError(
+ f"Invalid models_interact_mode, should be `sequential` or `parallel`,"
+ f" but got {self.models_interact_mode}")
+
+ def get_model_para(self):
+ """
+ return multiple model parameters
+ :return:
+ """
+ trained_model_para = []
+ for model_idx in range(self.model_nums):
+ trained_model_para.append(
+ self._param_filter(
+ self.ctx.models[model_idx].cpu().state_dict()))
+
+ return trained_model_para[
+ 0] if self.model_nums == 1 else trained_model_para
+
+ def update(self, model_parameters):
+ # update multiple model paras
+ """
+ Arguments:
+ model_parameters (list[dict]): Multiple pyTorch Module object's state_dict.
+ """
+ if self.model_nums == 1:
+ super().update(model_parameters)
+ else:
+ assert isinstance(model_parameters, list) and isinstance(model_parameters[0], dict), \
+ "model_parameters should a list of multiple state_dict"
+ assert len(model_parameters) == self.model_nums, \
+ f"model_parameters should has the same length to self.model_nums, " \
+ f"but got {len(model_parameters)} and {self.model_nums} respectively"
+ for model_idx in range(self.model_nums):
+ self.ctx.models[model_idx].load_state_dict(self._param_filter(
+ model_parameters[model_idx]),
+ strict=False)
+
+ def train(self, target_data_split_name="train"):
+ # return multiple model paras
+ sample_size, _, results = super().train(target_data_split_name)
+
+ return sample_size, self.get_model_para(), results
diff --git a/federatedscope/core/trainers/trainer_nbafl.py b/federatedscope/core/trainers/trainer_nbafl.py
new file mode 100644
index 000000000..8b06d1fb9
--- /dev/null
+++ b/federatedscope/core/trainers/trainer_nbafl.py
@@ -0,0 +1,133 @@
+from federatedscope.core.auxiliaries.utils import get_random
+from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
+#from federatedscope.core.worker.server import Server
+from typing import Type
+from copy import deepcopy
+
+import numpy as np
+import torch
+
+
+def wrap_nbafl_trainer(
+ base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
+ """Implementation of NbAFL refer to `Federated Learning with Differential Privacy: Algorithms and Performance Analysis` [et al., 2020]
+ (https://ieeexplore.ieee.org/abstract/document/9069945/)
+
+ Arguments:
+ mu: the factor of the regularizer
+ epsilon: the distinguishable bound
+ w_clip: the threshold to clip weights
+
+ """
+
+ # ---------------- attribute-level plug-in -----------------------
+ init_nbafl_ctx(base_trainer)
+
+ # ---------------- action-level plug-in -----------------------
+ base_trainer.register_hook_in_train(new_hook=record_initialization,
+ trigger='on_fit_start',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_eval(new_hook=record_initialization,
+ trigger='on_fit_start',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_train(new_hook=del_initialization,
+ trigger='on_fit_end',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_eval(new_hook=del_initialization,
+ trigger='on_fit_end',
+ insert_pos=-1)
+
+ base_trainer.register_hook_in_train(new_hook=inject_noise_in_upload,
+ trigger='on_fit_end',
+ insert_pos=-1)
+ return base_trainer
+
+
+def init_nbafl_ctx(base_trainer):
+ """Set proximal regularizer, and the scale of gaussian noise
+
+ """
+ ctx = base_trainer.ctx
+ cfg = base_trainer.cfg
+
+ # set proximal regularizer
+ cfg.defrost()
+ cfg.regularizer.type = 'proximal_regularizer'
+ cfg.regularizer.mu = cfg.nbafl.mu
+ cfg.freeze()
+ from federatedscope.core.auxiliaries.regularizer_builder import get_regularizer
+ ctx.regularizer = get_regularizer(cfg.regularizer.type)
+
+ # set noise scale during upload
+ ctx.nbafl_scale_u = cfg.nbafl.w_clip * cfg.federate.total_round_num * cfg.nbafl.constant / ctx.num_train_data / cfg.nbafl.epsilon
+
+
+# ------------------------------------------------------------------------ #
+# Additional functions for NbAFL algorithm
+# ------------------------------------------------------------------------ #
+
+
+# Trainer
+def record_initialization(ctx):
+ """Record the initialized weights within local updates
+
+ """
+ ctx.weight_init = deepcopy(
+ [_.data.detach() for _ in ctx.model.parameters()])
+
+
+def del_initialization(ctx):
+ """Clear the variable to avoid memory leakage
+
+ """
+ ctx.weight_init = None
+
+
+def inject_noise_in_upload(ctx):
+ """Inject noise into weights before the client upload them to server
+
+ """
+ for p in ctx.model.parameters():
+ noise = get_random("Normal", p.shape, {
+ "loc": 0,
+ "scale": ctx.nbafl_scale_u
+ }, p.device)
+ p.data += noise
+
+
+# Server
+def inject_noise_in_broadcast(cfg, sample_client_num, model):
+ """Inject noise into weights before the server broadcasts them
+
+ """
+
+ # Clip weight
+ for p in model.parameters():
+ p.data = p.data / torch.max(torch.ones(size=p.shape),
+ torch.abs(p.data) / cfg.nbafl.w_clip)
+
+ if len(sample_client_num) > 0:
+ # Inject noise
+ L = cfg.federate.sample_client_num if cfg.federate.sample_client_num > 0 else cfg.federate.client_num
+ if cfg.federate.total_round_num > np.sqrt(cfg.federate.client_num) * L:
+ scale_d = 2 * cfg.nbafl.w_clip * cfg.nbafl.constant * np.sqrt(
+ np.power(cfg.federate.total_round_num, 2) -
+ np.power(L, 2) * cfg.federate.client_num) / (
+ min(sample_client_num.values()) * cfg.federate.client_num *
+ cfg.nbafl.epsilon)
+ for p in model.parameters():
+ p.data += get_random("Normal", p.shape, {
+ "loc": 0,
+ "scale": scale_d
+ }, p.device)
+
+
+#def wrap_nbafl_server(server: Type[Server]) -> Type[Server]:
+def wrap_nbafl_server(server):
+ """Register noise injector for the server
+
+ """
+ server.register_noise_injector(inject_noise_in_broadcast)
diff --git a/federatedscope/core/trainers/trainer_pFedMe.py b/federatedscope/core/trainers/trainer_pFedMe.py
new file mode 100644
index 000000000..540808eaf
--- /dev/null
+++ b/federatedscope/core/trainers/trainer_pFedMe.py
@@ -0,0 +1,149 @@
+import copy
+
+from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
+from federatedscope.core.optimizer import wrap_regularized_optimizer
+from typing import Type
+
+
+def wrap_pFedMeTrainer(
+ base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
+ """
+ Build a `pFedMeTrainer` with a plug-in manner, by registering new functions into specific `BaseTrainer`
+
+ The pFedMe implementation, "Personalized Federated Learning with Moreau Envelopes (NeurIPS 2020)"
+ is based on the Algorithm 1 in their paper and official codes: https://github.com/CharlieDinh/pFedMe
+ """
+
+ # ---------------- attribute-level plug-in -----------------------
+ init_pFedMe_ctx(base_trainer)
+
+ # ---------------- action-level plug-in -----------------------
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_fit_start_set_local_para_tmp,
+ trigger="on_fit_start",
+ insert_pos=-1)
+ base_trainer.register_hook_in_train(
+ new_hook=hook_on_epoch_end_update_local,
+ trigger="on_epoch_end",
+ insert_pos=-1)
+ base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_update_local,
+ trigger="on_fit_end",
+ insert_pos=-1)
+
+ base_trainer.replace_hook_in_train(
+ new_hook=_hook_on_batch_forward_flop_count,
+ target_trigger="on_batch_forward",
+ target_hook_name="_hook_on_batch_forward_flop_count")
+ base_trainer.register_hook_in_train(new_hook=_hook_on_epoch_end_flop_count,
+ trigger="on_epoch_end",
+ insert_pos=-1)
+
+ # for "on_batch_start" trigger: replace the original hooks into new ones of pFedMe
+ # 1) cache the original hooks for "on_batch_start"
+ base_trainer.ctx.original_hook_on_batch_start_train = base_trainer.hooks_in_train[
+ "on_batch_start"]
+ # base_trainer.ctx.original_hook_on_batch_start_eval = base_trainer.hooks_in_eval[
+ # "on_batch_start"]
+ # 2) replace the original hooks for "on_batch_start"
+ base_trainer.replace_hook_in_train(
+ new_hook=hook_on_batch_start_init_pfedme,
+ target_trigger="on_batch_start",
+ target_hook_name=None)
+ # base_trainer.replace_hook_in_eval(new_hook=hook_on_batch_start_init_pfedme,
+ # target_trigger="on_batch_start",
+ # target_hook_name=None)
+
+ return base_trainer
+
+
+def init_pFedMe_ctx(base_trainer):
+ """
+ init necessary attributes used in pFedMe,
+ some new attributes will be with prefix `pFedMe` optimizer to avoid namespace pollution
+ """
+ ctx = base_trainer.ctx
+ cfg = base_trainer.cfg
+
+ # pFedMe finds approximate model with K steps using the same data batch
+ # the complexity of each pFedMe client is K times the one of FedAvg
+ ctx.pFedMe_K = cfg.personalization.K
+ ctx.num_train_epoch *= ctx.pFedMe_K
+ ctx.pFedMe_approx_fit_counter = 0
+
+ # the local_model_tmp is used to be the referenced parameter when finding the approximate \theta in paper
+ # will be copied from model every run_routine
+ ctx.pFedMe_local_model_tmp = None
+
+ # the optimizer used in pFedMe is based on Moreau Envelopes regularization
+ # besides, there are two distinct lr for the approximate model and base model
+ ctx.optimizer = wrap_regularized_optimizer(
+ ctx.optimizer, cfg.personalization.regular_weight)
+ for g in ctx.optimizer.param_groups:
+ g['lr'] = cfg.personalization.lr
+ ctx.pFedMe_outer_lr = cfg.optimizer.lr
+
+
+def hook_on_fit_start_set_local_para_tmp(ctx):
+ ctx.pFedMe_local_model_tmp = copy.deepcopy(ctx.model)
+ # set the compared model data, then the optimizer will find approximate model using trainer.cfg.personalization.lr
+ compared_global_model_para = [{
+ "params": list(ctx.pFedMe_local_model_tmp.parameters())
+ }]
+ ctx.optimizer.set_compared_para_group(compared_global_model_para)
+
+
+def hook_on_batch_start_init_pfedme(ctx):
+ # refresh data every K step
+ if ctx.pFedMe_approx_fit_counter == 0:
+ if ctx.cur_mode == "train":
+ for hook in ctx.original_hook_on_batch_start_train:
+ hook(ctx)
+ else:
+ for hook in ctx.original_hook_on_batch_start_eval:
+ hook(ctx)
+ ctx.data_batch_cache = copy.deepcopy(ctx.data_batch)
+ else:
+ # reuse the data_cache since the original hook `_hook_on_batch_end` will clean `data_batch`
+ ctx.data_batch = copy.deepcopy(ctx.data_batch_cache)
+ ctx.pFedMe_approx_fit_counter = (ctx.pFedMe_approx_fit_counter +
+ 1) % ctx.pFedMe_K
+
+
+def _hook_on_batch_forward_flop_count(ctx):
+ if ctx.monitor.flops_per_sample == 0:
+ # calculate the flops_per_sample
+ x, _ = [_.to(ctx.device) for _ in ctx.data_batch]
+ from fvcore.nn import FlopCountAnalysis
+ flops_one_batch = FlopCountAnalysis(ctx.model, x).total()
+ # besides the normal forward flops, pFedMe introduces
+ # 1) the regularization adds the cost of number of model parameters
+ flops_one_batch += ctx.monitor.total_model_size / 2
+ ctx.monitor.track_avg_flops(flops_one_batch, ctx.batch_size)
+ ctx.monitor.total_flops += ctx.monitor.flops_per_sample * ctx.batch_size
+
+
+def _hook_on_epoch_end_flop_count(ctx):
+ # due to the local weight updating
+ ctx.monitor.total_flops += ctx.monitor.total_model_size / 2
+
+
+def hook_on_epoch_end_update_local(ctx):
+ # update local weight after finding approximate theta
+ for client_param, local_para_tmp in zip(
+ ctx.model.parameters(), ctx.pFedMe_local_model_tmp.parameters()):
+ local_para_tmp.data = local_para_tmp.data - ctx.optimizer.regular_weight * \
+ ctx.pFedMe_outer_lr * (local_para_tmp.data - client_param.data)
+
+ # set the compared model data, then the optimizer will find approximate model using trainer.cfg.personalization.lr
+ compared_global_model_para = [{
+ "params": list(ctx.pFedMe_local_model_tmp.parameters())
+ }]
+ ctx.optimizer.set_compared_para_group(compared_global_model_para)
+
+
+def hook_on_fit_end_update_local(ctx):
+ for param, local_para_tmp in zip(ctx.model.parameters(),
+ ctx.pFedMe_local_model_tmp.parameters()):
+ param.data = local_para_tmp.data
+
+ del ctx.pFedMe_local_model_tmp
diff --git a/federatedscope/core/worker/__init__.py b/federatedscope/core/worker/__init__.py
new file mode 100644
index 000000000..05be87247
--- /dev/null
+++ b/federatedscope/core/worker/__init__.py
@@ -0,0 +1,10 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+from __future__ import with_statement
+
+from federatedscope.core.worker.base_worker import Worker
+from federatedscope.core.worker.server import Server
+from federatedscope.core.worker.client import Client
+
+__all__ = ['Worker', 'Server', 'Client']
diff --git a/federatedscope/core/worker/base_worker.py b/federatedscope/core/worker/base_worker.py
new file mode 100644
index 000000000..f7f064de9
--- /dev/null
+++ b/federatedscope/core/worker/base_worker.py
@@ -0,0 +1,55 @@
+from federatedscope.core.monitors.monitor import Monitor
+
+
+class Worker(object):
+ """
+ The base worker class.
+ """
+ def __init__(self, ID=-1, state=0, config=None, model=None, strategy=None):
+ self._ID = ID
+ self._state = state
+ self._model = model
+ self._cfg = config
+ self._strategy = strategy
+ self._mode = self._cfg.federate.mode.lower()
+ self._monitor = Monitor(config, monitored_object=self)
+
+ @property
+ def ID(self):
+ return self._ID
+
+ @ID.setter
+ def ID(self, value):
+ self._ID = value
+
+ @property
+ def state(self):
+ return self._state
+
+ @state.setter
+ def state(self, value):
+ self._state = value
+
+ @property
+ def model(self):
+ return self._model
+
+ @model.setter
+ def model(self, value):
+ self._model = value
+
+ @property
+ def strategy(self):
+ return self._strategy
+
+ @strategy.setter
+ def strategy(self, value):
+ self._strategy = value
+
+ @property
+ def mode(self):
+ return self._mode
+
+ @mode.setter
+ def mode(self, value):
+ self._mode = value
diff --git a/federatedscope/core/worker/client.py b/federatedscope/core/worker/client.py
new file mode 100644
index 000000000..6d53a73b6
--- /dev/null
+++ b/federatedscope/core/worker/client.py
@@ -0,0 +1,404 @@
+import copy
+import logging
+import sys
+import pickle
+
+from federatedscope.core.message import Message
+from federatedscope.core.communication import StandaloneCommManager, \
+ gRPCCommManager
+from federatedscope.core.monitors.early_stopper import EarlyStopper
+from federatedscope.core.monitors.monitor import update_best_result
+from federatedscope.core.worker import Worker
+from federatedscope.core.auxiliaries.trainer_builder import get_trainer
+from federatedscope.core.secret_sharing import AdditiveSecretSharing
+from federatedscope.core.auxiliaries.utils import merge_dict
+
+logger = logging.getLogger(__name__)
+
+
+class Client(Worker):
+ """
+ The Client class, which describes the behaviors of client in an FL course.
+ The behaviors are described by the handling functions (named as callback_funcs_for_xxx)
+
+ Arguments:
+ ID: The unique ID of the client, which is assigned by the server when joining the FL course
+ server_id: (Default) 0
+ state: The training round
+ config: The configuration
+ data: The data owned by the client
+ model: The model maintained locally
+ device: The device to run local training and evaluation
+ strategy: redundant attribute
+ """
+ def __init__(self,
+ ID=-1,
+ server_id=None,
+ state=-1,
+ config=None,
+ data=None,
+ model=None,
+ device='cpu',
+ strategy=None,
+ *args,
+ **kwargs):
+
+ super(Client, self).__init__(ID, state, config, model, strategy)
+
+ # Attack only support the stand alone model;
+ # Check if is a attacker; a client is a attacker if the config.attack.attack_method is provided
+ self.is_attacker = config.attack.attacker_id == ID and \
+ config.attack.attack_method != '' and config.federate.mode == 'standalone'
+
+ # Build Trainer
+ # trainer might need configurations other than those of trainer node
+ self.trainer = get_trainer(model=model,
+ data=data,
+ device=device,
+ config=self._cfg,
+ is_attacker=self.is_attacker,
+ monitor=self._monitor)
+
+ # For client-side evaluation
+ self.best_results = dict()
+ self.history_results = dict()
+ # in local or global training mode, we do use the early stopper.
+ # Otherwise, we set patience=0 to deactivate the local early-stopper
+ patience = self._cfg.early_stop.patience if self._cfg.federate.method in [
+ "local", "global"
+ ] else 0
+ self.early_stopper = EarlyStopper(
+ patience, self._cfg.early_stop.delta,
+ self._cfg.early_stop.improve_indicator_mode,
+ self._cfg.early_stop.the_smaller_the_better)
+
+ # Secret Sharing Manager and message buffer
+ self.ss_manager = AdditiveSecretSharing(
+ shared_party_num=int(self._cfg.federate.sample_client_num
+ )) if self._cfg.federate.use_ss else None
+ self.msg_buffer = {'train': dict(), 'eval': dict()}
+
+ # Register message handlers
+ self.msg_handlers = dict()
+ self._register_default_handlers()
+
+ # Initialize communication manager
+ self.server_id = server_id
+ if self.mode == 'standalone':
+ comm_queue = kwargs['shared_comm_queue']
+ self.comm_manager = StandaloneCommManager(comm_queue=comm_queue,
+ monitor=self._monitor)
+ self.local_address = None
+ elif self.mode == 'distributed':
+ host = kwargs['host']
+ port = kwargs['port']
+ server_host = kwargs['server_host']
+ server_port = kwargs['server_port']
+ self.comm_manager = gRPCCommManager(
+ host=host, port=port, client_num=self._cfg.federate.client_num)
+ logger.info('Client: Listen to {}:{}...'.format(host, port))
+ self.comm_manager.add_neighbors(neighbor_id=server_id,
+ address={
+ 'host': server_host,
+ 'port': server_port
+ })
+ self.local_address = {
+ 'host': self.comm_manager.host,
+ 'port': self.comm_manager.port
+ }
+
+ def register_handlers(self, msg_type, callback_func):
+ """
+ To bind a message type with a handling function.
+
+ Arguments:
+ msg_type (str): The defined message type
+ callback_func: The handling functions to handle the received message
+ """
+ self.msg_handlers[msg_type] = callback_func
+
+ def _register_default_handlers(self):
+ self.register_handlers('assign_client_id',
+ self.callback_funcs_for_assign_id)
+ self.register_handlers('ask_for_join_in_info',
+ self.callback_funcs_for_join_in_info)
+ self.register_handlers('address', self.callback_funcs_for_address)
+ self.register_handlers('model_para',
+ self.callback_funcs_for_model_para)
+ self.register_handlers('ss_model_para',
+ self.callback_funcs_for_model_para)
+ self.register_handlers('evaluate', self.callback_funcs_for_evaluate)
+ self.register_handlers('finish', self.callback_funcs_for_finish)
+ self.register_handlers('converged', self.callback_funcs_for_converged)
+
+ def join_in(self):
+ """
+ To send 'join_in' message to the server for joining in the FL course.
+ """
+ self.comm_manager.send(
+ Message(msg_type='join_in',
+ sender=self.ID,
+ receiver=[self.server_id],
+ content=self.local_address))
+
+ def run(self):
+ """
+ To listen to the message and handle them accordingly (used for distributed mode)
+ """
+ while True:
+ msg = self.comm_manager.receive()
+ if self.state <= msg.state:
+ self.msg_handlers[msg.msg_type](msg)
+
+ if msg.msg_type == 'finish':
+ break
+
+ def callback_funcs_for_model_para(self, message: Message):
+ """
+ The handling function for receiving model parameters, which triggers the local training process.
+ This handling function is widely used in various FL courses.
+
+ Arguments:
+ message: The received message, which includes sender, receiver, state, and content.
+ More detail can be found in federatedscope.core.message
+ """
+ if 'ss' in message.msg_type:
+ # A fragment of the shared secret
+ state, content = message.state, message.content
+ self.msg_buffer['train'][state].append(content)
+
+ if len(self.msg_buffer['train']
+ [state]) == self._cfg.federate.client_num:
+ # Check whether the received fragments are enough
+ model_list = self.msg_buffer['train'][state]
+ sample_size, first_aggregate_model_para = model_list[0]
+ single_model_case = True
+ if isinstance(first_aggregate_model_para, list):
+ assert isinstance(first_aggregate_model_para[0], dict), \
+ "aggregate_model_para should a list of multiple state_dict for multiple models"
+ single_model_case = False
+ else:
+ assert isinstance(first_aggregate_model_para, dict), \
+ "aggregate_model_para should a state_dict for single model case"
+ first_aggregate_model_para = [first_aggregate_model_para]
+ model_list = [[model] for model in model_list]
+
+ for sub_model_idx, aggregate_single_model_para in enumerate(
+ first_aggregate_model_para):
+ for key in aggregate_single_model_para:
+ for i in range(1, len(model_list)):
+ aggregate_single_model_para[key] += model_list[i][
+ sub_model_idx][key]
+
+ self.comm_manager.send(
+ Message(msg_type='model_para',
+ sender=self.ID,
+ receiver=[self.server_id],
+ state=self.state,
+ content=(sample_size, first_aggregate_model_para[0]
+ if single_model_case else
+ first_aggregate_model_para)))
+
+ else:
+ round, sender, content = message.state, message.sender, message.content
+ self.trainer.update(content)
+ self.state = round
+ if self.early_stopper.early_stopped and self._cfg.federate.method in [
+ "local", "global"
+ ]:
+ sample_size, model_para_all, results = 0, self.trainer.get_model_para(
+ ), {}
+ logger.info(
+ f"Client #{self.ID} has been early stopped, we will skip the local training"
+ )
+ self._monitor.local_converged()
+ else:
+ sample_size, model_para_all, results = self.trainer.train()
+ logger.info(
+ self._monitor.format_eval_res(results,
+ rnd=self.state,
+ role='Client #{}'.format(
+ self.ID),
+ return_raw=True))
+
+ # Return the feedbacks to the server after local update
+ if self._cfg.federate.use_ss:
+ single_model_case = True
+ if isinstance(model_para_all, list):
+ assert isinstance(model_para_all[0], dict), \
+ "model_para should a list of multiple state_dict for multiple models"
+ single_model_case = False
+ else:
+ assert isinstance(model_para_all, dict), \
+ "model_para should a state_dict for single model case"
+ model_para_all = [model_para_all]
+ model_para_list_all = []
+ for model_para in model_para_all:
+ for key in model_para:
+ model_para[key] = model_para[key] * sample_size
+ model_para_list = self.ss_manager.secret_split(model_para)
+ model_para_list_all.append(model_para_list)
+ #print(model_para)
+ #print(self.ss_manager.secret_reconstruct(model_para_list))
+ frame_idx = 0
+ for neighbor in self.comm_manager.neighbors:
+ if neighbor != self.server_id:
+ content_frame = model_para_list_all[0][frame_idx] if single_model_case else \
+ [model_para_list[frame_idx] for model_para_list in model_para_list_all]
+ self.comm_manager.send(
+ Message(msg_type='ss_model_para',
+ sender=self.ID,
+ receiver=[neighbor],
+ state=self.state,
+ content=content_frame))
+ frame_idx += 1
+ content_frame = model_para_list_all[0][frame_idx] if single_model_case else \
+ [model_para_list[frame_idx] for model_para_list in model_para_list_all]
+ self.msg_buffer['train'][self.state] = [(sample_size,
+ content_frame)]
+ else:
+ self.comm_manager.send(
+ Message(msg_type='model_para',
+ sender=self.ID,
+ receiver=[sender],
+ state=self.state,
+ content=(sample_size, model_para_all)))
+
+ def callback_funcs_for_assign_id(self, message: Message):
+ """
+ The handling function for receiving the client_ID assigned by the server (during the joining process),
+ which is used in the distributed mode.
+
+ Arguments:
+ message: The received message
+ """
+ content = message.content
+ self.ID = int(content)
+ logger.info('Client (address {}:{}) is assigned with #{:d}.'.format(
+ self.comm_manager.host, self.comm_manager.port, self.ID))
+
+ def callback_funcs_for_join_in_info(self, message: Message):
+ """
+ The handling function for receiving the request of join in information
+ (such as batch_size, num_of_samples) during the joining process.
+
+ Arguments:
+ message: The received message
+ """
+ requirements = message.content
+ join_in_info = dict()
+ for requirement in requirements:
+ if requirement.lower() == 'num_sample':
+ if self._cfg.federate.batch_or_epoch == 'batch':
+ num_sample = self._cfg.federate.local_update_steps * self._cfg.data.batch_size
+ else:
+ num_sample = self._cfg.federate.local_update_steps * self.trainer.ctx.num_train_batch
+ join_in_info['num_sample'] = num_sample
+ else:
+ raise ValueError(
+ 'Fail to get the join in information with type {}'.format(
+ requirement))
+ self.comm_manager.send(
+ Message(msg_type='join_in_info',
+ sender=self.ID,
+ receiver=[self.server_id],
+ state=self.state,
+ content=join_in_info))
+
+ def callback_funcs_for_address(self, message: Message):
+ """
+ The handling function for receiving other clients' IP addresses, which is used for constructing a complex topology
+
+ Arguments:
+ message: The received message
+ """
+ content = message.content
+ for neighbor_id, address in content.items():
+ if int(neighbor_id) != self.ID:
+ self.comm_manager.add_neighbors(neighbor_id, address)
+
+ def callback_funcs_for_evaluate(self, message: Message):
+ """
+ The handling function for receiving the request of evaluating
+
+ Arguments:
+ message: The received message
+ """
+
+ sender = message.sender
+ self.state = message.state
+ if message.content != None:
+ self.trainer.update(message.content)
+ if self.early_stopper.early_stopped and self._cfg.federate.method in [
+ "local", "global"
+ ]:
+ metrics = list(self.best_results.values())[0]
+ else:
+ metrics = {}
+ if self._cfg.trainer.finetune.before_eval:
+ self.trainer.finetune()
+ for split in self._cfg.eval.split:
+ # new function part
+ if split in ['poison'] and self.is_attacker:
+ continue
+ # new function part
+ eval_metrics = self.trainer.evaluate(
+ target_data_split_name=split)
+
+ if self._cfg.federate.mode == 'distributed':
+ logger.info(
+ self._monitor.format_eval_res(eval_metrics,
+ rnd=self.state,
+ role='Client #{}'.format(
+ self.ID)))
+
+ metrics.update(**eval_metrics)
+
+ formatted_eval_res = self._monitor.format_eval_res(
+ metrics,
+ rnd=self.state,
+ role='Client #{}'.format(self.ID),
+ forms='raw',
+ return_raw=True)
+ update_best_result(self.best_results,
+ formatted_eval_res['Results_raw'],
+ results_type=f"client #{self.ID}",
+ round_wise_update_key=self._cfg.eval.
+ best_res_update_round_wise_key)
+ self.history_results = merge_dict(
+ self.history_results, formatted_eval_res['Results_raw'])
+ self.early_stopper.track_and_check_best(self.history_results[
+ self._cfg.eval.best_res_update_round_wise_key])
+
+ self.comm_manager.send(
+ Message(msg_type='metrics',
+ sender=self.ID,
+ receiver=[sender],
+ state=self.state,
+ content=metrics))
+
+ def callback_funcs_for_finish(self, message: Message):
+ """
+ The handling function for receiving the signal of finishing the FL course
+
+ Arguments:
+ message: The received message
+ """
+ logger.info(
+ f"================= client {self.ID} received finish message ============================"
+ )
+
+ if message.content != None:
+ self.trainer.update(message.content)
+
+ self._monitor.finish_fl()
+
+ def callback_funcs_for_converged(self, message: Message):
+ """
+ The handling function for receiving the signal that the FL course converged
+
+ Arguments:
+ message: The received message
+ """
+
+ self._monitor.global_converged()
diff --git a/federatedscope/core/worker/server.py b/federatedscope/core/worker/server.py
new file mode 100644
index 000000000..6c042ba9a
--- /dev/null
+++ b/federatedscope/core/worker/server.py
@@ -0,0 +1,707 @@
+from http import client
+import logging
+import copy
+import os
+
+import numpy as np
+import pickle
+
+from federatedscope.core.monitors.early_stopper import EarlyStopper
+from federatedscope.core.message import Message
+from federatedscope.core.communication import StandaloneCommManager, gRPCCommManager
+from federatedscope.core.monitors.monitor import update_best_result
+from federatedscope.core.worker import Worker
+from federatedscope.core.auxiliaries.aggregator_builder import get_aggregator
+from federatedscope.core.auxiliaries.utils import merge_dict, Timeout
+from federatedscope.core.auxiliaries.trainer_builder import get_trainer
+from federatedscope.core.secret_sharing import AdditiveSecretSharing
+
+logger = logging.getLogger(__name__)
+
+
+class Server(Worker):
+ """
+ The Server class, which describes the behaviors of server in an FL course.
+ The behaviors are described by the handled functions (named as callback_funcs_for_xxx).
+
+ Arguments:
+ ID: The unique ID of the server, which is set to 0 by default
+ state: The training round
+ config: the configuration
+ data: The data owned by the server (for global evaluation)
+ model: The model used for aggregation
+ client_num: The (expected) client num to start the FL course
+ total_round_num: The total number of the training round
+ device: The device to run local training and evaluation
+ strategy: redundant attribute
+ """
+ def __init__(self,
+ ID=-1,
+ state=0,
+ config=None,
+ data=None,
+ model=None,
+ client_num=5,
+ total_round_num=10,
+ device='cpu',
+ strategy=None,
+ **kwargs):
+
+ super(Server, self).__init__(ID, state, config, model, strategy)
+
+ self.data = data
+ self.device = device
+ self.best_results = dict()
+ self.history_results = dict()
+ self.early_stopper = EarlyStopper(
+ self._cfg.early_stop.patience, self._cfg.early_stop.delta,
+ self._cfg.early_stop.improve_indicator_mode,
+ self._cfg.early_stop.the_smaller_the_better)
+
+ if self._cfg.federate.share_local_model:
+ # put the model to the specified device
+ model.to(device)
+ # Build aggregator
+ self.aggregator = get_aggregator(self._cfg.federate.method,
+ model=model,
+ device=device,
+ online=self._cfg.federate.online_aggr,
+ config=self._cfg)
+ if self._cfg.federate.restore_from != '':
+ cur_round = self.aggregator.load_model(
+ self._cfg.federate.restore_from)
+ logger.info("Restored the model from {}-th round's ckpt")
+
+ if int(config.model.model_num_per_trainer) != config.model.model_num_per_trainer or \
+ config.model.model_num_per_trainer < 1:
+ raise ValueError(
+ f"model_num_per_trainer should be integer and >= 1, "
+ f"got {config.model.model_num_per_trainer}.")
+ self.model_num = config.model.model_num_per_trainer
+ self.models = [self.model]
+ self.aggregators = [self.aggregator]
+ if self.model_num > 1:
+ self.models.extend(
+ [copy.deepcopy(self.model) for _ in range(self.model_num - 1)])
+ self.aggregators.extend([
+ copy.deepcopy(self.aggregator)
+ for _ in range(self.model_num - 1)
+ ])
+
+ # function for recovering shared secret
+ self.recover_fun = AdditiveSecretSharing(
+ shared_party_num=int(self._cfg.federate.sample_client_num)
+ ).fixedpoint2float if self._cfg.federate.use_ss else None
+
+ if self._cfg.federate.make_global_eval:
+ # set up a trainer for conducting evaluation in server
+ assert self.model is not None
+ assert self.data is not None
+ self.trainer = get_trainer(
+ model=self.model,
+ data=self.data,
+ device=self.device,
+ config=self._cfg,
+ only_for_eval=True,
+ monitor=self._monitor
+ ) # the trainer is only used for global evaluation
+ self.trainers = [self.trainer]
+ if self.model_num > 1:
+ # By default, the evaluation is conducted by calling trainer[i].eval over all internal models
+ self.trainers.extend([
+ copy.deepcopy(self.trainer)
+ for _ in range(self.model_num - 1)
+ ])
+
+ # Initialize the number of joined-in clients
+ self._client_num = client_num
+ self._total_round_num = total_round_num
+ self.sample_client_num = int(self._cfg.federate.sample_client_num)
+ self.join_in_client_num = 0
+ self.join_in_info = dict()
+
+ # Register message handlers
+ self.msg_handlers = dict()
+ self._register_default_handlers()
+
+ # Initialize communication manager and message buffer
+ self.msg_buffer = {'train': dict(), 'eval': dict()}
+ if self.mode == 'standalone':
+ comm_queue = kwargs['shared_comm_queue']
+ self.comm_manager = StandaloneCommManager(comm_queue=comm_queue,
+ monitor=self._monitor)
+ elif self.mode == 'distributed':
+ host = kwargs['host']
+ port = kwargs['port']
+ self.comm_manager = gRPCCommManager(host=host,
+ port=port,
+ client_num=client_num)
+ logger.info('Server #{:d}: Listen to {}:{}...'.format(
+ self.ID, host, port))
+
+ # inject noise before broadcast
+ self._noise_injector = None
+
+ @property
+ def client_num(self):
+ return self._client_num
+
+ @client_num.setter
+ def client_num(self, value):
+ self._client_num = value
+
+ @property
+ def total_round_num(self):
+ return self._total_round_num
+
+ @total_round_num.setter
+ def total_round_num(self, value):
+ self._total_round_num = value
+
+ def register_noise_injector(self, func):
+ self._noise_injector = func
+
+ def register_handlers(self, msg_type, callback_func):
+ """
+ To bind a message type with a handling function.
+
+ Arguments:
+ msg_type (str): The defined message type
+ callback_func: The handling functions to handle the received message
+ """
+ self.msg_handlers[msg_type] = callback_func
+
+ def _register_default_handlers(self):
+ self.register_handlers('join_in', self.callback_funcs_for_join_in)
+ self.register_handlers('join_in_info', self.callback_funcs_for_join_in)
+ self.register_handlers('model_para', self.callback_funcs_model_para)
+ self.register_handlers('metrics', self.callback_funcs_for_metrics)
+
+ def run(self):
+ """
+ To start the FL course, listen and handle messages (for distributed mode).
+ """
+
+ # Begin: Broadcast model parameters and start to FL train
+ while self.join_in_client_num < self.client_num:
+ msg = self.comm_manager.receive()
+ self.msg_handlers[msg.msg_type](msg)
+
+ # Running: listen for message (updates from clients),
+ # aggregate and broadcast feedbacks (aggregated model parameters)
+ min_received_num = self._cfg.asyn.min_received_num if hasattr(
+ self._cfg, 'asyn') else self._cfg.federate.sample_client_num
+ num_failure = 0
+ with Timeout(self._cfg.asyn.timeout) as time_counter:
+ while self.state <= self.total_round_num:
+ try:
+ msg = self.comm_manager.receive()
+ move_on_flag = self.msg_handlers[msg.msg_type](msg)
+ if move_on_flag:
+ time_counter.reset()
+ except TimeoutError:
+ logger.info('Time out at the training round #{}'.format(
+ self.state))
+ move_on_flag_eval = self.check_and_move_on(
+ min_received_num=min_received_num,
+ check_eval_result=True)
+ move_on_flag = self.check_and_move_on(
+ min_received_num=min_received_num)
+ if not move_on_flag and not move_on_flag_eval:
+ num_failure += 1
+ # Terminate the training if the number of failure exceeds the maximum number (default value: 10)
+ if time_counter.exceed_max_failure(num_failure):
+ logger.info(
+ '----------- Training fails at round #{:d} -------------'
+ .format(self.state))
+ break
+
+ # Time out, broadcast the model para and re-start the training round
+ logger.info(
+ '----------- Re-starting the training round (Round #{:d}) for {:d} time -------------'
+ .format(self.state, num_failure))
+ # Clean the msg_buffer
+ self.msg_buffer['train'][self.state].clear()
+
+ self.broadcast_model_para(
+ msg_type='model_para',
+ sample_client_num=self.sample_client_num)
+ else:
+ num_failure = 0
+ time_counter.reset()
+
+ self.terminate(msg_type='finish')
+
+ def check_and_move_on(self,
+ check_eval_result=False,
+ min_received_num=None):
+ """
+ To check the message_buffer. When enough messages are receiving, some events
+ (such as perform aggregation, evaluation, and move to the next training round) would be triggered.
+
+ Arguments:
+ check_eval_result (bool): If True, check the message buffer for evaluation;
+ and check the message buffer for training otherwise.
+ """
+ if min_received_num is None:
+ min_received_num = self._cfg.federate.sample_client_num
+ assert min_received_num <= self.sample_client_num
+
+ if check_eval_result:
+ min_received_num = len(list(self.comm_manager.neighbors.keys()))
+
+ move_on_flag = True # To record whether moving to a new training round or finishing the evaluation
+ if self.check_buffer(self.state, min_received_num, check_eval_result):
+
+ if not check_eval_result: # in the training process
+ # Get all the message
+ train_msg_buffer = self.msg_buffer['train'][self.state]
+ for model_idx in range(self.model_num):
+ model = self.models[model_idx]
+ aggregator = self.aggregators[model_idx]
+ msg_list = list()
+ for client_id in train_msg_buffer:
+ if self.model_num == 1:
+ msg_list.append(train_msg_buffer[client_id])
+ else:
+ train_data_size, model_para_multiple = train_msg_buffer[
+ client_id]
+ msg_list.append((train_data_size,
+ model_para_multiple[model_idx]))
+
+ # Trigger the monitor here (for training)
+ if 'dissim' in self._cfg.eval.monitoring:
+ B_val = self._monitor.calc_blocal_dissim(
+ model.load_state_dict(strict=False), msg_list)
+ formatted_eval_res = self._monitor.format_eval_res(
+ B_val, rnd=self.state, role='Server #')
+ logger.info(formatted_eval_res)
+
+ # Aggregate
+ agg_info = {
+ 'client_feedback': msg_list,
+ 'recover_fun': self.recover_fun
+ }
+ result = aggregator.aggregate(agg_info)
+ model.load_state_dict(result, strict=False)
+
+ self.state += 1
+ if self.state % self._cfg.eval.freq == 0 and self.state != self.total_round_num:
+ # Evaluate
+ logger.info(
+ 'Server #{:d}: Starting evaluation at the end of round {:d}.'
+ .format(self.ID, self.state - 1))
+ self.eval()
+
+ if self.state < self.total_round_num:
+ # Move to next round of training
+ logger.info(
+ '----------- Starting a new training round (Round #{:d}) -------------'
+ .format(self.state))
+ # Clean the msg_buffer
+ self.msg_buffer['train'][self.state - 1].clear()
+
+ self.broadcast_model_para(
+ msg_type='model_para',
+ sample_client_num=self.sample_client_num)
+ else:
+ # Final Evaluate
+ logger.info(
+ 'Server #{:d}: Training is finished! Starting evaluation.'
+ .format(self.ID))
+ self.eval()
+
+ else: # in the evaluation process
+ # Get all the message & aggregate
+ formatted_eval_res = self.merge_eval_results_from_all_clients()
+ self.history_results = merge_dict(self.history_results,
+ formatted_eval_res)
+ self.check_and_save()
+
+ else:
+ move_on_flag = False
+
+ return move_on_flag
+
+ def check_and_save(self):
+ """
+ To save the results and save model after each evaluation.
+ """
+
+ # early stopping
+ if "Results_weighted_avg" in self.history_results and \
+ self._cfg.eval.best_res_update_round_wise_key in self.history_results['Results_weighted_avg']:
+ should_stop = self.early_stopper.track_and_check(
+ self.history_results['Results_weighted_avg'][
+ self._cfg.eval.best_res_update_round_wise_key])
+ elif "Results_avg" in self.history_results and \
+ self._cfg.eval.best_res_update_round_wise_key in self.history_results['Results_avg']:
+ should_stop = self.early_stopper.track_and_check(
+ self.history_results['Results_avg'][
+ self._cfg.eval.best_res_update_round_wise_key])
+ else:
+ should_stop = False
+
+ if should_stop:
+ self._monitor.global_converged()
+ self.comm_manager.send(
+ Message(
+ msg_type="converged",
+ sender=self.ID,
+ receiver=list(self.comm_manager.neighbors.keys()),
+ state=self.state,
+ ))
+ self.state = self.total_round_num + 1
+
+ if should_stop or self.state == self.total_round_num:
+ logger.info(
+ 'Server #{:d}: Final evaluation is finished! Starting merging results.'
+ .format(self.ID))
+ # last round or early stopped
+ self.save_best_results()
+ if not self._cfg.federate.make_global_eval:
+ self.save_client_eval_results()
+ self.terminate(msg_type='finish')
+
+ # Clean the clients evaluation msg buffer
+ if not self._cfg.federate.make_global_eval:
+ round = max(self.msg_buffer['eval'].keys())
+ self.msg_buffer['eval'][round].clear()
+
+ if self.state == self.total_round_num:
+ # break out the loop for distributed mode
+ self.state += 1
+
+ def save_best_results(self):
+ """
+ To Save the best evaluation results.
+ """
+
+ if self._cfg.federate.save_to != '':
+ self.aggregator.save_model(self._cfg.federate.save_to, self.state)
+ formatted_best_res = self._monitor.format_eval_res(
+ results=self.best_results,
+ rnd="Final",
+ role='Server #',
+ forms=["raw"],
+ return_raw=True)
+ logger.info(formatted_best_res)
+ self.save_formatted_results(formatted_best_res)
+
+ def save_formatted_results(self, formatted_res):
+ with open(os.path.join(self._cfg.outdir, "eval_results.log"),
+ "a") as outfile:
+ outfile.write(str(formatted_res) + "\n")
+
+ def save_client_eval_results(self):
+ """
+ save the evaluation results of each client when the fl course early stopped or terminated
+
+ :return:
+ """
+ round = max(self.msg_buffer['eval'].keys())
+ eval_msg_buffer = self.msg_buffer['eval'][round]
+
+ with open(os.path.join(self._cfg.outdir, "eval_results.log"),
+ "a") as outfile:
+ for client_id, client_eval_results in eval_msg_buffer.items():
+ formatted_res = self._monitor.format_eval_res(
+ client_eval_results,
+ rnd=self.state,
+ role='Client #{}'.format(client_id),
+ return_raw=True)
+ logger.info(formatted_res)
+ outfile.write(str(formatted_res) + "\n")
+
+ def merge_eval_results_from_all_clients(self):
+ """
+ Merge evaluation results from all clients, update best,
+ log the merged results and save them into eval_results.log
+
+ :returns: the formatted merged results
+ """
+
+ round = max(self.msg_buffer['eval'].keys())
+ eval_msg_buffer = self.msg_buffer['eval'][round]
+ metrics_all_clients = dict()
+ for each_client in eval_msg_buffer:
+ client_eval_results = eval_msg_buffer[each_client]
+ for key in client_eval_results.keys():
+ if key not in metrics_all_clients:
+ metrics_all_clients[key] = list()
+ metrics_all_clients[key].append(float(
+ client_eval_results[key]))
+ formatted_logs = self._monitor.format_eval_res(
+ metrics_all_clients,
+ rnd=self.state,
+ role='Server #',
+ forms=self._cfg.eval.report)
+ logger.info(formatted_logs)
+ update_best_result(self.best_results,
+ metrics_all_clients,
+ results_type="client_individual",
+ round_wise_update_key=self._cfg.eval.
+ best_res_update_round_wise_key)
+ self.save_formatted_results(formatted_logs)
+ for form in self._cfg.eval.report:
+ if form != "raw":
+ update_best_result(self.best_results,
+ formatted_logs[f"Results_{form}"],
+ results_type=f"client_summarized_{form}",
+ round_wise_update_key=self._cfg.eval.
+ best_res_update_round_wise_key)
+
+ return formatted_logs
+
+ def broadcast_model_para(self,
+ msg_type='model_para',
+ sample_client_num=-1):
+ """
+ To broadcast the message to all clients or sampled clients
+
+ Arguments:
+ msg_type: 'model_para' or other user defined msg_type
+ sample_client_num: the number of sampled clients in the broadcast behavior.
+ And sample_client_num = -1 denotes to broadcast to all the clients.
+ """
+
+ if sample_client_num > 0: # only activated at training process
+
+ receiver = np.random.choice(np.arange(1, self.client_num + 1),
+ size=sample_client_num,
+ replace=False).tolist()
+
+ else:
+ # broadcast to all clients
+ receiver = list(self.comm_manager.neighbors.keys())
+
+ if self._noise_injector is not None and msg_type == 'model_para':
+ # Inject noise only when broadcast parameters
+ for model_idx_i in range(len(self.models)):
+ num_sample_clients = [
+ v["num_sample"] for v in self.join_in_info.values()
+ ]
+ self._noise_injector(self._cfg, num_sample_clients,
+ self.models[model_idx_i])
+
+ skip_broadcast = self._cfg.federate.method in ["local", "global"]
+ if self.model_num > 1:
+ model_para = [{} if skip_broadcast else model.state_dict()
+ for model in self.models]
+ else:
+ model_para = {} if skip_broadcast else self.model.state_dict()
+
+ self.comm_manager.send(
+ Message(msg_type=msg_type,
+ sender=self.ID,
+ receiver=receiver,
+ state=min(self.state, self.total_round_num),
+ content=model_para))
+ if self._cfg.federate.online_aggr:
+ for idx in range(self.model_num):
+ self.aggregators[idx].reset()
+
+ def broadcast_client_address(self):
+ """
+ To broadcast the communication addresses of clients (used for additive secret sharing)
+ """
+
+ self.comm_manager.send(
+ Message(msg_type='address',
+ sender=self.ID,
+ receiver=list(self.comm_manager.neighbors.keys()),
+ state=self.state,
+ content=self.comm_manager.get_neighbors()))
+
+ def check_buffer(self,
+ cur_round,
+ min_received_num,
+ check_eval_result=False):
+ """
+ To check the message buffer
+
+ Arguments:
+ cur_round (int): The current round number
+ min_received_num (int): The minimal number of the receiving messages
+ check_eval_result (bool): To check training results for evaluation results
+ :returns: Whether enough messages have been received or not
+ :rtype: bool
+ """
+
+ if check_eval_result:
+ if 'eval' not in self.msg_buffer.keys() or len(
+ self.msg_buffer['eval'].keys()) == 0:
+ return False
+ buffer = self.msg_buffer['eval']
+ cur_round = max(buffer.keys())
+ else:
+ buffer = self.msg_buffer['train']
+
+ if cur_round not in buffer or len(
+ buffer[cur_round]) < min_received_num:
+ return False
+ else:
+ return True
+
+ def check_client_join_in(self):
+ """
+ To check whether all the clients have joined in the FL course.
+ """
+
+ if len(self._cfg.federate.join_in_info) != 0:
+ return len(self.join_in_info) == self.client_num
+ else:
+ return self.join_in_client_num == self.client_num
+
+ def trigger_for_start(self):
+ """
+ To start the FL course when the expected number of clients have joined
+ """
+
+ if self.check_client_join_in():
+ if self._cfg.federate.use_ss:
+ self.broadcast_client_address()
+ logger.info(
+ '----------- Starting training (Round #{:d}) -------------'.
+ format(self.state))
+ self.broadcast_model_para(msg_type='model_para',
+ sample_client_num=self.sample_client_num)
+
+ def terminate(self, msg_type='finish'):
+ """
+ To terminate the FL course
+ """
+ if self.model_num > 1:
+ model_para = [model.state_dict() for model in self.models]
+ else:
+ model_para = self.model.state_dict()
+
+ self._monitor.finish_fl()
+
+ self.comm_manager.send(
+ Message(msg_type=msg_type,
+ sender=self.ID,
+ receiver=list(self.comm_manager.neighbors.keys()),
+ state=self.state,
+ content=model_para))
+
+ def eval(self):
+ """
+ To conduct evaluation. When cfg.federate.make_global_eval=True, a global evaluation is conducted by the server.
+ """
+
+ if self._cfg.federate.make_global_eval:
+ # By default, the evaluation is conducted one-by-one for all internal models;
+ # for other cases such as ensemble, override the eval function
+ for i in range(self.model_num):
+ trainer = self.trainers[i]
+ # Preform evaluation in server
+ metrics = {}
+ for split in self._cfg.eval.split:
+ eval_metrics = trainer.evaluate(
+ target_data_split_name=split)
+ metrics.update(**eval_metrics)
+ formatted_eval_res = self._monitor.format_eval_res(
+ metrics,
+ rnd=self.state,
+ role='Server #',
+ forms=self._cfg.eval.report,
+ return_raw=self._cfg.federate.make_global_eval)
+ update_best_result(self.best_results,
+ formatted_eval_res['Results_raw'],
+ results_type="server_global_eval",
+ round_wise_update_key=self._cfg.eval.
+ best_res_update_round_wise_key)
+ self.history_results = merge_dict(self.history_results,
+ formatted_eval_res)
+ self.save_formatted_results(formatted_eval_res)
+ logger.info(formatted_eval_res)
+ self.check_and_save()
+ else:
+ # Preform evaluation in clients
+ self.broadcast_model_para(msg_type='evaluate')
+
+ def callback_funcs_model_para(self, message: Message):
+ """
+ The handling function for receiving model parameters, which triggers check_and_move_on
+ (perform aggregation when enough feedback has been received).
+ This handling function is widely used in various FL courses.
+
+ Arguments:
+ message: The received message, which includes sender, receiver, state, and content.
+ More detail can be found in federatedscope.core.message
+ """
+
+ round, sender, content = message.state, message.sender, message.content
+ # For a new round
+ if round not in self.msg_buffer['train'].keys():
+ self.msg_buffer['train'][round] = dict()
+
+ self.msg_buffer['train'][round][sender] = content
+
+ if self._cfg.federate.online_aggr:
+ self.aggregator.inc(content)
+
+ return self.check_and_move_on()
+
+ def callback_funcs_for_join_in(self, message: Message):
+ """
+ The handling function for receiving the join in information. The server might request for some information
+ (such as num_of_samples) if necessary, assign IDs for the servers.
+ If all the clients have joined in, the training process will be triggered.
+
+ Arguments:
+ message: The received message
+ """
+
+ if 'info' in message.msg_type:
+ sender, info = message.sender, message.content
+ for key in self._cfg.federate.join_in_info:
+ assert key in info
+ self.join_in_info[sender] = info
+ logger.info('Server #{:d}: Client #{:d} has joined in !'.format(
+ self.ID, sender))
+ else:
+ self.join_in_client_num += 1
+ sender, address = message.sender, message.content
+ if int(sender) == -1: # assign number to client
+ sender = self.join_in_client_num
+ self.comm_manager.add_neighbors(neighbor_id=sender,
+ address=address)
+ self.comm_manager.send(
+ Message(msg_type='assign_client_id',
+ sender=self.ID,
+ receiver=[sender],
+ state=self.state,
+ content=str(sender)))
+ else:
+ self.comm_manager.add_neighbors(neighbor_id=sender,
+ address=address)
+
+ if len(self._cfg.federate.join_in_info) != 0:
+ self.comm_manager.send(
+ Message(msg_type='ask_for_join_in_info',
+ sender=self.ID,
+ receiver=[sender],
+ state=self.state,
+ content=self._cfg.federate.join_in_info.copy()))
+
+ self.trigger_for_start()
+
+ def callback_funcs_for_metrics(self, message: Message):
+ """
+ The handling function for receiving the evaluation results, which triggers check_and_move_on
+ (perform aggregation when enough feedback has been received).
+
+ Arguments:
+ message: The received message
+ """
+
+ round, sender, content = message.state, message.sender, message.content
+
+ if round not in self.msg_buffer['eval'].keys():
+ self.msg_buffer['eval'][round] = dict()
+
+ self.msg_buffer['eval'][round][sender] = content
+
+ return self.check_and_move_on(check_eval_result=True)
diff --git a/federatedscope/cross_backends/__init__.py b/federatedscope/cross_backends/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/federatedscope/cross_backends/tf_aggregator.py b/federatedscope/cross_backends/tf_aggregator.py
new file mode 100644
index 000000000..6c59fb335
--- /dev/null
+++ b/federatedscope/cross_backends/tf_aggregator.py
@@ -0,0 +1,44 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+from copy import deepcopy
+import numpy as np
+
+
+class FedAvgAggregator(object):
+ def __init__(self, model=None, device='cpu'):
+ self.model = model
+ self.device = device
+
+ def aggregate(self, agg_info):
+ models = agg_info["client_feedback"]
+ avg_model = self._para_weighted_avg(models)
+
+ return avg_model
+
+ def _para_weighted_avg(self, models):
+
+ training_set_size = 0
+ for i in range(len(models)):
+ sample_size, _ = models[i]
+ training_set_size += sample_size
+
+ sample_size, avg_model = models[0]
+ for key in avg_model:
+ for i in range(len(models)):
+ local_sample_size, local_model = models[i]
+ weight = local_sample_size / training_set_size
+ if i == 0:
+ avg_model[key] = np.asarray(local_model[key]) * weight
+ else:
+ avg_model[key] += np.asarray(local_model[key]) * weight
+
+ return avg_model
+
+ def update(self, model_parameters):
+ '''
+ Arguments:
+ model_parameters (dict): PyTorch Module object's state_dict.
+ '''
+ self.model.load_state_dict(model_parameters)
diff --git a/federatedscope/cross_backends/tf_lr.py b/federatedscope/cross_backends/tf_lr.py
new file mode 100644
index 000000000..2777562c5
--- /dev/null
+++ b/federatedscope/cross_backends/tf_lr.py
@@ -0,0 +1,81 @@
+import tensorflow as tf
+import numpy as np
+
+
+class LogisticRegression(object):
+ def __init__(self, in_channels, class_num, use_bias=True):
+
+ self.input_x = tf.placeholder(tf.float32, [None, in_channels],
+ name='input_x')
+ self.input_y = tf.placeholder(tf.float32, [None, 1], name='input_y')
+
+ self.out = self.fc_layer(input_x=self.input_x,
+ in_channels=in_channels,
+ class_num=class_num,
+ use_bias=use_bias)
+
+ with tf.name_scope('loss'):
+ self.losses = tf.losses.mean_squared_error(predictions=self.out,
+ labels=self.input_y)
+
+ with tf.name_scope('train_op'):
+ self.optimizer = tf.train.GradientDescentOptimizer(
+ learning_rate=0.001)
+ self.train_op = self.optimizer.minimize(self.losses)
+
+ self.sess = tf.Session()
+ self.graph = tf.get_default_graph()
+
+ with self.graph.as_default():
+ with self.sess.as_default():
+ tf.global_variables_initializer().run()
+
+ def fc_layer(self, input_x, in_channels, class_num, use_bias=True):
+ with tf.name_scope('fc'):
+ fc_w = tf.Variable(tf.truncated_normal([in_channels, class_num],
+ stddev=0.1),
+ name='weight')
+ if use_bias:
+ fc_b = tf.Variable(tf.constant(0.0, shape=[
+ class_num,
+ ]),
+ name='bias')
+ fc_out = tf.nn.bias_add(tf.matmul(input_x, fc_w), fc_b)
+ else:
+ fc_out = tf.matmul(input_x, fc_w)
+
+ return fc_out
+
+ def to(self, device):
+ pass
+
+ def trainable_variables(self):
+ return tf.trainable_variables()
+
+ def state_dict(self):
+ with self.graph.as_default():
+ with self.sess.as_default():
+ model_param = list()
+ param_name = list()
+ for var in tf.global_variables():
+ param = self.graph.get_tensor_by_name(var.name).eval()
+ if 'weight' in var.name:
+ param = np.transpose(param, (1, 0))
+ model_param.append(param)
+ param_name.append(var.name.split(':')[0].replace("/", '.'))
+
+ model_dict = {k: v for k, v in zip(param_name, model_param)}
+
+ return model_dict
+
+ def load_state_dict(self, model_para, strict=False):
+ with self.graph.as_default():
+ with self.sess.as_default():
+ for name in model_para.keys():
+ new_param = model_para[name]
+
+ param = self.graph.get_tensor_by_name(
+ name.replace('.', '/') + (':0'))
+ if 'weight' in name:
+ new_param = np.transpose(new_param, (1, 0))
+ tf.assign(param, new_param).eval()
diff --git a/federatedscope/cv/__init__.py b/federatedscope/cv/__init__.py
new file mode 100644
index 000000000..f8e91f237
--- /dev/null
+++ b/federatedscope/cv/__init__.py
@@ -0,0 +1,3 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
diff --git a/federatedscope/cv/dataloader/__init__.py b/federatedscope/cv/dataloader/__init__.py
new file mode 100644
index 000000000..1b4d542c4
--- /dev/null
+++ b/federatedscope/cv/dataloader/__init__.py
@@ -0,0 +1,3 @@
+from federatedscope.cv.dataloader.dataloader import load_cv_dataset
+
+__all__ = ['load_cv_dataset']
\ No newline at end of file
diff --git a/federatedscope/cv/dataloader/dataloader.py b/federatedscope/cv/dataloader/dataloader.py
new file mode 100644
index 000000000..706d22c34
--- /dev/null
+++ b/federatedscope/cv/dataloader/dataloader.py
@@ -0,0 +1,83 @@
+from random import sample
+from torch.utils.data import DataLoader, Dataset
+
+from federatedscope.cv.dataset.leaf_cv import LEAF_CV
+from federatedscope.core.auxiliaries.transform_builder import get_transform
+
+# from torch.utils.data import Dataset
+
+
+class TensorDataset(Dataset):
+ '''
+ tuple tensor is a list
+ '''
+ def __init__(self, tuple_tensor):
+ self.data_tensor = tuple_tensor
+
+ def __getitem__(self, index):
+ sample, target = self.data_tensor[index]
+ return sample, target
+
+ def __len__(self):
+ return len(self.data_tensor)
+
+
+def load_cv_dataset(config=None):
+ r"""
+ return {
+ 'client_id': {
+ 'train': DataLoader(),
+ 'test': DataLoader(),
+ 'val': DataLoader()
+ }
+ }
+ or return
+ dataset
+ """
+ splits = config.data.splits
+
+ path = config.data.root
+ name = config.data.type.lower()
+ batch_size = config.data.batch_size
+ transforms_funcs = get_transform(config, 'torchvision')
+
+ if name in ['femnist', 'celeba']:
+ dataset = LEAF_CV(root=path,
+ name=name,
+ s_frac=config.data.subsample,
+ tr_frac=splits[0],
+ val_frac=splits[1],
+ seed=1234,
+ **transforms_funcs)
+ else:
+ raise ValueError(f'No dataset named: {name}!')
+
+ client_num = min(len(dataset), config.federate.client_num
+ ) if config.federate.client_num > 0 else len(dataset)
+ config.merge_from_list(['federate.client_num', client_num])
+
+ # get local dataset
+ data_local_dict = dict()
+ for client_idx in range(client_num):
+
+ dataloader = {
+ 'train': DataLoader(dataset[client_idx]['train'],
+ batch_size,
+ shuffle=config.data.shuffle,
+ num_workers=config.data.num_workers),
+ 'test': DataLoader(dataset[client_idx]['test'],
+ batch_size,
+ shuffle=False,
+ num_workers=config.data.num_workers)
+ }
+ if 'val' in dataset[client_idx]:
+ dataloader['val'] = DataLoader(dataset[client_idx]['val'],
+ batch_size,
+ shuffle=False,
+ num_workers=config.data.num_workers)
+
+ data_local_dict[client_idx + 1] = dataloader
+ # we can return two forms, dataset or dataloader based on te needs of users.
+ #
+
+ return data_local_dict, config
diff --git a/federatedscope/cv/dataset/__init__.py b/federatedscope/cv/dataset/__init__.py
new file mode 100644
index 000000000..42638817a
--- /dev/null
+++ b/federatedscope/cv/dataset/__init__.py
@@ -0,0 +1,8 @@
+from os.path import dirname, basename, isfile, join
+import glob
+
+modules = glob.glob(join(dirname(__file__), "*.py"))
+__all__ = [
+ basename(f)[:-3] for f in modules
+ if isfile(f) and not f.endswith('__init__.py')
+]
\ No newline at end of file
diff --git a/federatedscope/cv/dataset/leaf.py b/federatedscope/cv/dataset/leaf.py
new file mode 100644
index 000000000..bd1c70b2a
--- /dev/null
+++ b/federatedscope/cv/dataset/leaf.py
@@ -0,0 +1,86 @@
+import zipfile
+import os
+import os.path as osp
+
+from torch.utils.data import Dataset
+
+LEAF_NAMES = [
+ 'femnist', 'celeba', 'synthetic', 'shakespeare', 'twitter', 'subreddit'
+]
+
+
+def is_exists(path, names):
+ exists_list = [osp.exists(osp.join(path, name)) for name in names]
+ return False not in exists_list
+
+
+class LEAF(Dataset):
+ """Base class for LEAF dataset from "LEAF: A Benchmark for Federated Settings"
+
+ Arguments:
+ root (str): root path.
+ name (str): name of dataset, in `LEAF_NAMES`.
+ transform: transform for x.
+ target_transform: transform for y.
+
+ """
+ def __init__(self, root, name, transform, target_transform):
+ self.root = root
+ self.name = name
+ self.data_dict = {}
+ if name not in LEAF_NAMES:
+ raise ValueError(f'No leaf dataset named {self.name}')
+ self.transform = transform
+ self.target_transform = target_transform
+ self.process_file()
+
+ @property
+ def raw_file_names(self):
+ names = ['all_data.zip']
+ return names
+
+ @property
+ def extracted_file_names(self):
+ names = ['all_data']
+ return names
+
+ @property
+ def raw_dir(self):
+ return osp.join(self.root, self.name, 'raw')
+
+ @property
+ def processed_dir(self):
+ return osp.join(self.root, self.name, 'processed')
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}({self.__len__()})'
+
+ def __len__(self):
+ return len(self.data_dict)
+
+ def __getitem__(self, index):
+ raise NotImplementedError
+
+ def __iter__(self):
+ for index in range(len(self.data_dict)):
+ yield self.__getitem__(index)
+
+ def download(self):
+ raise NotImplementedError
+
+ def extract(self):
+ for name in self.raw_file_names:
+ with zipfile.ZipFile(osp.join(self.raw_dir, name), 'r') as f:
+ f.extractall(self.raw_dir)
+
+ def process_file(self):
+ os.makedirs(self.processed_dir, exist_ok=True)
+ if len(os.listdir(self.processed_dir)) == 0:
+ if not is_exists(self.raw_dir, self.extracted_file_names):
+ if not is_exists(self.raw_dir, self.raw_file_names):
+ self.download()
+ self.extract()
+ self.process()
+
+ def process(self):
+ raise NotImplementedError
diff --git a/federatedscope/cv/dataset/leaf_cv.py b/federatedscope/cv/dataset/leaf_cv.py
new file mode 100644
index 000000000..f1be136e0
--- /dev/null
+++ b/federatedscope/cv/dataset/leaf_cv.py
@@ -0,0 +1,179 @@
+import os
+import random
+import json
+import torch
+import math
+
+import numpy as np
+import os.path as osp
+
+from PIL import Image
+from tqdm import tqdm
+
+from sklearn.model_selection import train_test_split
+
+from federatedscope.core.auxiliaries.utils import save_local_data, download_url
+from federatedscope.cv.dataset.leaf import LEAF
+
+IMAGE_SIZE = {'femnist': (28, 28), 'celeba': (84, 84, 3)}
+MODE = {'femnist': 'L', 'celeba': 'RGB'}
+
+
+class LEAF_CV(LEAF):
+ """
+ LEAF CV dataset from "LEAF: A Benchmark for Federated Settings"
+
+ leaf.cmu.edu
+
+ Arguments:
+ root (str): root path.
+ name (str): name of dataset, ‘femnist’ or ‘celeba’.
+ s_frac (float): fraction of the dataset to be used; default=0.3.
+ tr_frac (float): train set proportion for each task; default=0.8.
+ val_frac (float): valid set proportion for each task; default=0.0.
+ train_tasks_frac (float): fraction of test tasks; default=1.0.
+ transform: transform for x.
+ target_transform: transform for y.
+
+ """
+ def __init__(self,
+ root,
+ name,
+ s_frac=0.3,
+ tr_frac=0.8,
+ val_frac=0.0,
+ train_tasks_frac=1.0,
+ seed=123,
+ transform=None,
+ target_transform=None):
+ self.s_frac = s_frac
+ self.tr_frac = tr_frac
+ self.val_frac = val_frac
+ self.seed = seed
+ self.train_tasks_frac = train_tasks_frac
+ super(LEAF_CV, self).__init__(root, name, transform, target_transform)
+ files = os.listdir(self.processed_dir)
+ files = [f for f in files if f.startswith('task_')]
+ if len(files):
+ # Sort by idx
+ files.sort(key=lambda k: int(k[5:]))
+
+ for file in files:
+ train_data, train_targets = torch.load(
+ osp.join(self.processed_dir, file, 'train.pt'))
+ test_data, test_targets = torch.load(
+ osp.join(self.processed_dir, file, 'test.pt'))
+ self.data_dict[int(file[5:])] = {
+ 'train': (train_data, train_targets),
+ 'test': (test_data, test_targets)
+ }
+ if osp.exists(osp.join(self.processed_dir, file, 'val.pt')):
+ val_data, val_targets = torch.load(
+ osp.join(self.processed_dir, file, 'val.pt'))
+ self.data_dict[int(file[5:])]['val'] = (val_data,
+ val_targets)
+ else:
+ raise RuntimeError(
+ 'Please delete ‘processed’ folder and try again!')
+
+ @property
+ def raw_file_names(self):
+ names = [f'{self.name}_all_data.zip']
+ return names
+
+ def download(self):
+ # Download to `self.raw_dir`.
+ url = 'https://federatedscope.oss-cn-beijing.aliyuncs.com'
+ os.makedirs(self.raw_dir, exist_ok=True)
+ for name in self.raw_file_names:
+ download_url(f'{url}/{name}', self.raw_dir)
+
+ def __getitem__(self, index):
+ """
+ Arguments:
+ index (int): Index
+
+ :returns:
+ dict: {'train':[(image, target)],
+ 'test':[(image, target)],
+ 'val':[(image, target)]}
+ where target is the target class.
+ """
+ img_dict = {}
+ data = self.data_dict[index]
+ for key in data:
+ img_dict[key] = []
+ imgs, targets = data[key]
+ for idx in range(targets.shape[0]):
+ img = np.resize(imgs[idx].numpy().astype(np.uint8),
+ IMAGE_SIZE[self.name])
+ img = Image.fromarray(img, mode=MODE[self.name])
+ target = targets[idx]
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ img_dict[key].append((img, targets[idx]))
+
+ return img_dict
+
+ def process(self):
+ raw_path = osp.join(self.raw_dir, "all_data")
+ files = os.listdir(raw_path)
+ files = [f for f in files if f.endswith('.json')]
+
+ n_tasks = math.ceil(len(files) * self.s_frac)
+ random.shuffle(files)
+ files = files[:n_tasks]
+
+ print("Preprocess data (Please leave enough space)...")
+
+ idx = 0
+ for num, file in enumerate(tqdm(files)):
+
+ with open(osp.join(raw_path, file), 'r') as f:
+ raw_data = json.load(f)
+
+ # Numpy to Tensor
+ for writer, v in raw_data['user_data'].items():
+ data, targets = v['x'], v['y']
+
+ if len(v['x']) > 2:
+ data = torch.tensor(np.stack(data))
+ targets = torch.LongTensor(np.stack(targets))
+ else:
+ data = torch.tensor(data)
+ targets = torch.LongTensor(targets)
+
+ train_data, test_data, train_targets, test_targets =\
+ train_test_split(
+ data,
+ targets,
+ train_size=self.tr_frac,
+ random_state=self.seed
+ )
+
+ if self.val_frac > 0:
+ val_data, test_data, val_targets, test_targets = \
+ train_test_split(
+ test_data,
+ test_targets,
+ train_size=self.val_frac / (1.-self.tr_frac),
+ random_state=self.seed
+ )
+
+ else:
+ val_data, val_targets = None, None
+ save_path = osp.join(self.processed_dir, f"task_{idx}")
+ os.makedirs(save_path, exist_ok=True)
+
+ save_local_data(dir_path=save_path,
+ train_data=train_data,
+ train_targets=train_targets,
+ test_data=test_data,
+ test_targets=test_targets,
+ val_data=val_data,
+ val_targets=val_targets)
+ idx += 1
diff --git a/federatedscope/cv/model/__init__.py b/federatedscope/cv/model/__init__.py
new file mode 100644
index 000000000..7144c1a88
--- /dev/null
+++ b/federatedscope/cv/model/__init__.py
@@ -0,0 +1,8 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+from federatedscope.cv.model.cnn import ConvNet2, ConvNet2, VGG11
+from federatedscope.cv.model.model_builder import get_cnn
+
+__all__ = ['ConvNet2', 'ConvNet2', 'VGG11', 'get_cnn']
\ No newline at end of file
diff --git a/federatedscope/cv/model/cnn.py b/federatedscope/cv/model/cnn.py
new file mode 100644
index 000000000..595fef772
--- /dev/null
+++ b/federatedscope/cv/model/cnn.py
@@ -0,0 +1,204 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch.nn import Module
+from torch.nn import Sequential
+from torch.nn import Conv2d, BatchNorm2d
+from torch.nn import Flatten
+from torch.nn import Linear
+from torch.nn import MaxPool2d
+from torch.nn import ReLU
+
+
+class ConvNet2(Module):
+ def __init__(self,
+ in_channels,
+ h=32,
+ w=32,
+ hidden=2048,
+ class_num=10,
+ use_bn=True,
+ dropout=.0):
+ super(ConvNet2, self).__init__()
+
+ self.conv1 = Conv2d(in_channels, 32, 5, padding=2)
+ self.conv2 = Conv2d(32, 64, 5, padding=2)
+ self.use_bn = use_bn
+ if use_bn:
+ self.bn1 = BatchNorm2d(32)
+ self.bn2 = BatchNorm2d(64)
+
+ self.fc1 = Linear((h // 2 // 2) * (w // 2 // 2) * 64, hidden)
+ self.fc2 = Linear(hidden, class_num)
+
+ self.relu = ReLU(inplace=True)
+ self.maxpool = MaxPool2d(2)
+ self.dropout = dropout
+
+ def feature(self, x):
+ x = self.bn1(self.conv1(x)) if self.use_bn else self.conv1(x)
+ x = self.maxpool(self.relu(x))
+ x = self.bn2(self.conv2(x)) if self.use_bn else self.conv2(x)
+ x = self.maxpool(self.relu(x))
+ x = Flatten()(x)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = self.relu(self.fc1(x))
+
+ return x
+
+ def forward(self, x):
+ x = self.bn1(self.conv1(x)) if self.use_bn else self.conv1(x)
+ x = self.maxpool(self.relu(x))
+ x = self.bn2(self.conv2(x)) if self.use_bn else self.conv2(x)
+ x = self.maxpool(self.relu(x))
+ x = Flatten()(x)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = self.relu(self.fc1(x))
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = self.fc2(x)
+
+ return x
+
+
+class ConvNet5(Module):
+ def __init__(self,
+ in_channels,
+ h=32,
+ w=32,
+ hidden=2048,
+ class_num=10,
+ dropout=.0):
+ super(ConvNet5, self).__init__()
+
+ self.conv1 = Conv2d(in_channels, 32, 5, padding=2)
+ self.bn1 = BatchNorm2d(32)
+
+ self.conv2 = Conv2d(32, 64, 5, padding=2)
+ self.bn2 = BatchNorm2d(64)
+
+ self.conv3 = Conv2d(64, 64, 5, padding=2)
+ self.bn3 = BatchNorm2d(64)
+
+ self.conv4 = Conv2d(64, 128, 5, padding=2)
+ self.bn4 = BatchNorm2d(128)
+
+ self.conv5 = Conv2d(128, 128, 5, padding=2)
+ self.bn5 = BatchNorm2d(128)
+
+ self.relu = ReLU(inplace=True)
+ self.maxpool = MaxPool2d(2)
+
+ self.fc1 = Linear(
+ (h // 2 // 2 // 2 // 2 // 2) * (w // 2 // 2 // 2 // 2 // 2) * 128,
+ hidden)
+ self.fc2 = Linear(hidden, class_num)
+
+ self.dropout = dropout
+
+ def forward(self, x):
+ x = self.relu(self.bn1(self.conv1(x)))
+ x = self.maxpool(x)
+
+ x = self.relu(self.bn2(self.conv2(x)))
+ x = self.maxpool(x)
+
+ x = self.relu(self.bn3(self.conv3(x)))
+ x = self.maxpool(x)
+
+ x = self.relu(self.bn4(self.conv4(x)))
+ x = self.maxpool(x)
+
+ x = self.relu(self.bn5(self.conv5(x)))
+ x = self.maxpool(x)
+
+ x = Flatten()(x)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = self.relu(self.fc1(x))
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = self.fc2(x)
+
+ return x
+
+
+class VGG11(Module):
+ def __init__(self,
+ in_channels,
+ h=32,
+ w=32,
+ hidden=128,
+ class_num=10,
+ dropout=.0):
+ super(VGG11, self).__init__()
+
+ cfg = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+
+ self.conv1 = Conv2d(in_channels, 64, 3, padding=1)
+ self.bn1 = BatchNorm2d(64)
+
+ self.conv2 = Conv2d(64, 128, 3, padding=1)
+ self.bn2 = BatchNorm2d(128)
+
+ self.conv3 = Conv2d(128, 256, 3, padding=1)
+ self.bn3 = BatchNorm2d(256)
+
+ self.conv4 = Conv2d(256, 256, 3, padding=1)
+ self.bn4 = BatchNorm2d(256)
+
+ self.conv5 = Conv2d(256, 512, 3, padding=1)
+ self.bn5 = BatchNorm2d(512)
+
+ self.conv6 = Conv2d(512, 512, 3, padding=1)
+ self.bn6 = BatchNorm2d(512)
+
+ self.conv7 = Conv2d(512, 512, 3, padding=1)
+ self.bn7 = BatchNorm2d(512)
+
+ self.conv8 = Conv2d(512, 512, 3, padding=1)
+ self.bn8 = BatchNorm2d(512)
+
+ self.relu = ReLU(inplace=True)
+ self.maxpool = MaxPool2d(2)
+
+ self.fc1 = Linear(
+ (h // 2 // 2 // 2 // 2 // 2) * (w // 2 // 2 // 2 // 2 // 2) * 512,
+ hidden)
+ self.fc2 = Linear(hidden, hidden)
+ self.fc3 = Linear(hidden, class_num)
+
+ self.dropout = dropout
+
+ def forward(self, x):
+ x = self.relu(self.bn1(self.conv1(x)))
+ x = self.maxpool(x)
+
+ x = self.relu(self.bn2(self.conv2(x)))
+ x = self.maxpool(x)
+
+ x = self.relu(self.bn3(self.conv3(x)))
+ x = self.maxpool(x)
+
+ x = self.relu(self.bn4(self.conv4(x)))
+ x = self.maxpool(x)
+
+ x = self.relu(self.bn5(self.conv5(x)))
+ x = self.maxpool(x)
+
+ x = self.relu(self.bn6(self.conv6(x)))
+ x = self.maxpool(x)
+
+ x = self.relu(self.bn7(self.conv7(x)))
+ x = self.maxpool(x)
+
+ x = self.relu(self.bn8(self.conv8(x)))
+ x = self.maxpool(x)
+
+ x = Flatten()(x)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = self.relu(self.fc1(x))
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = self.relu(self.fc2(x))
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = self.fc3(x)
+
+ return x
diff --git a/federatedscope/cv/model/model_builder.py b/federatedscope/cv/model/model_builder.py
new file mode 100644
index 000000000..2d45e8c07
--- /dev/null
+++ b/federatedscope/cv/model/model_builder.py
@@ -0,0 +1,49 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+from federatedscope.cv.model.cnn import ConvNet2, ConvNet5, VGG11
+
+
+def get_cnn(model_config, local_data):
+ if isinstance(local_data, dict):
+ if 'data' in local_data.keys():
+ data = local_data['data']
+ elif 'train' in local_data.keys():
+ # local_data['train'] is Dataloader
+ data = next(iter(local_data['train']))
+ else:
+ raise TypeError('Unsupported data type.')
+ else:
+ data = local_data
+
+ x, _ = data
+ #
+ # if len(list[x.shape]) == 3:
+ # x = x.unsqueeze(0)
+ # check the task
+ if model_config.type == 'convnet2':
+ model = ConvNet2(in_channels=x.shape[-3],
+ h=x.shape[-2],
+ w=x.shape[-1],
+ hidden=model_config.hidden,
+ class_num=model_config.out_channels,
+ dropout=model_config.dropout)
+ elif model_config.type == 'convnet5':
+ model = ConvNet5(in_channels=x.shape[-3],
+ h=x.shape[-2],
+ w=x.shape[-1],
+ hidden=model_config.hidden,
+ class_num=model_config.out_channels,
+ dropout=model_config.dropout)
+ elif model_config.type == 'vgg11':
+ model = VGG11(in_channels=x.shape[-3],
+ h=x.shape[-2],
+ w=x.shape[-1],
+ hidden=model_config.hidden,
+ class_num=model_config.out_channels,
+ dropout=model_config.dropout)
+ else:
+ raise ValueError(f'No model named {model_config.type}!')
+
+ return model
diff --git a/federatedscope/cv/trainer/__init__.py b/federatedscope/cv/trainer/__init__.py
new file mode 100644
index 000000000..8ec5ce7cd
--- /dev/null
+++ b/federatedscope/cv/trainer/__init__.py
@@ -0,0 +1,30 @@
+"""
+Copyright (c) 2021 Matthias Fey, Jiaxuan You
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+from os.path import dirname, basename, isfile, join
+import glob
+
+modules = glob.glob(join(dirname(__file__), "*.py"))
+__all__ = [
+ basename(f)[:-3] for f in modules
+ if isfile(f) and not f.endswith('__init__.py')
+]
diff --git a/federatedscope/cv/trainer/trainer.py b/federatedscope/cv/trainer/trainer.py
new file mode 100644
index 000000000..b00ebd66a
--- /dev/null
+++ b/federatedscope/cv/trainer/trainer.py
@@ -0,0 +1,15 @@
+from federatedscope.register import register_trainer
+from federatedscope.core.trainers import GeneralTorchTrainer
+
+
+class CVTrainer(GeneralTorchTrainer):
+ pass
+
+
+def call_cv_trainer(trainer_type):
+ if trainer_type == 'cvtrainer':
+ trainer_builder = CVTrainer
+ return trainer_builder
+
+
+register_trainer('cvtrainer', call_cv_trainer)
diff --git a/federatedscope/gfl/__init__.py b/federatedscope/gfl/__init__.py
new file mode 100644
index 000000000..f8e91f237
--- /dev/null
+++ b/federatedscope/gfl/__init__.py
@@ -0,0 +1,3 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
diff --git a/federatedscope/gfl/baseline/__init__.py b/federatedscope/gfl/baseline/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/federatedscope/gfl/dataloader/__init__.py b/federatedscope/gfl/dataloader/__init__.py
new file mode 100644
index 000000000..31a526c3c
--- /dev/null
+++ b/federatedscope/gfl/dataloader/__init__.py
@@ -0,0 +1,8 @@
+from federatedscope.gfl.dataloader.dataloader_node import load_nodelevel_dataset
+from federatedscope.gfl.dataloader.dataloader_graph import load_graphlevel_dataset
+from federatedscope.gfl.dataloader.dataloader_link import load_linklevel_dataset
+
+__all__ = [
+ 'load_nodelevel_dataset', 'load_graphlevel_dataset',
+ 'load_linklevel_dataset'
+]
diff --git a/federatedscope/gfl/dataloader/dataloader_graph.py b/federatedscope/gfl/dataloader/dataloader_graph.py
new file mode 100644
index 000000000..13b260e54
--- /dev/null
+++ b/federatedscope/gfl/dataloader/dataloader_graph.py
@@ -0,0 +1,155 @@
+import numpy as np
+
+from torch_geometric import transforms
+from torch_geometric.loader import DataLoader
+from torch_geometric.datasets import TUDataset, MoleculeNet
+
+from federatedscope.core.auxiliaries.splitter_builder import get_splitter
+from federatedscope.core.auxiliaries.transform_builder import get_transform
+
+
+def get_numGraphLabels(dataset):
+ s = set()
+ for g in dataset:
+ s.add(g.y.item())
+ return len(s)
+
+
+def load_graphlevel_dataset(config=None):
+ r"""Convert dataset to Dataloader.
+ :returns:
+ data_local_dict
+ :rtype: Dict {
+ 'client_id': {
+ 'train': DataLoader(),
+ 'val': DataLoader(),
+ 'test': DataLoader()
+ }
+ }
+ """
+ splits = config.data.splits
+ path = config.data.root
+ name = config.data.type.upper()
+ client_num = config.federate.client_num
+ batch_size = config.data.batch_size
+
+ # Splitter
+ splitter = get_splitter(config)
+
+ # Transforms
+ transforms_funcs = get_transform(config, 'torch_geometric')
+
+ if name in [
+ 'MUTAG', 'BZR', 'COX2', 'DHFR', 'PTC_MR', 'AIDS', 'NCI1',
+ 'ENZYMES', 'DD', 'PROTEINS', 'COLLAB', 'IMDB-BINARY', 'IMDB-MULTI',
+ 'REDDIT-BINARY'
+ ]:
+ # Add feat for datasets without attrubute
+ if name in ['IMDB-BINARY', 'IMDB-MULTI'
+ ] and 'pre_transform' not in transforms_funcs:
+ transforms_funcs['pre_transform'] = transforms.Constant(value=1.0,
+ cat=False)
+ dataset = TUDataset(path, name, **transforms_funcs)
+ if splitter is None:
+ raise ValueError('Please set the graph.')
+ dataset = splitter(dataset)
+
+ elif name in [
+ 'HIV', 'ESOL', 'FREESOLV', 'LIPO', 'PCBA', 'MUV', 'BACE', 'BBBP',
+ 'TOX21', 'TOXCAST', 'SIDER', 'CLINTOX'
+ ]:
+ dataset = MoleculeNet(path, name, **transforms_funcs)
+ if splitter is None:
+ raise ValueError('Please set the graph.')
+ dataset = splitter(dataset)
+ elif name.startswith('graph_multi_domain'.upper()):
+ if name.endswith('mol'.upper()):
+ dnames = ['MUTAG', 'BZR', 'COX2', 'DHFR', 'PTC_MR', 'AIDS', 'NCI1']
+ elif name.endswith('small'.upper()):
+ dnames = [
+ 'MUTAG', 'BZR', 'COX2', 'DHFR', 'PTC_MR', 'ENZYMES', 'DD',
+ 'PROTEINS'
+ ]
+ elif name.endswith('mix'.upper()):
+ if 'pre_transform' not in transforms_funcs:
+ raise ValueError(f'pre_transform is None!')
+ dnames = [
+ 'MUTAG', 'BZR', 'COX2', 'DHFR', 'PTC_MR', 'AIDS', 'NCI1',
+ 'ENZYMES', 'DD', 'PROTEINS', 'COLLAB', 'IMDB-BINARY',
+ 'IMDB-MULTI'
+ ]
+ elif name.endswith('biochem'.upper()):
+ dnames = [
+ 'MUTAG', 'BZR', 'COX2', 'DHFR', 'PTC_MR', 'AIDS', 'NCI1',
+ 'ENZYMES', 'DD', 'PROTEINS'
+ ]
+ # We provide kddcup dataset here.
+ elif name.endswith('kddcupv1'.upper()):
+ dnames = [
+ 'MUTAG', 'BZR', 'COX2', 'DHFR', 'PTC_MR', 'AIDS', 'NCI1',
+ 'Mutagenicity', 'NCI109', 'PTC_MM', 'PTC_FR'
+ ]
+ elif name.endswith('kddcupv2'.upper()):
+ dnames = ['TBD']
+ else:
+ raise ValueError(f'No dataset named: {name}!')
+ dataset = []
+ # Some datasets contain x
+ for dname in dnames:
+ if dname.startswith('IMDB') or dname == 'COLLAB':
+ tmp_dataset = TUDataset(path, dname, **transforms_funcs)
+ else:
+ tmp_dataset = TUDataset(
+ path,
+ dname,
+ pre_transform=None,
+ transform=transforms_funcs['transform']
+ if 'transform' in transforms_funcs else None)
+ dataset.append(tmp_dataset)
+ else:
+ raise ValueError(f'No dataset named: {name}!')
+
+ client_num = min(len(dataset), config.federate.client_num
+ ) if config.federate.client_num > 0 else len(dataset)
+ config.merge_from_list(['federate.client_num', client_num])
+
+ # get local dataset
+ data_local_dict = dict()
+
+ # Build train/valid/test dataloader
+ raw_train = []
+ raw_valid = []
+ raw_test = []
+ for client_idx, gs in enumerate(dataset):
+ index = np.random.permutation(np.arange(len(gs)))
+ train_idx = index[:int(len(gs) * splits[0])]
+ valid_idx = index[int(len(gs) *
+ splits[0]):int(len(gs) * sum(splits[:2]))]
+ test_idx = index[int(len(gs) * sum(splits[:2])):]
+ dataloader = {
+ 'num_label': get_numGraphLabels(gs),
+ 'train': DataLoader([gs[idx] for idx in train_idx],
+ batch_size,
+ shuffle=True,
+ num_workers=config.data.num_workers),
+ 'val': DataLoader([gs[idx] for idx in valid_idx],
+ batch_size,
+ shuffle=False,
+ num_workers=config.data.num_workers),
+ 'test': DataLoader([gs[idx] for idx in test_idx],
+ batch_size,
+ shuffle=False,
+ num_workers=config.data.num_workers),
+ }
+ data_local_dict[client_idx + 1] = dataloader
+ raw_train = raw_train + [gs[idx] for idx in train_idx]
+ raw_valid = raw_valid + [gs[idx] for idx in valid_idx]
+ raw_test = raw_test + [gs[idx] for idx in test_idx]
+ if not name.startswith('graph_multi_domain'.upper()):
+ data_local_dict[0] = {
+ 'train': DataLoader(raw_train, batch_size, shuffle=True),
+ 'val': DataLoader(raw_valid, batch_size, shuffle=False),
+ 'test': DataLoader(raw_test, batch_size, shuffle=False),
+ }
+
+ return data_local_dict, config
diff --git a/federatedscope/gfl/dataloader/dataloader_link.py b/federatedscope/gfl/dataloader/dataloader_link.py
new file mode 100644
index 000000000..282e0da97
--- /dev/null
+++ b/federatedscope/gfl/dataloader/dataloader_link.py
@@ -0,0 +1,128 @@
+import torch
+
+from torch_geometric.data import Data
+from torch_geometric.loader import GraphSAINTRandomWalkSampler, NeighborSampler
+
+from federatedscope.core.auxiliaries.splitter_builder import get_splitter
+from federatedscope.core.auxiliaries.transform_builder import get_transform
+
+
+def raw2loader(raw_data, config=None):
+ """Transform a graph into either dataloader for graph-sampling-based mini-batch training
+ or still a graph for full-batch training.
+ Arguments:
+ raw_data (PyG.Data): a raw graph.
+ :returns:
+ sampler (object): a Dict containing loader and subgraph_sampler or still a PyG.Data object.
+ """
+
+ if config.data.loader == '':
+ sampler = raw_data
+ elif config.data.loader == 'graphsaint-rw':
+ loader = GraphSAINTRandomWalkSampler(
+ raw_data,
+ batch_size=config.data.batch_size,
+ walk_length=config.data.graphsaint.walk_length,
+ num_steps=config.data.graphsaint.num_steps,
+ sample_coverage=0)
+ subgraph_sampler = NeighborSampler(raw_data.edge_index,
+ sizes=[-1],
+ batch_size=4096,
+ shuffle=False,
+ num_workers=config.data.num_workers)
+ sampler = dict(data=raw_data,
+ train=loader,
+ val=subgraph_sampler,
+ test=subgraph_sampler)
+ else:
+ raise TypeError('Unsupported DataLoader Type {}'.format(
+ config.data.loader))
+
+ return sampler
+
+
+def load_linklevel_dataset(config=None):
+ r"""
+ :returns:
+ data_local_dict
+ :rtype:
+ (Dict): dict{'client_id': Data()}
+ """
+ path = config.data.root
+ name = config.data.type.lower()
+
+ # Splitter
+ splitter = get_splitter(config)
+
+ # Transforms
+ transforms_funcs = get_transform(config, 'torch_geometric')
+
+ if name in ['epinions', 'ciao']:
+ from federatedscope.gfl.dataset.recsys import RecSys
+ dataset = RecSys(path,
+ name,
+ FL=True,
+ splits=config.data.splits,
+ **transforms_funcs)
+ global_dataset = RecSys(path,
+ name,
+ FL=False,
+ splits=config.data.splits,
+ **transforms_funcs)
+ elif name in ['fb15k-237', 'wn18', 'fb15k', 'toy']:
+ from federatedscope.gfl.dataset.kg import KG
+ dataset = KG(path, name, **transforms_funcs)
+ dataset = splitter(dataset[0])
+ global_dataset = KG(path, name, **transforms_funcs)
+ else:
+ raise ValueError(f'No dataset named: {name}!')
+
+ dataset = [ds for ds in dataset]
+ client_num = min(len(dataset), config.federate.client_num
+ ) if config.federate.client_num > 0 else len(dataset)
+ config.merge_from_list(['federate.client_num', client_num])
+
+ # get local dataset
+ data_local_dict = dict()
+
+ for client_idx in range(len(dataset)):
+ local_data = raw2loader(dataset[client_idx], config)
+ data_local_dict[client_idx + 1] = local_data
+
+ if global_dataset is not None:
+ # Recode train & valid & test mask for global data
+ global_graph = global_dataset[0]
+ train_edge_mask = torch.BoolTensor([])
+ valid_edge_mask = torch.BoolTensor([])
+ test_edge_mask = torch.BoolTensor([])
+ global_edge_index = torch.LongTensor([[], []])
+ global_edge_type = torch.LongTensor([])
+
+ for client_sampler in data_local_dict.values():
+ if isinstance(client_sampler, Data):
+ client_subgraph = client_sampler
+ else:
+ client_subgraph = client_sampler['data']
+ orig_index = torch.zeros_like(client_subgraph.edge_index)
+ orig_index[0] = client_subgraph.index_orig[
+ client_subgraph.edge_index[0]]
+ orig_index[1] = client_subgraph.index_orig[
+ client_subgraph.edge_index[1]]
+ train_edge_mask = torch.cat(
+ (train_edge_mask, client_subgraph.train_edge_mask), dim=-1)
+ valid_edge_mask = torch.cat(
+ (valid_edge_mask, client_subgraph.valid_edge_mask), dim=-1)
+ test_edge_mask = torch.cat(
+ (test_edge_mask, client_subgraph.test_edge_mask), dim=-1)
+ global_edge_index = torch.cat((global_edge_index, orig_index),
+ dim=-1)
+ global_edge_type = torch.cat(
+ (global_edge_type, client_subgraph.edge_type), dim=-1)
+ global_graph.train_edge_mask = train_edge_mask
+ global_graph.valid_edge_mask = valid_edge_mask
+ global_graph.test_edge_mask = test_edge_mask
+ global_graph.edge_index = global_edge_index
+ global_graph.edge_type = global_edge_type
+ data_local_dict[0] = raw2loader(global_graph, config)
+
+ return data_local_dict, config
diff --git a/federatedscope/gfl/dataloader/dataloader_node.py b/federatedscope/gfl/dataloader/dataloader_node.py
new file mode 100644
index 000000000..1b36c84d4
--- /dev/null
+++ b/federatedscope/gfl/dataloader/dataloader_node.py
@@ -0,0 +1,183 @@
+import torch
+import numpy as np
+
+from torch_geometric.datasets import Planetoid
+from torch_geometric.utils import add_self_loops, remove_self_loops, to_undirected
+from torch_geometric.data import Data
+from torch_geometric.loader import GraphSAINTRandomWalkSampler, NeighborSampler
+
+from federatedscope.core.auxiliaries.splitter_builder import get_splitter
+from federatedscope.core.auxiliaries.transform_builder import get_transform
+
+INF = np.iinfo(np.int64).max
+
+
+def raw2loader(raw_data, config=None):
+ """Transform a graph into either dataloader for graph-sampling-based mini-batch training
+ or still a graph for full-batch training.
+ Arguments:
+ raw_data (PyG.Data): a raw graph.
+ :returns:
+ sampler (object): a Dict containing loader and subgraph_sampler or still a PyG.Data object.
+ """
+ # change directed graph to undirected
+ raw_data.edge_index = to_undirected(
+ remove_self_loops(raw_data.edge_index)[0])
+
+ if config.data.loader == '':
+ sampler = raw_data
+ elif config.data.loader == 'graphsaint-rw':
+ # Sampler would crash if there was isolated node.
+ raw_data.edge_index = add_self_loops(raw_data.edge_index,
+ num_nodes=raw_data.x.shape[0])[0]
+ loader = GraphSAINTRandomWalkSampler(
+ raw_data,
+ batch_size=config.data.batch_size,
+ walk_length=config.data.graphsaint.walk_length,
+ num_steps=config.data.graphsaint.num_steps,
+ sample_coverage=0)
+ #save_dir=dataset.processed_dir)
+ subgraph_sampler = NeighborSampler(raw_data.edge_index,
+ sizes=[-1],
+ batch_size=4096,
+ shuffle=False,
+ num_workers=config.data.num_workers)
+ sampler = dict(data=raw_data,
+ train=loader,
+ val=subgraph_sampler,
+ test=subgraph_sampler)
+ elif config.data.loader == 'neighbor':
+ # Sampler would crash if there was isolated node.
+ raw_data.edge_index = add_self_loops(raw_data.edge_index,
+ num_nodes=raw_data.x.shape[0])[0]
+
+ train_idx = raw_data.train_mask.nonzero(as_tuple=True)[0]
+ loader = NeighborSampler(raw_data.edge_index,
+ node_idx=train_idx,
+ sizes=config.data.sizes,
+ batch_size=config.data.batch_size,
+ shuffle=config.data.shuffle,
+ num_workers=config.data.num_workers)
+ subgraph_sampler = NeighborSampler(raw_data.edge_index,
+ sizes=[-1],
+ batch_size=4096,
+ shuffle=False,
+ num_workers=config.data.num_workers)
+ sampler = dict(data=raw_data,
+ train=loader,
+ val=subgraph_sampler,
+ test=subgraph_sampler)
+
+ return sampler
+
+
+def load_nodelevel_dataset(config=None):
+ r"""
+ :returns:
+ data_local_dict
+ :rtype:
+ Dict: dict{'client_id': Data()}
+ """
+ path = config.data.root
+ name = config.data.type.lower()
+
+ # Splitter
+ splitter = get_splitter(config)
+
+ # Transforms
+ transforms_funcs = get_transform(config, 'torch_geometric')
+
+ # Dataset
+ if name in ["cora", "citeseer", "pubmed"]:
+ num_split = {
+ 'cora': [232, 542, INF],
+ 'citeseer': [332, 665, INF],
+ 'pubmed': [3943, 3943, INF],
+ }
+
+ dataset = Planetoid(path,
+ name,
+ split='random',
+ num_train_per_class=num_split[name][0],
+ num_val=num_split[name][1],
+ num_test=num_split[name][2],
+ **transforms_funcs)
+ dataset = splitter(dataset[0])
+ global_dataset = Planetoid(path,
+ name,
+ split='random',
+ num_train_per_class=num_split[name][0],
+ num_val=num_split[name][1],
+ num_test=num_split[name][2],
+ **transforms_funcs)
+ elif name == "dblp_conf":
+ from federatedscope.gfl.dataset.dblp_new import DBLPNew
+ dataset = DBLPNew(path,
+ FL=1,
+ splits=config.data.splits,
+ **transforms_funcs)
+ global_dataset = DBLPNew(path,
+ FL=0,
+ splits=config.data.splits,
+ **transforms_funcs)
+ elif name == "dblp_org":
+ from federatedscope.gfl.dataset.dblp_new import DBLPNew
+ dataset = DBLPNew(path,
+ FL=2,
+ splits=config.data.splits,
+ **transforms_funcs)
+ global_dataset = DBLPNew(path,
+ FL=0,
+ splits=config.data.splits,
+ **transforms_funcs)
+ elif name.startswith("csbm"):
+ from federatedscope.gfl.dataset.cSBM_dataset import dataset_ContextualSBM
+ dataset = dataset_ContextualSBM(
+ root=path,
+ name=name if len(name) > len("csbm") else None,
+ theta=config.data.cSBM_phi,
+ epsilon=3.25,
+ n=2500,
+ d=5,
+ p=1000,
+ train_percent=0.2)
+ global_dataset = None
+ else:
+ raise ValueError(f'No dataset named: {name}!')
+
+ dataset = [ds for ds in dataset]
+ client_num = min(len(dataset), config.federate.client_num
+ ) if config.federate.client_num > 0 else len(dataset)
+ config.merge_from_list(['federate.client_num', client_num])
+
+ # get local dataset
+ data_local_dict = dict()
+
+ for client_idx in range(len(dataset)):
+ local_data = raw2loader(dataset[client_idx], config)
+ data_local_dict[client_idx + 1] = local_data
+
+ if global_dataset is not None:
+ global_graph = global_dataset[0]
+ train_mask = torch.zeros_like(global_graph.train_mask)
+ val_mask = torch.zeros_like(global_graph.val_mask)
+ test_mask = torch.zeros_like(global_graph.test_mask)
+
+ for client_sampler in data_local_dict.values():
+ if isinstance(client_sampler, Data):
+ client_subgraph = client_sampler
+ else:
+ client_subgraph = client_sampler['data']
+ train_mask[client_subgraph.index_orig[
+ client_subgraph.train_mask]] = True
+ val_mask[client_subgraph.index_orig[
+ client_subgraph.val_mask]] = True
+ test_mask[client_subgraph.index_orig[
+ client_subgraph.test_mask]] = True
+ global_graph.train_mask = train_mask
+ global_graph.val_mask = val_mask
+ global_graph.test_mask = test_mask
+
+ data_local_dict[0] = raw2loader(global_graph, config)
+
+ return data_local_dict, config
diff --git a/federatedscope/gfl/dataset/__init__.py b/federatedscope/gfl/dataset/__init__.py
new file mode 100644
index 000000000..0479564bb
--- /dev/null
+++ b/federatedscope/gfl/dataset/__init__.py
@@ -0,0 +1,10 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+from federatedscope.gfl.dataset.recsys import RecSys
+from federatedscope.gfl.dataset.dblp_new import DBLPNew
+from federatedscope.gfl.dataset.kg import KG
+from federatedscope.gfl.dataset.cSBM_dataset import dataset_ContextualSBM
+
+__all__ = ['RecSys', 'DBLPNew', 'KG', 'dataset_ContextualSBM']
diff --git a/federatedscope/gfl/dataset/cSBM_dataset.py b/federatedscope/gfl/dataset/cSBM_dataset.py
new file mode 100644
index 000000000..a1f65c820
--- /dev/null
+++ b/federatedscope/gfl/dataset/cSBM_dataset.py
@@ -0,0 +1,372 @@
+#! /usr/bin/env python
+# -*- coding: utf-8 -*-
+# vim:fenc=utf-8
+#
+#
+# Distributed under terms of the MIT license.
+"""
+cSBM is a configurable random graph model for studying homophily and heterophily.
+Source: https://github.com/jianhao2016/GPRGNN
+
+This is a script for contexual SBM model and its dataset generator.
+contains functions:
+ ContextualSBM
+ parameterized_Lambda_and_mu
+ save_data_to_pickle
+ class:
+ dataset_ContextualSBM
+
+"""
+import pickle
+from datetime import datetime
+import os
+import os.path as osp
+
+import numpy as np
+import torch
+from torch_geometric.data import Data, InMemoryDataset
+
+from federatedscope.gfl.dataset.utils import random_planetoid_splits
+
+
+def ContextualSBM(n, d, Lambda, p, mu, train_percent=0.01, u=None):
+ """To generate a graph with specified homophilic degree, avg node degree, feature dimension, etc.
+ Arguments:
+ n (int): the number of nodes.
+ d (int): the average node degree.
+ Lambda (float): the parameter controlling homophilic degree.
+ p (float): the dimension of node feature.
+ mu (float): the mean of node feature.
+ train_percent (float): (optional) the fraction of nodes used for training.
+ u (numpy.Array): (optional) the parameter controlling the node feature.
+ :returns:
+ data : the constructed graph.
+ u : the parameter controlling the node feature.
+ :rtype:
+ tuple: (PyG.Data, numpy.Array)
+
+ """
+ # n = 800 #number of nodes
+ # d = 5 # average degree
+ # Lambda = 1 # parameters
+ # p = 1000 # feature dim
+ # mu = 1 # mean of Gaussian
+ gamma = n / p
+
+ c_in = d + np.sqrt(d) * Lambda
+ c_out = d - np.sqrt(d) * Lambda
+ y = np.ones(n)
+ y[n // 2:] = -1
+ y = np.asarray(y, dtype=int)
+
+ quarter_len = n // 4
+ # creating edge_index
+ edge_index = [[], []]
+ for i in range(n - 1):
+ for j in range(i + 1, n):
+ if y[i] * y[j] > 0 and ((i // quarter_len) == (j // quarter_len)):
+ if (i // quarter_len == 0) or (i // quarter_len == 2):
+ Flip = np.random.binomial(1, c_in / n)
+ else:
+ Flip = np.random.binomial(1, c_out / n)
+ elif (y[i] * y[j] > 0) or (i // quarter_len + j // quarter_len
+ == 3):
+ Flip = np.random.binomial(1, 0.5 * (c_in / n + c_out / n))
+ else:
+ if i // quarter_len == 0:
+ Flip = np.random.binomial(1, c_out / n)
+ else:
+ Flip = np.random.binomial(1, c_in / n)
+ if Flip > 0.5:
+ edge_index[0].append(i)
+ edge_index[1].append(j)
+ edge_index[0].append(j)
+ edge_index[1].append(i)
+
+ # creating node features
+ x = np.zeros([n, p])
+ u = np.random.normal(0, 1 / np.sqrt(p), [1, p]) if u is None else u
+ for i in range(n):
+ Z = np.random.normal(0, 1, [1, p])
+ x[i] = np.sqrt(mu / n) * y[i] * u + Z / np.sqrt(p)
+ data = Data(x=torch.tensor(x, dtype=torch.float32),
+ edge_index=torch.tensor(edge_index),
+ y=torch.tensor((y + 1) // 2, dtype=torch.int64))
+ # order edge list and remove duplicates if any.
+ data.coalesce()
+
+ num_class = len(np.unique(y))
+ val_lb = int(n * train_percent)
+ percls_trn = int(round(train_percent * n / num_class))
+ data = random_planetoid_splits(data, num_class, percls_trn, val_lb)
+
+ # add parameters to attribute
+ data.Lambda = Lambda
+ data.mu = mu
+ data.n = n
+ data.p = p
+ data.d = d
+ data.train_percent = train_percent
+
+ return data, u
+
+
+def parameterized_Lambda_and_mu(theta, p, n, epsilon=0.1):
+ '''
+ based on claim 3 in the paper,
+
+ lambda^2 + mu^2/gamma = 1 + epsilon.
+
+ 1/gamma = p/n
+ longer axis: 1
+ shorter axis: 1/gamma.
+ =>
+ lambda = sqrt(1 + epsilon) * sin(theta * pi / 2)
+ mu = sqrt(gamma * (1 + epsilon)) * cos(theta * pi / 2)
+ Arguments:
+ theta (float): controlling the homophilic degree.
+ p (int): the dimension of node feature.
+ n (int): the number of nodes.
+ epsilon (float): (optional) controlling the var of node feature.
+ :returns:
+ Lambda : controlling the homophilic degree.
+ mu : the mean of node feature.
+ :rtype:
+ tuple: (float, float)
+ '''
+ from math import pi
+ gamma = n / p
+ assert (theta >= -1) and (theta <= 1)
+ Lambda = np.sqrt(1 + epsilon) * np.sin(theta * pi / 2)
+ mu = np.sqrt(gamma * (1 + epsilon)) * np.cos(theta * pi / 2)
+ return Lambda, mu
+
+
+def save_data_to_pickle(data, p2root='../data/', file_name=None):
+ '''
+ if file name not specified, use time stamp.
+ Arguments:
+ data (PyG.Data): the graph to be saved.
+ p2root (str): the path of dataset folder.
+ file_name (str): (optional) the name of output file.
+ :returns:
+ p2cSBM_data : the path of saved file.
+ :returns:
+ string
+ '''
+ now = datetime.now()
+ surfix = now.strftime('%b_%d_%Y-%H:%M')
+ if file_name is None:
+ tmp_data_name = '_'.join(['cSBM_data', surfix])
+ else:
+ tmp_data_name = file_name
+ p2cSBM_data = osp.join(p2root, tmp_data_name)
+ if not osp.isdir(p2root):
+ os.makedirs(p2root)
+ with open(p2cSBM_data, 'bw') as f:
+ pickle.dump(data, f)
+ return p2cSBM_data
+
+
+class dataset_ContextualSBM(InMemoryDataset):
+ r"""Create synthetic dataset based on the contextual SBM from the paper:
+ https://arxiv.org/pdf/1807.09596.pdf
+
+ Use the similar class as InMemoryDataset, but not requiring the root folder.
+
+ See `here `__ for the accompanying
+ tutorial.
+
+ Arguments:
+ root (string): Root directory where the dataset should be saved.
+ name (string): The name of the dataset if not specified use time stamp.
+
+ for {n, d, p, Lambda, mu}, with '_' as prefix: intial/feed in argument.
+ without '_' as prefix: loaded from data information
+
+ n: number nodes
+ d: avg degree of nodes
+ p: dimenstion of feature vector.
+
+ Lambda, mu: parameters balancing the mixture of information,
+ if not specified, use parameterized method to generate.
+
+ epsilon, theta: gap between boundary and chosen ellipsoid. theta is
+ angle of between the selected parameter and x-axis.
+ choosen between [0, 1] => 0 = 0, 1 = pi/2
+
+ transform (callable, optional): A function/transform that takes in an
+ :obj:`torch_geometric.data.Data` object and returns a transformed
+ version. The data object will be transformed before every access.
+ (default: :obj:`None`)
+ pre_transform (callable, optional): A function/transform that takes in
+ an :obj:`torch_geometric.data.Data` object and returns a
+ transformed version. The data object will be transformed before
+ being saved to disk. (default: :obj:`None`)
+ """
+
+ # url = 'https://github.com/kimiyoung/planetoid/raw/master/data'
+
+ def __init__(self,
+ root,
+ name=None,
+ n=800,
+ d=5,
+ p=100,
+ Lambda=None,
+ mu=None,
+ epsilon=0.1,
+ theta=[-0.5, -0.25, 0.25, 0.5],
+ train_percent=0.01,
+ transform=None,
+ pre_transform=None):
+
+ now = datetime.now()
+ surfix = now.strftime('%b_%d_%Y-%H:%M').lower()
+ if name is None:
+ # not specifing the dataset name, create one with time stamp.
+ self.name = '_'.join(['csbm_data', surfix])
+ else:
+ self.name = name
+
+ self._n = n
+ self._d = d
+ self._p = p
+
+ self._Lambda = Lambda
+ self._mu = mu
+ self._epsilon = epsilon
+ self._theta = theta
+
+ self._train_percent = train_percent
+
+ root = osp.join(root, self.name)
+ if not osp.isdir(root):
+ os.makedirs(root)
+ super(dataset_ContextualSBM, self).__init__(root, transform,
+ pre_transform)
+
+ # ipdb.set_trace()
+ self.data, self.slices = torch.load(self.processed_paths[0])
+ # overwrite the dataset attribute n, p, d, Lambda, mu
+ if isinstance(self._Lambda, list):
+ self.Lambda = self.data.Lambda.numpy()
+ self.mu = self.data.mu.numpy()
+ self.n = self.data.n.numpy()
+ self.p = self.data.p.numpy()
+ self.d = self.data.d.numpy()
+ self.train_percent = self.data.train_percent.numpy()
+ else:
+ self.Lambda = self.data.Lambda.item()
+ self.mu = self.data.mu.item()
+ self.n = self.data.n.item()
+ self.p = self.data.p.item()
+ self.d = self.data.d.item()
+ self.train_percent = self.data.train_percent.item()
+
+
+# @property
+# def raw_dir(self):
+# return osp.join(self.root, self.name, 'raw')
+
+# @property
+# def processed_dir(self):
+# return osp.join(self.root, self.name, 'processed')
+
+ @property
+ def raw_file_names(self):
+ file_names = [self.name]
+ return file_names
+
+ @property
+ def processed_file_names(self):
+ return ['data.pt']
+
+ def download(self):
+ for name in self.raw_file_names:
+ p2f = osp.join(self.raw_dir, name)
+ if not osp.isfile(p2f):
+ # file not exist, so we create it and save it there.
+ if self._Lambda is None or self._mu is None:
+ # auto generate the lambda and mu parameter by angle theta.
+ self._Lambda = []
+ self._mu = []
+ for theta in self._theta:
+ Lambda, mu = parameterized_Lambda_and_mu(
+ theta, self._p, self._n, self._epsilon)
+ self._Lambda.append(Lambda)
+ self._mu.append(mu)
+
+ if isinstance(self._Lambda, list):
+ u = None
+ for i, (Lambda,
+ mu) in enumerate(zip(self._Lambda, self._mu)):
+ tmp_data, u = ContextualSBM(self._n, self._d, Lambda,
+ self._p, mu,
+ self._train_percent, u)
+ name_split_idx = self.name.index('_', 2)
+ name = self.name[:name_split_idx] + '_{}'.format(
+ i) + self.name[name_split_idx:]
+ _ = save_data_to_pickle(tmp_data,
+ p2root=self.raw_dir,
+ file_name=name)
+
+ else:
+ tmp_data, _ = ContextualSBM(self._n, self._d, self._Lambda,
+ self._p, self._mu,
+ self._train_percent)
+
+ _ = save_data_to_pickle(tmp_data,
+ p2root=self.raw_dir,
+ file_name=self.name)
+ else:
+ # file exists already. Do nothing.
+ pass
+
+ def process(self):
+ if isinstance(self._Lambda, list):
+ all_data = []
+ for i, Lambda in enumerate(self._Lambda):
+ name_split_idx = self.name.index('_', 2)
+ name = self.name[:name_split_idx] + '_{}'.format(
+ i) + self.name[name_split_idx:]
+ p2f = osp.join(self.raw_dir, name)
+ with open(p2f, 'rb') as f:
+ data = pickle.load(f)
+ all_data.append(data)
+ for i in range(len(all_data)):
+ all_data[i] = all_data[
+ i] if self.pre_transform is None else self.pre_transform(
+ all_data[i])
+ torch.save(self.collate(all_data), self.processed_paths[0])
+ else:
+ p2f = osp.join(self.raw_dir, self.name)
+ with open(p2f, 'rb') as f:
+ data = pickle.load(f)
+ data = data if self.pre_transform is None else self.pre_transform(
+ data)
+ torch.save(self.collate([data]), self.processed_paths[0])
+
+ def __repr__(self):
+ return '{}()'.format(self.name)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--phi', type=float, default=1)
+ parser.add_argument('--epsilon', type=float, default=3.25)
+ parser.add_argument('--root', default='../data/')
+ parser.add_argument('--name', default='cSBM_demo')
+ parser.add_argument('--num_nodes', type=int, default=800)
+ parser.add_argument('--num_features', type=int, default=1000)
+ parser.add_argument('--avg_degree', type=float, default=5)
+
+ args = parser.parse_args()
+
+ dataset_ContextualSBM(root=args.root,
+ name=args.name,
+ theta=args.phi,
+ epsilon=args.epsilon,
+ n=args.num_nodes,
+ d=args.avg_degree,
+ p=args.num_features)
diff --git a/federatedscope/gfl/dataset/dblp_new.py b/federatedscope/gfl/dataset/dblp_new.py
new file mode 100644
index 000000000..38f671b47
--- /dev/null
+++ b/federatedscope/gfl/dataset/dblp_new.py
@@ -0,0 +1,184 @@
+import os.path as osp
+import numpy as np
+import networkx as nx
+import torch
+from torch_geometric.data import InMemoryDataset, download_url
+from torch_geometric.utils import from_networkx
+from sklearn.feature_extraction.text import CountVectorizer
+from sklearn.feature_extraction._stop_words import ENGLISH_STOP_WORDS as sklearn_stopwords
+
+
+class LemmaTokenizer(object):
+ def __init__(self):
+ from nltk.stem import WordNetLemmatizer
+ self.wnl = WordNetLemmatizer()
+
+ def __call__(self, doc):
+ from nltk import word_tokenize
+ return [self.wnl.lemmatize(t) for t in word_tokenize(doc)]
+
+
+def build_feature(words, threshold):
+ from nltk.corpus import stopwords as nltk_stopwords
+ # use bag-of-words representation of paper titles as the features of papers
+ stopwords = sklearn_stopwords.union(set(nltk_stopwords.words('english')))
+ vectorizer = CountVectorizer(min_df=int(threshold),
+ stop_words=stopwords,
+ tokenizer=LemmaTokenizer())
+ features_paper = vectorizer.fit_transform(words)
+
+ return features_paper
+
+
+def build_graph(path, filename, FL=0, threshold=15):
+ with open(osp.join(path, filename), 'r') as f:
+ node_cnt = sum([1 for line in f])
+
+ G = nx.DiGraph()
+ desc = node_cnt * [None]
+ neighbors = node_cnt * [None]
+ if FL == 1:
+ conf2paper = dict()
+ elif FL == 2:
+ org2paper = dict()
+
+ # Build node feature from title
+ with open(osp.join(path, filename), 'r') as f:
+ for line in f:
+ cols = line.strip().split('\t')
+ nid, title = int(cols[0]), cols[3]
+ desc[nid] = title
+
+ features = np.array(build_feature(desc, threshold).todense(),
+ dtype=np.float32)
+
+ # Build graph structure
+ with open(osp.join(path, filename), 'r') as f:
+ for line in f:
+ cols = line.strip().split('\t')
+ nid, conf, org, label = int(cols[0]), cols[1], cols[2], int(
+ cols[4])
+ neighbors[nid] = [int(val) for val in cols[-1].split(',')]
+
+ if FL == 1:
+ if conf not in conf2paper:
+ conf2paper[conf] = [nid]
+ else:
+ conf2paper[conf].append(nid)
+ elif FL == 2:
+ if org not in org2paper:
+ org2paper[org] = [nid]
+ else:
+ org2paper[org].append(nid)
+
+ G.add_node(nid, y=label, x=features[nid], index_orig=nid)
+
+ for nid, nbs in enumerate(neighbors):
+ for vid in nbs:
+ G.add_edge(nid, vid)
+
+ # Sort node id for index_orig
+ H = nx.Graph()
+ H.add_nodes_from(sorted(G.nodes(data=True)))
+ H.add_edges_from(G.edges(data=True))
+ G = H
+ graphs = []
+ if FL == 1:
+ for conf in conf2paper:
+ graphs.append(from_networkx(nx.subgraph(G, conf2paper[conf])))
+ elif FL == 2:
+ for org in org2paper:
+ graphs.append(from_networkx(nx.subgraph(G, org2paper[org])))
+ else:
+ graphs.append(from_networkx(G))
+
+ return graphs
+
+
+class DBLPNew(InMemoryDataset):
+ r"""
+ Args:
+ root (string): Root directory where the dataset should be saved.
+ FL (Bool): Federated setting, `0` for DBLP, `1` for FLDBLPbyConf, `2` for FLDBLPbyOrg
+ transform (callable, optional): A function/transform that takes in an
+ :obj:`torch_geometric.data.Data` object and returns a transformed
+ version. The data object will be transformed before every access.
+ (default: :obj:`None`)
+ pre_transform (callable, optional): A function/transform that takes in
+ an :obj:`torch_geometric.data.Data` object and returns a
+ transformed version. The data object will be transformed before
+ being saved to disk. (default: :obj:`None`)
+ """
+ def __init__(self,
+ root,
+ FL=0,
+ splits=[0.5, 0.2, 0.3],
+ transform=None,
+ pre_transform=None):
+ self.FL = FL
+ if self.FL == 0:
+ self.name = 'DBLPNew'
+ elif self.FL == 1:
+ self.name = 'FLDBLPbyConf'
+ else:
+ self.name = 'FLDBLPbyOrg'
+ self._customized_splits = splits
+ super(DBLPNew, self).__init__(root, transform, pre_transform)
+ self.data, self.slices = torch.load(self.processed_paths[0])
+
+ @property
+ def raw_file_names(self):
+ names = ['dblp_new.tsv']
+ return names
+
+ @property
+ def processed_file_names(self):
+ return ['data.pt']
+
+ @property
+ def raw_dir(self):
+ return osp.join(self.root, self.name, 'raw')
+
+ @property
+ def processed_dir(self):
+ return osp.join(self.root, self.name, 'processed')
+
+ def download(self):
+ # Download to `self.raw_dir`.
+ url = 'https://federatedscope.oss-cn-beijing.aliyuncs.com'
+ for name in self.raw_file_names:
+ download_url(f'{url}/{name}', self.raw_dir)
+
+ def process(self):
+ # Read data into huge `Data` list.
+ data_list = build_graph(self.raw_dir, self.raw_file_names[0], self.FL)
+
+ data_list_w_masks = []
+ for data in data_list:
+ if data.num_nodes == 0:
+ continue
+ indices = torch.randperm(data.num_nodes)
+ data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
+ data.train_mask[indices[:round(self._customized_splits[0] *
+ len(data.y))]] = True
+ data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
+ data.val_mask[
+ indices[round(self._customized_splits[0] *
+ len(data.y)):round((self._customized_splits[0] +
+ self._customized_splits[1]) *
+ len(data.y))]] = True
+ data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
+ data.test_mask[indices[round((self._customized_splits[0] +
+ self._customized_splits[1]) *
+ len(data.y)):]] = True
+ data_list_w_masks.append(data)
+ data_list = data_list_w_masks
+
+ if self.pre_filter is not None:
+ data_list = [data for data in data_list if self.pre_filter(data)]
+
+ if self.pre_transform is not None:
+ data_list = [self.pre_transform(data) for data in data_list]
+
+ data, slices = self.collate(data_list)
+ torch.save((data, slices), self.processed_paths[0])
diff --git a/federatedscope/gfl/dataset/kg.py b/federatedscope/gfl/dataset/kg.py
new file mode 100644
index 000000000..994f91a10
--- /dev/null
+++ b/federatedscope/gfl/dataset/kg.py
@@ -0,0 +1,130 @@
+"""This file is part of https://github.com/pyg-team/pytorch_geometric
+Copyright (c) 2021 Matthias Fey, Jiaxuan You
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+import os
+import os.path as osp
+
+import torch
+from torch_geometric.data import InMemoryDataset, Data, download_url
+
+
+class KG(InMemoryDataset):
+ def __init__(self, root, name, transform=None, pre_transform=None):
+ self.name = name
+ super().__init__(root, transform, pre_transform)
+ self.data, self.slices = torch.load(self.processed_paths[0])
+
+ @property
+ def num_relations(self):
+ return int(self.data.edge_type.max()) + 1
+
+ @property
+ def raw_dir(self):
+ return os.path.join(self.root, self.name, 'raw')
+
+ @property
+ def processed_dir(self):
+ return os.path.join(self.root, self.name, 'processed')
+
+ @property
+ def processed_file_names(self):
+ return 'data.pt'
+
+ @property
+ def raw_file_names(self):
+ return [
+ 'entities.dict', 'relations.dict', 'test.txt', 'train.txt',
+ 'valid.txt'
+ ]
+
+ def download(self):
+ url = 'https://github.com/MichSchli/RelationPrediction/tree/master/data/'
+ urls = {
+ 'fb15k': url + 'FB15k',
+ 'fb15k-237': url + 'FB-Toutanova',
+ 'wn18': url + 'wn18',
+ 'toy': url + 'Toy'
+ }
+ for file_name in self.raw_file_names:
+ download_url(f'{urls[self.name]}/{file_name}', self.raw_dir)
+
+ def process(self):
+ with open(osp.join(self.raw_dir, 'entities.dict'), 'r') as f:
+ lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
+ entities_dict = {key: int(value) for value, key in lines}
+
+ with open(osp.join(self.raw_dir, 'relations.dict'), 'r') as f:
+ lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
+ relations_dict = {key: int(value) for value, key in lines}
+
+ kwargs = {}
+ for split in ['train', 'valid', 'test']:
+ with open(osp.join(self.raw_dir, f'{split}.txt'), 'r') as f:
+ lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
+ src = [entities_dict[row[0]] for row in lines]
+ rel = [relations_dict[row[1]] for row in lines]
+ dst = [entities_dict[row[2]] for row in lines]
+ kwargs[f'{split}_edge_index'] = torch.tensor([src, dst])
+ kwargs[f'{split}_edge_type'] = torch.tensor(rel)
+
+ # For message passing, we add reverse edges and types to the graph:
+ row, col = kwargs['train_edge_index']
+ edge_type = kwargs['train_edge_type']
+ row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
+ edge_index = torch.stack([row, col], dim=0)
+ edge_type = torch.cat([edge_type, edge_type + len(relations_dict)])
+ num_nodes = len(entities_dict)
+ data = Data(num_nodes=num_nodes,
+ edge_index=edge_index,
+ edge_type=edge_type,
+ **kwargs)
+ edge_index = torch.cat((data.train_edge_index, data.valid_edge_index,
+ data.test_edge_index),
+ dim=-1)
+ edge_type = torch.cat(
+ (data.train_edge_type, data.valid_edge_type, data.test_edge_type),
+ dim=0)
+ num_edges = edge_index.size(-1)
+ train_edge_mask = torch.zeros(num_edges, dtype=torch.bool)
+ train_edge_mask[:data.train_edge_index.size(-1)] = True
+ valid_edge_mask = torch.zeros(num_edges, dtype=torch.bool)
+ valid_edge_mask[data.train_edge_index.
+ size(-1):-data.test_edge_index.size(-1)] = True
+ test_edge_mask = torch.zeros(num_edges, dtype=torch.bool)
+ test_edge_mask[-data.test_edge_index.size(-1):] = True
+ data = Data(edge_index=edge_index,
+ index_orig=torch.arange(num_nodes),
+ edge_type=edge_type,
+ num_nodes=num_nodes,
+ train_edge_mask=train_edge_mask,
+ valid_edge_mask=valid_edge_mask,
+ test_edge_mask=test_edge_mask,
+ input_edge_index=data.edge_index)
+
+ data_list = [data]
+ if self.pre_filter is not None:
+ data_list = [data for data in data_list if self.pre_filter(data)]
+
+ if self.pre_transform is not None:
+ data_list = [self.pre_transform(data) for data in data_list]
+
+ torch.save((self.collate([data])), self.processed_paths[0])
\ No newline at end of file
diff --git a/federatedscope/gfl/dataset/preprocess/__init__.py b/federatedscope/gfl/dataset/preprocess/__init__.py
new file mode 100644
index 000000000..f8e91f237
--- /dev/null
+++ b/federatedscope/gfl/dataset/preprocess/__init__.py
@@ -0,0 +1,3 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
diff --git a/federatedscope/gfl/dataset/preprocess/dblp_related.py b/federatedscope/gfl/dataset/preprocess/dblp_related.py
new file mode 100644
index 000000000..38aec84b2
--- /dev/null
+++ b/federatedscope/gfl/dataset/preprocess/dblp_related.py
@@ -0,0 +1,288 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import argparse
+import re
+from bson.json_util import loads
+
+KEYWORDS = ['AAAI', 'Association for the Advancement of Artificial Intelligence', \
+ 'CIKM', 'Conference on Information and Knowledge Management', \
+ 'CVPR', 'Conference on Computer Vision and Pattern Recognition', \
+ 'ECIR', 'European Conference on Information Retrieval', \
+ 'ECML', 'European Conference on Machine Learning', \
+ 'EDBT', 'International Conference on Extending Database Technology', \
+ 'ICDE', 'International Conference on Data Engineering', \
+ 'ICDM', 'International Conference on Data Mining', \
+ 'ICML', 'International Conference on Machine Learning', \
+ 'IJCAI', 'International Joint Conference on Artificial Intelligence', \
+ 'PAKDD', 'Pacific-Asia Conference on Knowledge Discovery and Data Mining', \
+ 'PKDD', 'Principles and Practice of Knowledge Discovery in Databases', \
+ 'KDD', 'Knowledge Discovery and Data Mining', \
+ 'PODS', 'Principles of Database Systems', \
+ 'SIGIR', 'Special Interest Group on Information Retrieval', \
+ 'SIGMOD', 'Special Interest Group on Management of Data', \
+ 'VLDB', 'Very Large Data Bases', \
+ 'WWW', 'World Wide Web Conference', \
+ 'WSDM', 'Web Search and Data Mining', \
+ 'SDM', 'SIAM International Conference on Data Mining']
+
+CONF2ORG = {
+ 'AAAI': 'AAAI',
+ 'CIKM': 'ACM',
+ 'CVPR': 'IEEE',
+ 'ECIR': 'Springer',
+ 'ECML': 'Springer',
+ 'EDBT': 'Springer',
+ 'ICDE': 'IEEE',
+ 'ICDM': 'IEEE',
+ 'ICML': 'PMLR',
+ 'IJCAI': 'the IJCAI, Inc.',
+ 'KDD': 'ACM',
+ 'PAKDD': 'Springer',
+ 'PKDD': 'Springer',
+ 'PODS': 'ACM',
+ 'SDM': 'SIAM',
+ 'SIGIR': 'ACM',
+ 'SIGMOD': 'ACM',
+ 'VLDB': 'VLDB',
+ 'WWW': 'ACM',
+ 'WSDM': 'ACM'
+}
+
+LABELS = [
+ 'Database', 'Data mining', 'Artificial intelligence',
+ 'Information retrieval'
+]
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--choice', type=int, default=-1)
+parser.add_argument('--input_path', type=str, default='')
+parser.add_argument('--output_path', type=str, default='')
+args = parser.parse_args()
+
+
+def extract_considered():
+ keywords = [val.lower() for val in KEYWORDS]
+ pat = re.compile(r'|'.join(keywords))
+
+ ent = 0
+ cnt = 0
+ rsvd = 0
+ ops = open(args.output_path, 'w')
+ try:
+ with open(args.input_path, 'r') as ips:
+ ele_contents = []
+ is_first = True
+ for line in ips:
+ if is_first:
+ is_first = False
+ continue
+
+ if line[0] == '{':
+ ent += 1
+ elif line[0] == '}':
+ ent -= 1
+
+ ele_contents.append(line.strip())
+
+ if ent == 0 and len(ele_contents):
+ json_text = ''.join(ele_contents)
+ json_text = re.sub(r'NumberInt\s*\(\s*(\S+)\s*\)',
+ r'{"$numberInt": "\1"}', json_text)
+ #print(json_text[:-1])
+ #ele = json.loads(json_text[:-1])
+ if json_text[-1] == ',':
+ ele = loads(json_text[:-1])
+ else:
+ ele = loads(json_text)
+ #if ('venue' in ele and '_id' in ele['venue']) and 'fos' in ele and 'references' in ele:
+ if '_id' in ele and 'venue' in ele and 'raw' in ele[
+ 'venue'] and ele['venue']['raw'] and 'fos' in ele and ele[
+ 'fos'] and 'references' in ele and 'title' in ele and ele[
+ 'title']:
+ raw_vanue_name = ele['venue']['raw'].lower()
+ if re.search(pat, raw_vanue_name):
+ ops.write("{}\t{}\t{}\t{}\t{}\n".format(
+ ele['_id'], ele['venue']['raw'].replace(
+ '\n', '').replace('\t', ' '),
+ ele['title'].replace('\n',
+ '').replace('\t', ' '),
+ ','.join(ele['fos']).replace('\n', '').replace(
+ '\t', ' '), ','.join(ele['references'])))
+ rsvd += 1
+ #print(ele)
+ cnt += 1
+ if cnt % 100000 == 0:
+ print(rsvd, cnt, "======>")
+ ele_contents = []
+ except Exception as ex:
+ print(ex)
+ finally:
+ ops.close()
+
+
+"""
+{'ICDM': 4589, 'KDD': 5476, 'IJCAI': 7586, 'VLDB': 5314, 'PAKDD': 2242, 'ECIR': 1482, 'ICML': 8322, 'CIKM': 5931, 'WWW': 5553, 'CVPR': 13355, 'EDBT': 1636, 'AAAI': 9695, 'ECML': 2216, 'SIGMOD': 4206, 'ICDE': 4330, 'PODS': 1670, 'SDM': 1624, 'SIGIR': 4619, 'WSDM': 746, 'PKDD': 547}
+======================
+{'IEEE': 22274, 'ACM': 28201, 'the IJCAI, Inc.': 7586, 'VLDB': 5314, 'Springer': 8123, 'PMLR': 8322, 'AAAI': 9695, 'SIAM': 1624}
+"""
+
+
+def be_canonical():
+ keywords = [val.lower() for val in KEYWORDS]
+ conf_cnts = dict()
+ org_cnts = dict()
+ ops = open(args.output_path, 'w')
+ with open(args.input_path, 'r') as ips:
+ for line in ips:
+ num_of_tab = line.count('\t')
+ if num_of_tab != 4:
+ print(num_of_tab)
+ print(line.replace('\t', 'TAB'))
+ continue
+ cols = line.strip().split('\t')
+ conf_raw_name = cols[1].lower()
+ org, conf_name = '', ''
+ for i, kw in enumerate(keywords):
+ if kw in conf_raw_name:
+ conf_name = keywords[i if (i % 2 == 0) else
+ (i - 1)].upper()
+ org = CONF2ORG[conf_name]
+ break
+ if conf_name == '':
+ print(cols[1])
+ continue
+ if conf_name not in conf_cnts:
+ conf_cnts[conf_name] = 0
+ if org not in org_cnts:
+ org_cnts[org] = 0
+ conf_cnts[conf_name] += 1
+ org_cnts[org] += 1
+ ops.write("{}\t{}\t{}\t{}\t{}\t{}\n".format(
+ cols[0], conf_name, org, cols[2], cols[3], cols[4]))
+ ops.close()
+
+ print(conf_cnts)
+ print("======================")
+ print(org_cnts)
+
+
+def be_fourclass_data():
+ labels = [val.lower() for val in LABELS]
+ cnt = 0
+ vset = dict()
+ with open(args.input_path, 'r') as ips:
+ for line in ips:
+ cols = line.strip().split('\t')
+ fos = [val.lower() for val in cols[4].split(',')]
+ for val in fos:
+ if val in labels:
+ cnt += 1
+ vset[cols[0]] = [0, 0]
+ # assume single label or say the classes are exclusive
+ break
+ print(cnt)
+
+ e_cnt = 0
+ with open(args.input_path, 'r') as ips:
+ for line in ips:
+ cols = line.strip().split('\t')
+ if cols[0] not in vset:
+ continue
+ refs = cols[-1].split(',')
+ for val in refs:
+ if val in vset:
+ e_cnt += 1
+ vset[cols[0]][0] += 1
+ vset[val][1] += 1
+ print(e_cnt)
+
+ connected = dict([(val, i) for i, val in enumerate(
+ [k for k, v in vset.items() if (v[0] > 0 or v[1] > 0)])])
+ print(len(connected))
+
+ ops = open(args.output_path, 'w')
+ with open(args.input_path, 'r') as ips:
+ for line in ips:
+ cols = line.strip().split('\t')
+ nid = cols[0]
+ if nid not in connected:
+ continue
+ for val in cols[4].split(','):
+ can_val = val.lower()
+ if can_val in labels:
+ lb = labels.index(can_val)
+ break
+ adjs = ','.join([
+ str(connected[val]) for val in cols[-1].split(',')
+ if val in connected
+ ])
+ ops.write("{}\t{}\t{}\t{}\t{}\t{}\n".format(
+ connected[nid], cols[1], cols[2], cols[3], lb, adjs))
+ ops.close()
+
+
+def stats():
+ p2c = dict()
+ p2o = dict()
+ with open(args.input_path, 'r') as ips:
+ for line in ips:
+ cols = line.strip().split('\t')
+ p2c[cols[0]] = cols[1]
+ p2o[cols[0]] = cols[2]
+
+ stats = dict()
+ with open(args.input_path, 'r') as ips:
+ for line in ips:
+ cols = line.strip().split('\t')
+ conf = cols[1]
+ if conf not in stats:
+ stats[conf] = [0, 0, 0, [0, 0, 0, 0]]
+ stats[conf][0] += 1
+ adjs = cols[-1].split(',')
+ for v in adjs:
+ if p2c[v] == conf:
+ stats[conf][1] += 1
+ else:
+ stats[conf][2] += 1
+ lb = int(cols[4])
+ stats[conf][3][lb] += 1
+
+ for k, v in stats.items():
+ print(k, v)
+
+ stats = dict()
+ with open(args.input_path, 'r') as ips:
+ for line in ips:
+ cols = line.strip().split('\t')
+ org = cols[2]
+ if org not in stats:
+ stats[org] = [0, 0, 0, [0, 0, 0, 0]]
+ stats[org][0] += 1
+ adjs = cols[-1].split(',')
+ for v in adjs:
+ if p2o[v] == org:
+ stats[org][1] += 1
+ else:
+ stats[org][2] += 1
+ lb = int(cols[4])
+ stats[org][3][lb] += 1
+
+ for k, v in stats.items():
+ print(k, v)
+
+
+def main():
+ if args.choice == 0:
+ extract_considered()
+ elif args.choice == 1:
+ be_canonical()
+ elif args.choice == 2:
+ be_fourclass_data()
+ elif args.choice == 3:
+ stats()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/federatedscope/gfl/dataset/recsys.py b/federatedscope/gfl/dataset/recsys.py
new file mode 100644
index 000000000..924c568d9
--- /dev/null
+++ b/federatedscope/gfl/dataset/recsys.py
@@ -0,0 +1,169 @@
+import os
+
+import numpy as np
+import os.path as osp
+import networkx as nx
+
+import torch
+from torch_geometric.data import InMemoryDataset, download_url, Data
+from torch_geometric.utils import from_networkx
+
+from federatedscope.gfl.dataset.utils import random_planetoid_splits
+
+
+# RecSys
+def read_mapping(path, filename):
+ mapping = {}
+ with open(os.path.join(path, filename)) as f:
+ for line in f:
+ s = line.strip().split()
+ mapping[int(s[0])] = int(s[1])
+
+ return mapping
+
+
+def partition_by_category(graph, mapping_item2category):
+ partition = {}
+ for key in mapping_item2category:
+ partition[key] = [mapping_item2category[key]]
+ for neighbor in graph.neighbors(key):
+ if neighbor not in partition:
+ partition[neighbor] = []
+ partition[neighbor].append(mapping_item2category[key])
+ return partition
+
+
+def subgraphing(g, partion, mapping_item2category):
+ nodelist = [[] for i in set(mapping_item2category.keys())]
+ for k, v in partion.items():
+ for category in v:
+ nodelist[category].append(k)
+
+ graphs = []
+ for nodes in nodelist:
+ if len(nodes) < 2:
+ continue
+ graph = nx.subgraph(g, nodes)
+ graphs.append(from_networkx(graph))
+ return graphs
+
+
+def read_RecSys(path, FL=False):
+ mapping_user = read_mapping(path, 'user.dict')
+ mapping_item = read_mapping(path, 'item.dict')
+
+ G = nx.Graph()
+ with open(osp.join(path, 'graph.txt')) as f:
+ for line in f:
+ s = line.strip().split()
+ s = [int(i) for i in s]
+ G.add_edge(mapping_user[s[0]], mapping_item[s[1]], edge_type=s[2])
+ dic = {}
+ for node in G.nodes:
+ dic[node] = node
+ nx.set_node_attributes(G, dic, "index_orig")
+ H = nx.Graph()
+ H.add_nodes_from(sorted(G.nodes(data=True)))
+ H.add_edges_from(G.edges(data=True))
+ G = H
+ if FL:
+ mapping_item2category = read_mapping(path, "category.dict")
+ partition = partition_by_category(G, mapping_item2category)
+ graphs = subgraphing(G, partition, mapping_item2category)
+ return graphs
+ else:
+ return [from_networkx(G)]
+
+
+class RecSys(InMemoryDataset):
+ r"""
+ Arguments:
+ root (string): Root directory where the dataset should be saved.
+ name (string): The name of the dataset (:obj:`"epinions"`, :obj:`"ciao"`).
+ FL (Bool): Federated setting or centralized setting.
+ transform (callable, optional): A function/transform that takes in an
+ :obj:`torch_geometric.data.Data` object and returns a transformed
+ version. The data object will be transformed before every access.
+ (default: :obj:`None`)
+ pre_transform (callable, optional): A function/transform that takes in
+ an :obj:`torch_geometric.data.Data` object and returns a
+ transformed version. The data object will be transformed before
+ being saved to disk. (default: :obj:`None`)
+ """
+ def __init__(self,
+ root,
+ name,
+ FL=False,
+ splits=[0.8, 0.1, 0.1],
+ transform=None,
+ pre_transform=None):
+ self.FL = FL
+ if self.FL:
+ self.name = 'FL' + name
+ else:
+ self.name = name
+ self._customized_splits = splits
+ super().__init__(root, transform, pre_transform)
+ self.data, self.slices = torch.load(self.processed_paths[0])
+
+ @property
+ def raw_file_names(self):
+ names = ['user.dict', 'item.dict', 'category.dict', 'graph.txt']
+ return names
+
+ @property
+ def processed_file_names(self):
+ return ['data.pt']
+
+ @property
+ def raw_dir(self):
+ return osp.join(self.root, self.name, 'raw')
+
+ @property
+ def processed_dir(self):
+ return osp.join(self.root, self.name, 'processed')
+
+ def download(self):
+ # Download to `self.raw_dir`.
+ url = 'https://github.com/FedML-AI/FedGraphNN/tree/main/data/recommender_system'
+ url = osp.join(url, self.name)
+ for name in self.raw_file_names:
+ download_url(f'{url}/{name}', self.raw_dir)
+
+ def process(self):
+ # Read data into huge `Data` list.
+ data_list = read_RecSys(self.raw_dir, self.FL)
+
+ data_list_w_masks = []
+ for data in data_list:
+ if self.name.endswith('epinions'):
+ data.edge_type = data.edge_type - 1
+ if data.num_edges == 0:
+ continue
+ indices = torch.randperm(data.num_edges)
+ data.train_edge_mask = torch.zeros(data.num_edges,
+ dtype=torch.bool)
+ data.train_edge_mask[indices[:round(self._customized_splits[0] *
+ data.num_edges)]] = True
+ data.valid_edge_mask = torch.zeros(data.num_edges,
+ dtype=torch.bool)
+ data.valid_edge_mask[indices[
+ round(self._customized_splits[0] *
+ data.num_edges):round((self._customized_splits[0] +
+ self._customized_splits[1]) *
+ data.num_edges)]] = True
+ data.test_edge_mask = torch.zeros(data.num_edges, dtype=torch.bool)
+ data.test_edge_mask[indices[round((self._customized_splits[0] +
+ self._customized_splits[1]) *
+ data.num_edges):]] = True
+ data_list_w_masks.append(data)
+ data_list = data_list_w_masks
+
+ if self.pre_filter is not None:
+ data_list = [data for data in data_list if self.pre_filter(data)]
+
+ if self.pre_transform is not None:
+ data_list = [self.pre_transform(data) for data in data_list]
+
+ data, slices = self.collate(data_list)
+ torch.save((data, slices), self.processed_paths[0])
diff --git a/federatedscope/gfl/dataset/utils.py b/federatedscope/gfl/dataset/utils.py
new file mode 100644
index 000000000..3f43af354
--- /dev/null
+++ b/federatedscope/gfl/dataset/utils.py
@@ -0,0 +1,53 @@
+import torch
+from torch_geometric.utils import to_networkx
+
+
+def index_to_mask(index, size, device='cpu'):
+ mask = torch.zeros(size, dtype=torch.bool, device=device)
+ mask[index] = 1
+ return mask
+
+
+def random_planetoid_splits(data,
+ num_classes,
+ percls_trn=20,
+ val_lb=500,
+ Flag=0):
+
+ indices = []
+ for i in range(num_classes):
+ index = (data.y == i).nonzero().view(-1)
+ index = index[torch.randperm(index.size(0))]
+ indices.append(index)
+
+ train_index = torch.cat([i[:percls_trn] for i in indices], dim=0)
+
+ if Flag == 0:
+ rest_index = torch.cat([i[percls_trn:] for i in indices], dim=0)
+ rest_index = rest_index[torch.randperm(rest_index.size(0))]
+
+ data.train_mask = index_to_mask(train_index, size=data.num_nodes)
+ data.val_mask = index_to_mask(rest_index[:val_lb], size=data.num_nodes)
+ data.test_mask = index_to_mask(rest_index[val_lb:],
+ size=data.num_nodes)
+ else:
+ val_index = torch.cat(
+ [i[percls_trn:percls_trn + val_lb] for i in indices], dim=0)
+ rest_index = torch.cat([i[percls_trn + val_lb:] for i in indices],
+ dim=0)
+ rest_index = rest_index[torch.randperm(rest_index.size(0))]
+
+ data.train_mask = index_to_mask(train_index, size=data.num_nodes)
+ data.val_mask = index_to_mask(val_index, size=data.num_nodes)
+ data.test_mask = index_to_mask(rest_index, size=data.num_nodes)
+ return data
+
+
+def get_maxDegree(graphs):
+ maxdegree = 0
+ for i, graph in enumerate(graphs):
+ g = to_networkx(graph, to_undirected=True)
+ gdegree = max(dict(g.degree).values())
+ if gdegree > maxdegree:
+ maxdegree = gdegree
+ return maxdegree
diff --git a/federatedscope/gfl/fedsageplus/__init__.py b/federatedscope/gfl/fedsageplus/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/federatedscope/gfl/fedsageplus/trainer.py b/federatedscope/gfl/fedsageplus/trainer.py
new file mode 100644
index 000000000..86d2334a8
--- /dev/null
+++ b/federatedscope/gfl/fedsageplus/trainer.py
@@ -0,0 +1,148 @@
+import torch
+import copy
+import numpy as np
+import torch.nn.functional as F
+
+from federatedscope.gfl.loss import GreedyLoss
+from federatedscope.gfl.trainer.nodetrainer import NodeFullBatchTrainer
+
+
+class LocalGenTrainer(NodeFullBatchTrainer):
+ def __init__(self,
+ model,
+ data,
+ device,
+ config,
+ only_for_eval=False,
+ monitor=None):
+ super(LocalGenTrainer, self).__init__(model, data, device, config,
+ only_for_eval, monitor)
+ self.criterion_num = F.smooth_l1_loss
+ self.criterion_feat = GreedyLoss
+
+ def _hook_on_batch_forward(self, ctx):
+ batch = ctx.data_batch.to(ctx.device)
+ mask = batch['{}_mask'.format(ctx.cur_mode)]
+ pred_missing, pred_feat, nc_pred = ctx.model(batch)
+ pred_missing, pred_feat, nc_pred = pred_missing[mask], pred_feat[
+ mask], nc_pred[mask]
+ loss_num = self.criterion_num(pred_missing, batch.num_missing[mask])
+ loss_feat = self.criterion_feat(
+ pred_feats=pred_feat,
+ true_feats=batch.x_missing[mask],
+ pred_missing=pred_missing,
+ true_missing=batch.num_missing[mask],
+ num_pred=self.cfg.fedsageplus.num_pred).requires_grad_()
+ loss_clf = ctx.criterion(nc_pred, batch.y[mask])
+ ctx.batch_size = torch.sum(mask).item()
+ ctx.loss_batch = (self.cfg.fedsageplus.a * loss_num +
+ self.cfg.fedsageplus.b * loss_feat +
+ self.cfg.fedsageplus.c * loss_clf).float()
+
+ ctx.y_true = batch.num_missing[mask]
+ ctx.y_prob = pred_missing
+
+
+class FedGenTrainer(LocalGenTrainer):
+ def _hook_on_batch_forward(self, ctx):
+ batch = ctx.data_batch.to(ctx.device)
+ mask = batch['{}_mask'.format(ctx.cur_mode)]
+ pred_missing, pred_feat, nc_pred = ctx.model(batch)
+ pred_missing, pred_feat, nc_pred = pred_missing[mask], pred_feat[
+ mask], nc_pred[mask]
+ loss_num = self.criterion_num(pred_missing, batch.num_missing[mask])
+ loss_feat = self.criterion_feat(pred_feats=pred_feat,
+ true_feats=batch.x_missing[mask],
+ pred_missing=pred_missing,
+ true_missing=batch.num_missing[mask],
+ num_pred=self.cfg.fedsageplus.num_pred)
+ loss_clf = ctx.criterion(nc_pred, batch.y[mask])
+ ctx.batch_size = torch.sum(mask).item()
+ ctx.loss_batch = (self.cfg.fedsageplus.a * loss_num +
+ self.cfg.fedsageplus.b * loss_feat +
+ self.cfg.fedsageplus.c *
+ loss_clf).float() / self.cfg.federate.client_num
+
+ ctx.y_true = batch.num_missing[mask]
+ ctx.y_prob = pred_missing
+
+ def update_by_grad(self, grads):
+ """
+ Arguments:
+ grads: grads of other clients to optimize the local model
+ :returns:
+ state_dict of generation model
+ """
+ for key in grads.keys():
+ if isinstance(grads[key], list):
+ grads[key] = torch.FloatTensor(grads[key]).to(self.ctx.device)
+
+ for key, value in self.ctx.model.named_parameters():
+ value.grad += grads[key]
+ self.ctx.optimizer.step()
+ return self.ctx.model.cpu().state_dict()
+
+ def cal_grad(self, raw_data, model_para, embedding, true_missing):
+ """
+ Arguments:
+ raw_data (Pyg.Data): raw graph
+ model_para: model parameters
+ embedding: output embeddings after local encoder
+ true_missing: number of missing node
+ :returns:
+ grads: grads to optimize the model of other clients
+ """
+ para_backup = copy.deepcopy(self.ctx.model.cpu().state_dict())
+
+ for key in model_para.keys():
+ if isinstance(model_para[key], list):
+ model_para[key] = torch.FloatTensor(model_para[key])
+ self.ctx.model.load_state_dict(model_para)
+ self.ctx.model = self.ctx.model.to(self.ctx.device)
+ self.ctx.model.train()
+
+ raw_data = raw_data.to(self.ctx.device)
+ embedding = torch.FloatTensor(embedding).to(self.ctx.device)
+ true_missing = true_missing.long().to(self.ctx.device)
+ pred_missing = self.ctx.model.reg_model(embedding)
+ pred_feat = self.ctx.model.gen(embedding)
+
+ # Random pick node and compare its neighbors with predicted nodes
+ choice = np.random.choice(raw_data.num_nodes, embedding.shape[0])
+ global_target_feat = []
+ for c_i in choice:
+ neighbors_ids = raw_data.edge_index[1][torch.where(
+ raw_data.edge_index[0] == c_i)[0]]
+ while len(neighbors_ids) == 0:
+ id_i = np.random.choice(raw_data.num_nodes, 1)[0]
+ neighbors_ids = raw_data.edge_index[1][torch.where(
+ raw_data.edge_index[0] == id_i)[0]]
+ choice_i = np.random.choice(neighbors_ids.detach().cpu().numpy(),
+ self.cfg.fedsageplus.num_pred)
+ for ch_i in choice_i:
+ global_target_feat.append(
+ raw_data.x[ch_i].detach().cpu().numpy())
+ global_target_feat = np.asarray(global_target_feat).reshape(
+ (embedding.shape[0], self.cfg.fedsageplus.num_pred,
+ raw_data.num_node_features))
+ loss_feat = self.criterion_feat(pred_feats=pred_feat,
+ true_feats=global_target_feat,
+ pred_missing=pred_missing,
+ true_missing=true_missing,
+ num_pred=self.cfg.fedsageplus.num_pred)
+ loss = self.cfg.fedsageplus.b * loss_feat
+ loss = (1.0 / self.cfg.federate.client_num * loss).requires_grad_()
+ loss.backward()
+ grads = {
+ key: value.grad
+ for key, value in self.ctx.model.named_parameters()
+ }
+ # Rollback
+ self.ctx.model.load_state_dict(para_backup)
+ return grads
+
+ @torch.no_grad()
+ def embedding(self):
+ model = self.ctx.model.to(self.ctx.device)
+ data = self.ctx.data.to(self.ctx.device)
+ return model.encoder_model(data).to('cpu')
diff --git a/federatedscope/gfl/fedsageplus/utils.py b/federatedscope/gfl/fedsageplus/utils.py
new file mode 100644
index 000000000..ba7f1880c
--- /dev/null
+++ b/federatedscope/gfl/fedsageplus/utils.py
@@ -0,0 +1,134 @@
+import torch
+
+from torch_geometric.data import Data
+from torch_geometric.transforms import BaseTransform
+from torch_geometric.utils import to_networkx, from_networkx
+
+import networkx as nx
+import numpy as np
+
+from federatedscope.core.configs.config import global_cfg
+
+
+class HideGraph(BaseTransform):
+ r"""
+ Generate impaired graph with labels and features to train NeighGen,
+ hide Node from validation set from raw graph.
+
+ Arguments:
+ hidden_portion (int): hidden_portion of validation set.
+ num_pred (int): hyperparameters which limit the maximum value of the prediction
+
+ :returns:
+ filled_data : impaired graph with attribute "num_missing"
+ :rtype:
+ nx.Graph
+ """
+ def __init__(self, hidden_portion=0.5, num_pred=5):
+ self.hidden_portion = hidden_portion
+ self.num_pred = num_pred
+
+ def __call__(self, data):
+
+ val_ids = torch.where(data.val_mask == True)[0]
+ hide_ids = np.random.choice(val_ids,
+ int(len(val_ids) * self.hidden_portion),
+ replace=False)
+ remaining_mask = torch.ones(data.num_nodes, dtype=torch.bool)
+ remaining_mask[hide_ids] = False
+ remaining_nodes = torch.where(remaining_mask == True)[0].numpy()
+
+ data.ids_missing = [[] for _ in range(data.num_nodes)]
+
+ G = to_networkx(data,
+ node_attrs=[
+ 'x', 'y', 'train_mask', 'val_mask', 'test_mask',
+ 'index_orig', 'ids_missing'
+ ],
+ to_undirected=True)
+
+ for missing_node in hide_ids:
+ neighbors = G.neighbors(missing_node)
+ for i in neighbors:
+ G.nodes[i]['ids_missing'].append(missing_node)
+ for i in G.nodes:
+ ids_missing = G.nodes[i]['ids_missing']
+ del G.nodes[i]['ids_missing']
+ G.nodes[i]['num_missing'] = np.array([len(ids_missing)],
+ dtype=np.float32)
+ if len(ids_missing) > 0:
+ if len(ids_missing) <= self.num_pred:
+ G.nodes[i]['x_missing'] = np.vstack(
+ (data.x[ids_missing],
+ np.zeros((self.num_pred - len(ids_missing),
+ data.x.shape[1]))))
+ else:
+ G.nodes[i]['x_missing'] = data.x[
+ ids_missing[:self.num_pred]]
+ else:
+ G.nodes[i]['x_missing'] = np.zeros(
+ (self.num_pred, data.x.shape[1]))
+
+ return from_networkx(nx.subgraph(G, remaining_nodes))
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}({self.hidden_portion})'
+
+
+def FillGraph(impaired_data, original_data, pred_missing, pred_feats,
+ num_pred):
+ # Mend the original data
+ original_data = original_data.detach().cpu()
+ new_features = original_data.x
+ new_edge_index = original_data.edge_index.T
+ pred_missing = pred_missing.detach().cpu().numpy()
+ pred_feats = pred_feats.detach().cpu().reshape(
+ (-1, num_pred, original_data.num_node_features))
+
+ start_id = original_data.num_nodes
+ for node in range(len(pred_missing)):
+ num_fill_node = np.around(pred_missing[node]).astype(np.int32).item()
+ if num_fill_node > 0:
+ new_ids_i = np.arange(start_id,
+ start_id + min(num_pred, num_fill_node))
+ org_id = impaired_data.index_orig[node]
+ org_node = torch.where(
+ original_data.index_orig == org_id)[0].item()
+ new_edges = torch.tensor([[org_node, fill_id]
+ for fill_id in new_ids_i],
+ dtype=torch.int64)
+ new_features = torch.vstack(
+ (new_features, pred_feats[node][:num_fill_node]))
+ new_edge_index = torch.vstack((new_edge_index, new_edges))
+ start_id = start_id + min(num_pred, num_fill_node)
+ new_y = torch.zeros(new_features.shape[0], dtype=torch.int64)
+ new_y[:original_data.num_nodes] = original_data.y
+ filled_data = Data(
+ x=new_features,
+ edge_index=new_edge_index.T,
+ train_idx=torch.where(original_data.train_mask == True)[0],
+ valid_idx=torch.where(original_data.val_mask == True)[0],
+ test_idx=torch.where(original_data.test_mask == True)[0],
+ y=new_y,
+ )
+ return filled_data
+
+
+@torch.no_grad()
+def GraphMender(model, impaired_data, original_data):
+ r"""Mend the graph with generation model
+ Arguments:
+ model (torch.nn.module): trained generation model
+ impaired_data (PyG.Data): impaired graph
+ original_data (PyG.Data): raw graph
+ :returns:
+ filled_data : Graph after Data Enhancement
+ :rtype:
+ PyG.data
+ """
+ device = impaired_data.x.device
+ model = model.to(device)
+ pred_missing, pred_feats, _ = model(impaired_data)
+
+ return FillGraph(impaired_data, original_data, pred_missing, pred_feats,
+ global_cfg.fedsageplus.num_pred)
\ No newline at end of file
diff --git a/federatedscope/gfl/fedsageplus/worker.py b/federatedscope/gfl/fedsageplus/worker.py
new file mode 100644
index 000000000..cafc4ccbe
--- /dev/null
+++ b/federatedscope/gfl/fedsageplus/worker.py
@@ -0,0 +1,392 @@
+import torch
+import logging
+
+from torch_geometric.loader import NeighborSampler
+
+from federatedscope.core.message import Message
+from federatedscope.core.worker.server import Server
+from federatedscope.core.worker.client import Client
+from federatedscope.core.auxiliaries.utils import merge_dict
+
+from federatedscope.gfl.trainer.nodetrainer import NodeMiniBatchTrainer
+from federatedscope.gfl.model.fedsageplus import LocalSage_Plus, FedSage_Plus
+from federatedscope.gfl.fedsageplus.utils import GraphMender, HideGraph
+from federatedscope.gfl.fedsageplus.trainer import LocalGenTrainer, FedGenTrainer
+
+logger = logging.getLogger(__name__)
+
+
+class FedSagePlusServer(Server):
+ def __init__(self,
+ ID=-1,
+ state=0,
+ config=None,
+ data=None,
+ model=None,
+ client_num=5,
+ total_round_num=10,
+ device='cpu',
+ strategy=None,
+ **kwargs):
+ r"""
+ FedSage+ consists of three of training stages.
+ Stage1: 0, local pre-train for generator.
+ Stage2: -> 2 * fedgen_epoch, federated training for generator.
+ Stage3: -> 2 * fedgen_epoch + total_round_num: federated training for GraphSAGE Classifier
+ """
+ super(FedSagePlusServer,
+ self).__init__(ID, state, config, data, model, client_num,
+ total_round_num, device, strategy, **kwargs)
+
+ assert self.model_num == 1, "Not supported multi-model for FedSagePlusServer"
+
+ # If state < fedgen_epoch and state % 2 == 0:
+ # Server receive [model, embedding, label]
+ # If state < fedgen_epoch and state % 2 == 1:
+ # Server receive [gradient]
+ self.fedgen_epoch = 2 * self._cfg.fedsageplus.fedgen_epoch
+ self.total_round_num = total_round_num + self.fedgen_epoch
+ self.grad_cnt = 0
+
+ def _register_default_handlers(self):
+ self.register_handlers('join_in', self.callback_funcs_for_join_in)
+ self.register_handlers('join_in_info', self.callback_funcs_for_join_in)
+ self.register_handlers('clf_para', self.callback_funcs_model_para)
+ self.register_handlers('gen_para', self.callback_funcs_model_para)
+ self.register_handlers('gradient', self.callback_funcs_gradient)
+ self.register_handlers('metrics', self.callback_funcs_for_metrics)
+
+ def callback_funcs_for_join_in(self, message: Message):
+ if 'info' in message.msg_type:
+ sender, info = message.sender, message.content
+ for key in self._cfg.federate.join_in_info:
+ assert key in info
+ self.join_in_info[sender] = info
+ logger.info('Server #{:d}: Client #{:d} has joined in !'.format(
+ self.ID, sender))
+ else:
+ self.join_in_client_num += 1
+ sender, address = message.sender, message.content
+ if int(sender) == -1: # assign number to client
+ sender = self.join_in_client_num
+ self.comm_manager.add_neighbors(neighbor_id=sender,
+ address=address)
+ self.comm_manager.send(
+ Message(msg_type='assign_client_id',
+ sender=self.ID,
+ receiver=[sender],
+ state=self.state,
+ content=str(sender)))
+ else:
+ self.comm_manager.add_neighbors(neighbor_id=sender,
+ address=address)
+
+ if len(self._cfg.federate.join_in_info) != 0:
+ self.comm_manager.send(
+ Message(msg_type='ask_for_join_in_info',
+ sender=self.ID,
+ receiver=[sender],
+ state=self.state,
+ content=self._cfg.federate.join_in_info.copy()))
+
+ if self.check_client_join_in():
+ if self._cfg.federate.use_ss:
+ self.broadcast_client_address()
+
+ self.comm_manager.send(
+ Message(msg_type='local_pretrain',
+ sender=self.ID,
+ receiver=list(self.comm_manager.neighbors.keys()),
+ state=self.state))
+
+ def callback_funcs_gradient(self, message: Message):
+ round, sender, content = message.state, message.sender, message.content
+ gen_grad, ID = content
+ # For a new round
+ if round not in self.msg_buffer['train'].keys():
+ self.msg_buffer['train'][round] = dict()
+ self.grad_cnt += 1
+ # Sum up all grad from other client
+ if ID not in self.msg_buffer['train'][round]:
+ self.msg_buffer['train'][round][ID] = dict()
+ for key in gen_grad.keys():
+ self.msg_buffer['train'][round][ID][key] = torch.FloatTensor(
+ gen_grad[key].cpu())
+ else:
+ for key in gen_grad.keys():
+ self.msg_buffer['train'][round][ID][key] += torch.FloatTensor(
+ gen_grad[key].cpu())
+ self.check_and_move_on()
+
+ def check_and_move_on(self, check_eval_result=False):
+ client_IDs = [i for i in range(1, self.client_num + 1)]
+
+ if check_eval_result:
+ # all clients are participating in evaluation
+ minimal_number = self.client_num
+ else:
+ # sampled clients are participating in training
+ minimal_number = self.sample_client_num
+
+ # Transmit model and embedding to get gradient back
+ if self.check_buffer(
+ self.state, self.client_num
+ ) and self.state < self._cfg.fedsageplus.fedgen_epoch and self.state % 2 == 0:
+ # FedGen: we should wait for all messages
+ for sender in self.msg_buffer['train'][self.state]:
+ content = self.msg_buffer['train'][self.state][sender]
+ gen_para, embedding, label = content
+ receiver_IDs = client_IDs[:sender - 1] + client_IDs[sender:]
+ self.comm_manager.send(
+ Message(msg_type='gen_para',
+ sender=self.ID,
+ receiver=receiver_IDs,
+ state=self.state + 1,
+ content=[gen_para, embedding, label, sender]))
+ logger.info(
+ f'\tServer #{self.ID}: Transmit gen_para to {receiver_IDs} @{self.state//2}.'
+ )
+ self.state += 1
+
+ # Sum up gradient client-wisely and send back
+ if self.check_buffer(
+ self.state, self.client_num
+ ) and self.state < self._cfg.fedsageplus.fedgen_epoch and self.state % 2 == 1 and self.grad_cnt == self.client_num * (
+ self.client_num - 1):
+ for ID in self.msg_buffer['train'][self.state]:
+ grad = self.msg_buffer['train'][self.state][ID]
+ self.comm_manager.send(
+ Message(msg_type='gradient',
+ sender=self.ID,
+ receiver=[ID],
+ state=self.state + 1,
+ content=grad))
+ # reset num of grad counter
+ self.grad_cnt = 0
+ self.state += 1
+
+ if self.check_buffer(
+ self.state, self.client_num
+ ) and self.state == self._cfg.fedsageplus.fedgen_epoch:
+ self.state += 1
+ # Setup Clf_trainer for each client
+ self.comm_manager.send(
+ Message(msg_type='setup',
+ sender=self.ID,
+ receiver=list(self.comm_manager.neighbors.keys()),
+ state=self.state))
+
+ if self.check_buffer(
+ self.state, minimal_number, check_eval_result
+ ) and self.state >= self._cfg.fedsageplus.fedgen_epoch:
+
+ if not check_eval_result: # in the training process
+ # Get all the message
+ train_msg_buffer = self.msg_buffer['train'][self.state]
+ msg_list = list()
+ for client_id in train_msg_buffer:
+ msg_list.append(train_msg_buffer[client_id])
+
+ # Trigger the monitor here (for training)
+ if 'dissim' in self._cfg.eval.monitoring:
+ B_val = self._monitor.calc_blocal_dissim(
+ self.model.load_state_dict(), msg_list)
+ formatted_logs = self._monitor.format_eval_res(
+ B_val, rnd=self.state, role='Server #')
+ logger.info(formatted_logs)
+
+ # Aggregate
+ agg_info = {
+ 'client_feedback': msg_list,
+ 'recover_fun': self.recover_fun
+ }
+ result = self.aggregator.aggregate(agg_info)
+ self.model.load_state_dict(result)
+ self.aggregator.update(result)
+
+ self.state += 1
+ if self.state % self._cfg.eval.freq == 0 and self.state != self.total_round_num:
+ # Evaluate
+ logger.info(
+ 'Server #{:d}: Starting evaluation at round {:d}.'.
+ format(self.ID, self.state))
+ self.eval()
+
+ if self.state < self.total_round_num:
+ # Move to next round of training
+ logger.info(
+ '----------- Starting a new training round (Round #{:d}) -------------'
+ .format(self.state))
+ self.broadcast_model_para(
+ msg_type='model_para',
+ sample_client_num=self.sample_client_num)
+ else:
+ # Final Evaluate
+ logger.info(
+ 'Server #{:d}: Training is finished! Starting evaluation.'
+ .format(self.ID))
+ self.eval()
+
+ else: # in the evaluation process
+ # Get all the message & aggregate
+ formatted_eval_res = self.merge_eval_results_from_all_clients()
+ self.history_results = merge_dict(self.history_results,
+ formatted_eval_res)
+ self.check_and_save()
+
+
+class FedSagePlusClient(Client):
+ def __init__(self,
+ ID=-1,
+ server_id=None,
+ state=-1,
+ config=None,
+ data=None,
+ model=None,
+ device='cpu',
+ strategy=None,
+ *args,
+ **kwargs):
+ super(FedSagePlusClient,
+ self).__init__(ID, server_id, state, config, data, model, device,
+ strategy, *args, **kwargs)
+ self.data = data
+ self.hide_data = HideGraph(self._cfg.fedsageplus.hide_portion)(data)
+ self.device = device
+ self.sage_batch_size = 64
+ self.gen = LocalSage_Plus(data.x.shape[-1],
+ self._cfg.model.out_channels,
+ hidden=self._cfg.model.hidden,
+ gen_hidden=self._cfg.fedsageplus.gen_hidden,
+ dropout=self._cfg.model.dropout,
+ num_pred=self._cfg.fedsageplus.num_pred)
+ self.clf = model
+ self.trainer_loc = LocalGenTrainer(self.gen,
+ self.hide_data,
+ self.device,
+ self._cfg,
+ monitor=self._monitor)
+
+ self.register_handlers('clf_para', self.callback_funcs_for_model_para)
+ self.register_handlers('local_pretrain',
+ self.callback_funcs_for_local_pre_train)
+ self.register_handlers('gradient', self.callback_funcs_for_gradient)
+ self.register_handlers('gen_para', self.callback_funcs_for_gen_para)
+ self.register_handlers('setup', self.callback_funcs_for_setup_fedsage)
+
+ def callback_funcs_for_local_pre_train(self, message: Message):
+ round, sender, content = message.state, message.sender, message.content
+ # Local pre-train
+ logger.info(f'\tClient #{self.ID} pre-train start...')
+ for i in range(self._cfg.fedsageplus.loc_epoch):
+ num_samples_train, _, _ = self.trainer_loc.train()
+ logger.info(f'\tClient #{self.ID} local pre-train @Epoch {i}.')
+ # Build fedgen base on locgen
+ self.fedgen = FedSage_Plus(self.gen)
+ # Build trainer for fedgen
+ self.trainer_fedgen = FedGenTrainer(self.fedgen,
+ self.hide_data,
+ self.device,
+ self._cfg,
+ monitor=self._monitor)
+
+ gen_para = self.fedgen.cpu().state_dict()
+ embedding = self.trainer_fedgen.embedding()
+ self.state = round
+ logger.info(f'\tClient #{self.ID} pre-train finish!')
+ # Start the training of fedgen
+ self.comm_manager.send(
+ Message(msg_type='gen_para',
+ sender=self.ID,
+ receiver=[sender],
+ state=self.state,
+ content=[gen_para, embedding, self.hide_data.num_missing]))
+ logger.info(f'\tClient #{self.ID} send gen_para to Server #{sender}.')
+
+ def callback_funcs_for_gen_para(self, message: Message):
+ round, sender, content = message.state, message.sender, message.content
+ gen_para, embedding, label, ID = content
+
+ gen_grad = self.trainer_fedgen.cal_grad(self.data, gen_para, embedding,
+ label)
+ self.state = round
+ self.comm_manager.send(
+ Message(msg_type='gradient',
+ sender=self.ID,
+ receiver=[sender],
+ state=self.state,
+ content=[gen_grad, ID]))
+ logger.info(f'\tClient #{self.ID}: send gradient to Server #{sender}.')
+
+ def callback_funcs_for_gradient(self, message):
+ # Aggregate gen_grad on server
+ round, sender, content = message.state, message.sender, message.content
+ gen_grad = content
+ self.trainer_fedgen.train()
+ gen_para = self.trainer_fedgen.update_by_grad(gen_grad)
+ embedding = self.trainer_fedgen.embedding()
+ self.state = round
+ self.comm_manager.send(
+ Message(msg_type='gen_para',
+ sender=self.ID,
+ receiver=[sender],
+ state=self.state,
+ content=[gen_para, embedding, self.hide_data.num_missing]))
+ logger.info(f'\tClient #{self.ID}: send gen_para to Server #{sender}.')
+
+ def callback_funcs_for_setup_fedsage(self, message: Message):
+ round, sender, content = message.state, message.sender, message.content
+ self.filled_data = GraphMender(model=self.fedgen,
+ impaired_data=self.hide_data.cpu(),
+ original_data=self.data)
+ subgraph_sampler = NeighborSampler(
+ self.filled_data.edge_index,
+ sizes=[-1],
+ batch_size=4096,
+ shuffle=False,
+ num_workers=self._cfg.data.num_workers)
+ fill_dataloader = {
+ 'data': self.filled_data,
+ 'train': NeighborSampler(self.filled_data.edge_index,
+ node_idx=self.filled_data.train_idx,
+ sizes=self._cfg.data.sizes,
+ batch_size=self.sage_batch_size,
+ shuffle=self._cfg.data.shuffle,
+ num_workers=self._cfg.data.num_workers),
+ 'val': subgraph_sampler,
+ 'test': subgraph_sampler
+ }
+ self._cfg.merge_from_list(['data.batch_size', self.sage_batch_size])
+ self.trainer_clf = NodeMiniBatchTrainer(self.clf,
+ fill_dataloader,
+ self.device,
+ self._cfg,
+ monitor=self._monitor)
+ sample_size, clf_para, results = self.trainer_clf.train()
+ self.state = round
+ logger.info(
+ self._monitor.format_eval_res(results,
+ rnd=self.state,
+ role='Client #{}'.format(self.ID)))
+ self.comm_manager.send(
+ Message(msg_type='clf_para',
+ sender=self.ID,
+ receiver=[sender],
+ state=self.state,
+ content=(sample_size, clf_para)))
+
+ def callback_funcs_for_model_para(self, message: Message):
+ round, sender, content = message.state, message.sender, message.content
+ self.trainer_clf.update(content)
+ self.state = round
+ sample_size, clf_para, results = self.trainer_clf.train()
+ logger.info(
+ self._monitor.format_eval_res(results,
+ rnd=self.state,
+ role='Client #{}'.format(self.ID)))
+ self.comm_manager.send(
+ Message(msg_type='clf_para',
+ sender=self.ID,
+ receiver=[sender],
+ state=self.state,
+ content=(sample_size, clf_para)))
diff --git a/federatedscope/gfl/flitplus/__init__.py b/federatedscope/gfl/flitplus/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/federatedscope/gfl/flitplus/trainer.py b/federatedscope/gfl/flitplus/trainer.py
new file mode 100644
index 000000000..65a45ba2e
--- /dev/null
+++ b/federatedscope/gfl/flitplus/trainer.py
@@ -0,0 +1,264 @@
+import torch
+from copy import deepcopy
+
+from federatedscope.gfl.loss.vat import VATLoss
+from federatedscope.core.trainers.trainer import GeneralTorchTrainer
+
+
+class FLITTrainer(GeneralTorchTrainer):
+ def register_default_hooks_train(self):
+ super(FLITTrainer, self).register_default_hooks_train()
+ self.register_hook_in_train(new_hook=record_initialization_local,
+ trigger='on_fit_start',
+ insert_pos=-1)
+ self.register_hook_in_train(new_hook=del_initialization_local,
+ trigger='on_fit_end',
+ insert_pos=-1)
+ self.register_hook_in_train(new_hook=record_initialization_global,
+ trigger='on_fit_start',
+ insert_pos=-1)
+ self.register_hook_in_train(new_hook=del_initialization_global,
+ trigger='on_fit_end',
+ insert_pos=-1)
+
+ def register_default_hooks_eval(self):
+ super(FLITTrainer, self).register_default_hooks_eval()
+ self.register_hook_in_eval(new_hook=record_initialization_local,
+ trigger='on_fit_start',
+ insert_pos=-1)
+ self.register_hook_in_eval(new_hook=del_initialization_local,
+ trigger='on_fit_end',
+ insert_pos=-1)
+ self.register_hook_in_eval(new_hook=record_initialization_global,
+ trigger='on_fit_start',
+ insert_pos=-1)
+ self.register_hook_in_eval(new_hook=del_initialization_global,
+ trigger='on_fit_end',
+ insert_pos=-1)
+
+ def _hook_on_batch_forward(self, ctx):
+ batch = ctx.data_batch.to(ctx.device)
+ pred = ctx.model(batch)
+ ctx.global_model.to(ctx.device)
+ predG = ctx.global_model(batch)
+ if ctx.criterion._get_name() == 'CrossEntropyLoss':
+ label = batch.y.squeeze(-1).long()
+ elif ctx.criterion._get_name() == 'MSELoss':
+ label = batch.y.float()
+ else:
+ raise ValueError(
+ f'FLIT trainer not support {ctx.criterion._get_name()}.')
+ if len(label.size()) == 0:
+ label = label.unsqueeze(0)
+
+ lossGlobalLabel = ctx.criterion(predG, label)
+ lossLocalLabel = ctx.criterion(pred, label)
+
+ weightloss = lossLocalLabel + torch.relu(lossLocalLabel -
+ lossGlobalLabel.detach())
+ if ctx.weight_denomaitor == None:
+ ctx.weight_denomaitor = weightloss.mean(dim=0,
+ keepdim=True).detach()
+ else:
+ ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * ctx.weight_denomaitor + (
+ 1 - self.cfg.flitplus.factor_ema) * weightloss.mean(
+ dim=0, keepdim=True).detach()
+
+ loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
+ 1e-7)**self.cfg.flitplus.tmpFed * (lossLocalLabel)
+ ctx.loss_batch = loss.mean()
+
+ ctx.batch_size = len(label)
+ ctx.y_true = label
+ ctx.y_prob = pred
+
+
+class FLITPlusTrainer(FLITTrainer):
+ def _hook_on_batch_forward(self, ctx):
+ # LDS should be calculated before the forward for cross entropy
+ batch = ctx.data_batch.to(ctx.device)
+ ctx.global_model.to(ctx.device)
+ if ctx.cur_mode == 'test':
+ lossLocalVAT, lossGlobalVAT = torch.tensor(0.), torch.tensor(0.)
+ else:
+ vat_loss = VATLoss() # xi, and eps
+ lossLocalVAT = vat_loss(deepcopy(ctx.model), batch,
+ deepcopy(ctx.criterion))
+ lossGlobalVAT = vat_loss(deepcopy(ctx.global_model), batch,
+ deepcopy(ctx.criterion))
+
+ pred = ctx.model(batch)
+ predG = ctx.global_model(batch)
+ if ctx.criterion._get_name() == 'CrossEntropyLoss':
+ label = batch.y.squeeze(-1).long()
+ elif ctx.criterion._get_name() == 'MSELoss':
+ label = batch.y.float()
+ else:
+ raise ValueError(
+ f'FLITPLUS trainer not support {ctx.criterion._get_name()}.')
+ if len(label.size()) == 0:
+ label = label.unsqueeze(0)
+ lossGlobalLabel = ctx.criterion(predG, label)
+ lossLocalLabel = ctx.criterion(pred, label)
+
+ weightloss_loss = lossLocalLabel + torch.relu(lossLocalLabel -
+ lossGlobalLabel.detach())
+ weightloss_vat = (lossLocalVAT +
+ torch.relu(lossLocalVAT - lossGlobalVAT.detach()))
+ weightloss = weightloss_loss + self.cfg.flitplus.lambdavat * weightloss_vat
+ if ctx.weight_denomaitor == None:
+ ctx.weight_denomaitor = weightloss.mean(dim=0,
+ keepdim=True).detach()
+ else:
+ ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * ctx.weight_denomaitor + (
+ 1 - self.cfg.flitplus.factor_ema) * weightloss.mean(
+ dim=0, keepdim=True).detach()
+
+ loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
+ 1e-7)**self.cfg.flitplus.tmpFed * (
+ lossLocalLabel +
+ self.cfg.flitplus.weightReg * lossLocalVAT)
+ ctx.loss_batch = loss.mean()
+
+ ctx.batch_size = len(label)
+ ctx.y_true = label
+ ctx.y_prob = pred
+
+
+class FedFocalTrainer(GeneralTorchTrainer):
+ def register_default_hooks_train(self):
+ super(FedFocalTrainer, self).register_default_hooks_train()
+ self.register_hook_in_train(new_hook=record_initialization_local,
+ trigger='on_fit_start',
+ insert_pos=-1)
+ self.register_hook_in_train(new_hook=del_initialization_local,
+ trigger='on_fit_end',
+ insert_pos=-1)
+
+ def register_default_hooks_eval(self):
+ super(FedFocalTrainer, self).register_default_hooks_eval()
+ self.register_hook_in_eval(new_hook=record_initialization_local,
+ trigger='on_fit_start',
+ insert_pos=-1)
+ self.register_hook_in_eval(new_hook=del_initialization_local,
+ trigger='on_fit_end',
+ insert_pos=-1)
+
+ def _hook_on_batch_forward(self, ctx):
+ batch = ctx.data_batch.to(ctx.device)
+ pred = ctx.model(batch)
+ if ctx.criterion._get_name() == 'CrossEntropyLoss':
+ label = batch.y.squeeze(-1).long()
+ elif ctx.criterion._get_name() == 'MSELoss':
+ label = batch.y.float()
+ else:
+ raise ValueError(
+ f'FLIT trainer not support {ctx.criterion._get_name()}.')
+ if len(label.size()) == 0:
+ label = label.unsqueeze(0)
+
+ lossLocalLabel = ctx.criterion(pred, label)
+ weightloss = lossLocalLabel
+ if ctx.weight_denomaitor == None:
+ ctx.weight_denomaitor = weightloss.mean(dim=0,
+ keepdim=True).detach()
+ else:
+ ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * ctx.weight_denomaitor + (
+ 1 - self.cfg.flitplus.factor_ema) * weightloss.mean(
+ dim=0, keepdim=True).detach()
+
+ loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
+ 1e-7)**self.cfg.flitplus.tmpFed * (lossLocalLabel)
+ ctx.loss_batch = loss.mean()
+
+ ctx.batch_size = len(label)
+ ctx.y_true = label
+ ctx.y_prob = pred
+
+
+class FedVATTrainer(GeneralTorchTrainer):
+ def register_default_hooks_train(self):
+ super(FedVATTrainer, self).register_default_hooks_train()
+ self.register_hook_in_train(new_hook=record_initialization_local,
+ trigger='on_fit_start',
+ insert_pos=-1)
+ self.register_hook_in_train(new_hook=del_initialization_local,
+ trigger='on_fit_end',
+ insert_pos=-1)
+
+ def register_default_hooks_eval(self):
+ super(FedVATTrainer, self).register_default_hooks_eval()
+ self.register_hook_in_eval(new_hook=record_initialization_local,
+ trigger='on_fit_start',
+ insert_pos=-1)
+ self.register_hook_in_eval(new_hook=del_initialization_local,
+ trigger='on_fit_end',
+ insert_pos=-1)
+
+ def _hook_on_batch_forward(self, ctx):
+ batch = ctx.data_batch.to(ctx.device)
+ if ctx.cur_mode == 'test':
+ lossLocalVAT = torch.tensor(0.)
+ else:
+ vat_loss = VATLoss() # xi, and eps
+ lossLocalVAT = vat_loss(deepcopy(ctx.model), batch,
+ deepcopy(ctx.criterion))
+
+ pred = ctx.model(batch)
+ if ctx.criterion._get_name() == 'CrossEntropyLoss':
+ label = batch.y.squeeze(-1).long()
+ elif ctx.criterion._get_name() == 'MSELoss':
+ label = batch.y.float()
+ else:
+ raise ValueError(
+ f'FedVAT trainer not support {ctx.criterion._get_name()}.')
+ if len(label.size()) == 0:
+ label = label.unsqueeze(0)
+ lossLocalLabel = ctx.criterion(pred, label)
+ weightloss = lossLocalLabel + self.cfg.flitplus.lambdavat * lossLocalVAT
+ if ctx.weight_denomaitor == None:
+ ctx.weight_denomaitor = weightloss.mean(dim=0,
+ keepdim=True).detach()
+ else:
+ ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * ctx.weight_denomaitor + (
+ 1 - self.cfg.flitplus.factor_ema) * weightloss.mean(
+ dim=0, keepdim=True).detach()
+
+ loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
+ 1e-7)**self.cfg.flitplus.tmpFed * (
+ lossLocalLabel +
+ self.cfg.flitplus.weightReg * lossLocalVAT)
+ ctx.loss_batch = loss.mean()
+
+ ctx.batch_size = len(label)
+ ctx.y_true = label
+ ctx.y_prob = pred
+
+
+def record_initialization_local(ctx):
+ """Record weight denomaitor to cpu
+
+ """
+ ctx.weight_denomaitor = None
+
+
+def del_initialization_local(ctx):
+ """Clear the variable to avoid memory leakage
+
+ """
+ ctx.weight_denomaitor = None
+
+
+def record_initialization_global(ctx):
+ """Record the shared global model to cpu
+
+ """
+ ctx.global_model = deepcopy(ctx.model)
+ ctx.global_model.to(torch.device("cpu"))
+
+
+def del_initialization_global(ctx):
+ """Clear the variable to avoid memory leakage
+
+ """
+ ctx.global_model = None
diff --git a/federatedscope/gfl/gcflplus/__init__.py b/federatedscope/gfl/gcflplus/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/federatedscope/gfl/gcflplus/utils.py b/federatedscope/gfl/gcflplus/utils.py
new file mode 100644
index 000000000..ba8042033
--- /dev/null
+++ b/federatedscope/gfl/gcflplus/utils.py
@@ -0,0 +1,33 @@
+import torch
+import numpy as np
+from dtaidistance import dtw
+"""
+ Utils from: https://github.com/Oxfordblue7/GCFL
+"""
+
+
+def norm(w):
+ return torch.norm(torch.cat([v.flatten() for v in w.values()])).item()
+
+
+def compute_pairwise_distances(seqs, standardize=False):
+ """ computes DTW distances for gcfl+"""
+ if standardize:
+ # standardize to only focus on the trends
+ seqs = np.array(seqs)
+ seqs = seqs / seqs.std(axis=1).reshape(-1, 1)
+ distances = dtw.distance_matrix(seqs)
+ else:
+ distances = dtw.distance_matrix(seqs)
+ return distances
+
+
+def min_cut(similarity, cluster):
+ g = nx.Graph()
+ for i in range(len(similarity)):
+ for j in range(len(similarity)):
+ g.add_edge(i, j, weight=similarity[i][j])
+ cut, partition = nx.stoer_wagner(g)
+ c1 = np.array([cluster[x] for x in partition[0]])
+ c2 = np.array([cluster[x] for x in partition[1]])
+ return c1, c2
diff --git a/federatedscope/gfl/gcflplus/worker.py b/federatedscope/gfl/gcflplus/worker.py
new file mode 100644
index 000000000..e69f66ee9
--- /dev/null
+++ b/federatedscope/gfl/gcflplus/worker.py
@@ -0,0 +1,208 @@
+import torch
+import logging
+import copy
+import numpy as np
+
+from federatedscope.core.message import Message
+from federatedscope.core.worker.server import Server
+from federatedscope.core.worker.client import Client
+from federatedscope.core.auxiliaries.utils import merge_dict
+from federatedscope.gfl.gcflplus.utils import compute_pairwise_distances, min_cut, norm
+
+logger = logging.getLogger(__name__)
+
+
+class GCFLPlusServer(Server):
+ def __init__(self,
+ ID=-1,
+ state=0,
+ config=None,
+ data=None,
+ model=None,
+ client_num=5,
+ total_round_num=10,
+ device='cpu',
+ strategy=None,
+ **kwargs):
+ super(GCFLPlusServer,
+ self).__init__(ID, state, config, data, model, client_num,
+ total_round_num, device, strategy, **kwargs)
+ # Initial cluster
+ self.cluster_indices = [
+ np.arange(1, self._cfg.federate.client_num + 1).astype("int")
+ ]
+ self.client_clusters = [[ID for ID in cluster_id]
+ for cluster_id in self.cluster_indices]
+ # Maintain a grad sequence
+ self.seqs_grads = {
+ idx: []
+ for idx in range(1, self._cfg.federate.client_num + 1)
+ }
+
+ def compute_update_norm(self, cluster):
+ max_norm = -np.inf
+ cluster_dWs = []
+ for key in cluster:
+ content = self.msg_buffer['train'][self.state][key]
+ _, model_para, client_dw, _ = content
+ dW = {}
+ for k in model_para.keys():
+ dW[k] = client_dw[k]
+ update_norm = norm(dW)
+ if update_norm > max_norm:
+ max_norm = update_norm
+ cluster_dWs.append(
+ torch.cat([value.flatten() for value in dW.values()]))
+ mean_norm = torch.norm(torch.mean(torch.stack(cluster_dWs),
+ dim=0)).item()
+ return max_norm, mean_norm
+
+ def check_and_move_on(self, check_eval_result=False):
+
+ if check_eval_result:
+ # all clients are participating in evaluation
+ minimal_number = self.client_num
+ else:
+ # sampled clients are participating in training
+ minimal_number = self.sample_client_num
+
+ if self.check_buffer(self.state, minimal_number, check_eval_result):
+
+ if not check_eval_result: # in the training process
+ # Get all the message
+ train_msg_buffer = self.msg_buffer['train'][self.state]
+ for model_idx in range(self.model_num):
+ model = self.models[model_idx]
+ aggregator = self.aggregators[model_idx]
+ msg_list = list()
+ for client_id in train_msg_buffer:
+ if self.model_num == 1:
+ train_data_size, model_para, _, convGradsNorm = train_msg_buffer[
+ client_id]
+ self.seqs_grads[client_id].append(convGradsNorm)
+ msg_list.append((train_data_size, model_para))
+ else:
+ raise ValueError(
+ 'GCFL server not support multi-model.')
+
+ cluster_indices_new = []
+ for cluster in self.cluster_indices:
+ max_norm, mean_norm = self.compute_update_norm(cluster)
+ # create new cluster
+ if mean_norm < self._cfg.gcflplus.EPS_1 and max_norm > self._cfg.gcflplus.EPS_2 and len(
+ cluster) > 2 and self.state > 20 and all(
+ len(value) >= self._cfg.gcflplus.seq_length
+ for value in self.seqs_grads.values()):
+ _, model_para_cluster, _, _ = self.msg_buffer[
+ 'train'][self.state][cluster[0]]
+ tmp = [
+ self.seqs_grads[ID]
+ [-self._cfg.gcflplus.seq_length:]
+ for ID in cluster
+ ]
+ dtw_distances = compute_pairwise_distances(
+ tmp, self._cfg.gcflplus.standardize)
+ c1, c2 = min_cut(
+ np.max(dtw_distances) - dtw_distances, cluster)
+ cluster_indices_new += [c1, c2]
+ # reset seqs_grads for all clients
+ self.seqs_grads = {
+ idx: []
+ for idx in range(
+ 1, self._cfg.federate.client_num + 1)
+ }
+ # keep this cluster
+ else:
+ cluster_indices_new += [cluster]
+
+ self.cluster_indices = cluster_indices_new
+ self.client_clusters = [[
+ ID for ID in cluster_id
+ ] for cluster_id in self.cluster_indices]
+
+ self.state += 1
+ if self.state % self._cfg.eval.freq == 0 and self.state != self.total_round_num:
+ # Evaluate
+ logger.info(
+ 'Server #{:d}: Starting evaluation at round {:d}.'.
+ format(self.ID, self.state))
+ self.eval()
+
+ if self.state < self.total_round_num:
+ for cluster in self.cluster_indices:
+ msg_lsit = list()
+ for key in cluster:
+ content = self.msg_buffer['train'][self.state -
+ 1][key]
+ train_data_size, model_para, client_dw, convGradsNorm = content
+ msg_lsit.append((train_data_size, model_para))
+
+ agg_info = {
+ 'client_feedback': msg_list,
+ 'recover_fun': self.recover_fun
+ }
+ result = aggregator.aggregate(agg_info)
+ model.load_state_dict(result, strict=False)
+ # aggregator.update(result)
+ # Send to Clients
+ self.comm_manager.send(
+ Message(msg_type='model_para',
+ sender=self.ID,
+ receiver=cluster.tolist(),
+ state=self.state,
+ content=result))
+
+ # Move to next round of training
+ logger.info(
+ '----------- Starting a new training round (Round #{:d}) -------------'
+ .format(self.state))
+ # Clean the msg_buffer
+ self.msg_buffer['train'][self.state - 1].clear()
+
+ else:
+ # Final Evaluate
+ logger.info(
+ 'Server #{:d}: Training is finished! Starting evaluation.'
+ .format(self.ID))
+ self.eval()
+
+ else: # in the evaluation process
+ # Get all the message & aggregate
+ formatted_eval_res = self.merge_eval_results_from_all_clients()
+ self.history_results = merge_dict(self.history_results,
+ formatted_eval_res)
+ self.check_and_save()
+
+
+class GCFLPlusClient(Client):
+ def callback_funcs_for_model_para(self, message: Message):
+ round, sender, content = message.state, message.sender, message.content
+ # Cache old W
+ W_old = copy.deepcopy(content)
+ self.trainer.update(content)
+ self.state = round
+ sample_size, model_para, results = self.trainer.train()
+ logger.info(
+ self._monitor.format_eval_res(results,
+ rnd=self.state,
+ role='Client #{}'.format(self.ID)))
+
+ # Compute norm of W & norm of grad
+ dW = dict()
+ for key in model_para.keys():
+ dW[key] = model_para[key] - W_old[key].cpu()
+
+ self.W = {key: value for key, value in self.model.named_parameters()}
+
+ convGradsNorm = dict()
+ for key in model_para.keys():
+ if key in self.W and self.W[key].grad is not None:
+ convGradsNorm[key] = self.W[key].grad
+ convGradsNorm = norm(convGradsNorm)
+
+ self.comm_manager.send(
+ Message(msg_type='model_para',
+ sender=self.ID,
+ receiver=[sender],
+ state=self.state,
+ content=(sample_size, model_para, dW, convGradsNorm)))
diff --git a/federatedscope/gfl/loss/__init__.py b/federatedscope/gfl/loss/__init__.py
new file mode 100644
index 000000000..bb0c1a585
--- /dev/null
+++ b/federatedscope/gfl/loss/__init__.py
@@ -0,0 +1,7 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+from federatedscope.gfl.loss.greedy_loss import GreedyLoss
+
+__all__ = ['GreedyLoss']
\ No newline at end of file
diff --git a/federatedscope/gfl/loss/greedy_loss.py b/federatedscope/gfl/loss/greedy_loss.py
new file mode 100644
index 000000000..9d20be2f6
--- /dev/null
+++ b/federatedscope/gfl/loss/greedy_loss.py
@@ -0,0 +1,67 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+def GreedyLoss(pred_feats, true_feats, pred_missing, true_missing, num_pred):
+ r"""Greedy loss is a loss function of cacluating the MSE loss for the feature.
+ https://proceedings.neurips.cc//paper/2021/file/34adeb8e3242824038aa65460a47c29e-Paper.pdf
+ Fedsageplus models from the "Subgraph Federated Learning with Missing Neighbor Generation" (FedSage+) paper, in NeurIPS'21
+ Source: https://github.com/zkhku/fedsage
+
+ Arguments:
+ pred_feats (torch.Tensor): generated missing features
+ true_feats (torch.Tensor): real missing features
+ pred_missing (torch.Tensor): number of predicted missing node
+ true_missing (torch.Tensor): number of missing node
+ num_pred (int): hyperparameters which limit the maximum value of the prediction
+ :returns:
+ loss : the Greedy Loss
+ :rtype:
+ torch.FloatTensor
+ """
+ CUDA, device = (pred_feats.device.type != 'cpu'), pred_feats.device
+ if CUDA:
+ true_missing = true_missing.cpu()
+ pred_missing = pred_missing.cpu()
+ loss = torch.zeros(pred_feats.shape)
+ if CUDA:
+ loss = loss.to(device)
+ pred_len = len(pred_feats)
+ pred_missing_np = np.round(
+ pred_missing.detach().numpy()).reshape(-1).astype(np.int32)
+ true_missing_np = true_missing.detach().numpy().reshape(-1).astype(
+ np.int32)
+ true_missing_np = np.clip(true_missing_np, 0, num_pred)
+ pred_missing_np = np.clip(pred_missing_np, 0, num_pred)
+ for i in range(pred_len):
+ for pred_j in range(min(num_pred, pred_missing_np[i])):
+ if true_missing_np[i] > 0:
+ if isinstance(true_feats[i][true_missing_np[i] - 1],
+ np.ndarray):
+ true_feats_tensor = torch.tensor(
+ true_feats[i][true_missing_np[i] - 1])
+ if CUDA:
+ true_feats_tensor = true_feats_tensor.to(device)
+ else:
+ true_feats_tensor = true_feats[i][true_missing_np[i] - 1]
+ loss[i][pred_j] += F.mse_loss(
+ pred_feats[i][pred_j].unsqueeze(0).float(),
+ true_feats_tensor.unsqueeze(0).float()).squeeze(0)
+
+ for true_k in range(min(num_pred, true_missing_np[i])):
+ if isinstance(true_feats[i][true_k], np.ndarray):
+ true_feats_tensor = torch.tensor(true_feats[i][true_k])
+ if CUDA:
+ true_feats_tensor = true_feats_tensor.to(device)
+ else:
+ true_feats_tensor = true_feats[i][true_k]
+
+ loss_ijk = F.mse_loss(
+ pred_feats[i][pred_j].unsqueeze(0).float(),
+ true_feats_tensor.unsqueeze(0).float()).squeeze(0)
+ if torch.sum(loss_ijk) < torch.sum(loss[i][pred_j].data):
+ loss[i][pred_j] = loss_ijk
+ else:
+ continue
+ return loss.unsqueeze(0).mean().float()
diff --git a/federatedscope/gfl/loss/vat.py b/federatedscope/gfl/loss/vat.py
new file mode 100644
index 000000000..9fb03a28c
--- /dev/null
+++ b/federatedscope/gfl/loss/vat.py
@@ -0,0 +1,88 @@
+import contextlib
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_geometric.data.batch import Batch
+
+
+@contextlib.contextmanager
+def _disable_tracking_bn_stats(model):
+ def switch_attr(m):
+ if hasattr(m, 'track_running_stats'):
+ m.track_running_stats ^= True
+
+ model.apply(switch_attr)
+ yield
+ model.apply(switch_attr)
+
+
+def _l2_normalize(d):
+ d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
+ d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8
+ return d
+
+
+class VATLoss(nn.Module):
+ def __init__(self, xi=1e-3, eps=2.5, ip=1):
+ r"""VAT loss
+ Source: https://github.com/lyakaap/VAT-pytorch
+
+ Arguments:
+ xi: hyperparameter of VAT in Eq.9, default: 0.0001
+ eps: hyperparameter of VAT in Eq.9, default: 2.5
+ ip: iteration times of computing adv noise
+
+ Returns:
+ loss : the VAT Loss
+
+ """
+ super(VATLoss, self).__init__()
+ self.xi = xi
+ self.eps = eps
+ self.ip = ip
+
+ def forward(self, model, graph, criterion):
+ pred = model(graph)
+ if criterion._get_name() == 'CrossEntropyLoss':
+ pred = torch.max(pred, dim=1).indices.long().view(-1)
+
+ # prepare random unit tensor
+ nodefea = graph.x
+ dn = torch.rand(nodefea.shape).sub(0.5).to(nodefea.device)
+ dn = _l2_normalize(dn)
+
+ with _disable_tracking_bn_stats(model):
+ # calc adversarial direction
+ with torch.enable_grad():
+ for _ in range(self.ip):
+ dn.requires_grad_()
+ x_neighbor = Batch(x=nodefea + self.xi * dn,
+ edge_index=graph.edge_index,
+ y=graph.y,
+ edge_attr=graph.edge_attr,
+ batch=graph.batch)
+ pred_hat = model(x_neighbor)
+ # logp_hat = F.log_softmax(pred_hat, dim=1)
+ # adv_distance = F.kl_div(logp_hat, logp, reduction='batchmean')
+ # adv_distance = ((pred - pred_hat) ** 2).sum(axis=0).sqrt()
+ adv_distance = criterion(pred_hat, pred)
+ # adv_distance.backward()
+ # dn = _l2_normalize(dn.grad)
+ dn = _l2_normalize(
+ torch.autograd.grad(outputs=adv_distance,
+ inputs=dn,
+ retain_graph=True)[0])
+ model.zero_grad()
+ del x_neighbor, pred_hat, adv_distance
+
+ # calc LDS
+ rn_adv = dn * self.eps
+ x_adv = Batch(x=nodefea + rn_adv,
+ edge_index=graph.edge_index,
+ y=graph.y,
+ edge_attr=graph.edge_attr,
+ batch=graph.batch)
+ pred_hat = model(x_adv)
+ lds = criterion(pred_hat, pred)
+
+ return lds
diff --git a/federatedscope/gfl/model/__init__.py b/federatedscope/gfl/model/__init__.py
new file mode 100644
index 000000000..ec5bf0316
--- /dev/null
+++ b/federatedscope/gfl/model/__init__.py
@@ -0,0 +1,19 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+from federatedscope.core.mlp import MLP
+from federatedscope.gfl.model.model_builder import get_gnn
+from federatedscope.gfl.model.gcn import GCN_Net
+from federatedscope.gfl.model.sage import SAGE_Net
+from federatedscope.gfl.model.gin import GIN_Net
+from federatedscope.gfl.model.gat import GAT_Net
+from federatedscope.gfl.model.gpr import GPR_Net
+from federatedscope.gfl.model.graph_level import GNN_Net_Graph
+from federatedscope.gfl.model.link_level import GNN_Net_Link
+from federatedscope.gfl.model.fedsageplus import LocalSage_Plus, FedSage_Plus
+
+__all__ = [
+ 'get_gnn', 'GCN_Net', 'SAGE_Net', 'GIN_Net', 'GAT_Net', 'GPR_Net',
+ 'GNN_Net_Graph', 'GNN_Net_Link', 'LocalSage_Plus', 'FedSage_Plus', 'MLP'
+]
diff --git a/federatedscope/gfl/model/fedsageplus.py b/federatedscope/gfl/model/fedsageplus.py
new file mode 100644
index 000000000..d067ddf3b
--- /dev/null
+++ b/federatedscope/gfl/model/fedsageplus.py
@@ -0,0 +1,175 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import torch
+import numpy as np
+import scipy.sparse as sp
+
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_geometric.data import Data
+
+from federatedscope.gfl.model import SAGE_Net
+"""
+https://proceedings.neurips.cc//paper/2021/file/34adeb8e3242824038aa65460a47c29e-Paper.pdf
+Fedsageplus models from the "Subgraph Federated Learning with Missing Neighbor Generation" (FedSage+) paper, in NeurIPS'21
+Source: https://github.com/zkhku/fedsage
+"""
+
+
+class Sampling(nn.Module):
+ def __init__(self):
+ super(Sampling, self).__init__()
+
+ def forward(self, inputs):
+ rand = torch.normal(0, 1, size=inputs.shape)
+
+ return inputs + rand.to(inputs.device)
+
+
+class FeatGenerator(nn.Module):
+ def __init__(self, latent_dim, dropout, num_pred, feat_shape):
+ super(FeatGenerator, self).__init__()
+ self.num_pred = num_pred
+ self.feat_shape = feat_shape
+ self.dropout = dropout
+ self.sample = Sampling()
+ self.fc1 = nn.Linear(latent_dim, 256)
+ self.fc2 = nn.Linear(256, 2048)
+ self.fc_flat = nn.Linear(2048, self.num_pred * self.feat_shape)
+
+ def forward(self, x):
+ x = self.sample(x)
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = torch.tanh(self.fc_flat(x))
+
+ return x
+
+
+class NumPredictor(nn.Module):
+ def __init__(self, latent_dim):
+ self.latent_dim = latent_dim
+ super(NumPredictor, self).__init__()
+ self.reg_1 = nn.Linear(self.latent_dim, 1)
+
+ def forward(self, x):
+ x = F.relu(self.reg_1(x))
+ return x
+
+
+# Mend the graph via NeighGen
+class MendGraph(nn.Module):
+ def __init__(self, num_pred):
+ super(MendGraph, self).__init__()
+ self.num_pred = num_pred
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def mend_graph(self, x, edge_index, pred_degree, gen_feats):
+ device = gen_feats.device
+ num_node, num_feature = x.shape
+ new_edges = []
+ gen_feats = gen_feats.view(-1, self.num_pred, num_feature)
+
+ if pred_degree.device.type != 'cpu':
+ pred_degree = pred_degree.cpu()
+ pred_degree = torch._cast_Int(torch.round(pred_degree)).detach()
+ x = x.detach()
+ fill_feats = torch.vstack((x, gen_feats.view(-1, num_feature)))
+
+ for i in range(num_node):
+ for j in range(min(self.num_pred, max(0, pred_degree[i]))):
+ new_edges.append(
+ np.asarray([i, num_node + i * self.num_pred + j]))
+
+ new_edges = torch.tensor(np.asarray(new_edges).reshape((-1, 2)),
+ dtype=torch.int64).T
+ new_edges = new_edges.to(device)
+ if len(new_edges) > 0:
+ fill_edges = torch.hstack((edge_index, new_edges))
+ else:
+ fill_edges = torch.clone(edge_index)
+ return fill_feats, fill_edges
+
+ def forward(self, x, edge_index, pred_missing, gen_feats):
+ fill_feats, fill_edges = self.mend_graph(x, edge_index, pred_missing,
+ gen_feats)
+
+ return fill_feats, fill_edges
+
+
+class LocalSage_Plus(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ hidden,
+ gen_hidden,
+ dropout=0.5,
+ num_pred=5):
+ super(LocalSage_Plus, self).__init__()
+
+ self.encoder_model = SAGE_Net(in_channels=in_channels,
+ out_channels=gen_hidden,
+ hidden=hidden,
+ max_depth=2,
+ dropout=dropout)
+ self.reg_model = NumPredictor(latent_dim=gen_hidden)
+ self.gen = FeatGenerator(latent_dim=gen_hidden,
+ dropout=dropout,
+ num_pred=num_pred,
+ feat_shape=in_channels)
+ self.mend_graph = MendGraph(num_pred)
+
+ self.classifier = SAGE_Net(in_channels=in_channels,
+ out_channels=out_channels,
+ hidden=hidden,
+ max_depth=2,
+ dropout=dropout)
+
+ def forward(self, data):
+ x = self.encoder_model(data)
+ degree = self.reg_model(x)
+ gen_feat = self.gen(x)
+ mend_feats, mend_edge_index = self.mend_graph(data.x, data.edge_index,
+ degree, gen_feat)
+ nc_pred = self.classifier(
+ Data(x=mend_feats, edge_index=mend_edge_index))
+ return degree, gen_feat, nc_pred[:data.num_nodes]
+
+ def inference(self, impared_data, raw_data):
+ x = self.encoder_model(data)
+ degree = self.reg_model(x)
+ gen_feat = self.gen(x)
+ mend_feats, mend_edge_index = self.mend_graph(raw_data.x,
+ raw_data.edge_index,
+ degree, gen_feat)
+ nc_pred = self.classifier(
+ Data(x=mend_feats, edge_index=mend_edge_index))
+ return degree, gen_feat, nc_pred[:raw_data.num_nodes]
+
+
+class FedSage_Plus(nn.Module):
+ def __init__(self, local_graph: LocalSage_Plus):
+ super(FedSage_Plus, self).__init__()
+ self.encoder_model = local_graph.encoder_model
+ self.reg_model = local_graph.reg_model
+ self.gen = local_graph.gen
+ self.mend_graph = local_graph.mend_graph
+ self.classifier = local_graph.classifier
+ self.encoder_model.requires_grad_(False)
+ self.reg_model.requires_grad_(False)
+ self.mend_graph.requires_grad_(False)
+ self.classifier.requires_grad_(False)
+
+ def forward(self, data):
+ x = self.encoder_model(data)
+ degree = self.reg_model(x)
+ gen_feat = self.gen(x)
+ mend_feats, mend_edge_index = self.mend_graph(data.x, data.edge_index,
+ degree, gen_feat)
+ nc_pred = self.classifier(
+ Data(x=mend_feats, edge_index=mend_edge_index))
+ return degree, gen_feat, nc_pred[:data.num_nodes]
\ No newline at end of file
diff --git a/federatedscope/gfl/model/gat.py b/federatedscope/gfl/model/gat.py
new file mode 100644
index 000000000..07449adcd
--- /dev/null
+++ b/federatedscope/gfl/model/gat.py
@@ -0,0 +1,53 @@
+import torch
+import torch.nn.functional as F
+from torch.nn import ModuleList
+from torch_geometric.data import Data
+from torch_geometric.nn import GATConv
+
+
+class GAT_Net(torch.nn.Module):
+ r"""GAT model from the "Graph Attention Networks" paper, in ICLR'18
+
+ Arguments:
+ in_channels (int): dimension of input.
+ out_channels (int): dimension of output.
+ hidden (int): dimension of hidden units, default=64.
+ max_depth (int): layers of GNN, default=2.
+ dropout (float): dropout ratio, default=.0.
+
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ hidden=64,
+ max_depth=2,
+ dropout=.0):
+ super(GAT_Net, self).__init__()
+ self.convs = ModuleList()
+ for i in range(max_depth):
+ if i == 0:
+ self.convs.append(GATConv(in_channels, hidden))
+ elif (i + 1) == max_depth:
+ self.convs.append(GATConv(hidden, out_channels))
+ else:
+ self.convs.append(GATConv(hidden, hidden))
+ self.dropout = dropout
+
+ def reset_parameters(self):
+ for m in self.convs:
+ m.reset_parameters()
+
+ def forward(self, data):
+ if isinstance(data, Data):
+ x, edge_index = data.x, data.edge_index
+ elif isinstance(data, tuple):
+ x, edge_index = data
+ else:
+ raise TypeError('Unsupported data type!')
+
+ for i, conv in enumerate(self.convs):
+ x = conv(x, edge_index)
+ if (i + 1) == len(self.convs):
+ break
+ x = F.relu(F.dropout(x, p=self.dropout, training=self.training))
+ return x
\ No newline at end of file
diff --git a/federatedscope/gfl/model/gcn.py b/federatedscope/gfl/model/gcn.py
new file mode 100644
index 000000000..c1fedf75c
--- /dev/null
+++ b/federatedscope/gfl/model/gcn.py
@@ -0,0 +1,53 @@
+import torch
+import torch.nn.functional as F
+from torch.nn import ModuleList
+from torch_geometric.data import Data
+from torch_geometric.nn import GCNConv
+
+
+class GCN_Net(torch.nn.Module):
+ r""" GCN model from the "Semi-supervised Classification with Graph Convolutional Networks" paper, in ICLR'17.
+
+ Arguments:
+ in_channels (int): dimension of input.
+ out_channels (int): dimension of output.
+ hidden (int): dimension of hidden units, default=64.
+ max_depth (int): layers of GNN, default=2.
+ dropout (float): dropout ratio, default=.0.
+
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ hidden=64,
+ max_depth=2,
+ dropout=.0):
+ super(GCN_Net, self).__init__()
+ self.convs = ModuleList()
+ for i in range(max_depth):
+ if i == 0:
+ self.convs.append(GCNConv(in_channels, hidden))
+ elif (i + 1) == max_depth:
+ self.convs.append(GCNConv(hidden, out_channels))
+ else:
+ self.convs.append(GCNConv(hidden, hidden))
+ self.dropout = dropout
+
+ def reset_parameters(self):
+ for m in self.convs:
+ m.reset_parameters()
+
+ def forward(self, data):
+ if isinstance(data, Data):
+ x, edge_index = data.x, data.edge_index
+ elif isinstance(data, tuple):
+ x, edge_index = data
+ else:
+ raise TypeError('Unsupported data type!')
+
+ for i, conv in enumerate(self.convs):
+ x = conv(x, edge_index)
+ if (i + 1) == len(self.convs):
+ break
+ x = F.relu(F.dropout(x, p=self.dropout, training=self.training))
+ return x
\ No newline at end of file
diff --git a/federatedscope/gfl/model/gin.py b/federatedscope/gfl/model/gin.py
new file mode 100644
index 000000000..73c6d1514
--- /dev/null
+++ b/federatedscope/gfl/model/gin.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn.functional as F
+from torch.nn import ModuleList
+from torch_geometric.data import Data
+from torch_geometric.nn import GINConv
+
+from federatedscope.core.mlp import MLP
+"""
+Model param names of GIN:
+['convs.0.eps', 'convs.0.nn.linears.0.weight', 'convs.0.nn.linears.0.bias', 'convs.0.nn.linears.1.weight',
+'convs.0.nn.linears.1.bias', 'convs.0.nn.norms.0.weight', 'convs.0.nn.norms.0.bias', 'convs.0.nn.norms.0.running_mean',
+'convs.0.nn.norms.0.running_var', 'convs.0.nn.norms.0.num_batches_tracked', 'convs.0.nn.norms.1.weight',
+'convs.0.nn.norms.1.bias', 'convs.0.nn.norms.1.running_mean', 'convs.0.nn.norms.1.running_var',
+'convs.0.nn.norms.1.num_batches_tracked', 'convs.1.eps', 'convs.1.nn.linears.0.weight', 'convs.1.nn.linears.0.bias',
+'convs.1.nn.linears.1.weight', 'convs.1.nn.linears.1.bias', 'convs.1.nn.norms.0.weight', 'convs.1.nn.norms.0.bias',
+'convs.1.nn.norms.0.running_mean', 'convs.1.nn.norms.0.running_var', 'convs.1.nn.norms.0.num_batches_tracked',
+'convs.1.nn.norms.1.weight', 'convs.1.nn.norms.1.bias', 'convs.1.nn.norms.1.running_mean',
+'convs.1.nn.norms.1.running_var', 'convs.1.nn.norms.1.num_batches_tracked',]
+"""
+
+
+class GIN_Net(torch.nn.Module):
+ r"""Graph Isomorphism Network model from the "How Powerful are Graph Neural Networks?" paper, in ICLR'19
+
+ Arguments:
+ in_channels (int): dimension of input.
+ out_channels (int): dimension of output.
+ hidden (int): dimension of hidden units, default=64.
+ max_depth (int): layers of GNN, default=2.
+ dropout (float): dropout ratio, default=.0.
+
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ hidden=64,
+ max_depth=2,
+ dropout=.0):
+ super(GIN_Net, self).__init__()
+ self.convs = ModuleList()
+ for i in range(max_depth):
+ if i == 0:
+ self.convs.append(
+ GINConv(MLP([in_channels, hidden, hidden],
+ batch_norm=True)))
+ elif (i + 1) == max_depth:
+ self.convs.append(
+ GINConv(
+ MLP([hidden, hidden, out_channels], batch_norm=True)))
+ else:
+ self.convs.append(
+ GINConv(MLP([hidden, hidden, hidden], batch_norm=True)))
+ self.dropout = dropout
+
+ def forward(self, data):
+ if isinstance(data, Data):
+ x, edge_index = data.x, data.edge_index
+ elif isinstance(data, tuple):
+ x, edge_index = data
+ else:
+ raise TypeError('Unsupported data type!')
+
+ for i, conv in enumerate(self.convs):
+ x = conv(x, edge_index)
+ if (i + 1) == len(self.convs):
+ break
+ x = F.relu(F.dropout(x, p=self.dropout, training=self.training))
+ return x
diff --git a/federatedscope/gfl/model/gpr.py b/federatedscope/gfl/model/gpr.py
new file mode 100644
index 000000000..c932a7c91
--- /dev/null
+++ b/federatedscope/gfl/model/gpr.py
@@ -0,0 +1,133 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+from torch.nn import Parameter
+from torch.nn import Linear
+from torch_geometric.data import Data
+from torch_geometric.nn.conv.gcn_conv import gcn_norm
+from torch_geometric.nn import MessagePassing, APPNP
+
+
+class GPR_prop(MessagePassing):
+ '''
+ propagation class for GPR_GNN
+ source: https://github.com/jianhao2016/GPRGNN/blob/master/src/GNN_models.py
+ '''
+ def __init__(self, K, alpha, Init, Gamma=None, bias=True, **kwargs):
+ super(GPR_prop, self).__init__(aggr='add', **kwargs)
+ self.K = K
+ self.Init = Init
+ self.alpha = alpha
+
+ assert Init in ['SGC', 'PPR', 'NPPR', 'Random', 'WS']
+ if Init == 'SGC':
+ # SGC-like, note that in this case, alpha has to be a integer.
+ # It means where the peak at when initializing GPR weights.
+ TEMP = 0.0 * np.ones(K + 1)
+ TEMP[alpha] = 1.0
+ elif Init == 'PPR':
+ # PPR-like
+ TEMP = alpha * (1 - alpha)**np.arange(K + 1)
+ TEMP[-1] = (1 - alpha)**K
+ elif Init == 'NPPR':
+ # Negative PPR
+ TEMP = (alpha)**np.arange(K + 1)
+ TEMP = TEMP / np.sum(np.abs(TEMP))
+ elif Init == 'Random':
+ # Random
+ bound = np.sqrt(3 / (K + 1))
+ TEMP = np.random.uniform(-bound, bound, K + 1)
+ TEMP = TEMP / np.sum(np.abs(TEMP))
+ elif Init == 'WS':
+ # Specify Gamma
+ TEMP = Gamma
+
+ self.temp = Parameter(torch.tensor(TEMP))
+
+ def reset_parameters(self):
+ torch.nn.init.zeros_(self.temp)
+ for k in range(self.K + 1):
+ self.temp.data[k] = self.alpha * (1 - self.alpha)**k
+ self.temp.data[-1] = (1 - self.alpha)**self.K
+
+ def forward(self, x, edge_index, edge_weight=None):
+ edge_index, norm = gcn_norm(edge_index,
+ edge_weight,
+ num_nodes=x.size(0),
+ dtype=x.dtype)
+
+ hidden = x * (self.temp[0])
+ for k in range(self.K):
+ x = self.propagate(edge_index, x=x, norm=norm)
+ gamma = self.temp[k + 1]
+ hidden = hidden + gamma * x
+ return hidden
+
+ def message(self, x_j, norm):
+ return norm.view(-1, 1) * x_j
+
+ def __repr__(self):
+ return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K,
+ self.temp)
+
+
+class GPR_Net(torch.nn.Module):
+ r"""GPR-GNN model from the "Adaptive Universal Generalized PageRank Graph Neural Network" paper, in ICLR'21
+
+ Arguments:
+ in_channels (int): dimension of input.
+ out_channels (int): dimension of output.
+ hidden (int): dimension of hidden units, default=64.
+ K (int): power of GPR-GNN, default=10.
+ dropout (float): dropout ratio, default=.0.
+ ppnp (str): propagation method in ['PPNP', 'GPR_prop']
+ Init (str): init method in ['SGC', 'PPR', 'NPPR', 'Random', 'WS']
+
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ hidden=64,
+ K=10,
+ dropout=.0,
+ ppnp='GPR_prop',
+ alpha=0.1,
+ Init='PPR',
+ Gamma=None):
+ super(GPR_Net, self).__init__()
+ self.lin1 = Linear(in_channels, hidden)
+ self.lin2 = Linear(hidden, out_channels)
+
+ if ppnp == 'PPNP':
+ self.prop1 = APPNP(K, alpha)
+ elif ppnp == 'GPR_prop':
+ self.prop1 = GPR_prop(K, alpha, Init, Gamma)
+
+ self.Init = Init
+ self.dprate = 0.5
+ self.dropout = dropout
+
+ def reset_parameters(self):
+ self.prop1.reset_parameters()
+
+ def forward(self, data):
+ if isinstance(data, Data):
+ x, edge_index = data.x, data.edge_index
+ elif isinstance(data, tuple):
+ x, edge_index = data
+ else:
+ raise TypeError('Unsupported data type!')
+
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = F.relu(self.lin1(x))
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = self.lin2(x)
+
+ if self.dprate == 0.0:
+ x = self.prop1(x, edge_index)
+ return F.log_softmax(x, dim=1)
+ else:
+ x = F.dropout(x, p=self.dprate, training=self.training)
+ x = self.prop1(x, edge_index)
+ return F.log_softmax(x, dim=1)
diff --git a/federatedscope/gfl/model/graph_level.py b/federatedscope/gfl/model/graph_level.py
new file mode 100644
index 000000000..5418d2622
--- /dev/null
+++ b/federatedscope/gfl/model/graph_level.py
@@ -0,0 +1,124 @@
+import torch
+import torch.nn.functional as F
+from torch.nn import Linear, Sequential
+from torch_geometric.data import Data
+from torch_geometric.data.batch import Batch
+from torch_geometric.nn.glob import global_add_pool, global_mean_pool, global_max_pool
+
+from federatedscope.gfl.model.gcn import GCN_Net
+from federatedscope.gfl.model.sage import SAGE_Net
+from federatedscope.gfl.model.gat import GAT_Net
+from federatedscope.gfl.model.gin import GIN_Net
+from federatedscope.gfl.model.gpr import GPR_Net
+
+EMD_DIM = 200
+
+
+class AtomEncoder(torch.nn.Module):
+ def __init__(self, in_channels, hidden):
+ super(AtomEncoder, self).__init__()
+ self.atom_embedding_list = torch.nn.ModuleList()
+ for i in range(in_channels):
+ emb = torch.nn.Embedding(EMD_DIM, hidden)
+ torch.nn.init.xavier_uniform_(emb.weight.data)
+ self.atom_embedding_list.append(emb)
+
+ def forward(self, x):
+ x_embedding = 0
+ for i in range(x.shape[1]):
+ x_embedding += self.atom_embedding_list[i](x[:, i])
+ return x_embedding
+
+
+class GNN_Net_Graph(torch.nn.Module):
+ r"""GNN model with pre-linear layer, pooling layer
+ and output layer for graph classification tasks.
+
+ Arguments:
+ in_channels (int): input channels.
+ out_channels (int): output channels.
+ hidden (int): hidden dim for all modules.
+ max_depth (int): number of layers for gnn.
+ dropout (float): dropout probability.
+ gnn (str): name of gnn type, use ("gcn" or "gin").
+ pooling (str): pooling method, use ("add", "mean" or "max").
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ hidden=64,
+ max_depth=2,
+ dropout=.0,
+ gnn='gcn',
+ pooling='add'):
+ super(GNN_Net_Graph, self).__init__()
+ self.dropout = dropout
+ # Embedding (pre) layer
+ self.encoder_atom = AtomEncoder(in_channels, hidden)
+ self.encoder = Linear(in_channels, hidden)
+ # GNN layer
+ if gnn == 'gcn':
+ self.gnn = GCN_Net(in_channels=hidden,
+ out_channels=hidden,
+ hidden=hidden,
+ max_depth=max_depth,
+ dropout=dropout)
+ elif gnn == 'sage':
+ self.gnn = SAGE_Net(in_channels=hidden,
+ out_channels=hidden,
+ hidden=hidden,
+ max_depth=max_depth,
+ dropout=dropout)
+ elif gnn == 'gat':
+ self.gnn = GAT_Net(in_channels=hidden,
+ out_channels=hidden,
+ hidden=hidden,
+ max_depth=max_depth,
+ dropout=dropout)
+ elif gnn == 'gin':
+ self.gnn = GIN_Net(in_channels=hidden,
+ out_channels=hidden,
+ hidden=hidden,
+ max_depth=max_depth,
+ dropout=dropout)
+ elif gnn == 'gpr':
+ self.gnn = GPR_Net(in_channels=hidden,
+ out_channels=hidden,
+ hidden=hidden,
+ K=max_depth,
+ dropout=dropout)
+ else:
+ raise ValueError(f'Unsupported gnn type: {gnn}.')
+
+ # Pooling layer
+ if pooling == 'add':
+ self.pooling = global_add_pool
+ elif pooling == 'mean':
+ self.pooling = global_mean_pool
+ elif pooling == 'max':
+ self.pooling = global_max_pool
+ else:
+ raise ValueError(f'Unsupported pooling type: {pooling}.')
+ # Output layer
+ self.linear = Sequential(Linear(hidden, hidden), torch.nn.ReLU())
+ self.clf = Linear(hidden, out_channels)
+
+ def forward(self, data):
+ if isinstance(data, Batch):
+ x, edge_index, batch = data.x, data.edge_index, data.batch
+ elif isinstance(data, tuple):
+ x, edge_index, batch = data
+ else:
+ raise TypeError('Unsupported data type!')
+
+ if x.dtype == torch.int64:
+ x = self.encoder_atom(x)
+ else:
+ x = self.encoder(x)
+
+ x = self.gnn((x, edge_index))
+ x = self.pooling(x, batch)
+ x = self.linear(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = self.clf(x)
+ return x
diff --git a/federatedscope/gfl/model/link_level.py b/federatedscope/gfl/model/link_level.py
new file mode 100644
index 000000000..cc53c2d43
--- /dev/null
+++ b/federatedscope/gfl/model/link_level.py
@@ -0,0 +1,88 @@
+import torch
+from torch_geometric.data import Data
+
+from federatedscope.core.mlp import MLP
+from federatedscope.gfl.model.gcn import GCN_Net
+from federatedscope.gfl.model.sage import SAGE_Net
+from federatedscope.gfl.model.gat import GAT_Net
+from federatedscope.gfl.model.gin import GIN_Net
+from federatedscope.gfl.model.gpr import GPR_Net
+
+
+class GNN_Net_Link(torch.nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ hidden=64,
+ max_depth=2,
+ dropout=.0,
+ gnn='gcn',
+ layers=2):
+ r"""GNN model with LinkPredictor for link prediction tasks.
+
+ Arguments:
+ in_channels (int): input channels.
+ out_channels (int): output channels.
+ hidden (int): hidden dim for all modules.
+ max_depth (int): number of layers for gnn.
+ dropout (float): dropout probability.
+ gnn (str): name of gnn type, use ("gcn" or "gin").
+ layers (int): number of layers for LinkPredictor.
+
+ """
+ super(GNN_Net_Link, self).__init__()
+ self.dropout = dropout
+
+ # GNN layer
+ if gnn == 'gcn':
+ self.gnn = GCN_Net(in_channels=in_channels,
+ out_channels=hidden,
+ hidden=hidden,
+ max_depth=max_depth,
+ dropout=dropout)
+ elif gnn == 'sage':
+ self.gnn = SAGE_Net(in_channels=in_channels,
+ out_channels=hidden,
+ hidden=hidden,
+ max_depth=max_depth,
+ dropout=dropout)
+ elif gnn == 'gat':
+ self.gnn = GAT_Net(in_channels=in_channels,
+ out_channels=hidden,
+ hidden=hidden,
+ max_depth=max_depth,
+ dropout=dropout)
+ elif gnn == 'gin':
+ self.gnn = GIN_Net(in_channels=in_channels,
+ out_channels=hidden,
+ hidden=hidden,
+ max_depth=max_depth,
+ dropout=dropout)
+ elif gnn == 'gpr':
+ self.gnn = GPR_Net(in_channels=in_channels,
+ out_channels=hidden,
+ hidden=hidden,
+ K=max_depth,
+ dropout=dropout)
+ else:
+ raise ValueError(f'Unsupported gnn type: {gnn}.')
+
+ dim_list = [hidden for _ in range(layers)]
+ self.output = MLP([hidden] + dim_list + [out_channels],
+ batch_norm=True)
+
+ def forward(self, data):
+ if isinstance(data, Data):
+ x, edge_index = data.x, data.edge_index
+ elif isinstance(data, tuple):
+ x, edge_index = data
+ else:
+ raise TypeError('Unsupported data type!')
+
+ x = self.gnn((x, edge_index))
+ return x
+
+ def link_predictor(self, x, edge_index):
+ x = x[edge_index[0]] * x[edge_index[1]]
+ x = self.output(x)
+ return x
diff --git a/federatedscope/gfl/model/model_builder.py b/federatedscope/gfl/model/model_builder.py
new file mode 100644
index 000000000..614d69f17
--- /dev/null
+++ b/federatedscope/gfl/model/model_builder.py
@@ -0,0 +1,90 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+from federatedscope.gfl.model.gcn import GCN_Net
+from federatedscope.gfl.model.sage import SAGE_Net
+from federatedscope.gfl.model.gat import GAT_Net
+from federatedscope.gfl.model.gin import GIN_Net
+from federatedscope.gfl.model.gpr import GPR_Net
+from federatedscope.gfl.model.link_level import GNN_Net_Link
+from federatedscope.gfl.model.graph_level import GNN_Net_Graph
+from federatedscope.gfl.model.mpnn import MPNNs2s
+
+
+def get_gnn(model_config, local_data):
+ num_label = 0
+ if isinstance(local_data, dict):
+ if 'data' in local_data.keys():
+ data = local_data['data']
+ elif 'train' in local_data.keys():
+ # local_data['train'] is Dataloader
+ data = next(iter(local_data['train']))
+ if 'num_label' in local_data.keys():
+ num_label = local_data['num_label']
+ else:
+ raise TypeError('Unsupported data type.')
+ else:
+ data = local_data
+
+ if model_config.task == 'node':
+ if model_config.type == 'gcn':
+ # assume `data` is a dict where key is the client index, and value is a PyG object
+ model = GCN_Net(data.x.shape[-1],
+ model_config.out_channels,
+ hidden=model_config.hidden,
+ max_depth=model_config.layer,
+ dropout=model_config.dropout)
+ elif model_config.type == 'sage':
+ model = SAGE_Net(data.x.shape[-1],
+ model_config.out_channels,
+ hidden=model_config.hidden,
+ max_depth=model_config.layer,
+ dropout=model_config.dropout)
+ elif model_config.type == 'gat':
+ model = GAT_Net(data.x.shape[-1],
+ model_config.out_channels,
+ hidden=model_config.hidden,
+ max_depth=model_config.layer,
+ dropout=model_config.dropout)
+ elif model_config.type == 'gin':
+ model = GIN_Net(data.x.shape[-1],
+ model_config.out_channels,
+ hidden=model_config.hidden,
+ max_depth=model_config.layer,
+ dropout=model_config.dropout)
+ elif model_config.type == 'gpr':
+ model = GPR_Net(data.x.shape[-1],
+ model_config.out_channels,
+ hidden=model_config.hidden,
+ K=model_config.layer,
+ dropout=model_config.dropout)
+ else:
+ raise ValueError('not recognized gnn model {}'.format(
+ model_config.type))
+
+ elif model_config.task == 'link':
+ model = GNN_Net_Link(data.x.shape[-1],
+ model_config.out_channels,
+ hidden=model_config.hidden,
+ max_depth=model_config.layer,
+ dropout=model_config.dropout,
+ gnn=model_config.type)
+ elif model_config.task == 'graph':
+ if model_config.type == 'mpnn':
+ model = MPNNs2s(in_channels=data.x.shape[-1],
+ out_channels=model_config.out_channels,
+ num_nn=data.num_edge_features,
+ hidden=model_config.hidden)
+ else:
+ model = GNN_Net_Graph(data.x.shape[-1],
+ max(model_config.out_channels, num_label),
+ hidden=model_config.hidden,
+ max_depth=model_config.layer,
+ dropout=model_config.dropout,
+ gnn=model_config.type,
+ pooling=model_config.graph_pooling)
+ else:
+ raise ValueError('not recognized data task {}'.format(
+ model_config.task))
+ return model
diff --git a/federatedscope/gfl/model/mpnn.py b/federatedscope/gfl/model/mpnn.py
new file mode 100644
index 000000000..2c1058e21
--- /dev/null
+++ b/federatedscope/gfl/model/mpnn.py
@@ -0,0 +1,56 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch_geometric.data import Data
+from torch_geometric.data.batch import Batch
+
+from torch.nn import GRU, Linear, ReLU, Sequential
+from torch_geometric.nn import NNConv, Set2Set
+
+
+class MPNNs2s(nn.Module):
+ r"""MPNN from "Neural Message Passing for Quantum Chemistry" for regression and classification on graphs.
+ Source: https://github.com/pyg-team/pytorch_geometric/blob/master/examples/qm9_nn_conv.py
+
+ Arguments:
+ in_channels (int): Size for the input node features.
+ out_channels (int): dimension of output.
+ num_nn (int): num_edge_features.
+ hidden (int): Size for the output node representations. Default to 64.
+
+ """
+ def __init__(self, in_channels, out_channels, num_nn, hidden=64):
+ super(MPNNs2s, self).__init__()
+ self.lin0 = torch.nn.Linear(in_channels, hidden)
+
+ nn = Sequential(Linear(num_nn, 16), ReLU(),
+ Linear(16, hidden * hidden))
+ self.conv = NNConv(hidden, hidden, nn, aggr='add')
+ self.gru = GRU(hidden, hidden)
+
+ self.set2set = Set2Set(hidden, processing_steps=3, num_layers=3)
+ self.lin1 = torch.nn.Linear(2 * hidden, hidden)
+ self.lin2 = torch.nn.Linear(hidden, out_channels)
+
+ def forward(self, data):
+ if isinstance(data, Batch):
+ x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
+ elif isinstance(data, tuple):
+ x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
+ else:
+ raise TypeError('Unsupported data type!')
+
+ self.gru.flatten_parameters()
+ out = F.relu(self.lin0(x.float()))
+ h = out.unsqueeze(0)
+
+ for i in range(3):
+ m = F.relu(self.conv(out, edge_index, edge_attr.float()))
+ out, h = self.gru(m.unsqueeze(0), h)
+ out = out.squeeze(0)
+
+ out = self.set2set(out, batch)
+ out = F.relu(self.lin1(out))
+ out = self.lin2(out)
+ return out
diff --git a/federatedscope/gfl/model/sage.py b/federatedscope/gfl/model/sage.py
new file mode 100644
index 000000000..3bd410d60
--- /dev/null
+++ b/federatedscope/gfl/model/sage.py
@@ -0,0 +1,125 @@
+import torch
+import torch.nn.functional as F
+from torch.nn import ModuleList
+from torch_geometric.data import Data
+from torch_geometric.nn import SAGEConv
+
+
+class SAGE_Net(torch.nn.Module):
+ r"""GraphSAGE model from the "Inductive Representation Learning on Large Graphs" paper, in NeurIPS'17
+
+ Source: https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_products_sage.py
+
+ Arguments:
+ in_channels (int): dimension of input.
+ out_channels (int): dimension of output.
+ hidden (int): dimension of hidden units, default=64.
+ max_depth (int): layers of GNN, default=2.
+ dropout (float): dropout ratio, default=.0.
+
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ hidden=64,
+ max_depth=2,
+ dropout=.0):
+ super(SAGE_Net, self).__init__()
+
+ self.num_layers = max_depth
+ self.dropout = dropout
+
+ self.convs = torch.nn.ModuleList()
+ self.convs.append(SAGEConv(in_channels, hidden))
+ for _ in range(self.num_layers - 2):
+ self.convs.append(SAGEConv(hidden, hidden))
+ self.convs.append(SAGEConv(hidden, out_channels))
+
+ def reset_parameters(self):
+ for conv in self.convs:
+ conv.reset_parameters()
+
+ def forward_full(self, data):
+ if isinstance(data, Data):
+ x, edge_index = data.x, data.edge_index
+ elif isinstance(data, tuple):
+ x, edge_index = data
+ else:
+ raise TypeError('Unsupported data type!')
+
+ for i, conv in enumerate(self.convs):
+ x = conv(x, edge_index)
+ if (i + 1) == len(self.convs):
+ break
+ x = F.relu(F.dropout(x, p=self.dropout, training=self.training))
+ return x
+
+ def forward(self, x, edge_index=None, edge_weight=None, adjs=None):
+ r"""
+ `train_loader` computes the k-hop neighborhood of a batch of nodes,
+ and returns, for each layer, a bipartite graph object, holding the
+ bipartite edges `edge_index`, the index `e_id` of the original edges,
+ and the size/shape `size` of the bipartite graph.
+ Target nodes are also included in the source nodes so that one can
+ easily apply skip-connections or add self-loops.
+
+ Arguments:
+ x (torch.Tensor or PyG.data or tuple): node features or full-batch data
+ edge_index (torch.Tensor): edge index.
+ edge_weight (torch.Tensor): edge weight.
+ adjs (List[PyG.loader.neighbor_sampler.EdgeIndex]): batched edge index
+ :returns:
+ x: output
+ :rtype:
+ torch.Tensor
+ """
+ if isinstance(x, torch.Tensor):
+ if edge_index is None:
+ for i, (edge_index, _, size) in enumerate(adjs):
+ x_target = x[:size[1]]
+ x = self.convs[i]((x, x_target), edge_index)
+ if i != self.num_layers - 1:
+ x = F.relu(x)
+ x = F.dropout(x,
+ p=self.dropout,
+ training=self.training)
+ else:
+ for conv in self.convs[:-1]:
+ x = conv(x, edge_index, edge_weight)
+ x = F.relu(x)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = self.convs[-1](x, edge_index, edge_weight)
+ return x
+ elif isinstance(x, Data) or isinstance(x, tuple):
+ return self.forward_full(x)
+ else:
+ raise TypeError
+
+ def inference(self, x_all, subgraph_loader, device):
+ r"""
+ Compute representations of nodes layer by layer, using *all*
+ available edges. This leads to faster computation in contrast to
+ immediately computing the final representations of each batch.
+
+ Arguments:
+ x_all (torch.Tensor): all node features
+ subgraph_loader (PyG.dataloader): dataloader
+ device (str): device
+ :returns:
+ x_all: output
+ """
+ total_edges = 0
+ for i in range(self.num_layers):
+ xs = []
+ for batch_size, n_id, adj in subgraph_loader:
+ edge_index, _, size = adj.to(device)
+ total_edges += edge_index.size(1)
+ x = x_all[n_id].to(device)
+ x_target = x[:size[1]]
+ x = self.convs[i]((x, x_target), edge_index)
+ if i != self.num_layers - 1:
+ x = F.relu(x)
+ xs.append(x.cpu())
+ x_all = torch.cat(xs, dim=0)
+
+ return x_all
\ No newline at end of file
diff --git a/federatedscope/gfl/trainer/__init__.py b/federatedscope/gfl/trainer/__init__.py
new file mode 100644
index 000000000..55b5e6759
--- /dev/null
+++ b/federatedscope/gfl/trainer/__init__.py
@@ -0,0 +1,12 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+from federatedscope.gfl.trainer.graphtrainer import GraphMiniBatchTrainer
+from federatedscope.gfl.trainer.linktrainer import LinkFullBatchTrainer, LinkMiniBatchTrainer
+from federatedscope.gfl.trainer.nodetrainer import NodeFullBatchTrainer, NodeMiniBatchTrainer
+
+__all__ = [
+ 'GraphMiniBatchTrainer', 'LinkFullBatchTrainer', 'LinkMiniBatchTrainer',
+ 'NodeFullBatchTrainer', 'NodeMiniBatchTrainer'
+]
\ No newline at end of file
diff --git a/federatedscope/gfl/trainer/graphtrainer.py b/federatedscope/gfl/trainer/graphtrainer.py
new file mode 100644
index 000000000..8e8e8e538
--- /dev/null
+++ b/federatedscope/gfl/trainer/graphtrainer.py
@@ -0,0 +1,67 @@
+import logging
+
+from federatedscope.core.monitors import Monitor
+from federatedscope.register import register_trainer
+from federatedscope.core.trainers import GeneralTorchTrainer
+
+logger = logging.getLogger(__name__)
+
+
+class GraphMiniBatchTrainer(GeneralTorchTrainer):
+ def _hook_on_batch_forward(self, ctx):
+ batch = ctx.data_batch.to(ctx.device)
+ pred = ctx.model(batch)
+ label = batch.y.squeeze(-1).long()
+ if len(label.size()) == 0:
+ label = label.unsqueeze(0)
+ ctx.loss_batch = ctx.criterion(pred, label)
+
+ ctx.batch_size = len(label)
+ ctx.y_true = label
+ ctx.y_prob = pred
+
+ def _hook_on_batch_forward_flop_count(self, ctx):
+ if not isinstance(self.ctx.monitor, Monitor):
+ logger.warning(
+ f"The trainer {type(self)} does contain a valid monitor, this may be caused by "
+ f"initializing trainer subclasses without passing a valid monitor instance."
+ f"Plz check whether this is you want.")
+ return
+
+ if self.ctx.monitor.flops_per_sample == 0:
+ # calculate the flops_per_sample
+ try:
+ batch = ctx.data_batch.to(ctx.device)
+ from torch_geometric.data import Data
+ if isinstance(batch, Data):
+ x, edge_index = batch.x, batch.edge_index
+ from fvcore.nn import FlopCountAnalysis
+ flops_one_batch = FlopCountAnalysis(ctx.model,
+ (x, edge_index)).total()
+ if self.model_nums > 1 and ctx.mirrored_models:
+ flops_one_batch *= self.model_nums
+ logger.warning(
+ "the flops_per_batch is multiplied by internal model nums as self.mirrored_models=True."
+ "if this is not the case you want, please customize the count hook"
+ )
+ self.ctx.monitor.track_avg_flops(flops_one_batch,
+ ctx.batch_size)
+ except:
+ logger.error(
+ "current flop count implementation is for general NodeFullBatchTrainer case: "
+ "1) the ctx.model takes only batch = ctx.data_batch as input."
+ "Please check the forward format or implement your own flop_count function"
+ )
+
+ # by default, we assume the data has the same input shape,
+ # thus simply multiply the flops to avoid redundant forward
+ self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * ctx.batch_size
+
+
+def call_graph_level_trainer(trainer_type):
+ if trainer_type == 'graphminibatch_trainer':
+ trainer_builder = GraphMiniBatchTrainer
+ return trainer_builder
+
+
+register_trainer('graphminibatch_trainer', call_graph_level_trainer)
diff --git a/federatedscope/gfl/trainer/linktrainer.py b/federatedscope/gfl/trainer/linktrainer.py
new file mode 100644
index 000000000..c7f102f01
--- /dev/null
+++ b/federatedscope/gfl/trainer/linktrainer.py
@@ -0,0 +1,206 @@
+import torch
+
+from torch.utils.data import DataLoader
+from torch_geometric.loader import DataLoader as PyGDataLoader
+from torch_geometric.data import Data
+from torch_geometric.loader import GraphSAINTRandomWalkSampler, NeighborSampler
+
+from federatedscope.core.monitors import Monitor
+from federatedscope.register import register_trainer
+from federatedscope.core.trainers import GeneralTorchTrainer
+from federatedscope.core.auxiliaries.ReIterator import ReIterator
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+MODE2MASK = {
+ 'train': 'train_edge_mask',
+ 'val': 'valid_edge_mask',
+ 'test': 'test_edge_mask'
+}
+
+
+class LinkFullBatchTrainer(GeneralTorchTrainer):
+ def register_default_hooks_eval(self):
+ super().register_default_hooks_eval()
+ self.register_hook_in_eval(
+ new_hook=self._hook_on_epoch_start_data2device,
+ trigger='on_fit_start',
+ insert_pos=-1)
+
+ def register_default_hooks_train(self):
+ super().register_default_hooks_train()
+ self.register_hook_in_train(
+ new_hook=self._hook_on_epoch_start_data2device,
+ trigger='on_fit_start',
+ insert_pos=-1)
+
+ def parse_data(self, data):
+ """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes
+
+ """
+ init_dict = dict()
+ if isinstance(data, Data):
+ for mode in ["train", "val", "test"]:
+ edges = data.edge_index.T[data[MODE2MASK[mode]]]
+ # Use an index loader
+ index_loader = DataLoader(range(edges.size(0)),
+ self.cfg.data.batch_size,
+ shuffle=self.cfg.data.shuffle
+ if mode == 'train' else False,
+ drop_last=self.cfg.data.drop_last
+ if mode == 'train' else False)
+ init_dict["{}_loader".format(mode)] = index_loader
+ init_dict["num_{}_data".format(mode)] = edges.size(0)
+ init_dict["{}_data".format(mode)] = None
+ else:
+ raise TypeError("Type of data should be PyG data.")
+ return init_dict
+
+ def _hook_on_epoch_start_data2device(self, ctx):
+ ctx.data = ctx.data.to(ctx.device)
+ # For handling different dict key
+ if "input_edge_index" in ctx.data:
+ ctx.input_edge_index = ctx.data.input_edge_index
+ else:
+ ctx.input_edge_index = ctx.data.edge_index.T[
+ ctx.data.train_edge_mask].T
+
+ def _hook_on_batch_forward(self, ctx):
+ data = ctx.data
+ perm = ctx.data_batch
+ mask = ctx.data[MODE2MASK[ctx.cur_data_split]]
+ edges = data.edge_index.T[mask]
+ if ctx.cur_data_split in ['train', 'val']:
+ h = ctx.model((data.x, ctx.input_edge_index))
+ else:
+ h = ctx.model((data.x, data.edge_index))
+ pred = ctx.model.link_predictor(h, edges[perm].T)
+ label = data.edge_type[mask][perm] # edge_type is y
+
+ ctx.loss_batch = ctx.criterion(pred, label)
+
+ ctx.batch_size = len(label)
+ ctx.y_true = label
+ ctx.y_prob = pred
+
+ def _hook_on_batch_forward_flop_count(self, ctx):
+ if not isinstance(self.ctx.monitor, Monitor):
+ logger.warning(
+ f"The trainer {type(self)} does contain a valid monitor, this may be caused by "
+ f"initializing trainer subclasses without passing a valid monitor instance."
+ f"Plz check whether this is you want.")
+ return
+
+ if self.ctx.monitor.flops_per_sample == 0:
+ # calculate the flops_per_sample
+ try:
+ data = ctx.data
+ from fvcore.nn import FlopCountAnalysis
+ if ctx.cur_data_split in ['train', 'val']:
+ flops_one_batch = FlopCountAnalysis(
+ ctx.model, (data.x, ctx.input_edge_index)).total()
+ else:
+ flops_one_batch = FlopCountAnalysis(
+ ctx.model, (data.x, data.edge_index)).total()
+ if self.model_nums > 1 and ctx.mirrored_models:
+ flops_one_batch *= self.model_nums
+ logger.warning(
+ "the flops_per_batch is multiplied by internal model nums as self.mirrored_models=True."
+ "if this is not the case you want, please customize the count hook"
+ )
+ self.ctx.monitor.track_avg_flops(flops_one_batch,
+ ctx.batch_size)
+ except:
+ logger.error(
+ "current flop count implementation is for general NodeFullBatchTrainer case: "
+ "1) the ctx.model takes the "
+ "tuple (data.x, data.edge_index) or tuple (data.x, ctx.input_edge_index) as input."
+ "Please check the forward format or implement your own flop_count function"
+ )
+
+
+class LinkMiniBatchTrainer(GeneralTorchTrainer):
+ """
+ # Support GraphSAGE with GraphSAINTRandomWalkSampler in train ONLY!
+ """
+ def parse_data(self, data):
+ """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes
+
+ """
+ init_dict = dict()
+ if isinstance(data, dict):
+ for mode in ["train", "val", "test"]:
+ init_dict["{}_data".format(mode)] = None
+ init_dict["{}_loader".format(mode)] = None
+ init_dict["num_{}_data".format(mode)] = 0
+ if data.get(mode, None) is not None:
+ if isinstance(
+ data.get(mode), NeighborSampler) or isinstance(
+ data.get(mode), GraphSAINTRandomWalkSampler):
+ if mode == 'train':
+ init_dict["{}_loader".format(mode)] = data.get(
+ mode)
+ init_dict["num_{}_data".format(mode)] = len(
+ data.get(mode).dataset)
+ else:
+ # We need to pass Full Dataloader to model
+ init_dict["{}_loader".format(mode)] = [
+ data.get(mode)
+ ]
+ init_dict["num_{}_data".format(
+ mode)] = self.cfg.data.batch_size
+ else:
+ raise TypeError("Type {} is not supported.".format(
+ type(data.get(mode))))
+ else:
+ raise TypeError("Type of data should be dict.")
+ return init_dict
+
+ def _hook_on_batch_forward(self, ctx):
+ if ctx.cur_data_split == 'train':
+ batch = ctx.data_batch.to(ctx.device)
+ mask = batch[MODE2MASK[ctx.cur_data_split]]
+ edges = batch.edge_index.T[mask].T
+ h = ctx.model((batch.x, edges))
+ pred = ctx.model.link_predictor(h, edges)
+ label = batch.edge_type[mask]
+ ctx.batch_size = torch.sum(
+ ctx.data_batch[MODE2MASK[ctx.cur_data_split]]).item()
+ else:
+ # For inference
+ mask = ctx.data['data'][MODE2MASK[ctx.cur_data_split]]
+ subgraph_loader = ctx.data_batch
+ h = ctx.model.gnn.inference(ctx.data['data'].x, subgraph_loader,
+ ctx.device).to(ctx.device)
+ edges = ctx.data['data'].edge_index.T[mask].to(ctx.device)
+ pred = []
+
+ for perm in DataLoader(range(edges.size(0)),
+ self.cfg.data.batch_size):
+ edge = edges[perm].T
+ pred += [ctx.model.link_predictor(h, edge).squeeze()]
+ pred = torch.cat(pred, dim=0)
+ label = ctx.data['data'].edge_type[mask].to(ctx.device)
+ ctx.batch_size = torch.sum(
+ ctx.data['data'][MODE2MASK[ctx.cur_data_split]]).item()
+
+ ctx.loss_batch = ctx.criterion(pred, label)
+ ctx.y_true = label
+ ctx.y_prob = pred
+
+
+def call_link_level_trainer(trainer_type):
+ if trainer_type == 'linkfullbatch_trainer':
+ trainer_builder = LinkFullBatchTrainer
+ elif trainer_type == 'linkminibatch_trainer':
+ trainer_builder = LinkMiniBatchTrainer
+ else:
+ raise ValueError
+
+ return trainer_builder
+
+
+register_trainer('linkfullbatch_trainer', call_link_level_trainer)
+register_trainer('linkminibatch_trainer', call_link_level_trainer)
diff --git a/federatedscope/gfl/trainer/nodetrainer.py b/federatedscope/gfl/trainer/nodetrainer.py
new file mode 100644
index 000000000..e24c21234
--- /dev/null
+++ b/federatedscope/gfl/trainer/nodetrainer.py
@@ -0,0 +1,177 @@
+import torch
+
+from torch_geometric.loader import DataLoader as PyGDataLoader
+from torch_geometric.data import Data
+from torch_geometric.loader import GraphSAINTRandomWalkSampler, NeighborSampler
+
+from federatedscope.core.monitors import Monitor
+from federatedscope.register import register_trainer
+from federatedscope.core.trainers import GeneralTorchTrainer
+from federatedscope.core.auxiliaries.ReIterator import ReIterator
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class NodeFullBatchTrainer(GeneralTorchTrainer):
+ def parse_data(self, data):
+ """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes
+
+ """
+ init_dict = dict()
+ if isinstance(data, Data):
+ for mode in ["train", "val", "test"]:
+ init_dict["{}_loader".format(mode)] = PyGDataLoader([data])
+ init_dict["{}_data".format(mode)] = None
+ # For node-level task dataloader contains one graph
+ init_dict["num_{}_data".format(mode)] = 1
+
+ else:
+ raise TypeError("Type of data should be PyG data.")
+ return init_dict
+
+ def _hook_on_batch_forward(self, ctx):
+ batch = ctx.data_batch.to(ctx.device)
+ pred = ctx.model(batch)[batch['{}_mask'.format(ctx.cur_data_split)]]
+ label = batch.y[batch['{}_mask'.format(ctx.cur_data_split)]]
+ ctx.batch_size = torch.sum(ctx.data_batch['{}_mask'.format(
+ ctx.cur_data_split)]).item()
+
+ ctx.loss_batch = ctx.criterion(pred, label)
+ ctx.y_true = label
+ ctx.y_prob = pred
+
+ def _hook_on_batch_forward_flop_count(self, ctx):
+ if not isinstance(self.ctx.monitor, Monitor):
+ logger.warning(
+ f"The trainer {type(self)} does contain a valid monitor, this may be caused by "
+ f"initializing trainer subclasses without passing a valid monitor instance."
+ f"Plz check whether this is you want.")
+ return
+
+ if self.ctx.monitor.flops_per_sample == 0:
+ # calculate the flops_per_sample
+ try:
+ batch = ctx.data_batch.to(ctx.device)
+ from torch_geometric.data import Data
+ if isinstance(batch, Data):
+ x, edge_index = batch.x, batch.edge_index
+ from fvcore.nn import FlopCountAnalysis
+ flops_one_batch = FlopCountAnalysis(ctx.model,
+ (x, edge_index)).total()
+
+ if self.model_nums > 1 and ctx.mirrored_models:
+ flops_one_batch *= self.model_nums
+ logger.warning(
+ "the flops_per_batch is multiplied by internal model nums as self.mirrored_models=True."
+ "if this is not the case you want, please customize the count hook"
+ )
+ self.ctx.monitor.track_avg_flops(flops_one_batch,
+ ctx.batch_size)
+ except:
+ logger.error(
+ "current flop count implementation is for general NodeFullBatchTrainer case: "
+ "1) the ctx.model takes only batch = ctx.data_batch as input."
+ "Please check the forward format or implement your own flop_count function"
+ )
+
+ # by default, we assume the data has the same input shape,
+ # thus simply multiply the flops to avoid redundant forward
+ self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * ctx.batch_size
+
+
+class NodeMiniBatchTrainer(GeneralTorchTrainer):
+ def parse_data(self, data):
+ """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes
+
+ """
+ init_dict = dict()
+ if isinstance(data, dict):
+ for mode in ["train", "val", "test"]:
+ init_dict["{}_data".format(mode)] = None
+ init_dict["{}_loader".format(mode)] = None
+ init_dict["num_{}_data".format(mode)] = 0
+ if data.get(mode, None) is not None:
+ if isinstance(
+ data.get(mode), NeighborSampler) or isinstance(
+ data.get(mode), GraphSAINTRandomWalkSampler):
+ if mode == 'train':
+ init_dict["{}_loader".format(mode)] = data.get(
+ mode)
+ init_dict["num_{}_data".format(mode)] = len(
+ data.get(mode).dataset)
+ else:
+ # We need to pass Full Dataloader to model
+ init_dict["{}_loader".format(mode)] = [
+ data.get(mode)
+ ]
+ init_dict["num_{}_data".format(
+ mode)] = self.cfg.data.batch_size
+ else:
+ raise TypeError("Type {} is not supported.".format(
+ type(data.get(mode))))
+ else:
+ raise TypeError("Type of data should be dict.")
+ return init_dict
+
+ def _hook_on_epoch_start(self, ctx):
+ # TODO: blind torch
+ if not isinstance(ctx.get("{}_loader".format(ctx.cur_data_split)),
+ ReIterator):
+ if isinstance(ctx.get("{}_loader".format(ctx.cur_data_split)),
+ NeighborSampler):
+ self.is_NeighborSampler = True
+ ctx.data['data'].x = ctx.data['data'].x.to(ctx.device)
+ ctx.data['data'].y = ctx.data['data'].y.to(ctx.device)
+ else:
+ self.is_NeighborSampler = False
+ setattr(
+ ctx, "{}_loader".format(ctx.cur_data_split),
+ ReIterator(ctx.get("{}_loader".format(ctx.cur_data_split))))
+
+ def _hook_on_batch_forward(self, ctx):
+ if ctx.cur_data_split == 'train':
+ # For training
+ if self.is_NeighborSampler:
+ # For NeighborSamper
+ batch_size, n_id, adjs = ctx.data_batch
+ adjs = [adj.to(ctx.device) for adj in adjs]
+ pred = ctx.model(ctx.data['data'].x[n_id], adjs=adjs)
+ label = ctx.data['data'].y[n_id[:batch_size]]
+ ctx.batch_size, _, _ = ctx.data_batch
+ else:
+ # For GraphSAINTRandomWalkSampler or PyGDataLoader
+ batch = ctx.data_batch.to(ctx.device)
+ pred = ctx.model(batch.x,
+ batch.edge_index)[batch['{}_mask'.format(
+ ctx.cur_data_split)]]
+ label = batch.y[batch['{}_mask'.format(ctx.cur_data_split)]]
+ ctx.batch_size = torch.sum(ctx.data_batch['train_mask']).item()
+ else:
+ # For inference
+ subgraph_loader = ctx.data_batch
+ mask = ctx.data['data']['{}_mask'.format(ctx.cur_data_split)]
+ pred = ctx.model.inference(ctx.data['data'].x, subgraph_loader,
+ ctx.device)[mask]
+ label = ctx.data['data'].y[mask]
+ ctx.batch_size = torch.sum(ctx.data['data']['{}_mask'.format(
+ ctx.cur_data_split)]).item()
+
+ ctx.loss_batch = ctx.criterion(pred, label)
+ ctx.y_true = label
+ ctx.y_prob = pred
+
+
+def call_node_level_trainer(trainer_type):
+ if trainer_type == 'nodefullbatch_trainer':
+ trainer_builder = NodeFullBatchTrainer
+ elif trainer_type == 'nodeminibatch_trainer':
+ trainer_builder = NodeMiniBatchTrainer
+ else:
+ raise ValueError
+
+ return trainer_builder
+
+
+register_trainer('nodefullbatch_trainer', call_node_level_trainer)
+register_trainer('nodeminibatch_trainer', call_node_level_trainer)
diff --git a/federatedscope/hpo.py b/federatedscope/hpo.py
new file mode 100644
index 000000000..96819c31d
--- /dev/null
+++ b/federatedscope/hpo.py
@@ -0,0 +1,42 @@
+import os
+import sys
+
+from yacs.config import CfgNode
+
+import yaml
+
+DEV_MODE = False # simplify the federatedscope re-setup everytime we change the source codes of federatedscope
+if DEV_MODE:
+ file_dir = os.path.join(os.path.dirname(__file__), '..')
+ sys.path.append(file_dir)
+
+from federatedscope.core.auxiliaries.utils import setup_seed, update_logger
+from federatedscope.core.cmd_args import parse_args
+from federatedscope.core.configs.config import global_cfg
+from federatedscope.autotune import get_scheduler
+
+if os.environ.get('https_proxy'):
+ del os.environ['https_proxy']
+if os.environ.get('http_proxy'):
+ del os.environ['http_proxy']
+
+if __name__ == '__main__':
+ init_cfg = global_cfg.clone()
+ args = parse_args()
+ init_cfg.merge_from_file(args.cfg_file)
+ init_cfg.merge_from_list(args.opts)
+
+ update_logger(init_cfg)
+ setup_seed(init_cfg.seed)
+
+ assert not args.client_cfg_file, 'No support for client-wise config in HPO mode.'
+
+ #with open(args.cfg_file, 'r') as ips:
+ # config = yaml.load(ips, Loader=yaml.FullLoader)
+ #det_config, tbd_config = split_raw_config(config)
+ #global_cfg.merge_from_list(config2cmdargs(det_config))
+ #global_cfg.merge_from_list(args.opts)
+
+ scheduler = get_scheduler(init_cfg)
+ _ = scheduler.optimize()
+ #logger.info(results)
diff --git a/federatedscope/main.py b/federatedscope/main.py
new file mode 100644
index 000000000..e2ea3b3a2
--- /dev/null
+++ b/federatedscope/main.py
@@ -0,0 +1,52 @@
+import os
+from pkgutil import ImpImporter
+import sys
+from typing import Hashable
+
+from yacs.config import CfgNode
+
+DEV_MODE = False # simplify the federatedscope re-setup everytime we change the source codes of federatedscope
+if DEV_MODE:
+ file_dir = os.path.join(os.path.dirname(__file__), '..')
+ sys.path.append(file_dir)
+
+from federatedscope.core.cmd_args import parse_args
+from federatedscope.core.auxiliaries.data_builder import get_data
+from federatedscope.core.auxiliaries.utils import setup_seed, update_logger
+from federatedscope.core.auxiliaries.worker_builder import get_client_cls, get_server_cls
+from federatedscope.core.configs.config import global_cfg
+from federatedscope.core.fed_runner import FedRunner
+
+if os.environ.get('https_proxy'):
+ del os.environ['https_proxy']
+if os.environ.get('http_proxy'):
+ del os.environ['http_proxy']
+
+if __name__ == '__main__':
+ init_cfg = global_cfg.clone()
+ args = parse_args()
+ print(args)
+ init_cfg.merge_from_file(args.cfg_file)
+ init_cfg.merge_from_list(args.opts)
+
+ update_logger(init_cfg)
+ setup_seed(init_cfg.data.seed)
+
+ # load clients' cfg file
+ client_cfg = CfgNode.load_cfg(open(args.client_cfg_file,
+ 'r')) if args.client_cfg_file else None
+
+ # federated dataset might change the number of clients
+ # thus, we allow the creation procedure of dataset to modify the global cfg object
+ data, modified_cfg = get_data(config=init_cfg.clone())
+ init_cfg.merge_from_other_cfg(modified_cfg)
+
+ setup_seed(init_cfg.seed)
+ init_cfg.freeze()
+
+ runner = FedRunner(data=data,
+ server_class=get_server_cls(init_cfg),
+ client_class=get_client_cls(init_cfg),
+ config=init_cfg.clone(),
+ client_config=client_cfg)
+ _ = runner.run()
diff --git a/federatedscope/mf/__init__.py b/federatedscope/mf/__init__.py
new file mode 100644
index 000000000..f8e91f237
--- /dev/null
+++ b/federatedscope/mf/__init__.py
@@ -0,0 +1,3 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
diff --git a/federatedscope/mf/baseline/__init__.py b/federatedscope/mf/baseline/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/federatedscope/mf/dataloader/__init__.py b/federatedscope/mf/dataloader/__init__.py
new file mode 100644
index 000000000..915b1c049
--- /dev/null
+++ b/federatedscope/mf/dataloader/__init__.py
@@ -0,0 +1,3 @@
+from federatedscope.mf.dataloader.dataloader import load_mf_dataset, MFDataLoader
+
+__all__ = ['load_mf_dataset', 'MFDataLoader']
diff --git a/federatedscope/mf/dataloader/dataloader.py b/federatedscope/mf/dataloader/dataloader.py
new file mode 100644
index 000000000..b8b2a47c4
--- /dev/null
+++ b/federatedscope/mf/dataloader/dataloader.py
@@ -0,0 +1,169 @@
+from scipy.sparse import csc_matrix
+from scipy.sparse import coo_matrix
+from numpy.random import shuffle
+
+import numpy as np
+
+import collections
+import importlib
+
+MFDATA_CLASS_DICT = {
+ "vflmovielens1m": "VFLMovieLens1M",
+ "vflmovielens10m": "VFLMovieLens10M",
+ "hflmovielens1m": "HFLMovieLens1M",
+ "hflmovielens10m": "HFLMovieLens10M"
+}
+
+
+def load_mf_dataset(config=None):
+ """Return the dataset of matrix factorization
+
+ Format:
+ {
+ 'client_id': {
+ 'train': DataLoader(),
+ 'test': DataLoader(),
+ 'val': DataLoader()
+ }
+ }
+
+ """
+ if config.data.type.lower() in MFDATA_CLASS_DICT:
+ # Dataset
+ dataset = getattr(
+ importlib.import_module("federatedscope.mf.dataset.movielens"),
+ MFDATA_CLASS_DICT[config.data.type.lower()])(
+ root=config.data.root,
+ num_client=config.federate.client_num,
+ train_portion=config.data.splits[0],
+ download=True)
+ else:
+ raise NotImplementedError("Dataset {} is not implemented.".format(
+ config.data.type))
+
+ data_local_dict = collections.defaultdict(dict)
+ for id_client, data in dataset.data.items():
+ data_local_dict[id_client]["train"] = MFDataLoader(
+ data["train"],
+ shuffle=config.data.shuffle,
+ batch_size=config.data.batch_size,
+ drop_last=config.data.drop_last,
+ theta=config.sgdmf.theta)
+ data_local_dict[id_client]["test"] = MFDataLoader(
+ data["test"],
+ shuffle=False,
+ batch_size=config.data.batch_size,
+ drop_last=config.data.drop_last,
+ theta=config.sgdmf.theta)
+
+ # Modify config
+ config.merge_from_list(['model.num_user', dataset.n_user])
+ config.merge_from_list(['model.num_item', dataset.n_item])
+
+ return data_local_dict, config
+
+
+class MFDataLoader(object):
+ """DataLoader for MF dataset
+
+ Args:
+ data (csc_matrix): sparse MF dataset
+ batch_size (int): the size of batch data
+ shuffle (bool): shuffle the dataset
+ drop_last (bool): drop the last batch if True
+ theta (int): the maximal number of ratings for each user
+ """
+ def __init__(self,
+ data: csc_matrix,
+ batch_size: int,
+ shuffle=True,
+ drop_last=False,
+ theta=None):
+ super(MFDataLoader, self).__init__()
+ self.dataset = self._trim_data(data, theta)
+ self.shuffle = shuffle
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+
+ self.n_row = self.dataset.shape[0]
+ self.n_col = self.dataset.shape[1]
+ self.n_rating = self.dataset.count_nonzero()
+
+ self._idx_samples = None
+ self._idx_cur = None
+
+ self._reset()
+
+ def _trim_data(self, data, theta=None):
+ """Trim rating data by parameter theta (per-user privacy)
+
+ Arguments:
+ data (csc_matrix): the dataset
+ theta (int): The maximal number of ratings for each user
+ """
+ if theta is None or theta <= 0:
+ return data
+ else:
+ # Each user has at most $theta$ items
+ dataset = data.tocoo()
+ user2items = collections.defaultdict(list)
+ for idx, user_id in enumerate(dataset.row):
+ user2items[user_id].append(idx)
+ # sample theta each
+ idx_select = list()
+ for items in user2items.values():
+ if len(items) > theta:
+ idx_select += np.random.choice(items, theta,
+ replace=False).tolist()
+ else:
+ idx_select += items
+ dataset = coo_matrix(
+ (dataset.data[idx_select],
+ (dataset.row[idx_select], dataset.col[idx_select])),
+ shape=dataset.shape).tocsc()
+ return dataset
+
+ def _reset(self):
+ self._idx_cur = 0
+ if self._idx_samples is None:
+ self._idx_samples = np.arange(self.n_rating)
+ if self.shuffle:
+ shuffle(self._idx_samples)
+
+ def _sample_data(self, sampled_rating_idx):
+ dataset = self.dataset.tocoo()
+ data = dataset.data[sampled_rating_idx]
+ rows = dataset.row[sampled_rating_idx]
+ cols = dataset.col[sampled_rating_idx]
+ return (rows, cols), data
+
+ def __len__(self):
+ """The number of batches within an epoch
+
+ """
+ if self.drop_last:
+ return int(self.n_rating / self.batch_size)
+ else:
+ return int(self.n_rating / self.batch_size) + int(
+ (self.n_rating % self.batch_size) != 0)
+
+ def __next__(self, theta=None):
+ """Get the next batch of data
+
+ Args:
+ theta (int): the maximal number of ratings for each user
+ """
+ idx_end = self._idx_cur + self.batch_size
+ if self._idx_cur == len(
+ self._idx_samples) or self.drop_last and idx_end > len(
+ self._idx_samples):
+ raise StopIteration
+ idx_end = min(idx_end, len(self._idx_samples))
+ idx_choice_samples = self._idx_samples[self._idx_cur:idx_end]
+ self._idx_cur = idx_end
+
+ return self._sample_data(idx_choice_samples)
+
+ def __iter__(self):
+ self._reset()
+ return self
diff --git a/federatedscope/mf/dataset/__init__.py b/federatedscope/mf/dataset/__init__.py
new file mode 100644
index 000000000..4c7fc0b75
--- /dev/null
+++ b/federatedscope/mf/dataset/__init__.py
@@ -0,0 +1,6 @@
+from federatedscope.mf.dataset.movielens import *
+
+__all__ = [
+ 'VMFDataset', 'HMFDataset', 'MovieLensData', 'MovieLens1M', 'MovieLens10M',
+ 'VFLMovieLens1M', 'HFLMovieLens1M', 'VFLMovieLens10M', 'HFLMovieLens10M'
+]
diff --git a/federatedscope/mf/dataset/movielens.py b/federatedscope/mf/dataset/movielens.py
new file mode 100644
index 000000000..63b9ed6aa
--- /dev/null
+++ b/federatedscope/mf/dataset/movielens.py
@@ -0,0 +1,247 @@
+import os
+import pickle
+import logging
+
+from torchvision.datasets.utils import check_integrity, download_and_extract_archive, calculate_md5
+import pandas as pd
+from numpy.random import shuffle
+from scipy.sparse import coo_matrix
+from scipy.sparse import csc_matrix
+import numpy as np
+
+logger = logging.getLogger(__name__)
+
+
+class VMFDataset:
+ """Dataset of matrix factorization task in vertical federated learning.
+
+ """
+ def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int,
+ test_portion: float):
+ id_item = np.arange(self.n_item)
+ shuffle(id_item)
+ items_per_client = np.array_split(id_item, num_client)
+ data = dict()
+ for clientId, items in enumerate(items_per_client):
+ client_ratings = ratings[:, items]
+ train_ratings, test_ratings = self._split_train_test_ratings(
+ client_ratings, test_portion)
+ data[clientId + 1] = {"train": train_ratings, "test": test_ratings}
+ self.data = data
+
+
+class HMFDataset:
+ """Dataset of matrix factorization task in horizontal federated learning.
+
+ """
+ def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int,
+ test_portion: float):
+ id_user = np.arange(self.n_user)
+ shuffle(id_user)
+ users_per_client = np.array_split(id_user, num_client)
+ data = dict()
+ for cliendId, users in enumerate(users_per_client):
+ client_ratings = ratings[users, :]
+ train_ratings, test_ratings = self._split_train_test_ratings(
+ client_ratings, test_portion)
+ data[cliendId + 1] = {"train": train_ratings, "test": test_ratings}
+ self.data = data
+
+
+class MovieLensData(object):
+ """Download and split MF datasets
+
+ Arguments:
+ root (string): the path of data
+ num_client (int): the number of clients
+ train_portion (float): the portion of training data
+ download (bool): indicator to download dataset
+ """
+ def __init__(self, root, num_client, train_portion=0.9, download=True):
+ super(MovieLensData, self).__init__()
+
+ self.root = root
+ self.data = None
+
+ self.n_user = None
+ self.n_item = None
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted." +
+ "You can use download=True to download it")
+
+ ratings = self._load_meta()
+ self._split_n_clients_rating(ratings, num_client, 1 - train_portion)
+
+ def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int,
+ test_portion: float):
+ id_item = np.arange(self.n_item)
+ shuffle(id_item)
+ items_per_client = np.array_split(id_item, num_client)
+ data = dict()
+ for clientId, items in enumerate(items_per_client):
+ client_ratings = ratings[:, items]
+ train_ratings, test_ratings = self._split_train_test_ratings(
+ client_ratings, test_portion)
+ data[clientId + 1] = {"train": train_ratings, "test": test_ratings}
+ self.data = data
+
+ def _split_train_test_ratings(self, ratings: csc_matrix,
+ test_portion: float):
+ n_ratings = ratings.count_nonzero()
+ id_test = np.random.choice(n_ratings,
+ int(n_ratings * test_portion),
+ replace=False)
+ id_train = list(set(np.arange(n_ratings)) - set(id_test))
+
+ ratings = ratings.tocoo()
+ test = coo_matrix((ratings.data[id_test],
+ (ratings.row[id_test], ratings.col[id_test])),
+ shape=ratings.shape)
+ train = coo_matrix((ratings.data[id_train],
+ (ratings.row[id_train], ratings.col[id_train])),
+ shape=ratings.shape)
+
+ train_ratings, test_ratings = train.tocsc(), test.tocsc()
+ return train_ratings, test_ratings
+
+ def _load_meta(self):
+ meta_path = os.path.join(self.root, self.base_folder, "ratings.pkl")
+ if not os.path.exists(meta_path):
+ logger.info("Processing data into {} parties.")
+ fpath = os.path.join(self.root, self.base_folder, self.filename,
+ self.raw_file)
+ data = pd.read_csv(fpath,
+ sep="::",
+ engine="python",
+ usecols=[0, 1, 2],
+ names=["userId", "movieId", "rating"],
+ dtype={
+ "userId": np.int32,
+ "movieId": np.int32,
+ "rating": np.float32
+ })
+ # Map idx
+ unique_id_item, unique_id_user = np.sort(
+ data["movieId"].unique()), np.sort(data["userId"].unique())
+ n_item, n_user = len(unique_id_item), len(unique_id_user)
+ mapping_item, mapping_user = {
+ mid: idx
+ for idx, mid in enumerate(unique_id_item)
+ }, {mid: idx
+ for idx, mid in enumerate(unique_id_user)}
+
+ row = [mapping_user[mid] for _, mid in data["userId"].iteritems()]
+ col = [mapping_item[mid] for _, mid in data["movieId"].iteritems()]
+
+ ratings = coo_matrix((data["rating"], (row, col)),
+ shape=(n_user, n_item))
+ ratings = ratings.tocsc()
+
+ with open(meta_path, 'wb') as f:
+ pickle.dump(ratings, f)
+ logger.info("Done.")
+ else:
+ with open(meta_path, 'rb') as f:
+ ratings = pickle.load(f)
+
+ self.n_user, self.n_item = ratings.shape
+ return ratings
+
+ def _check_integrity(self):
+ fpath = os.path.join(self.root, self.base_folder, self.filename,
+ self.raw_file)
+ return check_integrity(fpath, self.raw_file_md5)
+
+ def download(self):
+ if self._check_integrity():
+ logger.info("Files already downloaded and verified")
+ return
+ download_and_extract_archive(self.url,
+ os.path.join(self.root, self.base_folder),
+ filename=self.url.split('/')[-1],
+ md5=self.zip_md5)
+
+
+class MovieLens1M(MovieLensData):
+ """MoviesLens 1M Dataset
+ (https://grouplens.org/datasets/movielens)
+
+ Format:
+ UserID::MovieID::Rating::Timestamp
+
+ Arguments:
+ root (str): Root directory of dataset where directory
+ ``MoviesLen1M`` exists or will be saved to if download is set to True.
+ config (callable): Parameters related to matrix factorization.
+ train_size (float, optional): The proportion of training data.
+ test_size (float, optional): The proportion of test data.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ """
+ base_folder = 'MovieLens1M'
+ url = "https://files.grouplens.org/datasets/movielens/ml-1m.zip"
+ filename = "ml-1m"
+ zip_md5 = "c4d9eecfca2ab87c1945afe126590906"
+ raw_file = "ratings.dat"
+ raw_file_md5 = "a89aa3591bc97d6d4e0c89459ff39362"
+
+
+class MovieLens10M(MovieLensData):
+ """MoviesLens 10M Dataset
+ (https://grouplens.org/datasets/movielens)
+
+ Format:
+ UserID::MovieID::Rating::Timestamp
+
+ Arguments:
+ root (str): Root directory of dataset where directory
+ ``MoviesLen1M`` exists or will be saved to if download is set to True.
+ config (callable): Parameters related to matrix factorization.
+ train_size (float, optional): The proportion of training data.
+ test_size (float, optional): The proportion of test data.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ """
+ base_folder = 'MovieLens10M'
+ url = "https://files.grouplens.org/datasets/movielens/ml-10m.zip"
+ filename = "ml-10M100K"
+
+ zip_md5 = "ce571fd55effeba0271552578f2648bd"
+ raw_file = "ratings.dat"
+ raw_file_md5 = "3f317698625386f66177629fa5c6b2dc"
+
+
+class VFLMovieLens1M(MovieLens1M, VMFDataset):
+ """MovieLens1M dataset in VFL setting
+
+ """
+ pass
+
+
+class HFLMovieLens1M(MovieLens1M, HMFDataset):
+ """MovieLens1M dataset in HFL setting
+
+ """
+ pass
+
+
+class VFLMovieLens10M(MovieLens10M, VMFDataset):
+ """MovieLens10M dataset in VFL setting
+
+ """
+ pass
+
+
+class HFLMovieLens10M(MovieLens10M, HMFDataset):
+ """MovieLens10M dataset in HFL setting
+
+ """
+ pass
diff --git a/federatedscope/mf/model/__init__.py b/federatedscope/mf/model/__init__.py
new file mode 100644
index 000000000..709fc11f9
--- /dev/null
+++ b/federatedscope/mf/model/__init__.py
@@ -0,0 +1,4 @@
+from federatedscope.mf.model.model import BasicMFNet, VMFNet, HMFNet
+from federatedscope.mf.model.model_builder import get_mfnet
+
+__all__ = ["get_mfnet", "BasicMFNet", "VMFNet", "HMFNet"]
diff --git a/federatedscope/mf/model/model.py b/federatedscope/mf/model/model.py
new file mode 100644
index 000000000..2fe2572d3
--- /dev/null
+++ b/federatedscope/mf/model/model.py
@@ -0,0 +1,74 @@
+from torch.nn import Parameter
+from torch.nn import Module
+
+import numpy as np
+import torch
+
+
+class BasicMFNet(Module):
+ """Basic model for MF task
+
+ Arguments:
+ num_user (int): the number of users
+ num_item (int): the number of items
+ num_hidden (int): the dimension of embedding vector
+ """
+ def __init__(self, num_user, num_item, num_hidden):
+ super(BasicMFNet, self).__init__()
+
+ self.embed_user = Parameter(
+ torch.normal(mean=0,
+ std=0.1,
+ size=(num_user, num_hidden),
+ requires_grad=True,
+ dtype=torch.float32))
+ self.register_parameter('embed_user', self.embed_user)
+ self.embed_item = Parameter(
+ torch.normal(mean=0,
+ std=0.1,
+ size=(num_item, num_hidden),
+ requires_grad=True,
+ dtype=torch.float32))
+ self.register_parameter('embed_item', self.embed_item)
+
+ def forward(self, indices, ratings):
+ pred = torch.matmul(self.embed_user, self.embed_item.T)
+ label = torch.sparse_coo_tensor(indices,
+ ratings,
+ size=pred.shape,
+ device=pred.device,
+ dtype=torch.float32).to_dense()
+ mask = torch.sparse_coo_tensor(indices,
+ np.ones(len(ratings)),
+ size=pred.shape,
+ device=pred.device,
+ dtype=torch.float32).to_dense()
+
+ return mask * pred, label, float(np.prod(pred.size())) / len(ratings)
+
+ def load_state_dict(self,
+ state_dict: 'OrderedDict[str, Tensor]',
+ strict: bool = True):
+
+ state_dict[self.name_reserve] = getattr(self, self.name_reserve)
+ super().load_state_dict(state_dict, strict)
+
+ def state_dict(self, destination=None, prefix='', keep_vars=False):
+ state_dict = super().state_dict(destination, prefix, keep_vars)
+ # Mask embed_item
+ del state_dict[self.name_reserve]
+ return state_dict
+
+
+class VMFNet(BasicMFNet):
+ """MF model for vertical federated learning
+
+ """
+ name_reserve = "embed_item"
+
+
+class HMFNet(BasicMFNet):
+ """MF model for horizontal federated learning
+
+ """
+ name_reserve = "embed_user"
diff --git a/federatedscope/mf/model/model_builder.py b/federatedscope/mf/model/model_builder.py
new file mode 100644
index 000000000..33021396a
--- /dev/null
+++ b/federatedscope/mf/model/model_builder.py
@@ -0,0 +1,17 @@
+def get_mfnet(model_config, local_data):
+ """Return the MF model according to model configs
+
+ Arguments:
+ model_config: the model related parameters
+ local_data (dict): the dataset used for this model
+ """
+ if model_config.type.lower() == 'vmfnet':
+ from federatedscope.mf.model.model import VMFNet
+ return VMFNet(num_user=model_config.num_user,
+ num_item=local_data["train"].n_col,
+ num_hidden=model_config.hidden)
+ else:
+ from federatedscope.mf.model.model import HMFNet
+ return HMFNet(num_user=local_data["train"].n_row,
+ num_item=model_config.num_item,
+ num_hidden=model_config.hidden)
diff --git a/federatedscope/mf/trainer/__init__.py b/federatedscope/mf/trainer/__init__.py
new file mode 100644
index 000000000..4fc87a487
--- /dev/null
+++ b/federatedscope/mf/trainer/__init__.py
@@ -0,0 +1,7 @@
+from federatedscope.mf.trainer.trainer import MFTrainer
+from federatedscope.mf.trainer.trainer_sgdmf import wrap_MFTrainer, init_sgdmf_ctx, embedding_clip, hook_on_batch_backward
+
+__all__ = [
+ 'MFTrainer', 'wrap_MFTrainer', 'init_sgdmf_ctx', 'embedding_clip',
+ 'hook_on_batch_backward'
+]
diff --git a/federatedscope/mf/trainer/trainer.py b/federatedscope/mf/trainer/trainer.py
new file mode 100644
index 000000000..23f3e63ba
--- /dev/null
+++ b/federatedscope/mf/trainer/trainer.py
@@ -0,0 +1,135 @@
+import numpy
+from wandb.wandb_torch import torch
+
+from federatedscope.core.monitors import Monitor
+from federatedscope.mf.dataloader.dataloader import MFDataLoader
+from federatedscope.core.trainers import GeneralTorchTrainer
+from federatedscope.register import register_trainer
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class MFTrainer(GeneralTorchTrainer):
+ """Trainer for MF task
+
+ Arguments:
+ model (torch.nn.module): MF model.
+ data (dict): input data
+ device (str): device.
+ """
+ def parse_data(self, data):
+ """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes
+
+ """
+ init_dict = dict()
+ if isinstance(data, dict):
+ for mode in ["train", "val", "test"]:
+ init_dict["{}_data".format(mode)] = None
+ init_dict["{}_loader".format(mode)] = None
+ init_dict["num_{}_data".format(mode)] = 0
+ if data.get(mode, None) is not None:
+ if isinstance(data.get(mode), MFDataLoader):
+ init_dict["{}_loader".format(mode)] = data.get(mode)
+ init_dict["num_{}_data".format(mode)] = data.get(
+ mode).n_rating
+ else:
+ raise TypeError(
+ "Type {} is not supported for MFTrainer.".format(
+ type(data.get(mode))))
+ else:
+ raise TypeError("Type of data should be dict.")
+ return init_dict
+
+ def _hook_on_fit_end(self, ctx):
+ results = {
+ f"{ctx.cur_mode}_avg_loss": ctx.get("loss_batch_total_{}".format(
+ ctx.cur_mode)) /
+ ctx.get("num_samples_{}".format(ctx.cur_mode)),
+ f"{ctx.cur_mode}_total": ctx.get("num_samples_{}".format(
+ ctx.cur_mode))
+ }
+ setattr(ctx, 'eval_metrics', results)
+
+ def _hook_on_batch_end(self, ctx):
+ # update statistics
+ setattr(
+ ctx, "loss_batch_total_{}".format(ctx.cur_mode),
+ ctx.get("loss_batch_total_{}".format(ctx.cur_mode)) +
+ ctx.loss_batch.item() * ctx.batch_size)
+
+ if ctx.get("loss_regular", None) is None or ctx.loss_regular == 0:
+ loss_regular = 0.
+ else:
+ loss_regular = ctx.loss_regular.item()
+ setattr(
+ ctx, "loss_regular_total_{}".format(ctx.cur_mode),
+ ctx.get("loss_regular_total_{}".format(ctx.cur_mode)) +
+ loss_regular)
+ setattr(
+ ctx, "num_samples_{}".format(ctx.cur_mode),
+ ctx.get("num_samples_{}".format(ctx.cur_mode)) + ctx.batch_size)
+
+ # clean temp ctx
+ ctx.data_batch = None
+ ctx.batch_size = None
+ ctx.loss_task = None
+ ctx.loss_batch = None
+ ctx.loss_regular = None
+ ctx.y_true = None
+ ctx.y_prob = None
+
+ def _hook_on_batch_forward(self, ctx):
+ indices, ratings = ctx.data_batch
+ pred, label, ratio = ctx.model(indices, ratings)
+ ctx.loss_batch = ctx.criterion(pred, label) * ratio
+
+ ctx.batch_size = len(ratings)
+
+ def _hook_on_batch_forward_flop_count(self, ctx):
+ if not isinstance(self.ctx.monitor, Monitor):
+ logger.warning(
+ f"The trainer {type(self)} does contain a valid monitor, this may be caused by "
+ f"initializing trainer subclasses without passing a valid monitor instance."
+ f"Plz check whether this is you want.")
+ return
+
+ if self.ctx.monitor.flops_per_sample == 0:
+ # calculate the flops_per_sample
+ try:
+ indices, ratings = ctx.data_batch
+ if isinstance(indices, numpy.ndarray):
+ indices = torch.from_numpy(indices)
+ if isinstance(ratings, numpy.ndarray):
+ ratings = torch.from_numpy(ratings)
+ from fvcore.nn import FlopCountAnalysis
+ flops_one_batch = FlopCountAnalysis(
+ ctx.model, (indices, ratings)).total()
+ if self.model_nums > 1 and ctx.mirrored_models:
+ flops_one_batch *= self.model_nums
+ logger.warning(
+ "the flops_per_batch is multiplied by internal model nums as self.mirrored_models=True."
+ "if this is not the case you want, please customize the count hook"
+ )
+ self.ctx.monitor.track_avg_flops(flops_one_batch,
+ ctx.batch_size)
+ except:
+ logger.error(
+ "current flop count implementation is for general NodeFullBatchTrainer case: "
+ "1) the ctx.model takes tuple (indices, ratings) as input."
+ "Please check the forward format or implement your own flop_count function"
+ )
+
+ # by default, we assume the data has the same input shape,
+ # thus simply multiply the flops to avoid redundant forward
+ self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * ctx.batch_size
+
+
+def call_mf_trainer(trainer_type):
+ if trainer_type == "mftrainer":
+ trainer_builder = MFTrainer
+ return trainer_builder
+
+
+register_trainer("mftrainer", call_mf_trainer)
diff --git a/federatedscope/mf/trainer/trainer_sgdmf.py b/federatedscope/mf/trainer/trainer_sgdmf.py
new file mode 100644
index 000000000..685d761d8
--- /dev/null
+++ b/federatedscope/mf/trainer/trainer_sgdmf.py
@@ -0,0 +1,97 @@
+import logging
+
+from federatedscope.mf.trainer.trainer import MFTrainer
+from federatedscope.core.auxiliaries.utils import get_random
+from typing import Type
+import numpy as np
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+
+def wrap_MFTrainer(base_trainer: Type[MFTrainer]) -> Type[MFTrainer]:
+ """Build `SGDMFTrainer` with a plug-in manner, by registering new functions into specific `MFTrainer`
+
+ """
+
+ # ---------------- attribute-level plug-in -----------------------
+ init_sgdmf_ctx(base_trainer)
+
+ # ---------------- action-level plug-in -----------------------
+ base_trainer.replace_hook_in_train(
+ new_hook=hook_on_batch_backward,
+ target_trigger="on_batch_backward",
+ target_hook_name="_hook_on_batch_backward")
+
+ return base_trainer
+
+
+def init_sgdmf_ctx(base_trainer):
+ """Init necessary attributes used in SGDMF,
+ some new attributes will be with prefix `SGDMF` optimizer to avoid namespace pollution
+
+ """
+ ctx = base_trainer.ctx
+ cfg = base_trainer.cfg
+
+ sample_ratio = float(cfg.data.batch_size) / cfg.model.num_user
+ # Noise multiplier
+ tmp = cfg.sgdmf.constant * np.power(sample_ratio, 2) * (
+ cfg.federate.total_round_num * ctx.num_total_train_batch) * np.log(
+ 1. / cfg.sgdmf.delta)
+ noise_multipler = np.sqrt(tmp / np.power(cfg.sgdmf.epsilon, 2))
+ ctx.scale = max(cfg.sgdmf.theta, 1.) * noise_multipler * np.power(
+ cfg.sgdmf.R, 1.5)
+ logger.info("Inject noise: (loc=0, scale={})".format(ctx.scale))
+ ctx.sgdmf_R = cfg.sgdmf.R
+
+
+def embedding_clip(param, R: int):
+ """Clip embedding vector according to $R$
+
+ Arguments:
+ param (tensor): The embedding vector
+ R (int): The upper bound of ratings
+ """
+ # Turn all negative entries of U into 0
+ param.data = (torch.abs(param.data) + param.data) * 0.5
+ # Clip tensor
+ norms = torch.linalg.norm(param.data, dim=1)
+ threshold = np.sqrt(R)
+ param.data[norms > threshold] *= (threshold /
+ norms[norms > threshold]).reshape(
+ (-1, 1))
+ param.data[param.data < 0] = 0.
+
+
+def hook_on_batch_backward(ctx):
+ """Private local updates in SGDMF
+
+ """
+ ctx.optimizer.zero_grad()
+ ctx.loss_task.backward()
+
+ # Inject noise
+ ctx.model.embed_user.grad.data += get_random(
+ "Normal",
+ sample_shape=ctx.model.embed_user.shape,
+ params={
+ "loc": 0,
+ "scale": ctx.scale
+ },
+ device=ctx.model.embed_user.device)
+ ctx.model.embed_item.grad.data += get_random(
+ "Normal",
+ sample_shape=ctx.model.embed_item.shape,
+ params={
+ "loc": 0,
+ "scale": ctx.scale
+ },
+ device=ctx.model.embed_item.device)
+ ctx.optimizer.step()
+
+ # Embedding clipping
+ with torch.no_grad():
+ embedding_clip(ctx.model.embed_user, ctx.sgdmf_R)
+ embedding_clip(ctx.model.embed_item, ctx.sgdmf_R)
diff --git a/federatedscope/nlp/__init__.py b/federatedscope/nlp/__init__.py
new file mode 100644
index 000000000..f8e91f237
--- /dev/null
+++ b/federatedscope/nlp/__init__.py
@@ -0,0 +1,3 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
diff --git a/federatedscope/nlp/dataloader/__init__.py b/federatedscope/nlp/dataloader/__init__.py
new file mode 100644
index 000000000..fa39c9c7c
--- /dev/null
+++ b/federatedscope/nlp/dataloader/__init__.py
@@ -0,0 +1,3 @@
+from federatedscope.nlp.dataloader.dataloader import load_nlp_dataset
+
+__all__ = ['load_nlp_dataset']
\ No newline at end of file
diff --git a/federatedscope/nlp/dataloader/dataloader.py b/federatedscope/nlp/dataloader/dataloader.py
new file mode 100644
index 000000000..504a5a325
--- /dev/null
+++ b/federatedscope/nlp/dataloader/dataloader.py
@@ -0,0 +1,63 @@
+from torch.utils.data import DataLoader
+
+from federatedscope.nlp.dataset.leaf_nlp import LEAF_NLP
+from federatedscope.nlp.dataset.leaf_synthetic import LEAF_SYNTHETIC
+from federatedscope.core.auxiliaries.transform_builder import get_transform
+
+
+def load_nlp_dataset(config=None):
+ r"""
+ return {
+ 'client_id': {
+ 'train': DataLoader(),
+ 'test': DataLoader(),
+ 'val': DataLoader()
+ }
+ }
+ """
+ splits = config.data.splits
+
+ path = config.data.root
+ name = config.data.type.lower()
+ batch_size = config.data.batch_size
+ transforms_funcs = get_transform(config, 'torchtext')
+
+ if name in ['shakespeare', 'subreddit', 'twitter']:
+ dataset = LEAF_NLP(root=path,
+ name=name,
+ s_frac=config.data.subsample,
+ tr_frac=splits[0],
+ val_frac=splits[1],
+ seed=1234,
+ **transforms_funcs)
+ elif name == 'synthetic':
+ dataset = LEAF_SYNTHETIC(root=path)
+ else:
+ raise ValueError(f'No dataset named: {name}!')
+
+ client_num = min(len(dataset), config.federate.client_num
+ ) if config.federate.client_num > 0 else len(dataset)
+ config.merge_from_list(['federate.client_num', client_num])
+
+ # get local dataset
+ data_local_dict = dict()
+ for client_idx in range(client_num):
+ dataloader = {
+ 'train': DataLoader(dataset[client_idx]['train'],
+ batch_size,
+ shuffle=config.data.shuffle,
+ num_workers=config.data.num_workers),
+ 'test': DataLoader(dataset[client_idx]['test'],
+ batch_size,
+ shuffle=False,
+ num_workers=config.data.num_workers)
+ }
+ if 'val' in dataset[client_idx]:
+ dataloader['val'] = DataLoader(dataset[client_idx]['val'],
+ batch_size,
+ shuffle=False,
+ num_workers=config.data.num_workers)
+
+ data_local_dict[client_idx + 1] = dataloader
+
+ return data_local_dict, config
diff --git a/federatedscope/nlp/dataset/__init__.py b/federatedscope/nlp/dataset/__init__.py
new file mode 100644
index 000000000..42638817a
--- /dev/null
+++ b/federatedscope/nlp/dataset/__init__.py
@@ -0,0 +1,8 @@
+from os.path import dirname, basename, isfile, join
+import glob
+
+modules = glob.glob(join(dirname(__file__), "*.py"))
+__all__ = [
+ basename(f)[:-3] for f in modules
+ if isfile(f) and not f.endswith('__init__.py')
+]
\ No newline at end of file
diff --git a/federatedscope/nlp/dataset/leaf_nlp.py b/federatedscope/nlp/dataset/leaf_nlp.py
new file mode 100644
index 000000000..1e4fc0870
--- /dev/null
+++ b/federatedscope/nlp/dataset/leaf_nlp.py
@@ -0,0 +1,266 @@
+import os
+import random
+import pickle
+import json
+import torch
+import math
+
+import os.path as osp
+
+from tqdm import tqdm
+from collections import defaultdict
+
+from sklearn.model_selection import train_test_split
+
+from federatedscope.core.auxiliaries.utils import save_local_data, download_url
+from federatedscope.cv.dataset.leaf import LEAF
+from federatedscope.nlp.dataset.utils import *
+
+
+class LEAF_NLP(LEAF):
+ """
+ LEAF NLP dataset from
+
+ leaf.cmu.edu
+
+ Arguments:
+ root (str): root path.
+ name (str): name of dataset, ‘shakespeare’ or ‘xxx’.
+ s_frac (float): fraction of the dataset to be used; default=0.3.
+ tr_frac (float): train set proportion for each task; default=0.8.
+ val_frac (float): valid set proportion for each task; default=0.0.
+ transform: transform for x.
+ target_transform: transform for y.
+
+ """
+ def __init__(self,
+ root,
+ name,
+ s_frac=0.3,
+ tr_frac=0.8,
+ val_frac=0.0,
+ seed=123,
+ transform=None,
+ target_transform=None):
+ self.s_frac = s_frac
+ self.tr_frac = tr_frac
+ self.val_frac = val_frac
+ self.seed = seed
+ super(LEAF_NLP, self).__init__(root, name, transform, target_transform)
+ files = os.listdir(self.processed_dir)
+ files = [f for f in files if f.startswith('task_')]
+ if len(files):
+ # Sort by idx
+ files.sort(key=lambda k: int(k[5:]))
+
+ for file in files:
+ train_data, train_targets = torch.load(
+ osp.join(self.processed_dir, file, 'train.pt'))
+ test_data, test_targets = torch.load(
+ osp.join(self.processed_dir, file, 'test.pt'))
+ self.data_dict[int(file[5:])] = {
+ 'train': (train_data, train_targets),
+ 'test': (test_data, test_targets)
+ }
+ if osp.exists(osp.join(self.processed_dir, file, 'val.pt')):
+ val_data, val_targets = torch.load(
+ osp.join(self.processed_dir, file, 'val.pt'))
+ self.data_dict[int(file[5:])]['val'] = (val_data,
+ val_targets)
+ else:
+ raise RuntimeError(
+ 'Please delete ‘processed’ folder and try again!')
+
+ @property
+ def raw_file_names(self):
+ names = [f'{self.name}_all_data.zip']
+ return names
+
+ def download(self):
+ # Download to `self.raw_dir`.
+ url = 'https://federatedscope.oss-cn-beijing.aliyuncs.com'
+ os.makedirs(self.raw_dir, exist_ok=True)
+ for name in self.raw_file_names:
+ download_url(f'{url}/{name}', self.raw_dir)
+
+ def __getitem__(self, index):
+ """
+ Arguments:
+ index (int): Index
+
+ :returns:
+ dict: {'train':[(text, target)],
+ 'test':[(text, target)],
+ 'val':[(text, target)]}
+ where target is the target class.
+ """
+ text_dict = {}
+ data = self.data_dict[index]
+ for key in data:
+ text_dict[key] = []
+ texts, targets = data[key]
+ for idx in range(targets.shape[0]):
+ text = texts[idx]
+
+ if self.transform is not None:
+ text = self.transform(text)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ text_dict[key].append((text, targets[idx]))
+
+ return text_dict
+
+ def tokenizer(self, data, targets):
+ """
+ TOKENIZER = {
+ 'shakespeare': {
+ 'x': word_to_indices,
+ 'y': letter_to_vec
+ },
+ 'twitter': {
+ 'x': bag_of_words,
+ 'y': target_to_binary
+ },
+ 'subreddit': {
+ 'x': token_to_ids,
+ 'y': token_to_ids
+ }
+ }
+ """
+ if self.name == 'shakespeare':
+ data = [
+ word_to_indices(re.sub(r" *", r' ', raw_text))
+ for raw_text in data
+ ]
+ targets = [letter_to_vec(raw_target) for raw_target in targets]
+
+ elif self.name == 'twitter':
+ # Loading bag of word embeddings
+ with open(osp.join(self.raw_dir, 'embs.json'), 'r') as inf:
+ embs = json.load(inf)
+ id2word = embs['vocab']
+ word2id = {v: k for k, v in enumerate(id2word)}
+ # [ID, Date, Query, User, Content]
+ data = [bag_of_words(raw_text[4], word2id) for raw_text in data]
+ targets = [target_to_binary(raw_target) for raw_target in targets]
+
+ elif self.name == 'subreddit':
+ with open(osp.join(self.raw_dir, 'reddit_vocab.pck'), 'rb') as inf:
+ vocab_file = pickle.load(inf)
+ vocab = defaultdict(lambda: vocab_file['unk_symbol'])
+ vocab.update(vocab_file['vocab'])
+
+ data_x_by_seq, data_y_by_seq, mask_by_seq = [], [], []
+
+ for c, l in zip(data, targets):
+ data_x_by_seq.extend(c)
+ data_y_by_seq.extend(l['target_tokens'])
+ mask_by_seq.extend(l['count_tokens'])
+
+ data, targets, mask = data_x_by_seq, data_y_by_seq, mask_by_seq
+
+ data = token_to_ids(data, vocab)
+ targets = token_to_ids(targets, vocab)
+ # Next word prediction
+ targets = [words[-1] for words in targets]
+
+ return data, targets
+
+ def process(self):
+ raw_path = osp.join(self.raw_dir, "all_data")
+ files = os.listdir(raw_path)
+ files = [f for f in files if f.endswith('.json')]
+
+ if self.name == 'subreddit':
+ self.s_frac = 1.0
+
+ n_tasks = math.ceil(len(files) * self.s_frac)
+ random.shuffle(files)
+ files = files[:n_tasks]
+
+ print("Preprocess data (Please leave enough space)...")
+
+ idx = 0
+ reddit_idx = []
+ for num, file in enumerate(tqdm(files)):
+ with open(osp.join(raw_path, file), 'r') as f:
+ raw_data = json.load(f)
+
+ user_list = list(raw_data['user_data'].keys())
+ n_tasks = math.ceil(len(user_list) * self.s_frac)
+ random.shuffle(user_list)
+ user_list = user_list[:n_tasks]
+ for user in user_list:
+ data, targets = raw_data['user_data'][user]['x'], raw_data[
+ 'user_data'][user]['y']
+
+ # Filter the user within 50 contents
+ if self.name == 'twitter' and len(data) <= 50:
+ continue
+ if self.name == 'subreddit':
+ if user not in reddit_idx:
+ reddit_idx.append(user)
+
+ # Tokenize
+ data, targets = self.tokenizer(data, targets)
+
+ if len(data) > 2:
+ data = torch.LongTensor(np.stack(data))
+ targets = torch.LongTensor(np.stack(targets))
+ else:
+ data = torch.Longtensor(data)
+ targets = torch.LongTensor(targets)
+
+ if self.name == 'subreddit':
+ # subreddit has fixed split
+ train_data, test_data, val_data = None, None, None
+ train_targets, test_targets, val_targets = None, None, None
+ if file.startswith('train'):
+ train_data = data
+ train_targets = targets
+ elif file.startswith('test'):
+ test_data = data
+ test_targets = targets
+ elif file.startswith('val'):
+ val_data = data
+ val_targets = targets
+ else:
+ continue
+ save_path = osp.join(self.processed_dir,
+ f"task_{reddit_idx.index(user)}")
+ else:
+ train_data, test_data, train_targets, test_targets =\
+ train_test_split(
+ data,
+ targets,
+ train_size=self.tr_frac,
+ random_state=self.seed
+ )
+
+ if self.val_frac > 0:
+ try:
+ val_data, test_data, val_targets, test_targets = \
+ train_test_split(
+ test_data,
+ test_targets,
+ train_size=self.val_frac / (1.-self.tr_frac),
+ random_state=self.seed
+ )
+ except:
+ val_data, val_targets = None, None
+
+ else:
+ val_data, val_targets = None, None
+ save_path = osp.join(self.processed_dir, f"task_{idx}")
+ os.makedirs(save_path, exist_ok=True)
+
+ save_local_data(dir_path=save_path,
+ train_data=train_data,
+ train_targets=train_targets,
+ test_data=test_data,
+ test_targets=test_targets,
+ val_data=val_data,
+ val_targets=val_targets)
+ idx += 1
diff --git a/federatedscope/nlp/dataset/leaf_synthetic.py b/federatedscope/nlp/dataset/leaf_synthetic.py
new file mode 100644
index 000000000..24de31623
--- /dev/null
+++ b/federatedscope/nlp/dataset/leaf_synthetic.py
@@ -0,0 +1,199 @@
+import os
+import pickle
+import argparse
+import torch
+import numpy as np
+import os.path as osp
+
+from sklearn.utils import shuffle
+from torch.utils.data import Dataset
+
+from federatedscope.core.auxiliaries.utils import save_local_data
+from federatedscope.cv.dataset.leaf import LEAF
+
+
+def sigmoid(x):
+ return 1 / (1 + np.exp(-x))
+
+
+def softmax(x):
+ ex = np.exp(x)
+ sum_ex = np.sum(np.exp(x))
+ return ex / sum_ex
+
+
+class LEAF_SYNTHETIC(LEAF):
+ """SYNTHETIC dataset from "Federated Multi-Task Learning under a Mixture of Distributions"
+
+ Source: https://github.com/omarfoq/FedEM/tree/main/data/synthetic
+
+ Arguments:
+ root (str): root path.
+ name (str): name of dataset, `SYNTHETIC`.
+ n_components (int): number of mixture components, default=3.
+ n_task (int): number of tasks/clients, default = 300.
+ n_test (int): size of test set, default=5,000.
+ n_val (int): size of validation set, default=5,000.
+ dim (int): dimension of the data, default=150.
+ noise_level (float): proportion of noise, default=0.1.
+ alpha (float): alpha of LDA, default=0.4.
+ box (list): box of `x`, default=(-1.0, 1.0).
+
+ """
+ def __init__(self,
+ root,
+ name='synthetic',
+ n_components=3,
+ n_tasks=300,
+ n_test=5000,
+ n_val=5000,
+ dim=150,
+ noise_level=0.1,
+ alpha=0.4,
+ box=(-1.0, 1.0),
+ uniform_marginal=True):
+
+ self.root = root
+ self.n_components = n_components
+ self.n_tasks = n_tasks
+ self.n_test = n_test
+ self.n_val = n_val
+ self.dim = dim
+ self.noise_level = noise_level
+ self.alpha = alpha * np.ones(n_components)
+ self.box = box
+ self.uniform_marginal = uniform_marginal
+ self.num_samples = self.get_num_samples(self.n_tasks)
+
+ self.theta = np.zeros((self.n_components, self.dim))
+ self.mixture_weights = np.zeros((self.n_tasks, self.n_components))
+
+ self.generate_mixture_weights()
+ self.generate_components()
+
+ super(LEAF_SYNTHETIC, self).__init__(root, name, None, None)
+ files = os.listdir(self.processed_dir)
+ files = [f for f in files if f.startswith('task_')]
+ if len(files):
+ # Sort by idx
+ files.sort(key=lambda k: int(k[5:]))
+
+ for file in files:
+ train_data, train_targets = torch.load(
+ osp.join(self.processed_dir, file, 'train.pt'))
+ test_data, test_targets = torch.load(
+ osp.join(self.processed_dir, file, 'test.pt'))
+ self.data_dict[int(file[5:])] = {
+ 'train': (train_data, train_targets),
+ 'test': (test_data, test_targets)
+ }
+ if osp.exists(osp.join(self.processed_dir, file, 'val.pt')):
+ val_data, val_targets = torch.load(
+ osp.join(self.processed_dir, file, 'val.pt'))
+ self.data_dict[int(file[5:])]['val'] = (val_data,
+ val_targets)
+ else:
+ raise RuntimeError(
+ 'Please delete ‘processed’ folder and try again!')
+
+ def download(self):
+ pass
+
+ def extract(self):
+ pass
+
+ def __getitem__(self, index):
+ """
+ Arguments:
+ index (int): Index
+
+ :returns:
+ dict: {'train':[(x, target)],
+ 'test':[(x, target)],
+ 'val':[(x, target)]}
+ where target is the target class.
+ """
+ text_dict = {}
+ data = self.data_dict[index]
+ for key in data:
+ text_dict[key] = []
+ texts, targets = data[key]
+ for idx in range(targets.shape[0]):
+ text = texts[idx]
+ text_dict[key].append((text, targets[idx]))
+
+ return text_dict
+
+ def generate_mixture_weights(self):
+ for task_id in range(self.n_tasks):
+ self.mixture_weights[task_id] = np.random.dirichlet(
+ alpha=self.alpha)
+
+ def generate_components(self):
+ self.theta = np.random.uniform(self.box[0],
+ self.box[1],
+ size=(self.n_components, self.dim))
+
+ def generate_data(self, task_id, n_samples=10000):
+ latent_variable_count = np.random.multinomial(
+ n_samples, self.mixture_weights[task_id])
+ y = np.zeros(n_samples)
+
+ if self.uniform_marginal:
+ x = np.random.uniform(self.box[0],
+ self.box[1],
+ size=(n_samples, self.dim))
+ else:
+ raise NotImplementedError(
+ "Only uniform marginal is available for the moment")
+
+ current_index = 0
+ for component_id in range(self.n_components):
+ y_hat = x[current_index:current_index +
+ latent_variable_count[component_id]] @ self.theta[
+ component_id]
+ noise = np.random.normal(size=latent_variable_count[component_id],
+ scale=self.noise_level)
+ y[current_index: current_index + latent_variable_count[component_id]] = \
+ np.round(sigmoid(y_hat + noise)).astype(int)
+
+ return shuffle(x.astype(np.float32), y.astype(np.int64))
+
+ def save_metadata(self, path_):
+ metadata = dict()
+ metadata["mixture_weights"] = self.mixture_weights
+ metadata["theta"] = self.theta
+
+ with open(path_, 'wb') as f:
+ pickle.dump(metadata, f)
+
+ def get_num_samples(self,
+ num_tasks,
+ min_num_samples=50,
+ max_num_samples=1000):
+ num_samples = np.random.lognormal(4, 2, num_tasks).astype(int)
+ num_samples = [
+ min(s + min_num_samples, max_num_samples) for s in num_samples
+ ]
+ return num_samples
+
+ def process(self):
+ for task_id in range(self.n_tasks):
+ save_path = os.path.join(self.processed_dir, f"task_{task_id}")
+ os.makedirs(save_path, exist_ok=True)
+
+ train_data, train_targets = self.generate_data(
+ task_id, self.num_samples[task_id])
+ test_data, test_targets = self.generate_data(task_id, self.n_test)
+
+ if self.n_val > 0:
+ val_data, val_targets = self.generate_data(task_id, self.n_val)
+ else:
+ val_data, val_targets = None, None
+ save_local_data(dir_path=save_path,
+ train_data=train_data,
+ train_targets=train_targets,
+ test_data=test_data,
+ test_targets=test_targets,
+ val_data=val_data,
+ val_targets=val_targets)
\ No newline at end of file
diff --git a/federatedscope/nlp/dataset/utils.py b/federatedscope/nlp/dataset/utils.py
new file mode 100644
index 000000000..736d8c163
--- /dev/null
+++ b/federatedscope/nlp/dataset/utils.py
@@ -0,0 +1,88 @@
+"""
+Utils for language models.
+from https://github.com/litian96/FedProx/blob/master/flearn/utils/language_utils.py
+"""
+
+import re
+import numpy as np
+from collections import Counter
+
+# ------------------------
+# utils for shakespeare dataset
+
+ALL_LETTERS = "\n !\"&'(),-.0123456789:;>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz}"
+NUM_LETTERS = len(ALL_LETTERS)
+
+
+def _one_hot(index, size):
+ '''returns one-hot vector with given size and value 1 at given index
+ '''
+ vec = [0 for _ in range(size)]
+ vec[int(index)] = 1
+ return vec
+
+
+def letter_to_vec(letter):
+ index = ALL_LETTERS.find(letter)
+ return index
+
+
+def word_to_indices(word):
+ '''returns a list of character indices
+ Arguments:
+ word: string
+
+ :returns:
+ indices: int list with length len(word)
+ '''
+ indices = []
+ for c in word:
+ indices.append(ALL_LETTERS.find(c))
+ return indices
+
+
+# ------------------------
+# utils for sent140 dataset
+
+
+def split_line(line):
+ '''split given line/phrase into list of words
+ Arguments:
+ line: string representing phrase to be split
+
+ :returns:
+ list of strings, with each string representing a word
+ '''
+ return re.findall(r"[\w']+|[.,!?;]", line)
+
+
+def bag_of_words(line, vocab):
+ '''returns bag of words representation of given phrase using given vocab
+ Arguments:
+ line: string representing phrase to be parsed
+ vocab: dictionary with words as keys and indices as values
+ :returns:
+ integer list
+ '''
+ bag = [0] * len(vocab)
+ words = split_line(line)
+ for w in words:
+ if w in vocab:
+ bag[vocab[w]] += 1
+ return bag
+
+
+def target_to_binary(label):
+ return int(label == 1)
+
+
+def token_to_ids(texts, vocab):
+ to_ret = [[vocab[word] for word in line] for line in texts]
+ return np.array(to_ret)
+
+
+def label_to_index(labels):
+ counter = Counter(labels)
+ sorted_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
+ label_list = [x[0] for x in sorted_tuples]
+ return [label_list.index(x) for x in labels]
diff --git a/federatedscope/nlp/loss/__init__.py b/federatedscope/nlp/loss/__init__.py
new file mode 100644
index 000000000..ccd787417
--- /dev/null
+++ b/federatedscope/nlp/loss/__init__.py
@@ -0,0 +1,5 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+from federatedscope.nlp.loss.character_loss import *
\ No newline at end of file
diff --git a/federatedscope/nlp/loss/character_loss.py b/federatedscope/nlp/loss/character_loss.py
new file mode 100644
index 000000000..7ec766f82
--- /dev/null
+++ b/federatedscope/nlp/loss/character_loss.py
@@ -0,0 +1,54 @@
+import torch
+
+from federatedscope.register import register_criterion
+"""
+Norm for Letters freq from FedEM:
+https://github.com/omarfoq/FedEM/blob/13f366c41c14b234147c2662c258b8a9db2f38cc/utils/constants.py
+"""
+CHARACTERS_WEIGHTS = {
+ '\n': 0.43795308843799086,
+ ' ': 0.042500849608091536,
+ ',': 0.6559597911540539,
+ '.': 0.6987226398690805,
+ 'I': 0.9777491725556848,
+ 'a': 0.2226022051965085,
+ 'c': 0.813311655455682,
+ 'd': 0.4071860494572223,
+ 'e': 0.13455606165058104,
+ 'f': 0.7908671114133974,
+ 'g': 0.9532922255751889,
+ 'h': 0.2496906467588955,
+ 'i': 0.27444893060347214,
+ 'l': 0.37296488139109546,
+ 'm': 0.569937324017103,
+ 'n': 0.2520734570378263,
+ 'o': 0.1934141300462555,
+ 'r': 0.26035705948768273,
+ 's': 0.2534775933879391,
+ 't': 0.1876471355731429,
+ 'u': 0.47430062920373184,
+ 'w': 0.7470615815733715,
+ 'y': 0.6388302610200002
+}
+
+ALL_LETTERS = "\n !\"&'(),-.0123456789:;>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz}"
+
+
+def create_character_loss(type, device):
+ """
+ Character_loss from FedEM:
+ https://github.com/omarfoq/FedEM/blob/13f366c41c14b234147c2662c258b8a9db2f38cc/utils/utils.py
+ """
+ if type == 'character_loss':
+ all_characters = ALL_LETTERS
+ labels_weight = torch.ones(len(all_characters), device=device)
+ for character in CHARACTERS_WEIGHTS:
+ labels_weight[all_characters.index(
+ character)] = CHARACTERS_WEIGHTS[character]
+ labels_weight = labels_weight * 8
+ criterion = torch.nn.CrossEntropyLoss(weight=labels_weight).to(device)
+
+ return criterion
+
+
+register_criterion('character_loss', create_character_loss)
diff --git a/federatedscope/nlp/model/__init__.py b/federatedscope/nlp/model/__init__.py
new file mode 100644
index 000000000..90d10121f
--- /dev/null
+++ b/federatedscope/nlp/model/__init__.py
@@ -0,0 +1,8 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+from federatedscope.nlp.model.rnn import LSTM
+from federatedscope.nlp.model.model_builder import get_rnn, get_transformer
+
+__all__ = ['LSTM', 'get_rnn', 'get_transformer']
\ No newline at end of file
diff --git a/federatedscope/nlp/model/model_builder.py b/federatedscope/nlp/model/model_builder.py
new file mode 100644
index 000000000..3a2898fc4
--- /dev/null
+++ b/federatedscope/nlp/model/model_builder.py
@@ -0,0 +1,52 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+
+def get_rnn(model_config, local_data):
+ from federatedscope.nlp.model.rnn import LSTM
+ if isinstance(local_data, dict):
+ if 'data' in local_data.keys():
+ data = local_data['data']
+ elif 'train' in local_data.keys():
+ # local_data['train'] is Dataloader
+ data = next(iter(local_data['train']))
+ else:
+ raise TypeError('Unsupported data type.')
+ else:
+ data = local_data
+
+ x, _ = data
+
+ # check the task
+ if model_config.type == 'lstm':
+ model = LSTM(in_channels=x.shape[1] if not model_config.in_channels
+ else model_config.in_channels,
+ hidden=model_config.hidden,
+ out_channels=model_config.out_channels,
+ embed_size=model_config.embed_size,
+ dropout=model_config.dropout)
+ else:
+ raise ValueError(f'No model named {model_config.type}!')
+
+ return model
+
+
+def get_transformer(model_config, local_data):
+ from transformers import AutoModelForPreTraining, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelWithLMHead, AutoModel
+
+ model_func_dict = {
+ 'PreTraining'.lower(): AutoModelForPreTraining,
+ 'QuestionAnswering'.lower(): AutoModelForQuestionAnswering,
+ 'SequenceClassification'.lower(): AutoModelForSequenceClassification,
+ 'TokenClassification'.lower(): AutoModelForTokenClassification,
+ 'WithLMHead'.lower(): AutoModelWithLMHead,
+ 'Auto'.lower(): AutoModel
+ }
+ assert model_config.task.lower(
+ ) in model_func_dict, f'model_config.task should be in {model_func_dict.keys()} when using pre_trained transformer model '
+ path, _ = model_config.type.split('@')
+ model = model_func_dict[model_config.task.lower()].from_pretrained(
+ path, num_labels=model_config.out_channels)
+
+ return model
diff --git a/federatedscope/nlp/model/rnn.py b/federatedscope/nlp/model/rnn.py
new file mode 100644
index 000000000..8d5607590
--- /dev/null
+++ b/federatedscope/nlp/model/rnn.py
@@ -0,0 +1,40 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class LSTM(nn.Module):
+ def __init__(self,
+ in_channels,
+ hidden,
+ out_channels,
+ n_layers=2,
+ embed_size=8,
+ dropout=.0):
+ super(LSTM, self).__init__()
+ self.in_channels = in_channels
+ self.hidden = hidden
+ self.embed_size = embed_size
+ self.out_channels = out_channels
+ self.n_layers = n_layers
+
+ self.encoder = nn.Embedding(in_channels, embed_size)
+
+ self.rnn =\
+ nn.LSTM(
+ input_size=embed_size if embed_size else in_channels,
+ hidden_size=hidden,
+ num_layers=n_layers,
+ batch_first=True,
+ dropout=dropout
+ )
+
+ self.decoder = nn.Linear(hidden, out_channels)
+
+ def forward(self, input_):
+ if self.embed_size:
+ input_ = self.encoder(input_)
+ output, _ = self.rnn(input_)
+ output = self.decoder(output)
+ output = output.permute(0, 2, 1) # change dimension to (B, C, T)
+ final_word = output[:, :, -1]
+ return final_word
diff --git a/federatedscope/nlp/trainer/__init__.py b/federatedscope/nlp/trainer/__init__.py
new file mode 100644
index 000000000..42638817a
--- /dev/null
+++ b/federatedscope/nlp/trainer/__init__.py
@@ -0,0 +1,8 @@
+from os.path import dirname, basename, isfile, join
+import glob
+
+modules = glob.glob(join(dirname(__file__), "*.py"))
+__all__ = [
+ basename(f)[:-3] for f in modules
+ if isfile(f) and not f.endswith('__init__.py')
+]
\ No newline at end of file
diff --git a/federatedscope/nlp/trainer/trainer.py b/federatedscope/nlp/trainer/trainer.py
new file mode 100644
index 000000000..8f3fac5fc
--- /dev/null
+++ b/federatedscope/nlp/trainer/trainer.py
@@ -0,0 +1,28 @@
+from federatedscope.register import register_trainer
+from federatedscope.core.trainers import GeneralTorchTrainer
+from federatedscope.core.auxiliaries import utils
+
+
+class NLPTrainer(GeneralTorchTrainer):
+ def _hook_on_batch_forward(self, ctx):
+ x, label = [utils.move_to(_, ctx.device) for _ in ctx.data_batch]
+ if isinstance(x, dict):
+ pred = ctx.model(**x)[0]
+ else:
+ pred = ctx.model(x)
+ if len(label.size()) == 0:
+ label = label.unsqueeze(0)
+ ctx.loss_batch = ctx.criterion(pred, label)
+ ctx.y_true = label
+ ctx.y_prob = pred
+
+ ctx.batch_size = len(label)
+
+
+def call_nlp_trainer(trainer_type):
+ if trainer_type == 'nlptrainer':
+ trainer_builder = NLPTrainer
+ return trainer_builder
+
+
+register_trainer('nlptrainer', call_nlp_trainer)
diff --git a/federatedscope/parse_exp_results.py b/federatedscope/parse_exp_results.py
new file mode 100644
index 000000000..3a8e13309
--- /dev/null
+++ b/federatedscope/parse_exp_results.py
@@ -0,0 +1,66 @@
+import argparse
+import json
+import copy
+import numpy as np
+
+parser = argparse.ArgumentParser(description='FederatedScope result parsing')
+parser.add_argument('--input',
+ help='path of exp results',
+ required=True,
+ type=str)
+args = parser.parse_args()
+
+
+def merge_local_results(local_results):
+ aggr_results = copy.deepcopy(local_results[0])
+ aggr_results = {key: [aggr_results[key]] for key in aggr_results}
+ for i in range(1, len(local_results)):
+ for k, v in local_results[i].items():
+ aggr_results[k].append(v)
+ return aggr_results
+
+
+def main():
+ result_list_wavg = []
+ result_list_avg = []
+ result_list_global = []
+
+ with open(args.input, 'r') as ips:
+ for line in ips:
+ try:
+ state, line = line.split('INFO: ')
+ except:
+ continue
+ if line.startswith('{'):
+ line = line.replace("\'", "\"")
+ line = json.loads(s=line)
+ if line['Round'] == 'Final' and line['Role'] == 'Server #':
+ res = line['Results_raw']
+ if 'Results_raw' in line.keys():
+ if 'server_global_eval' in res.keys():
+ result_list_global.append(
+ res['server_global_eval'])
+ if 'client_summarized_weighted_avg' in res.keys():
+ result_list_wavg.append(
+ res['client_summarized_weighted_avg'])
+ if 'client_summarized_avg' in res.keys():
+ result_list_avg.append(
+ res['client_summarized_avg'])
+
+ print(args.input)
+ if len(result_list_wavg):
+ print('\tResults_weighted_avg')
+ for key, v in merge_local_results(result_list_wavg).items():
+ print("\t{}, {:.4f}, {:.4f}".format(key, np.mean(v), np.std(v)))
+ if len(result_list_avg):
+ print('\tResults_avg')
+ for key, v in merge_local_results(result_list_avg).items():
+ print("\t{}, {:.4f}, {:.4f}".format(key, np.mean(v), np.std(v)))
+ if len(result_list_global):
+ print('\tserver_global_eval')
+ for key, v in merge_local_results(result_list_global).items():
+ print("\t{}, {:.4f}, {:.4f}".format(key, np.mean(v), np.std(v)))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/federatedscope/register.py b/federatedscope/register.py
new file mode 100644
index 000000000..88fd68eb8
--- /dev/null
+++ b/federatedscope/register.py
@@ -0,0 +1,84 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def register(key, module, module_dict):
+ if key in module_dict:
+ logger.warning(
+ 'Key {} is already pre-defined, overwritten.'.format(key))
+ module_dict[key] = module
+
+
+data_dict = {}
+
+
+def register_data(key, module):
+ register(key, module, data_dict)
+
+
+model_dict = {}
+
+
+def register_model(key, module):
+ register(key, module, model_dict)
+
+
+trainer_dict = {}
+
+
+def register_trainer(key, module):
+ register(key, module, trainer_dict)
+
+
+config_dict = {}
+
+
+def register_config(key, module):
+ register(key, module, config_dict)
+
+
+metric_dict = {}
+
+
+def register_metric(key, module):
+ register(key, module, metric_dict)
+
+
+criterion_dict = {}
+
+
+def register_criterion(key, module):
+ register(key, module, criterion_dict)
+
+
+regularizer_dict = {}
+
+
+def register_regularizer(key, module):
+ register(key, module, regularizer_dict)
+
+
+auxiliary_data_loader_PIA_dict = {}
+
+
+def register_auxiliary_data_loader_PIA(key, module):
+ register(key, module, auxiliary_data_loader_PIA_dict)
+
+
+splitter_dict = {}
+
+
+def register_splitter(key, module):
+ register(key, module, splitter_dict)
+
+
+transform_dict = {}
+
+
+def register_transform(key, module):
+ register(key, module, transform_dict)
diff --git a/federatedscope/tabular/__init__.py b/federatedscope/tabular/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/federatedscope/tabular/dataloader/__init__.py b/federatedscope/tabular/dataloader/__init__.py
new file mode 100644
index 000000000..d3622deb9
--- /dev/null
+++ b/federatedscope/tabular/dataloader/__init__.py
@@ -0,0 +1,3 @@
+from federatedscope.tabular.dataloader.quadratic import load_quadratic_dataset
+
+__all__ = ['load_quadratic_dataset']
diff --git a/federatedscope/tabular/dataloader/quadratic.py b/federatedscope/tabular/dataloader/quadratic.py
new file mode 100644
index 000000000..37b73829d
--- /dev/null
+++ b/federatedscope/tabular/dataloader/quadratic.py
@@ -0,0 +1,21 @@
+import numpy as np
+
+from torch.utils.data import DataLoader
+
+
+def load_quadratic_dataset(config):
+ dataset = dict()
+ d = config.data.quadratic.dim
+ base = np.exp(
+ np.log(config.data.quadratic.max_curv / config.data.quadratic.min_curv)
+ / (config.federate.client_num - 1))
+ for i in range(1, 1 + config.federate.client_num):
+ # TODO: enable sphere
+ a = 0.02 * base**(i - 1) * np.identity(d)
+ # TODO: enable non-zero minimizer, i.e., provide a shift
+ client_data = dict()
+ client_data['train'] = DataLoader([(a.astype(np.float32), .0)])
+ client_data['val'] = DataLoader([(a.astype(np.float32), .0)])
+ client_data['test'] = DataLoader([(a.astype(np.float32), .0)])
+ dataset[i] = client_data
+ return dataset, config
diff --git a/federatedscope/tabular/model/__init__.py b/federatedscope/tabular/model/__init__.py
new file mode 100644
index 000000000..bbb23a4f9
--- /dev/null
+++ b/federatedscope/tabular/model/__init__.py
@@ -0,0 +1,3 @@
+from federatedscope.tabular.model.quadratic import QuadraticModel
+
+__all__ = ['QuadraticModel']
diff --git a/federatedscope/tabular/model/quadratic.py b/federatedscope/tabular/model/quadratic.py
new file mode 100644
index 000000000..92c235dbd
--- /dev/null
+++ b/federatedscope/tabular/model/quadratic.py
@@ -0,0 +1,11 @@
+import torch
+
+
+class QuadraticModel(torch.nn.Module):
+ def __init__(self, in_channels, class_num):
+ super(QuadraticModel, self).__init__()
+ x = torch.ones((in_channels, 1))
+ self.x = torch.nn.parameter.Parameter(x.uniform_(-10.0, 10.0).float())
+
+ def forward(self, A):
+ return torch.sum(self.x * torch.matmul(A, self.x), -1)
diff --git a/federatedscope/vertical_fl/Paillier/__init__.py b/federatedscope/vertical_fl/Paillier/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/federatedscope/vertical_fl/Paillier/abstract_paillier.py b/federatedscope/vertical_fl/Paillier/abstract_paillier.py
new file mode 100644
index 000000000..4e9c3f557
--- /dev/null
+++ b/federatedscope/vertical_fl/Paillier/abstract_paillier.py
@@ -0,0 +1,45 @@
+# You can refer to pyphe for the detail implementation. (https://github.com/data61/python-paillier/blob/master/phe/paillier.py)
+# Or implement an effective version of Paillier ()
+
+DEFAULT_KEYSIZE = 3072
+
+
+def generate_paillier_keypair(n_length=DEFAULT_KEYSIZE):
+ """Generate public key and private key used Paillier`.
+
+ Args:
+ n_length: key size in bits.
+
+ Returns:
+ tuple: The generated :class:`PaillierPublicKey` and
+ :class:`PaillierPrivateKey`
+ """
+ n = p = q = None
+ public_key = PaillierPublicKey(n)
+ private_key = PaillierPrivateKey(public_key, p, q)
+
+ return public_key, private_key
+
+
+class PaillierPublicKey(object):
+ """Contains a public key and associated encryption methods.
+ """
+ def __init__(self, n):
+ pass
+
+ def encrypt(self, value):
+ # We only provide an abstract implementation here
+
+ return value
+
+
+class PaillierPrivateKey(object):
+ """Contains a private key and associated decryption method.
+ """
+ def __init__(self, public_key, p, q):
+ pass
+
+ def decrypt(self, encrypted_number):
+ # We only provide an abstract implementation here
+
+ return encrypted_number
diff --git a/federatedscope/vertical_fl/__init__.py b/federatedscope/vertical_fl/__init__.py
new file mode 100644
index 000000000..f9458d03a
--- /dev/null
+++ b/federatedscope/vertical_fl/__init__.py
@@ -0,0 +1 @@
+from federatedscope.vertical_fl.Paillier.abstract_paillier import *
diff --git a/federatedscope/vertical_fl/dataloader/__init__.py b/federatedscope/vertical_fl/dataloader/__init__.py
new file mode 100644
index 000000000..420b3bbd7
--- /dev/null
+++ b/federatedscope/vertical_fl/dataloader/__init__.py
@@ -0,0 +1,3 @@
+from federatedscope.vertical_fl.dataloader.dataloader import load_vertical_data
+
+__all__ = ['load_vertical_data']
diff --git a/federatedscope/vertical_fl/dataloader/dataloader.py b/federatedscope/vertical_fl/dataloader/dataloader.py
new file mode 100644
index 000000000..2378ce2b9
--- /dev/null
+++ b/federatedscope/vertical_fl/dataloader/dataloader.py
@@ -0,0 +1,56 @@
+import numpy as np
+
+
+def load_vertical_data(config=None, generate=False):
+ """
+ To generate the synthetic data for vertical FL
+
+ Arguments:
+ config: configuration
+ generate (bool): whether to generate the synthetic data
+ :returns: The synthetic data, the modified config
+ :rtype: dict
+ """
+
+ if generate:
+ # generate toy data for running a vertical FL example
+ INSTANCE_NUM = 1000
+ TRAIN_SPLIT = 0.9
+
+ total_dims = np.sum(config.vertical.dims)
+ theta = np.random.uniform(low=-1.0, high=1.0, size=(total_dims, 1))
+ x = np.random.choice([-1.0, 1.0, -2.0, 2.0, -3.0, 3.0],
+ size=(INSTANCE_NUM, total_dims))
+ y = np.asarray([
+ 1.0 if x >= 0 else -1.0
+ for x in np.reshape(np.matmul(x, theta), -1)
+ ])
+
+ train_num = int(TRAIN_SPLIT * INSTANCE_NUM)
+ test_data = {'theta': theta, 'x': x[train_num:], 'y': y[train_num:]}
+ data = dict()
+
+ # For Server #0
+ data[0] = dict()
+ data[0]['train'] = None
+ data[0]['val'] = None
+ data[0]['test'] = test_data
+
+ # For Client #1
+ data[1] = dict()
+ data[1]['train'] = {
+ 'x': x[:train_num, :config.vertical.dims[0]],
+ 'y': y[:train_num]
+ }
+ data[0]['val'] = None
+ data[1]['test'] = test_data
+
+ # For Client #2
+ data[2] = dict()
+ data[2]['train'] = {'x': x[:train_num, config.vertical.dims[0]:]}
+ data[0]['val'] = None
+ data[2]['test'] = test_data
+
+ return data, config
+ else:
+ raise ValueError('You must provide the data file')
diff --git a/federatedscope/vertical_fl/dataloader/utils.py b/federatedscope/vertical_fl/dataloader/utils.py
new file mode 100644
index 000000000..96788c665
--- /dev/null
+++ b/federatedscope/vertical_fl/dataloader/utils.py
@@ -0,0 +1,30 @@
+import numpy as np
+import math
+
+
+def batch_iter(data, batch_size, shuffled=True):
+ """
+ A batch iteration
+
+ Arguments:
+ data(dict): data
+ batch_size (int): the batch size
+ shuffled (bool): whether to shuffle the data at the start of each epoch
+ :returns: sample index, batch of x, batch_of y
+ :rtype: int, ndarray, ndarry
+ """
+
+ assert 'x' in data and 'y' in data
+ data_x = data['x']
+ data_y = data['y']
+ data_size = len(data_y)
+ num_batches_per_epoch = math.ceil(data_size / batch_size)
+
+ while True:
+ shuffled_index = np.random.permutation(
+ np.arange(data_size)) if shuffled else np.arange(data_size)
+ for batch in range(num_batches_per_epoch):
+ start_index = batch * batch_size
+ end_index = min(data_size, (batch + 1) * batch_size)
+ sample_index = shuffled_index[start_index:end_index]
+ yield sample_index, data_x[sample_index], data_y[sample_index]
diff --git a/federatedscope/vertical_fl/worker/__init__.py b/federatedscope/vertical_fl/worker/__init__.py
new file mode 100644
index 000000000..3db6d2606
--- /dev/null
+++ b/federatedscope/vertical_fl/worker/__init__.py
@@ -0,0 +1,4 @@
+from federatedscope.vertical_fl.worker.vertical_client import vFLClient
+from federatedscope.vertical_fl.worker.vertical_server import vFLServer
+
+__all__ = ['vFLServer', 'vFLClient']
\ No newline at end of file
diff --git a/federatedscope/vertical_fl/worker/vertical_client.py b/federatedscope/vertical_fl/worker/vertical_client.py
new file mode 100644
index 000000000..c42e9d49f
--- /dev/null
+++ b/federatedscope/vertical_fl/worker/vertical_client.py
@@ -0,0 +1,107 @@
+import numpy as np
+import logging
+
+from federatedscope.core.worker import Client
+from federatedscope.core.message import Message
+from federatedscope.vertical_fl.dataloader.utils import batch_iter
+
+
+class vFLClient(Client):
+ """
+ The client class for vertical FL, which customizes the handled functions. Please refer to the tutorial for more details about the implementation algorithm
+ Implementation of Vertical FL refer to `Private federated learning on vertically partitioned data via entity resolution and additively homomorphic encryption` [Hardy, et al., 2017]
+ (https://arxiv.org/abs/1711.10677)
+ """
+ def __init__(self,
+ ID=-1,
+ server_id=None,
+ state=-1,
+ config=None,
+ data=None,
+ model=None,
+ device='cpu',
+ strategy=None,
+ *args,
+ **kwargs):
+
+ super(vFLClient,
+ self).__init__(ID, server_id, state, config, data, model, device,
+ strategy, *args, **kwargs)
+ self.data = data
+ self.public_key = None
+ self.theta = None
+ self.batch_index = None
+ self.own_label = ('y' in self.data['train'])
+ self.dataloader = batch_iter(self.data['train'],
+ self._cfg.data.batch_size,
+ shuffled=True)
+
+ self.register_handlers('public_keys',
+ self.callback_funcs_for_public_keys)
+ self.register_handlers('model_para',
+ self.callback_funcs_for_model_para)
+ self.register_handlers('encryped_gradient_u',
+ self.callback_funcs_for_encryped_gradient_u)
+ self.register_handlers('encryped_gradient_v',
+ self.callback_funcs_for_encryped_gradient_v)
+
+ def sample_data(self, index=None):
+ if index is None:
+ assert self.own_label
+ return next(self.dataloader)
+ else:
+ return self.data['train']['x'][index]
+
+ def callback_funcs_for_public_keys(self, message: Message):
+ self.public_key = message.content
+
+ def callback_funcs_for_model_para(self, message: Message):
+ self.theta = message.content
+ if self.own_label:
+ index, input_x, input_y = self.sample_data()
+ self.batch_index = index
+
+ u_A = 0.25 * np.matmul(input_x, self.theta) - 0.5 * input_y
+ en_u_A = [self.public_key.encrypt(x) for x in u_A]
+
+ self.comm_manager.send(
+ Message(msg_type='encryped_gradient_u',
+ sender=self.ID,
+ receiver=[
+ each for each in self.comm_manager.neighbors
+ if each != self.server_id
+ ],
+ state=self.state,
+ content=(self.batch_index, en_u_A)))
+
+ def callback_funcs_for_encryped_gradient_u(self, message: Message):
+ index, en_u_A = message.content
+ self.batch_index = index
+ input_x = self.sample_data(index=self.batch_index)
+ u_B = 0.25 * np.matmul(input_x, self.theta)
+ en_u_B = [self.public_key.encrypt(x) for x in u_B]
+ en_u = np.expand_dims([sum(x) for x in zip(en_u_A, en_u_B)], -1)
+ en_v_B = en_u * input_x
+
+ self.comm_manager.send(
+ Message(msg_type='encryped_gradient_v',
+ sender=self.ID,
+ receiver=[
+ each for each in self.comm_manager.neighbors
+ if each != self.server_id
+ ],
+ state=self.state,
+ content=(en_u, en_v_B)))
+
+ def callback_funcs_for_encryped_gradient_v(self, message: Message):
+ en_u, en_v_B = message.content
+ input_x = self.sample_data(index=self.batch_index)
+ en_v_A = en_u * input_x
+ en_v = np.concatenate([en_v_A, en_v_B], axis=-1)
+
+ self.comm_manager.send(
+ Message(msg_type='encryped_gradient',
+ sender=self.ID,
+ receiver=[self.server_id],
+ state=self.state,
+ content=(len(input_x), en_v)))
diff --git a/federatedscope/vertical_fl/worker/vertical_server.py b/federatedscope/vertical_fl/worker/vertical_server.py
new file mode 100644
index 000000000..9e6e1af34
--- /dev/null
+++ b/federatedscope/vertical_fl/worker/vertical_server.py
@@ -0,0 +1,121 @@
+import numpy as np
+import logging
+
+from federatedscope.core.monitors.monitor import update_best_result
+from federatedscope.core.worker import Server
+from federatedscope.core.message import Message
+from federatedscope.vertical_fl.Paillier import abstract_paillier
+
+logger = logging.getLogger(__name__)
+
+
+class vFLServer(Server):
+ """
+ The server class for vertical FL, which customizes the handled functions. Please refer to the tutorial for more details about the implementation algorithm
+ Implementation of Vertical FL refer to `Private federated learning on vertically partitioned data via entity resolution and additively homomorphic encryption` [Hardy, et al., 2017]
+ (https://arxiv.org/abs/1711.10677)
+ """
+ def __init__(self,
+ ID=-1,
+ state=0,
+ config=None,
+ data=None,
+ model=None,
+ client_num=5,
+ total_round_num=10,
+ device='cpu',
+ strategy=None,
+ **kwargs):
+ super(vFLServer,
+ self).__init__(ID, state, config, data, model, client_num,
+ total_round_num, device, strategy, **kwargs)
+ self.public_key, self.private_key = abstract_paillier.generate_paillier_keypair(
+ n_length=config.vertical.key_size)
+ self.dims = [0] + config.vertical.dims
+ self.theta = self.model.state_dict()['fc.weight'].numpy().reshape(-1)
+ self.lr = config.optimizer.lr
+
+ self.register_handlers('encryped_gradient',
+ self.callback_funcs_for_encryped_gradient)
+
+ def trigger_for_start(self):
+ if self.check_client_join_in():
+ self.broadcast_public_keys()
+ self.broadcast_client_address()
+ self.broadcast_model_para()
+
+ def broadcast_public_keys(self):
+ self.comm_manager.send(
+ Message(msg_type='public_keys',
+ sender=self.ID,
+ receiver=list(self.comm_manager.get_neighbors().keys()),
+ state=self.state,
+ content=self.public_key))
+
+ def broadcast_model_para(self):
+
+ client_ids = self.comm_manager.neighbors.keys()
+ cur_idx = 0
+ for client_id in client_ids:
+ theta_slices = self.theta[cur_idx:cur_idx +
+ self.dims[int(client_id)]]
+ self.comm_manager.send(
+ Message(msg_type='model_para',
+ sender=self.ID,
+ receiver=client_id,
+ state=self.state,
+ content=theta_slices))
+ cur_idx += self.dims[int(client_id)]
+
+ def callback_funcs_for_encryped_gradient(self, message: Message):
+ sample_num, en_v = message.content
+
+ v = np.reshape(
+ [self.private_key.decrypt(x) for x in np.reshape(en_v, -1)],
+ [sample_num, -1])
+ avg_gradients = np.mean(v, axis=0)
+ self.theta = self.theta - self.lr * avg_gradients
+
+ self.state += 1
+ if self.state % self._cfg.eval.freq == 0 and self.state != self.total_round_num:
+ metrics = self.evaluate()
+ update_best_result(self.best_results,
+ metrics,
+ results_type='server_global_eval',
+ round_wise_update_key=self._cfg.eval.
+ best_res_update_round_wise_key)
+ formatted_logs = self._monitor.format_eval_res(
+ metrics,
+ rnd=self.state,
+ role='Global-Eval-Server #',
+ forms=self._cfg.eval.report)
+ logger.info(formatted_logs)
+
+ if self.state < self.total_round_num:
+ # Move to next round of training
+ logger.info(
+ '----------- Starting a new training round (Round #{:d}) -------------'
+ .format(self.state))
+ self.broadcast_model_para()
+ else:
+ metrics = self.evaluate()
+ update_best_result(self.best_results,
+ metrics,
+ results_type='server_global_eval',
+ round_wise_update_key=self._cfg.eval.
+ best_res_update_round_wise_key)
+ formatted_logs = self._monitor.format_eval_res(
+ metrics,
+ rnd=self.state,
+ role='Server #',
+ forms=self._cfg.eval.report)
+ logger.info(formatted_logs)
+
+ def evaluate(self):
+ test_x = self.data['test']['x']
+ test_y = self.data['test']['y']
+ loss = np.mean(
+ np.log(1 + np.exp(-test_y * np.matmul(test_x, self.theta))))
+ acc = np.mean((test_y * np.matmul(test_x, self.theta)) > 0)
+
+ return {'test_loss': loss, 'test_acc': acc, 'test_total': len(test_y)}
diff --git a/materials/paper_list/FL-Attacker/README.md b/materials/paper_list/FL-Attacker/README.md
new file mode 100644
index 000000000..17862602b
--- /dev/null
+++ b/materials/paper_list/FL-Attacker/README.md
@@ -0,0 +1 @@
+The paper list goes here.
\ No newline at end of file
diff --git a/materials/paper_list/FL-Database/README.md b/materials/paper_list/FL-Database/README.md
new file mode 100644
index 000000000..b63d6bedb
--- /dev/null
+++ b/materials/paper_list/FL-Database/README.md
@@ -0,0 +1,25 @@
+# Federated Database
+
+# 2022
+| Title | Venue | Link |
+| ------------------------------------------------------------ | ---------- |---------------------------------------------|
+ | Conjunctive Queries with Comparisons | SIGMOD | [pdf](https://www.cse.ust.hk/~yike/CQC.pdf) |
+ | R2T: Instance-Optimal Truncation for Differentially Private Query Evaluation with Foreign Keys | SIGMOD | [pdf](https://www.cse.ust.hk/~yike/R2T.pdf) |
+
+# 2021
+| Title | Venue | Link |
+| --- | --- | --- |
+ | Approximate Range Counting Under Differential Privacy | SOCG | [pdf](https://drops.dagstuhl.de/opus/volltexte/2021/13844/pdf/LIPIcs-SoCG-2021-45.pdf) |
+ | Weighted Distince Sampling: Cardinality Estimation for SPJ Queries | SIGMOD | [pdf](https://dl.acm.org/doi/abs/10.1145/3448016.3452821) |
+ | Residual Sentivity for Differentially Private Multi-way Joins | SIGMOD | [pdf](https://dl.acm.org/doi/abs/10.1145/3448016.3452813) |
+
+# 2020
+| Title | Venue | Link |
+| --- | --- | --- |
+| Collecting and Analyzing Data Jointly from Multiple Services under Local Differential Privacy | VLDB | [pdf](https://dl.acm.org/doi/abs/10.14778/3407790.3407859) |
+| Improving Utility and Security of the Shuffler-based Differential Privacy | VLDB | [pdf](https://arxiv.org/abs/1908.11515) |
+
+# 2019
+| Title | Venue | Link |
+| --- | --- | --- |
+| Answering Multi-dimensional Analytical Queries under Local Differential Privacy | SIGMOD | [pdf](https://dl.acm.org/doi/abs/10.1145/3299869.3319891) |
diff --git a/materials/paper_list/FL-Incentive/README.md b/materials/paper_list/FL-Incentive/README.md
new file mode 100644
index 000000000..17862602b
--- /dev/null
+++ b/materials/paper_list/FL-Incentive/README.md
@@ -0,0 +1 @@
+The paper list goes here.
\ No newline at end of file
diff --git a/materials/paper_list/FL-NLP/README.md b/materials/paper_list/FL-NLP/README.md
new file mode 100644
index 000000000..afd812bd1
--- /dev/null
+++ b/materials/paper_list/FL-NLP/README.md
@@ -0,0 +1,47 @@
+## Federated Learning for NLP
+This list is constantly being updated. Feel free to contribute!
+
+
+### 2022
+| Title | Venue | Link |
+| --- | --- | --- |
+| FedBERT: When Federated Learning Meets Pre-Training | TIST | [pdf](https://dl.acm.org/doi/pdf/10.1145/3510033) |
+| FedKC: Federated Knowledge Composition for Multilingual Natural Language Understanding | WWW | [pdf](https://dl.acm.org/doi/pdf/10.1145/3485447.3511988) |
+
+
+### 2021
+| Title | Venue | Link |
+| --- | --- | --- |
+| FedMatch: Federated Learning Over Heterogeneous Question Answering Data | CIKM | [pdf](https://dl.acm.org/doi/pdf/10.1145/3459637.3482345), [code](https://github.com/Chriskuei/FedMatch) |
+| Federated Chinese Word Segmentation with Global Character Associations | ACL | [pdf](https://aclanthology.org/2021.findings-acl.376.pdf), [code](https://github.com/cuhksz-nlp/GCASeg) |
+| Improving Federated Learning for Aspect-based Sentiment Analysis via Topic Memories | EMNLP | [pdf](https://aclanthology.org/2021.emnlp-main.321.pdf), [code](https://github.com/cuhksz-nlp/ASA-TM) |
+| A Secure and Efficient Federated Learning Framework for NLP | EMNLP | [pdf](https://aclanthology.org/2021.emnlp-main.606.pdf) |
+| Distantly Supervised Relation Extraction in Federated Settings | EMNLP | [pdf](https://aclanthology.org/2021.findings-emnlp.52.pdf), [code](https://github.com/DianboWork/FedDS) |
+| FedNLP: Benchmarking Federated Learning Methods for Natural Language Processing Tasks | arXiv | [pdf](https://arxiv.org/pdf/2104.08815.pdf), [code](https://github.com/FedML-AI/FedNLP) |
+| Scaling Federated Learning for Fine-tuning of Large Language Models | arXiv | [pdf](https://arxiv.org/pdf/2102.00875.pdf) |
+
+
+### 2020
+| Title | Venue | Link |
+| --- | --- | --- |
+| Empirical Studies of Institutional Federated Learning For Natural Language Processing | EMNLP | [pdf](https://aclanthology.org/2020.findings-emnlp.55.pdf) |
+| FedED: Federated Learning via Ensemble Distillation for Medical Relation Extraction | EMNLP | [pdf](https://aclanthology.org/2020.emnlp-main.165.pdf) |
+| FedNER: Privacy-preserving Medical Named Entity Recognition with Federated Learning | arXiv | [pdf](https://arxiv.org/pdf/2003.09288.pdf) |
+| Federated Pretraining and Fine Tuning of BERT Using Clinical Notes from Multiple Silos | arXiv | [pdf](https://arxiv.org/pdf/2002.08562.pdf) |
+| Pretraining Federated Text Models for Next Word Prediction | arXiv | [pdf](https://arxiv.org/pdf/2005.04828.pdf), [code](https://github.com/federated-learning-experiments/fl-text-models) |
+
+
+### 2019
+| Title | Venue | Link |
+| --- | --- | --- |
+| Federated Learning of N-gram Language Models | CoNLL | [pdf](https://arxiv.org/pdf/1910.03432.pdf) |
+| Learning Private Neural Language Modeling with Attentive Aggregation | IJCNN | [pdf](https://arxiv.org/pdf/1812.07108.pdf) |
+| Federated Learning Of Out-Of-Vocabulary Words | arXiv | [pdf](https://arxiv.org/pdf/1903.10635.pdf) |
+| Federated Learning for Emoji Prediction in a Mobile Keyboard | arXiv | [pdf](https://arxiv.org/pdf/1906.04329.pdf) |
+
+
+### 2018
+| Title | Venue | Link |
+| --- | --- | --- |
+| Federated Learning for Mobile Keyboard Prediction | arXiv | [pdf](https://arxiv.org/pdf/1811.03604.pdf) |
+| Applied Federated Learning: Improving Google Keyboard Query Suggestions | arXiv | [pdf](https://arxiv.org/pdf/1812.02903.pdf) |
diff --git a/materials/paper_list/FL-Recommendation/README.md b/materials/paper_list/FL-Recommendation/README.md
new file mode 100644
index 000000000..dea3b1744
--- /dev/null
+++ b/materials/paper_list/FL-Recommendation/README.md
@@ -0,0 +1,31 @@
+## Federated Learning for Recommendation
+
+### 2022
+Coming soon!
+
+### 2021
+| Title | Venue | Link |
+| --- | --- | --- |
+| Efficient-FedRec: Efficient Federated Learning Framework for Privacy-Preserving News Recommendation | EMNLP | [pdf](https://aclanthology.org/2021.emnlp-main.223), [code](https://github.com/yjw1029/Efficient-FedRec) |
+| POI Recommendation with Federated Learning and Privacy Preserving in Cross Domain Recommendation | INFOCOM workshop | [pdf](https://ieeexplore.ieee.org/document/9484510) |
+| Fast-adapting and Privacy-preserving Federated Recommender System | VLDB Journal | [pdf](https://arxiv.org/abs/2104.00919) |
+| Stronger Privacy for Federated Collaborative Filtering with Implicit Feedback | RecSys | [pdf](https://arxiv.org/abs/2105.03941) |
+| A Payload Optimization Method for Federated Recommender Systems | RecSys | [pdf](https://arxiv.org/abs/2107.13078) |
+| Practical and Secure Federated Recommendation with Personalized Masks | arXiv | [pdf](https://arxiv.org/abs/2109.02464) |
+| FedGNN: Federated Graph Neural Network for Privacy-Preserving Recommendation | arXiv | [pdf](https://arxiv.org/abs/2102.04925) |
+| Federated Neural Collaborative Filtering | arXiv | [pdf](https://arxiv.org/abs/2106.04405) |
+
+### 2020
+| Title | Venue | Link |
+| --- | --- | --- |
+| Secure Federated Matrix Factorization | IEEE Intelligent Systems | [pdf](https://ieeexplore.ieee.org/abstract/document/9162459), [code](https://github.com/Di-Chai/FedMF) |
+| FedRec: Federated Recommendation With Explicit Feedback | IEEE Intelligent Systems | [pdf](https://ieeexplore.ieee.org/abstract/document/9170754), [code](https://csse.szu.edu.cn/staff/panwk/publications/FedRec/) |
+| Privacy-Preserving News Recommendation Model Learning | EMNLP | [pdf](https://aclanthology.org/2020.findings-emnlp.128/), [code](https://github.com/taoqi98/FedNewsRec) |
+| Federated Recommendation System via Differential Privacy | IEEE ISIT | [pdf](https://arxiv.org/abs/2005.06670) |
+| Meta Matrix Factorization for Federated Rating Predictions | SIGIR | [pdf](https://arxiv.org/abs/1910.10086), [code](https://github.com/TempSDU/MetaMF) |
+| FedFast: Going Beyond Average for Faster Training of Federated Recommender Systems | KDD | [pdf](https://dl.acm.org/doi/10.1145/3394486.3403176) |
+
+### 2019
+| Title | Venue | Link |
+| --- | --- | --- |
+| Federated Collaborative Filtering for Privacy-Preserving Personalized Recommendation System | arXiv | [pdf](https://arxiv.org/abs/1901.09888) |
diff --git a/materials/paper_list/Federated_Graph_Learning/README.md b/materials/paper_list/Federated_Graph_Learning/README.md
new file mode 100644
index 000000000..0a806e2e8
--- /dev/null
+++ b/materials/paper_list/Federated_Graph_Learning/README.md
@@ -0,0 +1,24 @@
+## Federated Learning for Graph
+
+### 2021
+
+| Title | Venue | Link |
+| ------------------------------------------------------------ | ---------- | ------------------------------------------------------------ |
+| Federated Graph Classification over Non-IID Graphs | NeurIPS | [pdf](https://proceedings.neurips.cc//paper/2021/file/9c6947bd95ae487c81d4e19d3ed8cd6f-Paper.pdf), [code](https://github.com/Oxfordblue7/GCFL) |
+| Subgraph Federated Learning with Missing Neighbor Generation | NeurIPS | [pdf](http://proceedings.neurips.cc/paper/2021/file/34adeb8e3242824038aa65460a47c29e-Paper.pdf), [code](https://github.com/zkhku/fedsage) |
+| Cross-Node Federated Graph Neural Network for Spatio-Temporal Data Modeling | KDD | [pdf](https://arxiv.org/pdf/2106.05223v1.pdf), [code](https://github.com/mengcz13/KDD2021_CNFGNN) |
+| Glint: Decentralized Federated Graph Learning with Traffic Throttling and Flow Scheduling | IEEE IWQoS | [pdf](https://ieeexplore.ieee.org/abstract/document/9521331) |
+| Differentially Private Federated Knowledge Graphs Embedding | CIKM | [pdf](https://arxiv.org/pdf/2105.07615v2.pdf), [code](https://github.com/HKUST-KnowComp/FKGE) |
+
+### 2020
+
+| Title | Venue | Link |
+| ----------------------------------------------------- | ----- | ------------------------------------------------------------ |
+| FedE: Embedding Knowledge Graphs in Federated Setting | IJCKG | [pdf](https://dl.acm.org/doi/fullHtml/10.1145/3502223.3502233), [code](https://github.com/zjukg/FedE) |
+
+### 2019
+
+| Title | Venue | Link |
+| ------------------------------------------------------------ | ----- | --------------------------------------------- |
+| Towards Federated Graph Learning for Collaborative Financial Crimes Detection | Arxiv | [pdf](https://arxiv.org/pdf/1909.12946v2.pdf) |
+
diff --git a/materials/paper_list/Federated_HPO/README.md b/materials/paper_list/Federated_HPO/README.md
new file mode 100644
index 000000000..caf077502
--- /dev/null
+++ b/materials/paper_list/Federated_HPO/README.md
@@ -0,0 +1,25 @@
+## Federated Hyperparameter Optimization
+
+### 2022
+Coming soon!
+
+### 2021
+
+| Title | Venue | Link |
+| --- | --- | --- |
+| Federated Hyperparameter Tuning: Challenges, Baselines, and Connections to Weight-Sharing | NeurIPS | [pdf](https://openreview.net/forum?id=p99rWde9fVJ), [code](https://github.com/mkhodak/FedEx) |
+| FLoRA: Single-shot Hyper-parameter Optimization for Federated Learning | NeurIPS workshop | [pdf](https://neurips2021workshopfl.github.io/NFFL-2021/papers/2021/Zhou2021.pdf) |
+| FedTune: Automatic Tuning of Federated Learning Hyper-Parameters from System Perspective | Arxiv | [pdf](https://arxiv.org/pdf/2110.03061.pdf) |
+
+### 2020
+
+| Title | Venue | Link |
+| --- | --- | --- |
+| Federated Bayesian Optimization via Thompson Sampling | NeurIPS | [pdf](https://proceedings.neurips.cc/paper/2020/file/6dfe08eda761bd321f8a9b239f6f4ec3-Paper.pdf) |
+
+### 2019
+
+| Title | Venue | Link |
+| --- | --- | --- |
+| Learning Rate Adaptation for Differentially Private Learning | AISTATS | [pdf](http://proceedings.mlr.press/v108/koskela20a.html) |
+| Robust Federated Learning Through Representation Matching and Adaptive Hyper-parameters | Arxiv | [pdf](https://arxiv.org/pdf/1912.13075.pdf) |
diff --git a/materials/paper_list/Personalized_FL/README.md b/materials/paper_list/Personalized_FL/README.md
new file mode 100644
index 000000000..08e18b345
--- /dev/null
+++ b/materials/paper_list/Personalized_FL/README.md
@@ -0,0 +1,38 @@
+## Personalized Federated Learning
+This list is constantly being updated. Feel free to contribute!
+
+### 2022
+| Title | Venue | Link |
+| --- | --- | --- |
+| Towards Personalized Federated Learning.| Transactions on Neural Networks and Learning Systems | [pdf](https://arxiv.org/pdf/2103.00710)|
+
+### 2021
+| Title | Venue | Link |
+| --- | --- | --- |
+| FedBN: Federated Learning on Non-IID Features via Local Batch Normalization | ICLR | [pdf](https://arxiv.org/pdf/2102.07623), [code](https://github.com/med-air/FedBN) |
+| Ditto: Fair and robust federated learning through personalization | ICML | [pdf](https://arxiv.org/pdf/2012.04221), [code](https://github.com/litian96/ditto) |
+| Parameterized Knowledge Transfer for Personalized Federated Learning | NeurIPS | [pdf](https://arxiv.org/pdf/2111.02862) |
+| Personalized Federated Learning with Gaussian Processes | NeurIPS | [pdf](https://arxiv.org/pdf/2106.15482), [code](https://github.com/IdanAchituve/pFedGP) |
+| Federated muli-task learning under a mixture of distributions | NeurIPS | [pdf](https://arxiv.org/pdf/2108.10252), [code](https://github.com/omarfoq/FedEM) |
+| Personalized Federated Learning using Hypernetworks | ICML | [pdf](https://arxiv.org/pdf/2103.04628), [code](https://github.com/AvivSham/pFedHN) |
+| Personalized Federated Learning with First Order Model Optimization | ICLR | [pdf](https://arxiv.org/pdf/2012.08565), [code](https://github.com/NVlabs/FedFomo) |
+| Exploiting Shared Representations for Personalized Federated Learning | ICML | [pdf](https://arxiv.org/pdf/2102.07078.pdf), [code](https://github.com/lgcollins/FedRep) |
+
+
+### 2020
+| Title | Venue | Link |
+| --- | --- | --- |
+| Personalized federated learning with theoretical guarantees: A model-agnostic meta-learning approach | NeurIPS | [pdf](https://proceedings.neurips.cc/paper/2020/file/24389bfe4fe2eba8bf9aa9203a44cdad-Paper.pdf) |
+| Personalized federated learning with moreau envelopes | NeurIPS | [pdf](https://proceedings.neurips.cc/paper/2020/file/f4f1f13c8289ac1b1ee0ff176b56fc60-Paper.pdf), [code](https://github.com/CharlieDinh/pFedMe) |
+| An efficient framework for clustered federated learning | NeurIPS | [pdf](https://arxiv.org/pdf/2006.04088), [code](https://github.com/jichan3751/ifca) |
+| Adaptive personalized federated learning | arXiv | [pdf](https://arxiv.org/pdf/2003.13461), [code](https://github.com/MLOPTPSU/FedTorch) |
+| Lower bounds and optimal algorithms for personalized federated learning | NeurIPS | [pdf](https://arxiv.org/pdf/2010.02372)|
+| Personalized Federated Learning With Differential Privacy | IEEE Internet of Things Journal | [pdf](https://par.nsf.gov/servlets/purl/10183051)|
+| Personalized federated learning for intelligent IoT applications: A cloud-edge based framework | IEEE Open Journal of the Computer Society | [pdf](https://ieeexplore.ieee.org/iel7/8782664/8821528/09090366.pdf)|
+| Survey of Personalization Techniques for Federated Learning | 2020 Fourth World Conference on Smart Trends in Systems, Security and Sustainability (WorldS4) | [pdf](https://par.nsf.gov/servlets/purl/10183051)|
+
+
+### 2019
+| Title | Venue | Link |
+| --- | --- | --- |
+| Federated Evaluation of On-device Personalization| arXiv | [pdf](https://arxiv.org/abs/1910.10252) |
\ No newline at end of file
diff --git a/materials/tutorial/KDD_2022/README.md b/materials/tutorial/KDD_2022/README.md
new file mode 100644
index 000000000..04026270f
--- /dev/null
+++ b/materials/tutorial/KDD_2022/README.md
@@ -0,0 +1 @@
+## The material of KDD 2022 tutorial will go here.
\ No newline at end of file
diff --git a/paper_plot/results_all/ConvNet2@cifar-10/Ditto_ASR.png b/paper_plot/results_all/ConvNet2@cifar-10/Ditto_ASR.png
new file mode 100644
index 000000000..e98261525
Binary files /dev/null and b/paper_plot/results_all/ConvNet2@cifar-10/Ditto_ASR.png differ
diff --git a/paper_plot/results_all/ConvNet2@cifar-10/Ditto_C_Acc.png b/paper_plot/results_all/ConvNet2@cifar-10/Ditto_C_Acc.png
new file mode 100644
index 000000000..ca10b5f49
Binary files /dev/null and b/paper_plot/results_all/ConvNet2@cifar-10/Ditto_C_Acc.png differ
diff --git a/paper_plot/results_all/ConvNet2@cifar-10/FT_ASR.png b/paper_plot/results_all/ConvNet2@cifar-10/FT_ASR.png
new file mode 100644
index 000000000..03114301e
Binary files /dev/null and b/paper_plot/results_all/ConvNet2@cifar-10/FT_ASR.png differ
diff --git a/paper_plot/results_all/ConvNet2@cifar-10/FT_C_Acc.png b/paper_plot/results_all/ConvNet2@cifar-10/FT_C_Acc.png
new file mode 100644
index 000000000..34797cd66
Binary files /dev/null and b/paper_plot/results_all/ConvNet2@cifar-10/FT_C_Acc.png differ
diff --git a/paper_plot/results_all/ConvNet2@cifar-10/FedAvg_ASR.png b/paper_plot/results_all/ConvNet2@cifar-10/FedAvg_ASR.png
new file mode 100644
index 000000000..ca13b669a
Binary files /dev/null and b/paper_plot/results_all/ConvNet2@cifar-10/FedAvg_ASR.png differ
diff --git a/paper_plot/results_all/ConvNet2@cifar-10/FedAvg_C_Acc.png b/paper_plot/results_all/ConvNet2@cifar-10/FedAvg_C_Acc.png
new file mode 100644
index 000000000..25f0fc1a5
Binary files /dev/null and b/paper_plot/results_all/ConvNet2@cifar-10/FedAvg_C_Acc.png differ
diff --git a/paper_plot/results_all/ConvNet2@cifar-10/FedBN_ASR.png b/paper_plot/results_all/ConvNet2@cifar-10/FedBN_ASR.png
new file mode 100644
index 000000000..c9c024244
Binary files /dev/null and b/paper_plot/results_all/ConvNet2@cifar-10/FedBN_ASR.png differ
diff --git a/paper_plot/results_all/ConvNet2@cifar-10/FedBN_C_Acc.png b/paper_plot/results_all/ConvNet2@cifar-10/FedBN_C_Acc.png
new file mode 100644
index 000000000..8cd208c3d
Binary files /dev/null and b/paper_plot/results_all/ConvNet2@cifar-10/FedBN_C_Acc.png differ
diff --git a/paper_plot/results_all/ConvNet2@cifar-10/FedEM_ASR.png b/paper_plot/results_all/ConvNet2@cifar-10/FedEM_ASR.png
new file mode 100644
index 000000000..3144b3327
Binary files /dev/null and b/paper_plot/results_all/ConvNet2@cifar-10/FedEM_ASR.png differ
diff --git a/paper_plot/results_all/ConvNet2@cifar-10/FedEM_C_Acc.png b/paper_plot/results_all/ConvNet2@cifar-10/FedEM_C_Acc.png
new file mode 100644
index 000000000..d48b37a89
Binary files /dev/null and b/paper_plot/results_all/ConvNet2@cifar-10/FedEM_C_Acc.png differ
diff --git a/paper_plot/results_all/ConvNet2@cifar-10/FedRep_ASR.png b/paper_plot/results_all/ConvNet2@cifar-10/FedRep_ASR.png
new file mode 100644
index 000000000..59d57d2dd
Binary files /dev/null and b/paper_plot/results_all/ConvNet2@cifar-10/FedRep_ASR.png differ
diff --git a/paper_plot/results_all/ConvNet2@cifar-10/FedRep_C_Acc.png b/paper_plot/results_all/ConvNet2@cifar-10/FedRep_C_Acc.png
new file mode 100644
index 000000000..8993d5d21
Binary files /dev/null and b/paper_plot/results_all/ConvNet2@cifar-10/FedRep_C_Acc.png differ
diff --git a/paper_plot/results_all/ConvNet2@cifar-10/pFedMe_ASR.png b/paper_plot/results_all/ConvNet2@cifar-10/pFedMe_ASR.png
new file mode 100644
index 000000000..d9379175f
Binary files /dev/null and b/paper_plot/results_all/ConvNet2@cifar-10/pFedMe_ASR.png differ
diff --git a/paper_plot/results_all/ConvNet2@cifar-10/pFedMe_C_Acc.png b/paper_plot/results_all/ConvNet2@cifar-10/pFedMe_C_Acc.png
new file mode 100644
index 000000000..796a6c18b
Binary files /dev/null and b/paper_plot/results_all/ConvNet2@cifar-10/pFedMe_C_Acc.png differ
diff --git a/paper_plot/results_all/ResNet18@cifar-10/Ditto_ASR.png b/paper_plot/results_all/ResNet18@cifar-10/Ditto_ASR.png
new file mode 100644
index 000000000..66a0be16f
Binary files /dev/null and b/paper_plot/results_all/ResNet18@cifar-10/Ditto_ASR.png differ
diff --git a/paper_plot/results_all/ResNet18@cifar-10/Ditto_C_Acc.png b/paper_plot/results_all/ResNet18@cifar-10/Ditto_C_Acc.png
new file mode 100644
index 000000000..35b119da0
Binary files /dev/null and b/paper_plot/results_all/ResNet18@cifar-10/Ditto_C_Acc.png differ
diff --git a/paper_plot/results_all/ResNet18@cifar-10/FT_ASR.png b/paper_plot/results_all/ResNet18@cifar-10/FT_ASR.png
new file mode 100644
index 000000000..af5752ceb
Binary files /dev/null and b/paper_plot/results_all/ResNet18@cifar-10/FT_ASR.png differ
diff --git a/paper_plot/results_all/ResNet18@cifar-10/FT_C_Acc.png b/paper_plot/results_all/ResNet18@cifar-10/FT_C_Acc.png
new file mode 100644
index 000000000..71ff39278
Binary files /dev/null and b/paper_plot/results_all/ResNet18@cifar-10/FT_C_Acc.png differ
diff --git a/paper_plot/results_all/ResNet18@cifar-10/FedAvg_ASR.png b/paper_plot/results_all/ResNet18@cifar-10/FedAvg_ASR.png
new file mode 100644
index 000000000..fcdbf6718
Binary files /dev/null and b/paper_plot/results_all/ResNet18@cifar-10/FedAvg_ASR.png differ
diff --git a/paper_plot/results_all/ResNet18@cifar-10/FedAvg_C_Acc.png b/paper_plot/results_all/ResNet18@cifar-10/FedAvg_C_Acc.png
new file mode 100644
index 000000000..23204fe3f
Binary files /dev/null and b/paper_plot/results_all/ResNet18@cifar-10/FedAvg_C_Acc.png differ
diff --git a/paper_plot/results_all/ResNet18@cifar-10/FedBN_ASR.png b/paper_plot/results_all/ResNet18@cifar-10/FedBN_ASR.png
new file mode 100644
index 000000000..075b7d074
Binary files /dev/null and b/paper_plot/results_all/ResNet18@cifar-10/FedBN_ASR.png differ
diff --git a/paper_plot/results_all/ResNet18@cifar-10/FedBN_C_Acc.png b/paper_plot/results_all/ResNet18@cifar-10/FedBN_C_Acc.png
new file mode 100644
index 000000000..b7a5d1aa9
Binary files /dev/null and b/paper_plot/results_all/ResNet18@cifar-10/FedBN_C_Acc.png differ
diff --git a/paper_plot/results_all/ResNet18@cifar-10/FedRep_ASR.png b/paper_plot/results_all/ResNet18@cifar-10/FedRep_ASR.png
new file mode 100644
index 000000000..d219ff149
Binary files /dev/null and b/paper_plot/results_all/ResNet18@cifar-10/FedRep_ASR.png differ
diff --git a/paper_plot/results_all/ResNet18@cifar-10/FedRep_C_Acc.png b/paper_plot/results_all/ResNet18@cifar-10/FedRep_C_Acc.png
new file mode 100644
index 000000000..1ebdb71e3
Binary files /dev/null and b/paper_plot/results_all/ResNet18@cifar-10/FedRep_C_Acc.png differ
diff --git a/paper_plot/results_all/ResNet18@cifar-10/pFedMe_ASR.png b/paper_plot/results_all/ResNet18@cifar-10/pFedMe_ASR.png
new file mode 100644
index 000000000..8cb73f709
Binary files /dev/null and b/paper_plot/results_all/ResNet18@cifar-10/pFedMe_ASR.png differ
diff --git a/paper_plot/results_all/ResNet18@cifar-10/pFedMe_C_Acc.png b/paper_plot/results_all/ResNet18@cifar-10/pFedMe_C_Acc.png
new file mode 100644
index 000000000..8fae0cdc7
Binary files /dev/null and b/paper_plot/results_all/ResNet18@cifar-10/pFedMe_C_Acc.png differ
diff --git a/rar/acknow.txt b/rar/acknow.txt
new file mode 100644
index 000000000..b982eceec
--- /dev/null
+++ b/rar/acknow.txt
@@ -0,0 +1,92 @@
+ ACKNOWLEDGMENTS
+
+* We used "Screaming Fast Galois Field Arithmetic Using Intel
+ SIMD Instructions" paper by James S. Plank, Kevin M. Greenan
+ and Ethan L. Miller to improve Reed-Solomon coding performance.
+ Also we are grateful to Artem Drobanov and Bulat Ziganshin
+ for samples and ideas allowed to make Reed-Solomon coding
+ more efficient.
+
+* RAR text compression algorithm is based on Dmitry Shkarin PPMII
+ and Dmitry Subbotin carryless rangecoder public domain source code.
+ You may find it in ftp.elf.stuba.sk/pub/pc/pack.
+
+* RAR encryption includes parts of code from Szymon Stefanek
+ and Brian Gladman AES implementations also as Steve Reid SHA-1 source.
+
+ ---------------------------------------------------------------------------
+ Copyright (c) 2002, Dr Brian Gladman < >, Worcester, UK.
+ All rights reserved.
+
+ LICENSE TERMS
+
+ The free distribution and use of this software in both source and binary
+ form is allowed (with or without changes) provided that:
+
+ 1. distributions of this source code include the above copyright
+ notice, this list of conditions and the following disclaimer;
+
+ 2. distributions in binary form include the above copyright
+ notice, this list of conditions and the following disclaimer
+ in the documentation and/or other associated materials;
+
+ 3. the copyright holder's name is not used to endorse products
+ built using this software without specific written permission.
+
+ ALTERNATIVELY, provided that this notice is retained in full, this product
+ may be distributed under the terms of the GNU General Public License (GPL),
+ in which case the provisions of the GPL apply INSTEAD OF those given above.
+
+ DISCLAIMER
+
+ This software is provided 'as is' with no explicit or implied warranties
+ in respect of its properties, including, but not limited to, correctness
+ and/or fitness for purpose.
+ ---------------------------------------------------------------------------
+
+ Source code of this package also as other cryptographic technology
+ and computing project related links are available on Brian Gladman's
+ web site: http://www.gladman.me.uk
+
+* RAR uses CRC32 function based on Intel Slicing-by-8 algorithm.
+ Original Intel Slicing-by-8 code is available here:
+
+ http://sourceforge.net/projects/slicing-by-8/
+
+ Original Intel Slicing-by-8 code is licensed under BSD License
+ available at http://www.opensource.org/licenses/bsd-license.html
+
+ Copyright (c) 2004-2006 Intel Corporation.
+ All Rights Reserved
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ Redistributions of source code must retain the above copyright notice,
+ this list of conditions and the following disclaimer.
+
+ Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer
+ in the documentation and/or other materials provided with
+ the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
+ OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
+ SUCH DAMAGE.
+
+* RAR archives may optionally include BLAKE2sp hash ( https://blake2.net ),
+ designed by Jean-Philippe Aumasson, Samuel Neves, Zooko Wilcox-O'Hearn
+ and Christian Winnerlein.
+
+* Useful hints provided by Alexander Khoroshev and Bulat Ziganshin allowed
+ to significantly improve RAR compression and speed.
diff --git a/rar/default.sfx b/rar/default.sfx
new file mode 100755
index 000000000..a2a215669
Binary files /dev/null and b/rar/default.sfx differ
diff --git a/rar/license.txt b/rar/license.txt
new file mode 100644
index 000000000..82af4af51
--- /dev/null
+++ b/rar/license.txt
@@ -0,0 +1,127 @@
+ END USER LICENSE AGREEMENT
+
+ The following agreement regarding RAR (and its Windows version - WinRAR)
+ archiver - referred to as "software" - is made between win.rar GmbH -
+ referred to as "licensor" - and anyone who is installing, accessing
+ or in any other way using the software - referred to as "user".
+
+ 1. The author and holder of the copyright of the software is
+ Alexander L. Roshal. The licensor and as such issuer of the license
+ and bearer of the worldwide exclusive usage rights including the rights
+ to reproduce, distribute and make the software available to the public
+ in any form is win.rar GmbH, Marienstr. 12, 10117 Berlin, Germany.
+
+ 2. The software is distributed as try before you buy. This means that
+ anyone may use the software during a test period of a maximum of 40 days
+ at no charge. Following this test period, the user must purchase
+ a license to continue using the software.
+
+ 3. The software's trial version may be freely distributed, with exceptions
+ noted below, provided the distribution package is not modified in any way.
+
+ a. Nobody may distribute separate parts of the package, with the exception
+ of the UnRAR components, without written permission.
+
+ b. The software's unlicensed trial version may not be distributed
+ inside of any other software package without written permission.
+ The software must remain in the original unmodified installation
+ file for download without any barrier and conditions to the user
+ such as collecting fees for the download or making the download
+ conditional on the user giving his contact data.
+
+ c. The unmodified installation file of WinRAR must be provided pure
+ and unpaired. Any bundling is interdicted. In particular the use
+ of any install or download software which is providing any kind
+ of download bundles is prohibited unless granted by win.rar GmbH
+ in written form.
+
+ d. Hacks/cracks, keys or key generators may not be included, pointed to
+ or referred to by the distributor of the trial version.
+
+ e. In case of violation of the precedent conditions the allowance
+ lapses immediately and automatically.
+
+ 4. The trial version of the software can display a registration reminder
+ dialog. Depending on the software version and configuration such dialog
+ can contain either a predefined text and links loaded locally
+ or a web page loaded from the internet. Such web page can contain
+ licensing instructions or other materials according to the licensor's
+ choice, including advertisement. When opening a web page, the software
+ transfers only those parameters which are technically required
+ by HTTP protocol to successfully open a web page in a browser.
+
+ 5. The software is distributed "as is". No warranty of any kind is expressed
+ or implied. You use at your own risk. Neither the author, the licensor
+ nor the agents of the licensor will be liable for data loss, damages,
+ loss of profits or any other kind of loss while using or misusing
+ this software.
+
+ 6. There are 2 basic types of licenses issued for the software. These are:
+
+ a. A single computer usage license. The user purchases one license to
+ use the software on one computer.
+
+ Home users may use their single computer usage license on all
+ computers and mobile devices (USB drive, external hard drive, etc.)
+ which are property of the license owner.
+
+ Business users require one license per computer or mobile device
+ on which the software is installed.
+
+ b. A multiple usage license. The user purchases a number of usage
+ licenses for use, by the purchaser or the purchaser's employees
+ on the same number of computers.
+
+ In a network (server/client) environment the user must purchase
+ a license copy for each separate client (workstation) on which
+ the software is installed, used or accessed. A separate license copy
+ for each client (workstation) is needed regardless of whether
+ the clients (workstations) will use the software simultaneously
+ or at different times. If for example you wish to have 9 different
+ clients (workstations) in your network with access to RAR,
+ you must purchase 9 license copies.
+
+ A user who purchased a license, is granted a non-exclusive right to use
+ the software on as many computers as defined by the licensing terms above
+ according to the number of licenses purchased, for any legal purpose.
+
+ 7. There are no additional license fees, apart from the cost of the license,
+ associated with the creation and distribution of RAR archives,
+ volumes, self-extracting archives or self-extracting volumes.
+ Owners of a license may use their copies of the software to produce
+ archives and self-extracting archives and to distribute those archives
+ free of any additional royalties.
+
+ 8. The licensed software may not be rented or leased but may be permanently
+ transferred, in its entirety, if the recipient agrees to the terms of
+ this license.
+
+ 9. To buy a license, please read the file order.htm provided with
+ the software for details.
+
+ 10. You may not use, copy, emulate, clone, rent, lease, sell, modify,
+ decompile, disassemble, otherwise reverse engineer, or transfer
+ the licensed software, or any subset of the licensed software,
+ except as provided for in this agreement. Any such unauthorized use
+ shall result in immediate and automatic termination of this license
+ and may result in criminal and/or civil prosecution.
+
+ Neither RAR binary code, WinRAR binary code, UnRAR source
+ or UnRAR binary code may be used or reverse engineered to re-create
+ the RAR compression algorithm, which is proprietary, without written
+ permission.
+
+ The software may be using components developed and/or copyrighted
+ by third parties. Please read "Acknowledgments" help file topic
+ for WinRAR or acknow.txt text file for other RAR versions for details.
+
+ 11. This License Agreement is construed solely and exclusively under
+ German law. If you are a merchant, the courts at the registered office
+ of win.rar GmbH in Berlin/Germany shall have exclusive jurisdiction
+ for any and all disputes arising in connection with this License
+ Agreement or its validity.
+
+ 12. Installing and using the software signifies acceptance of these terms
+ and conditions of the license. If you do not agree with the terms of this
+ license, you must remove all software files from your storage devices
+ and cease to use the software.
diff --git a/rar/makefile b/rar/makefile
new file mode 100644
index 000000000..d4798c8af
--- /dev/null
+++ b/rar/makefile
@@ -0,0 +1,13 @@
+###################################################################
+# Installing RAR executables, configuration files and SFX modules #
+# to appropriate directories #
+###################################################################
+
+PREFIX=/usr/local
+
+install:
+ mkdir -p $(PREFIX)/bin
+ mkdir -p $(PREFIX)/lib
+ cp rar unrar $(PREFIX)/bin
+ cp rarfiles.lst /etc
+ cp default.sfx $(PREFIX)/lib
\ No newline at end of file
diff --git a/rar/order.htm b/rar/order.htm
new file mode 100644
index 000000000..8bd93144c
--- /dev/null
+++ b/rar/order.htm
@@ -0,0 +1,85 @@
+
+
+
+
+How to buy WinRAR and RAR license
+
+
+
+
+
+
+
+
+
How to buy WinRAR and RAR license.
+
+
If you wish to use WinRAR and RAR after the evaluation period of 40 days,
+you need to purchase its license from one of the regional dealers
+listed here.
+You can also check the latest price list and buy on-line at
+www.rarlab.com.
+
+
Upon receipt of your registration fee you will receive an email
+containing a registration key file corresponding to the user name
+string which you have chosen. Please specify a valid email address
+when buying the licence, as it will be used to send you the key file.
+
+
The registration email will also contain all necessary instructions,
+so please just follow them. Below, we provide a brief explanation of
+the typical registration procedure, but instructions in the email are
+more up to date than this file and should thus have a higher precedence.
+
+
If you use WinRAR, you will need to copy the registration key file
+(rarreg.key) to a WinRAR folder or to %APPDATA%\WinRAR folder.
+By default WinRAR folder is "C:\Program Files\WinRAR", but it can be
+changed by a user when installing WinRAR. You can also drag rarreg.key file
+and drop it to WinRAR window to register.
+
+
If the key is archived in a .rar or .zip file, please extract
+rarreg.key from the archive before copying it. If archive name is
+rarkey.rar, another way to install the key file is to open such
+archive in WinRAR and answer "Yes" to confirmation prompt.
+
+
If you use RAR/Unix and RAR for OS X, you should copy rarreg.key
+to your home directory or to one of the following directories:
+/etc, /usr/lib, /usr/local/lib, /usr/local/etc. You may rename it
+to .rarreg.key or .rarregkey, if you wish, but rarreg.key is also valid.
+
+
WinRAR, RAR for Unix and OS X now use the same registration key
+format, so you can use the same key with current WinRAR and RAR versions
+for all mentioned platforms. It is not guaranteed for WinRAR and RAR
+versions that are not equal to version included to this distributive.
+For example, versions prior to 2.60 used different keys.
+
+
Please send your further questions about sales and licensing
+to .
+English, French, German or Spanish please.
+
+
+
diff --git a/rar/rar b/rar/rar
new file mode 100755
index 000000000..106e15011
Binary files /dev/null and b/rar/rar differ
diff --git a/rar/rar.txt b/rar/rar.txt
new file mode 100644
index 000000000..bdc071f6a
--- /dev/null
+++ b/rar/rar.txt
@@ -0,0 +1,2464 @@
+ User's Manual
+ ~~~~~~~~~~~~~
+ RAR 5.61 console version
+ ~~~~~~~~~~~~~~~~~~~~~~~~
+
+ =-=-=-=-=-=-=-=-=-=-=-=-=-=-
+ Welcome to the RAR Archiver!
+ -=-=-=-=-=-=-=-=-=-=-=-=-=-=
+
+ Introduction
+ ~~~~~~~~~~~~
+
+ RAR is a console application allowing to manage archive files
+ in command line mode. RAR provides compression, encryption,
+ data recovery and many other functions described in this manual.
+
+ RAR supports only RAR format archives, which have .rar file name
+ extension by default. ZIP and other formats are not supported.
+ Even if you specify .zip extension when creating an archive, it will
+ still be in RAR format. Windows users may install WinRAR, which supports
+ more archive types including RAR and ZIP formats.
+
+ WinRAR provides both graphical user interface and command line mode.
+ While console RAR and GUI WinRAR have the similar command line syntax,
+ some differences exist. So it is recommended to use this rar.txt manual
+ for console RAR (rar.exe in case of Windows version) and winrar.chm
+ WinRAR help file for GUI WinRAR (winrar.exe).
+
+
+ Configuration file
+ ~~~~~~~~~~~~~~~~~~
+
+ RAR and UnRAR for Unix read configuration information from .rarrc file
+ in a user's home directory (stored in HOME environment variable)
+ or in /etc directory.
+
+ RAR and UnRAR for Windows read configuration information from rar.ini file,
+ placed in the same directory as the rar.exe file.
+
+ This file may contain the following string:
+
+ switches=
+
+ For example:
+
+ switches=-m5 -s
+
+ It is also possible to specify separate switch sets for individual
+ RAR commands using the following syntax:
+
+ switches_=
+
+ For example:
+
+ switches_a=-m5 -s
+ switches_x=-o+
+
+
+
+ Environment variable
+ ~~~~~~~~~~~~~~~~~~~~
+
+ Default parameters may be added to the RAR command line by establishing
+ an environment variable "RAR".
+
+ For instance, in Unix following lines may be added to your profile:
+
+ RAR='-s -md1024'
+ export RAR
+
+ RAR will use this string as default parameters in the command line and
+ will create "solid" archives with 1024 KB sliding dictionary size.
+
+ RAR handles options with priority as following:
+
+ command line switches highest priority
+ switches in the RAR variable lower priority
+ switches saved in configuration file lowest priority
+
+
+ Log file
+ ~~~~~~~~
+
+ If the switch -ilog is specified in the command line or configuration
+ file, RAR will write informational messages, concerning errors
+ encountered while processing archives, into a log file. Read switch
+ -ilog description for more details.
+
+
+ The file order list for solid archiving - rarfiles.lst
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ rarfiles.lst contains a user-defined file list, which tells RAR
+ the order in which to add files to a solid archive. It may contain
+ file names, wildcards and special entry - $default. The default
+ entry defines the place in order list for files not matched
+ with other entries in this file. The comment character is ';'.
+
+ In Windows this file should be placed in the same directory as RAR
+ or in %APPDATA%\WinRAR directory, in Unix - to the user's home directory
+ or in /etc.
+
+ Tips to provide improved compression and speed of operation:
+
+ - similar files should be grouped together in the archive;
+ - frequently accessed files should be placed at the beginning.
+
+ Normally masks placed nearer to the top of list have a higher priority,
+ but there is an exception from this rule. If rarfiles.lst contains such
+ two masks that all files matched by one mask are also matched by another,
+ that mask which matches a smaller subset of file names will have higher
+ priority regardless of its position in the list. For example, if you have
+ *.cpp and f*.cpp masks, f*.cpp has a higher priority, so the position of
+ 'filename.cpp' will be chosen according to 'f*.cpp', not '*.cpp'.
+
+
+ RAR command line syntax
+ ~~~~~~~~~~~~~~~~~~~~~~~
+
+ Syntax
+
+ RAR [ - ] [ <@listfiles...> ]
+ [ ] [ ]
+
+ Description
+
+ Command line options (commands and switches) provide control of
+ creating and managing archives with RAR. The command is a string (or a
+ single letter) which commands RAR to perform a corresponding action.
+ Switches are designed to modify the way RAR performs the action. Other
+ parameters are archive name and files to be archived into or extracted
+ from the archive.
+
+ Listfiles are plain text files that contain names of files to process.
+ File names should start at the first column. It is possible to
+ put comments to the listfile after // characters. For example,
+ you may create backup.lst containing the following strings:
+
+ c:\work\doc\*.txt //backup text documents
+ c:\work\image\*.bmp //backup pictures
+ c:\work\misc
+
+ and then run:
+
+ rar a backup @backup.lst
+
+ If you wish to read file names from stdin (standard input),
+ specify the empty listfile name (just @).
+
+ By default, console RAR uses the single byte encoding in list files,
+ but it can be redefined with -scl switch.
+
+ You may specify both usual file names and list files in the same
+ command line. If neither files nor listfiles are specified,
+ then *.* is implied and RAR will process all files.
+
+ Many RAR commands, such as extraction, test or list, allow to use
+ wildcards in archive name. If no extension is specified in archive
+ mask, RAR assumes .rar, so * means all archives with .rar extension.
+ If you need to process all archives without extension, use *. mask.
+ *.* mask selects all files. Wildcards in archive name are not allowed
+ when archiving and deleting.
+
+ In Unix you need to enclose RAR command line parameters containing
+ wildcards in single or double quotes to prevent their expansion
+ by Unix shell. For example, this command will extract *.asm files
+ from all *.rar archives in current directory:
+
+ rar e '*.rar' '*.asm'
+
+
+ Command could be any of the following:
+
+ a Add files to archive.
+
+ Examples:
+
+ 1) add all *.hlp files from the current directory to
+ the archive help.rar:
+
+ rar a help *.hlp
+
+ 2) archive all files from the current directory and subdirectories
+ to 362000 bytes size solid, self-extracting volumes
+ and add the recovery record to each volume:
+
+ rar a -r -v362 -s -sfx -rr save
+
+ Because no file names are specified, all files (*) are assumed.
+
+ 3) as a special exception, if directory name is specified as
+ an argument and if directory name does not include file masks
+ and trailing backslash, the entire contents of the directory
+ and all subdirectories will be added to the archive even
+ if switch -r is not specified.
+
+ The following command will add all files from the directory
+ Bitmaps and its subdirectories to the RAR archive Pictures.rar:
+
+ rar a Pictures.rar Bitmaps
+
+ 4) if directory name includes file masks or trailing backslashes,
+ normal rules apply and you need to specify switch -r to process
+ its subdirectories.
+
+ The following command will add all files from directory Bitmaps,
+ but not from its subdirectories, because switch -r is not
+ specified:
+
+ rar a Pictures.rar Bitmaps\*
+
+
+ c Add archive comment. Comments are displayed while the archive is
+ being processed. Comment length is limited to 256 KB.
+
+ Examples:
+
+ rar c distrib.rar
+
+ Also comments may be added from a file using -z[file] switch.
+ The following command adds a comment from info.txt file:
+
+ rar c -zinfo.txt dummy
+
+
+ ch Change archive parameters.
+
+ This command can be used with most of archive modification
+ switches to modify archive parameters. It is especially
+ convenient for switches like -cl, -cu, -tl, which do not
+ have a dedicated command.
+
+ It is not able to recompress, encrypt or decrypt archive data
+ and it cannot merge or create volumes. If used without any
+ switches, 'ch' command just copies the archive data without
+ modification.
+
+ Example:
+
+ Set archive time to latest file:
+
+ rar ch -tl files.rar
+
+
+ cw Write archive comment to specified file.
+
+ Format of output file depends on -sc switch.
+
+ If output file name is not specified, comment data will be
+ sent to stdout.
+
+ Examples:
+
+ 1) rar cw arc comment.txt
+
+ 2) rar cw -scuc arc unicode.txt
+
+ 3) rar cw arc
+
+
+ d Delete files from archive. If this command removes all files
+ from archive, the empty archive is removed.
+
+
+ e Extract files without archived paths.
+
+ Extract files excluding their path component, so all files
+ are created in the same destination directory.
+
+ Use 'x' command if you wish to extract full pathnames.
+
+ Example:
+
+ rar e -or html.rar *.css css\
+
+ extract all *.css files from html.rar archive to 'css' folder
+ excluding archived paths. Rename extracted files automatically
+ in case several files have the same name.
+
+
+ f Freshen files in archive. Updates archived files older
+ than files to add. This command will not add new files
+ to the archive.
+
+
+ i[i|c|h|t]=
+ Find string in archives.
+
+ Supports following optional parameters:
+
+ i - case insensitive search (default);
+
+ c - case sensitive search;
+
+ h - hexadecimal search;
+
+ t - use ANSI, Unicode and OEM character tables (Windows only);
+
+ If no parameters are specified, it is possible to use
+ the simplified command syntax i instead of i=
+
+ It is allowed to specify 't' modifier with other parameters,
+ for example, ict=string performs case sensitive search
+ using all mentioned above character tables.
+
+ Examples:
+
+ 1) rar "ic=first level" -r c:\*.rar *.txt
+
+ Perform case sensitive search of "first level" string
+ in *.txt files in *.rar archives on the disk c:
+
+ 2) rar ih=f0e0aeaeab2d83e3a9 -r e:\texts\*.rar
+
+ Search for hex string f0 e0 ae ae ab 2d 83 e3 a9
+ in rar archives in e:\texts directory.
+
+
+ k Lock archive. Any command which intends to change the archive
+ will be ignored.
+
+ Example:
+
+ rar k final.rar
+
+
+ l[t[a],b]
+ List archive contents [technical [all], bare].
+
+ 'l' command lists archived file attributes, size, date,
+ time and name, one file per line. If file is encrypted,
+ line starts from '*' character.
+
+ 'lt' displays the detailed file information in multiline mode.
+ This information includes file checksum value, host OS,
+ compression options and other parameters.
+
+ 'lta' provide the detailed information not only for files,
+ but also for service headers like NTFS streams
+ or file security data.
+
+ 'lb' lists bare file names with path, one per line,
+ without any additional information.
+
+ You can use -v switch to list contents of all volumes
+ in volume set: rar l -v vol.part1.rar
+
+ Commands 'lt', 'lta' and 'lb' are equal to 'vt', 'vta'
+ and 'vb' correspondingly.
+
+
+ m[f] Move to archive [files only]. Moving files and directories
+ results in the files and directories being erased upon
+ successful completion of the packing operation. Directories will
+ not be removed if 'f' modifier is used and/or '-ed' switch is
+ applied.
+
+
+ p Print file to stdout.
+
+ You may use this command together with -inul switch to disable
+ all RAR messages and print only file data. It may be important
+ when you need to send a file to stdout for use in pipes.
+
+
+ r Repair archive. Archive repairing is performed in two stages.
+ First, the damaged archive is searched for a recovery record
+ (see 'rr' command). If archive contains the previously added
+ recovery record and if damaged data area is continuous
+ and smaller than error correction code size in recovery record,
+ chance of successful archive reconstruction is high.
+ When this stage has been completed, a new archive is created,
+ named as fixed.arcname.rar, where 'arcname' is the original
+ (damaged) archive name.
+
+ If broken archive does not contain a recovery record or if
+ archive is not completely recovered due to major damage,
+ second stage is performed. During this stage only the archive
+ structure is reconstructed and it is impossible to recover
+ files which fail checksum validation, it is still possible,
+ however, to recover undamaged files, which were inaccessible
+ due to the broken archive structure. Mostly this is useful
+ for non-solid archives. This stage is never efficient
+ for archives with encrypted file headers, which can be repaired
+ only if recovery record is present.
+
+ When the second stage is completed, the reconstructed archive
+ is saved as rebuilt.arcname.rar, where 'arcname' is
+ the original archive name.
+
+ By default, repaired archives are created in the current
+ directory, but you can append an optional destpath\ parameter
+ to specify another destination directory.
+
+ Example:
+
+ rar r buggy.rar c:\fixed\
+
+ repair buggy.rar and place the result to 'c:\fixed' directory.
+
+
+ rc Reconstruct missing and damaged volumes using recovery volumes
+ (.rev files). You need to specify any existing volume
+ as the archive name, for example, 'rar rc backup.part03.rar'
+
+ Read 'rv' command description for information about
+ recovery volumes.
+
+
+ rn Rename archived files.
+
+ The command syntax is:
+
+ rar rn ...
+
+ For example, the following command:
+
+ rar rn data.rar readme.txt readme.bak info.txt info.bak
+
+ will rename readme.txt to readme.bak and info.txt to info.bak
+ in the archive data.rar.
+
+ It is allowed to use wildcards in the source and destination
+ names for simple name transformations like changing file
+ extensions. For example:
+
+ rar rn data.rar *.txt *.bak
+
+ will rename all *.txt files to *.bak.
+
+ RAR does not check if the destination file name is already
+ present in the archive, so you need to be careful to avoid
+ duplicated names. It is especially important when using
+ wildcards. Such a command is potentially dangerous, because
+ a wrong wildcard may corrupt all archived names.
+
+
+ rr[N] Add data recovery record. Optionally, redundant information
+ (recovery record) may be added to archive. While it increases
+ the archive size, it helps to recover archived files in case of
+ disk failure or data loss of other kind, provided that damage
+ is not too severe. Such damage recovery can be done with
+ command "r" (repair).
+
+ RAR 4.x and RAR 5.0 archives use different recovery record
+ structure and algorithms.
+
+ RAR 4.x recovery record is based on XOR algorithm.
+ You can specify 4.x record size as a number of recovery sectors
+ or as a percent of archive size. To specify a number of sectors
+ just add it directly after 'rr', like 'rr1000' for 1000 sectors.
+ To use a percent append 'p' or '%' modifier after the percent
+ number, such as 'rr5p' or 'rr5%' for 5%. Note that in Windows
+ .bat and .cmd files it is necessary to use 'rr5%%' instead of
+ 'rr5%', because the command processor treats the single '%'
+ as the start of a batch file parameter, so it might be more
+ convenient to use 'p' instead of '%' in this case.
+
+ RAR 4.x recovery sectors are 512 bytes long. If damaged area
+ is continuous, every sector helps to recover 512 bytes of
+ damaged information. This value may be lower in cases of
+ multiple damage. Maximum number of recovery sectors is 524288.
+
+ Size of 4.x recovery record may be approximately determined as
+ /256 + *512 bytes.
+
+ RAR 5.0 recovery record uses Reed-Solomon error correction codes.
+ Its ability to repair continuous damage is about the same
+ as for RAR 4.x, allowing to restore slightly less data than
+ recovery record size. But it is significantly more efficient
+ than RAR 4.x record in case of multiple damaged areas.
+
+ RAR 5.0 record does not use 512 byte sectors and you can specify
+ its size only as a percent of archive size. Even if '%' or 'p'
+ modifier is not present, RAR treats the value as a percent
+ in case of RAR 5.0 format, so both 'rr5' and 'rr5p' mean 5%.
+ Due to service data overhead the actual resulting recovery record
+ size only approximately matches the user defined percent
+ and difference is larger for smaller archives.
+
+ RAR 5.0 recovery record size cannot exceed the protecting
+ archive size, so you cannot use more than 100% as a parameter.
+ Larger recovery records are processed slower both when creating
+ and repairing.
+
+ RAR 5.0 recovery record is more resistant to damage of recovery
+ record itself and can utilize a partially corrupt recovery
+ record data. Note, though, that 'R' repair command does not fix
+ broken blocks in recovery record. Only file data are corrected.
+ After successful archive repair, you may need to create a new
+ recovery record for saved files.
+
+ Both 4.x and 5.0 records are most efficient if data positions
+ in damaged archive are not shifted. If you copy an archive
+ from damaged media using some special software and if you have
+ a choice to fill damaged areas with zeroes or to cut out them
+ from copied file, filling with zeroes or any other value is
+ preferable, because it allows to preserve original data positions.
+ Still, even though it is not an optimal mode, both versions
+ of records attempt to repair data even in case of deletions
+ or insertions of reasonable size, when data positions were
+ shifted. RAR 5.0 recovery record handles deletions and insertions
+ more efficiently than RAR 4.x.
+
+ If you use the plain 'rr' command without optional parameter,
+ RAR will set the recovery record size to 3% of archive size
+ by default.
+
+ Example:
+
+ rar rr5p arcname
+
+ add the recovery record of 5% of archive size.
+
+
+ rv[N] Create recovery volumes (.rev files), which can be later
+ used to reconstruct missing and damaged files in a volume
+ set. This command makes sense only for multivolume archives
+ and you need to specify the name of the first volume
+ in the set as the archive name. For example:
+
+ rar rv3 data.part01.rar
+
+ This feature may be useful for backups or, for example,
+ when you posted a multivolume archive to a newsgroup
+ and a part of subscribers did not receive some of the files.
+ Reposting recovery volumes instead of usual volumes
+ may reduce the total number of files to repost.
+
+ Each recovery volume is able to reconstruct one missing
+ or damaged RAR volume. For example, if you have 30 volumes
+ and 3 recovery volumes, you are able to reconstruct any
+ 3 missing volumes. If the number of .rev files is less than
+ the number of missing volumes, reconstructing is impossible.
+ The total number of usual and recovery volumes must not
+ exceed 255 for RAR 4.x and 65535 for RAR 5.0 archive format.
+
+ Original RAR volumes must not be modified after creating
+ recovery volumes. Recovery algorithm uses data stored both
+ in REV files and in RAR volumes to rebuild missing RAR volumes.
+ So if you modify RAR volumes, for example, lock them, after
+ creating REV files, recovery process will fail.
+
+ Additionally to recovery data, RAR 5.0 recovery volumes
+ also store service information such as checksums of protected
+ RAR files. So they are slightly larger than RAR volumes
+ which they protect. If you plan to copy individual RAR and REV
+ files to some removable media, you need to take it into account
+ and specify RAR volume size by a few kilobytes smaller
+ than media size.
+
+ The optional parameter specifies a number of recovery
+ volumes to create and must be less than the total number
+ of RAR volumes in the set. You may also append a percent
+ or 'p' character to this parameter, in such case the number of
+ creating .rev files will be equal to this percent taken
+ from the total number of RAR volumes. For example:
+
+ rar rv15% data.part01.rar
+
+ RAR reconstructs missing and damaged volumes either when
+ using 'rc' command or automatically, if it cannot locate
+ the next volume and finds the required number of .rev files
+ when unpacking.
+
+ Original copies of damaged volumes are renamed to *.bad
+ before reconstruction. For example, volname.part03.rar
+ will be renamed to volname.part03.rar.bad.
+
+
+ s[name] Convert archive to SFX. The archive is merged with a SFX module
+ (using a module in file default.sfx or specified in the switch).
+ In the Windows version default.sfx should be placed in the
+ same directory as the rar.exe, in Unix - in the user's
+ home directory, in /usr/lib or /usr/local/lib.
+
+ s- Remove SFX module from the already existing SFX archive.
+ RAR creates a new archive without SFX module, the original
+ SFX archive is not deleted.
+
+ t Test archive files. This command performs a dummy file
+ extraction, writing nothing to the output stream, in order to
+ validate the specified file(s).
+
+ Examples:
+
+ Test archives in current directory:
+
+ rar t *
+
+ or for Unix:
+
+ rar t '*'
+
+ User may test archives in all sub-directories, starting
+ with the current path:
+
+ rar t -r *
+
+ or for Unix:
+
+ rar t -r '*'
+
+
+ u Update files in archive. Adds files not yet in the archive
+ and updates archived files that are older than files to add.
+
+
+ v[t[a],b]
+ Verbosely list archive contents [technical [all], bare].
+
+ 'v' command lists archived file attributes, size, packed size,
+ compression ratio, date, time, checksum and name, one file
+ per line. If file is encrypted, line starts from '*' character.
+ For BLAKE2sp checksum only two first and one last symbol are
+ displayed.
+
+ 'vt' displays the detailed file information in multiline mode.
+ This information includes file checksum value, host OS,
+ compression options and other parameters.
+
+ 'vta' provide the detailed information not only for files,
+ but also for service headers like NTFS streams
+ or file security data.
+
+ 'vb' lists bare file names with path, one per line,
+ without any additional information.
+
+ You can use -v switch to list contents of all volumes
+ in volume set: rar v -v vol.part1.rar
+
+ Commands 'vt', 'vta' and 'vb' are equal to 'lt', 'lta'
+ and 'lb' correspondingly.
+
+
+ x Extract files with full path.
+
+ Examples:
+
+ 1) extract 10cents.txt to current directory not displaying
+ the archive comment
+
+ rar x -c- dime 10cents.txt
+
+ 2) extract *.txt from docs.rar to c:\docs directory
+
+ rar x docs.rar *.txt c:\docs\
+
+ 3) extract the entire contents of docs.rar to current directory
+
+ rar x docs.rar
+
+
+ Switches (used in conjunction with a command):
+
+
+ -? Display help on commands and switches. The same as when none
+ or an illegal command line option is entered.
+
+
+ -- Stop switches scanning
+
+ This switch tells to RAR that there are no more switches
+ in the command line. It could be useful, if either archive
+ or file name starts from '-' character. Without '--' switch
+ such a name would be treated as a switch.
+
+ Example:
+
+ add all files from the current directory to the solid archive
+ '-StrangeName'
+
+ RAR a -s -- -StrangeName
+
+
+ -@[+] Disable [enable] file lists
+
+ RAR treats command line parameters starting from '@' character
+ as file lists. So by default, RAR attempts to read 'filename'
+ filelist, when encountering '@filename' parameter.
+ But if '@filename' file exists, RAR treats the parameter
+ as '@filename' file instead of reading the file list.
+
+ Switch -@[+] allows to avoid this ambiguity and strictly
+ define how to handle parameters starting from '@' character.
+
+ If you specify -@, all such parameters found after this switch
+ will be considered as file names, not file lists.
+
+ If you specify -@+, all such parameters found after this switch
+ will be considered as file lists, not file names.
+
+ This switch does not affect processing parameters located
+ before it.
+
+ Example:
+
+ test the archived file '@home'
+
+ rar t -@ notes.rar @home
+
+
+ -ac Clear Archive attribute after compression or extraction
+ (Windows version only).
+
+ If -ac is specified when archiving, "Archive" file attribute
+ is cleared for successfully compressed files. When extracting,
+ -ac will clear "Archive" attribute for extracted files.
+ This switch does not affect directory attributes.
+
+
+ -ad Append archive name to destination path.
+
+ This option may be useful when unpacking a group of archives.
+ By default RAR places files from all archives in the same
+ directory, but this switch creates a separate directory
+ for files unpacked from each archive.
+
+ Example:
+
+ rar x -ad *.rar data\
+
+ RAR will create subdirectories below 'data' for every unpacking
+ archive.
+
+
+ -ag[format]
+ Generate archive name using the current date and time.
+
+ Appends the current date string to an archive name when
+ creating or processing an archive. Useful for daily backups.
+
+ Format of the appending string is defined by the optional
+ "format" parameter or by "YYYYMMDDHHMMSS" if this parameter
+ is absent. The format string may include the following
+ characters:
+
+ Y - year
+ M - month
+ MMM - month name as text string (Jan, Feb, etc.)
+ W - a week number (a week starts with Monday)
+ A - day of week number (Monday is 1, Sunday - 7)
+ D - day of month
+ E - day of year
+ H - hours
+ M - minutes (treated as minutes if encountered after hours)
+ I - minutes (treated as minutes regardless of hours position)
+ S - seconds
+ N - archive number. RAR searches for already existing archive
+ with generated name and if found, increments the archive
+ number until generating a unique name. 'N' format character
+ is not supported when creating volumes.
+ When performing non-archiving operations like extracting,
+ RAR selects the existing archive preceding the first
+ unused name or sets N to 1 if no such archive exists.
+
+ Each of format string characters listed above represents only
+ one character added to archive name. For example, use WW for
+ two digit week number or YYYY to define four digit year.
+
+ If the first character in the format string is '+', positions
+ of the date string and base archive name are exchanged,
+ so a date will precede an archive name.
+
+ The format string may contain optional text enclosed in '{'
+ and '}' characters. This text is inserted into archive name.
+
+ All other characters are added to an archive name without
+ changes.
+
+ If you need to process an already existing archive, be careful
+ with -ag switch. Depending on the format string and time passed
+ since previous -ag use, generated and existing archive names
+ may mismatch. In this case RAR will create or open a new archive
+ instead of processing the already existing one. You may use
+ -log switch to write the generated archive name to a file
+ and then read it from file for further processing.
+
+
+ Examples:
+
+ 1) use the default YYYYMMDDHHMMSS format
+
+ rar a -ag backup
+
+ 2) use DD-MMM-YY format
+
+ rar t -agDD-MMM-YY backup
+
+ 3) use YYYYMMDDHHMM format, place date before 'backup'
+
+ rar a -ag+YYYYMMDDHHMM backup
+
+ 4) use YYYY-WW-A format, include fields description
+
+ rar a -agYYYY{year}-WW{week}-A{wday} backup
+
+ 5) use YYYYMMDD and the archive number. It allows to generate
+ unique names even when YYYYMMDD format mask used more than
+ once in the same day
+
+ rar a -agYYYYMMDD-NN backup
+
+
+ -ai Ignore file attributes.
+
+ If this switch is used when extracting, RAR does not set
+ general file attributes stored in archive to extracted files.
+ This switch preserves attributes assigned by operating system
+ to a newly created file.
+
+ If this switch is used when archiving, predefined values,
+ typical for file and directory, are stored instead of actual
+ attributes.
+
+ In Windows it affects archive, system, hidden and read-only
+ attributes. in Unix - user, group, and others file permissions.
+
+
+ -ao Add files with "Archive" attribute set
+ (Windows version only).
+
+ If -ao is used when archiving, only files with "Archive"
+ file attribute will be added to archive. This switch does not
+ affect directories, so all matching directories are added
+ regardless of their attributes. You can also specify -ed switch
+ if you prefer to omit all directory records.
+
+ Example:
+
+ add all disk C: files with "Archive" attribute set
+ to the 'f:backup' and clear files "Archive" attribute
+
+ rar a -r -ac -ao f:backup c:\*.*
+
+
+ -ap Set path inside archive. This path is merged to file
+ names when adding files to an archive and removed
+ from file names when extracting.
+
+ For example, if you wish to add the file 'readme.txt'
+ to the directory 'DOCS\ENG' of archive 'release',
+ you may run:
+
+ rar a -apDOCS\ENG release readme.txt
+
+ or to extract 'ENG' to the current directory:
+
+ rar x -apDOCS release DOCS\ENG\*.*
+
+
+ -as Synchronize archive contents
+
+ If this switch is used when archiving, those archived files
+ which are not present in the list of the currently added
+ files, will be deleted from the archive. It is convenient to
+ use this switch in combination with -u (update) to synchronize
+ contents of an archive and an archiving directory.
+
+ For example, after the command:
+
+ rar a -u -as backup sources\*.cpp
+
+ the archive 'backup.rar' will contain only *.cpp files
+ from directory 'sources', all other files will be deleted
+ from the archive. It looks similar to creating a new archive,
+ but with one important exception: if no files are modified
+ since the last backup, the operation is performed much faster
+ than the creation of a new archive.
+
+
+ -cfg- Ignore configuration file and RAR environment variable.
+
+
+ -cl Convert file names to lower case.
+
+
+ -cu Convert file names to upper case.
+
+
+ -c- Disable comments show.
+
+
+ -df Delete files after archiving
+
+ Move files to archive. This switch in combination with
+ the command "A" performs the same action as the command "M".
+
+
+ -dh Open shared files
+
+ Allows to process files opened by other applications
+ for writing.
+
+ This switch helps if an application allowed read access
+ to file, but if all types of file access are prohibited,
+ the file open operation will still fail.
+
+ This option could be dangerous, because it allows
+ to archive a file, which at the same time is modified
+ by another application, so use it carefully.
+
+
+ -dr Delete files to Recycle Bin
+
+ Delete files after archiving and place them to Recycle Bin.
+ Available in Windows version only.
+
+
+ -ds Do not sort files while adding to a solid archive.
+
+
+ -dw Wipe files after archiving
+
+ Delete files after archiving. Before deleting file data
+ are overwritten by zero bytes to prevent recovery of deleted
+ files, file is truncated and renamed to temporary name.
+
+ Please be aware that such approach is designed for usual
+ hard disks, but may fail to overwrite the original file data
+ on solid state disks, as result of SSD wear leveling technology
+ and more complicated data addressing.
+
+
+ -ed Do not add empty directories
+
+ This switch indicates that directory records are not to be
+ stored in the created archive. When extracting such archives,
+ RAR creates non-empty directories based on paths of files
+ contained in them. Information about empty directories is
+ lost. All attributes of non-empty directories except a name
+ (access rights, streams, etc.) will be lost as well, so use
+ this switch only if you do not need to preserve such information.
+
+ If -ed is used with 'm' command or -df switch, RAR will not
+ remove empty directories.
+
+
+ -en Do not add "end of archive" block
+
+ Not supported for RAR 5.0 archives.
+
+ By default, RAR adds an "end of archive" block to the end of
+ a new or updated archive. It allows to skip external data like
+ digital signatures safely, but in some special cases it may be
+ useful to disable this feature. For example, if an archive
+ is transferred between two systems via an unreliable link and
+ at the same time a sender adds new files to it, it may be
+ important to be sure that the already received file part will
+ not be modified on the other end between transfer sessions.
+
+ This switch cannot be used with volumes, because the end
+ of archive block contains information important for correct
+ volume processing.
+
+
+ -ep Exclude paths from names. This switch enables files to be
+ added to an archive without including the path information.
+ This could result in multiple files with the same name
+ existing in the archive.
+
+ If used when extracting, archived paths are ignored
+ for extracted files, so all files are created in the same
+ destination directory.
+
+
+ -ep1 Exclude base dir from names. Do not store or extract the path
+ entered in the command line. Ignored if path includes wildcards.
+
+ Examples:
+
+ 1) add all files and directories from 'tmp' directory to archive
+ 'test', but exclude 'tmp\' from archived names path:
+
+ rar a -ep1 -r test tmp\*
+
+ This is an equivalent to commands:
+
+ cd tmp
+ rar a -r ..\test
+ cd ..
+
+ 2) extract files matching images\* mask to dest\ directory,
+ but remove 'images\' from paths of created files:
+
+ rar x -ep1 data images\* dest\
+
+
+ -ep2 Expand paths to full. Store full file paths (except the drive
+ letter and leading path separator) when archiving.
+
+
+ -ep3 Expand paths to full including the drive letter.
+ Windows version only.
+
+ This switch stores full file paths including the drive
+ letter if used when archiving. Drive separators (colons)
+ are replaced by underscore characters.
+
+ If you use -ep3 when extracting, it will change
+ underscores back to colons and create unpacked files
+ in their original directories and disks. If the user
+ also specified a destination path, it will be ignored.
+
+ It also converts UNC paths from \\server\share to
+ __server\share when archiving and restores them to
+ the original state when extracting.
+
+ This switch can help to backup several disks to the same
+ archive. For example, you may run:
+
+ rar a -ep3 -r backup.rar c:\ d:\ e:\
+
+ to create backup and:
+
+ rar x -ep3 backup.rar
+
+ to restore it.
+
+ But be cautious and use -ep3 only if you are sure that
+ extracting archive does not contain any malicious files.
+ In other words, use it if you have created an archive yourself
+ or completely trust its author. This switch allows to overwrite
+ any file in any location on your computer including important
+ system files and should normally be used only for the purpose
+ of backup and restore.
+
+
+ -e[+]
+ Specifies file exclude or include attributes mask.
+
+ is a number in the decimal, octal (with leading '0')
+ or hex (with leading '0x') format.
+
+ By default, without '+' sign before , this switch
+ defines the exclude mask. So if result of bitwise AND between
+ and file attributes is nonzero, file would not be
+ processed.
+
+ If '+' sign is present, it specifies the include mask.
+ Only those files which have at least one attribute specified
+ in the mask will be processed.
+
+ In Windows version is also possible to use symbols D, S, H,
+ A and R instead of a digital mask to denote directories
+ and files with system, hidden, archive and read-only attributes.
+ The order in which the attributes are given is not significant.
+ Unix version supports D and V symbols to define directory
+ and device attributes.
+
+ It is allowed to specify both -e and -e+
+ in the same command line.
+
+ Examples:
+
+ 1) archive only directory names without their contents
+
+ rar a -r -e+d dirs
+
+ 2) do not compress system and hidden files:
+
+ rar a -esh files
+
+ 3) do not extract read-only files:
+
+ rar x -er files
+
+
+ -f Freshen files. May be used with archive extraction or creation.
+ The command string "a -f" is equivalent to the command 'f', you
+ could also use the switch '-f' with the commands 'm' or 'mf'. If
+ the switch '-f' is used with the commands 'x' or 'e', then only
+ old files would be replaced with new versions extracted from the
+ archive.
+
+
+ -hp[p] Encrypt both file data and headers.
+
+ This switch is similar to -p[p], but switch -p encrypts
+ only file data and leaves other information like file names
+ visible. This switch encrypts all sensitive archive areas
+ including file data, file names, sizes, attributes, comments
+ and other blocks, so it provides a higher security level.
+ Without a password it is impossible to view even the list of
+ files in archive encrypted with -hp.
+
+ Example:
+
+ rar a -hpfGzq5yKw secret report.txt
+
+ will add the file report.txt to the encrypted archive
+ secret.rar using the password 'fGzq5yKw'
+
+
+ -ht[b|c]
+ Select hash type [BLAKE2,CRC32] for file checksum.
+
+ File data integrity in RAR archive is protected by checksums
+ calculated and stored for every archived file.
+
+ By default, RAR uses CRC32 function to calculate the checksum.
+ RAR 5.0 archive format also allows to select BLAKE2sp hash
+ function instead of CRC32.
+
+ Specify -htb switch for BLAKE2sp and -htc for CRC32 hash function.
+ Since CRC32 is the default algorithm, you may need -htc only to
+ override -htb in RAR configuration.
+
+ CRC32 output is 32 bit length. While CRC32 properties are
+ suitable to detect most of unintentional data errors,
+ it is not reliable enough to verify file data identity.
+ In other words, if two files have the same CRC32,
+ it does not guarantee that file contents is the same.
+
+ BLAKE2sp output is 256 bit. Being a cryptographically strong
+ hash function, it practically guarantees that if two files
+ have the same value of BLAKE2sp, their contents is the same.
+ BLAKE2sp error detection property is also more reliable than
+ in shorter CRC32.
+
+ Since BLAKE2sp output is longer, resulting archive is
+ slightly larger for -htb switch.
+
+ If archive headers are unencrypted (no switch -hp), checksums
+ for encrypted RAR 5.0 files are modified using a special
+ password dependent algorithm, to make impossible guessing
+ file contents based on checksums. Do not expect such encrypted
+ file checksums to match usual CRC32 and BLAKE2sp values.
+
+ This switch is supported only by RAR 5.0 format, so you
+ need to use -ma switch with it.
+
+ You can see checksums of archived files using 'vt' or 'lt'
+ commands.
+
+
+ Example:
+
+ rar a -ma -htb lists.rar *.lst
+
+ will add *.lst to lists.rar using BLAKE2sp for file checksums.
+
+
+ -id[c,d,p,q]
+ Disable messages.
+
+ Switch -idc disables the copyright string.
+
+ Switch -idd disables "Done" string at the end of operation.
+
+ Switch -idp disables the percentage indicator.
+
+ Switch -idq turns on the quiet mode, so only error messages
+ and questions are displayed.
+
+ It is allowed to use several modifiers at once,
+ so switch -idcdp is correct.
+
+
+ -ieml[.][addr]
+ Send archive by email. Windows version only.
+
+ Attach an archive created or updated by the add command
+ to email message. You need to have a MAPI compliant email
+ client to use this switch (most modern email programs
+ support MAPI interface).
+
+ You may enter a destination email address directly
+ in the switch or leave it blank. In the latter case you
+ will be asked for it by your email program. It is possible
+ to specify several addresses separated by commas or semicolons.
+
+ If you append a dot character to -ieml, an archive will be
+ deleted after it was successfully attached to an email.
+ If the switch is used when creating a multivolume archive,
+ every volume is attached to a separate email message.
+
+
+ -ierr Send all messages to stderr.
+
+
+ -ilog[name]
+ Log errors to file.
+
+ Write error messages to rar.log file. If optional 'name'
+ parameter is not specified, the log file is created
+ using the following defaults:
+
+ Unix: .rarlog file in the user's home directory;
+ Windows: rar.log file in %APPDATA%\WinRAR directory.
+
+ If 'name' parameter includes a file name without path,
+ RAR will create the log file in the default directory
+ mentioned above using the specified name. Include both path
+ and name to 'name' parameter if you wish to change
+ the location of log file.
+
+ By default, log file uses UTF-16 little endian encoding,
+ but it can be changed with -scg switch, such as -scag
+ for native single byte encoding.
+
+
+ Example:
+
+ rar a -ilogc:\log\backup.log backup d:\docs
+
+ will create c:\log\backup.log log file in case of errors.
+
+
+ -inul Disable all messages.
+
+
+ -ioff[n]
+ Turn PC off after completing an operation.
+
+ Use -ioff or -ioff1 to turn PC off, -ioff2 to hibernate,
+ -ioff3 to sleep and -ioff4 to restart. Appropriate power features
+ must be supported by operating system. Windows version only.
+
+
+ -isnd Enable sound.
+
+
+ -iver Display the version number and quit. You can run just "RAR -iver".
+
+
+ -k Lock archive. Any command which intends to change the archive
+ will be ignored.
+
+
+ -kb Keep broken extracted files.
+
+ RAR, by default, deletes files with checksum errors
+ after extraction. The switch -kb specifies that files
+ with checksum errors should not be deleted.
+
+
+ -log[fmt][=name]
+ Write names to log file.
+
+ This switch allows to write archive and file names to specified
+ text file in archiving, extracting, deleting and listing commands.
+ Its behavior is defined by 'fmt' string, which can include one
+ or more of following characters:
+
+ A - write archive names to log file. If RAR creates or processes
+ volumes, all volume names are logged.
+
+ F - write processed file names to log file. It includes
+ files added to archive and extracted, deleted or listed
+ files inside of archive.
+
+ P - if log file with specified name is exist, append data
+ to existing file instead of creating a new one.
+
+ U - write data in Unicode format.
+
+ If neither 'A' nor 'F' are specified, 'A' is assumed.
+
+ 'name' parameter allows to specify the name of log file.
+ It must be separated from 'fmt' string by '=' character.
+ If 'name' is not present, RAR will use the default rarinfo.log
+ file name.
+
+ It is allowed to specify several -log switches in the same
+ command line.
+
+ This switch can be particularly useful, when you need to process
+ an archive created with -ag or -v switches in a batch script.
+ You can specify -loga=arcname.txt when creating an archive
+ and then read an archive name generated by RAR from arcname.txt
+ with an appropriate command. For example, in Windows batch file
+ it can be: set /p name= Set compression method:
+
+ -m0 store do not compress file when adding to archive
+ -m1 fastest use fastest method (less compressive)
+ -m2 fast use fast compression method
+ -m3 normal use normal (default) compression method
+ -m4 good use good compression method (more
+ compressive, but slower)
+ -m5 best use best compression method (slightly more
+ compressive, but slowest)
+
+ If this switch is not specified, RAR uses -m3 method
+ (normal compression).
+
+
+ -ma[4|5]
+ Specify a version of archiving format.
+
+ By default RAR creates archives in RAR 5.0 format.
+ Use -ma4 to create RAR 4.x archives.
+ Use -ma5 or just -ma in case you need to override -ma4 saved
+ in configuration and force RAR to use RAR 5.0 format.
+
+ This switch is used only when creating a new archive.
+ It is ignored when updating an existing archive.
+
+
+ -mc
+ Set advanced compression parameters.
+
+ This switch is intended mainly for benchmarking and
+ experiments. In the real environment it is usually better
+ to allow RAR to select optimal parameters automatically.
+ Please note that improper use of this switch may lead
+ to very serious performance and compression loss, so use
+ it only if you clearly understand what you do.
+
+ It has the following syntax:
+
+ -mc[param1][:param2][module][+ or -]
+
+ where is the one character field denoting a part
+ of the compression algorithm, which has to be configured.
+
+ It may have the following values:
+
+ A - audio compression;
+ C - true color (RGB) data compression;
+ D - delta compression;
+ E - x86 executable compression;
+ I - Intel Itanium executable compression;
+ T - text compression.
+
+ RAR 5.0 archive format supports only 'D' and 'E' values.
+
+ '+' sign at the end of switch applies the selected algorithm
+ module to all processed data, '-' disables the module at all.
+ If no sign is specified, RAR will choose modules automatically,
+ based on data and the current compression method.
+
+ Switch -mc- disables all optional modules and allows only
+ the general compression algorithm.
+
+ and are module dependent parameters
+ described below.
+
+
+ Audio compression, delta compression:
+
+ is a number of byte channels (can be 1 - 31).
+ RAR splits multibyte channels to bytes, for example,
+ two 16-bit audio channels are considered by RAR as four
+ channels one byte each.
+
+ is ignored.
+
+
+ x86 Intel executable compression, Intel Itanium executable
+ compression, true color (RGB) data compression:
+
+ and are ignored.
+
+
+ Text compression:
+
+ Text compression algorithm provides noticeably higher compression
+ on plain text data. But it cannot utilize several CPU cores
+ efficiently resulting in slower compression time comparing to
+ general algorithm in multicore and multiprocessor environment.
+ Also its decompression speed is much slower than in general
+ algorithm regardless of CPU cores number. This is why
+ the text compression is disabled by default. You can specify
+ -mct switch to allow RAR to select this algorithm automatically
+ for suitable data. Switch -mct+ will force use of the text
+ compression for all data.
+
+ Switch -mct can also include and , so its
+ full syntax is -mc[param1][:param2]t[+ or -].
+
+ is the order of PPM algorithm (can be 2 - 63).
+ Usually a higher value slightly increases the compression ratio
+ of redundant data, but only if enough memory is available
+ to PPM. In case of lack of memory the result may be negative.
+ Higher order values decrease both compression and decompression
+ speed.
+
+ is memory in megabytes allocated for PPM (1-128).
+ Higher values may increase the compression ratio, but note
+ that PPM uses the equal memory size both to compress and
+ decompress, so if you allocate too much memory when creating
+ an archive, other people may have problems when decompressing
+ it on a computer with less memory installed. Decompression
+ will be still possible using virtual memory, but it may
+ become very slow.
+
+
+ Examples:
+
+ 1) switch -mc1a+ forces use of 8-bit mono audio compression
+ for all data.
+
+ 2) switch -mc10:40t+ forces use of text compression
+ algorithm for all data, sets the compression order to 10
+ and allocates 40 MB memory.
+
+ 3) switch -mc12t sets the text compression order to 12,
+ when the text compression is used, but leaves to RAR to
+ decide when to use it.
+
+ 4) switches -mct -mcd- allow RAR to apply the text compression
+ to suitable data and disable the delta compression.
+
+
+ -md[k,m,g]
+ Select the dictionary size.
+
+ Sliding dictionary is the memory area used by compression
+ algorithm to find and compress repeated data patterns.
+ If size of file being compressed (or total files size in case
+ of solid archive) is larger than dictionary size, increasing
+ the dictionary is likely to increase the compression ratio,
+ reduce the archiving speed and increase memory requirements.
+
+ For RAR 4.x archive format the dictionary size can be:
+ 64 KB, 128 KB, 256 KB, 512 KB, 1 MB, 2 MB, 4 MB.
+
+ For RAR 5.0 archive format the dictionary size can be:
+ 128 KB, 256 KB, 512 KB, 1 MB, 2 MB, 4 MB, 8 MB, 16 MB,
+ 32 MB, 64 MB, 128 MB, 256 MB, 512 MB, 1 GB.
+
+ You can use 'k', 'm' and 'g' modifiers to specify the size
+ in kilo-, mega- and gigabytes, like -md64m for 64 MB dictionary.
+ If no modifier is specified, megabytes are assumed,
+ so -md64m and -md64 are equal.
+
+ When archiving, RAR needs about 6x memory of specified
+ dictionary size, so 512 MB and 1 GB sizes are available
+ in 64 bit RAR version only. When extracting, slightly more
+ than a single dictionary size is allocated, so both 32
+ and 64 bit versions can unpack archives with all dictionaries
+ up to and including 1 GB.
+
+ If size of all source files for solid archive or size of largest
+ source file for non-solid archive is at least twice less than
+ dictionary size, RAR can reduce the dictionary size. It helps
+ to lower memory usage without decreasing compression.
+
+ Default sliding dictionary size is 4 MB for RAR 4.x
+ and 32 MB for RAR 5.0 archive format.
+
+ Example:
+
+ RAR a -s -ma -md128 lib *.dll
+
+ create a solid archive in RAR 5.0 format with 128 MB dictionary.
+
+
+ -ms[list]
+ Specify file types to store.
+
+ Specify file types, which will be stored without compression.
+ This switch may be used to store already compressed files,
+ which helps to increase archiving speed without noticeable
+ loss in the compression ratio.
+
+ Optional parameter defines the list of file extensions
+ separated by semicolons. For example, -msrar;zip;jpg will
+ force RAR to store without compression all RAR and ZIP
+ archives and JPG images. It is also allowed to specify wildcard
+ file masks in the list, so -ms*.rar;*.zip;*.jpg will work too.
+ Several -ms switches are permitted, such as -msrar -mszip
+ instead of -msrar;zip.
+
+ In Unix -ms switch containing several file types needs to be
+ enclosed in quote marks. It protects semicolons from processing
+ by Unix shell. Another solution is to use individual -ms
+ switches for every file type.
+
+ If is not specified, -ms switch will use the default
+ set of extensions, which includes the following file types:
+
+ 7z, ace, arj, bz2, cab, gz, jpeg, jpg, lha, lz, lzh, mp3,
+ rar, taz, tgz, xz, z, zip, zipx
+
+
+ -mt
+ Set the number of threads.
+
+ parameter can take values from 1 to 32.
+ It defines the recommended maximum number of active threads
+ for compression algorithm also as for other RAR modules,
+ which can start several threads. While RAR attempts to follow
+ this recommendation, sometimes the real number of active
+ threads can exceed the specified value.
+
+ Change of parameter slightly affects the compression
+ ratio, so archives created with different -mt switches
+ will not be exactly the same even if all other compression
+ settings are equal.
+
+ If -mt switch is not specified, RAR will try to detect
+ the number of available processors and select the optimal
+ number of threads automatically.
+
+
+ -n Additionally filter included files.
+
+ Apply the mask as an additional filter to included file list.
+ Wildcards can be used both in the name and file parts of
+ file mask. See switch -x description for details on mask syntax.
+ You can specify the switch '-n' several times.
+
+ This switch does not replace usual file masks, which still
+ need to be entered in the command line. It is an additional
+ filter limiting processed files only to those matching
+ the include mask specified in -n switch. It can help to
+ reduce the command line length sometimes.
+
+ For example, if you need to compress all *.txt and *.lst
+ files in directories Project and Info, you can enter:
+
+ rar a -r text Project\*.txt Project\*.lst Info\*.txt Info\*.lst
+
+ or using the switch -n:
+
+ rar a -r -n*.txt -n*.lst text Project Info
+
+
+ -n@ Read additional filter masks from list file.
+
+ Similar to -n switch, but reads filter masks from
+ the list file. If you use -n@ without the list file name
+ parameter, it will read filter masks from stdin.
+
+ This switch does not replace usual list files or file masks,
+ which still need to be entered in the command line.
+ It is an additional filter limiting processed files only to
+ those matching the include mask specified in -n switch.
+
+ Example:
+
+ rar a -r -n@inclist.txt text Project Info @listfile.txt
+
+
+ -oc Set NTFS Compressed attribute. Windows version only.
+
+ This switch allows to restore NTFS Compressed attribute
+ when extracting files. RAR saves Compressed file attributes
+ when creating an archive, but does not restore them unless
+ -oc switch is specified.
+
+ -oh Save hard links as the link instead of the file.
+
+ If archiving files include several hard links, store the first
+ archived hard link as usual file and the rest of hard links
+ in the same set as links to this first file. When extracting
+ such files, RAR will create hard links instead of usual files.
+
+ You must not delete or rename the first hard link in archive
+ after the archive was created, because it will make extraction
+ of following links impossible. If you modify the first link,
+ all following links will also have the modified contents
+ after extracting. Extraction command must involve the first
+ hard link to create following hard links successfully.
+
+ This switch is supported only by RAR 5.0 format.
+
+
+ -oi[0-4][:]
+ Save identical files as references.
+
+ Switch -oi0 (or just -oi-) turns off identical file processing,
+ so such files are compressed as usual files. It can be used to
+ override another -oi value stored in RAR configuration.
+
+ If -oi1 (or just -oi) is specified, RAR analyzes the file
+ contents before starting archiving. If several identical files
+ are found, the first file in the set is saved as usual file
+ and all following files are saved as references to this first
+ file. It allows to reduce the archive size, but applies some
+ restrictions to resulting archive. You must not delete or rename
+ the first identical file in archive after the archive was
+ created, because it will make extraction of following files
+ using it as a reference impossible. If you modify the first file,
+ following files will also have the modified contents
+ after extracting. Extraction command must involve the first file
+ to create following files successfully.
+
+ It is recommended to use -oi only if you compress a lot of
+ identical files, will not modify an archive later and will
+ extract an archive entirely, without necessity to unpack or skip
+ individual files. If all identical files are small enough to
+ fit into compression dictionary specified with -md switch,
+ switch -s can provide more flexible solution than -oi.
+
+ Switch -oi2 is similar to -oi1, with the only difference:
+ it will display names of found identical files before starting
+ archiving.
+
+ Switches -oi3 and -oi4 allow to utilize RAR to generate
+ lists of identical files. Though you still need to provide
+ a dummy archive name to make the command syntax valid,
+ in this mode an archive is not created and nothing is compressed.
+ If -oi3 is used, file sizes and names are displayed
+ and every identical file group is separated with empty line.
+ Switch -oi4 displays bare file names and skips the first
+ identical file in every file group, so only duplicates
+ are listed.
+
+ Optional value allows to define the minimum file size
+ threshold. Files smaller than are not analyzed
+ and not considered as identical. If this parameter is not
+ present, it is assumed to be 64 KB by default. Selecting
+ too small may increase the time required to detect
+ identical files.
+
+ Switches -oi1 and -oi2 are supported only by RAR 5.0 format.
+
+ Examples:
+
+ 1) rar a -oi -ma archive
+
+ Save contents of current directory to archive.rar.
+ Store identical files as references.
+
+ 2) rar a -oi3:1000000 -r dummy c:\photo\*.jpg
+
+ List all duplicate *.jpg files lather than 1000000 bytes
+ found in c:\photo and its subdirectories.
+
+
+ -ol[a] Process symbolic links as the link [absolute paths]
+
+ Save symbolic links as links, so file contents is not archived.
+ In Windows version it also saves reparse points as links.
+ Such archive entries are restored as symbolic links
+ or reparse points when extracting.
+
+ Supported both for RAR 4.x and RAR 5.0 archives in RAR for Unix
+ and only for RAR 5.0 archives in RAR for Windows.
+
+ In Windows you may need to run RAR as administrator to create
+ symbolic links when extracting.
+
+ RAR adds all links regardless of target when archiving with
+ -ol switch. When extracting, by default, RAR skips symbolic
+ links pointing outside of destination directory, with absolute
+ paths, excessive number of ".." in link target or with other
+ potentially dangerous link parameters. You can enable extracting
+ such links with -ola switch.
+
+ Links pointing to directories outside of extraction destination
+ directory can present a security risk. Enable their extraction
+ only if you are sure that archive contents is safe,
+ such as your own backup.
+
+ Links that are considered safe by RAR are extracted always
+ regardless of -ol or -ola switch.
+
+
+ -oni Allow potentially incompatible names.
+
+ While NTFS file system permits file names with trailing spaces
+ and dots, a lot of Windows programs fail to process such names
+ correctly. If this switch is not specified, RAR removes trailing
+ spaces and dots, if any, from file names when extracting.
+ Specify this switch if you need to extract such names as is.
+
+ Windows version only.
+
+
+ -or Rename extracted files automatically if file with the same name
+ already exists. Renamed file will get the name like
+ 'filename(N).txt', where 'filename.txt' is the original file
+ name and 'N' is a number starting from 1 and incrementing
+ if file exists.
+
+
+ -os Save NTFS streams. Windows version only.
+
+ This switch has meaning only for NTFS file system and allows
+ to save alternate data streams associated with a file.
+ You may need to specify it when archiving if you use software
+ storing data in alternative streams and wish to preserve
+ these streams.
+
+ Streams are not saved for NTFS encrypted files.
+
+
+ -ow Use this switch when archiving to save file security
+ information and when extracting to restore it.
+
+ Unix RAR version saves file owner and group when using
+ this switch.
+
+ Windows version stores owner, group, file permissions and
+ audit information, but only if you have necessary privileges
+ to read them. Note that only NTFS file system supports
+ file based security under Windows.
+
+
+ -o[+|-] Set the overwrite mode. Can be used both when extracting
+ and updating archived files. Following modes are available:
+
+ -o Ask before overwrite
+ (default for extracting files);
+
+ -o+ Overwrite all
+ (default for updating archived files);
+
+ -o- Skip existing files.
+
+
+ -p[pwd] Set password
+
+ Set password to encrypt files when archiving
+ or to decrypt when extracting.
+
+ Passwords are case-sensitive. Maximum password length is
+ 127 characters. Longer passwords are truncated to this length.
+ If you omit a password in command line, you will be prompted
+ with "Enter password" message. You can also use file redirection
+ or pipe to specify a password if parameter is missing.
+
+ Examples:
+
+ 1) rar a -psecret texts.rar *.txt
+
+ add files *.txt and encrypt them with password "secret".
+
+ 2) rar -p texts.rar *.txt < psw.txt
+
+ set contents of psw.txt file as a password.
+
+
+ -p- Do not query password
+
+ Do not query password for encrypted files when extracting.
+ Actually you can specify any invalid password to suppress
+ the password prompt and force RAR to issue 'incorrect password'
+ message when extracting an encrypted file. This switch just
+ sets '-' as a password.
+
+
+ -qo[-|+]
+ Add quick open information [none|force]
+
+ RAR archives store every file header containing information
+ such as file name, time, size and attributes immediately
+ before data of described file. This approach is more damage
+ resistant than storing all file headers in a single continuous
+ block, which if broken or truncated would destroy the entire
+ archive contents. But while being more reliable, such file
+ headers scattered around the entire archive are slower to
+ access if we need to quickly open the archive contents
+ in a shell like WinRAR graphical interface.
+
+ To improve archive open speed and still not make the entire
+ archive dependent on a single damaged block, RAR 5.0 archives
+ can include an optional quick open record. Such record is
+ added to the end of archive and contains copies of file names
+ and other file information stored in a single continuous block
+ additionaly to normal file headers inside of archive.
+ Since the block is continuous, its contents can be read quickly,
+ without necessity to perform a lot of disk seek operations.
+ Every file header in this block is protected with a checksum.
+ If RAR detects that quick open information is damaged,
+ it resorts to reading individual headers from inside of archive,
+ so damage resistance is not lessened.
+
+ Quick open record contains the full copy of file header,
+ which may be several tens or hundreds of bytes per file,
+ increasing the archive size by the same amount. This size
+ increase is most noticeable for many small files, when file
+ data size is comparable to file header. So by default,
+ if no -qo is specified or -qo without parameter is used,
+ RAR stores copies of headers only for relatively large files
+ and continues to use local headers for smaller files.
+ Concrete file size threshold can depend on RAR version.
+ Such approach provides a reasonable open speed to archive size
+ tradeoff. If you prefer to have the maximum archive open speed
+ regardless of size, you can use -qo+ to store copies of all
+ file headers. If you need to have the smallest possible archive
+ and do not care about archive open speed in different programs,
+ specify -qo- to exclude the quick open information completely.
+
+ If you wish to measure the performance effect of this switch,
+ be sure that archive contents is not stored in a disk cache.
+ No real disk seeks are performed for cached archive file,
+ making access to file headers fast even without quick open
+ record.
+
+
+ -r Recurse subdirectories. May be used with commands:
+ a, u, f, m, x, e, t, p, v, l, c, cf and s.
+
+ When used with the commands 'a', 'u', 'f', 'm' will process
+ files in all sub-directories as well as the current working
+ directory.
+
+ When used with the commands x, e, t, p, v, l, c, cf or s will
+ process all archives in sub-directories as well as the current
+ working directory.
+
+
+ -r- Disable recursion.
+
+ Even without -r switch RAR can enable the recursion
+ automatically in some situations. Switch -r- prohibits it.
+
+ If you specify a directory name when archiving and if such
+ name does not include wildcards, by default RAR adds
+ the directory contents even if switch -r is not specified.
+ Also RAR automatically enables the recursion if disk root
+ without wildcards is specified as a file mask. Switch -r-
+ disables such behavior.
+
+ For example:
+
+ rar a -r- arc dirname
+
+ command will add only the empty 'dirname' directory and ignore
+ its contents. Following command:
+
+ rar a -r- arc c:\
+
+ will compress contents of root c: directory only and
+ will not recurse into subdirectories.
+
+
+ -r0 Similar to -r, but when used with the commands 'a', 'u', 'f',
+ 'm' will recurse into subdirectories only for those file masks,
+ which include wildcard characters '*' and '?'.
+
+ This switch works only for file names. Directory names without
+ a file name part, such as 'dirname', are not affected by -r0
+ and their contents is added to archive completely unless -r-
+ switch is specified.
+
+ Example:
+
+ rar a -r0 docs.rar *.doc readme.txt
+
+ add *.doc files from the current directory and its subdirectories
+ and readme.txt only from the current directory to docs.rar
+ archive. In case of usual -r switch, RAR would search for
+ readme.txt in subdirectories too.
+
+
+ -ri
[:]
+ Set priority and sleep time. Available only in RAR for Windows.
+ This switch regulates system load by RAR in multitasking
+ environment. Possible task priority
values are 0 - 15.
+
+ If
is 0, RAR uses the default task priority.
+
equal to 1 sets the lowest possible priority,
+ 15 - the highest possible.
+
+ Sleep time is a value from 0 to 1000 (milliseconds).
+ This is a period of time that RAR gives back to the system
+ after read or write operations while compressing or extracting.
+ Non-zero may be useful if you need to reduce system load
+ even more than can be achieved with
parameter.
+
+ Example:
+
+ execute RAR with default priority and 10 ms sleep time:
+
+ rar a -ri0:10 backup *.*
+
+
+ -rr[N] Add data recovery record. This switch is used when creating
+ or modifying an archive to add a data recovery record to
+ the archive. See the 'rr[N]' command description for details.
+
+
+ -rv[N] Create recovery volumes. This switch is used when creating
+ a multivolume archive to generate recovery volumes.
+ See the 'rv[N]' command description for details.
+
+
+ -s Create solid archive. A solid archive is an archive packed by
+ a special compression method, which treats several or all
+ files, within the archive, as one continuous data stream.
+
+ Solid archiving significantly increases compression, when
+ adding a large number of small, similar files. But it also
+ has a few important disadvantages: slower updating of existing
+ solid archives, slower access to individual files, lower
+ damage resistance.
+
+ Usually files in a solid archive are sorted by extension.
+ But it is possible to disable sorting with -ds switch or set
+ an alternative file order using a special file, rarfiles.lst.
+
+ Example:
+
+ create solid archive sources.rar with 512 KB dictionary,
+ recursing all directories, starting with the current directory.
+ Add only .asm files:
+
+ rar a -s -md512 sources.rar *.asm -r
+
+
+ -s Create solid groups using file count
+
+ Similar to -s, but resets solid statistics after compressing
+ files. Usually decreases compression, but also
+ decreases losses in case of solid archive damages.
+
+
+ -sc[objects]
+ Specify the character set.
+
+ 'Charset' parameter is mandatory and can have one
+ of the following values:
+
+ U - Unicode UTF-16;
+ F - Unicode UTF-8;
+ A - the native single byte encoding, which is ANSI
+ for Windows version;
+ O - OEM (DOS) encoding. Windows version only.
+
+ Endianness of source UTF-16 files, such as list files
+ or comments, is detected based on the byte order mark.
+ If byte order mask is missing, little endian encoding is assumed.
+
+ 'Objects' parameter is optional and can have one of
+ the following values:
+
+ G - log files produced by -ilog switch;
+ L - list files;
+ C - comment files;
+ R - messages sent to redirected files and pipes (Windows only).
+
+ It is allowed to specify more than one object, for example,
+ -scolc. If 'objects' parameter is missing, 'charset' is applied
+ to all objects.
+
+ This switch allows to specify the character set for files
+ in -z[file] switch, list files and comment files written by
+ "cw" command.
+
+ Examples:
+
+ 1) rar a -scol data @list
+
+ Read names contained in 'list' using OEM encoding.
+
+ 2) rar c -scuc -zcomment.txt data
+
+ Read comment.txt as Unicode file.
+
+ 3) rar cw -scuc data comment.txt
+
+ Write comment.txt as Unicode file.
+
+ 4) rar lb -scur data > list.txt
+
+ Save archived file names in data.rar to list.txt in Unicode.
+
+
+ -se Create solid groups using extension
+
+ Similar to -s, but resets solid statistics if file extension
+ is changed. Usually decreases compression, but also
+ decreases losses from solid archive damages.
+
+
+ -sfx[name]
+ Create SFX archives. If this switch is used when creating a new
+ archive, a Self-Extracting archive (using a module in file
+ default.sfx or specified in the switch) would be created.
+ In the Windows version default.sfx should be placed in the
+ same directory as the rar.exe, in Unix - in the user's
+ home directory, in /usr/lib or /usr/local/lib.
+
+ Example:
+
+ rar a -sfxwincon.sfx myinst
+
+ create SelF-eXtracting (SFX) archive using wincon.sfx
+ SFX-module.
+
+
+ -si[name]
+ Read data from stdin (standard input), when creating
+ an archive. Optional 'name' parameter allows to specify
+ a file name of compressed stdin data in the created
+ archive. If this parameter is missing, the name will be
+ set to 'stdin'.
+
+ Example:
+
+ type Tree.Far | rar a -siTree.Far tree.rar
+
+ will compress 'type Tree.Far' output as 'Tree.Far' file.
+
+
+ -sl
+ Process only those files, which size is less than
+ specified in parameter of this switch.
+ Parameter must be specified in bytes.
+
+
+ -sm
+ Process only those files, which size is more than
+ specified in parameter of this switch.
+ Parameter must be specified in bytes.
+
+
+ -sv Create independent solid volumes
+
+ By default RAR tries to reset solid statistics as soon
+ as possible when starting a new volume, but only
+ if enough data was packed after a previous reset
+ (at least a few megabytes).
+
+ This switch forces RAR to ignore packed data size and attempt
+ to reset statistics for volumes of any size. It decreases
+ compression, but increases chances to extract a part of data
+ if one of several solid volumes in a volume set was lost
+ or damaged.
+
+ Note that sometimes RAR cannot reset statistics even
+ using this switch. For example, it cannot be done when
+ compressing one large file split between several volumes.
+ RAR is able to reset solid statistics only between separate
+ files, but not inside of single file.
+
+ Ignored if used when creating a non-volume archive.
+
+
+ -sv- Create dependent solid volumes
+
+ Disables to reset solid statistics between volumes.
+
+ It slightly increases compression, but significantly reduces
+ chances to extract a part of data if one of several solid
+ volumes in a volume set was lost or damaged.
+
+ Ignored if used when creating a non-volume archive.
+
+
+ -s- Disable solid archiving
+
+
+ -t Test files after archiving. This switch is especially
+ useful in combination with the move command, so files will be
+ deleted only if the archive had been successfully tested.
+
+
+ -ta
+ Process only files modified after the specified date.
+
+ Format of the date string is YYYYMMDDHHMMSS.
+ It is allowed to insert separators like '-' or ':' to
+ the date string and omit trailing fields. For example,
+ the following switch is correct: -ta2001-11-20
+ Internally it will be expanded to -ta20011120000000
+ and treated as "files modified after 0 hour 0 minutes
+ 0 seconds of 20 November 2001".
+
+
+ -tb
+ Process only files modified before the specified date.
+ Format of the switch is the same as -ta.
+
+
+ -tk Keep original archive date. Prevents RAR from modifying the
+ archive date when changing an archive.
+
+
+ -tl Set archive time to newest file. Forces RAR to set the date of a
+ changed archive to the date of the newest file in the archive.
+
+
+ -tn