From 62e33b746080b597f7b36140419f33a726e47316 Mon Sep 17 00:00:00 2001 From: Aleks Date: Sun, 26 May 2024 17:31:54 -0400 Subject: [PATCH 1/3] Updated TensorRT-LLM version to latest, enabled dual-channel for tensorrt-llm backend --- .../services/transcribe_service.py | 168 ++++++++++++------ 1 file changed, 111 insertions(+), 57 deletions(-) diff --git a/src/wordcab_transcribe/services/transcribe_service.py b/src/wordcab_transcribe/services/transcribe_service.py index 554af54..63dd5fe 100644 --- a/src/wordcab_transcribe/services/transcribe_service.py +++ b/src/wordcab_transcribe/services/transcribe_service.py @@ -253,24 +253,18 @@ def __call__( outputs = TranscriptionOutput(segments=_outputs) else: - outputs = [] - for audio_index, audio_file in enumerate(audio): - outputs.append( - self.multi_channel( - audio_file, - source_lang=source_lang, - speaker_id=audio_index, - suppress_blank=suppress_blank, - word_timestamps=word_timestamps, - internal_vad=internal_vad, - repetition_penalty=repetition_penalty, - compression_ratio_threshold=compression_ratio_threshold, - log_prob_threshold=log_prob_threshold, - no_speech_threshold=no_speech_threshold, - prompt=prompt, - ) - ) - + outputs = self.multi_channel( + audio, + source_lang=source_lang, + suppress_blank=suppress_blank, + word_timestamps=word_timestamps, + internal_vad=internal_vad, + repetition_penalty=repetition_penalty, + compression_ratio_threshold=compression_ratio_threshold, + log_prob_threshold=log_prob_threshold, + no_speech_threshold=no_speech_threshold, + prompt=prompt, + ) return outputs async def async_live_transcribe( @@ -323,9 +317,8 @@ def live_transcribe( def multi_channel( self, - audio: Union[str, torch.Tensor, TensorShare], + audio_list: List[Union[str, torch.Tensor, TensorShare]], source_lang: str, - speaker_id: int, suppress_blank: bool = False, word_timestamps: bool = True, internal_vad: bool = True, @@ -340,9 +333,8 @@ def multi_channel( Transcribe an audio file using the faster-whisper original pipeline. Args: - audio (Union[str, torch.Tensor, TensorShare]): Audio file path or loaded audio. + audio_list (List[Union[str, torch.Tensor, TensorShare]]): List of audio file paths or audio tensors. source_lang (str): Language of the audio file. - speaker_id (int): Speaker ID used in the diarization. suppress_blank (bool): Whether to suppress blank at the beginning of the sampling. word_timestamps (bool): @@ -367,43 +359,105 @@ def multi_channel( Returns: MultiChannelTranscriptionOutput: Multi-channel transcription segments in a list. """ - if isinstance(audio, torch.Tensor): - _audio = audio.numpy() - elif isinstance(audio, TensorShare): - ts = audio.to_tensors(backend=Backend.NUMPY) - _audio = ts["audio"] + outputs = [] - final_segments = [] + if self.model_engine == "faster-whisper": + for speaker_id, audio in enumerate(audio_list): + final_segments = [] + if isinstance(audio, torch.Tensor): + _audio = audio.numpy() + elif isinstance(audio, TensorShare): + ts = audio.to_tensors(backend=Backend.NUMPY) + _audio = ts["audio"] - segments, _ = self.model.transcribe( - _audio, - language=source_lang, - initial_prompt=prompt, - repetition_penalty=repetition_penalty, - compression_ratio_threshold=compression_ratio_threshold, - log_prob_threshold=log_prob_threshold, - no_speech_threshold=no_speech_threshold, - condition_on_previous_text=condition_on_previous_text, - suppress_blank=suppress_blank, - word_timestamps=word_timestamps, - vad_filter=internal_vad, - vad_parameters={ - "threshold": 0.5, - "min_speech_duration_ms": 250, - "min_silence_duration_ms": 100, - "speech_pad_ms": 30, - "window_size_samples": 512, - }, - ) + segments, _ = self.model.transcribe( + _audio, + language=source_lang, + initial_prompt=prompt, + repetition_penalty=repetition_penalty, + compression_ratio_threshold=compression_ratio_threshold, + log_prob_threshold=log_prob_threshold, + no_speech_threshold=no_speech_threshold, + condition_on_previous_text=condition_on_previous_text, + suppress_blank=suppress_blank, + word_timestamps=word_timestamps, + vad_filter=internal_vad, + vad_parameters={ + "threshold": 0.5, + "min_speech_duration_ms": 250, + "min_silence_duration_ms": 100, + "speech_pad_ms": 30, + "window_size_samples": 512, + }, + ) - for segment in segments: - _segment = MultiChannelSegment( - start=segment.start, - end=segment.end, - text=segment.text, - words=[Word(**word._asdict()) for word in segment.words], - speaker=speaker_id, + for segment in segments: + _segment = MultiChannelSegment( + start=segment.start, + end=segment.end, + text=segment.text, + words=[Word(**word._asdict()) for word in segment.words], + speaker=speaker_id, + ) + final_segments.append(_segment) + + outputs.append(final_segments) + elif self.model_engine == "tensorrt-llm": + audio_channels = [] + speaker_ids = [] + + for speaker_id, audio in enumerate(audio_list): + if isinstance(audio, torch.Tensor): + audio = audio.numpy() + elif isinstance(audio, TensorShare): + ts = audio.to_tensors(backend=Backend.NUMPY) + audio = ts["audio"] + audio_channels.append(audio) + speaker_ids.append(speaker_id) + + channels_len = len(audio_channels) + segments_list = self.model.transcribe( + audio_data=audio_channels, + lang_codes=[source_lang] * channels_len, + tasks=["transcribe"] * channels_len, + initial_prompts=[prompt] * channels_len, + batch_size=channels_len, + use_vad=internal_vad, + generate_kwargs={"num_beams": 1}, ) - final_segments.append(_segment) - return MultiChannelTranscriptionOutput(segments=final_segments) + for speaker_id, segments in enumerate(segments_list): + final_segments = [] + + for segment in segments: + segment["words"] = segment.pop("word_timestamps") + for word in segment["words"]: + word["word"] = f" {word['word']}" + word["start"] = round(word["start"], 2) + word["end"] = round(word["end"], 2) + segment["text"] = segment["text"].strip() + + segment["start"] = round(segment.pop("start_time"), 2) + segment["end"] = round(segment.pop("end_time"), 2) + extra = { + "seek": 1, + "id": 1, + "tokens": [1], + "temperature": 0.0, + "avg_logprob": 0.0, + "compression_ratio": 0.0, + "no_speech_prob": 0.0, + } + _segment = Segment(**{**segment, **extra}) + _segment = MultiChannelSegment( + start=_segment.start, + end=_segment.end, + text=_segment.text, + words=[Word(**word) for word in _segment.words], + speaker=speaker_id, + ) + final_segments.append(_segment) + + outputs.append(MultiChannelTranscriptionOutput(segments=final_segments)) + + return outputs From b2f0d3b0af16082ef811537eec4062e40c9333d8 Mon Sep 17 00:00:00 2001 From: Aleks Date: Sun, 26 May 2024 17:31:54 -0400 Subject: [PATCH 2/3] Updated TensorRT-LLM version to latest, enabled dual-channel for tensorrt-llm backend --- src/wordcab_transcribe/services/transcribe_service.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/wordcab_transcribe/services/transcribe_service.py b/src/wordcab_transcribe/services/transcribe_service.py index 63dd5fe..8fd20a6 100644 --- a/src/wordcab_transcribe/services/transcribe_service.py +++ b/src/wordcab_transcribe/services/transcribe_service.py @@ -436,7 +436,6 @@ def multi_channel( word["start"] = round(word["start"], 2) word["end"] = round(word["end"], 2) segment["text"] = segment["text"].strip() - segment["start"] = round(segment.pop("start_time"), 2) segment["end"] = round(segment.pop("end_time"), 2) extra = { From b022cbabcd153a2dd26a25cf861cbb1acc112d4a Mon Sep 17 00:00:00 2001 From: Aleks Date: Sun, 26 May 2024 17:28:28 -0400 Subject: [PATCH 3/3] Updated TensorRT-LLM version to latest, enabled dual-channel for tensorrt-llm backend --- .gitignore | 1 + Dockerfile | 32 +++++--- pre_requirements.txt | 13 ++++ requirements.txt | 27 +++++++ .../tensorrt_llm/engine_builder/build.py | 73 ++++++++++++------- .../services/asr_service.py | 2 + .../services/post_processing_service.py | 1 + .../services/transcribe_service.py | 4 +- 8 files changed, 112 insertions(+), 41 deletions(-) create mode 100644 pre_requirements.txt create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore index 456b9e2..faf8fd1 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ __pycache__/ .coverage.* .DS_Store .env_dev +.env_* .nox/ .pytest_cache/ .python-version diff --git a/Dockerfile b/Dockerfile index 348c278..e305757 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,12 +1,12 @@ FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04 AS runtime ENV NVIDIA_DRIVER_CAPABILITIES ${NVIDIA_DRIVER_CAPABILITIES:-compute,utility} - ENV PYTHONUNBUFFERED=1 - ENV DEBIAN_FRONTEND=noninteractive +ENV MPI4PY_VERSION="3.1.5" +ENV RELEASE_URL="https://github.com/mpi4py/mpi4py/archive/refs/tags/${MPI4PY_VERSION}.tar.gz" -RUN apt update && apt install -y \ +RUN apt-get update && apt-get install -y --no-install-recommends \ libsndfile1 \ software-properties-common \ ffmpeg \ @@ -28,9 +28,6 @@ RUN apt update && apt install -y \ python3-dev \ liblzma-dev \ libsqlite3-dev \ - && rm -rf /var/lib/apt/lists/* - -RUN apt update && apt install -y \ libtiff-tools=4.3.0-6ubuntu0.8 \ libtiff5=4.3.0-6ubuntu0.8 \ libgnutls30=3.7.3-4ubuntu1.5 \ @@ -42,7 +39,8 @@ RUN apt update && apt install -y \ login=1:4.8.1-2ubuntu2.2 \ passwd=1:4.8.1-2ubuntu2.2 \ uidmap=1:4.8.1-2ubuntu2.2 \ - binutils=2.38-4ubuntu2.6 + binutils=2.38-4ubuntu2.6 \ + && rm -rf /var/lib/apt/lists/* RUN cd /tmp && \ wget https://www.python.org/ftp/python/3.10.12/Python-3.10.12.tgz && \ @@ -57,9 +55,6 @@ RUN cd /tmp && \ RUN export CUDNN_PATH=$(python -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))') && \ echo 'export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:'${CUDNN_PATH} >> ~/.bashrc -ENV MPI4PY_VERSION="3.1.5" -ENV RELEASE_URL="https://github.com/mpi4py/mpi4py/archive/refs/tags/${MPI4PY_VERSION}.tar.gz" - RUN curl -L ${RELEASE_URL} | tar -zx -C /tmp \ && sed -i 's/>= 40\\.9\\.0/>= 40.9.0, < 69/g' /tmp/mpi4py-${MPI4PY_VERSION}/pyproject.toml \ && pip install /tmp/mpi4py-${MPI4PY_VERSION} \ @@ -67,10 +62,23 @@ RUN curl -L ${RELEASE_URL} | tar -zx -C /tmp \ RUN python -m pip install pip --upgrade +COPY pre_requirements.txt . +COPY requirements.txt . + +RUN pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com -r pre_requirements.txt -r requirements.txt + WORKDIR /app -COPY . . +RUN git clone https://github.com/NVIDIA/NeMo.git ./nemo_local && \ + cd ./nemo_local && \ + git config --global user.email "you@example.com" && \ + git config --global user.name "Your Name" && \ + git fetch origin pull/9114/head:pr9114 && \ + git merge pr9114 && \ + pip install -e ".[asr]" + +ENV PYTHONPATH="/app/src" -RUN pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com .[runtime] +COPY . . CMD ["uvicorn", "--host=0.0.0.0", "--port=5001", "src.wordcab_transcribe.main:app"] diff --git a/pre_requirements.txt b/pre_requirements.txt new file mode 100644 index 0000000..ba8fef7 --- /dev/null +++ b/pre_requirements.txt @@ -0,0 +1,13 @@ +argon2-cffi==23.1.0 +fastapi==0.110.0 +python-jose[cryptography]==3.3.0 +python-multipart==0.0.9 +shortuuid==1.0.13 +svix==1.21.0 +uvicorn==0.29.0 +websockets==12.0 +tensorrt_llm==0.11.0.dev2024052100 +Cython==3.0.10 +youtokentome @ git+https://github.com/gburlet/YouTokenToMe.git@dependencies +deepmultilingualpunctuation==1.0.1 +pyannote.audio==3.2.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2d4077e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,27 @@ +aiohttp==3.9.3 +aiofiles==23.2.1 +boto3 +faster-whisper @ https://github.com/SYSTRAN/faster-whisper/archive/refs/heads/master.tar.gz +ffmpeg-python==0.2.0 +transformers==4.38.2 +librosa==0.10.1 +loguru==0.7.2 +nltk==3.8.1 +numpy==1.26.4 +onnxruntime==1.17.1 +pandas==2.2.1 +pydantic==2.6.4 +python-dotenv==1.0.1 +tensorshare==0.1.1 +torch==2.2.2 +torchaudio==2.2.2 +wget==3.2.0 +yt-dlp==2024.3.10 +tiktoken==0.6.0 +datasets==2.18.0 +kaldialign==0.9.0 +openai-whisper==v20231117 +soundfile==0.12.1 +safetensors==0.4.2 +janus==1.0.0 +backports.lzma==0.0.14 \ No newline at end of file diff --git a/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/build.py b/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/build.py index 1106a2d..f083fbf 100644 --- a/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/build.py +++ b/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/build.py @@ -22,10 +22,10 @@ from tensorrt_llm.builder import Builder from tensorrt_llm.functional import LayerNormPositionType, LayerNormType from tensorrt_llm.logger import logger -from tensorrt_llm.models import quantize_model from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.quantization import QuantMode +from tensorrt_llm.quantization.quantize_by_modelopt import quantize_model from weight import load_decoder_weight, load_encoder_weight MODEL_ENCODER_NAME = "whisper_encoder" @@ -316,32 +316,49 @@ def build_decoder(model, args): ) tensorrt_llm_whisper_decoder = tensorrt_llm.models.DecoderModel( - num_layers=model_metadata["n_text_layer"], - num_heads=model_metadata["n_text_head"], - hidden_size=model_metadata["n_text_state"], - ffn_hidden_size=4 * model_metadata["n_text_state"], - encoder_hidden_size=model_metadata["n_text_state"], - encoder_num_heads=model_metadata["n_text_head"], - vocab_size=model_metadata["n_vocab"], - head_size=model_metadata["n_text_state"] // model_metadata["n_text_head"], - max_position_embeddings=model_metadata["n_text_ctx"], - has_position_embedding=True, - relative_attention=False, - max_distance=0, - num_buckets=0, - has_embedding_layernorm=False, - has_embedding_scale=False, - q_scaling=1.0, - has_attention_qkvo_bias=True, - has_mlp_bias=True, - has_model_final_layernorm=True, - layernorm_eps=1e-5, - layernorm_position=LayerNormPositionType.pre_layernorm, - layernorm_type=LayerNormType.LayerNorm, - hidden_act="gelu", - rescale_before_lm_head=False, - dtype=str_dtype_to_trt(args.dtype), - logits_dtype=str_dtype_to_trt(args.dtype), + tensorrt_llm.models.modeling_utils.PretrainedConfig( + architecture="whisper", + dtype=str_dtype_to_trt(args.dtype), + logits_dtype=str_dtype_to_trt(args.dtype), + vocab_size=model_metadata["n_vocab"], + max_position_embeddings=model_metadata["n_text_ctx"], + hidden_size=model_metadata["n_text_state"], + num_hidden_layers=model_metadata["n_text_layer"], + num_attention_heads=model_metadata["n_text_head"], + num_key_value_heads=model_metadata["n_text_head"], + hidden_act="gelu", + intermediate_size=4 * model_metadata["n_text_state"], + norm_epsilon=1e-5, + position_embedding_type="learned_absolute", + world_size=1, + tp_size=1, + pp_size=1, + gpus_per_node=1, + quantization=tensorrt_llm.models.modeling_utils.QuantConfig(), + head_size=model_metadata["n_text_state"] // model_metadata["n_text_head"], + num_layers=model_metadata["n_text_layer"], + num_heads=model_metadata["n_text_head"], + ffn_hidden_size=4 * model_metadata["n_text_state"], + encoder_hidden_size=model_metadata["n_text_state"], + encoder_num_heads=model_metadata["n_text_head"], + has_position_embedding=True, + relative_attention=False, + max_distance=0, + num_buckets=0, + has_embedding_layernorm=False, + has_embedding_scale=False, + q_scaling=1.0, + has_attention_qkvo_bias=True, + has_mlp_bias=True, + has_model_final_layernorm=True, + layernorm_eps=1e-5, + layernorm_position=LayerNormPositionType.pre_layernorm, + layernorm_type=LayerNormType.LayerNorm, + rescale_before_lm_head=False, + encoder_head_size=model_metadata["n_text_state"] + // model_metadata["n_text_head"], # Added missing variable + skip_cross_qkv=False, + ) ) if args.use_weight_only: @@ -377,7 +394,7 @@ def build_decoder(model, args): model_metadata["n_audio_ctx"], ) - tensorrt_llm_whisper_decoder(*inputs) + tensorrt_llm_whisper_decoder(**inputs) if args.debug_mode: for k, v in tensorrt_llm_whisper_decoder.named_network_outputs(): diff --git a/src/wordcab_transcribe/services/asr_service.py b/src/wordcab_transcribe/services/asr_service.py index 19bf25e..ca078ed 100644 --- a/src/wordcab_transcribe/services/asr_service.py +++ b/src/wordcab_transcribe/services/asr_service.py @@ -336,10 +336,12 @@ def create_transcription_local_service(self) -> None: def create_diarization_local_service(self) -> None: """Create a local diarization service.""" if settings.diarization_backend == "longform-diarizer": + logger.info("Using LongFormDiarizeService for diarization.") self.local_services.diarization = LongFormDiarizeService( device=self.device, ) else: + logger.info("Using DiarizeService for diarization.") self.local_services.diarization = DiarizeService( device=self.device, device_index=self.device_index, diff --git a/src/wordcab_transcribe/services/post_processing_service.py b/src/wordcab_transcribe/services/post_processing_service.py index 3bc82d8..ffa7660 100644 --- a/src/wordcab_transcribe/services/post_processing_service.py +++ b/src/wordcab_transcribe/services/post_processing_service.py @@ -362,6 +362,7 @@ def reconstruct_multi_channel_utterances( sentences = [] for speaker, word in transcript_words: start_t, end_t, text = word.start, word.end, word.word + print(speaker, previous_speaker, text) if speaker != previous_speaker: sentences.append(current_sentence) diff --git a/src/wordcab_transcribe/services/transcribe_service.py b/src/wordcab_transcribe/services/transcribe_service.py index 8fd20a6..8f2fb9e 100644 --- a/src/wordcab_transcribe/services/transcribe_service.py +++ b/src/wordcab_transcribe/services/transcribe_service.py @@ -251,7 +251,6 @@ def __call__( _outputs = [segment._asdict() for segment in segments] outputs = TranscriptionOutput(segments=_outputs) - else: outputs = self.multi_channel( audio, @@ -319,6 +318,7 @@ def multi_channel( self, audio_list: List[Union[str, torch.Tensor, TensorShare]], source_lang: str, + speaker_id: int, suppress_blank: bool = False, word_timestamps: bool = True, internal_vad: bool = True, @@ -335,6 +335,7 @@ def multi_channel( Args: audio_list (List[Union[str, torch.Tensor, TensorShare]]): List of audio file paths or audio tensors. source_lang (str): Language of the audio file. + speaker_id (int): Speaker ID used in the diarization. suppress_blank (bool): Whether to suppress blank at the beginning of the sampling. word_timestamps (bool): @@ -436,6 +437,7 @@ def multi_channel( word["start"] = round(word["start"], 2) word["end"] = round(word["end"], 2) segment["text"] = segment["text"].strip() + segment["start"] = round(segment.pop("start_time"), 2) segment["end"] = round(segment.pop("end_time"), 2) extra = {