-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #202 from Wordcab/196-implement-the-segmentation-m…
…odel Re-implement EncDecSpeakerLabelModel
- Loading branch information
Showing
13 changed files
with
3,328 additions
and
2,910 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,311 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import hashlib\n", | ||
"import os\n", | ||
"import tarfile\n", | ||
"import wget\n", | ||
"import yaml\n", | ||
"from pathlib import Path\n", | ||
"from typing import Optional, Tuple, Union\n", | ||
"\n", | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"\n", | ||
"\n", | ||
"def resolve_diarization_cache_dir() -> Path:\n", | ||
" \"\"\"\n", | ||
" Utility method to get the cache directory for the diarization module.\n", | ||
"\n", | ||
" Returns:\n", | ||
" Path: The path to the cache directory.\n", | ||
" \"\"\"\n", | ||
" path = Path.joinpath(Path.home(), f\".cache/torch/diarization\")\n", | ||
"\n", | ||
" return path" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Inspired from NVIDIA NeMo's EncDecSpeakerLabelModel\n", | ||
"# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/label_models.py#L67\n", | ||
"class EncDecSpeakerLabelModel:\n", | ||
" \"\"\"The EncDecSpeakerLabelModel class encapsulates the encoder-decoder speaker label model.\"\"\"\n", | ||
"\n", | ||
" def __init__(self, model_name: str = \"titanet_large\") -> None:\n", | ||
" \"\"\"Initialize the EncDecSpeakerLabelModel class.\n", | ||
"\n", | ||
" The EncDecSpeakerLabelModel class encapsulates the encoder-decoder speaker label model.\n", | ||
" Only the \"titanet_large\" model is supported at the moment.\n", | ||
" For more models: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/label_models.py#L59\n", | ||
"\n", | ||
" Args:\n", | ||
" model_name (str, optional): The name of the model to use. Defaults to \"titanet_large\".\n", | ||
"\n", | ||
" Raises:\n", | ||
" ValueError: If the model name is not supported.\n", | ||
" \"\"\"\n", | ||
" if model_name != \"titanet_large\":\n", | ||
" raise ValueError(\n", | ||
" f\"Unknown model name: {model_name}. Only 'titanet_large' is supported at the moment.\"\n", | ||
" )\n", | ||
"\n", | ||
" self.model_name = model_name\n", | ||
" self.location_in_the_cloud = \"https://api.ngc.nvidia.com/v2/models/nvidia/nemo/titanet_large/versions/v1/files/titanet-l.nemo\"\n", | ||
" self.cache_dir = Path.joinpath(resolve_diarization_cache_dir(), \"titanet-l\")\n", | ||
" cache_subfolder = hashlib.md5((self.location_in_the_cloud).encode(\"utf-8\")).hexdigest()\n", | ||
"\n", | ||
" self.nemo_model_folder, self.nemo_model_file = self.download_model_if_required(\n", | ||
" url=self.location_in_the_cloud, cache_dir=self.cache_dir, subfolder=cache_subfolder,\n", | ||
" )\n", | ||
"\n", | ||
" self.model_files = Path.joinpath(self.nemo_model_folder, \"model_files\")\n", | ||
" if not self.model_files.exists():\n", | ||
" self.model_files.mkdir(parents=True, exist_ok=True)\n", | ||
" self.unpack_nemo_file(self.nemo_model_file, self.model_files)\n", | ||
"\n", | ||
" self.model_weights_file_path = Path.joinpath(self.model_files, \"model_weights.ckpt\")\n", | ||
" model_config_file_path = Path.joinpath(self.model_files, \"model_config.yaml\")\n", | ||
" with open(model_config_file_path, \"r\") as config_file:\n", | ||
" self.model_config = yaml.safe_load(config_file)\n", | ||
"\n", | ||
"\n", | ||
" @staticmethod\n", | ||
" def download_model_if_required(url, subfolder=None, cache_dir=None) -> Tuple[str, str]:\n", | ||
" \"\"\"\n", | ||
" Helper function to download pre-trained weights from the cloud.\n", | ||
"\n", | ||
" Args:\n", | ||
" url: (str) URL to download from.\n", | ||
" cache_dir: (str) a cache directory where to download. If not present, this function will attempt to create it.\n", | ||
" If None (default), then it will be $HOME/.cache/torch/diarization\n", | ||
" subfolder: (str) subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can\n", | ||
" be empty\n", | ||
"\n", | ||
" Returns:\n", | ||
" Tuple[str, str]: cache_dir and filepath to the downloaded file.\n", | ||
" \"\"\"\n", | ||
" destination = Path.joinpath(cache_dir, subfolder)\n", | ||
"\n", | ||
" if not destination.exists():\n", | ||
" destination.mkdir(parents=True, exist_ok=True)\n", | ||
"\n", | ||
" filename = url.split(\"/\")[-1]\n", | ||
" destination_file = Path.joinpath(destination, filename)\n", | ||
"\n", | ||
" if destination_file.exists():\n", | ||
" return destination, destination_file\n", | ||
"\n", | ||
" i = 0\n", | ||
" while i < 10: # try 10 times\n", | ||
" i += 1\n", | ||
"\n", | ||
" try:\n", | ||
" wget.download(url, str(destination_file))\n", | ||
" if os.path.exists(destination_file):\n", | ||
" return destination, destination_file\n", | ||
"\n", | ||
" except:\n", | ||
" continue\n", | ||
"\n", | ||
" raise ValueError(\"Not able to download the diarization model, please try again later.\")\n", | ||
" \n", | ||
" @staticmethod\n", | ||
" def unpack_nemo_file(filepath: Path, out_folder: Path) -> str:\n", | ||
" \"\"\"\n", | ||
" Unpacks a .nemo file into a folder.\n", | ||
"\n", | ||
" Args:\n", | ||
" filepath (Path): path to the .nemo file (can be compressed or uncompressed)\n", | ||
" out_folder (Path): path to the folder where the .nemo file should be unpacked\n", | ||
"\n", | ||
" Returns:\n", | ||
" path to the unpacked folder\n", | ||
" \"\"\"\n", | ||
" try:\n", | ||
" tar = tarfile.open(filepath, \"r:\") # try uncompressed\n", | ||
" except tarfile.ReadError:\n", | ||
" tar = tarfile.open(filepath, \"r:gz\") # try compressed\n", | ||
" finally:\n", | ||
" tar.extractall(path=out_folder)\n", | ||
" tar.close()\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = EncDecSpeakerLabelModel()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 18, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model.model_weights_file_path\n", | ||
"ckpt = torch.load(model.model_weights_file_path, map_location=\"cpu\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 23, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"odict_keys(['preprocessor.featurizer.window', 'preprocessor.featurizer.fb', 'encoder.encoder.0.mconv.0.conv.weight', 'encoder.encoder.0.mconv.1.conv.weight', 'encoder.encoder.0.mconv.2.weight', 'encoder.encoder.0.mconv.2.bias', 'encoder.encoder.0.mconv.2.running_mean', 'encoder.encoder.0.mconv.2.running_var', 'encoder.encoder.0.mconv.2.num_batches_tracked', 'encoder.encoder.0.mconv.3.fc.0.weight', 'encoder.encoder.0.mconv.3.fc.2.weight', 'encoder.encoder.1.mconv.0.conv.weight', 'encoder.encoder.1.mconv.1.conv.weight', 'encoder.encoder.1.mconv.2.weight', 'encoder.encoder.1.mconv.2.bias', 'encoder.encoder.1.mconv.2.running_mean', 'encoder.encoder.1.mconv.2.running_var', 'encoder.encoder.1.mconv.2.num_batches_tracked', 'encoder.encoder.1.mconv.5.conv.weight', 'encoder.encoder.1.mconv.6.conv.weight', 'encoder.encoder.1.mconv.7.weight', 'encoder.encoder.1.mconv.7.bias', 'encoder.encoder.1.mconv.7.running_mean', 'encoder.encoder.1.mconv.7.running_var', 'encoder.encoder.1.mconv.7.num_batches_tracked', 'encoder.encoder.1.mconv.10.conv.weight', 'encoder.encoder.1.mconv.11.conv.weight', 'encoder.encoder.1.mconv.12.weight', 'encoder.encoder.1.mconv.12.bias', 'encoder.encoder.1.mconv.12.running_mean', 'encoder.encoder.1.mconv.12.running_var', 'encoder.encoder.1.mconv.12.num_batches_tracked', 'encoder.encoder.1.mconv.13.fc.0.weight', 'encoder.encoder.1.mconv.13.fc.2.weight', 'encoder.encoder.1.res.0.0.conv.weight', 'encoder.encoder.1.res.0.1.weight', 'encoder.encoder.1.res.0.1.bias', 'encoder.encoder.1.res.0.1.running_mean', 'encoder.encoder.1.res.0.1.running_var', 'encoder.encoder.1.res.0.1.num_batches_tracked', 'encoder.encoder.2.mconv.0.conv.weight', 'encoder.encoder.2.mconv.1.conv.weight', 'encoder.encoder.2.mconv.2.weight', 'encoder.encoder.2.mconv.2.bias', 'encoder.encoder.2.mconv.2.running_mean', 'encoder.encoder.2.mconv.2.running_var', 'encoder.encoder.2.mconv.2.num_batches_tracked', 'encoder.encoder.2.mconv.5.conv.weight', 'encoder.encoder.2.mconv.6.conv.weight', 'encoder.encoder.2.mconv.7.weight', 'encoder.encoder.2.mconv.7.bias', 'encoder.encoder.2.mconv.7.running_mean', 'encoder.encoder.2.mconv.7.running_var', 'encoder.encoder.2.mconv.7.num_batches_tracked', 'encoder.encoder.2.mconv.10.conv.weight', 'encoder.encoder.2.mconv.11.conv.weight', 'encoder.encoder.2.mconv.12.weight', 'encoder.encoder.2.mconv.12.bias', 'encoder.encoder.2.mconv.12.running_mean', 'encoder.encoder.2.mconv.12.running_var', 'encoder.encoder.2.mconv.12.num_batches_tracked', 'encoder.encoder.2.mconv.13.fc.0.weight', 'encoder.encoder.2.mconv.13.fc.2.weight', 'encoder.encoder.2.res.0.0.conv.weight', 'encoder.encoder.2.res.0.1.weight', 'encoder.encoder.2.res.0.1.bias', 'encoder.encoder.2.res.0.1.running_mean', 'encoder.encoder.2.res.0.1.running_var', 'encoder.encoder.2.res.0.1.num_batches_tracked', 'encoder.encoder.3.mconv.0.conv.weight', 'encoder.encoder.3.mconv.1.conv.weight', 'encoder.encoder.3.mconv.2.weight', 'encoder.encoder.3.mconv.2.bias', 'encoder.encoder.3.mconv.2.running_mean', 'encoder.encoder.3.mconv.2.running_var', 'encoder.encoder.3.mconv.2.num_batches_tracked', 'encoder.encoder.3.mconv.5.conv.weight', 'encoder.encoder.3.mconv.6.conv.weight', 'encoder.encoder.3.mconv.7.weight', 'encoder.encoder.3.mconv.7.bias', 'encoder.encoder.3.mconv.7.running_mean', 'encoder.encoder.3.mconv.7.running_var', 'encoder.encoder.3.mconv.7.num_batches_tracked', 'encoder.encoder.3.mconv.10.conv.weight', 'encoder.encoder.3.mconv.11.conv.weight', 'encoder.encoder.3.mconv.12.weight', 'encoder.encoder.3.mconv.12.bias', 'encoder.encoder.3.mconv.12.running_mean', 'encoder.encoder.3.mconv.12.running_var', 'encoder.encoder.3.mconv.12.num_batches_tracked', 'encoder.encoder.3.mconv.13.fc.0.weight', 'encoder.encoder.3.mconv.13.fc.2.weight', 'encoder.encoder.3.res.0.0.conv.weight', 'encoder.encoder.3.res.0.1.weight', 'encoder.encoder.3.res.0.1.bias', 'encoder.encoder.3.res.0.1.running_mean', 'encoder.encoder.3.res.0.1.running_var', 'encoder.encoder.3.res.0.1.num_batches_tracked', 'encoder.encoder.4.mconv.0.conv.weight', 'encoder.encoder.4.mconv.1.conv.weight', 'encoder.encoder.4.mconv.2.weight', 'encoder.encoder.4.mconv.2.bias', 'encoder.encoder.4.mconv.2.running_mean', 'encoder.encoder.4.mconv.2.running_var', 'encoder.encoder.4.mconv.2.num_batches_tracked', 'encoder.encoder.4.mconv.3.fc.0.weight', 'encoder.encoder.4.mconv.3.fc.2.weight', 'decoder._pooling.attention_layer.0.conv_layer.weight', 'decoder._pooling.attention_layer.0.conv_layer.bias', 'decoder._pooling.attention_layer.0.bn.weight', 'decoder._pooling.attention_layer.0.bn.bias', 'decoder._pooling.attention_layer.0.bn.running_mean', 'decoder._pooling.attention_layer.0.bn.running_var', 'decoder._pooling.attention_layer.0.bn.num_batches_tracked', 'decoder._pooling.attention_layer.2.weight', 'decoder._pooling.attention_layer.2.bias', 'decoder.emb_layers.0.0.weight', 'decoder.emb_layers.0.0.bias', 'decoder.emb_layers.0.0.running_mean', 'decoder.emb_layers.0.0.running_var', 'decoder.emb_layers.0.0.num_batches_tracked', 'decoder.emb_layers.0.1.weight', 'decoder.emb_layers.0.1.bias', 'decoder.final.weight'])\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(ckpt.keys())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"encoder = model.model_config[\"encoder\"]\n", | ||
"decoder = model.model_config[\"decoder\"]\n", | ||
"preprocessor = model.model_config[\"preprocessor\"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"_target_: nemo.collections.asr.modules.ConvASREncoder\n", | ||
"feat_in: 80\n", | ||
"activation: relu\n", | ||
"conv_mask: True\n", | ||
"jasper:\n", | ||
"{'filters': 1024, 'repeat': 1, 'kernel': [3], 'stride': [1], 'dilation': [1], 'dropout': 0.0, 'residual': False, 'separable': True, 'se': True, 'se_context_size': -1}\n", | ||
"{'filters': 1024, 'repeat': 3, 'kernel': [7], 'stride': [1], 'dilation': [1], 'dropout': 0.1, 'residual': True, 'separable': True, 'se': True, 'se_context_size': -1}\n", | ||
"{'filters': 1024, 'repeat': 3, 'kernel': [11], 'stride': [1], 'dilation': [1], 'dropout': 0.1, 'residual': True, 'separable': True, 'se': True, 'se_context_size': -1}\n", | ||
"{'filters': 1024, 'repeat': 3, 'kernel': [15], 'stride': [1], 'dilation': [1], 'dropout': 0.1, 'residual': True, 'separable': True, 'se': True, 'se_context_size': -1}\n", | ||
"{'filters': 3072, 'repeat': 1, 'kernel': [1], 'stride': [1], 'dilation': [1], 'dropout': 0.0, 'residual': False, 'separable': True, 'se': True, 'se_context_size': -1}\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for k, v in encoder.items():\n", | ||
" if isinstance(v, list):\n", | ||
" print(f\"{k}:\")\n", | ||
" for value in v:\n", | ||
" print(f\"{value}\")\n", | ||
" else:\n", | ||
" print(f\"{k}: {v}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"_target_: nemo.collections.asr.modules.SpeakerDecoder\n", | ||
"feat_in: 3072\n", | ||
"num_classes: 16681\n", | ||
"pool_mode: attention\n", | ||
"emb_sizes: 192\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for k, v in decoder.items():\n", | ||
" if isinstance(v, list):\n", | ||
" print(f\"{k}:\")\n", | ||
" for value in v:\n", | ||
" print(f\"{value}\")\n", | ||
" else:\n", | ||
" print(f\"{k}: {v}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor\n", | ||
"normalize: per_feature\n", | ||
"window_size: 0.025\n", | ||
"sample_rate: 16000\n", | ||
"window_stride: 0.01\n", | ||
"window: hann\n", | ||
"features: 80\n", | ||
"n_fft: 512\n", | ||
"frame_splicing: 1\n", | ||
"dither: 1e-05\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for k, v in preprocessor.items():\n", | ||
" if isinstance(v, list):\n", | ||
" print(f\"{k}:\")\n", | ||
" for value in v:\n", | ||
" print(f\"{value}\")\n", | ||
" else:\n", | ||
" print(f\"{k}: {v}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "base", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.16" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.