Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example for Custom Metrics calculation during Federated Learning #1958

Merged
merged 34 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
92e598c
Update sim.ipynb
gubertoli Mar 6, 2022
7d58b6f
Merge branch 'main' into main
danieljanes Mar 14, 2022
2311465
Merge branch 'main' into main
danieljanes Mar 14, 2022
b559f88
Merge branch 'adap:main' into main
gubertoli Mar 15, 2022
6372cc7
Merge branch 'adap:main' into main
gubertoli Jun 28, 2022
64c7124
Merge branch 'adap:main' into main
gubertoli Jul 12, 2022
0770547
Change to comply with .fit() tuple requirements
gubertoli Jul 12, 2022
6a9076f
Merge branch 'main' into main
danieljanes Jul 13, 2022
99645ef
Merge branch 'adap:main' into main
gubertoli Jul 13, 2022
6774045
Merge branch 'adap:main' into main
gubertoli Jun 22, 2023
1eefe76
custom metrics example
gubertoli Jun 22, 2023
0edb274
Merge branch 'main' into extra_metrics
gubertoli Oct 18, 2023
e441f7a
Merge branch 'main' into extra_metrics
gubertoli Jan 16, 2024
63ee989
Format and test ok
gubertoli Jan 16, 2024
5e25c3c
README
gubertoli Jan 16, 2024
947897a
Merge branch 'main' into extra_metrics
danieljanes Jan 17, 2024
949b4ee
Update examples/custom-metrics/requirements.txt
gubertoli Jan 17, 2024
8298aec
Update examples/custom-metrics/client.py
gubertoli Jan 17, 2024
c5be003
Update to FlowerClient class and added e-mail
gubertoli Jan 17, 2024
41043b7
Using flwr-datasets and tested with pip and poetry
gubertoli Jan 17, 2024
5762f59
Merge branch 'main' into extra_metrics
gubertoli Jan 17, 2024
76e0711
Merge branch 'main' into extra_metrics
gubertoli Jan 17, 2024
8474a75
Merge branch 'main' into extra_metrics
gubertoli Jan 17, 2024
9d7002f
Merge branch 'main' into extra_metrics
danieljanes Jan 18, 2024
56e0b36
Uppercase comment
gubertoli Jan 18, 2024
d3eebe7
Uppercase comment
gubertoli Jan 18, 2024
290ac5a
Add comment
gubertoli Jan 18, 2024
8d1c380
Add comment about waiting for server.py
gubertoli Jan 18, 2024
4b0978d
Add comment about strategy definition
gubertoli Jan 18, 2024
5d9319b
Fix typos
gubertoli Jan 18, 2024
8e29ba8
Fix typo
gubertoli Jan 18, 2024
01e3ba2
Add missing reference to run.sh
gubertoli Jan 18, 2024
e5ecd6b
Improving docstring about mean average and about weighted average
gubertoli Jan 18, 2024
73f2cac
Merge branch 'main' into extra_metrics
yan-gao-GY Jan 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions examples/custom-metrics/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Flower Example using Custom Metrics

This simple example demonstrate how to calculate custom metrics over multiple clients beyond the traditional ones available in the ML frameworks. In this case, it demonstrate the use of ready-available scikit-learn metrics: accuracy, recall, precision, and f1-score.
gubertoli marked this conversation as resolved.
Show resolved Hide resolved

Once both the test values (`y_test`) and the predictions (`y_pred`) are available on the client side (`client.py`), other metrics or custom ones are possible to be calculated.

The main takeaways of this implementation are:

- the use of the `output_dict` on the client side - inside `evaluate` method on `client.py`
- the use of the `evaluate_metrics_aggregation_fn` - to aggregate the metrics on the server side, part of the `strategy` on `server.py`

