Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize data loading, preprocessing, and metric computation #120

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 49 additions & 59 deletions training/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,21 +461,27 @@ def main():
streaming=data_args.streaming,
num_proc=data_args.preprocessing_num_workers,
)
if dataset_dict["text_column_name"] not in list(sub_dataset.features.keys()):

text_column_name = dataset_dict["text_column_name"]
dataset_features = sub_dataset.features

if text_column_name not in dataset_features:
raise ValueError(
f"`--text_column_name` {dataset_dict['text_column_name']} not found in the evaluation "
f"`--text_column_name` {text_column_name} not found in the evaluation "
f"dataset {dataset_dict['name']}. Ensure `text_column_name` is set to the correct column "
f"for the target text. Should be one of {' '.join(list(sub_dataset.features.keys()))}"
f"for the target text. Should be one of {list(dataset_features.keys())}"
)
if dataset_dict["text_column_name"] != "text":
sub_dataset = sub_dataset.rename_column(dataset_dict["text_column_name"], "text")
if not data_args.streaming:
sub_dataset = sub_dataset.to_iterable_dataset()

# Clean-up the dataset name for pretty logging
# ("distil-whisper/librispeech_asr", "validation.clean") -> "librispeech_asr/validation-clean"
pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
raw_datasets[pretty_name] = sub_dataset
if text_column_name != "text":
sub_dataset.rename_column_(text_column_name, "text")

if not data_args.streaming:
sub_dataset = sub_dataset.to_iterable_dataset()

# Clean-up the dataset name for pretty logging
pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
raw_datasets[pretty_name] = sub_dataset


