Skip to content

Commit

Permalink
Merge pull request #202 from Wordcab/196-implement-the-segmentation-m…
Browse files Browse the repository at this point in the history
…odel

Re-implement EncDecSpeakerLabelModel
  • Loading branch information
aleksandr-smechov authored Aug 31, 2023
2 parents 7121065 + f3f3a99 commit d303ac0
Show file tree
Hide file tree
Showing 13 changed files with 3,328 additions and 2,910 deletions.
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ repos:
entry: check-yaml
language: system
types: [yaml]
- id: darglint
name: darglint
entry: darglint
language: system
types: [python]
stages: [manual]
- id: end-of-file-fixer
name: Fix End of Files
entry: end-of-file-fixer
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04

ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y \
Expand Down
311 changes: 311 additions & 0 deletions notebooks/implement_segmentation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import hashlib\n",
"import os\n",
"import tarfile\n",
"import wget\n",
"import yaml\n",
"from pathlib import Path\n",
"from typing import Optional, Tuple, Union\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"\n",
"def resolve_diarization_cache_dir() -> Path:\n",
" \"\"\"\n",
" Utility method to get the cache directory for the diarization module.\n",
"\n",
" Returns:\n",
" Path: The path to the cache directory.\n",
" \"\"\"\n",
" path = Path.joinpath(Path.home(), f\".cache/torch/diarization\")\n",
"\n",
" return path"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# Inspired from NVIDIA NeMo's EncDecSpeakerLabelModel\n",
"# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/label_models.py#L67\n",
"class EncDecSpeakerLabelModel:\n",
" \"\"\"The EncDecSpeakerLabelModel class encapsulates the encoder-decoder speaker label model.\"\"\"\n",
"\n",
" def __init__(self, model_name: str = \"titanet_large\") -> None:\n",
" \"\"\"Initialize the EncDecSpeakerLabelModel class.\n",
"\n",
" The EncDecSpeakerLabelModel class encapsulates the encoder-decoder speaker label model.\n",
" Only the \"titanet_large\" model is supported at the moment.\n",
" For more models: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/label_models.py#L59\n",
"\n",
" Args:\n",
" model_name (str, optional): The name of the model to use. Defaults to \"titanet_large\".\n",
"\n",
" Raises:\n",
" ValueError: If the model name is not supported.\n",
" \"\"\"\n",
" if model_name != \"titanet_large\":\n",
" raise ValueError(\n",
" f\"Unknown model name: {model_name}. Only 'titanet_large' is supported at the moment.\"\n",
" )\n",
"\n",
" self.model_name = model_name\n",
" self.location_in_the_cloud = \"https://api.ngc.nvidia.com/v2/models/nvidia/nemo/titanet_large/versions/v1/files/titanet-l.nemo\"\n",
" self.cache_dir = Path.joinpath(resolve_diarization_cache_dir(), \"titanet-l\")\n",
" cache_subfolder = hashlib.md5((self.location_in_the_cloud).encode(\"utf-8\")).hexdigest()\n",
"\n",
" self.nemo_model_folder, self.nemo_model_file = self.download_model_if_required(\n",
" url=self.location_in_the_cloud, cache_dir=self.cache_dir, subfolder=cache_subfolder,\n",
" )\n",
"\n",
" self.model_files = Path.joinpath(self.nemo_model_folder, \"model_files\")\n",
" if not self.model_files.exists():\n",
" self.model_files.mkdir(parents=True, exist_ok=True)\n",
" self.unpack_nemo_file(self.nemo_model_file, self.model_files)\n",
"\n",
" self.model_weights_file_path = Path.joinpath(self.model_files, \"model_weights.ckpt\")\n",
" model_config_file_path = Path.joinpath(self.model_files, \"model_config.yaml\")\n",
" with open(model_config_file_path, \"r\") as config_file:\n",
" self.model_config = yaml.safe_load(config_file)\n",
"\n",
"\n",
" @staticmethod\n",
" def download_model_if_required(url, subfolder=None, cache_dir=None) -> Tuple[str, str]:\n",
" \"\"\"\n",
" Helper function to download pre-trained weights from the cloud.\n",
"\n",
" Args:\n",
" url: (str) URL to download from.\n",
" cache_dir: (str) a cache directory where to download. If not present, this function will attempt to create it.\n",
" If None (default), then it will be $HOME/.cache/torch/diarization\n",
" subfolder: (str) subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can\n",
" be empty\n",
"\n",
" Returns:\n",
" Tuple[str, str]: cache_dir and filepath to the downloaded file.\n",
" \"\"\"\n",
" destination = Path.joinpath(cache_dir, subfolder)\n",
"\n",
" if not destination.exists():\n",
" destination.mkdir(parents=True, exist_ok=True)\n",
"\n",
" filename = url.split(\"/\")[-1]\n",
" destination_file = Path.joinpath(destination, filename)\n",
"\n",
" if destination_file.exists():\n",
" return destination, destination_file\n",
"\n",
" i = 0\n",
" while i < 10: # try 10 times\n",
" i += 1\n",
"\n",
" try:\n",
" wget.download(url, str(destination_file))\n",
" if os.path.exists(destination_file):\n",
" return destination, destination_file\n",
"\n",
" except:\n",
" continue\n",
"\n",
" raise ValueError(\"Not able to download the diarization model, please try again later.\")\n",
" \n",
" @staticmethod\n",
" def unpack_nemo_file(filepath: Path, out_folder: Path) -> str:\n",
" \"\"\"\n",
" Unpacks a .nemo file into a folder.\n",
"\n",
" Args:\n",
" filepath (Path): path to the .nemo file (can be compressed or uncompressed)\n",
" out_folder (Path): path to the folder where the .nemo file should be unpacked\n",
"\n",
" Returns:\n",
" path to the unpacked folder\n",
" \"\"\"\n",
" try:\n",
" tar = tarfile.open(filepath, \"r:\") # try uncompressed\n",
" except tarfile.ReadError:\n",
" tar = tarfile.open(filepath, \"r:gz\") # try compressed\n",
" finally:\n",
" tar.extractall(path=out_folder)\n",
" tar.close()\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"model = EncDecSpeakerLabelModel()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"model.model_weights_file_path\n",
"ckpt = torch.load(model.model_weights_file_path, map_location=\"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"odict_keys(['preprocessor.featurizer.window', 'preprocessor.featurizer.fb', 'encoder.encoder.0.mconv.0.conv.weight', 'encoder.encoder.0.mconv.1.conv.weight', 'encoder.encoder.0.mconv.2.weight', 'encoder.encoder.0.mconv.2.bias', 'encoder.encoder.0.mconv.2.running_mean', 'encoder.encoder.0.mconv.2.running_var', 'encoder.encoder.0.mconv.2.num_batches_tracked', 'encoder.encoder.0.mconv.3.fc.0.weight', 'encoder.encoder.0.mconv.3.fc.2.weight', 'encoder.encoder.1.mconv.0.conv.weight', 'encoder.encoder.1.mconv.1.conv.weight', 'encoder.encoder.1.mconv.2.weight', 'encoder.encoder.1.mconv.2.bias', 'encoder.encoder.1.mconv.2.running_mean', 'encoder.encoder.1.mconv.2.running_var', 'encoder.encoder.1.mconv.2.num_batches_tracked', 'encoder.encoder.1.mconv.5.conv.weight', 'encoder.encoder.1.mconv.6.conv.weight', 'encoder.encoder.1.mconv.7.weight', 'encoder.encoder.1.mconv.7.bias', 'encoder.encoder.1.mconv.7.running_mean', 'encoder.encoder.1.mconv.7.running_var', 'encoder.encoder.1.mconv.7.num_batches_tracked', 'encoder.encoder.1.mconv.10.conv.weight', 'encoder.encoder.1.mconv.11.conv.weight', 'encoder.encoder.1.mconv.12.weight', 'encoder.encoder.1.mconv.12.bias', 'encoder.encoder.1.mconv.12.running_mean', 'encoder.encoder.1.mconv.12.running_var', 'encoder.encoder.1.mconv.12.num_batches_tracked', 'encoder.encoder.1.mconv.13.fc.0.weight', 'encoder.encoder.1.mconv.13.fc.2.weight', 'encoder.encoder.1.res.0.0.conv.weight', 'encoder.encoder.1.res.0.1.weight', 'encoder.encoder.1.res.0.1.bias', 'encoder.encoder.1.res.0.1.running_mean', 'encoder.encoder.1.res.0.1.running_var', 'encoder.encoder.1.res.0.1.num_batches_tracked', 'encoder.encoder.2.mconv.0.conv.weight', 'encoder.encoder.2.mconv.1.conv.weight', 'encoder.encoder.2.mconv.2.weight', 'encoder.encoder.2.mconv.2.bias', 'encoder.encoder.2.mconv.2.running_mean', 'encoder.encoder.2.mconv.2.running_var', 'encoder.encoder.2.mconv.2.num_batches_tracked', 'encoder.encoder.2.mconv.5.conv.weight', 'encoder.encoder.2.mconv.6.conv.weight', 'encoder.encoder.2.mconv.7.weight', 'encoder.encoder.2.mconv.7.bias', 'encoder.encoder.2.mconv.7.running_mean', 'encoder.encoder.2.mconv.7.running_var', 'encoder.encoder.2.mconv.7.num_batches_tracked', 'encoder.encoder.2.mconv.10.conv.weight', 'encoder.encoder.2.mconv.11.conv.weight', 'encoder.encoder.2.mconv.12.weight', 'encoder.encoder.2.mconv.12.bias', 'encoder.encoder.2.mconv.12.running_mean', 'encoder.encoder.2.mconv.12.running_var', 'encoder.encoder.2.mconv.12.num_batches_tracked', 'encoder.encoder.2.mconv.13.fc.0.weight', 'encoder.encoder.2.mconv.13.fc.2.weight', 'encoder.encoder.2.res.0.0.conv.weight', 'encoder.encoder.2.res.0.1.weight', 'encoder.encoder.2.res.0.1.bias', 'encoder.encoder.2.res.0.1.running_mean', 'encoder.encoder.2.res.0.1.running_var', 'encoder.encoder.2.res.0.1.num_batches_tracked', 'encoder.encoder.3.mconv.0.conv.weight', 'encoder.encoder.3.mconv.1.conv.weight', 'encoder.encoder.3.mconv.2.weight', 'encoder.encoder.3.mconv.2.bias', 'encoder.encoder.3.mconv.2.running_mean', 'encoder.encoder.3.mconv.2.running_var', 'encoder.encoder.3.mconv.2.num_batches_tracked', 'encoder.encoder.3.mconv.5.conv.weight', 'encoder.encoder.3.mconv.6.conv.weight', 'encoder.encoder.3.mconv.7.weight', 'encoder.encoder.3.mconv.7.bias', 'encoder.encoder.3.mconv.7.running_mean', 'encoder.encoder.3.mconv.7.running_var', 'encoder.encoder.3.mconv.7.num_batches_tracked', 'encoder.encoder.3.mconv.10.conv.weight', 'encoder.encoder.3.mconv.11.conv.weight', 'encoder.encoder.3.mconv.12.weight', 'encoder.encoder.3.mconv.12.bias', 'encoder.encoder.3.mconv.12.running_mean', 'encoder.encoder.3.mconv.12.running_var', 'encoder.encoder.3.mconv.12.num_batches_tracked', 'encoder.encoder.3.mconv.13.fc.0.weight', 'encoder.encoder.3.mconv.13.fc.2.weight', 'encoder.encoder.3.res.0.0.conv.weight', 'encoder.encoder.3.res.0.1.weight', 'encoder.encoder.3.res.0.1.bias', 'encoder.encoder.3.res.0.1.running_mean', 'encoder.encoder.3.res.0.1.running_var', 'encoder.encoder.3.res.0.1.num_batches_tracked', 'encoder.encoder.4.mconv.0.conv.weight', 'encoder.encoder.4.mconv.1.conv.weight', 'encoder.encoder.4.mconv.2.weight', 'encoder.encoder.4.mconv.2.bias', 'encoder.encoder.4.mconv.2.running_mean', 'encoder.encoder.4.mconv.2.running_var', 'encoder.encoder.4.mconv.2.num_batches_tracked', 'encoder.encoder.4.mconv.3.fc.0.weight', 'encoder.encoder.4.mconv.3.fc.2.weight', 'decoder._pooling.attention_layer.0.conv_layer.weight', 'decoder._pooling.attention_layer.0.conv_layer.bias', 'decoder._pooling.attention_layer.0.bn.weight', 'decoder._pooling.attention_layer.0.bn.bias', 'decoder._pooling.attention_layer.0.bn.running_mean', 'decoder._pooling.attention_layer.0.bn.running_var', 'decoder._pooling.attention_layer.0.bn.num_batches_tracked', 'decoder._pooling.attention_layer.2.weight', 'decoder._pooling.attention_layer.2.bias', 'decoder.emb_layers.0.0.weight', 'decoder.emb_layers.0.0.bias', 'decoder.emb_layers.0.0.running_mean', 'decoder.emb_layers.0.0.running_var', 'decoder.emb_layers.0.0.num_batches_tracked', 'decoder.emb_layers.0.1.weight', 'decoder.emb_layers.0.1.bias', 'decoder.final.weight'])\n"
]
}
],
"source": [
"print(ckpt.keys())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"encoder = model.model_config[\"encoder\"]\n",
"decoder = model.model_config[\"decoder\"]\n",
"preprocessor = model.model_config[\"preprocessor\"]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_target_: nemo.collections.asr.modules.ConvASREncoder\n",
"feat_in: 80\n",
"activation: relu\n",
"conv_mask: True\n",
"jasper:\n",
"{'filters': 1024, 'repeat': 1, 'kernel': [3], 'stride': [1], 'dilation': [1], 'dropout': 0.0, 'residual': False, 'separable': True, 'se': True, 'se_context_size': -1}\n",
"{'filters': 1024, 'repeat': 3, 'kernel': [7], 'stride': [1], 'dilation': [1], 'dropout': 0.1, 'residual': True, 'separable': True, 'se': True, 'se_context_size': -1}\n",
"{'filters': 1024, 'repeat': 3, 'kernel': [11], 'stride': [1], 'dilation': [1], 'dropout': 0.1, 'residual': True, 'separable': True, 'se': True, 'se_context_size': -1}\n",
"{'filters': 1024, 'repeat': 3, 'kernel': [15], 'stride': [1], 'dilation': [1], 'dropout': 0.1, 'residual': True, 'separable': True, 'se': True, 'se_context_size': -1}\n",
"{'filters': 3072, 'repeat': 1, 'kernel': [1], 'stride': [1], 'dilation': [1], 'dropout': 0.0, 'residual': False, 'separable': True, 'se': True, 'se_context_size': -1}\n"
]
}
],
"source": [
"for k, v in encoder.items():\n",
" if isinstance(v, list):\n",
" print(f\"{k}:\")\n",
" for value in v:\n",
" print(f\"{value}\")\n",
" else:\n",
" print(f\"{k}: {v}\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_target_: nemo.collections.asr.modules.SpeakerDecoder\n",
"feat_in: 3072\n",
"num_classes: 16681\n",
"pool_mode: attention\n",
"emb_sizes: 192\n"
]
}
],
"source": [
"for k, v in decoder.items():\n",
" if isinstance(v, list):\n",
" print(f\"{k}:\")\n",
" for value in v:\n",
" print(f\"{value}\")\n",
" else:\n",
" print(f\"{k}: {v}\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor\n",
"normalize: per_feature\n",
"window_size: 0.025\n",
"sample_rate: 16000\n",
"window_stride: 0.01\n",
"window: hann\n",
"features: 80\n",
"n_fft: 512\n",
"frame_splicing: 1\n",
"dither: 1e-05\n"
]
}
],
"source": [
"for k, v in preprocessor.items():\n",
" if isinstance(v, list):\n",
" print(f\"{k}:\")\n",
" for value in v:\n",
" print(f\"{value}\")\n",
" else:\n",
" print(f\"{k}: {v}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion notebooks/youtube_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"alignment": False, # Longer processing time but better timestamps
"num_speakers": -1, # Leave at -1 to guess the number of speakers
"diarization": True, # Longer processing time but speaker segment attribution
"source_lang": "nl", # optional, default is "en"
"source_lang": "en", # optional, default is "en"
"timestamps": "s", # optional, default is "s". Can be "s", "ms" or "hms".
"internal_vad": False, # optional, default is False
"word_timestamps": False, # optional, default is False
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


package = "wordcab_transcribe"
python_versions = ["3.9"]
python_versions = ["3.10"]
nox.needs_version = ">= 2021.6.6"
nox.options.sessions = (
"pre-commit",
Expand Down
Loading

0 comments on commit d303ac0

Please sign in to comment.