Skip to content

Commit

Permalink
Refactoring the data pipeline. (#71)
Browse files Browse the repository at this point in the history
* Got the main stuff done, still a bunch of tests to write.

* propagate find_audio args

* Coverage on datasets.

* Fixing test, adding docs.

* Adding to find_audio test.

* Fixing train_separator script

* Fixing

* Fixing sisdr loss and train_separator.

* Examples working

* Rewriting how things get aligned.

* Bigger version bump

* Adding test

* Adding max history to autoclip

* Adding something for gradient accumulation in the collater.

* Fixing comment

* Removing comment

---------

Co-authored-by: pseeth <[email protected]>
  • Loading branch information
pseeth and pseeth authored Jan 31, 2023
1 parent 59cf819 commit 5f31615
Show file tree
Hide file tree
Showing 22 changed files with 669 additions and 1,182 deletions.
2 changes: 1 addition & 1 deletion audiotools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.5.7"
__version__ = "0.6.0"
from .core import AudioSignal
from .core import STFTParams
from .core import Meter
Expand Down
7 changes: 5 additions & 2 deletions audiotools/core/audio_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def batch(
pad_signals: bool = False,
truncate_signals: bool = False,
resample: bool = False,
dim: int = 0,
):
"""Creates a batched AudioSignal from a list of AudioSignals.
Expand All @@ -398,6 +399,8 @@ def batch(
resample : bool, optional
Whether to resample AudioSignal to the sample rate of
the first AudioSignal in the list, by default False
dim : int, optional
Dimension along which to batch the signals.
Returns
-------
Expand Down Expand Up @@ -453,8 +456,8 @@ def batch(
f"All signals must be the same length, or pad_signals/truncate_signals "
f"must be True. "
)
# Concatenate along the batch dimension
audio_data = torch.cat([x.audio_data for x in audio_signals], dim=0)
# Concatenate along the specified dimension (default 0)
audio_data = torch.cat([x.audio_data for x in audio_signals], dim=dim)
audio_paths = [x.path_to_file for x in audio_signals]

batched_signal = cls(
Expand Down
2 changes: 1 addition & 1 deletion audiotools/core/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def specshow(
log_mag = signal.log_magnitude(ref_value=ref)

if y_axis == "mel":
log_mag = 10 * signal.mel_spectrogram(n_mels).pow(2).clamp(1e-5).log10()
log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10()
log_mag -= log_mag.max()

librosa.display.specshow(
Expand Down
112 changes: 76 additions & 36 deletions audiotools/core/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import csv
import math
import numbers
import os
import random
Expand Down Expand Up @@ -212,7 +213,7 @@ def _close():
_close()


AUDIO_EXTENSIONS = ["wav", "flac", "mp3", "mp4"]
AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"]


def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS):
Expand All @@ -225,23 +226,35 @@ def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS):
Folder to look for audio files in, recursively.
ext : List[str], optional
Extensions to look for without the ., by default
``['wav', 'flac', 'mp3', 'mp4']``.
``['.wav', '.flac', '.mp3', '.mp4']``.
"""
folder = Path(folder)
# Take care of case where user has passed in an audio file directly
# into one of the calling functions.
if str(folder).endswith(tuple(ext)):
return [folder]
files = []
for x in ext:
files += folder.glob(f"**/*.{x}")
files += folder.glob(f"**/*{x}")
return files


def read_csv(filelists: List[str], remove_empty: bool = True):
"""Reads CSVs that are generated by
def read_sources(
sources: List[str],
remove_empty: bool = True,
relative_path: str = "",
ext: List[str] = AUDIO_EXTENSIONS,
):
"""Reads audio sources that can either be folders
full of audio files, or CSV files that contain paths
to audio files. CSV files that adhere to the expected
format can be generated by
:py:func:`audiotools.data.preprocess.create_csv`.
Parameters
----------
filelists : List[str]
List of CSV files to be converted into a
sources : List[str]
List of audio sources to be converted into a
list of lists of audio files.
remove_empty : bool, optional
Whether or not to remove rows with an empty "path"
Expand All @@ -253,18 +266,24 @@ def read_csv(filelists: List[str], remove_empty: bool = True):
List of lists of rows of CSV files.
"""
files = []
data_path = Path(os.getenv("PATH_TO_DATA", ""))
for filelist in filelists:
with open(filelist, "r") as f:
reader = csv.DictReader(f)
_files = []
for x in reader:
if remove_empty and x["path"] == "":
continue
if x["path"] != "":
x["path"] = str(data_path / x["path"])
_files.append(x)
files.append(_files)
relative_path = Path(relative_path)
for source in sources:
source = str(source)
_files = []
if source.endswith(".csv"):
with open(source, "r") as f:
reader = csv.DictReader(f)
for x in reader:
if remove_empty and x["path"] == "":
continue
if x["path"] != "":
x["path"] = str(relative_path / x["path"])
_files.append(x)
else:
for x in find_audio(source, ext=ext):
x = str(relative_path / x)
_files.append({"path": x})
files.append(sorted(_files, key=lambda x: x["path"]))
return files


Expand All @@ -287,9 +306,9 @@ def choose_from_list_of_lists(
typing.Any
An item from the list of lists.
"""
idx = state.choice(list(range(len(list_of_lists))), p=p)
item_idx = state.randint(len(list_of_lists[idx]))
return list_of_lists[idx][item_idx], idx
source_idx = state.choice(list(range(len(list_of_lists))), p=p)
item_idx = state.randint(len(list_of_lists[source_idx]))
return list_of_lists[source_idx][item_idx], source_idx, item_idx


@contextmanager
Expand Down Expand Up @@ -392,39 +411,60 @@ def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None):
return dist_fn(*dist_tuple[1:])


