From 4f96a75bbafbac99e1464fac3edf65d1a4891a43 Mon Sep 17 00:00:00 2001 From: Ludwigvsch Date: Sat, 31 Aug 2024 02:46:47 +0000 Subject: [PATCH] fixing custom model and changing it to base BirdAVES model --- ...ves-biox-base.torchaudio.model_config.json | 53 +++++++++++++ pyha_analyzer/models/CustomModel.py | 75 +++++++++++++------ .../models/inference_M_dataset.ipynb | 25 +++++++ pyha_analyzer/models/loss_fn.py | 2 +- pyha_analyzer/train.py | 17 +++-- 5 files changed, 142 insertions(+), 30 deletions(-) create mode 100644 birdaves-biox-base.torchaudio.model_config.json create mode 100644 pyha_analyzer/models/inference_M_dataset.ipynb diff --git a/birdaves-biox-base.torchaudio.model_config.json b/birdaves-biox-base.torchaudio.model_config.json new file mode 100644 index 0000000..17c3154 --- /dev/null +++ b/birdaves-biox-base.torchaudio.model_config.json @@ -0,0 +1,53 @@ +{ + "extractor_mode": "group_norm", + "extractor_conv_layer_config": [ + [ + 512, + 10, + 5 + ], + [ + 512, + 3, + 2 + ], + [ + 512, + 3, + 2 + ], + [ + 512, + 3, + 2 + ], + [ + 512, + 3, + 2 + ], + [ + 512, + 2, + 2 + ], + [ + 512, + 2, + 2 + ] + ], + "extractor_conv_bias": false, + "encoder_embed_dim": 768, + "encoder_projection_dropout": 0.1, + "encoder_pos_conv_kernel": 128, + "encoder_pos_conv_groups": 16, + "encoder_num_layers": 12, + "encoder_num_heads": 12, + "encoder_attention_dropout": 0.1, + "encoder_ff_interm_features": 3072, + "encoder_ff_interm_dropout": 0.0, + "encoder_dropout": 0.1, + "encoder_layer_norm_first": false, + "encoder_layer_drop": 0.05 +} \ No newline at end of file diff --git a/pyha_analyzer/models/CustomModel.py b/pyha_analyzer/models/CustomModel.py index 58a9cb0..e728f6e 100644 --- a/pyha_analyzer/models/CustomModel.py +++ b/pyha_analyzer/models/CustomModel.py @@ -14,44 +14,75 @@ class CustomModel(nn.Module): - def __init__(self, config_path, model_path, num_classes, trainable, embedding_dim=768): - super(CustomModel, self).__init__() + """ Uses AVES Hubert to embed sounds and classify """ + def __init__(self, cfg, num_classes, model_path, trainable, config_path, embedding_dim=768): + super().__init__() + # reference: https://pytorch.org/audio/stable/_modules/torchaudio/models/wav2vec2/utils/import_fairseq.html + self.cfg = cfg + self.trainable = trainable self.config = self.load_config(config_path) self.model = wav2vec2_model(**self.config, aux_num_out=None) self.model.load_state_dict(torch.load(model_path)) - self.trainable = trainable + # Freeze the AVES network self.freeze_embedding_weights(self.model, trainable) - self.classifier_head = nn.Linear(in_features=embedding_dim, out_features=num_classes) - self.loss_fn = None + # Add a linear layer to match the embedding dimensions + self.embedding_transform = nn.Linear(768, num_classes) #TODO: change this when switching models + # We will only train the classifier head + #self.classifier_head = nn.Linear(in_features=embedding_dim, out_features=num_classes) + self.audio_sr = cfg.sample_rate def load_config(self, config_path): - with open(config_path, 'r') as f: - return json.load(f) + with open(config_path, 'r') as ff: + obj = json.load(ff) + return obj + + def forward(self, sig): + """ + Input + sig (Tensor): (batch, time) + Returns + mean_embedding (Tensor): (batch, output_dim) + logits (Tensor): (batch, n_classes) + """ + # extract_feature in the torchaudio version will output all 12 layers' output, -1 to select the final one + out = self.model.extract_features(sig)[0][-1] + mean_embedding = out.mean(dim=1) #over time + logits = self.embedding_transform(mean_embedding) # Transform embedding dimensions + #logits = self.classifier_head(mean_embedding) + return mean_embedding, logits def freeze_embedding_weights(self, model, trainable): + """ Freeze weights in AVES embeddings for classification """ + # The convolutional layers should never be trainable model.feature_extractor.requires_grad_(False) model.feature_extractor.eval() + # The transformers are optionally trainable for param in model.encoder.parameters(): param.requires_grad = trainable if not trainable: + # We also set layers without params (like dropout) to eval mode, so they do not change model.encoder.eval() + + def set_eval_aves(model): + """ Set AVES-based classifier to eval mode. Takes into account whether we are training transformers """ + model.classifier_head.eval() + model.model.encoder.eval() + + - def forward(self, sig): - out = self.model.extract_features(sig)[0][-1] - mean_embedding = out.mean(dim=1) - logits = self.classifier_head(mean_embedding) - return mean_embedding, logits - - def create_loss_fn(self, cfg, train_dataset): - loss_desc = cfg.loss_fnc + def create_loss_fn(self, train_dataset): + loss_desc = self.cfg.loss_fnc if loss_desc == "CE": - self.loss_fn = nn.CrossEntropyLoss() - elif loss_desc == "BCE": - self.loss_fn = nn.BCEWithLogitsLoss() - else: - raise RuntimeError(f"Unsupported loss function: {loss_desc}") + return cross_entropy_loss_fn(self, train_dataset) + if loss_desc == "BCE": + return bce_loss_fn(self, without_logits=True) + if loss_desc == "BCEWL": + return bce_loss_fn(self, without_logits=False) + if loss_desc == "FL": + return focal_loss_fn(self, self.without_logits) + raise RuntimeError("Unsupported loss function") def download_model_files(): import os - os.system("wget https://storage.googleapis.com/esp-public-files/ported_aves/aves-base-bio.torchaudio.pt") - os.system("wget https://storage.googleapis.com/esp-public-files/ported_aves/aves-base-bio.torchaudio.model_config.json") + os.system("wget https://storage.googleapis.com/esp-public-files/ported_aves/birdaves-biox-base.torchaudio.pt") + os.system("wget https://storage.googleapis.com/esp-public-files/ported_aves/birdaves-biox-base.torchaudio.model_config.json") diff --git a/pyha_analyzer/models/inference_M_dataset.ipynb b/pyha_analyzer/models/inference_M_dataset.ipynb new file mode 100644 index 0000000..a4aabc8 --- /dev/null +++ b/pyha_analyzer/models/inference_M_dataset.ipynb @@ -0,0 +1,25 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyha_analyzer/models/loss_fn.py b/pyha_analyzer/models/loss_fn.py index 39705ff..12d59ac 100644 --- a/pyha_analyzer/models/loss_fn.py +++ b/pyha_analyzer/models/loss_fn.py @@ -27,7 +27,7 @@ def bce_loss_fn(self, without_logits=False): BCEwithLogitsLoss """ if not without_logits: - self.loss_fn = nn.BCEWithLogitsLoss(reduction='sum') + self.loss_fn = nn.BCEWithLogitsLoss(reduction='mean') else: self.loss_fn = nn.BCELoss(reduction='mean') return self.loss_fn diff --git a/pyha_analyzer/train.py b/pyha_analyzer/train.py index 2249c96..2a11c7e 100644 --- a/pyha_analyzer/train.py +++ b/pyha_analyzer/train.py @@ -321,10 +321,10 @@ def download_model_files(): import urllib.request urls = [ - #"https://storage.googleapis.com/esp-public-files/birdaves/birdaves-biox-large.torchaudio.pt", - #"https://storage.googleapis.com/esp-public-files/birdaves/birdaves-biox-large.torchaudio.model_config.json" - "https://storage.googleapis.com/esp-public-files/ported_aves/aves-base-bio.torchaudio.pt", - "https://storage.googleapis.com/esp-public-files/ported_aves/aves-base-bio.torchaudio.model_config.json" + "https://storage.googleapis.com/esp-public-files/birdaves/birdaves-biox-base.torchaudio.pt", + "https://storage.googleapis.com/esp-public-files/birdaves/birdaves-biox-base.torchaudio.model_config.json" + #"https://storage.googleapis.com/esp-public-files/ported_aves/aves-base-bio.torchaudio.pt", + #"https://storage.googleapis.com/esp-public-files/ported_aves/aves-base-bio.torchaudio.model_config.json" ] for url in urls: filename = url.split("/")[-1] @@ -368,12 +368,15 @@ def main(in_sweep=True) -> None: logger.info("Loading Model...") download_model_files() model_for_run = CustomModel( - config_path= "aves-base-bio.torchaudio.model_config.json", #aves-base-bio.torchaudio.model_config.json", - model_path= "aves-base-bio.torchaudio.pt", #"aves-base-bio.torchaudio.pt", + config_path="birdaves-biox-base.torchaudio.model_config.json", + model_path="birdaves-biox-base.torchaudio.pt", + cfg=cfg, + #config_path= "aves-base-bio.torchaudio.model_config.json", #aves-base-bio.torchaudio.model_config.json", + #model_path= "aves-base-bio.torchaudio.pt", #"aves-base-bio.torchaudio.pt", num_classes=train_dataset.num_classes, trainable=cfg.trainable, ).to(cfg.device) - model_for_run.create_loss_fn(cfg, train_dataset) + model_for_run.create_loss_fn(train_dataset) optimizer = Adam(model_for_run.parameters(), lr=cfg.learning_rate) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=1e-5, T_max=10)