Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added detailed documentation to dataset.py #172

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 201 additions & 20 deletions pyha_analyzer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
get_datasets returns the train and validation datasets as BirdCLEFDataset objects.

If this module is run directly, it tests that the dataloader works

"""
import logging
import os
Expand Down Expand Up @@ -34,8 +33,62 @@
# pylint: disable=too-many-instance-attributes
class PyhaDFDataset(Dataset):
"""
Dataset designed to work with pyha output
Save unchunked data
A class for loading the dataset and creates dataloaders for training and validation

This class represents a dataset designed to work with pyha output, and saves
unchunked data. Contains methods for loading the dataset.

Attributes:
samples (pandas.DataFrame):
- filtered dataframe which contains non-null values
in the "FILE NAME" column.

num_samples (int):
- how many samples the dataset contains.

train (bool):
- whether the dataset is a training set or not.

device (str):
- which device the computations will occur.
possible options include "cuda", "mps", and "cpu"

onehot (bool):
- whether the dataset has been one-hot encoded.

cfg (pyha_analyzer.config.Config):
- configuration settings of the dataset.

data_dir (set[str]):
- collection of paths of audio files.

bad_files (list[int]):
- indices of bad files.

classes (list[str]):
- list of all species to be classified.

class_to_idx (dict[str, int]):
- dictionary with each string in classes as keys
and an assigned int index value as value

num_classes (int):
- the number of species (classes) in the dataset

convert_to_mel (torchaudio.transforms.MelSpectrogram):
- transformation object that converts raw waveforms into mel spectrograms.

decibel_convert (torchaudio.transforms.AmplitudeToDB):
- transformation that converts raw waveforms into decibel scale.

mixup (pyha_analyzer.augmentations.Mixup):
- torch.nn.Module object that mixes up the dataset for data augmentation.

audio_augmentations (torch.nn.Sequential):
- pipeline for augmenting audio files.

image_augmentations (torch.nn.Sequential):
- pipeline for augmenting spectrogramimages.
"""

# df, train, and species decided outside of config, so those cannot be added in there
Expand All @@ -47,6 +100,30 @@ def __init__(self,
cfg: config.Config,
onehot:bool = False,
) -> None:
"""
Initializes a PyhaDFDataset with the given attributes.

Args:
df (pandas.DataFrame):
- dataframe of data contained in this object.

train (bool):
- whether the data is the training set data or not.

species (list[str]):
- a list of strings representing each species identified
in this dataset.

cfg (pyha_analyzer.config.Config):
- configuration settings of the dataset.

onehot (bool):
- whether the data has been one-hot encoded or not, dafaulted
to False.

Returns:
None
"""
self.samples = df[~(df[cfg.file_name_col].isnull())]
if onehot:
if self.samples.iloc[0][species].shape[0] != len(species):
Expand Down Expand Up @@ -80,6 +157,7 @@ def __init__(self,
self.num_classes = len(species)
self.serialize_data()


self.class_dist = self.calc_class_distribution()

#Data augmentations
Expand Down Expand Up @@ -107,7 +185,14 @@ def __init__(self,
RandomApply([audtr.TimeMasking(cfg.time_mask_param)], p=cfg.time_mask_p))

def calc_class_distribution(self) -> torch.Tensor:
""" Returns class distribution (number of samples per class) """
"""
Returns class distribution (number of samples per class).

Returns:
class_dist (torch.Tensor):
- a 1d Torch Tensor representing the amount of samples
in each class.
"""
class_dist = []
if self.onehot:
for class_name in self.classes:
Expand All @@ -125,7 +210,10 @@ def calc_class_distribution(self) -> torch.Tensor:

def verify_audio(self) -> None:
"""
Checks to make sure files exist that are referenced in input df
Checks to make sure files exist that are referenced in input df.

Returns:
None
"""
missing_files = pd.Series(self.samples[self.cfg.file_name_col].unique()) \
.progress_apply(
Expand All @@ -143,7 +231,19 @@ def verify_audio(self) -> None:

def process_audio_file(self, file_name: str) -> pd.Series:
"""
Save waveform of audio file as a tensor and save that tensor to .pt
Save waveform of audio file as a tensor and save that tensor to .pt.

Args:
file_name (str):
- name of an audio file

Returns:
- Pandas Series of the original file name and the new file name.
- If the audio file has already been processed, does not save and
simply returns the original file name and the "supposed" new file
name.
- If an exception occurs, returns the original file location and
"bad" as the new location.
"""
exts = "." + file_name.split(".")[-1]
new_name = file_name.replace(exts, ".pt")
Expand Down Expand Up @@ -193,9 +293,12 @@ def process_audio_file(self, file_name: str) -> pd.Series:

def serialize_data(self) -> None:
"""
For each file, check to see if the file is already a presaved tensor
If the files is not a presaved tensor and is an audio file, convert to tensor to make
Future training faster
For each file, check to see if the file is already a presaved tensor
If the files is not a presaved tensor and is an audio file, convert
to tensor to make future training faster

