Skip to content

Commit

Permalink
Merge pull request #181 from mobiusml/js_diar_transcription
Browse files Browse the repository at this point in the history
Speaker diarization with whisper transcription
  • Loading branch information
Jiltseb authored Sep 26, 2024
2 parents 1ca768e + 35364bb commit 9eb34b3
Show file tree
Hide file tree
Showing 12 changed files with 1,967 additions and 10 deletions.
6 changes: 6 additions & 0 deletions aana/core/models/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ class AsrWord(BaseModel):
Attributes:
word (str): The word text.
speaker (str| None): Speaker label for the word.
time_interval (TimeInterval): Time interval of the word.
alignment_confidence (float): Alignment confidence of the word, >= 0.0 and <= 1.0.
"""

word: str = Field(description="The word text")
speaker: str | None = Field(None, description="Speaker label for the word")
time_interval: TimeInterval = Field(description="Time interval of the word")
alignment_confidence: float = Field(
ge=0.0, le=1.0, description="Alignment confidence of the word"
Expand All @@ -52,6 +54,7 @@ def from_whisper(cls, whisper_word: WhisperWord) -> "AsrWord":
AsrWord: The converted AsrWord.
"""
return cls(
speaker=None,
word=whisper_word.word,
time_interval=TimeInterval(start=whisper_word.start, end=whisper_word.end),
alignment_confidence=whisper_word.probability,
Expand All @@ -73,6 +76,7 @@ class AsrSegment(BaseModel):
confidence (float | None): Confidence of the segment.
no_speech_confidence (float | None): Chance of being a silence segment.
words (list[AsrWord]): List of words in the segment. Default is [].
speaker (str | None): Speaker label. Default is None.
"""

text: str = Field(description="The text of the segment (transcript/translation)")
Expand All @@ -86,6 +90,7 @@ class AsrSegment(BaseModel):
words: list[AsrWord] = Field(
description="List of words in the segment", default_factory=list
)
speaker: str | None = Field(None, description="speaker label of the segment")

@classmethod
def from_whisper(cls, whisper_segment: WhisperSegment) -> "AsrSegment":
Expand Down Expand Up @@ -116,6 +121,7 @@ def from_whisper(cls, whisper_segment: WhisperSegment) -> "AsrSegment":
confidence=confidence,
no_speech_confidence=no_speech_confidence,
words=words,
speaker=None,
)

model_config = ConfigDict(
Expand Down
24 changes: 20 additions & 4 deletions aana/deployments/pyannote_speaker_diarization_deployment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, TypedDict

import torch
from huggingface_hub.utils import GatedRepoError
from pyannote.audio import Pipeline
from pyannote.core import Annotation
from pydantic import BaseModel, ConfigDict, Field
Expand All @@ -15,6 +16,7 @@
from aana.core.models.time import TimeInterval
from aana.deployments.base_deployment import BaseDeployment
from aana.exceptions.runtime import InferenceException
from aana.processors.speaker import combine_homogeneous_speaker_diarization_segments


class SpeakerDiarizationOutput(TypedDict):
Expand Down Expand Up @@ -69,9 +71,17 @@ async def apply_config(self, config: dict[str, Any]):
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)

# load model using pyannote Pipeline
self.diarize_model = Pipeline.from_pretrained(self.model_id)
self.diarize_model.to(torch.device(self.device))
try:
# load model using pyannote Pipeline
self.diarize_model = Pipeline.from_pretrained(self.model_id)

if self.diarize_model:
self.diarize_model.to(torch.device(self.device))

except Exception as e:
raise GatedRepoError(
message=f"This repository is private and requires a token to accept user conditions and access models in {self.model_id} pipeline."
) from e

async def __inference(
self, audio: Audio, params: PyannoteSpeakerDiarizationParams
Expand Down Expand Up @@ -134,4 +144,10 @@ async def diarize(
)
)

return SpeakerDiarizationOutput(segments=speaker_diarization_segments)
# Combine homogeneous speaker segments.
processed_speaker_diarization_segments = (
combine_homogeneous_speaker_diarization_segments(
speaker_diarization_segments
)
)
return SpeakerDiarizationOutput(segments=processed_speaker_diarization_segments)
Loading

0 comments on commit 9eb34b3

Please sign in to comment.