diff --git a/examples/embedded-devices/Dockerfile b/examples/embedded-devices/Dockerfile
deleted file mode 100644
index 428739114aeb..000000000000
--- a/examples/embedded-devices/Dockerfile
+++ /dev/null
@@ -1,14 +0,0 @@
-ARG BASE_IMAGE
-
-# Pull the base image from NVIDIA
-FROM $BASE_IMAGE
-
-# Update pip
-RUN pip3 install --upgrade pip
-
-# Install flower
-RUN pip3 install flwr>=1.0
-RUN pip3 install flwr-datasets>=0.0.2
-RUN pip3 install tqdm==4.65.0
-
-WORKDIR /client
diff --git a/examples/embedded-devices/README.md b/examples/embedded-devices/README.md
index 86f19399932d..c03646d475ac 100644
--- a/examples/embedded-devices/README.md
+++ b/examples/embedded-devices/README.md
@@ -1,24 +1,25 @@
---
-tags: [basic, vision, fds]
-dataset: [CIFAR-10, MNIST]
-framework: [torch, tensorflow]
+tags: [basic, vision, embedded]
+dataset: [Fashion-MNIST]
+framework: [torch]
---
-# Federated Learning on Embedded Devices with Flower
+# Federated AI with Embedded Devices using Flower
-This example will show you how Flower makes it very easy to run Federated Learning workloads on edge devices. Here we'll be showing how to use NVIDIA Jetson devices and Raspberry Pi as Flower clients. You can run this example using either PyTorch or Tensorflow. The FL workload (i.e. model, dataset and training loop) is mostly borrowed from the [quickstart-pytorch](https://github.com/adap/flower/tree/main/examples/simulation-pytorch) and [quickstart-tensorflow](https://github.com/adap/flower/tree/main/examples/quickstart-tensorflow) examples.
+This example will show you how Flower makes it very easy to run Federated Learning workloads on edge devices. Here we'll be showing how to use Raspberry Pi as Flower clients, or better said, `SuperNodes`. The FL workload (i.e. model, dataset and training loop) is mostly borrowed from the [quickstart-pytorch](https://github.com/adap/flower/tree/main/examples/simulation-pytorch) example, but you could adjust it to follow [quickstart-tensorflow](https://github.com/adap/flower/tree/main/examples/quickstart-tensorflow) if you prefere using TensorFlow. The main difference compare to those examples is that here you'll learn how to use Flower's Deployment Engine to run FL across multiple embedded devices.
![Different was of running Flower FL on embedded devices](_static/diagram.png)
## Getting things ready
+> \[!NOTE\]
> This example is designed for beginners that know a bit about Flower and/or ML but that are less familiar with embedded devices. If you already have a couple of devices up and running, clone this example and start the Flower clients after launching the Flower server.
This tutorial allows for a variety of settings (some shown in the diagrams above). As long as you have access to one embedded device, you can follow along. This is a list of components that you'll need:
-- For Flower server: A machine running Linux/macOS/Windows (e.g. your laptop). You can run the server on an embedded device too!
-- For Flower clients (one or more): Raspberry Pi 4 (or Zero 2), or an NVIDIA Jetson Xavier-NX (or Nano), or anything similar to these.
-- A uSD card with 32GB or more. While 32GB is enough for the RPi, a larger 64GB uSD card works best for the NVIDIA Jetson.
+- For Flower server: A machine running Linux/macOS (e.g. your laptop). You can run the server on an embedded device too!
+- For Flower clients (one or more): Raspberry Pi 5 or 4 (or Zero 2), or anything similar to these.
+- A uSD card with 32GB or more.
- Software to flash the images to a uSD card:
- For Raspberry Pi we recommend the [Raspberry Pi Imager](https://www.raspberrypi.com/software/)
- For other devices [balenaEtcher](https://www.balena.io/etcher/) it's a great option.
@@ -27,197 +28,120 @@ What follows is a step-by-step guide on how to setup your client/s and the serve
## Clone this example
-Start with cloning this example on your laptop or desktop machine. Later you'll run the same command on your embedded devices. We have prepared a single line which you can copy and execute:
+> \[!NOTE\]
+> Cloning the example and installing the project is only needed for the machine that's going to start the run. The embedded devices would typically run a Flower `SuperNode` for which only `flwr` and relevant libraries needed to run the `ClientApp` (more on this later) are needed.
-```bash
-git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/embedded-devices . && rm -rf flower && cd embedded-devices
+Start with cloning this example on your laptop or desktop machine. We have prepared a single line which you can copy and execute:
+
+```shell
+git clone --depth=1 https://github.com/adap/flower.git \
+ && mv flower/examples/embedded-devices . \
+ && rm -rf flower && cd embedded-devices
```
-## Setting up the server
+This will create a new directory called `embedded-devices` with the following structure:
+
+```shell
+embedded-devices
+├── embeddedexample
+│ ├── __init__.py
+│ ├── client_app.py # Defines your ClientApp
+│ ├── server_app.py # Defines your ServerApp
+│ └── task.py # Defines your model, training and data loading
+├── pyproject.toml # Project metadata like dependencies and configs
+└── README.md
+```
-The only requirement for the server is to have Flower installed alongside your ML framework of choice. Inside your Python environment run:
+Install the dependencies defined in `pyproject.toml` as well as the `embeddedexample` package.
```bash
-pip install -r requierments_pytorch.txt # to install Flower and PyTorch
-
-# or the below for TensorFlower
-# pip install -r requirements_tensorflow.txt
+pip install -e .
```
-If you are working on this tutorial on your laptop or desktop, it can host the Flower server that will orchestrate the entire FL process. You could also use an embedded device (e.g. a Raspberry Pi) as the Flower server. In order to do that, please follow the setup steps below.
-
## Setting up a Raspberry Pi
-> Wheter you use your RPi as a Flower server or a client, you need to follow these steps.
+> \[!TIP\]
+> This steps walk you through the process of setting up a Rapsberry Pi. If you have one already running and you have a Python environment with `flwr` installed already, you can skip this section entirely. Taking a quick look at the [Embedded Devices Setup](device_setup.md) page might be useful.
![alt text](_static/rpi_imager.png)
1. **Installing Ubuntu server on your Raspberry Pi** is easy with the [Raspberry Pi Imager](https://www.raspberrypi.com/software/). Before starting ensure you have a uSD card attached to your PC/Laptop and that it has sufficient space (ideally larger than 16GB). Then:
- - Click on `CHOOSE OS` > `Other general-pupose OS` > `Ubuntu` > `Ubuntu Server 22.04.03 LTS (64-bit)`. Other versions of `Ubuntu Server` would likely work but try to use a `64-bit` one.
+ - Click on `CHOOSE OS` > `Raspberry Pi OS (other)` > `Raspberry Pi OS Lite (64-bit)`. Other versions of `Raspberry Pi OS` or even `Ubuntu Server` would likely work but try to use a `64-bit` one.
- Select the uSD you want to flash the OS onto. (This will be the uSD you insert in your Raspberry Pi)
- - Click on the gear icon on the bottom right of the `Raspberry Pi Imager` window (the icon only appears after choosing your OS image). Here you can very conveniently set the username/password to access your device over ssh. You'll see I use as username `piubuntu` (you can choose something different) It's also the ideal place to select your WiFi network and add the password (this is of course not needed if you plan to connect the Raspberry Pi via ethernet). Click "save" when you are done.
- - Finally, click on `WRITE` to start flashing Ubuntu onto the uSD card.
-
-2. **Connecting to your Rapsberry Pi**
-
- After `ssh`-ing into your Raspberry Pi for the first time, make sure your OS is up-to-date.
-
- - Run: `sudo apt update` to look for updates
- - And then: `sudo apt upgrade -y` to apply updates (this might take a few minutes on the RPi Zero)
- - Then reboot your RPi with `sudo reboot`. Then ssh into it again.
-
-3. **Preparations for your Flower experiments**
-
- - Install `pip`. In the terminal type: `sudo apt install python3-pip -y`
- - Now clone this directory. You just need to execute the `git clone` command shown at the top of this README.md on your device.
- - Install Flower and your ML framework of choice: We have prepared some convenient installation scripts that will install everything you need. You are free to install other versions of these ML frameworks to suit your needs.
- - If you want your clients to use PyTorch: `pip3 install -r requirements_pytorch.txt`
- - If you want your clients to use TensorFlow: `pip3 install -r requirements_tf.txt`
-
- > While preparing this example I noticed that installing TensorFlow on the **Raspberry pi Zero** would fail due to lack of RAM (it only has 512MB). A workaround is to create a `swap` disk partition (non-existant by default) so the OS can offload some elements to disk. I followed the steps described [in this blogpost](https://www.digitalocean.com/community/tutorials/how-to-add-swap-space-on-ubuntu-20-04) that I copy below. You can follow these steps if you often see your RPi Zero running out of memory:
-
- ```bash
- # Let's create a 1GB swap partition
- sudo fallocate -l 1G /swapfile
- sudo chmod 600 /swapfile
- sudo mkswap /swapfile
- # Enable swap
- sudo swapon /swapfile # you should now be able to see the swap size on htop.
- # make changes permanent after reboot
- sudo cp /etc/fstab /etc/fstab.bak
- echo '/swapfile none swap sw 0 0' | sudo tee -a /etc/fstab
- ```
-
- Please note using swap as if it was RAM comes with a large penalty in terms of data movement.
-
-4. Run your Flower experiments following the steps in the [Running FL with Flower](https://github.com/adap/flower/tree/main/examples/embedded-devices#running-fl-training-with-flower) section.
-
-## Setting up a Jetson Xavier-NX
-
-> These steps have been validated for a Jetson Xavier-NX Dev Kit. An identical setup is needed for a Jetson Nano once you get ssh access to it (i.e. jumping straight to point `4` below). For instructions on how to setup these devices please refer to the "getting started guides" for [Jetson Nano](https://developer.nvidia.com/embedded/learn/get-started-jetson-nano-devkit#intro).
-
-1. **Install JetPack 5.1.2 on your Jetson device**
-
- - Download the JetPack 5.1.2 image from [NVIDIA-embedded](https://developer.nvidia.com/embedded/jetpack-sdk-512), note that you might need an NVIDIA developer account. You can find the download link under the `SD Card Image Method` section on NVIDIA's site. This image comes with Docker pre-installed as well as PyTorch+Torchvision and TensorFlow compiled with GPU support.
+ - After selecting your storage, click on `Next`. Then, you'll be asked if you want to edit the settings of the image you are about to flash. This allows you to setup a custom username and password as well as indicate to which WiFi network your device should connect to. In the screenshot you can see some dummy values. This tutorial doesn't make any assumptions on these values, set them according to your needs.
+ - Finally, complete the remaining steps to start flashing the chosen OS onto the uSD card.
- - Extract the image (~18GB and named `sd-blob.img`) and flash it onto the uSD card using [balenaEtcher](https://www.balena.io/etcher/) (or equivalent).
+2. **Preparations for your Flower experiments**
-2. **Follow [the instructions](https://developer.nvidia.com/embedded/learn/get-started-jetson-xavier-nx-devkit) to set up the device.** The first time you boot your Xavier-NX you should plug it into a display to complete the installation process. After that, a display is no longer needed for this example but you could still use it instead of connecting to your device over ssh.
+ - SSH into your Rapsberry Pi.
+ - Follow the steps outlined in [Embedded Devices Setup](device_setup.md) to set it up for develpment. The objetive of this step is to have your Pi ready to join later as a Flower `SuperNode` to an existing federation.
-3. **Setup Docker**: Docker comes pre-installed with the Ubuntu image provided by NVIDIA. But for convenience, we will create a new user group and add our user to it (with the idea of not having to use `sudo` for every command involving docker (e.g. `docker run`, `docker ps`, etc)). More details about what this entails can be found in the [Docker documentation](https://docs.docker.com/engine/install/linux-postinstall/). You can achieve this by doing:
+3. Run your Flower experiments following the steps in the [Running FL with Flower](https://github.com/adap/flower/tree/main/examples/embedded-devices#running-fl-training-with-flower) section.
- ```bash
- sudo usermod -aG docker $USER
- # apply changes to current shell (or logout/reboot)
- newgrp docker
- ```
+## Embedded Federated AI
-4. **Update OS and install utilities.** Then, install some useful utilities:
+For this demo, we'll be using [Fashion-MNIST](https://huggingface.co/datasets/zalando-datasets/fashion_mnist), a popular dataset for image classification comprised of 10 classes (e.g. boot, dress, trouser) and a total of 70K `28x28` greyscale images. The training set contains 60K images.
- ```bash
- sudo apt update && sudo apt upgrade -y
- # now reboot
- sudo reboot
- ```
+> \[!TIP\]
+> Refer to the [Flower Architecture](https://flower.ai/docs/framework/explanation-flower-architecture.html) page for an overview of the different components involved in a federation.
- Login again and (optional) install the following packages:
+### Ensure your embedded devices have some data
-
+Unless your devices already have some images that could be used to train a small CNN, we need to send a partition of the `Fashion-MNIST` dataset to each device that will run as a `SuperNode`. You can make use of the `generate_dataset.py` script to partition the `Fashion-MNIST` into N disjoint partitions that can be then given to each device in the federation.
- - [jtop](https://github.com/rbonghi/jetson_stats), to monitor CPU/GPU utilization, power consumption and, many more. You can read more about it in [this blog post](https://jetsonhacks.com/2023/02/07/jtop-the-ultimate-tool-for-monitoring-nvidia-jetson-devices/).
-
- ```bash
- # First we need to install pip3
- sudo apt install python3-pip -y
- # finally, install jtop
- sudo pip3 install -U jetson-stats
- # now reboot (or run `sudo systemctl restart jtop.service` and login again)
- sudo reboot
- ```
-
- Now you have installed `jtop`, just launch it by running the `jtop` command on your terminal. An interactive panel similar to the one shown on the right will show up. `jtop` allows you to monitor and control many features of your Jetson device. Read more in the [jtop documentation](https://rnext.it/jetson_stats/jtop/jtop.html)
-
- - [TMUX](https://github.com/tmux/tmux/wiki), a terminal multiplexer. As its name suggests, it allows you to device a single terminal window into multiple panels. In this way, you could (for example) use one panel to show your terminal and another to show `jtop`. That's precisely what the visualization on the right shows.
-
- ```bash
- # install tmux
- sudo apt install tmux -y
- # add mouse support
- echo set -g mouse on > ~/.tmux.conf
- ```
-
-5. **Power modes**. The Jetson devices can operate at different power modes, each making use of more or less CPU cores clocked at different frequencies. The right power mode might very much depend on the application and scenario. When power consumption is not a limiting factor, we could use the highest 15W mode using all 6 CPU cores. On the other hand, if the devices are battery-powered we might want to make use of a low-power mode using 10W and 2 CPU cores. All the details regarding the different power modes of a Jetson Xavier-NX can be found [here](https://docs.nvidia.com/jetson/l4t/index.html#page/Tegra%2520Linux%2520Driver%2520Package%2520Development%2520Guide%2Fpower_management_jetson_xavier.html%23wwpID0E0NO0HA). For this demo, we'll be setting the device to high-performance mode:
-
- ```bash
- sudo /usr/sbin/nvpmodel -m 2 # 15W with 6cpus @ 1.4GHz
- ```
-
- Jetson Stats (that you launch via `jtop`) also allows you to see and set the power mode on your device. Navigate to the `CTRL` panel and click on one of the `NVM modes` available.
-
-6. **Build base client image**. Before running a Flower client, we need to install `Flower` and other ML dependencies (i.e. Pytorch or Tensorflow). Instead of installing this manually via `pip3 install ...`, let's use the pre-built Docker images provided by NVIDIA. In this way, we can be confident that the ML infrastructure is optimized for these devices. Build your Flower client image with:
-
- ```bash
- # On your Jetson's terminal run
- ./build_jetson_flower_client.sh --pytorch # or --tensorflow
- # Bear in mind this might take a few minutes since the base images need to be donwloaded (~7GB) and decompressed.
- # To the above script pass the additional flag `--no-cache` to re-build the image.
- ```
-
- Once your script is finished, verify your `flower_client` Docker image is present. If you type `docker images` you'll see something like the following:
+```shell
+# Partition the Fashion-MNIST dataset into two partitions
+python generate_dataset.py --num-supernodes=2
+```
- ```bash
- REPOSITORY TAG IMAGE ID CREATED SIZE
- flower_client latest 87e935a8ee37 18 seconds ago 12.6GB
- ```
+The above command will create two subdirectories in `./datasets`, one for each partition. Next, copy those dataset over to your devices. You can use `scp` for this. Like shown below. Repeat for all your devices.
-7. **Access your client image**. Before launching the Flower client, we need to run the image we just created. To keep things simpler, let's run the image in interactive mode (`-it`), mount the entire repository you cloned inside the `/client` directory of your container (`` -v `pwd`:/client ``), and use the NVIDIA runtime so we can access the GPU `--runtime nvidia`:
+```shell
+# Copy one partition to a device
+scp -r datasets/fashionmnist_part_1 @:/path/to/home
+```
- ```bash
- # first ensure you are in the `embedded-devices` directory. If you are not, use the `cd` command to navigate to it
+### Launching the Flower `SuperLink`
- # run the client container (this won't launch your Flower client, it will just "take you inside docker". The client can be run following the steps in the next section of the readme)
- docker run -it --rm --runtime nvidia -v `pwd`:/client flower_client
- # this will take you to a shell that looks something like this:
- root@6e6ce826b8bb:/client#
- ```
+On your development machine, launch the `SuperLink`. You will connnect Flower `SuperNodes` to it in the next step.
-8. **Run your FL experiments with Flower**. Follow the steps in the section below.
+> \[!NOTE\]
+> If you decide to run the `SuperLink` in a different machine, you'll need to adjust the `address` under the `[tool.flwr.federations.embedded-federation]` tag in the `pyproject.toml`.
-## Running Embedded FL with Flower
+```shell
+flower-superlink --insecure
+```
-For this demo, we'll be using [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html), a popular dataset for image classification comprised of 10 classes (e.g. car, bird, airplane) and a total of 60K `32x32` RGB images. The training set contains 50K images. The server will automatically download the dataset should it not be found in `./data`. The clients do the same. The dataset is by default split into 50 partitions (each to be assigned to a different client). This can be controlled with the `NUM_CLIENTS` global variable in the client scripts. In this example, each device will play the role of a specific user (specified via `--cid` -- we'll show this later) and therefore only do local training with that portion of the data. For CIFAR-10, clients will be training a MobileNet-v2/3 model.
+### Connecting Flower `SuperNodes`
-You can run this example using MNIST and a smaller CNN model by passing flag `--mnist`. This is useful if you are using devices with a very limited amount of memory (e.g. RaspberryPi Zero) or if you want the training taking place on the embedded devices to be much faster (specially if these are CPU-only). The partitioning of the dataset is done in the same way.
+With the `SuperLink` up and running, now let's launch a `SuperNode` on each embedded device. In order to do this ensure you know what the IP of the machine running the `SuperLink` is and that you have copied the data to the device. Note with `--node-config` we set a key named `dataset-path`. That's the one expected by the `client_fn()` in [client_app.py](embeddedexample/client_app.py). This file will be automatically delivered to the `SuperNode` so it knows how to execute the `ClientApp` logic.
-### Start your Flower Server
+> \[!NOTE\]
+> You don't need to clone this example to your embedded devices running as Flower `SuperNodes`. The code they will execute (in [embeddedexamples/client_app.py](embeddedexamples/client_app.py)) will automatically be delivered.
-On the machine of your choice, launch the server:
+Ensure the Python environment you created earlier when setting up your device has all dependencies installed. For this example you'll need the following:
-```bash
-# Launch your server.
-# Will wait for at least 2 clients to be connected, then will train for 3 FL rounds
-# The command below will sample all clients connected (since sample_fraction=1.0)
-# The server is dataset agnostic (use the same command for MNIST and CIFAR10)
-python server.py --rounds 3 --min_num_clients 2 --sample_fraction 1.0
+```shell
+# After activating your environment
+pip install -U flwr
+pip install torch torchvision datasets
```
-> If you are on macOS with Apple Silicon (i.e. M1, M2 chips), you might encounter a `grpcio`-related issue when launching your server. If you are in a conda environment you can solve this easily by doing: `pip uninstall grpcio` and then `conda install grpcio`.
+Now, launch your `SuperNode` pointing it to the dataset you `scp`-ed earlier:
-### Start the Flower Clients
-
-It's time to launch your clients! Ensure you have followed the setup stages outline above for the devices at your disposal.
+```shell
+# Repeat for each embedded device (adjust SuperLink IP and dataset-path)
+flower-supernode --insecure --superlink="SUPERLINK_IP:9092" \
+ --node-config="dataset-path='path/to/fashionmnist_part_1'"
+```
-The first time you run this, the dataset will be downloaded. From the commands below, replace `` with either `pytorch` or `tf` to run the corresponding client Python file. In a FL setting, each client has its unique dataset. In this example you can simulate this by manually assigning an ID to a client (`cid`) which should be an integer `[0, NUM_CLIENTS-1]`, where `NUM_CLIENTS` is the total number of partitions or clients that could participate at any point. This is defined at the top of the client files and defaults to `50`. You can change this value to make each partition larger or smaller.
+Repeat for each embedded device that you want to connect to the `SuperLink`.
-Launch your Flower clients as follows. Remember that if you are using a Jetson device, you need first to run your Docker container (see tha last steps for the Jetson setup). If you are using Raspberry Pi Zero devices, it is normal if starting the clients take a few seconds.
+### Run the Flower App
-```bash
-# Run the default example (CIFAR-10)
-python3 client_.py --cid= --server_address=
+With both the long-running server (`SuperLink`) and two `SuperNodes` up and running, we can now start run. Note that the command below points to a federation named `embedded-federation`. Its entry point is defined in the `pyproject.toml`. Run the following from your development machine where you have cloned this example to, e.g. your laptop.
-# Use MNIST (and a smaller model) if your devices require a more lightweight workload
-python3 client_.py --cid= --server_address= --mnist
+```shell
+flwr run . embedded-federation
```
-
-Repeat the above for as many devices as you have. Pass a different `CLIENT_ID` to each device. You can naturally run this example using different types of devices (e.g. RPi, RPi Zero, Jetson) at the same time as long as they are training the same model. If you want to start more clients than the number of embedded devices you currently have access to, you can launch clients in your laptop: simply open a new terminal and run one of the `python3 client_.py ...` commands above.
diff --git a/examples/embedded-devices/_static/rpi_imager.png b/examples/embedded-devices/_static/rpi_imager.png
index a59a3137334e..958290fc112f 100644
Binary files a/examples/embedded-devices/_static/rpi_imager.png and b/examples/embedded-devices/_static/rpi_imager.png differ
diff --git a/examples/embedded-devices/_static/tmux_jtop_view.gif b/examples/embedded-devices/_static/tmux_jtop_view.gif
deleted file mode 100644
index 7e92b586851a..000000000000
Binary files a/examples/embedded-devices/_static/tmux_jtop_view.gif and /dev/null differ
diff --git a/examples/embedded-devices/build_jetson_flower_client.sh b/examples/embedded-devices/build_jetson_flower_client.sh
deleted file mode 100755
index 32725a58f1f7..000000000000
--- a/examples/embedded-devices/build_jetson_flower_client.sh
+++ /dev/null
@@ -1,42 +0,0 @@
-#!/bin/bash
-
-if [ -z "${CI}" ]; then
- BUILDKIT=1
-else
- BUILDKIT=0
-fi
-
-# This script build a docker image that's ready to run your flower client.
-# Depending on your choice of ML framework (TF or PyTorch), the appropiate
-# base image from NVIDIA will be pulled. This ensures you get the best
-# performance out of your Jetson device.
-
-BASE_PYTORCH=nvcr.io/nvidia/l4t-pytorch:r35.1.0-pth1.13-py3
-BASE_TF=nvcr.io/nvidia/l4t-tensorflow:r35.3.1-tf2.11-py3
-EXTRA=""
-
-while [[ $# -gt 0 ]]; do
- case $1 in
- -p|--pytorch)
- BASE_IMAGE=$BASE_PYTORCH
- shift
- ;;
- -t|--tensorflow)
- BASE_IMAGE=$BASE_TF
- shift
- ;;
- -r|--no-cache)
- EXTRA="--no-cache"
- shift
- ;;
- -*|--*)
- echo "Unknown option $1 (pass either --pytorch or --tensorflow)"
- exit 1
- ;;
- esac
-done
-
-DOCKER_BUILDKIT=${BUILDKIT} docker build $EXTRA \
- --build-arg BASE_IMAGE=$BASE_IMAGE \
- . \
- -t flower_client:latest
diff --git a/examples/embedded-devices/client_pytorch.py b/examples/embedded-devices/client_pytorch.py
deleted file mode 100644
index 0fee7a854d67..000000000000
--- a/examples/embedded-devices/client_pytorch.py
+++ /dev/null
@@ -1,195 +0,0 @@
-import argparse
-import warnings
-from collections import OrderedDict
-
-import flwr as fl
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from flwr_datasets import FederatedDataset
-from torch.utils.data import DataLoader
-from torchvision.models import mobilenet_v3_small
-from torchvision.transforms import Compose, Normalize, ToTensor
-from tqdm import tqdm
-
-parser = argparse.ArgumentParser(description="Flower Embedded devices")
-parser.add_argument(
- "--server_address",
- type=str,
- default="0.0.0.0:8080",
- help=f"gRPC server address (default '0.0.0.0:8080')",
-)
-parser.add_argument(
- "--cid",
- type=int,
- required=True,
- help="Client id. Should be an integer between 0 and NUM_CLIENTS",
-)
-parser.add_argument(
- "--mnist",
- action="store_true",
- help="If you use Raspberry Pi Zero clients (which just have 512MB or RAM) use "
- "MNIST",
-)
-
-warnings.filterwarnings("ignore", category=UserWarning)
-NUM_CLIENTS = 50
-
-
-class Net(nn.Module):
- """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')."""
-
- def __init__(self) -> None:
- super(Net, self).__init__()
- self.conv1 = nn.Conv2d(1, 6, 5)
- self.pool = nn.MaxPool2d(2, 2)
- self.conv2 = nn.Conv2d(6, 16, 5)
- self.fc1 = nn.Linear(16 * 4 * 4, 120)
- self.fc2 = nn.Linear(120, 84)
- self.fc3 = nn.Linear(84, 10)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.pool(F.relu(self.conv1(x)))
- x = self.pool(F.relu(self.conv2(x)))
- x = x.view(-1, 16 * 4 * 4)
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- return self.fc3(x)
-
-
-def train(net, trainloader, optimizer, epochs, device):
- """Train the model on the training set."""
- criterion = torch.nn.CrossEntropyLoss()
- for _ in range(epochs):
- for batch in tqdm(trainloader):
- batch = list(batch.values())
- images, labels = batch[0], batch[1]
- optimizer.zero_grad()
- criterion(net(images.to(device)), labels.to(device)).backward()
- optimizer.step()
-
-
-def test(net, testloader, device):
- """Validate the model on the test set."""
- criterion = torch.nn.CrossEntropyLoss()
- correct, loss = 0, 0.0
- with torch.no_grad():
- for batch in tqdm(testloader):
- batch = list(batch.values())
- images, labels = batch[0], batch[1]
- outputs = net(images.to(device))
- labels = labels.to(device)
- loss += criterion(outputs, labels).item()
- correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
- accuracy = correct / len(testloader.dataset)
- return loss, accuracy
-
-
-def prepare_dataset(use_mnist: bool):
- """Get MNIST/CIFAR-10 and return client partitions and global testset."""
- if use_mnist:
- fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS})
- img_key = "image"
- norm = Normalize((0.1307,), (0.3081,))
- else:
- fds = FederatedDataset(dataset="cifar10", partitioners={"train": NUM_CLIENTS})
- img_key = "img"
- norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- pytorch_transforms = Compose([ToTensor(), norm])
-
- def apply_transforms(batch):
- """Apply transforms to the partition from FederatedDataset."""
- batch[img_key] = [pytorch_transforms(img) for img in batch[img_key]]
- return batch
-
- trainsets = []
- validsets = []
- for partition_id in range(NUM_CLIENTS):
- partition = fds.load_partition(partition_id, "train")
- # Divide data on each node: 90% train, 10% test
- partition = partition.train_test_split(test_size=0.1, seed=42)
- partition = partition.with_transform(apply_transforms)
- trainsets.append(partition["train"])
- validsets.append(partition["test"])
- testset = fds.load_split("test")
- testset = testset.with_transform(apply_transforms)
- return trainsets, validsets, testset
-
-
-# Flower client, adapted from Pytorch quickstart/simulation example
-class FlowerClient(fl.client.NumPyClient):
- """A FlowerClient that trains a MobileNetV3 model for CIFAR-10 or a much smaller CNN
- for MNIST."""
-
- def __init__(self, trainset, valset, use_mnist):
- self.trainset = trainset
- self.valset = valset
- # Instantiate model
- if use_mnist:
- self.model = Net()
- else:
- self.model = mobilenet_v3_small(num_classes=10)
- # Determine device
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- self.model.to(self.device) # send model to device
-
- def set_parameters(self, params):
- """Set model weights from a list of NumPy ndarrays."""
- params_dict = zip(self.model.state_dict().keys(), params)
- state_dict = OrderedDict(
- {
- k: torch.Tensor(v) if v.shape != torch.Size([]) else torch.Tensor([0])
- for k, v in params_dict
- }
- )
- self.model.load_state_dict(state_dict, strict=True)
-
- def get_parameters(self, config):
- return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
-
- def fit(self, parameters, config):
- print("Client sampled for fit()")
- self.set_parameters(parameters)
- # Read hyperparameters from config set by the server
- batch, epochs = config["batch_size"], config["epochs"]
- # Construct dataloader
- trainloader = DataLoader(self.trainset, batch_size=batch, shuffle=True)
- # Define optimizer
- optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
- # Train
- train(self.model, trainloader, optimizer, epochs=epochs, device=self.device)
- # Return local model and statistics
- return self.get_parameters({}), len(trainloader.dataset), {}
-
- def evaluate(self, parameters, config):
- print("Client sampled for evaluate()")
- self.set_parameters(parameters)
- # Construct dataloader
- valloader = DataLoader(self.valset, batch_size=64)
- # Evaluate
- loss, accuracy = test(self.model, valloader, device=self.device)
- # Return statistics
- return float(loss), len(valloader.dataset), {"accuracy": float(accuracy)}
-
-
-def main():
- args = parser.parse_args()
- print(args)
-
- assert args.cid < NUM_CLIENTS
-
- use_mnist = args.mnist
- # Download dataset and partition it
- trainsets, valsets, _ = prepare_dataset(use_mnist)
-
- # Start Flower client setting its associated data partition
- fl.client.start_client(
- server_address=args.server_address,
- client=FlowerClient(
- trainset=trainsets[args.cid], valset=valsets[args.cid], use_mnist=use_mnist
- ).to_client(),
- )
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/embedded-devices/client_tf.py b/examples/embedded-devices/client_tf.py
deleted file mode 100644
index 524404b3ef8b..000000000000
--- a/examples/embedded-devices/client_tf.py
+++ /dev/null
@@ -1,134 +0,0 @@
-import argparse
-import math
-import warnings
-
-import flwr as fl
-import tensorflow as tf
-from flwr_datasets import FederatedDataset
-from tensorflow import keras as keras
-
-parser = argparse.ArgumentParser(description="Flower Embedded devices")
-parser.add_argument(
- "--server_address",
- type=str,
- default="0.0.0.0:8080",
- help=f"gRPC server address (deafault '0.0.0.0:8080')",
-)
-parser.add_argument(
- "--cid",
- type=int,
- required=True,
- help="Client id. Should be an integer between 0 and NUM_CLIENTS",
-)
-parser.add_argument(
- "--mnist",
- action="store_true",
- help="If you use Raspberry Pi Zero clients (which just have 512MB or RAM) use MNIST",
-)
-
-warnings.filterwarnings("ignore", category=UserWarning)
-NUM_CLIENTS = 50
-
-
-def prepare_dataset(use_mnist: bool):
- """Download and partitions the CIFAR-10/MNIST dataset."""
- if use_mnist:
- fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS})
- img_key = "image"
- else:
- fds = FederatedDataset(dataset="cifar10", partitioners={"train": NUM_CLIENTS})
- img_key = "img"
- partitions = []
- for partition_id in range(NUM_CLIENTS):
- partition = fds.load_partition(partition_id, "train")
- partition.set_format("numpy")
- # Divide data on each node: 90% train, 10% test
- partition = partition.train_test_split(test_size=0.1, seed=42)
- x_train, y_train = (
- partition["train"][img_key] / 255.0,
- partition["train"]["label"],
- )
- x_test, y_test = partition["test"][img_key] / 255.0, partition["test"]["label"]
- partitions.append(((x_train, y_train), (x_test, y_test)))
- data_centralized = fds.load_split("test")
- data_centralized.set_format("numpy")
- x_centralized = data_centralized[img_key] / 255.0
- y_centralized = data_centralized["label"]
- return partitions, (x_centralized, y_centralized)
-
-
-class FlowerClient(fl.client.NumPyClient):
- """A FlowerClient that uses MobileNetV3 for CIFAR-10 or a much smaller CNN for
- MNIST."""
-
- def __init__(self, trainset, valset, use_mnist: bool):
- self.x_train, self.y_train = trainset
- self.x_val, self.y_val = valset
- # Instantiate model
- if use_mnist:
- # small model for MNIST
- self.model = keras.Sequential(
- [
- keras.Input(shape=(28, 28, 1)),
- keras.layers.Conv2D(32, kernel_size=(5, 5), activation="relu"),
- keras.layers.MaxPooling2D(pool_size=(2, 2)),
- keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
- keras.layers.MaxPooling2D(pool_size=(2, 2)),
- keras.layers.Flatten(),
- keras.layers.Dropout(0.5),
- keras.layers.Dense(10, activation="softmax"),
- ]
- )
- else:
- # let's use a larger model for cifar
- self.model = tf.keras.applications.MobileNetV3Small(
- (32, 32, 3), classes=10, weights=None
- )
- self.model.compile(
- "adam", "sparse_categorical_crossentropy", metrics=["accuracy"]
- )
-
- def get_parameters(self, config):
- return self.model.get_weights()
-
- def set_parameters(self, params):
- self.model.set_weights(params)
-
- def fit(self, parameters, config):
- print("Client sampled for fit()")
- self.set_parameters(parameters)
- # Set hyperparameters from config sent by server/strategy
- batch, epochs = config["batch_size"], config["epochs"]
- # train
- self.model.fit(self.x_train, self.y_train, epochs=epochs, batch_size=batch)
- return self.get_parameters({}), len(self.x_train), {}
-
- def evaluate(self, parameters, config):
- print("Client sampled for evaluate()")
- self.set_parameters(parameters)
- loss, accuracy = self.model.evaluate(self.x_val, self.y_val)
- return loss, len(self.x_val), {"accuracy": accuracy}
-
-
-def main():
- args = parser.parse_args()
- print(args)
-
- assert args.cid < NUM_CLIENTS
-
- use_mnist = args.mnist
- # Download dataset and partition it
- partitions, _ = prepare_dataset(use_mnist)
- trainset, valset = partitions[args.cid]
-
- # Start Flower client setting its associated data partition
- fl.client.start_client(
- server_address=args.server_address,
- client=FlowerClient(
- trainset=trainset, valset=valset, use_mnist=use_mnist
- ).to_client(),
- )
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/embedded-devices/device_setup.md b/examples/embedded-devices/device_setup.md
new file mode 100644
index 000000000000..642ad4c0f93c
--- /dev/null
+++ b/examples/embedded-devices/device_setup.md
@@ -0,0 +1,94 @@
+# Setting up your Embedded Device
+
+> \[!NOTE\]
+> This guide is applicable to many embedded devices such as Raspberry Pi. This guide assumes you have a fresh install of Raspberry Pi OS Lite or Ubuntu Server (e.g. 22.04) and that you have successfully `ssh`-ed into your device.
+
+## Setting up your device for Python developemnet
+
+We are going to use [`pyenv`](https://github.com/pyenv/pyenv) to manage different Python versions and to create an environment. First, we need to install some system dependencies
+
+```shell
+sudo apt-get update
+# Install python deps relevant for this and other examples
+sudo apt-get install build-essential zlib1g-dev libssl-dev \
+ libsqlite3-dev libreadline-dev libbz2-dev \
+ git libffi-dev liblzma-dev libsndfile1 -y
+
+# Install some good to have
+sudo apt-get install htop tmux -y
+
+# Add mouse support for tmux
+echo "set-option -g mouse on" >> ~/.tmux.conf
+```
+
+It is recommended to work on virtual environments instead of in the global Python environment. Let's install `pyenv` with the `virtualenv` plugin.
+
+### Install `pyenv` and `virtualenv` plugin
+
+```shell
+git clone https://github.com/pyenv/pyenv.git ~/.pyenv
+echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc
+echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc
+echo 'eval "$(pyenv init -)"' >> ~/.bashrc
+
+# Now reload .bashrc
+source ~/.bashrc
+
+# Install pyenv virtual env plugin
+git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv
+# Restart your shell
+exec "$SHELL"
+```
+
+## Create a Python environment and activate it
+
+> \[!TIP\]
+> If you are using a Raspberry Pi Zero 2 or another embedded device with a small amount of RAM (e.g. \<1GB), you probably need to extend the size of the SWAP partition. See the guide at the end of this readme.
+
+Now all is ready to create a virtualenvironment. But first, let's install a recent version of Python:
+
+```shell
+# Install python 3.10+
+pyenv install 3.10.14
+
+# Then create a virtual environment
+pyenv virtualenv 3.10.14 my-env
+```
+
+Finally, activate your environment and install the dependencies for your project:
+
+```shell
+# Activate your environment
+pyenv activate my-env
+
+# Then, install flower
+pip install flwr
+
+# Install any other dependency needed for your device
+# Likely your embedded device will run a Flower SuperNode
+# This means you'll likely want to install dependencies that
+# your Flower `ClientApp` needs.
+
+pip install
+```
+
+## Extening SWAP for `RPi Zero 2`
+
+> \[!NOTE\]
+> This mini-guide is useful if your RPi Zero 2 cannot complete installing some packages (e.g. TensorFlow or even Python) or do some processing due to its limited RAM.
+
+A workaround is to create a `swap` disk partition (non-existant by default) so the OS can offload some elements to disk. I followed the steps described [in this blogpost](https://www.digitalocean.com/community/tutorials/how-to-add-swap-space-on-ubuntu-20-04) that I copy below. You can follow these steps if you often see your RPi Zero running out of memory:
+
+```shell
+# Let's create a 1GB swap partition
+sudo fallocate -l 1G /swapfile
+sudo chmod 600 /swapfile
+sudo mkswap /swapfile
+# Enable swap
+sudo swapon /swapfile # you should now be able to see the swap size on htop.
+# make changes permanent after reboot
+sudo cp /etc/fstab /etc/fstab.bak
+echo '/swapfile none swap sw 0 0' | sudo tee -a /etc/fstab
+```
+
+Please note using swap as if it was RAM comes with a large penalty in terms of data movement.
diff --git a/examples/embedded-devices/embeddedexample/__init__.py b/examples/embedded-devices/embeddedexample/__init__.py
new file mode 100644
index 000000000000..d70d6aaf8d39
--- /dev/null
+++ b/examples/embedded-devices/embeddedexample/__init__.py
@@ -0,0 +1 @@
+"""embeddedexample: A Flower / PyTorch app."""
diff --git a/examples/embedded-devices/embeddedexample/client_app.py b/examples/embedded-devices/embeddedexample/client_app.py
new file mode 100644
index 000000000000..442e16b4cb3b
--- /dev/null
+++ b/examples/embedded-devices/embeddedexample/client_app.py
@@ -0,0 +1,64 @@
+"""embeddedexample: A Flower / PyTorch app."""
+
+import torch
+from flwr.client import ClientApp, NumPyClient
+from flwr.common import Context
+
+from embeddedexample.task import (
+ Net,
+ get_weights,
+ load_data_from_disk,
+ set_weights,
+ test,
+ train,
+)
+
+
+# Define Flower Client
+class FlowerClient(NumPyClient):
+ def __init__(self, trainloader, valloader, local_epochs, learning_rate):
+ self.net = Net()
+ self.trainloader = trainloader
+ self.valloader = valloader
+ self.local_epochs = local_epochs
+ self.lr = learning_rate
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ def fit(self, parameters, config):
+ """Train the model with data of this client."""
+ set_weights(self.net, parameters)
+ results = train(
+ self.net,
+ self.trainloader,
+ self.valloader,
+ self.local_epochs,
+ self.lr,
+ self.device,
+ )
+ return get_weights(self.net), len(self.trainloader.dataset), results
+
+ def evaluate(self, parameters, config):
+ """Evaluate the model on the data this client has."""
+ set_weights(self.net, parameters)
+ loss, accuracy = test(self.net, self.valloader, self.device)
+ return loss, len(self.valloader.dataset), {"accuracy": accuracy}
+
+
+def client_fn(context: Context):
+ """Construct a Client that will be run in a ClientApp."""
+
+ # Read the node_config to know where dataset is located
+ dataset_path = context.node_config["dataset-path"]
+
+ # Read run_config to fetch hyperparameters relevant to this run
+ batch_size = context.run_config["batch-size"]
+ trainloader, valloader = load_data_from_disk(dataset_path, batch_size)
+ local_epochs = context.run_config["local-epochs"]
+ learning_rate = context.run_config["learning-rate"]
+
+ # Return Client instance
+ return FlowerClient(trainloader, valloader, local_epochs, learning_rate).to_client()
+
+
+# Flower ClientApp
+app = ClientApp(client_fn)
diff --git a/examples/embedded-devices/embeddedexample/server_app.py b/examples/embedded-devices/embeddedexample/server_app.py
new file mode 100644
index 000000000000..59ec72bebbfa
--- /dev/null
+++ b/examples/embedded-devices/embeddedexample/server_app.py
@@ -0,0 +1,46 @@
+"""embeddedexample: A Flower / PyTorch app."""
+
+from typing import List, Tuple
+
+from flwr.common import Context, Metrics, ndarrays_to_parameters
+from flwr.server import ServerApp, ServerAppComponents, ServerConfig
+from flwr.server.strategy import FedAvg
+
+from embeddedexample.task import Net, get_weights
+
+
+# Define metric aggregation function
+def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
+ # Multiply accuracy of each client by number of examples used
+ accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
+ examples = [num_examples for num_examples, _ in metrics]
+
+ # Aggregate and return custom metric (weighted average)
+ return {"accuracy": sum(accuracies) / sum(examples)}
+
+
+def server_fn(context: Context):
+ """Construct components that set the ServerApp behaviour."""
+
+ # Read from config
+ num_rounds = context.run_config["num-server-rounds"]
+
+ # Initialize model parameters
+ ndarrays = get_weights(Net())
+ parameters = ndarrays_to_parameters(ndarrays)
+
+ # Define the strategy
+ strategy = FedAvg(
+ fraction_fit=1.0,
+ fraction_evaluate=context.run_config["fraction-evaluate"],
+ min_available_clients=2,
+ evaluate_metrics_aggregation_fn=weighted_average,
+ initial_parameters=parameters,
+ )
+ config = ServerConfig(num_rounds=num_rounds)
+
+ return ServerAppComponents(strategy=strategy, config=config)
+
+
+# Create ServerApp
+app = ServerApp(server_fn=server_fn)
diff --git a/examples/embedded-devices/embeddedexample/task.py b/examples/embedded-devices/embeddedexample/task.py
new file mode 100644
index 000000000000..f08441c0426a
--- /dev/null
+++ b/examples/embedded-devices/embeddedexample/task.py
@@ -0,0 +1,98 @@
+"""embeddedexample: A Flower / PyTorch app."""
+
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from datasets import load_from_disk
+from torch.utils.data import DataLoader
+from torchvision.transforms import Compose, Normalize, ToTensor
+
+
+class Net(nn.Module):
+ """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
+
+ def __init__(self):
+ super(Net, self).__init__()
+ self.conv1 = nn.Conv2d(1, 6, 5)
+ self.pool = nn.MaxPool2d(2, 2)
+ self.conv2 = nn.Conv2d(6, 16, 5)
+ self.fc1 = nn.Linear(16 * 4 * 4, 120)
+ self.fc2 = nn.Linear(120, 84)
+ self.fc3 = nn.Linear(84, 10)
+
+ def forward(self, x):
+ x = self.pool(F.relu(self.conv1(x)))
+ x = self.pool(F.relu(self.conv2(x)))
+ x = x.view(-1, 16 * 4 * 4)
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ return self.fc3(x)
+
+
+def get_weights(net):
+ return [val.cpu().numpy() for _, val in net.state_dict().items()]
+
+
+def set_weights(net, parameters):
+ params_dict = zip(net.state_dict().keys(), parameters)
+ state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
+ net.load_state_dict(state_dict, strict=True)
+
+
+def load_data_from_disk(path: str, batch_size: int):
+ """Load a dataset in Huggingface format from disk and creates dataloaders."""
+ partition_train_test = load_from_disk(path)
+ pytorch_transforms = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
+
+ def apply_transforms(batch):
+ """Apply transforms to the partition from FederatedDataset."""
+ batch["image"] = [pytorch_transforms(img) for img in batch["image"]]
+ return batch
+
+ partition_train_test = partition_train_test.with_transform(apply_transforms)
+ trainloader = DataLoader(
+ partition_train_test["train"], batch_size=batch_size, shuffle=True
+ )
+ testloader = DataLoader(partition_train_test["test"], batch_size=batch_size)
+ return trainloader, testloader
+
+
+def train(net, trainloader, valloader, epochs, learning_rate, device):
+ """Train the model on the training set."""
+ net.to(device) # move model to GPU if available
+ criterion = torch.nn.CrossEntropyLoss().to(device)
+ optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
+ net.train()
+ for _ in range(epochs):
+ for batch in trainloader:
+ images = batch["image"]
+ labels = batch["label"]
+ optimizer.zero_grad()
+ criterion(net(images.to(device)), labels.to(device)).backward()
+ optimizer.step()
+
+ val_loss, val_acc = test(net, valloader, device)
+
+ results = {
+ "val_loss": val_loss,
+ "val_accuracy": val_acc,
+ }
+ return results
+
+
+def test(net, testloader, device):
+ """Validate the model on the test set."""
+ criterion = torch.nn.CrossEntropyLoss()
+ correct, loss = 0, 0.0
+ with torch.no_grad():
+ for batch in testloader:
+ images = batch["image"].to(device)
+ labels = batch["label"].to(device)
+ outputs = net(images)
+ loss += criterion(outputs, labels).item()
+ correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
+ accuracy = correct / len(testloader.dataset)
+ loss = loss / len(testloader)
+ return loss, accuracy
diff --git a/examples/embedded-devices/generate_dataset.py b/examples/embedded-devices/generate_dataset.py
new file mode 100644
index 000000000000..e1ab30ad31da
--- /dev/null
+++ b/examples/embedded-devices/generate_dataset.py
@@ -0,0 +1,47 @@
+import argparse
+from flwr_datasets import FederatedDataset
+from flwr_datasets.partitioner import IidPartitioner
+
+
+DATASET_DIRECTORY = "datasets"
+
+
+def save_dataset_to_disk(num_partitions: int):
+ """This function downloads the Fashion-MNIST dataset and generates N partitions.
+
+ Each will be saved into the DATASET_DIRECTORY.
+ """
+ partitioner = IidPartitioner(num_partitions=num_partitions)
+ fds = FederatedDataset(
+ dataset="zalando-datasets/fashion_mnist",
+ partitioners={"train": partitioner},
+ )
+
+ for partition_id in range(num_partitions):
+ partition = fds.load_partition(partition_id)
+ partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
+ file_path = f"./{DATASET_DIRECTORY}/fashionmnist_part_{partition_id + 1}"
+ partition_train_test.save_to_disk(file_path)
+ print(f"Written: {file_path}")
+
+
+if __name__ == "__main__":
+ # Initialize argument parser
+ parser = argparse.ArgumentParser(
+ description="Save Fashion-MNIST dataset partitions to disk"
+ )
+
+ # Add an optional positional argument for number of partitions
+ parser.add_argument(
+ "--num-supernodes",
+ type=int,
+ nargs="?",
+ default=2,
+ help="Number of partitions to create (default: 2)",
+ )
+
+ # Parse the arguments
+ args = parser.parse_args()
+
+ # Call the function with the provided argument
+ save_dataset_to_disk(args.num_supernodes)
diff --git a/examples/embedded-devices/pyproject.toml b/examples/embedded-devices/pyproject.toml
new file mode 100644
index 000000000000..f7354a4e95d2
--- /dev/null
+++ b/examples/embedded-devices/pyproject.toml
@@ -0,0 +1,45 @@
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[project]
+name = "embeddedexample"
+version = "1.0.0"
+description = "Federated AI with Embedded Devices using Flower"
+license = "Apache-2.0"
+dependencies = [
+ "flwr>=1.13.0",
+ "flwr-datasets[vision]>=0.3.0",
+ "torch==2.2.1",
+ "torchvision==0.17.1",
+]
+
+[tool.hatch.build]
+exclude = [
+ "datasets/*", # Exclude datasets from FAB (if generated in this directory)
+ "_static/*", # Exclude images in README from FAB
+]
+
+[tool.hatch.build.targets.wheel]
+packages = ["."]
+
+[tool.flwr.app]
+publisher = "flwrlabs"
+
+[tool.flwr.app.components]
+serverapp = "embeddedexample.server_app:app"
+clientapp = "embeddedexample.client_app:app"
+
+[tool.flwr.app.config]
+num-server-rounds = 3
+fraction-evaluate = 0.5
+local-epochs = 1
+learning-rate = 0.1
+batch-size = 32
+
+[tool.flwr.federations]
+default = "embedded-federation"
+
+[tool.flwr.federations.embedded-federation]
+address = "49.12.200.204:9093"
+insecure = true
diff --git a/examples/embedded-devices/requirements_pytorch.txt b/examples/embedded-devices/requirements_pytorch.txt
deleted file mode 100644
index dbad686d914e..000000000000
--- a/examples/embedded-devices/requirements_pytorch.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-flwr>=1.0, <2.0
-flwr-datasets[vision]>=0.0.2, <1.0.0
-torch==1.13.1
-torchvision==0.14.1
-tqdm==4.66.3
diff --git a/examples/embedded-devices/requirements_tf.txt b/examples/embedded-devices/requirements_tf.txt
deleted file mode 100644
index ff65b9c31648..000000000000
--- a/examples/embedded-devices/requirements_tf.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-flwr>=1.0, <2.0
-flwr-datasets[vision]>=0.0.2, <1.0.0
-tensorflow >=2.9.1, != 2.11.1
diff --git a/examples/embedded-devices/server.py b/examples/embedded-devices/server.py
deleted file mode 100644
index 49c72720f02a..000000000000
--- a/examples/embedded-devices/server.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import argparse
-from typing import List, Tuple
-
-import flwr as fl
-from flwr.common import Metrics
-
-parser = argparse.ArgumentParser(description="Flower Embedded devices")
-parser.add_argument(
- "--server_address",
- type=str,
- default="0.0.0.0:8080",
- help=f"gRPC server address (deafault '0.0.0.0:8080')",
-)
-parser.add_argument(
- "--rounds",
- type=int,
- default=5,
- help="Number of rounds of federated learning (default: 5)",
-)
-parser.add_argument(
- "--sample_fraction",
- type=float,
- default=1.0,
- help="Fraction of available clients used for fit/evaluate (default: 1.0)",
-)
-parser.add_argument(
- "--min_num_clients",
- type=int,
- default=2,
- help="Minimum number of available clients required for sampling (default: 2)",
-)
-
-
-# Define metric aggregation function
-def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
- """This function averages teh `accuracy` metric sent by the clients in a `evaluate`
- stage (i.e. clients received the global model and evaluate it on their local
- validation sets)."""
- # Multiply accuracy of each client by number of examples used
- accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
- examples = [num_examples for num_examples, _ in metrics]
-
- # Aggregate and return custom metric (weighted average)
- return {"accuracy": sum(accuracies) / sum(examples)}
-
-
-def fit_config(server_round: int):
- """Return a configuration with static batch size and (local) epochs."""
- config = {
- "epochs": 3, # Number of local epochs done by clients
- "batch_size": 16, # Batch size to use by clients during fit()
- }
- return config
-
-
-def main():
- args = parser.parse_args()
-
- print(args)
-
- # Define strategy
- strategy = fl.server.strategy.FedAvg(
- fraction_fit=args.sample_fraction,
- fraction_evaluate=args.sample_fraction,
- min_fit_clients=args.min_num_clients,
- on_fit_config_fn=fit_config,
- evaluate_metrics_aggregation_fn=weighted_average,
- )
-
- # Start Flower server
- fl.server.start_server(
- server_address=args.server_address,
- config=fl.server.ServerConfig(num_rounds=3),
- strategy=strategy,
- )
-
-
-if __name__ == "__main__":
- main()