diff --git a/README.md b/README.md index be544e0..84cee85 100644 --- a/README.md +++ b/README.md @@ -44,8 +44,38 @@ dataset[0] # (1, 158621) dataset[1] # (1, 153757) ``` +#### Full API: +```py +LJSpeechDataset( + root: str = "./data", # The root where the dataset will be downloaded + transforms: Optional[Callable] = None, # Transforms to apply to audio files +) +``` + +### LibriSpeech Dataset +Wrapper for the [LibriSpeech](https://www.openslr.org/12) dataset (EN only). Requires `pip install datasets`. Note that this dataset requires several GBs of storage. + +```py +from audio_data_pytorch import LibriSpeechDataset + +dataset = LibriSpeechDataset( + root="./data", +) + +dataset[0] # (1, 222336) +``` + +#### Full API: +```py +LibriSpeechDataset( + root: str = "./data", # The root where the dataset will be downloaded + with_info: bool = False, # Whether to return info (i.e. text, sampling rate, speaker_id) + transforms: Optional[Callable] = None, # Transforms to apply to audio files +) +``` + ### Common Voice Dataset -Multilanguage wrapper for the [Common Voice](https://commonvoice.mozilla.org/) dataset with voice-only data. Requires `pip install datasets`. Note that each language requires several GBs of storage, and that you have to confirm access for each distinct version you use e.g. [here](https://huggingface.co/datasets/mozilla-foundation/common_voice_10_0), to validate your Huggingface access token. You can provide a list of `languages` and to avoid an unbalanced dataset the values will be interleaved by downsampling the majority language to have the same number of samples as the minority language. +Multilanguage wrapper for the [Common Voice](https://commonvoice.mozilla.org/). Requires `pip install datasets`. Note that each language requires several GBs of storage, and that you have to confirm access for each distinct version you use e.g. [here](https://huggingface.co/datasets/mozilla-foundation/common_voice_10_0), to validate your Huggingface access token. You can provide a list of `languages` and to avoid an unbalanced dataset the values will be interleaved by downsampling the majority language to have the same number of samples as the minority language. ```py from audio_data_pytorch import CommonVoiceDataset @@ -66,7 +96,7 @@ CommonVoiceDataset( sub_version: int = 0, # Subversion: common_voice_{version}_{sub_version} root: str = "./data", # The root where the dataset will be downloaded languages: Sequence[str] = ['en'], # List of languages to include in the dataset - with_sample_rate: bool = False, # Returns sample rate as second argument + with_info: bool = False, # Whether to return info (i.e. text, sampling rate, age, gender, accent, locale) transforms: Optional[Callable] = None, # Transforms to apply to audio files ) ``` diff --git a/audio_data_pytorch/datasets/__init__.py b/audio_data_pytorch/datasets/__init__.py index 1f6b431..3fedc19 100644 --- a/audio_data_pytorch/datasets/__init__.py +++ b/audio_data_pytorch/datasets/__init__.py @@ -1,4 +1,5 @@ from .common_voice_dataset import CommonVoiceDataset -from .ljspeech_dataset import LJSpeechDataset +from .libri_speech_dataset import LibriSpeechDataset +from .lj_speech_dataset import LJSpeechDataset from .wav_dataset import WAVDataset from .youtube_dataset import YoutubeDataset diff --git a/audio_data_pytorch/datasets/common_voice_dataset.py b/audio_data_pytorch/datasets/common_voice_dataset.py index 4a291ca..6e6a69a 100644 --- a/audio_data_pytorch/datasets/common_voice_dataset.py +++ b/audio_data_pytorch/datasets/common_voice_dataset.py @@ -1,5 +1,5 @@ import os -from typing import Callable, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -14,10 +14,10 @@ def __init__( sub_version: int = 0, root: str = "./data", languages: Sequence[str] = ["en"], - with_sample_rate: bool = False, + with_info: bool = False, transforms: Optional[Callable] = None, ): - self.with_sample_rate = with_sample_rate + self.with_info = with_info self.transforms = transforms from datasets import interleave_datasets, load_dataset @@ -37,15 +37,24 @@ def __init__( def __getitem__( self, idx: Union[Tensor, int] - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + ) -> Union[Tensor, Tuple[Tensor, Dict]]: idx = idx.tolist() if torch.is_tensor(idx) else idx # type: ignore data = self.dataset[idx] + waveform = torch.tensor(data["audio"]["array"]).view(1, -1) - sample_rate = data["audio"]["sampling_rate"] + + info = dict( + sample_rate=data["audio"]["sampling_rate"], + text=data["sentence"], + age=data["age"], + accent=data["accent"], + gender=data["gender"], + locale=data["locale"], + ) if self.transforms: waveform = self.transforms(waveform) - return (waveform, sample_rate) if self.with_sample_rate else waveform + return (waveform, info) if self.with_info else waveform def __len__(self) -> int: return len(self.dataset) diff --git a/audio_data_pytorch/datasets/libri_speech_dataset.py b/audio_data_pytorch/datasets/libri_speech_dataset.py new file mode 100644 index 0000000..7850b6b --- /dev/null +++ b/audio_data_pytorch/datasets/libri_speech_dataset.py @@ -0,0 +1,44 @@ +import os +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.utils.data import Dataset + + +class LibriSpeechDataset(Dataset): + def __init__( + self, + root: str = "./data", + with_info: bool = False, + transforms: Optional[Callable] = None, + ): + self.with_info = with_info + self.transforms = transforms + + from datasets import load_dataset + + self.dataset = load_dataset( + "librispeech_asr", + "clean", + split="train.100", + cache_dir=os.path.join(root, "librispeech_dataset"), + ) + + def __getitem__( + self, idx: Union[Tensor, int] + ) -> Union[Tensor, Tuple[Tensor, Dict]]: + idx = idx.tolist() if torch.is_tensor(idx) else idx # type: ignore + data = self.dataset[idx] + waveform = torch.tensor(data["audio"]["array"]).view(1, -1) + info = dict( + sample_rate=data["audio"]["sampling_rate"], + text=data["text"], + speaker_id=data["speaker_id"], + ) + if self.transforms: + waveform = self.transforms(waveform) + return (waveform, info) if self.with_info else waveform + + def __len__(self) -> int: + return len(self.dataset) diff --git a/audio_data_pytorch/datasets/ljspeech_dataset.py b/audio_data_pytorch/datasets/lj_speech_dataset.py similarity index 100% rename from audio_data_pytorch/datasets/ljspeech_dataset.py rename to audio_data_pytorch/datasets/lj_speech_dataset.py diff --git a/setup.py b/setup.py index d8870f6..10f84c8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-data-pytorch", packages=find_packages(exclude=[]), - version="0.0.10", + version="0.0.11", license="MIT", description="Audio Data - PyTorch", long_description_content_type="text/markdown",