# 5. Load pretrained model, tokenizer, and feature extractor
processor = WhisperProcessor.from_pretrained(
Expand Down Expand Up @@ -561,11 +567,9 @@ def prepare_dataset(batch):
audio,
sampling_rate=sampling_rate,
return_tensors="pt",
truncation=False,
padding="longest",
return_attention_mask=True,
)
if inputs.input_features.shape[-1] < 3000:
if "truncation" in inputs:
inputs = processor.feature_extractor(
audio,
sampling_rate=sampling_rate,
Expand All @@ -577,18 +581,18 @@ def prepare_dataset(batch):
else:
batch["input_features"] = audio

# process audio length
batch["length_in_s"] = [len(sample) / sampling_rate for sample in audio]
# process audio length using numpy
batch["length_in_s"] = np.array([len(sample) for sample in audio]) / sampling_rate

# process targets
batch["reference"] = batch["text"]
return batch

vectorized_datasets = IterableDatasetDict()

for split in raw_datasets:
raw_datasets_features = list(raw_datasets[split].features.keys())

vectorized_datasets[split] = raw_datasets[split].map(
for split, dataset in raw_datasets.items():
raw_datasets_features = list(dataset.features.keys())
vectorized_datasets[split] = dataset.map(
function=prepare_dataset,
remove_columns=raw_datasets_features,
batch_size=data_args.batch_size,
Expand All @@ -609,16 +613,19 @@ def prepare_dataset(batch):

def compute_metrics(pred_str, label_str):
# normalize everything and re-compute the WER
norm_pred_str = [normalizer(pred) for pred in pred_str]
norm_label_str = [normalizer(label) for label in label_str]

# filtering step to only evaluate the samples that correspond to non-zero normalized references:
norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
norm_pred_str, norm_label_str = [], []

for pred, label in zip(pred_str, label_str):
norm_pred = normalizer(pred)
norm_label = normalizer(label)
if len(norm_label) > 0:
norm_pred_str.append(norm_pred)
norm_label_str.append(norm_label)

wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
return wer


gen_kwargs = {
"max_length": data_args.generation_max_length,
"return_timestamps": data_args.return_timestamps,
Expand Down Expand Up @@ -653,25 +660,21 @@ def benchmark(batch):
if model_pipeline is None:
inputs = torch.stack(batch["input_features"], dim=0).cuda()
attention_mask = torch.stack(batch["attention_mask"], dim=0).cuda()
# automatically use long-form args if required
inner_batch_size, num_mels, seq_len = inputs.shape
if seq_len == 3000:
batch_gen_kwargs = gen_kwargs
else:
batch_gen_kwargs = {**gen_kwargs, **long_form_gen_kwargs}
batch_gen_kwargs = gen_kwargs if seq_len == 3000 else {**gen_kwargs, **long_form_gen_kwargs}

set_seed(data_args.seed)
start_time = time.time()
output_ids = model.generate(inputs, attention_mask=attention_mask, **batch_gen_kwargs)
batch["time"] = inner_batch_size * [(time.time() - start_time) / inner_batch_size]
with torch.no_grad():
output_ids = model.generate(inputs, attention_mask=attention_mask, **batch_gen_kwargs)
batch["time"] = [(time.time() - start_time) / inner_batch_size] * inner_batch_size

batch["transcription"] = processor.batch_decode(
output_ids, skip_special_tokens=True, decode_with_timestamps=data_args.return_timestamps
)

else:
inputs = batch["input_features"]
# Time forward: let's make sure that only forward is timed and not pre- and post-processing
time_result = []

def _forward_time(*args, **kwargs):
Expand All @@ -683,7 +686,8 @@ def _forward_time(*args, **kwargs):

model_pipeline._forward = _forward_time

result = model_pipeline(inputs, batch_size=PIPELINE_BATCH_SIZE, generate_kwargs=gen_kwargs)[0]["text"]
with torch.no_grad():
result = model_pipeline(inputs, batch_size=PIPELINE_BATCH_SIZE, generate_kwargs=gen_kwargs)[0]["text"]
batch["transcription"] = [result]
batch["time"] = [sum(time_result)]

Expand All @@ -701,31 +705,25 @@ def _forward_time(*args, **kwargs):
)

stats_dataset = DatasetDict()

all_stats = {"rtf": 0, "wer": 0}
rtf_stats = {
"times_audio_total": 0,
"times_transcription_total": 0,
}
rtf_stats = {"times_audio_total": 0, "times_transcription_total": 0}

logger.info("***** Running Evaluation *****")
for key in generation_arguments:
logger.info(f" {key}: {generation_arguments[key]}")

datasets_evaluated_progress_bar = tqdm(result_datasets, desc="Datasets", position=0)
for split in datasets_evaluated_progress_bar:
for split in tqdm(result_datasets, desc="Datasets"):
transcriptions = []
references = []
stats = {}
times_audio_total = 0
times_transcription_total = 0

datasets_evaluated_progress_bar.write(f"Start benchmarking {split}...")
logger.info(f"Start benchmarking {split}...")
result_iter = iter(result_datasets[split])
for result in tqdm(result_iter, desc="Samples", position=1):
times_audio_total += result["length_in_s"]
times_transcription_total += result["time"]
# ensure prompt is removed from the transcription (awaiting fix in Transformers)
if data_args.prompt_text is not None:
result["transcription"] = result["transcription"].replace(data_args.prompt_text, "")
transcriptions.append(result["transcription"])
Expand All @@ -734,25 +732,20 @@ def _forward_time(*args, **kwargs):
norm_transcriptions = [normalizer(pred) for pred in transcriptions]
norm_references = [normalizer(label) for label in references]

transcriptions = [transcriptions[i] for i in range(len(transcriptions)) if len(norm_references[i]) > 0]
references = [references[i] for i in range(len(references)) if len(norm_references[i]) > 0]

norm_transcriptions = [
norm_transcriptions[i] for i in range(len(norm_transcriptions)) if len(norm_references[i]) > 0
]
norm_references = [norm_references[i] for i in range(len(norm_references)) if len(norm_references[i]) > 0]
transcriptions = [t for t, ref in zip(transcriptions, norm_references) if ref]
references = [ref for ref in references if normalizer(ref)]
norm_transcriptions = [t for t, ref in zip(norm_transcriptions, norm_references) if ref]
norm_references = [ref for ref in norm_references if ref]

stats["wer"] = compute_metrics(norm_transcriptions, norm_references)

wer_per_sample = []
for pred, ref in zip(norm_transcriptions, norm_references):
wer_per_sample.append(compute_metrics([pred], [ref]))
wer_per_sample = [compute_metrics([pred], [ref]) for pred, ref in zip(norm_transcriptions, norm_references)]

stats["rtf"] = times_audio_total / times_transcription_total
stats_dataset[split] = stats

wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in stats.items()])
datasets_evaluated_progress_bar.write(wer_desc)
logger.info(wer_desc)

write_wandb_metric(wandb_logger, stats, prefix=split)

Expand All @@ -771,8 +764,7 @@ def _forward_time(*args, **kwargs):
rtf_stats["times_transcription_total"] += times_transcription_total
all_stats["wer"] += stats["wer"]

all_stats["wer"] = all_stats["wer"] / len(result_datasets)
# technically this is the reciprocal of the RTF, but it makes the scale easier to read on wandb
all_stats["wer"] /= len(result_datasets)
all_stats["rtf"] = rtf_stats["times_audio_total"] / rtf_stats["times_transcription_total"]

stats_dataset["all"] = all_stats
Expand All @@ -783,10 +775,8 @@ def _forward_time(*args, **kwargs):
with tempfile.TemporaryDirectory() as temp_dir:
for split in stats_dataset:
file_name = os.path.join(temp_dir, f"{'_'.join(split.split('/'))}.json")

with open(file_name, "w") as json_file:
json.dump(stats_dataset[split], json_file)

benchmark_artifact.add_file(file_name, split)

wandb_logger.log_artifact(benchmark_artifact)
Expand Down