diff --git a/audiotools/__init__.py b/audiotools/__init__.py index b1ab1e46..c4945117 100644 --- a/audiotools/__init__.py +++ b/audiotools/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.7" +__version__ = "0.6.0" from .core import AudioSignal from .core import STFTParams from .core import Meter diff --git a/audiotools/core/audio_signal.py b/audiotools/core/audio_signal.py index e984e86d..6ad25a85 100644 --- a/audiotools/core/audio_signal.py +++ b/audiotools/core/audio_signal.py @@ -382,6 +382,7 @@ def batch( pad_signals: bool = False, truncate_signals: bool = False, resample: bool = False, + dim: int = 0, ): """Creates a batched AudioSignal from a list of AudioSignals. @@ -398,6 +399,8 @@ def batch( resample : bool, optional Whether to resample AudioSignal to the sample rate of the first AudioSignal in the list, by default False + dim : int, optional + Dimension along which to batch the signals. Returns ------- @@ -453,8 +456,8 @@ def batch( f"All signals must be the same length, or pad_signals/truncate_signals " f"must be True. " ) - # Concatenate along the batch dimension - audio_data = torch.cat([x.audio_data for x in audio_signals], dim=0) + # Concatenate along the specified dimension (default 0) + audio_data = torch.cat([x.audio_data for x in audio_signals], dim=dim) audio_paths = [x.path_to_file for x in audio_signals] batched_signal = cls( diff --git a/audiotools/core/display.py b/audiotools/core/display.py index 23231a39..66cbcf34 100644 --- a/audiotools/core/display.py +++ b/audiotools/core/display.py @@ -72,7 +72,7 @@ def specshow( log_mag = signal.log_magnitude(ref_value=ref) if y_axis == "mel": - log_mag = 10 * signal.mel_spectrogram(n_mels).pow(2).clamp(1e-5).log10() + log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10() log_mag -= log_mag.max() librosa.display.specshow( diff --git a/audiotools/core/util.py b/audiotools/core/util.py index 2839568b..fa021ef7 100644 --- a/audiotools/core/util.py +++ b/audiotools/core/util.py @@ -1,4 +1,5 @@ import csv +import math import numbers import os import random @@ -212,7 +213,7 @@ def _close(): _close() -AUDIO_EXTENSIONS = ["wav", "flac", "mp3", "mp4"] +AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"] def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS): @@ -225,23 +226,35 @@ def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS): Folder to look for audio files in, recursively. ext : List[str], optional Extensions to look for without the ., by default - ``['wav', 'flac', 'mp3', 'mp4']``. + ``['.wav', '.flac', '.mp3', '.mp4']``. """ folder = Path(folder) + # Take care of case where user has passed in an audio file directly + # into one of the calling functions. + if str(folder).endswith(tuple(ext)): + return [folder] files = [] for x in ext: - files += folder.glob(f"**/*.{x}") + files += folder.glob(f"**/*{x}") return files -def read_csv(filelists: List[str], remove_empty: bool = True): - """Reads CSVs that are generated by +def read_sources( + sources: List[str], + remove_empty: bool = True, + relative_path: str = "", + ext: List[str] = AUDIO_EXTENSIONS, +): + """Reads audio sources that can either be folders + full of audio files, or CSV files that contain paths + to audio files. CSV files that adhere to the expected + format can be generated by :py:func:`audiotools.data.preprocess.create_csv`. Parameters ---------- - filelists : List[str] - List of CSV files to be converted into a + sources : List[str] + List of audio sources to be converted into a list of lists of audio files. remove_empty : bool, optional Whether or not to remove rows with an empty "path" @@ -253,18 +266,24 @@ def read_csv(filelists: List[str], remove_empty: bool = True): List of lists of rows of CSV files. """ files = [] - data_path = Path(os.getenv("PATH_TO_DATA", "")) - for filelist in filelists: - with open(filelist, "r") as f: - reader = csv.DictReader(f) - _files = [] - for x in reader: - if remove_empty and x["path"] == "": - continue - if x["path"] != "": - x["path"] = str(data_path / x["path"]) - _files.append(x) - files.append(_files) + relative_path = Path(relative_path) + for source in sources: + source = str(source) + _files = [] + if source.endswith(".csv"): + with open(source, "r") as f: + reader = csv.DictReader(f) + for x in reader: + if remove_empty and x["path"] == "": + continue + if x["path"] != "": + x["path"] = str(relative_path / x["path"]) + _files.append(x) + else: + for x in find_audio(source, ext=ext): + x = str(relative_path / x) + _files.append({"path": x}) + files.append(sorted(_files, key=lambda x: x["path"])) return files @@ -287,9 +306,9 @@ def choose_from_list_of_lists( typing.Any An item from the list of lists. """ - idx = state.choice(list(range(len(list_of_lists))), p=p) - item_idx = state.randint(len(list_of_lists[idx])) - return list_of_lists[idx][item_idx], idx + source_idx = state.choice(list(range(len(list_of_lists))), p=p) + item_idx = state.randint(len(list_of_lists[source_idx])) + return list_of_lists[source_idx][item_idx], source_idx, item_idx @contextmanager @@ -392,7 +411,7 @@ def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None): return dist_fn(*dist_tuple[1:]) -def collate(list_of_dicts: list): +def collate(list_of_dicts: list, n_splits: int = None): """Collates a list of dictionaries (e.g. as returned by a dataloader) into a dictionary with batched values. This routine uses the default torch collate function for everything @@ -400,31 +419,52 @@ def collate(list_of_dicts: list): :py:func:`audiotools.core.audio_signal.AudioSignal.batch` function. + This function takes n_splits to enable splitting a batch + into multiple sub-batches for the purposes of gradient accumulation, + etc. + Parameters ---------- list_of_dicts : list List of dictionaries to be collated. + n_splits : int + Number of splits to make when creating the batches (split into + sub-batches). Useful for things like gradient accumulation. Returns ------- dict Dictionary containing batched data. """ + from . import AudioSignal - # Flatten the dictionaries to avoid recursion. - list_of_dicts = [flatten(d) for d in list_of_dicts] - dict_of_lists = {k: [dic[k] for dic in list_of_dicts] for k in list_of_dicts[0]} + batches = [] + list_len = len(list_of_dicts) - batch = {} - for k, v in dict_of_lists.items(): - if isinstance(v, list): - if all(isinstance(s, AudioSignal) for s in v): - batch[k] = AudioSignal.batch(v, pad_signals=True) - else: - # Borrow the default collate fn from torch. - batch[k] = torch.utils.data._utils.collate.default_collate(v) - return unflatten(batch) + return_list = False if n_splits is None else True + n_splits = 1 if n_splits is None else n_splits + n_items = int(math.ceil(list_len / n_splits)) + + for i in range(0, list_len, n_items): + # Flatten the dictionaries to avoid recursion. + list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]] + dict_of_lists = { + k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0] + } + + batch = {} + for k, v in dict_of_lists.items(): + if isinstance(v, list): + if all(isinstance(s, AudioSignal) for s in v): + batch[k] = AudioSignal.batch(v, pad_signals=True) + else: + # Borrow the default collate fn from torch. + batch[k] = torch.utils.data._utils.collate.default_collate(v) + batches.append(unflatten(batch)) + + batches = batches[0] if not return_list else batches + return batches BASE_SIZE = 864 @@ -614,6 +654,6 @@ def generate_chord_dataset( voice_lists[voice_name].append("") for voice_name, paths in voice_lists.items(): - create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True, data_path="") + create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True) return output_dir diff --git a/audiotools/data/datasets.py b/audiotools/data/datasets.py index dae17216..b0208eeb 100644 --- a/audiotools/data/datasets.py +++ b/audiotools/data/datasets.py @@ -1,202 +1,56 @@ -import copy -import typing -from multiprocessing import Manager from pathlib import Path from typing import Callable from typing import Dict from typing import List -from typing import Optional from typing import Union import numpy as np -from numpy.random import RandomState -from torch.utils.data import BatchSampler as _BatchSampler from torch.utils.data import SequentialSampler from torch.utils.data.distributed import DistributedSampler from ..core import AudioSignal from ..core import util -# We need to set SHARED_KEYS statically, with no relationship to the -# BaseDataset object, or we'll hit RecursionErrors in the lookup. -SHARED_KEYS = [ - "signal", - "duration", - "shared_transform", - "check_transform", - "sample_rate", - "batch_size", -] - - -class SharedMixin: - """Mixin which creates a set of keys that are shared across processes. - - The getter looks up the name in ``SHARED_KEYS`` (see above). If it's there, - return it from the dictionary that is kept in shared memory. - Otherwise, do the normal ``__getattribute__``. This line only - runs if the key is in ``SHARED_KEYS``. - - The setter looks up the name in ``SHARED_KEYS``. If it's there - set the value in the dictionary accordingly, so that it the other - dataset replicas know about it. Otherwise, do the normal - ``__setattr__``. This line only runs if the key is in ``SHARED_KEYS``. - - >>> SHARED_KEYS = [ - >>> "signal", - >>> "duration", - >>> "shared_transform", - >>> "check_transform", - >>> "sample_rate", - >>> "batch_size", - >>> ] - - """ - - def __getattribute__(self, name: str): - if name in SHARED_KEYS: - return self.shared_dict[name] - else: - return super().__getattribute__(name) - - def __setattr__(self, name, value): - if name in SHARED_KEYS: - self.shared_dict[name] = value - else: - super().__setattr__(name, value) - - -class BaseDataset(SharedMixin): - """This BaseDataset class adds all the necessary logic so that there is - a dictionary that is shared across processes when working with a - DataLoader with num_workers > 0. - - It adds an attribute called ``shared_dict``, and changes the - ``getattr` and ``setattr`` for the object so that it looks things up - in the shared_dict if it's in the above ``SHARED_KEYS``. The complexity - here is coming from working around a few quirks in multiprocessing. - - Parameters - ---------- - length : int - Length of the dataset. - transform : typing.Callable, optional - Transform to instantiate and apply to every item , by default None - """ - - def __init__(self, length: int, transform: typing.Callable = None, **kwargs): - super().__init__() - self.length = length - # The following snippet of code is how we share a - # parameter across workers in a DataLoader, without - # introducing syntax overhead upstream. - - # 1. We use a Manager object, which is shared between - # dataset replicas that are passed to the workers. - self.shared_dict = Manager().dict() - - # Instead of setting `self.duration = duration` for example, we - # instead first set it inside the `self.shared_dict` object. Further - # down, we'll make it so that `self.duration` still works, but - # it works by looking up the key "duration" in `self.shared_dict`. - for k, v in kwargs.items(): - if k in SHARED_KEYS: - self.shared_dict[k] = v - - self.shared_dict["shared_transform"] = copy.deepcopy(transform) - self.shared_dict["check_transform"] = False - self._transform = transform - self.length = length - - @property - def transform(self): - """Transform that is associated with the dataset, copied from - the shared dictionary so that it's up to date, but executution of - "instantiate" will be done within each worker. - """ - if self.check_transform: - self._transform = copy.deepcopy(self.shared_transform) - self.check_transform = False - return self._transform - - @transform.setter - def transform(self, value): - self.shared_transform = value - self.check_transform = True - - def __len__(self): - return self.length - - @staticmethod - def collate(list_of_dicts: typing.Union[list, dict]): - """Collates items drawn from this dataset. Uses - :py:func:`audiotools.core.util.collate`. - - Parameters - ---------- - list_of_dicts : typing.Union[list, dict] - Data drawn from each item. - - Returns - ------- - dict - Dictionary of batched data. - """ - return util.collate(list_of_dicts) - - -def load_signal( - path: str, - sample_rate: int, - num_channels: int = 1, - state: Optional[Union[RandomState, int]] = None, - offset: Optional[int] = None, - duration: Optional[float] = None, - loudness_cutoff: Optional[float] = None, -): - if offset is None: - signal = AudioSignal.salient_excerpt( - path, - duration=duration, - state=state, - loudness_cutoff=loudness_cutoff, - ) - else: - signal = AudioSignal( - path, - offset=offset, - duration=duration, - ) - - if num_channels == 1: - signal = signal.to_mono() - signal = signal.resample(sample_rate) - - if signal.duration < duration: - signal = signal.zero_pad_to(int(duration * sample_rate)) - - return signal - class AudioLoader: - """Loads audio endlessly from a list of CSV files - containing paths to audio files. + """Loads audio endlessly from a list of audio sources + containing paths to audio files. Audio sources can be + folders full of audio files (which are found via file + extension) or by providing a CSV file which contains paths + to audio files. Parameters ---------- - csv_files : List[str], optional - CSV files containing paths to audio files, by default None - csv_weights : List[float], optional - Weights to sample audio files from each CSV, by default None + sources : List[str], optional + Sources containing folders, or CSVs with + paths to audio files, by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None + relative_path : str, optional + Path audio should be loaded relative to, by default "" + transform : Callable, optional + Transform to instantiate alongside audio sample, + by default None + ext : List[str] + List of extensions to find audio within each source by. Can + also be a file name (e.g. "vocals.wav"). by default + ``['.wav', '.flac', '.mp3', '.mp4']``. """ def __init__( self, - csv_files: List[str] = None, - csv_weights: List[float] = None, + sources: List[str] = None, + weights: List[float] = None, + transform: Callable = None, + relative_path: str = "", + ext: List[str] = util.AUDIO_EXTENSIONS, ): - self.audio_lists = util.read_csv(csv_files) - self.csv_weights = csv_weights + self.audio_lists = util.read_sources( + sources, relative_path=relative_path, ext=ext + ) + self.sources = sources + self.weights = weights + self.transform = transform def __call__( self, @@ -206,623 +60,387 @@ def __call__( loudness_cutoff: float = -40, num_channels: int = 1, offset: float = None, + source_idx: int = None, + item_idx: int = None, ): - audio_info, csv_idx = util.choose_from_list_of_lists( - state, self.audio_lists, p=self.csv_weights - ) - - signal = load_signal( - path=audio_info["path"], - sample_rate=sample_rate, - num_channels=num_channels, - state=state, - offset=offset, - duration=duration, - loudness_cutoff=loudness_cutoff, - ) - - for k, v in audio_info.items(): - signal.metadata[k] = v - - return signal, csv_idx - - -class MultiTrackAudioLoader: - """ - This loader behaves similarly to AudioLoader, but - it loads multiple tracks in a group, and returns a dictionary - of AudioSignals, one for each track. - - For example, one may call this loader like this:: - - loader = MultiTrackAudioLoader( - csv_groups = [ - { - "vocals": "datset1/vocals.csv", - "drums": "dataset1/drums.csv", - "bass": "dataset1/bass.csv", - - "coherence": 0.5, - "csv_weight": 1, - "primary_key": "vocals", - }, - { - "vocals": "datset2/vocals.csv", - "drums": "dataset2/drums.csv", - "bass": "dataset2/bass.csv", - "guitar": "dataset2/guitar.csv", - - "coherence": 0.5, - "csv_weight": 3, - "primary_key": "vocals", - }, - ] - ) - - There are special keys that can be passed to each csv group dictionary: + if source_idx is not None and item_idx is not None: + try: + audio_info = self.audio_lists[source_idx][item_idx] + except: + audio_info = {"path": "none"} + else: + audio_info, source_idx, item_idx = util.choose_from_list_of_lists( + state, self.audio_lists, p=self.weights + ) - - csv_weight: (*float*, *optional*) - - weight for sampling this CSV group, by default 1.0. - - primary_key: (*str*, *optional*) - - If provided, will load a salient excerpt from the audio file specified by the primary key. - - If not provided, will pick the first csv file in the group from the csv_group dict. - - coherence: (*float*, *optional*) - - Coherence of sampled multitrack data, by default 1.0 - - Probability of sampling a multitrack recording that is coherent. - - A coherent multitrack recording is one the same CSV row is drawn for each of the sources. - - An incoherent multitrack recording is one where a random row is drawn for each of the sources. + path = audio_info["path"] + signal = AudioSignal.zeros(duration, sample_rate, num_channels) - You can change the default values for these keys by updating the - ``MultiTrackAudioLoader.CSV_GROUP_DEFAULTS`` dictionary. + if path != "none": + if offset is None: + signal = AudioSignal.salient_excerpt( + path, + duration=duration, + state=state, + loudness_cutoff=loudness_cutoff, + ) + else: + signal = AudioSignal( + path, + offset=offset, + duration=duration, + ) - .. note:: If no offset is provided to the loader, then the loader will - choose a salient excerpt as dictated by the signal associated with ``primary_key``. - This may fail if all of the signals in a given row are not of equal duration. + if num_channels == 1: + signal = signal.to_mono() + signal = signal.resample(sample_rate) - Parameters - ---------- - csv_groups: List[Dict[str, str]], optional - List of dictionaries containing CSV files and their associated keys. - """ + if signal.duration < duration: + signal = signal.zero_pad_to(int(duration * sample_rate)) - CSV_GROUP_DEFAULTS = {"csv_weight": 1.0, "coherence": 1.0} - CSV_GROUP_RESERVED_KEYS = ["csv_weight", "coherence", "primary_key"] - - def __init__( - self, - csv_groups: List[Dict[str, str]] = None, - ): + for k, v in audio_info.items(): + signal.metadata[k] = v - csv_weights = [ - g.pop("csv_weight", self.CSV_GROUP_DEFAULTS["csv_weight"]) - for g in csv_groups - ] - csv_weights = np.exp(csv_weights) / np.sum(np.exp(csv_weights)) - self.csv_weights = csv_weights.tolist() - self.coherences = [ - g.pop("coherence", self.CSV_GROUP_DEFAULTS["coherence"]) for g in csv_groups - ] - - # find the set of audio columns - # (i.e. the union of all keys across all csv groups) - # this way, we can add zero signals for any missing tracks - # which let's us batch different csv groups together - csv_group_keys = [list(g.keys()) for g in csv_groups] - self.audio_columns = list( - set( - [ - key - for keys in csv_group_keys - for key in keys - if key not in self.CSV_GROUP_RESERVED_KEYS - ] - ) - ) - self.primary_keys = [ - g.pop("primary_key", keys[0]) for g, keys in zip(csv_groups, csv_group_keys) - ] - - self.csv_groups = csv_groups - self.audio_lists = [] - for csv_dict in csv_groups: - self.audio_lists.append( - { - k: util.read_csv([v], remove_empty=False)[0] - for k, v in csv_dict.items() - } - ) + item = { + "signal": signal, + "source_idx": source_idx, + "item_idx": item_idx, + "source": str(self.sources[source_idx]), + "path": str(path), + } + if self.transform is not None: + item["transform_args"] = self.transform.instantiate(state, signal=signal) + return item - for key, csv_group in zip(self.primary_keys, self.csv_groups): - if key not in csv_group.keys(): - raise ValueError( - f"Primary key {key} not found in csv keys {csv_group.keys()}" - ) - def __call__( - self, - state, - sample_rate: int, - duration: float, - loudness_cutoff: float = -40, - num_channels: int = 1, - offset: float = None, - ): - # pick a group of csvs - csv_group_idx = state.choice(len(self.audio_lists), p=self.csv_weights) - - # grab the group of csvs and primary key for this group - csv_group = self.audio_lists[csv_group_idx] - primary_key = self.primary_keys[csv_group_idx] - - # if not coherent, sample the csv idxs for each track independently - coherence = self.coherences[csv_group_idx] - coherent = state.rand() < coherence - if not coherent: - csv_idxs = state.choice( - len(csv_group[primary_key]), size=len(csv_group), replace=False - ) - csv_idxs = {key: csv_idxs[i] for i, key in enumerate(csv_group.keys())} - else: - # otherwise, use the same csv idx for each track - choice_idx = state.choice(len(csv_group[primary_key])) - csv_idxs = {key: choice_idx for key in csv_group.keys()} - - # pick a row from the primary csv - csv_idx = csv_idxs[primary_key] - p_audio_info = csv_group[primary_key][csv_idx] - - # load the primary signal first (if it exists in this row), - # and use it to determine the offset and duration - if p_audio_info["path"] == "": - primary_signal = AudioSignal.zeros( - sample_rate=sample_rate, num_channels=num_channels, duration=duration - ) - else: - primary_signal = load_signal( - path=p_audio_info["path"], - sample_rate=sample_rate, - num_channels=num_channels, - state=state, - offset=offset, - duration=duration, - loudness_cutoff=loudness_cutoff, - ) +def default_matcher(x, y): + return Path(x).parent == Path(y).parent - # update the offset and duration according to the primary signal - offset = primary_signal.metadata["offset"] - duration = primary_signal.metadata["duration"] - - for k, v in p_audio_info.items(): - # don't update the duration and offset keys - if k not in primary_signal.metadata: - primary_signal.metadata[k] = v - - # load the rest of the signals - signals = {} - for audio_key, audio_list in csv_group.items(): - if audio_key == primary_key: - signals[audio_key] = primary_signal - continue - - csv_idx = csv_idxs[audio_key] - audio_info = audio_list[csv_idx] - - # if the path is empty, then skip - # and add a zero signal later - if audio_info["path"] == "": - continue - - signal = load_signal( - path=audio_info["path"], - sample_rate=sample_rate, - num_channels=num_channels, - state=state, - offset=offset, - duration=duration, - loudness_cutoff=loudness_cutoff, - ) - for k, v in audio_info.items(): - signal.metadata[k] = v - signals[audio_key] = signal +def align_lists(lists, matcher: Callable = default_matcher): + longest_list = lists[np.argmax([len(l) for l in lists])] + for i, x in enumerate(longest_list): + for l in lists: + if i >= len(l): + l.append({"path": "none"}) + elif not matcher(l[i]["path"], x["path"]): + l.insert(i, {"path": "none"}) + return lists - # add zero signals for any missing tracks - for k in self.audio_columns: - if k not in signals: - signals[k] = AudioSignal.zeros( - duration=duration, - num_channels=num_channels, - sample_rate=sample_rate, - ) - for signal in signals.values(): - assert signal.duration == duration +class AudioDataset: + """Loads audio from multiple loaders (with associated transforms) + for a specified number of samples. Excerpts are drawn randomly + of the specified duration, above a specified loudness threshold + and are resampled on the fly to the desired sample rate + (if it is different from the audio source sample rate). - return signals, csv_group_idx + This takes either a single AudioLoader object, + a dictionary of AudioLoader objects, or a dictionary of AudioLoader + objects. Each AudioLoader is called by the dataset, and the + result is placed in the output dictionary. A transform can also be + specified for the entire dataset, rather than for each specific + loader. This transform can be applied to the output of all the + loaders if desired. + AudioLoader objects can be specified as aligned, which means the + loaders correspond to multitrack audio (e.g. a vocals, bass, + drums, and other loader for multitrack music mixtures). -class CSVDataset(BaseDataset): - """This is the core data handling routine in this library. - It expects to draw ``n_examples`` audio files at a specified - ``sample_rate`` of a specified ``duration`` from a list - of ``csv_files`` with probability of each file being - given by ``csv_weights``. All excerpts drawn - will be above the specified ``loudness_cutoff``, have the - same ``num_channels``. A transform is also instantiated using - the index of the item, which is used to actually apply the - transform to the item. Parameters ---------- + loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]] + AudioLoaders to sample audio from. sample_rate : int - Sample rate of audio. + Desired sample rate. n_examples : int, optional - Number of examples, by default 1000 + Number of examples (length of dataset), by default 1000 duration : float, optional - Duration of excerpts, in seconds, by default 0.5 - csv_files : List[str], optional - List of CSV files, by default None - csv_weights : List[float], optional - List of weights of CSV files, by default None + Duration of audio samples, by default 0.5 loudness_cutoff : float, optional - Loudness cutoff in decibels, by default -40 + Loudness cutoff threshold for audio samples, by default -40 num_channels : int, optional - Number of channels, by default 1 - transform : typing.Callable, optional - Transform to instantiate with each item, by default None + Number of channels in output audio, by default 1 + transform : Callable, optional + Transform to instantiate alongside each dataset item, by default None + aligned : bool, optional + Whether the loaders should be sampled in an aligned manner (e.g. same + offset, duration, and matched file name), by default False + shuffle_loaders : bool, optional + Whether to shuffle the loaders before sampling from them, by default False + Examples -------- - - >>> transform = tfm.Compose( - >>> [ - >>> tfm.VolumeNorm(), - >>> tfm.Silence(prob=0.5), - >>> ], - >>> ) - >>> dataset = audiotools.data.datasets.CSVDataset( - >>> 44100, - >>> n_examples=100, - >>> csv_files=["tests/audio/spk.csv"], - >>> transform=transform, + >>> from audiotools.data.datasets import AudioLoader + >>> from audiotools.data.datasets import AudioDataset + >>> from audiotools import transforms as tfm + >>> import numpy as np + >>> + >>> loaders = [ + >>> AudioLoader( + >>> sources=[f"tests/audio/spk"], + >>> transform=tfm.Equalizer(), + >>> ext=["wav"], + >>> ) + >>> for i in range(5) + >>> ] + >>> + >>> dataset = AudioDataset( + >>> loaders = loaders, + >>> sample_rate = 44100, + >>> duration = 1.0, + >>> transform = tfm.RescaleAudio(), >>> ) + >>> + >>> item = dataset[np.random.randint(len(dataset))] + >>> + >>> for i in range(len(loaders)): + >>> item[i]["signal"] = loaders[i].transform( + >>> item[i]["signal"], **item[i]["transform_args"] + >>> ) + >>> item[i]["signal"].widget(i) + >>> + >>> mix = sum([item[i]["signal"] for i in range(len(loaders))]) + >>> mix = dataset.transform(mix, **item["transform_args"]) + >>> mix.widget("mix") + + Below is an example of how one could load MUSDB multitrack data: + + >>> import audiotools as at + >>> from pathlib import Path + >>> from audiotools import transforms as tfm + >>> import numpy as np + >>> import torch + >>> + >>> def build_dataset( + >>> sample_rate: int = 44100, + >>> duration: float = 5.0, + >>> musdb_path: str = "~/.data/musdb/", + >>> ): + >>> musdb_path = Path(musdb_path).expanduser() + >>> loaders = { + >>> src: at.datasets.AudioLoader( + >>> sources=[musdb_path], + >>> transform=tfm.Compose( + >>> tfm.VolumeNorm(("uniform", -20, -10)), + >>> tfm.Silence(prob=0.1), + >>> ), + >>> ext=[f"{src}.wav"], + >>> ) + >>> for src in ["vocals", "bass", "drums", "other"] + >>> } + >>> + >>> dataset = at.datasets.AudioDataset( + >>> loaders=loaders, + >>> sample_rate=sample_rate, + >>> duration=duration, + >>> num_channels=1, + >>> aligned=True, + >>> transform=tfm.RescaleAudio(), + >>> shuffle_loaders=True, + >>> ) + >>> return dataset, list(loaders.keys()) + >>> + >>> train_data, sources = build_dataset() >>> dataloader = torch.utils.data.DataLoader( - >>> dataset, + >>> train_data, >>> batch_size=16, >>> num_workers=0, - >>> collate_fn=dataset.collate, + >>> collate_fn=train_data.collate, >>> ) + >>> batch = next(iter(dataloader)) >>> + >>> for k in sources: + >>> src = batch[k] + >>> src["transformed"] = train_data.loaders[k].transform( + >>> src["signal"].clone(), **src["transform_args"] + >>> ) >>> - >>> for batch in dataloader: - >>> kwargs = batch["transform_args"] - >>> signal = batch["signal"] - >>> original = signal.clone() + >>> mixture = sum(batch[k]["transformed"] for k in sources) + >>> mixture = train_data.transform(mixture, **batch["transform_args"]) >>> - >>> signal = dataset.transform(signal, **kwargs) - >>> original = dataset.transform(original, **kwargs) + >>> # Say a model takes the mix and gives back (n_batch, n_src, n_time). + >>> # Construct the targets: + >>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1) + + Similarly, here's example code for loading Slakh data: + + >>> import audiotools as at + >>> from pathlib import Path + >>> from audiotools import transforms as tfm + >>> import numpy as np + >>> import torch + >>> import glob >>> - >>> mask = kwargs["Compose"]["1.Silence"]["mask"] + >>> def build_dataset( + >>> sample_rate: int = 16000, + >>> duration: float = 10.0, + >>> slakh_path: str = "~/.data/slakh/", + >>> ): + >>> slakh_path = Path(slakh_path).expanduser() >>> - >>> zeros_ = torch.zeros_like(signal[mask].audio_data) - >>> original_ = original[~mask].audio_data + >>> # Find the max number of sources in Slakh + >>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)] + >>> n_sources = len(list(set(src_names))) >>> - >>> assert torch.allclose(signal[mask].audio_data, zeros_) - >>> assert torch.allclose(signal[~mask].audio_data, original_) + >>> loaders = { + >>> f"S{i:02d}": at.datasets.AudioLoader( + >>> sources=[slakh_path], + >>> transform=tfm.Compose( + >>> tfm.VolumeNorm(("uniform", -20, -10)), + >>> tfm.Silence(prob=0.1), + >>> ), + >>> ext=[f"S{i:02d}.wav"], + >>> ) + >>> for i in range(n_sources) + >>> } + >>> dataset = at.datasets.AudioDataset( + >>> loaders=loaders, + >>> sample_rate=sample_rate, + >>> duration=duration, + >>> num_channels=1, + >>> aligned=True, + >>> transform=tfm.RescaleAudio(), + >>> shuffle_loaders=False, + >>> ) + >>> + >>> return dataset, list(loaders.keys()) + >>> + >>> train_data, sources = build_dataset() + >>> dataloader = torch.utils.data.DataLoader( + >>> train_data, + >>> batch_size=16, + >>> num_workers=0, + >>> collate_fn=train_data.collate, + >>> ) + >>> batch = next(iter(dataloader)) + >>> + >>> for k in sources: + >>> src = batch[k] + >>> src["transformed"] = train_data.loaders[k].transform( + >>> src["signal"].clone(), **src["transform_args"] + >>> ) + >>> + >>> mixture = sum(batch[k]["transformed"] for k in sources) + >>> mixture = train_data.transform(mixture, **batch["transform_args"]) """ def __init__( self, + loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]], sample_rate: int, n_examples: int = 1000, duration: float = 0.5, - csv_files: List[str] = None, - csv_weights: List[float] = None, loudness_cutoff: float = -40, num_channels: int = 1, - transform: typing.Callable = None, + transform: Callable = None, + aligned: bool = False, + shuffle_loaders: bool = False, + matcher: Callable = default_matcher, ): - super().__init__( - n_examples, duration=duration, transform=transform, sample_rate=sample_rate - ) + # Internally we convert loaders to a dictionary + if isinstance(loaders, list): + loaders = {i: l for i, l in enumerate(loaders)} + elif isinstance(loaders, AudioLoader): + loaders = {0: loaders} - self.loader = AudioLoader(csv_files, csv_weights) + self.loaders = loaders self.loudness_cutoff = loudness_cutoff self.num_channels = num_channels - def __getitem__(self, idx): - state = util.random_state(idx) + self.length = n_examples + self.transform = transform + self.sample_rate = sample_rate + self.duration = duration + self.aligned = aligned + self.shuffle_loaders = shuffle_loaders - signal, csv_idx = self.loader( - state, - self.sample_rate, - duration=self.duration, - loudness_cutoff=self.loudness_cutoff, - num_channels=self.num_channels, - ) + if aligned: + loaders_list = list(loaders.values()) + for i in range(len(loaders_list[0].audio_lists)): + input_lists = [l.audio_lists[i] for l in loaders_list] + # Alignment happens in-place + align_lists(input_lists, matcher) - # Instantiate the transform. - item = { - "idx": idx, - "signal": signal, - "label": csv_idx, + def __getitem__(self, idx): + state = util.random_state(idx) + offset = None + item = {} + + keys = list(self.loaders.keys()) + if self.shuffle_loaders: + state.shuffle(keys) + + loader_kwargs = { + "state": state, + "sample_rate": self.sample_rate, + "duration": self.duration, + "loudness_cutoff": self.loudness_cutoff, + "num_channels": self.num_channels, } - if self.transform is not None: - item["transform_args"] = self.transform.instantiate(state, signal=signal) - return item - - -class CSVMultiTrackDataset(BaseDataset): - """ - A dataset for loading coherent multitrack data for source separation. - - This dataset behaves similarly to CSV dataset, but instead of - passing a list of single CSV files, you must pass a list of - dictionaries of CSV files, where each dictionary represents - a group of multitrack files that should be loaded together. - Within a dictionary, each CSV file must have the same number of rows, - since it is expected that each row represents a single track in a multitrack recording. - - For example, our list of CSV groups might look like this:: - - - csv_groups = [ - { - "vocals": "dataset1/vocals.csv", - "drums": "dataset1/drums.csv", - "bass": "dataset1/bass.csv", - "coherence_prob": 0.5, # probability of sampling coherent multitracks. - "primary_key": "vocals", # the key of the primary track. - "csv_weight": 1.0, # the weight for sampling this group - }, - { - "vocals": "datset2/vocals.csv", - "drums": "dataset2/drums.csv", - "bass": "dataset2/bass.csv", - "guitar": "dataset2/guitar.csv", - "coherence_prob": 1.0, - "primary_key": "vocals", - "csv_weight": 1.0, - }, - ] - - .. note: - - There are special keys that can be passed to each csv group dictionary: - - - csv_weight: (*float*, *optional*) - - weight for sampling this CSV group, by default 1.0. - - primary_key: (*str*, *optional*) - - If provided, will load a salient excerpt from the audio file specified by the primary key. - - If not provided, will pick the first csv file in the group from the csv_group dict. - - coherence: (*float*, *optional*) - - Coherence of sampled multitrack data, by default 1.0 - - Probability of sampling a multitrack recording that is coherent. - - A coherent multitrack recording is one the same CSV row - is drawn for each of the sources. - - An incoherent multitrack recording is one where a random row - is drawn for each of the sources. - - You can change the default values for these keys by updating the - ``MultiTrackAudioLoader.CSV_GROUP_DEFAULTS`` dictionary. - - You can create a multitrack dataset that behaves similar to - a regular CSV dataset:: - - import audiotools - - transform = audiotools.transforms.Identity() - - # csv dataset - csv_dataset = audiotools.data.datasets.CSVDataset( - 44100, - n_examples=100, - csv_files=["tests/audio/spk.csv"], - transform=transform, - ) - # get an item - data = csv_dataset[0] - # get the signal - signal = data["signal"] - print(signal) - - # multitrack dataset - multitrack_dataset = audiotools.data.datasets.CSVMultiTrackDataset( - 44100, - n_examples=100, - csv_groups=[{ - "speaker": "tests/audio/spk.csv" - }], - transform={ - "speaker": transform, - } - ) - - # take an item from the dataset - data = multitrack_dataset[0] - - # access the audio signal - signal = data["signals"]["speaker"] - print(signal) - - - Parameters - ---------- - sample_rate : int - Sample rate of audio. - csv_groups: List[Dict[str, str]], optional - List of dictionaries containing CSV files and their associated keys. - n_examples : int, optional - Number of examples, by default 1000 - duration : float, optional - Duration of excerpts, in seconds, by default 0.5 - num_channels : int, optional - Number of channels, by default 1 - transforms : Dict[str, typing.Callable], optional - Dict of transforms, one for each source. - mix_transform : typing.Callable, optional - Transform to apply a mix of all sources, by default None. - This is useful if you plan to mix the sources yourself, but want to - apply a transform to the mix. - - Examples - --------- - >>> import audiotools - >>> from audiotools.data.datasets import CSVMultiTrackDataset - >>> source_transforms = { - >>> "vocals": audiotools.transforms.Identity(), - >>> "drums": audiotools.transforms.Identity(), - >>> "bass": audiotools.transforms.Identity(), - >>> } - >>> mix_transform = audiotools.transforms.Identity() - >>> dataset = CSVMultiTrackDataset( - >>> sample_rate=44100, - >>> n_examples=20, - >>> csv_groups=[{ - >>> "drums": "tests/audio/musdb-7s/drums.csv", - >>> "bass": "tests/audio/musdb-7s/bass.csv", - >>> "vocals": "tests/audio/musdb-7s/vocals.csv", - >>> "coherence_prob": 0.95, - >>> "primary_key": "drums", - >>> }, { - >>> "drums": "tests/audio/musdb-7s/drums.csv", - >>> "bass": "tests/audio/musdb-7s/bass.csv", - >>> "coherence_prob": 1.0, - >>> "primary_key": "drums", - >>> }], - >>> transform=source_transforms, - >>> mix_transform=mix_transform, - >>> ) - >>> assert set(dataset.source_names) == set(["bass", "drums", "vocals"]) - >>> dataloader = torch.utils.data.DataLoader( - >>> dataset, - >>> batch_size=16, - >>> num_workers=0, - >>> collate_fn=dataset.collate, - >>> ) - >>> for batch in dataloader: - >>> kwargs = batch["transform_args"] - >>> signals = batch["signals"] - >>> tfmed = { - >>> k: dataset.transform[k](sig.clone(), **kwargs[k]) - >>> for k, sig in signals.items() - >>> } - >>> mix = sum(tfmed.values()) - >>> # apply the mix transform - >>> mix_tfm_kwargs = batch["mix_transform_args"] - >>> mix_tfmed = dataset.mix_transform(mix.clone(), **mix_tfm_kwargs) - """ - def __init__( - self, - sample_rate: int, - n_examples: int = 1000, - duration: float = 0.5, - csv_groups: List[Dict[str, str]] = None, - loudness_cutoff: float = -40, - num_channels: int = 1, - transform: Dict[str, Callable] = None, - mix_transform: Callable = None, - ): - self.loader = MultiTrackAudioLoader(csv_groups) - - self.num_channels = num_channels - self.loudness_cutoff = loudness_cutoff + # Draw item from first loader + loader = self.loaders[keys[0]] + item[keys[0]] = loader(**loader_kwargs) + + for key in keys[1:]: + loader = self.loaders[key] + if self.aligned: + # Path mapper takes the current loader + everything + # returned by the first loader. + offset = item[keys[0]]["signal"].metadata["offset"] + loader_kwargs.update( + { + "offset": offset, + "source_idx": item[keys[0]]["source_idx"], + "item_idx": item[keys[0]]["item_idx"], + } + ) + item[key] = loader(**loader_kwargs) - if transform is None: - transform = {} + # Sort dictionary back into original order + keys = list(self.loaders.keys()) + item = {k: item[k] for k in keys} - assert isinstance( - transform, dict - ), "transform for CSVMultiTrackDataset must be a dict" - for key in self.loader.audio_columns: - if key not in transform: - from .transforms import Identity + item["idx"] = idx + if self.transform is not None: + item["transform_args"] = self.transform.instantiate( + state=state, signal=item[keys[0]]["signal"] + ) - transform[key] = Identity() + # If there's only one loader, pop it up + # to the main dictionary, instead of keeping it + # nested. + if len(keys) == 1: + item.update(item.pop(keys[0])) - if mix_transform is None: - from .transforms import Identity + return item - mix_transform = Identity() + def __len__(self): + return self.length - assert ( - "mix" not in self.loader.audio_columns - ), "mix is a reserved key for CSVMultiTrackDataset" - assert ( - "primary_key" not in self.loader.audio_columns - ), "primary_key is a reserved key for CSVMultiTrackDataset" - assert "mix" not in transform, "mix is a reserved key in the transform dict" - transform["mix"] = mix_transform + @staticmethod + def collate(list_of_dicts: Union[list, dict], n_splits: int = None): + """Collates items drawn from this dataset. Uses + :py:func:`audiotools.core.util.collate`. - super().__init__( - n_examples, duration=duration, transform=transform, sample_rate=sample_rate - ) + Parameters + ---------- + list_of_dicts : typing.Union[list, dict] + Data drawn from each item. + n_splits : int + Number of splits to make when creating the batches (split into + sub-batches). Useful for things like gradient accumulation. - @property - def mix_transform(self): - """ - The mix transform for this dataset, also accessible - via ``dataset.transform["mix"]``. + Returns + ------- + dict + Dictionary of batched data. """ - return self.transform["mix"] - - @mix_transform.setter - def mix_transform(self, value): - self.transform["mix"] = value - - @property - def source_names(self): - """A list of all the source keys dataset.""" - return list(self.loader.audio_columns) - - @property - def primary_keys(self): - """A list of all the primary keys in the dataset, one per csv group.""" - return self.loader.primary_keys - - def __getitem__(self, idx): - state = util.random_state(idx) - - signals, csv_idx = self.loader( - state, - self.sample_rate, - duration=self.duration, - loudness_cutoff=self.loudness_cutoff, - num_channels=self.num_channels, - ) - - # Instantiate the transform. - transform_kwargs = { - k: self.transform[k].instantiate(state, signal=signals[k]) for k in signals - } - - mix_transform_kwargs = self.mix_transform.instantiate( - state, - signal=sum(signals.values()), - ) - - item = { - "idx": idx, - "signals": signals, - "primary_key": self.primary_keys[csv_idx], - "label": csv_idx, - "transform_args": transform_kwargs, - "mix_transform_args": mix_transform_kwargs, - } - return item - - -# Samplers -class BatchSampler(_BatchSampler, SharedMixin): - """BatchSampler that is like the default batch sampler, but shares - the batch size across each worker, so that batch size can be - manipulated across all workers on the fly during training.""" - - def __init__(self, sampler, batch_size: int, drop_last: bool = False): - self.shared_dict = Manager().dict() - super().__init__(sampler, batch_size, drop_last=drop_last) + return util.collate(list_of_dicts, n_splits=n_splits) class ResumableDistributedSampler(DistributedSampler): # pragma: no cover diff --git a/audiotools/data/preprocess.py b/audiotools/data/preprocess.py index 4b2acbed..d90de210 100644 --- a/audiotools/data/preprocess.py +++ b/audiotools/data/preprocess.py @@ -51,12 +51,10 @@ def create_csv( List of audio files. output_csv : Path Output CSV, with each row containing the relative path of every file - to ``PATH_TO_DATA`` (defaults to None). + to ``data_path``, if specified (defaults to None). loudness : bool Compute loudness of entire file and store alongside path. """ - if data_path is None: - data_path = Path(os.getenv("PATH_TO_DATA", "")) info = [] pbar = tqdm(audio_files) @@ -69,7 +67,7 @@ def create_csv( if loudness: _info["loudness"] = -float("inf") else: - _info["path"] = af.relative_to(data_path) + _info["path"] = af.relative_to(data_path) if data_path is not None else af if loudness: _info["loudness"] = AudioSignal(af).ffmpeg_loudness().item() diff --git a/audiotools/data/transforms.py b/audiotools/data/transforms.py index ba64a706..20d301e2 100644 --- a/audiotools/data/transforms.py +++ b/audiotools/data/transforms.py @@ -69,8 +69,8 @@ class BaseTransform: >>> signal = AudioSignal(audio_path, offset=10, duration=2) >>> transform = tfm.Compose( >>> [ - >>> tfm.RoomImpulseResponse(csv_files=["tests/audio/irs.csv"]), - >>> tfm.BackgroundNoise(csv_files=["tests/audio/noises.csv"]), + >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), + >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]), >>> ], >>> ) >>> @@ -292,8 +292,8 @@ class Compose(BaseTransform): >>> transform = tfm.Compose( >>> [ - >>> tfm.RoomImpulseResponse(csv_files=["tests/audio/irs.csv"]), - >>> tfm.BackgroundNoise(csv_files=["tests/audio/noises.csv"]), + >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), + >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]), >>> ], >>> ) @@ -735,10 +735,11 @@ class BackgroundNoise(BaseTransform): ---------- snr : tuple, optional Signal-to-noise ratio, by default ("uniform", 10.0, 30.0) - csv_files : List[str], optional - A list of files to load audio from, by default None - csv_weights : List[float], optional - Weights to sample audio files from each CSV, by default None + sources : List[str], optional + Sources containing folders, or CSVs with paths to audio files, + by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None eq_amount : tuple, optional Amount of equalization to apply, by default ("const", 1.0) n_bands : int, optional @@ -755,8 +756,8 @@ class BackgroundNoise(BaseTransform): def __init__( self, snr: tuple = ("uniform", 10.0, 30.0), - csv_files: List[str] = None, - csv_weights: List[float] = None, + sources: List[str] = None, + weights: List[float] = None, eq_amount: tuple = ("const", 1.0), n_bands: int = 3, name: str = None, @@ -768,7 +769,7 @@ def __init__( self.snr = snr self.eq_amount = eq_amount self.n_bands = n_bands - self.loader = AudioLoader(csv_files, csv_weights) + self.loader = AudioLoader(sources, weights) self.loudness_cutoff = loudness_cutoff def _instantiate(self, state: RandomState, signal: AudioSignal): @@ -776,13 +777,13 @@ def _instantiate(self, state: RandomState, signal: AudioSignal): eq = -eq_amount * state.rand(self.n_bands) snr = util.sample_from_dist(self.snr, state) - bg_signal, _ = self.loader( + bg_signal = self.loader( state, signal.sample_rate, duration=signal.signal_duration, loudness_cutoff=self.loudness_cutoff, num_channels=signal.num_channels, - ) + )["signal"] return {"eq": eq, "bg_signal": bg_signal, "snr": snr} @@ -804,10 +805,11 @@ class CrossTalk(BaseTransform): snr : tuple, optional How loud cross-talk speaker is relative to original signal in dB, by default ("uniform", 0.0, 10.0) - csv_files : List[str], optional - A list of files to load audio from, by default None - csv_weights : List[float], optional - Weights to sample audio files from each CSV, by default None + sources : List[str], optional + Sources containing folders, or CSVs with paths to audio files, + by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None name : str, optional Name of this transform, used to identify it in the dictionary produced by ``self.instantiate``, by default None @@ -820,8 +822,8 @@ class CrossTalk(BaseTransform): def __init__( self, snr: tuple = ("uniform", 0.0, 10.0), - csv_files: List[str] = None, - csv_weights: List[float] = None, + sources: List[str] = None, + weights: List[float] = None, name: str = None, prob: float = 1.0, loudness_cutoff: float = -40, @@ -829,18 +831,18 @@ def __init__( super().__init__(name=name, prob=prob) self.snr = snr - self.loader = AudioLoader(csv_files, csv_weights) + self.loader = AudioLoader(sources, weights) self.loudness_cutoff = loudness_cutoff def _instantiate(self, state: RandomState, signal: AudioSignal): snr = util.sample_from_dist(self.snr, state) - crosstalk_signal, _ = self.loader( + crosstalk_signal = self.loader( state, signal.sample_rate, duration=signal.signal_duration, loudness_cutoff=self.loudness_cutoff, num_channels=signal.num_channels, - ) + )["signal"] return {"crosstalk_signal": crosstalk_signal, "snr": snr} @@ -866,10 +868,11 @@ class RoomImpulseResponse(BaseTransform): ---------- drr : tuple, optional _description_, by default ("uniform", 0.0, 30.0) - csv_files : List[str], optional - A list of files to load audio from, by default None - csv_weights : List[float], optional - Weights to sample audio files from each CSV, by default None + sources : List[str], optional + Sources containing folders, or CSVs with paths to audio files, + by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None eq_amount : tuple, optional Amount of equalization to apply, by default ("const", 1.0) n_bands : int, optional @@ -890,8 +893,8 @@ class RoomImpulseResponse(BaseTransform): def __init__( self, drr: tuple = ("uniform", 0.0, 30.0), - csv_files: List[str] = None, - csv_weights: List[float] = None, + sources: List[str] = None, + weights: List[float] = None, eq_amount: tuple = ("const", 1.0), n_bands: int = 6, name: str = None, @@ -907,7 +910,7 @@ def __init__( self.n_bands = n_bands self.use_original_phase = use_original_phase - self.loader = AudioLoader(csv_files, csv_weights) + self.loader = AudioLoader(sources, weights) self.offset = offset self.duration = duration @@ -916,14 +919,14 @@ def _instantiate(self, state: RandomState, signal: AudioSignal = None): eq = -eq_amount * state.rand(self.n_bands) drr = util.sample_from_dist(self.drr, state) - ir_signal, _ = self.loader( + ir_signal = self.loader( state, signal.sample_rate, offset=self.offset, duration=self.duration, loudness_cutoff=None, num_channels=signal.num_channels, - ) + )["signal"] ir_signal.zero_pad_to(signal.sample_rate) return {"eq": eq, "ir_signal": ir_signal, "drr": drr} diff --git a/audiotools/metrics/distance.py b/audiotools/metrics/distance.py index 5e29e7b9..ce78739b 100644 --- a/audiotools/metrics/distance.py +++ b/audiotools/metrics/distance.py @@ -68,7 +68,7 @@ class SISDRLoss(nn.Module): def __init__( self, scaling: int = True, - reduction: str = " mean", + reduction: str = "mean", zero_mean: int = True, clip_min: int = None, weight: float = 1.0, diff --git a/audiotools/ml/tricks.py b/audiotools/ml/tricks.py index c3756a4c..51bb1c0a 100644 --- a/audiotools/ml/tricks.py +++ b/audiotools/ml/tricks.py @@ -1,3 +1,4 @@ +import collections from typing import List import numpy as np @@ -58,8 +59,16 @@ class AutoClip: How often to re-compute the clipping value. """ - def __init__(self, percentile: float = 10, frequency: int = 1, mask_nan: int = 0): + def __init__( + self, + percentile: float = 10, + frequency: int = 1, + mask_nan: int = 0, + max_history: int = None, + ): self.grad_history = [] + if max_history is not None: + self.grad_history = collections.deque([], maxlen=max_history) self.percentile = percentile self.frequency = frequency self.mask_nan = bool(mask_nan) diff --git a/docs/tutorials/transforms.md b/docs/tutorials/transforms.md index 2ca109da..77c397e4 100644 --- a/docs/tutorials/transforms.md +++ b/docs/tutorials/transforms.md @@ -102,11 +102,11 @@ distance = metrics.spectral.MelSpectrogramLoss() for transform_name in transforms_to_demo: kwargs = {} if transform_name == "BackgroundNoise": - kwargs["csv_files"] = ["../../tests/audio/noises.csv"] + kwargs["sources"] = ["../../tests/audio/noises.csv"] if transform_name == "RoomImpulseResponse": - kwargs["csv_files"] = ["../../tests/audio/irs.csv"] + kwargs["sources"] = ["../../tests/audio/irs.csv"] if transform_name == "CrossTalk": - kwargs["csv_files"] = ["../../tests/audio/spk.csv"] + kwargs["sources"] = ["../../tests/audio/spk.csv"] if "Quantization" in transform_name: kwargs["channels"] = ("choice", [8, 16, 32]) transform_cls = getattr(tfm, transform_name) @@ -297,7 +297,7 @@ kwargs = transform.batch_instantiate(seeds) pp.pprint(kwargs) ``` -There are now 4 cutoffs, and 4 mask values in the dictionary, instead of just 1 as before. Under the hood, the `batch_instantiate` function calls `instantiate` with every `seed` in `seeds`, and then collates the results using the `audiotools.util.collate` function. In practice, you'll likely use `audiotools.datasets.BaseDataset` instead to get a single item at a time, and then use the `collate` function as an argument to the torch `DataLoader`'s `collate_fn` argument. +There are now 4 cutoffs, and 4 mask values in the dictionary, instead of just 1 as before. Under the hood, the `batch_instantiate` function calls `instantiate` with every `seed` in `seeds`, and then collates the results using the `audiotools.util.collate` function. In practice, you'll likely use `audiotools.datasets.AudioDataset` instead to get a single item at a time, and then use the `collate` function as an argument to the torch `DataLoader`'s `collate_fn` argument. Alright, let's augment the entire batch at once, instead of in a for loop: @@ -541,19 +541,19 @@ class YourTransform(BaseTransform): There are two transforms which require a dataset to run. They are: -1. `BackgroundNoise`: takes a `csv_files` argument which points to a list of files that it can load background noise from. -2. `RoomImpulseResponse`: takes a `csv_files` argument which points to a list of files that it can load impulse response data from. +1. `BackgroundNoise`: takes a `sources` argument which points to a list of files that it can load background noise from. +2. `RoomImpulseResponse`: takes a `sources` argument which points to a list of files that it can load impulse response data from. Both of these transforms require an additional argument to their `instantiate` function: an `AudioSignal` object. They get instantiated like this: ```python seed = ... signal = ... -transform = tfm.BackgroundNoise(csv_files=["/tmp/noises.csv"]) +transform = tfm.BackgroundNoise(sources=["/tmp/noises.csv"]) transform.instantiate(seed, signal) ``` -The signal is used to load audio from the `csv_files` that is at the same +The signal is used to load audio from the `sources` that is at the same sample rate, the same number of channels, and (in the case of `BackgroundNoise`) the same duration as that of `signal`. ## Complete example @@ -580,8 +580,8 @@ preprocess.create_csv(util.find_audio(_path / "ir"), "/tmp/irs.csv") preprocess = tfm.VolumeChange(name="pre") process = tfm.Compose( [ - tfm.RoomImpulseResponse(csv_files=["/tmp/irs.csv"]), - tfm.BackgroundNoise(csv_files=["/tmp/noises.csv"]), + tfm.RoomImpulseResponse(sources=["/tmp/irs.csv"]), + tfm.BackgroundNoise(sources=["/tmp/noises.csv"]), tfm.ClippingDistortion(), tfm.MuLawQuantization(), tfm.LowPass(prob=0.5), diff --git a/examples/train_classifier.py b/examples/train_classifier.py index 2528a085..384fb1ff 100644 --- a/examples/train_classifier.py +++ b/examples/train_classifier.py @@ -30,17 +30,18 @@ def forward(self, signal: AudioSignal): def build_dataset( sample_rate: int = 44100, duration: float = 0.5, - csv_files: List[str] = ["tests/audio/spk.csv", "tests/audio/noises.csv"], + sources: List[str] = ["tests/audio/spk.csv", "tests/audio/noises.csv"], ): - num_classes = len(csv_files) + num_classes = len(sources) transform = tfm.Compose( - tfm.RoomImpulseResponse(csv_files=["tests/audio/irs.csv"]), + tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), tfm.LowPass(prob=0.5), tfm.ClippingDistortion(prob=0.1), ) - dataset = audiotools.datasets.CSVDataset( + loader = audiotools.datasets.AudioLoader(sources=sources) + dataset = audiotools.datasets.AudioDataset( + loader, sample_rate, - csv_files=csv_files, duration=duration, transform=transform, ) @@ -69,7 +70,7 @@ def train_loop(self, engine, batch): signal = batch["signal"] kwargs = batch["transform_args"] signal = train_data.transform(signal.clone(), **kwargs) - label = batch["label"] + label = batch["source_idx"] model.train() optimizer.zero_grad() diff --git a/examples/train_separator.py b/examples/train_separator.py index 295c69c6..dadd80be 100644 --- a/examples/train_separator.py +++ b/examples/train_separator.py @@ -1,11 +1,6 @@ -import random from pathlib import Path -from typing import Dict -from typing import List -from typing import Tuple import argbind -import librosa import torch import torchaudio from torch.utils.tensorboard import SummaryWriter @@ -29,32 +24,32 @@ def forward(self, signal: AudioSignal): @argbind.bind(without_prefix=True) def build_dataset( - sample_rate: int = 44100, + sample_rate: int = 8000, duration: float = 0.5, - csv_groups: List[str] = None, + musdb_path: str = "~/.data/musdb/", ): - - transform = { - "bass": tfm.Compose( - tfm.RoomImpulseResponse(csv_files=["tests/audio/irs.csv"]), - tfm.LowPass(prob=0.5), - tfm.ClippingDistortion(prob=0.1), - tfm.VolumeNorm(("uniform", -20, -10)), - ), - "drums": tfm.Compose( - tfm.RoomImpulseResponse(csv_files=["tests/audio/irs.csv"]), - tfm.ClippingDistortion(prob=0.1), - tfm.VolumeNorm(("uniform", -20, -10)), - ), + musdb_path = Path(musdb_path).expanduser() + loaders = { + src: audiotools.datasets.AudioLoader( + sources=[musdb_path], + transform=tfm.Compose( + tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), + tfm.LowPass(("const", 2000), prob=0.5), + tfm.ClippingDistortion(prob=0.1), + tfm.VolumeNorm(("uniform", -20, -10)), + ), + ext=[f"{src}.wav"], + ) + for src in ["vocals", "bass", "drums", "other"] } - dataset = audiotools.datasets.CSVMultiTrackDataset( + dataset = audiotools.datasets.AudioDataset( + loaders=loaders, sample_rate=sample_rate, - csv_groups=csv_groups, - transform=transform, duration=duration, + num_channels=1, ) - return dataset, dataset.source_names + return dataset, list(loaders.keys()) @argbind.bind(without_prefix=True) @@ -63,58 +58,45 @@ def train(accel, batch_size: int = 4): if accel.local_rank == 0: writer = SummaryWriter(log_dir="logs/") - # generate some fake data to train on - audiotools.util.generate_chord_dataset(max_voices=4, output_dir="chords") - - train_data, source_names = build_dataset( - csv_groups=[ - { - "voice_1": "chords/voice_0.csv", - "voice_2": "chords/voice_1.csv", - "voice_3": "chords/voice_2.csv", - "voice_4": "chords/voice_3.csv", - } - ] - ) - + train_data, sources = build_dataset() train_dataloader = accel.prepare_dataloader( - train_data, batch_size=batch_size, collate_fn=audiotools.util.collate + train_data, batch_size=batch_size, collate_fn=train_data.collate ) - model = accel.prepare_model(Model(num_sources=len(source_names))) + model = accel.prepare_model(Model(num_sources=len(sources))) optimizer = Adam(model.parameters()) - criterion = audiotools.metrics.spectral.MultiScaleSTFTLoss() + criterion = audiotools.metrics.distance.SISDRLoss() class Trainer(audiotools.ml.BaseTrainer): def train_loop(self, engine, batch): batch = audiotools.util.prepare_batch(batch, accel.device) - signals = batch["signals"] - tfm_kwargs = batch["transform_args"] - signals = { - k: train_data.transform[k](v, **tfm_kwargs[k]) - for k, v in signals.items() - } + for k in sources: + d = batch[k] + d["augmented"] = train_data.loaders[k].transform( + d["signal"].clone(), **d["transform_args"] + ) - mixture = sum(signals.values()) - sources = mixture.clone() - sources.audio_data = torch.concat( - [s.audio_data for s in signals.values()], dim=-2 - ) + mixture = sum(batch[k]["augmented"] for k in sources) + _targets = [batch[k]["signal"] for k in sources] + targets = mixture.clone() + targets.audio_data = torch.concat([s.audio_data for s in _targets], dim=-2) model.train() optimizer.zero_grad() - source_estimates = model(mixture) - loss = criterion(sources, source_estimates) + estimates = model(mixture) + loss = criterion(targets, estimates) loss.backward() optimizer.step() # log! if engine.state.iteration % 10 == 0: mixture.write_audio_to_tb("mixture", writer, engine.state.iteration) - for i, (k, v) in enumerate(signals.items()): - v.write_audio_to_tb(f"source/{k}", writer, engine.state.iteration) - source_estimates[i].detach().write_audio_to_tb( + for i, k in enumerate(sources): + batch[k]["signal"].write_audio_to_tb( + f"source/{k}", writer, engine.state.iteration + ) + estimates[i].detach().write_audio_to_tb( f"estimate/{k}", writer, engine.state.iteration ) diff --git a/setup.py b/setup.py index a09c1432..dee54b67 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="audiotools", - version="0.5.7", + version="0.6.0", classifiers=[ "Intended Audience :: Developers", "Intended Audience :: Education", diff --git a/tests/core/test_loudness.py b/tests/core/test_loudness.py index 35bd55c0..ae32f04d 100644 --- a/tests/core/test_loudness.py +++ b/tests/core/test_loudness.py @@ -241,11 +241,12 @@ def test_fir_accuracy(): transforms.Equalizer(prob=0.5), prob=0.5, ) - dataset = datasets.CSVDataset( + loader = datasets.AudioLoader(sources=["tests/audio/spk.csv"]) + dataset = datasets.AudioDataset( + loader, 44100, 10, 5.0, - csv_files=["tests/audio/spk.csv"], transform=transform, ) diff --git a/tests/core/test_util.py b/tests/core/test_util.py index 91538a29..e678267f 100644 --- a/tests/core/test_util.py +++ b/tests/core/test_util.py @@ -63,6 +63,9 @@ def test_find_audio(): audio_files = util.find_audio("tests/", ["flac"]) assert not audio_files + # Make sure it works with single audio files + audio_files = util.find_audio("tests/audio/spk//f10_script4_produced.wav") + def test_chdir(): with tempfile.TemporaryDirectory(suffix="tmp") as d: @@ -113,3 +116,31 @@ def _one_item(): assert collated["tensor"].shape[0] == batch_size assert len(collated["string"]) == batch_size assert collated["dict"]["nested_signal"].batch_size == batch_size + + # test collate with splitting (evenly) + batch_size = 16 + n_splits = 4 + + items = [_one_item() for _ in range(batch_size)] + collated = util.collate(items, n_splits=n_splits) + + for x in collated: + assert x["signal"].batch_size == batch_size // n_splits + assert x["tensor"].shape[0] == batch_size // n_splits + assert len(x["string"]) == batch_size // n_splits + assert x["dict"]["nested_signal"].batch_size == batch_size // n_splits + + # test collate with splitting (unevenly) + batch_size = 15 + n_splits = 4 + + items = [_one_item() for _ in range(batch_size)] + collated = util.collate(items, n_splits=n_splits) + + tlen = [4, 4, 4, 3] + + for x, t in zip(collated, tlen): + assert x["signal"].batch_size == t + assert x["tensor"].shape[0] == t + assert len(x["string"]) == t + assert x["dict"]["nested_signal"].batch_size == t diff --git a/tests/data/test_datasets.py b/tests/data/test_datasets.py index 71b6aba2..512a2186 100644 --- a/tests/data/test_datasets.py +++ b/tests/data/test_datasets.py @@ -1,168 +1,56 @@ +import tempfile +from pathlib import Path + import numpy as np import pytest import torch import audiotools -from audiotools.core import util from audiotools.data import transforms as tfm -def test_static_shared_args(): - dataset = audiotools.data.datasets.CSVDataset( - 44100, - n_examples=100, - csv_files=["tests/audio/spk.csv"], - ) - - for nw in (0, 1, 2): - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=1, - num_workers=nw, - collate_fn=dataset.collate, - ) - - targets = {"dur": [dataloader.dataset.duration], "sr": [44100]} - observed = {"dur": [], "sr": []} - - sample_rates = [8000, 16000, 44100] - - for batch in dataloader: - dur = np.random.rand() + 1.0 - sr = int(np.random.choice(sample_rates)) - - # Change attributes in the shared dict. - # Later we'll make sure they actually worked. - dataloader.dataset.duration = dur - dataloader.dataset.sample_rate = sr - - # Record observations from the batch and the signal. - targets["dur"].append(dur) - observed["dur"].append(batch["signal"].signal_duration) - - targets["sr"].append(sr) - observed["sr"].append(batch["signal"].sample_rate) - - # You aren't guaranteed that every requested attribute setting gets to every - # worker in time, but you can expect that every output attribute - # is in the requested attributes, and that it happens at least twice. - for k in targets: - _targets = targets[k] - _observed = observed[k] - - num_succeeded = 0 - for val in np.unique(_observed): - assert np.any(np.abs(np.array(_targets) - val) < 1e-3) - num_succeeded += 1 - - assert num_succeeded >= 2 - - -# This transform just adds the ID of the object, so we -# can see if it's the same across processes. -class IDTransform(audiotools.data.transforms.BaseTransform): - def __init__(self, id): - super().__init__(["id"]) - self.id = id - - def _instantiate(self, state): - return {"id": self.id} +def test_align_lists(): + input_lists = [ + ["a/1.wav", "b/1.wav", "c/1.wav", "d/1.wav"], + ["a/2.wav", "c/2.wav"], + ["c/3.wav"], + ] + target_lists = [ + ["a/1.wav", "b/1.wav", "c/1.wav", "d/1.wav"], + ["a/2.wav", "none", "c/2.wav", "none"], + ["none", "none", "c/3.wav", "none"], + ] + def _preprocess(lists): + output = [] + for x in lists: + output.append([]) + for y in x: + output[-1].append({"path": y}) + return output -def test_shared_transform(): - for nw in (0, 1, 2): - transform = IDTransform(1) - dataset = audiotools.data.datasets.CSVDataset( - 44100, - n_examples=10, - csv_files=["tests/audio/spk.csv"], - transform=transform, - ) - - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=1, - num_workers=nw, - collate_fn=dataset.collate, - ) - - targets = {"id": [transform.id]} - observed = {"id": []} - - for batch in dataloader: - kwargs = batch["transform_args"] - new_id = np.random.randint(100) - - # Create a new transform with a different ID. - # This gets propagated to all processes. - transform = IDTransform(new_id) - dataloader.dataset.transform = transform - - targets["id"].append(new_id) - observed["id"].append(kwargs["IDTransform"]["id"]) - - for k in targets: - _targets = [int(x) for x in targets[k]] - _observed = [int(x.item()) for x in observed[k]] - - num_succeeded = 0 - for val in np.unique(_observed): - assert any([x == val for x in _targets]) - num_succeeded += 1 - assert num_succeeded >= 2 - - -def test_batch_sampler(): - for nw in (0, 1, 2): - dataset = audiotools.data.datasets.CSVDataset( - 44100, - n_examples=100, - csv_files=["tests/audio/spk.csv"], - ) - - sampler = audiotools.datasets.BatchSampler( - audiotools.datasets.SequentialSampler(dataset), batch_size=1, drop_last=True - ) - dataloader = torch.utils.data.DataLoader( - dataset, - batch_sampler=sampler, - num_workers=nw, - collate_fn=dataset.collate, - ) + input_lists = _preprocess(input_lists) + target_lists = _preprocess(target_lists) - targets = {"bs": [1]} - observed = {"bs": []} + aligned_lists = audiotools.datasets.align_lists(input_lists) + assert target_lists == aligned_lists - for new_bs in [1, 5, 10]: - dataloader.batch_sampler.batch_size = new_bs - targets["bs"].append(new_bs) - for batch in dataloader: - actual_bs = batch["signal"].batch_size - observed["bs"].append(actual_bs) - - for k in targets: - _targets = [int(x) for x in targets[k]] - _observed = [int(x) for x in observed[k]] - - num_succeeded = 0 - for val in np.unique(_observed): - assert any([x == val for x in _targets]) - num_succeeded += 1 - assert num_succeeded >= 2 - - -def test_csv_dataset(): +def test_audio_dataset(): transform = tfm.Compose( [ tfm.VolumeNorm(), tfm.Silence(prob=0.5), ], ) - dataset = audiotools.data.datasets.CSVDataset( + loader = audiotools.data.datasets.AudioLoader( + sources=["tests/audio/spk.csv"], + transform=transform, + ) + dataset = audiotools.data.datasets.AudioDataset( + loader, 44100, n_examples=100, - csv_files=["tests/audio/spk.csv"], transform=transform, ) dataloader = torch.utils.data.DataLoader( @@ -188,147 +76,51 @@ def test_csv_dataset(): assert torch.allclose(signal[~mask].audio_data, original_) -def test_multitrack_incoherent_dataset(): - from audiotools.data.datasets import CSVMultiTrackDataset, MultiTrackAudioLoader - from audiotools.core.util import generate_chord_dataset - - generate_chord_dataset(max_voices=4, num_items=3, output_dir="tests/audio/chords") - dataset = CSVMultiTrackDataset( - sample_rate=44100, - n_examples=20, - csv_groups=[ - { - "voice_0": "tests/audio/chords/voice_0.csv", - "voice_1": "tests/audio/chords/voice_0.csv", - "voice_2": "tests/audio/chords/voice_0.csv", - "coherence": 0.0, - }, - ], - ) - for i in range(10): - item = dataset[i] - assert len(item["signals"]) == 3 - assert ( - item["signals"]["voice_0"].path_to_file - != item["signals"]["voice_1"].path_to_file +def test_aligned_audio_dataset(): + with tempfile.TemporaryDirectory() as d: + dataset_dir = Path(d) + audiotools.util.generate_chord_dataset( + max_voices=8, num_items=3, output_dir=dataset_dir ) - - -@pytest.mark.parametrize( - "source_transforms", - [ - None, - { - "voice_0": tfm.Compose([tfm.VolumeNorm(), tfm.Silence(prob=0.5)]), - "voice_1": tfm.VolumeNorm(), - "voice_2": tfm.VolumeNorm(), - "voice_3": tfm.VolumeNorm(), - }, - {"voice_0": tfm.VolumeNorm(), "voice_2": tfm.VolumeNorm()}, - ], -) -def test_multitrack_dataset(source_transforms): - from audiotools.data.datasets import CSVMultiTrackDataset, MultiTrackAudioLoader - from pathlib import Path - - dataset_dir = Path("tests/audio/chords") - - from audiotools.core.util import generate_chord_dataset - - generate_chord_dataset(max_voices=4, num_items=3, output_dir=dataset_dir) - - # wrong primary key - with pytest.raises(ValueError): - MultiTrackAudioLoader( - [ - { - "irs": "tests/audio/irs.csv", - "primary_key": "voice_0", - } - ], + loaders = [ + audiotools.data.datasets.AudioLoader([dataset_dir / f"track_{i}"]) + for i in range(3) + ] + dataset = audiotools.data.datasets.AudioDataset( + loaders, 44100, n_examples=1000, aligned=True, shuffle_loaders=True + ) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=16, + num_workers=0, + collate_fn=dataset.collate, ) - from copy import deepcopy - - dataset = CSVMultiTrackDataset( - sample_rate=44100, - n_examples=20, - csv_groups=[ - { - "voice_0": "tests/audio/chords/voice_0.csv", - "voice_1": "tests/audio/chords/voice_1.csv", - "voice_2": "tests/audio/chords/voice_2.csv", - "voice_3": "tests/audio/chords/voice_3.csv", - "empty": "tests/audio/empty.csv", - "primary_key": "voice_2", - }, - { - "voice_0": "tests/audio/chords/voice_0.csv", - }, - ], - transform=deepcopy(source_transforms), - ) - - assert set(dataset.source_names) == set( - ["voice_0", "voice_1", "voice_2", "voice_3", "empty"] - ) - assert dataset.primary_keys == ["voice_2", "voice_0"] - - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=16, - num_workers=0, - collate_fn=dataset.collate, - ) - - dataset.mix_transform = tfm.Identity() - assert isinstance(dataset.transform["mix"], tfm.Identity) - - for batch in dataloader: - kwargs = batch["transform_args"] - signals = batch["signals"] - - tfmed = { - k: dataset.transform[k](sig.clone(), **kwargs[k]) - for k, sig in signals.items() - } - mix = sum(tfmed.values()) - - mix_tfm_args = batch["mix_transform_args"] - mix_tfmed = dataset.mix_transform(mix.clone(), **mix_tfm_args) - - # test that if the csv has a 'duration' column, it does NOT - # overwrite the actual duration grabbed by the loader - dataset = CSVMultiTrackDataset( - sample_rate=44100, - n_examples=5, - duration=2.0, - csv_groups=[ - { - "empty": "tests/audio/empty.csv", - "empty2": "tests/audio/empty.csv", - "primary_key": "empty", - }, - {"empty3": "tests/audio/empty.csv"}, - ], - ) - - for i in range(len(dataset)): - item = dataset[i] - assert dataset.duration == item["signals"]["empty"].duration + # Make sure the voice tracks are aligned. + for batch in dataloader: + paths = [] + for i in range(len(loaders)): + _paths = [p.split("/")[-1] for p in batch[i]["path"]] + paths.append(_paths) + paths = np.array(paths) + for i in range(paths.shape[1]): + col = paths[:, i] + col = col[col != "none"] + assert np.all(col == col[0]) def test_dataset_pipeline(): transform = tfm.Compose( [ - tfm.RoomImpulseResponse(csv_files=["tests/audio/irs.csv"]), - tfm.BackgroundNoise(csv_files=["tests/audio/noises.csv"]), + tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), + tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]), ] ) - dataset = audiotools.data.datasets.CSVDataset( + loader = audiotools.data.datasets.AudioLoader(sources=["tests/audio/spk.csv"]) + dataset = audiotools.data.datasets.AudioDataset( + loader, 44100, - 10, - csv_files=["tests/audio/spk.csv"], + n_examples=10, transform=transform, ) dataloader = torch.utils.data.DataLoader( diff --git a/tests/data/test_preprocess.py b/tests/data/test_preprocess.py index 059c9f71..d5e25a8b 100644 --- a/tests/data/test_preprocess.py +++ b/tests/data/test_preprocess.py @@ -2,7 +2,7 @@ from pathlib import Path from audiotools.core.util import find_audio -from audiotools.core.util import read_csv +from audiotools.core.util import read_sources from audiotools.data import preprocess @@ -21,7 +21,7 @@ def test_create_csv_with_empty_rows(): with tempfile.NamedTemporaryFile(suffix=".csv") as f: preprocess.create_csv(audio_files, f.name, loudness=True) - audio_files = read_csv([f.name], remove_empty=True) + audio_files = read_sources([f.name], remove_empty=True) assert len(audio_files[0]) == 1 - audio_files = read_csv([f.name], remove_empty=False) + audio_files = read_sources([f.name], remove_empty=False) assert len(audio_files[0]) == 3 diff --git a/tests/data/test_transforms.py b/tests/data/test_transforms.py index c12c8335..02d709b4 100644 --- a/tests/data/test_transforms.py +++ b/tests/data/test_transforms.py @@ -8,7 +8,7 @@ from audiotools import AudioSignal from audiotools import util from audiotools.data import transforms as tfm -from audiotools.data.datasets import CSVDataset +from audiotools.data.datasets import AudioDataset non_deterministic_transforms = ["TimeNoise", "FrequencyNoise"] transforms_to_test = [] @@ -39,11 +39,11 @@ def test_transform(transform_name): kwargs = {} if transform_name == "BackgroundNoise": - kwargs["csv_files"] = ["tests/audio/noises.csv"] + kwargs["sources"] = ["tests/audio/noises.csv"] if transform_name == "RoomImpulseResponse": - kwargs["csv_files"] = ["tests/audio/irs.csv"] + kwargs["sources"] = ["tests/audio/irs.csv"] if transform_name == "CrossTalk": - kwargs["csv_files"] = ["tests/audio/spk.csv"] + kwargs["sources"] = ["tests/audio/spk.csv"] audio_path = "tests/audio/spk/f10_script4_produced.wav" signal = AudioSignal(audio_path, offset=10, duration=2) @@ -92,8 +92,8 @@ def test_compose_basic(): signal = AudioSignal(audio_path, offset=10, duration=2) transform = tfm.Compose( [ - tfm.RoomImpulseResponse(csv_files=["tests/audio/irs.csv"]), - tfm.BackgroundNoise(csv_files=["tests/audio/noises.csv"]), + tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), + tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]), ], ) @@ -204,8 +204,8 @@ def test_choose_basic(): signal = AudioSignal(audio_path, offset=10, duration=2) transform = tfm.Choose( [ - tfm.RoomImpulseResponse(csv_files=["tests/audio/irs.csv"]), - tfm.BackgroundNoise(csv_files=["tests/audio/noises.csv"]), + tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), + tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]), ] ) @@ -359,7 +359,7 @@ def test_masking(): dataset, batch_size=16, num_workers=0, - collate_fn=audiotools.data.datasets.BaseDataset.collate, + collate_fn=util.collate, ) for batch in dataloader: signal = batch.pop("signal") @@ -385,10 +385,11 @@ def test_nested_masking(): prob=0.9, ) - dataset = CSVDataset( + loader = audiotools.data.datasets.AudioLoader(sources=["tests/audio/spk.csv"]) + dataset = audiotools.data.datasets.AudioDataset( + loader, 44100, - 100, - csv_files=["tests/audio/spk.csv"], + n_examples=100, transform=transform, ) dataloader = torch.utils.data.DataLoader( diff --git a/tests/ml/test_tricks.py b/tests/ml/test_tricks.py index c4ce238d..0cf221d1 100644 --- a/tests/ml/test_tricks.py +++ b/tests/ml/test_tricks.py @@ -57,6 +57,8 @@ def test_autoclip(): assert np.allclose(state_dict["grad_history"], grad_history) autoclip.load_state_dict(state_dict) + autoclip = ml.tricks.AutoClip(0, max_history=100) + def test_autobalance(): losses = torch.randn(10).abs().tolist() diff --git a/tests/profilers/profile_loudness.py b/tests/profilers/profile_loudness.py index d9713415..460a5724 100644 --- a/tests/profilers/profile_loudness.py +++ b/tests/profilers/profile_loudness.py @@ -10,7 +10,8 @@ from audiotools import AudioSignal from audiotools.core import util -from audiotools.data.datasets import CSVDataset +from audiotools.data.datasets import AudioDataset +from audiotools.data.datasets import AudioLoader def collate(list_of_dicts): @@ -30,11 +31,12 @@ def collate(list_of_dicts): def run(batch_size=64, duration=5.0, device="cuda"): - dataset = CSVDataset( + loader = AudioLoader(sources=["tests/audio/spk.csv"]) + dataset = AudioDataset( + loader, 44100, 10 * batch_size, duration, - csv_files=["tests/audio/spk.csv"], ) dataloader = torch.utils.data.DataLoader( dataset, num_workers=16, batch_size=batch_size, collate_fn=collate diff --git a/tests/profilers/profile_speed.py b/tests/profilers/profile_speed.py index 9a776fe1..24994830 100644 --- a/tests/profilers/profile_speed.py +++ b/tests/profilers/profile_speed.py @@ -9,7 +9,8 @@ from audiotools import AudioSignal from audiotools.core import util from audiotools.data import transforms as tfm -from audiotools.data.datasets import CSVDataset +from audiotools.data.datasets import AudioDataset +from audiotools.data.datasets import AudioLoader def run(batch_size=64, duration=5.0, device="cuda"): @@ -19,11 +20,12 @@ def run(batch_size=64, duration=5.0, device="cuda"): tfm.BackgroundNoise(csv_files=["tests/audio/noises.csv"]), ] ) - dataset = CSVDataset( + loader = AudioLoader(sources=["tests/audio/spk.csv"]) + dataset = AudioDataset( + loader, 44100, - 1000, - duration, - csv_files=["tests/audio/spk.csv"], + n_examples=1000, + duration=duration, transform=transform, ) dataloader = torch.utils.data.DataLoader( diff --git a/tests/profilers/profile_transforms.py b/tests/profilers/profile_transforms.py index 3a98e307..a2c8087e 100644 --- a/tests/profilers/profile_transforms.py +++ b/tests/profilers/profile_transforms.py @@ -9,7 +9,8 @@ from audiotools import AudioSignal from audiotools.core import util from audiotools.data import transforms as tfm -from audiotools.data.datasets import CSVDataset +from audiotools.data.datasets import AudioDataset +from audiotools.data.datasets import AudioLoader transforms_to_demo = [] for x in dir(tfm): @@ -24,20 +25,21 @@ def run(batch_size=64, duration=5.0, device="cuda"): for transform_name in track(transforms_to_demo): kwargs = {} if transform_name == "BackgroundNoise": - kwargs["csv_files"] = ["tests/audio/noises.csv"] + kwargs["sources"] = ["tests/audio/noises.csv"] if transform_name == "RoomImpulseResponse": - kwargs["csv_files"] = ["tests/audio/irs.csv"] + kwargs["sources"] = ["tests/audio/irs.csv"] if "Quantization" in transform_name: kwargs["channels"] = ("choice", [8, 16, 32]) transform_cls = getattr(tfm, transform_name) t = transform_cls(prob=1.0, **kwargs) - dataset = CSVDataset( + loader = AudioLoader(sources=["tests/audio/spk.csv"]) + dataset = AudioDataset( + loader, 44100, batch_size * 10, duration, - csv_files=["tests/audio/spk.csv"], transform=t, ) dataloader = torch.utils.data.DataLoader(