From 348c6d5dc2ddbb5a4cdca15a75b1650ed5263025 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Mon, 24 Oct 2022 13:27:50 +0200 Subject: [PATCH] feat: add fastscan dir, add corrupt check, add silence check --- README.md | 1 + .../datasets/audio_web_dataset.py | 3 +- audio_data_pytorch/datasets/clotho_dataset.py | 4 +- audio_data_pytorch/datasets/wav_dataset.py | 38 ++++++++++++---- audio_data_pytorch/utils.py | 43 +++++++++++++++++++ setup.py | 2 +- 6 files changed, 79 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index ce64950..10040d3 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ WAVDataset( recursive: bool = False # Recursively load files from provided paths with_sample_rate: bool = False, # Returns sample rate as second argument transforms: Optional[Callable] = None, # Transforms to apply to audio files + check_silence: bool = True # Discards silent samples if true ) ``` diff --git a/audio_data_pytorch/datasets/audio_web_dataset.py b/audio_data_pytorch/datasets/audio_web_dataset.py index a2d2f61..da3b02d 100644 --- a/audio_data_pytorch/datasets/audio_web_dataset.py +++ b/audio_data_pytorch/datasets/audio_web_dataset.py @@ -46,7 +46,6 @@ def process_wav(self): waveform = self.transforms(waveform) wav_dest_path = f"{self.path_prefix}.wav" - print(wav_dest_path) torchaudio.save(wav_dest_path, waveform, rate) self.wav_dest_path = wav_dest_path @@ -106,7 +105,7 @@ async def preprocess(self): waveform_id = 0 async with Downloader(urls, path=path) as files: - async with Decompressor(files, path=path) as folders: + async with Decompressor(files, path=path, remove_on_exit=True) as folders: with tarfile.open(tarfile_name, "w") as archive: for folder in tqdm(folders): for wav in tqdm(glob.glob(folder + "/**/*.wav")): diff --git a/audio_data_pytorch/datasets/clotho_dataset.py b/audio_data_pytorch/datasets/clotho_dataset.py index 4aa9bf4..4135a31 100644 --- a/audio_data_pytorch/datasets/clotho_dataset.py +++ b/audio_data_pytorch/datasets/clotho_dataset.py @@ -54,7 +54,9 @@ async def preprocess(self): async with Downloader(urls, path=path) as files: to_decompress = [f for f in files if f.endswith(".7z")] caption_csv_file = [f for f in files if f.endswith(".csv")][0] - async with Decompressor(to_decompress, path=path) as folders: + async with Decompressor( + to_decompress, path=path, remove_on_exit=True + ) as folders: captions = pd.read_csv(caption_csv_file) length = len(captions.index) diff --git a/audio_data_pytorch/datasets/wav_dataset.py b/audio_data_pytorch/datasets/wav_dataset.py index 156542c..6047516 100644 --- a/audio_data_pytorch/datasets/wav_dataset.py +++ b/audio_data_pytorch/datasets/wav_dataset.py @@ -1,5 +1,4 @@ -import glob -import os +import random from typing import Callable, List, Optional, Sequence, Tuple, Union import torch @@ -7,14 +6,15 @@ from torch import Tensor from torch.utils.data import Dataset +from ..utils import fast_scandir, is_silence + def get_all_wav_filenames(paths: Sequence[str], recursive: bool) -> List[str]: - extensions = ["wav", "flac"] + extensions = [".wav", ".flac"] filenames = [] - for ext_name in extensions: - ext = f"**/*.{ext_name}" if recursive else f"*.{ext_name}" - for path in paths: - filenames.extend(glob.glob(os.path.join(path, ext), recursive=recursive)) + for path in paths: + _, files = fast_scandir(path, extensions, recursive=recursive) + filenames.extend(files) return filenames @@ -25,26 +25,48 @@ def __init__( recursive: bool = False, transforms: Optional[Callable] = None, sample_rate: Optional[int] = None, + check_silence: bool = True, ): self.paths = path if isinstance(path, (list, tuple)) else [path] self.wavs = get_all_wav_filenames(self.paths, recursive=recursive) self.transforms = transforms self.sample_rate = sample_rate + self.check_silence = check_silence def __getitem__( self, idx: Union[Tensor, int] ) -> Union[Tensor, Tuple[Tensor, Tensor]]: idx = idx.tolist() if torch.is_tensor(idx) else idx # type: ignore - waveform, sample_rate = torchaudio.load(self.wavs[idx]) + invalid_audio = False + + # Check that we can load audio properly + try: + waveform, sample_rate = torchaudio.load(self.wavs[idx]) + except Exception: + invalid_audio = True + + # Check that the sample is not silent + if not invalid_audio and self.check_silence and is_silence(waveform): + invalid_audio = True + # Get new sample if audio is invalid + if invalid_audio: + return self[random.randrange(len(self))] + + # Apply sample rate transform if necessary if self.sample_rate and sample_rate != self.sample_rate: waveform = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=self.sample_rate )(waveform) + # Apply other transforms if self.transforms: waveform = self.transforms(waveform) + # Check silence after transforms (useful for random crops) + if self.check_silence and is_silence(waveform): + return self[random.randrange(len(self))] + return waveform def __len__(self) -> int: diff --git a/audio_data_pytorch/utils.py b/audio_data_pytorch/utils.py index 86c4805..623f83e 100644 --- a/audio_data_pytorch/utils.py +++ b/audio_data_pytorch/utils.py @@ -11,6 +11,7 @@ import aiohttp import torch +from torch import Tensor from torch.utils.data.dataset import Dataset, Subset from tqdm import tqdm from typing_extensions import TypeGuard @@ -46,6 +47,48 @@ def fractional_random_split( return splits +""" +Audio utils +""" + + +def is_silence(audio: Tensor, thresh: int = -60): + dBmax = 20 * torch.log10(torch.flatten(audio.abs()).max()) + return dBmax < thresh + + +""" +Data/async utils +""" + + +def fast_scandir(path: str, exts: List[str], recursive: bool = False): + # Scan files recursively faster than glob + # From github.com/drscotthawley/aeiou/blob/main/aeiou/core.py + subfolders, files = [], [] + + try: # hope to avoid 'permission denied' by this try + for f in os.scandir(path): + try: # 'hope to avoid too many levels of symbolic links' error + if f.is_dir(): + subfolders.append(f.path) + elif f.is_file(): + if os.path.splitext(f.name)[1].lower() in exts: + files.append(f.path) + except Exception: + pass + except Exception: + pass + + if recursive: + for path in list(subfolders): + sf, f = fast_scandir(path, exts, recursive=recursive) + subfolders.extend(sf) + files.extend(f) # type: ignore + + return subfolders, files + + class RunThread(threading.Thread): def __init__(self, func): self.func = func diff --git a/setup.py b/setup.py index 99aacbf..23977a5 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-data-pytorch", packages=find_packages(exclude=[]), - version="0.0.16", + version="0.0.17", license="MIT", description="Audio Data - PyTorch", long_description_content_type="text/markdown",