diff --git a/asr/asr.py b/asr/asr.py index 9c015ab..ced7a87 100644 --- a/asr/asr.py +++ b/asr/asr.py @@ -1,4 +1,5 @@ import copy +from dataclasses import dataclass import gc import logging from typing import List, Optional, Tuple, Union @@ -8,6 +9,7 @@ from tqdm import tqdm import torch from transformers import pipeline, Pipeline +from scipy.signal import resample_poly from utils.utils import time_to_str from wav_io.wav_io import TARGET_SAMPLING_FREQUENCY @@ -164,8 +166,7 @@ def initialize_model_for_speech_segmentation(language: str = 'ru', model_info: O - for language='en': 'jonatasgrosman/wav2vec2-large-xlsr-53-english'. Returned value: an AutomaticSpeechRecognitionPipeline, to be called on mono sound with rate 16_000 - and argument `return_timestamps='word'`. In Pisets, only the output timestamps are used, not the - transcribed speech. + and argument `return_timestamps='word'`. NOTE: the pipeline should have the ability to process long audios. To achieve this, the method calls the `transformers.pipeline` factory with arguments `chunk_length_s=10, stride_length_s=(4, 2)`. @@ -288,11 +289,27 @@ def initialize_model_for_speech_recognition(language: str = 'ru', model_info: Op else: model_name = 'openai/whisper-large-v3' try: + pipeline_kwargs = {} + if 'whisper' in model_name.lower(): + if language == 'ru': + pipeline_kwargs['generate_kwargs'] = { + 'language': '<|ru|>', + 'task': 'transcribe', + 'forced_decoder_ids': None + } + elif language == 'en': + pipeline_kwargs['generate_kwargs'] = { + 'language': '<|en|>', + 'task': 'transcribe', + 'forced_decoder_ids': None + } + if torch.cuda.is_available(): recognizer = pipeline( 'automatic-speech-recognition', model=model_name, chunk_length_s=20, stride_length_s=(4, 2), - device='cuda:0', model_kwargs={'attn_implementation': 'sdpa'}, torch_dtype=torch.float16 + device='cuda:0', model_kwargs={'attn_implementation': 'sdpa'}, torch_dtype=torch.float16, + **pipeline_kwargs ) else: recognizer = pipeline( @@ -306,11 +323,11 @@ def initialize_model_for_speech_recognition(language: str = 'ru', model_info: Op def select_word_groups( - words: List[Tuple[float, float]], + words: List[Tuple[float, float, str]], segment_size: float -) -> List[List[Tuple[float, float]]]: +) -> List[List[Tuple[float, float, str]]]: """ - Accepts a list of consecutive segments, each segment is a tuple (start_time, end_time). + Accepts a list of consecutive segments, each segment is a tuple (start_time, end_time, transcription). Iteratively splits the list of segments into "left" and "right" part by the largest pause between segments, then splits both parts the same way, and so on. A list of segments is splitted only if its total length @@ -322,7 +339,7 @@ def select_word_groups( Example: ``` - A, B, C, D, E, F = (3, 4), (4, 9), (12, 14), (14.5, 18), (18, 20), (28, 29) + A, B, C, D, E, F = (3, 4, ''), (4, 9, ''), (12, 14, ''), (14.5, 18, ''), (18, 20, ''), (28, 29, '') result = select_word_groups([A, B, C, D, E, F], segment_size=9) assert result == [[A, B], [C, D, E], [F]] @@ -364,19 +381,22 @@ def select_word_groups( def strip_segments( - segments: List[Tuple[float, float]], + segments: List[Tuple[float, float, str]], max_sound_duration: float ) -> List[Tuple[float, float]]: """ - Clips tuples (start_time, end_time) between (0, max_sound_duration). + Clips tuples (start_time, end_time, transcription) between (0, max_sound_duration). """ - return [(max(0.0, it[0]), min(it[1], max_sound_duration)) for it in segments] + return [ + (max(0.0, start), min(end, max_sound_duration), transcription) + for start, end, transcription in segments + ] def join_short_segments_to_long_ones( - segments: List[Tuple[float, float]], + segments: List[Tuple[float, float, str]], min_segment_size: float -) -> List[Tuple[float, float]]: +) -> List[Tuple[float, float, str]]: """ Iterates segments from left to right and merges two segments if two conditions are met: 1) The segment is shorter than `min_segment_size` @@ -392,7 +412,7 @@ def join_short_segments_to_long_ones( segment_idx = 0 while segment_idx < len(new_segments): - segment_start, segment_end = new_segments[segment_idx] + segment_start, segment_end, text = new_segments[segment_idx] if (segment_end - segment_start) < min_segment_size: if (segment_idx > 0) and (segment_idx < len(new_segments) - 1): distance_to_left = segment_start - new_segments[segment_idx - 1][1] @@ -401,7 +421,8 @@ def join_short_segments_to_long_ones( if distance_to_left < min_segment_size: new_segments[segment_idx - 1] = ( new_segments[segment_idx - 1][0], - segment_end + segment_end, + new_segments[segment_idx - 1][2] + ' ' + text ) _ = new_segments.pop(segment_idx) else: @@ -410,7 +431,8 @@ def join_short_segments_to_long_ones( if distance_to_right < min_segment_size: new_segments[segment_idx + 1] = ( segment_start, - new_segments[segment_idx + 1][1] + new_segments[segment_idx + 1][1], + text + ' ' + new_segments[segment_idx + 1][2] ) _ = new_segments.pop(segment_idx) else: @@ -420,7 +442,8 @@ def join_short_segments_to_long_ones( if distance_to_left < min_segment_size: new_segments[segment_idx - 1] = ( new_segments[segment_idx - 1][0], - segment_end + segment_end, + new_segments[segment_idx - 1][2] + ' ' + text ) _ = new_segments.pop(segment_idx) else: @@ -430,7 +453,8 @@ def join_short_segments_to_long_ones( if distance_to_right < min_segment_size: new_segments[segment_idx + 1] = ( segment_start, - new_segments[segment_idx + 1][1] + new_segments[segment_idx + 1][1], + text + ' ' + new_segments[segment_idx + 1][2] ) _ = new_segments.pop(segment_idx) else: @@ -449,18 +473,18 @@ def segment_sound( min_segment_size: float, max_segment_size: float, indent_for_silence: float = 0.5 -) -> List[Tuple[float, float]]: +) -> List[Tuple[float, float, str]]: """ Arguments: - mono_sound: 1D waveform with rate 16_000 (equals wav_io.TARGET_SAMPLING_FREQUENCY), possibly very long, and no shorter than asr.MIN_SOUND_LENGTH. - segmenter: an AutomaticSpeechRecognitionPipeline that can process long audios and - returns word timestamps. See `initialize_model_for_speech_segmentation` for details. + returns transcriptions and word timestamps. See `initialize_model_for_speech_segmentation` for details. - min_segment_size: see below - max_segment_size: see below - indent_for_silence: see below - Output: a list of tuples (start_time, end_time) for all found utterances, can be empty. + Output: a list of tuples (start_time, end_time, transcription) for all found utterances, can be empty. Performs the following actions: 1) Obtains speech segment boundaries by applying `segmenter` to `mono_sound`. @@ -506,23 +530,39 @@ def segment_sound( gc.collect() torch.cuda.empty_cache() - word_bounds = [(float(it['timestamp'][0]), float(it['timestamp'][1])) for it in output['chunks']] + word_bounds = [ + ( + float(it['timestamp'][0]), + float(it['timestamp'][1]), + str(it['text']) + ) + for it in output['chunks'] + ] if len(word_bounds) < 1: return [] if len(word_bounds) == 1: segment_start = word_bounds[0][0] - indent_for_silence segment_end = word_bounds[0][1] + indent_for_silence - return strip_segments([(segment_start, segment_end)], + full_transcription = word_bounds[0][2] + return strip_segments([(segment_start, segment_end, full_transcription)], mono_sound.shape[0] / TARGET_SAMPLING_FREQUENCY) if (word_bounds[-1][1] - word_bounds[0][0]) <= max_segment_size: segment_start = word_bounds[0][0] - indent_for_silence segment_end = word_bounds[-1][1] + indent_for_silence - return strip_segments([(segment_start, segment_end)], + full_transcription = ' '.join(text for _, _, text in word_bounds) + return strip_segments([(segment_start, segment_end, full_transcription)], mono_sound.shape[0] / TARGET_SAMPLING_FREQUENCY) word_groups = select_word_groups(word_bounds, max_segment_size) segments = strip_segments( - [(cur_group[0][0] - indent_for_silence, cur_group[-1][1] + indent_for_silence) for cur_group in word_groups], + [ + ( + cur_group[0][0] - indent_for_silence, + cur_group[-1][1] + indent_for_silence, + ' '.join(text for _, _, text in cur_group) + ) + for cur_group in word_groups + ], mono_sound.shape[0] / TARGET_SAMPLING_FREQUENCY ) n_segments = len(segments) @@ -531,8 +571,8 @@ def segment_sound( for idx in range(1, n_segments): if segments[idx - 1][1] > segments[idx][0]: overlap = segments[idx - 1][1] - segments[idx][0] - segments[idx - 1] = (segments[idx - 1][0], segments[idx - 1][1] - overlap / 2.0) - segments[idx] = (segments[idx][0] + overlap / 2.0, segments[idx][1]) + segments[idx - 1] = (segments[idx - 1][0], segments[idx - 1][1] - overlap / 2.0, segments[idx - 1][2]) + segments[idx] = (segments[idx][0] + overlap / 2.0, segments[idx][1], segments[idx][2]) return join_short_segments_to_long_ones(segments, min_segment_size) @@ -555,7 +595,7 @@ def is_speech(sound: np.ndarray, classifier: Pipeline) -> bool: return contains_speech -def recognize_sounds(sounds: List[np.ndarray], recognizer: Pipeline) -> List[str]: +def recognize_sounds(sounds: List[np.ndarray], recognizer: Pipeline, stretch: tuple[int, int] | None = None) -> List[str]: """ Arguments: - mono_sound: a list of 1D waveforms with rate 16_000 (equals wav_io.TARGET_SAMPLING_FREQUENCY) @@ -583,6 +623,16 @@ def recognize_sounds(sounds: List[np.ndarray], recognizer: Pipeline) -> List[str torch.cuda.empty_cache() return [remove_oscillatory_hallucinations(it) for it in all_transcriptions] +@dataclass +class TranscribedSegment: + """ + A transcribed segment. See `.transcribe()` function for details. + """ + start: float + end: float + transcription: str + transcription_from_segmenter: str | None = None + transcription_stretched: str | None = None def transcribe( mono_sound: np.ndarray, @@ -590,8 +640,9 @@ def transcribe( voice_activity_detector: Pipeline, asr: Pipeline, min_segment_size: float, - max_segment_size: float -) -> List[Tuple[float, float, str]]: + max_segment_size: float, + stretch: tuple[int, int] | None = None, +) -> List[TranscribedSegment]: """ Transcribes a (possibly long) audio as follows: @@ -607,30 +658,71 @@ def transcribe( - mono_sound: 1D waveform with rate 16_000 (equals wav_io.TARGET_SAMPLING_FREQUENCY), no shorter than asr.MIN_SOUND_LENGTH. - segmenter: an AutomaticSpeechRecognitionPipeline that can process long audios and - returns word timestamps. See `initialize_model_for_speech_segmentation` for details. + returns transcriptions and word timestamps. See `initialize_model_for_speech_segmentation` for details. - voice_activity_detector: an AudioClassificationPipeline that can classify audios. See `initialize_model_for_speech_classification` for details. - asr: an AutomaticSpeechRecognitionPipeline that can return transcriptions. See `initialize_model_for_speech_recognition` for details. - min_segment_size: a parameter for segment processing, see `segment_sound` for details. - max_segment_size: a parameter for segment processing, see `segment_sound` for details. + - stretch: if specified, stretches each segment in `stretch[1]/stretch[0]` times and perform an + additional speech recognition with `asr` pipeline. The results are returned in + `.transcription_stretched` field of `TranscribedSegment`. - Output: a list of tuples (start_time, end_time, transcription) for all found utterances, - can be empty. + Output: a list of `TranscribedSegment` for all found utterances, can be empty: + - `.transcription`: a transcription from `asr` Pipeline. + - `.transcription_from_segmenter`: a transcription from `segmenter` Pipeline. + - `.transcription_stretched`: a transcription of stretched segment from `asr` Pipeline + (if `stretch` agument is provided) Example: ``` + from wav_io.wav_io import load_sound + from asr.asr import * + waveform = load_sound('tests/testdata/mono_sound.wav') segmenter = initialize_model_for_speech_segmentation() - voice_activity_detector = initialize_model_for_speech_classification() - asr = initialize_model_for_speech_recognition('ru', 'openai/whisper-tiny') - transcribe(waveform, segmenter, vad, asr, min_segment_size=1, max_segment_size=5) + vad = initialize_model_for_speech_classification() + asr = initialize_model_for_speech_recognition('ru', 'openai/whisper-large-v3') + results = transcribe(waveform, segmenter, vad, asr, min_segment_size=1, max_segment_size=5, stretch=(3, 4)) + print(results) >>> [ - (0.0, 4.18, 'Она советовала нам отнести и спасену предмету к одному почтиному мужу.'), - (4.18, 6.8100000000000005, 'Большому другому и вану переселший годы.'), - (6.8100000000000005, 11.28, 'счастливые дни, как вешные воды, промчались они.') + TranscribedSegment( + start=0.0, + end=4.18, + transcription='она советовала нам отнестись посему предмету к одному почтенному мужу', + transcription_from_segmenter='Она советовала нам отнестись по всему предмету к одному почтенному мужу.', + transcription_stretched='Она советовала нам отнестись по всему предмету к одному почтенному мужу.' + ), + TranscribedSegment( + start=4.18, + end=6.8100000000000005, + transcription='бывшему другам ивану переселые годы', + transcription_from_segmenter='бывшему другом Ивану Петровичу.', + transcription_stretched='Бывшему другом Ивану Петровичу.' + ), + TranscribedSegment( + start=6.8100000000000005, + end=11.28, + transcription='счастливые дни как вешние воды промчались они', + transcription_from_segmenter='Счастливые дни, как вешние воды, промчались они.', + transcription_stretched='Счастливые дни, как вешние воды, промчались они.' + ) ] + + from asr.comparison import MultipleTextsAlignment, visualize_correction_suggestions + + for result in results: + suggestions = MultipleTextsAlignment.from_strings( + result.transcription, + result.transcription_stretched + ).get_correction_suggestions() + print(visualize_correction_suggestions(result.transcription, suggestions)) + + >>> она советовала нам отнестись {посему|по всему} предмету к одному почтенному мужу + бывшему {другам|другом} ивану {переселые годы|Петровичу} + счастливые дни как вешние воды промчались они ``` TODO when calling `voice_activity_detector` and `asr`, process all segments at once as @@ -685,14 +777,22 @@ def transcribe( return [] recognized_transcriptions = recognize_sounds( sounds=sounds_with_speech, - recognizer=asr + recognizer=asr, ) - del sounds_with_speech - results = list(filter( - lambda it2: len(it2[2]) > 0, - map( - lambda it: (it[0][0], it[0][1], it[1].strip()), - zip(segments_with_speech, recognized_transcriptions) + results = [ + TranscribedSegment(start, end, transcription.strip(), transcription_from_segmenter.strip()) + for (start, end, transcription_from_segmenter), transcription + in zip(segments_with_speech, recognized_transcriptions) + if len(transcription.strip()) > 0 + ] + if stretch is not None: + transcriptions_stretched = recognize_sounds( + sounds=[ + resample_poly(sound, up=stretch[0], down=stretch[1]) + for sound in sounds_with_speech + ], + recognizer=asr, ) - )) - return results + for result, t in zip(results, transcriptions_stretched): + result.transcription_stretched = t + return results \ No newline at end of file diff --git a/asr/comparison.py b/asr/comparison.py new file mode 100644 index 0000000..84e5253 --- /dev/null +++ b/asr/comparison.py @@ -0,0 +1,720 @@ +from __future__ import annotations + +import copy +from dataclasses import dataclass +import difflib +from typing import Iterable, Literal +import numpy as np +import razdel +from pymystem3 import Mystem +from tqdm.auto import tqdm + +@dataclass +class Substring: + """ + Intended to store information about where words or punctuation marks are located + in a text. + + This class is an extension of razdel.substring.Substring to store additional flags. + """ + start: int + stop: int + text: str + is_punct: bool + +@dataclass +class TokenizedText: + """ + Stores text and positions of tokens (words and punctuation marks). + + Tokenization is performed using Razdel (tested for Ru and En). A token + is considered a punctuation mark if it does not contain letters or digits. + + Example: + ``` + tokenized = TokenizedText.from_text('Это "тестовый" текст. !!') + tokenized.tokens + + >>> [Substring(start=0, stop=3, text='это', is_punct=False), + Substring(start=4, stop=5, text='"', is_punct=True), + Substring(start=5, stop=13, text='тестовый', is_punct=False), + Substring(start=13, stop=14, text='"', is_punct=True), + Substring(start=15, stop=20, text='текст', is_punct=False), + Substring(start=20, stop=21, text='.', is_punct=True), + Substring(start=22, stop=24, text='!!', is_punct=True)] + + tokenized.get_words() + + >>> [Substring(start=0, stop=3, text='это', is_punct=False), + Substring(start=5, stop=13, text='тестовый', is_punct=False), + Substring(start=15, stop=20, text='текст', is_punct=False)] + ``` + """ + text: str + tokens: list[Substring] + + def get_words(self) -> list[Substring]: + """ + Returns a list of words (skips punctuation marks). + """ + return [t for t in self.tokens if not t.is_punct] + + @classmethod + def from_text(cls, text: str, dash_as_separator: bool = True) -> TokenizedText: + orig_text = text + if dash_as_separator: + text = text.replace('-', ' ') + tokens = [ + Substring( + start=t.start, + stop=t.stop, + text=t.text.lower(), + is_punct=all(not c.isalnum() for c in t.text) + ) + for t in razdel.tokenize(text) + ] + return TokenizedText(text=orig_text, tokens=tokens) + + @classmethod + def concatenate(cls, texts: list[TokenizedText], sep: str = ' ') -> TokenizedText: + result_text = '' + result_tokens = [] + for i, tokenized_text in enumerate(texts): + shift = len(result_text) + result_text += tokenized_text.text + for token in tokenized_text.tokens: + token = copy.copy(token) + token.start += shift + token.stop += shift + result_tokens.append(token) + if i < len(texts) - 1: + result_text += sep + + return TokenizedText(text=result_text, tokens=result_tokens) + +@dataclass +class WordLevelMatch: + """ + A dataclass variant of `difflib.SequenceMatcher` outputs. Represents a matching + part between two lists: `list1[start1:end1]` matches `list2[start2:end2]` + + If self.len1 == self.len2, may be additionally be marked as equal or not equal + match (if not equal this Match represents a replacement operation). + + Use case: usually indices in Match are word indices (not character indices). + """ + start1: int + end1: int + start2: int + end2: int + is_equal: bool + + char_start1: int | None = None + char_end1: int | None = None + char_start2: int | None = None + char_end2: int | None = None + + def __post_init__(self): + assert self.len1 > 0 or self.len2 > 0 + if self.is_equal: + assert self.len1 == self.len2 + + @property + def len1(self) -> int: + return self.end1 - self.start1 + + @property + def len2(self) -> int: + return self.end2 - self.start2 + + @property + def is_replace(self) -> bool: + return self.len1 > 0 and self.len2 > 0 and not self.is_equal + + @property + def is_insert(self) -> bool: + return self.len1 == 0 + + @property + def is_delete(self) -> bool: + return self.len2 == 0 + + +@dataclass +class MultipleTextsAlignment: + """ + Stores text, divided into words, and a list of found matches between the words. + + In the following example, we have two texts: + ``` + text_1 = 'Aaaa aa, bb-bb' + text_2 = 'Aa bbbb cc cc!' + ``` + + We split them into words with `TokenizedText`, which uses Razdel library under the hood. + `TokenizedText` keeps a list of tokens, each token is either a lower-case word, or a + punctuation mark. + ``` + tokenized_text_1 = TokenizedText.from_text(text_1) + tokenized_text_2 = TokenizedText.from_text(text_2) + print(tokenized_text_1.tokens, tokenized_text_2.tokens) + + >>> [ + Substring(start=0, stop=4, text='aaaa', is_punct=False), + Substring(start=5, stop=7, text='aa', is_punct=False), + Substring(start=7, stop=8, text=',', is_punct=True), + Substring(start=9, stop=14, text='bb-bb', is_punct=False) + ], [ + Substring(start=0, stop=2, text='aa', is_punct=False), + Substring(start=3, stop=7, text='bbbb', is_punct=False), + Substring(start=8, stop=10, text='cc', is_punct=False), + Substring(start=11, stop=13, text='cc', is_punct=False), + Substring(start=13, stop=14, text='!', is_punct=True) + ] + ``` + + We then match only words (a method `TokenizedText.get_words()`) in both texts: + ``` + word_matches=MultipleTextsAlighment.get_matches( + tokenized_text_1.get_words(), + tokenized_text_2.get_words() + ) + print(word_matches) + + >>> [ + WordLevelMatch(start1=0, end1=1, start2=0, end2=0, is_equal=False), + WordLevelMatch(start1=1, end1=2, start2=0, end2=1, is_equal=True), + WordLevelMatch(start1=2, end1=3, start2=1, end2=2, is_equal=False), + WordLevelMatch(start1=3, end1=3, start2=2, end2=4, is_equal=False) + ] + ``` + + For example, consider the last `WordLevelMatch`. It means that words [3:3] in + `tokenized_text_1.tokens` match words [2:4] in `tokenized_text_2.tokens`. Since + the first span is empty, this means that the last two words "cc" and "cc" in the + second text have no counterparts in the first text. As for the other matches: + + - The 1st match is a deletion (the word "aaaa" is present only in the first text) + - The 2nd match is an equality (the word "aa" is present in both texts) + - The 3rd match is a replacement (the word "bb-bb" is replaced by "bbbb") + - The 4th match is an insertion (the words "cc cc" are present only in the second text) + + Now we can construct `MultipleTextsAlighment`: + ``` + alignment = MultipleTextsAlignment(tokenized_text_1, tokenized_text_2, word_matches) + ``` + + Or we can get the same result from the original texts using `.from_strings()`: + ``` + alignment = MultipleTextsAlignment.from_strings(text_1, text_2) + ``` + + Now we can obtain the corrections that the second text suggests when compared with the + first text. Here the positions (`start_pos`, `end_pos`) are character positions in the + original `text_1`. + ``` + suggestions = alignment.get_correction_suggestions() + print(suggestions) + + >>> [ + CorrectionSuggestion(start_pos=0, end_pos=4, suggestion=''), + CorrectionSuggestion(start_pos=9, end_pos=14, suggestion='bbbb'), + CorrectionSuggestion(start_pos=14, end_pos=14, suggestion=' cc cc') + ] + ``` + + We can visualize them in brackets, so that we can see all the matches: the deletion, + the equality, the replacement and the insertion: + ``` + print(visualize_correction_suggestions(text_1, suggestions)) + + >>> '{Aaaa} aa, {bb-bb|bbbb} {+cc cc}' + ``` + + NOTE: while this class keeps a list of `WordLevelMatch`, and each match `m` may be one of + `m.is_equal`, `m.is_delete`, `m.is_insert` or `m.is_replace`, they do not directly correspond + one-to-one to "delete", "insert" and "replace" operations in Word Error Rate (WER) metric. + Example: + + ``` + print(MultipleTextsAlignment.from_strings('a b c', 'd e').matches) + + >>> [WordLevelMatch(start1=0, end1=3, start2=0, end2=2, is_equal=False)] + ``` + + We can see a single "replace" operation from 3 words to 2 words. However, in WER metric this + will be considered as two "replace" and one "delete" operation. To calculate WER correctly, + use `.wer` method. + """ + text1: TokenizedText + text2: TokenizedText + matches: list[WordLevelMatch] + + @classmethod + def from_strings( + cls, + text1: str | TokenizedText, + text2: str | TokenizedText, + ) -> MultipleTextsAlignment: + if isinstance(text1, str): + text1 = TokenizedText.from_text(text1) + if isinstance(text2, str): + text2 = TokenizedText.from_text(text2) + return MultipleTextsAlignment( + text1=text1, + text2=text2, + matches=MultipleTextsAlignment.get_matches( + text1.get_words(), + text2.get_words(), + ) + ) + + def get_uncertainty_mask(self, match_indices: list[int] | None = None) -> np.ndarray: + is_certain = np.full(len(self.text1.get_words()), False) + for i, match in enumerate(self.matches): + if match_indices is not None and i not in match_indices: + is_certain[match.start1:match.end1] = True + else: + is_certain[match.start1:match.end1] = match.is_equal + return ~is_certain + + def wer( + self, + max_insertions: int | None = 4, + uncertainty_mask: np.ndarray = None, + ) -> dict: + """ + Calculates WER. `max_insertions` allows to make WER more robust by not penalizing + too much insertions in a row (usually an oscillatory hallucinations of ASR model). + + TODO switch to n unique insertions + """ + _max_insertions = float('inf') if max_insertions is None else max_insertions + + words1 = self.text1.get_words() + words2 = self.text2.get_words() + + n_equal = sum([m.len1 for m in self.matches if m.is_equal]) + n_deletions = sum([m.len1 for m in self.matches if m.is_delete]) + n_insertions = sum([min(m.len2, _max_insertions) for m in self.matches if m.is_insert]) + n_replacements = 0 + + # replace operations contrubute to n_deletions and n_insertions if len1 != len2 + for match in self.matches: + if match.is_replace: + if match.len1 > match.len2: + n_replacements += match.len2 + n_deletions += match.len1 - match.len2 + elif match.len1 < match.len2: + n_replacements += match.len1 + n_insertions += min(match.len2 - match.len1, _max_insertions) + else: + n_replacements += match.len1 + + assert n_equal + n_deletions + n_replacements == len(words1) + if max_insertions is None: + assert n_equal + n_insertions + n_replacements == len(words2) + + results = {'wer': (n_deletions + n_insertions + n_replacements) / len(words1)} + + if uncertainty_mask is not None: + assert len(uncertainty_mask) == len(words2) + uncertainty_mask = uncertainty_mask.astype(bool) + + certain_n_correct = 0 + certain_n_incorrect = 0 + uncertain_n_correct = 0 + uncertain_n_incorrect = 0 + + for match in self.matches: + mask = uncertainty_mask[match.start2:match.end2] + if match.is_equal: + uncertain_n_correct += mask.sum() + certain_n_correct += (~mask).sum() + elif (match.is_insert or match.is_replace): + uncertain_n_incorrect += mask.sum() + certain_n_incorrect += (~mask).sum() + + if uncertainty_mask is not None: + results['certain_n_correct'] = certain_n_correct + results['certain_n_incorrect'] = certain_n_incorrect + results['uncertain_n_correct'] = uncertain_n_correct + results['uncertain_n_incorrect'] = uncertain_n_incorrect + results['certain_accuracy'] = ( + certain_n_correct / (certain_n_correct + certain_n_incorrect) + ) + results['uncertain_accuracy'] = ( + uncertain_n_correct / (uncertain_n_correct + uncertain_n_incorrect) + ) + results['precision'] = ( + results['uncertain_n_incorrect'] + / (results['uncertain_n_incorrect'] + results['uncertain_n_correct']) + ) + results['recall'] = ( + results['uncertain_n_incorrect'] + / (results['uncertain_n_incorrect'] + results['certain_n_incorrect']) + ) + results['uncertainty_ratio'] = uncertainty_mask.mean() + results['report'] = ( + f'uncertainty_ratio {results["uncertainty_ratio"]:.3f}' + f', certain acc. {results["certain_accuracy"]:.3f}' + f', uncertain acc. {results["uncertain_accuracy"]:.3f}' + f', precision {results["precision"]:.3f}' + f', recall {results["recall"]:.3f}' + ) + + return results + + @staticmethod + def get_matches( + words1: list[Substring], + words2: list[Substring], + diff_only: bool = False, + improved_matching: bool = True, + ) -> list[WordLevelMatch]: + """ + Finds matching words (excluding punctuation) in two word lists. If `diff_only`, + returns only non-equal matches: deletions, additions or changes. + + With `improved_matching=True`, performs postprocessing after `difflib.SequenceMatcher` + to split of join some matches. + """ + # get operations (delete, insert, replace, equal) + difflib_opcodes: list[tuple[str, int, int, int, int]] = difflib.SequenceMatcher( + None, + [t.text for t in words1], + [t.text for t in words2], + autojunk=False + ).get_opcodes() + + ops: list[WordLevelMatch] = [ + WordLevelMatch(start1, end1, start2, end2, is_equal=(op == 'equal')) + for op, start1, end1, start2, end2 in difflib_opcodes + ] + + # now we have a list of Match-es between words1 and words2 + + if improved_matching: + for _ in range(10): + # improvements over plain SequenceMatcher + ops, was_change1 = MultipleTextsAlignment._maybe_split_replace_ops(words1, words2, ops) + ops, was_change2 = MultipleTextsAlignment._maybe_join_subsequent_ops(words1, words2, ops) + + if not was_change1 and not was_change2: + break + + if diff_only: + # consider only non-equal matches + ops = [op for op in ops if not op.is_equal] + + # set character positions for each WordLevelMatch + for op in ops: + if op.start1 != op.end1: + op.char_start1 = words1[op.start1].start + op.char_end1 = words1[op.end1 - 1].stop + else: + if op.end1 > 0: + op.char_start1 = op.char_end1 = words1[op.end1 - 1].stop + else: + op.char_start1 = op.char_end1 = words1[op.end1].start + + if op.start2 != op.end2: + op.char_start2 = words2[op.start2].start + op.char_end2 = words2[op.end2 - 1].stop + else: + if op.end2 > 0: + op.char_start2 = op.char_end2 = words2[op.end2 - 1].stop + else: + op.char_start2 = op.char_end2 = words2[op.end2].start + + return ops + + @staticmethod + def _string_match_score(word1: str, word2: str) -> float: + """ + How similar are two strings (character-wise)? + """ + return difflib.SequenceMatcher(None, word1, word2).ratio() + + @staticmethod + def _maybe_split_replace_ops( + words1: list[Substring], + words2: list[Substring], + ops: list[WordLevelMatch], + ) -> tuple[list[WordLevelMatch], bool]: + """ + We try to split some "replace" ops into two ops, such as + replace('aaaa bbb ccc', 'aaa') -> replace('aaaa', 'aaa') + delete('bbb ccc') + + Returns + - a new ops list + - flag that is True if any changes were made + """ + new_ops: list[WordLevelMatch] = [] + for match in ops: + start1, end1, start2, end2 = match.start1, match.end1, match.start2, match.end2 + if not match.is_replace: + new_ops.append(match) + else: + if MultipleTextsAlignment._string_match_score(words1[start1].text, words2[start2].text) > 0.5: + new_ops.append(WordLevelMatch(start1, start1 + 1, start2, start2 + 1, is_equal=False)) + if end1 > start1 + 1 or end2 > start2 + 1: + new_ops.append(WordLevelMatch(start1 + 1, end1, start2 + 1, end2, is_equal=False)) + elif MultipleTextsAlignment._string_match_score(words1[end1 - 1].text, words2[end2 - 1].text) > 0.5: + if end1 - 1 > start1 or end2 - 1 > start2: + new_ops.append(WordLevelMatch(start1, end1 - 1, start2, end2 - 1, is_equal=False)) + new_ops.append(WordLevelMatch(end1 - 1, end1, end2 - 1, end2, is_equal=False)) + else: + new_ops.append(match) + + return new_ops, (ops != new_ops) + + @staticmethod + def _maybe_join_subsequent_ops( + words1: list[Substring], + words2: list[Substring], + ops: list[WordLevelMatch], + ) -> tuple[list[WordLevelMatch], bool]: + """ + We try to merge two subsequent ops, such as + delete('no', '') + replace('thing', 'nothing') -> replace('no thing', 'nothing') + + Returns + - a new ops list + - flag that is True if any changes were made + """ + new_ops: list[WordLevelMatch] = [] + i = 0 + while i < len(ops): + op = ops[i] + if i == len(ops) - 1: + # the last op, cannot merge with subsequent op + new_ops.append(op) + i += 1 + continue + next_op = ops[i + 1] + if op.end1 != next_op.start1 or op.end2 != next_op.start2: + # ops are not close to each other + new_ops.append(op) + i += 1 + continue + if op.is_equal and next_op.is_equal: + # we usually shouldn't have two `.is_equal` ops in a row, but just in case + new_ops.append(op) + i += 1 + continue + op_words1 = ' '.join(x.text for x in words1[op.start1:op.end1]) + op_words2 = ' '.join(x.text for x in words2[op.start2:op.end2]) + next_op_words1 = ' '.join(x.text for x in words1[next_op.start1:next_op.end1]) + next_op_words2 = ' '.join(x.text for x in words2[next_op.start2:next_op.end2]) + + match_score = MultipleTextsAlignment._string_match_score(op_words1, op_words2) + next_match_score = MultipleTextsAlignment._string_match_score(next_op_words1, next_op_words2) + joint_match_score = MultipleTextsAlignment._string_match_score( + (op_words1 + ' ' + next_op_words1).strip(), + (op_words2 + ' ' + next_op_words2).strip() + ) + + if joint_match_score > max(match_score, next_match_score): + # merging ops + new_ops.append(WordLevelMatch(op.start1, next_op.end1, op.start2, next_op.end2, is_equal=False)) + i += 2 # skipping the next op, since we've already merged it + else: + new_ops.append(op) + i += 1 + + return new_ops, (ops != new_ops) + + def substitute( + self, + replace: Iterable[int] | None = None, + show_in_braces: Iterable[int] | None = None, + pref_first: Iterable[int] | None = None, + pref_second: Iterable[int] | None = None, + ) -> str: + """ + This function can insert fragments from the second text to the first text, + based on matches. + + Explanation. Let we have a `MultipleTextsAlignment` with a single non-equal match + (difference): + + ``` + text1 = "aa bb! cc!" + text2 = "aa bbb cc" + al = MultipleTextsAlignment.from_strings(text1, text2) + print([m for m in al.matches if not m.is_equal]) + >>> [WordLevelMatch(start1=1, end1=2, start2=1, end2=2, is_equal=False, + char_start1=3, char_end1=5, char_start2=3, char_end2=6)] + ``` + + The difference `m = al.matches[1]` corresponds to a substring in both texts: + 1) A segment in the 1st test: `al.text1.text[m.char_start1:m.char_end1] == 'bb'` + 2) A segment in the 2nd text: `al.text2.text[m.char_start2:m.char_end2] == 'bbb'`. + + Based on this, we can cut out the segment from the 1st text, and replace it + with the segment from the 2nd text. This is exactly what does the `substitute` method. + The `replace` argument is a list of all differences to apply. + + ``` + print(al.substitute(replace=[1])) + >>> "aa bbb! cc!" + ``` + + The `show_in_braces` is also a list of differences. It does not replace text parts, but + visualize both variants in {braces}. + - {aaa|bbb} - suggest to replace aaa to bbb + - {aaa} - suggest to remove aaa + - {+aaa} - suggest to insert aaa (not present in `text1`) + + ``` + text1 = 'она советовала нам отнестись посему предмету к одному почтенному мужу' + text2 = 'Она советовала нам отнести и спасену предмету к одному почтиному мужу.' + al = MultipleTextsAlignment.from_strings(text1, text2) + al.substitute(show_in_braces=range(len(al.matches))) + >>> 'она советовала нам {отнестись|отнести} {+и} {посему|спасену} предмету к одному {почтенному|почтиному} мужу' + ``` + """ + text1 = self.text1.text + text2 = self.text2.text + + replace = list(replace) if replace is not None else [] + show_in_braces = list(show_in_braces) if show_in_braces is not None else [] + + pref_first = list(pref_first) if pref_first is not None else [] + pref_second = list(pref_second) if pref_second is not None else [] + # assert set(pref_first).intersection(set(pref_second)) == set() + + result = '' + text1_idx = 0 + + for op_idx, op in enumerate(self.matches): + if op.is_equal: + continue + + result += text1[text1_idx:op.char_start1] + text1_idx = op.char_start1 + + segment1 = text1[op.char_start1:op.char_end1] + segment2 = text2[op.char_start2:op.char_end2] + + if op_idx in replace: + fragment = segment2 + + elif op_idx in show_in_braces: + if len(segment1) == 0: + formatting = 'add' + elif len(segment2) == 0: + formatting = 'remove' + else: + formatting = 'correct' + + if op_idx in pref_first: + segment1 = '!' + segment1 + if op_idx in pref_second: + segment2 = '!' + segment2 + + if formatting == 'add': + fragment = '{+' + segment2.strip() + '}' + if text1[op.char_start1] == ' ': + fragment = ' ' + fragment + else: + fragment = fragment + ' ' + elif formatting == 'remove': + fragment = '{' + segment1 + '}' + else: + fragment = '{' + segment1 + '|' + segment2 + '}' + + else: + fragment = segment1 + + result += fragment + text1_idx = op.char_end1 + + result += text1[text1_idx:] + + return result + + +def _is_junk_word(word: str) -> bool: + return word in ['вот', 'ага', 'и', 'а', 'ну', 'это'] + +def _is_junk_word_sequence(text: str) -> bool: + return text in ['то есть', 'да то есть', 'это самое'] + +def _lemmatize(text: str) -> str: + return ''.join(Mystem().lemmatize(text)).strip() # here we need to join with '', not ' ' + +def _should_keep( + alignment: MultipleTextsAlignment, + diff: WordLevelMatch, + skip_word_form_change: bool, +) -> bool: + """ + A single diff variant of .filter_correction_suggestions(). + """ + words1: list[str] = [w.text for w in alignment.text1.get_words()[diff.start1:diff.end1]] + words2: list[str] = [w.text for w in alignment.text2.get_words()[diff.start2:diff.end2]] + + joined1 = ' '.join(words1).lower().replace('ё', 'е') + joined2 = ' '.join(words2).lower().replace('ё', 'е') + + if all([_is_junk_word(w) for w in words1]) and all([_is_junk_word(w) for w in words2]): + # insertion, replacement or deletion of junk words + return False + + if ( + (len(joined1) == 0 or _is_junk_word_sequence(joined1)) + and (len(joined2) == 0 or _is_junk_word_sequence(joined2)) + ): + # insertion, replacement or deletion of junk words + return False + + if diff.is_replace: + if joined1 == joined2: + # the same text + return False + if skip_word_form_change and _lemmatize(joined1) == _lemmatize(joined2): + # different forms of the same words, skip according to `skip_word_form_change=True` + return False + + ru_letters = set('абвгдеёжзийклмнопрстуфхцчшщъыьэюя') + has_ru1 = ru_letters & set(joined1) != set() + has_ru2 = ru_letters & set(joined2) != set() + + if has_ru1 and not has_ru2: + # probably a transliteration or letters-to-digits conversion + return False + if has_ru2 and not has_ru1: + # probably a transliteration or letters-to-digits conversion + return False + + return True + +def filter_correction_suggestions( + alignment: MultipleTextsAlignment, + skip_word_form_change: bool = False, + pbar: bool = True, +) -> list[int]: + """ + Arguments: + - alignment: a `MultipleTextsAlignment` between base speech recognition predictions and + additional predictions from another model. + - skip_word_form_change: whether to skips word form changes + + Outputs: + - Indices all non-equal matches, filtered by several heuristics. This is treated as + suggestions to replace, delete or insert something in the `text1`, based on the + difference between words in both texts. Punctuation is not compared, since + `MultipleTextsAlignment` ignores punctuation. + + NOTE: currently is adapted for Ru language + """ + return [ + i for i, op in enumerate(tqdm(alignment.matches, desc='Filtering suggestions', disable=not pbar)) + if not op.is_equal and _should_keep( + alignment=alignment, + diff=op, + skip_word_form_change=skip_word_form_change + ) + ] \ No newline at end of file diff --git a/asr/lm.py b/asr/lm.py new file mode 100644 index 0000000..f1b92df --- /dev/null +++ b/asr/lm.py @@ -0,0 +1,137 @@ +from itertools import combinations +from typing import Any + +from tqdm import tqdm +import numpy as np + +import torch +import torch.nn.functional as F +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.tokenization_utils_fast import PreTrainedTokenizerBase +from transformers.generation.utils import GenerationMixin + +from asr.comparison import MultipleTextsAlignment + + +class SequenceScore: + """ + Calculates a sequence score for a text from an autoregressive LM. + """ + def __init__( + self, + name: str | None = 'ai-forever/rugpt3large_based_on_gpt2', + tokenizer: PreTrainedTokenizerBase | None = None, + model: GenerationMixin | None = None, + ): + if name is not None: + assert not tokenizer and not model + # https://stackoverflow.com/a/75242984 + tokenizer = AutoTokenizer.from_pretrained(name, add_bos_token=True) + model = AutoModelForCausalLM.from_pretrained(name) + else: + assert tokenizer and model + + + self.tokenizer = tokenizer + self.model = model + self.model.eval() + + def __call__(self, text: str) -> int: + inputs = self.tokenizer([text], return_tensors='pt') + with torch.no_grad(): + logits = self.model(**inputs, return_dict=True).logits[:, :-1] + targets = inputs['input_ids'][:, 1:] + logloss = F.cross_entropy(input=logits.transpose(1, 2), target=targets) + + logloss = logloss.cpu().detach().numpy() + + if np.isnan(logloss): + return 0 # TODO why happens? + + return -logloss + + +def get_all_subsets(elements: list[Any]): + """ + Returns all subsets of a list. + ``` + get_all_subsets([1, 2, 3]) + >>> [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] + ``` + """ + return sum(( + [list(x) for x in combinations(elements, r)] + for r in range(len(elements) + 1) + ), []) + + +def accept_suggestions_by_lm( + base_vs_additional: MultipleTextsAlignment, + suggestion_indices: list[int], + scorer: SequenceScore, + look_forward: int = 2, + context_before: int = 100, + context_after: int = 50, + pbar: bool = True, + verbose: bool = False, +) -> list[int]: + """ + When two predictions disagree, selects one that LM prefers. Returns suggestion_indices + when the second prediction (`base_vs_additional.text2`) was selected. + + TODO better docstring + """ + + orig_indices_to_resolve = suggestion_indices + indices_to_resolve = orig_indices_to_resolve.copy() + indices_accepted = [] + + if pbar: + _pbar = tqdm(total=len(indices_to_resolve)) + + while len(indices_to_resolve): + indices = indices_to_resolve[:look_forward] + + scores = {} + + for indices_to_consider in get_all_subsets(indices): + text = base_vs_additional.substitute(replace=indices_accepted + indices_to_consider) + + start_idx = base_vs_additional.matches[indices[0]].char_start1 + end_idx = ( + base_vs_additional.matches[indices[-1]].char_end1 + + len(text) - len(base_vs_additional.text1.text) + ) + + start_idx -= context_before + end_idx += context_after + + start_idx = np.clip(start_idx, 0, len(text)) + end_idx = np.clip(end_idx, 0, len(text)) + + text = text[start_idx:end_idx] + + scores[tuple(indices_to_consider)] = { + 'score': scorer(text), + # 'text' : text + } + + best_option = max(scores, key=lambda k: scores[k]['score']) + + if verbose: + print(f'[{len(indices_to_resolve)}] selected {best_option} from {scores}') + + should_accept_index = indices[0] in best_option + + if should_accept_index: + indices_accepted.append(indices[0]) + + indices_to_resolve = indices_to_resolve[1:] + + if pbar: + _pbar.update(1) + + if pbar: + _pbar.close() + + return indices_accepted \ No newline at end of file diff --git a/asr/whisper_scores.py b/asr/whisper_scores.py new file mode 100644 index 0000000..d177edd --- /dev/null +++ b/asr/whisper_scores.py @@ -0,0 +1,179 @@ +from typing import Any + +import torch +import numpy as np +from transformers.models.whisper.tokenization_whisper import bytes_to_unicode +from transformers import ( + AutomaticSpeechRecognitionPipeline, + WhisperFeatureExtractor, + WhisperTokenizer, + WhisperTokenizerFast, + WhisperForConditionalGeneration +) + +from .comparison import TokenizedText + + +def whisper_pipeline_transcribe_with_word_scores( + waveform: np.ndarray, + recognizer: AutomaticSpeechRecognitionPipeline, +) -> tuple[TokenizedText, list[list[str]], list[list[float]]]: + """ + A wrapper around `.whisper_transcribe_with_word_scores()` to use a pipeline. + Example: + + ``` + import librosa + from asr.asr import initialize_model_for_speech_recognition + waveform, _ = librosa.load('tests/testdata/test_sound_ru.wav', sr=None) + pipeline = initialize_model_for_speech_recognition() + whisper_pipeline_transcribe_with_word_scores(waveform, pipeline) + ``` + """ + return whisper_transcribe_with_word_scores( + waveform, + recognizer.feature_extractor, + recognizer.tokenizer, + recognizer.model, + recognizer._forward_params, # lang, task + ) + + +def whisper_transcribe_with_word_scores( + waveform: np.ndarray, + feature_extractor: WhisperFeatureExtractor, + tokenizer: WhisperTokenizer | WhisperTokenizerFast, + model: WhisperForConditionalGeneration, + generate_kwargs: dict[str, Any], +) -> tuple[TokenizedText, list[list[str]], list[list[float]]]: + """ + Transcribes the audio with Whisper and returns: + - the resulting text tokenized into words + - a list of tokens for each word + - a list of token scores for each word + + Example: + ``` + import librosa + waveform, _ = librosa.load('tests/testdata/test_sound_ru.wav', sr=None) + recognizer = pipeline('automatic-speech-recognition', model='openai/whisper-large-v3') + whisper_transcribe_with_word_scores( + waveform, + recognizer.feature_extractor, + recognizer.tokenizer, + recognizer.model, + {'language': '<|ru|>', 'task': 'transcribe'}, # or `recognizer._forward_params` + ) + + >>> ( + TokenizedText( + text=' нейронные сети это хорошо.', + tokens=[ + Substring(start=1, stop=10, text='нейронные', is_punct=False), + Substring(start=11, stop=15, text='сети', is_punct=False), + Substring(start=16, stop=19, text='это', is_punct=False), + Substring(start=20, stop=26, text='хорошо', is_punct=False), + Substring(start=26, stop=27, text='.', is_punct=True) + ] + ), + [[' ней', 'рон', 'ные'], [' с', 'ети'], [' это'], [' хорошо']], + [[-0.61, -6.80e-05, -0.00], [-8.82e-05, -2.41e-05], [-0.57], [-0.00]] + ) + ``` + """ + assert model.config.model_type == 'whisper' + + inputs = feature_extractor( + waveform, + return_tensors='pt', + sampling_rate=16_000, + ).to(model.device, model.dtype) + result = model.generate( + **inputs, + **generate_kwargs, + return_dict_in_generate=True, + return_token_timestamps=True, + ) + + # convert token ids and logits to numpy + token_ids = result['sequences'][0].cpu().numpy() + logits = torch.nn.functional.log_softmax(torch.stack(result['scores']), dim=-1).cpu().numpy() + + # skip start special tokens to align with logits + token_ids = token_ids[-len(logits):] + + # skip all special tokens + is_special = np.array([id in tokenizer.all_special_ids for id in token_ids]) + token_ids = token_ids[~is_special] + logits = logits[~is_special] + + score_per_token = np.array([float(l[0, token_id]) for token_id, l in zip(token_ids, logits)]) + + # reproducing whisper bpe decoding + byte_decoder = {v: k for k, v in bytes_to_unicode().items()} + bytes_list_per_token = [ + [byte_decoder[x] for x in bytes_str] + for bytes_str in tokenizer.convert_ids_to_tokens(token_ids) + ] + + # searching for token positions in the text + token_end_positions = [] + for i in range(len(bytes_list_per_token)): + concatenated_bytes = sum(bytes_list_per_token[:i + 1], []) + try: + text = bytearray(concatenated_bytes).decode('utf-8', errors='strict') + token_end_positions.append(len(text)) + except UnicodeDecodeError: + token_end_positions.append(None) # not a full utf-8 charachter + + assert text == tokenizer.decode(token_ids, clean_up_tokenization_spaces=False) + + # cleaning up tokenization spaces, shifting token_end_positions + # (see .clean_up_tokenization() in PreTrainedTokenizerBase) + if tokenizer.clean_up_tokenization_spaces: + for replace_from in [" .", " ?", " !", " ,", " ' ", " n't", " 'm", " 's", " 've", " 're"]: + replace_to = replace_from.strip() + while (start_pos := text.find(replace_from)) != -1: + delta_len = len(replace_to) - len(replace_from) + text = text[:start_pos] + replace_to + text[start_pos + len(replace_from):] + token_end_positions = [ + ( + token_end_pos + if token_end_pos <= start_pos + else token_end_pos + delta_len + ) + for token_end_pos in token_end_positions + ] + + assert text == tokenizer.decode(token_ids) + + # tokenizing the text + tokenized_text = TokenizedText.from_text(text) + + # matching words and tokens + tokens_range_per_word = [] + for word in tokenized_text.get_words(): + first_token_idx = None # first token of the word, inclusive + for token_idx, token_end_pos in enumerate(token_end_positions): + if token_end_pos is None: + continue + if token_end_pos > word.start and first_token_idx is None: + first_token_idx = token_idx + if token_end_pos >= word.stop: + break + tokens_range_per_word.append((first_token_idx, token_idx + 1)) + + tokens_per_word = [ + [ + bytearray(b).decode('utf-8', errors='replace') + for b in bytes_list_per_token[start_token_idx:end_token_idx] + ] + for start_token_idx, end_token_idx in tokens_range_per_word + ] + + token_scores_per_word = [ + list(score_per_token[start_token_idx:end_token_idx]) + for start_token_idx, end_token_idx in tokens_range_per_word + ] + + return tokenized_text, tokens_per_word, token_scores_per_word \ No newline at end of file diff --git a/evaluation/calc_metrics.ipynb b/evaluation/calc_metrics.ipynb new file mode 100644 index 0000000..b564d1d --- /dev/null +++ b/evaluation/calc_metrics.ipynb @@ -0,0 +1,1169 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from pathlib import Path\n", + "from itertools import combinations\n", + "from typing import Any\n", + "\n", + "from tqdm import tqdm\n", + "import numpy as np\n", + "import pandas as pd\n", + "from datasets import load_dataset\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from asr.comparison import MultipleTextsAlignment, filter_correction_suggestions, TokenizedText, Substring\n", + "from asr.lm import SequenceScore, accept_suggestions_by_lm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading results from disk" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = load_dataset('dangrebenkin/long_audio_youtube_lectures')['train']\n", + "name_to_transcription = dict(zip(dataset['name'], dataset['transcription']))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/77 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
audio_namepipeline_namealignmentwerscores_per_word
0kolodezevBaseline Whisper longformMultipleTextsAlignment(text1=TokenizedText(tex...0.161696NaN
1zhirinovskyPisets WhisperV3 no-VAD (segments 1s-20s)MultipleTextsAlignment(text1=TokenizedText(tex...0.052458NaN
2zhirinovskyPisets WhisperV3 no-VAD stretched (segments 1s...MultipleTextsAlignment(text1=TokenizedText(tex...0.064849NaN
3lankovPisets WhisperV3 no-VAD Podlodka (segments 1s-...MultipleTextsAlignment(text1=TokenizedText(tex...0.097544NaN
4kolodezevBaseline Whisper longform conditionedMultipleTextsAlignment(text1=TokenizedText(tex...0.276680NaN
..................
2lankovPisets WhisperV3 no-VAD (segments 1s-20s) with...MultipleTextsAlignment(text1=TokenizedText(tex...0.089934[[-0.4045039713382721], [-0.25986623764038086]...
3zaliznyakPisets WhisperV3 no-VAD (segments 1s-20s) with...MultipleTextsAlignment(text1=TokenizedText(tex...0.112038[[-0.1870197057723999, -6.603976362384856e-05]...
4savvateevPisets WhisperV3 no-VAD (segments 1s-20s) with...MultipleTextsAlignment(text1=TokenizedText(tex...0.162985[[-0.6434793472290039], [-0.008379065431654453...
5kolodezevPisets WhisperV3 no-VAD (segments 1s-20s) with...MultipleTextsAlignment(text1=TokenizedText(tex...0.127201[[-1.3415793180465698, -0.010715918615460396],...
6zhirinovskyPisets WhisperV3 no-VAD (segments 1s-20s) with...MultipleTextsAlignment(text1=TokenizedText(tex...0.053697[[-1.8754836320877075, -0.01690865121781826], ...
\n", + "

84 rows × 5 columns

\n", + "" + ], + "text/plain": [ + " audio_name pipeline_name \\\n", + "0 kolodezev Baseline Whisper longform \n", + "1 zhirinovsky Pisets WhisperV3 no-VAD (segments 1s-20s) \n", + "2 zhirinovsky Pisets WhisperV3 no-VAD stretched (segments 1s... \n", + "3 lankov Pisets WhisperV3 no-VAD Podlodka (segments 1s-... \n", + "4 kolodezev Baseline Whisper longform conditioned \n", + ".. ... ... \n", + "2 lankov Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + "3 zaliznyak Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + "4 savvateev Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + "5 kolodezev Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + "6 zhirinovsky Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + "\n", + " alignment wer \\\n", + "0 MultipleTextsAlignment(text1=TokenizedText(tex... 0.161696 \n", + "1 MultipleTextsAlignment(text1=TokenizedText(tex... 0.052458 \n", + "2 MultipleTextsAlignment(text1=TokenizedText(tex... 0.064849 \n", + "3 MultipleTextsAlignment(text1=TokenizedText(tex... 0.097544 \n", + "4 MultipleTextsAlignment(text1=TokenizedText(tex... 0.276680 \n", + ".. ... ... \n", + "2 MultipleTextsAlignment(text1=TokenizedText(tex... 0.089934 \n", + "3 MultipleTextsAlignment(text1=TokenizedText(tex... 0.112038 \n", + "4 MultipleTextsAlignment(text1=TokenizedText(tex... 0.162985 \n", + "5 MultipleTextsAlignment(text1=TokenizedText(tex... 0.127201 \n", + "6 MultipleTextsAlignment(text1=TokenizedText(tex... 0.053697 \n", + "\n", + " scores_per_word \n", + "0 NaN \n", + "1 NaN \n", + "2 NaN \n", + "3 NaN \n", + "4 NaN \n", + ".. ... \n", + "2 [[-0.4045039713382721], [-0.25986623764038086]... \n", + "3 [[-0.1870197057723999, -6.603976362384856e-05]... \n", + "4 [[-0.6434793472290039], [-0.008379065431654453... \n", + "5 [[-1.3415793180465698, -0.010715918615460396],... \n", + "6 [[-1.8754836320877075, -0.01690865121781826], ... \n", + "\n", + "[84 rows x 5 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## WER results" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
audio_nameharvardkolodezevlankovsavvateevtuberculosiszaliznyakzhirinovsky
pipeline_name
Baseline Whisper longform0.0109290.1616960.1030790.2061860.1695760.1580860.043371
Baseline Whisper longform conditioned0.0505460.2766800.1238330.2302410.1399630.6787530.064436
Baseline Whisper pipeline0.0455370.1552280.1473540.1924400.1995010.1316170.115655
Pisets WhisperV3 (segments 10s-30s)0.0118400.1340280.1317880.1826220.1599130.1131250.067741
Pisets WhisperV3 (segments 1s-20s)0.0159380.1293570.0875130.2169860.1312340.1167510.060306
Pisets WhisperV3 Podlodka (segments 1s-20s)0.0309650.1027670.0975440.2911140.0763720.1163890.088806
Pisets WhisperV3 no-VAD (segments 1s-20s)0.0159380.1293570.0875130.1860580.1312340.1065990.052458
Pisets WhisperV3 no-VAD (segments 1s-20s) with scores0.0168490.1272010.0899340.1629850.1296760.1120380.053697
Pisets WhisperV3 no-VAD Podlodka (segments 1s-20s)0.0309650.1027670.0975440.2592050.0763720.1069620.081371
Pisets WhisperV3 no-VAD stretched (segments 1s-20s)0.0377960.1149840.1099970.3166420.1184540.1294420.064849
Pisets WhisperV3 stretched (segments 1s-20s)0.0377960.1149840.1099970.3480610.1184540.1392310.072697
W2V2 Golos LM0.1498180.2716490.3168450.6293570.2793020.2505440.261875
\n", + "
" + ], + "text/plain": [ + "audio_name harvard kolodezev \\\n", + "pipeline_name \n", + "Baseline Whisper longform 0.010929 0.161696 \n", + "Baseline Whisper longform conditioned 0.050546 0.276680 \n", + "Baseline Whisper pipeline 0.045537 0.155228 \n", + "Pisets WhisperV3 (segments 10s-30s) 0.011840 0.134028 \n", + "Pisets WhisperV3 (segments 1s-20s) 0.015938 0.129357 \n", + "Pisets WhisperV3 Podlodka (segments 1s-20s) 0.030965 0.102767 \n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) 0.015938 0.129357 \n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with ... 0.016849 0.127201 \n", + "Pisets WhisperV3 no-VAD Podlodka (segments 1s-20s) 0.030965 0.102767 \n", + "Pisets WhisperV3 no-VAD stretched (segments 1s-... 0.037796 0.114984 \n", + "Pisets WhisperV3 stretched (segments 1s-20s) 0.037796 0.114984 \n", + "W2V2 Golos LM 0.149818 0.271649 \n", + "\n", + "audio_name lankov savvateev \\\n", + "pipeline_name \n", + "Baseline Whisper longform 0.103079 0.206186 \n", + "Baseline Whisper longform conditioned 0.123833 0.230241 \n", + "Baseline Whisper pipeline 0.147354 0.192440 \n", + "Pisets WhisperV3 (segments 10s-30s) 0.131788 0.182622 \n", + "Pisets WhisperV3 (segments 1s-20s) 0.087513 0.216986 \n", + "Pisets WhisperV3 Podlodka (segments 1s-20s) 0.097544 0.291114 \n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) 0.087513 0.186058 \n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with ... 0.089934 0.162985 \n", + "Pisets WhisperV3 no-VAD Podlodka (segments 1s-20s) 0.097544 0.259205 \n", + "Pisets WhisperV3 no-VAD stretched (segments 1s-... 0.109997 0.316642 \n", + "Pisets WhisperV3 stretched (segments 1s-20s) 0.109997 0.348061 \n", + "W2V2 Golos LM 0.316845 0.629357 \n", + "\n", + "audio_name tuberculosis zaliznyak \\\n", + "pipeline_name \n", + "Baseline Whisper longform 0.169576 0.158086 \n", + "Baseline Whisper longform conditioned 0.139963 0.678753 \n", + "Baseline Whisper pipeline 0.199501 0.131617 \n", + "Pisets WhisperV3 (segments 10s-30s) 0.159913 0.113125 \n", + "Pisets WhisperV3 (segments 1s-20s) 0.131234 0.116751 \n", + "Pisets WhisperV3 Podlodka (segments 1s-20s) 0.076372 0.116389 \n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) 0.131234 0.106599 \n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with ... 0.129676 0.112038 \n", + "Pisets WhisperV3 no-VAD Podlodka (segments 1s-20s) 0.076372 0.106962 \n", + "Pisets WhisperV3 no-VAD stretched (segments 1s-... 0.118454 0.129442 \n", + "Pisets WhisperV3 stretched (segments 1s-20s) 0.118454 0.139231 \n", + "W2V2 Golos LM 0.279302 0.250544 \n", + "\n", + "audio_name zhirinovsky \n", + "pipeline_name \n", + "Baseline Whisper longform 0.043371 \n", + "Baseline Whisper longform conditioned 0.064436 \n", + "Baseline Whisper pipeline 0.115655 \n", + "Pisets WhisperV3 (segments 10s-30s) 0.067741 \n", + "Pisets WhisperV3 (segments 1s-20s) 0.060306 \n", + "Pisets WhisperV3 Podlodka (segments 1s-20s) 0.088806 \n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) 0.052458 \n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with ... 0.053697 \n", + "Pisets WhisperV3 no-VAD Podlodka (segments 1s-20s) 0.081371 \n", + "Pisets WhisperV3 no-VAD stretched (segments 1s-... 0.064849 \n", + "Pisets WhisperV3 stretched (segments 1s-20s) 0.072697 \n", + "W2V2 Golos LM 0.261875 " + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results.pivot_table(values='wer', index='pipeline_name', columns='audio_name')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Uncertainty with model disagreement\n", + "\n", + "\"Method 3: LM filtering\" may take a lot of time. It will be saved on disk as soon as it is calculated." + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM all diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) all diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM all diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) all diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM all diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) all diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM all diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) all diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM all diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) all diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM all diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) all diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM all diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores W2V2 Golos LM LM filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) all diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) filtered diffs\n", + "Pisets WhisperV3 no-VAD (segments 1s-20s) with scores Pisets WhisperV3 no-VAD stretched (segments 1s-20s) LM filtered diffs\n" + ] + } + ], + "source": [ + "uncertainty_results = []\n", + "\n", + "scorer = SequenceScore('ai-forever/rugpt3large_based_on_gpt2')\n", + "\n", + "for audio_name in name_to_transcription:\n", + "\n", + " base_pipeline_name = 'Pisets WhisperV3 no-VAD (segments 1s-20s) with scores'\n", + "\n", + " truth_vs_base: MultipleTextsAlignment = results.query(\n", + " f'audio_name == \"{audio_name}\" and pipeline_name == \"{base_pipeline_name}\"'\n", + " ).iloc[0]['alignment']\n", + "\n", + " for additional_pipeline_name in [\n", + " 'W2V2 Golos LM',\n", + " 'Pisets WhisperV3 no-VAD stretched (segments 1s-20s)',\n", + " ]:\n", + " additional_predictions: TokenizedText = results.query(\n", + " f'audio_name == \"{audio_name}\" and pipeline_name == \"{additional_pipeline_name}\"'\n", + " ).iloc[0]['alignment'].text2\n", + "\n", + " base_vs_additional = MultipleTextsAlignment.from_strings(truth_vs_base.text2, additional_predictions)\n", + "\n", + " # method 1: no filtering\n", + " print(base_pipeline_name, additional_pipeline_name, 'all diffs')\n", + "\n", + " is_uncertain = base_vs_additional.get_uncertainty_mask()\n", + " uncertainty_results.append({\n", + " 'audio_name': audio_name,\n", + " 'base_pipeline': base_pipeline_name,\n", + " 'additional_pipeline': additional_pipeline_name,\n", + " 'method': 'all diffs',\n", + " 'mask': is_uncertain,\n", + " 'metrics': truth_vs_base.wer(uncertainty_mask=is_uncertain)\n", + " })\n", + "\n", + " # method 2: filtering\n", + " print(base_pipeline_name, additional_pipeline_name, 'filtered diffs')\n", + "\n", + " correction_indices = filter_correction_suggestions(base_vs_additional, skip_word_form_change=False, pbar=False)\n", + " is_uncertain = base_vs_additional.get_uncertainty_mask(match_indices=correction_indices)\n", + " uncertainty_results.append({\n", + " 'audio_name': audio_name,\n", + " 'base_pipeline': base_pipeline_name,\n", + " 'additional_pipeline': additional_pipeline_name,\n", + " 'method': 'filtered diffs',\n", + " 'mask': is_uncertain,\n", + " 'metrics': truth_vs_base.wer(uncertainty_mask=is_uncertain)\n", + " })\n", + "\n", + " # method 3: LM filtering\n", + " print(base_pipeline_name, additional_pipeline_name, 'LM filtered diffs')\n", + "\n", + " cache_path = (\n", + " Path('/home/oleg/pisets_test_results_lm')\n", + " / f'[{audio_name}] [{base_pipeline_name}] [{additional_pipeline_name}].json'\n", + " )\n", + " if cache_path.is_file():\n", + " lm_filtered_suggestion_indices = json.loads(cache_path.read_text())['indices']\n", + " else:\n", + " lm_filtered_suggestion_indices = accept_suggestions_by_lm(\n", + " base_vs_additional,\n", + " [i for i, m in enumerate(base_vs_additional.matches) if not m.is_equal],\n", + " scorer,\n", + " pbar=False,\n", + " verbose=True,\n", + " )\n", + " cache_path.parent.mkdir(parents=True, exist_ok=True)\n", + " cache_path.write_text(json.dumps({'indices': lm_filtered_suggestion_indices}))\n", + " is_uncertain = base_vs_additional.get_uncertainty_mask(match_indices=lm_filtered_suggestion_indices)\n", + " uncertainty_results.append({\n", + " 'audio_name': audio_name,\n", + " 'base_pipeline': base_pipeline_name,\n", + " 'additional_pipeline': additional_pipeline_name,\n", + " 'method': 'LM filtered diffs',\n", + " 'mask': is_uncertain,\n", + " 'metrics': truth_vs_base.wer(uncertainty_mask=is_uncertain)\n", + " })" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Uncertainty with Whisper sequence score" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/oleg/pisets/asr/comparison.py:347: RuntimeWarning: invalid value encountered in scalar divide\n", + " uncertain_n_correct / (uncertain_n_correct + uncertain_n_incorrect)\n", + "/home/oleg/pisets/asr/comparison.py:350: RuntimeWarning: invalid value encountered in scalar divide\n", + " results['uncertain_n_incorrect']\n" + ] + } + ], + "source": [ + "for audio_name in name_to_transcription:\n", + "\n", + " base_pipeline_name = 'Pisets WhisperV3 no-VAD (segments 1s-20s) with scores'\n", + "\n", + " row = results.query(\n", + " f'audio_name == \"{audio_name}\" and pipeline_name == \"{base_pipeline_name}\"'\n", + " ).iloc[0]\n", + " truth_vs_base = row['alignment']\n", + " scores_per_word = row['scores_per_word']\n", + "\n", + " reductions = {'min': min, 'mean': np.mean, 'sum': np.sum}\n", + " log_proba_thresholds = np.linspace(-1.5, -0.1, num=15)\n", + "\n", + " for reduction_name, reduction_fn in reductions.items():\n", + " for log_proba_threshold in log_proba_thresholds:\n", + " is_uncertain = np.array([reduction_fn(s) for s in scores_per_word]) < log_proba_threshold\n", + " uncertainty_results.append({\n", + " 'audio_name': audio_name,\n", + " 'base_pipeline': base_pipeline_name,\n", + " 'method': f'WhisperLogProba_{reduction_name}',\n", + " 't': log_proba_threshold,\n", + " 'mask': is_uncertain,\n", + " 'metrics': truth_vs_base.wer(uncertainty_mask=is_uncertain),\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
audio_namebase_pipelineadditional_pipelinemethodmaskmetricst
0zaliznyakPisets WhisperV3 no-VAD (segments 1s-20s) with...W2V2 Golos LMall diffs[False, False, False, False, False, False, Fal...{'wer': 0.112037708484409, 'certain_n_correct'...NaN
1zaliznyakPisets WhisperV3 no-VAD (segments 1s-20s) with...W2V2 Golos LMfiltered diffs[False, False, False, False, False, False, Fal...{'wer': 0.112037708484409, 'certain_n_correct'...NaN
2zaliznyakPisets WhisperV3 no-VAD (segments 1s-20s) with...W2V2 Golos LMLM filtered diffs[False, False, False, False, False, False, Fal...{'wer': 0.112037708484409, 'certain_n_correct'...NaN
3zaliznyakPisets WhisperV3 no-VAD (segments 1s-20s) with...Pisets WhisperV3 no-VAD stretched (segments 1s...all diffs[False, False, False, False, False, False, Fal...{'wer': 0.112037708484409, 'certain_n_correct'...NaN
4zaliznyakPisets WhisperV3 no-VAD (segments 1s-20s) with...Pisets WhisperV3 no-VAD stretched (segments 1s...filtered diffs[False, False, False, False, False, False, Fal...{'wer': 0.112037708484409, 'certain_n_correct'...NaN
........................
352tuberculosisPisets WhisperV3 no-VAD (segments 1s-20s) with...NaNWhisperLogProba_sum[False, False, False, True, False, False, Fals...{'wer': 0.12967581047381546, 'certain_n_correc...-0.5
353tuberculosisPisets WhisperV3 no-VAD (segments 1s-20s) with...NaNWhisperLogProba_sum[False, False, False, True, False, False, Fals...{'wer': 0.12967581047381546, 'certain_n_correc...-0.4
354tuberculosisPisets WhisperV3 no-VAD (segments 1s-20s) with...NaNWhisperLogProba_sum[True, False, False, True, False, False, False...{'wer': 0.12967581047381546, 'certain_n_correc...-0.3
355tuberculosisPisets WhisperV3 no-VAD (segments 1s-20s) with...NaNWhisperLogProba_sum[True, False, False, True, False, False, False...{'wer': 0.12967581047381546, 'certain_n_correc...-0.2
356tuberculosisPisets WhisperV3 no-VAD (segments 1s-20s) with...NaNWhisperLogProba_sum[True, False, False, True, False, False, False...{'wer': 0.12967581047381546, 'certain_n_correc...-0.1
\n", + "

357 rows × 7 columns

\n", + "
" + ], + "text/plain": [ + " audio_name base_pipeline \\\n", + "0 zaliznyak Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + "1 zaliznyak Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + "2 zaliznyak Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + "3 zaliznyak Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + "4 zaliznyak Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + ".. ... ... \n", + "352 tuberculosis Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + "353 tuberculosis Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + "354 tuberculosis Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + "355 tuberculosis Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + "356 tuberculosis Pisets WhisperV3 no-VAD (segments 1s-20s) with... \n", + "\n", + " additional_pipeline method \\\n", + "0 W2V2 Golos LM all diffs \n", + "1 W2V2 Golos LM filtered diffs \n", + "2 W2V2 Golos LM LM filtered diffs \n", + "3 Pisets WhisperV3 no-VAD stretched (segments 1s... all diffs \n", + "4 Pisets WhisperV3 no-VAD stretched (segments 1s... filtered diffs \n", + ".. ... ... \n", + "352 NaN WhisperLogProba_sum \n", + "353 NaN WhisperLogProba_sum \n", + "354 NaN WhisperLogProba_sum \n", + "355 NaN WhisperLogProba_sum \n", + "356 NaN WhisperLogProba_sum \n", + "\n", + " mask \\\n", + "0 [False, False, False, False, False, False, Fal... \n", + "1 [False, False, False, False, False, False, Fal... \n", + "2 [False, False, False, False, False, False, Fal... \n", + "3 [False, False, False, False, False, False, Fal... \n", + "4 [False, False, False, False, False, False, Fal... \n", + ".. ... \n", + "352 [False, False, False, True, False, False, Fals... \n", + "353 [False, False, False, True, False, False, Fals... \n", + "354 [True, False, False, True, False, False, False... \n", + "355 [True, False, False, True, False, False, False... \n", + "356 [True, False, False, True, False, False, False... \n", + "\n", + " metrics t \n", + "0 {'wer': 0.112037708484409, 'certain_n_correct'... NaN \n", + "1 {'wer': 0.112037708484409, 'certain_n_correct'... NaN \n", + "2 {'wer': 0.112037708484409, 'certain_n_correct'... NaN \n", + "3 {'wer': 0.112037708484409, 'certain_n_correct'... NaN \n", + "4 {'wer': 0.112037708484409, 'certain_n_correct'... NaN \n", + ".. ... ... \n", + "352 {'wer': 0.12967581047381546, 'certain_n_correc... -0.5 \n", + "353 {'wer': 0.12967581047381546, 'certain_n_correc... -0.4 \n", + "354 {'wer': 0.12967581047381546, 'certain_n_correc... -0.3 \n", + "355 {'wer': 0.12967581047381546, 'certain_n_correc... -0.2 \n", + "356 {'wer': 0.12967581047381546, 'certain_n_correc... -0.1 \n", + "\n", + "[357 rows x 7 columns]" + ] + }, + "execution_count": 99, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "uncertainty_results = pd.DataFrame(uncertainty_results)\n", + "uncertainty_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Ensembling uncertainty estimation methods" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": {}, + "outputs": [], + "source": [ + "ensemble_uncertainty_results = []\n", + "\n", + "for audio_name in name_to_transcription:\n", + "\n", + " base_pipeline_name = 'Pisets WhisperV3 no-VAD (segments 1s-20s) with scores'\n", + " additional_pipeline_name = 'W2V2 Golos LM'\n", + "\n", + " truth_vs_base = results.query(\n", + " f'audio_name == \"{audio_name}\" and pipeline_name == \"{base_pipeline_name}\"'\n", + " ).iloc[0]['alignment']\n", + "\n", + " t = -1\n", + " row1 = uncertainty_results.query(\n", + " f'audio_name == \"{audio_name}\"'\n", + " f' and base_pipeline == \"{base_pipeline_name}\"'\n", + " ' and method == \"WhisperLogProba_sum\"'\n", + " f' and t > {t - 0.001}'\n", + " f' and t < {t + 0.001}'\n", + " ).iloc[0]\n", + "\n", + " row2 = uncertainty_results.query(\n", + " f'audio_name == \"{audio_name}\"'\n", + " f' and base_pipeline == \"{base_pipeline_name}\"'\n", + " f' and additional_pipeline == \"{additional_pipeline_name}\"'\n", + " ' and method == \"LM filtered diffs\"'\n", + " ).iloc[0]\n", + "\n", + " is_uncertain = row1['mask'] | row2['mask']\n", + "\n", + " ensemble_uncertainty_results.append({\n", + " 'audio_name': audio_name,\n", + " 'base_pipeline': base_pipeline_name,\n", + " 'additional_pipeline': additional_pipeline_name,\n", + " 'method': f'LM filtered diffs + WhisperLogProba_sum (t={t})',\n", + " 'mask': is_uncertain,\n", + " 'metrics': truth_vs_base.wer(uncertainty_mask=is_uncertain),\n", + " })\n", + "\n", + "ensemble_uncertainty_results = pd.DataFrame(ensemble_uncertainty_results)\n", + "uncertainty_results = pd.concat([uncertainty_results, ensemble_uncertainty_results], axis='rows', ignore_index=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Uncertainty plots" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(15, 7))\n", + "\n", + "show_for_all_datasets = True\n", + "\n", + "x_statistics = 'uncertainty_ratio'\n", + "y_statistics = 'recall'\n", + "\n", + "for i, ((base_pipeline, additional_pipeline, method), group_loc) in enumerate(\n", + " uncertainty_results.groupby(['base_pipeline', 'additional_pipeline', 'method']).groups.items()\n", + "):\n", + " group = uncertainty_results.loc[group_loc]\n", + " color = f'C{i}'\n", + " has_t = not pd.isna(group['t'].values[0]) and group['t'].nunique() > 1\n", + " \n", + " if not pd.isna(additional_pipeline):\n", + " label = f'{base_pipeline.replace(\" with scores\", \"\")} | {additional_pipeline} | {method}'\n", + " else:\n", + " label = f'{base_pipeline.replace(\" with scores\", \"\")} | {method}'\n", + "\n", + " if not has_t:\n", + " # no parameter, scatter plot\n", + " assert group.audio_name.nunique() == len(group)\n", + " xs = [m[x_statistics] for m in group.metrics]\n", + " ys = [m[y_statistics] for m in group.metrics]\n", + " assert len(xs) == len(name_to_transcription)\n", + " if show_for_all_datasets:\n", + " plt.scatter(xs, ys, alpha=0.1, color=color)\n", + " plt.scatter([np.mean(xs)], [np.mean(ys)], label=label, color=color)\n", + " \n", + " else:\n", + " # has a parameter, line plot\n", + " t_range = sorted(group['t'].unique())\n", + "\n", + " xs = []\n", + " ys = []\n", + " for t in t_range:\n", + " group_for_t = group[group['t'] == t]\n", + " assert group_for_t.audio_name.nunique() == len(group_for_t)\n", + " xs.append([m[x_statistics] for m in group_for_t.metrics])\n", + " ys.append([m[y_statistics] for m in group_for_t.metrics])\n", + "\n", + " xs = np.array(xs).T # shape: (n_audios, n_t_values)\n", + " ys = np.array(ys).T # shape: (n_audios, n_t_values)\n", + " assert len(xs) == len(name_to_transcription)\n", + "\n", + " if show_for_all_datasets:\n", + " for _xs, _ys in zip(xs, ys):\n", + " plt.plot(_xs, _ys, alpha=0.1, color=color)\n", + " plt.plot(xs.mean(axis=0), ys.mean(axis=0), label=label, color=color)\n", + "\n", + "plt.xlabel(x_statistics)\n", + "plt.ylabel(y_statistics)\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visual analysis of uncertainty highlighting" + ] + }, + { + "cell_type": "code", + "execution_count": 124, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Вторая строка фактически была в очень плохом состоянии, но удалось однако же все-таки ее практически целиком восстановить. Я не буду вам выписывать все скобки неполной видимости, это не очень в данном случае существенно, поскольку в конечном счете результат совершенно надежный {остался|оказался} восстановлен. И читается следующее. Адресат. Вот практически все, что сохранилось от этой грамоты, это адресная формула. поклон от {Клименте|элементе} и от {Марьи|марья} к Петку {Копарину. Имя Петко|кабаринаимя пятко} находится далеко. скажем своде тупикова которым постоянно пользуемся своде древнерусских имен {петка|пятко} упоминается 11 раз то есть один из разных персонажей но это очень понятно это одно из {элементов а|имен того} типа как какой-нибудь шестак {3 2|третей второй} и так далее когда Долго не думая, детей называли просто по счету появления, и больше ничего. Что касается опарина, то, конечно, он происходит от имени опара. Но опара – это такое тесто, вылезающее из катки. Я очень себе представляю, какого человека должны были награждать прозвищем опара. {В} всяком случае, это имя вполне... и прозвище, и имя, вполне {существующие|существующий} в русской традиции, и {фамилии|фамилия} хорошо известные. Казалось бы, больше ничего из этого особенного извлечь не можем, кроме того, что имя Пятко, которое раньше не встречалось, внесем в {словарь|словари}, и все. Но нет. Это из тех замечательных случаев, когда... так сказать, покопавшись в фонде уже имеющихся материалов, мы обнаруживаем какую-то перекличку. В данном случае эта грамота оказалась в полной перекличке с грамотой, найденной 60 лет назад под номером 311. Грамота под номером 311 гласит... Господину своему Михаилу Юрьевичу. Михаил Юрьевич – посадничий сын, того времени очень важный боярин, начало 15 века. «Христиане твои черенщане челом бьют». Дальше я дам перевод, чтобы не заниматься… лишними деталями ты отдал {пашин|пашенку} {куб крем цук|климцу} опарину а мы его не хотим не соседний человек {в больном|волен} бог {длительное|да и ты} {такая|такой} {замечательная|замечательной} но это формула очень {хорошо известная|хооизвестный} {вольн|волин} бог {дайте|да} То есть смысл стоит в том, что ты один, по сути дела, отвечаешь за то, как ты решишь дело. Климцу опарину. Точное сочетание имени, которое у нас встретилось в {полной|полном} виде. Заметьте, там они называют его Климцом, поскольку он им не нравится. Вообще он такой человек, {которого|который} они хотят, чтобы ему никаких пашинок никто не давал. Здесь он называет себя более официально, от Климентии, но совершенно ясно, что одно и то же имя. И фамилия Парин, которому он пишет просто без всяких там {господину|господин} и так далее, совершенно очевидно, как естественно было бы писать своему родственнику. И тогда довольно ясно, что это начало семейного письма, к брату хозяина. И таких семейных {мероприятий есть некоторое|съесть мехри} количество, они очень похожи по типу. Они бывают или просто приглашение приехать, или иногда поздравление с чем-то. И, пожалуй, тут тоже уместно привести {точный|точные} {пример|примеры} того, Очень похожие примеры того, как можно себе представить, что там дальше было в этой грамоте. К сожалению, в данном случае конец нам неизвестен, пока что его не нашли. Например, грамота того же времени. «Поклон от Гаврилы {Постни|посдни}, зятю моему куму Григорию и сестре моей Улите». Ну, очевидно, мужу и жене тоже. Поехали бы вы в город, в город, это, разумеется, в Новгород, то есть это письмо... {послана|посла} куда-то за пределы города. Поехали бы в город к радости моей, а нашего слова не забыли бы. Дай Бог вам радость. Вот, пожалуй, такое очень, кстати, тоже очень известное письмо, номер {497|четыреста девяносто семь}, показывающее, что такого рода записки тоже в то время вполне могли посылать. и другое письмо тоже того же времени на этот раз и старые русы старые русы номер 10 тоже прочту Кстати, еще идет по образцу XII века, а не XIV, что показывает, что это могло держаться. Поклон от Оксинии и Анании. Анания – мужское имя, Оксиния, естественно, женское. Поклон от Оксинии и Анании к Родивону и сестре моей Татьяне. «Пойдите в город, опять-таки, конечно, в Новгород, к сей неделе, то есть к этому воскресенью». Дальше фраза, на которую я обращаю внимание, потому что она еще нам понадобится для дальнейшего. «Давать мне дочь, а сестре моей приставничать». Это надо перевести. То есть я... Мне предстоит выдавать свою дочь, а сестре моей предстоит быть распорядительницей на свадьбе. Все эти термины очень хорошо прослеживаются. Ну и так далее. А я господину своему Родиону и своей сестре много челом бью. Вот примерно тот тип писем, который явно совершенно представленный этим письмом, но вот с таким снова обнаруженными, {ну|но} теперь уже двумя персонажами этим писем. {Кремцом|кримцом} {опаренным|опариным}, про которого мы кое-что знаем из 311 грамоты, давным-давно найденной, и его {братья|брате}, потому что они находились, так сказать, еще в таких вот отношениях взаимных. дружили семьями, скажем так, {сейчас будет} названо вероятно. Ну, об этом... Об этом достаточно, пойдемте дальше. Дальше мы попадаем в тот самый комплекс писем середины второй половины {XIV|четырнадцатого} века. Ну, второй половины, нет, середины там мало что. с сильно пересекающимися именами. Итак, номер {98|девяносто восемь}. Ну, опять-таки, вас не удивит, что я скажу, что грамоты целы. {Семи|семитшесть} строк... Шесть строк. Да, я не сказал вам, что на той грамоте, где перечислялись все... кто дал рубль и полтину, имелась запись на обороте, а именно первая половина алфавита. Такое упражнение довольно часто встречается. Терпения, правда, не хватило у писавшего дойти дальше буквы «К». До этого он {всё|все} успел записать. Это бывает. Это такое очень естественное занятие для человека, который... умел грамоте так себя реализовать в свободный момент. Итак, 1098 Здесь тоже большое письмо, {6|шесть} строк, причем более длинных, чем там на лицевой стороне, и еще одна строка на обороте. Ну вот, почитаем. Вас, {уважение|уже не}, должно удивлять. В это время это совершенно нормальное начало писем. Вот если бы такое встретилось в в письме {XII|двенадцатого} века, это была бы совершенная сенсация, чтобы {начиналось|начиналась} поклон. Тогда пришлось бы десять раз контролировать археологов, не ошиблись ли они, и на самом деле не является ли эта грамота более поздней, чем такое предполагается. Но таких случаев не было, это я говорю в абстрактном виде. Итак, поклон... Вот выступает первое лицо, который будет еще нам встречаться. Да. {пожалуй|этожалуйт} я {с ним еще|не} начальная формула Совершенно стандартные имена тоже обычного {набор|набора} из обычного набора вот дальше уже идет содержание как всегда, некоторым некоторой драматической основы, поскольку, если не считать вот таких пригласительных, ласковых писем, которых другого содержания не было, кроме того, что «дай Бог вам радость» или что-нибудь в этом духе, то всегда нужно было что-то такое расхлебать, что было неудовлетворительно для писавшего. Так и здесь. Вот чем он недоволен. Заметьте, XIV век, уже легко вам понимать текст. Это Не всегда, конечно. Бывают и неприятные казусы даже с {XIV|четырнадцатым} веком. Но, тем не менее, пока что вы должны понимать все совершенно без всякого затруднения. Со скоростью, так сказать, чтения. Верно? {Уж}... Это, скорее всего, уже «уж» читалось. Но дело в том, что наше с вами «уже» раньше имело ударение «уже». И, соответственно, «уж» очень легко получалось из этого «{уже|ужа}». Но по смыслу уже к вам шлю третью, обращаю ваше внимание, именно такая была древняя форма, это была полная форма, {но} не краткая, {третьюю|третью} третью грамоту. А вы мне подскажите, какое дальше будет слово? А, правильно, конечно, совершенно ясно. Зачем иначе это писать, если не касается того, что... Ну, абсолютно очевидно. Может быть, даже {вы} еще одно слово угадаете? А вы, конечно, смотрите, как все замечательно. Правильно. Совершенно справедливо. Ну, дальше уже... {здесь|есть} разнообразие, {а там|это} совершенно правильно. Комни... Ну, вот это вот первый случай, где у вас {ядь|ять} реализован в виде и. Для {XIV|четырнадцатого} века вещь совершенно нормальная, так что не заявляйтесь, это будет еще и не раз, и не только в этой грамоте. Значит, ко мне, это ко мне с {ядьем|ятем}. В высшей степени все естественно. Кстати, обращаю ваше внимание, что сейчас бы мы сказали, вы ко мне не присылаете. Это нормальный русский оборот, вот то, что я предпочитаю называть {presence|презенс} напрасного ожидания, который требует совершенного вида. Вот как в известной форме, там денег все не соберем, а не собираем, это в точности. этот же тип {синтаксис|синтекс}, который в древних текстах довольно часто. А вы ко мне не пришлете. Это не будущее время, конечно, а то, что сейчас выразилось бы. А вы ко мне не присылаете. Не призываете, но ясно того, что... материального. Придется немножко мне здесь... Вы, конечно, думаете о накладных современных, но это немножко будет поспешно. Нет, он написал это правильно, конечно, было бы через ЕР, но он написал через ЕР, простите меня. Нет, пока еще правильно. Потом он эту ошибку сделает. Не ошибку, а вариативность. Дважды написал чуть-чуть {различным|различно}, потому что это слово еще повторится. Вот такая жалоба. А что это за накладное серебро? Ну, я говорю накладное с нынешним ударением, конечно, тогда это было {ударение|дарение} накладного, без всякого сомнения, но как вы думаете? Будьте ближе к нормальным материальным интересам {тогдашнего}. Что такое {осталось|остался} в накладе? В убытке. И... это бы убытки а вообще говоря {накладка|наклада} за то что наложено сверху это вообще это просто проценты {на} серебро конечно означает не серебро так напрасно сразу {думаете|думайте} о том что это такой металл который надо наложить куда-то конечно все это могло быть на далеко идет от значит серебро это деньги абсолютно точно то же самое как по французски {ажан} совершенно тот же семантический переход {а} накладное серебро это серебро {лихвенные|лихвинные} проценты то есть не присылаешь мне процентных денег и {не|рыб} кстати рыба {процент} в древнерусском {употреблении|потреблении} {почти заводит у и в} исчисляемом значении там одна рыба две рыбы {пятеро|пять} и так далее это и сейчас можно но у нас кроме того есть рыба как обобщающий рыбы как товар сколько рыбы мы можем сказать А древнерусский человек, он говорит, сколько рыб. Поэтому сейчас мы бы сказали, не присылаешь мне рыбы, как масло, как товара. {древнерусской|древнерусская} здесь не присылаешь не процентных денег не рыб То, что он должен был сделать. Дальше очень аккуратно он пишет следующее. Здесь смысловой разрыв, который я так символизирую, чтобы дальше мы будем читать. Ныне, в этом слове нормальный {ядь|ять} конечный, ныне с ядьем, так что ныне совершенно регулярно. Ныне не {пришлете|пришлите}... Да, мне не хватает строк. Ну, одну строку я еще умещу, но все равно это будет меньше, чем... Ну, ладно, одну строку ниже. Потому что я пишу строка в строку, чтобы у вас было представление о том, как выглядит письмо. Но все равно все письмо нельзя {уместить|вместить}. Или... А я напишу где-нибудь там рядом. А что такое к неделе? К воскресенью. Совершенно ясно. Эти два «и» – это двояти. К неделе. Соответственно, к неделе. Ну, а дальше все идет к описанному. {Причем|чем} он очень аккуратно, не ленится написать второй раз. И на этот раз уже с «ер»ом. Потому что это колеблется. Какое будет следующее слово на следующей строке? Рыб, конечно, {правят|правил}. Давайте тогда... Больше мне ничего не остается, как сюда перейти. Значит, следующая строка. Сколько у нас? Раз, два, три, четыре, пять. Шестая и последняя строка лицевой стороны – и это конец лицевой стороны. Так что дальше ему пришлось писать на... {наоборот истанины но|на оборотной сторонену} довольно понятно что {создать дом число и такое мнение|том присловий такой ныне} не... Не пришлете, а вот теперь, в том смысле, что я долго ждал, но уж теперь, если не пришлете, {подождите|подожнается}, к ближайшему воскресенью, процентных денег и рыб {замечательное|замечательно} не такое сейчас мы скажем {и нато до|иногда} {было|был} естественно не поскольку под отрицанием не рыб то что будет это и соответствует нашему то. Нормально в {XIV|четырнадцатом} веке оборот типа «если то» мог бы быть там какой-нибудь «оже и», в отличие от {XII|двенадцатого} века, где было не «и», а «а» в этом значении. Это меняется и тоже датирует довольно хорошо. И что такое, значит, «и {слатьми|слать ми} по вас»? Как вы это понимаете? Слать ми – это ровно тот же синтаксис, что у нас был мне выдавать мне свою дочь а сестре приставничать как Предстоит, должен и так далее. Совершенно точно. Провалиться мне сквозь землю. Вот типичный синтаксис, прекрасно работающий в современном русском языке. Ну или какой-нибудь там «мне скоро уезжать». Все эти формулы совершенно устойчивы. Так что «мне предстоит слать». Это говорится, очевидно, «{ну|но} ничего не поняла, и что, да? Мне ничего не {остаётся|остается}, как слать по вас». Ну и тогда попробуйте придумать как продолжение. слать по вас то есть за вами {придется|придет} мне придется {ссылать|слать} за вами и что то что начинается на {бит вирусчик|бибиющих} в этом что то есть да По сути дела, конечно, конечно, слать за вами каких-то, которые вам крепко покажут, как так плохо себя вести. Но это немножко надо... Может быть, от глагола «бить». Но не буду, действительно, {эта|это} задача немножко слишком сложная. Но следующая строка замечательным образом начинается с следующих четырех букв. Надо знать слово «беречь». Беречь в точности тот {человек}, гражданский офицер тот исполнитель судебный исполнитель которого который призначался для того чтобы там своими кулачными помощниками являлся {за} исканием долга наложением штрафа приведением человека в суд и так далее Так что {это|эта} угроза, которая у нас бывает в других формах и в других текстах, что если там что-то такое вовремя не будет возвращено или выплачено, то за это будет вызван этот... Беричи может иметь и другие названия. В самых древних текстах вместо слова «{беричь|беречь}» выступает слово «{отрак|отрок}» замечательным образом. Это вовсе не младенец, как раз очень такая фигура устрашающая. Это младший офицер, поэтому он отрок первоначально, но он отрок только по сравнению с «могучими воинами». а на самом деле облечен властью, и которого посылают для того, чтобы взыскать силы штраф и так далее. {Спереди|десперечи}. ну и последняя фраза понятно Это очень {такое|такая}... характерная фигура {такого действия|такая то действий}. Значит, если вы не выполните то, что я от вас хочу и требую, то я сделаю такую-то неприятную вещь, и уж тогда на меня не жалуйтесь. Здесь все очень понятно и прозрачно, кроме только одного места «ме», которое явно создает некоторую лингвистическую задачу. Потому что могло бы быть «мя», «намя», а «намя» се не {жальте|жаль}. И можно было бы даже думать о том, что здесь каким-то образом фонетическая смена на мне произошла. Почему и как, это был бы отдельный вопрос, но в принципе можно было думать. Если бы не то, что сочетания типа «намя», «замя», «натя», «затя», «предтя» в это время уже ушли из языка. И вместо них уже употребляются полные местоимения. На меня, за {тебя|себя}, за себя и так далее. Совершенно как у нас с вами. Это был бы большой анахронизм. Поэтому идти по пути и думать, каким образом здесь я изменилась на е, бессмысленно. Хронологически это {невозможно|нереально}. {Хронологически|хронлогически} единственное, что остается, это то, что очень простая вещь, что у него было что-то типа на {ме|мессе} {сине|не}, а должен было бы на ме не сине, с двумя не. И тогда {было|был} бы {ми не|дмене}, который полностью здесь ожидается по правилам {XIV|четырнадцатого} века. То есть из этих двух не, немножко разделенных между собой, он одно... по одной, ну, такой психологической ошибке, которая бывает, вообще говоря, пропустил. Так что нам приходится здесь признать все-таки некоторый маленький огрех. А смысл совершенно ясен. Ну вот, не хочу на этом слишком долго останавливаться. Тем не менее, вы видите, что... Вполне такое прозрачное письмо. очень характерного с характерной структурой и концом, который в разных вариантах у нас в других грамотах тоже встречается. Все, перехожу {к следующей Спасибо|в следующих рамтину у}. Вас тоже не удивит, если я скажу, что {если... Прямо|есть там}-то тоже {целые|целое}. Я бы сказал так, это меня самого удивляет. что такое возможно. Но тем не менее, это изобилие есть. Это просто следующие по порядку находки. находимся тоже поклон нет В данном случае он пишет... бытовым образом. {Так, ладно. А|поклоноа} это что за человек? Как вы его понимаете? Да, наверное, да. это вполне русское {прозвище|провозвище} действительно так совершенно верно уже эпохи поздней когда уже оглушение согласных может быть зафиксировано на письме Одно это было бы достаточно, чтобы никоим образом эта грамота не могла бы быть признана {XII|двенадцатого} века, сколько бы ни говорили, что вот мы нашли на такой-то глубине, ничего подобного. значит попал не туда потому что в мордке через {ты|т} может быть написано только начиная с конца {13|тринадцатого} века и позже ну и так далее это так {панике|паленький} пример такого датирования независимо от {стратеграфии|стратиграфии} {вот мордки вот мордки|от мортки от мортки} и дальше довольно любопытно получается две строки вместе очень Первый раз нам встречается, чтобы два человека таким образом обращались к кому бы то ни было. Один обращается к человеку, называя его Афанос, по именем, а другой обращается, называя его {господином|господину} моему. Можно вообще говоря представить себе, что они просто в разном положении находятся, что этот Семен, он же {Смён|смен}, {и} так сказать равный афаносу {а} {мордка|мортка} такой который себя равным {фанату|фаноса} считать не может возможно что решение в этом {и|мы} точно до конца {этого|это} не знаем сколько {до|то} единственный пример других у нас нет сравнить пока что не с чем и\n" + ] + } + ], + "source": [ + "audio_name = 'zaliznyak'\n", + "\n", + "base_pipeline_name = 'Pisets WhisperV3 no-VAD (segments 1s-20s) with scores'\n", + "additional_pipeline_name = 'W2V2 Golos LM'\n", + "\n", + "is_uncertain = uncertainty_results.query(\n", + " f'audio_name == \"{audio_name}\"'\n", + " f' and base_pipeline == \"{base_pipeline_name}\"'\n", + " ' and method == \"LM filtered diffs + WhisperLogProba_sum (t=-1)\"'\n", + ").iloc[0]['mask']\n", + "\n", + "truth_vs_base = results.query(\n", + " f'audio_name == \"{audio_name}\"'\n", + " f' and pipeline_name == \"{base_pipeline_name}\"'\n", + ").iloc[0]['alignment']\n", + "\n", + "truth_vs_additional = results.query(\n", + " f'audio_name == \"{audio_name}\"'\n", + " f' and pipeline_name == \"{additional_pipeline_name}\"'\n", + ").iloc[0]['alignment']\n", + "\n", + "base_vs_additional = MultipleTextsAlignment.from_strings(truth_vs_base.text2, truth_vs_additional.text2)\n", + "diffs_to_highlight = [i for i, m in enumerate(base_vs_additional.matches) if is_uncertain[m.start1:m.end1].any()]\n", + "print(base_vs_additional.substitute(show_in_braces=diffs_to_highlight))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/evaluation/make_predictions.py b/evaluation/make_predictions.py new file mode 100644 index 0000000..9205f64 --- /dev/null +++ b/evaluation/make_predictions.py @@ -0,0 +1,275 @@ +import time +import json +from pathlib import Path +from typing import Callable, Literal +from dataclasses import dataclass + +import torch +import numpy as np +from datasets import load_dataset, Audio +from transformers import pipeline, Pipeline, WhisperProcessor + +from asr.asr import ( + initialize_model_for_speech_segmentation, + initialize_model_for_speech_classification, + initialize_model_for_speech_recognition, + transcribe +) + +class TranscribeWhisperPipeline: + """ + A Whisper baseline to compare with `TranscribePisets`. + """ + def __init__(self, predictions_name: str): + self.predictions_name = predictions_name + self.whisper_pipeline = pipeline( + 'automatic-speech-recognition', + model='openai/whisper-large-v3', + chunk_length_s=20, + stride_length_s=(4, 2), + device='cuda:0', + model_kwargs={'attn_implementation': 'sdpa'}, + # torch_dtype=torch.float16, + generate_kwargs={ + 'language': '<|ru|>', + 'task': 'transcribe', + 'forced_decoder_ids': None + } + ) + + def __call__(self, waveform: np.ndarray) -> dict[str, str]: + return {self.predictions_name: self.whisper_pipeline(waveform)['text']} + + +class TranscribeWhisperLongform(TranscribeWhisperPipeline): + """ + A Whisper longform baseline to compare with `TranscribePisets`. + """ + def __init__(self, predictions_name: str, condition_on_prev_tokens: bool): + super().__init__(predictions_name) + self.whisper_processor = WhisperProcessor.from_pretrained( + 'openai/whisper-large-v3', + language='Russian', + task='transcribe', + ) + self.condition_on_prev_tokens = condition_on_prev_tokens + + def __call__(self, waveform: np.ndarray) -> dict[str, str]: + # https://github.com/huggingface/transformers/pull/27658 + inputs = self.whisper_processor( + waveform, + return_tensors='pt', + truncation=False, + padding='longest', + return_attention_mask=True, # probably we do not need this for Whisper + sampling_rate=16_000 + ) + result = self.whisper_pipeline.model.generate( + **inputs.to('cuda'), + condition_on_prev_tokens=self.condition_on_prev_tokens, + # temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + temperature=0, # for determinism + do_sample=False, # for determinism + logprob_threshold=-1.0, + compression_ratio_threshold=1.35, + return_timestamps=True, + language='<|ru|>', + task='transcribe', + ) + text = self.whisper_processor.batch_decode(result, skip_special_tokens=True)[0] + return {self.predictions_name: text} + + +@dataclass +class TranscribePisets: + """ + A Pisets wrapper for evaluation purposes. + + Transcribes waveform with Pisets and returns results for all stages. + + In contrast to asr.asr.transcribe() this class: + - Concatenates transcriptions for all segments + - Does not return timestamps + - Allows to define custom names for all stages + """ + + segmenter: Pipeline | Callable + vad: Pipeline | Callable | Literal['skip'] + asr: Pipeline | Callable | Literal['skip'] + + min_segment_size: int = 1 + max_segment_size: int = 20 + stretch: tuple[int, int] | None = None + + segmenter_predictions_name: str | None = None + asr_predictions_name: str | None = None + asr_stretched_predictions_name: str | None = None + + def __call__(self, waveform: np.ndarray) -> dict[str, str]: + # transcribing + outputs = transcribe( + waveform, + segmenter=self.segmenter, + voice_activity_detector=( + self.vad + if self.vad != 'skip' + else (lambda audio: [{'score': 1, 'label': 'Speech'}]) + ), + asr=( + self.asr + if self.asr != 'skip' + else (lambda audio: {'text': 'none'}) + ), + min_segment_size=self.min_segment_size, + max_segment_size=self.max_segment_size, + stretch=self.stretch, + ) + # concatenating segments + results = {} + if self.segmenter_predictions_name is not None: + results[self.segmenter_predictions_name] = ' '.join([s.transcription_from_segmenter for s in outputs]) + if self.asr_predictions_name is not None: + results[self.asr_predictions_name] = ' '.join([s.transcription for s in outputs]) + if self.asr_stretched_predictions_name is not None: + results[self.asr_stretched_predictions_name] = ' '.join([s.transcription_stretched for s in outputs]) + return results + + +@dataclass +class TranscribeNoisy: + """ + Transcribe with a specified signal-to-noise ratio + """ + snr: float + transcriber: Callable + + def __call__(self, waveform: np.ndarray) -> dict[str, str]: + # TODO augment + return self.transcriber(waveform) + + +# defining transcribers without instantiating them all at once to save GPU memory + +transcribers = { + 'Whisper pipeline': lambda: TranscribeWhisperPipeline( + predictions_name='Baseline Whisper pipeline', + ), + 'Whisper longform': lambda: TranscribeWhisperLongform( + predictions_name='Baseline Whisper longform', + condition_on_prev_tokens=False, + ), + 'Whisper longform conditioned': lambda: TranscribeWhisperLongform( + predictions_name='Baseline Whisper longform conditioned', + condition_on_prev_tokens=True, + ), + 'Pisets (segments 1s-20s)': lambda: TranscribePisets( + segmenter=initialize_model_for_speech_segmentation('ru', 'bond005/wav2vec2-large-ru-golos-with-lm'), + vad=initialize_model_for_speech_classification(), + asr=initialize_model_for_speech_recognition('ru', 'openai/whisper-large-v3'), + min_segment_size=1, + max_segment_size=20, + stretch=(3, 4), + segmenter_predictions_name='W2V2 Golos LM', + asr_predictions_name='Pisets WhisperV3 (segments 1s-20s)', + asr_stretched_predictions_name='Pisets WhisperV3 stretched (segments 1s-20s)', + ), + 'Pisets (segments 10s-30s)': lambda: TranscribePisets( + segmenter=initialize_model_for_speech_segmentation('ru', 'bond005/wav2vec2-large-ru-golos-with-lm'), + vad=initialize_model_for_speech_classification(), + asr=initialize_model_for_speech_recognition('ru', 'openai/whisper-large-v3'), + min_segment_size=10, + max_segment_size=30, + asr_predictions_name='Pisets WhisperV3 (segments 10s-30s)', + ), + # 'W2V2 golos no LM': lambda: TranscribePisets( + # segmenter=initialize_model_for_speech_segmentation('ru', 'bond005/wav2vec2-large-ru-golos'), + # vad='skip', + # asr='skip', + # segmenter_predictions_name='W2V2 Golos no LM', + # ), + 'Pisets Podlodka': lambda: TranscribePisets( + segmenter=initialize_model_for_speech_segmentation('ru', 'bond005/wav2vec2-large-ru-golos-with-lm'), + vad=initialize_model_for_speech_classification(), + asr=initialize_model_for_speech_recognition('ru', 'bond005/whisper-large-v3-ru-podlodka'), + min_segment_size=1, + max_segment_size=20, + asr_predictions_name='Pisets WhisperV3 Podlodka (segments 1s-20s)', + ), + 'Pisets no-VAD': lambda: TranscribePisets( + segmenter=initialize_model_for_speech_segmentation('ru', 'bond005/wav2vec2-large-ru-golos-with-lm'), + vad='skip', + asr=initialize_model_for_speech_recognition('ru', 'openai/whisper-large-v3'), + min_segment_size=1, + max_segment_size=20, + stretch=(3, 4), + asr_predictions_name='Pisets WhisperV3 no-VAD (segments 1s-20s)', + asr_stretched_predictions_name='Pisets WhisperV3 no-VAD stretched (segments 1s-20s)', + ), + 'Pisets no-VAD Podlodka': lambda: TranscribePisets( + segmenter=initialize_model_for_speech_segmentation('ru', 'bond005/wav2vec2-large-ru-golos-with-lm'), + vad='skip', + asr=initialize_model_for_speech_recognition('ru', 'bond005/whisper-large-v3-ru-podlodka'), + min_segment_size=1, + max_segment_size=20, + asr_predictions_name='Pisets WhisperV3 no-VAD Podlodka (segments 1s-20s)', + ), +} + +# for snr in [1, 2, 3, 4, 5]: +# transcribers[f'Whisper longform SNR={snr}'] = lambda: TranscribeNoisy( +# snr=snr, +# transcriber=TranscribeWhisperLongform( +# predictions_name=f'Baseline Whisper longform SNR={snr}', +# condition_on_prev_tokens=False, +# ), +# ) +# transcribers[f'Pisets (segments 1s-20s) SNR={snr}'] = lambda: TranscribeNoisy( +# snr=snr, +# transcriber=TranscribePisets( +# segmenter=initialize_model_for_speech_segmentation('ru', 'bond005/wav2vec2-large-ru-golos-with-lm'), +# vad=initialize_model_for_speech_classification(), +# asr=initialize_model_for_speech_recognition('ru', 'openai/whisper-large-v3'), +# min_segment_size=1, +# max_segment_size=20, +# segmenter_predictions_name=f'W2V2 Golos LM SNR={snr}', +# asr_predictions_name=f'Pisets WhisperV3 (segments 1s-20s) SNR={snr}', +# ), +# ) + + +dataset = ( + load_dataset('dangrebenkin/long_audio_youtube_lectures') + .cast_column('audio', Audio(sampling_rate=16_000)) + ['train'] +) + +output_dir = Path('/home/oleg/pisets_test_results') +output_dir.mkdir(parents=True, exist_ok=True) + +for transcriber_name, transcriber_lambda in transcribers.items(): + + # instantiate transcriber on GPU + transcriber = transcriber_lambda() + + for sample in dataset: + print(filepath := output_dir / f'{sample["name"]} {transcriber_name}.json') + + torch.cuda.reset_peak_memory_stats() + + if filepath.is_file(): + print(f'Already exists') + continue + + start_time = time.time() + transcriptions = transcriber(sample['audio']['array']) + print('Elapsed', elapsed_time := time.time() - start_time) + + with open(filepath, 'w') as f: + json.dump({ + 'audio_name': sample['name'], + 'transcriber_name': transcriber_name, + 'elapsed_time': elapsed_time, + 'transcriptions': transcriptions, + }, f) + + print(f'GPU max allocated memory: {torch.cuda.max_memory_allocated(0) / 2**30:.2f} GB') \ No newline at end of file diff --git a/evaluation/make_predictions_with_scores.py b/evaluation/make_predictions_with_scores.py new file mode 100644 index 0000000..31f6032 --- /dev/null +++ b/evaluation/make_predictions_with_scores.py @@ -0,0 +1,66 @@ +import json +from pathlib import Path +import dataclasses + +from datasets import load_dataset, Audio +from tqdm.auto import tqdm + +from asr.asr import ( + initialize_model_for_speech_segmentation, + initialize_model_for_speech_recognition, + transcribe +) +from asr.comparison import TokenizedText +from asr.whisper_scores import whisper_pipeline_transcribe_with_word_scores + + +dataset = ( + load_dataset('dangrebenkin/long_audio_youtube_lectures') + .cast_column('audio', Audio(sampling_rate=16_000)) + ['train'] +) + +output_dir = Path('/home/oleg/pisets_test_results_with_scores') +output_dir.mkdir(parents=True, exist_ok=True) + +segmenter = initialize_model_for_speech_segmentation('ru', 'bond005/wav2vec2-large-ru-golos') +whisper_pipeline = initialize_model_for_speech_recognition('ru', 'openai/whisper-large-v3') + +for sample in dataset: + print(sample['name']) + + waveform = sample['audio']['array'] + + results = transcribe( + waveform, + segmenter=segmenter, + voice_activity_detector=lambda audio: [{'score': 1, 'label': 'Speech'}], + asr=lambda audio: {'text': 'none'}, + min_segment_size=1, + max_segment_size=20, + ) + + tokenized_segments = [] + scores_per_word = [] + + for segment in tqdm(results, desc='whisper'): + waveform_segment = waveform[int(segment.start * 16_000):int(segment.end * 16_000)] + tokenized_text_for_segment, _, scores_for_segment = ( + whisper_pipeline_transcribe_with_word_scores(waveform_segment, whisper_pipeline) + ) + tokenized_segments.append(tokenized_text_for_segment) + scores_per_word += scores_for_segment + + tokenized_text = TokenizedText.concatenate(tokenized_segments) + + transcriber_name = 'Pisets WhisperV3 no-VAD (segments 1s-20s) with scores' + + filepath = output_dir / f'{sample["name"]} {transcriber_name}.json' + + with open(filepath, 'w') as f: + json.dump({ + 'audio_name': sample['name'], + 'transcriber_name': transcriber_name, + 'tokenized_text': dataclasses.asdict(tokenized_text), + 'scores_per_word': scores_per_word, + }, f, ensure_ascii=False) \ No newline at end of file diff --git a/evaluation/requirements.txt b/evaluation/requirements.txt new file mode 100644 index 0000000..b4d3ca3 --- /dev/null +++ b/evaluation/requirements.txt @@ -0,0 +1,4 @@ +pysrt +soundfile>=0.12.1 +librosa +matplotlib \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 71501c2..e715475 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,8 @@ torchaudio==2.3.1 torchvision==0.18.1 tokenizers>=0.19.1 transformers>=4.41.2 -webrtcvad>=2.0.10 \ No newline at end of file +webrtcvad>=2.0.10 +setuptools +pymystem3 +kenlm +pyctcdecode \ No newline at end of file diff --git a/server_ru.py b/server_ru.py index 5cea57e..5a65adb 100644 --- a/server_ru.py +++ b/server_ru.py @@ -147,11 +147,14 @@ async def transcribe(): async def create_result_file(input_sound, segmenter, vad, asr, task_id): - texts_with_timestamps = transcribe_speech(input_sound, segmenter, vad, asr, MIN_FRAME_SIZE, MAX_FRAME_SIZE) + segment_transcriptions = transcribe_speech(input_sound, segmenter, vad, asr, MIN_FRAME_SIZE, MAX_FRAME_SIZE) output_filename = task_id + '.docx' doc = Document() - for start_time, end_time, sentence_text in texts_with_timestamps: - line = f'{start_time:.2f} - {end_time:.2f} - {sentence_text}' + for segment_transcription in segment_transcriptions: + start_time = segment_transcription.start + end_time = segment_transcription.end + text_final = segment_transcription.transcription + line = f'{start_time:.2f} - {end_time:.2f} - {text_final}' doc.add_paragraph(line) doc.add_paragraph('') diff --git a/tests/test_asr.py b/tests/test_asr.py index 06730fc..6ecd957 100644 --- a/tests/test_asr.py +++ b/tests/test_asr.py @@ -19,47 +19,47 @@ class TestASR(unittest.TestCase): def test_strip_segments_pos01(self): max_sound_duration = 5.5 - input_segments = [(0.1, 0.9), (0.95, 3.0), (3.0, 5.0)] - target_segments = [(0.1, 0.9), (0.95, 3.0), (3.0, 5.0)] + input_segments = [(0.1, 0.9, ''), (0.95, 3.0, ''), (3.0, 5.0, '')] + target_segments = [(0.1, 0.9, ''), (0.95, 3.0, ''), (3.0, 5.0, '')] predicted_segments = strip_segments(input_segments, max_sound_duration) self.assertIsInstance(predicted_segments, list) self.assertEqual(len(predicted_segments), len(target_segments)) for idx in range(len(target_segments)): self.assertIsInstance(predicted_segments[idx], tuple) - self.assertEqual(len(predicted_segments[idx]), 2) + self.assertEqual(len(predicted_segments[idx]), 3) self.assertAlmostEqual(predicted_segments[idx][0], target_segments[idx][0]) self.assertAlmostEqual(predicted_segments[idx][1], target_segments[idx][1]) def test_strip_segments_pos02(self): max_sound_duration = 5.5 - input_segments = [(-0.1, 0.9), (0.95, 3.0), (3.0, 5.0)] - target_segments = [(0.0, 0.9), (0.95, 3.0), (3.0, 5.0)] + input_segments = [(-0.1, 0.9, ''), (0.95, 3.0, ''), (3.0, 5.0, '')] + target_segments = [(0.0, 0.9, ''), (0.95, 3.0, ''), (3.0, 5.0, '')] predicted_segments = strip_segments(input_segments, max_sound_duration) self.assertIsInstance(predicted_segments, list) self.assertEqual(len(predicted_segments), len(target_segments)) for idx in range(len(target_segments)): self.assertIsInstance(predicted_segments[idx], tuple) - self.assertEqual(len(predicted_segments[idx]), 2) + self.assertEqual(len(predicted_segments[idx]), 3) self.assertAlmostEqual(predicted_segments[idx][0], target_segments[idx][0]) self.assertAlmostEqual(predicted_segments[idx][1], target_segments[idx][1]) def test_strip_segments_pos03(self): max_sound_duration = 5.5 - input_segments = [(0.1, 0.9), (0.95, 3.0), (3.0, 5.8)] - target_segments = [(0.1, 0.9), (0.95, 3.0), (3.0, 5.5)] + input_segments = [(0.1, 0.9, ''), (0.95, 3.0, ''), (3.0, 5.8, '')] + target_segments = [(0.1, 0.9, ''), (0.95, 3.0, ''), (3.0, 5.5, '')] predicted_segments = strip_segments(input_segments, max_sound_duration) self.assertIsInstance(predicted_segments, list) self.assertEqual(len(predicted_segments), len(target_segments)) for idx in range(len(target_segments)): self.assertIsInstance(predicted_segments[idx], tuple) - self.assertEqual(len(predicted_segments[idx]), 2) + self.assertEqual(len(predicted_segments[idx]), 3) self.assertAlmostEqual(predicted_segments[idx][0], target_segments[idx][0]) self.assertAlmostEqual(predicted_segments[idx][1], target_segments[idx][1]) def test_select_word_groups_pos01(self): segment_size = 2 - words = [(0.1, 0.5), (0.7, 1.0), (1.1, 2.3), (2.7, 2.8), (3.6, 3.8), (3.8, 4.0)] - target_groups = [[(0.1, 0.5)], [(0.7, 1.0), (1.1, 2.3)], [(2.7, 2.8)], [(3.6, 3.8), (3.8, 4.0)]] + words = [(0.1, 0.5, ''), (0.7, 1.0, ''), (1.1, 2.3, ''), (2.7, 2.8, ''), (3.6, 3.8, ''), (3.8, 4.0, '')] + target_groups = [[(0.1, 0.5, '')], [(0.7, 1.0, ''), (1.1, 2.3, '')], [(2.7, 2.8, '')], [(3.6, 3.8, ''), (3.8, 4.0, '')]] predicted_groups = select_word_groups(words, segment_size) self.assertIsInstance(predicted_groups, list) self.assertEqual(len(predicted_groups), len(target_groups)) @@ -68,14 +68,14 @@ def test_select_word_groups_pos01(self): self.assertEqual(len(predicted_groups[group_idx]), len(target_groups[group_idx])) for word_idx in range(len(target_groups[group_idx])): self.assertIsInstance(predicted_groups[group_idx][word_idx], tuple) - self.assertEqual(len(predicted_groups[group_idx][word_idx]), 2) + self.assertEqual(len(predicted_groups[group_idx][word_idx]), 3) self.assertAlmostEqual(predicted_groups[group_idx][word_idx][0], target_groups[group_idx][word_idx][0]) self.assertAlmostEqual(predicted_groups[group_idx][word_idx][1], target_groups[group_idx][word_idx][1]) def test_select_word_groups_pos02(self): segment_size = 2 - words = [(0.1, 0.5), (0.7, 1.0)] - target_groups = [[(0.1, 0.5), (0.7, 1.0)]] + words = [(0.1, 0.5, ''), (0.7, 1.0, '')] + target_groups = [[(0.1, 0.5, ''), (0.7, 1.0, '')]] predicted_groups = select_word_groups(words, segment_size) self.assertIsInstance(predicted_groups, list) self.assertEqual(len(predicted_groups), len(target_groups)) @@ -84,14 +84,14 @@ def test_select_word_groups_pos02(self): self.assertEqual(len(predicted_groups[group_idx]), len(target_groups[group_idx])) for word_idx in range(len(target_groups[group_idx])): self.assertIsInstance(predicted_groups[group_idx][word_idx], tuple) - self.assertEqual(len(predicted_groups[group_idx][word_idx]), 2) + self.assertEqual(len(predicted_groups[group_idx][word_idx]), 3) self.assertAlmostEqual(predicted_groups[group_idx][word_idx][0], target_groups[group_idx][word_idx][0]) self.assertAlmostEqual(predicted_groups[group_idx][word_idx][1], target_groups[group_idx][word_idx][1]) def test_select_word_groups_pos03(self): segment_size = 2 - words = [(0.1, 0.5), (3.7, 4.0)] - target_groups = [[(0.1, 0.5)], [(3.7, 4.0)]] + words = [(0.1, 0.5, ''), (3.7, 4.0, '')] + target_groups = [[(0.1, 0.5, '')], [(3.7, 4.0, '')]] predicted_groups = select_word_groups(words, segment_size) self.assertIsInstance(predicted_groups, list) self.assertEqual(len(predicted_groups), len(target_groups)) @@ -100,14 +100,14 @@ def test_select_word_groups_pos03(self): self.assertEqual(len(predicted_groups[group_idx]), len(target_groups[group_idx])) for word_idx in range(len(target_groups[group_idx])): self.assertIsInstance(predicted_groups[group_idx][word_idx], tuple) - self.assertEqual(len(predicted_groups[group_idx][word_idx]), 2) + self.assertEqual(len(predicted_groups[group_idx][word_idx]), 3) self.assertAlmostEqual(predicted_groups[group_idx][word_idx][0], target_groups[group_idx][word_idx][0]) self.assertAlmostEqual(predicted_groups[group_idx][word_idx][1], target_groups[group_idx][word_idx][1]) def test_select_word_groups_pos04(self): segment_size = 2 - words = [(0.1, 4.0)] - target_groups = [[(0.1, 4.0)]] + words = [(0.1, 4.0, '')] + target_groups = [[(0.1, 4.0, '')]] predicted_groups = select_word_groups(words, segment_size) self.assertIsInstance(predicted_groups, list) self.assertEqual(len(predicted_groups), len(target_groups)) @@ -116,7 +116,7 @@ def test_select_word_groups_pos04(self): self.assertEqual(len(predicted_groups[group_idx]), len(target_groups[group_idx])) for word_idx in range(len(target_groups[group_idx])): self.assertIsInstance(predicted_groups[group_idx][word_idx], tuple) - self.assertEqual(len(predicted_groups[group_idx][word_idx]), 2) + self.assertEqual(len(predicted_groups[group_idx][word_idx]), 3) self.assertAlmostEqual(predicted_groups[group_idx][word_idx][0], target_groups[group_idx][word_idx][0]) self.assertAlmostEqual(predicted_groups[group_idx][word_idx][1], target_groups[group_idx][word_idx][1]) @@ -273,98 +273,98 @@ def test_remove_oscillatory_hallucinations_pos02(self): self.assertEqual(res, true_text) def test_join_short_segments_to_long_ones_pos01(self): - source_segments = [(0.5, 2.5), (2.7, 3.92), (5.0, 7.5)] - true_segments = [(0.5, 2.5), (2.7, 3.92), (5.0, 7.5)] + source_segments = [(0.5, 2.5, ''), (2.7, 3.92, ''), (5.0, 7.5, '')] + true_segments = [(0.5, 2.5, ''), (2.7, 3.92, ''), (5.0, 7.5, '')] predicted_segments = join_short_segments_to_long_ones(source_segments, 1) self.assertIsInstance(predicted_segments, list) self.assertEqual(len(predicted_segments), len(true_segments)) for idx in range(len(true_segments)): self.assertIsInstance(predicted_segments[idx], tuple) - self.assertEqual(len(predicted_segments[idx]), 2) + self.assertEqual(len(predicted_segments[idx]), 3) self.assertAlmostEqual(predicted_segments[idx][0], true_segments[idx][0], delta=1e-6) self.assertAlmostEqual(predicted_segments[idx][1], true_segments[idx][1], delta=1e-6) def test_join_short_segments_to_long_ones_pos02(self): - source_segments = [(0.5, 1.1), (2.7, 3.92), (5.0, 7.5)] - true_segments = [(0.5, 1.1), (2.7, 3.92), (5.0, 7.5)] + source_segments = [(0.5, 1.1, ''), (2.7, 3.92, ''), (5.0, 7.5, '')] + true_segments = [(0.5, 1.1, ''), (2.7, 3.92, ''), (5.0, 7.5, '')] predicted_segments = join_short_segments_to_long_ones(source_segments, 1) self.assertIsInstance(predicted_segments, list) self.assertEqual(len(predicted_segments), len(true_segments)) for idx in range(len(true_segments)): self.assertIsInstance(predicted_segments[idx], tuple) - self.assertEqual(len(predicted_segments[idx]), 2) + self.assertEqual(len(predicted_segments[idx]), 3) self.assertAlmostEqual(predicted_segments[idx][0], true_segments[idx][0], delta=1e-6) self.assertAlmostEqual(predicted_segments[idx][1], true_segments[idx][1], delta=1e-6) def test_join_short_segments_to_long_ones_pos03(self): - source_segments = [(0.5, 1.1), (1.7, 3.92), (5.0, 7.5)] - true_segments = [(0.5, 3.92), (5.0, 7.5)] + source_segments = [(0.5, 1.1, ''), (1.7, 3.92, ''), (5.0, 7.5, '')] + true_segments = [(0.5, 3.92, ''), (5.0, 7.5, '')] predicted_segments = join_short_segments_to_long_ones(source_segments, 1) self.assertIsInstance(predicted_segments, list) self.assertEqual(len(predicted_segments), len(true_segments)) for idx in range(len(true_segments)): self.assertIsInstance(predicted_segments[idx], tuple) - self.assertEqual(len(predicted_segments[idx]), 2) + self.assertEqual(len(predicted_segments[idx]), 3) self.assertAlmostEqual(predicted_segments[idx][0], true_segments[idx][0], delta=1e-6) self.assertAlmostEqual(predicted_segments[idx][1], true_segments[idx][1], delta=1e-6) def test_join_short_segments_to_long_ones_pos04(self): - source_segments = [(0.5, 2.5), (2.7, 2.92), (5.0, 7.5)] - true_segments = [(0.5, 2.92), (5.0, 7.5)] + source_segments = [(0.5, 2.5, ''), (2.7, 2.92, ''), (5.0, 7.5, '')] + true_segments = [(0.5, 2.92, ''), (5.0, 7.5, '')] predicted_segments = join_short_segments_to_long_ones(source_segments, 1) self.assertIsInstance(predicted_segments, list) self.assertEqual(len(predicted_segments), len(true_segments)) for idx in range(len(true_segments)): self.assertIsInstance(predicted_segments[idx], tuple) - self.assertEqual(len(predicted_segments[idx]), 2) + self.assertEqual(len(predicted_segments[idx]), 3) self.assertAlmostEqual(predicted_segments[idx][0], true_segments[idx][0], delta=1e-6) self.assertAlmostEqual(predicted_segments[idx][1], true_segments[idx][1], delta=1e-6) def test_join_short_segments_to_long_ones_pos05(self): - source_segments = [(0.5, 2.5), (2.7, 2.92), (3.0, 7.5)] - true_segments = [(0.5, 2.5), (2.7, 7.5)] + source_segments = [(0.5, 2.5, ''), (2.7, 2.92, ''), (3.0, 7.5, '')] + true_segments = [(0.5, 2.5, ''), (2.7, 7.5, '')] predicted_segments = join_short_segments_to_long_ones(source_segments, 1) self.assertIsInstance(predicted_segments, list) self.assertEqual(len(predicted_segments), len(true_segments)) for idx in range(len(true_segments)): self.assertIsInstance(predicted_segments[idx], tuple) - self.assertEqual(len(predicted_segments[idx]), 2) + self.assertEqual(len(predicted_segments[idx]), 3) self.assertAlmostEqual(predicted_segments[idx][0], true_segments[idx][0], delta=1e-6) self.assertAlmostEqual(predicted_segments[idx][1], true_segments[idx][1], delta=1e-6) def test_join_short_segments_to_long_ones_pos06(self): - source_segments = [(0.5, 2.5), (2.7, 3.92), (4.0, 4.3)] - true_segments = [(0.5, 2.5), (2.7, 4.3)] + source_segments = [(0.5, 2.5, ''), (2.7, 3.92, ''), (4.0, 4.3, '')] + true_segments = [(0.5, 2.5, ''), (2.7, 4.3, '')] predicted_segments = join_short_segments_to_long_ones(source_segments, 1) self.assertIsInstance(predicted_segments, list) self.assertEqual(len(predicted_segments), len(true_segments)) for idx in range(len(true_segments)): self.assertIsInstance(predicted_segments[idx], tuple) - self.assertEqual(len(predicted_segments[idx]), 2) + self.assertEqual(len(predicted_segments[idx]), 3) self.assertAlmostEqual(predicted_segments[idx][0], true_segments[idx][0], delta=1e-6) self.assertAlmostEqual(predicted_segments[idx][1], true_segments[idx][1], delta=1e-6) def test_join_short_segments_to_long_ones_pos07(self): - source_segments = [(0.5, 2.5), (2.7, 3.92), (5.0, 5.5)] - true_segments = [(0.5, 2.5), (2.7, 3.92), (5.0, 5.5)] + source_segments = [(0.5, 2.5, ''), (2.7, 3.92, ''), (5.0, 5.5, '')] + true_segments = [(0.5, 2.5, ''), (2.7, 3.92, ''), (5.0, 5.5, '')] predicted_segments = join_short_segments_to_long_ones(source_segments, 1) self.assertIsInstance(predicted_segments, list) self.assertEqual(len(predicted_segments), len(true_segments)) for idx in range(len(true_segments)): self.assertIsInstance(predicted_segments[idx], tuple) - self.assertEqual(len(predicted_segments[idx]), 2) + self.assertEqual(len(predicted_segments[idx]), 3) self.assertAlmostEqual(predicted_segments[idx][0], true_segments[idx][0], delta=1e-6) self.assertAlmostEqual(predicted_segments[idx][1], true_segments[idx][1], delta=1e-6) def test_join_short_segments_to_long_ones_pos08(self): - source_segments = [(0.5, 0.6)] - true_segments = [(0.5, 0.6)] + source_segments = [(0.5, 0.6, '')] + true_segments = [(0.5, 0.6, '')] predicted_segments = join_short_segments_to_long_ones(source_segments, 1) self.assertIsInstance(predicted_segments, list) self.assertEqual(len(predicted_segments), len(true_segments)) for idx in range(len(true_segments)): self.assertIsInstance(predicted_segments[idx], tuple) - self.assertEqual(len(predicted_segments[idx]), 2) + self.assertEqual(len(predicted_segments[idx]), 3) self.assertAlmostEqual(predicted_segments[idx][0], true_segments[idx][0], delta=1e-6) self.assertAlmostEqual(predicted_segments[idx][1], true_segments[idx][1], delta=1e-6) diff --git a/tests/test_asr_en.py b/tests/test_asr_en.py index 8bd386c..3d0f406 100644 --- a/tests/test_asr_en.py +++ b/tests/test_asr_en.py @@ -77,17 +77,12 @@ def test_recognize_pos01(self): max_segment_size=5 ) true_words = ['neural', 'networks', 'are', 'good'] - self.assertIsInstance(res, list) + predicted_text = ' '.join([r.transcription for r in res]).lower() self.assertEqual(len(res), 1) - self.assertIsInstance(res[0], tuple) - self.assertEqual(len(res[0]), 3) - self.assertIsInstance(res[0][0], float) - self.assertIsInstance(res[0][1], float) - self.assertIsInstance(res[0][2], str) - self.assertLessEqual(0.0, res[0][0]) - self.assertLess(res[0][0], res[0][1]) - self.assertLessEqual(res[0][1], self.sound.shape[0] / TARGET_SAMPLING_FREQUENCY) - predicted_words = list(filter(lambda it: it.isalnum(), wordpunct_tokenize(res[0][2].lower()))) + self.assertLessEqual(0.0, res[0].start) + self.assertLess(res[0].start, res[0].end) + self.assertLessEqual(res[0].end, self.sound.shape[0] / TARGET_SAMPLING_FREQUENCY) + predicted_words = list(filter(lambda it: it.isalnum(), wordpunct_tokenize(predicted_text))) self.assertEqual(predicted_words, true_words) def test_recognize_pos02(self): diff --git a/tests/test_asr_ru.py b/tests/test_asr_ru.py index bd1ae37..5d62666 100644 --- a/tests/test_asr_ru.py +++ b/tests/test_asr_ru.py @@ -77,17 +77,12 @@ def test_recognize_pos01(self): max_segment_size=5 ) true_words = ['нейронные', 'сети', 'это', 'хорошо'] - self.assertIsInstance(res, list) + predicted_text = ' '.join([r.transcription for r in res]).lower() self.assertEqual(len(res), 1) - self.assertIsInstance(res[0], tuple) - self.assertEqual(len(res[0]), 3) - self.assertIsInstance(res[0][0], float) - self.assertIsInstance(res[0][1], float) - self.assertIsInstance(res[0][2], str) - self.assertLessEqual(0.0, res[0][0]) - self.assertLess(res[0][0], res[0][1]) - self.assertLessEqual(res[0][1], self.sound.shape[0] / TARGET_SAMPLING_FREQUENCY) - predicted_words = list(filter(lambda it: it.isalnum(), wordpunct_tokenize(res[0][2].lower()))) + self.assertLessEqual(0.0, res[0].start) + self.assertLess(res[0].start, res[0].end) + self.assertLessEqual(res[0].end, self.sound.shape[0] / TARGET_SAMPLING_FREQUENCY) + predicted_words = list(filter(lambda it: it.isalnum(), wordpunct_tokenize(predicted_text))) self.assertEqual(predicted_words, true_words) def test_recognize_pos02(self):