From f8137af64c71fe38988cf27f3edac7ce1c42d5e7 Mon Sep 17 00:00:00 2001 From: chainyo Date: Tue, 3 Oct 2023 16:29:47 +0000 Subject: [PATCH] remove auto creation of all the services --- .../services/asr_service.py | 48 +++++++++---------- src/wordcab_transcribe/utils.py | 2 +- 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/src/wordcab_transcribe/services/asr_service.py b/src/wordcab_transcribe/services/asr_service.py index 6860e11..2f3f9ad 100644 --- a/src/wordcab_transcribe/services/asr_service.py +++ b/src/wordcab_transcribe/services/asr_service.py @@ -207,21 +207,6 @@ def __init__( super().__init__() self.services: dict = { - "transcription": TranscribeService( - model_path=whisper_model, - compute_type=compute_type, - device=self.device, - device_index=self.device_index, - extra_languages=extra_languages, - extra_languages_model_paths=extra_languages_model_paths, - ), - "diarization": DiarizeService( - device=self.device, - device_index=self.device_index, - window_lengths=window_lengths, - shift_lengths=shift_lengths, - multiscale_weights=multiscale_weights, - ), "post_processing": PostProcessingService(), "vad": VadService(), } @@ -241,12 +226,27 @@ def __init__( ) else: self.use_remote_transcription = False + self.services["transcription"] = TranscribeService( + model_path=whisper_model, + compute_type=compute_type, + device=self.device, + device_index=self.device_index, + extra_languages=extra_languages, + extra_languages_model_paths=extra_languages_model_paths, + ) if diarize_server_urls is not None: self.use_remote_diarization = True self.diarization_url_handler = URLService(remote_urls=diarize_server_urls) else: self.use_remote_diarization = False + self.services["diarization"] = DiarizeService( + device=self.device, + device_index=self.device_index, + window_lengths=window_lengths, + shift_lengths=shift_lengths, + multiscale_weights=multiscale_weights, + ) self.debug_mode = debug_mode @@ -469,7 +469,7 @@ async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None: try: if isinstance(task.transcription.execution, LocalExecution): out = await time_and_tell_async( - self.services["transcription"]( + lambda: self.services["transcription"]( task.audio, model_index=task.transcription.execution.index, suppress_blank=False, @@ -536,7 +536,7 @@ async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None: try: if isinstance(task.diarization.execution, LocalExecution): out = await time_and_tell_async( - self.services["diarization"]( + lambda: self.services["diarization"]( waveform=task.audio, audio_duration=task.duration, oracle_num_speakers=task.diarization.num_speakers, @@ -601,9 +601,9 @@ def process_post_processing(self, task: ASRTask) -> None: if task.multi_channel: utterances, process_time = time_and_tell( - self.services[ - "post_processing" - ].multi_channel_speaker_mapping(task.transcription.result), + self.services["post_processing"].multi_channel_speaker_mapping( + task.transcription.result + ), func_name="multi_channel_speaker_mapping", debug_mode=self.debug_mode, ) @@ -621,9 +621,7 @@ def process_post_processing(self, task: ASRTask) -> None: if task.diarization.execution is not None: utterances, process_time = time_and_tell( - self.services[ - "post_processing" - ].single_channel_speaker_mapping( + self.services["post_processing"].single_channel_speaker_mapping( transcript_segments=formatted_segments, speaker_timestamps=task.diarization.result, word_timestamps=task.word_timestamps, @@ -636,9 +634,7 @@ def process_post_processing(self, task: ASRTask) -> None: utterances = formatted_segments final_utterances, process_time = time_and_tell( - self.services[ - "post_processing" - ].final_processing_before_returning( + self.services["post_processing"].final_processing_before_returning( utterances=utterances, offset_start=task.offset_start, timestamps_format=task.timestamps_format, diff --git a/src/wordcab_transcribe/utils.py b/src/wordcab_transcribe/utils.py index 43a0c77..fcb752a 100644 --- a/src/wordcab_transcribe/utils.py +++ b/src/wordcab_transcribe/utils.py @@ -433,7 +433,7 @@ def format_segments(transcription_output: TranscriptionOutput) -> List[Utterance probability=word.probability, ) for word in segment.words - ] + ], ) for segment in transcription_output.segments ]