diff --git a/.env b/.env index aca1213..9b566a5 100644 --- a/.env +++ b/.env @@ -41,6 +41,10 @@ WHISPER_MODEL="large-v3" # You can specify one of two engines, "faster-whisper" or "tensorrt-llm". At the moment, "faster-whisper" is more # stable, adjustable, and accurate, while "tensorrt-llm" is faster but less accurate and adjustable. WHISPER_ENGINE="tensorrt-llm" +# This helps adjust some build during the conversion of the Whisper model to TensorRT. If you change this, be sure to +# it in pre_requirements.txt. The only available options are "0.9.0.dev2024032600" and "0.11.0.dev2024052100". +# Note that version "0.11.0.dev2024052100" is not compatible with T4 or V100 GPUs. +TENSORRT_LLM_VERSION="0.9.0.dev2024032600" # The align model is used for aligning timestamps under the "tensorrt-llm" engine. The available options are: # "tiny", "small", "base", or "medium". ALIGN_MODEL="tiny" @@ -60,7 +64,7 @@ TOKENIZERS_PARALLELISM=False # The diarization_backend parameter is used to control the diarization model used. The available options are: # "longform-diarizer" or "default-diarizer". It's suggested to use "default-diarizer" for better stability. # The "longform-diarizer" is still being developed. -DIARIZATION_BACKEND="default-diarizer" +DIARIZATION_BACKEND="longform-diarizer" # In a MSDD (Multiscale Diarization Decoder) model, the diarization model is trained on multiple window lengths. # The window_lengths are specified in seconds, and separated by a comma. If not specified, the default value will # be "1.5, 1.25, 1.0, 0.75, 0.5". @@ -97,7 +101,7 @@ ASR_TYPE="async" # # Include the cortex endpoint in the API. This endpoint is used to process audio files from the Cortex API. # Use this only if you deploy the API using Cortex and Kubernetes. -CORTEX_ENDPOINT=True +CORTEX_ENDPOINT=False # # ---------------------------------------- API AUTHENTICATION CONFIGURATION ------------------------------------------ # # The API authentication is used to control the access to the API endpoints. @@ -130,6 +134,7 @@ SVIX_APP_ID= # # ----------------------------------------------- AWS CONFIGURATION ------------------------------------------------- # # +SEND_RESULTS_TO_S3=False AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_STORAGE_BUCKET_NAME= diff --git a/.gitignore b/.gitignore index faf8fd1..26007ee 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,7 @@ test.ipynb test.py whisper_model whisper_model_he +storage +nemo_storage +nemo_local +error.log \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 186c377..0000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,22 +0,0 @@ -repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: "v4.3.0" - hooks: - - id: check-added-large-files - args: [--maxkb=2000] - - id: check-toml - - id: check-yaml - - - repo: https://github.com/psf/black - rev: 22.10.0 - hooks: - - id: black - args: ["--preview"] - language_version: python3 - - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.0.263" - hooks: - - id: ruff - args: [--fix] -exclude: ^notebooks/ diff --git a/error.log b/error.log new file mode 100644 index 0000000..ae3d7f1 --- /dev/null +++ b/error.log @@ -0,0 +1 @@ +/app/temp_outputs/mono_file.wav:[Errno 2] No such file or directory: '/app/temp_outputs/mono_file.wav' \ No newline at end of file diff --git a/nemo_local b/nemo_local new file mode 160000 index 0000000..5703d95 --- /dev/null +++ b/nemo_local @@ -0,0 +1 @@ +Subproject commit 5703d95f09caa5bb223866b011425a2019e601fc diff --git a/nemo_storage/infer_manifest.json b/nemo_storage/infer_manifest.json new file mode 100644 index 0000000..f19d538 --- /dev/null +++ b/nemo_storage/infer_manifest.json @@ -0,0 +1 @@ +{"audio_filepath": "/app/temp_outputs/mono_file.wav", "offset": 0, "duration": null, "label": "infer", "text": "-", "rttm_filepath": null, "uem_filepath": null} diff --git a/pre_requirements.txt b/pre_requirements.txt index ba8fef7..cc286fa 100644 --- a/pre_requirements.txt +++ b/pre_requirements.txt @@ -6,7 +6,7 @@ shortuuid==1.0.13 svix==1.21.0 uvicorn==0.29.0 websockets==12.0 -tensorrt_llm==0.11.0.dev2024052100 +tensorrt_llm==0.9.0.dev2024032600 Cython==3.0.10 youtokentome @ git+https://github.com/gburlet/YouTokenToMe.git@dependencies deepmultilingualpunctuation==1.0.1 diff --git a/src/wordcab_transcribe/config.py b/src/wordcab_transcribe/config.py index 6f30e32..aa4ce88 100644 --- a/src/wordcab_transcribe/config.py +++ b/src/wordcab_transcribe/config.py @@ -69,6 +69,7 @@ class Settings: # Cortex configuration cortex_api_key: str # AWS configuration + send_results_to_s3: bool aws_access_key_id: str aws_secret_access_key: str aws_storage_bucket_name: str @@ -137,7 +138,7 @@ def align_model_compatibility_check(cls, value: str): # noqa: B902, N805 """Check that the whisper engine is compatible.""" if value.lower() not in ["tiny", "small", "base", "medium"]: raise ValueError( - "The whisper engine must be one of `tiny`, `small`, `base`, or" + "The align model must be one of `tiny`, `small`, `base`, or" " `medium`." ) @@ -348,6 +349,7 @@ def __post_init__(self): # Cortex configuration cortex_api_key=getenv("WORDCAB_TRANSCRIBE_API_KEY", ""), # AWS configuration + send_results_to_s3=getenv("SEND_RESULTS_TO_S3", False), aws_access_key_id=getenv("AWS_ACCESS_KEY_ID", ""), aws_secret_access_key=getenv("AWS_SECRET_ACCESS_KEY", ""), aws_storage_bucket_name=getenv("AWS_STORAGE_BUCKET_NAME", ""), diff --git a/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/build.py b/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/build.py index f083fbf..459848e 100644 --- a/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/build.py +++ b/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/build.py @@ -12,12 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse import os import time +import argparse -import tensorrt_llm import torch +from loguru import logger + +import tensorrt_llm from tensorrt_llm import str_dtype_to_torch, str_dtype_to_trt from tensorrt_llm.builder import Builder from tensorrt_llm.functional import LayerNormPositionType, LayerNormType @@ -25,12 +27,18 @@ from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.quantization import QuantMode -from tensorrt_llm.quantization.quantize_by_modelopt import quantize_model from weight import load_decoder_weight, load_encoder_weight MODEL_ENCODER_NAME = "whisper_encoder" MODEL_DECODER_NAME = "whisper_decoder" +TENSORRT_LLM_VERSION = os.getenv("TENSORRT_LLM_VERSION") +if "0.9.0" in TENSORRT_LLM_VERSION: + from tensorrt_llm.models import quantize_model +elif "0.11.0" in TENSORRT_LLM_VERSION: + from tensorrt_llm.quantization.quantize_by_modelopt import quantize_model +else: + raise ValueError(f"Unsupported version of tensorrt_llm: {TENSORRT_LLM_VERSION}") def get_engine_name(model, dtype, tp_size=1, rank=0): return "{}_{}_tp{}_rank{}.engine".format(model, dtype, tp_size, rank) @@ -79,7 +87,7 @@ def parse_arguments(): parser.add_argument("--quantize_dir", type=str, default="quantize/1-gpu") parser.add_argument("--dtype", type=str, default="float16", choices=["float16"]) parser.add_argument("--log_level", type=str, default="info") - parser.add_argument("--max_batch_size", type=int, default=24) + parser.add_argument("--max_batch_size", type=int, default=16) parser.add_argument("--max_input_len", type=int, default=4) parser.add_argument("--max_output_len", type=int, default=448) parser.add_argument("--max_beam_width", type=int, default=1) @@ -315,32 +323,63 @@ def build_decoder(model, args): int8=args.quant_mode.has_act_or_weight_quant(), ) - tensorrt_llm_whisper_decoder = tensorrt_llm.models.DecoderModel( - tensorrt_llm.models.modeling_utils.PretrainedConfig( - architecture="whisper", - dtype=str_dtype_to_trt(args.dtype), - logits_dtype=str_dtype_to_trt(args.dtype), - vocab_size=model_metadata["n_vocab"], - max_position_embeddings=model_metadata["n_text_ctx"], - hidden_size=model_metadata["n_text_state"], - num_hidden_layers=model_metadata["n_text_layer"], - num_attention_heads=model_metadata["n_text_head"], - num_key_value_heads=model_metadata["n_text_head"], - hidden_act="gelu", - intermediate_size=4 * model_metadata["n_text_state"], - norm_epsilon=1e-5, - position_embedding_type="learned_absolute", - world_size=1, - tp_size=1, - pp_size=1, - gpus_per_node=1, - quantization=tensorrt_llm.models.modeling_utils.QuantConfig(), - head_size=model_metadata["n_text_state"] // model_metadata["n_text_head"], + try: + tensorrt_llm_whisper_decoder = tensorrt_llm.models.DecoderModel( + tensorrt_llm.models.modeling_utils.PretrainedConfig( + architecture="whisper", + dtype=str_dtype_to_trt(args.dtype), + logits_dtype=str_dtype_to_trt(args.dtype), + vocab_size=model_metadata["n_vocab"], + max_position_embeddings=model_metadata["n_text_ctx"], + hidden_size=model_metadata["n_text_state"], + num_hidden_layers=model_metadata["n_text_layer"], + num_attention_heads=model_metadata["n_text_head"], + num_key_value_heads=model_metadata["n_text_head"], + hidden_act="gelu", + intermediate_size=4 * model_metadata["n_text_state"], + norm_epsilon=1e-5, + position_embedding_type="learned_absolute", + world_size=1, + tp_size=1, + pp_size=1, + gpus_per_node=1, + quantization=tensorrt_llm.models.modeling_utils.QuantConfig(), + head_size=model_metadata["n_text_state"] // model_metadata["n_text_head"], + num_layers=model_metadata["n_text_layer"], + num_heads=model_metadata["n_text_head"], + ffn_hidden_size=4 * model_metadata["n_text_state"], + encoder_hidden_size=model_metadata["n_text_state"], + encoder_num_heads=model_metadata["n_text_head"], + has_position_embedding=True, + relative_attention=False, + max_distance=0, + num_buckets=0, + has_embedding_layernorm=False, + has_embedding_scale=False, + q_scaling=1.0, + has_attention_qkvo_bias=True, + has_mlp_bias=True, + has_model_final_layernorm=True, + layernorm_eps=1e-5, + layernorm_position=LayerNormPositionType.pre_layernorm, + layernorm_type=LayerNormType.LayerNorm, + rescale_before_lm_head=False, + encoder_head_size=model_metadata["n_text_state"] + // model_metadata["n_text_head"], # Added missing variable + skip_cross_qkv=False, + ) + ) + except: + tensorrt_llm_whisper_decoder = tensorrt_llm.models.DecoderModel( num_layers=model_metadata["n_text_layer"], num_heads=model_metadata["n_text_head"], + hidden_size=model_metadata["n_text_state"], ffn_hidden_size=4 * model_metadata["n_text_state"], encoder_hidden_size=model_metadata["n_text_state"], encoder_num_heads=model_metadata["n_text_head"], + vocab_size=model_metadata["n_vocab"], + head_size=model_metadata["n_text_state"] // model_metadata["n_text_head"], + max_position_embeddings=model_metadata["n_text_ctx"], has_position_embedding=True, relative_attention=False, max_distance=0, @@ -354,12 +393,11 @@ def build_decoder(model, args): layernorm_eps=1e-5, layernorm_position=LayerNormPositionType.pre_layernorm, layernorm_type=LayerNormType.LayerNorm, + hidden_act="gelu", rescale_before_lm_head=False, - encoder_head_size=model_metadata["n_text_state"] - // model_metadata["n_text_head"], # Added missing variable - skip_cross_qkv=False, + dtype=str_dtype_to_trt(args.dtype), + logits_dtype=str_dtype_to_trt(args.dtype), ) - ) if args.use_weight_only: tensorrt_llm_whisper_decoder = quantize_model( @@ -394,7 +432,10 @@ def build_decoder(model, args): model_metadata["n_audio_ctx"], ) - tensorrt_llm_whisper_decoder(**inputs) + if "0.9.0" in TENSORRT_LLM_VERSION: + tensorrt_llm_whisper_decoder(*inputs) + else: + tensorrt_llm_whisper_decoder(**inputs) if args.debug_mode: for k, v in tensorrt_llm_whisper_decoder.named_network_outputs(): diff --git a/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/create_trt_model.py b/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/create_trt_model.py index 8dd36ea..539f401 100644 --- a/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/create_trt_model.py +++ b/src/wordcab_transcribe/engines/tensorrt_llm/engine_builder/create_trt_model.py @@ -63,6 +63,21 @@ } +TRT_BUILD_MAX_OUTPUT_LEN = os.getenv("TRT_BUILD_MAX_OUTPUT_LEN", None) +TRT_BUILD_MAX_BEAM_WIDTH = os.getenv("TRT_BUILD_MAX_BEAM_WIDTH", None) +if not TRT_BUILD_MAX_OUTPUT_LEN: + TRT_BUILD_MAX_OUTPUT_LEN = 448 +else: + TRT_BUILD_MAX_OUTPUT_LEN = int(TRT_BUILD_MAX_OUTPUT_LEN) +logger.info(f"TRT_BUILD_MAX_OUTPUT_LEN: {TRT_BUILD_MAX_OUTPUT_LEN}") + +if not TRT_BUILD_MAX_BEAM_WIDTH: + TRT_BUILD_MAX_BEAM_WIDTH = 1 +else: + TRT_BUILD_MAX_BEAM_WIDTH = int(TRT_BUILD_MAX_BEAM_WIDTH) +logger.info(f"TRT_BUILD_MAX_BEAM_WIDTH: {TRT_BUILD_MAX_BEAM_WIDTH}") + + def build_whisper_trt_model( output_dir, use_gpt_attention_plugin=True, @@ -70,7 +85,9 @@ def build_whisper_trt_model( use_bert_attention_plugin=True, enable_context_fmha=True, use_weight_only=False, - model_name="distil-large-v2", + max_output_len=TRT_BUILD_MAX_OUTPUT_LEN, + max_beam_width=TRT_BUILD_MAX_BEAM_WIDTH, + model_name="large-v3", ): """ Build a Whisper model using the specified configuration. @@ -158,6 +175,10 @@ def build_whisper_trt_model( command.append("--enable_context_fmha") if use_weight_only: command.append("--use_weight_only") + if max_output_len: + command.extend(["--max_output_len", str(max_output_len)]) + if max_beam_width: + command.extend(["--max_beam_width", str(max_beam_width)]) try: subprocess.run(command, check=True) diff --git a/src/wordcab_transcribe/engines/tensorrt_llm/model.py b/src/wordcab_transcribe/engines/tensorrt_llm/model.py index 6ceda2c..6adadd3 100644 --- a/src/wordcab_transcribe/engines/tensorrt_llm/model.py +++ b/src/wordcab_transcribe/engines/tensorrt_llm/model.py @@ -39,11 +39,11 @@ def exact_div(x, y): "best_of": 5, "patience": 1, "length_penalty": 1, - "repetition_penalty": 1.01, + "repetition_penalty": 1.05, "no_repeat_ngram_size": 0, "compression_ratio_threshold": 2.4, "log_prob_threshold": -1.0, - "no_speech_threshold": 0.3, + "no_speech_threshold": 0.4, "prefix": None, "suppress_blank": False, "suppress_tokens": [-1], @@ -61,7 +61,7 @@ def exact_div(x, y): "best_of": 1, "patience": 2, "length_penalty": 1, - "repetition_penalty": 1.01, + "repetition_penalty": 1.05, "no_repeat_ngram_size": 0, "compression_ratio_threshold": 2.4, "log_prob_threshold": -1.0, @@ -224,7 +224,9 @@ def align_words( start_seq_wise_req[_sot_seq] = [_idx] token_alignments = [[] for _ in seg_metadata] + for start_seq, req_idx in start_seq_wise_req.items(): + res = self.align_model.align( ctranslate2.StorageView.from_array(features[req_idx]), start_sequence=list(start_seq), diff --git a/src/wordcab_transcribe/engines/tensorrt_llm/trt_model.py b/src/wordcab_transcribe/engines/tensorrt_llm/trt_model.py index 3295a39..09a9d1f 100644 --- a/src/wordcab_transcribe/engines/tensorrt_llm/trt_model.py +++ b/src/wordcab_transcribe/engines/tensorrt_llm/trt_model.py @@ -93,8 +93,8 @@ def get_session(self, engine_dir, runtime_mapping, debug_mode=False): # TODO: Make dynamic max_batch_size and max_beam_width decoder_model_config = ModelConfig( - max_batch_size=24, - max_beam_width=1, + max_batch_size=16, + max_beam_width=5, num_heads=self.decoder_config["num_heads"], num_kv_heads=self.decoder_config["num_heads"], hidden_size=self.decoder_config["hidden_size"], diff --git a/src/wordcab_transcribe/models.py b/src/wordcab_transcribe/models.py index a4add17..f066b66 100644 --- a/src/wordcab_transcribe/models.py +++ b/src/wordcab_transcribe/models.py @@ -403,6 +403,7 @@ class BaseRequest(BaseModel): diarization: bool = False batch_size: int = 1 source_lang: str = "en" + num_beams: int = 1 timestamps: Timestamps = Timestamps.seconds vocab: Union[List[str], None] = None word_timestamps: bool = False diff --git a/src/wordcab_transcribe/router/v1/audio_file_endpoint.py b/src/wordcab_transcribe/router/v1/audio_file_endpoint.py index 95908a5..eb56c7b 100644 --- a/src/wordcab_transcribe/router/v1/audio_file_endpoint.py +++ b/src/wordcab_transcribe/router/v1/audio_file_endpoint.py @@ -52,6 +52,7 @@ async def inference_with_audio( # noqa: C901 diarization: bool = Form(False), # noqa: B008 multi_channel: bool = Form(False), # noqa: B008 source_lang: str = Form("en"), # noqa: B008 + num_beams: int = Form(1), # noqa: B008 timestamps: str = Form("s"), # noqa: B008 vocab: Union[List[str], None] = Form(None), # noqa: B008 word_timestamps: bool = Form(False), # noqa: B008 @@ -73,6 +74,28 @@ async def inference_with_audio( # noqa: C901 await save_file_locally(filename=filename, file=file) + num_channels = await check_num_channels(filename) + if (num_channels > 1 and multi_channel is False) or (num_channels == 1 and multi_channel is True): + num_channels = 1 # Force mono channel if more than 1 channel or vice versa + if multi_channel: + diarization = True + multi_channel = False + + try: + filepath: Union[str, List[str]] = await process_audio_file( + filename, num_channels=num_channels + ) + except Exception as e: + try: + background_tasks.add_task(delete_file, filepath=filename) + background_tasks.add_task(delete_file, filepath=filepath) + except: + pass + raise HTTPException( # noqa: B904 + status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Process failed: {e}", + ) + data = AudioRequest( offset_start=offset_start, offset_end=offset_end, @@ -80,6 +103,7 @@ async def inference_with_audio( # noqa: C901 diarization=diarization, batch_size=batch_size, source_lang=source_lang, + num_beams=num_beams, timestamps=timestamps, vocab=vocab, word_timestamps=word_timestamps, @@ -92,22 +116,6 @@ async def inference_with_audio( # noqa: C901 multi_channel=multi_channel, ) - num_channels = await check_num_channels(filename) - - if num_channels > 1 and data.multi_channel is False: - num_channels = 1 # Force mono channel if more than 1 channel - - try: - filepath: Union[str, List[str]] = await process_audio_file( - filename, num_channels=num_channels - ) - - except Exception as e: - raise HTTPException( # noqa: B904 - status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Process failed: {e}", - ) - background_tasks.add_task(delete_file, filepath=filename) task = asyncio.create_task( @@ -120,6 +128,7 @@ async def inference_with_audio( # noqa: C901 batch_size=data.batch_size, multi_channel=data.multi_channel, source_lang=data.source_lang, + num_beams=data.num_beams, timestamps_format=data.timestamps, vocab=data.vocab, word_timestamps=data.word_timestamps, diff --git a/src/wordcab_transcribe/router/v1/audio_url_endpoint.py b/src/wordcab_transcribe/router/v1/audio_url_endpoint.py index c2da96c..3d8185d 100644 --- a/src/wordcab_transcribe/router/v1/audio_url_endpoint.py +++ b/src/wordcab_transcribe/router/v1/audio_url_endpoint.py @@ -56,14 +56,15 @@ def retrieve_service(service, aws_creds): ) -s3_client = retrieve_service( - "s3", - { - "aws_access_key_id": settings.aws_access_key_id, - "aws_secret_access_key": settings.aws_secret_access_key, - "region_name": settings.aws_region_name, - }, -) +if settings.send_results_to_s3: + s3_client = retrieve_service( + "s3", + { + "aws_access_key_id": settings.aws_access_key_id, + "aws_secret_access_key": settings.aws_secret_access_key, + "region_name": settings.aws_region_name, + }, + ) @router.post("", status_code=http_status.HTTP_202_ACCEPTED) @@ -76,21 +77,34 @@ async def inference_with_audio_url( filename = f"audio_url_{shortuuid.ShortUUID().random(length=32)}" data = AudioRequest() if data is None else AudioRequest(**data.dict()) - async def process_audio(): + async def process_audio(data): try: async with download_limit: _filepath = await download_audio_file("url", url, filename) num_channels = await check_num_channels(_filepath) - if num_channels > 1 and data.multi_channel is False: - num_channels = 1 # Force mono channel if more than 1 channel + if ( + num_channels > 1 and data.multi_channel is False + ) or ( + num_channels == 1 and data.multi_channel is True + ): + num_channels = 1 # Force mono channel if more than 1 channel or vice versa + new_data = data.dict() + if data.multi_channel: + new_data["diarization"] = True + new_data["multi_channel"] = False + data = AudioRequest(**new_data) try: filepath: Union[str, List[str]] = await process_audio_file( _filepath, num_channels=num_channels ) - except Exception as e: + try: + background_tasks.add_task(delete_file, filepath=filename) + background_tasks.add_task(delete_file, filepath=filepath) + except: + pass raise HTTPException( # noqa: B904 status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Process failed: {e}", @@ -110,6 +124,7 @@ async def process_audio(): batch_size=data.batch_size, multi_channel=data.multi_channel, source_lang=data.source_lang, + num_beams=data.num_beams, timestamps_format=data.timestamps, vocab=data.vocab, word_timestamps=data.word_timestamps, @@ -123,6 +138,7 @@ async def process_audio(): ) result = await task + utterances, process_times, audio_duration = result result = AudioResponse( utterances=utterances, @@ -148,24 +164,39 @@ async def process_audio(): process_times=process_times, ) - upload_file( - s3_client, - file=bytes(json.dumps(result.model_dump()).encode("UTF-8")), - bucket=settings.aws_storage_bucket_name, - object_name=f"responses/{data.task_token}_{data.job_name}.json", - ) + if settings.debug: + logger.debug(f"Result: {result.model_dump()}") + else: + if settings.send_results_to_s3: + upload_file( + s3_client, + file=bytes(json.dumps(result.model_dump()).encode("UTF-8")), + bucket=settings.aws_storage_bucket_name, + object_name=f"responses/{data.task_token}_{data.job_name}.json", + ) + + if settings.svix_api_key and settings.svix_app_id: + await send_update_with_svix( + data.job_name, + "finished", + { + "job_name": data.job_name, + "task_token": data.task_token, + }, + ) background_tasks.add_task(delete_file, filepath=filepath) - await send_update_with_svix( - data.job_name, - "finished", - { - "job_name": data.job_name, - "task_token": data.task_token, - }, - ) except Exception as e: + try: + background_tasks.add_task(delete_file, filepath=filename) + background_tasks.add_task(delete_file, filepath=filepath) + except: + pass error_message = f"Error during transcription: {e}" + try: + logger.error(result.message) + except: + pass logger.error(error_message) error_payload = { @@ -177,7 +208,7 @@ async def process_audio(): await send_update_with_svix(data.job_name, "error", error_payload) # Add the process_audio function to background tasks - background_tasks.add_task(process_audio) + background_tasks.add_task(process_audio, data) # Return the job name and task token immediately return {"job_name": data.job_name, "task_token": data.task_token} diff --git a/src/wordcab_transcribe/router/v1/youtube_endpoint.py b/src/wordcab_transcribe/router/v1/youtube_endpoint.py index ef15926..7c08ba4 100644 --- a/src/wordcab_transcribe/router/v1/youtube_endpoint.py +++ b/src/wordcab_transcribe/router/v1/youtube_endpoint.py @@ -61,6 +61,7 @@ async def inference_with_youtube( batch_size=data.batch_size, multi_channel=False, source_lang=data.source_lang, + num_beams=data.num_beams, timestamps_format=data.timestamps, vocab=data.vocab, word_timestamps=data.word_timestamps, diff --git a/src/wordcab_transcribe/services/asr_service.py b/src/wordcab_transcribe/services/asr_service.py index ca078ed..6524257 100644 --- a/src/wordcab_transcribe/services/asr_service.py +++ b/src/wordcab_transcribe/services/asr_service.py @@ -19,21 +19,24 @@ # and limitations under the License. """ASR Service module that handle all AI interactions.""" -import asyncio + import time +import aiohttp +import asyncio import traceback -from abc import ABC, abstractmethod from dataclasses import dataclass +from abc import ABC, abstractmethod + from enum import Enum from pathlib import Path +from typing_extensions import Literal +from pydantic import BaseModel, ConfigDict from typing import Iterable, List, Optional, Tuple, Union -import aiohttp import torch -from loguru import logger -from pydantic import BaseModel, ConfigDict from tensorshare import Backend, TensorShare -from typing_extensions import Literal + +from loguru import logger from wordcab_transcribe.config import settings from wordcab_transcribe.logging import time_and_tell, time_and_tell_async @@ -58,6 +61,18 @@ from wordcab_transcribe.utils import early_return, format_segments, read_audio +class AsyncLocationTrustedRedirectSession(aiohttp.ClientSession): + async def _request(self, method, url, location_trusted, *args, **kwargs): + if not location_trusted: + return await super(AsyncLocationTrustedRedirectSession, self)._request(method, url, *args, **kwargs) + kwargs["allow_redirects"] = False + response = await super(AsyncLocationTrustedRedirectSession, self)._request(method, url, *args, **kwargs) + if response.status in (301, 302, 303, 307, 308) and "Location" in response.headers: + new_url = response.headers["Location"] + return await super(AsyncLocationTrustedRedirectSession, self)._request(method, new_url, *args, **kwargs) + return response + + class ExceptionSource(str, Enum): """Exception source enum.""" @@ -132,6 +147,7 @@ class TranscriptionOptions(BaseModel): no_speech_threshold: float repetition_penalty: float source_lang: str + num_beams: int vocab: Union[List[str], None] @@ -376,6 +392,7 @@ async def inference_warmup(self) -> None: diarization=True, multi_channel=False, source_lang="en", + num_beams=1, timestamps_format="s", vocab=None, word_timestamps=False, @@ -397,6 +414,7 @@ async def process_input( # noqa: C901 diarization: bool, multi_channel: bool, source_lang: str, + num_beams: int, timestamps_format: str, vocab: Union[List[str], None], word_timestamps: bool, @@ -434,6 +452,8 @@ async def process_input( # noqa: C901 Whether to do multi-channel diarization or not. source_lang (str): Source language of the audio file. + num_beams (int): + The number of beams to use for the beam search. timestamps_format (str): Timestamps format to use. vocab (Union[List[str], None]): @@ -525,6 +545,7 @@ async def process_input( # noqa: C901 no_speech_threshold=no_speech_threshold, repetition_penalty=repetition_penalty, source_lang=source_lang, + num_beams=num_beams, vocab=vocab, ), ), @@ -809,19 +830,42 @@ def process_post_processing(self, task: ASRTask) -> None: return None async def remote_transcription( - self, - url: str, - data: TranscribeRequest, + self, + url: str, + data: TranscribeRequest, ) -> TranscriptionOutput: """Remote transcription method.""" - async with aiohttp.ClientSession() as session: + headers = {"Content-Type": "application/json"} + + if not settings.debug: + headers = {"Content-Type": "application/x-www-form-urlencoded"} + auth_url = f"{url}/api/v1/auth" + async with aiohttp.ClientSession() as session: + async with session.post( + url=auth_url, + data={"username": settings.username, "password": settings.password}, + headers=headers, + ) as response: + if response.status != 200: + raise Exception(response.status) + else: + token = await response.json() + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {token['access_token']}", + } + + transcription_timeout = aiohttp.ClientTimeout(total=1200) + async with AsyncLocationTrustedRedirectSession(timeout=transcription_timeout) as session: async with session.post( - url=f"{url}/api/v1/transcribe", - data=data.model_dump_json(), - headers={"Content-Type": "application/json"}, + url=f"{url}/api/v1/transcribe", + data=data.model_dump_json(), + headers=headers, + location_trusted=True, ) as response: if response.status != 200: - raise Exception(response.status) + r = await response.json() + raise Exception(r["detail"]) else: return TranscriptionOutput(**await response.json()) @@ -851,11 +895,12 @@ async def remote_diarization( "Authorization": f"Bearer {token['access_token']}", } diarization_timeout = aiohttp.ClientTimeout(total=1200) - async with aiohttp.ClientSession(timeout=diarization_timeout) as session: + async with AsyncLocationTrustedRedirectSession(timeout=diarization_timeout) as session: async with session.post( url=f"{url}/api/v1/diarize", data=data.model_dump_json(), headers=headers, + location_trusted=True, ) as response: if response.status != 200: r = await response.json() @@ -948,7 +993,7 @@ def __init__(self, whisper_model: str, compute_type: str, debug_mode: bool) -> N self.transcription_service = TranscribeService( model_path=whisper_model, - model_engine=settings.model_engine, + model_engine=settings.whisper_engine, compute_type=compute_type, device=self.device, device_index=self.device_index, @@ -1013,7 +1058,7 @@ def __init__( self.transcription_service = TranscribeService( model_path=whisper_model, - model_engine=settings.model_engine, + model_engine=settings.whisper_engine, compute_type=compute_type, device=self.device, device_index=self.device_index, @@ -1066,6 +1111,7 @@ async def process_input( audio=data.audio, batch_size=data.batch_size, source_lang=data.source_lang, + num_beams=data.num_beams, model_index=gpu_index, suppress_blank=False, word_timestamps=True, diff --git a/src/wordcab_transcribe/services/post_processing_service.py b/src/wordcab_transcribe/services/post_processing_service.py index ffa7660..3bc82d8 100644 --- a/src/wordcab_transcribe/services/post_processing_service.py +++ b/src/wordcab_transcribe/services/post_processing_service.py @@ -362,7 +362,6 @@ def reconstruct_multi_channel_utterances( sentences = [] for speaker, word in transcript_words: start_t, end_t, text = word.start, word.end, word.word - print(speaker, previous_speaker, text) if speaker != previous_speaker: sentences.append(current_sentence) diff --git a/src/wordcab_transcribe/services/transcribe_service.py b/src/wordcab_transcribe/services/transcribe_service.py index 8f2fb9e..ba5cacd 100644 --- a/src/wordcab_transcribe/services/transcribe_service.py +++ b/src/wordcab_transcribe/services/transcribe_service.py @@ -203,6 +203,7 @@ def __call__( audio, language=source_lang, initial_prompt=prompt, + beam_size=num_beams, repetition_penalty=repetition_penalty, compression_ratio_threshold=compression_ratio_threshold, log_prob_threshold=log_prob_threshold, @@ -227,13 +228,21 @@ def __call__( initial_prompts=[prompt], batch_size=batch_size, use_vad=internal_vad, - generate_kwargs={"num_beams": num_beams}, + generate_kwargs={ + "num_beams": num_beams, + "length_penalty": 1, + "repetition_penalty": repetition_penalty, + "stop_words_list": suppress_blank, + "bad_words_list": [-1], + "temperature": 1.0, + }, )[0] # TODO: make batch compatible for ix, segment in enumerate(segments): segment["words"] = segment.pop("word_timestamps") for word in segment["words"]: + word["word"] = f" {word['word']}" word["start"] = round(word["start"], 2) word["end"] = round(word["end"], 2) segment["start"] = round(segment.pop("start_time"), 2) @@ -255,6 +264,7 @@ def __call__( outputs = self.multi_channel( audio, source_lang=source_lang, + num_beams=num_beams, suppress_blank=suppress_blank, word_timestamps=word_timestamps, internal_vad=internal_vad, @@ -318,7 +328,7 @@ def multi_channel( self, audio_list: List[Union[str, torch.Tensor, TensorShare]], source_lang: str, - speaker_id: int, + num_beams: int = 1, suppress_blank: bool = False, word_timestamps: bool = True, internal_vad: bool = True, @@ -335,6 +345,7 @@ def multi_channel( Args: audio_list (List[Union[str, torch.Tensor, TensorShare]]): List of audio file paths or audio tensors. source_lang (str): Language of the audio file. + num_beams (int): Number of beams to use during generation. speaker_id (int): Speaker ID used in the diarization. suppress_blank (bool): Whether to suppress blank at the beginning of the sampling. @@ -375,6 +386,7 @@ def multi_channel( _audio, language=source_lang, initial_prompt=prompt, + beam_size=num_beams, repetition_penalty=repetition_penalty, compression_ratio_threshold=compression_ratio_threshold, log_prob_threshold=log_prob_threshold, @@ -402,7 +414,7 @@ def multi_channel( ) final_segments.append(_segment) - outputs.append(final_segments) + outputs.append(MultiChannelTranscriptionOutput(segments=final_segments)) elif self.model_engine == "tensorrt-llm": audio_channels = [] speaker_ids = [] @@ -422,9 +434,16 @@ def multi_channel( lang_codes=[source_lang] * channels_len, tasks=["transcribe"] * channels_len, initial_prompts=[prompt] * channels_len, - batch_size=channels_len, + batch_size=1, use_vad=internal_vad, - generate_kwargs={"num_beams": 1}, + generate_kwargs={ + "num_beams": num_beams, + "length_penalty": 1, + "repetition_penalty": repetition_penalty, + "stop_words_list": suppress_blank, + "bad_words_list": [-1], + "temperature": 1.0, + }, ) for speaker_id, segments in enumerate(segments_list): diff --git a/storage/path/infer_manifest.json b/storage/path/infer_manifest.json new file mode 100644 index 0000000..f19d538 --- /dev/null +++ b/storage/path/infer_manifest.json @@ -0,0 +1 @@ +{"audio_filepath": "/app/temp_outputs/mono_file.wav", "offset": 0, "duration": null, "label": "infer", "text": "-", "rttm_filepath": null, "uem_filepath": null} diff --git a/storage/path/infer_manifest_0.json b/storage/path/infer_manifest_0.json new file mode 100644 index 0000000..6896922 --- /dev/null +++ b/storage/path/infer_manifest_0.json @@ -0,0 +1 @@ +{"audio_filepath": "/home/aleks/PycharmProjects/wordcab-transcribe/wordcab-transcribe/temp_outputs_0/mono_file.wav", "offset": 0, "duration": null, "label": "infer", "text": "-", "rttm_filepath": null, "uem_filepath": null}