Skip to content

Commit

Permalink
update process
Browse files Browse the repository at this point in the history
  • Loading branch information
chainyo committed Oct 3, 2023
1 parent 5fa454c commit cb0551d
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 45 deletions.
13 changes: 12 additions & 1 deletion src/wordcab_transcribe/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@

"""Logging module to add a logging middleware to the Wordcab Transcribe API."""

import asyncio
import sys
import time
import uuid
from functools import partial
from typing import Any, Awaitable, Callable, Tuple

from loguru import logger
Expand Down Expand Up @@ -117,7 +119,16 @@ async def time_and_tell_async(
The appropriate wrapper for the function.
"""
start_time = time.time()
result = await func()

if asyncio.iscoroutinefunction(func) or asyncio.iscoroutine(func):
result = await func()
else:
loop = asyncio.get_event_loop()
if isinstance(func, partial):
result = await loop.run_in_executor(None, func.func, *func.args, **func.keywords)
else:
result = await loop.run_in_executor(None, func)

process_time = time.time() - start_time

if debug_mode:
Expand Down
73 changes: 35 additions & 38 deletions src/wordcab_transcribe/services/asr_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,11 @@ async def process_input( # noqa: C901
if isinstance(task.transcription.result, ProcessException):
return task.transcription.result

await self.process_post_processing(task)
await asyncio.get_event_loop().run_in_executor(
None,
self.process_post_processing,
task,
)

if isinstance(task.post_processing.result, ProcessException):
return task.post_processing.result
Expand Down Expand Up @@ -464,21 +468,18 @@ async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None:
"""
try:
if isinstance(task.transcription.execution, LocalExecution):
out = asyncio.get_event_loop().run_in_executor(
None,
time_and_tell(
lambda: self.services["transcription"](
task.audio,
model_index=task.transcription.execution.index,
suppress_blank=False,
word_timestamps=True,
**task.transcription.options.model_dump(),
),
func_name="transcription",
debug_mode=debug_mode,
out = await time_and_tell_async(
lambda: self.services["transcription"](
task.audio,
model_index=task.transcription.execution.index,
suppress_blank=False,
word_timestamps=True,
**task.transcription.options.model_dump(),
),
func_name="transcription",
debug_mode=debug_mode,
)
result, process_time = await out
result, process_time = out

elif isinstance(task.transcription.execution, RemoteExecution):
if isinstance(task.audio, list):
Expand All @@ -495,14 +496,15 @@ async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None:
audio=ts,
**task.transcription.options.model_dump(),
)
result, process_time = await time_and_tell_async(
out = await time_and_tell_async(
lambda: self.remote_transcription(
url=task.transcription.execution.url,
data=data,
),
func_name="transcription",
debug_mode=debug_mode,
)
result, process_time = out

else:
raise NotImplementedError("No execution method specified.")
Expand Down Expand Up @@ -533,21 +535,18 @@ async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None:
"""
try:
if isinstance(task.diarization.execution, LocalExecution):
out = asyncio.get_event_loop().run_in_executor(
None,
time_and_tell(
lambda: self.services["diarization"](
waveform=task.audio,
audio_duration=task.duration,
oracle_num_speakers=task.diarization.num_speakers,
model_index=task.diarization.execution.index,
vad_service=self.services["vad"],
),
func_name="diarization",
debug_mode=debug_mode,
out = await time_and_tell_async(
lambda: self.services["diarization"](
waveform=task.audio,
audio_duration=task.duration,
oracle_num_speakers=task.diarization.num_speakers,
model_index=task.diarization.execution.index,
vad_service=self.services["vad"],
),
func_name="diarization",
debug_mode=debug_mode,
)
result, process_time = await out
result, process_time = out

elif isinstance(task.diarization.execution, RemoteExecution):
ts = TensorShare.from_dict({"audio": task.audio}, backend=Backend.TORCH)
Expand All @@ -557,14 +556,15 @@ async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None:
duration=task.duration,
num_speakers=task.diarization.num_speakers,
)
result, process_time = await time_and_tell_async(
out = await time_and_tell_async(
lambda: self.remote_diarization(
url=task.diarization.execution.url,
data=data,
),
func_name="diarization",
debug_mode=debug_mode,
)
result, process_time = out

elif task.diarization.execution is None:
result = None
Expand All @@ -586,7 +586,7 @@ async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None:

return None

async def process_post_processing(self, task: ASRTask) -> None:
def process_post_processing(self, task: ASRTask) -> None:
"""
Process a task of post-processing.
Expand All @@ -598,10 +598,9 @@ async def process_post_processing(self, task: ASRTask) -> None:
"""
try:
total_post_process_time = 0
diarization = False if task.diarization.execution is None else True

if task.multi_channel:
utterances, process_time = await time_and_tell_async(
utterances, process_time = time_and_tell(
lambda: self.services[
"post_processing"
].multi_channel_speaker_mapping(task.transcription.result),
Expand All @@ -611,7 +610,7 @@ async def process_post_processing(self, task: ASRTask) -> None:
total_post_process_time += process_time

else:
formatted_segments, process_time = await time_and_tell_async(
formatted_segments, process_time = time_and_tell(
lambda: format_segments(
transcription_output=task.transcription.result,
),
Expand All @@ -620,8 +619,8 @@ async def process_post_processing(self, task: ASRTask) -> None:
)
total_post_process_time += process_time

if diarization:
utterances, process_time = await time_and_tell_async(
if task.diarization.execution is not None:
utterances, process_time = time_and_tell(
lambda: self.services[
"post_processing"
].single_channel_speaker_mapping(
Expand All @@ -636,13 +635,11 @@ async def process_post_processing(self, task: ASRTask) -> None:
else:
utterances = formatted_segments

final_utterances, process_time = await time_and_tell_async(
final_utterances, process_time = time_and_tell(
lambda: self.services[
"post_processing"
].final_processing_before_returning(
utterances=utterances,
diarization=diarization,
multi_channel=task.multi_channel,
offset_start=task.offset_start,
timestamps_format=task.timestamps_format,
word_timestamps=task.word_timestamps,
Expand Down
6 changes: 0 additions & 6 deletions src/wordcab_transcribe/services/post_processing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,6 @@ def reconstruct_multi_channel_utterances(
def final_processing_before_returning(
self,
utterances: List[Utterance],
diarization: bool,
multi_channel: bool,
offset_start: Union[float, None],
timestamps_format: Timestamps,
word_timestamps: bool,
Expand All @@ -338,10 +336,6 @@ def final_processing_before_returning(
Args:
utterances (List[Utterance]):
List of utterances.
diarization (bool):
Whether diarization is enabled.
multi_channel (bool):
Whether multi-channel is enabled.
offset_start (Union[float, None]):
Offset start.
timestamps_format (Timestamps):
Expand Down

0 comments on commit cb0551d

Please sign in to comment.