Skip to content

Commit

Permalink
Add Whisper example for federated downstreaming (#2569)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Nov 10, 2023
1 parent a76364d commit a3111ed
Show file tree
Hide file tree
Showing 13 changed files with 1,055 additions and 0 deletions.
250 changes: 250 additions & 0 deletions examples/whisper-federated-finetuning/README.md

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
138 changes: 138 additions & 0 deletions examples/whisper-federated-finetuning/centralised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import argparse
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
import numpy as np
from datasets import concatenate_datasets
import random

from utils import (
get_model,
train_one_epoch,
eval_model,
prepare_silences_dataset,
get_encoding_fn,
remove_cols,
)

random.seed(1989)
torch.set_float32_matmul_precision(
"high"
) # If “high” or “medium” are set then the TensorFloat32 is used
NUM_CLASSES = 12
parser = argparse.ArgumentParser(description="Whisper centralised")

parser.add_argument("--checkpoint", type=str, help="path to classifier`s checkpoint")
parser.add_argument(
"--epochs", type=int, default=3, help="Number of epochs of training."
)
parser.add_argument(
"--compile", action="store_true", help="compiles model (pytorch 2.0+ only)"
)


def save_classifier(classifier, acc: float):
filename = f"classifier_{acc:.4f}.pt"
torch.save(classifier.cpu().state_dict(), filename)
return filename


def main():
args = parser.parse_args()

# load train and test partitions
sc = load_dataset("speech_commands", "v0.02", split="train", token=False)
sc_val = load_dataset("speech_commands", "v0.02", split="validation", token=False)
sc_test = load_dataset("speech_commands", "v0.02", split="test", token=False)

# pre-process dataset
# ! If you know how to speedup this pre-processing stage, please do let us know!
# ! Become a contributor by proposing as a new PR !
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
prepare_dataset_fn = get_encoding_fn(processor)
og_threads = torch.get_num_threads()
print(f"{og_threads = }")
torch.set_num_threads(
1
) # not clear to me why we need this in order to be able to use `num_proc > 1 for .map`
train_encoded = sc.map(prepare_dataset_fn, num_proc=4, remove_columns=remove_cols)
val_encoded = sc_val.map(prepare_dataset_fn, num_proc=4, remove_columns=remove_cols)
test_encoded = sc_test.map(
prepare_dataset_fn, num_proc=4, remove_columns=remove_cols
)

# create and pre-process the dataset of silences
silences_dataset = prepare_silences_dataset(sc, ratio_silence=0.1)
# ! You might want to save this encoded_silences dataset to disk, so this stage is not
# ! needed each time you run the code. Alternatively, this silence generation could be
# ! implemented as part of a `collate_fn` in the standard PyTorch dataloader...
encoded_silences = silences_dataset.map(
prepare_dataset_fn, num_proc=4, remove_columns=remove_cols
)
full_train_dataset = concatenate_datasets([train_encoded, encoded_silences])

torch.set_num_threads(og_threads)

lbls = set(full_train_dataset["targets"])
print(f"{lbls = }")
hist = np.histogram(full_train_dataset["targets"], bins=12)
print(f"{[int(count) for count in hist[0]]}")

# make balanced batches with a WeightedRandomSampler
w_per_class = (
len(full_train_dataset) / hist[0]
) # doesn't have to add up to 1 (relative is what matters)
print(f"{w_per_class = }")
w_ss = [w_per_class[t] for t in full_train_dataset["targets"]]
sampler = WeightedRandomSampler(w_ss, len(w_ss))

# prepare dataloaders
train_dataset = full_train_dataset.with_format("torch", columns=["data", "targets"])
train_loader = DataLoader(
train_dataset, batch_size=64, shuffle=False, num_workers=4, sampler=sampler
)
val_encoded = val_encoded.with_format("torch", columns=["data", "targets"])
val_loader = DataLoader(val_encoded, batch_size=64, num_workers=4)
test_dataset = test_encoded.with_format("torch", columns=["data", "targets"])
test_loader = DataLoader(test_dataset, batch_size=64, num_workers=4)

# model to cuda, set criterion, classification layer to train and optimiser
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
encoder, classifier = get_model(device, num_classes=12)
criterion = torch.nn.CrossEntropyLoss()

if args.checkpoint:
print(f"Loading checkpoint: {args.checkpoint = }")
classifier.load_state_dict(torch.load(args.checkpoint))
classifier = classifier.to(device)
optimizer = torch.optim.SGD(classifier.parameters(), lr=0.001)
encoder.eval()

# Let's count the size of the classification head
classifier_head_params = sum(p.numel() for p in classifier.parameters())
print(f"{classifier_head_params = }")

# eval initial model
loss, accuracy = eval_model(encoder, classifier, criterion, val_loader, device)
print(f"Initial (loss, acc): {loss = }, {accuracy = }")
best = [-float("inf"), None]
for e in range(args.epochs):
print(f"Epoch: {e}")
train_one_epoch(encoder, classifier, optimizer, criterion, train_loader, device)
loss, accuracy = eval_model(encoder, classifier, criterion, val_loader, device)
last_saved = save_classifier(classifier, accuracy)
if accuracy > best[0]:
best[0] = accuracy
best[1] = last_saved
print(f"VALIDATION ---> {loss = }, {accuracy = }")

print("Training done...")
print("Evaluating test set. Loading best model")
classifier.load_state_dict(torch.load(best[1]))
loss, accuracy = eval_model(encoder, classifier, criterion, test_loader, device)
print(f"TEST ---> {loss = }, {accuracy = }")


if __name__ == "__main__":
main()
183 changes: 183 additions & 0 deletions examples/whisper-federated-finetuning/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import argparse
import torch
import flwr as fl
import numpy as np
from torch.utils.data import DataLoader, WeightedRandomSampler
from datasets import load_dataset, load_from_disk, concatenate_datasets
from transformers import WhisperProcessor

from utils import (
get_model,
set_params,
train_one_epoch,
remove_cols,
prepare_silences_dataset,
construct_client_mapping,
get_encoding_fn,
)

parser = argparse.ArgumentParser(description="Flower+Whisper")
parser.add_argument("--cid", type=int, required=True, help="Client id.")
parser.add_argument(
"--server_address", type=str, required=True, help="IP of the server."
)
parser.add_argument(
"--no-compile", action="store_true", help="To not compile client models."
)

CLIENT_DATA = "client_datasets"


class WhisperFlowerClient(fl.client.NumPyClient):
"""A Flower client that does trains a classification head attached to the encoder of
a Whisper-tiny encoder for Keyword spotting."""

def __init__(self, trainset, num_classes: int, disable_tqdm: bool, compile: bool):
self.disable_tqdm = disable_tqdm
self.trainset = trainset.with_format("torch", columns=["data", "targets"])

# Determine device
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
self.encoder, self.classifier = get_model(self.device, num_classes, compile)

def get_parameters(self, config):
"""Return parameters in a format that is understood by the server."""
return [val.cpu().numpy() for _, val in self.classifier.state_dict().items()]

def fit(self, parameters, config):
"""Do on-device training.
Here the client receives the parameters of the classification head from the
server. Then trains that classifier using the data that belongs to this client.
Finally, The updated classifier is sent back to the server for aggregation.
"""

# Apply the classifier parameters to the model in this client
set_params(self.classifier, parameters)

# Read from config
batch, epochs = config["batch_size"], config["epochs"]

# construct sampler in order to have balanced batches
hist = np.histogram(self.trainset["targets"], bins=12)
w_per_class = (
len(self.trainset) / hist[0]
) # doesn't have to add up to 1 (relative is what matters)
# print(f"{w_per_class = }")
w_ss = [w_per_class[t] for t in self.trainset["targets"]]
ss = WeightedRandomSampler(w_ss, len(w_ss))

# Construct dataloader
train_loader = DataLoader(
self.trainset,
batch_size=batch,
shuffle=False,
num_workers=0,
sampler=ss,
drop_last=True,
)

# Define optimizer and criterion
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(self.classifier.parameters(), lr=0.001)
# Train
train_one_epoch(
self.encoder,
self.classifier,
optimizer,
criterion,
train_loader,
self.device,
disable_tqdm=self.disable_tqdm,
)

# Return local classification head and statistics
return self.get_parameters({}), len(train_loader.dataset), {}


def get_client_fn(
full_data,
encoding_fn,
client_mapping,
client_data_path: str = "./",
num_classes: int = 12,
disable_tqdm: bool = False,
compile: bool = True,
):
"""Return a function that can be used to instantiate a particular client."""

def client_fn(cid: str):
torch.set_float32_matmul_precision(
"high"
) # If “high” or “medium” are set then the TensorFloat32 is used

# if dataset hasn't been processed for this client, do so.
# else, just load it
try:
full_train_dataset = load_from_disk(f"{client_data_path}/client{cid}.hf")
except:
# get this client's data and preprocess it
print(f"Dataset for client {cid} not found. Pre-processing...")
og_threads = torch.get_num_threads()
torch.set_num_threads(1)
sc_client = full_data.filter(
lambda example: example["speaker_id"] in client_mapping[int(cid)]
)
client_train_data = sc_client.map(
encoding_fn, num_proc=4, remove_columns=remove_cols
)

# now let's add some _silence_ training examples (add 10% of total examples in this client's data)
ratio_silences_for_client = 0.1 * (len(client_train_data) / len(full_data))
silence_dataset = prepare_silences_dataset(
full_data, ratio_silences_for_client
)
print(
f"adding {len(silence_dataset)} to client data ({len(client_train_data)})"
)
silence_enc = silence_dataset.map(encoding_fn, remove_columns=remove_cols)

full_train_dataset = concatenate_datasets([client_train_data, silence_enc])
# save dataset. It will be loaded next time this client is spawned
full_train_dataset.save_to_disk(f"{client_data_path}/client{cid}.hf")
torch.set_num_threads(og_threads)

return WhisperFlowerClient(
full_train_dataset, num_classes, disable_tqdm, compile
)

return client_fn


def run_client():
"""Run clinet."""

# Parse input arguments
args = parser.parse_args()

sc_train = load_dataset("speech_commands", "v0.02", split="train", token=False)

# generate splits
client_mapping = construct_client_mapping(sc_train, num_clients=100)

# pre-process all partitions (+store to disk)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
prepare_dataset_fn = get_encoding_fn(processor)

client_fn = get_client_fn(
sc_train,
prepare_dataset_fn,
client_mapping,
compile=not (args.no_compile),
client_data_path=CLIENT_DATA,
)

fl.client.start_numpy_client(
server_address=f"{args.server_address}:8080", client=client_fn(args.cid)
)


if __name__ == "__main__":
run_client()
19 changes: 19 additions & 0 deletions examples/whisper-federated-finetuning/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "whisper-flower"
version = "0.1.0"
description = "On-device Federated Downstreaming for Speech Classification"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = { extras = ["simulation"], version = ">=1.0,<2.0" }
transformers = "4.32.1"
tokenizers = "0.13.3"
datasets = "2.14.6"
soundfile = "0.12.1"
librosa = "0.10.1"
# this example was tested with pytorch 2.1.0
7 changes: 7 additions & 0 deletions examples/whisper-federated-finetuning/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
transformers==4.32.1
tokenizers==0.13.3
datasets==2.14.6
soundfile==0.12.1
librosa==0.10.1
flwr==1.5.0
ray==2.6.3
Loading

0 comments on commit a3111ed

Please sign in to comment.