Skip to content

Commit

Permalink
Example for Custom Metrics calculation during Federated Learning (#1958)
Browse files Browse the repository at this point in the history
Co-authored-by: Yan Gao <[email protected]>
  • Loading branch information
gubertoli and yan-gao-GY authored Jan 23, 2024
1 parent d7be8fb commit dfa30a3
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 0 deletions.
106 changes: 106 additions & 0 deletions examples/custom-metrics/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Flower Example using Custom Metrics

This simple example demonstrates how to calculate custom metrics over multiple clients beyond the traditional ones available in the ML frameworks. In this case, it demonstrates the use of ready-available `scikit-learn` metrics: accuracy, recall, precision, and f1-score.

Once both the test values (`y_test`) and the predictions (`y_pred`) are available on the client side (`client.py`), other metrics or custom ones are possible to be calculated.

The main takeaways of this implementation are:

- the use of the `output_dict` on the client side - inside `evaluate` method on `client.py`
- the use of the `evaluate_metrics_aggregation_fn` - to aggregate the metrics on the server side, part of the `strategy` on `server.py`

This example is based on the `quickstart-tensorflow` with CIFAR-10, source [here](https://flower.dev/docs/quickstart-tensorflow.html), with the addition of [Flower Datasets](https://flower.dev/docs/datasets/index.html) to retrieve the CIFAR-10.

Using the CIFAR-10 dataset for classification, this is a multi-class classification problem, thus some changes on how to calculate the metrics using `average='micro'` and `np.argmax` is required. For binary classification, this is not required. Also, for unsupervised learning tasks, such as using a deep autoencoder, a custom metric based on reconstruction error could be implemented on client side.

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/custom-metrics . && rm -rf flower && cd custom-metrics
```

This will create a new directory called `custom-metrics` containing the following files:

```shell
-- pyproject.toml
-- requirements.txt
-- client.py
-- server.py
-- run.sh
-- README.md
```

### Installing Dependencies

Project dependencies (such as `scikit-learn`, `tensorflow` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

#### Poetry

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
poetry run python3 -c "import flwr"
```

If you don't see any errors you're good to go!

#### pip

Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt.

```shell
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
```

## Run Federated Learning with Custom Metrics

Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows:

```shell
python server.py
```

Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each:

```shell
python client.py
```

Alternatively you can run all of it in one shell as follows:

```shell
python server.py &
# Wait for a few seconds to give the server enough time to start, then:
python client.py &
python client.py
```

or

```shell
chmod +x run.sh
./run.sh
```

You will see that Keras is starting a federated training. Have a look to the [Flower Quickstarter documentation](https://flower.dev/docs/quickstart-tensorflow.html) for a detailed explanation. You can add `steps_per_epoch=3` to `model.fit()` if you just want to evaluate that everything works without having to wait for the client-side training to finish (this will save you a lot of time during development).

Running `run.sh` will result in the following output (after 3 rounds):

```shell
INFO flwr 2024-01-17 17:45:23,794 | app.py:228 | app_fit: metrics_distributed {
'accuracy': [(1, 0.10000000149011612), (2, 0.10000000149011612), (3, 0.3393000066280365)],
'acc': [(1, 0.1), (2, 0.1), (3, 0.3393)],
'rec': [(1, 0.1), (2, 0.1), (3, 0.3393)],
'prec': [(1, 0.1), (2, 0.1), (3, 0.3393)],
'f1': [(1, 0.10000000000000002), (2, 0.10000000000000002), (3, 0.3393)]
}
```
71 changes: 71 additions & 0 deletions examples/custom-metrics/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os

import flwr as fl
import numpy as np
import tensorflow as tf
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
from flwr_datasets import FederatedDataset


# Make TensorFlow log less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


# Load model (MobileNetV2)
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])

# Load data with Flower Datasets (CIFAR-10)
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10})
train = fds.load_full("train")
test = fds.load_full("test")

# Using Numpy format
train_np = train.with_format("numpy")
test_np = test.with_format("numpy")
x_train, y_train = train_np["img"], train_np["label"]
x_test, y_test = test_np["img"], test_np["label"]


