From 35d8d0ebe6c24269f48be50e501ff2bed94d9a35 Mon Sep 17 00:00:00 2001 From: chainyo Date: Wed, 7 Jun 2023 18:54:58 +0000 Subject: [PATCH 1/2] implement word_timestamps --- wordcab_transcribe/services/asr_service.py | 2 +- .../services/transcribe_service.py | 309 ++++++++++++++++-- wordcab_transcribe/utils.py | 8 +- 3 files changed, 294 insertions(+), 25 deletions(-) diff --git a/wordcab_transcribe/services/asr_service.py b/wordcab_transcribe/services/asr_service.py index 16de11b..b445d91 100644 --- a/wordcab_transcribe/services/asr_service.py +++ b/wordcab_transcribe/services/asr_service.py @@ -201,7 +201,7 @@ async def process_input( "dual_channel": dual_channel, "source_lang": source_lang, "timestamps_format": timestamps_format, - "word_timestamps": False, # TODO: Implement word timestamps, False for now + "word_timestamps": word_timestamps, "post_processed": False, "transcription_result": None, "transcription_done": asyncio.Event(), diff --git a/wordcab_transcribe/services/transcribe_service.py b/wordcab_transcribe/services/transcribe_service.py index 3bc41ea..07918a2 100644 --- a/wordcab_transcribe/services/transcribe_service.py +++ b/wordcab_transcribe/services/transcribe_service.py @@ -13,6 +13,7 @@ # limitations under the License. """Transcribe Service for audio files.""" +import itertools import math import os import zlib @@ -23,6 +24,7 @@ import torch import torch.nn.functional as F # noqa N812 import torchaudio +from ctranslate2 import StorageView from ctranslate2.models import WhisperGenerationResult from faster_whisper import WhisperModel from faster_whisper.tokenizer import Tokenizer @@ -259,6 +261,7 @@ def __init__( self.hop_length = 160 self.n_samples = self.sample_rate * self.chunk_size + self.tokens_per_second = self.sample_rate // self.hop_length assets_dir = Path(__file__).parent.parent / "assets" / "mel_filters.npz" with np.load(str(assets_dir)) as f: @@ -267,11 +270,14 @@ def __init__( self.compression_ratio_threshold = 2.4 self.log_probability_threshold = -0.8 + self.prepend_punctuation = "\"'“¿([{-" + self.append_punctuation = "\"'.。,,!!??::”)]}、" + def __call__( self, audio: Union[str, torch.Tensor], source_lang: str, - **kwargs: Optional[dict], + word_timestamps: bool = False, ) -> List[dict]: """ Run inference with the transcribe model. @@ -279,10 +285,10 @@ def __call__( Args: audio (Union[str, torch.Tensor]): Audio file to transcribe. source_lang (str): Language of the audio file. - kwargs (Any): Additional arguments to pass to TranscribeService. + word_timestamps (bool): Whether to return word timestamps. Returns: - List[dict]: List of segments with the following keys: "start", "end", "text", "confidence". + List[dict]: List of transcribed segments. """ if self.tokenizer.language_code != source_lang: self.tokenizer = Tokenizer( @@ -292,18 +298,24 @@ def __call__( language=source_lang, ) - outputs = self.pipeline(audio, batch_size=self._batch_size) + outputs = self.pipeline(audio, self._batch_size, word_timestamps) return outputs @time_and_tell - def pipeline(self, audio: Union[str, torch.Tensor], batch_size: int) -> List[dict]: + def pipeline( + self, + audio: Union[str, torch.Tensor], + batch_size: int, + word_timestamps: bool = False, + ) -> List[dict]: """ Transcription pipeline for audio chunks in batches. Args: audio (Union[str, torch.Tensor]): Audio file to transcribe. batch_size (int): Batch size to use for inference. + word_timestamps (bool): Whether to return word timestamps. Returns: List[dict]: List of segments with the following keys: "start", "end", "text". @@ -348,6 +360,7 @@ def pipeline(self, audio: Union[str, torch.Tensor], batch_size: int) -> List[dic sampling_top_k=sampling_top_k, temperature=temperature, last_chance_inference=False if stop_temperature != 1.0 else True, + word_timestamps=word_timestamps, ) for output_index, output in enumerate(batch_outputs): @@ -380,7 +393,7 @@ def pipeline(self, audio: Union[str, torch.Tensor], batch_size: int) -> List[dic else: break # All segments have been processed successfully. - outputs = [item for sublist in _outputs for item in sublist] + outputs = list(itertools.chain.from_iterable(_outputs)) return outputs @@ -403,6 +416,7 @@ def _generate_segment_batched( sampling_top_k: int = 1, temperature: float = 1.0, without_timestamps: bool = False, + word_timestamps: bool = False, ) -> List[dict]: """ Use the ctranslate2 Whisper model to generate text from audio chunks. @@ -422,9 +436,10 @@ def _generate_segment_batched( sampling_top_k (int): Sampling top k to use for sampling. temperature (float): Temperature to use for sampling. without_timestamps (bool): Whether to remove timestamps from the generated text. + word_timestamps (bool): Whether to use word timestamps instead of character timestamps. Returns: - List[dict]: List of segments with the following keys: "start", "end", "text", "confidence". + List[dict]: List of segments with the following keys: "segments", "need_fallback". """ if "TOKENIZERS_PARALLELISM" not in os.environ: os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -447,11 +462,8 @@ def _generate_segment_batched( prefix=prefix, ) - features = get_ctranslate2_storage(features) + features = self._encode_batch(features, word_timestamps=word_timestamps) - # TODO: Could be better to get the results as np.ndarray/torch.tensor and not as a class for speed - # Atm, we need to extract the results as a Python list which is slow because we get this results: - # https://opennmt.net/CTranslate2/python/ctranslate2.models.WhisperGenerationResult.html # TODO: We access the inherited ctranslate2 model for generation here. This is not ideal. result: WhisperGenerationResult = self.model.model.generate( features, @@ -550,23 +562,61 @@ def _generate_segment_batched( ) ) - # TODO: Implement word timestamps - outputs.append( - { - "segments": self._decode_batch(current_segments), - "need_fallback": len(current_segments) == 0, - } + dict( + segments=self._decode_batch(current_segments, tokenizer), + need_fallback=len(current_segments) == 0, + ) + ) + + if word_timestamps: + segment_sizes = [ + int(segment_duration / (self.hop_length / self.sample_rate)) + for segment_duration in segment_durations + ] + self._add_word_timestamps( + outputs, + tokenizer, + features, + segment_sizes, + time_offsets, + self.prepend_punctuation, + self.append_punctuation, ) return outputs - def _decode_batch(self, outputs: List[dict]) -> List[dict]: + def _encode_batch( + self, features: torch.Tensor, word_timestamps: bool + ) -> StorageView: + """Encode the features using the model encoder. + + We encode the features only if word timestamps are enabled. + Otherwise, we just return the features formatted as a StorageView. + + Args: + features (torch.Tensor): Features to encode. + word_timestamps (bool): Whether to encode the features or not. + + Returns: + StorageView: Encoded features. + """ + features = get_ctranslate2_storage(features) + + if ( + word_timestamps + ): # We encode the features to re-use the encoder output later. + features = self.model.model.encode(features, to_cpu=False) + + return features + + def _decode_batch(self, outputs: List[dict], tokenizer: Tokenizer) -> List[dict]: """ Extract the token ids from the sequences ids and decode them using the tokenizer. Args: outputs (List[dict]): List of outputs from the model. + tokenizer (Tokenizer): Tokenizer to use to decode the token ids. Returns: List[str]: List of decoded texts. @@ -575,18 +625,237 @@ def _decode_batch(self, outputs: List[dict]) -> List[dict]: return outputs tokens_to_decode = [ - [token for token in out["tokens"] if token < self.tokenizer.eot] + [token for token in out["tokens"] if token < tokenizer.eot] for out in outputs ] # TODO: We call the inherited tokenizer here, because faster_whisper tokenizer # doesn't have the decode_batch method. We should fix this in the future. - decoded_tokens = self.tokenizer.tokenizer.decode_batch(tokens_to_decode) + decoded_tokens = tokenizer.tokenizer.decode_batch(tokens_to_decode) for out, text in zip(outputs, decoded_tokens): out["text"] = text return outputs + @time_and_tell + def _add_word_timestamps( + self, + outputs: List[dict], + tokenizer: Tokenizer, + encoder_output: StorageView, + segment_sizes: List[int], + time_offsets: List[float], + prepend_punctuation: str, + append_punctuation: str, + ) -> None: + """ + Add word timestamps to the segments. + + Args: + outputs (List[dict]): List of outputs from the model. + tokenizer (Tokenizer): Tokenizer to use to decode the token ids. + encoder_output (StorageView): Encoder output. + segment_sizes (List[int]): List of segment sizes. + time_offsets (List[float]): List of time offsets. + prepend_punctuation (str): Punctuation to prepend to the text. + append_punctuation (str): Punctuation to append to the text. + """ + text_tokens_per_output = [] + for out in outputs: + text_tokens_per_segment = [ + [token for token in segment["tokens"] if token < tokenizer.eot] + for segment in out["segments"] + ] + text_tokens_per_output.append(text_tokens_per_segment) + + alignments = self._find_alignment( + encoder_output, text_tokens_per_output, tokenizer, segment_sizes + ) + self._merge_punctuation(alignments, prepend_punctuation, append_punctuation) + + for out, alignment, text_tokens_per_segment, time_offset in zip( + outputs, alignments, text_tokens_per_output, time_offsets + ): + if out["need_fallback"]: + continue + + word_index = 0 + + for segment_idx, segment in enumerate(out["segments"]): + saved_tokens = 0 + words = [] + + if isinstance(alignment, int): + alignment = [alignment] + + while word_index < len(alignment) and saved_tokens < len( + text_tokens_per_segment[segment_idx] + ): + timing = alignment[word_index] + + if timing["word"]: + words.append( + dict( + word=timing["word"], + start=round(time_offset + timing["start"], 2), + end=round(time_offset + timing["end"], 2), + probability=timing["probability"], + ) + ) + + saved_tokens += len(timing["tokens"]) + word_index += 1 + + if len(words) > 0: + segment["start"] = words[0]["start"] + segment["end"] = words[-1]["end"] + + segment["words"] = words + + def _find_alignment( + self, + encoder_output: StorageView, + text_tokens_per_output: List[List[int]], + tokenizer: Tokenizer, + segment_sizes: List[int], + median_filter_width: int = 7, + ) -> List[List[dict]]: + """ + Find the alignment between the encoder output and the text tokens in a batch. + + Args: + encoder_output (StorageView): Encoder output. + text_tokens_per_output (List[List[int]]): List of text tokens per output. + tokenizer (Tokenizer): Tokenizer to use to decode the token ids. + segment_sizes (List[int]): List of segment sizes. + median_filter_width (int): Width of the median filter to apply on the alignment. + + Returns: + List[List[dict]]: List of alignments per output. + """ + text_tokens_per_output = [ + list(itertools.chain.from_iterable(list_of_tokens)) + for list_of_tokens in text_tokens_per_output + ] + + results = self.model.model.align( + encoder_output, + tokenizer.sot_sequence, + text_tokens_per_output, + segment_sizes, + median_filter_width=median_filter_width, + ) + + final_alignments = [] + for res, text_tokens in zip(results, text_tokens_per_output): + words, word_tokens = tokenizer.split_to_word_tokens( + text_tokens + [tokenizer.eot] + ) + word_boundaries = np.pad( + np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0) + ) + if len(word_boundaries) <= 1: + final_alignments.append([]) + continue + + alignments = res.alignments + text_indices = np.array([pair[0] for pair in alignments]) + time_indices = np.array([pair[1] for pair in alignments]) + + jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype( + bool + ) + jump_times = time_indices[jumps] / self.tokens_per_second + start_times = jump_times[word_boundaries[:-1]] + end_times = jump_times[word_boundaries[1:]] + + text_token_probs = res.text_token_probs + word_probabilities = [ + np.mean(text_token_probs[i:j]) + for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) + ] + + word_durations = end_times - start_times + word_durations = word_durations[word_durations.nonzero()] + + if len(word_durations) > 0: + median_duration = np.median(word_durations) + max_duration = median_duration * 2 + + if len(word_durations) >= 2 and word_durations[1] > max_duration: + boundary = max(end_times[2] / 2, end_times[2] - max_duration) + end_times[0] = start_times[1] = boundary + + if ( + len(word_durations) >= 1 + and end_times[0] - start_times[0] > max_duration + ): + start_times[0] = max(0, end_times[0] - max_duration) + + final_alignments.append( + [ + dict( + word=word, + tokens=tokens, + start=start, + end=end, + probability=probability, + ) + for word, tokens, start, end, probability in zip( + words, word_tokens, start_times, end_times, word_probabilities + ) + ] + ) + + return final_alignments + + def _merge_punctuation( + self, alignments: List[List[dict]], prepended: str, appended: str + ) -> None: + """ + Fix punctuation boundaries for the alignments. + + Args: + alignments (List[List[dict]]): List of alignments. + prepended (str): Prepended punctuation. + appended (str): Appended punctuation. + """ + for alignment in alignments: + # merge prepended punctuations + i = len(alignment) - 2 + j = len(alignment) - 1 + while i >= 0: + previous = alignment[i] + following = alignment[j] + if ( + previous["word"].startswith(" ") + and previous["word"].strip() in prepended + ): + # prepend it to the following word + following["word"] = previous["word"] + following["word"] + following["tokens"] = previous["tokens"] + following["tokens"] + previous["word"] = "" + previous["tokens"] = [] + else: + j = i + i -= 1 + + # merge appended punctuations + i = 0 + j = 1 + while j < len(alignment): + previous = alignment[i] + following = alignment[j] + if not previous["word"].endswith(" ") and following["word"] in appended: + # append it to the previous word + previous["word"] = previous["word"] + following["word"] + previous["tokens"] = previous["tokens"] + following["tokens"] + following["word"] = "" + following["tokens"] = [] + else: + i = j + j += 1 + def _get_quality_metrics( self, tokens: List[int], text: str, score: float, length_penalty: float ) -> Tuple[float, float]: diff --git a/wordcab_transcribe/utils.py b/wordcab_transcribe/utils.py index 1551b5a..5ffe66c 100644 --- a/wordcab_transcribe/utils.py +++ b/wordcab_transcribe/utils.py @@ -440,10 +440,10 @@ def format_segments( else: _words = [ { - "word": word.word.strip(), - "start": word.start, - "end": word.end, - "score": round(word.probability, 3), + "word": word["word"].strip(), + "start": word["start"], + "end": word["end"], + "score": round(word["probability"], 2), } for word in segment["words"] ] From 6e2a59ae7b6433fa71e42adb4d7fa06dc42ac773 Mon Sep 17 00:00:00 2001 From: chainyo Date: Wed, 7 Jun 2023 18:55:15 +0000 Subject: [PATCH 2/2] round to 2 --- wordcab_transcribe/services/align_service.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/wordcab_transcribe/services/align_service.py b/wordcab_transcribe/services/align_service.py index ad6da92..b57d0c6 100644 --- a/wordcab_transcribe/services/align_service.py +++ b/wordcab_transcribe/services/align_service.py @@ -419,9 +419,9 @@ def align( start, end, score = None, None, None if cdx in segment["clean_cdx"]: char_seg = char_segments[segment["clean_cdx"].index(cdx)] - start = round(char_seg.start * ratio + t1, 3) - end = round(char_seg.end * ratio + t1, 3) - score = round(char_seg.score, 3) + start = round(char_seg.start * ratio + t1, 2) + end = round(char_seg.end * ratio + t1, 2) + score = round(char_seg.score, 2) char_segments_arr.append( { @@ -468,7 +468,7 @@ def align( word_chars = word_chars[word_chars["char"] != " "] word_start = word_chars["start"].min() word_end = word_chars["end"].max() - word_score = round(word_chars["score"].mean(), 3) + word_score = round(word_chars["score"].mean(), 2) # -1 indicates unalignable word_segment = {"word": word_text}