404
+ +Page not found
+ + +diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000..e69de29 diff --git a/404.html b/404.html new file mode 100644 index 0000000..2d3bc5e --- /dev/null +++ b/404.html @@ -0,0 +1,160 @@ + + +
+ + + + +Page not found
+ + +This site contains the project documentation for the leb
project used for the Sunbird AI Language Projects.
Welcome to the Leb project documentation!
+This documentation serves as the official guide for the Leb project, which is part of the Sunbird AI Language Projects. The goal of this documentation is to provide you with comprehensive information on how to use the Leb project effectively.
+Quickly find what you're looking for depending on your use case by looking at the different sections and subsections.
+ +This part of the project documentation focuses on
+an information-oriented approach. Use it as a
+reference for the technical implementation of the
+leb
project code.
Leb, inspired by the Luo word for 'language,' is a project dedicated to the seamless integration of Sunbird AI Language Technology. Our goal is to empower developers to effortlessly create machine learning models for Neural Machine Translation (NMT), Speech-to-Text (STT), Text-to-Speech (TTS), and other language-related applications.
+By drawing inspiration from the Luo concept of 'language' itself, Project Leb is envisioned as a springboard for connecting ideas and cultures across the Africa's diverse range of tongues and dialects. Just as languages connect people, this technology would connect languages - old and new - through an inclusive platform optimized for integration, accessibility, and human-centric design.
+ +This part of the project documentation focuses on an +understanding-oriented approach. You'll get a +chance to read about the background of the project, +as well as reasoning about how it was implemented.
+++Note: Expand this section by considering the +following points:
+
This part of the project documentation focuses on a +problem-oriented approach. You'll tackle common +tasks that you might have, with the help of the code +provided in this project.
+ +
+import sys
+sys.path.append('../..')
+import leb.dataset
+import leb.utils
+import yaml
+
+
+set up the configs
+
+yaml_config = '''
+huggingface_load:
+ path: Sunbird/salt
+ split: train
+ name: text-all
+source:
+ type: text
+ language: eng
+ preprocessing:
+ - prefix_target_language
+target:
+ type: text
+ language: [lug, ach]
+'''
+
+config = yaml.safe_load(yaml_config)
+ds = leb.dataset.create(config)
+list(ds.take(5))
+
+
+output
+[{'source': '>>lug<< Eggplants always grow best under warm conditions.',
+ 'target': 'Bbiringanya lubeerera asinga kukulira mu mbeera ya bugumu'},
+ {'source': '>>ach<< Eggplants always grow best under warm conditions.',
+ 'target': 'Bilinyanya pol kare dongo maber ka lyeto tye'},
+ {'source': '>>lug<< Farmland is sometimes a challenge to farmers.',
+ 'target': "Ettaka ly'okulimirako n'okulundirako ebiseera ebimu kisoomooza abalimi"},
+ {'source': '>>ach<< Farmland is sometimes a challenge to farmers.',
+ 'target': 'Ngom me pur i kare mukene obedo peko madit bot lupur'},
+ {'source': '>>lug<< Farmers should be encouraged to grow more coffee.',
+ 'target': 'Abalimi balina okukubirizibwa okwongera okulima emmwanyi'}]
+
+
+This is how a basic data loader works
+ +This tutorial provides a step-by-step guide on how to perform multilingual Automatic Speech Recognition (ASR) training for Luganda and English languages using the Leb module.
+Before getting started, ensure that you have the following prerequisites:
+To begin, install the necessary dependencies by running the following commands:
+!pip install -q jiwer evaluate
+!pip install -qU accelerate
+!pip install -q transformers[torch]
+!git clone https://github.com/jqug/leb.git
+!pip install -qr leb/requirements.txt
+!pip install -q mlflow psutil pynvml
+
+These commands will install the required libraries, including Jiwer, Evaluate, Accelerate, Transformers, MLflow, and the Leb module.
+Create a YAML configuration file named asr_config.yml with the necessary settings for your training. Here's an example configuration:
+train:
+ source:
+ language: [luganda, english]
+ # Add other training dataset configurations
+
+validation:
+ source:
+ language: [luganda, english]
+ # Add other validation dataset configurations
+
+pretrained_model: "facebook/wav2vec2-large-xlsr-53"
+pretrained_adapter: null
+
+Wav2Vec2ForCTC_args:
+ adapter_model_name: "wav2vec2"
+
+training_args:
+ output_dir: "luganda_english_asr"
+ # Add other training arguments
+
+Load the configuration file in your Python script.
+To use the trained model for inference, follow these steps:
+To use the trained model for inference, follow these steps:
+model = Wav2Vec2ForCTC.from_pretrained("path/to/trained/model")
+processor = Wav2Vec2Processor.from_pretrained("path/to/processor")
+
+
+ Speaker Diarization is the process of partitioning an audio stream into homogeneous segments according to the identity of the speaker. It answers the question "who spoke when?" in a given audio or video recording. This is a crucial step in many speech processing applications, such as transcription, speaker recognition, and meeting analysis.
+Speaker Diarization at Sunbird is performed using pyannote's speaker-diarization-3.0 as the main tool for identifying speakers and the text that corresponds to them along with the Sunbird mms that aids in transcription of the text in the audio.
+Setup and Installation
+The necessary libraries to perform speaker diarization required for efficient execution of the pipeline and determine various metrics are installed and imported.
+!pip install pyctcdecode
+!pip install kenlm
+!pip install jiwer
+!pip install huggingface-hub
+!pip install transformers
+!pip install pandas
+!pip install pyannote.audio
+!pip install onnxruntime
+
+
+import torch
+from huggingface_hub import hf_hub_download
+from transformers import (
+ Wav2Vec2ForCTC,
+ Wav2Vec2CTCTokenizer,
+ Wav2Vec2FeatureExtractor,
+ Wav2Vec2Processor,
+ Wav2Vec2ProcessorWithLM,
+ AutomaticSpeechRecognitionPipeline,
+ AutoProcessor
+)
+from pyctcdecode import build_ctcdecoder
+from jiwer import wer
+
+import os
+import csv
+import pandas as pd
+
+Loading Models and LM Heads
+The Sunbird mms is a huggingface repository with a variety of models and adapters optimized for transcription and translation of languages. Currently, the Diarization developed caters for three languages English, Luganda and Acholi.
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+lang_config = {
+ "ach": "Sunbird/sunbird-mms",
+ "lug": "Sunbird/sunbird-mms",
+ "eng": "Sunbird/sunbird-mms",
+}
+model_id = "Sunbird/sunbird-mms"
+model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device)
+
+processor = AutoProcessor.from_pretrained(model_id)
+tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_id)
+
+tokenizer.set_target_lang("eng")
+model.load_adapter("eng_meta")
+
+feature_extractor = Wav2Vec2FeatureExtractor(
+ feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True
+)
+processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+vocab_dict = processor.tokenizer.get_vocab()
+sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
+
+Within the Sunbird/sunbird-mms
huggingface repository is a subfolder named language_model
containing various language models capable of efficient transcription.
lm_file_name = "eng_3gram.bin"
+lm_file_subfolder = "language_model"
+lm_file = hf_hub_download(
+ repo_id=lang_config["eng"],
+ filename=lm_file_name,
+ subfolder=lm_file_subfolder,
+)
+
+decoder = build_ctcdecoder(
+ labels=list(sorted_vocab_dict.keys()),
+ kenlm_model_path=lm_file,
+)
+
+processor_with_lm = Wav2Vec2ProcessorWithLM(
+ feature_extractor=feature_extractor,
+ tokenizer=tokenizer,
+ decoder=decoder,
+)
+feature_extractor._set_processor_class("Wav2Vec2ProcessorWithLM")
+
+The ASR pipeline is initialized with the pretrained Sunbird-mms
model, processor_with_lm
attributes tokenizer
, feature_extractor
and decoder
, respective device, chunch_length_s
, stride_length_s
and return_timestamps
pipe = AutomaticSpeechRecognitionPipeline(
+ model=model,
+ tokenizer=processor_with_lm.tokenizer, feature_extractor=processor_with_lm.feature_extractor,
+ decoder=processor_with_lm.decoder,
+ device=device,
+ chunk_length_s=10,
+ stride_length_s=(4, 2),
+ return_timestamps="word"
+)
+
+Performing a transcription
+ transcription = pipe("/content/Kibuuka_eng.mp3")
+
+The resulting dictionary transcription
will contain a text
key containing all the transcribed text as well as a chunks
containing individual texts along with their time stamps of the format below:
{
+ 'text' : 'Hello world',
+ 'chunks': [
+ {'text': 'Hello','timestamp': (0.5, 1.0)},
+ {'text': 'world','timestamp': (1.5, 2.0)}
+ ]
+}
+
+Imports
+from typing import Optional, Union
+import numpy as np
+from pyannote.audio import Pipeline
+import librosa
+
+Loading an audio file
+SAMPLE_RATE = 16000
+
+def load_audio(file: str, sr: int = SAMPLE_RATE) -> np.ndarray:
+
+ try:
+ # librosa automatically resamples to the given sample rate (if necessary)
+ # and converts the signal to mono (by averaging channels)
+ audio, _ = librosa.load(file, sr=sr, mono=True, dtype=np.float32)
+ except Exception as e:
+ raise RuntimeError(f"Failed to load audio with librosa: {e}") from e
+
+ return audio
+
+The load_audio
functions takes an audio file and sampling rate as one of its parameters. The sampling rate used for this Speaker Diarization is 16000. This sampling rate should be the same sampling rate used to transcribe the audio from using the Sunbird mms to ensure consistency with the output.
Diarization Pipeline
+The class Diarization Pipeline
is a custom class created to facilitate the diarization task. It initializes with a pretrained model and can be called with an audio file or waveform to perform diarization.
It returns a pandas DataFrame with with columns for the segment, label, speaker, start time, and end time of each speaker segment.
+class DiarizationPipeline:
+ def __init__(
+ self,
+ model_name="pyannote/speaker-diarization-3.0",
+ use_auth_token=None,
+ device: Optional[Union[str, torch.device]] = "cpu",
+ ):
+ if isinstance(device, str):
+ device = torch.device(device)
+ self.model = Pipeline.from_pretrained(model_name,
+ use_auth_token=use_auth_token).to(device)
+
+ def __call__(
+ self,
+ audio: Union[str, np.ndarray],
+ min_speakers: Optional[int] = None,
+ max_speakers: Optional[int] = None
+ ) -> pd.DataFrame:
+
+ if isinstance(audio, str):
+ audio = load_audio(audio)
+ audio_data = {
+ 'waveform': torch.from_numpy(audio[None, :]),
+ 'sample_rate': SAMPLE_RATE
+ }
+ segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers)
+ diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
+ diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
+ diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
+ return diarize_df
+
+Segment
+A class to represent a single segment of an audio with start time, end time and speaker label.
+This class to encapsulates the information about a segment of audio that has been identified during a speaker diarization process, including the time the segment starts, when it ends, and which speaker is speaking.
+class Segment:
+ def __init__(self, start, end, speaker=None):
+ self.start = start
+ self.end = end
+ self.speaker = speaker
+
+Assigning Speakers
+This is the process that involves taking the transcribed chunks and assigning them to the speakers discovered by the Speaker Diarization Pipeline.
+In this function, timestamps of the different chunks are compared against the start and end times of speakers in the DataFrame returned by the SpeakerDiarization
pipeline segments of a transcript are assigned speaker labels based on the overlap between the speech segments and diarization data.
The function iterates through segments of a transcript and assigns the speaker labels based on the overlap between the speech segments and the diarization data.
+In case of no overlap, a the fill_nearest parameter can be set to True
, then the function will assign the speakers to segments by finding the closest speaker in time.
The function takes parameters:
+diarize_df
: a pandas DataFrame returned by the DiarizationPipeline containing the diarization information with columns like start
, end
and speaker
transcript_result
: A dictionary with a key chunks
that contains a list of trancript Segments
obtained from the ASR pipeline.
fill_nearest
: Default is False
Returns:
An updated transcript_result
with speakers assigned to each segment in the form:
{
+ 'text':'Hello World',
+ 'chunks':[
+ {'text': 'Hello', 'timestamp': (0.5, 1.0), 'speaker': 0},
+ {'text': 'world', 'timestamp': (1.5, 2.0), 'speaker': 1}
+ ]
+}
+
+
+def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
+
+ transcript_segments = transcript_result["chunks"]
+
+ for seg in transcript_segments:
+ # Calculate intersection and union between diarization segments and transcript segment
+ diarize_df['intersection'] = np.minimum(diarize_df['end'], seg["timestamp"][1]) - np.maximum(diarize_df['start'], seg["timestamp"][0])
+ diarize_df['union'] = np.maximum(diarize_df['end'], seg["timestamp"][1]) - np.minimum(diarize_df['start'], seg["timestamp"][0])
+
+ # Filter out diarization segments with no overlap if fill_nearest is False
+ if not fill_nearest:
+ dia_tmp = diarize_df[diarize_df['intersection'] > 0]
+ else:
+ dia_tmp = diarize_df
+
+ # If there are overlapping segments, assign the speaker with the greatest overlap
+ if len(dia_tmp) > 0:
+ speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
+ seg["speaker"] = speaker
+
+ return transcript_result
+
+
+Running the diarization model
+diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
+diarize_segments = diarize_model("/content/Kibuuka_eng.mp3", min_speakers=1, max_speakers=2)
+
+diarize_segments
+
+Sample Output
+ +output = assign_word_speakers(diarize_segments, transcription)
+output
+
+Sample Output after Assigning Speakers
+{'text': "this is the chitaka's podcast my husband and i will be letting in honor life as a couple husband and helper husband and wife as husband and wife marriage is not a new wild you enter into you don't become a new person you come with what you been working on it's easy to go through the first year of your marriage trying to knit pick the shortcomings of your partner now this is our first episode and it's a series of random reflections from our one year in marriage now we hope that as we share experiences and insights on our journey the you will be inspired to pursue the potion and purpose to your marriage so this is the chitaka'spodcast and these are random reflections when you are married",
+ 'chunks': [{'text': 'this',
+ 'timestamp': (2.42, 2.58),
+ 'speaker': 'SPEAKER_01'},
+ {'text': 'is', 'timestamp': (2.68, 2.72), 'speaker': 'SPEAKER_01'},
+ {'text': 'the', 'timestamp': (2.78, 2.84), 'speaker': 'SPEAKER_01'},
+ {'text': "chitaka's", 'timestamp': (2.9, 3.32), 'speaker': 'SPEAKER_01'},
+ {'text': 'podcast', 'timestamp': (3.38, 3.86), 'speaker': 'SPEAKER_01'},
+ {'text': 'my', 'timestamp': (4.4, 4.48), 'speaker': 'SPEAKER_01'},
+ {'text': 'husband', 'timestamp': (4.52, 4.72), 'speaker': 'SPEAKER_01'},
+ {'text': 'and', 'timestamp': (4.8, 4.86), 'speaker': 'SPEAKER_01'},
+ {'text': 'i', 'timestamp': (4.96, 4.98), 'speaker': 'SPEAKER_01'},
+ {'text': 'will', 'timestamp': (5.1, 5.22), 'speaker': 'SPEAKER_01'},
+ {'text': 'be', 'timestamp': (5.28, 5.32), 'speaker': 'SPEAKER_01'},
+ {'text': 'letting', 'timestamp': (5.38, 5.64), 'speaker': 'SPEAKER_01'},
+ {'text': 'in', 'timestamp': (5.82, 5.86), 'speaker': 'SPEAKER_01'},
+ {'text': 'honor', 'timestamp': (6.06, 6.32), 'speaker': 'SPEAKER_01'},
+ {'text': 'life', 'timestamp': (6.42, 6.7), 'speaker': 'SPEAKER_01'},
+ {'text': 'as', 'timestamp': (6.82, 6.9), 'speaker': 'SPEAKER_01'},
+ {'text': 'a', 'timestamp': (6.98, 7.0), 'speaker': 'SPEAKER_01'},
+ {'text': 'couple', 'timestamp': (7.14, 7.52), 'speaker': 'SPEAKER_01'},
+ {'text': 'husband', 'timestamp': (8.06, 8.36), 'speaker': 'SPEAKER_00'},
+ {'text': 'and', 'timestamp': (8.44, 8.5), 'speaker': 'SPEAKER_00'},
+ {'text': 'helper', 'timestamp': (8.64, 9.02), 'speaker': 'SPEAKER_00'},
+ {'text': 'husband', 'timestamp': (9.36, 9.68), 'speaker': 'SPEAKER_01'},
+ {'text': 'and', 'timestamp': (9.76, 9.84), 'speaker': 'SPEAKER_01'},
+ {'text': 'wife', 'timestamp': (9.94, 10.3), 'speaker': 'SPEAKER_01'},
+ {'text': 'as', 'timestamp': (11.06, 11.14), 'speaker': 'SPEAKER_01'},
+ {'text': 'husband', 'timestamp': (11.24, 11.56), 'speaker': 'SPEAKER_01'},
+ {'text': 'and', 'timestamp': (11.62, 11.7), 'speaker': 'SPEAKER_01'},
+ {'text': 'wife', 'timestamp': (11.76, 12.04), 'speaker': 'SPEAKER_01'},
+ {'text': 'marriage', 'timestamp': (12.48, 12.82), 'speaker': 'SPEAKER_00'},
+ {'text': 'is', 'timestamp': (12.88, 12.94), 'speaker': 'SPEAKER_00'},
+ {'text': 'not', 'timestamp': (13.12, 13.48), 'speaker': 'SPEAKER_00'},
+ {'text': 'a', 'timestamp': (13.78, 13.8), 'speaker': 'SPEAKER_00'},
+ {'text': 'new', 'timestamp': (13.92, 14.06), 'speaker': 'SPEAKER_00'},
+ {'text': 'wild', 'timestamp': (14.16, 14.42), 'speaker': 'SPEAKER_00'},
+ {'text': 'you', 'timestamp': (14.5, 14.56), 'speaker': 'SPEAKER_00'},
+ {'text': 'enter', 'timestamp': (14.64, 14.82), 'speaker': 'SPEAKER_00'},
+ {'text': 'into', 'timestamp': (14.94, 15.2), 'speaker': 'SPEAKER_00'},
+ {'text': 'you', 'timestamp': (15.38, 15.44), 'speaker': 'SPEAKER_00'},
+ {'text': "don't", 'timestamp': (15.5, 15.64), 'speaker': 'SPEAKER_00'},
+ {'text': 'become', 'timestamp': (15.74, 15.98), 'speaker': 'SPEAKER_00'},
+ {'text': 'a', 'timestamp': (16.06, 16.08), 'speaker': 'SPEAKER_00'},
+ {'text': 'new', 'timestamp': (16.18, 16.28), 'speaker': 'SPEAKER_00'},
+ {'text': 'person', 'timestamp': (16.42, 16.86), 'speaker': 'SPEAKER_00'},
+ {'text': 'you', 'timestamp': (17.2, 17.26), 'speaker': 'SPEAKER_00'},
+ {'text': 'come', 'timestamp': (17.44, 17.64), 'speaker': 'SPEAKER_00'},
+ {'text': 'with', 'timestamp': (17.72, 17.82), 'speaker': 'SPEAKER_00'},
+ {'text': 'what', 'timestamp': (17.92, 18.02), 'speaker': 'SPEAKER_00'},
+ {'text': 'you', 'timestamp': (18.12, 18.18), 'speaker': 'SPEAKER_00'},
+ {'text': 'been', 'timestamp': (18.34, 18.46), 'speaker': 'SPEAKER_00'},
+ {'text': 'working', 'timestamp': (18.54, 18.86), 'speaker': 'SPEAKER_00'},
+ {'text': 'on', 'timestamp': (18.96, 19.12), 'speaker': 'SPEAKER_00'},
+ {'text': "it's", 'timestamp': (19.42, 19.52), 'speaker': 'SPEAKER_01'},
+ {'text': 'easy', 'timestamp': (19.64, 19.78), 'speaker': 'SPEAKER_01'},
+ {'text': 'to', 'timestamp': (19.9, 19.96), 'speaker': 'SPEAKER_01'},
+ {'text': 'go', 'timestamp': (20.12, 20.16), 'speaker': 'SPEAKER_01'},
+ {'text': 'through', 'timestamp': (20.36, 20.62), 'speaker': 'SPEAKER_01'},
+ {'text': 'the', 'timestamp': (21.32, 21.38), 'speaker': 'SPEAKER_01'},
+ {'text': 'first', 'timestamp': (21.44, 21.64), 'speaker': 'SPEAKER_01'},
+ {'text': 'year', 'timestamp': (21.7, 21.82), 'speaker': 'SPEAKER_01'},
+ {'text': 'of', 'timestamp': (21.86, 21.9), 'speaker': 'SPEAKER_01'},
+ {'text': 'your', 'timestamp': (21.96, 22.08), 'speaker': 'SPEAKER_01'},
+ {'text': 'marriage', 'timestamp': (22.14, 22.42), 'speaker': 'SPEAKER_01'},
+ {'text': 'trying', 'timestamp': (22.54, 22.74), 'speaker': 'SPEAKER_01'},
+ {'text': 'to', 'timestamp': (22.84, 22.88), 'speaker': 'SPEAKER_01'},
+ {'text': 'knit', 'timestamp': (23.2, 23.42), 'speaker': 'SPEAKER_01'},
+ {'text': 'pick', 'timestamp': (23.6, 23.78), 'speaker': 'SPEAKER_01'},
+ {'text': 'the', 'timestamp': (24.58, 24.64), 'speaker': 'SPEAKER_01'},
+ {'text': 'shortcomings', 'timestamp': (24.7, 25.2), 'speaker': 'SPEAKER_01'},
+ {'text': 'of', 'timestamp': (25.26, 25.3), 'speaker': 'SPEAKER_01'},
+ {'text': 'your', 'timestamp': (25.36, 25.46), 'speaker': 'SPEAKER_01'},
+ {'text': 'partner', 'timestamp': (25.52, 25.86), 'speaker': 'SPEAKER_01'},
+ {'text': 'now', 'timestamp': (26.28, 26.38), 'speaker': 'SPEAKER_01'},
+ {'text': 'this', 'timestamp': (26.46, 26.54), 'speaker': 'SPEAKER_01'},
+ {'text': 'is', 'timestamp': (26.62, 26.68), 'speaker': 'SPEAKER_01'},
+ {'text': 'our', 'timestamp': (26.74, 26.82), 'speaker': 'SPEAKER_01'},
+ {'text': 'first', 'timestamp': (26.92, 27.12), 'speaker': 'SPEAKER_01'},
+ {'text': 'episode', 'timestamp': (27.24, 27.68), 'speaker': 'SPEAKER_01'},
+ {'text': 'and', 'timestamp': (27.82, 28.04), 'speaker': 'SPEAKER_01'},
+ {'text': "it's", 'timestamp': (28.48, 28.6), 'speaker': 'SPEAKER_01'},
+ {'text': 'a', 'timestamp': (28.66, 28.68), 'speaker': 'SPEAKER_01'},
+ {'text': 'series', 'timestamp': (28.74, 28.96), 'speaker': 'SPEAKER_01'},
+ {'text': 'of', 'timestamp': (29.0, 29.04), 'speaker': 'SPEAKER_01'},
+ {'text': 'random', 'timestamp': (29.14, 29.4), 'speaker': 'SPEAKER_01'},
+ {'text': 'reflections', 'timestamp': (29.5, 30.04), 'speaker': 'SPEAKER_01'},
+ {'text': 'from', 'timestamp': (30.2, 30.3), 'speaker': 'SPEAKER_01'},
+ {'text': 'our', 'timestamp': (30.38, 30.52), 'speaker': 'SPEAKER_01'},
+ {'text': 'one', 'timestamp': (30.7, 30.82), 'speaker': 'SPEAKER_01'},
+ {'text': 'year', 'timestamp': (30.9, 31.08), 'speaker': 'SPEAKER_01'},
+ {'text': 'in', 'timestamp': (31.26, 31.34), 'speaker': 'SPEAKER_01'},
+ {'text': 'marriage', 'timestamp': (31.44, 31.82), 'speaker': 'SPEAKER_01'},
+ {'text': 'now', 'timestamp': (31.92, 32.02), 'speaker': 'SPEAKER_01'},
+ {'text': 'we', 'timestamp': (32.14, 32.22), 'speaker': 'SPEAKER_01'},
+ {'text': 'hope', 'timestamp': (32.36, 32.54), 'speaker': 'SPEAKER_01'},
+ {'text': 'that', 'timestamp': (32.66, 32.82), 'speaker': 'SPEAKER_01'},
+ {'text': 'as', 'timestamp': (32.96, 33.02), 'speaker': 'SPEAKER_01'},
+ {'text': 'we', 'timestamp': (33.08, 33.14), 'speaker': 'SPEAKER_01'},
+ {'text': 'share', 'timestamp': (33.24, 33.44), 'speaker': 'SPEAKER_01'},
+ {'text': 'experiences',
+ 'timestamp': (33.58, 34.14),
+ 'speaker': 'SPEAKER_01'},
+ {'text': 'and', 'timestamp': (34.2, 34.26), 'speaker': 'SPEAKER_01'},
+ {'text': 'insights', 'timestamp': (34.34, 34.74), 'speaker': 'SPEAKER_01'},
+ {'text': 'on', 'timestamp': (34.9, 34.98), 'speaker': 'SPEAKER_01'},
+ {'text': 'our', 'timestamp': (35.06, 35.16), 'speaker': 'SPEAKER_01'},
+ {'text': 'journey', 'timestamp': (35.22, 35.54), 'speaker': 'SPEAKER_01'},
+ {'text': 'the', 'timestamp': (36.0, 36.08), 'speaker': 'SPEAKER_01'},
+ {'text': 'you', 'timestamp': (36.22, 36.32), 'speaker': 'SPEAKER_01'},
+ {'text': 'will', 'timestamp': (36.44, 36.56), 'speaker': 'SPEAKER_01'},
+ {'text': 'be', 'timestamp': (36.64, 36.68), 'speaker': 'SPEAKER_01'},
+ {'text': 'inspired', 'timestamp': (36.76, 37.24), 'speaker': 'SPEAKER_01'},
+ {'text': 'to', 'timestamp': (37.6, 37.64), 'speaker': 'SPEAKER_01'},
+ {'text': 'pursue', 'timestamp': (37.7, 37.94), 'speaker': 'SPEAKER_01'},
+ {'text': 'the', 'timestamp': (38.0, 38.06), 'speaker': 'SPEAKER_01'},
+ {'text': 'potion', 'timestamp': (38.14, 38.46), 'speaker': 'SPEAKER_01'},
+ {'text': 'and', 'timestamp': (38.5, 38.58), 'speaker': 'SPEAKER_01'},
+ {'text': 'purpose', 'timestamp': (38.66, 39.06), 'speaker': 'SPEAKER_01'},
+ {'text': 'to', 'timestamp': (39.4, 39.46), 'speaker': 'SPEAKER_01'},
+ {'text': 'your', 'timestamp': (39.54, 39.66), 'speaker': 'SPEAKER_01'},
+ {'text': 'marriage', 'timestamp': (39.86, 40.24), 'speaker': 'SPEAKER_01'},
+ {'text': 'so', 'timestamp': (40.82, 40.9), 'speaker': 'SPEAKER_01'},
+ {'text': 'this', 'timestamp': (41.42, 41.6), 'speaker': 'SPEAKER_01'},
+ {'text': 'is', 'timestamp': (41.78, 41.84), 'speaker': 'SPEAKER_01'},
+ {'text': 'the', 'timestamp': (41.94, 42.0), 'speaker': 'SPEAKER_01'},
+ {'text': "chitaka'spodcast",
+ 'timestamp': (42.12, 43.16),
+ 'speaker': 'SPEAKER_01'},
+ {'text': 'and', 'timestamp': (43.54, 43.62), 'speaker': 'SPEAKER_01'},
+ {'text': 'these', 'timestamp': (43.7, 43.86), 'speaker': 'SPEAKER_01'},
+ {'text': 'are', 'timestamp': (43.94, 44.02), 'speaker': 'SPEAKER_01'},
+ {'text': 'random', 'timestamp': (44.1, 44.32), 'speaker': 'SPEAKER_01'},
+ {'text': 'reflections', 'timestamp': (44.4, 44.88), 'speaker': 'SPEAKER_01'},
+ {'text': 'when', 'timestamp': (45.28, 45.42), 'speaker': 'SPEAKER_01'},
+ {'text': 'you', 'timestamp': (45.48, 45.54), 'speaker': 'SPEAKER_01'},
+ {'text': 'are', 'timestamp': (45.56, 45.62), 'speaker': 'SPEAKER_01'},
+ {'text': 'married', 'timestamp': (45.68, 45.92), 'speaker': 'SPEAKER_01'}]}
+
+
+ This process highlights the steps taken for Model Training on the CallHome Dataset. For this particular dataset we used the English version of the CallHome Dataset. The Model Training Architecture, Loss Functions, Optimisation Techniques, Data Augmentation and Metrics Used.
+Forward
+forward
: Forward pass function of the Pretrained Model.
Parameters:
+waveforms(torch.tensor)
: A tensor containing audio data to be processed by the model and ensures the waveforms parameter is a PyTorch tensor.
labels
: Ground truth labels for Training. Defaults to None.
nb_speakers
: Number of speakers. Defaults to None
Returns: A dictionary with loss(if predicted) and predictions.
+Setup loss function
+setup_loss_func
: Sets up the loss function especially when using the powerset classes. ie self.specifications.powerset=True
Segmentation Loss Function
+segmentation_loss
: Defines the permutation-invariant segmentation loss. Computes the loss using either nll_loss
(negative log likelihood) for powerset
or binary_cross_entropy
Parameters:
+permutated_prediction
: Prediction after permutation. Type: torch.Tensor
target
: Ground truth labels. Type: torch.Tensor
weight
: Type: Optional[torch.Tensor]
Returns: Permutation-invariant segmentation loss. torch.Tensor
To pyannote
+to_pyannote_model
: Converts the current model to a pyannote segmentation model for use in pyannote pipelines
class SegmentationModel(PreTrainedModel):
+ config_class = SegmentationModelConfig
+
+ def __init__(
+ self,
+ config=SegmentationModelConfig(),
+ ):
+ super().__init__(config)
+
+ self.model = PyanNet_nn(sincnet={"stride": 10})
+
+ self.weigh_by_cardinality = config.weigh_by_cardinality
+ self.max_speakers_per_frame = config.max_speakers_per_frame
+ self.chunk_duration = config.chunk_duration
+ self.min_duration = config.min_duration
+ self.warm_up = config.warm_up
+ self.max_speakers_per_chunk = config.max_speakers_per_chunk
+
+ self.specifications = Specifications(
+ problem=Problem.MULTI_LABEL_CLASSIFICATION
+ if self.max_speakers_per_frame is None
+ else Problem.MONO_LABEL_CLASSIFICATION,
+ resolution=Resolution.FRAME,
+ duration=self.chunk_duration,
+ min_duration=self.min_duration,
+ warm_up=self.warm_up,
+ classes=[f"speaker#{i+1}" for i in range(self.max_speakers_per_chunk)],
+ powerset_max_classes=self.max_speakers_per_frame,
+ permutation_invariant=True,
+ )
+ self.model.specifications = self.specifications
+ self.model.build()
+ self.setup_loss_func()
+
+ def forward(self, waveforms, labels=None, nb_speakers=None):
+
+ prediction = self.model(waveforms.unsqueeze(1))
+ batch_size, num_frames, _ = prediction.shape
+
+ if labels is not None:
+ weight = torch.ones(batch_size, num_frames, 1, device=waveforms.device)
+ warm_up_left = round(self.specifications.warm_up[0] / self.specifications.duration * num_frames)
+ weight[:, :warm_up_left] = 0.0
+ warm_up_right = round(self.specifications.warm_up[1] / self.specifications.duration * num_frames)
+ weight[:, num_frames - warm_up_right :] = 0.0
+
+ if self.specifications.powerset:
+ multilabel = self.model.powerset.to_multilabel(prediction)
+ permutated_target, _ = permutate(multilabel, labels)
+
+ permutated_target_powerset = self.model.powerset.to_powerset(permutated_target.float())
+ loss = self.segmentation_loss(prediction, permutated_target_powerset, weight=weight)
+
+ else:
+ permutated_prediction, _ = permutate(labels, prediction)
+ loss = self.segmentation_loss(permutated_prediction, labels, weight=weight)
+
+ return {"loss": loss, "logits": prediction}
+
+ return {"logits": prediction}
+
+ def setup_loss_func(self):
+ if self.specifications.powerset:
+ self.model.powerset = Powerset(
+ len(self.specifications.classes),
+ self.specifications.powerset_max_classes,
+ )
+
+ def segmentation_loss(
+ self,
+ permutated_prediction: torch.Tensor,
+ target: torch.Tensor,
+ weight: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+
+
+ if self.specifications.powerset:
+ # `clamp_min` is needed to set non-speech weight to 1.
+ class_weight = torch.clamp_min(self.model.powerset.cardinality, 1.0) if self.weigh_by_cardinality else None
+ seg_loss = nll_loss(
+ permutated_prediction,
+ torch.argmax(target, dim=-1),
+ class_weight=class_weight,
+ weight=weight,
+ )
+ else:
+ seg_loss = binary_cross_entropy(permutated_prediction, target.float(), weight=weight)
+
+ return seg_loss
+
+ @classmethod
+ def from_pyannote_model(cls, pretrained):
+
+ # Initialize model:
+ specifications = copy.deepcopy(pretrained.specifications)
+
+ # Copy pretrained model hyperparameters:
+ chunk_duration = specifications.duration
+ max_speakers_per_frame = specifications.powerset_max_classes
+ weigh_by_cardinality = False
+ min_duration = specifications.min_duration
+ warm_up = specifications.warm_up
+ max_speakers_per_chunk = len(specifications.classes)
+
+ config = SegmentationModelConfig(
+ chunk_duration=chunk_duration,
+ max_speakers_per_frame=max_speakers_per_frame,
+ weigh_by_cardinality=weigh_by_cardinality,
+ min_duration=min_duration,
+ warm_up=warm_up,
+ max_speakers_per_chunk=max_speakers_per_chunk,
+ )
+
+ model = cls(config)
+
+ # Copy pretrained model weights:
+ model.model.hparams = copy.deepcopy(pretrained.hparams)
+ model.model.sincnet = copy.deepcopy(pretrained.sincnet)
+ model.model.sincnet.load_state_dict(pretrained.sincnet.state_dict())
+ model.model.lstm = copy.deepcopy(pretrained.lstm)
+ model.model.lstm.load_state_dict(pretrained.lstm.state_dict())
+ model.model.linear = copy.deepcopy(pretrained.linear)
+ model.model.linear.load_state_dict(pretrained.linear.state_dict())
+ model.model.classifier = copy.deepcopy(pretrained.classifier)
+ model.model.classifier.load_state_dict(pretrained.classifier.state_dict())
+ model.model.activation = copy.deepcopy(pretrained.activation)
+ model.model.activation.load_state_dict(pretrained.activation.state_dict())
+
+ return model
+
+ def to_pyannote_model(self):
+
+ seg_model = PyanNet(sincnet={"stride": 10})
+ seg_model.hparams.update(self.model.hparams)
+
+ seg_model.sincnet = copy.deepcopy(self.model.sincnet)
+ seg_model.sincnet.load_state_dict(self.model.sincnet.state_dict())
+
+ seg_model.lstm = copy.deepcopy(self.model.lstm)
+ seg_model.lstm.load_state_dict(self.model.lstm.state_dict())
+
+ seg_model.linear = copy.deepcopy(self.model.linear)
+ seg_model.linear.load_state_dict(self.model.linear.state_dict())
+
+ seg_model.classifier = copy.deepcopy(self.model.classifier)
+ seg_model.classifier.load_state_dict(self.model.classifier.state_dict())
+
+ seg_model.activation = copy.deepcopy(self.model.activation)
+ seg_model.activation.load_state_dict(self.model.activation.state_dict())
+
+ seg_model.specifications = self.specifications
+
+ return seg_model
+
+Segmentation Model Configuration
+SegmentationModelConfig
Configuration class for the segmentation model, specifying various parameters like chunk duration, maximum speakers per frame, etc.class SegmentationModelConfig(PretrainedConfig):
+
+ model_type = "pyannet"
+
+ def __init__(
+ self,
+ chunk_duration=10,
+ max_speakers_per_frame=2,
+ max_speakers_per_chunk=3,
+ min_duration=None,
+ warm_up=(0.0, 0.0),
+ weigh_by_cardinality=False,
+ **kwargs,
+ ):
+
+ super().__init__(**kwargs)
+ self.chunk_duration = chunk_duration
+ self.max_speakers_per_frame = max_speakers_per_frame
+ self.max_speakers_per_chunk = max_speakers_per_chunk
+ self.min_duration = min_duration
+ self.warm_up = warm_up
+ self.weigh_by_cardinality = weigh_by_cardinality
+ # For now, the model handles only 16000 Hz sampling rate
+ self.sample_rate = 16000
+
+Preprocess
class used to handle these preprocessing steps is not detailed here, but it's responsible for preparing the input data.class Preprocess:
+ def __init__(
+ self,
+ config,
+ ):
+
+ self.chunk_duration = config.chunk_duration
+ self.max_speakers_per_frame = config.max_speakers_per_frame
+ self.max_speakers_per_chunk = config.max_speakers_per_chunk
+ self.min_duration = config.min_duration
+ self.warm_up = config.warm_up
+
+ self.sample_rate = config.sample_rate
+ self.model = SegmentationModel(config).to_pyannote_model()
+
+ # Get the number of frames associated to a chunk:
+ _, self.num_frames_per_chunk, _ = self.model(
+ torch.rand((1, int(self.chunk_duration * self.sample_rate)))
+ ).shape
+
+ def get_labels_in_file(self, file):
+
+
+ file_labels = []
+ for i in range(len(file["speakers"][0])):
+ if file["speakers"][0][i] not in file_labels:
+ file_labels.append(file["speakers"][0][i])
+
+ return file_labels
+
+ def get_segments_in_file(self, file, labels):
+
+
+ file_annotations = []
+
+ for i in range(len(file["timestamps_start"][0])):
+ start_segment = file["timestamps_start"][0][i]
+ end_segment = file["timestamps_end"][0][i]
+ label = labels.index(file["speakers"][0][i])
+ file_annotations.append((start_segment, end_segment, label))
+
+ dtype = [("start", "<f4"), ("end", "<f4"), ("labels", "i1")]
+
+ annotations = np.array(file_annotations, dtype)
+
+ return annotations
+
+ def get_chunk(self, file, start_time):
+
+
+ sample_rate = file["audio"][0]["sampling_rate"]
+
+ assert sample_rate == self.sample_rate
+
+ end_time = start_time + self.chunk_duration
+ start_frame = math.floor(start_time * sample_rate)
+ num_frames_waveform = math.floor(self.chunk_duration * sample_rate)
+ end_frame = start_frame + num_frames_waveform
+
+ waveform = file["audio"][0]["array"][start_frame:end_frame]
+
+ labels = self.get_labels_in_file(file)
+
+ file_segments = self.get_segments_in_file(file, labels)
+
+ chunk_segments = file_segments[(file_segments["start"] < end_time) & (file_segments["end"] > start_time)]
+
+ # compute frame resolution:
+ # resolution = self.chunk_duration / self.num_frames_per_chunk
+
+ # discretize chunk annotations at model output resolution
+ step = self.model.receptive_field.step
+ half = 0.5 * self.model.receptive_field.duration
+
+ # discretize chunk annotations at model output resolution
+ start = np.maximum(chunk_segments["start"], start_time) - start_time - half
+ start_idx = np.maximum(0, np.round(start / step)).astype(int)
+
+ # start_idx = np.floor(start / resolution).astype(int)
+ end = np.minimum(chunk_segments["end"], end_time) - start_time - half
+ end_idx = np.round(end / step).astype(int)
+
+ # end_idx = np.ceil(end / resolution).astype(int)
+
+ # get list and number of labels for current scope
+ labels = list(np.unique(chunk_segments["labels"]))
+ num_labels = len(labels)
+ # initial frame-level targets
+ y = np.zeros((self.num_frames_per_chunk, num_labels), dtype=np.uint8)
+
+ # map labels to indices
+ mapping = {label: idx for idx, label in enumerate(labels)}
+
+ for start, end, label in zip(start_idx, end_idx, chunk_segments["labels"]):
+ mapped_label = mapping[label]
+ y[start : end + 1, mapped_label] = 1
+
+ return waveform, y, labels
+
+ def get_start_positions(self, file, overlap, random=False):
+
+ sample_rate = file["audio"][0]["sampling_rate"]
+
+ assert sample_rate == self.sample_rate
+
+ file_duration = len(file["audio"][0]["array"]) / sample_rate
+ start_positions = np.arange(0, file_duration - self.chunk_duration, self.chunk_duration * (1 - overlap))
+
+ if random:
+ nb_samples = int(file_duration / self.chunk_duration)
+ start_positions = np.random.uniform(0, file_duration, nb_samples)
+
+ return start_positions
+
+ def __call__(self, file, random=False, overlap=0.0):
+
+ new_batch = {"waveforms": [], "labels": [], "nb_speakers": []}
+
+ if random:
+ start_positions = self.get_start_positions(file, overlap, random=True)
+ else:
+ start_positions = self.get_start_positions(file, overlap)
+
+ for start_time in start_positions:
+ waveform, target, label = self.get_chunk(file, start_time)
+
+ new_batch["waveforms"].append(waveform)
+ new_batch["labels"].append(target)
+ new_batch["nb_speakers"].append(label)
+
+ return new_batch
+
+import numpy as np
+import torch
+from pyannote.audio.torchmetrics import (DiarizationErrorRate, FalseAlarmRate,
+ MissedDetectionRate,
+ SpeakerConfusionRate)
+from pyannote.audio.utils.powerset import Powerset
+
+
+class Metrics:
+ """Metric class used by the HF trainer to compute speaker diarization metrics."""
+
+ def __init__(self, specifications) -> None:
+ """init method
+
+ Args:
+ specifications (_type_): specifications attribute from a SegmentationModel.
+ """
+ self.powerset = specifications.powerset
+ self.classes = specifications.classes
+ self.powerset_max_classes = specifications.powerset_max_classes
+
+ self.model_powerset = Powerset(
+ len(self.classes),
+ self.powerset_max_classes,
+ )
+
+ self.metrics = {
+ "der": DiarizationErrorRate(0.5),
+ "confusion": SpeakerConfusionRate(0.5),
+ "missed_detection": MissedDetectionRate(0.5),
+ "false_alarm": FalseAlarmRate(0.5),
+ }
+
+ def __call__(self, eval_pred):
+
+ logits, labels = eval_pred
+
+ if self.powerset:
+ predictions = self.model_powerset.to_multilabel(torch.tensor(logits))
+ else:
+ predictions = torch.tensor(logits)
+
+ labels = torch.tensor(labels)
+
+ predictions = torch.transpose(predictions, 1, 2)
+ labels = torch.transpose(labels, 1, 2)
+
+ metrics = {"der": 0, "false_alarm": 0, "missed_detection": 0, "confusion": 0}
+
+ metrics["der"] += self.metrics["der"](predictions, labels).cpu().numpy()
+ metrics["false_alarm"] += self.metrics["false_alarm"](predictions, labels).cpu().numpy()
+ metrics["missed_detection"] += self.metrics["missed_detection"](predictions, labels).cpu().numpy()
+ metrics["confusion"] += self.metrics["confusion"](predictions, labels).cpu().numpy()
+
+ return metrics
+
+
+class DataCollator:
+ """Data collator that will dynamically pad the target labels to have max_speakers_per_chunk"""
+
+ def __init__(self, max_speakers_per_chunk) -> None:
+ self.max_speakers_per_chunk = max_speakers_per_chunk
+
+ def __call__(self, features):
+ """_summary_
+
+ Args:
+ features (_type_): _description_
+
+ Returns:
+ _type_: _description_
+ """
+
+ batch = {}
+
+ speakers = [f["nb_speakers"] for f in features]
+ labels = [f["labels"] for f in features]
+
+ batch["labels"] = self.pad_targets(labels, speakers)
+
+ batch["waveforms"] = torch.stack([f["waveforms"] for f in features])
+
+ return batch
+
+ def pad_targets(self, labels, speakers):
+ """
+ labels:
+ speakers:
+
+ Returns:
+ _type_:
+ Collated target tensor of shape (num_frames, self.max_speakers_per_chunk)
+ If one chunk has more than max_speakers_per_chunk speakers, we keep
+ the max_speakers_per_chunk most talkative ones. If it has less, we pad with
+ zeros (artificial inactive speakers).
+ """
+
+ targets = []
+
+ for i in range(len(labels)):
+ label = speakers[i]
+ target = labels[i].numpy()
+ num_speakers = len(label)
+
+ if num_speakers > self.max_speakers_per_chunk:
+ indices = np.argsort(-np.sum(target, axis=0), axis=0)
+ target = target[:, indices[: self.max_speakers_per_chunk]]
+
+ elif num_speakers < self.max_speakers_per_chunk:
+ target = np.pad(
+ target,
+ ((0, 0), (0, self.max_speakers_per_chunk - num_speakers)),
+ mode="constant",
+ )
+
+ targets.append(target)
+
+ return torch.from_numpy(np.stack(targets))
+
+
+!python3 train_segmentation.py \
+ --dataset_name=diarizers-community/callhome \
+ --dataset_config_name=eng \
+ --split_on_subset=data \
+ --model_name_or_path=pyannote/segmentation-3.0 \
+ --output_dir=./speaker-segmentation-fine-tuned-callhome-eng \
+ --do_train \
+ --do_eval \
+ --learning_rate=1e-3 \
+ --num_train_epochs=20 \
+ --lr_scheduler_type=cosine \
+ --per_device_train_batch_size=32 \
+ --per_device_eval_batch_size=32 \
+ --evaluation_strategy=epoch \
+ --save_strategy=epoch \
+ --preprocessing_num_workers=2 \
+ --dataloader_num_workers=2 \
+ --logging_steps=100 \
+ --load_best_model_at_end \
+ --push_to_hub
+
+The script test_segmentation.pycan be used to evaluate a fine-tuned model on a diarization dataset. In the following example, we evaluate the fine-tuned model from the previous step on the test split of the CallHome English dataset:
+!python3 test_segmentation.py \
+ --dataset_name=diarizers-community/callhome \
+ --dataset_config_name=eng \
+ --split_on_subset=data \
+ --test_split_name=test \
+ --model_name_or_path=diarizers-community/speaker-segmentation-fine-tuned-callhome-eng \
+ --preprocessing_num_workers=2 \
+ --evaluate_with_pipeline
+
+Sample Output
+ +from diarizers import SegmentationModel
+from pyannote.audio import Pipeline
+from datasets import load_dataset
+import torch
+
+device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
+
+# load the pre-trained pyannote pipeline
+pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
+pipeline.to(device)
+
+# replace the segmentation model with your fine-tuned one
+model = SegmentationModel().from_pretrained("diarizers-community/speaker-segmentation-fine-tuned-callhome-jpn")
+model = model.to_pyannote_model()
+pipeline._segmentation.model = model.to(device)
+
+# load dataset example
+dataset = load_dataset("diarizers-community/callhome", "jpn", split="data")
+sample = dataset[0]["audio"]
+
+# pre-process inputs
+sample["waveform"] = torch.from_numpy(sample.pop("array")[None, :]).to(device, dtype=model.dtype)
+sample["sample_rate"] = sample.pop("sampling_rate")
+
+# perform inference
+diarization = pipeline(sample)
+
+# dump the diarization output to disk using RTTM format
+with open("audio.rttm", "w") as rttm:
+ diarization.write_rttm(rttm)
+
+
+ The script test_segmentation.py
can be used to evaluate a fine-tuned model on a diarization dataset.
In the following example, we evaluate the fine-tuned model from the test split of the CallHome English Dataset.
+python3 test_segmentation.py \
+ --dataset_name=diarizers-community/callhome \
+ --dataset_config_name=eng \
+ --split_on_subset=data \
+ --test_split_name=test \
+ --model_name_or_path=diarizers-community/speaker-segmentation-fine-tuned-callhome-eng \
+ --preprocessing_num_workers=2 \
+ --evaluate_with_pipeline \
+
+The output above is the default output that can be obtained using the default Evaluation Script.
+This documentation further explores the evaluation process adding more to the metrics that can be measured during this process and highlighting the editing.
+Considering there are many metrics that can be obtained throughout the diarization process as documented in the pyannote.audio.metrics
documentation.
In this documentation, we'll focus on the Segmentation Precision, Segmentation Recall and Identification F1 Score.
+Imports
+Segment
: Speaker segmentation is the process of dividing an audio recording into segments based on the changing speakers’ identities. The goal of speaker segmentation is to determine the time boundaries where the speaker changes occur, effectively identifying the points at which one speaker’s speech ends, and another’s begins. That said, a Segment
is a data structure with start
and end
time that will then be placed in a Timeline
Timeline
: A data structure containing various segments. Reference timelines are provided in the ground truth and are compared against the predicted timelines to calculate segmentation precision
and segmentation recall
from pyannote.core import SlidingWindow, SlidingWindowFeature, Timeline, Segment
+from pyannote.metrics import segmentation, identification
+
+Initialization
+class Test
: The Segmentation Model test implementation is carried out within the Test Class found in the Test.py
file in src/diarizers
Parameters
+test_dataset
: The test dataset to be used. In this example, it will be the test split on the Callhome English dataset.
model (SegmentationModel)
: The model is the finetuned model trained by the train_segmentation.py
script.
step (float, optional)
: Steps between successive generated audio chunks. Defaults to 2.5.
metrics
: For this example, the metrics segmentation_precision
,segmentation_recall
,recall_value
,precision_value
and count
have been added for the purpose of calculating the segmentation recall and precision of the Segmentation Model.
class Test:
+
+ def __init__(self, test_dataset, model, step=2.5):
+
+ self.test_dataset = test_dataset
+ self.model = model
+ (self.device,) = get_devices(needs=1)
+ self.inference = Inference(self.model, step=step, device=self.device)
+
+ self.sample_rate = test_dataset[0]["audio"]["sampling_rate"]
+
+ # Get the number of frames associated to a chunk:
+ _, self.num_frames, _ = self.inference.model(
+ torch.rand((1, int(self.inference.duration * self.sample_rate))).to(self.device)
+ ).shape
+ # compute frame resolution:
+ self.resolution = self.inference.duration / self.num_frames
+
+ self.metrics = {
+ "der": DiarizationErrorRate(0.5).to(self.device),
+ "confusion": SpeakerConfusionRate(0.5).to(self.device),
+ "missed_detection": MissedDetectionRate(0.5).to(self.device),
+ "false_alarm": FalseAlarmRate(0.5).to(self.device),
+ "segmentation_precision": segmentation.SegmentationPrecision(),
+ "segmentation_recall": segmentation.SegmentationRecall(),
+ "recall_value":0,
+ "precision_value": 0,
+ "count": 0,
+ }
+
+
+Predict function
+This function makes a prediction on a dataset row using pyannote inference object.
+ def predict(self, file):
+ audio = torch.tensor(file["audio"]["array"]).unsqueeze(0).to(torch.float32).to(self.device)
+ sample_rate = file["audio"]["sampling_rate"]
+
+ input = {"waveform": audio, "sample_rate": sample_rate}
+
+ prediction = self.inference(input)
+
+ return prediction
+
+Compute Ground Truth Function
+This function converts a dataset row into the suitable format for evaluation as the ground truth.
+Returns
: numpy array with shape (num_frames, num_speakers).
def compute_gt(self, file):
+
+ audio = torch.tensor(file["audio"]["array"]).unsqueeze(0).to(torch.float32)
+ sample_rate = file["audio"]["sampling_rate"]
+
+ audio_duration = len(audio[0]) / sample_rate
+ num_frames = int(round(audio_duration / self.resolution))
+
+ labels = list(set(file["speakers"]))
+
+ gt = np.zeros((num_frames, len(labels)), dtype=np.uint8)
+
+ for i in range(len(file["timestamps_start"])):
+ start = file["timestamps_start"][i]
+ end = file["timestamps_end"][i]
+ speaker = file["speakers"][i]
+ start_frame = int(round(start / self.resolution))
+ end_frame = int(round(end / self.resolution))
+ speaker_index = labels.index(speaker)
+
+ gt[start_frame:end_frame, speaker_index] += 1
+
+ return gt
+
+Convert to Timeline
+This function creates a Timeline
using data and labels passed as parameters and converted into Segments
. Required in order to calculate Segmentation Precision and Recall.
def convert_to_timeline(self, data, labels):
+ timeline = Timeline()
+ for speaker_index, label in enumerate(labels):
+ segments = np.where(data[:, speaker_index] == 1)[0]
+ if len(segments) > 0:
+ start = segments[0] * self.resolution
+ end = segments[0] * self.resolution
+ for frame in segments[1:]:
+ if frame == end / self.resolution + 1:
+ end += self.resolution
+ else:
+ timeline.add(Segment(start, end + self.resolution))
+ start = frame * self.resolution
+ end = frame * self.resolution
+ timeline.add(Segment(start, end + self.resolution))
+ return timeline
+
+Compute Metrics on File
+Function that computes metrics for a dataset row passed into it. This function is run iteratively until the entire dataset has been processed.
+ def compute_metrics_on_file(self, file):
+ gt = self.compute_gt(file)
+ prediction = self.predict(file)
+
+ sliding_window = SlidingWindow(start=0, step=self.resolution, duration=self.resolution)
+ labels = list(set(file["speakers"]))
+
+ reference = SlidingWindowFeature(data=gt, labels=labels, sliding_window=sliding_window)
+
+ # Convert to Timeline for SegmentationPrecision
+ reference_timeline = self.convert_to_timeline(gt, labels)
+ prediction_timeline = self.convert_to_timeline(prediction.data, labels)
+
+
+ for window, pred in prediction:
+ reference_window = reference.crop(window, mode="center")
+ common_num_frames = min(self.num_frames, reference_window.shape[0])
+
+ _, ref_num_speakers = reference_window.shape
+ _, pred_num_speakers = pred.shape
+
+ if pred_num_speakers > ref_num_speakers:
+ reference_window = np.pad(reference_window, ((0, 0), (0, pred_num_speakers - ref_num_speakers)))
+ elif ref_num_speakers > pred_num_speakers:
+ pred = np.pad(pred, ((0, 0), (0, ref_num_speakers - pred_num_speakers)))
+
+ pred = torch.tensor(pred[:common_num_frames]).unsqueeze(0).permute(0, 2, 1).to(self.device)
+ target = (torch.tensor(reference_window[:common_num_frames]).unsqueeze(0).permute(0, 2, 1)).to(self.device)
+
+ self.metrics["der"](pred, target)
+ self.metrics["false_alarm"](pred, target)
+ self.metrics["missed_detection"](pred, target)
+ self.metrics["confusion"](pred, target)
+
+
+ # Compute precision
+ self.metrics["precision_value"] += self.metrics["segmentation_precision"](reference_timeline, prediction_timeline)
+ self.metrics["recall_value"] += self.metrics["segmentation_recall"](reference_timeline, prediction_timeline)
+ self.metrics["count"] += 1
+
+Compute Metrics
+Using all the functions above, the metrics for the Segmentation Model can then be computed and returned at once as shown below.
+Further information on metrics that extracted from the Segmentation model can be found here
+ def compute_metrics(self):
+ """Main method, used to compute speaker diarization metrics on test_dataset.
+ Returns:
+ dict: metric values.
+ """
+
+ for file in tqdm(self.test_dataset):
+ self.compute_metrics_on_file(file)
+ if self.metrics["count"] != 0:
+ self.metrics["precision_value"] /= self.metrics["count"]
+ self.metrics["recall_value"] /= self.metrics["count"]
+
+ return {
+ "der": self.metrics["der"].compute(),
+ "false_alarm": self.metrics["false_alarm"].compute(),
+ "missed_detection": self.metrics["missed_detection"].compute(),
+ "confusion": self.metrics["confusion"].compute(),
+ "segmentation_precision": self.metrics["precision_value"],
+ "segmentation_recall": self.metrics["recall_value"],
+ }
+
+The Fine-tuned segmentation model can be run in the Speaker Diarization Pipeline by calling from_pretrained
and overwriting the segmentation model with the fine-tuned model. Code can be found in the Test.py
script.
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
+ pipeline._segmentation.model = model
+
+Initialization
+The class TestPipeline
will be implementing and testing the Speaker Diarization Pipeline with the finetuned segmentation model.
Parameters
+pipeline
: Speaker Diarization pipeline
test_dataset
: Data to be tested. In this example, it is data from the Callhome English dataset.
metrics
: Since pyannote.metrics
does not offer Identification F1-score, we'll use the Precision and Recall to calculate the identificationF1Score
class TestPipeline:
+ def __init__(self, test_dataset, pipeline) -> None:
+
+ self.test_dataset = test_dataset
+
+ (self.device,) = get_devices(needs=1)
+ self.pipeline = pipeline.to(self.device)
+ self.sample_rate = test_dataset[0]["audio"]["sampling_rate"]
+
+ # Get the number of frames associated to a chunk:
+ _, self.num_frames, _ = self.pipeline._segmentation.model(
+ torch.rand((1, int(self.pipeline._segmentation.duration * self.sample_rate))).to(self.device)
+ ).shape
+ # compute frame resolution:
+ self.resolution = self.pipeline._segmentation.duration / self.num_frames
+
+ self.metrics = {
+ "der": diarization.DiarizationErrorRate(),
+ "identification_precision": identification.IdentificationPrecision(),
+ "identification_recall": identification.IdentificationRecall(),
+ "identification_f1": 0,
+
+ }
+
+Compute Ground Truth
+Function that reformats the Dataset Row to return the ground truth to be used for evaluation.
+Parameters
+file
: A single Dataset Row
def compute_gt(self, file):
+
+ """
+ Args:
+ file (_type_): dataset row.
+
+ Returns:
+ gt: numpy array with shape (num_frames, num_speakers).
+ """
+
+ audio = torch.tensor(file["audio"]["array"]).unsqueeze(0).to(torch.float32)
+ sample_rate = file["audio"]["sampling_rate"]
+
+ audio_duration = len(audio[0]) / sample_rate
+ num_frames = int(round(audio_duration / self.resolution))
+
+ labels = list(set(file["speakers"]))
+
+ gt = np.zeros((num_frames, len(labels)), dtype=np.uint8)
+
+ for i in range(len(file["timestamps_start"])):
+ start = file["timestamps_start"][i]
+ end = file["timestamps_end"][i]
+ speaker = file["speakers"][i]
+ start_frame = int(round(start / self.resolution))
+ end_frame = int(round(end / self.resolution))
+ speaker_index = labels.index(speaker)
+
+ gt[start_frame:end_frame, speaker_index] += 1
+
+ return gt
+
+Predict Function
+def predict(self, file):
+
+ sample = {}
+ sample["waveform"] = (
+ torch.from_numpy(file["audio"]["array"])
+ .to(self.device, dtype=self.pipeline._segmentation.model.dtype)
+ .unsqueeze(0)
+ )
+ sample["sample_rate"] = file["audio"]["sampling_rate"]
+
+ prediction = self.pipeline(sample)
+ # print("Prediction data: ", prediction.data )
+
+ return prediction
+
+Compute on File
+Function that calculates the f1 score of a file
(Dataset Row) using the precision
and recall
. It also calculates the der
(Diarization Error rate) and can be edited to extract more evaluation metrics such as Segmentation Purity
and Segmentation Coverage
.
For the purpose of this demonstration, the latter two were not obtained. Details about Segmentation Coverage and Segmentation Purity can be obtained here.
+
+
+def compute_metrics_on_file(self, file):
+
+ pred = self.predict(file)
+ gt = self.compute_gt(file)
+
+ sliding_window = SlidingWindow(start=0, step=self.resolution, duration=self.resolution)
+ gt = SlidingWindowFeature(data=gt, sliding_window=sliding_window)
+
+ gt = self.pipeline.to_annotation(
+ gt,
+ min_duration_on=0.0,
+ min_duration_off=self.pipeline.segmentation.min_duration_off,
+ )
+
+ mapping = {label: expected_label for label, expected_label in zip(gt.labels(), self.pipeline.classes())}
+
+ gt = gt.rename_labels(mapping=mapping)
+
+
+ der = self.metrics["der"](pred, gt)
+ identificationPrecision = self.metrics["identification_precision"](pred, gt)
+ identificationRecall = self.metrics["identification_recall"](pred, gt)
+ identificationF1 = (2 * identificationPrecision * identificationRecall) / (identificationRecall + identificationPrecision)
+
+ return {"der": der, "identificationF1": identificationF1}
+
+Compute Metrics
+This function iteratively calls the compute_metrics_on_file
function to perform computation on all the files in the dataset.
Returns
: The average values of the der
(diarization error rate) and f1
(F1 Score).
def compute_metrics(self):
+
+ der = 0
+ f1 = 0
+ for file in tqdm(self.test_dataset):
+ met = self.compute_metrics_on_file(file)
+ der += met["der"]
+ f1 += met["identificationF1"]
+
+ der /= len(self.test_dataset)
+ f1 /= len(self.test_dataset)
+
+ return {"der": der, "identificationF1Score": f1}
+
+An example of the output as expected from the edited script.
+ + +