Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dual_channel bug #268

Merged
merged 2 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/wordcab_transcribe/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ class Word(BaseModel):
probability: float


class MultiChannelSegment(NamedTuple):
"""Multi-channel segment model for the API."""

start: float
end: float
text: str
words: List[Word]
speaker: int


class Utterance(BaseModel):
"""Utterance model for the API."""

Expand Down Expand Up @@ -513,6 +523,12 @@ class DiarizationRequest(BaseModel):
num_speakers: int


class MultiChannelTranscriptionOutput(BaseModel):
"""Multi-channel transcription output model for the API."""

segments: List[MultiChannelSegment]


class TranscriptionOutput(BaseModel):
"""Transcription output model for the API."""

Expand Down
7 changes: 4 additions & 3 deletions src/wordcab_transcribe/services/post_processing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from wordcab_transcribe.models import (
DiarizationOutput,
DiarizationSegment,
MultiChannelTranscriptionOutput,
Timestamps,
TranscriptionOutput,
Utterance,
Expand Down Expand Up @@ -75,13 +76,13 @@ def single_channel_speaker_mapping(
return utterances

def multi_channel_speaker_mapping(
self, multi_channel_segments: List[TranscriptionOutput]
self, multi_channel_segments: List[MultiChannelTranscriptionOutput]
) -> TranscriptionOutput:
"""
Run the multi-channel post-processing functions on the inputs by merging the segments based on the timestamps.

Args:
multi_channel_segments (List[TranscriptionOutput]):
multi_channel_segments (List[MultiChannelTranscriptionOutput]):
List of segments from multi speakers.

Returns:
Expand All @@ -93,7 +94,7 @@ def multi_channel_speaker_mapping(
for segment in output.segments
for word in segment.words
]
words_with_speaker_mapping.sort(key=lambda _, word: word.start)
words_with_speaker_mapping.sort(key=lambda x: x[1].start)

utterances: List[Utterance] = self.reconstruct_multi_channel_utterances(
words_with_speaker_mapping
Expand Down
43 changes: 17 additions & 26 deletions src/wordcab_transcribe/services/transcribe_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
from loguru import logger
from tensorshare import Backend, TensorShare

from wordcab_transcribe.models import TranscriptionOutput
from wordcab_transcribe.models import (
MultiChannelSegment,
MultiChannelTranscriptionOutput,
TranscriptionOutput,
Word,
)


class FasterWhisperModel(NamedTuple):
Expand Down Expand Up @@ -312,7 +317,7 @@ def multi_channel(
no_speech_threshold: float = 0.6,
condition_on_previous_text: bool = False,
prompt: Optional[str] = None,
) -> TranscriptionOutput:
) -> MultiChannelTranscriptionOutput:
"""
Transcribe an audio file using the faster-whisper original pipeline.

Expand Down Expand Up @@ -342,7 +347,7 @@ def multi_channel(
prompt (Optional[str]): Initial prompt to use for the generation.

Returns:
TranscriptionOutput: Transcription output.
MultiChannelTranscriptionOutput: Multi-channel transcription segments in a list.
"""
if isinstance(audio, torch.Tensor):
_audio = audio.numpy()
Expand Down Expand Up @@ -374,27 +379,13 @@ def multi_channel(
)

for segment in segments:
segment_dict: dict = {
"start": None,
"end": None,
"text": segment.text,
"words": [],
"speaker": speaker_id,
}

for word in segment.words:
segment_dict["words"].append(
{
"start": word.start,
"end": word.end,
"word": word.word,
"score": word.probability,
}
)

segment_dict["start"] = segment_dict["words"][0]["start"]
segment_dict["end"] = segment_dict["words"][-1]["end"]

final_segments.append(segment_dict)
_segment = MultiChannelSegment(
start=segment.start,
end=segment.end,
text=segment.text,
words=[Word(**word._asdict()) for word in segment.words],
speaker=speaker_id,
)
final_segments.append(_segment)

return TranscriptionOutput(segments=final_segments)
return MultiChannelTranscriptionOutput(segments=final_segments)