-
Notifications
You must be signed in to change notification settings - Fork 898
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Whisper example for federated downstreaming (#2569)
- Loading branch information
Showing
13 changed files
with
1,055 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file added
BIN
+51.6 KB
...s/whisper-federated-finetuning/_static/federated_finetuning_flower_pipeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+32 KB
examples/whisper-federated-finetuning/_static/keyword_spotting_overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+32.4 KB
examples/whisper-federated-finetuning/_static/whisper_flower_data.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import argparse | ||
from datasets import load_dataset | ||
from transformers import WhisperForConditionalGeneration, WhisperProcessor | ||
import torch | ||
from torch.utils.data import DataLoader, WeightedRandomSampler | ||
import numpy as np | ||
from datasets import concatenate_datasets | ||
import random | ||
|
||
from utils import ( | ||
get_model, | ||
train_one_epoch, | ||
eval_model, | ||
prepare_silences_dataset, | ||
get_encoding_fn, | ||
remove_cols, | ||
) | ||
|
||
random.seed(1989) | ||
torch.set_float32_matmul_precision( | ||
"high" | ||
) # If “high” or “medium” are set then the TensorFloat32 is used | ||
NUM_CLASSES = 12 | ||
parser = argparse.ArgumentParser(description="Whisper centralised") | ||
|
||
parser.add_argument("--checkpoint", type=str, help="path to classifier`s checkpoint") | ||
parser.add_argument( | ||
"--epochs", type=int, default=3, help="Number of epochs of training." | ||
) | ||
parser.add_argument( | ||
"--compile", action="store_true", help="compiles model (pytorch 2.0+ only)" | ||
) | ||
|
||
|
||
def save_classifier(classifier, acc: float): | ||
filename = f"classifier_{acc:.4f}.pt" | ||
torch.save(classifier.cpu().state_dict(), filename) | ||
return filename | ||
|
||
|
||
def main(): | ||
args = parser.parse_args() | ||
|
||
# load train and test partitions | ||
sc = load_dataset("speech_commands", "v0.02", split="train", token=False) | ||
sc_val = load_dataset("speech_commands", "v0.02", split="validation", token=False) | ||
sc_test = load_dataset("speech_commands", "v0.02", split="test", token=False) | ||
|
||
# pre-process dataset | ||
# ! If you know how to speedup this pre-processing stage, please do let us know! | ||
# ! Become a contributor by proposing as a new PR ! | ||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") | ||
prepare_dataset_fn = get_encoding_fn(processor) | ||
og_threads = torch.get_num_threads() | ||
print(f"{og_threads = }") | ||
torch.set_num_threads( | ||
1 | ||
) # not clear to me why we need this in order to be able to use `num_proc > 1 for .map` | ||
train_encoded = sc.map(prepare_dataset_fn, num_proc=4, remove_columns=remove_cols) | ||
val_encoded = sc_val.map(prepare_dataset_fn, num_proc=4, remove_columns=remove_cols) | ||
test_encoded = sc_test.map( | ||
prepare_dataset_fn, num_proc=4, remove_columns=remove_cols | ||
) | ||
|
||
# create and pre-process the dataset of silences | ||
silences_dataset = prepare_silences_dataset(sc, ratio_silence=0.1) | ||
# ! You might want to save this encoded_silences dataset to disk, so this stage is not | ||
# ! needed each time you run the code. Alternatively, this silence generation could be | ||
# ! implemented as part of a `collate_fn` in the standard PyTorch dataloader... | ||
encoded_silences = silences_dataset.map( | ||
prepare_dataset_fn, num_proc=4, remove_columns=remove_cols | ||
) | ||
full_train_dataset = concatenate_datasets([train_encoded, encoded_silences]) | ||
|
||
torch.set_num_threads(og_threads) | ||
|
||
lbls = set(full_train_dataset["targets"]) | ||
print(f"{lbls = }") | ||
hist = np.histogram(full_train_dataset["targets"], bins=12) | ||
print(f"{[int(count) for count in hist[0]]}") | ||
|
||
# make balanced batches with a WeightedRandomSampler | ||
w_per_class = ( | ||
len(full_train_dataset) / hist[0] | ||
) # doesn't have to add up to 1 (relative is what matters) | ||
print(f"{w_per_class = }") | ||
w_ss = [w_per_class[t] for t in full_train_dataset["targets"]] | ||
sampler = WeightedRandomSampler(w_ss, len(w_ss)) | ||
|
||
# prepare dataloaders | ||
train_dataset = full_train_dataset.with_format("torch", columns=["data", "targets"]) | ||
train_loader = DataLoader( | ||
train_dataset, batch_size=64, shuffle=False, num_workers=4, sampler=sampler | ||
) | ||
val_encoded = val_encoded.with_format("torch", columns=["data", "targets"]) | ||
val_loader = DataLoader(val_encoded, batch_size=64, num_workers=4) | ||
test_dataset = test_encoded.with_format("torch", columns=["data", "targets"]) | ||
test_loader = DataLoader(test_dataset, batch_size=64, num_workers=4) | ||
|
||
# model to cuda, set criterion, classification layer to train and optimiser | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
encoder, classifier = get_model(device, num_classes=12) | ||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
if args.checkpoint: | ||
print(f"Loading checkpoint: {args.checkpoint = }") | ||
classifier.load_state_dict(torch.load(args.checkpoint)) | ||
classifier = classifier.to(device) | ||
optimizer = torch.optim.SGD(classifier.parameters(), lr=0.001) | ||
encoder.eval() | ||
|
||
# Let's count the size of the classification head | ||
classifier_head_params = sum(p.numel() for p in classifier.parameters()) | ||
print(f"{classifier_head_params = }") | ||
|
||
# eval initial model | ||
loss, accuracy = eval_model(encoder, classifier, criterion, val_loader, device) | ||
print(f"Initial (loss, acc): {loss = }, {accuracy = }") | ||
best = [-float("inf"), None] | ||
for e in range(args.epochs): | ||
print(f"Epoch: {e}") | ||
train_one_epoch(encoder, classifier, optimizer, criterion, train_loader, device) | ||
loss, accuracy = eval_model(encoder, classifier, criterion, val_loader, device) | ||
last_saved = save_classifier(classifier, accuracy) | ||
if accuracy > best[0]: | ||
best[0] = accuracy | ||
best[1] = last_saved | ||
print(f"VALIDATION ---> {loss = }, {accuracy = }") | ||
|
||
print("Training done...") | ||
print("Evaluating test set. Loading best model") | ||
classifier.load_state_dict(torch.load(best[1])) | ||
loss, accuracy = eval_model(encoder, classifier, criterion, test_loader, device) | ||
print(f"TEST ---> {loss = }, {accuracy = }") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import argparse | ||
import torch | ||
import flwr as fl | ||
import numpy as np | ||
from torch.utils.data import DataLoader, WeightedRandomSampler | ||
from datasets import load_dataset, load_from_disk, concatenate_datasets | ||
from transformers import WhisperProcessor | ||
|
||
from utils import ( | ||
get_model, | ||
set_params, | ||
train_one_epoch, | ||
remove_cols, | ||
prepare_silences_dataset, | ||
construct_client_mapping, | ||
get_encoding_fn, | ||
) | ||
|
||
parser = argparse.ArgumentParser(description="Flower+Whisper") | ||
parser.add_argument("--cid", type=int, required=True, help="Client id.") | ||
parser.add_argument( | ||
"--server_address", type=str, required=True, help="IP of the server." | ||
) | ||
parser.add_argument( | ||
"--no-compile", action="store_true", help="To not compile client models." | ||
) | ||
|
||
CLIENT_DATA = "client_datasets" | ||
|
||
|
||
class WhisperFlowerClient(fl.client.NumPyClient): | ||
"""A Flower client that does trains a classification head attached to the encoder of | ||
a Whisper-tiny encoder for Keyword spotting.""" | ||
|
||
def __init__(self, trainset, num_classes: int, disable_tqdm: bool, compile: bool): | ||
self.disable_tqdm = disable_tqdm | ||
self.trainset = trainset.with_format("torch", columns=["data", "targets"]) | ||
|
||
# Determine device | ||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
# processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") | ||
self.encoder, self.classifier = get_model(self.device, num_classes, compile) | ||
|
||
def get_parameters(self, config): | ||
"""Return parameters in a format that is understood by the server.""" | ||
return [val.cpu().numpy() for _, val in self.classifier.state_dict().items()] | ||
|
||
def fit(self, parameters, config): | ||
"""Do on-device training. | ||
Here the client receives the parameters of the classification head from the | ||
server. Then trains that classifier using the data that belongs to this client. | ||
Finally, The updated classifier is sent back to the server for aggregation. | ||
""" | ||
|
||
# Apply the classifier parameters to the model in this client | ||
set_params(self.classifier, parameters) | ||
|
||
# Read from config | ||
batch, epochs = config["batch_size"], config["epochs"] | ||
|
||
# construct sampler in order to have balanced batches | ||
hist = np.histogram(self.trainset["targets"], bins=12) | ||
w_per_class = ( | ||
len(self.trainset) / hist[0] | ||
) # doesn't have to add up to 1 (relative is what matters) | ||
# print(f"{w_per_class = }") | ||
w_ss = [w_per_class[t] for t in self.trainset["targets"]] | ||
ss = WeightedRandomSampler(w_ss, len(w_ss)) | ||
|
||
# Construct dataloader | ||
train_loader = DataLoader( | ||
self.trainset, | ||
batch_size=batch, | ||
shuffle=False, | ||
num_workers=0, | ||
sampler=ss, | ||
drop_last=True, | ||
) | ||
|
||
# Define optimizer and criterion | ||
criterion = torch.nn.CrossEntropyLoss() | ||
optimizer = torch.optim.SGD(self.classifier.parameters(), lr=0.001) | ||
# Train | ||
train_one_epoch( | ||
self.encoder, | ||
self.classifier, | ||
optimizer, | ||
criterion, | ||
train_loader, | ||
self.device, | ||
disable_tqdm=self.disable_tqdm, | ||
) | ||
|
||
# Return local classification head and statistics | ||
return self.get_parameters({}), len(train_loader.dataset), {} | ||
|
||
|
||
def get_client_fn( | ||
full_data, | ||
encoding_fn, | ||
client_mapping, | ||
client_data_path: str = "./", | ||
num_classes: int = 12, | ||
disable_tqdm: bool = False, | ||
compile: bool = True, | ||
): | ||
"""Return a function that can be used to instantiate a particular client.""" | ||
|
||
def client_fn(cid: str): | ||
torch.set_float32_matmul_precision( | ||
"high" | ||
) # If “high” or “medium” are set then the TensorFloat32 is used | ||
|
||
# if dataset hasn't been processed for this client, do so. | ||
# else, just load it | ||
try: | ||
full_train_dataset = load_from_disk(f"{client_data_path}/client{cid}.hf") | ||
except: | ||
# get this client's data and preprocess it | ||
print(f"Dataset for client {cid} not found. Pre-processing...") | ||
og_threads = torch.get_num_threads() | ||
torch.set_num_threads(1) | ||
sc_client = full_data.filter( | ||
lambda example: example["speaker_id"] in client_mapping[int(cid)] | ||
) | ||
client_train_data = sc_client.map( | ||
encoding_fn, num_proc=4, remove_columns=remove_cols | ||
) | ||
|
||
# now let's add some _silence_ training examples (add 10% of total examples in this client's data) | ||
ratio_silences_for_client = 0.1 * (len(client_train_data) / len(full_data)) | ||
silence_dataset = prepare_silences_dataset( | ||
full_data, ratio_silences_for_client | ||
) | ||
print( | ||
f"adding {len(silence_dataset)} to client data ({len(client_train_data)})" | ||
) | ||
silence_enc = silence_dataset.map(encoding_fn, remove_columns=remove_cols) | ||
|
||
full_train_dataset = concatenate_datasets([client_train_data, silence_enc]) | ||
# save dataset. It will be loaded next time this client is spawned | ||
full_train_dataset.save_to_disk(f"{client_data_path}/client{cid}.hf") | ||
torch.set_num_threads(og_threads) | ||
|
||
return WhisperFlowerClient( | ||
full_train_dataset, num_classes, disable_tqdm, compile | ||
) | ||
|
||
return client_fn | ||
|
||
|
||
def run_client(): | ||
"""Run clinet.""" | ||
|
||
# Parse input arguments | ||
args = parser.parse_args() | ||
|
||
sc_train = load_dataset("speech_commands", "v0.02", split="train", token=False) | ||
|
||
# generate splits | ||
client_mapping = construct_client_mapping(sc_train, num_clients=100) | ||
|
||
# pre-process all partitions (+store to disk) | ||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") | ||
prepare_dataset_fn = get_encoding_fn(processor) | ||
|
||
client_fn = get_client_fn( | ||
sc_train, | ||
prepare_dataset_fn, | ||
client_mapping, | ||
compile=not (args.no_compile), | ||
client_data_path=CLIENT_DATA, | ||
) | ||
|
||
fl.client.start_numpy_client( | ||
server_address=f"{args.server_address}:8080", client=client_fn(args.cid) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_client() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
[build-system] | ||
requires = ["poetry-core>=1.4.0"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.poetry] | ||
name = "whisper-flower" | ||
version = "0.1.0" | ||
description = "On-device Federated Downstreaming for Speech Classification" | ||
authors = ["The Flower Authors <[email protected]>"] | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.8,<3.11" | ||
flwr = { extras = ["simulation"], version = ">=1.0,<2.0" } | ||
transformers = "4.32.1" | ||
tokenizers = "0.13.3" | ||
datasets = "2.14.6" | ||
soundfile = "0.12.1" | ||
librosa = "0.10.1" | ||
# this example was tested with pytorch 2.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
transformers==4.32.1 | ||
tokenizers==0.13.3 | ||
datasets==2.14.6 | ||
soundfile==0.12.1 | ||
librosa==0.10.1 | ||
flwr==1.5.0 | ||
ray==2.6.3 |
Oops, something went wrong.