# Method for extra learning metrics calculation
def eval_learning(y_test, y_pred):
acc = accuracy_score(y_test, y_pred)
rec = recall_score(
y_test, y_pred, average="micro"
) # average argument required for multi-class
prec = precision_score(y_test, y_pred, average="micro")
f1 = f1_score(y_test, y_pred, average="micro")
return acc, rec, prec, f1


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return model.get_weights()

def fit(self, parameters, config):
model.set_weights(parameters)
model.fit(x_train, y_train, epochs=1, batch_size=32)
return model.get_weights(), len(x_train), {}

def evaluate(self, parameters, config):
model.set_weights(parameters)
loss, accuracy = model.evaluate(x_test, y_test)
y_pred = model.predict(x_test)
y_pred = np.argmax(y_pred, axis=1).reshape(
-1, 1
) # MobileNetV2 outputs 10 possible classes, argmax returns just the most probable

acc, rec, prec, f1 = eval_learning(y_test, y_pred)
output_dict = {
"accuracy": accuracy, # accuracy from tensorflow model.evaluate
"acc": acc,
"rec": rec,
"prec": prec,
"f1": f1,
}
return loss, len(x_test), output_dict


# Start Flower client
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient())
19 changes: 19 additions & 0 deletions examples/custom-metrics/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "custom-metrics"
version = "0.1.0"
description = "Federated Learning with Flower and Custom Metrics"
authors = [
"The Flower Authors <[email protected]>",
"Gustavo Bertoli <gubertoli -at- gmail.com>"
]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
flwr-datasets = { version = "*", extras = ["vision"] }
scikit-learn = "^1.2.2"
tensorflow = "==2.12.0"
4 changes: 4 additions & 0 deletions examples/custom-metrics/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
flwr>=1.0,<2.0
flwr-datasets[vision]
scikit-learn>=1.2.2
tensorflow==2.12.0
15 changes: 15 additions & 0 deletions examples/custom-metrics/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

echo "Starting server"
python server.py &
sleep 3 # Sleep for 3s to give the server enough time to start

for i in `seq 0 1`; do
echo "Starting client $i"
python client.py &
done

# This will allow you to use CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
58 changes: 58 additions & 0 deletions examples/custom-metrics/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import flwr as fl
import numpy as np


# Define metrics aggregation function
def average_metrics(metrics):
"""Aggregate metrics from multiple clients by calculating mean averages.
Parameters:
- metrics (list): A list containing tuples, where each tuple represents metrics for a client.
Each tuple is structured as (num_examples, metric), where:
- num_examples (int): The number of examples used to compute the metrics.
- metric (dict): A dictionary containing custom metrics provided as `output_dict`
in the `evaluate` method from `client.py`.
Returns:
A dictionary with the aggregated metrics, calculating mean averages. The keys of the
dictionary represent different metrics, including:
- 'accuracy': Mean accuracy calculated by TensorFlow.
- 'acc': Mean accuracy from scikit-learn.
- 'rec': Mean recall from scikit-learn.
- 'prec': Mean precision from scikit-learn.
- 'f1': Mean F1 score from scikit-learn.
Note: If a weighted average is required, the `num_examples` parameter can be leveraged.
Example:
Example `metrics` list for two clients after the last round:
[(10000, {'prec': 0.108, 'acc': 0.108, 'f1': 0.108, 'accuracy': 0.1080000028014183, 'rec': 0.108}),
(10000, {'f1': 0.108, 'rec': 0.108, 'accuracy': 0.1080000028014183, 'prec': 0.108, 'acc': 0.108})]
"""

# Here num_examples are not taken into account by using _
accuracies_tf = np.mean([metric["accuracy"] for _, metric in metrics])
accuracies = np.mean([metric["acc"] for _, metric in metrics])
recalls = np.mean([metric["rec"] for _, metric in metrics])
precisions = np.mean([metric["prec"] for _, metric in metrics])
f1s = np.mean([metric["f1"] for _, metric in metrics])

return {
"accuracy": accuracies_tf,
"acc": accuracies,
"rec": recalls,
"prec": precisions,
"f1": f1s,
}


# Define strategy and the custom aggregation function for the evaluation metrics
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=average_metrics)


# Start Flower server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)

0 comments on commit dfa30a3

Please sign in to comment.