From dfa30a30681042e1ad03264fa4aa5c9a1af1960a Mon Sep 17 00:00:00 2001 From: Gustavo Bertoli Date: Tue, 23 Jan 2024 15:25:47 +0100 Subject: [PATCH] Example for Custom Metrics calculation during Federated Learning (#1958) Co-authored-by: Yan Gao --- examples/custom-metrics/README.md | 106 +++++++++++++++++++++++ examples/custom-metrics/client.py | 71 +++++++++++++++ examples/custom-metrics/pyproject.toml | 19 ++++ examples/custom-metrics/requirements.txt | 4 + examples/custom-metrics/run.sh | 15 ++++ examples/custom-metrics/server.py | 58 +++++++++++++ 6 files changed, 273 insertions(+) create mode 100644 examples/custom-metrics/README.md create mode 100644 examples/custom-metrics/client.py create mode 100644 examples/custom-metrics/pyproject.toml create mode 100644 examples/custom-metrics/requirements.txt create mode 100755 examples/custom-metrics/run.sh create mode 100644 examples/custom-metrics/server.py diff --git a/examples/custom-metrics/README.md b/examples/custom-metrics/README.md new file mode 100644 index 000000000000..debcd7919839 --- /dev/null +++ b/examples/custom-metrics/README.md @@ -0,0 +1,106 @@ +# Flower Example using Custom Metrics + +This simple example demonstrates how to calculate custom metrics over multiple clients beyond the traditional ones available in the ML frameworks. In this case, it demonstrates the use of ready-available `scikit-learn` metrics: accuracy, recall, precision, and f1-score. + +Once both the test values (`y_test`) and the predictions (`y_pred`) are available on the client side (`client.py`), other metrics or custom ones are possible to be calculated. + +The main takeaways of this implementation are: + +- the use of the `output_dict` on the client side - inside `evaluate` method on `client.py` +- the use of the `evaluate_metrics_aggregation_fn` - to aggregate the metrics on the server side, part of the `strategy` on `server.py` + +This example is based on the `quickstart-tensorflow` with CIFAR-10, source [here](https://flower.dev/docs/quickstart-tensorflow.html), with the addition of [Flower Datasets](https://flower.dev/docs/datasets/index.html) to retrieve the CIFAR-10. + +Using the CIFAR-10 dataset for classification, this is a multi-class classification problem, thus some changes on how to calculate the metrics using `average='micro'` and `np.argmax` is required. For binary classification, this is not required. Also, for unsupervised learning tasks, such as using a deep autoencoder, a custom metric based on reconstruction error could be implemented on client side. + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/custom-metrics . && rm -rf flower && cd custom-metrics +``` + +This will create a new directory called `custom-metrics` containing the following files: + +```shell +-- pyproject.toml +-- requirements.txt +-- client.py +-- server.py +-- run.sh +-- README.md +``` + +### Installing Dependencies + +Project dependencies (such as `scikit-learn`, `tensorflow` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +python -m venv venv +source venv/bin/activate +pip install -r requirements.txt +``` + +## Run Federated Learning with Custom Metrics + +Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: + +```shell +python server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: + +```shell +python client.py +``` + +Alternatively you can run all of it in one shell as follows: + +```shell +python server.py & +# Wait for a few seconds to give the server enough time to start, then: +python client.py & +python client.py +``` + +or + +```shell +chmod +x run.sh +./run.sh +``` + +You will see that Keras is starting a federated training. Have a look to the [Flower Quickstarter documentation](https://flower.dev/docs/quickstart-tensorflow.html) for a detailed explanation. You can add `steps_per_epoch=3` to `model.fit()` if you just want to evaluate that everything works without having to wait for the client-side training to finish (this will save you a lot of time during development). + +Running `run.sh` will result in the following output (after 3 rounds): + +```shell +INFO flwr 2024-01-17 17:45:23,794 | app.py:228 | app_fit: metrics_distributed { + 'accuracy': [(1, 0.10000000149011612), (2, 0.10000000149011612), (3, 0.3393000066280365)], + 'acc': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'rec': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'prec': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'f1': [(1, 0.10000000000000002), (2, 0.10000000000000002), (3, 0.3393)] +} +``` diff --git a/examples/custom-metrics/client.py b/examples/custom-metrics/client.py new file mode 100644 index 000000000000..b2206118ed44 --- /dev/null +++ b/examples/custom-metrics/client.py @@ -0,0 +1,71 @@ +import os + +import flwr as fl +import numpy as np +import tensorflow as tf +from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score +from flwr_datasets import FederatedDataset + + +# Make TensorFlow log less verbose +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + + +# Load model (MobileNetV2) +model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None) +model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"]) + +# Load data with Flower Datasets (CIFAR-10) +fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) +train = fds.load_full("train") +test = fds.load_full("test") + +# Using Numpy format +train_np = train.with_format("numpy") +test_np = test.with_format("numpy") +x_train, y_train = train_np["img"], train_np["label"] +x_test, y_test = test_np["img"], test_np["label"] + + +# Method for extra learning metrics calculation +def eval_learning(y_test, y_pred): + acc = accuracy_score(y_test, y_pred) + rec = recall_score( + y_test, y_pred, average="micro" + ) # average argument required for multi-class + prec = precision_score(y_test, y_pred, average="micro") + f1 = f1_score(y_test, y_pred, average="micro") + return acc, rec, prec, f1 + + +# Define Flower client +class FlowerClient(fl.client.NumPyClient): + def get_parameters(self, config): + return model.get_weights() + + def fit(self, parameters, config): + model.set_weights(parameters) + model.fit(x_train, y_train, epochs=1, batch_size=32) + return model.get_weights(), len(x_train), {} + + def evaluate(self, parameters, config): + model.set_weights(parameters) + loss, accuracy = model.evaluate(x_test, y_test) + y_pred = model.predict(x_test) + y_pred = np.argmax(y_pred, axis=1).reshape( + -1, 1 + ) # MobileNetV2 outputs 10 possible classes, argmax returns just the most probable + + acc, rec, prec, f1 = eval_learning(y_test, y_pred) + output_dict = { + "accuracy": accuracy, # accuracy from tensorflow model.evaluate + "acc": acc, + "rec": rec, + "prec": prec, + "f1": f1, + } + return loss, len(x_test), output_dict + + +# Start Flower client +fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) diff --git a/examples/custom-metrics/pyproject.toml b/examples/custom-metrics/pyproject.toml new file mode 100644 index 000000000000..8a2da6562018 --- /dev/null +++ b/examples/custom-metrics/pyproject.toml @@ -0,0 +1,19 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "custom-metrics" +version = "0.1.0" +description = "Federated Learning with Flower and Custom Metrics" +authors = [ + "The Flower Authors ", + "Gustavo Bertoli " +] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr = ">=1.0,<2.0" +flwr-datasets = { version = "*", extras = ["vision"] } +scikit-learn = "^1.2.2" +tensorflow = "==2.12.0" \ No newline at end of file diff --git a/examples/custom-metrics/requirements.txt b/examples/custom-metrics/requirements.txt new file mode 100644 index 000000000000..69d867c5f287 --- /dev/null +++ b/examples/custom-metrics/requirements.txt @@ -0,0 +1,4 @@ +flwr>=1.0,<2.0 +flwr-datasets[vision] +scikit-learn>=1.2.2 +tensorflow==2.12.0 diff --git a/examples/custom-metrics/run.sh b/examples/custom-metrics/run.sh new file mode 100755 index 000000000000..c64f362086aa --- /dev/null +++ b/examples/custom-metrics/run.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +echo "Starting server" +python server.py & +sleep 3 # Sleep for 3s to give the server enough time to start + +for i in `seq 0 1`; do + echo "Starting client $i" + python client.py & +done + +# This will allow you to use CTRL+C to stop all background processes +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/custom-metrics/server.py b/examples/custom-metrics/server.py new file mode 100644 index 000000000000..f8420bf51f16 --- /dev/null +++ b/examples/custom-metrics/server.py @@ -0,0 +1,58 @@ +import flwr as fl +import numpy as np + + +# Define metrics aggregation function +def average_metrics(metrics): + """Aggregate metrics from multiple clients by calculating mean averages. + + Parameters: + - metrics (list): A list containing tuples, where each tuple represents metrics for a client. + Each tuple is structured as (num_examples, metric), where: + - num_examples (int): The number of examples used to compute the metrics. + - metric (dict): A dictionary containing custom metrics provided as `output_dict` + in the `evaluate` method from `client.py`. + + Returns: + A dictionary with the aggregated metrics, calculating mean averages. The keys of the + dictionary represent different metrics, including: + - 'accuracy': Mean accuracy calculated by TensorFlow. + - 'acc': Mean accuracy from scikit-learn. + - 'rec': Mean recall from scikit-learn. + - 'prec': Mean precision from scikit-learn. + - 'f1': Mean F1 score from scikit-learn. + + Note: If a weighted average is required, the `num_examples` parameter can be leveraged. + + Example: + Example `metrics` list for two clients after the last round: + [(10000, {'prec': 0.108, 'acc': 0.108, 'f1': 0.108, 'accuracy': 0.1080000028014183, 'rec': 0.108}), + (10000, {'f1': 0.108, 'rec': 0.108, 'accuracy': 0.1080000028014183, 'prec': 0.108, 'acc': 0.108})] + """ + + # Here num_examples are not taken into account by using _ + accuracies_tf = np.mean([metric["accuracy"] for _, metric in metrics]) + accuracies = np.mean([metric["acc"] for _, metric in metrics]) + recalls = np.mean([metric["rec"] for _, metric in metrics]) + precisions = np.mean([metric["prec"] for _, metric in metrics]) + f1s = np.mean([metric["f1"] for _, metric in metrics]) + + return { + "accuracy": accuracies_tf, + "acc": accuracies, + "rec": recalls, + "prec": precisions, + "f1": f1s, + } + + +# Define strategy and the custom aggregation function for the evaluation metrics +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=average_metrics) + + +# Start Flower server +fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +)