Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uncertainty and evaluation #10

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 147 additions & 47 deletions asr/asr.py

Large diffs are not rendered by default.

720 changes: 720 additions & 0 deletions asr/comparison.py

Large diffs are not rendered by default.

137 changes: 137 additions & 0 deletions asr/lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from itertools import combinations
from typing import Any

from tqdm import tqdm
import numpy as np

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerBase
from transformers.generation.utils import GenerationMixin

from asr.comparison import MultipleTextsAlignment


class SequenceScore:
"""
Calculates a sequence score for a text from an autoregressive LM.
"""
def __init__(
self,
name: str | None = 'ai-forever/rugpt3large_based_on_gpt2',
tokenizer: PreTrainedTokenizerBase | None = None,
model: GenerationMixin | None = None,
):
if name is not None:
assert not tokenizer and not model
# https://stackoverflow.com/a/75242984
tokenizer = AutoTokenizer.from_pretrained(name, add_bos_token=True)
model = AutoModelForCausalLM.from_pretrained(name)
else:
assert tokenizer and model


self.tokenizer = tokenizer
self.model = model
self.model.eval()

def __call__(self, text: str) -> int:
inputs = self.tokenizer([text], return_tensors='pt')
with torch.no_grad():
logits = self.model(**inputs, return_dict=True).logits[:, :-1]
targets = inputs['input_ids'][:, 1:]
logloss = F.cross_entropy(input=logits.transpose(1, 2), target=targets)

logloss = logloss.cpu().detach().numpy()

if np.isnan(logloss):
return 0 # TODO why happens?

return -logloss


def get_all_subsets(elements: list[Any]):
"""
Returns all subsets of a list.
```
get_all_subsets([1, 2, 3])
>>> [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
```
"""
return sum((
[list(x) for x in combinations(elements, r)]
for r in range(len(elements) + 1)
), [])


def accept_suggestions_by_lm(
base_vs_additional: MultipleTextsAlignment,
suggestion_indices: list[int],
scorer: SequenceScore,
look_forward: int = 2,
context_before: int = 100,
context_after: int = 50,
pbar: bool = True,
verbose: bool = False,
) -> list[int]:
"""
When two predictions disagree, selects one that LM prefers. Returns suggestion_indices
when the second prediction (`base_vs_additional.text2`) was selected.

TODO better docstring
"""

orig_indices_to_resolve = suggestion_indices
indices_to_resolve = orig_indices_to_resolve.copy()
indices_accepted = []

if pbar:
_pbar = tqdm(total=len(indices_to_resolve))

while len(indices_to_resolve):
indices = indices_to_resolve[:look_forward]

scores = {}

for indices_to_consider in get_all_subsets(indices):
text = base_vs_additional.substitute(replace=indices_accepted + indices_to_consider)

start_idx = base_vs_additional.matches[indices[0]].char_start1
end_idx = (
base_vs_additional.matches[indices[-1]].char_end1
+ len(text) - len(base_vs_additional.text1.text)
)

start_idx -= context_before
end_idx += context_after

start_idx = np.clip(start_idx, 0, len(text))
end_idx = np.clip(end_idx, 0, len(text))

text = text[start_idx:end_idx]

scores[tuple(indices_to_consider)] = {
'score': scorer(text),
# 'text' : text
}

best_option = max(scores, key=lambda k: scores[k]['score'])

if verbose:
print(f'[{len(indices_to_resolve)}] selected {best_option} from {scores}')

should_accept_index = indices[0] in best_option

if should_accept_index:
indices_accepted.append(indices[0])

indices_to_resolve = indices_to_resolve[1:]

if pbar:
_pbar.update(1)

if pbar:
_pbar.close()

return indices_accepted
179 changes: 179 additions & 0 deletions asr/whisper_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from typing import Any

import torch
import numpy as np
from transformers.models.whisper.tokenization_whisper import bytes_to_unicode
from transformers import (
AutomaticSpeechRecognitionPipeline,
WhisperFeatureExtractor,
WhisperTokenizer,
WhisperTokenizerFast,
WhisperForConditionalGeneration
)

from .comparison import TokenizedText


def whisper_pipeline_transcribe_with_word_scores(
waveform: np.ndarray,
recognizer: AutomaticSpeechRecognitionPipeline,
) -> tuple[TokenizedText, list[list[str]], list[list[float]]]:
"""
A wrapper around `.whisper_transcribe_with_word_scores()` to use a pipeline.
Example:

```
import librosa
from asr.asr import initialize_model_for_speech_recognition
waveform, _ = librosa.load('tests/testdata/test_sound_ru.wav', sr=None)
pipeline = initialize_model_for_speech_recognition()
whisper_pipeline_transcribe_with_word_scores(waveform, pipeline)
```
"""
return whisper_transcribe_with_word_scores(
waveform,
recognizer.feature_extractor,
recognizer.tokenizer,
recognizer.model,
recognizer._forward_params, # lang, task
)


