Skip to content

Commit

Permalink
Migrate TensorFlow quickstart to Flower Datasets (#2318)
Browse files Browse the repository at this point in the history
Co-authored-by: jafermarq <[email protected]>
  • Loading branch information
adam-narozniak and jafermarq authored Dec 20, 2023
1 parent c078d3a commit f6a10f9
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 7 deletions.
8 changes: 4 additions & 4 deletions examples/quickstart-tensorflow/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Flower Example using TensorFlow/Keras

This introductory example to Flower uses Keras but deep knowledge of Keras is not necessarily required to run the example. However, it will help you understanding how to adapt Flower to your use-cases.
Running this example in itself is quite easy.
This introductory example to Flower uses Keras but deep knowledge of Keras is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case.
Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset.

## Project Setup

Expand Down Expand Up @@ -50,7 +50,7 @@ pip install -r requirements.txt

## Run Federated Learning with TensorFlow/Keras and Flower

Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows:
Afterward, you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows:

```shell
poetry run python3 server.py
Expand All @@ -62,7 +62,7 @@ Now you are ready to start the Flower clients which will participate in the lear
poetry run python3 client.py
```

Alternatively you can run all of it in one shell as follows:
Alternatively, you can run all of it in one shell as follows:

```shell
poetry run python3 server.py &
Expand Down
26 changes: 24 additions & 2 deletions examples/quickstart-tensorflow/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,38 @@
import argparse
import os

import flwr as fl
import tensorflow as tf

from flwr_datasets import FederatedDataset

# Make TensorFlow log less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# Parse arguments
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--node-id",
type=int,
choices=[0, 1, 2],
required=True,
help="Partition of the dataset (0,1 or 2). "
"The dataset is divided into 3 partitions created artificially.",
)
args = parser.parse_args()

# Load model and data (MobileNetV2, CIFAR-10)
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Download and partition dataset
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3})
partition = fds.load_partition(args.node_id, "train")
partition.set_format("numpy")

# Divide data on each node: 80% train, 20% test
partition = partition.train_test_split(test_size=0.2)
x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"]
x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"]


# Define Flower client
Expand Down
1 change: 1 addition & 0 deletions examples/quickstart-tensorflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ authors = ["The Flower Authors <[email protected]>"]
[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" }
tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""}
tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""}
1 change: 1 addition & 0 deletions examples/quickstart-tensorflow/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
flwr>=1.0, <2.0
flwr-datasets[vision]>=0.0.2, <1.0.0
tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64"
tensorflow-cpu>=2.9.1, != 2.11.1 ; platform_machine == "x86_64"
2 changes: 1 addition & 1 deletion examples/quickstart-tensorflow/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start

for i in `seq 0 1`; do
echo "Starting client $i"
python client.py &
python client.py --node-id $i &
done

# This will allow you to use CTRL+C to stop all background processes
Expand Down
17 changes: 17 additions & 0 deletions examples/quickstart-tensorflow/server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
from typing import List, Tuple

import flwr as fl
from flwr.common import Metrics


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]

# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}


# Define strategy
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average)

# Start Flower server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)

0 comments on commit f6a10f9

Please sign in to comment.