-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathevaluate_salmonn.py
135 lines (107 loc) · 4.61 KB
/
evaluate_salmonn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import argparse
import json
import random
import sys
import torch
import pandas as pd
from pathlib import Path
from tqdm import tqdm
# Add custom module path
sys.path.append(str(Path(__file__).parent / "audiolm-trainer"))
# Custom modules
from salmonn_utils import SALMONNTestDataset, load_preprocessor, load_model
from config import Config
from utils import get_dataloader, prepare_sample
from metrics import compute_wer, compute_spider
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--cfg-path",
type=str,
help='path to configuration file',
default='salmonn_eval_config.yaml'
)
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
parser.add_argument("--task", type=str, default=None, help="Task to evaluate", choices=['asr', 'aac'])
parser.add_argument("--skip_scoring", action='store_true', help='if True, skip scoring after inference')
return parser.parse_args()
def get_dataset(dataset_cfg, run_cfg, task):
testset = SALMONNTestDataset(
dataset_cfg.prefix, dataset_cfg.test_ann_path, dataset_cfg.whisper_path, task
)
test_loader = get_dataloader(testset, run_cfg, is_train=False, use_distributed=False)
return test_loader
def main(args):
cfg = Config(args)
# Load models
salmonn_preprocessor = load_preprocessor(cfg)
llama_model, tokenizer = load_model(salmonn_preprocessor)
salmonn_preprocessor.llama_model = llama_model
# Load data
dataloader = get_dataset(cfg.config.datasets, cfg.config.run, args.task)
with open("audiolm-trainer/prompts/test_prompt.json", "r") as f:
test_prompt = json.load(f)
# Evaluation
testset_ids, hyps, refs = [], [], []
for samples in tqdm(dataloader):
testset_id = samples["testset_id"]
testset_ids.extend(testset_id)
# Preprocess
samples = prepare_sample(samples, cuda_enabled=torch.cuda.is_available())
batch_size = samples["spectrogram"].shape[0]
spectrogram = samples["spectrogram"]
raw_wav = samples.get("raw_wav", None)
audio_padding_mask = samples.get("padding_mask", None)
speech_embeds, speech_atts = salmonn_preprocessor.encode_speech(spectrogram, raw_wav=raw_wav, audio_padding_mask=audio_padding_mask)
# Add prompt embeds + audio embed
prompts = [test_prompt[task] for task in samples['task']]
templated_prompts = [cfg.config.model.prompt_template.format(prompt) for prompt in prompts]
speech_embeds, speech_atts = salmonn_preprocessor.prompt_wrap(speech_embeds, speech_atts, templated_prompts, multi_prompt=True)
bos = torch.ones(
[batch_size, 1],
dtype=torch.int32,
device=speech_embeds.device,
) * tokenizer.bos_token_id
bos_embeds = llama_model.model.model.embed_tokens(bos)
atts_bos = speech_atts[:, :1]
embeds = torch.cat([bos_embeds, speech_embeds], dim=1)
attns = torch.cat([atts_bos, speech_atts], dim=1)
generate_cfg = cfg.config.generate
# Generation
outputs = llama_model.model.generate(
inputs_embeds=embeds,
pad_token_id=llama_model.config.eos_token_id[0],
max_new_tokens=generate_cfg.get("max_new_tokens", 200),
num_beams=generate_cfg.get("num_beams", 4),
do_sample=generate_cfg.get("do_sample", False),
min_length=generate_cfg.get("min_length", 1),
temperature=generate_cfg.get("temperature", 1.0),
top_p=generate_cfg.get("top_p", 0.9),
repetition_penalty=generate_cfg.get("repetition_penalty", 1.0),
length_penalty=generate_cfg.get("length_penalty", 1.0),
attention_mask=attns,
)
results = tokenizer.batch_decode(outputs)
hyp = [result.split(generate_cfg.end_sym)[0].lower() for result in results]
hyps.extend(hyp)
if not args.skip_scoring:
ref = samples["text"]
refs.extend(ref)
if not args.skip_scoring:
if args.task == 'asr':
compute_wer(hyps, refs)
elif args.task == 'aac':
compute_spider(hyps, refs)
result_df = pd.DataFrame({"testset_id": testset_ids, "text": hyps})
result_df.to_csv("submission.csv", index=False)
if __name__ == '__main__':
args = parse_args()
random.seed(42)
main(args)