def whisper_transcribe_with_word_scores(
waveform: np.ndarray,
feature_extractor: WhisperFeatureExtractor,
tokenizer: WhisperTokenizer | WhisperTokenizerFast,
model: WhisperForConditionalGeneration,
generate_kwargs: dict[str, Any],
) -> tuple[TokenizedText, list[list[str]], list[list[float]]]:
"""
Transcribes the audio with Whisper and returns:
- the resulting text tokenized into words
- a list of tokens for each word
- a list of token scores for each word

Example:
```
import librosa
waveform, _ = librosa.load('tests/testdata/test_sound_ru.wav', sr=None)
recognizer = pipeline('automatic-speech-recognition', model='openai/whisper-large-v3')
whisper_transcribe_with_word_scores(
waveform,
recognizer.feature_extractor,
recognizer.tokenizer,
recognizer.model,
{'language': '<|ru|>', 'task': 'transcribe'}, # or `recognizer._forward_params`
)

>>> (
TokenizedText(
text=' нейронные сети это хорошо.',
tokens=[
Substring(start=1, stop=10, text='нейронные', is_punct=False),
Substring(start=11, stop=15, text='сети', is_punct=False),
Substring(start=16, stop=19, text='это', is_punct=False),
Substring(start=20, stop=26, text='хорошо', is_punct=False),
Substring(start=26, stop=27, text='.', is_punct=True)
]
),
[[' ней', 'рон', 'ные'], [' с', 'ети'], [' это'], [' хорошо']],
[[-0.61, -6.80e-05, -0.00], [-8.82e-05, -2.41e-05], [-0.57], [-0.00]]
)
```
"""
assert model.config.model_type == 'whisper'

inputs = feature_extractor(
waveform,
return_tensors='pt',
sampling_rate=16_000,
).to(model.device, model.dtype)
result = model.generate(
**inputs,
**generate_kwargs,
return_dict_in_generate=True,
return_token_timestamps=True,
)

# convert token ids and logits to numpy
token_ids = result['sequences'][0].cpu().numpy()
logits = torch.nn.functional.log_softmax(torch.stack(result['scores']), dim=-1).cpu().numpy()

# skip start special tokens to align with logits
token_ids = token_ids[-len(logits):]

# skip all special tokens
is_special = np.array([id in tokenizer.all_special_ids for id in token_ids])
token_ids = token_ids[~is_special]
logits = logits[~is_special]

score_per_token = np.array([float(l[0, token_id]) for token_id, l in zip(token_ids, logits)])

# reproducing whisper bpe decoding
byte_decoder = {v: k for k, v in bytes_to_unicode().items()}
bytes_list_per_token = [
[byte_decoder[x] for x in bytes_str]
for bytes_str in tokenizer.convert_ids_to_tokens(token_ids)
]

# searching for token positions in the text
token_end_positions = []
for i in range(len(bytes_list_per_token)):
concatenated_bytes = sum(bytes_list_per_token[:i + 1], [])
try:
text = bytearray(concatenated_bytes).decode('utf-8', errors='strict')
token_end_positions.append(len(text))
except UnicodeDecodeError:
token_end_positions.append(None) # not a full utf-8 charachter

assert text == tokenizer.decode(token_ids, clean_up_tokenization_spaces=False)

# cleaning up tokenization spaces, shifting token_end_positions
# (see .clean_up_tokenization() in PreTrainedTokenizerBase)
if tokenizer.clean_up_tokenization_spaces:
for replace_from in [" .", " ?", " !", " ,", " ' ", " n't", " 'm", " 's", " 've", " 're"]:
replace_to = replace_from.strip()
while (start_pos := text.find(replace_from)) != -1:
delta_len = len(replace_to) - len(replace_from)
text = text[:start_pos] + replace_to + text[start_pos + len(replace_from):]
token_end_positions = [
(
token_end_pos
if token_end_pos <= start_pos
else token_end_pos + delta_len
)
for token_end_pos in token_end_positions
]

assert text == tokenizer.decode(token_ids)

# tokenizing the text
tokenized_text = TokenizedText.from_text(text)

# matching words and tokens
tokens_range_per_word = []
for word in tokenized_text.get_words():
first_token_idx = None # first token of the word, inclusive
for token_idx, token_end_pos in enumerate(token_end_positions):
if token_end_pos is None:
continue
if token_end_pos > word.start and first_token_idx is None:
first_token_idx = token_idx
if token_end_pos >= word.stop:
break
tokens_range_per_word.append((first_token_idx, token_idx + 1))

tokens_per_word = [
[
bytearray(b).decode('utf-8', errors='replace')
for b in bytes_list_per_token[start_token_idx:end_token_idx]
]
for start_token_idx, end_token_idx in tokens_range_per_word
]

token_scores_per_word = [
list(score_per_token[start_token_idx:end_token_idx])
for start_token_idx, end_token_idx in tokens_range_per_word
]

return tokenized_text, tokens_per_word, token_scores_per_word
Loading