Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[eval] fix check for language arg #139

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 39 additions & 38 deletions training/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,22 +556,47 @@ def main():

assistant_model.cuda()

# 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
# 6. Define generation arguments
gen_kwargs = {
"max_length": data_args.generation_max_length,
"return_timestamps": data_args.return_timestamps,
"num_beams": data_args.num_beams,
"top_k": 0,
}

if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
gen_kwargs["language"] = data_args.language
gen_kwargs["task"] = data_args.task
elif data_args.language is not None:
raise ValueError(
"Setting language token for an English-only checkpoint is not permitted. The language argument should "
"only be set for multilingual checkpoints."
)
Comment on lines +570 to +574
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Informative error that will be thrown if language is passed but the checkpoint is English-only


if assistant_model is not None:
gen_kwargs["assistant_model"] = assistant_model

if data_args.prompt_text is not None:
gen_kwargs["prompt_ids"] = processor.get_prompt_ids(data_args.prompt_text, return_tensors="pt").to("cuda:0")

long_form_gen_kwargs = {
"condition_on_prev_tokens": data_args.condition_on_prev_tokens,
"compression_ratio_threshold": data_args.compression_ratio_threshold,
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) if data_args.temperature_fallback else 0,
"logprob_threshold": data_args.logprob_threshold,
"no_speech_threshold": data_args.no_speech_threshold,
}

# 7. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
# so we just need to set the correct target sampling rate.
raw_datasets = raw_datasets.cast_column(
data_args.audio_column_name,
datasets.features.Audio(sampling_rate=processor.feature_extractor.sampling_rate),
)

# 7. Preprocessing the datasets.
# 8. Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the targets.
audio_column_name = data_args.audio_column_name
language = language_to_id(data_args.language, model.generation_config) if data_args.language else None
Copy link
Collaborator Author

@sanchit-gandhi sanchit-gandhi Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue previously was that we could have passed the language arg for an English-only checkpoint, which is invalid. We now check this case first, and then set the normalizer afterwards

if language is None or language == "<|en|>":
normalizer = EnglishTextNormalizer(processor.tokenizer.english_spelling_normalizer)
else:
normalizer = BasicTextNormalizer()

sampling_rate = processor.feature_extractor.sampling_rate

if data_args.samples_per_dataset is not None:
Expand Down Expand Up @@ -631,6 +656,12 @@ def prepare_dataset(batch):
logger.info(f"Data preprocessing finished. Files cached at {cache}.")
return

language = language_to_id(data_args.language, model.generation_config) if data_args.language else None
if language is None or language == "<|en|>":
normalizer = EnglishTextNormalizer(processor.tokenizer.english_spelling_normalizer)
else:
normalizer = BasicTextNormalizer()

metric = evaluate.load("wer")

def compute_metrics(pred_str, label_str):
Expand All @@ -645,36 +676,6 @@ def compute_metrics(pred_str, label_str):
wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
return wer

gen_kwargs = {
"max_length": data_args.generation_max_length,
"return_timestamps": data_args.return_timestamps,
"num_beams": data_args.num_beams,
"top_k": 0,
}

if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
gen_kwargs["language"] = data_args.language
gen_kwargs["task"] = data_args.task
elif data_args.language is not None:
raise ValueError(
"Setting language token for an English-only checkpoint is not permitted. The language argument should "
"only be set for multilingual checkpoints."
)

if assistant_model is not None:
gen_kwargs["assistant_model"] = assistant_model

if data_args.prompt_text is not None:
gen_kwargs["prompt_ids"] = processor.get_prompt_ids(data_args.prompt_text, return_tensors="pt").to("cuda:0")

long_form_gen_kwargs = {
"condition_on_prev_tokens": data_args.condition_on_prev_tokens,
"compression_ratio_threshold": data_args.compression_ratio_threshold,
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) if data_args.temperature_fallback else 0,
"logprob_threshold": data_args.logprob_threshold,
"no_speech_threshold": data_args.no_speech_threshold,
}

def benchmark(batch):
if model_pipeline is None:
inputs = torch.stack(batch["input_features"], dim=0).cuda()
Expand Down