From 79b5bc70a9d1ce5ecef3c2a9ad0cdd8fd1b58f5a Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 18 Jun 2024 15:45:14 +0100 Subject: [PATCH] change ordering --- training/run_eval.py | 77 ++++++++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 38 deletions(-) diff --git a/training/run_eval.py b/training/run_eval.py index 768ac33..4c9ea37 100644 --- a/training/run_eval.py +++ b/training/run_eval.py @@ -556,22 +556,47 @@ def main(): assistant_model.cuda() - # 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio, + # 6. Define generation arguments + gen_kwargs = { + "max_length": data_args.generation_max_length, + "return_timestamps": data_args.return_timestamps, + "num_beams": data_args.num_beams, + "top_k": 0, + } + + if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual: + gen_kwargs["language"] = data_args.language + gen_kwargs["task"] = data_args.task + elif data_args.language is not None: + raise ValueError( + "Setting language token for an English-only checkpoint is not permitted. The language argument should " + "only be set for multilingual checkpoints." + ) + + if assistant_model is not None: + gen_kwargs["assistant_model"] = assistant_model + + if data_args.prompt_text is not None: + gen_kwargs["prompt_ids"] = processor.get_prompt_ids(data_args.prompt_text, return_tensors="pt").to("cuda:0") + + long_form_gen_kwargs = { + "condition_on_prev_tokens": data_args.condition_on_prev_tokens, + "compression_ratio_threshold": data_args.compression_ratio_threshold, + "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) if data_args.temperature_fallback else 0, + "logprob_threshold": data_args.logprob_threshold, + "no_speech_threshold": data_args.no_speech_threshold, + } + + # 7. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio, # so we just need to set the correct target sampling rate. raw_datasets = raw_datasets.cast_column( data_args.audio_column_name, datasets.features.Audio(sampling_rate=processor.feature_extractor.sampling_rate), ) - # 7. Preprocessing the datasets. + # 8. Preprocessing the datasets. # We need to read the audio files as arrays and tokenize the targets. audio_column_name = data_args.audio_column_name - language = language_to_id(data_args.language, model.generation_config) if data_args.language else None - if language is None or language == "<|en|>": - normalizer = EnglishTextNormalizer(processor.tokenizer.english_spelling_normalizer) - else: - normalizer = BasicTextNormalizer() - sampling_rate = processor.feature_extractor.sampling_rate if data_args.samples_per_dataset is not None: @@ -631,6 +656,12 @@ def prepare_dataset(batch): logger.info(f"Data preprocessing finished. Files cached at {cache}.") return + language = language_to_id(data_args.language, model.generation_config) if data_args.language else None + if language is None or language == "<|en|>": + normalizer = EnglishTextNormalizer(processor.tokenizer.english_spelling_normalizer) + else: + normalizer = BasicTextNormalizer() + metric = evaluate.load("wer") def compute_metrics(pred_str, label_str): @@ -645,36 +676,6 @@ def compute_metrics(pred_str, label_str): wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str) return wer - gen_kwargs = { - "max_length": data_args.generation_max_length, - "return_timestamps": data_args.return_timestamps, - "num_beams": data_args.num_beams, - "top_k": 0, - } - - if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual: - gen_kwargs["language"] = data_args.language - gen_kwargs["task"] = data_args.task - elif data_args.language is not None: - raise ValueError( - "Setting language token for an English-only checkpoint is not permitted. The language argument should " - "only be set for multilingual checkpoints." - ) - - if assistant_model is not None: - gen_kwargs["assistant_model"] = assistant_model - - if data_args.prompt_text is not None: - gen_kwargs["prompt_ids"] = processor.get_prompt_ids(data_args.prompt_text, return_tensors="pt").to("cuda:0") - - long_form_gen_kwargs = { - "condition_on_prev_tokens": data_args.condition_on_prev_tokens, - "compression_ratio_threshold": data_args.compression_ratio_threshold, - "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) if data_args.temperature_fallback else 0, - "logprob_threshold": data_args.logprob_threshold, - "no_speech_threshold": data_args.no_speech_threshold, - } - def benchmark(batch): if model_pipeline is None: inputs = torch.stack(batch["input_features"], dim=0).cuda()