Skip to content

Commit

Permalink
fix logging and Word schema
Browse files Browse the repository at this point in the history
  • Loading branch information
chainyo committed Oct 3, 2023
1 parent cb0551d commit 91b0ff2
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 15 deletions.
8 changes: 5 additions & 3 deletions src/wordcab_transcribe/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def time_and_tell(
The appropriate wrapper for the function.
"""
start_time = time.time()
result = func()
result = func
process_time = time.time() - start_time

if debug_mode:
Expand All @@ -121,11 +121,13 @@ async def time_and_tell_async(
start_time = time.time()

if asyncio.iscoroutinefunction(func) or asyncio.iscoroutine(func):
result = await func()
result = await func
else:
loop = asyncio.get_event_loop()
if isinstance(func, partial):
result = await loop.run_in_executor(None, func.func, *func.args, **func.keywords)
result = await loop.run_in_executor(
None, func.func, *func.args, **func.keywords
)
else:
result = await loop.run_in_executor(None, func)

Expand Down
11 changes: 10 additions & 1 deletion src/wordcab_transcribe/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from enum import Enum
from typing import List, Literal, NamedTuple, Optional, Union

from faster_whisper.transcribe import Segment, Word
from faster_whisper.transcribe import Segment
from pydantic import BaseModel, field_validator
from tensorshare import TensorShare

Expand All @@ -44,6 +44,15 @@ class Timestamps(str, Enum):
hour_minute_second = "hms"


class Word(BaseModel):
"""Word model for the API."""

word: str
start: float
end: float
probability: float


class Utterance(BaseModel):
"""Utterance model for the API."""

Expand Down
23 changes: 13 additions & 10 deletions src/wordcab_transcribe/services/asr_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None:
try:
if isinstance(task.transcription.execution, LocalExecution):
out = await time_and_tell_async(
lambda: self.services["transcription"](
self.services["transcription"](
task.audio,
model_index=task.transcription.execution.index,
suppress_blank=False,
Expand Down Expand Up @@ -497,7 +497,7 @@ async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None:
**task.transcription.options.model_dump(),
)
out = await time_and_tell_async(
lambda: self.remote_transcription(
self.remote_transcription(
url=task.transcription.execution.url,
data=data,
),
Expand Down Expand Up @@ -536,7 +536,7 @@ async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None:
try:
if isinstance(task.diarization.execution, LocalExecution):
out = await time_and_tell_async(
lambda: self.services["diarization"](
self.services["diarization"](
waveform=task.audio,
audio_duration=task.duration,
oracle_num_speakers=task.diarization.num_speakers,
Expand All @@ -557,7 +557,7 @@ async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None:
num_speakers=task.diarization.num_speakers,
)
out = await time_and_tell_async(
lambda: self.remote_diarization(
self.remote_diarization(
url=task.diarization.execution.url,
data=data,
),
Expand Down Expand Up @@ -601,7 +601,7 @@ def process_post_processing(self, task: ASRTask) -> None:

if task.multi_channel:
utterances, process_time = time_and_tell(
lambda: self.services[
self.services[
"post_processing"
].multi_channel_speaker_mapping(task.transcription.result),
func_name="multi_channel_speaker_mapping",
Expand All @@ -611,7 +611,7 @@ def process_post_processing(self, task: ASRTask) -> None:

else:
formatted_segments, process_time = time_and_tell(
lambda: format_segments(
format_segments(
transcription_output=task.transcription.result,
),
func_name="format_segments",
Expand All @@ -621,7 +621,7 @@ def process_post_processing(self, task: ASRTask) -> None:

if task.diarization.execution is not None:
utterances, process_time = time_and_tell(
lambda: self.services[
self.services[
"post_processing"
].single_channel_speaker_mapping(
transcript_segments=formatted_segments,
Expand All @@ -636,7 +636,7 @@ def process_post_processing(self, task: ASRTask) -> None:
utterances = formatted_segments

final_utterances, process_time = time_and_tell(
lambda: self.services[
self.services[
"post_processing"
].final_processing_before_returning(
utterances=utterances,
Expand Down Expand Up @@ -672,9 +672,10 @@ async def remote_transcription(
async with session.post(
url=f"{url}/api/v1/transcribe",
data=data.model_dump_json(),
headers={"Content-Type": "application/json"},
) as response:
if response.status != 200:
raise Exception(response.detail)
raise Exception(response.status)
else:
return TranscriptionOutput(**await response.json())

Expand All @@ -688,9 +689,11 @@ async def remote_diarization(
async with session.post(
url=f"{url}/api/v1/diarize",
data=data.model_dump_json(),
headers={"Content-Type": "application/json"},
) as response:
if response.status != 200:
raise Exception(response.detail)
r = await response.json()
raise Exception(r["detail"])
else:
return DiarizationOutput(**await response.json())

Expand Down
11 changes: 10 additions & 1 deletion src/wordcab_transcribe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
Timestamps,
TranscriptionOutput,
Utterance,
Word,
)


Expand Down Expand Up @@ -424,7 +425,15 @@ def format_segments(transcription_output: TranscriptionOutput) -> List[Utterance
text=segment.text,
start=segment.start,
end=segment.end,
words=segment.words,
words=[
Word(
word=word.word,
start=word.start,
end=word.end,
probability=word.probability,
)
for word in segment.words
]
)
for segment in transcription_output.segments
]
Expand Down

0 comments on commit 91b0ff2

Please sign in to comment.