def collate(list_of_dicts: list):
def collate(list_of_dicts: list, n_splits: int = None):
"""Collates a list of dictionaries (e.g. as returned by a
dataloader) into a dictionary with batched values. This routine
uses the default torch collate function for everything
except AudioSignal objects, which are handled by the
:py:func:`audiotools.core.audio_signal.AudioSignal.batch`
function.
This function takes n_splits to enable splitting a batch
into multiple sub-batches for the purposes of gradient accumulation,
etc.
Parameters
----------
list_of_dicts : list
List of dictionaries to be collated.
n_splits : int
Number of splits to make when creating the batches (split into
sub-batches). Useful for things like gradient accumulation.
Returns
-------
dict
Dictionary containing batched data.
"""

from . import AudioSignal

# Flatten the dictionaries to avoid recursion.
list_of_dicts = [flatten(d) for d in list_of_dicts]
dict_of_lists = {k: [dic[k] for dic in list_of_dicts] for k in list_of_dicts[0]}
batches = []
list_len = len(list_of_dicts)

batch = {}
for k, v in dict_of_lists.items():
if isinstance(v, list):
if all(isinstance(s, AudioSignal) for s in v):
batch[k] = AudioSignal.batch(v, pad_signals=True)
else:
# Borrow the default collate fn from torch.
batch[k] = torch.utils.data._utils.collate.default_collate(v)
return unflatten(batch)
return_list = False if n_splits is None else True
n_splits = 1 if n_splits is None else n_splits
n_items = int(math.ceil(list_len / n_splits))

for i in range(0, list_len, n_items):
# Flatten the dictionaries to avoid recursion.
list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]]
dict_of_lists = {
k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0]
}

batch = {}
for k, v in dict_of_lists.items():
if isinstance(v, list):
if all(isinstance(s, AudioSignal) for s in v):
batch[k] = AudioSignal.batch(v, pad_signals=True)
else:
# Borrow the default collate fn from torch.
batch[k] = torch.utils.data._utils.collate.default_collate(v)
batches.append(unflatten(batch))

batches = batches[0] if not return_list else batches
return batches


BASE_SIZE = 864
Expand Down Expand Up @@ -614,6 +654,6 @@ def generate_chord_dataset(
voice_lists[voice_name].append("")

for voice_name, paths in voice_lists.items():
create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True, data_path="")
create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True)

return output_dir
Loading

0 comments on commit 5f31615

Please sign in to comment.