diff --git a/examples/whisper-federated-finetuning/README.md b/examples/whisper-federated-finetuning/README.md new file mode 100644 index 000000000000..712cf0e88369 --- /dev/null +++ b/examples/whisper-federated-finetuning/README.md @@ -0,0 +1,250 @@ +# On-device Federated Finetuning for Speech Classification + +This example demonstrates how to, from a pre-trained [Whisper](https://openai.com/research/whisper) model, finetune it for the downstream task of keyword spotting. We'll be implementing a federated downstream finetuning pipeline using Flower involving a total of 100 clients. As for the downstream dataset, we'll be using the [Google Speech Commands](https://huggingface.co/datasets/speech_commands) dataset for keyword spotting. We'll take the encoder part of the [Whisper-tiny](https://huggingface.co/openai/whisper-tiny) model, freeze its parameters, and learn a lightweight classification (\<800K parameters !!) head to correctly classify a spoken word. + +![Keyword Spotting with Whisper overview](_static/keyword_spotting_overview.png) + +This example can be run in three modes: + +- **Centralized training**: the standard way of training ML models, where all the data is available to the node doing the finetuning. +- **Federated Learning**: the better way of doing ML, where a model is finetuned collaboratively by nodes (i.e. clients), each using their own data. These clients can run: + - in _simulation_ mode: a client is an ephemeral Python process with a portion of the system resources assigned to it. + - in _on-device_ mode: clients are detached entities and each can run on a different device. + +## Running the example + +Start by cloning the code example. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/whisper-federated-finetuning . && rm -rf flower && cd whisper-federated-finetuning +``` + +This will create a new directory called `whisper-federated-finetuning` containing the following files: + +``` +-- README.md <- Your're reading this right now +-- rpi_setup.md <- A guide that illustrates how to setup your RPi from scratch +-- sim.py <- Runs the example with Flower simulation +-- server.py <- Defines the server-side logic for the on-device setting +-- client.py <- Defines the client-side logic for the on-device setting +-- utils.py <- auxiliary functions for this example +-- centralised.py <- Runs the example in centralized mode +-- pyproject.toml <- Example dependencies (if you use Poetry) +-- requirements.txt <- Example dependencies +``` + +This example can be run in different ways, please refer to the corresponding section for further instructions. This example was tested with `PyTorch 2.1.0` for all the different ways of running this example except when running on the Raspberry Pi, which seemed to only work with `PyTorch 1.13.1`. Please note the requirement files do not specify a version of PyTorch, therefore you need to choose one that works for you and your system. + +## Centralized Training + +This section describes how to finetune `Whisper-tiny` for keyword spotting without making use of Federated Learning. This means that the whole training set is available at any point and therefore it is in its entirety to finetune the model each epoch. + +On your favorite Python environment manager, install a recent version of [PyTorch](https://pytorch.org/get-started/locally/) (PyTorch 2.0+ is recommended for faster training times). Then install the rest of the requirements. For instance: + +```bash +pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu118 +pip install -r requirements.txt +``` + +Then run centralized training as follows. Please note that the first time you run the code, the `SpeechCommnads` dataset will be downloaded and pre-processed using πŸ€— API (which takes a little while -- approx 40min -- and is cached in `~/.cache/huggingface/datasets/speechcommands` wiht a footprint of ~83GB). Subsequent runs shouldn't require this preprocessing. + +```bash +python centralised.py --compile # don't use `--compile` flag if you are using pytorch < 2.0 + +# The script will save a checkpoint of the classifier head after each epoch +# These checkpoints followo the naming style: `classifier_.pt` + +# You can load a checkpoint by passing it like this: +python centralised.py --checkpoint .pt +``` + +Within 2 epochs you should see a validation accuracy of over 95%. On an RTX 3090Ti each epoch takes ~3min30sec. The final test set consistently reaches 97%+. Below is the log you should expect to see: + +```bash +... +classifier_head_params = 781964 +Initial (loss, acc): loss = 0.04124763025785586, accuracy = 0.03215788419154478 +Epoch: 0 +100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 84928/84928 [03:05<00:00, 456.93it/s, avg_loss=0.7269, avg_acc=0.8282] +VALIDATION ---> loss = 0.0051703976778501234, accuracy = 0.9319775596072931 +Epoch: 1 +100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 84928/84928 [03:07<00:00, 454.06it/s, avg_loss=0.1588, avg_acc=0.9629] +VALIDATION ---> loss = 0.003613288299632327, accuracy = 0.943097575636145 +Epoch: 2 +100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 84928/84928 [03:06<00:00, 454.16it/s, avg_loss=0.1208, avg_acc=0.9675] +VALIDATION ---> loss = 0.0022978041400064466, accuracy = 0.9610298537367261 +Training done... +Evaluating test set. Loading best model +TEST ---> loss = 0.001703281509680464, accuracy = 0.9740286298568507 +``` + +> You made it work better ? Let us know how you did it by opening an GitHub issue or a PR and we'll gladly incorporate your suggestions! + +## Federated Learning + +Centralized training is ok but in many settings it cannot be realised. Primarily because the training data must remain distributed (i.e. on the client side) and cannot be aggregated into a single node (e.g. your server). With Flower we can easily design a federated finetuning pipeline by which clients locally train the classification head on their data, before communicating it to a central server. There, the updates sent by the clients get aggregated and re-distributed among clients for another round of FL. This process is repeated until convergence. Note that, unlike the encoder part of the Whisper model, the classification head is incredibly lightweight (just 780K parameters), adding little communication costs as a result. + +In this example, we partition the training set along the `speaker_id` column into 100 buckets to simulate that many groups of people. You can think of each group as an individual FL _client_ that contains several users/speakers. One way to think about this is to view each client as an office with several people working there, each interacting with the Keyword spotting system. This example exclusively federates the training of the classification head. + +```python +from datasets import load_dataset +sc_train = load_dataset("speech_commands", "v0.02", split="train", token=False) +print(sc_train) +# Dataset({ +# features: ['file', 'audio', 'label', 'is_unknown', 'speaker_id', 'utterance_id'], +# num_rows: 84848 +# }) + +# The training set is comprised of ~85K 1-second audio clips from 2112 individual speakers +ids = set(sc_train['speaker_id']) +print(len(ids)) +# 2113 # <--- +1 since a "None" speaker is included (for clips to construct the _silence_ training examples) +``` + +![Federated Whisper Finetuning pipeline](_static/federated_finetuning_flower_pipeline.png) + +An overview of the FL pipeline built with Flower for this example is illustrated above. + +1. At the start of a round, the server communicates the classification head to a fraction of the clients. At round #0, the classification head is randomly intialised. +2. Each client, using a frozen pre-trained Whisper encoder, trains the classification head using its own data samples. +3. Once on-site training is completed, each client sends back the (now updated) classification head to the Flower server. +4. The Flower server aggregates (via FedAvg) the classification heads in order to obtain a new _global_ classification head. This head will be shared with clients in the next round. + +Flower supports two ways of doing Federated Learning: simulated and non-simulated FL. The former, managed by the [`VirtualClientEngine`](https://flower.dev/docs/framework/how-to-run-simulations.html), allows you to run large-scale workloads in a system-aware manner, that scales with the resources available on your system (whether it is a laptop, a desktop with a single GPU, or a cluster of GPU servers). The latter is better suited for settings where clients are unique devices (e.g. a server, a smart device, etc). This example shows you how to use both. + +### Preparing the dataset + +If you have run the centralized version of this example first, you probably realized that it takes some time to get a fully pre-processed SpeechCommands dataset using the πŸ€— HuggingFace API. This pre-processing is ideal so nothing slowdowns our training once we launch the experiment. For the federated part of this example, we also need to pre-process the data however in a different way since first the training set needs to be split into N different buckets, one for each FL client. + +To launch a Flower client we need a `client_fn` callable that will: (1) Load the dataset of the client; then, (2) return the Client object itself. In `client.py` we have included a few lines of code that preprocess the training partition of a given client and save it to disk (so this doesn't have to be repeated each time you run the experiment). The average pre-processed partition is ~0.5GB. You can run the experiment right away and the data will be pre-processed on-demand (i.e. when the `i`-th client is spawned for the first time), or you can pre-process all client partitions first. In order to do so, please run: + +```bash +# will write to disk all pre-processed data partitions +# by default these will go to a new directory named `client_datasets` +# Similarly to the centralised setting, this preprocessing will take a while (30mins approx) +python sim.py --preprocess +``` + +The resulting data partitions are not equal-sized (which is what you'd often find in practice in the real world) because not all `speaker_id` contributed the same amount of audio clips when the [Speech Commands Dataset](https://arxiv.org/abs/1804.03209) was created. If we make a bar plot showing the amount of data each client has this is the result. + +![Amount of data per client](_static/whisper_flower_data.png) + +### Federated Finetuning (Simulation) + +The setup instructions for simulations are the same as those described for the centralized setting above: install PyTorch and then `pip install -r requirements.txt`. Then, you can launch your simulation as shown below. Without changes to the code or input arguments, the simulation will sample `10` clients per round, these would do 1 local epoch of finetuning the classification head while the encoder remains frozen. Once this is completed, the classification head is sent to the server for aggregation via `FedAvg`. By default, this example assumes you have a GPU available. + +```bash +# By default it will run 2 clients in parallel on a single GPU (which should be fine if your GPU has at least 16GB ) +# If that's too much, consider reduing either the batch size or raise `num_gpus` passed to `start_simulation` +python sim.py # append --num_gpus=0 if you don't have GPUs on your system + +# Once finished centralised evaluation loss/acc metrics will be shown + +INFO flwr 2023-11-08 14:03:57,557 | app.py:229 | app_fit: metrics_centralized {'val_accuracy': [(0, 0.03977158885994791), + (1, 0.6940492887196954), (2, 0.5969745541975556), (3, 0.8794830695251452), (4, 0.9021238228811861), (5, 0.8943097575636145), + (6, 0.9047285113203767), (7, 0.9330795431777199), (8, 0.9446002805049089), (9, 0.9556201162091765)], + 'test_accuracy': [(10, 0.9719836400817996)]} +``` + +![Global validation accuracy FL with Whisper model](_static/whisper_flower_acc.png) + +With just 5 FL rounds, the global model should be reaching ~95% validation accuracy. A test accuracy of 97% can be reached with 10 rounds of FL training using the default hyperparameters. On an RTX 3090Ti, each round takes ~20-30s depending on the amount of data the clients selected in a round have. + +Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for more details on how you can customize your simulation. + +### Federated Finetuning (non-simulated) + +Running the exact same FL pipeline as in the simulation setting can be done without using Flower's simulation engine. To achieve this, you need to launch first a server and then two or more clients. You can do this on your development machine assuming you have set up your environment already. + +First, launch the server, which will orchestrate the FL process: + +```bash +# The server will wait until at least two clients are connected +# you can use `--server_address='localhost'` if you are running everything on the same machine. +python server.py --server_addres= +``` + +Then on different (new) terminals run: + +```bash +# use a difference `--cid` (client id) to make the client load a particular dataset partition (any integer between 0-99) +# you can use `--server_address='localhost'` if you are running everything on the same machine. +python client.py --server_address= --cid=0 + +# and on a new terminal/machine (and optionally a different `cid`) +python client.py --server_address= --cid=1 +``` + +Once the second client connects to the server, the FL process will begin. Each client will report its training progress. The server process will do the same + +```bash +# python client.py --server_address='localhost' --cid=50 +# This client runs on a NVIDIA RTX 3090Ti +INFO flwr 2023-11-08 14:12:50,135 | grpc.py:49 | Opened insecure gRPC connection (no certificates were passed) +DEBUG flwr 2023-11-08 14:12:50,136 | connection.py:42 | ChannelConnectivity.IDLE +DEBUG flwr 2023-11-08 14:12:50,136 | connection.py:42 | ChannelConnectivity.CONNECTING +DEBUG flwr 2023-11-08 14:12:50,140 | connection.py:42 | ChannelConnectivity.READY +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [00:09<00:00, 93.39it/s, avg_loss=2.4414, avg_acc=0.1837] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [00:04<00:00, 216.93it/s, avg_loss=2.0191, avg_acc=0.3315] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [00:04<00:00, 214.29it/s, avg_loss=1.5950, avg_acc=0.5500] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [00:04<00:00, 212.70it/s, avg_loss=1.1883, avg_acc=0.7348] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [00:04<00:00, 208.69it/s, avg_loss=0.8466, avg_acc=0.8228] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [00:04<00:00, 206.31it/s, avg_loss=0.6353, avg_acc=0.8837] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [00:03<00:00, 266.73it/s, avg_loss=0.4842, avg_acc=0.9207] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [00:04<00:00, 212.13it/s, avg_loss=0.3519, avg_acc=0.9391] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [00:04<00:00, 213.17it/s, avg_loss=0.3233, avg_acc=0.9359] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [00:04<00:00, 205.12it/s, avg_loss=0.2646, avg_acc=0.9543] +DEBUG flwr 2023-11-08 14:20:01,065 | connection.py:139 | gRPC channel closed +INFO flwr 2023-11-08 14:20:01,065 | app.py:215 | Disconnect and shut down +``` + +### Federated Finetuning on Raspberry Pi + +Setting up the environment for the Raspberry Pi is not that different from the steps you'd follow on any other Ubuntu machine (this example assumes your Raspberry Pi -- either 5 or 4 -- runs Ubuntu server 22.04/23.10 64bits). Using the code as-is, RAM usage on the Raspberry Pi does not exceed 1.5GB. Note that unlike in the previous sections of this example, clients for Raspberry Pi work better when using PyTorch 1.13.1 (or earlier versions to PyTorch 2.0 in general). + +> Please follow the steps [here](rpi_setup.md) if you are looking for a step-by-step guide on how to setup your Raspberry Pi to run this example. + +In order to run this example on a Raspberry Pi, you'll need to follow the same steps as outlined above in the `non-simulated` section. First, launch the server on your development machine. + +```bash +# The server will wait until at least two clients are connected +python server.py --server_addres= +``` + +Then, on each of your Raspberry Pi do the following. If you only have one RPi, you can still run the example! But you will need two clients. In addition to the one on the Raspberry Pi, you could launch a client in a separate terminal on your development machine (as shown above in the `non-simulated` section). + +```bash +# use a difference `--cid` (client id) to make this device load a particular dataset partition +# we pass the `--no-compile` option since for RPi we are not using PyTorch 2.0+ +python client.py --server_address= --cid=0 --no-compile +``` + +The first time you run a client on the RPi, the dataset of a client needs to be extracted from the full train set and then pre-processed. The Raspberry Pi 5 is also faster in this pre-processing stage using `.filter()` and `.map()` of πŸ€— HuggingFace Dataset. `map()` used `num_proc=4`: + +| **Stage** | Notes | **RPi 4** | **RPi 5** | +| :-------------------------------------: | :----------------------------------------------: | --------- | --------- | +| Filter through training set (~85k rows) | doing `.filter()` in `client.client_fn` | 1:58 | 0.37 | +| Encode 845 rows with `WhisperProcessor` | doing `.map()` passing `utils.prepare_dataset()` | 1:55 | 1:06 | + +Some clients have more data than others, but on average, the RPi5 is 1.9x faster than an RPi4 when training the classification head given a frozen encoder. A client with 925 training examples needs ~20min on an RPi to complete an epoch of on-device finetuning. + +```bash +# Running the 50-th client on a RPi 5 showed the following log (a RPi4 ran client 83) +python client.py --cid=50 --server_address= --no-compile +INFO flwr 2023-11-08 16:20:33,331 | grpc.py:49 | Opened insecure gRPC connection (no certificates were passed) +DEBUG flwr 2023-11-08 16:20:33,333 | connection.py:42 | ChannelConnectivity.IDLE +DEBUG flwr 2023-11-08 16:20:33,334 | connection.py:42 | ChannelConnectivity.CONNECTING +DEBUG flwr 2023-11-08 16:20:33,349 | connection.py:42 | ChannelConnectivity.READY +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [20:09<00:06, 1.31s/it, avg_loss=2.4392, avg_acc=0.1902] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [20:06<00:06, 1.31s/it, avg_loss=1.9830, avg_acc=0.3533] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [20:06<00:06, 1.31s/it, avg_loss=1.6069, avg_acc=0.5641] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [20:07<00:06, 1.31s/it, avg_loss=1.1933, avg_acc=0.7402] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [20:07<00:06, 1.31s/it, avg_loss=0.8749, avg_acc=0.8478] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [20:06<00:06, 1.31s/it, avg_loss=0.5933, avg_acc=0.9109] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [20:08<00:06, 1.31s/it, avg_loss=0.4882, avg_acc=0.9359] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [20:01<00:06, 1.31s/it, avg_loss=0.4022, avg_acc=0.9304] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [20:10<00:06, 1.32s/it, avg_loss=0.3219, avg_acc=0.9533] +99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 920/925 [20:13<00:06, 1.32s/it, avg_loss=0.2729, avg_acc=0.9641] +DEBUG flwr 2023-11-08 19:47:56,544 | connection.py:139 | gRPC channel closed +INFO flwr 2023-11-08 19:47:56,544 | app.py:215 | Disconnect and shut down +``` diff --git a/examples/whisper-federated-finetuning/_static/federated_finetuning_flower_pipeline.png b/examples/whisper-federated-finetuning/_static/federated_finetuning_flower_pipeline.png new file mode 100644 index 000000000000..8b931e43c80e Binary files /dev/null and b/examples/whisper-federated-finetuning/_static/federated_finetuning_flower_pipeline.png differ diff --git a/examples/whisper-federated-finetuning/_static/keyword_spotting_overview.png b/examples/whisper-federated-finetuning/_static/keyword_spotting_overview.png new file mode 100644 index 000000000000..141f11edfe04 Binary files /dev/null and b/examples/whisper-federated-finetuning/_static/keyword_spotting_overview.png differ diff --git a/examples/whisper-federated-finetuning/_static/whisper_flower_acc.png b/examples/whisper-federated-finetuning/_static/whisper_flower_acc.png new file mode 100644 index 000000000000..9988cdaefa35 Binary files /dev/null and b/examples/whisper-federated-finetuning/_static/whisper_flower_acc.png differ diff --git a/examples/whisper-federated-finetuning/_static/whisper_flower_data.png b/examples/whisper-federated-finetuning/_static/whisper_flower_data.png new file mode 100644 index 000000000000..92a29ceff979 Binary files /dev/null and b/examples/whisper-federated-finetuning/_static/whisper_flower_data.png differ diff --git a/examples/whisper-federated-finetuning/centralised.py b/examples/whisper-federated-finetuning/centralised.py new file mode 100644 index 000000000000..6af591a7502b --- /dev/null +++ b/examples/whisper-federated-finetuning/centralised.py @@ -0,0 +1,138 @@ +import argparse +from datasets import load_dataset +from transformers import WhisperForConditionalGeneration, WhisperProcessor +import torch +from torch.utils.data import DataLoader, WeightedRandomSampler +import numpy as np +from datasets import concatenate_datasets +import random + +from utils import ( + get_model, + train_one_epoch, + eval_model, + prepare_silences_dataset, + get_encoding_fn, + remove_cols, +) + +random.seed(1989) +torch.set_float32_matmul_precision( + "high" +) # If β€œhigh” or β€œmedium” are set then the TensorFloat32 is used +NUM_CLASSES = 12 +parser = argparse.ArgumentParser(description="Whisper centralised") + +parser.add_argument("--checkpoint", type=str, help="path to classifier`s checkpoint") +parser.add_argument( + "--epochs", type=int, default=3, help="Number of epochs of training." +) +parser.add_argument( + "--compile", action="store_true", help="compiles model (pytorch 2.0+ only)" +) + + +def save_classifier(classifier, acc: float): + filename = f"classifier_{acc:.4f}.pt" + torch.save(classifier.cpu().state_dict(), filename) + return filename + + +def main(): + args = parser.parse_args() + + # load train and test partitions + sc = load_dataset("speech_commands", "v0.02", split="train", token=False) + sc_val = load_dataset("speech_commands", "v0.02", split="validation", token=False) + sc_test = load_dataset("speech_commands", "v0.02", split="test", token=False) + + # pre-process dataset + # ! If you know how to speedup this pre-processing stage, please do let us know! + # ! Become a contributor by proposing as a new PR ! + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + prepare_dataset_fn = get_encoding_fn(processor) + og_threads = torch.get_num_threads() + print(f"{og_threads = }") + torch.set_num_threads( + 1 + ) # not clear to me why we need this in order to be able to use `num_proc > 1 for .map` + train_encoded = sc.map(prepare_dataset_fn, num_proc=4, remove_columns=remove_cols) + val_encoded = sc_val.map(prepare_dataset_fn, num_proc=4, remove_columns=remove_cols) + test_encoded = sc_test.map( + prepare_dataset_fn, num_proc=4, remove_columns=remove_cols + ) + + # create and pre-process the dataset of silences + silences_dataset = prepare_silences_dataset(sc, ratio_silence=0.1) + # ! You might want to save this encoded_silences dataset to disk, so this stage is not + # ! needed each time you run the code. Alternatively, this silence generation could be + # ! implemented as part of a `collate_fn` in the standard PyTorch dataloader... + encoded_silences = silences_dataset.map( + prepare_dataset_fn, num_proc=4, remove_columns=remove_cols + ) + full_train_dataset = concatenate_datasets([train_encoded, encoded_silences]) + + torch.set_num_threads(og_threads) + + lbls = set(full_train_dataset["targets"]) + print(f"{lbls = }") + hist = np.histogram(full_train_dataset["targets"], bins=12) + print(f"{[int(count) for count in hist[0]]}") + + # make balanced batches with a WeightedRandomSampler + w_per_class = ( + len(full_train_dataset) / hist[0] + ) # doesn't have to add up to 1 (relative is what matters) + print(f"{w_per_class = }") + w_ss = [w_per_class[t] for t in full_train_dataset["targets"]] + sampler = WeightedRandomSampler(w_ss, len(w_ss)) + + # prepare dataloaders + train_dataset = full_train_dataset.with_format("torch", columns=["data", "targets"]) + train_loader = DataLoader( + train_dataset, batch_size=64, shuffle=False, num_workers=4, sampler=sampler + ) + val_encoded = val_encoded.with_format("torch", columns=["data", "targets"]) + val_loader = DataLoader(val_encoded, batch_size=64, num_workers=4) + test_dataset = test_encoded.with_format("torch", columns=["data", "targets"]) + test_loader = DataLoader(test_dataset, batch_size=64, num_workers=4) + + # model to cuda, set criterion, classification layer to train and optimiser + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + encoder, classifier = get_model(device, num_classes=12) + criterion = torch.nn.CrossEntropyLoss() + + if args.checkpoint: + print(f"Loading checkpoint: {args.checkpoint = }") + classifier.load_state_dict(torch.load(args.checkpoint)) + classifier = classifier.to(device) + optimizer = torch.optim.SGD(classifier.parameters(), lr=0.001) + encoder.eval() + + # Let's count the size of the classification head + classifier_head_params = sum(p.numel() for p in classifier.parameters()) + print(f"{classifier_head_params = }") + + # eval initial model + loss, accuracy = eval_model(encoder, classifier, criterion, val_loader, device) + print(f"Initial (loss, acc): {loss = }, {accuracy = }") + best = [-float("inf"), None] + for e in range(args.epochs): + print(f"Epoch: {e}") + train_one_epoch(encoder, classifier, optimizer, criterion, train_loader, device) + loss, accuracy = eval_model(encoder, classifier, criterion, val_loader, device) + last_saved = save_classifier(classifier, accuracy) + if accuracy > best[0]: + best[0] = accuracy + best[1] = last_saved + print(f"VALIDATION ---> {loss = }, {accuracy = }") + + print("Training done...") + print("Evaluating test set. Loading best model") + classifier.load_state_dict(torch.load(best[1])) + loss, accuracy = eval_model(encoder, classifier, criterion, test_loader, device) + print(f"TEST ---> {loss = }, {accuracy = }") + + +if __name__ == "__main__": + main() diff --git a/examples/whisper-federated-finetuning/client.py b/examples/whisper-federated-finetuning/client.py new file mode 100644 index 000000000000..2bfeadfbdae6 --- /dev/null +++ b/examples/whisper-federated-finetuning/client.py @@ -0,0 +1,183 @@ +import argparse +import torch +import flwr as fl +import numpy as np +from torch.utils.data import DataLoader, WeightedRandomSampler +from datasets import load_dataset, load_from_disk, concatenate_datasets +from transformers import WhisperProcessor + +from utils import ( + get_model, + set_params, + train_one_epoch, + remove_cols, + prepare_silences_dataset, + construct_client_mapping, + get_encoding_fn, +) + +parser = argparse.ArgumentParser(description="Flower+Whisper") +parser.add_argument("--cid", type=int, required=True, help="Client id.") +parser.add_argument( + "--server_address", type=str, required=True, help="IP of the server." +) +parser.add_argument( + "--no-compile", action="store_true", help="To not compile client models." +) + +CLIENT_DATA = "client_datasets" + + +class WhisperFlowerClient(fl.client.NumPyClient): + """A Flower client that does trains a classification head attached to the encoder of + a Whisper-tiny encoder for Keyword spotting.""" + + def __init__(self, trainset, num_classes: int, disable_tqdm: bool, compile: bool): + self.disable_tqdm = disable_tqdm + self.trainset = trainset.with_format("torch", columns=["data", "targets"]) + + # Determine device + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + self.encoder, self.classifier = get_model(self.device, num_classes, compile) + + def get_parameters(self, config): + """Return parameters in a format that is understood by the server.""" + return [val.cpu().numpy() for _, val in self.classifier.state_dict().items()] + + def fit(self, parameters, config): + """Do on-device training. + + Here the client receives the parameters of the classification head from the + server. Then trains that classifier using the data that belongs to this client. + Finally, The updated classifier is sent back to the server for aggregation. + """ + + # Apply the classifier parameters to the model in this client + set_params(self.classifier, parameters) + + # Read from config + batch, epochs = config["batch_size"], config["epochs"] + + # construct sampler in order to have balanced batches + hist = np.histogram(self.trainset["targets"], bins=12) + w_per_class = ( + len(self.trainset) / hist[0] + ) # doesn't have to add up to 1 (relative is what matters) + # print(f"{w_per_class = }") + w_ss = [w_per_class[t] for t in self.trainset["targets"]] + ss = WeightedRandomSampler(w_ss, len(w_ss)) + + # Construct dataloader + train_loader = DataLoader( + self.trainset, + batch_size=batch, + shuffle=False, + num_workers=0, + sampler=ss, + drop_last=True, + ) + + # Define optimizer and criterion + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(self.classifier.parameters(), lr=0.001) + # Train + train_one_epoch( + self.encoder, + self.classifier, + optimizer, + criterion, + train_loader, + self.device, + disable_tqdm=self.disable_tqdm, + ) + + # Return local classification head and statistics + return self.get_parameters({}), len(train_loader.dataset), {} + + +def get_client_fn( + full_data, + encoding_fn, + client_mapping, + client_data_path: str = "./", + num_classes: int = 12, + disable_tqdm: bool = False, + compile: bool = True, +): + """Return a function that can be used to instantiate a particular client.""" + + def client_fn(cid: str): + torch.set_float32_matmul_precision( + "high" + ) # If β€œhigh” or β€œmedium” are set then the TensorFloat32 is used + + # if dataset hasn't been processed for this client, do so. + # else, just load it + try: + full_train_dataset = load_from_disk(f"{client_data_path}/client{cid}.hf") + except: + # get this client's data and preprocess it + print(f"Dataset for client {cid} not found. Pre-processing...") + og_threads = torch.get_num_threads() + torch.set_num_threads(1) + sc_client = full_data.filter( + lambda example: example["speaker_id"] in client_mapping[int(cid)] + ) + client_train_data = sc_client.map( + encoding_fn, num_proc=4, remove_columns=remove_cols + ) + + # now let's add some _silence_ training examples (add 10% of total examples in this client's data) + ratio_silences_for_client = 0.1 * (len(client_train_data) / len(full_data)) + silence_dataset = prepare_silences_dataset( + full_data, ratio_silences_for_client + ) + print( + f"adding {len(silence_dataset)} to client data ({len(client_train_data)})" + ) + silence_enc = silence_dataset.map(encoding_fn, remove_columns=remove_cols) + + full_train_dataset = concatenate_datasets([client_train_data, silence_enc]) + # save dataset. It will be loaded next time this client is spawned + full_train_dataset.save_to_disk(f"{client_data_path}/client{cid}.hf") + torch.set_num_threads(og_threads) + + return WhisperFlowerClient( + full_train_dataset, num_classes, disable_tqdm, compile + ) + + return client_fn + + +def run_client(): + """Run clinet.""" + + # Parse input arguments + args = parser.parse_args() + + sc_train = load_dataset("speech_commands", "v0.02", split="train", token=False) + + # generate splits + client_mapping = construct_client_mapping(sc_train, num_clients=100) + + # pre-process all partitions (+store to disk) + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + prepare_dataset_fn = get_encoding_fn(processor) + + client_fn = get_client_fn( + sc_train, + prepare_dataset_fn, + client_mapping, + compile=not (args.no_compile), + client_data_path=CLIENT_DATA, + ) + + fl.client.start_numpy_client( + server_address=f"{args.server_address}:8080", client=client_fn(args.cid) + ) + + +if __name__ == "__main__": + run_client() diff --git a/examples/whisper-federated-finetuning/pyproject.toml b/examples/whisper-federated-finetuning/pyproject.toml new file mode 100644 index 000000000000..dd5578b8b3d0 --- /dev/null +++ b/examples/whisper-federated-finetuning/pyproject.toml @@ -0,0 +1,19 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "whisper-flower" +version = "0.1.0" +description = "On-device Federated Downstreaming for Speech Classification" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr = { extras = ["simulation"], version = ">=1.0,<2.0" } +transformers = "4.32.1" +tokenizers = "0.13.3" +datasets = "2.14.6" +soundfile = "0.12.1" +librosa = "0.10.1" +# this example was tested with pytorch 2.1.0 \ No newline at end of file diff --git a/examples/whisper-federated-finetuning/requirements.txt b/examples/whisper-federated-finetuning/requirements.txt new file mode 100644 index 000000000000..eb4a5d7eb47b --- /dev/null +++ b/examples/whisper-federated-finetuning/requirements.txt @@ -0,0 +1,7 @@ +transformers==4.32.1 +tokenizers==0.13.3 +datasets==2.14.6 +soundfile==0.12.1 +librosa==0.10.1 +flwr==1.5.0 +ray==2.6.3 \ No newline at end of file diff --git a/examples/whisper-federated-finetuning/rpi_setup.md b/examples/whisper-federated-finetuning/rpi_setup.md new file mode 100644 index 000000000000..d49bbd6a472b --- /dev/null +++ b/examples/whisper-federated-finetuning/rpi_setup.md @@ -0,0 +1,49 @@ +# Setting up your RaspberryPi + +> This guide assumes you have a fresh install of Ubuntu Server (either 22.04 or 23.10) and that you have successfully `ssh`-ed into your device. + +## Setting up your device for Python developemnet + +We are going to use [`pyenv`](https://github.com/pyenv/pyenv) to manage different Python versions and to create an environment. First, we need to install some system dependencies + +```bash +sudo apt update +# the last package is needed for whisper +sudo apt install build-essential zlib1g-dev libssl-dev libsqlite3-dev libreadline-dev libbz2-dev libffi-dev liblzma-dev libsndfile1 +``` + +Create Python environment with `pyenv`: + +```bash + +# Ensure you have installed pyenv, else do the below: +git clone https://github.com/pyenv/pyenv.git ~/.pyenv +echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc +echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc +echo 'eval "$(pyenv init -)"' >> ~/.bashrc + +# Install python 3.9+ +pyenv install 3.9.17 + +# Install pyenv virtual env plugin +git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv +# Restart your shell +exec "$SHELL" + +# Create the environment +pyenv virtualenv 3.9.17 flower-whisperer +``` + +## Installing the dependencies for Whisper+Flower + +With our environmnet ready, let's install the dependencies. Please note that at the time of writing, PyTorch 2.0+ won't work properly on `aarm64`. Because of this, we'll be using an earlier version of this package. + +```bash +# activate your environment +pyenv activate flower-whisperer + +# install pytorch (RPi aren't ready for PyTorch 2.0+ apparently...) +pip install torch==1.13.1 +# install rest of requirerments +pip install -r requirements.txt +``` diff --git a/examples/whisper-federated-finetuning/server.py b/examples/whisper-federated-finetuning/server.py new file mode 100644 index 000000000000..101d43f04ec2 --- /dev/null +++ b/examples/whisper-federated-finetuning/server.py @@ -0,0 +1,104 @@ +import argparse + +import torch +from datasets import load_dataset +from transformers import WhisperProcessor +from torch.utils.data import DataLoader +import flwr as fl + +from utils import eval_model, get_model, set_params, remove_cols, get_encoding_fn + + +parser = argparse.ArgumentParser(description="Flower+Whisper") +parser.add_argument("--num_rounds", type=int, default=5, help="Number of FL rounds.") +parser.add_argument( + "--server_address", type=str, required=True, help="IP of the server." +) + + +NUM_CLASSES = 12 +NUM_CLIENTS = 100 + + +def fit_config(server_round: int): + """Return a configuration with static batch size and (local) epochs.""" + config = { + "epochs": 1, # Number of local epochs done by clients + "batch_size": 8, # Batch size to use by clients during fit() + } + return config + + +def get_evaluate_fn(val_set, test_set, encoding_fn, num_rounds): + def evaluate(server_round: int, parameters: fl.common.NDArrays, config): + """Use the entire CIFAR-10 test set for evaluation.""" + + # Determine device + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # prepare model + encoder, classifier = get_model(device, NUM_CLASSES) + set_params(classifier, parameters) + classifier.to(device) + + # prepare dataset + og_threads = torch.get_num_threads() + torch.set_num_threads( + 1 + ) # ! still, not clear to me why this is needed if we want `num_proc>1` + if server_round == num_rounds: + prefix = "test" + encoded = test_set.map(encoding_fn, num_proc=4, remove_columns=remove_cols) + else: + prefix = "val" + encoded = val_set.map(encoding_fn, num_proc=4, remove_columns=remove_cols) + torch.set_num_threads(og_threads) + + val_encoded = encoded.with_format("torch", columns=["data", "targets"]) + val_loader = DataLoader(val_encoded, batch_size=64, num_workers=4) + + # Run global evaluation + criterion = torch.nn.CrossEntropyLoss() + loss, accuracy = eval_model(encoder, classifier, criterion, val_loader, device) + + print(f"{prefix}: --> {loss = }, {accuracy = }") + + return loss, {f"{prefix}_accuracy": accuracy} + + return evaluate + + +def main(): + # Parse input arguments + args = parser.parse_args() + + # The sever will use the validation set to assess the performance of the global + # model after each round. Then, the test set will be used for evaluating the global + # model after the last round + sc_val = load_dataset("speech_commands", "v0.02", split="validation", token=False) + sc_test = load_dataset("speech_commands", "v0.02", split="test", token=False) + + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + prepare_dataset_fn = get_encoding_fn(processor) + + # We use a standard FedAvg strategy + strategy = fl.server.strategy.FedAvg( + fraction_fit=0.00001, + min_fit_clients=2, # the strategy will wait until at least 2 clients are sampled for fit + fraction_evaluate=0.0, # we don't do federated evaluation in this example + min_available_clients=2, # the strategy will do nothing until 2 clients are connected to the server + on_fit_config_fn=fit_config, + evaluate_fn=get_evaluate_fn( + sc_val, sc_test, prepare_dataset_fn, args.num_rounds + ), + ) + + fl.server.start_server( + server_address=f"{args.server_address}:8080", + config=fl.server.ServerConfig(num_rounds=args.num_rounds), + strategy=strategy, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/whisper-federated-finetuning/sim.py b/examples/whisper-federated-finetuning/sim.py new file mode 100644 index 000000000000..c04f768bb24a --- /dev/null +++ b/examples/whisper-federated-finetuning/sim.py @@ -0,0 +1,95 @@ +import argparse + +import torch +from datasets import load_dataset +from transformers import WhisperProcessor + +import flwr as fl + +from client import get_client_fn +from server import fit_config, get_evaluate_fn +from utils import construct_client_mapping, get_encoding_fn + +parser = argparse.ArgumentParser(description="Flower+Whisper") + +parser.add_argument("--num_rounds", type=int, default=10, help="Number of FL rounds.") +parser.add_argument( + "--num_cpus", type=int, default=4, help="Number of CPUs reserved for each client." +) +parser.add_argument( + "--num_gpus", + type=float, + default=0.5, + help="GPU ratio reserved for each client (`num_gpus`=1.0 means one client gets the whole GPU)", +) +parser.add_argument( + "--preprocess", + action="store_true", + help="Preprocesses all client's datasets and exits (creates ~83GB data)", +) + +NUM_CLASSES = 12 +NUM_CLIENTS = 100 +CLIENT_DATA = "client_datasets" +torch.set_float32_matmul_precision( + "high" +) # If β€œhigh” or β€œmedium” are set then the TensorFloat32 is used + + +def main(): + # Parse input arguments + args = parser.parse_args() + + # dataset download and preparation + sc_train = load_dataset("speech_commands", "v0.02", split="train", token=False) + sc_val = load_dataset("speech_commands", "v0.02", split="validation", token=False) + sc_test = load_dataset("speech_commands", "v0.02", split="test", token=False) + + # generate splits + client_mapping = construct_client_mapping(sc_train, num_clients=NUM_CLIENTS) + + # pre-process all partitions (+store to disk) + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + prepare_dataset_fn = get_encoding_fn(processor) + if args.preprocess: + import sys + + client_fn = get_client_fn( + sc_train, prepare_dataset_fn, client_mapping, CLIENT_DATA, NUM_CLASSES + ) + + for i in range(NUM_CLIENTS): + _ = client_fn(str(i)) + print("Preprocessing completed. Run the code again without `--preprocess`") + sys.exit(0) + + strategy = fl.server.strategy.FedAvg( + fraction_fit=0.00001, + min_fit_clients=10, + fraction_evaluate=0.0, + min_available_clients=NUM_CLIENTS, + on_fit_config_fn=fit_config, + evaluate_fn=get_evaluate_fn( + sc_val, sc_test, prepare_dataset_fn, args.num_rounds + ), + ) + + # Start simulation + fl.simulation.start_simulation( + client_fn=get_client_fn( + sc_train, + prepare_dataset_fn, + client_mapping, + CLIENT_DATA, + NUM_CLASSES, + disable_tqdm=True, + ), + num_clients=NUM_CLIENTS, + client_resources={"num_cpus": args.num_cpus, "num_gpus": args.num_gpus}, + config=fl.server.ServerConfig(num_rounds=args.num_rounds), + strategy=strategy, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/whisper-federated-finetuning/utils.py b/examples/whisper-federated-finetuning/utils.py new file mode 100644 index 000000000000..21fe0309151c --- /dev/null +++ b/examples/whisper-federated-finetuning/utils.py @@ -0,0 +1,210 @@ +from tqdm import tqdm +import torch +import random +from datasets import Dataset +import numpy as np +from collections import OrderedDict +from transformers import WhisperForConditionalGeneration + +from typing import List + +import flwr as fl + + +remove_cols = ["file", "audio", "label", "is_unknown", "speaker_id", "utterance_id"] + + +class RunningAvg: + def __init__(self): + self.n = 0 + self.total = 0 + + def update(self, val): + self.total += val + self.n += 1 + + def __call__(self): + return self.total / self.n + + +def train_one_epoch( + model, + classifier, + optimizer, + criterion, + dataloader, + device, + disable_tqdm: bool = False, +): + """Train the classification head. + + This is a very standard looking way of training PyTorch models. + """ + model.eval() + classifier.train() + classifier.to(device) + loss_avg, acc_avg = RunningAvg(), RunningAvg() + with tqdm(total=len(dataloader.dataset), disable=disable_tqdm) as t: + for b in dataloader: + optimizer.zero_grad() + data = b["data"].squeeze().to(device) + # print(data.shape) + labels = b["targets"].to(device) + with torch.no_grad(): + res = model(data)[0] + + resres = classifier(res) + + loss = criterion(resres.float(), labels) + loss.backward() + optimizer.step() + _, predicted = torch.max(resres.data, 1) + correct = (predicted == labels).sum().item() + acc = correct / data.shape[0] + loss_ = loss.cpu().item() + + loss_avg.update(loss_) + acc_avg.update(acc) + + t.update(data.shape[0]) + t.set_postfix( + {"avg_loss": f"{loss_avg():.4f}", "avg_acc": f"{acc_avg():.4f}"} + ) + + +def eval_model(model, classifier, criterion, dataloader, device): + """Evaluate a model on a validation/test set. + + This is a very normal looking way of doing this with PyTorch. + """ + model.eval() + classifier.eval() + classifier.to(device) + correct = 0 + loss_ = 0 + total = 0 + with torch.no_grad(): + for b in dataloader: + data = b["data"].squeeze().to(device) + # print(data.shape) + labels = b["targets"].to(device) + res = model(data)[0] + resres = classifier(res) + + loss = criterion(resres.float(), labels) + _, predicted = torch.max(resres.data, 1) + correct += (predicted == labels).sum().item() + total += data.shape[0] + loss_ += loss.cpu().item() + + accuracy = correct / total + loss = loss_ / total + + return loss, accuracy + + +def prepare_silences_dataset(train_dataset, ratio_silence: float = 0.1) -> Dataset: + """Generate silences for the train set. + + One of the classes in the SpeechCommands datatset is `silence`. However, the dataset + does not include clips of silence. It does however include 5 long files with different + background sounds. The taks of this function is to extract several (defined by `ratio_silence`) + one-second long clips from those background audio files. Later, those audio clips will be + included into the training set. + """ + # retrieve original silence audio clips + silences = [d for d in train_dataset if d["label"] == 35] + # figure out how many to add + num_silence_total = int(len(train_dataset) * ratio_silence) + # num new entries per background noise clip + num_silence_per_bkg = num_silence_total // len(silences) + + silence_to_add = [] + for sil in silences: + sil_array = sil["audio"]["array"] + sr = sil["audio"]["sampling_rate"] + print(f"Extracting audio from: {sil['file']} ...") + for _ in range(num_silence_per_bkg): + random_offset = random.randint(0, len(sil_array) - sr - 1) + sil_array_crop = sil_array[random_offset : random_offset + sr] + + entry = sil + silence_to_add.append(entry) + silence_to_add[-1]["audio"]["array"] = sil_array_crop + + return Dataset.from_list(silence_to_add) + + +def construct_client_mapping(full_trainset, num_clients: int = 100): + """Create a mapping to partition the dataset into `num_client` buckets. + + These buckets contain the same number of `spekaer_id` but likely different + number of training exampes since each `speaker_id` in SpeechCommands does + provide different amounts of data to the dataset. + """ + client_ids = list(set(full_trainset["speaker_id"])) + client_ids.remove( + None + ) # remove "none" which corresponds to the _silence_ audio clips + client_ids.sort() # we sort this as a quick way of ensuring our client mapping is consistent between runs + len( + client_ids + ) # should be 2112 (i.e. the number of participats in SpeechCommands dataset v0.02) + + # split into groups (each group represents a client) + client_mapping = np.array_split(client_ids, num_clients) + + return client_mapping + + +def get_encoding_fn(processor): + """Return a function to use to pre-process/encode the SpeechCommands dataset. + + We are working with the 12classes version of this dataset, therefore we need to do + some reassignment of labels. + """ + + def prepare_dataset(batch): + audio = batch["audio"] + data = {} + data["data"] = processor( + audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt" + ).input_features + + # All unknown keywords are assigned label 11. The silence clips get assigned label 10 + # In this way we have 12 classes with labels 0-11 + data["targets"] = ( + 11 + if batch["is_unknown"] + else (10 if batch["label"] == 35 else batch["label"]) + ) + return data + + return prepare_dataset + + +def set_params(model: torch.nn.ModuleList, params: List[fl.common.NDArrays]): + """Set model weights from a list of NumPy ndarrays.""" + params_dict = zip(model.state_dict().keys(), params) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + model.load_state_dict(state_dict, strict=True) + + +def get_model(device, num_classes, compile: bool = True): + """Create model: Whisper-tiny Encoder + classification head""" + encoder = WhisperForConditionalGeneration.from_pretrained( + "openai/whisper-tiny" + ).get_encoder() + encoder = encoder.to(device) + if compile: + encoder = torch.compile(encoder) + + # This classification head is 782K parameters + # This is the only part of the model that is trained in federation + classifier = torch.nn.Sequential( + torch.nn.Conv1d(1500, 128, kernel_size=1), + torch.nn.ReLU(), + torch.nn.Flatten(1), + torch.nn.Linear(128 * 384, num_classes), + ).to(device) + return encoder, classifier