diff --git a/pyha_analyzer/augmentations.py b/pyha_analyzer/augmentations.py index cf33262..a656502 100644 --- a/pyha_analyzer/augmentations.py +++ b/pyha_analyzer/augmentations.py @@ -5,7 +5,7 @@ import logging import os from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Iterable +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import numpy as np import pandas as pd @@ -16,6 +16,12 @@ logger = logging.getLogger("acoustic_multiclass_training") +def get_training_proportion(): + """ Returns proportion of training done """ + total_epochs = config.cfg.epochs + current_epoch = config.cfg.current_epoch + return current_epoch/total_epochs + def invert(seq: Iterable[int]) -> List[float]: """ Replace each element in list with its inverse @@ -24,14 +30,22 @@ def invert(seq: Iterable[int]) -> List[float]: raise ValueError('Passed iterable cannot contain zero') return [1/x for x in seq] +def get_unnormed_probabilities(seq: Iterable[int]) -> Iterable[float]: + """ + Get probabilities for each element in seq + Spread changes over time due to curriculum learning + """ + power = 2 * (0.9 - get_training_proportion()) + return [1/(x**power) for x in seq] + def hyperbolic(seq: Iterable[int]) -> List[Tuple[float, int]]: """ Takes a list of numbers and assigns them a probability distribution accourding to the inverse of their values """ - invert_seq = invert(seq) - norm_factor = sum(invert_seq) - probabilities = [x/norm_factor for x in invert_seq] + unnormed_probabilities = get_unnormed_probabilities(seq) + norm_factor = sum(unnormed_probabilities) + probabilities = [x/norm_factor for x in unnormed_probabilities] return list(zip(probabilities, seq)) def sample(distribution: List[Tuple[float, int]]) -> int: @@ -217,6 +231,7 @@ def __init__(self, cfg: config.Config): self.noise_type = cfg.noise_type self.alpha = cfg.noise_alpha self.device = cfg.prepros_device + self.cfg = cfg def forward(self, clip: torch.Tensor)->torch.Tensor: """ @@ -225,9 +240,12 @@ def forward(self, clip: torch.Tensor)->torch.Tensor: Returns: Clip mixed with noise according to noise_type and alpha """ + alpha = (self.alpha + * get_training_proportion() + * self.cfg.curriculum_learning_scale_factor) noise_function = self.noise_names[self.noise_type] noise = noise_function(len(clip)).to(self.device) - return (1 - self.alpha) * clip + self.alpha* noise + return (1 - alpha) * clip + alpha * noise class RandomEQ(torch.nn.Module): @@ -285,6 +303,7 @@ def __init__(self, cfg: config.Config, norm=False): self.length = cfg.chunk_length_s self.device = cfg.prepros_device self.norm = norm + self.cfg = cfg if self.noise_path_str != "" and cfg.bg_noise_p > 0.0: files = list(os.listdir(self.noise_path)) audio_extensions = (".mp3",".wav",".ogg",".flac",".opus",".sphere",".pt") @@ -309,6 +328,10 @@ def forward(self, clip: torch.Tensor) -> torch.Tensor: """ # Skip loading if no noise path alpha = utils.rand(*self.alpha_range) + training_proportion = get_training_proportion() + alpha = alpha + ((training_proportion-0.5) + * alpha + * self.cfg.curriculum_learning_scale_factor) if self.noise_path_str == "": return clip # If loading fails, skip for now diff --git a/pyha_analyzer/default_config.yml b/pyha_analyzer/default_config.yml index 725cc61..8a369e8 100644 --- a/pyha_analyzer/default_config.yml +++ b/pyha_analyzer/default_config.yml @@ -7,6 +7,9 @@ infer_csv: # Optional, automatically generates class order if not given class_list: +#Curriculum learning +curriculum_learning_scale_factor: 1.3 + # Dataframe column names offset_col: "OFFSET" duration_col: "DURATION" diff --git a/pyha_analyzer/train.py b/pyha_analyzer/train.py index e4c1ca2..5dd3481 100644 --- a/pyha_analyzer/train.py +++ b/pyha_analyzer/train.py @@ -321,9 +321,10 @@ def logging_setup() -> None: def main(in_sweep=True) -> None: """ Main function """ + + setattr(cfg, "current_epoch", 0) logger.info("Device is: %s, Preprocessing Device is %s", cfg.device, cfg.prepros_device) set_seed(cfg.seed) - if in_sweep: run = wandb.init() for key, val in dict(wandb.config).items(): @@ -360,6 +361,7 @@ def main(in_sweep=True) -> None: for epoch in range(cfg.epochs): logger.info("Epoch %d", epoch) + setattr(cfg, "current_epoch", epoch) best_valid_cmap = train(model_for_run, train_dataloader,