Returns:
None
"""
self.verify_audio()
files = pd.DataFrame(self.samples[self.cfg.file_name_col].unique(),
Expand Down Expand Up @@ -225,11 +328,24 @@ def serialize_data(self) -> None:
self.samples["original_file_path"] = self.samples[self.cfg.file_name_col]

def __len__(self):
"""
Returns how many elements are in the sample DataFrame.

Returns:
- The number of elements in the sample DataFrame
"""
return self.samples.shape[0]

def to_image(self, audio):
"""
Convert audio clip to 3-channel spectrogram image
Convert audio clip to 3-channel spectrogram image

Args:
audio (torch.Tensor):
- torch tensor that represents the audio clip as a raw waveform

Returns:
- torch tensor that represents the audioclip as a mel spectrogram
"""
# Mel spectrogram
# Pylint complains this is not callable, but it is a torch.nn.Module
Expand All @@ -250,7 +366,20 @@ def to_image(self, audio):
return torch.stack([mel, mel, mel])

def __getitem__(self, index): #-> Any:
""" Takes an index and returns tuple of spectrogram image with corresponding label
"""
Takes an index and returns tuple of spectrogram image with corresponding label
Args:
index (int):
- index of the item

Returns:
tuple: a tuple containing:
image (torch.Tensor):
- torch tensor representing the mel spectrogram
image at the index
target (torch.Tensor):
- torch tensor representing the one-hot encoded label of
the image
"""
assert isinstance(index, int)
audio, target = utils.get_annotation(
Expand Down Expand Up @@ -282,14 +411,24 @@ def __getitem__(self, index): #-> Any:
return image, target

def get_num_classes(self) -> int:
""" Returns number of classes
"""
Returns number of classes

Returns:
- the "num_classes" attribute, which represents the number of classes
in the dataset
"""
return self.num_classes

def get_sample_weights(self) -> pd.Series:
""" Returns the weights as computed by the first place winner of BirdCLEF 2023
See https://www.kaggle.com/competitions/birdclef-2023/discussion/412808
Congrats on your win!
"""
Returns the weights as computed by the first place winner of BirdCLEF 2023
See https://www.kaggle.com/competitions/birdclef-2023/discussion/412808
Congrats on your win!

Returns:
- a pandas.Series object that represents the weights as computed by
the first place winner of BirdCLEF 2023.
"""
manual_id = self.cfg.manual_id_col
all_primary_labels = self.samples[manual_id]
Expand All @@ -302,9 +441,20 @@ def get_sample_weights(self) -> pd.Series:


def get_datasets(cfg) -> Tuple[PyhaDFDataset, PyhaDFDataset, Optional[PyhaDFDataset]]:
""" Returns train and validation datasets
does random sampling for train/valid split
adds transforms to dataset
"""
Returns train and validation datasets
does random sampling for train/valid split
adds transforms to dataset

Args:
cfg (pyha_analyzer.config.Config):
- configuration settings of the dataset.

Returns:
tuple: a tuple containing:
- train_ds (PyhaDFDataset): the training dataset
- valid_ds (PyhaDFDataset): the validation dataset
- infer_ds (Optional[PyhaDFDataset]): the inference/test dataset
"""
train_p = cfg.train_test_split
path = cfg.dataframe_csv
Expand Down Expand Up @@ -392,7 +542,13 @@ def get_datasets(cfg) -> Tuple[PyhaDFDataset, PyhaDFDataset, Optional[PyhaDFData

def set_torch_file_sharing(_) -> None:
"""
Sets torch.multiprocessing to use file sharing
Sets torch.multiprocessing to use file sharing

Args:
_ (any): placeholder parameter.

Returns:
None
"""
torch.multiprocessing.set_sharing_strategy("file_system")

Expand All @@ -401,6 +557,28 @@ def make_dataloaders(train_dataset, val_dataset, infer_dataset, cfg
)-> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
"""
Loads datasets and dataloaders for train and validation

Args:
train_dataset (PyhaDFDataset):
- the training dataset

val_dataset (PyhaDFDataset):
- the validation dataset

infer_dataset (PyhaDFDataset):
- the inference/test dataset

cfg (pyha_analyzer.config.Config):
- configuration settings of the dataset

Returns:
tuple: a tuple containing:
- train_dataloader (torch.utils.data.Dataloader): dataloader for the
training set
- val_dataloader (torch.utils.data.Dataloader): dataloader for the
validation set
- infer_dataloader (Optional[torch.utils.data.Dataloader]): dataloader
for the inference/test dataset
"""


Expand Down Expand Up @@ -452,7 +630,10 @@ def make_dataloaders(train_dataset, val_dataset, infer_dataset, cfg

def main() -> None:
"""
testing function.
testing function.

Returns:
None
"""
# run = wandb.init(
# entity=cfg.wandb_entity,
Expand Down
Loading