Skip to content

Commit

Permalink
fix output multi-channel
Browse files Browse the repository at this point in the history
  • Loading branch information
chainyo committed Oct 5, 2023
1 parent 9c7c513 commit fbc8e4c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 29 deletions.
16 changes: 16 additions & 0 deletions src/wordcab_transcribe/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ class Word(BaseModel):
probability: float


class MultiChannelSegment(NamedTuple):
"""Multi-channel segment model for the API."""

start: float
end: float
text: str
words: List[Word]
speaker: int


class Utterance(BaseModel):
"""Utterance model for the API."""

Expand Down Expand Up @@ -513,6 +523,12 @@ class DiarizationRequest(BaseModel):
num_speakers: int


class MultiChannelTranscriptionOutput(BaseModel):
"""Multi-channel transcription output model for the API."""

segments: List[MultiChannelSegment]


class TranscriptionOutput(BaseModel):
"""Transcription output model for the API."""

Expand Down
7 changes: 4 additions & 3 deletions src/wordcab_transcribe/services/post_processing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from wordcab_transcribe.models import (
DiarizationOutput,
DiarizationSegment,
MultiChannelTranscriptionOutput,
Timestamps,
TranscriptionOutput,
Utterance,
Expand Down Expand Up @@ -75,13 +76,13 @@ def single_channel_speaker_mapping(
return utterances

def multi_channel_speaker_mapping(
self, multi_channel_segments: List[TranscriptionOutput]
self, multi_channel_segments: List[MultiChannelTranscriptionOutput]
) -> TranscriptionOutput:
"""
Run the multi-channel post-processing functions on the inputs by merging the segments based on the timestamps.
Args:
multi_channel_segments (List[TranscriptionOutput]):
multi_channel_segments (List[MultiChannelTranscriptionOutput]):
List of segments from multi speakers.
Returns:
Expand All @@ -93,7 +94,7 @@ def multi_channel_speaker_mapping(
for segment in output.segments
for word in segment.words
]
words_with_speaker_mapping.sort(key=lambda _, word: word.start)
words_with_speaker_mapping.sort(key=lambda x: x[1].start)

utterances: List[Utterance] = self.reconstruct_multi_channel_utterances(
words_with_speaker_mapping
Expand Down
43 changes: 17 additions & 26 deletions src/wordcab_transcribe/services/transcribe_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
from loguru import logger
from tensorshare import Backend, TensorShare

from wordcab_transcribe.models import TranscriptionOutput
from wordcab_transcribe.models import (
MultiChannelSegment,
MultiChannelTranscriptionOutput,
TranscriptionOutput,
Word,
)


class FasterWhisperModel(NamedTuple):
Expand Down Expand Up @@ -312,7 +317,7 @@ def multi_channel(
no_speech_threshold: float = 0.6,
condition_on_previous_text: bool = False,
prompt: Optional[str] = None,
) -> TranscriptionOutput:
) -> MultiChannelTranscriptionOutput:
"""
Transcribe an audio file using the faster-whisper original pipeline.
Expand Down Expand Up @@ -342,7 +347,7 @@ def multi_channel(
prompt (Optional[str]): Initial prompt to use for the generation.
Returns:
TranscriptionOutput: Transcription output.
MultiChannelTranscriptionOutput: Multi-channel transcription segments in a list.
"""
if isinstance(audio, torch.Tensor):
_audio = audio.numpy()
Expand Down Expand Up @@ -374,27 +379,13 @@ def multi_channel(
)

for segment in segments:
segment_dict: dict = {
"start": None,
"end": None,
"text": segment.text,
"words": [],
"speaker": speaker_id,
}

for word in segment.words:
segment_dict["words"].append(
{
"start": word.start,
"end": word.end,
"word": word.word,
"probability": word.probability,
}
)

segment_dict["start"] = segment_dict["words"][0]["start"]
segment_dict["end"] = segment_dict["words"][-1]["end"]

final_segments.append(segment_dict)
_segment = MultiChannelSegment(
start=segment.start,
end=segment.end,
text=segment.text,
words=[Word(**word._asdict()) for word in segment.words],
speaker=speaker_id,
)
final_segments.append(_segment)

return TranscriptionOutput(segments=final_segments)
return MultiChannelTranscriptionOutput(segments=final_segments)

0 comments on commit fbc8e4c

Please sign in to comment.