diff --git a/.github/workflows/update-pr.yml b/.github/workflows/update-pr.yml
new file mode 100644
index 000000000000..5faa0cd7049d
--- /dev/null
+++ b/.github/workflows/update-pr.yml
@@ -0,0 +1,20 @@
+name: PR update
+
+on:
+ push:
+ branches:
+ - 'main'
+jobs:
+ autoupdate:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Automatically update mergeable PRs
+ uses: adRise/update-pr-branch@0.7
+ with:
+ token: ${{ secrets.ACTION_USER_TOKEN }}
+ base: 'main'
+ required_approval_count: 1
+ require_passed_checks: true
+ require_auto_merge_enabled: true
+ sort: 'created'
+ direction: 'desc'
diff --git a/baselines/fedmeta/LICENSE b/baselines/fedmeta/LICENSE
new file mode 100644
index 000000000000..d64569567334
--- /dev/null
+++ b/baselines/fedmeta/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/baselines/fedmeta/README.md b/baselines/fedmeta/README.md
new file mode 100644
index 000000000000..a1ed982f8bf2
--- /dev/null
+++ b/baselines/fedmeta/README.md
@@ -0,0 +1,136 @@
+---
+title: Federated Meta-Learning with Fast Convergence and Efficient Communication
+url: https://arxiv.org/abs/1802.07876
+labels: [meta learning, maml, meta-sgd, personalization]
+dataset: [FEMNIST, SHAKESPEARE]
+---
+
+# FedMeta: Federated Meta-Learning with Fast Convergence and Efficient Communication
+
+**Paper:** [arxiv.org/abs/1802.07876](https://arxiv.org/abs/1802.07876)
+
+**Authors:** Fei Chen, Mi Luo, Zhenhua Dong, Zhenguo Li, Xiuqiang He
+
+**Abstract:** Statistical and systematic challenges in collaboratively training machine learning models across distributed networks of mobile devices have been the bottlenecks in the real-world application of federated learning. In this work, we show that meta-learning is a natural choice to handle these issues, and propose a federated meta-learning framework FedMeta, where a parameterized algorithm (or meta-learner) is shared, instead of a global model in previous approaches. We conduct an extensive empirical evaluation on LEAF datasets and a real-world production dataset, and demonstrate that FedMeta achieves a reduction in required communication cost by 2.82-4.33 times with faster convergence, and an increase in accuracy by 3.23%-14.84% as compared to Federated Averaging (FedAvg) which is a leading optimization algorithm in federated learning. Moreover, FedMeta preserves user privacy since only the parameterized algorithm is transmitted between mobile devices and central servers, and no raw data is collected onto the servers.
+
+
+## About this baseline
+
+**What’s implemented:** We reimplemented the experiments from the paper 'FedMeta: Federated Meta-Learning with Fast Convergence and Efficient Communication' by Fei Chen (2018). which proposed the FedMeta(MAML & Meta-SGD) algorithm. Specifically, we replicate the results from Table 2 and Figure 2 of the paper.
+
+**Datasets:** FEMNIST and SHAKESPEARE from Leaf Federated Learning Dataset
+
+**Hardware Setup:** These experiments were run on a machine with 16 CPU threads and 1 GPU(GeForce RTX 2080 Ti). **FedMeta experiment using the Shakespeare dataset required more computing power.** Out of Memory errors may occur with some clients, but federated learning can continue to operate. On a GPU with more VRAM (A6000 with 48GB) no clients failed.
+
+**Contributors:** Jinsoo Kim and Kangyoon Lee
+
+
+## Experimental Setup
+
+**Task:** A comparison task of four algorithms(FedAvg, FedAvg(Meta), FedMeta(MAML), FedMeta(Meta-SGD)) in the categories of Image Classification and next-word prediction.
+
+**Model:** This directory implements two models:
+* A two-layer CNN network as used in the FedMeta paper for Femnist (see `models/CNN_Network`).
+* A StackedLSTM model used in the FedMeta paper for Shakespeare (see `models/StackedLSTM`).
+
+**You can see more detail in Apendix.A of the paper**
+
+**Dataset:** This baseline includes the FEMNIST dataset and SHAKESPEARE. For data partitioning and sampling per client, we use the Leaf GitHub([LEAF: A Benchmark for Federated Settings](https://github.com/TalwalkarLab/leaf)). The data and client specifications used in this experiment are listed in the table below (Table 1 in the paper).
+
+**Shakespeare Dataset Issue:** In the FedMeta paper experiment, the Shakespeare dataset had 1126 users. However, due to a current bug, the number of users has decreased to 660 users. Therefore, we have only maintained the total number of data.
+
+| Dataset | #Clients | #Samples | #Classes | #Partition Clients | #Partition Dataset |
+|:-----------:|:--------:|:--------:|:--------:|:---------------------------------------------------------------:|:----------------------:|
+| FEMNIST | 1109 | 245,337 | 62 | Train Clients : 0.8 Valid Clients : 0.1, Test Clients : 0.1 | Sup : 0.2 Qry : 0.8 |
+| SHAKESPEARE | 138 | 646,697 | 80 | Train Clients : 0.8 Valid Clients : 0.1, Test Clients : 0.1 | Sup : 0.2 Qry : 0.8 |
+
+**The original specifications of the Leaf dataset can be found in the Leaf paper(_"LEAF: A Benchmark for Federated Settings"_).**
+
+**Training Hyperparameters:** The following table shows the main hyperparameters for this baseline with their default value (i.e. the value used if you run `python main.py algo=? data=?` directly)
+
+| Algorithm | Dataset | Clients per Round | Number of Rounds | Batch Size | Optimizer | Learning Rate(α, β) | Client Resources | Gradient Step |
+|:-----------------:|:--------------:|:-----------------:|:----------------:|:----------:|:---------:|:-------------------:|:---------------------------------------:|:-------------:|
+| FedAvg | FEMNIST SHAKESPEARE | 4 | 2000 400 | 10 | Adam | 0.0001 0.001 | {'num_cpus': 4.0, 'num_gpus': 0.25 } | - |
+| FedAvg(Meta) | FEMNIST SHAKESPEARE | 4 | 2000 400 | 10 | Adam | 0.0001 0.001 | {'num_cpus': 4.0, 'num_gpus': 0.25 } | - |
+| FedMeta(MAML) | FEMNIST SHAKESPEARE | 4 | 2000 400 | 10 | Adam | (0.001, 0.0001) (0.1, 0.01) | {'num_cpus': 4.0, 'num_gpus': 1.0 } | 5 1 |
+| FedMeta(Meta-SGD) | FEMNIST SHAKESPEARE | 4 | 2000 400 | 10 | Adam | (0.001, 0.0001) (0.1, 0.01) | {'num_cpus': 4.0, 'num_gpus': 1.0 } | 5 1 |
+
+
+## Environment Setup
+```bash
+#Environment Setup
+# Set python version
+pyenv install 3.10.6
+pyenv local 3.10.6
+
+# Tell poetry to use python 3.10
+poetry env use 3.10.6
+
+# install the base Poetry environment
+poetry install
+poetry shell
+```
+
+## Running the Experiments
+
+**Download Dataset:** Go [LEAF: A Benchmark for Federated Settings](https://github.com/TalwalkarLab/leaf) and Use the command below! You can download dataset (FEMNIST and SHAKESPEARE).
+```bash
+# clone LEAF repo
+git clone https://github.com/TalwalkarLab/leaf.git
+
+# navigate to data directory and then the dataset
+cd leaf/data/femnist
+#FEMNIST dataset Download command for these experiments
+./preprocess.sh -s niid --sf 0.3 -k 0 -t sample
+
+# navigate to data directory and then the dataset
+cd leaf/data/shakespeare
+#SHAKESEPEARE dataset Download command for these experiments
+./preprocess.sh -s niid --sf 0.16 -k 0 -t sample
+```
+
+*Run `./preprocess.sh` with a choice of the following tags*
+* `-s` := 'iid' to sample in an i.i.d. manner, or 'niid' to sample in a non-i.i.d. manner; more information on i.i.d. versus non-i.i.d. is included in the 'Notes' section
+* `--sf` := fraction of data to sample, written as a decimal; default is 0.1
+* `-k` := minimum number of samples per user
+* `-t` := 'user' to partition users into train-test groups, or 'sample' to partition each user's samples into train-test groups
+
+More detailed tag information can be found on Leaf GitHub.
+
+****Start experiments****
+```bash
+# FedAvg + Femnist Dataset
+python -m fedmeta.main algo=fedavg data=femnist path=(your leaf dataset path)/leaf/data/femnist/data
+
+# FedAvg(Meta) + Femnist Dataset
+python -m fedmeta.main algo=fedavg_meta data=femnist path=./leaf/data/femnist/data
+
+# FedMeta(MAML) + Femnist Dataset
+python -m fedmeta.main algo=fedmeta_maml data=femnist path=./leaf/data/femnist/data
+
+# FedMeta(Meta-SGD) + Femnist Dataset
+python -m fedmeta.main algo=fedmeta_meta_sgd data=femnist path=./leaf/data/femnist/data
+
+
+
+#FedAvg + Shakespeare Dataset
+python -m fedmeta.main algo=fedavg data=shakespeare path=./leaf/data/shakespeare/data
+
+#FedAvg(Meta) + Shakespeare Dataset
+python -m fedmeta.main algo=fedavg_meta data=shakespeare path=./leaf/data/shakespeare/data
+
+#FedMeta(MAML) + Shakespeare Dataset
+python -m fedmeta.main algo=fedmeta_maml data=shakespeare path=./leaf/data/shakespeare/data
+
+#FedMeta(Meta-SGD) + Shakespeare Dataset
+python -m fedmeta.main algo=fedmeta_meta_sgd data=shakespeare path=./leaf/data/shakespeare/data
+
+```
+
+
+## Expected Results
+If you proceed with all of the above experiments, You can get a graph of your experiment results as shown below along that `./femnist or shakespeare/graph_params/result_graph.png`.
+
+| FEMNIST | SHAKESPEARE |
+|:-------------------------------------------:|:----------------------------------------------------:|
+| ![](_static/femnist_result_graph.png) | ![](_static/shakespeare_result_graph.png) |
diff --git a/baselines/fedmeta/_static/femnist_result_graph.png b/baselines/fedmeta/_static/femnist_result_graph.png
new file mode 100644
index 000000000000..935643b46f90
Binary files /dev/null and b/baselines/fedmeta/_static/femnist_result_graph.png differ
diff --git a/baselines/fedmeta/_static/shakespeare_result_graph.png b/baselines/fedmeta/_static/shakespeare_result_graph.png
new file mode 100644
index 000000000000..f23f529adf32
Binary files /dev/null and b/baselines/fedmeta/_static/shakespeare_result_graph.png differ
diff --git a/baselines/fedmeta/fedmeta/__init__.py b/baselines/fedmeta/fedmeta/__init__.py
new file mode 100644
index 000000000000..a5e567b59135
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/__init__.py
@@ -0,0 +1 @@
+"""Template baseline package."""
diff --git a/baselines/fedmeta/fedmeta/client.py b/baselines/fedmeta/fedmeta/client.py
new file mode 100644
index 000000000000..fb09773eebed
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/client.py
@@ -0,0 +1,185 @@
+"""Define your client class and a function to construct such clients."""
+
+from collections import OrderedDict
+from typing import Callable, Dict, List, Tuple
+
+import flwr as fl
+import torch
+import torch.nn
+from flwr.common.typing import NDArrays, Scalar
+from hydra.utils import instantiate
+from omegaconf import DictConfig
+from torch.utils.data import DataLoader
+
+from fedmeta.models import test, test_meta, train, train_meta
+
+
+# pylint: disable=too-many-instance-attributes
+class FlowerClient(fl.client.NumPyClient):
+ """Standard Flower client for Local training."""
+
+ # pylint: disable=too-many-arguments
+ def __init__(
+ self,
+ net: torch.nn.Module,
+ trainloaders: DataLoader,
+ valloaders: DataLoader,
+ cid: str,
+ device: torch.device,
+ num_epochs: int,
+ learning_rate: float,
+ gradient_step: int,
+ ):
+ self.net = net
+ self.trainloaders = trainloaders
+ self.valloaders = valloaders
+ self.cid = int(cid)
+ self.device = device
+ self.num_epochs = num_epochs
+ self.learning_rate = learning_rate
+ self.gradient_step = gradient_step
+
+ def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
+ """Return the parameters of the current net."""
+ return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
+
+ def set_parameters(self, parameters: NDArrays) -> None:
+ """Change the parameters of the model using the given ones."""
+ params_dict = zip(self.net.state_dict().keys(), parameters)
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
+ self.net.load_state_dict(state_dict, strict=True)
+
+ def fit( # type: ignore
+ self, parameters: NDArrays, config: Dict[str, Scalar]
+ ) -> Tuple[NDArrays, int, Dict]:
+ """Implement distributed fit function for a given client."""
+ self.set_parameters(parameters)
+ algo = config["algo"]
+
+ # Total number of data for Weighted Avg and Grad
+ total_len = len(self.trainloaders["qry"][self.cid].dataset) + len(
+ self.trainloaders["sup"][self.cid].dataset
+ )
+
+ # FedAvg & FedAvg(Meta) train basic Learning
+ if algo in ("fedavg", "fedavg_meta"):
+ loss = train(
+ self.net,
+ self.trainloaders["sup"][self.cid],
+ self.device,
+ epochs=self.num_epochs,
+ learning_rate=self.learning_rate,
+ )
+ return self.get_parameters({}), total_len, {"loss": loss}
+
+ # FedMeta(MAML) & FedMeta(Meta-SGD) train inner and outer loop
+ if algo in ("fedmeta_maml", "fedmeta_meta_sgd"):
+ alpha = config["alpha"]
+ loss, grads = train_meta( # type: ignore
+ self.net,
+ self.trainloaders["sup"][self.cid],
+ self.trainloaders["qry"][self.cid],
+ alpha,
+ self.device,
+ self.gradient_step,
+ )
+ return self.get_parameters({}), total_len, {"loss": loss, "grads": grads}
+ raise ValueError("Unsupported algorithm")
+
+ def evaluate( # type: ignore
+ self, parameters: NDArrays, config: Dict[str, Scalar]
+ ) -> Tuple[float, int, Dict]:
+ """Implement distributed evaluation for a given client."""
+ self.set_parameters(parameters)
+ algo = config["algo"]
+
+ # Total number of data for Weighted Avg and Grad
+ total_len = len(self.valloaders["qry"][self.cid].dataset) + len(
+ self.valloaders["sup"][self.cid].dataset
+ )
+
+ # FedAvg & FedAvg(Meta) train basic Learning
+ if algo in ("fedavg", "fedavg_meta"):
+ loss, accuracy = test(
+ self.net,
+ self.valloaders["sup"][self.cid],
+ self.valloaders["qry"][self.cid],
+ self.device,
+ algo=str(config["algo"]),
+ data=str(config["data"]),
+ learning_rate=self.learning_rate,
+ )
+ return float(loss), total_len, {"correct": accuracy, "loss": loss}
+
+ # FedMeta(MAML) & FedMeta(Meta-SGD) train inner and outer loop
+ if algo in ("fedmeta_maml", "fedmeta_meta_sgd"):
+ alpha = config["alpha"]
+ loss, accuracy = test_meta(
+ self.net,
+ self.valloaders["sup"][self.cid],
+ self.valloaders["qry"][self.cid],
+ alpha,
+ self.device,
+ self.gradient_step,
+ )
+ return float(loss), total_len, {"correct": float(accuracy), "loss": loss}
+ raise ValueError("Unsupported algorithm")
+
+
+# pylint: disable=too-many-arguments
+def gen_client_fn(
+ num_epochs: int,
+ trainloaders: List[DataLoader],
+ valloaders: List[DataLoader],
+ learning_rate: float,
+ model: DictConfig,
+ gradient_step: int,
+) -> Callable[[str], FlowerClient]:
+ """Generate the client function that creates the Flower Clients.
+
+ Parameters
+ ----------
+ num_epochs : int
+ The number of local epochs each client should run the training for before
+ sending it to the server.
+ trainloaders: List[DataLoader]
+ A list of DataLoaders, each pointing to the dataset training partition
+ belonging to a particular client.
+ valloaders: List[DataLoader]
+ A list of DataLoaders, each pointing to the dataset validation partition
+ belonging to a particular client.
+ model: DictConfig
+ The global Model for Federated Learning.
+ learning_rate : float
+ The learning rate for the SGD optimizer of clients.
+ gradient_step : int
+ The gradient step for Meta Learning of clients.
+ FedAvg and FedAvg(Meta) is None
+
+ Returns
+ -------
+ Tuple[Callable[[str], FlowerClient], DataLoader]
+ A tuple containing the client function that creates Flower Clients and
+ the DataLoader that will be used for testing
+ """
+
+ def client_fn(cid: str) -> FlowerClient:
+ """Create a Flower client representing a single organization."""
+ # Load model
+ torch.manual_seed(42)
+ torch.cuda.manual_seed_all(42)
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ net = instantiate(model).to(device)
+
+ return FlowerClient(
+ net,
+ trainloaders,
+ valloaders,
+ cid,
+ device,
+ num_epochs,
+ learning_rate,
+ gradient_step,
+ )
+
+ return client_fn
diff --git a/baselines/fedmeta/fedmeta/conf/algo/fedavg.yaml b/baselines/fedmeta/fedmeta/conf/algo/fedavg.yaml
new file mode 100644
index 000000000000..df5a5b4b6b65
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/conf/algo/fedavg.yaml
@@ -0,0 +1,14 @@
+---
+# this is the config that will be loaded as default by main.py
+# Please follow the provided structure (this will ensuring all baseline follow
+# a similar configuration structure and hence be easy to customise)
+
+algo: fedavg
+femnist:
+ alpha: 0.0001
+ beta: 0
+
+shakespeare:
+ alpha: 0.001
+ beta: 0
+
diff --git a/baselines/fedmeta/fedmeta/conf/algo/fedavg_meta.yaml b/baselines/fedmeta/fedmeta/conf/algo/fedavg_meta.yaml
new file mode 100644
index 000000000000..928fd0a96cb9
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/conf/algo/fedavg_meta.yaml
@@ -0,0 +1,13 @@
+---
+# this is the config that will be loaded as default by main.py
+# Please follow the provided structure (this will ensuring all baseline follow
+# a similar configuration structure and hence be easy to customise)
+
+algo: fedavg_meta
+femnist:
+ alpha: 0.0001
+ beta: None
+
+shakespeare:
+ alpha: 0.001
+ beta: None
diff --git a/baselines/fedmeta/fedmeta/conf/algo/fedmeta_maml.yaml b/baselines/fedmeta/fedmeta/conf/algo/fedmeta_maml.yaml
new file mode 100644
index 000000000000..2c3c86df4edb
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/conf/algo/fedmeta_maml.yaml
@@ -0,0 +1,13 @@
+---
+# this is the config that will be loaded as default by main.py
+# Please follow the provided structure (this will ensuring all baseline follow
+# a similar configuration structure and hence be easy to customise)
+
+algo: fedmeta_maml
+femnist:
+ alpha: 0.001
+ beta: 0.0001
+
+shakespeare:
+ alpha: 0.1
+ beta: 0.01
diff --git a/baselines/fedmeta/fedmeta/conf/algo/fedmeta_meta_sgd.yaml b/baselines/fedmeta/fedmeta/conf/algo/fedmeta_meta_sgd.yaml
new file mode 100644
index 000000000000..cb72f1fe6e2c
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/conf/algo/fedmeta_meta_sgd.yaml
@@ -0,0 +1,13 @@
+---
+# this is the config that will be loaded as default by main.py
+# Please follow the provided structure (this will ensuring all baseline follow
+# a similar configuration structure and hence be easy to customise)
+
+algo: fedmeta_meta_sgd
+femnist:
+ alpha: 0.001
+ beta: 0.0001
+
+shakespeare:
+ alpha: 0.1
+ beta: 0.01
diff --git a/baselines/fedmeta/fedmeta/conf/config.yaml b/baselines/fedmeta/fedmeta/conf/config.yaml
new file mode 100644
index 000000000000..db1bfb11f169
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/conf/config.yaml
@@ -0,0 +1,21 @@
+---
+# this is the config that will be loaded as default by main.py
+# Please follow the provided structure (this will ensuring all baseline follow
+# a similar configuration structure and hence be easy to customise)
+
+path: ???
+num_epochs: 1
+clients_per_round: 4
+
+defaults:
+ - _self_
+ - algo: ???
+ - data: ???
+
+strategy:
+ _target_: fedmeta.strategy.FedMeta
+ fraction_fit: 0.00001
+ fraction_evaluate: 0.00001
+ min_fit_clients : ${clients_per_round}
+ min_evaluate_clients : ${clients_per_round}
+ min_available_clients : ${clients_per_round}
diff --git a/baselines/fedmeta/fedmeta/conf/data/femnist.yaml b/baselines/fedmeta/fedmeta/conf/data/femnist.yaml
new file mode 100644
index 000000000000..7f24eed1a951
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/conf/data/femnist.yaml
@@ -0,0 +1,17 @@
+---
+# this is the config that will be loaded as default by main.py
+# Please follow the provided structure (this will ensuring all baseline follow
+# a similar configuration structure and hence be easy to customise)
+
+model:
+ _target_: fedmeta.models.FemnistNetwork # model config
+
+client_resources:
+ num_cpus: 4
+ num_gpus: 0.25
+
+num_rounds: 2000
+data: femnist
+support_ratio: 0.2
+batch_size: 10
+gradient_step: 5
diff --git a/baselines/fedmeta/fedmeta/conf/data/shakespeare.yaml b/baselines/fedmeta/fedmeta/conf/data/shakespeare.yaml
new file mode 100644
index 000000000000..300e11da77f6
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/conf/data/shakespeare.yaml
@@ -0,0 +1,17 @@
+---
+# this is the config that will be loaded as default by main.py
+# Please follow the provided structure (this will ensuring all baseline follow
+# a similar configuration structure and hence be easy to customise)
+
+model:
+ _target_: fedmeta.models.StackedLSTM
+
+client_resources:
+ num_cpus: 4
+ num_gpus: 1.0
+
+num_rounds: 400
+data: shakespeare
+support_ratio: 0.2
+batch_size: 10
+gradient_step: 1
diff --git a/baselines/fedmeta/fedmeta/dataset.py b/baselines/fedmeta/fedmeta/dataset.py
new file mode 100644
index 000000000000..25ff5201bfdb
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/dataset.py
@@ -0,0 +1,234 @@
+"""Handle basic dataset creation.
+
+In case of PyTorch it should return dataloaders for your dataset (for both the clients
+and the server). If you are using a custom dataset class, this module is the place to
+define it. If your dataset requires to be downloaded (and this is not done
+automatically -- e.g. as it is the case for many dataset in TorchVision) and
+partitioned, please include all those functions and logic in the
+`dataset_preparation.py` module. You can use all those functions from functions/methods
+defined here of course.
+"""
+
+from typing import Dict, List, Tuple
+
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from omegaconf import DictConfig
+from torch.utils.data import DataLoader, Dataset
+
+from fedmeta.dataset_preparation import (
+ _partition_data,
+ split_train_validation_test_clients,
+)
+from fedmeta.utils import letter_to_vec, word_to_indices
+
+
+class ShakespeareDataset(Dataset):
+ """
+ [LEAF: A Benchmark for Federated Settings](https://github.com/TalwalkarLab/leaf).
+
+ We imported the preprocessing method for the Shakespeare dataset from GitHub.
+
+ word_to_indices : returns a list of character indices
+ sentences_to_indices: converts an index to a one-hot vector of a given size.
+ letter_to_vec : returns one-hot representation of given letter
+
+ """
+
+ def __init__(self, data):
+ sentence, label = data["x"], data["y"]
+ sentences_to_indices = [word_to_indices(word) for word in sentence]
+ sentences_to_indices = np.array(sentences_to_indices)
+ self.sentences_to_indices = np.array(sentences_to_indices, dtype=np.int64)
+ self.labels = np.array(
+ [letter_to_vec(letter) for letter in label], dtype=np.int64
+ )
+
+ def __len__(self):
+ """Return the number of labels present in the dataset.
+
+ Returns
+ -------
+ int: The total number of labels.
+ """
+ return len(self.labels)
+
+ def __getitem__(self, index):
+ """Retrieve the data and its corresponding label at a given index.
+
+ Args:
+ index (int): The index of the data item to fetch.
+
+ Returns
+ -------
+ tuple: (data tensor, label tensor)
+ """
+ data, target = self.sentences_to_indices[index], self.labels[index]
+ return torch.tensor(data), torch.tensor(target)
+
+
+class FemnistDataset(Dataset):
+ """
+ [LEAF: A Benchmark for Federated Settings](https://github.com/TalwalkarLab/leaf).
+
+ We imported the preprocessing method for the Femnist dataset from GitHub.
+ """
+
+ def __init__(self, dataset, transform):
+ self.x = dataset["x"]
+ self.y = dataset["y"]
+ self.transform = transform
+
+ def __getitem__(self, index):
+ """Retrieve the input data and its corresponding label at a given index.
+
+ Args:
+ index (int): The index of the data item to fetch.
+
+ Returns
+ -------
+ tuple:
+ - input_data (torch.Tensor): Reshaped and optionally transformed data.
+ - target_data (int or torch.Tensor): Label for the input data.
+ """
+ input_data = np.array(self.x[index]).reshape(28, 28)
+ if self.transform:
+ input_data = self.transform(input_data)
+ target_data = self.y[index]
+ return input_data.to(torch.float32), target_data
+
+ def __len__(self):
+ """Return the number of labels present in the dataset.
+
+ Returns
+ -------
+ int: The total number of labels.
+ """
+ return len(self.y)
+
+
+def load_datasets(
+ config: DictConfig,
+ path: str,
+) -> Tuple[DataLoader, DataLoader, DataLoader]:
+ """Create the dataloaders to be fed into the model.
+
+ Parameters
+ ----------
+ config: DictConfig
+ data: float
+ Used data type
+ batch_size : int
+ The size of the batches to be fed into the model,
+ by default 10
+ support_ratio : float
+ The ratio of Support set for each client.(between 0 and 1)
+ by default 0.2
+ path : str
+ The path where the leaf dataset was downloaded
+
+ Returns
+ -------
+ Tuple[DataLoader, DataLoader, DataLoader]
+ """
+ dataset = _partition_data(
+ data_type=config.data, dir_path=path, support_ratio=config.support_ratio
+ )
+
+ # Client list : 0.8, 0.1, 0.1
+ clients_list = split_train_validation_test_clients(dataset[0]["users"])
+
+ trainloaders: Dict[str, List[DataLoader]] = {"sup": [], "qry": []}
+ valloaders: Dict[str, List[DataLoader]] = {"sup": [], "qry": []}
+ testloaders: Dict[str, List[DataLoader]] = {"sup": [], "qry": []}
+
+ data_type = config.data
+ if data_type == "femnist":
+ transform = transforms.Compose([transforms.ToTensor()])
+ for user in clients_list[0]:
+ trainloaders["sup"].append(
+ DataLoader(
+ FemnistDataset(dataset[0]["user_data"][user], transform),
+ batch_size=config.batch_size,
+ shuffle=True,
+ )
+ )
+ trainloaders["qry"].append(
+ DataLoader(
+ FemnistDataset(dataset[1]["user_data"][user], transform),
+ batch_size=config.batch_size,
+ )
+ )
+ for user in clients_list[1]:
+ valloaders["sup"].append(
+ DataLoader(
+ FemnistDataset(dataset[0]["user_data"][user], transform),
+ batch_size=config.batch_size,
+ )
+ )
+ valloaders["qry"].append(
+ DataLoader(
+ FemnistDataset(dataset[1]["user_data"][user], transform),
+ batch_size=config.batch_size,
+ )
+ )
+ for user in clients_list[2]:
+ testloaders["sup"].append(
+ DataLoader(
+ FemnistDataset(dataset[0]["user_data"][user], transform),
+ batch_size=config.batch_size,
+ )
+ )
+ testloaders["qry"].append(
+ DataLoader(
+ FemnistDataset(dataset[1]["user_data"][user], transform),
+ batch_size=config.batch_size,
+ )
+ )
+
+ elif data_type == "shakespeare":
+ for user in clients_list[0]:
+ trainloaders["sup"].append(
+ DataLoader(
+ ShakespeareDataset(dataset[0]["user_data"][user]),
+ batch_size=config.batch_size,
+ shuffle=True,
+ )
+ )
+ trainloaders["qry"].append(
+ DataLoader(
+ ShakespeareDataset(dataset[1]["user_data"][user]),
+ batch_size=config.batch_size,
+ )
+ )
+ for user in clients_list[1]:
+ valloaders["sup"].append(
+ DataLoader(
+ ShakespeareDataset(dataset[0]["user_data"][user]),
+ batch_size=config.batch_size,
+ shuffle=True,
+ )
+ )
+ valloaders["qry"].append(
+ DataLoader(
+ ShakespeareDataset(dataset[1]["user_data"][user]),
+ batch_size=config.batch_size,
+ )
+ )
+ for user in clients_list[2]:
+ testloaders["sup"].append(
+ DataLoader(
+ ShakespeareDataset(dataset[0]["user_data"][user]),
+ batch_size=config.batch_size,
+ shuffle=True,
+ )
+ )
+ testloaders["qry"].append(
+ DataLoader(
+ ShakespeareDataset(dataset[1]["user_data"][user]),
+ batch_size=config.batch_size,
+ )
+ )
+
+ return trainloaders, valloaders, testloaders
diff --git a/baselines/fedmeta/fedmeta/dataset_preparation.py b/baselines/fedmeta/fedmeta/dataset_preparation.py
new file mode 100644
index 000000000000..c139cdf86d69
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/dataset_preparation.py
@@ -0,0 +1,190 @@
+"""Handle the dataset partitioning and (optionally) complex downloads.
+
+Please add here all the necessary logic to either download, uncompress, pre/post-process
+your dataset (or all of the above). If the desired way of running your baseline is to
+first download the dataset and partition it and then run the experiments, please
+uncomment the lines below and tell us in the README.md (see the "Running the Experiment"
+block) that this file should be executed first.
+"""
+import json
+import os
+from collections import defaultdict
+from typing import Any, DefaultDict, Dict, List, Tuple
+
+import numpy as np
+from sklearn.model_selection import train_test_split
+
+
+def _read_dataset(path: str) -> Tuple[List, DefaultDict, List]:
+ """Read (if necessary) and returns the leaf dataset.
+
+ Parameters
+ ----------
+ path : str
+ The path where the leaf dataset was downloaded
+
+ Returns
+ -------
+ Tuple[user, data[x,y], num_total_data]
+ The dataset for training and the dataset for testing.
+ """
+ users = []
+ data: DefaultDict[str, Any] = defaultdict(lambda: None)
+ num_example = []
+
+ files = [f for f in os.listdir(path) if f.endswith(".json")]
+
+ for file_name in files:
+ with open(f"{path}/{file_name}") as file:
+ dataset = json.load(file)
+ users.extend(dataset["users"])
+ data.update(dataset["user_data"])
+ num_example.extend(dataset["num_samples"])
+
+ users = sorted(data.keys())
+ return users, data, num_example
+
+
+def support_query_split(
+ data,
+ label,
+ support_ratio: float,
+) -> Tuple[List, List, List, List]:
+ """Separate support set and query set.
+
+ Parameters
+ ----------
+ data: DefaultDict,
+ Raw all Datasets
+ label: List,
+ Raw all Labels
+ support_ratio : float
+ The ratio of Support set for each client.(between 0 and 1)
+ by default 0.2
+
+ Returns
+ -------
+ Tuple[List, List, List, List]
+ Support set and query set classification of data and labels
+ """
+ x_train, x_test, y_train, y_test = train_test_split(
+ data, label, train_size=support_ratio, stratify=label, random_state=42
+ )
+
+ return x_train, x_test, y_train, y_test
+
+
+def split_train_validation_test_clients(
+ clients: List,
+ train_rate: float = 0.8,
+ val_rate: float = 0.1,
+) -> Tuple[List[str], List[str], List[str]]:
+ """Classification of all clients into train, valid, and test.
+
+ Parameters
+ ----------
+ clients: List,
+ Full list of clients for the sampled leaf dataset.
+ train_rate: float, optional
+ The ratio of training clients to total clients
+ by default 0.8
+ val_rate: float, optional
+ The ratio of validation clients to total clients
+ by default 0.1
+
+ Returns
+ -------
+ Tuple[List, List, List]
+ List of each train client, valid client, and test client
+ """
+ np.random.seed(42)
+ train_rate = int(train_rate * len(clients))
+ val_rate = int(val_rate * len(clients))
+
+ index = np.random.permutation(len(clients))
+ trans_numpy = np.asarray(clients)
+ train_clients = trans_numpy[index[:train_rate]].tolist()
+ val_clients = trans_numpy[index[train_rate : train_rate + val_rate]].tolist()
+ test_clients = trans_numpy[index[train_rate + val_rate :]].tolist()
+
+ return train_clients, val_clients, test_clients
+
+
+# pylint: disable=too-many-locals
+def _partition_data(
+ data_type: str,
+ dir_path: str,
+ support_ratio: float,
+) -> Tuple[Dict, Dict]:
+ """Classification of support sets and query sets by client.
+
+ Parameters
+ ----------
+ data_type: str,
+ The type of femnist for classification or shakespeare for regression
+ dir_path: str,
+ The path where the leaf dataset was downloaded
+ support_ratio: float,
+ The ratio of Support set for each client.(between 0 and 1)
+ by default 0.2
+
+ Returns
+ -------
+ Tuple[Dict, Dict]
+ Return support set and query set for total data
+ """
+ train_path = f"{dir_path}/train"
+ test_path = f"{dir_path}/test"
+
+ train_users, train_data, _ = _read_dataset(train_path)
+ _, test_data, _ = _read_dataset(test_path)
+
+ all_dataset: Dict[str, Any] = {"users": [], "user_data": {}, "num_samples": []}
+ support_dataset: Dict[str, Any] = {"users": [], "user_data": {}, "num_samples": []}
+ query_dataset: Dict[str, Any] = {"users": [], "user_data": {}, "num_samples": []}
+
+ for user in train_users:
+ all_x = np.asarray(train_data[user]["x"] + test_data[user]["x"])
+ all_y = np.asarray(train_data[user]["y"] + test_data[user]["y"])
+
+ if data_type == "femnist":
+ unique, counts = np.unique(all_y, return_counts=True)
+ class_counts = dict(zip(unique, counts))
+
+ # Find classes with only one sample
+ classes_to_remove = [
+ cls for cls, count in class_counts.items() if count == 1
+ ]
+
+ # Filter out the samples of those classes
+ mask = np.isin(all_y, classes_to_remove, invert=True)
+
+ all_x = all_x[mask]
+ all_y = all_y[mask]
+
+ # Client filtering for support set and query set classification
+ try:
+ sup_x, qry_x, sup_y, qry_y = support_query_split(
+ all_x, all_y, support_ratio
+ )
+ except Exception: # pylint: disable=broad-except
+ continue
+
+ elif data_type == "shakespeare":
+ sup_x, qry_x, sup_y, qry_y = train_test_split(
+ all_x, all_y, train_size=support_ratio, random_state=42
+ )
+
+ all_dataset["users"].append(user)
+ all_dataset["user_data"][user] = {"x": all_x.tolist(), "y": all_y.tolist()}
+ all_dataset["num_samples"].append(len(all_y.tolist()))
+
+ support_dataset["users"].append(user)
+ support_dataset["user_data"][user] = {"x": sup_x, "y": sup_y}
+ support_dataset["num_samples"].append(len(sup_y))
+
+ query_dataset["users"].append(user)
+ query_dataset["user_data"][user] = {"x": qry_x, "y": qry_y}
+ query_dataset["num_samples"].append(len(qry_y))
+
+ return support_dataset, query_dataset
diff --git a/baselines/fedmeta/fedmeta/fedmeta_client_manager.py b/baselines/fedmeta/fedmeta/fedmeta_client_manager.py
new file mode 100644
index 000000000000..098922b92215
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/fedmeta_client_manager.py
@@ -0,0 +1,67 @@
+"""Handles clients that are sampled every round.
+
+In a FedMeta experiment, there is a train and a test client. So we modified the manager
+to sample from each list each round.
+"""
+
+import random
+from logging import INFO
+from typing import List, Optional
+
+from flwr.common.logger import log
+from flwr.server.client_manager import SimpleClientManager
+from flwr.server.client_proxy import ClientProxy
+from flwr.server.criterion import Criterion
+
+
+class FedmetaClientManager(SimpleClientManager):
+ """In the fit phase, clients must be sampled from the training client list.
+
+ And in the evaluate stage, clients must be sampled from the validation client list.
+ So we modify 'fedmeta_client_manager' to sample clients from [cid: List] for each
+ list.
+ """
+
+ def __init__(self, valid_client, **kwargs):
+ super().__init__(**kwargs)
+ self.valid_client = valid_client
+
+ # pylint: disable=too-many-arguments
+ def sample( # pylint: disable=arguments-differ
+ self,
+ num_clients: int,
+ min_num_clients: Optional[int] = None,
+ criterion: Optional[Criterion] = None,
+ server_round: Optional[int] = None,
+ step: Optional[str] = None,
+ ) -> List[ClientProxy]:
+ """Sample a number of Flower ClientProxy instances."""
+ # Block until at least num_clients are connected.
+ if min_num_clients is None:
+ min_num_clients = num_clients
+ self.wait_for(min_num_clients)
+
+ # Sample clients which meet the criterion
+ if step == "evaluate":
+ available_cids = [str(result) for result in range(0, self.valid_client)]
+ else:
+ available_cids = list(self.clients)
+
+ if criterion is not None:
+ available_cids = [
+ cid for cid in available_cids if criterion.select(self.clients[cid])
+ ]
+
+ if num_clients > len(available_cids):
+ log(
+ INFO,
+ "Sampling failed: number of available clients"
+ " (%s) is less than number of requested clients (%s).",
+ len(available_cids),
+ num_clients,
+ )
+ return []
+ if server_round is not None:
+ random.seed(server_round)
+ sampled_cids = random.sample(available_cids, num_clients)
+ return [self.clients[cid] for cid in sampled_cids]
diff --git a/baselines/fedmeta/fedmeta/main.py b/baselines/fedmeta/fedmeta/main.py
new file mode 100644
index 000000000000..e43ad94a3089
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/main.py
@@ -0,0 +1,100 @@
+"""Create and connect the building blocks for your experiments; start the simulation.
+
+It includes processioning the dataset, instantiate strategy, specify how the global
+model is going to be evaluated, etc. At the end, this script saves the results.
+"""
+
+
+import flwr as fl
+import hydra
+from hydra.core.hydra_config import HydraConfig
+from hydra.utils import instantiate
+from omegaconf import DictConfig, OmegaConf
+
+import fedmeta.client as client
+from fedmeta.dataset import load_datasets
+from fedmeta.fedmeta_client_manager import FedmetaClientManager
+from fedmeta.strategy import weighted_average
+from fedmeta.utils import plot_from_pkl, save_graph_params
+
+
+@hydra.main(config_path="conf", config_name="config", version_base=None)
+def main(cfg: DictConfig) -> None:
+ """Run the baseline.
+
+ Parameters
+ ----------
+ cfg : DictConfig
+ An omegaconf object that stores the hydra config.
+
+ algo : FedAvg, FedAvg(Meta), FedMeta(MAML), FedMeta(Meta-SGD)
+ data : Femnist, Shakespeare
+ """
+ # print config structured as YAML
+ print(OmegaConf.to_yaml(cfg))
+
+ # partition dataset and get dataloaders
+ trainloaders, valloaders, _ = load_datasets(config=cfg.data, path=cfg.path)
+
+ # prepare function that will be used to spawn each client
+ client_fn = client.gen_client_fn(
+ num_epochs=cfg.num_epochs,
+ trainloaders=trainloaders,
+ valloaders=valloaders,
+ learning_rate=cfg.algo[cfg.data.data].alpha,
+ model=cfg.data.model,
+ gradient_step=cfg.data.gradient_step,
+ )
+
+ # prepare strategy function
+ strategy = instantiate(
+ cfg.strategy,
+ evaluate_metrics_aggregation_fn=weighted_average,
+ alpha=cfg.algo[cfg.data.data].alpha,
+ beta=cfg.algo[cfg.data.data].beta,
+ data=cfg.data.data,
+ algo=cfg.algo.algo,
+ )
+
+ # Start Simulation
+ history = fl.simulation.start_simulation(
+ client_fn=client_fn,
+ num_clients=len(trainloaders["sup"]),
+ config=fl.server.ServerConfig(num_rounds=cfg.data.num_rounds),
+ client_resources={
+ "num_cpus": cfg.data.client_resources.num_cpus,
+ "num_gpus": cfg.data.client_resources.num_gpus,
+ },
+ client_manager=FedmetaClientManager(valid_client=len(valloaders["qry"])),
+ strategy=strategy,
+ )
+
+ # 6. Save your results
+ # Here you can save the `history` returned by the simulation and include
+ # also other buffers, statistics, info needed to be saved in order to later
+ # on generate the plots you provide in the README.md. You can for instance
+ # access elements that belong to the strategy for example:
+ # data = strategy.get_my_custom_data() -- assuming you have such method defined.
+ # Hydra will generate for you a directory each time you run the code. You
+ # can retrieve the path to that directory with this:
+ # save_path = HydraConfig.get().runtime.output_dir
+
+ print("................")
+ print(history)
+ output_path = HydraConfig.get().runtime.output_dir
+
+ data_params = {
+ "algo": cfg.algo.algo,
+ "data": cfg.data.data,
+ "loss": history.losses_distributed,
+ "accuracy": history.metrics_distributed,
+ "path": output_path,
+ }
+
+ save_graph_params(data_params)
+ plot_from_pkl(directory=output_path)
+ print("................")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/baselines/fedmeta/fedmeta/models.py b/baselines/fedmeta/fedmeta/models.py
new file mode 100644
index 000000000000..f065ae6c372b
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/models.py
@@ -0,0 +1,445 @@
+"""Define our models, and training and eval functions.
+
+If your model is 100% off-the-shelf (e.g. directly from torchvision without requiring
+modifications) you might be better off instantiating your model directly from the Hydra
+config. In this way, swapping your model for another one can be done without changing
+the python code at all
+"""
+
+from copy import deepcopy
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+
+
+class StackedLSTM(nn.Module):
+ """StackedLSTM architecture.
+
+ As described in Fei Chen 2018 paper :
+
+ [FedMeta: Federated Meta-Learning with Fast Convergence and Efficient Communication]
+ (https://arxiv.org/abs/1802.07876)
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ self.embedding = nn.Embedding(80, 8)
+ self.lstm = nn.LSTM(8, 256, num_layers=2, dropout=0.5, batch_first=True)
+ self.fully_ = nn.Linear(256, 80)
+
+ def forward(self, text):
+ """Forward pass of the StackedLSTM.
+
+ Parameters
+ ----------
+ text : torch.Tensor
+ Input Tensor that will pass through the network
+
+ Returns
+ -------
+ torch.Tensor
+ The resulting Tensor after it has passed through the network
+ """
+ embedded = self.embedding(text)
+ self.lstm.flatten_parameters()
+ lstm_out, _ = self.lstm(embedded)
+ final_output = self.fully_(lstm_out[:, -1, :])
+ return final_output
+
+
+class FemnistNetwork(nn.Module):
+ """Convolutional Neural Network architecture.
+
+ As described in Fei Chen 2018 paper :
+
+ [FedMeta: Federated Meta-Learning with Fast Convergence and Efficient Communication]
+ (https://arxiv.org/abs/1802.07876)
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2)
+ self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2))
+ self.conv2 = nn.Conv2d(
+ in_channels=32, out_channels=64, kernel_size=5, padding=2
+ )
+ self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2))
+ self.linear1 = nn.Linear(7 * 7 * 64, 2048)
+ self.linear2 = nn.Linear(2048, 62)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass of the CNN.
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input Tensor that will pass through the network
+
+ Returns
+ -------
+ torch.Tensor
+ The resulting Tensor after it has passed through the network
+ """
+ x = torch.relu(self.conv1(x))
+ x = self.maxpool1(x)
+ x = torch.relu(self.conv2(x))
+ x = self.maxpool2(x)
+ x = torch.flatten(x, start_dim=1)
+ x = torch.relu((self.linear1(x)))
+ x = self.linear2(x)
+ return x
+
+
+# pylint: disable=too-many-arguments
+def train(
+ net: nn.Module,
+ trainloader: DataLoader,
+ device: torch.device,
+ epochs: int,
+ learning_rate: float,
+) -> Tuple[float]:
+ """Train the network on the training set.
+
+ Parameters
+ ----------
+ net : nn.Module
+ The neural network to train.
+ trainloader : DataLoader
+ The DataLoader containing the data to train the network on.
+ testloader : DataLoader
+ The DataLoader containing the data to test the network on.
+ device : torch.device
+ The device on which the model should be trained, either 'cpu' or 'cuda'.
+ epochs : int
+ The number of epochs the model should be trained for.
+ learning_rate : float
+ The learning rate for the optimizer.
+
+ Returns
+ -------
+ nn.Module
+ The model that has been trained for one epoch.
+ loss
+ The Loss that bas been trained for one epoch
+ """
+ criterion = torch.nn.CrossEntropyLoss()
+ optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=0.001)
+ net.train()
+ for _ in range(epochs):
+ net, loss = _train_one_epoch(net, trainloader, device, criterion, optimizer)
+ return loss
+
+
+def _train_one_epoch(
+ net: nn.Module,
+ trainloader: DataLoader,
+ device: torch.device,
+ criterion: torch.nn.CrossEntropyLoss,
+ optimizer: torch.optim.Adam,
+) -> nn.Module:
+ """Train for one epoch.
+
+ Parameters
+ ----------
+ net : nn.Module
+ The neural network to train.
+ trainloader : DataLoader
+ The DataLoader containing the data to train the network on.
+ device : torch.device
+ The device on which the model should be trained, either 'cpu' or 'cuda'.
+ criterion : torch.nn.CrossEntropyLoss
+ The loss function to use for training
+ optimizer : torch.optim.Adam
+ The optimizer to use for training
+
+ Returns
+ -------
+ nn.Module
+ The model that has been trained for one epoch.
+ total_loss
+ The Loss that has been trained for one epoch.
+ """
+ total_loss = 0.0
+
+ for images, labels in trainloader:
+ images, labels = images.to(device), labels.to(device)
+ optimizer.zero_grad()
+ loss = criterion(net(images), labels)
+ total_loss += loss.item() * labels.size(0)
+ loss.backward()
+ optimizer.step()
+ total_loss = total_loss / len(trainloader.dataset)
+ return net, total_loss
+
+
+# pylint: disable=too-many-locals
+def test(
+ net: nn.Module,
+ trainloader: DataLoader,
+ testloader: DataLoader,
+ device: torch.device,
+ algo: str,
+ data: str,
+ learning_rate: float,
+) -> Tuple[float, float]:
+ """Evaluate the network on the entire test set.
+
+ Parameters
+ ----------
+ net : nn.Module
+ The neural network to test.
+ trainloader: DataLoader,
+ The DataLoader containing the data to train the network on.
+ testloader : DataLoader
+ The DataLoader containing the data to test the network on.
+ device : torch.device
+ The device on which the model should be tested, either 'cpu' or 'cuda'.
+ algo: str
+ The Algorithm of Federated Learning
+ data: str
+ The training data type of Federated Learning
+ learning_rate: float
+ The learning rate for the optimizer.
+
+ Returns
+ -------
+ Tuple[float, float]
+ The loss and the accuracy of the input model on the given data.
+ """
+ criterion = torch.nn.CrossEntropyLoss()
+ total_loss = 0.0
+ if algo == "fedavg_meta":
+ optimizer = torch.optim.Adam(
+ net.parameters(), lr=learning_rate, weight_decay=0.001
+ )
+ net.train()
+ optimizer.zero_grad()
+ if data == "femnist":
+ for images, labels in trainloader:
+ images, labels = images.to(device), labels.to(device)
+ loss = criterion(net(images), labels)
+ loss.backward()
+ total_loss += loss * labels.size(0)
+ total_loss = total_loss / len(trainloader.dataset)
+ optimizer.step()
+
+ elif data == "shakespeare":
+ for images, labels in trainloader:
+ images, labels = images.to(device), labels.to(device)
+ loss = criterion(net(images), labels)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ correct, total, loss = 0, 0, 0.0
+ net.eval()
+ with torch.no_grad():
+ for images, labels in testloader:
+ images, labels = images.to(device), labels.to(device)
+ outputs = net(images)
+ loss += criterion(outputs, labels).item() * labels.size(0)
+ _, predicted = torch.max(outputs.data, 1)
+ total += labels.size(0)
+ correct += (predicted == labels).sum().item()
+ if len(testloader.dataset) == 0:
+ raise ValueError("Testloader can't be 0, exiting...")
+ loss /= len(testloader.dataset)
+ accuracy = correct / total
+ return loss, accuracy
+
+
+def train_meta(
+ net: nn.Module,
+ supportloader: DataLoader,
+ queryloader: DataLoader,
+ alpha: torch.nn.ParameterList,
+ device: torch.device,
+ gradient_step: int,
+) -> Tuple[float, List]:
+ """Train the network on the training set.
+
+ Parameters
+ ----------
+ net : nn.Module
+ The neural network to train.
+ supportloader : DataLoader
+ The DataLoader containing the data to inner loop train the network on.
+ queryloader : DataLoader
+ The DataLoader containing the data to outer loop train the network on.
+ alpha : torch.nn.ParameterList
+ The learning rate for the optimizer.
+ device : torch.device
+ The device on which the model should be trained, either 'cpu' or 'cuda'.
+ gradient_step : int
+ The number of inner loop learning
+
+ Returns
+ -------
+ total_loss
+ The Loss that has been trained for one epoch.
+ grads
+ The gradients that has been trained for one epoch.
+ """
+ criterion = torch.nn.CrossEntropyLoss()
+ for _ in range(1):
+ loss, grads = _train_meta_one_epoch(
+ net, supportloader, queryloader, alpha, criterion, device, gradient_step
+ )
+ return loss, grads
+
+
+# pylint: disable=too-many-locals
+def _train_meta_one_epoch(
+ net: nn.Module,
+ supportloader: DataLoader,
+ queryloader: DataLoader,
+ alpha: torch.nn.ParameterList,
+ criterion: torch.nn.CrossEntropyLoss,
+ device: torch.device,
+ gradient_step: int,
+) -> Tuple[float, List]:
+ """Train for one epoch.
+
+ Parameters
+ ----------
+ net : nn.Module
+ The neural network to train.
+ supportloader : DataLoader
+ The DataLoader containing the data to inner loop train the network on.
+ queryloader : DataLoader
+ The DataLoader containing the data to outer loop train the network on.
+ alpha : torch.nn.ParameterList
+ The learning rate for the optimizer.
+ criterion : torch.nn.CrossEntropyLoss
+ The loss function to use for training
+ device : torch.device
+ The device on which the model should be trained, either 'cpu' or 'cuda'.
+ gradient_step : int
+ The number of inner loop learning
+
+ Returns
+ -------
+ total_loss
+ The Loss that has been trained for one epoch.
+ grads
+ The gradients that has been trained for one epoch.
+ """
+ num_adaptation_steps = gradient_step
+ train_net = deepcopy(net)
+ alpha = [alpha.to(device) for alpha in alpha]
+ train_net.train()
+ for _ in range(num_adaptation_steps):
+ loss_sum = 0.0
+ sup_num_sample = []
+ sup_total_loss = []
+ for images, labels in supportloader:
+ images, labels = images.to(device), labels.to(device)
+ loss = criterion(train_net(images), labels)
+ loss_sum += loss * labels.size(0)
+ sup_num_sample.append(labels.size(0))
+ sup_total_loss.append(loss * labels.size(0))
+ grads = torch.autograd.grad(
+ loss, list(train_net.parameters()), create_graph=True, retain_graph=True
+ )
+
+ for param, grad_, alphas in zip(train_net.parameters(), grads, alpha):
+ param.data = param.data - alphas * grad_
+
+ for param in train_net.parameters():
+ if param.grad is not None:
+ param.grad.zero_()
+
+ qry_total_loss = []
+ qry_num_sample = []
+ loss_sum = 0.0
+ for images, labels in queryloader:
+ images, labels = images.to(device), labels.to(device)
+ loss = criterion(train_net(images), labels)
+ loss_sum += loss * labels.size(0)
+ qry_num_sample.append(labels.size(0))
+ qry_total_loss.append(loss.item())
+ loss_sum = loss_sum / sum(qry_num_sample)
+ grads = torch.autograd.grad(loss_sum, list(train_net.parameters()))
+
+ for param in train_net.parameters():
+ if param.grad is not None:
+ param.grad.zero_()
+
+ grads = [grad_.cpu().numpy() for grad_ in grads]
+ loss = sum(sup_total_loss) / sum(sup_num_sample)
+ return loss, grads
+
+
+def test_meta(
+ net: nn.Module,
+ supportloader: DataLoader,
+ queryloader: DataLoader,
+ alpha: torch.nn.ParameterList,
+ device: torch.device,
+ gradient_step: int,
+) -> Tuple[float, float]:
+ """Evaluate the network on the entire test set.
+
+ Parameters
+ ----------
+ net : nn.Module
+ The neural network to test.
+ supportloader : DataLoader
+ The DataLoader containing the data to test the network on.
+ queryloader : DataLoader
+ The DataLoader containing the data to test the network on.
+ alpha : torch.nn.ParameterList
+ The learning rate for the optimizer.
+ device : torch.device
+ The device on which the model should be tested, either 'cpu' or 'cuda'.
+ gradient_step : int
+ The number of inner loop learning
+
+ Returns
+ -------
+ Tuple[float, float]
+ The loss and the accuracy of the input model on the given data.
+ """
+ criterion = torch.nn.CrossEntropyLoss()
+ test_net = deepcopy(net)
+ num_adaptation_steps = gradient_step
+ alpha = [alpha_tensor.to(device) for alpha_tensor in alpha]
+ test_net.train()
+ for _ in range(num_adaptation_steps):
+ loss_sum = 0.0
+ sup_num_sample = []
+ sup_total_loss = []
+ for images, labels in supportloader:
+ images, labels = images.to(device), labels.to(device)
+ loss = criterion(test_net(images), labels)
+ loss_sum += loss * labels.size(0)
+ sup_num_sample.append(labels.size(0))
+ sup_total_loss.append(loss)
+ grads = torch.autograd.grad(
+ loss, list(test_net.parameters()), create_graph=True, retain_graph=True
+ )
+
+ for param, grad_, alphas in zip(test_net.parameters(), grads, alpha):
+ param.data -= alphas * grad_
+
+ for param in test_net.parameters():
+ if param.grad is not None:
+ param.grad.zero_()
+
+ test_net.eval()
+ correct, total, loss = 0, 0, 0.0
+ for images, labels in queryloader:
+ images, labels = images.to(device), labels.to(device)
+ outputs = test_net(images)
+ loss += criterion(outputs, labels).item() * labels.size(0)
+ _, predicted = torch.max(outputs.data, 1)
+ total += labels.size(0)
+ correct += (predicted == labels).sum().item()
+ if len(queryloader.dataset) == 0:
+ raise ValueError("Testloader can't be 0, exiting...")
+ loss = loss / total
+ accuracy = correct / total
+ return loss, accuracy
diff --git a/baselines/fedmeta/fedmeta/server.py b/baselines/fedmeta/fedmeta/server.py
new file mode 100644
index 000000000000..b24928de48b3
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/server.py
@@ -0,0 +1 @@
+"""Flower Server."""
diff --git a/baselines/fedmeta/fedmeta/strategy.py b/baselines/fedmeta/fedmeta/strategy.py
new file mode 100644
index 000000000000..7947938116e9
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/strategy.py
@@ -0,0 +1,333 @@
+"""Optionally define a custom strategy.
+
+Needed only when the strategy is not yet implemented in Flower or because you want to
+extend or modify the functionality of an existing strategy.
+"""
+from collections import OrderedDict
+from logging import WARNING
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from flwr.common import (
+ EvaluateIns,
+ EvaluateRes,
+ FitIns,
+ FitRes,
+ Metrics,
+ NDArrays,
+ Parameters,
+ Scalar,
+ ndarrays_to_parameters,
+ parameters_to_ndarrays,
+)
+from flwr.common.logger import log
+from flwr.server.client_manager import ClientManager
+from flwr.server.client_proxy import ClientProxy
+from flwr.server.strategy import FedAvg
+from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg
+
+from fedmeta.models import FemnistNetwork, StackedLSTM
+from fedmeta.utils import update_ema
+
+
+# pylint: disable=too-many-arguments
+def fedmeta_update_meta_sgd(
+ net: torch.nn.Module,
+ alpha: torch.nn.ParameterList,
+ beta: float,
+ weights_results: NDArrays,
+ gradients_aggregated: NDArrays,
+ weight_decay: float,
+) -> Tuple[NDArrays, torch.nn.ParameterList]:
+ """Update model parameters for FedMeta(Meta-SGD).
+
+ Parameters
+ ----------
+ net : torch.nn.Module
+ The list of metrics to aggregate.
+ alpha : torch.nn.ParameterList
+ alpha is the learning rate. it is updated with parameters in FedMeta (Meta-SGD).
+ beta : float
+ beta is the learning rate for updating parameters and alpha on the server.
+ weights_results : List[Tuple[NDArrays, int]]
+ These are the global model parameters for the current round.
+ gradients_aggregated : List[Tuple[NDArrays, int]]
+ Weighted average of the gradient in the current round.
+ WD : float
+ The weight decay for Adam optimizer
+
+ Returns
+ -------
+ weights_prime : List[Tuple[NDArrays, int]]
+ These are updated parameters.
+ alpha : torch.nn.ParameterLis
+ These are updated alpha.
+ """
+ params_dict = zip(net.state_dict().keys(), weights_results)
+ state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
+ net.load_state_dict(state_dict, strict=True)
+ optimizer = torch.optim.Adam(
+ list(net.parameters()) + list(alpha), lr=beta, weight_decay=weight_decay
+ )
+ for params, grad_ins, alphas in zip(net.parameters(), gradients_aggregated, alpha):
+ params.grad = torch.tensor(grad_ins).to(params.dtype)
+ alphas.grad = torch.tensor(grad_ins).to(params.dtype)
+ optimizer.step()
+ optimizer.zero_grad()
+ weights_prime = [val.cpu().numpy() for _, val in net.state_dict().items()]
+
+ return weights_prime, alpha
+
+
+def fedmeta_update_maml(
+ net: torch.nn.Module,
+ beta: float,
+ weights_results: NDArrays,
+ gradients_aggregated: NDArrays,
+ weight_decay: float,
+) -> NDArrays:
+ """Update model parameters for FedMeta(Meta-SGD).
+
+ Parameters
+ ----------
+ net : torch.nn.Module
+ The list of metrics to aggregate.
+ beta : float
+ beta is the learning rate for updating parameters on the server.
+ weights_results : List[Tuple[NDArrays, int]]
+ These are the global model parameters for the current round.
+ gradients_aggregated : List[Tuple[NDArrays, int]]
+ Weighted average of the gradient in the current round.
+ WD : float
+ The weight decay for Adam optimizer
+
+ Returns
+ -------
+ weights_prime : List[Tuple[NDArrays, int]]
+ These are updated parameters.
+ """
+ params_dict = zip(net.state_dict().keys(), weights_results)
+ state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
+ net.load_state_dict(state_dict, strict=True)
+ optimizer = torch.optim.Adam(
+ list(net.parameters()), lr=beta, weight_decay=weight_decay
+ )
+ for params, grad_ins in zip(net.parameters(), gradients_aggregated):
+ params.grad = torch.tensor(grad_ins).to(params.dtype)
+ optimizer.step()
+ optimizer.zero_grad()
+ weights_prime = [val.cpu().numpy() for _, val in net.state_dict().items()]
+
+ return weights_prime
+
+
+def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
+ """Aggregate using a weighted average during evaluation.
+
+ Parameters
+ ----------
+ metrics : List[Tuple[int, Metrics]]
+ The list of metrics to aggregate.
+
+ Returns
+ -------
+ Metrics
+ The weighted average metric.
+ """
+ # Multiply accuracy of each client by number of examples used
+ correct = [num_examples * float(m["correct"]) for num_examples, m in metrics]
+ examples = [num_examples for num_examples, _ in metrics]
+
+ # Aggregate and return custom metric (weighted average)
+ return {"accuracy": float(sum(correct)) / float(sum(examples))}
+
+
+class FedMeta(FedAvg):
+ """FedMeta averages the gradient and server parameter update through it."""
+
+ def __init__(self, alpha, beta, data, algo, **kwargs):
+ super().__init__(**kwargs)
+ self.algo = algo
+ self.data = data
+ self.beta = beta
+ self.ema_loss = None
+ self.ema_acc = None
+
+ if self.data == "femnist":
+ self.net = FemnistNetwork()
+ elif self.data == "shakespeare":
+ self.net = StackedLSTM()
+
+ self.alpha = torch.nn.ParameterList(
+ [
+ torch.nn.Parameter(torch.full_like(p, alpha))
+ for p in self.net.parameters()
+ ]
+ )
+
+ def configure_fit(
+ self, server_round: int, parameters: Parameters, client_manager: ClientManager
+ ) -> List[Tuple[ClientProxy, FitIns]]:
+ """Configure the next round of training."""
+ config = {"alpha": self.alpha, "algo": self.algo, "data": self.data}
+ if self.on_fit_config_fn is not None:
+ # Custom fit config function provided
+ config = self.on_fit_config_fn(server_round)
+ fit_ins = FitIns(parameters, config)
+
+ # Sample clients
+ sample_size, min_num_clients = self.num_fit_clients(
+ client_manager.num_available()
+ )
+ clients = client_manager.sample( # type: ignore
+ num_clients=sample_size,
+ min_num_clients=min_num_clients,
+ server_round=server_round,
+ step="fit",
+ )
+
+ # Return client/config pairs
+ return [(client, fit_ins) for client in clients]
+
+ def configure_evaluate(
+ self, server_round: int, parameters: Parameters, client_manager: ClientManager
+ ) -> List[Tuple[ClientProxy, EvaluateIns]]:
+ """Configure the next round of evaluation."""
+ # Do not configure federated evaluation if fraction eval is 0.
+ if self.fraction_evaluate == 0.0:
+ return []
+
+ # Parameters and config
+ config = {"alpha": self.alpha, "algo": self.algo, "data": self.data}
+ if self.on_evaluate_config_fn is not None:
+ # Custom evaluation config function provided
+ config = self.on_evaluate_config_fn(server_round)
+ evaluate_ins = EvaluateIns(parameters, config)
+
+ # Sample clients
+ sample_size, min_num_clients = self.num_evaluation_clients(
+ client_manager.num_available()
+ )
+ clients = client_manager.sample( # type: ignore
+ num_clients=sample_size,
+ min_num_clients=min_num_clients,
+ server_round=server_round,
+ step="evaluate",
+ )
+
+ # Return client/config pairs
+ return [(client, evaluate_ins) for client in clients]
+
+ def aggregate_fit(
+ self,
+ server_round: int,
+ results: List[Tuple[ClientProxy, FitRes]],
+ failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
+ ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
+ """Aggregate fit results using weighted average."""
+ if not results:
+ return None, {}
+ # Do not aggregate if there are failures and failures are not accepted
+ if not self.accept_failures and failures:
+ return None, {}
+
+ # Convert results
+ weights_results: List[Tuple[NDArrays, int]] = [
+ (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
+ for _, fit_res in results
+ ]
+
+ parameters_aggregated = aggregate(weights_results)
+ if self.data == "femnist":
+ weight_decay_ = 0.001
+ else:
+ weight_decay_ = 0.0001
+
+ # Gradient Average and Update Parameter for FedMeta(MAML)
+ if self.algo == "fedmeta_maml":
+ grads_results: List[Tuple[NDArrays, int]] = [
+ (fit_res.metrics["grads"], fit_res.num_examples) # type: ignore
+ for _, fit_res in results
+ ]
+ gradients_aggregated = aggregate(grads_results)
+ weights_prime = fedmeta_update_maml(
+ self.net,
+ self.beta,
+ weights_results[0][0],
+ gradients_aggregated,
+ weight_decay_,
+ )
+ parameters_aggregated = weights_prime
+
+ # Gradient Average and Update Parameter for FedMeta(Meta-SGD)
+ elif self.algo == "fedmeta_meta_sgd":
+ grads_results: List[Tuple[NDArrays, int]] = [ # type: ignore
+ (fit_res.metrics["grads"], fit_res.num_examples)
+ for _, fit_res in results
+ ]
+ gradients_aggregated = aggregate(grads_results)
+ weights_prime, update_alpha = fedmeta_update_meta_sgd(
+ self.net,
+ self.alpha,
+ self.beta,
+ weights_results[0][0],
+ gradients_aggregated,
+ weight_decay_,
+ )
+ self.alpha = update_alpha
+ parameters_aggregated = weights_prime
+
+ # Aggregate custom metrics if aggregation fn was provided
+ metrics_aggregated = {}
+ if self.fit_metrics_aggregation_fn:
+ fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
+ metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
+ elif server_round == 1: # Only log this warning once
+ log(WARNING, "No fit_metrics_aggregation_fn provided")
+
+ return ndarrays_to_parameters(parameters_aggregated), metrics_aggregated
+
+ def aggregate_evaluate(
+ self,
+ server_round: int,
+ results: List[Tuple[ClientProxy, EvaluateRes]],
+ failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
+ ) -> Tuple[Optional[float], Dict[str, Scalar]]:
+ """Aggregate evaluation losses using weighted average."""
+ if not results:
+ return None, {}
+ # Do not aggregate if there are failures and failures are not accepted
+ if not self.accept_failures and failures:
+ return None, {}
+
+ # Aggregate loss
+ loss_aggregated = weighted_loss_avg(
+ [
+ (evaluate_res.num_examples, evaluate_res.loss)
+ for _, evaluate_res in results
+ ]
+ )
+
+ if self.data == "femnist":
+ smoothing_weight = 0.95
+ else:
+ smoothing_weight = 0.7
+ self.ema_loss = update_ema(self.ema_loss, loss_aggregated, smoothing_weight)
+ loss_aggregated = self.ema_loss
+
+ # Aggregate custom metrics if aggregation fn was provided
+ metrics_aggregated = {}
+ if self.evaluate_metrics_aggregation_fn:
+ eval_metrics = [(res.num_examples, res.metrics) for _, res in results]
+ metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
+ self.ema_acc = update_ema(
+ self.ema_acc,
+ round(float(metrics_aggregated["accuracy"] * 100), 3),
+ smoothing_weight,
+ )
+ metrics_aggregated["accuracy"] = self.ema_acc
+
+ elif server_round == 1: # Only log this warning once
+ log(WARNING, "No evaluate_metrics_aggregation_fn provided")
+
+ return loss_aggregated, metrics_aggregated
diff --git a/baselines/fedmeta/fedmeta/utils.py b/baselines/fedmeta/fedmeta/utils.py
new file mode 100644
index 000000000000..b8e1dd95acab
--- /dev/null
+++ b/baselines/fedmeta/fedmeta/utils.py
@@ -0,0 +1,160 @@
+"""Define any utility function.
+
+They are not directly relevant to the other (more FL specific) python modules. For
+example, you may define here things like: loading a model from a checkpoint, saving
+results, plotting.
+"""
+
+import os
+import pickle
+from typing import Dict, List
+
+import matplotlib.pyplot as plt
+
+# Encoding list for the Shakespeare dataset
+ALL_LETTERS = (
+ "\n !\"&'(),-.0123456789:;>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz}"
+)
+
+
+def _one_hot(
+ index: int,
+ size: int,
+) -> List:
+ """Return one-hot vector with given size and value 1 at given index."""
+ vec = [0 for _ in range(size)]
+ vec[int(index)] = 1
+ return vec
+
+
+def letter_to_vec(
+ letter: str,
+) -> int:
+ """Return one-hot representation of given letter."""
+ index = ALL_LETTERS.find(letter)
+ return index
+
+
+def word_to_indices(
+ word: str,
+) -> List:
+ """Return a list of character indices.
+
+ Parameters
+ ----------
+ word: string.
+
+ Returns
+ -------
+ indices: int list with length len(word)
+ """
+ indices = []
+ for count in word:
+ indices.append(ALL_LETTERS.find(count))
+ return indices
+
+
+def update_ema(
+ prev_ema: float,
+ current_value: float,
+ smoothing_weight: float,
+) -> float:
+ """We use EMA to visually enhance the learning trend for each round.
+
+ Parameters
+ ----------
+ prev_ema : float
+ The list of metrics to aggregate.
+ current_value : float
+ The list of metrics to aggregate.
+ smoothing_weight : float
+ The list of metrics to aggregate.
+
+
+ Returns
+ -------
+ EMA_Loss or EMA_ACC
+ The weighted average metric.
+ """
+ if prev_ema is None:
+ return current_value
+ return (1 - smoothing_weight) * current_value + smoothing_weight * prev_ema
+
+
+def save_graph_params(data_info: Dict):
+ """Save parameters to visualize experiment results (Loss, ACC).
+
+ Parameters
+ ----------
+ data_info : Dict
+ This is a parameter dictionary of data from which the experiment was completed.
+ """
+ if os.path.exists(f"{data_info['path']}/{data_info['algo']}.pkl"):
+ raise ValueError(
+ f"'{data_info['path']}/{data_info['algo']}.pkl' is already exists!"
+ )
+
+ with open(f"{data_info['path']}/{data_info['algo']}.pkl", "wb") as file:
+ pickle.dump(data_info, file)
+
+
+def plot_from_pkl(directory="."):
+ """Visualization of algorithms like 4 Algorithm for data.
+
+ Parameters
+ ----------
+ directory : str
+ Graph params directory path for Femnist or Shakespeare
+ """
+ color_mapping = {
+ "fedavg.pkl": "#66CC00",
+ "fedavg_meta.pkl": "#3333CC",
+ "fedmeta_maml.pkl": "#FFCC00",
+ "fedmeta_meta_sgd.pkl": "#CC0000",
+ }
+
+ pkl_files = [f for f in os.listdir(directory) if f.endswith(".pkl")]
+
+ all_data = {}
+
+ for file in pkl_files:
+ with open(os.path.join(directory, file), "rb") as file_:
+ data = pickle.load(file_)
+ all_data[file] = data
+
+ plt.figure(figsize=(7, 12))
+
+ # Acc graph
+ plt.subplot(2, 1, 1)
+ for file in sorted(all_data.keys()):
+ data = all_data[file]
+ accuracies = [acc for _, acc in data["accuracy"]["accuracy"]]
+ legend_ = file[:-4] if file.endswith(".pkl") else file
+ plt.plot(
+ accuracies,
+ label=legend_,
+ color=color_mapping.get(file, "black"),
+ linewidth=3,
+ )
+ plt.title("Accuracy")
+ plt.grid(True)
+ plt.legend()
+
+ plt.subplot(2, 1, 2)
+ for file in sorted(all_data.keys()):
+ data = all_data[file]
+ loss = [loss for _, loss in data["loss"]]
+ legend_ = file[:-4] if file.endswith(".pkl") else file
+ plt.plot(
+ loss, label=legend_, color=color_mapping.get(file, "black"), linewidth=3
+ )
+ plt.title("Loss")
+ plt.legend()
+ plt.grid(True)
+
+ plt.tight_layout()
+
+ save_path = f"{directory}/result_graph.png"
+ plt.savefig(save_path)
+
+ plt.show()
diff --git a/baselines/fedmeta/pyproject.toml b/baselines/fedmeta/pyproject.toml
new file mode 100644
index 000000000000..cbaa9bb5d110
--- /dev/null
+++ b/baselines/fedmeta/pyproject.toml
@@ -0,0 +1,143 @@
+[build-system]
+requires = ["poetry-core>=1.4.0"]
+build-backend = "poetry.masonry.api"
+
+[tool.poetry]
+name = "fedmeta"
+version = "1.0.0"
+description = "Implementation of FedMeta (Fei Chen et al. 2018)"
+license = "Apache-2.0"
+authors = ["Jinsoo Kim ", "Kangyoon Lee "]
+readme = "README.md"
+homepage = "https://flower.dev"
+repository = "https://github.com/adap/flower"
+documentation = "https://flower.dev"
+classifiers = [
+ "Development Status :: 3 - Alpha",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: MacOS :: MacOS X",
+ "Operating System :: POSIX :: Linux",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3 :: Only",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: Implementation :: CPython",
+ "Topic :: Scientific/Engineering",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Scientific/Engineering :: Mathematics",
+ "Topic :: Software Development",
+ "Topic :: Software Development :: Libraries",
+ "Topic :: Software Development :: Libraries :: Python Modules",
+ "Typing :: Typed",
+]
+
+[tool.poetry.dependencies]
+python = ">=3.10.0, <3.11.0"
+flwr = { extras = ["simulation"], version = "1.5.0" }
+hydra-core = "1.3.2" # don't change this
+matplotlib = "3.7.1"
+scikit-learn = "1.3.1"
+torch = { url = "https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp310-cp310-linux_x86_64.whl"}
+torchvision = { url = "https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp310-cp310-linux_x86_64.whl"}
+pillow = "9.5.0" # needed <10.0.0 for LEAF repo scripts
+
+
+[tool.poetry.dev-dependencies]
+isort = "==5.11.5"
+black = "==23.1.0"
+docformatter = "==1.5.1"
+mypy = "==1.4.1"
+pylint = "==2.8.2"
+flake8 = "==3.9.2"
+pytest = "==6.2.4"
+pytest-watch = "==4.2.0"
+ruff = "==0.0.272"
+types-requests = "==2.27.7"
+
+[tool.isort]
+line_length = 88
+indent = " "
+multi_line_output = 3
+include_trailing_comma = true
+force_grid_wrap = 0
+use_parentheses = true
+
+[tool.black]
+line-length = 88
+target-version = ["py38", "py39", "py310", "py311"]
+
+[tool.pytest.ini_options]
+minversion = "6.2"
+addopts = "-qq"
+testpaths = [
+ "flwr_baselines",
+]
+
+[tool.mypy]
+ignore_missing_imports = true
+strict = false
+plugins = "numpy.typing.mypy_plugin"
+
+[tool.pylint."MESSAGES CONTROL"]
+disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias"
+good-names = "i,j,k,_,x,y,X,Y"
+signature-mutators="hydra.main.main"
+
+[tool.pylint.typecheck]
+generated-members="numpy.*, torch.*, tensorflow.*"
+
+[[tool.mypy.overrides]]
+module = [
+ "importlib.metadata.*",
+ "importlib_metadata.*",
+]
+follow_imports = "skip"
+follow_imports_for_stubs = true
+disallow_untyped_calls = false
+
+[[tool.mypy.overrides]]
+module = "torch.*"
+follow_imports = "skip"
+follow_imports_for_stubs = true
+
+[tool.docformatter]
+wrap-summaries = 88
+wrap-descriptions = 88
+
+[tool.ruff]
+target-version = "py38"
+line-length = 88
+select = ["D", "E", "F", "W", "B", "ISC", "C4"]
+fixable = ["D", "E", "F", "W", "B", "ISC", "C4"]
+ignore = ["B024", "B027"]
+exclude = [
+ ".bzr",
+ ".direnv",
+ ".eggs",
+ ".git",
+ ".hg",
+ ".mypy_cache",
+ ".nox",
+ ".pants.d",
+ ".pytype",
+ ".ruff_cache",
+ ".svn",
+ ".tox",
+ ".venv",
+ "__pypackages__",
+ "_build",
+ "buck-out",
+ "build",
+ "dist",
+ "node_modules",
+ "venv",
+ "proto",
+]
+
+[tool.ruff.pydocstyle]
+convention = "numpy"
diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md
index cf482521fa73..891632edaaf5 100644
--- a/doc/source/ref-changelog.md
+++ b/doc/source/ref-changelog.md
@@ -26,6 +26,8 @@
- TAMUNA ([#2254](https://github.com/adap/flower/pull/2254), [#2508](https://github.com/adap/flower/pull/2508))
+ - FedMeta [#2438](https://github.com/adap/flower/pull/2438)
+
- **Update Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384)), ([#2425](https://github.com/adap/flower/pull/2425))
- **General updates to baselines** ([#2301](https://github.com/adap/flower/pull/2301), [#2305](https://github.com/adap/flower/pull/2305), [#2307](https://github.com/adap/flower/pull/2307), [#2327](https://github.com/adap/flower/pull/2327), [#2435](https://github.com/adap/flower/pull/2435))
diff --git a/examples/android/README.md b/examples/android/README.md
index 6e45803eb2f5..7931aa96b0c5 100644
--- a/examples/android/README.md
+++ b/examples/android/README.md
@@ -1,12 +1,14 @@
# Flower Android Example (TensorFlowLite)
-This example demonstrates a federated learning setup with Android Clients. The training on Android is done on a CIFAR10 dataset using TensorFlow Lite. The setup is as follows:
+This example demonstrates a federated learning setup with Android clients in a background thread. The training on Android is done on a CIFAR10 dataset using TensorFlow Lite. The setup is as follows:
- The CIFAR10 dataset is randomly split across 10 clients. Each Android client holds a local dataset of 5000 training examples and 1000 test examples.
- The FL server runs in Python but all the clients run on Android.
- We use a strategy called FedAvgAndroid for this example.
- The strategy is vanilla FedAvg with a custom serialization and deserialization to handle the Bytebuffers sent from Android clients to Python server.
+The background thread is established via the `WorkManager` library of Android, thus, it will run comfortably on Android Versions from 8 to 13.
+
## Project Setup
Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:
diff --git a/examples/android/client/app/src/main/AndroidManifest.xml b/examples/android/client/app/src/main/AndroidManifest.xml
index 18eb6bad1feb..4b2319013878 100644
--- a/examples/android/client/app/src/main/AndroidManifest.xml
+++ b/examples/android/client/app/src/main/AndroidManifest.xml
@@ -1,8 +1,16 @@
-
-
-
+
+
+
+
+
+
+
+
-
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/android/client/app/src/main/java/flwr/android_client/FlowerClient.java b/examples/android/client/app/src/main/java/flwr/android_client/FlowerClient.java
index c453a1d106ea..e789e8f15cbc 100644
--- a/examples/android/client/app/src/main/java/flwr/android_client/FlowerClient.java
+++ b/examples/android/client/app/src/main/java/flwr/android_client/FlowerClient.java
@@ -11,6 +11,8 @@
import androidx.lifecycle.MutableLiveData;
import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
@@ -64,6 +66,7 @@ public void setLastLoss(int epoch, float newLoss) {
public void loadData(int device_id) {
try {
+ Log.d("FLOWERCLIENT_LOAD", "loadData: ");
BufferedReader reader = new BufferedReader(new InputStreamReader(this.context.getAssets().open("data/partition_" + (device_id - 1) + "_train.txt")));
String line;
int i = 0;
@@ -137,4 +140,38 @@ private static float[] prepareImage(Bitmap bitmap) {
return normalizedRgb;
}
-}
\ No newline at end of file
+
+ // function to write to a file :
+
+ public void writeStringToFile( Context context , String fileName, String content) {
+ try {
+ // Get the app-specific external storage directory
+ File directory = context.getExternalFilesDir(null);
+
+ if (directory != null) {
+ File file = new File(directory, fileName);
+
+ // Check if the file exists
+ boolean fileExists = file.exists();
+
+ // Open a FileWriter in append mode
+ FileWriter writer = new FileWriter(file, true);
+
+ // If the file exists and is not empty, add a new line
+ if (fileExists && file.length() > 0) {
+ writer.append("\n");
+ }
+
+ // Write the string to the file
+ writer.append(content);
+
+ // Close the FileWriter
+ writer.close();
+ }
+ } catch (IOException e) {
+ e.printStackTrace(); // Handle the exception as needed
+ }
+ }
+
+
+}
diff --git a/examples/android/client/app/src/main/java/flwr/android_client/FlowerWorker.java b/examples/android/client/app/src/main/java/flwr/android_client/FlowerWorker.java
new file mode 100644
index 000000000000..8e0bc4347d3a
--- /dev/null
+++ b/examples/android/client/app/src/main/java/flwr/android_client/FlowerWorker.java
@@ -0,0 +1,480 @@
+package flwr.android_client;
+
+
+
+import static android.content.Context.NOTIFICATION_SERVICE;
+import android.app.Notification;
+import android.app.NotificationChannel;
+import android.app.NotificationManager;
+import android.app.PendingIntent;
+import android.content.Context;
+import android.icu.text.SimpleDateFormat;
+import android.os.Build;
+import androidx.annotation.NonNull;
+import androidx.annotation.RequiresApi;
+import android.util.Log;
+import android.util.Pair;
+import io.grpc.ManagedChannel;
+import io.grpc.ManagedChannelBuilder;
+import flwr.android_client.FlowerServiceGrpc.FlowerServiceStub;
+import com.google.protobuf.ByteString;
+import io.grpc.stub.StreamObserver;
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.io.StringWriter;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.HashMap;
+import java.util.Map;
+import androidx.core.app.NotificationCompat;
+import androidx.work.Data;
+import androidx.work.ForegroundInfo;
+import androidx.work.WorkManager;
+import androidx.work.Worker;
+import androidx.work.WorkerParameters;
+
+public class FlowerWorker extends Worker {
+
+ private ManagedChannel channel;
+ public FlowerClient fc;
+ private StreamObserver UniversalRequestObserver;
+ private static final String TAG = "Flower";
+ String serverIp = "00:00:00";
+ String serverPort = "0000";
+ String dataslice = "1";
+ public static String start_time;
+ public static String end_time;
+ // following variables are just to send the worker routine to the
+ public static String workerStartTime = "";
+
+ public static String workerEndTime = "";
+
+ public static String workerEndReason = "worker ended";
+
+ public String getTime() {
+ // Extract hours, minutes, and seconds
+ if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.O) {
+ java.text.SimpleDateFormat sdf = new java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.getDefault());
+ String formattedDateTime = sdf.format(new Date());
+ return formattedDateTime;
+ }
+ return "";
+ }
+ private NotificationManager notificationManager;
+
+ private static String PROGRESS = "PROGRESS";
+
+ public FlowerWorker(@NonNull Context context, @NonNull WorkerParameters workerParams) {
+ super(context, workerParams);
+ FlowerWorker worker = this;
+ notificationManager = (NotificationManager)
+ context.getSystemService(NOTIFICATION_SERVICE);
+ fc = new FlowerClient(context.getApplicationContext());
+ }
+
+ @NonNull
+ @Override
+ public Result doWork() {
+
+ Data checkData = getInputData();
+ serverIp = checkData.getString("server");
+ serverPort = checkData.getString("port");
+ dataslice = checkData.getString("dataslice");
+
+ // Creating Foreground Notification Service about the Background Worker FL tasks
+ setForegroundAsync(createForegroundInfo("Progress"));
+ try {
+ workerStartTime = getTime();
+ // Ensuring whether the connection is establish or not with the given gRPC IP & port
+ boolean resultConnect = connect();
+ if(resultConnect)
+ {
+ loadData();
+ CompletableFuture grpcFuture = runGrpc();
+ grpcFuture.get();
+ return Result.success();
+ }
+ else
+ {
+ workerEndReason = "GRPC Connection failed";
+ return Result.failure();
+ }
+
+ } catch (Exception e) {
+ // To handle any exceptions and return a failure result
+ // Failure if there is any OOM or midway connection error
+ workerEndReason = "Unknown Error occured in main try catch";
+ Log.e(TAG, "Error executing flower code: " + e.getMessage(), e);
+ return Result.failure();
+ }
+ }
+
+ @Override
+ public void onStopped() {
+ super.onStopped();
+ // Worker is canceled, stopping the global requestObserver if it's not null
+ Throwable cancellationCause = new Throwable("Worker canceled");
+ if (UniversalRequestObserver != null) {
+ UniversalRequestObserver.onError(cancellationCause); // Signal to the server that communication is done
+ }
+ }
+
+ public boolean connect() {
+ int port = Integer.parseInt(serverPort);
+ try {
+ channel = ManagedChannelBuilder.forAddress(serverIp, port)
+ .maxInboundMessageSize(10 * 1024 * 1024)
+ .usePlaintext()
+ .build();
+ fc.writeStringToFile(getApplicationContext(), "FlowerResults.txt" , "Connection : Successful with " + serverIp + " : " + serverPort + " : " + dataslice);
+ return true; // connection is successful
+ } catch (Exception e) {
+ Log.e(TAG, "Failed to connect to the server: " + e.getMessage(), e);
+ fc.writeStringToFile(getApplicationContext(), "FlowerResults.txt" , "Connection : Failed with " + serverIp + " : " + serverPort + " : " + dataslice);
+ return false; // connection failed
+ }
+ }
+
+ public void loadData() {
+ try {
+ fc.loadData(Integer.parseInt(dataslice));
+ Log.d("LOAD", "Loading is complete");
+ fc.writeStringToFile(getApplicationContext(), "FlowerResults.txt", "Loading Bit Images : Success" );
+ } catch (Exception e) {
+ StringWriter sw = new StringWriter();
+ PrintWriter pw = new PrintWriter(sw);
+ e.printStackTrace(pw);
+ pw.flush();
+ Log.d("LOAD_ERROR", "Error occured in Loading");
+ fc.writeStringToFile(getApplicationContext(), "FlowerResults.txt", "Loading Bit Images : Failed" );
+ }
+ }
+
+ public CompletableFuture runGrpc() {
+
+ CompletableFuture future = new CompletableFuture<>();
+ FlowerWorker worker = this;
+ ExecutorService executor = Executors.newSingleThreadExecutor();
+
+ ProgressUpdater progressUpdater = new ProgressUpdater();
+
+ executor.execute(new Runnable() {
+ @Override
+ public void run() {
+ try {
+ CountDownLatch latch = new CountDownLatch(1);
+
+ (new FlowerServiceRunnable()).run(FlowerServiceGrpc.newStub(channel), worker, latch , progressUpdater , getApplicationContext());
+
+ latch.await(); // Wait for the latch to count down
+ future.complete(null); // Complete the future when the latch is counted down
+
+ Log.d("GRPC", "inside GRPC");
+ } catch (Exception e) {
+ StringWriter sw = new StringWriter();
+ PrintWriter pw = new PrintWriter(sw);
+ e.printStackTrace(pw);
+ pw.flush();
+ Log.e("GRPC", "Failed to connect to the FL server \n" + sw);
+ future.completeExceptionally(e); // Complete the future with an exception
+ }
+ }
+ });
+
+ return future;
+ }
+
+
+ @NonNull
+ private ForegroundInfo createForegroundInfo(@NonNull String progress) {
+ // Building a notification using bytesRead and contentLength
+ Context context = getApplicationContext();
+ String id = context.getString(R.string.notification_channel_id);
+ String title = context.getString(R.string.notification_title);
+ String cancel =context.getString(R.string.cancel_download);
+ // Creating a PendingIntent that can be used to cancel the worker
+ PendingIntent intent = WorkManager.getInstance(context)
+ .createCancelPendingIntent(getId());
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
+ createChannel();
+ }
+ Notification notification = new NotificationCompat.Builder(context, id)
+ .setContentTitle(title)
+ .setTicker(title)
+ .setSmallIcon(R.drawable.ic_logo)
+ .setOngoing(true)
+ // Add the cancel action to the notification which can
+ // be used to cancel the worker
+ .addAction(android.R.drawable.ic_delete, cancel, intent)
+ .build();
+ int notificationId = 1002;
+ return new ForegroundInfo(notificationId, notification);
+ }
+
+ @RequiresApi(Build.VERSION_CODES.O)
+ private void createChannel() {
+ Context context = getApplicationContext();
+ String channelId = context.getString(R.string.notification_channel_id);
+ String channelName = context.getString(R.string.notification_title);
+ int importance = NotificationManager.IMPORTANCE_DEFAULT;
+
+ NotificationChannel channel = new NotificationChannel(channelId, channelName, importance);
+ // Configure the channel
+ channel.setDescription("Channel description");
+ // Set other properties of the channel as needed if needed ...
+ NotificationManager notificationManager = (NotificationManager) context.getSystemService(Context.NOTIFICATION_SERVICE);
+ notificationManager.createNotificationChannel(channel);
+ }
+
+
+ public class ProgressUpdater {
+ public void setProgress() {
+ // Aim of this class is to allow static FlowerServiceRunnable Object to notifiy Main Activity about the changes in real time to be displayed to User
+ Log.d("DATA-BACKGROUND","Sending it to the main activity");
+ setProgressAsync(new Data.Builder().putInt("progress", 0).build());
+
+ }
+ }
+
+ private static class FlowerServiceRunnable{
+ protected Throwable failed;
+ public void run(FlowerServiceStub asyncStub, FlowerWorker worker , CountDownLatch latch , ProgressUpdater progressUpdater , Context context) {
+ join(asyncStub, worker , latch , progressUpdater , context);
+ }
+
+ public void writeStringToFile( Context context , String fileName, String content) {
+ try {
+ // Getting the app-specific external storage directory
+ File directory = context.getExternalFilesDir(null);
+
+ if (directory != null) {
+ File file = new File(directory, fileName);
+
+ // Checking if the file exists
+ boolean fileExists = file.exists();
+
+ // Open a FileWriter in append mode
+ FileWriter writer = new FileWriter(file, true);
+
+ // If the file exists and is not empty, add a new line
+ if (fileExists && file.length() > 0) {
+ writer.append("\n");
+ }
+
+ // Write the string to the file
+ writer.append(content);
+
+ // Close the FileWriter
+ writer.close();
+ }
+ } catch (IOException e) {
+ e.printStackTrace(); // Handle the exception as needed
+ }
+ }
+
+ private void join(FlowerServiceStub asyncStub, FlowerWorker worker, CountDownLatch latch , ProgressUpdater progressUpdater , Context context)
+ throws RuntimeException {
+ final CountDownLatch finishLatch = new CountDownLatch(1);
+
+ worker.UniversalRequestObserver = asyncStub.join(new StreamObserver() {
+ @Override
+ public void onNext(ServerMessage msg) {
+ handleMessage(msg, worker , progressUpdater , context);
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ t.printStackTrace();
+ failed = t;
+ finishLatch.countDown();
+ latch.countDown();
+ // Error handling for timeout & other GRPC communication related Errors
+ workerEndReason = t.getMessage();
+ writeStringToFile( context ,"FlowerResults.txt", workerEndReason);
+ Log.e(TAG, t.getMessage());
+ }
+
+ @Override
+ public void onCompleted() {
+ finishLatch.countDown();
+ latch.countDown();
+ Log.e(TAG, "Done");
+ }
+ });
+
+
+ try {
+ finishLatch.await();
+ } catch (InterruptedException e) {
+ Log.e(TAG, "Interrupted while waiting for gRPC communication to finish: " + e.getMessage(), e);
+ Thread.currentThread().interrupt();
+ }
+ }
+
+ private void handleMessage(ServerMessage message, FlowerWorker worker , ProgressUpdater progressUpdater , Context context) {
+
+ try {
+ ByteBuffer[] weights;
+ ClientMessage c = null;
+
+ if (message.hasGetParametersIns()) {
+ Log.e(TAG, "Handling GetParameters");
+
+ weights = worker.fc.getWeights();
+ c = weightsAsProto(weights);
+ } else if (message.hasFitIns()) {
+
+ SimpleDateFormat sdf = null;
+ if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) {
+ sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.getDefault());
+ }
+
+ // Get the current date and time
+ Date currentDate = new Date();
+
+ // Format the date and time using the SimpleDateFormat object
+ // String formattedDate = null;
+ if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) {
+ start_time = sdf.format(currentDate);
+ }
+ Log.e(TAG, "Handling FitIns");
+
+ List layers = message.getFitIns().getParameters().getTensorsList();
+
+ Scalar epoch_config = null;
+ if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) {
+ epoch_config = message.getFitIns().getConfigMap().getOrDefault("local_epochs", Scalar.newBuilder().setSint64(1).build());
+ }
+
+ assert epoch_config != null;
+ int local_epochs = (int) epoch_config.getSint64();
+
+ // Our model has 10 layers
+ ByteBuffer[] newWeights = new ByteBuffer[10] ;
+ for (int i = 0; i < 10; i++) {
+ newWeights[i] = ByteBuffer.wrap(layers.get(i).toByteArray());
+ }
+
+ Pair outputs = worker.fc.fit(newWeights, local_epochs);
+ currentDate = new Date();
+ // Format the date and time using the SimpleDateFormat object
+ if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) {
+ end_time = sdf.format(currentDate);
+ }
+ Log.d("FIT-RESPONSE", "ABOUT TO SEND FIT RESPONSE");
+ c = fitResAsProto(outputs.first, outputs.second);
+ } else if (message.hasEvaluateIns()) {
+ Log.e(TAG, "Handling EvaluateIns");
+
+ SimpleDateFormat sdf = null;
+ if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) {
+ sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.getDefault());
+ }
+ Date currentDate = new Date();
+ // Format the date and time using the SimpleDateFormat object
+ if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) {
+ start_time = sdf.format(currentDate);
+ }
+ List layers = message.getEvaluateIns().getParameters().getTensorsList();
+ // Our model has 10 layers
+ ByteBuffer[] newWeights = new ByteBuffer[10] ;
+ for (int i = 0; i < 10; i++) {
+ newWeights[i] = ByteBuffer.wrap(layers.get(i).toByteArray());
+ }
+ Pair, Integer> inference = worker.fc.evaluate(newWeights);
+ float loss = inference.first.first;
+ float accuracy = inference.first.second;
+ int test_size = inference.second;
+ currentDate = new Date();
+ // Format the date and time using the SimpleDateFormat object
+ if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) {
+ end_time = sdf.format(currentDate);
+ }
+ Log.d("EVALUATE-RESPONSE", "ABOUT TO SEND EVALUATE RESPONSE");
+ String newMessage = "Time : " + end_time + " , " + " Round Accuracy : " + String.valueOf(accuracy);
+ writeStringToFile( context ,"FlowerResults.txt", newMessage);
+ progressUpdater.setProgress();
+ c = evaluateResAsProto(loss , accuracy , test_size);
+ }
+ worker.UniversalRequestObserver.onNext(c);
+ }
+ catch (Exception e){
+ Log.e("Exception","Exception occured in GRPC Connection");
+ Log.e(TAG, e.getMessage());
+ }
+ }
+ }
+
+ private static ClientMessage weightsAsProto(ByteBuffer[] weights){
+ List layers = new ArrayList<>();
+ for (ByteBuffer weight : weights) {
+ layers.add(ByteString.copyFrom(weight));
+ }
+ Parameters p = Parameters.newBuilder().addAllTensors(layers).setTensorType("ND").build();
+ ClientMessage.GetParametersRes res = ClientMessage.GetParametersRes.newBuilder().setParameters(p).build();
+ return ClientMessage.newBuilder().setGetParametersRes(res).build();
+ }
+
+ private static ClientMessage fitResAsProto(ByteBuffer[] weights, int training_size){
+ List layers = new ArrayList<>();
+ for (ByteBuffer weight : weights) {
+ layers.add(ByteString.copyFrom(weight));
+ }
+
+ Log.d("ENDTIME", end_time);
+ Log.d("STARTTIME", start_time);
+
+ // An example portraying how to upload data to the server via FLower Server side GRPC
+ Map metrics = new HashMap<>();
+
+ metrics.put("start_time", Scalar.newBuilder().setString(start_time).build());
+ metrics.put("end_time", Scalar.newBuilder().setString(end_time).build());
+ Parameters p = Parameters.newBuilder().addAllTensors(layers).setTensorType("ND").build();
+ ClientMessage.FitRes res = ClientMessage.FitRes.newBuilder().setParameters(p).setNumExamples(training_size).putAllMetrics(metrics).build();
+ return ClientMessage.newBuilder().setFitRes(res).build();
+ }
+
+
+
+ private static ClientMessage evaluateResAsProto(float loss, float accuracy ,int testing_size){
+
+ // attempting to send accuracy to the server :
+ Map metrics = new HashMap<>();
+
+
+ Log.d("ENDTIME", end_time);
+ Log.d("STARTTIME", start_time);
+
+ Log.d("Accuracy", String.valueOf(accuracy));
+ Log.d("Loss", String.valueOf(loss));
+
+
+ // An example portraying how to upload data to the server via FLower Server side GRPC
+ metrics.put("Accuracy", Scalar.newBuilder().setString(String.valueOf(accuracy)).build());
+ metrics.put("Loss" , Scalar.newBuilder().setString(String.valueOf(loss)).build());
+ metrics.put("start_time", Scalar.newBuilder().setString(start_time).build());
+ metrics.put("end_time", Scalar.newBuilder().setString(end_time).build());
+
+
+ ClientMessage.EvaluateRes res = ClientMessage.EvaluateRes.newBuilder().setLoss(loss).setNumExamples(testing_size).putAllMetrics(metrics).build();
+ return ClientMessage.newBuilder().setEvaluateRes(res).build();
+ }
+
+
+}
+
+
+
+
+
+
+
diff --git a/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java b/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java
index 911d5043dfef..cbf804140954 100644
--- a/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java
+++ b/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java
@@ -1,289 +1,278 @@
package flwr.android_client;
-import android.app.Activity;
-import android.icu.text.SimpleDateFormat;
+import android.content.Context;
+import android.content.Intent;
+import android.net.Uri;
+import android.os.Build;
import android.os.Bundle;
-
import androidx.appcompat.app.AppCompatActivity;
-
-import android.os.Handler;
-import android.os.Looper;
+import androidx.lifecycle.LifecycleOwner;
+import androidx.recyclerview.widget.LinearLayoutManager;
+import androidx.recyclerview.widget.RecyclerView;
+import androidx.work.Constraints;
+import androidx.work.Data;
+import androidx.work.ExistingPeriodicWorkPolicy;
+import androidx.work.PeriodicWorkRequest;
+import androidx.work.WorkInfo;
+import androidx.work.WorkManager;
+import android.os.PowerManager;
+import android.provider.Settings;
import android.text.TextUtils;
-import android.text.method.ScrollingMovementMethod;
-import android.util.Log;
-import android.util.Pair;
-import android.util.Patterns;
import android.view.View;
-import android.view.inputmethod.InputMethodManager;
import android.widget.Button;
import android.widget.EditText;
-import android.widget.TextView;
import android.widget.Toast;
-
-import io.grpc.ManagedChannel;
-import io.grpc.ManagedChannelBuilder;
-
-import flwr.android_client.FlowerServiceGrpc.FlowerServiceStub;
-import com.google.protobuf.ByteString;
-
-import io.grpc.stub.StreamObserver;
-
-import java.io.PrintWriter;
-import java.io.StringWriter;
-import java.nio.ByteBuffer;
+import androidx.lifecycle.Observer;
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileReader;
+import java.io.FileWriter;
+import java.io.IOException;
import java.util.ArrayList;
-import java.util.Date;
import java.util.List;
-import java.util.Locale;
-import java.util.Objects;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
public class MainActivity extends AppCompatActivity {
- private EditText ip;
- private EditText port;
- private Button loadDataButton;
- private Button connectButton;
- private Button trainButton;
- private TextView resultText;
- private EditText device_id;
- private ManagedChannel channel;
- public FlowerClient fc;
- private static final String TAG = "Flower";
+ private static final String TAG = "Flower";
+ private static final int REQUEST_WRITE_PERMISSION = 786;
+ private Button batteryOptimisationButton;
+ MessageAdapter messageAdapter;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
- resultText = (TextView) findViewById(R.id.grpc_response_text);
- resultText.setMovementMethod(new ScrollingMovementMethod());
- device_id = (EditText) findViewById(R.id.device_id_edit_text);
- ip = (EditText) findViewById(R.id.serverIP);
- port = (EditText) findViewById(R.id.serverPort);
- loadDataButton = (Button) findViewById(R.id.load_data) ;
- connectButton = (Button) findViewById(R.id.connect);
- trainButton = (Button) findViewById(R.id.trainFederated);
-
- fc = new FlowerClient(this);
- }
+ RecyclerView recyclerView = findViewById(R.id.recyclerView);
+ recyclerView.setLayoutManager(new LinearLayoutManager(this));
- public static void hideKeyboard(Activity activity) {
- InputMethodManager imm = (InputMethodManager) activity.getSystemService(Activity.INPUT_METHOD_SERVICE);
- View view = activity.getCurrentFocus();
- if (view == null) {
- view = new View(activity);
- }
- imm.hideSoftInputFromWindow(view.getWindowToken(), 0);
- }
+ messageAdapter = new MessageAdapter(readStringFromFile( getApplicationContext() , "FlowerResults.txt")); // Create your custom adapter
+ recyclerView.setLayoutManager(new LinearLayoutManager(this));
+ recyclerView.setAdapter(messageAdapter);
+ requestPermission();
- public void setResultText(String text) {
- SimpleDateFormat dateFormat = new SimpleDateFormat("HH:mm:ss", Locale.GERMANY);
- String time = dateFormat.format(new Date());
- resultText.append("\n" + time + " " + text);
+ LifecycleOwner lifecycleOwner = this ;
+ WorkManager.getInstance(getApplicationContext()).getWorkInfosForUniqueWorkLiveData("my_unique_periodic_work").observe(lifecycleOwner, new Observer>() {
+ @Override
+ public void onChanged(List workInfos) {
+ if (workInfos.size() > 0) {
+ WorkInfo info = workInfos.get(0);
+ int progress = info.getProgress().getInt("progress", -1);
+ // You can recieve any message from the Worker Thread
+ refreshRecyclerView();
+ }
+ }
+ });
+ // code for functionality of permission buttons :
+ batteryOptimisationButton = findViewById(R.id.battery_optimisation);
+ batteryOptimisationButton.setOnClickListener(new View.OnClickListener() {
+ @Override
+ public void onClick(View v) {
+ toggleBatteryOptimization();
+ }
+ });
}
- public void loadData(View view){
- if (TextUtils.isEmpty(device_id.getText().toString())) {
- Toast.makeText(this, "Please enter a client partition ID between 1 and 10 (inclusive)", Toast.LENGTH_LONG).show();
+ private void requestPermission() {
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
+ requestPermissions(new String[]{android.Manifest.permission.WRITE_EXTERNAL_STORAGE}, REQUEST_WRITE_PERMISSION);
+ createEmptyFile("FlowerResults.txt");
}
- else if (Integer.parseInt(device_id.getText().toString()) > 10 || Integer.parseInt(device_id.getText().toString()) < 1)
+ else
{
- Toast.makeText(this, "Please enter a client partition ID between 1 and 10 (inclusive)", Toast.LENGTH_LONG).show();
- }
- else{
- hideKeyboard(this);
- setResultText("Loading the local training dataset in memory. It will take several seconds.");
- loadDataButton.setEnabled(false);
-
- ExecutorService executor = Executors.newSingleThreadExecutor();
- Handler handler = new Handler(Looper.getMainLooper());
-
- executor.execute(new Runnable() {
- private String result;
- @Override
- public void run() {
- try {
- fc.loadData(Integer.parseInt(device_id.getText().toString()));
- result = "Training dataset is loaded in memory.";
- } catch (Exception e) {
- StringWriter sw = new StringWriter();
- PrintWriter pw = new PrintWriter(sw);
- e.printStackTrace(pw);
- pw.flush();
- result = "Training dataset is loaded in memory.";
- }
- handler.post(() -> {
- setResultText(result);
- connectButton.setEnabled(true);
- });
- }
- });
+ createEmptyFile("FlowerResults.txt");
}
}
+ private List readStringFromFile(Context context, String fileName) {
+ List lines = new ArrayList<>();
- public void connect(View view) {
- String host = ip.getText().toString();
- String portStr = port.getText().toString();
- if (TextUtils.isEmpty(host) || TextUtils.isEmpty(portStr) || !Patterns.IP_ADDRESS.matcher(host).matches()) {
- Toast.makeText(this, "Please enter the correct IP and port of the FL server", Toast.LENGTH_LONG).show();
- }
- else {
- int port = TextUtils.isEmpty(portStr) ? 0 : Integer.parseInt(portStr);
- channel = ManagedChannelBuilder.forAddress(host, port).maxInboundMessageSize(10 * 1024 * 1024).usePlaintext().build();
- hideKeyboard(this);
- trainButton.setEnabled(true);
- connectButton.setEnabled(false);
- setResultText("Channel object created. Ready to train!");
+ try {
+ File directory = context.getExternalFilesDir(null);
+
+ if (directory != null) {
+ File file = new File(directory, fileName);
+
+ // Checking if the file exists
+ if (!file.exists()) {
+ return lines; // File doesn't exist then return an empty list
+ }
+ // Opening a FileReader to read the file
+ FileReader reader = new FileReader(file);
+ BufferedReader bufferedReader = new BufferedReader(reader);
+ String line;
+ while ((line = bufferedReader.readLine()) != null) {
+ lines.add(line);
+ }
+ // Closing the readers
+ bufferedReader.close();
+ reader.close();
+ }
+ } catch (IOException e) {
+ e.printStackTrace(); // Handle the exception as needed
}
+
+ return lines;
}
- public void runGrpc(View view){
- MainActivity activity = this;
- ExecutorService executor = Executors.newSingleThreadExecutor();
- Handler handler = new Handler(Looper.getMainLooper());
- executor.execute(new Runnable() {
- private String result;
- @Override
- public void run() {
- try {
- (new FlowerServiceRunnable()).run(FlowerServiceGrpc.newStub(channel), activity);
- result = "Connection to the FL server successful \n";
- } catch (Exception e) {
- StringWriter sw = new StringWriter();
- PrintWriter pw = new PrintWriter(sw);
- e.printStackTrace(pw);
- pw.flush();
- result = "Failed to connect to the FL server \n" + sw;
+ private void clearFileContents(Context context, String fileName) {
+ try {
+ File directory = context.getExternalFilesDir(null);
+
+ if (directory != null) {
+ File file = new File(directory, fileName);
+
+ // Checking if the file exists
+ if (!file.exists()) {
+ // File doesn't exist, so there's nothing to clear
+ return;
}
- handler.post(() -> {
- setResultText(result);
- trainButton.setEnabled(false);
- });
+
+ // Opens a FileWriter with append mode set to false (this will clear the file)
+ FileWriter writer = new FileWriter(file, false);
+ writer.write(""); // Write an empty string to clear the file
+ writer.close();
+
+ refreshRecyclerView();
}
- });
+ } catch (IOException e) {
+ e.printStackTrace(); // Handle the exception as needed
+ }
}
- private static class FlowerServiceRunnable{
- protected Throwable failed;
- private StreamObserver requestObserver;
- public void run(FlowerServiceStub asyncStub, MainActivity activity) {
- join(asyncStub, activity);
- }
+ public void startWorker(View view) {
+
+ // ensuring all inputs are entered :
+
+ EditText deviceIdEditText = findViewById(R.id.device_id_edit_text);
+ EditText serverIPEditText = findViewById(R.id.serverIP);
+ EditText serverPortEditText = findViewById(R.id.serverPort);
+
+ // Get the text from the EditText widgets
+ String dataSlice = deviceIdEditText.getText().toString();
+ String serverIP = serverIPEditText.getText().toString();
+ String serverPort = serverPortEditText.getText().toString();
- private void join(FlowerServiceStub asyncStub, MainActivity activity)
- throws RuntimeException {
-
- final CountDownLatch finishLatch = new CountDownLatch(1);
- requestObserver = asyncStub.join(
- new StreamObserver() {
- @Override
- public void onNext(ServerMessage msg) {
- handleMessage(msg, activity);
- }
-
- @Override
- public void onError(Throwable t) {
- t.printStackTrace();
- failed = t;
- finishLatch.countDown();
- Log.e(TAG, t.getMessage());
- }
-
- @Override
- public void onCompleted() {
- finishLatch.countDown();
- Log.e(TAG, "Done");
- }
- });
+ if (TextUtils.isEmpty(dataSlice) || TextUtils.isEmpty(serverIP) || TextUtils.isEmpty(serverPort)) {
+ // Display a toast message indicating that fields are omitted
+ Toast.makeText(this, "Please fill in all fields", Toast.LENGTH_SHORT).show();
+ } else {
+
+ // Launching the Worker :
+ Constraints constraints = new Constraints.Builder()
+ // Add constraints if needed (e.g., network connectivity)
+ .build();
+
+ PeriodicWorkRequest workRequest = new PeriodicWorkRequest.Builder(
+ FlowerWorker.class, 15, TimeUnit.MINUTES)
+ .setInitialDelay(0, TimeUnit.MILLISECONDS)
+ .setInputData(new Data.Builder()
+ .putString( "dataslice", deviceIdEditText.getText().toString() )
+ .putString( "server", serverIPEditText.getText().toString())
+ .putString( "port" , serverPortEditText.getText().toString())
+ .build())
+ .setConstraints(constraints)
+ .build();
+
+ String uniqueWorkName = "my_unique_periodic_work";
+
+ WorkManager.getInstance(getApplicationContext())
+ .enqueueUniquePeriodicWork(uniqueWorkName, ExistingPeriodicWorkPolicy.KEEP, workRequest);
+
+ // Providing user feedback, e.g., a toast message
+ Toast.makeText(this, "Worker started!", Toast.LENGTH_SHORT).show();
}
+ }
- private void handleMessage(ServerMessage message, MainActivity activity) {
+ // Listener function for the "Stop" button
+ public void stopWorker(View view) {
+ // Cancel the worker
+ WorkManager.getInstance(getApplicationContext()).cancelAllWork();
+ // Providing user feedback again, e.g., a toast message
+ Toast.makeText(this, "Worker stopped!", Toast.LENGTH_SHORT).show();
+ }
- try {
- ByteBuffer[] weights;
- ClientMessage c = null;
- if (message.hasGetParametersIns()) {
- Log.e(TAG, "Handling GetParameters");
- activity.setResultText("Handling GetParameters message from the server.");
+ // Another Listener function to refresh the updates :
- weights = activity.fc.getWeights();
- c = weightsAsProto(weights);
- } else if (message.hasFitIns()) {
- Log.e(TAG, "Handling FitIns");
- activity.setResultText("Handling Fit request from the server.");
+ public void refresh(View view)
+ {
+ refreshRecyclerView();
+ }
- List layers = message.getFitIns().getParameters().getTensorsList();
+ // Another Listener to clear the contents of the File :
- Scalar epoch_config = message.getFitIns().getConfigMap().getOrDefault("local_epochs", Scalar.newBuilder().setSint64(1).build());
+ public void clear(View view)
+ {
+ clearFileContents(getApplicationContext() , "FlowerResults.txt");
+ }
- assert epoch_config != null;
- int local_epochs = (int) epoch_config.getSint64();
- // Our model has 10 layers
- ByteBuffer[] newWeights = new ByteBuffer[10] ;
- for (int i = 0; i < 10; i++) {
- newWeights[i] = ByteBuffer.wrap(layers.get(i).toByteArray());
- }
+ private void refreshRecyclerView() {
+ // Get messages from MessageRepository using the getMessagesArray method
+ List messages = readStringFromFile( getApplicationContext() ,"FlowerResults.txt");
- Pair outputs = activity.fc.fit(newWeights, local_epochs);
- c = fitResAsProto(outputs.first, outputs.second);
- } else if (message.hasEvaluateIns()) {
- Log.e(TAG, "Handling EvaluateIns");
- activity.setResultText("Handling Evaluate request from the server");
+ // Update the data source of your adapter with the new messages
+ messageAdapter.setData(messages);
- List layers = message.getEvaluateIns().getParameters().getTensorsList();
+ // Notify the adapter that the data has changed
+ messageAdapter.notifyDataSetChanged();
- // Our model has 10 layers
- ByteBuffer[] newWeights = new ByteBuffer[10] ;
- for (int i = 0; i < 10; i++) {
- newWeights[i] = ByteBuffer.wrap(layers.get(i).toByteArray());
- }
- Pair, Integer> inference = activity.fc.evaluate(newWeights);
+ }
- float loss = inference.first.first;
- float accuracy = inference.first.second;
- activity.setResultText("Test Accuracy after this round = " + accuracy);
- int test_size = inference.second;
- c = evaluateResAsProto(loss, test_size);
- }
- requestObserver.onNext(c);
- activity.setResultText("Response sent to the server");
- }
- catch (Exception e){
- Log.e(TAG, e.getMessage());
- }
+
+
+ // following code is for just the permissions :
+ private void toggleBatteryOptimization() {
+ if (isBatteryOptimizationEnabled()) {
+ disableBatteryOptimization();
+ } else {
+ requestBatteryOptimization();
}
}
- private static ClientMessage weightsAsProto(ByteBuffer[] weights){
- List layers = new ArrayList<>();
- for (ByteBuffer weight : weights) {
- layers.add(ByteString.copyFrom(weight));
+ private boolean isBatteryOptimizationEnabled() {
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
+ String packageName = getPackageName();
+ PowerManager powerManager = (PowerManager) getSystemService(Context.POWER_SERVICE);
+ return powerManager.isIgnoringBatteryOptimizations(packageName);
}
- Parameters p = Parameters.newBuilder().addAllTensors(layers).setTensorType("ND").build();
- ClientMessage.GetParametersRes res = ClientMessage.GetParametersRes.newBuilder().setParameters(p).build();
- return ClientMessage.newBuilder().setGetParametersRes(res).build();
+ // Battery optimization is not available on versions prior to M, so return false.
+ return false;
}
- private static ClientMessage fitResAsProto(ByteBuffer[] weights, int training_size){
- List layers = new ArrayList<>();
- for (ByteBuffer weight : weights) {
- layers.add(ByteString.copyFrom(weight));
+ private void disableBatteryOptimization() {
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
+ Intent intent = new Intent(Settings.ACTION_IGNORE_BATTERY_OPTIMIZATION_SETTINGS);
+ startActivity(intent);
}
- Parameters p = Parameters.newBuilder().addAllTensors(layers).setTensorType("ND").build();
- ClientMessage.FitRes res = ClientMessage.FitRes.newBuilder().setParameters(p).setNumExamples(training_size).build();
- return ClientMessage.newBuilder().setFitRes(res).build();
}
- private static ClientMessage evaluateResAsProto(float accuracy, int testing_size){
- ClientMessage.EvaluateRes res = ClientMessage.EvaluateRes.newBuilder().setLoss(accuracy).setNumExamples(testing_size).build();
- return ClientMessage.newBuilder().setEvaluateRes(res).build();
+ private void requestBatteryOptimization() {
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
+ Intent intent = new Intent(Settings.ACTION_REQUEST_IGNORE_BATTERY_OPTIMIZATIONS);
+ intent.setData(Uri.parse("package:" + getPackageName()));
+ startActivity(intent);
+// startActivityForResult(intent, BATTERY_OPTIMIZATION_REQUEST_CODE);
+ }
+ }
+
+ public void createEmptyFile(String fileName) {
+ try {
+ File file = new File(fileName);
+
+ // Create the file if it doesn't exist
+ if (!file.exists()) {
+ file.createNewFile();
+ }
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
}
}
+
+
diff --git a/examples/android/client/app/src/main/java/flwr/android_client/MessageAdapter.java b/examples/android/client/app/src/main/java/flwr/android_client/MessageAdapter.java
new file mode 100644
index 000000000000..75bcfdbdbe66
--- /dev/null
+++ b/examples/android/client/app/src/main/java/flwr/android_client/MessageAdapter.java
@@ -0,0 +1,58 @@
+package flwr.android_client;
+
+import android.view.LayoutInflater;
+import android.view.View;
+import android.view.ViewGroup;
+import android.widget.TextView;
+import androidx.annotation.NonNull;
+import androidx.recyclerview.widget.RecyclerView;
+import java.util.List;
+
+public class MessageAdapter extends RecyclerView.Adapter {
+
+ private List messages;
+
+ // Constructor to initialize the data source
+ public MessageAdapter(List messages) {
+ this.messages = messages;
+ }
+
+ @NonNull
+ @Override
+ public MessageViewHolder onCreateViewHolder(@NonNull ViewGroup parent, int viewType) {
+ View view = LayoutInflater.from(parent.getContext()).inflate(R.layout.item_message, parent, false);
+ return new MessageViewHolder(view);
+ }
+
+ @Override
+ public void onBindViewHolder(@NonNull MessageViewHolder holder, int position) {
+ String message = messages.get(position);
+ holder.bind(message);
+ }
+
+ @Override
+ public int getItemCount() {
+ return messages != null ? messages.size() : 0;
+ }
+
+ public void setData(List messages) {
+ this.messages = messages;
+ notifyDataSetChanged();
+ }
+
+
+ // ViewHolder class
+ public static class MessageViewHolder extends RecyclerView.ViewHolder {
+ TextView messageTextView;
+
+ public MessageViewHolder(@NonNull View itemView) {
+ super(itemView);
+ messageTextView = itemView.findViewById(R.id.messageTextView);
+ }
+
+ // Bind data to the TextView
+ public void bind(String message) {
+ messageTextView.setText(message);
+ }
+ }
+}
diff --git a/examples/android/client/app/src/main/res/layout/activity_main.xml b/examples/android/client/app/src/main/res/layout/activity_main.xml
index 543f1eb1cd65..7d98e65823be 100644
--- a/examples/android/client/app/src/main/res/layout/activity_main.xml
+++ b/examples/android/client/app/src/main/res/layout/activity_main.xml
@@ -1,131 +1,200 @@
-
+
-
+
-
-
-
-
-
+ android:layout_margin="8dp"
+ android:hint="Client Partition ID (1-10)"
+ android:inputType="numberDecimal"
+ android:textAppearance="@style/TextAppearance.AppCompat.Medium"
+ android:textColor="#4a5663"
+ app:layout_constraintStart_toStartOf="parent"
+ app:layout_constraintEnd_toEndOf="parent"
+ app:layout_constraintTop_toTopOf="parent" />
+
-
+ android:textColor="#4a5663"
+ app:layout_constraintStart_toStartOf="parent"
+ app:layout_constraintEnd_toEndOf="parent"
+ app:layout_constraintTop_toBottomOf="@+id/device_id_edit_text" />
+
+ android:textColor="#4a5663"
+ app:layout_constraintStart_toStartOf="parent"
+ app:layout_constraintEnd_toEndOf="parent"
+ app:layout_constraintTop_toBottomOf="@+id/serverIP" />
+
+
+
+
-
+ android:onClick="startWorker"
+ android:text="Start"
+ app:layout_constraintEnd_toEndOf="parent"
+ app:layout_constraintHorizontal_bias="0.328"
+ app:layout_constraintStart_toStartOf="parent"
+ app:layout_constraintTop_toBottomOf="@+id/serverPort" />
+ android:enabled="true"
+ android:onClick="stopWorker"
+ android:text="Stop"
+ app:layout_constraintEnd_toEndOf="parent"
+ app:layout_constraintHorizontal_bias="0.676"
+ app:layout_constraintStart_toStartOf="parent"
+ app:layout_constraintTop_toBottomOf="@+id/serverPort" />
+ android:onClick="batteryOptimisation"
+ android:text="battery_optimisation"
+ app:layout_constraintStart_toStartOf="parent"
+ app:layout_constraintEnd_toEndOf="parent"
+ app:layout_constraintTop_toBottomOf="@+id/connect" />
+
+
+
+
+
+
+
+
+ android:layout_marginStart="16dp"
+ android:layout_marginTop="12dp"
+ android:layout_marginEnd="16dp"
+ android:text="Results"
+ android:textAppearance="@style/TextAppearance.AppCompat.Large"
+ android:textColor="#000000"
+ android:textStyle="bold"
+ app:layout_constraintEnd_toEndOf="parent"
+ app:layout_constraintHorizontal_bias="0.16"
+ app:layout_constraintStart_toStartOf="parent"
+ app:layout_constraintTop_toBottomOf="@id/battery_optimisation" />
-
+
+
-
-
+ app:layout_constraintBottom_toBottomOf="parent"
+ app:layout_constraintEnd_toEndOf="parent"
+ app:layout_constraintHorizontal_bias="0.0"
+ app:layout_constraintStart_toStartOf="parent"
+ app:layout_constraintTop_toBottomOf="@id/results_heading"
+ app:layout_constraintVertical_bias="0.0" />
+
+
+
+
+
diff --git a/examples/android/client/app/src/main/res/layout/item_message.xml b/examples/android/client/app/src/main/res/layout/item_message.xml
new file mode 100644
index 000000000000..e3f7488ad4c4
--- /dev/null
+++ b/examples/android/client/app/src/main/res/layout/item_message.xml
@@ -0,0 +1,17 @@
+
+
+
+
+
+
diff --git a/examples/android/client/app/src/main/res/values/strings.xml b/examples/android/client/app/src/main/res/values/strings.xml
index 12b3ba5e5ebd..2fcdf7b0f1e0 100644
--- a/examples/android/client/app/src/main/res/values/strings.xml
+++ b/examples/android/client/app/src/main/res/values/strings.xml
@@ -1,3 +1,6 @@
Flower
+ 1002
+ Federated learning running
+ Cancel federated learning