This example is based on the `quickstart_tensorflow` with CIFAR-10, source [here](https://flower.dev/docs/quickstart-tensorflow.html).

Using the CIFAR-10 dataset for classification, this is a multi-class classification problem, thus some changes on how to calculate the metrics using `average='micro'` and `np.argmax` is required. For binary classification, this is not required. Also, for unsupervised learning tasks, such as using a deep autoencoder, a custom metric based on reconstruction error could be implemented on client side.

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/custom-metrics . && rm -rf flower && cd custom-metrics
```

This will create a new directory called `custom-metrics` containing the following files:

```shell
-- pyproject.toml
-- requirements.txt
-- client.py
-- server.py
gubertoli marked this conversation as resolved.
Show resolved Hide resolved
-- README.md
```

### Installing Dependencies

Project dependencies (such as `scikit-learn`, `tensorflow` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

#### Poetry

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
poetry run python3 -c "import flwr"
```

If you don't see any errors you're good to go!

#### pip

Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt.

```shell
pip install -r requirements.txt
```

## Run Federated Learning with Custom Metrics

Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows:

```shell
poetry run python3 server.py
```

Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each:

```shell
poetry run python3 client.py
```

Alternatively you can run all of it in one shell as follows:

```shell
poetry run python3 server.py &
poetry run python3 client.py &
poetry run python3 client.py
```

You will see that Keras is starting a federated training. Have a look to the [Flower Quickstarter documentation](https://flower.dev/docs/quickstart-tensorflow.html) for a detailed explanation. You can add `steps_per_epoch=3` to `model.fit()` if you just want to evaluate that everything works without having to wait for the client-side training to finish (this will save you a lot of time during development).
60 changes: 60 additions & 0 deletions examples/custom-metrics/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os

import flwr as fl
import numpy as np
import tensorflow as tf
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score


# Make TensorFlow log less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


# Load model and data (MobileNetV2, CIFAR-10)
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
gubertoli marked this conversation as resolved.
Show resolved Hide resolved


# method for extra learning metrics calculation
gubertoli marked this conversation as resolved.
Show resolved Hide resolved
def eval_learning(y_test, y_pred):
acc = accuracy_score(y_test, y_pred)
rec = recall_score(
y_test, y_pred, average="micro"
) # average argument required for multi-class
prec = precision_score(y_test, y_pred, average="micro")
f1 = f1_score(y_test, y_pred, average="micro")
return acc, rec, prec, f1


# Define Flower client
class CifarClient(fl.client.NumPyClient):
gubertoli marked this conversation as resolved.
Show resolved Hide resolved
def get_parameters(self, config):
return model.get_weights()

def fit(self, parameters, config):
model.set_weights(parameters)
model.fit(x_train, y_train, epochs=1, batch_size=32)
return model.get_weights(), len(x_train), {}

def evaluate(self, parameters, config):
model.set_weights(parameters)
loss, accuracy = model.evaluate(x_test, y_test)
y_pred = model.predict(x_test)
y_pred = np.argmax(y_pred, axis=1).reshape(
-1, 1
) # MobileNetV2 outputs 10 possible classes, argmax returns just the most probable

acc, rec, prec, f1 = eval_learning(y_test, y_pred)
output_dict = {
"accuracy": accuracy, # accuracy from tensorflow model.evaluate
"acc": acc,
"rec": rec,
"prec": prec,
"f1": f1,
}
return loss, len(x_test), output_dict


# Start Flower client
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=CifarClient())
19 changes: 19 additions & 0 deletions examples/custom-metrics/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "custom-metrics"
version = "0.1.0"
description = "Federated Learning with Flower and Custom Metrics"
authors = [
"The Flower Authors <[email protected]>",
"Gustavo Bertoli <>"
gubertoli marked this conversation as resolved.
Show resolved Hide resolved
]

[tool.poetry.dependencies]
python = "^3.8"
flwr = ">=1.0,<2.0"
scikit-learn = "^1.2.2"
tensorflow-cpu = {version = "^2.9.1, !=2.11.1", markers="platform_machine == 'x86_64'"}
tensorflow-macos = {version = "^2.9.1, !=2.11.1", markers="sys_platform == 'darwin' and platform_machine == 'arm64'"}
3 changes: 3 additions & 0 deletions examples/custom-metrics/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
flwr==1.4.0
gubertoli marked this conversation as resolved.
Show resolved Hide resolved
scikit-learn==1.2.2
tensorflow==2.12.0
15 changes: 15 additions & 0 deletions examples/custom-metrics/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

echo "Starting server"
python server.py &
sleep 3 # Sleep for 3s to give the server enough time to start

for i in `seq 0 1`; do
echo "Starting client $i"
python client.py &
done

# This will allow you to use CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
29 changes: 29 additions & 0 deletions examples/custom-metrics/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import flwr as fl
import numpy as np


gubertoli marked this conversation as resolved.
Show resolved Hide resolved
def average_metrics(metrics):
gubertoli marked this conversation as resolved.
Show resolved Hide resolved
accuracies_tf = np.mean([metric["accuracy"] for _, metric in metrics])
accuracies = np.mean([metric["acc"] for _, metric in metrics])
recalls = np.mean([metric["rec"] for _, metric in metrics])
precisions = np.mean([metric["prec"] for _, metric in metrics])
f1s = np.mean([metric["f1"] for _, metric in metrics])

return {
"accuracy": accuracies_tf,
"acc": accuracies,
"rec": recalls,
"prec": precisions,
"f1": f1s,
}


gubertoli marked this conversation as resolved.
Show resolved Hide resolved
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=average_metrics)


# Start Flower server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)