-
Notifications
You must be signed in to change notification settings - Fork 907
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Example for Custom Metrics calculation during Federated Learning (#1958)
Co-authored-by: Yan Gao <[email protected]>
- Loading branch information
1 parent
d7be8fb
commit dfa30a3
Showing
6 changed files
with
273 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Flower Example using Custom Metrics | ||
|
||
This simple example demonstrates how to calculate custom metrics over multiple clients beyond the traditional ones available in the ML frameworks. In this case, it demonstrates the use of ready-available `scikit-learn` metrics: accuracy, recall, precision, and f1-score. | ||
|
||
Once both the test values (`y_test`) and the predictions (`y_pred`) are available on the client side (`client.py`), other metrics or custom ones are possible to be calculated. | ||
|
||
The main takeaways of this implementation are: | ||
|
||
- the use of the `output_dict` on the client side - inside `evaluate` method on `client.py` | ||
- the use of the `evaluate_metrics_aggregation_fn` - to aggregate the metrics on the server side, part of the `strategy` on `server.py` | ||
|
||
This example is based on the `quickstart-tensorflow` with CIFAR-10, source [here](https://flower.dev/docs/quickstart-tensorflow.html), with the addition of [Flower Datasets](https://flower.dev/docs/datasets/index.html) to retrieve the CIFAR-10. | ||
|
||
Using the CIFAR-10 dataset for classification, this is a multi-class classification problem, thus some changes on how to calculate the metrics using `average='micro'` and `np.argmax` is required. For binary classification, this is not required. Also, for unsupervised learning tasks, such as using a deep autoencoder, a custom metric based on reconstruction error could be implemented on client side. | ||
|
||
## Project Setup | ||
|
||
Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: | ||
|
||
```shell | ||
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/custom-metrics . && rm -rf flower && cd custom-metrics | ||
``` | ||
|
||
This will create a new directory called `custom-metrics` containing the following files: | ||
|
||
```shell | ||
-- pyproject.toml | ||
-- requirements.txt | ||
-- client.py | ||
-- server.py | ||
-- run.sh | ||
-- README.md | ||
``` | ||
|
||
### Installing Dependencies | ||
|
||
Project dependencies (such as `scikit-learn`, `tensorflow` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. | ||
|
||
#### Poetry | ||
|
||
```shell | ||
poetry install | ||
poetry shell | ||
``` | ||
|
||
Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: | ||
|
||
```shell | ||
poetry run python3 -c "import flwr" | ||
``` | ||
|
||
If you don't see any errors you're good to go! | ||
|
||
#### pip | ||
|
||
Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. | ||
|
||
```shell | ||
python -m venv venv | ||
source venv/bin/activate | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Run Federated Learning with Custom Metrics | ||
|
||
Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: | ||
|
||
```shell | ||
python server.py | ||
``` | ||
|
||
Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: | ||
|
||
```shell | ||
python client.py | ||
``` | ||
|
||
Alternatively you can run all of it in one shell as follows: | ||
|
||
```shell | ||
python server.py & | ||
# Wait for a few seconds to give the server enough time to start, then: | ||
python client.py & | ||
python client.py | ||
``` | ||
|
||
or | ||
|
||
```shell | ||
chmod +x run.sh | ||
./run.sh | ||
``` | ||
|
||
You will see that Keras is starting a federated training. Have a look to the [Flower Quickstarter documentation](https://flower.dev/docs/quickstart-tensorflow.html) for a detailed explanation. You can add `steps_per_epoch=3` to `model.fit()` if you just want to evaluate that everything works without having to wait for the client-side training to finish (this will save you a lot of time during development). | ||
|
||
Running `run.sh` will result in the following output (after 3 rounds): | ||
|
||
```shell | ||
INFO flwr 2024-01-17 17:45:23,794 | app.py:228 | app_fit: metrics_distributed { | ||
'accuracy': [(1, 0.10000000149011612), (2, 0.10000000149011612), (3, 0.3393000066280365)], | ||
'acc': [(1, 0.1), (2, 0.1), (3, 0.3393)], | ||
'rec': [(1, 0.1), (2, 0.1), (3, 0.3393)], | ||
'prec': [(1, 0.1), (2, 0.1), (3, 0.3393)], | ||
'f1': [(1, 0.10000000000000002), (2, 0.10000000000000002), (3, 0.3393)] | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import os | ||
|
||
import flwr as fl | ||
import numpy as np | ||
import tensorflow as tf | ||
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score | ||
from flwr_datasets import FederatedDataset | ||
|
||
|
||
# Make TensorFlow log less verbose | ||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | ||
|
||
|
||
# Load model (MobileNetV2) | ||
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None) | ||
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"]) | ||
|
||
# Load data with Flower Datasets (CIFAR-10) | ||
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) | ||
train = fds.load_full("train") | ||
test = fds.load_full("test") | ||
|
||
# Using Numpy format | ||
train_np = train.with_format("numpy") | ||
test_np = test.with_format("numpy") | ||
x_train, y_train = train_np["img"], train_np["label"] | ||
x_test, y_test = test_np["img"], test_np["label"] | ||
|
||
|
||
# Method for extra learning metrics calculation | ||
def eval_learning(y_test, y_pred): | ||
acc = accuracy_score(y_test, y_pred) | ||
rec = recall_score( | ||
y_test, y_pred, average="micro" | ||
) # average argument required for multi-class | ||
prec = precision_score(y_test, y_pred, average="micro") | ||
f1 = f1_score(y_test, y_pred, average="micro") | ||
return acc, rec, prec, f1 | ||
|
||
|
||
# Define Flower client | ||
class FlowerClient(fl.client.NumPyClient): | ||
def get_parameters(self, config): | ||
return model.get_weights() | ||
|
||
def fit(self, parameters, config): | ||
model.set_weights(parameters) | ||
model.fit(x_train, y_train, epochs=1, batch_size=32) | ||
return model.get_weights(), len(x_train), {} | ||
|
||
def evaluate(self, parameters, config): | ||
model.set_weights(parameters) | ||
loss, accuracy = model.evaluate(x_test, y_test) | ||
y_pred = model.predict(x_test) | ||
y_pred = np.argmax(y_pred, axis=1).reshape( | ||
-1, 1 | ||
) # MobileNetV2 outputs 10 possible classes, argmax returns just the most probable | ||
|
||
acc, rec, prec, f1 = eval_learning(y_test, y_pred) | ||
output_dict = { | ||
"accuracy": accuracy, # accuracy from tensorflow model.evaluate | ||
"acc": acc, | ||
"rec": rec, | ||
"prec": prec, | ||
"f1": f1, | ||
} | ||
return loss, len(x_test), output_dict | ||
|
||
|
||
# Start Flower client | ||
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
[build-system] | ||
requires = ["poetry-core>=1.4.0"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.poetry] | ||
name = "custom-metrics" | ||
version = "0.1.0" | ||
description = "Federated Learning with Flower and Custom Metrics" | ||
authors = [ | ||
"The Flower Authors <[email protected]>", | ||
"Gustavo Bertoli <gubertoli -at- gmail.com>" | ||
] | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.8,<3.11" | ||
flwr = ">=1.0,<2.0" | ||
flwr-datasets = { version = "*", extras = ["vision"] } | ||
scikit-learn = "^1.2.2" | ||
tensorflow = "==2.12.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
flwr>=1.0,<2.0 | ||
flwr-datasets[vision] | ||
scikit-learn>=1.2.2 | ||
tensorflow==2.12.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#!/bin/bash | ||
|
||
echo "Starting server" | ||
python server.py & | ||
sleep 3 # Sleep for 3s to give the server enough time to start | ||
|
||
for i in `seq 0 1`; do | ||
echo "Starting client $i" | ||
python client.py & | ||
done | ||
|
||
# This will allow you to use CTRL+C to stop all background processes | ||
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM | ||
# Wait for all background processes to complete | ||
wait |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import flwr as fl | ||
import numpy as np | ||
|
||
|
||
# Define metrics aggregation function | ||
def average_metrics(metrics): | ||
"""Aggregate metrics from multiple clients by calculating mean averages. | ||
Parameters: | ||
- metrics (list): A list containing tuples, where each tuple represents metrics for a client. | ||
Each tuple is structured as (num_examples, metric), where: | ||
- num_examples (int): The number of examples used to compute the metrics. | ||
- metric (dict): A dictionary containing custom metrics provided as `output_dict` | ||
in the `evaluate` method from `client.py`. | ||
Returns: | ||
A dictionary with the aggregated metrics, calculating mean averages. The keys of the | ||
dictionary represent different metrics, including: | ||
- 'accuracy': Mean accuracy calculated by TensorFlow. | ||
- 'acc': Mean accuracy from scikit-learn. | ||
- 'rec': Mean recall from scikit-learn. | ||
- 'prec': Mean precision from scikit-learn. | ||
- 'f1': Mean F1 score from scikit-learn. | ||
Note: If a weighted average is required, the `num_examples` parameter can be leveraged. | ||
Example: | ||
Example `metrics` list for two clients after the last round: | ||
[(10000, {'prec': 0.108, 'acc': 0.108, 'f1': 0.108, 'accuracy': 0.1080000028014183, 'rec': 0.108}), | ||
(10000, {'f1': 0.108, 'rec': 0.108, 'accuracy': 0.1080000028014183, 'prec': 0.108, 'acc': 0.108})] | ||
""" | ||
|
||
# Here num_examples are not taken into account by using _ | ||
accuracies_tf = np.mean([metric["accuracy"] for _, metric in metrics]) | ||
accuracies = np.mean([metric["acc"] for _, metric in metrics]) | ||
recalls = np.mean([metric["rec"] for _, metric in metrics]) | ||
precisions = np.mean([metric["prec"] for _, metric in metrics]) | ||
f1s = np.mean([metric["f1"] for _, metric in metrics]) | ||
|
||
return { | ||
"accuracy": accuracies_tf, | ||
"acc": accuracies, | ||
"rec": recalls, | ||
"prec": precisions, | ||
"f1": f1s, | ||
} | ||
|
||
|
||
# Define strategy and the custom aggregation function for the evaluation metrics | ||
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=average_metrics) | ||
|
||
|
||
# Start Flower server | ||
fl.server.start_server( | ||
server_address="0.0.0.0:8080", | ||
config=fl.server.ServerConfig(num_rounds=3), | ||
strategy=strategy, | ||
) |