Skip to content

Commit

Permalink
remove auto creation of all the services
Browse files Browse the repository at this point in the history
  • Loading branch information
chainyo committed Oct 3, 2023
1 parent 91b0ff2 commit f8137af
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 27 deletions.
48 changes: 22 additions & 26 deletions src/wordcab_transcribe/services/asr_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,21 +207,6 @@ def __init__(
super().__init__()

self.services: dict = {
"transcription": TranscribeService(
model_path=whisper_model,
compute_type=compute_type,
device=self.device,
device_index=self.device_index,
extra_languages=extra_languages,
extra_languages_model_paths=extra_languages_model_paths,
),
"diarization": DiarizeService(
device=self.device,
device_index=self.device_index,
window_lengths=window_lengths,
shift_lengths=shift_lengths,
multiscale_weights=multiscale_weights,
),
"post_processing": PostProcessingService(),
"vad": VadService(),
}
Expand All @@ -241,12 +226,27 @@ def __init__(
)
else:
self.use_remote_transcription = False
self.services["transcription"] = TranscribeService(
model_path=whisper_model,
compute_type=compute_type,
device=self.device,
device_index=self.device_index,
extra_languages=extra_languages,
extra_languages_model_paths=extra_languages_model_paths,
)

if diarize_server_urls is not None:
self.use_remote_diarization = True
self.diarization_url_handler = URLService(remote_urls=diarize_server_urls)
else:
self.use_remote_diarization = False
self.services["diarization"] = DiarizeService(
device=self.device,
device_index=self.device_index,
window_lengths=window_lengths,
shift_lengths=shift_lengths,
multiscale_weights=multiscale_weights,
)

self.debug_mode = debug_mode

Expand Down Expand Up @@ -469,7 +469,7 @@ async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None:
try:
if isinstance(task.transcription.execution, LocalExecution):
out = await time_and_tell_async(
self.services["transcription"](
lambda: self.services["transcription"](
task.audio,
model_index=task.transcription.execution.index,
suppress_blank=False,
Expand Down Expand Up @@ -536,7 +536,7 @@ async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None:
try:
if isinstance(task.diarization.execution, LocalExecution):
out = await time_and_tell_async(
self.services["diarization"](
lambda: self.services["diarization"](
waveform=task.audio,
audio_duration=task.duration,
oracle_num_speakers=task.diarization.num_speakers,
Expand Down Expand Up @@ -601,9 +601,9 @@ def process_post_processing(self, task: ASRTask) -> None:

if task.multi_channel:
utterances, process_time = time_and_tell(
self.services[
"post_processing"
].multi_channel_speaker_mapping(task.transcription.result),
self.services["post_processing"].multi_channel_speaker_mapping(
task.transcription.result
),
func_name="multi_channel_speaker_mapping",
debug_mode=self.debug_mode,
)
Expand All @@ -621,9 +621,7 @@ def process_post_processing(self, task: ASRTask) -> None:

if task.diarization.execution is not None:
utterances, process_time = time_and_tell(
self.services[
"post_processing"
].single_channel_speaker_mapping(
self.services["post_processing"].single_channel_speaker_mapping(
transcript_segments=formatted_segments,
speaker_timestamps=task.diarization.result,
word_timestamps=task.word_timestamps,
Expand All @@ -636,9 +634,7 @@ def process_post_processing(self, task: ASRTask) -> None:
utterances = formatted_segments

final_utterances, process_time = time_and_tell(
self.services[
"post_processing"
].final_processing_before_returning(
self.services["post_processing"].final_processing_before_returning(
utterances=utterances,
offset_start=task.offset_start,
timestamps_format=task.timestamps_format,
Expand Down
2 changes: 1 addition & 1 deletion src/wordcab_transcribe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def format_segments(transcription_output: TranscriptionOutput) -> List[Utterance
probability=word.probability,
)
for word in segment.words
]
],
)
for segment in transcription_output.segments
]
Expand Down

0 comments on commit f8137af

Please sign in to comment.