From 79e5cdfe56e1811e72312f9680a475153e5890a1 Mon Sep 17 00:00:00 2001 From: Benjamin Cretois Date: Wed, 10 Jan 2024 13:25:17 +0100 Subject: [PATCH 1/3] [ADD] option to train a baseline model --- Models/baseline.py | 33 ++++++++++++++++++++ callbacks/callbacks.py | 4 +-- prototypicalbeats/prototraining.py | 50 ++++++++++++++++++++---------- 3 files changed, 69 insertions(+), 18 deletions(-) create mode 100644 Models/baseline.py diff --git a/Models/baseline.py b/Models/baseline.py new file mode 100644 index 0000000..34392c4 --- /dev/null +++ b/Models/baseline.py @@ -0,0 +1,33 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch +from collections import OrderedDict + +def conv_block(in_channels,out_channels): + + return nn.Sequential( + nn.Conv2d(in_channels,out_channels,3,padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.MaxPool2d(2) + ) + +class ProtoNet(nn.Module): + def __init__(self): + super(ProtoNet,self).__init__() + self.encoder = nn.Sequential( + conv_block(1,64), + conv_block(64,64), + conv_block(64,64), + conv_block(64,64) + ) + def forward(self,x): + (num_samples,seq_len,mel_bins) = x.shape + x = x.view(-1,1,seq_len,mel_bins) + x = self.encoder(x) + x = nn.MaxPool2d(2)(x) + + return x.view(x.size(0),-1) + + def extract_features(self, x, padding_mask=None): + return self.forward(x) \ No newline at end of file diff --git a/callbacks/callbacks.py b/callbacks/callbacks.py index 4b6a4db..4ae16d9 100644 --- a/callbacks/callbacks.py +++ b/callbacks/callbacks.py @@ -11,7 +11,7 @@ def __init__(self, milestones: int = 1): self.milestones = milestones def freeze_before_training(self, pl_module: pl.LightningModule): - self.freeze(modules=pl_module.beats) + self.freeze(modules=pl_module.model) def finetune_function( self, @@ -23,5 +23,5 @@ def finetune_function( if epoch == self.milestones: # unfreeze BEATs self.unfreeze_and_add_param_group( - modules=pl_module.beats, optimizer=optimizer + modules=pl_module.model, optimizer=optimizer ) diff --git a/prototypicalbeats/prototraining.py b/prototypicalbeats/prototraining.py index baf1225..859f147 100644 --- a/prototypicalbeats/prototraining.py +++ b/prototypicalbeats/prototraining.py @@ -11,6 +11,7 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_info from BEATs.BEATs import BEATs, BEATsConfig +from Models.baseline import ProtoNet class ProtoBEATsModel(pl.LightningModule): def __init__( @@ -20,6 +21,7 @@ def __init__( lr: float = 1e-5, lr_scheduler_gamma: float = 1e-1, num_workers: int = 6, + model_type: str = "baseline", # or baseline model_path: str = "/data/BEATs/BEATs_iter3_plus_AS2M.pt", distance: str = "euclidean", specaugment_params = None, @@ -36,15 +38,18 @@ def __init__( self.num_workers = num_workers self.milestones = milestones self.distance = distance - # Initialise BEATs model - self.checkpoint = torch.load(model_path) - self.cfg = BEATsConfig( - { - **self.checkpoint["cfg"], - "finetuned_model": False, - "specaugment_params": specaugment_params, - } - ) + self.model_type = model_type + + # If BEATS --> initialise BEATs model + if self.model_type == "beats": + self.checkpoint = torch.load(model_path) + self.cfg = BEATsConfig( + { + **self.checkpoint["cfg"], + "finetuned_model": False, + "specaugment_params": specaugment_params, + } + ) self._build_model() self.save_hyperparameters() @@ -53,8 +58,13 @@ def __init__( self.valid_acc = Accuracy(task="multiclass", num_classes=self.n_way) def _build_model(self): - self.beats = BEATs(self.cfg) - self.beats.load_state_dict(self.checkpoint["model"]) + if self.model_type == "baseline": + print("[MODEL] Loading the baseline model") + self.model = ProtoNet() + if self.model_type == "beats": + print("[MODEL] Loading the BEATs model") + self.model = BEATs(self.cfg) + self.model.load_state_dict(self.checkpoint["model"]) def euclidean_distance(self, x1, x2): return torch.sqrt(torch.sum((x1 - x2) ** 2, dim=1)) @@ -92,7 +102,7 @@ def get_prototypes(self, z_support, support_labels, n_way): def get_embeddings(self, input, padding_mask): """Return the embeddings and the padding mask""" - return self.beats.extract_features(input, padding_mask) + return self.model.extract_features(input, padding_mask) def forward(self, support_images: torch.Tensor, @@ -101,8 +111,15 @@ def forward(self, padding_mask=None): # Extract the features of support and query images - z_support, _ = self.get_embeddings(support_images, padding_mask) - z_query, _ = self.get_embeddings(query_images, padding_mask) + if self.model_type == "beats": + z_support, _ = self.get_embeddings(support_images, padding_mask) + z_query, _ = self.get_embeddings(query_images, padding_mask) + print(z_support.shape) + + else: + z_support = self.get_embeddings(support_images, padding_mask) + z_query = self.get_embeddings(query_images, padding_mask) + print(z_support.shape) # Infer the number of classes from the labels of the support set n_way = len(torch.unique(support_labels)) @@ -127,7 +144,8 @@ def forward(self, dists = torch.stack(dists, dim=0) # We drop the last dimension without changing the gradients - dists = dists.mean(dim=2).squeeze() + if self.model_type == "beats": + dists = dists.mean(dim=2).squeeze() scores = -dists @@ -170,7 +188,7 @@ def validation_step(self, batch, batch_idx): def configure_optimizers(self): optimizer = optim.AdamW( - self.beats.parameters(), + self.model.parameters(), lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01 ) return optimizer From 2a048e6eb063badbea494cada8c292a72c462385 Mon Sep 17 00:00:00 2001 From: Benjamin Cretois Date: Thu, 11 Jan 2024 11:02:19 +0100 Subject: [PATCH 2/3] [ADD] PANN model --- CONFIG.yaml | 11 +- Models/baseline.py | 1 + Models/pann.py | 487 +++++++++++++++++++++++++++++ prototypicalbeats/prototraining.py | 12 +- 4 files changed, 505 insertions(+), 6 deletions(-) create mode 100644 Models/pann.py diff --git a/CONFIG.yaml b/CONFIG.yaml index e05c320..f24ed54 100644 --- a/CONFIG.yaml +++ b/CONFIG.yaml @@ -12,14 +12,14 @@ data: n_task_val: 100 target_fs: 16000 resample: true - denoise: true + denoise: true # should the data be denoised using spectral gating https://github.com/timsainb/noisereduce normalize: true frame_length: 25.0 - tensor_length: 128 + tensor_length: 128 # used to extract a random segment of length $tensor_length out of the mel-spec of a positive sample n_shot: 5 n_query: 10 overlap: 0.5 - n_subsample: 1 # ask Femke what this stands for + n_subsample: 1 # status: train # train or validate or evaluate ################################# @@ -28,7 +28,7 @@ data: # Be sure the parameters match the ones in data processing trainer: - max_epochs: 1 + max_epochs: 1 # number of epochs to train the model on default_root_dir: /data accelerator: gpu gpus: 1 @@ -36,7 +36,8 @@ trainer: model: distance: euclidean # other option is mahalanobis lr: 1.0e-05 - model_path: /data/BEATs/BEATs_iter3_plus_AS2M.pt # Or FOR INFERENCE: /data/lightning_logs/version_X/checkpoints/epoch=X-step=1500.ckpt + model_type: pann # beats, pann or baseline + model_path: /data/model/PANN/Cnn14_mAP=0.431.pth # /data/model/BEATs/BEATs_iter3_plus_AS2M.pt # # Or FOR INFERENCE: /data/lightning_logs/version_X/checkpoints/epoch=X-step=1500.ckpt specaugment_params: null # specaugment_params: # application_ratio: 1.0 diff --git a/Models/baseline.py b/Models/baseline.py index 34392c4..4ada08e 100644 --- a/Models/baseline.py +++ b/Models/baseline.py @@ -24,6 +24,7 @@ def __init__(self): def forward(self,x): (num_samples,seq_len,mel_bins) = x.shape x = x.view(-1,1,seq_len,mel_bins) + print(x.shape) x = self.encoder(x) x = nn.MaxPool2d(2)(x) diff --git a/Models/pann.py b/Models/pann.py new file mode 100644 index 0000000..9142247 --- /dev/null +++ b/Models/pann.py @@ -0,0 +1,487 @@ +# TAKEN FOM: https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +#from pytorch_utils import do_mixup, interpolate, pad_framewise_output + + +def init_layer(layer): + """Initialize a Linear or Convolutional layer. """ + nn.init.xavier_uniform_(layer.weight) + + if hasattr(layer, 'bias'): + if layer.bias is not None: + layer.bias.data.fill_(0.) + + +def init_bn(bn): + """Initialize a Batchnorm layer. """ + bn.bias.data.fill_(0.) + bn.weight.data.fill_(1.) + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), stride=(1, 1), + padding=(1, 1), bias=False) + + self.conv2 = nn.Conv2d(in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), stride=(1, 1), + padding=(1, 1), bias=False) + + self.bn1 = nn.BatchNorm2d(out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_layer(self.conv2) + init_bn(self.bn1) + init_bn(self.bn2) + + + def forward(self, input, pool_size=(2, 2), pool_type='avg'): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + x = F.relu_(self.bn2(self.conv2(x))) + if pool_type == 'max': + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == 'avg': + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == 'avg+max': + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception('Incorrect argument!') + + return x + + +class ConvBlock5x5(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock5x5, self).__init__() + + self.conv1 = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(5, 5), stride=(1, 1), + padding=(2, 2), bias=False) + + self.bn1 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_bn(self.bn1) + + + def forward(self, input, pool_size=(2, 2), pool_type='avg'): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + if pool_type == 'max': + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == 'avg': + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == 'avg+max': + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception('Incorrect argument!') + + return x + + +class AttBlock(nn.Module): + def __init__(self, n_in, n_out, activation='linear', temperature=1.): + super(AttBlock, self).__init__() + + self.activation = activation + self.temperature = temperature + self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) + self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) + + self.bn_att = nn.BatchNorm1d(n_out) + self.init_weights() + + def init_weights(self): + init_layer(self.att) + init_layer(self.cla) + init_bn(self.bn_att) + + def forward(self, x): + # x: (n_samples, n_in, n_time) + norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) + cla = self.nonlinear_transform(self.cla(x)) + x = torch.sum(norm_att * cla, dim=2) + return x, norm_att, cla + + def nonlinear_transform(self, x): + if self.activation == 'linear': + return x + elif self.activation == 'sigmoid': + return torch.sigmoid(x) + + +class Cnn14(nn.Module): + #def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, + # fmax, classes_num): + def __init__(self, classes_num=1): + super(Cnn14, self).__init__() + + #window = 'hann' + #center = True + #pad_mode = 'reflect' + #ref = 1.0 + #amin = 1e-10 + #top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, + win_length=window_size, window=window, center=center, pad_mode=pad_mode, + freeze_parameters=True) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, + n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, + freeze_parameters=True) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, + freq_drop_width=8, freq_stripes_num=2) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length)""" + #x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) --> 1 is the number of channels + #x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + x = input.unsqueeze(1) + print(x.shape) + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + #if self.training: + # x = self.spec_augmenter(x) + + # Mixup on spectrogram + #if self.training and mixup_lambda is not None: + # x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + #clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + #output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} + + return embedding # output_dict + + def extract_features(self, x, padding_mask=None): + return self.forward(x) + + + + +import numpy as np +import time +import torch +import torch.nn as nn + + +def move_data_to_device(x, device): + if 'float' in str(x.dtype): + x = torch.Tensor(x) + elif 'int' in str(x.dtype): + x = torch.LongTensor(x) + else: + return x + + return x.to(device) + + +def do_mixup(x, mixup_lambda): + """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes + (1, 3, 5, ...). + + Args: + x: (batch_size * 2, ...) + mixup_lambda: (batch_size * 2,) + + Returns: + out: (batch_size, ...) + """ + out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \ + x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1) + return out + + +def append_to_dict(dict, key, value): + if key in dict.keys(): + dict[key].append(value) + else: + dict[key] = [value] + + +def forward(model, generator, return_input=False, + return_target=False): + """Forward data to a model. + + Args: + model: object + generator: object + return_input: bool + return_target: bool + + Returns: + audio_name: (audios_num,) + clipwise_output: (audios_num, classes_num) + (ifexist) segmentwise_output: (audios_num, segments_num, classes_num) + (ifexist) framewise_output: (audios_num, frames_num, classes_num) + (optional) return_input: (audios_num, segment_samples) + (optional) return_target: (audios_num, classes_num) + """ + output_dict = {} + device = next(model.parameters()).device + time1 = time.time() + + # Forward data to a model in mini-batches + for n, batch_data_dict in enumerate(generator): + print(n) + batch_waveform = move_data_to_device(batch_data_dict['waveform'], device) + + with torch.no_grad(): + model.eval() + batch_output = model(batch_waveform) + + append_to_dict(output_dict, 'audio_name', batch_data_dict['audio_name']) + + append_to_dict(output_dict, 'clipwise_output', + batch_output['clipwise_output'].data.cpu().numpy()) + + if 'segmentwise_output' in batch_output.keys(): + append_to_dict(output_dict, 'segmentwise_output', + batch_output['segmentwise_output'].data.cpu().numpy()) + + if 'framewise_output' in batch_output.keys(): + append_to_dict(output_dict, 'framewise_output', + batch_output['framewise_output'].data.cpu().numpy()) + + if return_input: + append_to_dict(output_dict, 'waveform', batch_data_dict['waveform']) + + if return_target: + if 'target' in batch_data_dict.keys(): + append_to_dict(output_dict, 'target', batch_data_dict['target']) + + if n % 10 == 0: + print(' --- Inference time: {:.3f} s / 10 iterations ---'.format( + time.time() - time1)) + time1 = time.time() + + for key in output_dict.keys(): + output_dict[key] = np.concatenate(output_dict[key], axis=0) + + return output_dict + + +def interpolate(x, ratio): + """Interpolate data in time domain. This is used to compensate the + resolution reduction in downsampling of a CNN. + + Args: + x: (batch_size, time_steps, classes_num) + ratio: int, ratio to interpolate + + Returns: + upsampled: (batch_size, time_steps * ratio, classes_num) + """ + (batch_size, time_steps, classes_num) = x.shape + upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) + upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) + return upsampled + + +def pad_framewise_output(framewise_output, frames_num): + """Pad framewise_output to the same length as input frames. The pad value + is the same as the value of the last frame. + + Args: + framewise_output: (batch_size, frames_num, classes_num) + frames_num: int, number of frames to pad + + Outputs: + output: (batch_size, frames_num, classes_num) + """ + pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1) + """tensor for padding""" + + output = torch.cat((framewise_output, pad), dim=1) + """(batch_size, frames_num, classes_num)""" + + return output + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def count_flops(model, audio_length): + """Count flops. Code modified from others' implementation. + """ + multiply_adds = True + list_conv2d=[] + def conv2d_hook(self, input, output): + batch_size, input_channels, input_height, input_width = input[0].size() + output_channels, output_height, output_width = output[0].size() + + kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) + bias_ops = 1 if self.bias is not None else 0 + + params = output_channels * (kernel_ops + bias_ops) + flops = batch_size * params * output_height * output_width + + list_conv2d.append(flops) + + list_conv1d=[] + def conv1d_hook(self, input, output): + batch_size, input_channels, input_length = input[0].size() + output_channels, output_length = output[0].size() + + kernel_ops = self.kernel_size[0] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) + bias_ops = 1 if self.bias is not None else 0 + + params = output_channels * (kernel_ops + bias_ops) + flops = batch_size * params * output_length + + list_conv1d.append(flops) + + list_linear=[] + def linear_hook(self, input, output): + batch_size = input[0].size(0) if input[0].dim() == 2 else 1 + + weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) + bias_ops = self.bias.nelement() + + flops = batch_size * (weight_ops + bias_ops) + list_linear.append(flops) + + list_bn=[] + def bn_hook(self, input, output): + list_bn.append(input[0].nelement() * 2) + + list_relu=[] + def relu_hook(self, input, output): + list_relu.append(input[0].nelement() * 2) + + list_pooling2d=[] + def pooling2d_hook(self, input, output): + batch_size, input_channels, input_height, input_width = input[0].size() + output_channels, output_height, output_width = output[0].size() + + kernel_ops = self.kernel_size * self.kernel_size + bias_ops = 0 + params = output_channels * (kernel_ops + bias_ops) + flops = batch_size * params * output_height * output_width + + list_pooling2d.append(flops) + + list_pooling1d=[] + def pooling1d_hook(self, input, output): + batch_size, input_channels, input_length = input[0].size() + output_channels, output_length = output[0].size() + + kernel_ops = self.kernel_size[0] + bias_ops = 0 + + params = output_channels * (kernel_ops + bias_ops) + flops = batch_size * params * output_length + + list_pooling2d.append(flops) + + def foo(net): + childrens = list(net.children()) + if not childrens: + if isinstance(net, nn.Conv2d): + net.register_forward_hook(conv2d_hook) + elif isinstance(net, nn.Conv1d): + net.register_forward_hook(conv1d_hook) + elif isinstance(net, nn.Linear): + net.register_forward_hook(linear_hook) + elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d): + net.register_forward_hook(bn_hook) + elif isinstance(net, nn.ReLU): + net.register_forward_hook(relu_hook) + elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d): + net.register_forward_hook(pooling2d_hook) + elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d): + net.register_forward_hook(pooling1d_hook) + else: + print('Warning: flop of module {} is not counted!'.format(net)) + return + for c in childrens: + foo(c) + + # Register hook + foo(model) + + device = device = next(model.parameters()).device + input = torch.rand(1, audio_length).to(device) + + out = model(input) + + total_flops = sum(list_conv2d) + sum(list_conv1d) + sum(list_linear) + \ + sum(list_bn) + sum(list_relu) + sum(list_pooling2d) + sum(list_pooling1d) + + return total_flops \ No newline at end of file diff --git a/prototypicalbeats/prototraining.py b/prototypicalbeats/prototraining.py index 859f147..1669d3c 100644 --- a/prototypicalbeats/prototraining.py +++ b/prototypicalbeats/prototraining.py @@ -12,6 +12,7 @@ from BEATs.BEATs import BEATs, BEATsConfig from Models.baseline import ProtoNet +from Models.pann import Cnn14 class ProtoBEATsModel(pl.LightningModule): def __init__( @@ -22,7 +23,7 @@ def __init__( lr_scheduler_gamma: float = 1e-1, num_workers: int = 6, model_type: str = "baseline", # or baseline - model_path: str = "/data/BEATs/BEATs_iter3_plus_AS2M.pt", + model_path: str = "/data/model/BEATs/BEATs_iter3_plus_AS2M.pt", distance: str = "euclidean", specaugment_params = None, **kwargs, @@ -50,6 +51,10 @@ def __init__( "specaugment_params": specaugment_params, } ) + + # If we are using the PANN model: + if self.model_type == "pann": + self.checkpoint = torch.load(model_path) self._build_model() self.save_hyperparameters() @@ -65,6 +70,11 @@ def _build_model(self): print("[MODEL] Loading the BEATs model") self.model = BEATs(self.cfg) self.model.load_state_dict(self.checkpoint["model"]) + if self.model_type == "pann": + print("[MODEL] Loading the PANN model") + self.model = Cnn14() + self.model.load_state_dict(self.checkpoint["model"]) + def euclidean_distance(self, x1, x2): return torch.sqrt(torch.sum((x1 - x2) ** 2, dim=1)) From f5148f92de9a22ab5fc87dd2ae0c92083a5e0963 Mon Sep 17 00:00:00 2001 From: Benjamin Cretois Date: Thu, 11 Jan 2024 13:49:22 +0100 Subject: [PATCH 3/3] [ADD] torchlibrosa + MaxPool to PANN --- CONFIG.yaml | 4 ++-- Models/pann.py | 32 +++++++++++++++++------------- callbacks/callbacks.py | 2 +- poetry.lock | 22 +++++++++++++++++++- prototypicalbeats/prototraining.py | 16 +++++++++++---- prototypicalbeats/trainer.py | 2 -- pyproject.toml | 1 + 7 files changed, 55 insertions(+), 24 deletions(-) diff --git a/CONFIG.yaml b/CONFIG.yaml index f24ed54..004e712 100644 --- a/CONFIG.yaml +++ b/CONFIG.yaml @@ -28,8 +28,8 @@ data: # Be sure the parameters match the ones in data processing trainer: - max_epochs: 1 # number of epochs to train the model on - default_root_dir: /data + max_epochs: 100 # number of epochs to train the model on + default_root_dir: /data/lightning_logs/pann # folder that contains all the logs accelerator: gpu gpus: 1 diff --git a/Models/pann.py b/Models/pann.py index 9142247..8079fda 100644 --- a/Models/pann.py +++ b/Models/pann.py @@ -139,8 +139,8 @@ def nonlinear_transform(self, x): class Cnn14(nn.Module): - #def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, - # fmax, classes_num): + #def __init__(self, sample_rate=1, window_size=1, hop_size=1, mel_bins=1, fmin=1, + # fmax=1, classes_num=1): def __init__(self, classes_num=1): super(Cnn14, self).__init__() @@ -152,18 +152,22 @@ def __init__(self, classes_num=1): #top_db = None # Spectrogram extractor - self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, - win_length=window_size, window=window, center=center, pad_mode=pad_mode, - freeze_parameters=True) + #self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, + # win_length=window_size, window=window, center=center, pad_mode=pad_mode, + # freeze_parameters=True) # Logmel feature extractor - self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, - n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, - freeze_parameters=True) + #self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, + # n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, + # freeze_parameters=True) # Spec augmenter - self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, - freq_drop_width=8, freq_stripes_num=2) + #self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, + # freq_drop_width=8, freq_stripes_num=2) + + # OUR INPUT IS A SPECTROGRAM 128x128 SO HERE WE MAXPOOL TO CONFORM TO + # THE INPUT OF PANN + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # --> (batch_size, n_channels, H=128, W=128) --> (batch_size, n_channels, H=64, W=64) self.bn0 = nn.BatchNorm2d(64) @@ -174,15 +178,15 @@ def __init__(self, classes_num=1): self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) - self.fc1 = nn.Linear(2048, 2048, bias=True) - self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + self.fc1 = nn.Linear(2048, 2048, bias=True) # We get an embedding in a 2048 dimension space + #self.fc_audioset = nn.Linear(2048, classes_num, bias=True) self.init_weight() def init_weight(self): init_bn(self.bn0) init_layer(self.fc1) - init_layer(self.fc_audioset) + #init_layer(self.fc_audioset) def forward(self, input, mixup_lambda=None): """ @@ -190,7 +194,7 @@ def forward(self, input, mixup_lambda=None): #x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) --> 1 is the number of channels #x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) x = input.unsqueeze(1) - print(x.shape) + x = self.pool(x) x = x.transpose(1, 3) x = self.bn0(x) x = x.transpose(1, 3) diff --git a/callbacks/callbacks.py b/callbacks/callbacks.py index 4ae16d9..8c38944 100644 --- a/callbacks/callbacks.py +++ b/callbacks/callbacks.py @@ -6,7 +6,7 @@ # See https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/computer_vision_fine_tuning.py class MilestonesFinetuning(BaseFinetuning): - def __init__(self, milestones: int = 1): + def __init__(self, milestones: int = 100): super().__init__() self.milestones = milestones diff --git a/poetry.lock b/poetry.lock index 41215e8..44c40f2 100755 --- a/poetry.lock +++ b/poetry.lock @@ -1037,6 +1037,7 @@ files = [ {file = "greenlet-2.0.2-cp27-cp27m-win32.whl", hash = "sha256:6c3acb79b0bfd4fe733dff8bc62695283b57949ebcca05ae5c129eb606ff2d74"}, {file = "greenlet-2.0.2-cp27-cp27m-win_amd64.whl", hash = "sha256:283737e0da3f08bd637b5ad058507e578dd462db259f7f6e4c5c365ba4ee9343"}, {file = "greenlet-2.0.2-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:d27ec7509b9c18b6d73f2f5ede2622441de812e7b1a80bbd446cb0633bd3d5ae"}, + {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d967650d3f56af314b72df7089d96cda1083a7fc2da05b375d2bc48c82ab3f3c"}, {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:30bcf80dda7f15ac77ba5af2b961bdd9dbc77fd4ac6105cee85b0d0a5fcf74df"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26fbfce90728d82bc9e6c38ea4d038cba20b7faf8a0ca53a9c07b67318d46088"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9190f09060ea4debddd24665d6804b995a9c122ef5917ab26e1566dcc712ceeb"}, @@ -1045,6 +1046,7 @@ files = [ {file = "greenlet-2.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:76ae285c8104046b3a7f06b42f29c7b73f77683df18c49ab5af7983994c2dd91"}, {file = "greenlet-2.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:2d4686f195e32d36b4d7cf2d166857dbd0ee9f3d20ae349b6bf8afc8485b3645"}, {file = "greenlet-2.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c4302695ad8027363e96311df24ee28978162cdcdd2006476c43970b384a244c"}, + {file = "greenlet-2.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d4606a527e30548153be1a9f155f4e283d109ffba663a15856089fb55f933e47"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c48f54ef8e05f04d6eff74b8233f6063cb1ed960243eacc474ee73a2ea8573ca"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a1846f1b999e78e13837c93c778dcfc3365902cfb8d1bdb7dd73ead37059f0d0"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a06ad5312349fec0ab944664b01d26f8d1f05009566339ac6f63f56589bc1a2"}, @@ -1074,6 +1076,7 @@ files = [ {file = "greenlet-2.0.2-cp37-cp37m-win32.whl", hash = "sha256:3f6ea9bd35eb450837a3d80e77b517ea5bc56b4647f5502cd28de13675ee12f7"}, {file = "greenlet-2.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:7492e2b7bd7c9b9916388d9df23fa49d9b88ac0640db0a5b4ecc2b653bf451e3"}, {file = "greenlet-2.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b864ba53912b6c3ab6bcb2beb19f19edd01a6bfcbdfe1f37ddd1778abfe75a30"}, + {file = "greenlet-2.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1087300cf9700bbf455b1b97e24db18f2f77b55302a68272c56209d5587c12d1"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:ba2956617f1c42598a308a84c6cf021a90ff3862eddafd20c3333d50f0edb45b"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc3a569657468b6f3fb60587e48356fe512c1754ca05a564f11366ac9e306526"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8eab883b3b2a38cc1e050819ef06a7e6344d4a990d24d45bc6f2cf959045a45b"}, @@ -1082,6 +1085,7 @@ files = [ {file = "greenlet-2.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b0ef99cdbe2b682b9ccbb964743a6aca37905fda5e0452e5ee239b1654d37f2a"}, {file = "greenlet-2.0.2-cp38-cp38-win32.whl", hash = "sha256:b80f600eddddce72320dbbc8e3784d16bd3fb7b517e82476d8da921f27d4b249"}, {file = "greenlet-2.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:4d2e11331fc0c02b6e84b0d28ece3a36e0548ee1a1ce9ddde03752d9b79bba40"}, + {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8512a0c38cfd4e66a858ddd1b17705587900dd760c6003998e9472b77b56d417"}, {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:88d9ab96491d38a5ab7c56dd7a3cc37d83336ecc564e4e8816dbed12e5aaefc8"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:561091a7be172ab497a3527602d467e2b3fbe75f9e783d8b8ce403fa414f71a6"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:971ce5e14dc5e73715755d0ca2975ac88cfdaefcaab078a284fea6cfabf866df"}, @@ -3048,6 +3052,22 @@ torch = "1.11.0" type = "url" url = "https://download.pytorch.org/whl/cu113/torchaudio-0.11.0%2Bcu113-cp38-cp38-linux_x86_64.whl" +[[package]] +name = "torchlibrosa" +version = "0.1.0" +description = "PyTorch implemention of part of librosa functions." +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "torchlibrosa-0.1.0-py3-none-any.whl", hash = "sha256:89b65fd28b833ceb6bc74a3d0d87e2924ddc5a845d0a246b194952a4e12a38cb"}, + {file = "torchlibrosa-0.1.0.tar.gz", hash = "sha256:62a8beedf9c9b4141a06234df3f10229f7ba86e67678ccee02489ec4ef044028"}, +] + +[package.dependencies] +librosa = ">=0.8.0" +numpy = "*" + [[package]] name = "torchmetrics" version = "0.11.4" @@ -3330,4 +3350,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "b2328b1e6c54894bd46ff19cefc84c554c6f3704ff6ff78c61e3826d829b4d74" +content-hash = "009f340346be33d1a907f55d43be66ec9019eed277ae84edf1a75f28a4523803" diff --git a/prototypicalbeats/prototraining.py b/prototypicalbeats/prototraining.py index 1669d3c..69cf56e 100644 --- a/prototypicalbeats/prototraining.py +++ b/prototypicalbeats/prototraining.py @@ -22,8 +22,8 @@ def __init__( lr: float = 1e-5, lr_scheduler_gamma: float = 1e-1, num_workers: int = 6, - model_type: str = "baseline", # or baseline - model_path: str = "/data/model/BEATs/BEATs_iter3_plus_AS2M.pt", + model_type: str = "baseline", + model_path: str = None, distance: str = "euclidean", specaugment_params = None, **kwargs, @@ -63,17 +63,27 @@ def __init__( self.valid_acc = Accuracy(task="multiclass", num_classes=self.n_way) def _build_model(self): + if self.model_type == "baseline": print("[MODEL] Loading the baseline model") self.model = ProtoNet() + if self.model_type == "beats": print("[MODEL] Loading the BEATs model") self.model = BEATs(self.cfg) self.model.load_state_dict(self.checkpoint["model"]) + if self.model_type == "pann": print("[MODEL] Loading the PANN model") + layers_to_remove = ["spectrogram_extractor.stft.conv_real.weight", "spectrogram_extractor.stft.conv_imag.weight", "logmel_extractor.melW", + "fc_audioset.weight", "fc_audioset.bias"] + + for key in layers_to_remove: + del self.checkpoint["model"][key] self.model = Cnn14() self.model.load_state_dict(self.checkpoint["model"]) + else: + print("[ERROR] the model specified is not included in the pipeline. Please use 'baseline', 'pann' or 'beats'") def euclidean_distance(self, x1, x2): @@ -124,12 +134,10 @@ def forward(self, if self.model_type == "beats": z_support, _ = self.get_embeddings(support_images, padding_mask) z_query, _ = self.get_embeddings(query_images, padding_mask) - print(z_support.shape) else: z_support = self.get_embeddings(support_images, padding_mask) z_query = self.get_embeddings(query_images, padding_mask) - print(z_support.shape) # Infer the number of classes from the labels of the support set n_way = len(torch.unique(support_labels)) diff --git a/prototypicalbeats/trainer.py b/prototypicalbeats/trainer.py index 9f33a6c..ee73ac4 100755 --- a/prototypicalbeats/trainer.py +++ b/prototypicalbeats/trainer.py @@ -4,11 +4,9 @@ from pytorch_lightning.cli import LightningCLI from prototypicalbeats.prototraining import ProtoBEATsModel -#from datamodules.miniECS50DataModule import miniECS50DataModule from datamodules.DCASEDataModule import DCASEDataModule from callbacks.callbacks import MilestonesFinetuning - class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_lightning_class_args(MilestonesFinetuning, "finetuning") diff --git a/pyproject.toml b/pyproject.toml index 423472b..f2699cc 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ black = "^23.1.0" noisereduce = "^2.0.1" mir-eval = "0.6" numpy = "1.19.5" +torchlibrosa = "^0.1.0" [tool.poetry.dev-dependencies]